博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
[pytorch] 自定义激活函数中的注意事项
阅读量:5119 次
发布时间:2019-06-13

本文共 13755 字,大约阅读时间需要 45 分钟。

最近在尝试在pytorch中自定义激活函数,如何在pytorch中使用自定义的激活函数?

如果自定义的激活函数是可导的,那么可以直接写一个python function来定义并调用,因为pytorch的autograd会自动对其求导。

如果自定义的激活函数不是可导的,比如类似于ReLU的分段可导的函数,需要写一个继承torch.autograd.Function的类,并自行定义forward和backward的过程

在pytorch中提供了定义新的autograd function的tutorial: , tutorial以ReLU为例介绍了在forward, backward中需要自行定义的内容。

1 import torch 2  3  4 class MyReLU(torch.autograd.Function): 5     """ 6     We can implement our own custom autograd Functions by subclassing 7     torch.autograd.Function and implementing the forward and backward passes 8     which operate on Tensors. 9     """10 11     @staticmethod12     def forward(ctx, input):13         """14         In the forward pass we receive a Tensor containing the input and return15         a Tensor containing the output. ctx is a context object that can be used16         to stash information for backward computation. You can cache arbitrary17         objects for use in the backward pass using the ctx.save_for_backward method.18         """19         ctx.save_for_backward(input)20         return input.clamp(min=0)21 22     @staticmethod23     def backward(ctx, grad_output):24         """25         In the backward pass we receive a Tensor containing the gradient of the loss26         with respect to the output, and we need to compute the gradient of the loss27         with respect to the input.28         """29         input, = ctx.saved_tensors30         grad_input = grad_output.clone()31         grad_input[input < 0] = 032         return grad_input33 34 35 dtype = torch.float36 device = torch.device("cpu")37 # device = torch.device("cuda:0") # Uncomment this to run on GPU38 39 # N is batch size; D_in is input dimension;40 # H is hidden dimension; D_out is output dimension.41 N, D_in, H, D_out = 64, 1000, 100, 1042 43 # Create random Tensors to hold input and outputs.44 x = torch.randn(N, D_in, device=device, dtype=dtype)45 y = torch.randn(N, D_out, device=device, dtype=dtype)46 47 # Create random Tensors for weights.48 w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)49 w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)50 51 learning_rate = 1e-652 for t in range(500):53     # To apply our Function, we use Function.apply method. We alias this as 'relu'.54     relu = MyReLU.apply55 56     # Forward pass: compute predicted y using operations; we compute57     # ReLU using our custom autograd operation.58     y_pred = relu(x.mm(w1)).mm(w2)59 60     # Compute and print loss61     loss = (y_pred - y).pow(2).sum()62     print(t, loss.item())63 64     # Use autograd to compute the backward pass.65     loss.backward()66 67     # Update weights using gradient descent68     with torch.no_grad():69         w1 -= learning_rate * w1.grad70         w2 -= learning_rate * w2.grad71 72         # Manually zero the gradients after updating weights73         w1.grad.zero_()74         w2.grad.zero_()

但是如果定义ReLU函数时,没有使用以上正确的方法,而是直接自定义的函数,会出现什么问题呢?

这里对比了使用以上MyReLU和自定义函数:no_back的实验结果。

1 def no_back(x):2     return x * (x > 0).float()

代码:

