逻辑回归不泛化

Logistic regression not generalizing

根据 Andrew Ng 在 Coursera 上的逻辑回归讲座,可以使用以下更新表达式最小化以下成本函数:

运行 在 ~150 个样本上更新函数数百次,我得到以下模式,尽管每次迭代后成本似乎都按预期降低:

圆圈是我正在训练的样本,其中输入特征是每个点的 (x, y) 坐标,颜色是目标标签。红色或黄色背景是模型预测 (x, y) 输入分类为 (red = 0, yellow = 1)。

问题


训练方法

// A single pass/epoch

const lr = 0.003;
let params = [0.5, 0.5, 0.5];

const scores = samples.map(sample => sig(sum(sample, params));
const errors = scores.map((score, i) => score - labels[i][0]);

params = params.map((param, col) => {
  return param - lr * errors.reduce((acc, error, row) => {
    return acc + error * samples[row][col];
  }, 0);
});

样本训练数据

const samples = [
  [1, 142, 78],
  [1, 108, 182],
  [1, 396, 47],
  [1, 66,  102],
  [1, 165, 116],
  [1, 8,   106],
  [1, 245, 119],
  [1, 302, 17],
  [1, 96,  38],
  [1, 201, 132],
];

const labels = [
  [0],
  [1],
  [0],
  [0],
  [1],
  [1],
  [1],
  [0],
  [1],
];

编辑

这是一个 JSBin:https://jsbin.com/jinole/edit?html,js,output

你的问题是纯数值的,因为你直接实现了逻辑损失,所以你的函数 J 需要取一个点的指数。同时,你的数据量很大,你的x/y坐标成百上千。 exp(400) 在 JS 中导致 NaN,因此您的整个代码无法收敛。您需要做的就是将您的点放在 [0,2] x [0,4] 而不是 [0,200] x [0, 400] 矩形中,它会工作得很好。

例如:

function sum(x, w) {
  return x.reduce((acc, _x, i) => acc + _x * w[i], 0);
}

function sig(z) {
  return 1 / (1 + Math.exp(-z));
}

function cost(scores, labels) {
  return -(1 / scores.length) * scores.reduce((acc, score, i) => {
    var y = labels[i][0];
    return y * Math.log(score) + (1 - y) * Math.log(1 - score);
  }, 0);
}

function clear(ctx) {
  ctx.clearRect(0, 0, 400, 200);
}

function render(ctx, points) {
  points.forEach(point => {
    if (point[2] > 0) {
      ctx.fillStyle = '#3c5cff';
    } else {
      ctx.fillStyle = '#f956ff';
    }
    ctx.fillRect(Math.max(0, point[0] * 100 - 2), Math.max(0, point[1] * 100 - 2), 4, 4);
    //      ctx.fillRect(point[0], point[1], 1, 1);
  })
}

function renderEach(ctx, params) {
  for (let y = 0; y < 200; y++) {
    for (let x = 0; x < 400; x++) {
      if (sig(sum([1, x / 100, y / 100], params)) < 0.5) {
        ctx.fillStyle = '#b22438';
      } else {
        ctx.fillStyle = '#fff9b6';
      }

      ctx.fillRect(x, y, 1, 1);
    }
  }
}

function doEpoch(samples, params, learningRate, lastCost, cycle, maxCycles) {
  var scores = samples.map(sample => sig(sum(sample, params)));
  var errors = scores.map((score, i) => score - labels[i][0]);

  var p = document.getElementById('log');
  if (!p) {
    p = document.createElement('p');
    p.setAttribute('id', 'log');
    document.body.appendChild(p);
  }

  params = params.map((param, col) => {
    return param - learningRate * errors.reduce((acc, error, row) => (acc + error * samples[row][col]), 0);
  });

  var J = cost(scores, labels);
  if (lastCost === null) {
    lastCost = J;
  }

  if (cycle % 100 === 0) {
    p.textContent = `Epoch = ${cycle}, Cost = ${J} (${J - lastCost}), Params = ${JSON.stringify(params, null, 2)}`;
    clear(ctx);
    renderEach(ctx, params);
    render(ctx, points);
  }

  if (cycle < maxCycles) {
    setTimeout(function() {
      doEpoch(samples, params, learningRate, J, cycle + 1, maxCycles);
    }, 10);
  }
}

var canvas = document.createElement('canvas');
canvas.width = 400;
canvas.height = 200;
document.body.appendChild(canvas);
var ctx = canvas.getContext('2d');

var lineY = 150;
var points = [];
for (let i = 0; i < 500; i++) {
  var point = [parseInt(Math.random() * canvas.width, 10) / 100, parseInt(Math.random() * canvas.height, 10) / 100];
  point.push(Number(point[1] <= lineY / 100));
  points.push(point);
}

render(ctx, points);

var samples = points.map(point => [point[0], point[1]]);
var labels = points.map(point => [point[2]]);

console.log('Samples', JSON.stringify(samples.slice(0, 10)));
console.log('Labels', JSON.stringify(labels.slice(0, 10)));

var params = [1].concat(samples[0].map(() => Math.random()));
var withBias = samples.map(sample => [1].concat(sample));

var epochs = 100000;
var learningRate = 0.01;
var lastCost = null;

doEpoch(withBias, params, learningRate, lastCost, 0, epochs);
body {
  background: #eee;
  padding: 0;
  margin: 0;
  font-family: monospace;
}

canvas {
  background: #fff;
  width: 100%;
  image-rendering: pixelated;
}
<div id="plot-app"></div>