如果我们使用索引矩阵,是否需要在 Theano 中使用 flatten 和 reshape?
Do we need to use flatten and reshape in Theano if we use a matrix of indexes?
我试图理解 Theano implementation of LSTM(目前 link 出于某种原因无法正常工作,但我希望它能尽快恢复)。
在代码中我看到以下部分:
emb = tparams['Wemb'][x.flatten()].reshape([n_timesteps,
n_samples,
options['dim_proj']])
为了做到"context independent"我重写了如下:
e = W[x.flatten()]].reshape([n1, n2, n3])
其中 x
的维度是 (n1, n2)
,W
的维度是 (N, n3)
。
因此,我的假设是可以重写代码以使其更短。特别是我们可以写:
e = W[x]
或者,如果我们使用原来的表示法应该是:
emb = tparams['Wemb'][x]
我说得对吗?
为了提供更多上下文,x
是一个包含表示单词的整数的二维数组(例如 27 表示 "word number 27")。我的记法中的W
(或者原记法中的tparams['Wemb']
)是一个二维矩阵,每一行对应一个词。因此,它是一个词嵌入矩阵 (Word2Vec),将每个词映射到一个实值向量。
是的,你是对的。
W[x.flatten()]]
为您提供由 x
的值定义的 W
的行(即单词)。所以结果是shape = (n1*n2,n3)
。我们称其为 "list of words"(不是 python 列表,只是一个普通的语音列表)。
然后重塑为您提供所需的大小,其中单词列表被细分为 n1
页 n2
个单词。
您可以使用 W[x]
实现相同的效果,因为 x
的 n2
行中的每一行都会为您提供结果的 n1
页之一。
这是一个示例程序,显示两个表达式是等价的:
import numpy as np
N = 4
n3 = 5
W = np.arange(n3*N).reshape((N,n3))
print("W = \n", W)
n1 = 2
n2 = 3
x = np.random.randint(low=0, high=N,size=(n1,n2))
print("\nx = \n", x)
print("\ne = \n", W[x.flatten()].reshape([n1, n2, n3]))
print("\nalternativeE = \n", W[x])
我试图理解 Theano implementation of LSTM(目前 link 出于某种原因无法正常工作,但我希望它能尽快恢复)。
在代码中我看到以下部分:
emb = tparams['Wemb'][x.flatten()].reshape([n_timesteps,
n_samples,
options['dim_proj']])
为了做到"context independent"我重写了如下:
e = W[x.flatten()]].reshape([n1, n2, n3])
其中 x
的维度是 (n1, n2)
,W
的维度是 (N, n3)
。
因此,我的假设是可以重写代码以使其更短。特别是我们可以写:
e = W[x]
或者,如果我们使用原来的表示法应该是:
emb = tparams['Wemb'][x]
我说得对吗?
为了提供更多上下文,x
是一个包含表示单词的整数的二维数组(例如 27 表示 "word number 27")。我的记法中的W
(或者原记法中的tparams['Wemb']
)是一个二维矩阵,每一行对应一个词。因此,它是一个词嵌入矩阵 (Word2Vec),将每个词映射到一个实值向量。
是的,你是对的。
W[x.flatten()]]
为您提供由 x
的值定义的 W
的行(即单词)。所以结果是shape = (n1*n2,n3)
。我们称其为 "list of words"(不是 python 列表,只是一个普通的语音列表)。
然后重塑为您提供所需的大小,其中单词列表被细分为 n1
页 n2
个单词。
您可以使用 W[x]
实现相同的效果,因为 x
的 n2
行中的每一行都会为您提供结果的 n1
页之一。
这是一个示例程序,显示两个表达式是等价的:
import numpy as np
N = 4
n3 = 5
W = np.arange(n3*N).reshape((N,n3))
print("W = \n", W)
n1 = 2
n2 = 3
x = np.random.randint(low=0, high=N,size=(n1,n2))
print("\nx = \n", x)
print("\ne = \n", W[x.flatten()].reshape([n1, n2, n3]))
print("\nalternativeE = \n", W[x])