使用 TPU 在 Tensorflow 中加载 CSV 文件时出现问题

Problem loading CSV files in Tensorflow with TPU

我正在尝试在 Tensorflow (V2.4.1) 中加载 CSV 文件。我正在使用 tf.data.experimental.make_csv_dataset,虽然它在执行函数时不会引发任何错误,但在尝试迭代数据集时出现错误。

我在使用 TPU 加速的 Kaggle 笔记本中 运行 它。如果我在 CPU 或 GPU 环境中执行相同的代码,一切正常。

GCS_PATH = KaggleDatasets().get_gcs_path('mydsname')
fpath = GCS_PATH + '/train.csv'

train_ds = tf.data.experimental.make_csv_dataset(
        fpath,
        64,
        select_columns=['sentence', 'label'],
        column_defaults=[tf.string, tf.float32],
        label_name='label',
        num_epochs=3,
        shuffle=False)

for item in train_ds.take(1):
    print(item)

我之前也有 copy/pasted 激活 Google Cloud SDK 的代码:

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)

这是我遇到的错误:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/iterator_ops.py in _next_internal(self)
    736         # Fast path for the case `self._structure` is not a nested structure.
--> 737         return self._element_spec._from_compatible_tensor_list(ret)  # pylint: disable=protected-access
    738       except AttributeError:

AttributeError: 'tuple' object has no attribute '_from_compatible_tensor_list'

During handling of the above exception, another exception occurred:

InvalidArgumentError                      Traceback (most recent call last)
/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/context.py in execution_mode(mode)
   2112       ctx.executor = executor_new
-> 2113       yield
   2114     finally:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/iterator_ops.py in _next_internal(self)
    738       except AttributeError:
--> 739         return structure.from_compatible_tensor_list(self._element_spec, ret)
    740 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/util/structure.py in from_compatible_tensor_list(element_spec, tensor_list)
    243       lambda spec, value: spec._from_compatible_tensor_list(value),
--> 244       element_spec, tensor_list)
    245 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/util/structure.py in _from_tensor_list_helper(decode_fn, element_spec, tensor_list)
    218     value = tensor_list[i:i + num_flat_values]
--> 219     flat_ret.append(decode_fn(component_spec, value))
    220     i += num_flat_values

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/util/structure.py in <lambda>(spec, value)
    242   return _from_tensor_list_helper(
--> 243       lambda spec, value: spec._from_compatible_tensor_list(value),
    244       element_spec, tensor_list)

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/tensor_spec.py in _from_compatible_tensor_list(self, tensor_list)
    176     assert len(tensor_list) == 1
--> 177     tensor_list[0].set_shape(self._shape)
    178     return tensor_list[0]

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in set_shape(self, shape)
   1213   def set_shape(self, shape):
-> 1214     if not self.shape.is_compatible_with(shape):
   1215       raise ValueError(

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in shape(self)
   1174         # `EagerTensor`, in C.
-> 1175         self._tensor_shape = tensor_shape.TensorShape(self._shape_tuple())
   1176       except core._NotOkStatusException as e:

InvalidArgumentError: Can't read header of file

During handling of the above exception, another exception occurred:

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-12-935e50497dbb> in <module>
----> 1 for e in train_ds.take(1):
      2     pass

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/iterator_ops.py in __next__(self)
    745   def __next__(self):
    746     try:
--> 747       return self._next_internal()
    748     except errors.OutOfRangeError:
    749       raise StopIteration

/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/iterator_ops.py in _next_internal(self)
    737         return self._element_spec._from_compatible_tensor_list(ret)  # pylint: disable=protected-access
    738       except AttributeError:
--> 739         return structure.from_compatible_tensor_list(self._element_spec, ret)
    740 
    741   @property

/opt/conda/lib/python3.7/contextlib.py in __exit__(self, type, value, traceback)
    128                 value = type()
    129             try:
--> 130                 self.gen.throw(type, value, traceback)
    131             except StopIteration as exc:
    132                 # Suppress StopIteration *unless* it's the same exception that

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/context.py in execution_mode(mode)
   2114     finally:
   2115       ctx.executor = executor_old
-> 2116       executor_new.wait()
   2117 
   2118 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/executor.py in wait(self)
     67   def wait(self):
     68     """Waits for ops dispatched in this executor to finish."""
---> 69     pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
     70 
     71   def clear_error(self):

InvalidArgumentError: Can't read header of file

fpath 似乎是正确的,因为如果我更改它的值,那么 make_csv_dataset 会引发不同的错误。

是否有人知道可能导致错误的原因?

我找到了问题的根源。在连接到 TPU 之前,我正在执行激活 Google Cloud SDK 的代码。正如他们在 this post 中所述,必须在连接到 TPU 后激活 SDK。