Keras 似乎忽略了我的 batch_size 并尝试将所有数据放入 GPU 内存中

Keras seem to ignore my batch_size and tries to fit all data in GPU memory

我有一个带有 2 个输入数组和 2 个输出数组的 RNN,大小为:

input1 = (339679, 90, 15)
input2 =(339679, 90, 27)
output1 = 339679,2
output2 = 339679,16

创建 RNN LSTM 的代码是(我只展示两个 RNN 之一,另一个是相同的但有 16 个输出并从 input2 获取输入大小):

inputs = Input(shape=(n_in1.shape[1], n_in1.shape[2]), name='inputs')                   
lstmA1 = LSTM(1024, return_sequences=True, name="lstmA1") (inputs)                    
lstmA2 = LSTM(512//1, return_sequences=True, name="lstmA2")(lstmA1)                   
lstmA3 = LSTM(512//2, return_sequences=True, name="lstmA3")(lstmA2)                   
lstmA4 = LSTM(512//2, return_sequences=True, name="lstmA4")(lstmA3)                   
lstmA5 = LSTM(512//4, return_sequences=False, name="lstmA5")(lstmA4)                  
denseA1 = DenseBND(lstmA5, 512//1, "denseA1", None, False, 0.2)                       
denseA2 = DenseBND(denseA1, 512//4, "denseA2", None, False, 0.2)                      
denseA3 = DenseBND(denseA2, 512//8, "denseA3", None, False, 0.2)                      
outputsA = Dense(2, name="outputsA")(denseA3)  

这里,n_in1是我之前描述的input1,所以给出的shape是90,15

DenseBND 只是一个函数,returns 具有 BatchNormalization 和 dropout 的密集层。在这种情况下,BatchNormalization 为 False,激活函数为 None,Dropout 为 0.2 所以它只是 returns 具有线性激活函数和 20% Dropout 的密集层。

最后,训练它的线:

model.fit( {'inputsA': np.array(n_input1), 'inputsB': np.array(n_input2)},
      {'outputsA': np.array(n_output1), 'outputsB': np.array(n_output2)},
      validation_split=0.1, epochs=1000, batch_size=256, 
      callbacks=callbacks_list)

可以看到validation_split是0.1,batch_size是256

然而,当我尝试训练它时,出现以下错误:

ResourceExhaustedError: OOM when allocating tensor with shape[335376,90,27] and type float on /job:

如您所见,它似乎试图将整个数据集放入 GPU 内存中,而不是逐批处理。我曾将 batch_size 设置为 1,但此错误仍然存​​在。第一个数字 335376 是我数据集的 90%(这个数字与上面的不同,上面的那个有效,这个无效)。

它不应该尝试分配形状为 256,90,27 的张量吗?

不,keras 不会忽略您的批量大小。

您正在尝试创建维度过大的 numpy 数组。

'inputsB': np.array(n_input2) 这分配了一个非常大的 numpy 数组,因此即使在训练开始之前,由于内存有限,这种 numpy 转换也是不可能的。

您需要使用不会立即将完整数据加载到内存中的数据生成器。

参考:https://keras.io/api/preprocessing/