使用 bsxfun 加速 Matlab 嵌套 for 循环
Accelerate Matlab nested for loop with bsxfun
我有一个图 n x n
图 W
描述为它的邻接矩阵和每个节点的组标签(整数)的 n
向量。
我需要为每对组计算 c
组中的节点和 d
组中的节点之间的 links(边)数。为此,我写了一个嵌套的 for loop
,但我确信这不是计算代码中称为 mcd
的矩阵的最快方法,即计算边数的矩阵c
和 d
组。
是否可以通过 bsxfun
使此操作更快?
function mcd = interlinks(W,ci)
%// W is the adjacency matrix of a simple undirected graph
%// ci are the group labels of every node in the graph, can be from 1 to |C|
n = length(W); %// number of nodes in the graph
m = sum(nonzeros(triu(W))); %// number of edges in the graph
ncomms = length(unique(ci)); %// number of groups of ci
mcd = zeros(ncomms); %// this is the matrix that counts the number of edges between group c and group d, twice the number of it if c==d
for c=1:ncomms
nodesc = find(ci==c); %// nodes in group c
for d=1:ncomms
nodesd = find(ci==d); %// nodes in group d
M = W(nodesc,nodesd); %// submatrix of edges between c and d
mcd(c,d) = sum(sum(M)); %// count of edges between c and d
end
end
%// Divide diagonal half because counted twice
mcd(1:ncomms+1:ncomms*ncomms)=mcd(1:ncomms+1:ncomms*ncomms)/2;
例如图中的邻接矩阵是
W=[0 1 1 0 0 0;
1 0 1 1 0 0;
1 1 0 0 1 1;
0 1 0 0 1 0;
0 0 1 1 0 1;
0 0 1 0 1 0];
组标签向量为ci=[ 1 1 1 2 2 3]
,结果矩阵mcd
为:
mcd=[3 2 1;
2 1 1;
1 1 0];
这意味着例如第 1 组有 3 个 link 与它自己,2 个 link 与第 2 组和 1 个 link 与第 3 组。
IIUC 并假设 ci
作为排序数组,看来您基本上是在进行块求和,但块大小不规则。因此,您可以使用沿行和列使用 cumsum
的方法,然后在 ci
中的移位位置进行微分,这基本上会为您提供分块求和。
实现看起来像这样 -
%// Get cumulative sums row-wise and column-wise
csums = cumsum(cumsum(W,1),2)
%/ Get IDs of shifts and thus get cumsums at those positions
[~,idx] = unique(ci) %// OR find(diff([ci numel(ci)]))
csums_indexed = csums(idx,idx)
%// Get the blockwise summations by differentiations on csums at shifts
col1 = diff(csums_indexed(:,1),[],1)
row1 = diff(csums_indexed(1,:),[],2)
rest2D = diff(diff(csums_indexed,[],2),[],1)
out = [[csums_indexed(1,1) ; col1] [row1 ; rest2D]]
如果你不反对 mex 函数,你可以使用我下面的代码。
测试代码
n = 2000;
n_labels = 800;
W = rand(n, n);
W = W * W' > .5; % generate symmetric adjacency matrix of logicals
Wd = double(W);
ci = floor(rand(n, 1) * n_labels ) + 1; % generate ids from 1 to 251
[C, IA, IC] = unique(ci);
disp(sprintf('base avg fun time = %g ',timeit(@() interlinks(W, IC))));
disp(sprintf('mex avg fun time = %g ',timeit(@() interlink_mex(W, IC))));
%note this function requires symmetric (function from @aarbelle)
disp(sprintf('bsx avg fun time = %g ',timeit(@() interlinks_bsx(Wd, IC'))));
x1 = interlinks(W, IC);
x2 = interlink_mex(W, IC);
x3 = interlinks_bsx(Wd, IC');
disp(sprintf('norm(x1 - x2) = %g', norm(x1 - x2)));
disp(sprintf('norm(x1 - x3) = %g', norm(x1 - x3)));
检测结果
使用这些设置的测试结果:
base avg fun time = 4.94275
mex avg fun time = 0.0373092
bsx avg fun time = 0.126406
norm(x1 - x2) = 0
norm(x1 - x3) = 0
基本上,对于小 n_labels
,bsx 函数表现非常好,但您可以将其设置得足够大,以便 mex 函数更快。
c++代码
将其放入interlink_mex.cpp
之类的文件中并使用mex interlink_mex.cpp
进行编译。你的机器上需要一个 C++ 编译器等...
#include "mex.h"
#include "matrix.h"
#include <math.h>
// Author: Matthew Gunn
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
if(nrhs != 2)
mexErrMsgTxt("Invalid number of inputs. Shoudl be 2 input argument.");
if(nlhs != 1)
mexErrMsgTxt("Invalid number of outputs. Should be 1 output arguments.");
if(!mxIsLogical(prhs[0])) {
mexErrMsgTxt("First argument should be a logical array (i.e. type logical)");
}
if(!mxIsDouble(prhs[1])) {
mexErrMsgTxt("Second argument should be an array of type double");
}
const mxArray *W = prhs[0];
const mxArray *ci = prhs[1];
size_t W_m = mxGetM(W);
size_t W_n = mxGetN(W);
if(W_m != W_n)
mexErrMsgTxt("Rows and columns of W are not equal");
// size_t ci_m = mxGetM(ci);
size_t ci_n = mxGetNumberOfElements(ci);
mxLogical *W_data = mxGetLogicals(W);
// double *W_data = mxGetPr(W);
double *ci_data = mxGetPr(ci);
size_t *ci_data_size_t = (size_t*) mxCalloc(ci_n, sizeof(size_t));
size_t ncomms = 0;
double intpart;
for(size_t i = 0; i < ci_n; i++) {
double x = ci_data[i];
if(x < 1 || x > 65536 || modf(x, &intpart) != 0.0) {
mexErrMsgTxt("Input ci is not all integers from 1 to a maximum value of 65536 (can edit source code to change this)");
}
size_t xx = (size_t) x;
if(xx > ncomms)
ncomms = xx;
ci_data_size_t[i] = xx - 1;
}
mxArray *mcd = mxCreateDoubleMatrix(ncomms, ncomms, mxREAL);
double *mcd_data = mxGetPr(mcd);
for(size_t i = 0; i < W_n; i++) {
size_t ii = ci_data_size_t[i];
for(size_t j = 0; j < W_n; j++) {
size_t jj = ci_data_size_t[j];
mcd_data[ii + jj * ncomms] += (W_data[i + j * W_m] != 0);
}
}
for(size_t i = 0; i < ncomms * ncomms; i+= ncomms + 1) //go along diagonal
mcd_data[i]/=2; //divide by 2
mxFree(ci_data_size_t);
plhs[0] = mcd;
}
这个怎么样?
C = bsxfun(@eq, ci,unique(ci)');
mcd = C*W*C'
mcd(logical(eye(size(mcd)))) = mcd(logical(eye(size(mcd))))./2;
我想这就是你想要的。
我有一个图 n x n
图 W
描述为它的邻接矩阵和每个节点的组标签(整数)的 n
向量。
我需要为每对组计算 c
组中的节点和 d
组中的节点之间的 links(边)数。为此,我写了一个嵌套的 for loop
,但我确信这不是计算代码中称为 mcd
的矩阵的最快方法,即计算边数的矩阵c
和 d
组。
是否可以通过 bsxfun
使此操作更快?
function mcd = interlinks(W,ci)
%// W is the adjacency matrix of a simple undirected graph
%// ci are the group labels of every node in the graph, can be from 1 to |C|
n = length(W); %// number of nodes in the graph
m = sum(nonzeros(triu(W))); %// number of edges in the graph
ncomms = length(unique(ci)); %// number of groups of ci
mcd = zeros(ncomms); %// this is the matrix that counts the number of edges between group c and group d, twice the number of it if c==d
for c=1:ncomms
nodesc = find(ci==c); %// nodes in group c
for d=1:ncomms
nodesd = find(ci==d); %// nodes in group d
M = W(nodesc,nodesd); %// submatrix of edges between c and d
mcd(c,d) = sum(sum(M)); %// count of edges between c and d
end
end
%// Divide diagonal half because counted twice
mcd(1:ncomms+1:ncomms*ncomms)=mcd(1:ncomms+1:ncomms*ncomms)/2;
例如图中的邻接矩阵是
W=[0 1 1 0 0 0;
1 0 1 1 0 0;
1 1 0 0 1 1;
0 1 0 0 1 0;
0 0 1 1 0 1;
0 0 1 0 1 0];
组标签向量为ci=[ 1 1 1 2 2 3]
,结果矩阵mcd
为:
mcd=[3 2 1;
2 1 1;
1 1 0];
这意味着例如第 1 组有 3 个 link 与它自己,2 个 link 与第 2 组和 1 个 link 与第 3 组。
IIUC 并假设 ci
作为排序数组,看来您基本上是在进行块求和,但块大小不规则。因此,您可以使用沿行和列使用 cumsum
的方法,然后在 ci
中的移位位置进行微分,这基本上会为您提供分块求和。
实现看起来像这样 -
%// Get cumulative sums row-wise and column-wise
csums = cumsum(cumsum(W,1),2)
%/ Get IDs of shifts and thus get cumsums at those positions
[~,idx] = unique(ci) %// OR find(diff([ci numel(ci)]))
csums_indexed = csums(idx,idx)
%// Get the blockwise summations by differentiations on csums at shifts
col1 = diff(csums_indexed(:,1),[],1)
row1 = diff(csums_indexed(1,:),[],2)
rest2D = diff(diff(csums_indexed,[],2),[],1)
out = [[csums_indexed(1,1) ; col1] [row1 ; rest2D]]
如果你不反对 mex 函数,你可以使用我下面的代码。
测试代码
n = 2000;
n_labels = 800;
W = rand(n, n);
W = W * W' > .5; % generate symmetric adjacency matrix of logicals
Wd = double(W);
ci = floor(rand(n, 1) * n_labels ) + 1; % generate ids from 1 to 251
[C, IA, IC] = unique(ci);
disp(sprintf('base avg fun time = %g ',timeit(@() interlinks(W, IC))));
disp(sprintf('mex avg fun time = %g ',timeit(@() interlink_mex(W, IC))));
%note this function requires symmetric (function from @aarbelle)
disp(sprintf('bsx avg fun time = %g ',timeit(@() interlinks_bsx(Wd, IC'))));
x1 = interlinks(W, IC);
x2 = interlink_mex(W, IC);
x3 = interlinks_bsx(Wd, IC');
disp(sprintf('norm(x1 - x2) = %g', norm(x1 - x2)));
disp(sprintf('norm(x1 - x3) = %g', norm(x1 - x3)));
检测结果
使用这些设置的测试结果:
base avg fun time = 4.94275
mex avg fun time = 0.0373092
bsx avg fun time = 0.126406
norm(x1 - x2) = 0
norm(x1 - x3) = 0
基本上,对于小 n_labels
,bsx 函数表现非常好,但您可以将其设置得足够大,以便 mex 函数更快。
c++代码
将其放入interlink_mex.cpp
之类的文件中并使用mex interlink_mex.cpp
进行编译。你的机器上需要一个 C++ 编译器等...
#include "mex.h"
#include "matrix.h"
#include <math.h>
// Author: Matthew Gunn
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[]) {
if(nrhs != 2)
mexErrMsgTxt("Invalid number of inputs. Shoudl be 2 input argument.");
if(nlhs != 1)
mexErrMsgTxt("Invalid number of outputs. Should be 1 output arguments.");
if(!mxIsLogical(prhs[0])) {
mexErrMsgTxt("First argument should be a logical array (i.e. type logical)");
}
if(!mxIsDouble(prhs[1])) {
mexErrMsgTxt("Second argument should be an array of type double");
}
const mxArray *W = prhs[0];
const mxArray *ci = prhs[1];
size_t W_m = mxGetM(W);
size_t W_n = mxGetN(W);
if(W_m != W_n)
mexErrMsgTxt("Rows and columns of W are not equal");
// size_t ci_m = mxGetM(ci);
size_t ci_n = mxGetNumberOfElements(ci);
mxLogical *W_data = mxGetLogicals(W);
// double *W_data = mxGetPr(W);
double *ci_data = mxGetPr(ci);
size_t *ci_data_size_t = (size_t*) mxCalloc(ci_n, sizeof(size_t));
size_t ncomms = 0;
double intpart;
for(size_t i = 0; i < ci_n; i++) {
double x = ci_data[i];
if(x < 1 || x > 65536 || modf(x, &intpart) != 0.0) {
mexErrMsgTxt("Input ci is not all integers from 1 to a maximum value of 65536 (can edit source code to change this)");
}
size_t xx = (size_t) x;
if(xx > ncomms)
ncomms = xx;
ci_data_size_t[i] = xx - 1;
}
mxArray *mcd = mxCreateDoubleMatrix(ncomms, ncomms, mxREAL);
double *mcd_data = mxGetPr(mcd);
for(size_t i = 0; i < W_n; i++) {
size_t ii = ci_data_size_t[i];
for(size_t j = 0; j < W_n; j++) {
size_t jj = ci_data_size_t[j];
mcd_data[ii + jj * ncomms] += (W_data[i + j * W_m] != 0);
}
}
for(size_t i = 0; i < ncomms * ncomms; i+= ncomms + 1) //go along diagonal
mcd_data[i]/=2; //divide by 2
mxFree(ci_data_size_t);
plhs[0] = mcd;
}
这个怎么样?
C = bsxfun(@eq, ci,unique(ci)');
mcd = C*W*C'
mcd(logical(eye(size(mcd)))) = mcd(logical(eye(size(mcd))))./2;
我想这就是你想要的。