调用在基础 class 上定义的函数时,如何获取 subclass return 类型?

How can I to get subclass return types when calling functions defined on the base class?

我正在尝试在 Python 中编写一个 class 层次结构,以便子 classes 可以覆盖方法 predict 以获得更窄的 return 类型本身就是父 return 类型的子 class 类型。当我实例化 subclass 的实例并调用 predict 时,这似乎工作正常; returned 值具有预期的窄类型。但是,当我调用在基 class (predict_batch) 上定义的另一个函数时,它本身调用 predict,窄 return 类型丢失。

一些上下文:我的程序必须支持使用两种类型的图像分割模型,“实例”和“语义”。这两个模型的输出非常不同,所以我想用对称的 class 层次结构来存储它们的输出(即 BaseResultInstResultSemResult)。这将允许一些客户端代码在不需要知道使用了哪种特定类型的模型时通过使用 BaseResults 来通用。

这是一个玩具代码示例:

from abc import ABC, abstractmethod
from typing import List

from overrides import overrides

##################
# Result classes #
##################


class BaseResult(ABC):
    """Abstract container class for result of image segmentation"""

    pass


class InstResult(BaseResult):
    """Stores the result of instance segmentation"""

    pass


class SemResult(BaseResult):
    """Stores the result of semantic segmentation"""

    pass


#################
# Model classes #
#################


class BaseModel(ABC):
    def predict_batch(self, images: List) -> List[BaseResult]:
        return [self.predict(img) for img in images]

    @abstractmethod
    def predict(self, image) -> BaseResult:
        raise NotImplementedError()


class InstanceSegModel(BaseModel):
    """performs instance segmentation on images"""

    @overrides
    def predict(self, image) -> InstResult:
        return InstResult()


class SemanticSegModel(BaseModel):
    """performs semantic segmentation on images"""

    @overrides
    def predict(self, image) -> SemResult:
        return SemResult()


########
# main #
########

# placeholder for illustration 
images = [None, None, None]

model = InstanceSegModel()
single_result = model.predict(images[0])  # has type InstResult
batch_result = model.predict_batch(images)  # has type List[BaseResult]

在上面的代码中,我希望 batch_result 的类型为 List[InstResult]

在运行时,none 这很重要,我的代码执行得很好。但是我的编辑器 (VS Code) 中的静态类型检查器 (Pylance) 不喜欢客户端代码如何假定 batch_result 是更窄的类型。我只能想到这两种可能的解决方案,但对我来说都不干净:

  1. 使用 typing 模块中的 cast 函数
  2. 覆盖子class中的predict_batch,即使逻辑没有改变

您可以将 generics 和继承一起用于 override/narrow 父项中的注释 class

from typing import List, Generic, TypeVar

T = TypeVar('T')


class BaseModel(ABC, Generic[T]):
    def predict_batch(self, images: List) -> List[T]:
        return [self.predict(img) for img in images]

    @abstractmethod
    def predict(self, image) -> T:
        raise NotImplementedError()


class InstanceSegModel(BaseModel[InstResult]):
    """performs instance segmentation on images"""

    @overrides
    def predict(self, image) -> InstResult:
        return InstResult()


class SemanticSegModel(BaseModel[SemResult]):
    """performs semantic segmentation on images"""

    @overrides
    def predict(self, image) -> SemResult:
        return SemResult()


images = [None, None, None]

model = InstanceSegModel()
single_result = model.predict(images[0])  # has type InstResult
batch_result = model.predict_batch(images)  # has type List[InstResult]