使用pytorch的不平衡数据的焦点损失
focal loss for imbalanced data using pytorch
我想使用 pytorch 对多类不平衡数据使用焦点损失。我搜索了并尝试使用此代码,但出现错误
class_weights=tf.constant([0.21, 0.45, 0.4, 0.46, 0.48, 0.49])
loss_fn=nn.CrossEntropyLoss(weight=class_weights,reduction='mean')
并在训练函数中使用它
preds = model(sent_id, mask, labels)
# compu25te the validation loss between actual and predicted values
alpha=0.25
gamma=2
ce_loss = loss_fn(preds, labels)
pt = torch.exp(-ce_loss)
focal_loss = (alpha * (1-pt)**gamma * ce_loss).mean()
错误是
TypeError: cannot assign 'tensorflow.python.framework.ops.EagerTensor' object to buffer 'weight' (torch Tensor or None required)
在这一行
loss_fn=nn.CrossEntropyLoss(weight=class_weights,reduction='mean')
您正在混合使用 tensorflow 和 pytorch 对象。
尝试:
class_weights=torch.tensor([0.21, ...], requires_grad=False)
我想使用 pytorch 对多类不平衡数据使用焦点损失。我搜索了并尝试使用此代码,但出现错误
class_weights=tf.constant([0.21, 0.45, 0.4, 0.46, 0.48, 0.49])
loss_fn=nn.CrossEntropyLoss(weight=class_weights,reduction='mean')
并在训练函数中使用它
preds = model(sent_id, mask, labels)
# compu25te the validation loss between actual and predicted values
alpha=0.25
gamma=2
ce_loss = loss_fn(preds, labels)
pt = torch.exp(-ce_loss)
focal_loss = (alpha * (1-pt)**gamma * ce_loss).mean()
错误是
TypeError: cannot assign 'tensorflow.python.framework.ops.EagerTensor' object to buffer 'weight' (torch Tensor or None required)
在这一行
loss_fn=nn.CrossEntropyLoss(weight=class_weights,reduction='mean')
您正在混合使用 tensorflow 和 pytorch 对象。
尝试:
class_weights=torch.tensor([0.21, ...], requires_grad=False)