loss.backward() 是要在每个样本上还是在每个批次上调用?
Is loss.backward() meant to be called on each sample or on each batch?
我有一个包含不同大小特征的训练数据集。我理解这在网络架构方面的含义,并相应地设计了我的网络来处理这些异构形状。但是,当谈到我的训练循环时,我对 optimizer.zero_grad()
、loss.backward()
和 optimizer.step()
.
的 order/placement 感到困惑
由于特征大小不等,我无法同时对一批特征进行前向传递。因此,我的训练循环手动循环批次样本,如下所示:
for epoch in range(NUM_EPOCHS):
for bidx, batch in enumerate(train_loader):
optimizer.zero_grad()
batch_loss = 0
for sample in batch:
feature1 = sample['feature1']
feature2 = sample['feature2']
label1 = sample['label1']
label2 = sample['label2']
pred_l1, pred_l2 = model(feature1, feature2)
sample_loss = compute_loss(label1, pred_l1)
sample_loss += compute_loss(label2, pred_l2)
sample_loss.backward() # CHOICE 1
batch_loss += sample_loss.item()
# batch_loss.backward() # CHOICE 2
optimizer.step()
我想知道在这里对每个 sample_loss
调用向后调用是否有意义,优化器步骤调用每个 BATCH_SIZE
个样本(选择 1)。我认为,另一种方法是向后调用 batch_loss
(选择 2),我不确定哪个是正确的选择。
微分是一个 linear 运算,因此理论上,您是先微分不同的损失并添加它们的导数,还是先添加损失然后计算它们的总和的导数,这并不重要。
因此出于实际目的,它们都应该导致相同的结果(忽略通常的浮点问题)。
您的内存要求和计算速度可能会略有不同(我猜第二个版本可能会稍快一些。),但这很难预测,但您可以通过对两个版本进行计时来轻松发现.
我有一个包含不同大小特征的训练数据集。我理解这在网络架构方面的含义,并相应地设计了我的网络来处理这些异构形状。但是,当谈到我的训练循环时,我对 optimizer.zero_grad()
、loss.backward()
和 optimizer.step()
.
由于特征大小不等,我无法同时对一批特征进行前向传递。因此,我的训练循环手动循环批次样本,如下所示:
for epoch in range(NUM_EPOCHS):
for bidx, batch in enumerate(train_loader):
optimizer.zero_grad()
batch_loss = 0
for sample in batch:
feature1 = sample['feature1']
feature2 = sample['feature2']
label1 = sample['label1']
label2 = sample['label2']
pred_l1, pred_l2 = model(feature1, feature2)
sample_loss = compute_loss(label1, pred_l1)
sample_loss += compute_loss(label2, pred_l2)
sample_loss.backward() # CHOICE 1
batch_loss += sample_loss.item()
# batch_loss.backward() # CHOICE 2
optimizer.step()
我想知道在这里对每个 sample_loss
调用向后调用是否有意义,优化器步骤调用每个 BATCH_SIZE
个样本(选择 1)。我认为,另一种方法是向后调用 batch_loss
(选择 2),我不确定哪个是正确的选择。
微分是一个 linear 运算,因此理论上,您是先微分不同的损失并添加它们的导数,还是先添加损失然后计算它们的总和的导数,这并不重要。
因此出于实际目的,它们都应该导致相同的结果(忽略通常的浮点问题)。
您的内存要求和计算速度可能会略有不同(我猜第二个版本可能会稍快一些。),但这很难预测,但您可以通过对两个版本进行计时来轻松发现.