mxnet 无法推断类型

mxnet failed to infer type

我的 mxnet 代码 - 由一系列复杂的连接和切片组成,引发了以下错误:

Error in operator concat0: [03:03:51] src/operator/./concat-inl.h:211: Not enough information to infer type in Concat.

我不确定如何解释它或提供哪些信息来帮助调试它。 Concat0 是操作的一部分:

# Define take_column function as transpose(take(transpose(x), i))

for i in range(47):
    y_hat_lt = take_column(y_hat,
                mx.sym.concat(mx.sym.slice(some_indices, begin=i, end=i+1), self.label_dim + mx.sym.slice(some_indices, begin=i, end=i+1), dim=0))

这里some_indices是一个变量,我固定为一个列表。请告诉我!

MXNet 似乎无法推断输出的形状。您是否为变量 some_indices 指定了形状?

例如some_indices = mx.sym.var('indices', 形状=(1,1))

如果你能粘贴一个最小的可重现代码就太好了:)

不是采用转置,而是在轴之间交换解决了这个问题。

def ttake( x, i ):
    """ Take from axis 1 instead of 0.
    """
    a = mx.sym.swapaxes(x, dim1=0, dim2=1)
    return mx.sym.flatten( mx.sym.transpose( mx.sym.take( a , i ) ) )