张量行矩阵的点积

Dot Product of Tensor row matrices

我有一个张量,我需要以矢量化方式对其所有行矩阵进行点积:
a = np.zeros((3,4,2,2))+1
这是一个 3x4 张量,元素是 2x2 矩阵。我需要对每行中的 2x2 矩阵进行点积。
结果应该是一个 3x1 矩阵,其中包含一个 2x2 矩阵,其中填充了 8s
我试过了

a = np.zeros((3,4,2,2))+1
np.prod(a, axis= 1) 

但它只给出元素乘积:

array([[[1., 1.],
        [1., 1.]],

       [[1., 1.],
        [1., 1.]],

       [[1., 1.],
        [1., 1.]]])

我需要一个向量化函数,而不是 for 循环。 如果有人有使用 NumPy 或 Scipy 的解决方案,我将不胜感激,因为 TensorFlow 是一个需要包含的巨大依赖项。

怎么样

def np_multi_matmul(tensor: np.ndarray, axis: int) -> np.ndarray:
    arrays = np.split(tensor, tensor.shape[axis], axis = axis)
    return functools.reduce(lambda x, y: np.matmul(y, x), arrays)

编辑:首先,沿要减少的轴拆分数组。然后一次计算每两个矩阵的matmul。只要数组的其他维度相同,matmul 将忽略除最后两个维度之外的所有维度并计算它们的矩阵乘法结果。