N, D_in, H, D_out = 2, 3, 4, 5# Create random Tensors to hold input and outputs.x = torch.randn(N, D_in, device=device, dtype=dtype)y = torch.randn(N, D_out, device=device, dtype=dtype)# Create random Tensors for weights.origin_w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)origin_w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)learning_rate = 1e-3def myReLU(func, x, y, origin_w1, origin_w2, learning_rate,N = 2, D_in = 3, H = 4, D_out = 5):    w1 = deepcopy(origin_w1)    w2 = deepcopy(origin_w2)    for t in range(5):        # Forward pass: compute predicted y using operations; we compute        # ReLU using our custom autograd operation.        y_pred = func(x.mm(w1)).mm(w2)        # Compute and print loss        loss = (y_pred - y).pow(2).sum()        print("------", t, loss.item(), "------------")        # Use autograd to compute the backward pass.        loss.backward()        # Update weights using gradient descent        with torch.no_grad():            print('w1 = ')            print(w1)            print('---------------------')            print("x.mm(w1) = ")            print(x.mm(w1))            print('---------------------')            print('func(x.mm(w1))')            print(func(x.mm(w1)))            print('---------------------')            print("w1.grad:", w1.grad)            # print("w2.grad:",w2.grad)            print('---------------------')            w1 -= learning_rate * w1.grad            w2 -= learning_rate * w2.grad            # Manually zero the gradients after updating weights            w1.grad.zero_()            w2.grad.zero_()            print('========================')            print()myReLU(func = MyReLU.apply, x = x, y = y, origin_w1 = origin_w1, origin_w2 = origin_w2, learning_rate = learning_rate, N = 2, D_in = 3, H = 4, D_out = 5)print('============')print('============')print('============')myReLU(func = no_back, x = x, y = y, origin_w1 = origin_w1, origin_w2 = origin_w2, learning_rate = learning_rate, N = 2, D_in = 3, H = 4, D_out = 5)

对于使用了MyReLU.apply的实验结果为:

