如何 tf.cast tensorflow 数据集中的字段

How to tf.cast a field within a tensorflow Dataset

我有一个 tf.data.Dataset 看起来像这样:

<BatchDataset shapes: ((None, 256, 256, 3), (None,)), types: (tf.float32, tf.int32)>

第二个元素(如果索引为零则为第一个)对应于一个标签。我想将第二项(标签)投射到 tf.uint8.

如何在处理td.data.Dataset时使用tf.cast


类似问题

非常相似,但不适用于 tf.data.Dataset


复制

来自Image classification from scratch

curl -O https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_5340.zip
unzip kagglecatsanddogs_5340.zip

然后在 Python 中 tensorflow~=2.4:

import tensorflow as tf

ds = tf.keras.preprocessing.image_dataset_from_directory(
    "PetImages", batch_size=32
)
print(ds)

map 函数可能有帮助

a = tf.data.Dataset.from_tensor_slices(np.empty((2,5,3)))
b = tf.data.Dataset.range(5, 8)
c = tf.data.Dataset.zip((a,b))
d = c.batch(1)
d
<BatchDataset shapes: ((None, 5, 3), (None,)), types: (tf.float64, tf.int64)>

# change the dtype of the 2nd arg in the batch from int64 to int8
e = d.map(lambda x,y:(x,tf.cast(y, tf.int8))) 
<MapDataset shapes: ((None, 5, 3), (None,)), types: (tf.float64, tf.int8)>