tensorflow.js 的估计参数的标准误差
Standard error of estimated parameters with tensorflow.js
我正在使用 tensorflow.js 获取指数回归的参数:
y(x) = c0*e^(kx)
我附上代码:
x = tf.tensor1d(x);
y = tf.tensor1d(y);
const c0 = tf.scalar().variable();
const k = tf.scalar(this.regression_parameters.array[index][0]).variable();
// y = c0*e^(k*x)
const fun = (x) => x.mul(k).exp().mul(c0);
const cost = (pred, label) => pred.sub(label).square().mean();
const learning_rate = 0.1;
const optimizer = tf.train.adagrad(learning_rate);
// Train the model.
for (let i = 0; i < 500; i++) {
optimizer.minimize(() => cost(fun(x), y));
}
非常符合实验信号。但是,我需要报告估计值(c0 和 k)的标准误差,因为它在 SciPy 中由 curve_fit() 给出。我想知道这是否可以用 tensorflow.js 完成。如果没有,是否还有其他 JavaScript 库可能有用?谢谢!
对于任何感兴趣的人,我最终自己估计指数函数的 Hessian 矩阵以计算 tensorflow.js 优化后参数的标准误差。我遵循了以下代码:
https://www.cpp.edu/~pbsiegel/javascript/curvefitchi.html
如下:
function hess_exp_errors(c0, k, x, y){
var sum1=0,sum2=0,sum3=0,sum4=0,sum5=0, expon1,
he11, he12, he22, dett, dof, k_err, c0_error;
for (i=0;i<x.length;i++){
expon1=Math.exp(k*x[i]);
sum1+=expon1*expon1;
sum2+=y[i]*x[i]*x[i]*expon1;
sum3+=x[i]*x[i]*expon1*expon1;
sum4+=y[i]*x[i]*expon1;
sum5+=x[i]*expon1*expon1;
}
he11=4*c0*c0*sum3-2*c0*sum2; he22=2*sum1;
he12=4*c0*sum5-2*sum4; dett=he11*he22-he12*he12;
c0_err=Math.sqrt(he11/dett); k_err=Math.sqrt(he22/dett);
return [c0_err, k_err];
};
我正在使用 tensorflow.js 获取指数回归的参数:
y(x) = c0*e^(kx)
我附上代码:
x = tf.tensor1d(x);
y = tf.tensor1d(y);
const c0 = tf.scalar().variable();
const k = tf.scalar(this.regression_parameters.array[index][0]).variable();
// y = c0*e^(k*x)
const fun = (x) => x.mul(k).exp().mul(c0);
const cost = (pred, label) => pred.sub(label).square().mean();
const learning_rate = 0.1;
const optimizer = tf.train.adagrad(learning_rate);
// Train the model.
for (let i = 0; i < 500; i++) {
optimizer.minimize(() => cost(fun(x), y));
}
非常符合实验信号。但是,我需要报告估计值(c0 和 k)的标准误差,因为它在 SciPy 中由 curve_fit() 给出。我想知道这是否可以用 tensorflow.js 完成。如果没有,是否还有其他 JavaScript 库可能有用?谢谢!
对于任何感兴趣的人,我最终自己估计指数函数的 Hessian 矩阵以计算 tensorflow.js 优化后参数的标准误差。我遵循了以下代码: https://www.cpp.edu/~pbsiegel/javascript/curvefitchi.html
如下:
function hess_exp_errors(c0, k, x, y){
var sum1=0,sum2=0,sum3=0,sum4=0,sum5=0, expon1,
he11, he12, he22, dett, dof, k_err, c0_error;
for (i=0;i<x.length;i++){
expon1=Math.exp(k*x[i]);
sum1+=expon1*expon1;
sum2+=y[i]*x[i]*x[i]*expon1;
sum3+=x[i]*x[i]*expon1*expon1;
sum4+=y[i]*x[i]*expon1;
sum5+=x[i]*expon1*expon1;
}
he11=4*c0*c0*sum3-2*c0*sum2; he22=2*sum1;
he12=4*c0*sum5-2*sum4; dett=he11*he22-he12*he12;
c0_err=Math.sqrt(he11/dett); k_err=Math.sqrt(he22/dett);
return [c0_err, k_err];
};