1 ------ 0 20.18220329284668 ------------ 2 w1 =  3 tensor([[ 0.7070,  2.5772,  0.7987,  2.2287], 4         [ 0.7425, -0.6309,  0.3268, -1.5072], 5         [ 0.6930, -2.6128,  0.1949,  0.8819]], requires_grad=True) 6 --------------------- 7 x.mm(w1) =  8 tensor([[-0.9788,  1.0135, -0.4164,  1.8834], 9         [-0.7692, -1.8556, -0.7085, -0.9849]])10 ---------------------11 func(x.mm(w1))12 tensor([[0.0000, 1.0135, 0.0000, 1.8834],13         [0.0000, 0.0000, 0.0000, 0.0000]])14 ---------------------15 w1.grad: tensor([[  0.0000,   0.0499,   0.0000,   0.1881],16         [  0.0000,  -4.4962,   0.0000, -16.9378],17         [  0.0000,  -0.2401,   0.0000,  -0.9043]])18 ---------------------19 ========================20 21 ------ 1 19.546737670898438 ------------22 w1 = 23 tensor([[ 0.7070,  2.5772,  0.7987,  2.2285],24         [ 0.7425, -0.6265,  0.3268, -1.4903],25         [ 0.6930, -2.6126,  0.1949,  0.8828]], requires_grad=True)26 ---------------------27 x.mm(w1) = 28 tensor([[-0.9788,  1.0078, -0.4164,  1.8618],29         [-0.7692, -1.8574, -0.7085, -0.9915]])30 ---------------------31 func(x.mm(w1))32 tensor([[0.0000, 1.0078, 0.0000, 1.8618],33         [0.0000, 0.0000, 0.0000, 0.0000]])34 ---------------------35 w1.grad: tensor([[  0.0000,   0.0483,   0.0000,   0.1827],36         [  0.0000,  -4.3446,   0.0000, -16.4493],37         [  0.0000,  -0.2320,   0.0000,  -0.8782]])38 ---------------------39 ========================40 41 ------ 2 18.94647789001465 ------------42 w1 = 43 tensor([[ 0.7070,  2.5771,  0.7987,  2.2283],44         [ 0.7425, -0.6221,  0.3268, -1.4738],45         [ 0.6930, -2.6123,  0.1949,  0.8837]], requires_grad=True)46 ---------------------47 x.mm(w1) = 48 tensor([[-0.9788,  1.0023, -0.4164,  1.8409],49         [-0.7692, -1.8591, -0.7085, -0.9978]])50 ---------------------51 func(x.mm(w1))52 tensor([[0.0000, 1.0023, 0.0000, 1.8409],53         [0.0000, 0.0000, 0.0000, 0.0000]])54 ---------------------55 w1.grad: tensor([[  0.0000,   0.0467,   0.0000,   0.1775],56         [  0.0000,  -4.2009,   0.0000, -15.9835],57         [  0.0000,  -0.2243,   0.0000,  -0.8534]])58 ---------------------59 ========================60 61 ------ 3 18.378826141357422 ------------62 w1 = 63 tensor([[ 0.7070,  2.5771,  0.7987,  2.2281],64         [ 0.7425, -0.6179,  0.3268, -1.4578],65         [ 0.6930, -2.6121,  0.1949,  0.8846]], requires_grad=True)66 ---------------------67 x.mm(w1) = 68 tensor([[-0.9788,  0.9969, -0.4164,  1.8206],69         [-0.7692, -1.8607, -0.7085, -1.0040]])70 ---------------------71 func(x.mm(w1))72 tensor([[0.0000, 0.9969, 0.0000, 1.8206],73         [0.0000, 0.0000, 0.0000, 0.0000]])74 ---------------------75 w1.grad: tensor([[  0.0000,   0.0451,   0.0000,   0.1726],76         [  0.0000,  -4.0644,   0.0000, -15.5391],77         [  0.0000,  -0.2170,   0.0000,  -0.8296]])78 ---------------------79 ========================80 81 ------ 4 17.841421127319336 ------------82 w1 = 83 tensor([[ 0.7070,  2.5770,  0.7987,  2.2280],84         [ 0.7425, -0.6138,  0.3268, -1.4423],85         [ 0.6930, -2.6119,  0.1949,  0.8854]], requires_grad=True)86 ---------------------87 x.mm(w1) = 88 tensor([[-0.9788,  0.9918, -0.4164,  1.8008],89         [-0.7692, -1.8623, -0.7085, -1.0100]])90 ---------------------91 func(x.mm(w1))92 tensor([[0.0000, 0.9918, 0.0000, 1.8008],93         [0.0000, 0.0000, 0.0000, 0.0000]])94 ---------------------95 w1.grad: tensor([[  0.0000,   0.0437,   0.0000,   0.1679],96         [  0.0000,  -3.9346,   0.0000, -15.1145],97         [  0.0000,  -0.2101,   0.0000,  -0.8070]])98 ---------------------99 ========================
View Code

对于使用了no_back的实验结果为:

