如何在 pytorch 中为 NNI(神经网络智能)实现平均标准误差(MSE)指标?

How implement a Mean Standard Error (MSE) metric for NNI (Neural network intelligence) in pytorch?

我对 pytorch 有点陌生,因为我已经使用 Keras 多年了。现在我想 运行 基于 DARTS 的网络架构搜索 (NAS):可区分架构搜索(参见 https://nni.readthedocs.io/en/stable/NAS/DARTS.html),它基于 pytorch。

所有可用示例都使用 accuracy 作为指标,但我需要计算 MSE。 这是可用的示例之一:

DartsTrainer(model,
                               loss=criterion,
                               metrics=lambda output, target: accuracy(output, target, topk=(1,)),
                               optimizer=optim,
                               num_epochs=args.epochs,
                               dataset_train=dataset_train,
                               dataset_valid=dataset_valid,
                               batch_size=args.batch_size,
                               log_frequency=args.log_frequency,
                               unrolled=args.unrolled,
                               callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")]) 

# where the accuracy is defined in a separate function:

def accuracy(output, target, topk=(1,)):
    # Computes the precision@k for the specified values of k
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    # one-hot case
    if target.ndimension() > 1:
        target = target.max(1)[1]

    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = dict()
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
    return res

正如我在 pytorch 中看到的那样,计算指标比在 Keras 中更复杂。有人可以帮忙吗?

作为试用,我写了这段代码:

def accuracy_mse(output, target):
    batch_size = target.size(0)
    
    diff = torch.square(output.t()-target)/batch_size
    diff = diff.sum()

    res = dict()

    res["acc_mse"] = diff
    return res    

它似乎有效,但我不是 100% 确定...

最后我发现问题出在转置 (.t()) 上,所以最终代码是:

def accuracy_mse(output, target):
   
    """ Computes the mse """
    batch_size = target.size(0)
    
    diff = torch.square(output-target)/batch_size
    diff = diff.sum()
    res = dict()

    res["mse"] = diff

    return res