在 tf.keras.layers.Embedding 中,为什么知道字典的大小很重要?
In tf.keras.layers.Embedding, why it is important to know the size of dictionary?
同题,在tf.keras.layers.Embedding中,为什么知道字典的大小作为输入维度很重要?
在这样的设置下,张量的dimensions/shapes如下:
- 输入张量的大小为
[batch_size, max_time_steps]
,因此该张量的每个元素的值都在 0 to vocab_size-1
. 范围内
- 然后,来自输入张量的每个值都通过具有形状
[vocab_size, embedding_size]
的嵌入层。嵌入层的输出形状为 [batch_size, max_time_steps, embedding_size]
.
- 那么,在典型的seq2seq场景中,这个
3D
张量就是循环神经网络的输入。
- ...
以下是 Tensorflow 中的实现方式,以便您更好地理解:
inputs = tf.placeholder(shape=(batch_size, max_time_steps), ...)
embeddings = tf.Variable(shape=(vocab_size, embedding_size], ...)
inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)
现在,嵌入查找 table 的输出具有 [batch_size, max_time_steps, embedding_size]
形状。
因为在内部,嵌入层只不过是一个大小为 vocab_size x embedding_size
的矩阵。这是一个简单的查找 table:该矩阵的行 n
存储单词 n
.
的向量
所以,如果你有,例如1000 个不同的单词,您的嵌入层需要知道这个数字才能存储 1000 个向量(作为矩阵)。
不要将层的内部存储与其输入或输出形状混淆。
输入形状是 (batch_size, sequence_length)
,其中每个条目都是 [0, vocab_size[
范围内的整数。对于这些整数中的每一个,该层将 return 内部矩阵的相应行(这是一个大小为 embedding_size
的向量),因此输出形状变为 (batch_size, sequence_length, embedding_size)
.
同题,在tf.keras.layers.Embedding中,为什么知道字典的大小作为输入维度很重要?
在这样的设置下,张量的dimensions/shapes如下:
- 输入张量的大小为
[batch_size, max_time_steps]
,因此该张量的每个元素的值都在0 to vocab_size-1
. 范围内
- 然后,来自输入张量的每个值都通过具有形状
[vocab_size, embedding_size]
的嵌入层。嵌入层的输出形状为[batch_size, max_time_steps, embedding_size]
. - 那么,在典型的seq2seq场景中,这个
3D
张量就是循环神经网络的输入。 - ...
以下是 Tensorflow 中的实现方式,以便您更好地理解:
inputs = tf.placeholder(shape=(batch_size, max_time_steps), ...)
embeddings = tf.Variable(shape=(vocab_size, embedding_size], ...)
inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)
现在,嵌入查找 table 的输出具有 [batch_size, max_time_steps, embedding_size]
形状。
因为在内部,嵌入层只不过是一个大小为 vocab_size x embedding_size
的矩阵。这是一个简单的查找 table:该矩阵的行 n
存储单词 n
.
所以,如果你有,例如1000 个不同的单词,您的嵌入层需要知道这个数字才能存储 1000 个向量(作为矩阵)。
不要将层的内部存储与其输入或输出形状混淆。
输入形状是 (batch_size, sequence_length)
,其中每个条目都是 [0, vocab_size[
范围内的整数。对于这些整数中的每一个,该层将 return 内部矩阵的相应行(这是一个大小为 embedding_size
的向量),因此输出形状变为 (batch_size, sequence_length, embedding_size)
.