1 ------ 0 20.18220329284668 ------------ 2 w1 =  3 tensor([[ 0.7070,  2.5772,  0.7987,  2.2287], 4         [ 0.7425, -0.6309,  0.3268, -1.5072], 5         [ 0.6930, -2.6128,  0.1949,  0.8819]], requires_grad=True) 6 --------------------- 7 x.mm(w1) =  8 tensor([[-0.9788,  1.0135, -0.4164,  1.8834], 9         [-0.7692, -1.8556, -0.7085, -0.9849]])10 ---------------------11 func(x.mm(w1))12 tensor([[-0.0000, 1.0135, -0.0000, 1.8834],13         [-0.0000, -0.0000, -0.0000, -0.0000]])14 ---------------------15 w1.grad: tensor([[  0.0000,   0.0499,   0.0000,   0.1881],16         [  0.0000,  -4.4962,   0.0000, -16.9378],17         [  0.0000,  -0.2401,   0.0000,  -0.9043]])18 ---------------------19 ========================20 21 ------ 1 19.546737670898438 ------------22 w1 = 23 tensor([[ 0.7070,  2.5772,  0.7987,  2.2285],24         [ 0.7425, -0.6265,  0.3268, -1.4903],25         [ 0.6930, -2.6126,  0.1949,  0.8828]], requires_grad=True)26 ---------------------27 x.mm(w1) = 28 tensor([[-0.9788,  1.0078, -0.4164,  1.8618],29         [-0.7692, -1.8574, -0.7085, -0.9915]])30 ---------------------31 func(x.mm(w1))32 tensor([[-0.0000, 1.0078, -0.0000, 1.8618],33         [-0.0000, -0.0000, -0.0000, -0.0000]])34 ---------------------35 w1.grad: tensor([[  0.0000,   0.0483,   0.0000,   0.1827],36         [  0.0000,  -4.3446,   0.0000, -16.4493],37         [  0.0000,  -0.2320,   0.0000,  -0.8782]])38 ---------------------39 ========================40 41 ------ 2 18.94647789001465 ------------42 w1 = 43 tensor([[ 0.7070,  2.5771,  0.7987,  2.2283],44         [ 0.7425, -0.6221,  0.3268, -1.4738],45         [ 0.6930, -2.6123,  0.1949,  0.8837]], requires_grad=True)46 ---------------------47 x.mm(w1) = 48 tensor([[-0.9788,  1.0023, -0.4164,  1.8409],49         [-0.7692, -1.8591, -0.7085, -0.9978]])50 ---------------------51 func(x.mm(w1))52 tensor([[-0.0000, 1.0023, -0.0000, 1.8409],53         [-0.0000, -0.0000, -0.0000, -0.0000]])54 ---------------------55 w1.grad: tensor([[  0.0000,   0.0467,   0.0000,   0.1775],56         [  0.0000,  -4.2009,   0.0000, -15.9835],57         [  0.0000,  -0.2243,   0.0000,  -0.8534]])58 ---------------------59 ========================60 61 ------ 3 18.378826141357422 ------------62 w1 = 63 tensor([[ 0.7070,  2.5771,  0.7987,  2.2281],64         [ 0.7425, -0.6179,  0.3268, -1.4578],65         [ 0.6930, -2.6121,  0.1949,  0.8846]], requires_grad=True)66 ---------------------67 x.mm(w1) = 68 tensor([[-0.9788,  0.9969, -0.4164,  1.8206],69         [-0.7692, -1.8607, -0.7085, -1.0040]])70 ---------------------71 func(x.mm(w1))72 tensor([[-0.0000, 0.9969, -0.0000, 1.8206],73         [-0.0000, -0.0000, -0.0000, -0.0000]])74 ---------------------75 w1.grad: tensor([[  0.0000,   0.0451,   0.0000,   0.1726],76         [  0.0000,  -4.0644,   0.0000, -15.5391],77         [  0.0000,  -0.2170,   0.0000,  -0.8296]])78 ---------------------79 ========================80 81 ------ 4 17.841421127319336 ------------82 w1 = 83 tensor([[ 0.7070,  2.5770,  0.7987,  2.2280],84         [ 0.7425, -0.6138,  0.3268, -1.4423],85         [ 0.6930, -2.6119,  0.1949,  0.8854]], requires_grad=True)86 ---------------------87 x.mm(w1) = 88 tensor([[-0.9788,  0.9918, -0.4164,  1.8008],89         [-0.7692, -1.8623, -0.7085, -1.0100]])90 ---------------------91 func(x.mm(w1))92 tensor([[-0.0000, 0.9918, -0.0000, 1.8008],93         [-0.0000, -0.0000, -0.0000, -0.0000]])94 ---------------------95 w1.grad: tensor([[  0.0000,   0.0437,   0.0000,   0.1679],96         [  0.0000,  -3.9346,   0.0000, -15.1145],97         [  0.0000,  -0.2101,   0.0000,  -0.8070]])98 ---------------------99 ========================
View Code

对比发现,二者在梯度大小及更新的数值、loss大小等都是数值相等的,这是否说明对于不可导函数,直接定义函数也可以取得和正确定义前向后向过程相同的结果呢?

