具有元组输入的自定义 TensorFlow RNN 单元

Custom TensorFlow RNN Cell with Tuple Input

我正在尝试在 TensorFlow 中创建一个接受元组作为输入的自定义 RNN 单元,但我 运行 遇到了父 class BasicLSTMCell 的问题要求输入是二维的:

# Inputs must be 2-dimensional.
self.input_spec = base_layer.InputSpec(ndim=2)

我怎样才能绕过这个限制?我无法在 call() 方法中添加处理元组的逻辑,因为执行永远不会到达该方法 - 维度检查会引发错误。

其实我也发现了这个问题。 tensorflow 平台存在一个错误。可以通过更改recurrent.py文件中的get_step_input_shape函数来解决。只需将 [0] 添加到此行的末尾:nest.map_structure(get_input_spec, input_shape))