使用 pytorch 绘制 sin(x) 的导数
Plot derivatives of sin(x) using pytorch
我不确定为什么我的代码没有绘制 cos(x)(是的,我知道 pytorch 有 cos(x) 函数)
import math
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
x = torch.linspace(-math.pi, math.pi, 5000, requires_grad=True)
y = torch.sin(x)
y.backward(x)
x.grad == torch.cos(x) # assert x.grad same as cos(x)
plt.plot(x.detach().numpy(), y.detach().numpy(), label='sin(x)')
plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label='cos(x)') # print derivative of sin(x)
您需要将上游梯度(等于您的情况下的所有梯度)而不是 x
作为 y.backward()
.
的输入
因此
import math
import torch
import matplotlib.pyplot as plt
x = torch.linspace(-math.pi, math.pi, 5000, requires_grad=True)
y = torch.sin(x)
y.backward(torch.ones_like(x))
plt.plot(x.detach().numpy(), y.detach().numpy(), label='sin(x)')
plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label='cos(x)') # print derivative of sin(x)
plt.show()
我不确定为什么我的代码没有绘制 cos(x)(是的,我知道 pytorch 有 cos(x) 函数)
import math
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
x = torch.linspace(-math.pi, math.pi, 5000, requires_grad=True)
y = torch.sin(x)
y.backward(x)
x.grad == torch.cos(x) # assert x.grad same as cos(x)
plt.plot(x.detach().numpy(), y.detach().numpy(), label='sin(x)')
plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label='cos(x)') # print derivative of sin(x)
您需要将上游梯度(等于您的情况下的所有梯度)而不是 x
作为 y.backward()
.
因此
import math
import torch
import matplotlib.pyplot as plt
x = torch.linspace(-math.pi, math.pi, 5000, requires_grad=True)
y = torch.sin(x)
y.backward(torch.ones_like(x))
plt.plot(x.detach().numpy(), y.detach().numpy(), label='sin(x)')
plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label='cos(x)') # print derivative of sin(x)
plt.show()