将维度为 (3,50) 的嵌入层连接到 lstm

connect embedding layer with dimension (3,50) to lstm

如何将维度为 (3,50) 的嵌入层连接到 lstm?

数组 (3, 50) 被馈送到输入“layer_i_emb”,其中存储了长度为 50 的数组的三个时间步长,其中存储了产品标识符

reshape 之前我尝试连接它,但它也没有用。嵌入增加了维度,而 lstm 不占用额外的维度。可怕的是,您必须将张量转换为 tf 并手动处理张量。

layer_i_inp = Input(shape = (3,50), name = 'item')
layer_i_emb = Embedding(output_dim = EMBEDDING_DIM*2,
                        input_dim = us_it_count[0]+1,
                        input_length = (3,50),
                        name = 'item_embedding')(layer_i_inp) 

layer_i_emb = Reshape([3,50, EMBEDDING_DIM*2])(layer_i_emb)

layer_i_emb = LSTM(MAX_FEATURES, dropout = 0.4, recurrent_dropout = 0.4, return_sequences = True)(layer_i_emb)
layer_i_emb = LSTM(MAX_FEATURES, dropout = 0.4, recurrent_dropout = 0.4, return_sequences = True)(layer_i_emb)
layer_i_emb = LSTM(MAX_FEATURES, dropout = 0.4, recurrent_dropout = 0.4)(layer_i_emb)

layer_i_emb = Flatten()(layer_i_emb)

问题是 Embedding 层正在输出 3D 张量,但是 LSTM 层需要 2D 输入(不包括批次维度)。您可以尝试以下几个选项:

选项 1

import tensorflow as tf

samples = 100
orders = 3
product_ids_per_order = 50
max_product_id = 120

data = tf.random.uniform((samples, orders, product_ids_per_order), maxval=max_product_id, dtype=tf.int32)
Y = tf.random.uniform((samples,), maxval=2, dtype=tf.int32)

EMBEDDING_DIM = 5

item_input = tf.keras.layers.Input(shape = (orders, product_ids_per_order), name = 'item')
embedding_layer = tf.keras.layers.Embedding(
                        max_product_id + 1,
                        output_dim = EMBEDDING_DIM,
                        input_length = product_ids_per_order,
                        name = 'item_embedding')

# Map each time step with 50 product ids to an embedding vector of size 5
outputs = []
for i in range(orders):
  tensor = embedding_layer(item_input[:, i, :])
  layer_i_emb = tf.keras.layers.LSTM(32, dropout = 0.4, recurrent_dropout = 0.4, return_sequences = True)(tensor)
  layer_i_emb = tf.keras.layers.LSTM(32, dropout = 0.4, recurrent_dropout = 0.4, return_sequences = True)(layer_i_emb)
  layer_i_emb = tf.keras.layers.LSTM(32, dropout = 0.4, recurrent_dropout = 0.4)(layer_i_emb)
  outputs.append(layer_i_emb)
  
output = tf.keras.layers.Concatenate(axis=1)(outputs)
output = tf.keras.layers.Dense(1, activation='sigmoid')(layer_i_emb)
model = tf.keras.Model(item_input, output)
model.compile(optimizer='adam', loss=tf.keras.losses.BinaryCrossentropy())
model.fit(data, Y)
4/4 [==============================] - 15s 1s/step - loss: 0.6926

选项 2

import tensorflow as tf

samples = 100
orders = 3
product_ids_per_order = 50
max_product_id = 120

EMBEDDING_DIM = 5

item_input = tf.keras.layers.Input(shape = (orders, product_ids_per_order), name = 'item')
embedding_layer = tf.keras.layers.Embedding(
                        max_product_id + 1,
                        output_dim = EMBEDDING_DIM,
                        input_length = product_ids_per_order,
                        name = 'item_embedding')

# Map each time step with 50 product ids to an embedding vector of size 5
inputs = []
for i in range(orders):
  tensor = embedding_layer(item_input[:, i, :])
  tensor = tf.keras.layers.Reshape([product_ids_per_order*EMBEDDING_DIM])(tensor)
  tensor = tf.expand_dims(tensor, axis=1)
  inputs.append(tensor)

embedding_inputs = tf.keras.layers.Concatenate(axis=1)(inputs)
layer_i_emb = tf.keras.layers.LSTM(32, dropout = 0.4, recurrent_dropout = 0.4, return_sequences = True)(embedding_inputs)
layer_i_emb = tf.keras.layers.LSTM(32, dropout = 0.4, recurrent_dropout = 0.4, return_sequences = True)(layer_i_emb)
layer_i_emb = tf.keras.layers.LSTM(32, dropout = 0.4, recurrent_dropout = 0.4)(layer_i_emb)
output = tf.keras.layers.Dense(1, activation='sigmoid')(layer_i_emb)
model = tf.keras.Model(item_input, output)
print(model.summary())
Model: "model_11"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 item (InputLayer)              [(None, 3, 50)]      0           []                               
                                                                                                  
 tf.__operators__.getitem_41 (S  (None, 50)          0           ['item[0][0]']                   
 licingOpLambda)                                                                                  
                                                                                                  
 tf.__operators__.getitem_42 (S  (None, 50)          0           ['item[0][0]']                   
 licingOpLambda)                                                                                  
                                                                                                  
 tf.__operators__.getitem_43 (S  (None, 50)          0           ['item[0][0]']                   
 licingOpLambda)                                                                                  
                                                                                                  
 item_embedding (Embedding)     (None, 50, 5)        605         ['tf.__operators__.getitem_41[0][
                                                                 0]',                             
                                                                  'tf.__operators__.getitem_42[0][
                                                                 0]',                             
                                                                  'tf.__operators__.getitem_43[0][
                                                                 0]']                             
                                                                                                  
 reshape_10 (Reshape)           (None, 250)          0           ['item_embedding[0][0]']         
                                                                                                  
 reshape_11 (Reshape)           (None, 250)          0           ['item_embedding[1][0]']         
                                                                                                  
 reshape_12 (Reshape)           (None, 250)          0           ['item_embedding[2][0]']         
                                                                                                  
 tf.expand_dims_9 (TFOpLambda)  (None, 1, 250)       0           ['reshape_10[0][0]']             
                                                                                                  
 tf.expand_dims_10 (TFOpLambda)  (None, 1, 250)      0           ['reshape_11[0][0]']             
                                                                                                  
 tf.expand_dims_11 (TFOpLambda)  (None, 1, 250)      0           ['reshape_12[0][0]']             
                                                                                                  
 concatenate_13 (Concatenate)   (None, 3, 250)       0           ['tf.expand_dims_9[0][0]',       
                                                                  'tf.expand_dims_10[0][0]',      
                                                                  'tf.expand_dims_11[0][0]']      
                                                                                                  
 lstm_34 (LSTM)                 (None, 3, 32)        36224       ['concatenate_13[0][0]']         
                                                                                                  
 lstm_35 (LSTM)                 (None, 3, 32)        8320        ['lstm_34[0][0]']                
                                                                                                  
 lstm_36 (LSTM)                 (None, 32)           8320        ['lstm_35[0][0]']                
                                                                                                  
 dense_11 (Dense)               (None, 1)            33          ['lstm_36[0][0]']                
                                                                                                  
==================================================================================================
Total params: 53,502
Trainable params: 53,502
Non-trainable params: 0
__________________________________________________________________________________________________
None