具有 2x2 输入的双向 GRU

Bidirectional GRU with 2x2 inputs

我正在构建一个网络,它将字符串拆分为单词,将单词拆分为字符,嵌入每个字符,然后通过将字符聚合为单词并将单词聚合为字符串来计算该字符串的向量表示。聚合是用双向 gru 层进行的,注意。
为了测试这个东西,假设我对这个字符串中的 5 个单词和 5 个字符感兴趣。在这种情况下,我的转换是:

["Some string"] -> ["Some","strin","","",""] -> 
["Some_","string","_____","_____","_____"] where _ is the padding symbol ) -> 
[[1,2,3,4,0],[1,5,6,7,8],[0,0,0,0,0],[0,0,0,0,0],[0,0,0,0,0]] (shape 5x5)

接下来我有一个嵌入层,它将每个字符变成一个长度为 6 的嵌入向量。所以我的特征变成了 5x5x6 矩阵。然后我将这个输出传递给双向 gru 层并执行一些其他操作,我相信在这种情况下并不重要。

问题是当我 运行 它带有一个迭代器时,比如

for string in strings:
    output = model(string)

它似乎工作得很好(字符串是一个由 5x5 切片创建的 tf 数据集),所以它是一堆 5 x 5 矩阵。

然而,当我转到训练或在数据集级别使用预测等函数工作时,模型失败了:

model.predict(strings.batch(1))
ValueError: Input 0 of layer bidirectional is incompatible with the layer: expected ndim=3, found ndim=4. Full shape received: (None, 5, 5, 6)

据我从文档中了解到,双向层将 3d 张量作为输入:[batch, timesteps, feature],因此在这种情况下,我的输入形状应如下所示:[batch_size,timesteps ,(5,5,6)]

所以问题是我应该对输入数据应用哪种转换来获得这种形状?

对于双向输入层,如果您使用的是 GRU,请使用 return_sequences=True,以获得 3 维输出。由于 GRU 输出是 2D,return_sequences 将为您提供 3D 输出。对于堆叠双向层输入应为 3D 形状。

示例代码

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
model = keras.Sequential()

model.add(
    layers.Bidirectional(layers.GRU(64, return_sequences=True), input_shape=(5, 10))
)
model.add(layers.Bidirectional(layers.GRU(32)))
model.add(layers.Dense(10))

model.summary()

输出

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
bidirectional_3 (Bidirection (None, 5, 128)            38400     
_________________________________________________________________
bidirectional_4 (Bidirection (None, 64)                41216     
_________________________________________________________________
dense_2 (Dense)              (None, 10)                650       
=================================================================
Total params: 80,266
Trainable params: 80,266
Non-trainable params: 0
___________________________