使用 JSON 文件设置 TF 数据集对象:解析 JSON 时出错
Setting up TF Datatset Object with JSON files: Error while parsing JSON
所以我正在尝试为模型输入设置我的 tensorflow 数据集对象。 X 是一系列图像(.png 文件),Y 是保存在 json 个文件中的一系列列表。
在我打印下面代码底部的一些数据集元素之前,一切似乎都运行良好。我想确保它正常工作,但我收到一条错误消息:
Error while parsing JSON: : Root element must be a message. [[{{node DecodeJSONExample}}]]{Op:IteratorGetNext]
一些数据说明:
json 文件包含表示 3D 中的点的各种长度和值的一维列表 space。它们看起来像这样:
[.28, -.39, .48, 1, 55, 88]
图片为原始格式,分辨率为 (1080, 1920, 3)。我希望能有所帮助。如果需要更多信息,请告诉我。
对我做错了什么有什么想法吗?
def build_dataset():
for root, dirs, files in os.walk(directory):
for filename in files:
path = os.path.join(root, filename)
if path.endswith('.png'):
x.append(path)
if path.endswith('.json'):
y.append(path)
dataset = tf.data.Dataset.from_tensor_slices((x, y))
return dataset.shuffle(buffer_size=len(x))
def read_data(x_img, y_model):
img_data = tf.io.read_file(x_img)
img = tf.io.decode_png(img_data)
model_data = tf.io.read_file(y_model)
model = tf.io.decode_json_example(model_data)
return img, model
def prepare_data(img, models):
return img/255, models
train_ds = build_dataset()
train_ds = train_ds.map(read_data)
train_ds = train_ds.map(prepare_data)
train_ds = train_ds.batch(64)
for x, y in train_ds:
print(x, y)
由于您使用的是列表而不是正确的 JSONs,我建议修复 JSON 文件。例如:
{ "data" : [.28, -.39, .48, 1, 55, 88] }
或者,如果您无法更改文件,只需读取每个文件并将其解析为张量,而无需任何 JSON 实用程序,例如 tf.io.decode_json_example
:
import tensorflow as tf
def read_data(x_img, y_model):
img_data = tf.io.read_file(x_img)
img = tf.io.decode_png(img_data)
model_data = tf.io.read_file(y_model)
return img, tf.strings.to_number(tf.strings.split(tf.strings.regex_replace(tf.strings.strip(model_data), '[\[\],]', '')))
def prepare_data(img, models):
return img/255, models
train_ds = tf.data.Dataset.from_tensor_slices((['/content/result_image.png', '/content/result_image1.png'],
['/content/test.json', '/content/test2.json']))
train_ds = train_ds.map(read_data)
train_ds = train_ds.map(prepare_data)
train_ds = train_ds.batch(64)
for x, y in train_ds:
print(x.shape, y)
(2, 100, 100, 3) tf.Tensor(
[[ 0.28 -0.39 0.48 1. 55. 88. ]
[ 0.28 -0.39 0.48 1. 55. 88. ]], shape=(2, 6), dtype=float32)
所以我正在尝试为模型输入设置我的 tensorflow 数据集对象。 X 是一系列图像(.png 文件),Y 是保存在 json 个文件中的一系列列表。
在我打印下面代码底部的一些数据集元素之前,一切似乎都运行良好。我想确保它正常工作,但我收到一条错误消息:
Error while parsing JSON: : Root element must be a message. [[{{node DecodeJSONExample}}]]{Op:IteratorGetNext]
一些数据说明:
json 文件包含表示 3D 中的点的各种长度和值的一维列表 space。它们看起来像这样:
[.28, -.39, .48, 1, 55, 88]
图片为原始格式,分辨率为 (1080, 1920, 3)。我希望能有所帮助。如果需要更多信息,请告诉我。
对我做错了什么有什么想法吗?
def build_dataset():
for root, dirs, files in os.walk(directory):
for filename in files:
path = os.path.join(root, filename)
if path.endswith('.png'):
x.append(path)
if path.endswith('.json'):
y.append(path)
dataset = tf.data.Dataset.from_tensor_slices((x, y))
return dataset.shuffle(buffer_size=len(x))
def read_data(x_img, y_model):
img_data = tf.io.read_file(x_img)
img = tf.io.decode_png(img_data)
model_data = tf.io.read_file(y_model)
model = tf.io.decode_json_example(model_data)
return img, model
def prepare_data(img, models):
return img/255, models
train_ds = build_dataset()
train_ds = train_ds.map(read_data)
train_ds = train_ds.map(prepare_data)
train_ds = train_ds.batch(64)
for x, y in train_ds:
print(x, y)
由于您使用的是列表而不是正确的 JSONs,我建议修复 JSON 文件。例如:
{ "data" : [.28, -.39, .48, 1, 55, 88] }
或者,如果您无法更改文件,只需读取每个文件并将其解析为张量,而无需任何 JSON 实用程序,例如 tf.io.decode_json_example
:
import tensorflow as tf
def read_data(x_img, y_model):
img_data = tf.io.read_file(x_img)
img = tf.io.decode_png(img_data)
model_data = tf.io.read_file(y_model)
return img, tf.strings.to_number(tf.strings.split(tf.strings.regex_replace(tf.strings.strip(model_data), '[\[\],]', '')))
def prepare_data(img, models):
return img/255, models
train_ds = tf.data.Dataset.from_tensor_slices((['/content/result_image.png', '/content/result_image1.png'],
['/content/test.json', '/content/test2.json']))
train_ds = train_ds.map(read_data)
train_ds = train_ds.map(prepare_data)
train_ds = train_ds.batch(64)
for x, y in train_ds:
print(x.shape, y)
(2, 100, 100, 3) tf.Tensor(
[[ 0.28 -0.39 0.48 1. 55. 88. ]
[ 0.28 -0.39 0.48 1. 55. 88. ]], shape=(2, 6), dtype=float32)