带 Tensorflow 转置的矩阵乘法

Matrix multiplication with transpose with Tensorflow

尝试在 TF

中完成 MatrixMultiplication
import tensorflow as tf
a1 = tf.constant(tf.random.normal(shape=(5,4,64)))
tf.matmul(a1,a1,transpose_b=True)

这工作得很好,但是如果我 transpose 输入 a1 像下面这样手动,我会得到一个错误:

tf.matmul(a1,tf.transpose(a1))

错误:

InvalidArgumentError: In[0] and In[1] must have compatible batch dimensions: [5,4,64] vs. [64,4,5] [Op:BatchMatMulV2]

文档:

transpose_b: If True, b is transposed before multiplication.

所以我不明白其中的区别,任何建议都会有所帮助。

要计算批量矩阵乘法,您需要确保 3D 张量具有以下格式。检查3-D张量矩阵乘法。

[batchs, h`, w`] @ [batchs, w``, h``]

所以,在你上面的例子中,它应该是

import tensorflow as tf

a1 = tf.constant(tf.random.normal(shape=(5,4,64)))

a1.shape, tf.transpose(a1, perm=[0, 2, 1]).shape
(TensorShape([5, 4, 64]), TensorShape([5, 64, 4]))

# swape the height and width - not batch axis 
tf.matmul(a1, tf.transpose(a1, perm=[0, 2, 1])).shape
TensorShape([5, 4, 4])