如何用 Python 数据类覆盖 OOP 继承中的属性

How to overwrite attributes in OOP inheritance with Python dataclass

目前,我有一些代码看起来像这样,删除了不相关的方法。

import math
import numpy as np
from decimal import Decimal
from dataclasses import dataclass, field
from typing import Optional, List

@dataclass
class A:
    S0: int
    K: int
    r: float = 0.05
    T: int = 1
    N: int = 2
    StockTrees: List[float] = field(init=False, default_factory=list)
    pu: Optional[float] = 0
    pd: Optional[float] = 0
    div: Optional[float] = 0
    sigma: Optional[float] = 0
    is_put: Optional[bool] = field(default=False)
    is_american: Optional[bool] = field(default=False)
    is_call: Optional[bool] = field(init=False)
    is_european: Optional[bool] = field(init=False)
    
    def __post_init__(self):
        self.is_call = not self.is_put
        self.is_european = not self.is_american
        
    @property
    def dt(self):
        return self.T/float(self.N)
    
    @property
    def df(self):
        return math.exp(-(self.r - self.div) * self.dt)

@dataclass
class B(A):

    u: float = field(init=False)
    d: float = field(init=False)
    qu: float = field(init=False)
    qd: float = field(init=False)
    
    def __post_init__(self):
        super().__post_init__()
        self.u = 1 + self.pu
        self.d = 1 - self.pd
        self.qu = (math.exp((self.r - self.div) * self.dt) - self.d)/(self.u - self.d)
        self.qd = 1 - self.qu
    
    
@dataclass
class C(B):
    def __post_init__(self):
        super().__post_init__()
        self.u = math.exp(self.sigma * math.sqrt(self.dt))
        self.d = 1/self.u
        self.qu = (math.exp((self.r - self.div)*self.dt) - self.d)/(self.u - self.d)
        self.qd = 1 - self.qu

基本上,我有一个 class A ,它定义了它的所有子 classes 将共享的一些属性,所以它只是真正意味着通过实例化来初始化它的子 classes 及其属性将由其子 classes 继承。子 class B 是一个进行某些计算的过程,该计算由 C 继承,后者进行相同计算的变体。 C基本上继承了B的所有方法,唯一不同的是self.uself.d的计算方式不同。

可以通过使用需要参数 pupdB 计算或需要参数 [=25] 的 C 计算来 运行 代码=],如下

if __name__ == "__main__":
    
    am_option = B(50, 52, r=0.05, T=2, N=2, pu=0.2, pd=0.2, is_put=True, is_american=True)
    print(f"{am_option.sigma = }")
    print(f"{am_option.pu = }")
    print(f"{am_option.pd = }")
    print(f"{am_option.qu = }")
    print(f"{am_option.qd = }")
    
    eu_option2 = C(50, 52, r=0.05, T=2, N=2, sigma=0.3, is_put=True)
    print(f"{am_option.sigma = }")
    print(f"{am_option.pu = }")
    print(f"{am_option.pd = }")
    print(f"{am_option.qu = }")
    print(f"{am_option.qd = }")

给出输出

am_option.pu = 0.2
am_option.pd = 0.2
am_option.qu = 0.6281777409400603
am_option.qd = 0.3718222590599397
Traceback (most recent call last):
  File "/home/dazza/option_pricer/test.py", line 136, in <module>
    eu_option2 = C(50, 52, r=0.05, T=2, N=2, sigma=0.3, is_put=True)
  File "<string>", line 15, in __init__
  File "/home/dazza/option_pricer/test.py", line 109, in __post_init__
    super().__post_init__()
  File "/home/dazza/option_pricer/test.py", line 55, in __post_init__
    self.qu = (math.exp((self.r - self.div) * self.dt) - self.d)/(self.u - self.d)
ZeroDivisionError: float division by zero

所以实例化 B 工作正常,因为它成功计算了值 pupdquqd。但是,当 C 的实例化无法计算 qu 时,我的问题就来了,因为 pupd 默认情况下为零,使其除以 0.

我的问题:如何解决这个问题,以便 C 继承 A 的所有属性初始化(包括 __post_init__)和 B 的所有方法,以及同时有它的计算self.u = math.exp(self.sigma * math.sqrt(self.dt))self.d = 1/self.u覆盖self.u = 1 + self.puself.d = 1 - self.pdB,以及保持 self.quself.qd 相同?(BC 相同)

Python 支持多重继承。您可以在 B 之前从 A 继承,这意味着任何重叠的方法都将从 A 中获取(例如 __post_init__)。您在 class C 中编写的任何代码都将覆盖从 AB 继承的代码。如果您需要更好地控制哪些方法来自哪些 class,您始终可以在 C 中定义方法并调用 AB 的函数(例如A.dt(self)).

class C(A, B):
    ...

另一项编辑: 我刚刚看到 AC 中初始化了一些你想要的东西。因为 C 的父级现在是 A(如果你使用我上面的代码),你可以在 super().__post_init__() 行中添加回 C__post_init__以便它调用 A__post_init__。 如果这不起作用,您可以随时将 A.__post_init__(self) 放在 C__post_init__ 中。

定义另一个方法来初始化 ud,这样你就可以覆盖 Bthat 部分而不用覆盖 quqd 已定义。

@dataclass
class B(A):

    u: float = field(init=False)
    d: float = field(init=False)
    qu: float = field(init=False)
    qd: float = field(init=False)
    
    def __post_init__(self):
        super().__post_init__()
        self._define_u_and_d()
        self.qu = (math.exp((self.r - self.div) * self.dt) - self.d)/(self.u - self.d)
        self.qd = 1 - self.qu

    def _define_u_and_d(self):
        self.u = 1 + self.pu
        self.d = 1 - self.pd



@dataclass
class C(B):
    def _define_u_and_d(self):
        self.u = math.exp(self.sigma * math.sqrt(self.dt))
        self.d = 1/self.u