Theano MiniBatch 迭代器不工作

Theano MiniBatch Iterator not working

Theano MiniBatch 迭代器不工作

我编写了一个小批量迭代器来从我的神经网络中获得预测结果。 但是,我做了一些测试并发现了一些错误。

基本上:

If batch_size > amount of inputs  : error

我制作了一个脚本来显示我的代码中的这个错误。如下所示:

import numpy as np

def minibatch_iterator_predictor(inputs, batch_size):
    assert len(inputs) > 0

    for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
        excerpt = slice(start_idx, start_idx + batch_size)
        yield inputs[excerpt]


def test(x, batch_size):
    prediction = np.empty((x.shape[0], 2), dtype=np.float32)

    index = 0
    for batch in minibatch_iterator_predictor(inputs=x, batch_size=batch_size):
        inputs = batch

        # y = self.predict_function(inputs)
        y = inputs

        prediction[index * batch_size:batch_size * (index + 1), :] = y[:]
        index += 1
    return prediction

######################################
#TEST SCRIPT
######################################

#Input
arr = np.zeros(shape=(10, 2))

arr[0] = [1, 0]
arr[1] = [2, 0]
arr[2] = [3, 0]
arr[3] = [4, 0]
arr[4] = [5, 0]
arr[5] = [6, 0]
arr[6] = [7, 0]
arr[7] = [8, 0]
arr[8] = [9, 0]
arr[9] = [10, 0]

###############################################

batch_size = 5
print "\nBatch_size ", batch_size
r = test(x=arr, batch_size=batch_size)

#Debug
for k in xrange(r.shape[0]):
        print str(k) + " : " + str(r[k])

##Assert

assert arr.shape[0] == r.shape[0]

for k in xrange(0,r.shape[0]):
    print r[k] == arr[k]

这是测试

对于 batch_size = 10 :

Batch_size  10
0 : [ 1.  0.]
1 : [ 2.  0.]
2 : [ 3.  0.]
3 : [ 4.  0.]
4 : [ 5.  0.]
5 : [ 6.  0.]
6 : [ 7.  0.]
7 : [ 8.  0.]
8 : [ 9.  0.]
9 : [ 10.   0.]

对于 batch_size = 11 :

0 : [  1.13876845e-37   0.00000000e+00]
1 : [  1.14048027e-37   0.00000000e+00]
2 : [  1.14048745e-37   0.00000000e+00]
3 : [  9.65151604e-38   0.00000000e+00]
4 : [  1.14002468e-37   0.00000000e+00]
5 : [  1.14340036e-37   0.00000000e+00]
6 : [  1.14343264e-37   0.00000000e+00]
7 : [  8.02794698e-38   0.00000000e+00]
8 : [  8.02794698e-38   0.00000000e+00]
9 : [  8.02794698e-38   0.00000000e+00]

对于Batch_size12

0 : [  1.13876845e-37   0.00000000e+00]
1 : [  1.14048027e-37   0.00000000e+00]
2 : [  1.14048745e-37   0.00000000e+00]
3 : [  9.65151604e-38   0.00000000e+00]
4 : [  1.14002468e-37   0.00000000e+00]
5 : [  1.14340036e-37   0.00000000e+00]
6 : [  1.14343264e-37   0.00000000e+00]
7 : [  8.10141537e-38   0.00000000e+00]
8 : [  8.10141537e-38   0.00000000e+00]
9 : [  8.10141537e-38   0.00000000e+00]

我该如何解决这个问题?

请尝试在问题中更具体。你到底想修复什么?

没有任何错误。 当批处理大小大于输入时,函数 minibatch_iterator_predictor 生成一个空迭代器并且不执行循环 for batch in minibatch_iterator_predictor(inputs=x, batch_size=batch_size)

当 batch_size 大于输入数时,您得到的只是初始化中的零:prediction = np.empty((x.shape[0], 2), dtype=np.float32)

您可以做的是将最大值 batch_size 限制为输入数:

def minibatch_iterator_predictor(inputs, batch_size):
    assert len(inputs) > 0
    if batch_size > len(inputs):
        batch_size = len(inputs)

    for start_idx in range(0, len(inputs) - batch_size + 1, batch_size):
        excerpt = slice(start_idx, start_idx + batch_size)
        yield inputs[excerpt]