如何在Python2中实现中缀算子矩阵乘法?

How to implement infix operator matrix multiplication in Python 2?

我已经使用 python 3.7 几个月了,但最近不得不转向 python 2.7。由于我正在开发科学代码,因此我严重依赖使用中缀运算符 @ 来乘以 nd 数组。此运算符是在 python 3.5 中引入的(请参阅 here),因此,我无法在我的新设置中使用它。

显而易见的解决方案是将所有 M1 @ M2 替换为 numpy.matmul(M1, M2),这严重限制了我的代码的可读性。

我看到了这个 hack,它定义了一个中缀 class,允许通过重载 orror 运算符来创建自定义运算符。我的问题是:如何使用这个技巧使中缀 |at| 运算符像 @ 一样工作?

我试过的是:

import numpy as np

class Infix:
    def __init__(self, function):
        self.function = function
    def __ror__(self, other):      
        return Infix(lambda x, self=self, other=other: self.function(other, x))
    def __or__(self, other):
        return self.function(other)
    def __call__(self, value1, value2):
        return self.function(value1, value2)

# Matrix multiplication
at = Infix(lambda x,y: np.matmul(x,y))

M1 = np.ones((2,3))
M2 = np.ones((3,4))

print(M1 |at| M2)

当我执行这段代码时,我得到:

ValueError: operands could not be broadcast together with shapes (2,3) (3,4) 

我想我知道什么是行不通的。当我只看M1|at时,我可以看到它是一个2*3的函数数组:

array([[<__main__.Infix object at 0x7faa1c0d6da0>,
        <__main__.Infix object at 0x7faa1c0d6860>,
        <__main__.Infix object at 0x7faa1c0d6828>],
       [<__main__.Infix object at 0x7faa1c0d6f60>,
        <__main__.Infix object at 0x7faa1c0d61d0>,
        <__main__.Infix object at 0x7faa1c0d64e0>]], dtype=object)

这不是我所期望的,因为我希望我的代码将这个二维数组视为一个整体,而不是逐个元素...

有人知道我应该做什么吗?

PS:我也考虑过使用this answer,但我必须避免使用外部模块。

我找到了问题的解决方案 here

正如评论中所建议的那样,理想的解决方法是使用 Python 3.x 或使用 numpy.matmul,但这段代码似乎有效,甚至具有正确的优先级:

import numpy as np

class Infix(np.ndarray):
    def __new__(cls, function):
        obj = np.ndarray.__new__(cls, 0)
        obj.function = function
        return obj
    def __array_finalize__(self, obj):
        if obj is None: return
        self.function = getattr(obj, 'function', None)
    def __rmul__(self, other):
        return Infix(lambda x, self=self, other=other: self.function(other, x))
    def __mul__(self, other):
        return self.function(other)
    def __call__(self, value1, value2):
        return self.function(value1, value2)

at = Infix(np.matmul)

M1 = np.ones((2,3))
M2 = np.ones((3,4))
M3 = np.ones((2,4))

print(M1 *at* M2)
print(M3 + M1 *at* M2)