使用 TFRecordDataset 时如何从 TF2 图表中释放内存

How to release memory from TF2 graphs when using TFRecordDataset

如有错误请见谅。我来自 PyTorch 背景,但我需要使用 TFRecordDataset 才能从 TFRecord 中读取。目前,这看起来像下面这样:

class TFRecordReader:
    def __iter__(self):
        dataset = tf.data.TFRecordDataset(
            [self.tfrecord_path], compression_type="GZIP"
        )
        dataset = dataset.map(self._parse_example)
        self._tfr_iter = iter(dataset)

    def __next__(self):
        return next(self._tfr_iter)

但是,我需要为每个 PyTorch worker 创建多个 TFRecordReader 来进行批处理平衡。这导致我每个 GPU 的每个工作人员有 4 TFRecordDataset(4 个要平衡的桶),所以我最终在内存中有 4 * 4 * 4 = 64 TFRecordDataset。我有足够的 CPU 内存来执行此操作,但问题是内存没有从 TFRecordDataset 中释放,因为内存在训练过程中不断增加。我认为问题在于计算图不断增长(每次读取新的 TFRecord 都会为其创建新的 TFRecordDataset),但从未发布。

如何确保在完成对单个 TFRecord 的迭代后释放 TFRecordDataset 使用的内存?

我试过了:

def __iter__(self)
    with tf.Graph().as_default() as g:
        dataset = tf.data.TFRecordDataset(
            [self.tfrecord_path], compression_type="GZIP"
        )
        dataset = dataset.map(self._parse_example)

        tf.compat.v1.enable_eager_execution()
        self._tfr_iter = iter(dataset)

        while True:
            try:
                example_dict = next(
                    self._tfr_iter
                )
   # ...

但是,我得到一个错误:

RuntimeError: __iter__() is only supported inside of tf.function or when eager execution is enabled.

如果有任何关于如何确保内存不会持续增长的建议,我将不胜感激。我正在使用 Tensorflow 2.5 作为参考。

原来问题出在将 PyTorch 分析器与 PyTorch Lightning 结合使用。问题不在于 Tensorflow。

查看相关问题here