tf.io.decode_raw return张量如何使它成为字节或字符串

tf.io.decode_raw return tensor how to make it bytes or string

我为此苦苦挣扎了一段时间。我搜索了堆栈并检查了 tf2 医生很多次。指出了一种解决方案,但是 我不明白为什么我的解决方案不起作用。

在我的例子中,我在 tfrecords 中存储了一个二进制字符串(即字节)。 如果我通过 as_numpy_list 遍历数据集或直接调用 numpy() 在每个项目上,我都可以取回二进制字符串。 在迭代数据集时,它确实有效。

我不确定 map() 传递给 test_callback 的到底是什么。 我看到没有方法,也没有 属性 numpy,类型也一样 tf.io.decode_raw return。 (它是Tensor,但它也没有numpy)

基本上我需要一个二进制字符串,通过我的解析它 x = decoder.FromString(y) 然后把它传给我的编码器 这会将 x 二进制字符串转换为张量。

def test_callback(example_proto):

    # I tried to figure out. can I use bytes?decode 
    # directly and what is the most optimal solution.

    parsed_features = tf.io.decode_raw(example_proto, out_type=tf.uint8)
    # tf.io.decoder returns tensor with N bytes.

    x = creator.FromString(parsed_features.numpy)
    encoded_seq = midi_encoder.encode(x)
    return encoded_seq

raw_dataset = tf.data.TFRecordDataset(filenames=["main.tfrecord"])
raw_dataset = raw_dataset.map(test_callback)

谢谢大家。

我找到了一种解决方案,但我希望看到更多建议。

def test_callback(example_proto):
    from_string = creator.FromString(example_proto.numpy())
    encoded_seq = encoder.encoder(from_string)
    return encoded_seq

raw_dataset = tf.data.TFRecordDataset(filenames=["main.tfrecord"])
raw_dataset = raw_dataset.map(lambda x: tf.py_function(test_callback, [x], [tf.int64]))

我的理解是 tf.py_function 会降低性能。

谢谢