DECODE_RAW TensorSliceDataset

DECODE_RAW the TensorSliceDataset

我正在复制 TTS 模型,Deep Voice 3。 数据集是 LJSpeech-1.1。我找到了一个 github 存储库 (https://github.com/Kyubyong/deepvoice3),但它是在我使用 TF 2.0 的早期 tensorflow 版本中编写的。 在数据处理中,我需要对 TensorSliceDataset 的输出应用 decode_raw 函数。 但是,我无法将 decode_raw 函数应用于输出。 所以,我的问题是如何将 decode_raw 应用于 TensorSliceDataset 的输出?

我已经将文本转换成维数为(13066,)的张量。 在原始 repo 中,他使用 tf.train.slice_input_producer。 对于 TF 2.0,我使用 tf.data.Dataset.from_tensor_slices 将该张量转换为 TensorSliceDataset。 之后,我无法将 decode_raw 应用于 TensorSliceDataset。下面是代码

# old TF code
texts, mels, dones, mags = tf.train.slice_input_producer([texts, mels, dones, mags], shuffle = True)
# TF 2.0 code
texts = tf.convert_to_tensor(texts)
texts = tf.data.Dataset.from_tensor_slices(texts)
texts = tf.io.decode_raw(texts, tf.int32) # (None,)

您需要对数据集对象应用解析函数。 而不是这一行

texts = tf.io.decode_raw(texts, tf.int32) # (None,)`

使用

texts = texts.map(lambda x: tf.io.decode_raw(x, tf.int32))