应当注意到一个问题,那就是在MyReLU.apply的实验结果中,出现数值为0的地方,显示为0.0000,而在no_back的实验结果中,出现数值为0的地方,显示为-0.0000;

0.0000与-0.0000有什么区别呢?

参考stack overflow中的解答:

和wikipedia中对于signed zero的介绍:

在python中二者是显然不同的对象,但是在数值比较时,二者的值显示为相等。

-0.0 == +0.0 == 0

在Python 中使它们数值相等的设定,是在尽量避免为code引入bug.

>>> a = 3.4>>> b =4.4>>> c = -0.0>>> d = +0.0>>> a*c-0.0>>> b*d0.0>>> a*c == b*dTrue>>>

虽然看起来,它们在使用中并没有什么区别,但是在计算机内部对它们的编码表示并不相同。

在对于整数的1+7位元的中,负零是用二进制代码10000000表示的。在8位元中,负零是用二进制代码11111111表示,但补码表示法則沒有負零的概念。在二进制浮点数算术标准中,指数和尾数为零、符号位元为一的数就是负零。

在的普通十进制算数编码规范中,运用十进制来表示浮点数。这里负零被表示为指数为编码内任意合法数值、所有系数均为零、符号位元为一的数。

 ~(wikipedia)

在数值分析中,也常将-0看做从负数区间无限趋近于0的值,将+0看做从正数区间无限趋近于0的值,二者在数值上近似相等,但在某些操作中却可能产生不同的结果。

比如 divmod,会沿用数值的sign:

>>> divmod(-0.0,100)(-0.0, 0.0)>>> divmod(+0.0,100)(0.0, 0.0)

比如 atan2, (介绍详见)

 

atan2(+0, +0) = +0;  

atan2(+0, −0) = +π;  ( 当y是位于y轴正半轴,无限趋近于0的值;x是位于x轴负半轴,无限趋近于0的值,=> 可以看做是在第二象限中位于x轴负半轴的一点 => $\theta夹角为$\pi$)

atan2(−0, +0) = −0;  ( 可以看做是在第四象限中位于x轴正半轴的一点 => $\theta夹角为-0)

atan2(−0, −0) = −π.

用代码验证:

>>> math.atan2(0.0, 0.0) == math.atan2(-0.0, 0.0)True >>> math.atan2(0.0, -0.0) == math.atan2(-0.0, -0.0)False

所以,尽管在上面自定义激活函数时,将不可导函数强行加入到pytorch的autograd中运算,数值结果相同;但是注意到-0.0000的出现是程序有bug的提示,严谨考虑仍需要规范定义,如MyReLU。

转载于:https://www.cnblogs.com/shiyublog/p/11201461.html

你可能感兴趣的文章
C# ITextSharp pdf 自动打印
查看>>
【Java】synchronized与lock的区别
查看>>
django高级应用(分页功能)
查看>>
【转】Linux之printf命令
查看>>
关于PHP会话:session和cookie
查看>>
STM32F10x_RTC秒中断
查看>>
display:none和visiblity:hidden区别
查看>>
C#double转化成字符串 保留小数位数, 不以科学计数法的形式出现。
查看>>
SpringMVC学习总结(三)——Controller接口详解(1)
查看>>
牛的障碍Cow Steeplechase
查看>>
Zookeeper选举算法原理
查看>>
嵌入式成长轨迹52 【Zigbee项目】【CC2430基础实验】【在PC用串口收数并发数】...
查看>>
3月29日AM
查看>>
利用IP地址查询接口来查询IP归属地
查看>>
HTML元素定义 ID,Class,Style的优先级
查看>>
【实数二分/前缀和维护】Best Cow Fences
查看>>
构造者模式
查看>>
浮点数转化为字符串
查看>>
http和https的区别
查看>>
Hbuild在线云ios打包失败,提示BuildConfigure Failed 31013 App Store 图标 未找到 解决方法...
查看>>