PyTorch torch.no_grad() 与 requires_grad=False

PyTorch torch.no_grad() versus requires_grad=False

我正在关注 PyTorch tutorial,它使用来自 Huggingface T运行sformers 库的 BERT NLP 模型(特征提取器)。梯度更新有两段相互关联的代码没看懂

(1) torch.no_grad()

本教程有一个 class,其中 forward() 函数围绕对 BERT 特征提取器的调用创建了一个 torch.no_grad() 块,如下所示:

bert = BertModel.from_pretrained('bert-base-uncased')

class BERTGRUSentiment(nn.Module):
    
    def __init__(self, bert):
        super().__init__()
        self.bert = bert
        
    def forward(self, text):
        with torch.no_grad():
            embedded = self.bert(text)[0]

(2) param.requires_grad = False

同一教程中还有另一部分冻结了 BERT 参数。

for name, param in model.named_parameters():                
    if name.startswith('bert'):
        param.requires_grad = False

我什么时候需要 (1) and/or (2)?

此外,我 运行 所有四种组合并发现:

   with torch.no_grad   requires_grad = False  Parameters  Ran
   ------------------   ---------------------  ----------  ---
a. Yes                  Yes                      3M        Successfully
b. Yes                  No                     112M        Successfully
c. No                   Yes                      3M        Successfully
d. No                   No                     112M        CUDA out of memory

有人可以解释一下这是怎么回事吗? 为什么我得到 CUDA out of memory 的 (d) 而不是 (b)?两者都有 112M 可学习参数。

这是一个较早的讨论,多年来略有变化(主要是由于 with torch.no_grad() 作为一种模式的目的。可以找到一个很好的答案来回答你的问题 .
但是,由于原始问题有很大不同,我不会将其标记为重复,特别是由于第二部分是关于内存的。

给出no_grad的初步解释here:

with torch.no_grad() is a context manager and is used to prevent calculating gradients [...].

requires_grad另一方面使用

to freeze part of your model and train the rest [...].

再次来源.

本质上,使用 requires_grad 你只是禁用网络的一部分,而 no_grad 根本不会存储 any 梯度,因为你是可能将其用于推理而不是训练。
为了分析您的参数组合的行为,让我们调查正在发生的事情:

  • a)b) 根本不存储任何梯度,这意味着无论参数数量多少,您都可以使用更多的内存,因为您没有保留它们潜在的向后传球。
  • c) 必须存储前向传播以供以后反向传播,但是,只存储了有限数量的参数(300 万),这使得这仍然是可管理的。
  • d),然而,需要为所有 1.12 亿 参数存储正向传递 ,这会导致你 运行 内存不足。