在 pytorch "torchmetrics" 中使用 Dice 指标:dice_score() 缺少 2 个必需的位置参数:'preds' 和 'target'
Using Dice metric in pytorch "torchmetrics" : dice_score() missing 2 required positional arguments: 'preds' and 'target'
我正在尝试使用 pytorch“torchmetrics”中的 Dice 指标。我找到了一个使用准确度指标的例子。如下所示:
from torchmetrics.classification import Accuracy
train_accuracy = Accuracy()
valid_accuracy = Accuracy()
for epoch in range(epochs):
for x, y in train_data:
y_hat = model(x)
# training step accuracy
batch_acc = train_accuracy(y_hat, y)
print(f"Accuracy of batch{i} is {batch_acc}")
for x, y in valid_data:
y_hat = model(x)
valid_accuracy.update(y_hat, y)
# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()
# total accuracy over all validation batches
total_valid_accuracy = valid_accuracy.compute()
print(f"Training acc for epoch {epoch}: {total_train_accuracy}")
print(f"Validation acc for epoch {epoch}: {total_valid_accuracy}")
# Reset metric states after each epoch
train_accuracy.reset()
valid_accuracy.reset()
但是,当我将“Accuracy()”替换为“Dice_score()”时。如下所示:
from torchmetrics.functional import dice_score
train_accuracy =dice_score()
valid_accuracy =dice_score()
我遇到以下错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-43-726045592283> in <module>
3 from torchmetrics.functional import dice_score
4
----> 5 train_accuracy_2 =dice_score()# Accuracy()
6 valid_accuracy_2 =dice_score()# Accuracy()
7
TypeError: dice_score() missing 2 required positional arguments: 'preds' and 'target'
是否有使用 "Dice" 来自 "torchmetrics"
指标的示例
torchmetrics.classification.dice_score
是 Dice 分数的功能接口。这意味着它是一个无状态函数,需要基本事实和预测。似乎没有骰子分数的模块接口,就像准确度一样。
torchmetrics.classification.Accuracy
is a class that maintains state. Under the hood, it uses the functional interface, which is torchmetrics.functional.accuracy
.
这不会以任何方式强制执行,但通常 类 以 CamelCase 命名,函数以 snake_case.
命名
我正在尝试使用 pytorch“torchmetrics”中的 Dice 指标。我找到了一个使用准确度指标的例子。如下所示:
from torchmetrics.classification import Accuracy
train_accuracy = Accuracy()
valid_accuracy = Accuracy()
for epoch in range(epochs):
for x, y in train_data:
y_hat = model(x)
# training step accuracy
batch_acc = train_accuracy(y_hat, y)
print(f"Accuracy of batch{i} is {batch_acc}")
for x, y in valid_data:
y_hat = model(x)
valid_accuracy.update(y_hat, y)
# total accuracy over all training batches
total_train_accuracy = train_accuracy.compute()
# total accuracy over all validation batches
total_valid_accuracy = valid_accuracy.compute()
print(f"Training acc for epoch {epoch}: {total_train_accuracy}")
print(f"Validation acc for epoch {epoch}: {total_valid_accuracy}")
# Reset metric states after each epoch
train_accuracy.reset()
valid_accuracy.reset()
但是,当我将“Accuracy()”替换为“Dice_score()”时。如下所示:
from torchmetrics.functional import dice_score
train_accuracy =dice_score()
valid_accuracy =dice_score()
我遇到以下错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-43-726045592283> in <module>
3 from torchmetrics.functional import dice_score
4
----> 5 train_accuracy_2 =dice_score()# Accuracy()
6 valid_accuracy_2 =dice_score()# Accuracy()
7
TypeError: dice_score() missing 2 required positional arguments: 'preds' and 'target'
是否有使用 "Dice" 来自 "torchmetrics"
指标的示例torchmetrics.classification.dice_score
是 Dice 分数的功能接口。这意味着它是一个无状态函数,需要基本事实和预测。似乎没有骰子分数的模块接口,就像准确度一样。
torchmetrics.classification.Accuracy
is a class that maintains state. Under the hood, it uses the functional interface, which is torchmetrics.functional.accuracy
.
这不会以任何方式强制执行,但通常 类 以 CamelCase 命名,函数以 snake_case.
命名