Python 继承和关键字参数

Python inheritance and keyword arguments

我正在为 pytorch 变压器编写一个包装器。为了简单起见,我将包括一个最小的例子。一个 class Parent,这将是 class for classes My_BERT(Parent)[= 的抽象 class 41=] 和 My_GPT2(父级)。因为model_Bertmodel_gpt2的LM模型都包含在pytorch中,它们有很多类似的功能,因此我想通过在 Partent.

中编码其他相同的函数来最小化代码冗余

My_bertMy_gpt2 与模型初始化基本不同,一个参数传递给模型,但是 99% 的函数以相同的方式使用这两种模型。

问题在于接受不同参数的函数 "model":

  • 对于model_Bert它被定义为模型(input_ids,masekd_lm_labels)
  • 对于model_gpt2它被定义为model(input_ids, labels)

最小代码示例:

class Parent():
    """ My own class that is an abstract class for My_bert and My_gpt2 """
    def __init__(self):
        pass

    def fancy_arithmetic(self, text):
        print("do_fancy_stuff_that_works_identically_for_both_models(text=text)")

    def compute_model(self, text):
        return self.model(input_ids=text, masked_lm_labels=text) #this line works for My_Bert
        #return self.model(input_ids=text, labels=text) #I'd need this line for My_gpt2

class My_bert(Parent): 
    """ My own My_bert class that is initialized with BERT pytorch 
    model (here model_bert), and uses methods from Parent """
    def __init__(self):
        self.model = model_bert()

class My_gpt2(Parent):
    """ My own My_gpt2 class that is initialized with gpt2 pytorch model (here model_gpt2), and uses methods from Parent """
    def __init__(self):
        self.model = model_gpt2()

class model_gpt2:
    """ This class mocks pytorch transformers gpt2 model, thus I'm writing just bunch of code that allows you run this example"""
    def __init__(self):
        pass

    def __call__(self,*input, **kwargs):
        return self.model( *input, **kwargs)

    def model(self, input_ids, labels):
        print("gpt2")

class model_bert:
    """ This class mocks pytorch transformers bert model"""
    def __init__(self):
        pass

    def __call__(self, *input, **kwargs):
        self.model(*input, **kwargs)

    def model(self, input_ids, masked_lm_labels):
        print("bert")


foo = My_bert()
foo.compute_model("bar")  # this works
bar = My_gpt2()
#bar.compute_model("rawr") #this does not work.

我知道我可以覆盖 My_bertMy_gpt2 中的 Parent::compute_model 函数 classes.

但是由于两个 "model" 方法非常相似,我想知道是否有办法说: " 我将传递给你三个参数,你可以使用你知道的那些"

def compute_model(self, text):
    return self.model(input_ids=text, masked_lm_labels=text, labels=text) # ignore the arguments you dont know

*args**kwargs 应该解决您 运行 遇到的问题。

在您的代码中,您将修改 compute_model 以采用任意参数

def compute_model(self, *args, **kwargs):
    return self.model(*args, **kwargs)

现在参数将由 model 方法在不同的 类

上定义

通过此更改,以下内容应该起作用:

foo = My_bert()
foo.compute_model("bar", "baz")
bar = My_gpt2()
bar.compute_model("rawr", "baz")

如果您不熟悉 args 和 kwargs,它们允许您将任意参数传递给函数。 args 将采用未命名参数并按照它们被接收到函数的顺序传递它们 kwargs 或关键字参数采用命名参数并将它们传递给正确的参数。因此以下内容也将起作用:

foo = My_bert()
foo.compute_model(input_ids="bar", masked_lm_labels="baz") 
bar = My_gpt2()
bar.compute_model(input_ids="rawr", labels="baz") 

注意名称 args 和 kwargs 是没有意义的,你可以给它们起任何名字,但典型的约定是 args 和 kwargs