将图像大小作为第二个输入添加到现有 PyTorch 模型
Adding image size as second input to existing PyTorch model
我在 PyTorch 中使用预训练的 torchvision
模型并将学习迁移到 class 验证我自己的数据集。这工作正常,但我认为我可以进一步提高我的 classification 性能。我们的图像有不同的尺寸,所有的图像都被调整大小以适应我的模型的输入(例如 224x224 像素)。
不过,原图大小往往会说很多class这张图所属的。所以我认为这可能有助于模型将原始图像维度作为第二个输入添加到模型中。
目前我在 PyTorch 中构建我的模型是这样的:
model = resnet50(pretrained=True) # Could be another base model as well
for module, param in zip(model.modules(), model.parameters()):
if isinstance(module, nn.BatchNorm2d):
param.requires_grad = False
model.fc = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(256, num_classes),
)
现在我如何向该模型添加另一个(二维?)输入,以便我可以将原始图像的 x 和 y 维度提供给模型?此外,哪里 最有意义 - 直接进入模型的“开始”,或者更好的“中间”某处?
将数据注入模型的一种方法是直接注入线性层。
这会有不影响转换层的缺点。
请注意,我注入到最后一层,但这可以进入任何层。
model.start = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.25),
)
model.end = nn.Sequential(
nn.Linear(256 + 2, num_classes),
)
你的 forward
应该是(伪代码)类似于
def forward(x):
x1 = model.start(x)
mid = torch.concatenate([x, extra_2d_data])
x2 = model.end(mid)
return x2
另见 this
我在 PyTorch 中使用预训练的 torchvision
模型并将学习迁移到 class 验证我自己的数据集。这工作正常,但我认为我可以进一步提高我的 classification 性能。我们的图像有不同的尺寸,所有的图像都被调整大小以适应我的模型的输入(例如 224x224 像素)。
不过,原图大小往往会说很多class这张图所属的。所以我认为这可能有助于模型将原始图像维度作为第二个输入添加到模型中。
目前我在 PyTorch 中构建我的模型是这样的:
model = resnet50(pretrained=True) # Could be another base model as well
for module, param in zip(model.modules(), model.parameters()):
if isinstance(module, nn.BatchNorm2d):
param.requires_grad = False
model.fc = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(256, num_classes),
)
现在我如何向该模型添加另一个(二维?)输入,以便我可以将原始图像的 x 和 y 维度提供给模型?此外,哪里 最有意义 - 直接进入模型的“开始”,或者更好的“中间”某处?
将数据注入模型的一种方法是直接注入线性层。
这会有不影响转换层的缺点。
请注意,我注入到最后一层,但这可以进入任何层。
model.start = nn.Sequential(
nn.Linear(2048, 512),
nn.ReLU(),
nn.Dropout(0.25),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.25),
)
model.end = nn.Sequential(
nn.Linear(256 + 2, num_classes),
)
你的 forward
应该是(伪代码)类似于
def forward(x):
x1 = model.start(x)
mid = torch.concatenate([x, extra_2d_data])
x2 = model.end(mid)
return x2
另见 this