如何训练 LSTM 进行最简单的函数识别
How to train LSTM for a simplest function recognition
我正在学习 LSTM 网络并决定尝试综合测试。我希望 LSTM 网络由一些点 (x,y) 来区分三个基本功能:
- 行:y = k*x + b
- 抛物线:y = k*x^2 + b
- sqrt: y = k*sqrt(x) + b
我正在使用 lua + 手电筒。
数据集完全是虚拟的 - 它是在 'dataset' 对象上即时创建的。当训练周期要求另一个小批量样本时,函数 mt.__index returns 样本,动态创建。它随机选择三个描述的函数中的一个,并为它们选择一些随机点。
想法是 LSTM 网络会学习一些特征来识别最后的点属于哪种函数。
包括完整而简单的源脚本:
require "torch"
require "nn"
require "rnn"
-- hyper-parameters
batchSize = 8
rho = 5 -- sequence length
hiddenSize = 100
outputSize = 3
lr = 0.001
-- Initialize synthetic dataset
-- dataset[index] returns table of the form: {inputs, targets}
-- where inputs is a set of points (x,y) of a randomly selected function: line, parabola, sqrt
-- and targets is a set of corresponding class of a function (1=line, 2=parabola, 3=sqrt)
local dataset = {}
dataset.size = function (self)
return 1000
end
local mt = {}
mt.__index = function (self, i)
local class = math.random(3)
local t = torch.Tensor(3):zero()
t[class] = 1
local targets = {}
for i = 1,batchSize do table.insert(targets, class) end
local inputs = {}
local k = math.random()
local b = math.random()*5
-- Line
if class == 1 then
for i = 1,batchSize do
local x = math.random()*10 + 5
local y = k*x + b
input = torch.Tensor(2)
input[1] = x
input[2] = y
table.insert(inputs, input)
end
-- Parabola
elseif class == 2 then
for i = 1,batchSize do
local x = math.random()*10 + 5
local y = k*x*x + b
input = torch.Tensor(2)
input[1] = x
input[2] = y
table.insert(inputs, input)
end
-- Sqrt
else
for i = 1,batchSize do
local x = math.random()*5 + 5
local y = k*math.sqrt(x) + b
input = torch.Tensor(2)
input[1] = x
input[2] = y
table.insert(inputs, input)
end
end
return { inputs, targets }
end -- dataset.__index meta function
setmetatable(dataset, mt)
-- Initialize random number generator
math.randomseed( os.time() )
-- build simple recurrent neural network
local model = nn.Sequencer(
nn.Sequential()
:add( nn.LSTM(2, hiddenSize, rho) )
:add( nn.Linear(hiddenSize, outputSize) )
:add( nn.LogSoftMax() )
)
print(model)
-- build criterion
local criterion = nn.SequencerCriterion( nn.ClassNLLCriterion() )
-- training
model:training()
local epoch = 1
while true do
print ("Epoch "..tostring(epoch).." started")
for iteration = 1, dataset:size() do
-- 1. Load minibatch of samples
local sample = dataset[iteration] -- pick random sample (dataset always returns random set)
local inputs = sample[1]
local targets = sample[2]
-- 2. Perform forward run and calculate error
local outputs = model:forward(inputs)
local err = criterion:forward(outputs, targets)
print(string.format("Epoch %d Iteration %d Error = %f", epoch, iteration, err))
-- 3. Backward sequence through model(i.e. backprop through time)
local gradOutputs = criterion:backward(outputs, targets)
-- Sequencer handles the backwardThroughTime internally
model:backward(inputs, gradOutputs)
model:updateParameters(lr)
model:zeroGradParameters()
end -- for dataset
epoch = epoch + 1
end -- while epoch
问题是:网络不收敛。
你能分享我做错了什么吗?
这种做法是完全错误的。由于很多原因,以这种方式学习 LSTM 并不会学到你想要的东西。我将说明其中两个:
让我们假设您从 (-1, 1)
统一绘制 x
。然后函数 |x|
和 0.5x + 0.5
将为您提供与 y
完全相同的分布。这表明您使用的方法不是最好的函数识别方法。
LSTM 的关键是它的内存允许您在输入之间存储信息。它与独立绘制点序列完全相反(您在脚本中所做的)。在你的方法中学习的每一个记忆相关性都可能只是虚假的。
我决定post我自己的答案,因为我解决了这个问题并收到了很好的效果。
首先是关于 LSTM 对这类任务的适用性。如前所述,LSTM 很适合处理时间序列。您也可以将直线、抛物线和平方根视为一种时间函数。所以 LSTM 在这里完全适用。假设您正在接收实验结果,一次一个向量,并且您想了解哪种函数可以描述您的系列?
有人可能会争辩说,在上面的代码中,我们总是得到具有固定数量点的提要 NN(即 batch_size)。那么为什么要使用 LSTM?也许尝试使用一些线性或卷积网络?
好吧,别忘了 - 这是综合测试。在现实生活中的应用程序中,您可能会向 NN 提供大量数据点,并期望它能够识别函数形式。
例如在下面的代码中我们训练 NN一次有8个点(batch_size),但是当我们test NN我们只用了4个点(test_size).
我们得到了很好的结果:大约 1000 次迭代后,NN 给出了大约 99% 的正确答案。
但是一层NN不是魔术师。如果我们在每次迭代中改变函数的形式,它就无法学习任何特征。 IE。在原始代码中,k 和 b 在每次请求时都会更改为 dataset。我们应该做的是在启动时生成它们,不要更改。
所以下面的工作代码:
require "torch"
require "nn"
require "rnn"
-- Initialize random number generator
math.randomseed( os.time() )
-- hyper-parameters
batch_size = 8
test_size = 4
rho = 5 -- sequence length
hidden_size = 100
output_size = 3
learning_rate = 0.001
-- Initialize synthetic dataset
-- dataset[index] returns table of the form: {inputs, targets}
-- where inputs is a set of points (x,y) of a randomly selected function: line, parabola, sqrt
-- and targets is a set of corresponding class of a function (1=line, 2=parabola, 3=sqrt)
local dataset = {}
dataset.k = math.random()
dataset.b = math.random()*5
dataset.size = function (self)
return 1000
end
local mt = {}
mt.__index = function (self, i)
local class = math.random(3)
local t = torch.Tensor(3):zero()
t[class] = 1
local targets = {}
for i = 1,batch_size do table.insert(targets, class) end
local inputs = {}
local k = self.k
local b = self.b
-- Line
if class == 1 then
for i = 1,batch_size do
local x = math.random()*10 + 5
local y = k*x + b
input = torch.Tensor(2)
input[1] = x
input[2] = y
table.insert(inputs, input)
end
-- Parabola
elseif class == 2 then
for i = 1,batch_size do
local x = math.random()*10 + 5
local y = k*x*x + b
input = torch.Tensor(2)
input[1] = x
input[2] = y
table.insert(inputs, input)
end
-- Sqrt
else
for i = 1,batch_size do
local x = math.random()*5 + 5
local y = k*math.sqrt(x) + b
input = torch.Tensor(2)
input[1] = x
input[2] = y
table.insert(inputs, input)
end
end
return { inputs, targets }
end -- dataset.__index meta function
setmetatable(dataset, mt)
-- build simple recurrent neural network
local model = nn.Sequencer(
nn.Sequential()
:add( nn.LSTM(2, hidden_size, rho) )
:add( nn.Linear(hidden_size, output_size) )
:add( nn.LogSoftMax() )
)
print(model)
-- build criterion
local criterion = nn.SequencerCriterion( nn.ClassNLLCriterion() )
local epoch = 1
local err = 0
local pos = 0
local N = math.floor( dataset:size() * 0.1 )
while true do
print ("Epoch "..tostring(epoch).." started")
-- training
model:training()
for iteration = 1, dataset:size() do
-- 1. Load minibatch of samples
local sample = dataset[iteration] -- pick random sample (dataset always returns random set)
local inputs = sample[1]
local targets = sample[2]
-- 2. Perform forward run and calculate error
local outputs = model:forward(inputs)
local _err = criterion:forward(outputs, targets)
print(string.format("Epoch %d (pos=%f) Iteration %d Error = %f", epoch, pos, iteration, _err))
-- 3. Backward sequence through model(i.e. backprop through time)
local gradOutputs = criterion:backward(outputs, targets)
-- Sequencer handles the backwardThroughTime internally
model:backward(inputs, gradOutputs)
model:updateParameters(learning_rate)
model:zeroGradParameters()
end -- for training
-- Testing
model:evaluate()
err = 0
pos = 0
for iteration = 1, N do
-- 1. Load minibatch of samples
local sample = dataset[ math.random(dataset:size()) ]
local inputs = sample[1]
local targets = sample[2]
-- Drop last points to reduce to test_size
for i = #inputs, test_size, -1 do
inputs[i] = nil
targets[i] = nil
end
-- 2. Perform forward run and calculate error
local outputs = model:forward(inputs)
err = err + criterion:forward(outputs, targets)
local p = 0
for i = 1, #outputs do
local _, oi = torch.max(outputs[i], 1)
if oi[1] == targets[i] then p = p + 1 end
end
pos = pos + p/#outputs
end -- for testing
err = err / N
pos = pos / N
print(string.format("Epoch %d testing results: pos=%f err=%f", epoch, pos, err))
if (pos > 0.95) then break end
epoch = epoch + 1
end -- while epoch
我正在学习 LSTM 网络并决定尝试综合测试。我希望 LSTM 网络由一些点 (x,y) 来区分三个基本功能:
- 行:y = k*x + b
- 抛物线:y = k*x^2 + b
- sqrt: y = k*sqrt(x) + b
我正在使用 lua + 手电筒。
数据集完全是虚拟的 - 它是在 'dataset' 对象上即时创建的。当训练周期要求另一个小批量样本时,函数 mt.__index returns 样本,动态创建。它随机选择三个描述的函数中的一个,并为它们选择一些随机点。
想法是 LSTM 网络会学习一些特征来识别最后的点属于哪种函数。
包括完整而简单的源脚本:
require "torch"
require "nn"
require "rnn"
-- hyper-parameters
batchSize = 8
rho = 5 -- sequence length
hiddenSize = 100
outputSize = 3
lr = 0.001
-- Initialize synthetic dataset
-- dataset[index] returns table of the form: {inputs, targets}
-- where inputs is a set of points (x,y) of a randomly selected function: line, parabola, sqrt
-- and targets is a set of corresponding class of a function (1=line, 2=parabola, 3=sqrt)
local dataset = {}
dataset.size = function (self)
return 1000
end
local mt = {}
mt.__index = function (self, i)
local class = math.random(3)
local t = torch.Tensor(3):zero()
t[class] = 1
local targets = {}
for i = 1,batchSize do table.insert(targets, class) end
local inputs = {}
local k = math.random()
local b = math.random()*5
-- Line
if class == 1 then
for i = 1,batchSize do
local x = math.random()*10 + 5
local y = k*x + b
input = torch.Tensor(2)
input[1] = x
input[2] = y
table.insert(inputs, input)
end
-- Parabola
elseif class == 2 then
for i = 1,batchSize do
local x = math.random()*10 + 5
local y = k*x*x + b
input = torch.Tensor(2)
input[1] = x
input[2] = y
table.insert(inputs, input)
end
-- Sqrt
else
for i = 1,batchSize do
local x = math.random()*5 + 5
local y = k*math.sqrt(x) + b
input = torch.Tensor(2)
input[1] = x
input[2] = y
table.insert(inputs, input)
end
end
return { inputs, targets }
end -- dataset.__index meta function
setmetatable(dataset, mt)
-- Initialize random number generator
math.randomseed( os.time() )
-- build simple recurrent neural network
local model = nn.Sequencer(
nn.Sequential()
:add( nn.LSTM(2, hiddenSize, rho) )
:add( nn.Linear(hiddenSize, outputSize) )
:add( nn.LogSoftMax() )
)
print(model)
-- build criterion
local criterion = nn.SequencerCriterion( nn.ClassNLLCriterion() )
-- training
model:training()
local epoch = 1
while true do
print ("Epoch "..tostring(epoch).." started")
for iteration = 1, dataset:size() do
-- 1. Load minibatch of samples
local sample = dataset[iteration] -- pick random sample (dataset always returns random set)
local inputs = sample[1]
local targets = sample[2]
-- 2. Perform forward run and calculate error
local outputs = model:forward(inputs)
local err = criterion:forward(outputs, targets)
print(string.format("Epoch %d Iteration %d Error = %f", epoch, iteration, err))
-- 3. Backward sequence through model(i.e. backprop through time)
local gradOutputs = criterion:backward(outputs, targets)
-- Sequencer handles the backwardThroughTime internally
model:backward(inputs, gradOutputs)
model:updateParameters(lr)
model:zeroGradParameters()
end -- for dataset
epoch = epoch + 1
end -- while epoch
问题是:网络不收敛。 你能分享我做错了什么吗?
这种做法是完全错误的。由于很多原因,以这种方式学习 LSTM 并不会学到你想要的东西。我将说明其中两个:
让我们假设您从
(-1, 1)
统一绘制x
。然后函数|x|
和0.5x + 0.5
将为您提供与y
完全相同的分布。这表明您使用的方法不是最好的函数识别方法。LSTM 的关键是它的内存允许您在输入之间存储信息。它与独立绘制点序列完全相反(您在脚本中所做的)。在你的方法中学习的每一个记忆相关性都可能只是虚假的。
我决定post我自己的答案,因为我解决了这个问题并收到了很好的效果。
首先是关于 LSTM 对这类任务的适用性。如前所述,LSTM 很适合处理时间序列。您也可以将直线、抛物线和平方根视为一种时间函数。所以 LSTM 在这里完全适用。假设您正在接收实验结果,一次一个向量,并且您想了解哪种函数可以描述您的系列?
有人可能会争辩说,在上面的代码中,我们总是得到具有固定数量点的提要 NN(即 batch_size)。那么为什么要使用 LSTM?也许尝试使用一些线性或卷积网络?
好吧,别忘了 - 这是综合测试。在现实生活中的应用程序中,您可能会向 NN 提供大量数据点,并期望它能够识别函数形式。
例如在下面的代码中我们训练 NN一次有8个点(batch_size),但是当我们test NN我们只用了4个点(test_size).
我们得到了很好的结果:大约 1000 次迭代后,NN 给出了大约 99% 的正确答案。
但是一层NN不是魔术师。如果我们在每次迭代中改变函数的形式,它就无法学习任何特征。 IE。在原始代码中,k 和 b 在每次请求时都会更改为 dataset。我们应该做的是在启动时生成它们,不要更改。
所以下面的工作代码:
require "torch"
require "nn"
require "rnn"
-- Initialize random number generator
math.randomseed( os.time() )
-- hyper-parameters
batch_size = 8
test_size = 4
rho = 5 -- sequence length
hidden_size = 100
output_size = 3
learning_rate = 0.001
-- Initialize synthetic dataset
-- dataset[index] returns table of the form: {inputs, targets}
-- where inputs is a set of points (x,y) of a randomly selected function: line, parabola, sqrt
-- and targets is a set of corresponding class of a function (1=line, 2=parabola, 3=sqrt)
local dataset = {}
dataset.k = math.random()
dataset.b = math.random()*5
dataset.size = function (self)
return 1000
end
local mt = {}
mt.__index = function (self, i)
local class = math.random(3)
local t = torch.Tensor(3):zero()
t[class] = 1
local targets = {}
for i = 1,batch_size do table.insert(targets, class) end
local inputs = {}
local k = self.k
local b = self.b
-- Line
if class == 1 then
for i = 1,batch_size do
local x = math.random()*10 + 5
local y = k*x + b
input = torch.Tensor(2)
input[1] = x
input[2] = y
table.insert(inputs, input)
end
-- Parabola
elseif class == 2 then
for i = 1,batch_size do
local x = math.random()*10 + 5
local y = k*x*x + b
input = torch.Tensor(2)
input[1] = x
input[2] = y
table.insert(inputs, input)
end
-- Sqrt
else
for i = 1,batch_size do
local x = math.random()*5 + 5
local y = k*math.sqrt(x) + b
input = torch.Tensor(2)
input[1] = x
input[2] = y
table.insert(inputs, input)
end
end
return { inputs, targets }
end -- dataset.__index meta function
setmetatable(dataset, mt)
-- build simple recurrent neural network
local model = nn.Sequencer(
nn.Sequential()
:add( nn.LSTM(2, hidden_size, rho) )
:add( nn.Linear(hidden_size, output_size) )
:add( nn.LogSoftMax() )
)
print(model)
-- build criterion
local criterion = nn.SequencerCriterion( nn.ClassNLLCriterion() )
local epoch = 1
local err = 0
local pos = 0
local N = math.floor( dataset:size() * 0.1 )
while true do
print ("Epoch "..tostring(epoch).." started")
-- training
model:training()
for iteration = 1, dataset:size() do
-- 1. Load minibatch of samples
local sample = dataset[iteration] -- pick random sample (dataset always returns random set)
local inputs = sample[1]
local targets = sample[2]
-- 2. Perform forward run and calculate error
local outputs = model:forward(inputs)
local _err = criterion:forward(outputs, targets)
print(string.format("Epoch %d (pos=%f) Iteration %d Error = %f", epoch, pos, iteration, _err))
-- 3. Backward sequence through model(i.e. backprop through time)
local gradOutputs = criterion:backward(outputs, targets)
-- Sequencer handles the backwardThroughTime internally
model:backward(inputs, gradOutputs)
model:updateParameters(learning_rate)
model:zeroGradParameters()
end -- for training
-- Testing
model:evaluate()
err = 0
pos = 0
for iteration = 1, N do
-- 1. Load minibatch of samples
local sample = dataset[ math.random(dataset:size()) ]
local inputs = sample[1]
local targets = sample[2]
-- Drop last points to reduce to test_size
for i = #inputs, test_size, -1 do
inputs[i] = nil
targets[i] = nil
end
-- 2. Perform forward run and calculate error
local outputs = model:forward(inputs)
err = err + criterion:forward(outputs, targets)
local p = 0
for i = 1, #outputs do
local _, oi = torch.max(outputs[i], 1)
if oi[1] == targets[i] then p = p + 1 end
end
pos = pos + p/#outputs
end -- for testing
err = err / N
pos = pos / N
print(string.format("Epoch %d testing results: pos=%f err=%f", epoch, pos, err))
if (pos > 0.95) then break end
epoch = epoch + 1
end -- while epoch