指定轴的 numpy `diagflat`

numpy `diagflat` for specified axis

我想将 ndarray 的轴扩展为 diagflat

的对角矩阵

例如

In [ ]: import numpy as np

In [ ]: example = np.random.random((200, 5))
In [ ]: example.shape
Out[ ]: (200, 5)

我正在寻找的是:

In [ ]: np.diagflat(example, axis=-1).shape
Out[ ]: (200, 5, 5)

diagflat 但是有 axis 参数。例如,我的想法是简单地插入一个新轴并将其与单位矩阵相乘。

In [ ]: Id = np.eye(example.shape[-1])
In [ ]: (example[..., np.newaxis] @ Id).shape
ValueError: shapes (200,5,1) and (5,5) not aligned: 1 (dim 2) != 5 (dim 0)

但这会引发一个错误,显然广播不适用于矩阵乘法。有没有优雅的解决方案,还是我必须手动创建和填充数组?

只需执行:

example[..., np.newaxis] * Id