使用 Tfdata 的多输入 Keras 模型

Multi Input Keras Model with Tfdata

我有一个模型,我正在尝试使用 8 个输入进行设置。前 7 个是长度为 1 的 ID,每个 ID 都被送入嵌入层,这些输出与一组 4 个数字变量连接。

所以在模型定义中包括:

input_A = keras.Input(shape=(1,))
input_B = keras.Input(shape=(1,))
input_C = keras.Input(shape=(1,))    
input_D = keras.Input(shape=(1,))    
input_E = keras.Input(shape=(1,)) 
input_F = keras.Input(shape=(1,))
input_G = keras.Input(shape=(1,))

input_nums = keras.Input(shape=(4,))

embed_A = keras.layers.Embedding(1223 + 1, 50)(input_A)
embed_B = keras.layers.Embedding(50 + 1, 25)(input_B)
embed_C = keras.layers.Embedding(1259 + 1, 50)(input_C)
embed_D = keras.layers.Embedding(3995 + 1, 50)(input_D)
embed_E = keras.layers.Embedding(2040 + 1, 50)(input_E)
embed_F = keras.layers.Embedding(174 + 1, 50)(input_F)
embed_G = keras.layers.Embedding(227 + 1, 50)(input_G)

embed_A = keras.layers.Flatten()(embed_A)
embed_B = keras.layers.Flatten()(embed_B)
embed_C = keras.layers.Flatten()(embed_C)
embed_D = keras.layers.Flatten()(embed_D)
embed_E = keras.layers.Flatten()(embed_E)
embed_F = keras.layers.Flatten()(embed_F)
embed_G = keras.layers.Flatten()(embed_G)

x = keras.layers.concatenate([embed_A,embed_B,embed_C,embed_D,embed_E,embed_F,embed_G,input_nums])

然后构建模型:

模型=keras.Model(输入=[input_A,input_B,input_C,input_D,input_E,input_F, input_G, input_nums], 输出 = [out])

在 tfdataset 映射函数中,我尝试像这样构造输入数据,但拟合模型会产生错误:

# keras needs:  Should return a tuple of either (inputs, targets) or (inputs, targets, sample_weights).
return  (
        (example["A"],example["B"],example["C"],
         example["D"],example["E"],example["F"],
         example["G"],
         
         (example[‘num_A'],example[' num_B '],example[' num_C'],example[' num_D '])
         ),
    
         label)


ValueError: Layer model expects 8 input(s), but it received 11 input tensors

如何设置 tfdataset 的映射函数以使用此模型?

我发现它作为地图函数中的 return 工作:

return   (example["A"],example["B"],example["C"],
         example["D"],example["E"],example["F"],
         example["G"],
         
         [example[‘num_A'],example[' num_B '],example[' num_C'],example[' num_D ']]
         ),
    
         label