fast.ai 课程第 8 课 g 属性的问题

question with fast.ai course lesson 8 g attribute

在课程fast.ai 2019第8课中,反向传播中使用了一个奇怪的g属性,我检查了torch.Tensor这个属性没有'存在。我试图在调用方法中打印 inp.g/out.g 的值,但我得到了 AttributeError: 'Tensor' object has no attribute 'g',但我能够获得 inp.g/out.g value before the assignment in backward,这个g属性是如何工作的?

class Linear():
    def __init__(self, w, b):
        self.w, self.b = w, b

    def __call__(self, inp):
        print('in lin call')
        self.inp = inp
        self.out = inp@self.w + self.b
        try:
            print('out.g', self.out.g)
        except Exception as e:
            print('out.g dne yet')
        return self.out

    def backward(self):
        print('out.g', self.out.g)
        self.inp.g = self.out.g @ self.w.t()
        self.w.g = (self.inp.unsqueeze(-1) * self.out.g.unsqueeze(1)).sum(0)
        self.b.g = self.out.g.sum(0)

link to full code from the course

-更新-

我能够弄清楚 self.out.g 值与成本函数 MSE self.inp.g 完全相同,但仍然无法弄清楚该值是如何传递到最后一个线性层的。

class MSE():
    def __call__(self, inp, targ):
        self.inp = inp
        self.targ = targ
        self.out = (inp.squeeze() - targ).pow(2).mean()
        return self.out

    def backward(self):
        self.inp.g = 2. * (self.inp.squeeze() - self.targ).unsqueeze(-1) \
                        / self.targ.shape[0]
        print('in mse backward', self.inp.g)

class Model():
    def __init__(self, w1, b1, w2, b2):
        self.layers = [Lin(w1, b1), Relu(), Lin(w2, b2)]
        self.loss = Mse()

    def __call__(self, x, targ):
        for l in self.layers:
            x = l(x)
        return self.loss(x, targ)

    def backward(self):
        self.loss.backward()
        for l in reversed(self.layers):
            l.backward()

基本上这必须处理 python 赋值的工作方式(指针,类似于 C 指针的工作方式)。在使用 id(variable name) 跟踪变量后,我能够弄清楚 g 属性是如何出现的。

# ... in model (forward pass)...
    x = layer(x) # from linear layer >> return self.out and is assigned to x

# ...
    return self.loss(x, targ) # x is the same x (id) obtained from the model

# ========

# ... in model (backward pass) ...
    self.loss.backward() # this is how the self.inp.g came by 

# ... in linear ...
    self.inp.g = self.out.g @ self.w.t() 
    # this self.out.g is the same instance as self.inp.g from loss