使用 TFRecordDataset 时如何从 TF2 图表中释放内存
How to release memory from TF2 graphs when using TFRecordDataset
如有错误请见谅。我来自 PyTorch 背景,但我需要使用 TFRecordDataset
才能从 TFRecord
中读取。目前,这看起来像下面这样:
class TFRecordReader:
def __iter__(self):
dataset = tf.data.TFRecordDataset(
[self.tfrecord_path], compression_type="GZIP"
)
dataset = dataset.map(self._parse_example)
self._tfr_iter = iter(dataset)
def __next__(self):
return next(self._tfr_iter)
但是,我需要为每个 PyTorch worker 创建多个 TFRecordReader
来进行批处理平衡。这导致我每个 GPU 的每个工作人员有 4 TFRecordDataset
(4 个要平衡的桶),所以我最终在内存中有 4 * 4 * 4 = 64 TFRecordDataset
。我有足够的 CPU 内存来执行此操作,但问题是内存没有从 TFRecordDataset
中释放,因为内存在训练过程中不断增加。我认为问题在于计算图不断增长(每次读取新的 TFRecord
都会为其创建新的 TFRecordDataset
),但从未发布。
如何确保在完成对单个 TFRecord
的迭代后释放 TFRecordDataset
使用的内存?
我试过了:
def __iter__(self)
with tf.Graph().as_default() as g:
dataset = tf.data.TFRecordDataset(
[self.tfrecord_path], compression_type="GZIP"
)
dataset = dataset.map(self._parse_example)
tf.compat.v1.enable_eager_execution()
self._tfr_iter = iter(dataset)
while True:
try:
example_dict = next(
self._tfr_iter
)
# ...
但是,我得到一个错误:
RuntimeError: __iter__() is only supported inside of tf.function or when eager execution is enabled.
如果有任何关于如何确保内存不会持续增长的建议,我将不胜感激。我正在使用 Tensorflow 2.5 作为参考。
原来问题出在将 PyTorch 分析器与 PyTorch Lightning 结合使用。问题不在于 Tensorflow。
查看相关问题here
如有错误请见谅。我来自 PyTorch 背景,但我需要使用 TFRecordDataset
才能从 TFRecord
中读取。目前,这看起来像下面这样:
class TFRecordReader:
def __iter__(self):
dataset = tf.data.TFRecordDataset(
[self.tfrecord_path], compression_type="GZIP"
)
dataset = dataset.map(self._parse_example)
self._tfr_iter = iter(dataset)
def __next__(self):
return next(self._tfr_iter)
但是,我需要为每个 PyTorch worker 创建多个 TFRecordReader
来进行批处理平衡。这导致我每个 GPU 的每个工作人员有 4 TFRecordDataset
(4 个要平衡的桶),所以我最终在内存中有 4 * 4 * 4 = 64 TFRecordDataset
。我有足够的 CPU 内存来执行此操作,但问题是内存没有从 TFRecordDataset
中释放,因为内存在训练过程中不断增加。我认为问题在于计算图不断增长(每次读取新的 TFRecord
都会为其创建新的 TFRecordDataset
),但从未发布。
如何确保在完成对单个 TFRecord
的迭代后释放 TFRecordDataset
使用的内存?
我试过了:
def __iter__(self)
with tf.Graph().as_default() as g:
dataset = tf.data.TFRecordDataset(
[self.tfrecord_path], compression_type="GZIP"
)
dataset = dataset.map(self._parse_example)
tf.compat.v1.enable_eager_execution()
self._tfr_iter = iter(dataset)
while True:
try:
example_dict = next(
self._tfr_iter
)
# ...
但是,我得到一个错误:
RuntimeError: __iter__() is only supported inside of tf.function or when eager execution is enabled.
如果有任何关于如何确保内存不会持续增长的建议,我将不胜感激。我正在使用 Tensorflow 2.5 作为参考。
原来问题出在将 PyTorch 分析器与 PyTorch Lightning 结合使用。问题不在于 Tensorflow。
查看相关问题here