无法使用 importlib 重新加载从 nn.Module 继承的 class

Unable to reload class inherited from nn.Module with importlib

我正在尝试使用 importlib 重新加载 class,但是我遇到了一个错误,指出它不是一个模块,这是一个 jupyter notebook。

Class代码

import torch.nn as nn
import torch.nn.functional as F
import torch

class FeedForwardNeuralNetwork(nn.Module):
    def __init__(self, input_size, layers_data, random_seed=42):
        super(FeedForwardNeuralNetwork, self).__init__()
        torch.manual_seed(random_seed)
        # So that number of dense layers are configurable.
        self.layers = nn.ModuleList()
        for size, activation in layers_data:
            self.layers.append(nn.Linear(input_size, size, bias=False))
            torch.nn.init.xavier_uniform(self.layers[-1].weight)
            self.layers.append(activation)
            input_size = size 
    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        for layer in self.layers:
            x = layer(x)
        return x

重新加载代码

import importlib
from feed_forward_neural_network import FeedForwardNeuralNetwork

print((type(nn.Module)))
print(type(FeedForwardNeuralNetwork))
importlib.reload(FeedForwardNeuralNetwork)

错误

<class 'type'>
<class 'type'>
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [53], in <module>
      4 print((type(nn.Module)))
      5 print(type(FeedForwardNeuralNetwork))
----> 6 importlib.reload(FeedForwardNeuralNetwork)

File /usr/lib/python3.9/importlib/__init__.py:140, in reload(module)
    134 """Reload the module and return it.
    135 
    136 The module must have been successfully imported before.
    137 
    138 """
    139 if not module or not isinstance(module, types.ModuleType):
--> 140     raise TypeError("reload() argument must be a module")
    141 try:
    142     name = module.__spec__.name

TypeError: reload() argument must be a module

环境详细信息

reload() 参数应该是一个模块。你需要写成

importlib.reload(sys.modules.get(FeedForwardNeuralNetwork.__module__))

<module 'feed_forward_neural_network' from '/home/harshit/Downloads/feed_forward_neural_network/feed_forward_neural_network.py'>