Keras backend.repeat_elements 不工作?

Keras backend.repeat_elements not working?

我试图在这里创建一个简单的矩阵,对我的批次中的每个样本重复。

这是矩阵:

balanceMatrix = np.array([[[5,10,10],[1,1,1],[1,1,1]]])
print(balanceMatrix.shape)

balanceMatrix = K.constant(balanceMatrix)
print(K.shape(balanceMatrix).eval())

到目前为止一切顺利,我得到了预期的矩阵形状 (1,3,3)。 现在我希望对批次中的每个样本重复一次(比如 60000 个样本)。从 keras documentation,我应该做的就是:

balanceMatrix = K.repeat_elements(balanceMatrix, 60000,axis=0)
print(K.shape(balanceMatrix).eval())

但这会引发以下错误,我无法简单地理解:

IndexError                                Traceback (most recent call last)
<ipython-input-28-4356baf13de8> in <module>()
     20 balanceMatrix = K.constant(balanceMatrix)
     21 print(K.shape(balanceMatrix).eval())
---> 22 balanceMatrix = K.repeat_elements(balanceMatrix, 60000,axis=0)
     23 print(K.shape(balanceMatrix).eval())
     24 

c:\users\ut65\appdata\local\programs\python\python35\lib\site-packages\keras\backend\theano_backend.py in repeat_elements(x, rep, axis)
    743     if hasattr(x, '_keras_shape'):
    744         y._keras_shape = list(x._keras_shape)
--> 745         repeat_dim = x._keras_shape[axis]
    746         if repeat_dim is not None:
    747                 y._keras_shape[axis] = repeat_dim * rep

IndexError: tuple index out of range

这是怎么回事?? 我知道,我可以先使用 np.repeat(balanceMatrix,60000,axis=0) 然后创建 keras 张量,但是 keras 选项不应该也起作用吗?

我相信 K.variable 会在这里有所帮助:

balanceMatrix = K.variable(value=balanceMatrix)