带 Tensorflow 转置的矩阵乘法
Matrix multiplication with transpose with Tensorflow
尝试在 TF
中完成 MatrixMultiplication
import tensorflow as tf
a1 = tf.constant(tf.random.normal(shape=(5,4,64)))
tf.matmul(a1,a1,transpose_b=True)
这工作得很好,但是如果我 transpose
输入 a1
像下面这样手动,我会得到一个错误:
tf.matmul(a1,tf.transpose(a1))
错误:
InvalidArgumentError: In[0] and In[1] must have compatible batch dimensions: [5,4,64] vs. [64,4,5] [Op:BatchMatMulV2]
文档:
transpose_b: If True, b is transposed before multiplication.
所以我不明白其中的区别,任何建议都会有所帮助。
要计算批量矩阵乘法,您需要确保 3D 张量具有以下格式。检查3-D
张量矩阵乘法。
[batchs, h`, w`] @ [batchs, w``, h``]
所以,在你上面的例子中,它应该是
import tensorflow as tf
a1 = tf.constant(tf.random.normal(shape=(5,4,64)))
a1.shape, tf.transpose(a1, perm=[0, 2, 1]).shape
(TensorShape([5, 4, 64]), TensorShape([5, 64, 4]))
# swape the height and width - not batch axis
tf.matmul(a1, tf.transpose(a1, perm=[0, 2, 1])).shape
TensorShape([5, 4, 4])
尝试在 TF
MatrixMultiplication
import tensorflow as tf
a1 = tf.constant(tf.random.normal(shape=(5,4,64)))
tf.matmul(a1,a1,transpose_b=True)
这工作得很好,但是如果我 transpose
输入 a1
像下面这样手动,我会得到一个错误:
tf.matmul(a1,tf.transpose(a1))
错误:
InvalidArgumentError: In[0] and In[1] must have compatible batch dimensions: [5,4,64] vs. [64,4,5] [Op:BatchMatMulV2]
文档:
transpose_b: If True, b is transposed before multiplication.
所以我不明白其中的区别,任何建议都会有所帮助。
要计算批量矩阵乘法,您需要确保 3D 张量具有以下格式。检查3-D
张量矩阵乘法。
[batchs, h`, w`] @ [batchs, w``, h``]
所以,在你上面的例子中,它应该是
import tensorflow as tf
a1 = tf.constant(tf.random.normal(shape=(5,4,64)))
a1.shape, tf.transpose(a1, perm=[0, 2, 1]).shape
(TensorShape([5, 4, 64]), TensorShape([5, 64, 4]))
# swape the height and width - not batch axis
tf.matmul(a1, tf.transpose(a1, perm=[0, 2, 1])).shape
TensorShape([5, 4, 4])