Pixel RNN Pytorch 实现

Pixel RNN Pytorch Implementation

我正在尝试在 pytorch 中实现 Pixel RNN,但我似乎找不到任何相关文档。 Pixel RNN 的主要部分是 Row LSTM 和 BiDiagonal LSTM,所以我正在寻找这些算法的一些代码以更好地理解它们在做什么。具体来说,我对这些算法分别一次计算一行和对角线感到困惑。任何帮助将不胜感激。

总结

这是正在进行的部分实施:

https://github.com/carpedm20/pixel-rnn-tensorflow

这里是 google deepmind 对 Row LSTM 和 BiDiagonal LSTM 的描述:

https://towardsdatascience.com/summary-of-pixelrnn-by-google-deepmind-7-min-read-938d9871d6d9


行 LSTM

来自链接的 deepmind 博客:

一个像素的隐藏状态,下图中的红色,是基于它前面的三角形三个像素的"memory"。因为它们在"row"中,我们可以并行计算,加快计算速度。我们牺牲了一些上下文信息(使用更多的历史或内存)来进行这种并行计算并加快训练速度。

实际的实现依赖于其他几个优化,并且非常复杂。来自 original paper:

The computation proceeds as follows. An LSTM layer has an input-to-state component and a recurrent state-to-state component that together determine the four gates inside the LSTM core. To enhance parallelization in the Row LSTM the input-to-state component is first computed for the entire two-dimensional input map; for this a k × 1 convolution is used to follow the row-wise orientation of the LSTM itself. The convolution is masked to include only the valid context (see Section 3.4) and produces a tensor of size 4h × n × n, representing the four gate vectors for each position in the input map, where h is the number of output feature maps. To compute one step of the state-to-state component of the LSTM layer, one is given the previous hidden and cell states hi−1 and ci−1, each of size h × n × 1. The new hidden and cell states hi , ci are obtained as follows:

where xi of size h × n × 1 is row i of the input map, and ~ represents the convolution operation and the elementwise multiplication. The weights Kss and Kis are the kernel weights for the state-to-state and the input-to-state components, where the latter is precomputed as described above. In the case of the output, forget and input gates oi , fi and ii , the activation σ is the logistic sigmoid function, whereas for the content gate gi , σ is the tanh function. Each step computes at once the new state for an entire row of the input map

对角BLSTM

对角线 BLSTM 的开发是为了在不牺牲尽可能多的上下文信息的情况下利用并行化的加速。 DBLSTM 中的节点向其左侧和上方看;由于这些节点也向左和上方看,因此给定节点的条件概率在某种意义上取决于其所有祖先。否则,架构非常相似。来自 deepmind 博客: