使用 torch.serialize 两次时 Torch 内存不足
Torch out of memory in thread when using torch.serialize twice
我正在尝试向 torch-dataframe in order to add torchnet compatibility. I've used the tnt.ParallelDatasetIterator and changed it 添加并行数据加载器,以便:
- 基本批处理在线程外加载
- 批次被序列化并发送到线程
- 在线程中批处理被反序列化并将批处理数据转换为张量
- 张量在具有
input
和 target
键的 table 中返回,以匹配 tnt.Engine 设置。
问题 第二次调用 enque
时出现错误:.../torch_distro/install/bin/luajit: not enough memory
。我目前只使用 mnist with an adapted mnist-example。 enque
循环现在看起来像这样(带有调试内存输出):
-- `samplePlaceholder` stands in for samples which have been
-- filtered out by the `filter` function
local samplePlaceholder = {}
-- The enque does the main loop
local idx = 1
local function enqueue()
while idx <= size and threads:acceptsjob() do
local batch, reset = self.dataset:get_batch(batch_size)
if (reset) then
idx = size + 1
else
idx = idx + 1
end
if (batch) then
local serialized_batch = torch.serialize(batch)
-- In the parallel section only the to_tensor is run in parallel
-- this should though be the computationally expensive operation
threads:addjob(
function(argList)
io.stderr:write("\n Start");
io.stderr:write("\n 1: " ..tostring(collectgarbage("count")))
local origIdx, serialized_batch, samplePlaceholder = unpack(argList)
io.stderr:write("\n 2: " ..tostring(collectgarbage("count")))
local batch = torch.deserialize(serialized_batch)
serialized_batch = nil
collectgarbage()
collectgarbage()
io.stderr:write("\n 3: " .. tostring(collectgarbage("count")))
batch = transform(batch)
io.stderr:write("\n 4: " .. tostring(collectgarbage("count")))
local sample = samplePlaceholder
if (filter(batch)) then
sample = {}
sample.input, sample.target = batch:to_tensor()
end
io.stderr:write("\n 5: " ..tostring(collectgarbage("count")))
collectgarbage()
collectgarbage()
io.stderr:write("\n 6: " ..tostring(collectgarbage("count")))
io.stderr:write("\n End \n");
return {
sample,
origIdx
}
end,
function(argList)
sample, sampleOrigIdx = unpack(argList)
end,
{idx, serialized_batch, samplePlaceholder}
)
end
end
end
我已经洒了 collectgarbage
并且还尝试删除任何不需要的对象。内存输出相当简单:
Start
1: 374840.87695312
2: 374840.94433594
3: 372023.79101562
4: 372023.85839844
5: 372075.41308594
6: 372023.73632812
End
循环 enque
的函数是无序函数,它很简单(内存错误在第二个 enque
和 )中抛出):
iterFunction = function()
while threads:hasjob() do
enqueue()
threads:dojob()
if threads:haserror() then
threads:synchronize()
end
enqueue()
if table.exact_length(sample) > 0 then
return sample
end
end
end
所以问题是 torch.serialize
设置中的函数将整个数据集耦合到该函数。添加时:
serialized_batch = nil
collectgarbage()
collectgarbage()
问题已解决。我进一步想知道是什么占用了这么多 space,结果证明是我在一个具有与函数交织在一起的大型数据集的环境中定义了函数,从而大大增加了大小。这里原始定义数据local
mnist = require 'mnist'
local dataset = mnist[mode .. 'dataset']()
-- PROBLEMATIC LINE BELOW --
local ext_resource = dataset.data:reshape(dataset.data:size(1),
dataset.data:size(2) * dataset.data:size(3)):double()
-- Create a Dataframe with the label. The actual images will be loaded
-- as an external resource
local df = Dataframe(
Df_Dict{
label = dataset.label:totable(),
row_id = torch.range(1, dataset.data:size(1)):totable()
})
-- Since the mnist package already has taken care of the data
-- splitting we create a single subsetter
df:create_subsets{
subsets = Df_Dict{core = 1},
class_args = Df_Tbl({
batch_args = Df_Tbl({
label = Df_Array("label"),
data = function(row)
return ext_resource[row.row_id]
end
})
})
}
事实证明,删除我突出显示的行可将内存使用量从 358 Mb 降低到 0.0008 Mb!我用于测试性能的代码是:
local mem = {}
table.insert(mem, collectgarbage("count"))
local ser_data = torch.serialize(batch.dataset)
table.insert(mem, collectgarbage("count"))
local ser_retriever = torch.serialize(batch.batchframe_defaults.data)
table.insert(mem, collectgarbage("count"))
local ser_raw_retriever = torch.serialize(function(row)
return ext_resource[row.row_id]
end)
table.insert(mem, collectgarbage("count"))
local serialized_batch = torch.serialize(batch)
table.insert(mem, collectgarbage("count"))
for i=2,#mem do
print(i-1, (mem[i] - mem[i-1])/1024)
end
最初产生的输出:
1 0.0082607269287109
2 358.23344707489
3 0.0017471313476562
4 358.90182781219
修复后:
1 0.0094480514526367
2 0.00080204010009766
3 0.00090408325195312
4 0.010146141052246
我尝试使用 setfenv
作为函数,但没有解决问题。将序列化数据发送到线程仍然存在性能损失,但主要问题已解决,并且在没有昂贵的数据检索器的情况下,功能要小得多。
我正在尝试向 torch-dataframe in order to add torchnet compatibility. I've used the tnt.ParallelDatasetIterator and changed it 添加并行数据加载器,以便:
- 基本批处理在线程外加载
- 批次被序列化并发送到线程
- 在线程中批处理被反序列化并将批处理数据转换为张量
- 张量在具有
input
和target
键的 table 中返回,以匹配 tnt.Engine 设置。
问题 第二次调用 enque
时出现错误:.../torch_distro/install/bin/luajit: not enough memory
。我目前只使用 mnist with an adapted mnist-example。 enque
循环现在看起来像这样(带有调试内存输出):
-- `samplePlaceholder` stands in for samples which have been
-- filtered out by the `filter` function
local samplePlaceholder = {}
-- The enque does the main loop
local idx = 1
local function enqueue()
while idx <= size and threads:acceptsjob() do
local batch, reset = self.dataset:get_batch(batch_size)
if (reset) then
idx = size + 1
else
idx = idx + 1
end
if (batch) then
local serialized_batch = torch.serialize(batch)
-- In the parallel section only the to_tensor is run in parallel
-- this should though be the computationally expensive operation
threads:addjob(
function(argList)
io.stderr:write("\n Start");
io.stderr:write("\n 1: " ..tostring(collectgarbage("count")))
local origIdx, serialized_batch, samplePlaceholder = unpack(argList)
io.stderr:write("\n 2: " ..tostring(collectgarbage("count")))
local batch = torch.deserialize(serialized_batch)
serialized_batch = nil
collectgarbage()
collectgarbage()
io.stderr:write("\n 3: " .. tostring(collectgarbage("count")))
batch = transform(batch)
io.stderr:write("\n 4: " .. tostring(collectgarbage("count")))
local sample = samplePlaceholder
if (filter(batch)) then
sample = {}
sample.input, sample.target = batch:to_tensor()
end
io.stderr:write("\n 5: " ..tostring(collectgarbage("count")))
collectgarbage()
collectgarbage()
io.stderr:write("\n 6: " ..tostring(collectgarbage("count")))
io.stderr:write("\n End \n");
return {
sample,
origIdx
}
end,
function(argList)
sample, sampleOrigIdx = unpack(argList)
end,
{idx, serialized_batch, samplePlaceholder}
)
end
end
end
我已经洒了 collectgarbage
并且还尝试删除任何不需要的对象。内存输出相当简单:
Start
1: 374840.87695312
2: 374840.94433594
3: 372023.79101562
4: 372023.85839844
5: 372075.41308594
6: 372023.73632812
End
循环 enque
的函数是无序函数,它很简单(内存错误在第二个 enque
和 )中抛出):
iterFunction = function()
while threads:hasjob() do
enqueue()
threads:dojob()
if threads:haserror() then
threads:synchronize()
end
enqueue()
if table.exact_length(sample) > 0 then
return sample
end
end
end
所以问题是 torch.serialize
设置中的函数将整个数据集耦合到该函数。添加时:
serialized_batch = nil
collectgarbage()
collectgarbage()
问题已解决。我进一步想知道是什么占用了这么多 space,结果证明是我在一个具有与函数交织在一起的大型数据集的环境中定义了函数,从而大大增加了大小。这里原始定义数据local
mnist = require 'mnist'
local dataset = mnist[mode .. 'dataset']()
-- PROBLEMATIC LINE BELOW --
local ext_resource = dataset.data:reshape(dataset.data:size(1),
dataset.data:size(2) * dataset.data:size(3)):double()
-- Create a Dataframe with the label. The actual images will be loaded
-- as an external resource
local df = Dataframe(
Df_Dict{
label = dataset.label:totable(),
row_id = torch.range(1, dataset.data:size(1)):totable()
})
-- Since the mnist package already has taken care of the data
-- splitting we create a single subsetter
df:create_subsets{
subsets = Df_Dict{core = 1},
class_args = Df_Tbl({
batch_args = Df_Tbl({
label = Df_Array("label"),
data = function(row)
return ext_resource[row.row_id]
end
})
})
}
事实证明,删除我突出显示的行可将内存使用量从 358 Mb 降低到 0.0008 Mb!我用于测试性能的代码是:
local mem = {}
table.insert(mem, collectgarbage("count"))
local ser_data = torch.serialize(batch.dataset)
table.insert(mem, collectgarbage("count"))
local ser_retriever = torch.serialize(batch.batchframe_defaults.data)
table.insert(mem, collectgarbage("count"))
local ser_raw_retriever = torch.serialize(function(row)
return ext_resource[row.row_id]
end)
table.insert(mem, collectgarbage("count"))
local serialized_batch = torch.serialize(batch)
table.insert(mem, collectgarbage("count"))
for i=2,#mem do
print(i-1, (mem[i] - mem[i-1])/1024)
end
最初产生的输出:
1 0.0082607269287109
2 358.23344707489
3 0.0017471313476562
4 358.90182781219
修复后:
1 0.0094480514526367
2 0.00080204010009766
3 0.00090408325195312
4 0.010146141052246
我尝试使用 setfenv
作为函数,但没有解决问题。将序列化数据发送到线程仍然存在性能损失,但主要问题已解决,并且在没有昂贵的数据检索器的情况下,功能要小得多。