mod.predict 给出的列比预期的多

mod.predict gives more columns than expected

我在 IRIS 数据集上使用 MXNet,它有 4 个特征,它将花分类为 -'setosa'、'versicolor'、'virginica'。我的训练数据有 89 行。我的标签数据是 89 列的行向量。我将花名编码为数字 -0,1,2,因为 mx.io.NDArrayIter 似乎不接受带有字符串值的 numpy ndarray。然后我尝试使用

进行预测

re = mod.predict(test_iter)

我得到一个形状为 14 * 10 的结果。 当我只有 3 个标签时为什么会得到 10 列以及如何将这些结果映射到我的标签。预测结果如下图:

[[ 0.11760861 0.12082944 0.1207106 0.09154381 0.09155304 0.09155869 0.09154817 0.09155204 0.09154914 0.09154641] [ 0.1176083 0.12082954 0.12071151 0.09154379 0.09155323 0.09155825 0.0915481 0.09155164 0.09154923 0.09154641] [ 0.11760829 0.1208293 0.12071083 0.09154385 0.09155313 0.09155875 0.09154838 0.09155186 0.09154932 0.09154625] [ 0.11760861 0.12082901 0.12071037 0.09154388 0.09155303 0.09155875 0.09154829 0.09155209 0.09154959 0.09154641] [ 0.11760896 0.12082863 0.12070955 0.09154405 0.09155299 0.09155875 0.09154839 0.09155225 0.09154996 0.09154646] [ 0.1176089 0.1208287 0.1207095 0.09154407 0.09155297 0.09155882 0.09154844 0.09155232 0.09154989 0.0915464 ] [ 0.11760896 0.12082864 0.12070941 0.09154408 0.09155297 0.09155882 0.09154844 0.09155234 0.09154993 0.09154642] [ 0.1176088 0.12082874 0.12070983 0.09154399 0.09155302 0.09155872 0.09154837 0.09155215 0.09154984 0.09154641] [ 0.11760852 0.12082904 0.12071032 0.09154394 0.09155304 0.09155876 0.09154835 0.09155209 0.09154959 0.09154631] [ 0.11760963 0.12082832 0.12070873 0.09154428 0.09155257 0.09155893 0.09154856 0.09155177 0.09155051 0.09154671] [ 0.11760966 0.12082829 0.12070868 0.09154429 0.09155258 0.09155892 0.09154858 0.0915518 0.09155052 0.09154672] [ 0.11760949 0.1208282 0.12070852 0.09154446 0.09155259 0.09155893 0.09154854 0.09155205 0.0915506 0.09154666] [ 0.11760952 0.12082817 0.12070853 0.0915444 0.09155261 0.09155891 0.09154853 0.09155206 0.09155057 0.09154668] [ 0.1176096 0.1208283 0.12070892 0.09154423 0.09155267 0.09155882 0.09154859 0.09155172 0.09155044 0.09154676]]

使用"y = mod.predict(val_iter,num_batch=1)"代替"y = mod.predict(val_iter)",则只能得到一个批次标签。例如,如果您 batch_size 是 10,那么您将只会得到 10 个标签。