• Tutorials >
• Creating extensions using numpy and scipy
Shortcuts

Creating extensions using numpy and scipy¶

In this tutorial, we shall go through two tasks:

1. Create a neural network layer with no parameters.

• This calls into numpy as part of it’s implementation
2. Create a neural network layer that has learnable weights

• This calls into SciPy as part of it’s implementation
import torch


Parameter-less example¶

This layer doesn’t particularly do anything useful or mathematically correct.

Layer Implementation

from numpy.fft import rfft2, irfft2

def forward(self, input):
numpy_input = input.detach().numpy()
result = abs(rfft2(numpy_input))
return input.new(result)

result = irfft2(numpy_go)

# since this layer does not have any parameters, we can
# simply declare this as a function, rather than as an nn.Module class

def incorrect_fft(input):


Example usage of the created layer:

input = torch.randn(8, 8, requires_grad=True)
result = incorrect_fft(input)
print(result)
result.backward(torch.randn(result.size()))
print(input)


Out:

tensor([[ 3.6501,  2.4913,  2.4807,  6.1489, 14.1820],
[ 8.1184,  4.9159,  3.4977,  2.8809,  5.6381],
[16.5501,  2.7887,  6.0598,  9.6402,  5.2785],
[ 2.7207, 10.4355,  1.6398, 13.7756,  2.1615],
[ 1.9000,  4.0814,  6.7236,  1.2540,  0.2131],
[ 2.7207,  9.7629, 11.4593, 12.3608,  2.1615],
[16.5501,  6.7424, 12.5088, 11.1904,  5.2785],
[ 8.1184,  2.5428,  9.7204, 11.9310,  5.6381]],
tensor([[-0.3332,  0.5516, -0.0298,  0.2007, -1.0682, -0.0172, -0.7434, -2.4018],
[-1.0572,  0.4485,  0.2664,  0.1740,  0.0141,  0.6235,  1.2641,  0.6440],
[ 0.0600,  1.1582, -0.6999,  0.6557,  1.5535,  1.4846, -0.6100,  1.3682],
[-0.9021,  0.6801,  0.7459, -0.2814, -1.2220,  0.2134, -1.0915,  0.4185],
[-1.4876, -1.1652, -0.1899, -1.8256, -0.3000, -0.0350, -1.6568,  0.8676],
[ 0.7612, -0.3704,  0.4488,  1.4821, -2.0133, -0.1012, -0.2749, -2.0584],
[ 1.4229, -0.9126, -0.2474,  1.4955,  0.0021, -0.2107, -0.5520,  0.8907],
[ 0.4420,  0.2636, -1.0631, -1.2706,  0.9102,  0.0214, -1.2648,  2.2741]],


Parametrized example¶

This implements a layer with learnable weights.

It implements the Cross-correlation with a learnable kernel.

In deep learning literature, it’s confusingly referred to as Convolution.

The backward computes the gradients wrt the input and gradients wrt the filter.

Implementation:

Please Note that the implementation serves as an illustration, and we did not verify it’s correctness

from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter

class ScipyConv2dFunction(Function):
@staticmethod
def forward(ctx, input, filter):
input, filter = input.detach(), filter.detach()  # detach so we can cast to NumPy
result = correlate2d(input.numpy(), filter.detach().numpy(), mode='valid')
ctx.save_for_backward(input, filter)
return input.new(result)

@staticmethod
input, filter = ctx.saved_tensors

class ScipyConv2d(Module):

def __init__(self, kh, kw):
super(ScipyConv2d, self).__init__()
self.filter = Parameter(torch.randn(kh, kw))

def forward(self, input):
return ScipyConv2dFunction.apply(input, self.filter)


Example usage:

module = ScipyConv2d(3, 3)
print(list(module.parameters()))
output = module(input)
print(output)
output.backward(torch.randn(8, 8))


Out:

[Parameter containing:
tensor([[-2.3718, -0.3494,  1.6504],
[ 0.5133,  0.6631,  1.3076],
tensor([[ -2.4874,  -0.0737,   5.3549,  -0.0783,  -0.4158,  -3.5112,   7.9663,
6.4007],
[ -0.1666,   0.4403,   0.0210,  -1.8895,   4.0234,  -2.6265,  -3.1844,
-0.4761],
[ -5.7182,   0.3143,  -0.9324,   0.0886,  -1.4209,   3.5986,   2.7599,
-0.7782],
[  3.0651,   1.7201,  -4.8526,  -3.9738,   1.5550,   2.8246,   2.4622,
-3.8836],
[ -4.1502,  -3.6368,  -3.5543,  -0.5330,   3.1335,   8.6079,  -0.6952,
-4.9075],
[ -2.5105,  -1.2910,   0.9143,   1.5850,  -3.8070,  -5.1687,  -3.6606,
0.1077],
[ -1.9499,   0.1882,  -0.7216,  -3.3668,   9.2076,   6.3375,   2.7028,
4.7159],
[  6.3560,   1.7081,  -6.4317,  -1.7194,   4.1965,  -2.8422, -10.2885,
tensor([[ 0.8247,  0.1782,  0.5291,  5.7989, -2.6241,  3.0677,  0.7841, -1.2099,
0.4242, -0.1436],
[ 1.5484, -1.3916,  5.8290, -1.0568,  0.7348, -5.0618, -2.1580, -2.7158,
-0.0336,  0.0686],
[ 1.4320,  0.4210,  5.8871, -5.9086, -6.6954,  5.1975, -3.9309,  5.5487,
2.3694,  0.0294],
[-3.6414,  5.5913, -5.8681, -4.2472, -1.3091,  2.1656, 11.6295, -0.4403,
-1.0767, -1.7386],
[-5.8986,  1.9896, -0.3923, -5.5039,  4.9920,  0.8166,  0.2139,  0.1244,
-3.0334,  1.6831],
[-5.4207, -0.1731, -4.7614,  0.8801, -2.6112,  2.2662, -5.8488, -3.7419,
1.6701, -0.3737],
[ 1.4507,  0.3673, -7.0348,  2.4159,  1.6810,  9.8278, -5.4261, -0.6347,
0.8506,  1.7189],
[ 5.6446,  1.3771, -2.5447, -2.5770,  2.0065,  0.8168,  1.2361, -0.7975,
-1.5884, -0.5331],
[ 0.4765,  1.3488,  2.7189, -1.8958,  2.2903, -0.3011, -6.3699,  3.6232,
-1.1421,  0.9611],
[-0.6719,  1.3100,  2.9186,  2.6405,  1.2540, -4.1039,  0.0925, -0.0965,
0.9177, -0.3529]])


Total running time of the script: ( 0 minutes 0.006 seconds)

Gallery generated by Sphinx-Gallery

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources