Creating extensions using numpy and scipy

Author: Adam Paszke

In this tutorial, we shall go through two tasks:

  1. Create a neural network layer with no parameters.

    • This calls into numpy as part of it’s implementation
  2. Create a neural network layer that has learnable weights

    • This calls into SciPy as part of it’s implementation
import torch
from torch.autograd import Function
from torch.autograd import Variable

Parameter-less example

This layer doesn’t particularly do anything useful or mathematically correct.

It is aptly named BadFFTFunction

Layer Implementation

from numpy.fft import rfft2, irfft2


class BadFFTFunction(Function):

    def forward(self, input):
        numpy_input = input.numpy()
        result = abs(rfft2(numpy_input))
        return input.new(result)

    def backward(self, grad_output):
        numpy_go = grad_output.numpy()
        result = irfft2(numpy_go)
        return grad_output.new(result)

# since this layer does not have any parameters, we can
# simply declare this as a function, rather than as an nn.Module class


def incorrect_fft(input):
    return BadFFTFunction()(input)

Example usage of the created layer:

input = Variable(torch.randn(8, 8), requires_grad=True)
result = incorrect_fft(input)
print(result.data)
result.backward(torch.randn(result.size()))
print(input.grad)

Out:

9.3850   3.7278  12.1252   4.7440   0.1007
  8.7267  10.0111   7.8526   3.9947   5.8996
  2.2608   6.1943   4.1151   4.2829   1.9711
 13.3399   1.4309   3.4576   6.0777   2.9095
  2.2723   5.1269   1.0994   6.9922   0.8258
 13.3399  11.7054   3.6514   7.9590   2.9095
  2.2608  10.0309   1.3558   2.1418   1.9711
  8.7267   7.2774  10.2759   6.5953   5.8996
[torch.FloatTensor of size 8x5]

Variable containing:
-0.1448 -0.2484 -0.0526  0.0304 -0.2208  0.0304 -0.0526 -0.2484
-0.0952  0.1012 -0.1489 -0.2706 -0.1588 -0.0548 -0.1102 -0.0034
-0.1208  0.1651  0.0928 -0.0030 -0.1559 -0.1216  0.1992 -0.0146
-0.2294 -0.0393 -0.1855  0.0452 -0.0132 -0.2218  0.2052  0.0554
-0.2641  0.0647  0.0580  0.0371  0.0597  0.0371  0.0580  0.0647
-0.2294  0.0554  0.2052 -0.2218 -0.0132  0.0452 -0.1855 -0.0393
-0.1208 -0.0146  0.1992 -0.1216 -0.1559 -0.0030  0.0928  0.1651
-0.0952 -0.0034 -0.1102 -0.0548 -0.1588 -0.2706 -0.1489  0.1012
[torch.FloatTensor of size 8x8]

Parametrized example

This implements a layer with learnable weights.

It implements the Cross-correlation with a learnable kernel.

In deep learning literature, it’s confusingly referred to as Convolution.

The backward computes the gradients wrt the input and gradients wrt the filter.

Implementation:

Please Note that the implementation serves as an illustration, and we did not verify it’s correctness

from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter


class ScipyConv2dFunction(Function):
    @staticmethod
    def forward(ctx, input, filter):
        result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
        ctx.save_for_backward(input, filter)
        return input.new(result)

    @staticmethod
    def backward(ctx, grad_output):
        input, filter = ctx.saved_tensors
        grad_output = grad_output.data
        grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')
        grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')

        return Variable(grad_output.new(grad_input)), \
            Variable(grad_output.new(grad_filter))


class ScipyConv2d(Module):

    def __init__(self, kh, kw):
        super(ScipyConv2d, self).__init__()
        self.filter = Parameter(torch.randn(kh, kw))

    def forward(self, input):
        return ScipyConv2dFunction.apply(input, self.filter)

Example usage:

module = ScipyConv2d(3, 3)
print(list(module.parameters()))
input = Variable(torch.randn(10, 10), requires_grad=True)
output = module(input)
print(output)
output.backward(torch.randn(8, 8))
print(input.grad)

Out:

[Parameter containing:
 1.0754 -0.8066  0.3304
-1.2231 -0.2035  0.4939
-0.7975 -0.5881 -0.4792
[torch.FloatTensor of size 3x3]
]
Variable containing:
 0.7439  0.2808 -0.5860 -2.8316 -0.2953 -3.6276 -0.2887 -2.4730
-0.5076  4.2934 -1.0637  2.4279 -3.4973  1.7451 -1.5061 -0.4581
-1.5620 -2.0292 -0.6322 -0.7476  1.3275 -1.3496 -4.6143 -0.3132
 1.1628  0.2356  0.7270 -1.8374  2.3684 -0.8792  1.6460 -3.7633
-1.3106  0.1731  4.7157  1.4189 -0.4938  0.1348 -1.0437 -4.3917
 0.0328  3.8527  0.3729 -0.3629  1.0639 -1.6496 -2.5545 -0.6386
 1.0335  3.0539 -1.2026  0.1993 -2.2277  0.9936 -1.1667 -3.0316
-0.0735  2.0411 -0.1119 -1.1178  0.9447 -2.3621  1.1529 -0.5529
[torch.FloatTensor of size 8x8]

Variable containing:
-0.8229  0.2347  1.2491  2.3353 -3.1271  0.7935  0.7964 -1.6882  1.0310  0.7973
-0.9049  2.5531 -0.9611  2.5009  1.8931 -1.4156  1.4328 -1.0260  2.3507  2.0734
 1.2155 -0.7011  3.0197  0.7824  1.0602 -2.9660  2.6185  1.9948  0.1988  1.8599
-1.6361  1.1401  1.0264 -6.7814  2.7394 -0.2772 -2.1926  2.6752  0.8478  1.3467
 2.1865 -1.7675 -3.1038  1.3559 -0.4572  3.3876 -3.6443  1.0721  2.3393  0.6073
-1.1338 -2.0506  3.8598 -0.1274 -0.6803 -1.5970 -1.0469 -1.5129  4.6992  1.8922
 1.1923 -0.4203  0.1477 -1.6515 -1.7188  3.3253  1.6627  0.0663  0.5617  0.9708
-2.0164  1.7991 -1.1626  2.6752 -0.2630  0.5074  0.2911 -1.7214 -0.0022  1.3144
 1.3249  0.4139  0.8088  1.1148 -0.2780  1.1887  1.0866  1.2730  0.0558  0.2604
-0.4282 -0.6577  0.3834 -0.3729 -0.0176 -0.4416 -0.0738  0.1783 -0.4358  0.3744
[torch.FloatTensor of size 10x10]

Total running time of the script: ( 0 minutes 0.003 seconds)

Gallery generated by Sphinx-Gallery