在 pytorch "torchmetrics" 中使用 Dice 指标:dice_score() 缺少 2 个必需的位置参数:'preds' 和 'target'

Using Dice metric in pytorch "torchmetrics" : dice_score() missing 2 required positional arguments: 'preds' and 'target'

我正在尝试使用 pytorch“torchmetrics”中的 Dice 指标。我找到了一个使用准确度指标的例子。如下所示:

from torchmetrics.classification import Accuracy

train_accuracy = Accuracy()
valid_accuracy = Accuracy()

for epoch in range(epochs):
    for x, y in train_data:
        y_hat = model(x)

        # training step accuracy
        batch_acc = train_accuracy(y_hat, y)
        print(f"Accuracy of batch{i} is {batch_acc}")

    for x, y in valid_data:
        y_hat = model(x)
        valid_accuracy.update(y_hat, y)

    # total accuracy over all training batches
    total_train_accuracy = train_accuracy.compute()

    # total accuracy over all validation batches
    total_valid_accuracy = valid_accuracy.compute()

    print(f"Training acc for epoch {epoch}: {total_train_accuracy}")
    print(f"Validation acc for epoch {epoch}: {total_valid_accuracy}")

    # Reset metric states after each epoch
    train_accuracy.reset()
    valid_accuracy.reset() 

但是,当我将“Accuracy()”替换为“Dice_score()”时。如下所示:

from torchmetrics.functional import dice_score

train_accuracy =dice_score()
valid_accuracy =dice_score()

我遇到以下错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-43-726045592283> in <module>
      3 from torchmetrics.functional import dice_score
      4 
----> 5 train_accuracy_2 =dice_score()# Accuracy()
      6 valid_accuracy_2 =dice_score()# Accuracy()
      7 

TypeError: dice_score() missing 2 required positional arguments: 'preds' and 'target' 

是否有使用 "Dice" 来自 "torchmetrics"

指标的示例

torchmetrics.classification.dice_score 是 Dice 分数的功能接口。这意味着它是一个无状态函数,需要基本事实和预测。似乎没有骰子分数的模块接口,就像准确度一样。

torchmetrics.classification.Accuracy is a class that maintains state. Under the hood, it uses the functional interface, which is torchmetrics.functional.accuracy.

这不会以任何方式强制执行,但通常 类 以 CamelCase 命名,函数以 snake_case.

命名