自定义迭代器不起作用的示例

example for custom iterator not working

我按照此处描述的创建自定义迭代器的说明和示例进行操作:http://mxnet.io/tutorials/basic/data.html

以下代码产生一个 ValueError:

mod.fit(data_iter, num_epoch=5)

ValueError: Shape of labels 0 does not match shape of predictions 1

我的问题:

我在 mac 上使用 jupyter,所有东西都是新安装的,包括 python... 我还直接使用 python 进行了测试:

Python 3.6.1 |Anaconda custom (x86_64)| (default, May 11 2017, 13:04:09) 
[GCC 4.2.1 Compatible Apple LLVM 6.0 (clang-600.0.57)] on darwin

代码:

import mxnet as mx
import os
import subprocess
import numpy as np
import matplotlib.pyplot as plt
import tarfile

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)


import numpy as np
data = np.random.rand(100,3)
label = np.random.randint(0, 10, (100,))
data_iter = mx.io.NDArrayIter(data=data, label=label, batch_size=30)
for batch in data_iter:
    print([batch.data, batch.label, batch.pad])

#[[<NDArray 30x3 @cpu(0)>], [<NDArray 30 @cpu(0)>], 0]
#[[<NDArray 30x3 @cpu(0)>], [<NDArray 30 @cpu(0)>], 0]
#[[<NDArray 30x3 @cpu(0)>], [<NDArray 30 @cpu(0)>], 0]
#[[<NDArray 30x3 @cpu(0)>], [<NDArray 30 @cpu(0)>], 20]


#lets save `data` into a csv file first and try reading it back
np.savetxt('data.csv', data, delimiter=',')
data_iter = mx.io.CSVIter(data_csv='data.csv', data_shape=(3,), batch_size=30)
for batch in data_iter:
    print([batch.data, batch.pad])

#[[<NDArray 30x3 @cpu(0)>], 0]
#[[<NDArray 30x3 @cpu(0)>], 0]
#[[<NDArray 30x3 @cpu(0)>], 0]
#[[<NDArray 30x3 @cpu(0)>], 20]

class SimpleIter(mx.io.DataIter):
    def __init__(self, data_names, data_shapes, data_gen,
                 label_names, label_shapes, label_gen, num_batches=10):
        self._provide_data = zip(data_names, data_shapes)
        self._provide_label = zip(label_names, label_shapes)
        self.num_batches = num_batches
        self.data_gen = data_gen
        self.label_gen = label_gen
        self.cur_batch = 0

    def __iter__(self):
        return self

    def reset(self):
        self.cur_batch = 0

    def __next__(self):
        return self.next()

    @property
    def provide_data(self):
        return self._provide_data

    @property
    def provide_label(self):
        return self._provide_label

    def next(self):
        if self.cur_batch < self.num_batches:
            self.cur_batch += 1
            data = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_data,\
                                                        self.data_gen)]
            label = [mx.nd.array(g(d[1])) for d,g in zip(self._provide_label,\
                                                        self.label_gen)]
            return mx.io.DataBatch(data, label)
        else:
            raise StopIteration


import mxnet as mx
num_classes = 10
net = mx.sym.Variable('data')
net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=64)
net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=num_classes)
net = mx.sym.SoftmaxOutput(data=net, name='softmax')
print(net.list_arguments())
print(net.list_outputs())

#['data', 'fc1_weight', 'fc1_bias', 'fc2_weight', 'fc2_bias', 'softmax_label']
#['softmax_output']



import logging
logging.basicConfig(level=logging.INFO)

n = 32
data_iter = SimpleIter(['data'], [(n, 100)],
                  [lambda s: np.random.uniform(-1, 1, s)],
                  ['softmax_label'], [(n,)],
                  [lambda s: np.random.randint(0, num_classes, s)])

mod = mx.mod.Module(symbol=net)
mod.fit(data_iter, num_epoch=5)

错误:

ValueError                                Traceback (most recent call last)
 <ipython-input-57-6ceb7dd11508> in <module>()
         9 
        10 mod = mx.mod.Module(symbol=net)
        ---> 11 mod.fit(data_iter, num_epoch=5)/Users/bernd/anaconda/lib/python3.6/site-packages/mxnet/module/base_module.py in fit(self, train_data, eval_data, eval_metric, epoch_end_callback, batch_end_callback, kvstore, optimizer, optimizer_params, eval_end_callback, eval_batch_end_callback, initializer, arg_params, aux_params, allow_missing, force_rebind, force_init, begin_epoch, num_epoch, validation_metric, monitor)
            493                     end_of_batch = True
            494 
        --> 495                 self.update_metric(eval_metric, data_batch.label)
            496 
            497                 if monitor is not None:
/Users/bernd/anaconda/lib/python3.6/site-packages/mxnet/module/module.py in update_metric(self, eval_metric, labels)
            678             Typically ``data_batch.label``.
            679         """
        --> 680         self._exec_group.update_metric(eval_metric, labels)
            681 
            682     def _sync_params_from_devices(self):
/Users/bernd/anaconda/lib/python3.6/site-packages/mxnet/module/executor_group.py in update_metric(self, eval_metric, labels)
            561             labels_ = OrderedDict(zip(self.label_names, labels_slice))
            562             preds = OrderedDict(zip(self.output_names, texec.outputs))
        --> 563             eval_metric.update_dict(labels_, preds)
            564 
            565     def _bind_ith_exec(self, i, data_shapes, label_shapes, shared_group):
/Users/bernd/anaconda/lib/python3.6/site-packages/mxnet/metric.py in update_dict(self, label, pred)
             89             label = label.values()
             90 
        ---> 91         self.update(label, pred)
             92 
             93     def update(self, labels, preds):
/Users/bernd/anaconda/lib/python3.6/site-packages/mxnet/metric.py in update(self, labels, preds)
            369             Predicted values.
            370         """
        --> 371         check_label_shapes(labels, preds)
            372 
            373         for label, pred_label in zip(labels, preds):
/Users/bernd/anaconda/lib/python3.6/site-packages/mxnet/metric.py in check_label_shapes(labels, preds, shape)
             22     if label_shape != pred_shape:
             23         raise ValueError("Shape of labels {} does not match shape of "
        ---> 24                          "predictions {}".format(label_shape, pred_shape))
             25 
             26 
ValueError: Shape of labels 0 does not match shape of predictions 1

这是一个 python 版本地狱问题。我能够让它与在 python 2.7 上工作和编译的所有东西一起工作。 python 3.x 版本似乎造成了问题,错误消息并没有真正帮助...