TensorFlow decode_csv 形状错误
TensorFlow decode_csv shape error
我使用 tf.data.TextLineDataset
读入 *.csv
文件并在其上应用 map
:
dataset = tf.data.TextLineDataset(os.path.join(data_dir, subset, 'label.txt'))
dataset = dataset.map(lambda value: parse_record_fn(value, is_training),
num_parallel_calls=num_parallel_calls)
解析函数 parse_record_fn
如下所示:
def parse_record(raw_record, is_training):
default_record = ["./", -1]
filename, label = tf.decode_csv([raw_record], default_record)
# do something
return image, label
但是在解析函数 tf.decode_csv
处有一个 ValueError
:
ValueError: Shape must be rank 1 but is rank 0 for 'DecodeCSV' (op: 'DecodeCSV') with input shapes: [1], [], [].
我的 *.csv
文件示例:
/data/1.png, 5
/data/2.png, 7
问题:
- 哪里出错了?
shapes: [1], [], []
是什么意思?
复制
此错误可在以下代码中重现:
import tensorflow as tf
import os
def parse_record(raw_record, is_training):
default_record = ["./", -1]
filename, label = tf.decode_csv([raw_record], default_record)
# do something
return image, label
with tf.Session() as sess:
csv_path = './labels.txt'
dataset = tf.data.TextLineDataset(csv_path)
dataset = dataset.map(lambda value: parse_record(value, True))
sess.run(dataset)
查看 tf.decode_csv
的文档,其中提到了默认记录:
record_defaults: A list of Tensor objects with specific types.
Acceptable types are float32, float64, int32, int64, string. One
tensor per column of the input record, with either a scalar default
value for that column or empty if the column is required.
我相信您遇到的错误源于您定义张量的方式 default_record
。您的 default_record
当然是张量对象(或可转换为张量的对象)的列表,但我认为错误消息表明它们应该是 rank-1 张量,而不是像您的情况那样是 rank-0 张量。
您可以通过将默认记录设为 1 阶张量来解决此问题。请参阅以下玩具示例:
import tensorflow as tf
my_line = 'filename.png, 10'
default_record_1 = [['./'], [-1]] # do this!
default_record_2 = ['./', -1] # this is what you do now
decoded_1 = tf.decode_csv(my_line, default_record_1)
with tf.Session() as sess:
d = sess.run(decoded_1)
print(d)
# This will cause an error
decoded_2 = tf.decode_csv(my_line, default_record_2)
最后一行产生的错误很熟悉:
ValueError: Shape must be rank 1 but is rank 0 for 'DecodeCSV_1' (op:
'DecodeCSV') with input shapes: [], [], [].
消息中的输入形状,三个括号[]
,指的是[的输入参数records
、record_defaults
、field_delim
的形状=11=]。在您的情况下,这些形状中的第一个是 [1]
,因为您输入了 [raw_record]
。我同意这个案例的信息不是很有用...
我使用 tf.data.TextLineDataset
读入 *.csv
文件并在其上应用 map
:
dataset = tf.data.TextLineDataset(os.path.join(data_dir, subset, 'label.txt'))
dataset = dataset.map(lambda value: parse_record_fn(value, is_training),
num_parallel_calls=num_parallel_calls)
解析函数 parse_record_fn
如下所示:
def parse_record(raw_record, is_training):
default_record = ["./", -1]
filename, label = tf.decode_csv([raw_record], default_record)
# do something
return image, label
但是在解析函数 tf.decode_csv
处有一个 ValueError
:
ValueError: Shape must be rank 1 but is rank 0 for 'DecodeCSV' (op: 'DecodeCSV') with input shapes: [1], [], [].
我的 *.csv
文件示例:
/data/1.png, 5
/data/2.png, 7
问题:
- 哪里出错了?
shapes: [1], [], []
是什么意思?
复制
此错误可在以下代码中重现:
import tensorflow as tf
import os
def parse_record(raw_record, is_training):
default_record = ["./", -1]
filename, label = tf.decode_csv([raw_record], default_record)
# do something
return image, label
with tf.Session() as sess:
csv_path = './labels.txt'
dataset = tf.data.TextLineDataset(csv_path)
dataset = dataset.map(lambda value: parse_record(value, True))
sess.run(dataset)
查看 tf.decode_csv
的文档,其中提到了默认记录:
record_defaults: A list of Tensor objects with specific types. Acceptable types are float32, float64, int32, int64, string. One tensor per column of the input record, with either a scalar default value for that column or empty if the column is required.
我相信您遇到的错误源于您定义张量的方式 default_record
。您的 default_record
当然是张量对象(或可转换为张量的对象)的列表,但我认为错误消息表明它们应该是 rank-1 张量,而不是像您的情况那样是 rank-0 张量。
您可以通过将默认记录设为 1 阶张量来解决此问题。请参阅以下玩具示例:
import tensorflow as tf
my_line = 'filename.png, 10'
default_record_1 = [['./'], [-1]] # do this!
default_record_2 = ['./', -1] # this is what you do now
decoded_1 = tf.decode_csv(my_line, default_record_1)
with tf.Session() as sess:
d = sess.run(decoded_1)
print(d)
# This will cause an error
decoded_2 = tf.decode_csv(my_line, default_record_2)
最后一行产生的错误很熟悉:
ValueError: Shape must be rank 1 but is rank 0 for 'DecodeCSV_1' (op: 'DecodeCSV') with input shapes: [], [], [].
消息中的输入形状,三个括号[]
,指的是[的输入参数records
、record_defaults
、field_delim
的形状=11=]。在您的情况下,这些形状中的第一个是 [1]
,因为您输入了 [raw_record]
。我同意这个案例的信息不是很有用...