使用带有 zarr 的 PyTorch IterableDataset 的内存泄漏问题
Memory leak issue using PyTorch IterableDataset with zarr
我正在尝试在 IterableDataset
上构建一个 pytorch
项目,使用 zarr
作为存储后端。
class Data(IterableDataset):
def __init__(self, path, start=None, end=None):
super(Data, self).__init__()
store = zarr.DirectoryStore(path)
self.array = zarr.open(store, mode='r')
if start is None:
start = 0
if end is None:
end = self.array.shape[0]
assert end > start
self.start = start
self.end = end
def __iter__(self):
return islice(self.array, self.start, self.end)
这适用于小型测试数据集,但一旦我移动到我的实际数据集 (480 000 000 x 290),我 运行 就会发生内存泄漏。我已经尝试定期注销 python 堆,因为一切都变慢了,但我看不到任何异常增加的大小,所以我使用的库 (pympler
) 实际上没有捕捉到内存泄漏。
我有点不知所措,所以如果有人知道如何进一步调试它,我们将不胜感激。
交叉发布于 PyTorch Forums。
原来我的验证例程有问题:
with torch.no_grad():
for batch in tqdm(testloader, **params):
x = batch[:, 1:].to(device)
y = batch[:, 0].unsqueeze(0).T
y_test_pred = torch.sigmoid(sxnet(x))
y_pred_tag = torch.round(y_test_pred)
y_pred_list.append(y_pred_tag.cpu().numpy())
y_list.append(y.numpy())
我原以为我很清楚 运行 将我的结果附加到列表中的麻烦,但问题是 .numpy
的结果是一个数组数组(因为原始数据类型是 1xn 张量)。
在 numpy 数组上添加 .flatten()
解决了这个问题,现在 RAM 消耗与我最初配置的一样。
我正在尝试在 IterableDataset
上构建一个 pytorch
项目,使用 zarr
作为存储后端。
class Data(IterableDataset):
def __init__(self, path, start=None, end=None):
super(Data, self).__init__()
store = zarr.DirectoryStore(path)
self.array = zarr.open(store, mode='r')
if start is None:
start = 0
if end is None:
end = self.array.shape[0]
assert end > start
self.start = start
self.end = end
def __iter__(self):
return islice(self.array, self.start, self.end)
这适用于小型测试数据集,但一旦我移动到我的实际数据集 (480 000 000 x 290),我 运行 就会发生内存泄漏。我已经尝试定期注销 python 堆,因为一切都变慢了,但我看不到任何异常增加的大小,所以我使用的库 (pympler
) 实际上没有捕捉到内存泄漏。
我有点不知所措,所以如果有人知道如何进一步调试它,我们将不胜感激。
交叉发布于 PyTorch Forums。
原来我的验证例程有问题:
with torch.no_grad():
for batch in tqdm(testloader, **params):
x = batch[:, 1:].to(device)
y = batch[:, 0].unsqueeze(0).T
y_test_pred = torch.sigmoid(sxnet(x))
y_pred_tag = torch.round(y_test_pred)
y_pred_list.append(y_pred_tag.cpu().numpy())
y_list.append(y.numpy())
我原以为我很清楚 运行 将我的结果附加到列表中的麻烦,但问题是 .numpy
的结果是一个数组数组(因为原始数据类型是 1xn 张量)。
在 numpy 数组上添加 .flatten()
解决了这个问题,现在 RAM 消耗与我最初配置的一样。