如何在 Torch nn 包中禁用 omp?
How to disable omp in Torch nn package?
具体来说,我希望 nn.LogSoftMax
在输入张量较小时不使用 omp。我有一个小脚本来测试 运行 时间。
require 'nn'
my_lsm = function(t)
o = torch.zeros((#t)[1])
sum = 0.0
for i = 1,(#t)[1] do
o[i] = torch.exp(t[i])
sum = sum + o[i]
end
o = o / sum
return torch.log(o)
end
ii=torch.randn(arg[1])
m=nn.LogSoftMax()
timer = torch.Timer()
timer:stop()
timer:reset()
timer:resume()
my_lsm(ii)
print(timer:time().real)
timer:stop()
timer:reset()
timer:resume()
m:forward(ii)
print(timer:time().real)
如果 arg[1]
是 10,那么我的基本 log softmax 函数 运行 快得多:
0.00021696090698242
0.033425092697144
但是一旦 arg[1]
是 10,000,000,omp 真的很有帮助:
29.561321973801
0.11547803878784
所以我怀疑omp开销很高。如果我的代码必须使用小输入多次调用 log softmax(说张量大小只有 3),这将花费太多时间。有没有办法在某些情况下(但并非总是如此)手动禁用 omp 的使用?
Is there a way to manually disable omp usage in some cases (but not always)?
如果你真的想这样做,一种可能性是像这样使用 torch.setnumthreads
and torch.getnumthreads
:
local nth = torch.getnumthreads()
torch.setnumthreads(1)
-- do something
torch.setnumthreads(nth)
所以你可以猴子补丁 nn.LogSoftMax
如下:
nn.LogSoftMax.updateOutput = function(self, input)
local nth = torch.getnumthreads()
torch.setnumthreads(1)
local out = input.nn.LogSoftMax_updateOutput(self, input)
torch.setnumthreads(nth)
return out
end
具体来说,我希望 nn.LogSoftMax
在输入张量较小时不使用 omp。我有一个小脚本来测试 运行 时间。
require 'nn'
my_lsm = function(t)
o = torch.zeros((#t)[1])
sum = 0.0
for i = 1,(#t)[1] do
o[i] = torch.exp(t[i])
sum = sum + o[i]
end
o = o / sum
return torch.log(o)
end
ii=torch.randn(arg[1])
m=nn.LogSoftMax()
timer = torch.Timer()
timer:stop()
timer:reset()
timer:resume()
my_lsm(ii)
print(timer:time().real)
timer:stop()
timer:reset()
timer:resume()
m:forward(ii)
print(timer:time().real)
如果 arg[1]
是 10,那么我的基本 log softmax 函数 运行 快得多:
0.00021696090698242
0.033425092697144
但是一旦 arg[1]
是 10,000,000,omp 真的很有帮助:
29.561321973801
0.11547803878784
所以我怀疑omp开销很高。如果我的代码必须使用小输入多次调用 log softmax(说张量大小只有 3),这将花费太多时间。有没有办法在某些情况下(但并非总是如此)手动禁用 omp 的使用?
Is there a way to manually disable omp usage in some cases (but not always)?
如果你真的想这样做,一种可能性是像这样使用 torch.setnumthreads
and torch.getnumthreads
:
local nth = torch.getnumthreads()
torch.setnumthreads(1)
-- do something
torch.setnumthreads(nth)
所以你可以猴子补丁 nn.LogSoftMax
如下:
nn.LogSoftMax.updateOutput = function(self, input)
local nth = torch.getnumthreads()
torch.setnumthreads(1)
local out = input.nn.LogSoftMax_updateOutput(self, input)
torch.setnumthreads(nth)
return out
end