如何为 pytorch multi-class 问题设置交叉熵损失目标
How to set target in cross entropy loss for pytorch multi-class problem
问题陈述:我有一张图片,图片的一个像素只能属于Band5','Band6', 'Band7'
之一(详情见下文)。因此,我有一个 pytorch multi-class 问题,但我无法理解如何设置需要 [batch, w, h]
形式的目标
我的数据加载器return 两个值:
x = chips.loc[:, :, :, self.input_bands]
y = chips.loc[:, :, :, self.output_bands]
x = x.transpose('chip','channel','x','y')
y_ohe = y.transpose('chip','channel','x','y')
此外,我定义了:
input_bands = ['Band1','Band2', 'Band3', 'Band3', 'Band4'] # input classes
output_bands = ['Band5','Band6', 'Band7'] #target classes
model = ModelName(num_classes = 3, depth=default_depth, in_channels=5, merge_mode='concat').to(device)
loss_new = nn.CrossEntropyLoss()
在我的训练函数中:
#get values from dataloader
X = normalize_zero_to_one(X) #input
y = normalize_zero_to_one(y) #target
images = Variable(torch.from_numpy(X)).to(device) # [batch, channel, H, W]
masks = Variable(torch.from_numpy(y)).to(device)
optim.zero_grad()
outputs = model(images)
loss = loss_new(outputs, masks) # (preds, target)
loss.backward()
optim.step() # Update weights
我知道目标(此处masks
)应该是[batch_size, w, h]
。但是,目前 [batch_size, channels, w, h]
.
我读了很多帖子,包括 1, 2,他们说 the target should only contain the target class indices
。我不明白如何连接三个 classes 的索引并仍然将目标设置为 [batch_size, w, h]
.
现在,我收到错误:
RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4
据我所知,我不需要进行任何一次热编码。我在网上找到的类似错误和解释在这里:'
任何帮助将不胜感激!谢谢。
如果我没理解错的话,你目前的 "target" 是 [batch_size, channels, w, h]
和 channels==3
因为你有三个可能的目标。
您的目标中的 值 代表什么?您基本上每个像素都有一个 3 向量目标 - 这些是预期的 class 概率吗?它们 "one-hot-vectors" 是否表示正确的 "band"?
如果是这样,您只需沿目标通道维度取 argmax
即可获得目标索引:
proper_target = torch.argmax(masks, dim=1) # make sure keepdim=False
loss = loss_new(outputs, proper_target)
问题陈述:我有一张图片,图片的一个像素只能属于Band5','Band6', 'Band7'
之一(详情见下文)。因此,我有一个 pytorch multi-class 问题,但我无法理解如何设置需要 [batch, w, h]
我的数据加载器return 两个值:
x = chips.loc[:, :, :, self.input_bands]
y = chips.loc[:, :, :, self.output_bands]
x = x.transpose('chip','channel','x','y')
y_ohe = y.transpose('chip','channel','x','y')
此外,我定义了:
input_bands = ['Band1','Band2', 'Band3', 'Band3', 'Band4'] # input classes
output_bands = ['Band5','Band6', 'Band7'] #target classes
model = ModelName(num_classes = 3, depth=default_depth, in_channels=5, merge_mode='concat').to(device)
loss_new = nn.CrossEntropyLoss()
在我的训练函数中:
#get values from dataloader
X = normalize_zero_to_one(X) #input
y = normalize_zero_to_one(y) #target
images = Variable(torch.from_numpy(X)).to(device) # [batch, channel, H, W]
masks = Variable(torch.from_numpy(y)).to(device)
optim.zero_grad()
outputs = model(images)
loss = loss_new(outputs, masks) # (preds, target)
loss.backward()
optim.step() # Update weights
我知道目标(此处masks
)应该是[batch_size, w, h]
。但是,目前 [batch_size, channels, w, h]
.
我读了很多帖子,包括 1, 2,他们说 the target should only contain the target class indices
。我不明白如何连接三个 classes 的索引并仍然将目标设置为 [batch_size, w, h]
.
现在,我收到错误:
RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4
据我所知,我不需要进行任何一次热编码。我在网上找到的类似错误和解释在这里:'
任何帮助将不胜感激!谢谢。
如果我没理解错的话,你目前的 "target" 是 [batch_size, channels, w, h]
和 channels==3
因为你有三个可能的目标。
您的目标中的 值 代表什么?您基本上每个像素都有一个 3 向量目标 - 这些是预期的 class 概率吗?它们 "one-hot-vectors" 是否表示正确的 "band"?
如果是这样,您只需沿目标通道维度取 argmax
即可获得目标索引:
proper_target = torch.argmax(masks, dim=1) # make sure keepdim=False
loss = loss_new(outputs, proper_target)