在 PyTorch 中提前停止
early stopping in PyTorch
我尝试实现提前停止功能以避免我的神经网络模型过度拟合。我很确定逻辑没问题,但由于某种原因,它不起作用。
我希望当验证损失大于某些时期的训练损失时,提前停止函数 returns 为真。但它 returns 一直都是假的,即使验证损失变得比训练损失大得多。请问您能看出问题出在哪里吗?
提前停止功能
def early_stopping(train_loss, validation_loss, min_delta, tolerance):
counter = 0
if (validation_loss - train_loss) > min_delta:
counter +=1
if counter >= tolerance:
return True
在训练期间调用函数
for i in range(epochs):
print(f"Epoch {i+1}")
epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
train_loss.append(epoch_train_loss)
# validation
with torch.no_grad():
epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
validation_loss.append(epoch_validate_loss)
# early stopping
if early_stopping(epoch_train_loss, epoch_validate_loss, min_delta=10, tolerance = 20):
print("We are at epoch:", i)
break
编辑:
训练和验证损失:
编辑 2:
def train_validate (model, train_dataloader, validate_dataloader, loss_func, optimiser, device, epochs):
preds = []
train_loss = []
validation_loss = []
min_delta = 5
for e in range(epochs):
print(f"Epoch {e+1}")
epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
train_loss.append(epoch_train_loss)
# validation
with torch.no_grad():
epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
validation_loss.append(epoch_validate_loss)
# early stopping
early_stopping = EarlyStopping(tolerance=2, min_delta=5)
early_stopping(epoch_train_loss, epoch_validate_loss)
if early_stopping.early_stop:
print("We are at epoch:", e)
break
return train_loss, validation_loss
你的实现的问题是,每当你调用 early_stopping()
时,计数器是 re-initialized 和 0
。
这是使用 oo-oriented 方法的工作解决方案,其中 __call__()
和 __init__()
代替:
class EarlyStopping():
def __init__(self, tolerance=5, min_delta=0):
self.tolerance = tolerance
self.min_delta = min_delta
self.counter = 0
self.early_stop = False
def __call__(self, train_loss, validation_loss):
if (validation_loss - train_loss) > self.min_delta:
self.counter +=1
if self.counter >= self.tolerance:
self.early_stop = True
这样称呼它:
early_stopping = EarlyStopping(tolerance=5, min_delta=10)
for i in range(epochs):
print(f"Epoch {i+1}")
epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
train_loss.append(epoch_train_loss)
# validation
with torch.no_grad():
epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
validation_loss.append(epoch_validate_loss)
# early stopping
early_stopping(epoch_train_loss, epoch_validate_loss)
if early_stopping.early_stop:
print("We are at epoch:", i)
break
示例:
early_stopping = EarlyStopping(tolerance=2, min_delta=5)
train_loss = [
642.14990234,
601.29278564,
561.98400879,
530.01501465,
497.1098938,
466.92709351,
438.2364502,
413.76028442,
391.5090332,
370.79074097,
]
validate_loss = [
509.13619995,
497.3125,
506.17315674,
497.68960571,
505.69918823,
459.78610229,
480.25592041,
418.08630371,
446.42675781,
372.09902954,
]
for i in range(len(train_loss)):
early_stopping(train_loss[i], validate_loss[i])
print(f"loss: {train_loss[i]} : {validate_loss[i]}")
if early_stopping.early_stop:
print("We are at epoch:", i)
break
输出:
loss: 642.14990234 : 509.13619995
loss: 601.29278564 : 497.3125
loss: 561.98400879 : 506.17315674
loss: 530.01501465 : 497.68960571
loss: 497.1098938 : 505.69918823
loss: 466.92709351 : 459.78610229
loss: 438.2364502 : 480.25592041
We are at epoch: 6
我尝试实现提前停止功能以避免我的神经网络模型过度拟合。我很确定逻辑没问题,但由于某种原因,它不起作用。 我希望当验证损失大于某些时期的训练损失时,提前停止函数 returns 为真。但它 returns 一直都是假的,即使验证损失变得比训练损失大得多。请问您能看出问题出在哪里吗?
提前停止功能
def early_stopping(train_loss, validation_loss, min_delta, tolerance):
counter = 0
if (validation_loss - train_loss) > min_delta:
counter +=1
if counter >= tolerance:
return True
在训练期间调用函数
for i in range(epochs):
print(f"Epoch {i+1}")
epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
train_loss.append(epoch_train_loss)
# validation
with torch.no_grad():
epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
validation_loss.append(epoch_validate_loss)
# early stopping
if early_stopping(epoch_train_loss, epoch_validate_loss, min_delta=10, tolerance = 20):
print("We are at epoch:", i)
break
编辑:
训练和验证损失:
编辑 2:
def train_validate (model, train_dataloader, validate_dataloader, loss_func, optimiser, device, epochs):
preds = []
train_loss = []
validation_loss = []
min_delta = 5
for e in range(epochs):
print(f"Epoch {e+1}")
epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
train_loss.append(epoch_train_loss)
# validation
with torch.no_grad():
epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
validation_loss.append(epoch_validate_loss)
# early stopping
early_stopping = EarlyStopping(tolerance=2, min_delta=5)
early_stopping(epoch_train_loss, epoch_validate_loss)
if early_stopping.early_stop:
print("We are at epoch:", e)
break
return train_loss, validation_loss
你的实现的问题是,每当你调用 early_stopping()
时,计数器是 re-initialized 和 0
。
这是使用 oo-oriented 方法的工作解决方案,其中 __call__()
和 __init__()
代替:
class EarlyStopping():
def __init__(self, tolerance=5, min_delta=0):
self.tolerance = tolerance
self.min_delta = min_delta
self.counter = 0
self.early_stop = False
def __call__(self, train_loss, validation_loss):
if (validation_loss - train_loss) > self.min_delta:
self.counter +=1
if self.counter >= self.tolerance:
self.early_stop = True
这样称呼它:
early_stopping = EarlyStopping(tolerance=5, min_delta=10)
for i in range(epochs):
print(f"Epoch {i+1}")
epoch_train_loss, pred = train_one_epoch(model, train_dataloader, loss_func, optimiser, device)
train_loss.append(epoch_train_loss)
# validation
with torch.no_grad():
epoch_validate_loss = validate_one_epoch(model, validate_dataloader, loss_func, device)
validation_loss.append(epoch_validate_loss)
# early stopping
early_stopping(epoch_train_loss, epoch_validate_loss)
if early_stopping.early_stop:
print("We are at epoch:", i)
break
示例:
early_stopping = EarlyStopping(tolerance=2, min_delta=5)
train_loss = [
642.14990234,
601.29278564,
561.98400879,
530.01501465,
497.1098938,
466.92709351,
438.2364502,
413.76028442,
391.5090332,
370.79074097,
]
validate_loss = [
509.13619995,
497.3125,
506.17315674,
497.68960571,
505.69918823,
459.78610229,
480.25592041,
418.08630371,
446.42675781,
372.09902954,
]
for i in range(len(train_loss)):
early_stopping(train_loss[i], validate_loss[i])
print(f"loss: {train_loss[i]} : {validate_loss[i]}")
if early_stopping.early_stop:
print("We are at epoch:", i)
break
输出:
loss: 642.14990234 : 509.13619995
loss: 601.29278564 : 497.3125
loss: 561.98400879 : 506.17315674
loss: 530.01501465 : 497.68960571
loss: 497.1098938 : 505.69918823
loss: 466.92709351 : 459.78610229
loss: 438.2364502 : 480.25592041
We are at epoch: 6