我应该如何通过 tf.data.Dataset 将元素附加到每个序列数据

How should I append an element to each sequence data by tf.data.Dataset

我想获取由tf.data.Dataset后面添加的char2int['EOS']的序列数据。 我写的代码如下:

import tensorflow as tf 

def _get_generator(list_of_text, char2int):
    def gen():
        for text in list_of_text:
            yield [char2int[x] for x in text] # transform char to int
    return gen

def get_dataset(list_of_text, char2int):
    gen = _get_generator(list_of_text, char2int)
    dataset = tf.data.Dataset.from_generator(gen, (tf.int32), tf.TensorShape([None]))

    dataset = dataset.map(lambda seq: seq+[char2int['EOS']])  # append EOS to the end of line

    data_iter = dataset.make_initializable_iterator()

    return dataset, data_iter

char2int = {'EOS':1, 'a':2, 'b':3, 'c':4}
list_of_text = ['aaa', 'abc'] # the sequence data

with tf.Graph().as_default():
    dataset, data_iter = get_dataset(list_of_text, char2int)
    with tf.Session() as sess:
        sess.run(data_iter.initializer)
        tt1 = sess.run(data_iter.get_next())
        tt2 = sess.run(data_iter.get_next())
        print(tt1)  # got [3 3 3] but I want [2 2 2 1]
        print(tt2)  # god [3 4 5] but I want [2 3 4 1]

但是我得不到我想要的。它对每个数据执行逐元素加法。我该如何解决,谢谢

在您的地图函数中,您将每个值加 1,而不是连接值。您可以将 _get_generator 更改为 :

def _get_generator(list_of_text, char2int):
   def gen():
     for text in list_of_text:
        yield [char2int[x] for x in text] + [char2int['EOS']]# transform char to int
   return gen

并删除 dataset.map 调用。

正如 Vijay 在 Dataset.map() 中的 , the + operator on a tf.Tensor of type tf.int32 performs addition rather than concatenation. To concatenate an additional symbol onto the end of the sequence, instead use tf.concat() 中指出的那样:

dataset = dataset.map(lambda seq: tf.concat([seq, [char2int['EOS']]], axis=0)