如何在 Tensorflow 中使用 tf.datasets 和迭代器
How to use tf.datasets with iterator in Tensorflow
我正在尝试使用 tf.data.TextLineDataset 从 csv 文件中读取数据,将数据集分片到多个工作节点,然后创建一个迭代器来迭代它们以分批提供数据。我使用了 TensorFlow (https://www.tensorflow.org/programmers_guide/datasets) tf.datasets 上的程序员指南。
运行 tf session 中的代码时出现以下错误:
*** tensorflow.python.framework.errors_impl.NotFoundError: Date,Open,High,Low,Last,Close,Total Trade Quantity,Turnover,close_pct_change_1d,KAMA7-KAMA30,KAMA15-KAMA30,HT_QUAD,TURNOVER,BOP,MFI,MINUS_DI,ROCP,STOCH_SLOWK,NATR,EMA7-EMA30-1d,DX-1d,PPO-1d,NATR-1d,HT_INPHASOR-2d,day_0,day_1,day_2,day_3; No such file or directory
[[Node: IteratorGetNext_5 = IteratorGetNext[output_shapes=[[], [], [], [], [], ..., [], [], [], [], []], output_types=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32, ..., DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator_8)]]
现在,"Date"、"Open"、"High"等是我要加载的数据集中的列。因此,我知道该错误与加载数据集无关。
加载数据集时,我使用 tf.data.TextLineDataset(file).skip(1)
但根据错误,它似乎没有跳过我的数据集的第一行(列标题)。
有人知道这个错误是从哪里来的吗?有人对此有解决办法吗?
请参阅以下代码进行说明:
def create_pipeline(bs, nr, ep):
def _X_parse_csv(file):
record_defaults=[[0]]*20
splits = tf.decode_csv(file, record_defaults)
input = splits
return input
def _y_parse_csv(file):
record_defaults=[[0]]*20
splits = tf.decode_csv(file, record_defaults)
label = splits[0]
return label
# Dataset for input data
file = tf.gfile.Glob("./NSEOIL.csv")
num_workers = 1 # for testing; simulate 1 node for sharding below
task_index = 0
ds_file = tf.data.TextLineDataset(file)
ds = ds_file.flat_map(lambda file: (tf.data.TextLineDataset(file).skip(1))) #remove CSV headers
ds = ds.shard(num_workers, task_index).repeat(ep)
X_train = ds.map(_X_parse_csv)
ds = ds_file.flat_map(lambda file: (tf.data.TextLineDataset(file).skip(2))) #remove CSV headers + shift forward 1 day
ds = ds.shard(num_workers, task_index).repeat(ep)
y_train = ds.map(_y_parse_csv)
X_iterator = X_train.make_initializable_iterator()
y_iterator = y_train.make_initializable_iterator()
return X_iterator, y_iterator
这两行似乎是问题的根源:
ds_file = tf.data.TextLineDataset(file)
ds = ds_file.flat_map(lambda file: (tf.data.TextLineDataset(file).skip(1))) #remove CSV headers
第一行根据 file
中命名的一个或多个文件的行创建数据集。然后,第二行为 ds_file
中的每个元素创建一个数据集,将每个元素(来自 file
的一行文本)视为另一个文件名。当 file
的第一行似乎是 CSV header 时,您看到的 NotFoundError
被视为文件名。
修复相对简单,幸运的是,因为您可以使用 Dataset.list_files()
创建与您的 glob 匹配的文件数据集,然后 Dataset.flat_map()
将对文件名进行操作:
# Create a dataset of filenames.
ds_file = tf.data.Dataset.list_files("./NSEOIL.csv")
# For each filename in `ds_file`, read the lines from that file (skipping the
# header).
ds = ds_file.flat_map(lambda file: (tf.data.TextLineDataset(file).skip(1)))
我正在尝试使用 tf.data.TextLineDataset 从 csv 文件中读取数据,将数据集分片到多个工作节点,然后创建一个迭代器来迭代它们以分批提供数据。我使用了 TensorFlow (https://www.tensorflow.org/programmers_guide/datasets) tf.datasets 上的程序员指南。 运行 tf session 中的代码时出现以下错误:
*** tensorflow.python.framework.errors_impl.NotFoundError: Date,Open,High,Low,Last,Close,Total Trade Quantity,Turnover,close_pct_change_1d,KAMA7-KAMA30,KAMA15-KAMA30,HT_QUAD,TURNOVER,BOP,MFI,MINUS_DI,ROCP,STOCH_SLOWK,NATR,EMA7-EMA30-1d,DX-1d,PPO-1d,NATR-1d,HT_INPHASOR-2d,day_0,day_1,day_2,day_3; No such file or directory
[[Node: IteratorGetNext_5 = IteratorGetNext[output_shapes=[[], [], [], [], [], ..., [], [], [], [], []], output_types=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32, ..., DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator_8)]]
现在,"Date"、"Open"、"High"等是我要加载的数据集中的列。因此,我知道该错误与加载数据集无关。
加载数据集时,我使用 tf.data.TextLineDataset(file).skip(1)
但根据错误,它似乎没有跳过我的数据集的第一行(列标题)。
有人知道这个错误是从哪里来的吗?有人对此有解决办法吗?
请参阅以下代码进行说明:
def create_pipeline(bs, nr, ep):
def _X_parse_csv(file):
record_defaults=[[0]]*20
splits = tf.decode_csv(file, record_defaults)
input = splits
return input
def _y_parse_csv(file):
record_defaults=[[0]]*20
splits = tf.decode_csv(file, record_defaults)
label = splits[0]
return label
# Dataset for input data
file = tf.gfile.Glob("./NSEOIL.csv")
num_workers = 1 # for testing; simulate 1 node for sharding below
task_index = 0
ds_file = tf.data.TextLineDataset(file)
ds = ds_file.flat_map(lambda file: (tf.data.TextLineDataset(file).skip(1))) #remove CSV headers
ds = ds.shard(num_workers, task_index).repeat(ep)
X_train = ds.map(_X_parse_csv)
ds = ds_file.flat_map(lambda file: (tf.data.TextLineDataset(file).skip(2))) #remove CSV headers + shift forward 1 day
ds = ds.shard(num_workers, task_index).repeat(ep)
y_train = ds.map(_y_parse_csv)
X_iterator = X_train.make_initializable_iterator()
y_iterator = y_train.make_initializable_iterator()
return X_iterator, y_iterator
这两行似乎是问题的根源:
ds_file = tf.data.TextLineDataset(file)
ds = ds_file.flat_map(lambda file: (tf.data.TextLineDataset(file).skip(1))) #remove CSV headers
第一行根据 file
中命名的一个或多个文件的行创建数据集。然后,第二行为 ds_file
中的每个元素创建一个数据集,将每个元素(来自 file
的一行文本)视为另一个文件名。当 file
的第一行似乎是 CSV header 时,您看到的 NotFoundError
被视为文件名。
修复相对简单,幸运的是,因为您可以使用 Dataset.list_files()
创建与您的 glob 匹配的文件数据集,然后 Dataset.flat_map()
将对文件名进行操作:
# Create a dataset of filenames.
ds_file = tf.data.Dataset.list_files("./NSEOIL.csv")
# For each filename in `ds_file`, read the lines from that file (skipping the
# header).
ds = ds_file.flat_map(lambda file: (tf.data.TextLineDataset(file).skip(1)))