使用 Dataset api 返回特定元素
Returning specific elements with Dataset api
我写了一个 tfrecord 文件,里面有图片和它们 labels.Then 我可以使用
来获取它们
def parserTrain(record):
keys_to_features = {
"image_raw": tf.FixedLenFeature((), tf.string, default_value=""),
"label": tf.FixedLenFeature((), tf.int64,
default_value=tf.zeros([], dtype=tf.int64)),
}
parsed = tf.parse_single_example(record, keys_to_features)
# Perform additional preprocessing on the parsed data.
image = tf.image.decode_jpeg(parsed["image_raw"])
image = tf.reshape(image, [256, 256, 3])
image = tf.transpose(image, perm=[2, 0, 1]) # channels first
image = tf.truediv(image, 255.0)
label = tf.cast(parsed["label"], tf.int32)
return {"image": image}, label
# Set up training input function.
def train_input_fn():
"""Prepare data for training."""
train_tfrecord = 'Dataset/train_images.tfrecords'
dataset = tf.data.TFRecordDataset(train_tfrecord)
dataset = dataset.map(parserTrain)
在那之后我想过滤掉一些可能使用这样的例子:
def f(x):
return x[1] == 1
ds1 = dataset.filter(f)
但我得到这个错误:
TypeError: f() takes 1 positional argument but 2 were given
因此,假设您有一个数据集(例如 TFRecordDataset
),您可以按如下方式过滤示例:
dataset = tf.data.TFRecordDataset(filenames=files)
dataset = dataset.filter(lambda example: example["value"] == value and example["label"] == label)
dataset = ...
在我找到答案后回复我的问题。
元组数据集的过滤函数的正确语法如下:
def f(im, label):
return tf.equal(label, 1)
ds1 = dataset.filter(f)
我写了一个 tfrecord 文件,里面有图片和它们 labels.Then 我可以使用
来获取它们 def parserTrain(record):
keys_to_features = {
"image_raw": tf.FixedLenFeature((), tf.string, default_value=""),
"label": tf.FixedLenFeature((), tf.int64,
default_value=tf.zeros([], dtype=tf.int64)),
}
parsed = tf.parse_single_example(record, keys_to_features)
# Perform additional preprocessing on the parsed data.
image = tf.image.decode_jpeg(parsed["image_raw"])
image = tf.reshape(image, [256, 256, 3])
image = tf.transpose(image, perm=[2, 0, 1]) # channels first
image = tf.truediv(image, 255.0)
label = tf.cast(parsed["label"], tf.int32)
return {"image": image}, label
# Set up training input function.
def train_input_fn():
"""Prepare data for training."""
train_tfrecord = 'Dataset/train_images.tfrecords'
dataset = tf.data.TFRecordDataset(train_tfrecord)
dataset = dataset.map(parserTrain)
在那之后我想过滤掉一些可能使用这样的例子:
def f(x):
return x[1] == 1
ds1 = dataset.filter(f)
但我得到这个错误:
TypeError: f() takes 1 positional argument but 2 were given
因此,假设您有一个数据集(例如 TFRecordDataset
),您可以按如下方式过滤示例:
dataset = tf.data.TFRecordDataset(filenames=files)
dataset = dataset.filter(lambda example: example["value"] == value and example["label"] == label)
dataset = ...
在我找到答案后回复我的问题。 元组数据集的过滤函数的正确语法如下:
def f(im, label):
return tf.equal(label, 1)
ds1 = dataset.filter(f)