Tensorflow 中的错误切片索引张量

Bad slice index Tensor in Tensorflow

这里是tensorflow中张量的定义:

word_weight   = tf.get_variable("word_weight", [word_num])
x_index = tf.placeholder(tf.int32, [None, sentence_length, 1])  

当我尝试时: word_weight[0]word_weight[1] 或其他,有效,我可以得到结果。但是当我尝试 word_weight[x_index[0,0,0]] 时,出现错误:

TypeError: Bad slice index Tensor("modle/RNN/Squeeze_1:0", shape=(), dtype=int32) of type <class 'tensorflow.python.framework.ops.Tensor'>

TensorFlow 对 Tensors 的下标运算符 (__getitem__) 的实现是 tf.slice 函数的语法糖。下标运算符实现支持 Python 整数、列表、元组和 slice 作为下标的类型。正如您所发现的,不支持 Tensor 本身作为下标。但是,您可以直接使用 tf.slice 函数来达到您的目的:

word_num = 100
sentence_length = 10
word_weight   = tf.get_variable("word_weight", [word_num])
x_index = tf.placeholder(tf.int32, [None, sentence_length, 1])  
ind = x_index[0, 0, 0:1]
_ = tf.slice(word_weight, ind, [1])