AllenNLP:如何知道输出张量的哪个索引对应于哪个 class

AllenNLP: How to know which index of the output-tensor corresponds to which class

我正在使用 allennlp 2.1,我想将 class 权重传递给我使用的 pytorch-cross-entropy 损失函数。

@Head.register('model_head_two_layers')
class ModelHeadTwoLayers(Head):

    default_predictor = 'head_predictor'

    def __init__(self, vocab: Vocabulary, input_dim: int, output_dim: int, dropout: float = 0.0,
                 class_weights: Union[List[float], None] = None):
        super().__init__(vocab=vocab)
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.layers = torch.nn.Sequential(
            torch.nn.Dropout(dropout),
            torch.nn.Linear(self.input_dim, self.input_dim),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(self.input_dim, output_dim)
        )
        self.metrics = {
            'accuracy': CategoricalAccuracy(),
            'f1_macro': FBetaMeasure(average='macro')
        }
        if class_weights:
            self.class_weights = torch.FloatTensor(class_weights)
            self.cross_ent = torch.nn.CrossEntropyLoss(weight=self.class_weights)
        else:
            self.cross_ent = torch.nn.CrossEntropyLoss()

在配置文件中,我按如下方式传递 class 权重:

"heads": {
            "task_name": {
                "type": "model_head_two_layers",
                "input_dim": embedding_dim,
                "output_dim": 4,
                "dropout": dropout,
                "class_weights": [0.25, 0.90, 0.91, 0.94]
            }
        }

为了使 class 权重的顺序正确,我需要知道输出张量的哪个索引对应于哪个 class。据我所知,找到答案的唯一方法是首先训练一个没有 class 权重的模型,然后进入模型的词汇表目录并检查 class 名称的顺序被写入标签文件。

虽然这似乎可行...有没有更简单的方法来获得映射而无需先训练模型?

您可以使用 allennlp build-vocab 命令在不训练模型的情况下生成词汇表。但我认为这里更好的解决方案是将 class_weights 作为 label -> weight 的映射传递给您的模型,然后使用 __init__ 函数构建权重数组。像这样:

class ModelHeadTwoLayers(Head):
    def __init__(
        self,
        vocab: Vocabulary,
        input_dim: int,
        output_dim: int,
        dropout: float = 0.0,
        class_weights: Optional[Dict[str, float]] = None,
        label_namespace: str = "labels",
    ):
        super().__init__(vocab=vocab)

        # ...

        if class_weights:
            weights: List[float] = [0.0] * len(class_weights)
            for label, weight in class_weights.items():
                label_idx = self.vocab.get_token_index(label, namespace=label_namespace)
                weights[label_idx] = weight
            self.class_weights = torch.FloatTensor(weights)
            self.cross_ent = torch.nn.CrossEntropyLoss(weight=self.class_weights)
        else:
            self.cross_ent = torch.nn.CrossEntropyLoss()