如何在 TensorFlow 中更改 csv 文件的 dtype?

How do I change the dtype in TensorFlow for a csv file?

这是我正在尝试的代码 运行-

import tensorflow as tf
import numpy as np
import input_data

filename_queue = tf.train.string_input_producer(["cs-training.csv"])

reader = tf.TextLineReader()
key, value = reader.read(filename_queue)

record_defaults = [[1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1]]
col1, col2, col3, col4, col5, col6, col7, col8, col9, col10, col11 = tf.decode_csv(
    value, record_defaults=record_defaults)
features = tf.concat(0, [col2, col3, col4, col5, col6, col7, col8, col9, col10, col11])

with tf.Session() as sess:
  # Start populating the filename queue.
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)

  for i in range(1200):
    # Retrieve a single instance:
    print i
    example, label = sess.run([features, col1])
    try:
        print example, label
    except:
        pass

  coord.request_stop()
  coord.join(threads)

此代码return下面的错误。

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-23-e42fe2609a15> in <module>()
      7     # Retrieve a single instance:
      8     print i
----> 9     example, label = sess.run([features, col1])
     10     try:
     11         print example, label

/root/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in run(self, fetches, feed_dict)
    343 
    344     # Run request and get response.
--> 345     results = self._do_run(target_list, unique_fetch_targets, feed_dict_string)
    346 
    347     # User may have fetched the same tensor multiple times, but we

/root/anaconda/lib/python2.7/site-packages/tensorflow/python/client/session.pyc in _do_run(self, target_list, fetch_list, feed_dict)
    417         # pylint: disable=protected-access
    418         raise errors._make_specific_exception(node_def, op, e.error_message,
--> 419                                               e.code)
    420         # pylint: enable=protected-access
    421       raise e_type, e_value, e_traceback

InvalidArgumentError: Field 1 in record 0 is not a valid int32: 0.766126609

它后面有很多我认为与问题无关的信息。显然,问题是我提供给程序的很多数据都不是 int32 数据类型。它主要是浮点数。我尝试了一些改变数据类型的方法,比如在 tf.decode_csvtf.concat 中显式设置 dtype=float 参数。都没有用。这是一个无效的论点。最重要的是,我不知道这段代码是否会真正对数据进行预测。我希望它预测 col1 是 1 还是 0,并且我在代码中看不到任何暗示它实际上会做出该预测的内容。也许我会把这个问题留到另一个线程。非常感谢任何帮助!

更改 dtype 的答案就是像这样更改默认值-

record_defaults = [[1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.], [1.]]

执行此操作后,如果打印出 col1,您将收到此消息。

Tensor("DecodeCSV_43:0", shape=TensorShape([]), dtype=float32)

但是您会 运行 进入另一个错误, 总结一下答案,解决方法是像这样将 tf.concat 更改为 tf.pack

features = tf.pack([col2, col3, col4, col5, col6, col7, col8, col9, col10, col11])

tf.decode_csv() 的界面有点棘手。每列的 dtyperecord_defaults 参数的相应元素确定。代码中 record_defaults 的值被解释为每个列的类型为 tf.int32,这会在遇到浮点数据时导致错误。

假设您有以下 CSV 数据,其中包含三个整数列,后跟一个浮点数列:

4, 8, 9, 4.5
2, 5, 1, 3.7
2, 2, 2, 0.1

假设所有列都是 必需的,您将构建 record_defaults 如下:

value = ...

record_defaults = [tf.constant([], dtype=tf.int32),    # Column 0
                   tf.constant([], dtype=tf.int32),    # Column 1
                   tf.constant([], dtype=tf.int32),    # Column 2
                   tf.constant([], dtype=tf.float32)]  # Column 3

col0, col1, col2, col3 = tf.decode_csv(value, record_defaults=record_defauts)

assert col0.dtype == tf.int32
assert col1.dtype == tf.int32
assert col2.dtype == tf.int32
assert col3.dtype == tf.float32

record_defaults 中的空值表示该值是必需的。或者,如果(例如)第 2 列允许有缺失值,您可以定义 record_defaults 如下:

record_defaults = [tf.constant([], dtype=tf.int32),     # Column 0
                   tf.constant([], dtype=tf.int32),     # Column 1
                   tf.constant([0], dtype=tf.int32),    # Column 2
                   tf.constant([], dtype=tf.float32)]   # Column 3

您问题的第二部分涉及如何构建和训练一个模型,该模型根据输入数据预测其中一列的值。目前,该程序没有:它只是将列连接成一个张量,称为 features。您将需要定义和训练一个模型来解释该数据。最简单的此类方法之一是线性回归,您可能会发现有关 linear regression in TensorFlow 的本教程适用于您的问题。