使用 pytorch 绘制 sin(x) 的导数

Plot derivatives of sin(x) using pytorch

我不确定为什么我的代码没有绘制 cos(x)(是的,我知道 pytorch 有 cos(x) 函数)

import math
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import random

x = torch.linspace(-math.pi, math.pi, 5000, requires_grad=True)
y = torch.sin(x)
y.backward(x)
x.grad == torch.cos(x) # assert x.grad same as cos(x)
plt.plot(x.detach().numpy(), y.detach().numpy(), label='sin(x)')
plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label='cos(x)') # print derivative of sin(x)

您需要将上游梯度(等于您的情况下的所有梯度)而不是 x 作为 y.backward().

的输入

因此

import math
import torch
import matplotlib.pyplot as plt

x = torch.linspace(-math.pi, math.pi, 5000, requires_grad=True)
y = torch.sin(x)
y.backward(torch.ones_like(x))
plt.plot(x.detach().numpy(), y.detach().numpy(), label='sin(x)')
plt.plot(x.detach().numpy(), x.grad.detach().numpy(), label='cos(x)') # print derivative of sin(x)
plt.show()