R 神经网络对线性函数 f(x)=19x 采取大量步骤

R neuralnet taking huge number of steps for linear function f(x)=19x

场景是这样的:-

我 运行 以下代码创建神经网络(使用 R 中的 neuralnet 包)来逼近函数 f(x)=x^2:-

set.seed(2016);
rm(list=ls());
# Prepare Training Data
attribute<-as.data.frame(sample(seq(-2,2,length=50), 50 , replace=FALSE ),ncol=1);
response<-attribute^2;
data <- cbind(attribute,response);
colnames(data)<- c("attribute","response");

# Create DNN
fit<-neuralnet(response~attribute,data=data,hidden = c(3,3),threshold=0.01);
fit$result.matrix;

这工作得很好,收敛于 3191 步。现在我对代码做了一点小改动——我改变了被逼近的函数。我没有使用二次函数,而是使用了一个非常简单的线性函数 f(x)=2x。这也很好用,然后我调整了 x 的系数并进行了多次运行,例如

f(x) = 2x
f(x) = 3x
.
.
f(x) = 19x

到目前为止它运行良好。但我注意到的一件事是,收敛所需的步骤数量从 2 倍急剧增加到 19 倍。例如 19x 的步数是惊人的 84099。奇怪的是,DNN 仅对线性函数采取了如此多的步骤来收敛,而对于二次函数 f(x)=x^2,它仅花费了 3191 步。

所以当我将函数更改为 f(x)=20x 时,它可能需要更多步骤,所以我收到以下警告:-

> set.seed(2016);
> rm(list=ls());
> # Prepare Training Data
> attribute<-as.data.frame(sample(seq(-2,2,length=50), 50 , replace=FALSE ),ncol=1);
> response<-attribute*20;
> data <- cbind(attribute,response);
> colnames(data)<- c("attribute","response");
> 
> # Create DNN
> fit<-neuralnet(response~attribute,data=data,hidden = c(3,3),threshold=0.01);
Warning message:
algorithm did not converge in 1 of 1 repetition(s) within the stepmax 
> fit$result.matrix;

所以我想我可以调整默认的 stepmax 参数并增加步数。但真正的问题是——为什么仅仅为了这样一个简单的线性函数就需要这么多步骤?

我认为这里最重要的是您需要扩展数据。随着你的值变大,神经网络需要覆盖更大范围的结果(这可以使用更复杂的技术来处理,比如不同形式的 momentum 但这不是这里的重点)。如果您只是像这样缩放数据:

maxs <- apply(data, 2, max) 
mins <- apply(data, 2, min)

raw.sc <- scale(data, center = mins, scale = maxs - mins)
scaled <- as.data.frame(raw.sc)

您现在可以更快地创建神经网络。注意 - 线性函数不需要多层。在这里,我用一个单层、一个节点的网络来演示这一点。这对 'deep learning'.

来说不是真正的问题
set.seed(123)
# Create NN
fit<-neuralnet(response~attribute,data=scaled,hidden = c(1),threshold=0.01);


# make sure to use normalized input and de-normalize the output
pr.nn <- compute(fit, scaled$attribute)
pr.nn_ <- pr.nn$net.result*attr(raw.sc, 'scaled:scale')[2] + attr(raw.sc, 'scaled:center')[2]

在这种情况下,它收敛于 1349 步。然后你计算一些指标,比如 MSE。

# MSE
sum((data$response - pr.nn_)^2)/nrow(data)