PyTorch TensorBoard add_graph() 字典输入错误
PyTorch TensorBoard add_graph() dictionary input error
将 PyTorch 字典数据集 传递给 TensorBoard add_graph(model, data)
.
的正确方法是什么
可能看起来与 Question1
, Qeustion2
and Question3
相似,但是找不到正确的字典数据集处理方式。
错误信息
Dictionary inputs to traced functions must have consistent type. Found Tensor and List[str]
Error occurs, No graph saved
下面是我的项目的匿名脚本。
train.py
from torch.utils.tensorboard import SummaryWriter
from models import CustomModel
from datasets import CustomDataset
writer = SummaryWriter()
# Dataset
dataset = CustomDataset(params ...)
train_dataset = [dataset[i] for i in range(0, k)]
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# Model & TensorBoard
model = CustomModel(params....)
writer.add_graph(model, next(iter(train_dataloader))) # ---- HERE ----
datasets.py
class CustomDataset(Dataset):
def __init__(self, ...):
...
self.x_sequences = pad_sequence(x_sequences, batch_first=True, padding_value=0)
self.y_label = torch.LongTensor(label_list)
...
def __len__(self):
return len(self.y_label)
def __getitem__(self, index):
...
return {
"x_categoricals": self.x_categoricals[index],
"x_sequences": self.x_sequences[index],
"y_label": self.y_label[index],
"info": self.info[index],
}
错误消息告诉您字典的条目必须都是同一类型,但在您的情况下,您似乎在一个条目中有一个 Tensor
,但在一个条目中有 list
个字符串另一个条目。您必须确保所有条目都具有相同的类型。
将 PyTorch 字典数据集 传递给 TensorBoard add_graph(model, data)
.
可能看起来与 Question1
, Qeustion2
and Question3
相似,但是找不到正确的字典数据集处理方式。
错误信息
Dictionary inputs to traced functions must have consistent type. Found Tensor and List[str]
Error occurs, No graph saved
下面是我的项目的匿名脚本。
train.py
from torch.utils.tensorboard import SummaryWriter
from models import CustomModel
from datasets import CustomDataset
writer = SummaryWriter()
# Dataset
dataset = CustomDataset(params ...)
train_dataset = [dataset[i] for i in range(0, k)]
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# Model & TensorBoard
model = CustomModel(params....)
writer.add_graph(model, next(iter(train_dataloader))) # ---- HERE ----
datasets.py
class CustomDataset(Dataset):
def __init__(self, ...):
...
self.x_sequences = pad_sequence(x_sequences, batch_first=True, padding_value=0)
self.y_label = torch.LongTensor(label_list)
...
def __len__(self):
return len(self.y_label)
def __getitem__(self, index):
...
return {
"x_categoricals": self.x_categoricals[index],
"x_sequences": self.x_sequences[index],
"y_label": self.y_label[index],
"info": self.info[index],
}
错误消息告诉您字典的条目必须都是同一类型,但在您的情况下,您似乎在一个条目中有一个 Tensor
,但在一个条目中有 list
个字符串另一个条目。您必须确保所有条目都具有相同的类型。