无法理解 tf.nn.raw_rnn

Unable to understand tf.nn.raw_rnn

tf.nn.raw_rnnofficial documentation 中,当 loop_fn 第一次为 运行 时,我们将发射结构作为 loop_fn 的第三个输出。

稍后emit_structure用于复制tf.zeros_like(emit_structure)emit = tf.where(finished, tf.zeros_like(emit_structure), emit)完成的minibatch条目。

我对 google 的部分缺乏理解或糟糕的文档是:emit 结构是 None 所以 tf.where(finished, tf.zeros_like(emit_structure), emit) 会像 tf.zeros_like(None) 那样抛出 ValueError所以。有人可以填写我在这里遗漏的内容吗?

是的,文档在这个地方相当混乱。如果您查看 tf.nn.raw_rnn 的内部结构,那里的关键术语是 "in pseudo-code",因此文档中的示例不准确。

确切的源代码如下所示(可能因您的 tensorflow 版本而异):

if emit_structure is not None:
  flat_emit_structure = nest.flatten(emit_structure)
  flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else
                    array_ops.shape(emit) for emit in flat_emit_structure]
  flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure]
else:
  emit_structure = cell.output_size
  flat_emit_size = nest.flatten(emit_structure)
  flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size)

所以它处理了 emit_structure is None 的情况,只取值 cell.output_size。这就是为什么什么都不会坏的原因。