为什么火炬修剪实际上没有删除过滤器或权重?

Why doesn't torch pruning actually remove filters or weights?

我使用一种架构并尝试通过修剪来稀疏它。我写了修剪函数,这是其中之一:

def prune_model_l1_unstructured(model, layer_type, proportion):
    for module in model.modules():
        if isinstance(module, layer_type):
            prune.l1_unstructured(module, 'weight', proportion)
            prune.remove(module, 'weight')
    return model

# prune model
prune_model_l1_unstructured(model, nn.Conv2d, 0.5)

它会修剪一些权重(将它们更改为零)。但是 prune.remove 只是删除原始权重并保留零。参数总量仍然相同(我检查过)。模型的文件 (model.pt) 大小也一样。并且模型的“速度”在它之后仍然相同。我也尝试了全局剪枝和结构化 L1 剪枝,结果是一样的。那么这如何帮助改善模型的性能时间呢?为什么不删除权重以及如何删除修剪的连接?

TLDR; PyTorch prune 的功能只是作为一个权重掩码,仅此而已。使用 torch.nn.utils.prune.
不会节省内存

正如 torch.nn.utils.prune.remove 的文档页面所述:

Removes the pruning reparameterization from a module and the pruning method from the forward hook.

实际上,这意味着它将从参数中删除 prune.l1_unstructured 添加的掩码。作为副作用,删除修剪将意味着在先前屏蔽的值上有零,但这些值不会保留为 0s。最后,与不使用相比,PyTorch prune 只会占用更多内存。所以这实际上不是您正在寻找的功能。

您可以在 this comment 上阅读更多内容。


这是一个例子:

>>> module = nn.Linear(10,3)
>>> prune.l1_unstructured(module, name='weight', amount=0.3)

屏蔽了权重参数:

>>> module.weight
tensor([[-0.0000,  0.0000, -0.1397, -0.0942,  0.0000,  0.0000,  0.0000, -0.1452,
          0.0401,  0.1098],
        [ 0.2909, -0.0000,  0.2871,  0.1725,  0.0000,  0.0587,  0.0795, -0.1253,
          0.0764, -0.2569],
        [ 0.0000, -0.3054, -0.2722,  0.2414,  0.1737, -0.0000, -0.2825,  0.0685,
          0.1616,  0.1095]], grad_fn=<MulBackward0>)

这是面具:

>>> module.weight_mask
tensor([[0., 0., 1., 1., 0., 0., 0., 1., 1., 1.],
        [1., 0., 1., 1., 0., 1., 1., 1., 1., 1.],
        [0., 1., 1., 1., 1., 0., 1., 1., 1., 1.]])

请注意,在应用 prune.remove 时,删除了修剪。并且,掩码值保持为零,但 "unfrozen":

>>> prune.remove(module, 'weight')

>>> module.weight
tensor([[-0.0000,  0.0000, -0.1397, -0.0942,  0.0000,  0.0000,  0.0000, -0.1452,
          0.0401,  0.1098],
        [ 0.2909, -0.0000,  0.2871,  0.1725,  0.0000,  0.0587,  0.0795, -0.1253,
          0.0764, -0.2569],
        [ 0.0000, -0.3054, -0.2722,  0.2414,  0.1737, -0.0000, -0.2825,  0.0685,
          0.1616,  0.1095]], grad_fn=<MulBackward0>)

面具不见了:

>>> hasattr(module, 'weight_mask')
False