使用 MNIST 样本可视化每个数字的 10x10 网格
visualize 10x10 grid of each digit using MNIST samples
我正在尝试从 MNIST 数据集中绘制 10x10 网格样本。每个数字 10 个。这是代码:
导入库:
import sklearn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
from sklearn.pipeline import Pipeline
from sklearn.datasets import fetch_openml
加载数字数据:
X, Y = fetch_openml(name='mnist_784', return_X_y=True, cache=False)
绘制网格:
def P1(num_examples=10):
plt.rc('image', cmap='Greys')
plt.figure(figsize=(num_examples,len(np.unique(Y))), dpi=X.shape[1])
# For each digit (from 0 to 9)
for i in np.nditer(np.unique(Y)):
# Create a ndarray with the features of "num_examples" examples of digit "i"
features = X[Y == i][:num_examples]
# For each of the "num_examples" examples
for j in range(num_examples):
# Create subplot (from 1 to "num_digits"*"num_examples" of each digit)
plt.subplot(len(np.unique(Y)), num_examples, i * num_examples + j + 1)
plt.subplots_adjust(wspace=0, hspace=0)
# Hide tickmarks and scale
ax = plt.gca()
# ax.set_axis_off() # Also hide axes (frame)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
# Plot the corresponding digit (reshaped to square matrix/image)
dim = int(np.sqrt(X.shape[1]))
digit = features[j].reshape((dim,dim))
plt.imshow(digit)
P1(10)
但是,我在这里收到一条错误消息:"Iterator operand or requested dtype holds references, but the REFS_OK flag was not enabled"
谁能帮我解决这个问题?
此错误很可能来自 nd.iter
,您不需要 - 还建议使用 subplots
和 ax
而不是 MATLAB 样式 plt
调用:
digits = np.unique(Y)
M = 10
dim = int(np.sqrt(X.shape[1]))
fig, axs = plt.subplots(len(digits), M, figsize=(20,20))
for i,d in enumerate(digits):
for j in range(M):
axs[i,j].imshow(X[Y==d][j].reshape((dim,dim)))
axs[i,j].axis('off')
我正在尝试从 MNIST 数据集中绘制 10x10 网格样本。每个数字 10 个。这是代码:
导入库:
import sklearn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
from sklearn.pipeline import Pipeline
from sklearn.datasets import fetch_openml
加载数字数据:
X, Y = fetch_openml(name='mnist_784', return_X_y=True, cache=False)
绘制网格:
def P1(num_examples=10):
plt.rc('image', cmap='Greys')
plt.figure(figsize=(num_examples,len(np.unique(Y))), dpi=X.shape[1])
# For each digit (from 0 to 9)
for i in np.nditer(np.unique(Y)):
# Create a ndarray with the features of "num_examples" examples of digit "i"
features = X[Y == i][:num_examples]
# For each of the "num_examples" examples
for j in range(num_examples):
# Create subplot (from 1 to "num_digits"*"num_examples" of each digit)
plt.subplot(len(np.unique(Y)), num_examples, i * num_examples + j + 1)
plt.subplots_adjust(wspace=0, hspace=0)
# Hide tickmarks and scale
ax = plt.gca()
# ax.set_axis_off() # Also hide axes (frame)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
# Plot the corresponding digit (reshaped to square matrix/image)
dim = int(np.sqrt(X.shape[1]))
digit = features[j].reshape((dim,dim))
plt.imshow(digit)
P1(10)
但是,我在这里收到一条错误消息:"Iterator operand or requested dtype holds references, but the REFS_OK flag was not enabled"
谁能帮我解决这个问题?
此错误很可能来自 nd.iter
,您不需要 - 还建议使用 subplots
和 ax
而不是 MATLAB 样式 plt
调用:
digits = np.unique(Y)
M = 10
dim = int(np.sqrt(X.shape[1]))
fig, axs = plt.subplots(len(digits), M, figsize=(20,20))
for i,d in enumerate(digits):
for j in range(M):
axs[i,j].imshow(X[Y==d][j].reshape((dim,dim)))
axs[i,j].axis('off')