如何使用矩阵乘法在tensorflow1.10中实现covnert (Batch, a, b) tensor * (Batch, b) tensor into (Batch, a) tensor

how to use matrix multiplication to implement that covnert (Batch, a, b) tensor * (Batch, b) tensor into (Batch, a) tensor in tensorflow1.10

例如,

# Batch = 5, a = 25, b = 2
# tensor t1 shape: (Batch, a, b)
# tensor t2 shape: (Batch, b)
# tensor res shape: (Batch, a)
print(t1)
<tf.Tensor: id=466, shape=(2, 25, 2), dtype=int32, numpy=
array([[[ 1, 26],
        [ 2, 27],
        [ 3, 28],
        [ 4, 29],
        [ 5, 30],
        [ 6, 31],
        [ 7, 32],
        [ 8, 33],
        [ 9, 34],
        [10, 35],
        [11, 36],
        [12, 37],
        [13, 38],
        [14, 39],
        [15, 40],
        [16, 41],
        [17, 42],
        [18, 43],
        [19, 44],
        [20, 45],
        [21, 46],
        [22, 47],
        [23, 48],
        [24, 49],
        [25, 50]],

       [[ 1, 26],
        [ 2, 27],
        [ 3, 28],
        [ 4, 29],
        [ 5, 30],
        [ 6, 31],
        [ 7, 32],
        [ 8, 33],
        [ 9, 34],
        [10, 35],
        [11, 36],
        [12, 37],
        [13, 38],
        [14, 39],
        [15, 40],
        [16, 41],
        [17, 42],
        [18, 43],
        [19, 44],
        [20, 45],
        [21, 46],
        [22, 47],
        [23, 48],
        [24, 49],
        [25, 50]]], dtype=int32)>

print(t2)
<tf.Tensor: id=410, shape=(2, 2), dtype=int32, numpy=
array([[1, 0],
       [1, 0]], dtype=int32)>

# after matrix multiplication
print(res)
<tf.Tensor: id=474, shape=(2, 25), dtype=int32, numpy=
array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25]], dtype=int32)>

我想的方法是像以前一样用矩阵乘法只保留一部分,但是我很难实现。
如果不介意谁能帮帮我?

开始于

import tensorflow as tf # 2.6.0
Batch, a, b = 5, 25, 2
t1 = tf.random.normal((Batch, a, b))
t2 = tf.random.normal((Batch, b));

如果我理解正确,你想要 t3[b,i] = sum(t2[b,i,j], t2[b,j]) 这可以使用 Einstein summation

以直接的方式描述
t3 = tf.einsum('bij,bj->bi', t1, t2)
assert t3.shape == (Batch, a)

这也可以用reduce_sum写成

t3 = tf.reduce_sum(t1 * t2[:,None,:], axis=2)
assert t3.shape == (Batch, a)

或者如果你想用矩阵乘法运算符来做它也是可能的,但在那种情况下你必须将 t2 解压缩到 (Batch, b, 1) (t2[:,:,None]), 并将结果从 (Batch, b, 1) 压缩回 (Batch, b) (t3[:,:,0]).

t3 = t1 @ t2[:,:,None]
assert t3[:,:,0].shape == (Batch, a)