PyTorch TensorBoard add_graph() 字典输入错误

PyTorch TensorBoard add_graph() dictionary input error

将 PyTorch 字典数据集 传递给 TensorBoard add_graph(model, data).

的正确方法是什么

可能看起来与 Question1, Qeustion2 and Question3 相似,但是找不到正确的字典数据集处理方式。

错误信息

Dictionary inputs to traced functions must have consistent type. Found Tensor and List[str]
Error occurs, No graph saved

下面是我的项目的匿名脚本。

train.py

from torch.utils.tensorboard import SummaryWriter
from models import CustomModel
from datasets import CustomDataset
writer = SummaryWriter()

# Dataset
dataset = CustomDataset(params ...)
train_dataset = [dataset[i] for i in range(0, k)]
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Model & TensorBoard
model = CustomModel(params....)
writer.add_graph(model, next(iter(train_dataloader)))  # ---- HERE ----

datasets.py

class CustomDataset(Dataset):
    def __init__(self, ...):
        ...
        self.x_sequences = pad_sequence(x_sequences, batch_first=True, padding_value=0)
        self.y_label = torch.LongTensor(label_list)
        ...

    def __len__(self):
        return len(self.y_label)
        
    def __getitem__(self, index):
        ...
        return {
            "x_categoricals": self.x_categoricals[index],
            "x_sequences": self.x_sequences[index],
            "y_label": self.y_label[index],
            "info": self.info[index],
        }

错误消息告诉您字典的条目必须都是同一类型,但在您的情况下,您似乎在一个条目中有一个 Tensor,但在一个条目中有 list 个字符串另一个条目。您必须确保所有条目都具有相同的类型。