如何正确地重新标记 TensorFlow 数据集?
How to properly relabel a TensorFlow dataset?
我目前正在使用 TensorFlow 处理 CIFAR10 数据集。
由于各种原因,我需要按预定义规则更改标签,例如。标签为 4 的每个示例都应更改为 3,或者每个标签为 1 的示例应更改为 6。
我试过以下方法:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
builder = tfds.image.Cifar10()
builder.download_and_prepare()
ds_train: tf.data.Dataset = builder.as_dataset(split='train')
def relabel_map(l):
return {0: 0, 1: 6, 2: 1, 3: 2, 4: 3, 5: 4, 6: 9, 7: 5, 8: 7, 9: 8}[l]
ds_train = ds_train.map(lambda example: (example['image'], tf.py_function(relabel_map, [example['label']], [tf.int64])))
for ex in ds_train.take(1):
plt.imshow(np.array(ex[0], dtype=np.uint8))
plt.show()
print(ex[1])
当我尝试 运行 时,在带有 for ex in ds_train.take(1):
的行出现以下错误:
TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.
我的python版本是3.8.12,TensorFlow版本是2.7.0。
PS:也许我可以通过转换为 one-hot 并用矩阵对其进行转换来完成此转换,但这在代码中看起来不那么直接。
我建议您使用 tf.lookup.StaticHashTable
:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
builder = tfds.image.Cifar10()
builder.download_and_prepare()
ds_train: tf.data.Dataset = builder.as_dataset(split='train')
table = tf.lookup.StaticHashTable(
initializer=tf.lookup.KeyValueTensorInitializer(
keys=tf.constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=tf.int64),
values=tf.constant([0, 6, 1, 2, 3, 4, 9, 5, 7, 8], dtype=tf.int64),
),
default_value= tf.constant(0, dtype=tf.int64)
)
def relabel_map(example):
example['label'] = table.lookup(example['label'])
return example
ds_train = ds_train.map(relabel_map)
for ex in ds_train.take(1):
plt.imshow(np.array(ex['image'], dtype=np.uint8))
plt.show()
print(ex['label'])
tf.Tensor(5, shape=(), dtype=int64)
我目前正在使用 TensorFlow 处理 CIFAR10 数据集。 由于各种原因,我需要按预定义规则更改标签,例如。标签为 4 的每个示例都应更改为 3,或者每个标签为 1 的示例应更改为 6。
我试过以下方法:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
builder = tfds.image.Cifar10()
builder.download_and_prepare()
ds_train: tf.data.Dataset = builder.as_dataset(split='train')
def relabel_map(l):
return {0: 0, 1: 6, 2: 1, 3: 2, 4: 3, 5: 4, 6: 9, 7: 5, 8: 7, 9: 8}[l]
ds_train = ds_train.map(lambda example: (example['image'], tf.py_function(relabel_map, [example['label']], [tf.int64])))
for ex in ds_train.take(1):
plt.imshow(np.array(ex[0], dtype=np.uint8))
plt.show()
print(ex[1])
当我尝试 运行 时,在带有 for ex in ds_train.take(1):
的行出现以下错误:
TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key.
我的python版本是3.8.12,TensorFlow版本是2.7.0。
PS:也许我可以通过转换为 one-hot 并用矩阵对其进行转换来完成此转换,但这在代码中看起来不那么直接。
我建议您使用 tf.lookup.StaticHashTable
:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
builder = tfds.image.Cifar10()
builder.download_and_prepare()
ds_train: tf.data.Dataset = builder.as_dataset(split='train')
table = tf.lookup.StaticHashTable(
initializer=tf.lookup.KeyValueTensorInitializer(
keys=tf.constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=tf.int64),
values=tf.constant([0, 6, 1, 2, 3, 4, 9, 5, 7, 8], dtype=tf.int64),
),
default_value= tf.constant(0, dtype=tf.int64)
)
def relabel_map(example):
example['label'] = table.lookup(example['label'])
return example
ds_train = ds_train.map(relabel_map)
for ex in ds_train.take(1):
plt.imshow(np.array(ex['image'], dtype=np.uint8))
plt.show()
print(ex['label'])
tf.Tensor(5, shape=(), dtype=int64)