如何降低 O(N^2) C-index 函数的时间复杂度?

How to reduce the time complexity of my O(N^2) C-index function?

我有以下函数(在 Matlab 中),它将计算一组给定的预测和观察值的一致性指数:

function civalue = CI(predval)
% FUNCTION civalue = CI(predval)
%
% DESCRIPTION: 
% - This function will calculate the concordance index. Not suitable for
% big vectors. O(n^2) time function. 
%
% INPUTS: 
% 'predval' a n-by-2 matrix, where the first column consists of the
% prediction values and the second column the actual label values. 
%
% OUTPUT: 
% 'civalue' the CI-value.

N = 0;
hSum = 0;

for i = 1:size(predval, 1)

    yi_pred = predval(i, 1);
    yi_val = predval(i, 2);
    for j = i+1:size(predval, 1)
        yj_pred = predval(j, 1);
        yj_val = predval(j, 2);
        if yi_val ~= yj_val
            N = N + 1;
            if  (yi_pred < yj_pred && yi_val < yj_val) || (yi_pred > yj_pred && yi_val > yj_val) % Order correct
                hSum = hSum + 1;
            elseif (yi_pred < yj_pred && yi_val > yj_val) || (yi_pred > yj_pred && yi_val < yj_val) % Order opposite 
                hSum = hSum + 0;
            elseif yi_pred == yj_pred % Random
                hSum = hSum + 0.5;
            end
        end
    end

end

civalue = hSum / N;

我的函数的时间复杂度为 O(N^2)。代码的想法是在数据点之间进行成对比较。有什么想法可以降低代码的时间复杂度吗?

CI 值或 C 指数背后的想法是衡量预测模型能够将数据点排列成正确顺序的程度。你给这个函数的是一组观测值 X 和它们对应的预测值 Y。该函数将对具有不同观测值的数据点进行排名比较,因为它们显然具有排名。

例如,假设某个变量有两个观测值,例如股票价格:P1 = 5$, P2 = 7$

现在我们创建一个模型来尝试预测股票价格。假设我们建立了模型并测试了它预测股票价格的能力,对于两个数据点 P1、P2,它预测值 Y1 = 5.5$ 和 Y2 = 8$。

现在您可以看到模型得到的 ORDER 是正确的,P1 < P2 && Y1 < Y2 但不是绝对值。当我们需要在一组备选方案之间进行选择时,这很有用,例如我应该买哪只股票最能增值等。

感谢大家的帮助!如果您需要更多信息等,请告诉我。:)

以下是我自己的实现与 Martin 的实现之间的比较:

您可以通过向量化内部循环显着缩短 运行 时间。下面的代码可以进一步优化(以牺牲易读性为代价)。在我的机器上,使用随机输入,代码 运行s 快了大约 50 倍并产生相同的结果。 (随机输入可能是一个糟糕的测试用例,因为 == 分支永远不会执行)

N = 0;
hSum = 0;
for i = 1:size(predval, 1)

    yi_pred = predval(i, 1);
    yi_val = predval(i, 2);
    yj_pred = predval(i+1:end,1);
    yj_val = predval(i+1:end,2);
    idxs = yi_val ~= yj_val;
    N = N + sum(idxs);

    yj_pred = yj_pred(idxs); % redefined to make the next lines prettier
    yj_val = yj_val(idxs); 
    hSum = hSum + sum((yi_pred < yj_pred & yi_val < yj_val) | ...
        (yi_pred > yj_pred & yi_val > yj_val)); % Order correct
    hSum = hSum + 0.5*sum(yi_pred == yj_pred); % Order random
end

虽然函数的复杂度仍然是 O(n^2)。

假设您的最终目标是提高 运行 时间性能,并且如果您有足够的记忆力来 运行 矢量化方法,这可能是其中之一 -

%// Column arrays
c1 = predval(:,1);
c2 = predval(:,2);

%// Get logical arrays of IF conditional statements in the original code
start_cond = bsxfun(@ne,c2,c2.')               %//'# starting condition

%// Rest of the three IF conditionals
case1 = bsxfun(@lt,c1,c1.') & bsxfun(@lt,c2,c2.') | ...
    bsxfun(@gt,c1,c1.') & bsxfun(@gt,c2,c2.')  %//'
case2 = bsxfun(@lt,c1,c1.') & bsxfun(@gt,c2,c2.') | ...
    bsxfun(@gt,c1,c1.') & bsxfun(@lt,c2,c2.')  %//'
case3 = bsxfun(@eq,c1,c1.')                    %//'

%// Get the counts for different cases and finally get the output sum
w1 = start_cond & case1
w2 = start_cond & ~case1 & ~case2 & case3
hSum = sum(w1(:))./2 + sum(w2(:))./4