Jupyter notebook 在使用 matplotlib 绘制张量 (Tensorflow 2) 时卡住了(CPU 达到 100%)

Jupyter notebook stuck at plotting a tensor (Tensorflow 2) using matplotlib (CPU going 100%)

我正在观看视频教程并使用 matplotlib 在 Jupyter Notebook 中绘制 tensorflow 张量。 电池卡住了,一个 CPU 达到了 100%。

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

import matplotlib.pyplot as plt
# %matplotlib inline

normal = tfd.Normal(loc=0, scale=1)
n = 10000
z = normal.sample(n)

scale = 4.5
shift = 7
scale_and_shift = tfb.Chain([tfb.Shift(shift), tfb.Scale(scale)])
x = scale_and_shift.forward(z)
plt.hist(z, bins=60, density=True)
plt.show()

在我正在学习的教程中,它运行顺利。但是在我的尝试中它卡住了。 为什么它对我的情况不起作用?有什么包需要安装吗? z 是一个张量,可以直接画出来吗?

我注意到如果我使用 plt.hist(z.numpy(), bins=60, density=True) 它会起作用。但仍然想知道为什么直接绘制 z 在我的环境中不起作用。

首先,这不是丢包问题,一般直接画张量是没有问题的。例如,您可以毫无问题地完成 plt.plot(z)。如果你检查 plt.hist documentation 你可以看到这个:

The return value is a tuple (n, bins, patches) or ([n0, n1, ...], bins, [patches0, patches1, ...]) if the input contains multiple data.

在您的情况下,由于您正在绘制分布中的样本值(z 是 [10000] 形张量),因此您属于第二类([n0, n1, ... ]) 而你实际上创建了 10000 个重叠的直方图。为了清楚起见,请参阅 s=4

这就是为什么您的 CPU 超载了 - 试图一次创建 10000 个直方图(这也是为什么每条线的颜色不同 --> 不同的直方图)。

当您使用 tensor.numpy() 时,它会将 Tensor 转换为 NumPy 数组,并能够正确创建直方图。

PS 不确定为什么它在教程中有效。也许 trey 使用的是不同的包版本 - 但 z.numpy() 应该会给你相同的结果