为 numpy 数组重载“==”运算符

Overloading "==" operator for numpy arrays

我正在 Python 中定义一个需要检查的函数

if a==b:
  do.stuff()

原则上,ab 可以是 numpy 数组或整数,我希望我的实现能够对此进行健壮。但是,要检查 numpy 数组的相等性,需要使用 all() 附加布尔值,这将在 ab 为整数时破坏代码。

是否有一种简单的方法来编写相等性测试的代码,以便无论 ab 是整数还是 numpy 数组,它都能正常工作?

这个对数组和整数(数字)都有效的方法怎么样:

if np.array_equal(a,b):
    do.stuff()