__eq__ Python 中的命令执行

__eq__ order enforcement in Python

稍微长一点的问题,足以说明背景...

假设有一个内置的 class A:

class A:
    def __init__(self, a=None):
        self.a = a
    def __eq__(self, other):
        return self.a == other.a

预计这样比较:

a1, a2 = A(1), A(2)
a1 == a2  # False

出于某种原因,该团队在其之上引入了一个包装器(该代码示例实际上并未包装 A 以简化代码复杂性。

class WrapperA:
    def __init__(self, a=None):
        self.pa = a
    def __eq__(self, other):
        return self.pa == other.pa

同样,预计这样比较:

wa1, wa2 = WrapperA(1), WrapperA(2)
wa1 == wa2  # False

虽然预期使用 AWrapperA,但问题是某些代码库包含这两种用法,因此以下比较失败:

a, wa = A(), WrapperA()
wa == a  # AttributeError
a == wa  # AttributeError

一个已知的解决方案是修改__eq__:

对于wa == a

class WrapperA:
    def __init__(self, a=None):
        self.pa = a
    def __eq__(self, other):
        if isinstance(other, A):
            return self.pa == other.a
        return self.pa == other.pa

对于a == wa

class A:
    def __init__(self, a=None):
        self.a = a
    def __eq__(self, other):
        if isinstance(other, WrapperA):
            return self.a == other.pa
        return self.a == other.a

需要修改 WrapperA。对于A,因为是内置的东西,所以有两种解决方案:

  1. 使用 setattr 扩展 A 以支持 WrapperA。
setattr(A, '__eq__', eq_that_supports_WrapperA)
  1. 强制开发人员只比较 wa == a(然后不关心 a == wa)。

第一个选项显然是重复实现的丑陋,第二个选项给开发人员带来了不必要的“惊喜”。所以我的问题是,是否有一种优雅的方法可以在内部通过 Python 实现将 a == wa 的任何用法替换为 wa == a

我不太喜欢这整件事,因为我认为包装内置函数并使用不同的属性名称会导致意想不到的结果,但无论如何,这对你有用

import inspect


class A:
    def __init__(self, a=None):
        self.a = a

    def __eq__(self, other):
        return self.a == other.a


class WrapperA:
    def __init__(self, a=None):
        self.pa = a

    def __eq__(self, other):
        if isinstance(other, A):
            return self.pa == other.a
        return self.pa == other.pa

    def __getattribute__(self, item):
        # Figure out who tried to get the attribute
        # If the item requested was 'a', check if A's __eq__ method called us,
        # in that case return pa instead
        caller = inspect.stack()[1]
        if item == 'a' and getattr(caller, 'function') == '__eq__' and isinstance(caller.frame.f_locals.get('self'), A):
            return super(WrapperA, self).__getattribute__('pa')
        return super(WrapperA, self).__getattribute__(item)

a = A(5)
wrap_a = WrapperA(5)

print(a == wrap_a)
print(wrap_a == a)

wrap_a.pa = 7
print(a == wrap_a)
print(wrap_a == a)
print(f'{wrap_a.pa=}')

输出:

True
True
False
False
wrap_a.pa=7

引用MisterMiyagi在问题下的评论:

Note that == is generally expected to work across all types. A.__eq__ requiring other to be an A is actually a bug that should be fixed. It should at the very least return NotImplemented when it cannot make a decision

这很重要,而不仅仅是风格问题。事实上,根据 the documentation:

When a binary (or in-place) method returns NotImplemented the interpreter will try the reflected operation on the other type.

因此,如果您只是应用 MisterMiyagi 的评论并修复 __eq__ 的逻辑,您会发现您的代码已经可以正常工作了:

class A:
    def __init__(self, a=None):
        self.a = a

    def __eq__(self, other):
        if isinstance(other, A):
            return self.a == other.a
        return NotImplemented


class WrapperA:
    def __init__(self, a=None):
        self.pa = a

    def __eq__(self, other):
        if isinstance(other, A):
            return self.pa == other.a
        elif isinstance(other, WrapperA):
            return self.pa == other.pa
        return NotImplemented

# Trying it
a = A(5)
wrap_a = WrapperA(5)

print(a == wrap_a)
print(wrap_a == a)

wrap_a.pa = 7
print(a == wrap_a)
print(wrap_a == a)
print(f'{wrap_a.pa=}')

产量:

True
True
False
False
wrap_a.pa=7

在幕后,a == wrap_a 首先调用 A.__eq__,returns NotImplemented。 Python 然后自动尝试 WrapperA.__eq__

类似于 Ron Serruya 的回答:

这使用 __getattr__ 而不是 __getattribute__,其中第一个仅在第二个引发 AttributeError 或显式调用它时调用它 (ref)。这意味着如果包装器没有实现 __eq__ 并且等式应该 对底层数据结构执行(存储在 class A 的对象中), 一个工作示例由:

class A(object):
  def __init__(self, internal_data=None):
    self._internal_data = internal_data

  def __eq__(self, other):
    return self._internal_data == other._internal_data

class WrapperA(object):
  def __init__(self, a_object: A):
    self._a = a_object

  def __getattr__(self, attribute):
    if attribute != '_a':  # This is neccessary to prevent recursive calls
      return getattr(self._a, attribute)

a1 = A(internal_data=1)
a2 = A(internal_data=2)

wa1 = WrapperA(a1)
wa2 = WrapperA(a2)    

print(
    a1 == a1,
    a1 == a2,
    wa1 == wa1,
    a1 == wa1,
    a2 == wa2,
    wa1 == a1)

>>> True False True True True True