绘制 pytorch 函数的等高线

Plotting contour of pytorch functions

我有一个类似下面的函数:

x0 = torch.tensor([1.5, .1])
Q = torch.tensor([[10.0, 6.0],
                  [9.0, 8.0]])

def f1(x):
    z = x - x0
    Qz = z @ Q
    return 0.5 * Qz@z

如何获得等高线图? x 是二维张量。 我在使用 meshgrid 时弄乱了某个地方。

首先,让我们让您的代码使用 batch 个 2d 点,即形状 nx2x :

def f1(x):
  z = x - x0  # z of shape n-2
  Qz = z @ Q  # Qz of shape n-2
  return 0.5 * (Qz * z).sum(dim=-1)  # we want output of size n and not n-n

现在我们可以创建一个我们想要绘制的网格 f1(x):

grid = torch.stack(torch.meshgrid(torch.linspace(-20., 20., 100), torch.linspace(-20., 20., 100), indexing='xy'))
# convert the grid to a batch of 2d points:
grid = grid.reshape(2, -1).T
# get the output on all points in the grid
out = f1(grid)
# plot
plt.matshow(out.detach().numpy().reshape(100,100))

您将获得: