Python numpy 从形状为 3,3 的输入矩阵计算出形状为 3,3,3 的矩阵

Python numpy computing out matrix with shape 3,3,3 from input matrecies with shape 3,3

我目前正在 python 中仅使用 numpy 构建神经网络。

这是问题区域的布局:

我有一个数组保存列中输入神经元的值,行代表不同的训练数据点。它的形状是 3, 3:

in_a
array([['t_1a_1', 't_1a_2', 't_1a_3'],
       ['t_2a_1', 't_2a_2', 't_2a_3'],
       ['t_3a_1', 't_3a_2', 't_3a_3']], dtype='<U6')

然后我有一个权重数组,其中列是到输出 1、2 和 3 的连接,行是从 1、2 和 3 开始的连接。它的形状也是 3, 3:

in_w
array([['w_11', 'w_12', 'w_13'],
       ['w_21', 'w_22', 'w_23'],
       ['w_31', 'w_32', 'w_33']], dtype='<U4')

现在我想计算一个形状为 3, 3, 3 的矩阵。它看起来如下:

out
array([[['t_1*a_1*w_11', 't_1*a_1*w_12', 't_1*a_1*w_13'],
        ['t_1*a_2*w_21', 't_1*a_2*w_22', 't_1*a_2*w_23'],
        ['t_1*a_3*w_31', 't_1*a_2*w_32', 't_1*a_2*w_33']],

       [['t_2*a_1*w_11', 't_2*a_1*w_12', 't_2*a_1*w_13'],
        ['t_2*a_2*w_21', 't_2*a_2*w_22', 't_2*a_2*w_23'],
        ['t_2*a_3*w_31', 't_2*a_2*w_32', 't_2*a_2*w_33']],

       [['t_3*a_1*w_11', 't_3*a_1*w_12', 't_3*a_1*w_13'],
        ['t_3*a_2*w_21', 't_3*a_2*w_22', 't_3*a_2*w_23'],
        ['t_3*a_3*w_31', 't_3*a_2*w_32', 't_3*a_3*w_33']]], dtype='<U12')

我尝试了 numpy.dot,简单的 * 乘法,@ 组合,但没有任何效果。我认为解决方案可能是 numpy.einsum 或 numpy.tensordot,但我无法理解它们。有谁知道如何根据输入矩阵计算输出矩阵或者可以推荐一种方法和解释?谢谢你的帮助

你只需要

in_a[...,None] * in_w

如果您考虑一下 in_a 的形状 (training_sets, input_neurons)in_w (input_neurons, output_neurons)。你的输出似乎是

的逐元素乘法
    (T, I)       # in_a
*      (I, O)    # in_w

让我们来演示一下吧

class Variable:
    def __init__(self, name):
        self.name = name

    def __mul__(self, other):
        if not isinstance(other, Variable):
            raise ValueError
        return Variable(f'{self.name}*{other.name}')

    def __repr__(self):
        return self.name


def generate_array(fmt, rows, columns):
    return np.array([[Variable(fmt.format(i, j)) for j in range(1, columns+1)] 
                     for i in range(1, rows+1)])

in_a = generate_array('t_{}a_{}', 3, 3)
in_w = generate_array('ww_{}{}', 3, 3)
print(in_a[...,None] * in_w)

打印

[[[t_1a_1*ww_11 t_1a_1*ww_12 t_1a_1*ww_13]
  [t_1a_2*ww_21 t_1a_2*ww_22 t_1a_2*ww_23]
  [t_1a_3*ww_31 t_1a_3*ww_32 t_1a_3*ww_33]]

 [[t_2a_1*ww_11 t_2a_1*ww_12 t_2a_1*ww_13]
  [t_2a_2*ww_21 t_2a_2*ww_22 t_2a_2*ww_23]
  [t_2a_3*ww_31 t_2a_3*ww_32 t_2a_3*ww_33]]

 [[t_3a_1*ww_11 t_3a_1*ww_12 t_3a_1*ww_13]
  [t_3a_2*ww_21 t_3a_2*ww_22 t_3a_2*ww_23]
  [t_3a_3*ww_31 t_3a_3*ww_32 t_3a_3*ww_33]]]