将多个文件输入到 Tensorflow 数据集中
Input multiple files into Tensorflow dataset
我有以下 input_fn.
def input_fn(filenames, batch_size):
# Create a dataset containing the text lines.
dataset = tf.data.TextLineDataset(filenames).skip(1)
# Parse each line.
dataset = dataset.map(_parse_line)
# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(10000).repeat().batch(batch_size)
# Return the dataset.
return dataset
如果 filenames=['file1.csv']
或 filenames=['file2.csv']
效果很好。如果 filenames=['file1.csv', 'file2.csv']
,它会给我一个错误。在 Tensorflow documentation 中,它说 filenames
是一个包含一个或多个文件名的 tf.string
张量。我应该如何导入多个文件?
错误如下。它似乎忽略了上面 input_fn
中的 .skip(1)
:
InvalidArgumentError: Field 0 in record 0 is not a valid int32: row_id
[[Node: DecodeCSV = DecodeCSV[OUT_TYPE=[DT_INT32, DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT, ..., DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32], field_delim=",", na_value="", use_quote_delim=true](arg0, DecodeCSV/record_defaults_0, DecodeCSV/record_defaults_1, DecodeCSV/record_defaults_2, DecodeCSV/record_defaults_3, DecodeCSV/record_defaults_4, DecodeCSV/record_defaults_5, DecodeCSV/record_defaults_6, DecodeCSV/record_defaults_7, DecodeCSV/record_defaults_8, DecodeCSV/record_defaults_9, DecodeCSV/record_defaults_10, DecodeCSV/record_defaults_11, DecodeCSV/record_defaults_12, DecodeCSV/record_defaults_13, DecodeCSV/record_defaults_14, DecodeCSV/record_defaults_15, DecodeCSV/record_defaults_16, DecodeCSV/record_defaults_17, DecodeCSV/record_defaults_18)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?], [?], [?], [?], [?], ..., [?], [?], [?], [?], [?]], output_types=[DT_FLOAT, DT_INT32, DT_INT32, DT_STRING, DT_STRING, ..., DT_INT32, DT_FLOAT, DT_INT32, DT_INT32, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator)]]
您使用 tf.data.TextLineDataset
的想法是正确的。但是,您当前的实现所做的是在其文件名输入张量中生成每个文件的每一行,第一个文件的第一个除外。您跳过第一行的方式现在只会影响第一个文件中的第一行。在第二个文件中,第一行没有被跳过。
基于每个文件名 Datasets guide, you should adapt your code to first create a regular Dataset
from the filenames, then run flat_map
上的示例,使用 TextLineDataset
读取它,同时跳过第一行:
d = tf.data.Dataset.from_tensor_slices(filenames)
# get dataset from each file, skipping first line of each file
d = d.flat_map(lambda filename: tf.data.TextLineDataset(filename).skip(1))
d = d.map(_parse_line) # And whatever else you need to do
此处,flat_map
通过读取文件内容并跳过第一行,从原始数据集的每个元素创建一个新数据集。
我有以下 input_fn.
def input_fn(filenames, batch_size):
# Create a dataset containing the text lines.
dataset = tf.data.TextLineDataset(filenames).skip(1)
# Parse each line.
dataset = dataset.map(_parse_line)
# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(10000).repeat().batch(batch_size)
# Return the dataset.
return dataset
如果 filenames=['file1.csv']
或 filenames=['file2.csv']
效果很好。如果 filenames=['file1.csv', 'file2.csv']
,它会给我一个错误。在 Tensorflow documentation 中,它说 filenames
是一个包含一个或多个文件名的 tf.string
张量。我应该如何导入多个文件?
错误如下。它似乎忽略了上面 input_fn
中的 .skip(1)
:
InvalidArgumentError: Field 0 in record 0 is not a valid int32: row_id
[[Node: DecodeCSV = DecodeCSV[OUT_TYPE=[DT_INT32, DT_INT32, DT_FLOAT, DT_INT32, DT_FLOAT, ..., DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32], field_delim=",", na_value="", use_quote_delim=true](arg0, DecodeCSV/record_defaults_0, DecodeCSV/record_defaults_1, DecodeCSV/record_defaults_2, DecodeCSV/record_defaults_3, DecodeCSV/record_defaults_4, DecodeCSV/record_defaults_5, DecodeCSV/record_defaults_6, DecodeCSV/record_defaults_7, DecodeCSV/record_defaults_8, DecodeCSV/record_defaults_9, DecodeCSV/record_defaults_10, DecodeCSV/record_defaults_11, DecodeCSV/record_defaults_12, DecodeCSV/record_defaults_13, DecodeCSV/record_defaults_14, DecodeCSV/record_defaults_15, DecodeCSV/record_defaults_16, DecodeCSV/record_defaults_17, DecodeCSV/record_defaults_18)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?], [?], [?], [?], [?], ..., [?], [?], [?], [?], [?]], output_types=[DT_FLOAT, DT_INT32, DT_INT32, DT_STRING, DT_STRING, ..., DT_INT32, DT_FLOAT, DT_INT32, DT_INT32, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator)]]
您使用 tf.data.TextLineDataset
的想法是正确的。但是,您当前的实现所做的是在其文件名输入张量中生成每个文件的每一行,第一个文件的第一个除外。您跳过第一行的方式现在只会影响第一个文件中的第一行。在第二个文件中,第一行没有被跳过。
基于每个文件名 Datasets guide, you should adapt your code to first create a regular Dataset
from the filenames, then run flat_map
上的示例,使用 TextLineDataset
读取它,同时跳过第一行:
d = tf.data.Dataset.from_tensor_slices(filenames)
# get dataset from each file, skipping first line of each file
d = d.flat_map(lambda filename: tf.data.TextLineDataset(filename).skip(1))
d = d.map(_parse_line) # And whatever else you need to do
此处,flat_map
通过读取文件内容并跳过第一行,从原始数据集的每个元素创建一个新数据集。