sklearn 的 MLPClassifier 用于 MNIST 数字分类任务的输入和输出层有多少个节点

How many nodes in input and output layers of sklearn's MLPClassifier for MNIST digits classification task

我正在按照 https://scikit-learn.org/stable/auto_examples/neural_networks/plot_mnist_filters.html#sphx-glr-auto-examples-neural-networks-plot-mnist-filters-py 上的示例进行操作,并试图弄清楚我对示例中输入层和输出层中节点数量的理解是否正确。所需代码如下:

import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.neural_network import MLPClassifier

print(__doc__)

# Load data from https://www.openml.org/d/554
X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
X = X / 255.

# rescale the data, use the traditional train/test split
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]

mlp = MLPClassifier(hidden_layer_sizes=(50,), max_iter=10, alpha=1e-4,
                    solver='sgd', verbose=10, random_state=1,
                    learning_rate_init=.1)

mlp.fit(X_train, y_train)
score = mlp.score(X_test, y_test)

根据https://dudeperf3ct.github.io/mlp/mnist/2018/10/08/Force-of-Multi-Layer-Perceptron/,该示例说明输入层有 784 个节点(我假设来自数据的 shape)和输出层有 10 个节点,每个数字 1 个.

上面代码中的 MLPClassifier 也是这样吗?

谢谢你,如果有一些澄清就更好了!

您的理解是正确的。 MNIST 数字数据的图像大小为 28x28,展平为 784,输出大小为 10(从 0 到 9 的每个数字一个)。 MLPClassifier 根据 Fit 方法中提供的数据隐式设计输入和输出层。

您的 NN 配置将如下所示: 输入:200 x 784 隐藏层:784 x 50(特征尺寸:200 x 50) 输出层:50 x 10(特征尺寸:200 x 10)

MLPClassifier 中的批量大小默认为 200,因为训练数据大小为 60000