Pytorch:批处理中每个图像的文件特定操作

Pytorch: File-specific action for each image in the batch

我有一个图像数据集,每个图像都有一个附加属性“channel_no”。每个图像都应该根据其 channel_no:

用 nn 层进行处理
 images with channel_no=1 have to be processed with layer1
 images with channel_no=2 have to be processed with layer2
 images with channel_no=3 have to be processed with layer3
etc...

问题在于,当批次包含多张图像时,forward() 函数以批次图像作为输入获取火炬张量,并且每张图像具有不同的channel_no。所以不清楚如何分别处理每张图片

这是批处理只有 1 张图像时的代码:

class Net(nn.Module):
    def __init__ (self, weight):
        super(Net, self).__init__()

        self.layer1 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.layer2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.layer3 = nn.Linear(hidden_sizes[0], hidden_sizes[1])

        self.outp = nn.Linear(hidden_sizes[1], output_size)
        
    def forward(self, x, channel_no):
        channel_no = channel_no[0] #extract channel_no from the batch list

        x = x.view(-1,hidden_sizes[0])

        if channel_no == 1: x = F.relu(self.layer1(x))
        if channel_no == 2: x = F.relu(self.layer2(x))
        if channel_no == 3: x = F.relu(self.layer3(x))

        x = torch.sigmoid(self.outp(x))

        return x    

是否可以使用 > 1 的批量大小分别处理每个图像?

要单独处理图像,您可能需要单独的张量。我不确定是否有快速的方法来做到这一点,但您可以在批量维度中拆分张量以获得单独的图像张量,然后遍历它们以按通道号对它们进行排序。然后将每组具有相同通道号的图像连接成一个新的张量,并对该张量进行特殊处理。