TensorFlow 中的双线性张量积

Bilinear Tensor Product in TensorFlow

我正在重新实现 this paper,关键操作是双线性张量积。我几乎不知道那是什么意思,但这篇论文有一个漂亮的小图形,我明白了。

关键操作是e_1 * W * e_2,我想知道在tensorflow中如何实现,因为剩下的应该很简单。

基本上,给定 3D 张量 W,将其切成矩阵,对于第 j 个切片(矩阵),在每一边乘以 e_1e_2,得到一个标量,它是结果向量中的第 j 个条目(此操作的输出)。

所以我想做一个e_1的乘积,一个d维向量,W,d x d x k张量,和e_2,另一个d维向量。这个产品是否可以像现在这样在 TensorFlow 中简洁地表达,或者我是否必须以某种方式定义自己的 op?

早期编辑

为什么这些张量相乘不起作用,是否有某种方法可以更明确地定义它以便它起作用?

>>> import tensorflow as tf
>>> tf.InteractiveSession()
>>> a = tf.ones([3, 3, 3])
>>> a.eval()
array([[[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]],

       [[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]],

       [[ 1.,  1.,  1.],
        [ 1.,  1.,  1.],
        [ 1.,  1.,  1.]]], dtype=float32)
>>> b = tf.ones([3, 1, 1])
>>> b.eval()
array([[[ 1.]],

       [[ 1.]],

       [[ 1.]]], dtype=float32)
>>> 

错误信息是

ValueError: Shapes TensorShape([Dimension(3), Dimension(3), Dimension(3)]) and TensorShape([Dimension(None), Dimension(None)]) must have the same rank

目前

事实证明,将两个 3D 张量相乘对 tf.matmul 也不起作用,所以 tf.batch_matmul 可以。 tf.batch_matmul 也会做 3D 张量和矩阵。然后我尝试了 3D 和矢量:

ValueError: Dimensions Dimension(3) and Dimension(1) are not compatible

您可以通过简单的重塑来做到这一点。对于两个矩阵乘法中的第一个,您有 k*d 个长度为 d 的向量进行点积。

这应该很接近:

temp = tf.matmul(E1,tf.reshape(Wddk,[d,d*k]))
result = tf.matmul(E2,tf.reshape(temp,[d,k]))

您可以在 W 和 e2 之间执行 3 阶张量和向量乘法,生成二维数组,然后将结果与 e1 相乘。下面的函数利用张量积和张量收缩来定义这个乘积(例如 W * e3)

import sympy as sp

def tensor3_vector_product(T, v):
    """Implements a product of a rank 3 tensor (3D array) with a 
       vector using tensor product and tensor contraction.

    Parameters
    ----------
    T: sp.Array of dimensions n x m x k

    v: sp.Array of dimensions k x 1

    Returns
    -------
    A: sp.Array of dimensions n x m

    """
    assert(T.rank() == 3)
    # reshape v to ensure a 1D vector so that contraction do 
    # not contain x 1 dimension
    v.reshape(v.shape[0], )
    p = sp.tensorproduct(T, v)
    return sp.tensorcontraction(p, (2, 3))

您可以使用 ref 中提供的示例来验证此乘法。上面的函数收缩了第二个和第三个轴,在你的情况下我认为你应该收缩 (1, 2) 因为 W 被定义为 d x d x k 而不是我的情况下的 k x d x d。