Armadillo (C++) 中的快速数组置换(广义张量转置)
Fast array permutation (generalised tensor transpose) in Armadillo (C++)
我有一个项目涉及很多 3D 数组排列 (arma::Cube<cx_double>
)。特别是,所需的排列是通过切片交换列。在 Matlab 中,这是通过 permute(cube,[1,3,2])
有效计算的,在 Python 中通过 numpy.transpose(cube,axis=[0,2,1])
有效计算。
遗憾的是 Armadillo 本身没有 permute
功能。我尝试过不同的方法,但与 Matlab 相比,它们都相当慢。 我想知道在 Armadillo 中置换(相当大的)立方体的更快方法是什么。使用 gprof
分析代码,大部分时间花在我在下面尝试的置换函数上,而在 Matlab 中,对于同一个移植项目,大部分时间花在 SVD 或 QR 矩阵分解(重塑和置换在 matlab 中速度很快)。
我想了解在犰狳中进行这种排列的最快方法是什么,以及为什么有些方法比其他方法效果更好。
选项 1:原始排列(最快的选项)(有更快的方法吗?)
将输入立方体按元素分配给输出立方体。
template <typename T>
static Cube<T> permute (Cube<T>& cube){
uword D1=cube.n_rows;
uword D2=cube.n_cols;
uword D3=cube.n_slices;
Cube<T> output(D1,D3,D2);
for (uword s = 0; s < D3; ++s){
for (uword c = 0; c < D2; ++c){
for (uword r = 0; r < D1; ++r){
output.at(r, s, c) = cube.at(r, c, s);
// output[ D1*D3*c + D1*s+ r ] = cube[ D1*D2*s + D1*c + r ];
}
}
}
return output;
}
选项 2:填充切片(很慢)
用不连续的 subcube
视图填充输出立方体的切片。
template <typename T>
static Cube<T> permute (Cube<T>& cube_in){
uword D1 = cube_in.n_rows;
uword D2 = cube_in.n_cols;
uword D3 = cube_in.n_slices;
Cube<T> output;
output.zeros(D1, D3, D2);
for (uword c=0; c<D2; ++c) {
output.slice(c) = cube_in.subcube( span(0,D1-1),span(c),span(0,D3-1) );
}
return output;
}
选项 3:转置层(比原始排列慢但相当)
我们可以迭代输入立方体的层(固定行)并转置它们。
template <typename T>
static Cube<T> permute (Cube<T>& cube_in){
// in a cube, permute {1,3,2} (permute slices by columns)
uword D1 = cube_in.n_rows;
uword D2 = cube_in.n_cols;
uword D3 = cube_in.n_slices;
if(D3 > D2){
cube_in.resize(D1,D3,D3);
} else if (D2 > D3) {
cube_in.resize(D1,D2,D2);
}
for (uword r=0; r<D1; ++r) {
static cmat layer = cmat(cube_in.rows(r,r));
inplace_strans(layer);
cube_in.rows(r,r)=layer;
}
cube_in.resize(D1,D3,D2);
return cube_in;
}
选项 4:查找 table
通过读取向量中的索引获得非连续访问。
template <typename T>
arma::Cube<T> permuteCS (arma::Cube<T> cube_in){
// in a cube, permute {1,3,2} (permute slices by columns)
uword D1 = cube_in.n_rows;
uword D2 = cube_in.n_cols;
uword D3 = cube_in.n_slices;
cx_vec onedcube = cube_in.elem(gen_trans_idx(cube_in));
return arma::Cube<T>(onedcube.memptr(), D1, D3, D2, true ) ;
}
其中 gen_trans_idx
是一个生成置换立方体索引的函数:
template <typename T>
uvec gen_trans_idx(Cube<T>& cube){
uword D1 = cube.n_rows;
uword D2 = cube.n_cols;
uword D3 = cube.n_slices;
uvec perm132(D1*D2*D3);
uword ii = 0;
for (int c = 0; c < D2; ++c){
for (int s = 0; s < D3; ++s){
for (int r = 0; r < D1; ++r){
perm132.at(ii) = sub2ind(size(cube), r, c, s);
ii=ii+1;
}}}
return perm132;
}
理想情况下,如果事先确定了多维数据集维度,则可以预先计算这些查找 table。
选项5(就地转置)非常慢,内存效率高
// Option: In-place transpose
template <typename T>
arma::Cube<T> permuteCS (arma::Cube<T> cube_in, uvec permlist ){
T* Qpoint = cube_in.memptr(); // pointer to first element of cube_in
uvec updateidx = find(permlist - arma::linspace<uvec>(0,cube_in.n_elem-1,cube_in.n_elem)); // index of elements that change position in memory
uvec skiplist(updateidx.n_elem,fill::zeros);
uword rr = 0; // aux index for updatelix
for(uword jj=0;jj<updateidx.n_elem;++jj){
if(any(updateidx[jj] == skiplist)){ // if element jj has already been updated
// do nothing
} else {
uword scope = updateidx[jj];
T target = *(Qpoint+permlist[scope]); // store the value of the target element
while(any(scope==skiplist)-1){ // while wareyou has not been updated
T local = *(Qpoint+scope); // store local value
*(Qpoint+scope) = target;
skiplist[rr]=scope;
++rr;
uvec wareyou = find(permlist==scope); // find where the local value will appear
scope = wareyou[0];
target = local;
}
}
}
cube_in.reshape(cube_in.n_rows,cube_in.n_slices,cube_in.n_cols);
return cub
e_in;
}
此代码作为我对 memcpy
黑客攻击的评论的补充。另外不要忘记尝试添加 const reference
以防止复制对象。
template <typename T>
static Cube<T> permute(const Cube<T> &cube){
const uword D1 = cube.n_rows;
const uword D2 = cube.n_cols;
const uword D3 = cube.n_slices;
const uword D1_mul_D3 = D1 * D3;
const Cube<T> output(D1, D3, D2);
const T * from = cube.memptr();
T *to = output.memptr();
for (uword s = 0; s < D3; ++s){
T *to_tmp = to + D1 * s;
for (uword c = 0; c < D2; ++c){
memcpy(to_tmp, from, D1 * sizeof(*from));
from += D1;
to_tmp += D1_mul_D3;
}
}
return output;
}
我有一个项目涉及很多 3D 数组排列 (arma::Cube<cx_double>
)。特别是,所需的排列是通过切片交换列。在 Matlab 中,这是通过 permute(cube,[1,3,2])
有效计算的,在 Python 中通过 numpy.transpose(cube,axis=[0,2,1])
有效计算。
遗憾的是 Armadillo 本身没有 permute
功能。我尝试过不同的方法,但与 Matlab 相比,它们都相当慢。 我想知道在 Armadillo 中置换(相当大的)立方体的更快方法是什么。使用 gprof
分析代码,大部分时间花在我在下面尝试的置换函数上,而在 Matlab 中,对于同一个移植项目,大部分时间花在 SVD 或 QR 矩阵分解(重塑和置换在 matlab 中速度很快)。
我想了解在犰狳中进行这种排列的最快方法是什么,以及为什么有些方法比其他方法效果更好。
选项 1:原始排列(最快的选项)(有更快的方法吗?)
将输入立方体按元素分配给输出立方体。
template <typename T>
static Cube<T> permute (Cube<T>& cube){
uword D1=cube.n_rows;
uword D2=cube.n_cols;
uword D3=cube.n_slices;
Cube<T> output(D1,D3,D2);
for (uword s = 0; s < D3; ++s){
for (uword c = 0; c < D2; ++c){
for (uword r = 0; r < D1; ++r){
output.at(r, s, c) = cube.at(r, c, s);
// output[ D1*D3*c + D1*s+ r ] = cube[ D1*D2*s + D1*c + r ];
}
}
}
return output;
}
选项 2:填充切片(很慢)
用不连续的 subcube
视图填充输出立方体的切片。
template <typename T>
static Cube<T> permute (Cube<T>& cube_in){
uword D1 = cube_in.n_rows;
uword D2 = cube_in.n_cols;
uword D3 = cube_in.n_slices;
Cube<T> output;
output.zeros(D1, D3, D2);
for (uword c=0; c<D2; ++c) {
output.slice(c) = cube_in.subcube( span(0,D1-1),span(c),span(0,D3-1) );
}
return output;
}
选项 3:转置层(比原始排列慢但相当)
我们可以迭代输入立方体的层(固定行)并转置它们。
template <typename T>
static Cube<T> permute (Cube<T>& cube_in){
// in a cube, permute {1,3,2} (permute slices by columns)
uword D1 = cube_in.n_rows;
uword D2 = cube_in.n_cols;
uword D3 = cube_in.n_slices;
if(D3 > D2){
cube_in.resize(D1,D3,D3);
} else if (D2 > D3) {
cube_in.resize(D1,D2,D2);
}
for (uword r=0; r<D1; ++r) {
static cmat layer = cmat(cube_in.rows(r,r));
inplace_strans(layer);
cube_in.rows(r,r)=layer;
}
cube_in.resize(D1,D3,D2);
return cube_in;
}
选项 4:查找 table 通过读取向量中的索引获得非连续访问。
template <typename T>
arma::Cube<T> permuteCS (arma::Cube<T> cube_in){
// in a cube, permute {1,3,2} (permute slices by columns)
uword D1 = cube_in.n_rows;
uword D2 = cube_in.n_cols;
uword D3 = cube_in.n_slices;
cx_vec onedcube = cube_in.elem(gen_trans_idx(cube_in));
return arma::Cube<T>(onedcube.memptr(), D1, D3, D2, true ) ;
}
其中 gen_trans_idx
是一个生成置换立方体索引的函数:
template <typename T>
uvec gen_trans_idx(Cube<T>& cube){
uword D1 = cube.n_rows;
uword D2 = cube.n_cols;
uword D3 = cube.n_slices;
uvec perm132(D1*D2*D3);
uword ii = 0;
for (int c = 0; c < D2; ++c){
for (int s = 0; s < D3; ++s){
for (int r = 0; r < D1; ++r){
perm132.at(ii) = sub2ind(size(cube), r, c, s);
ii=ii+1;
}}}
return perm132;
}
理想情况下,如果事先确定了多维数据集维度,则可以预先计算这些查找 table。
选项5(就地转置)非常慢,内存效率高
// Option: In-place transpose
template <typename T>
arma::Cube<T> permuteCS (arma::Cube<T> cube_in, uvec permlist ){
T* Qpoint = cube_in.memptr(); // pointer to first element of cube_in
uvec updateidx = find(permlist - arma::linspace<uvec>(0,cube_in.n_elem-1,cube_in.n_elem)); // index of elements that change position in memory
uvec skiplist(updateidx.n_elem,fill::zeros);
uword rr = 0; // aux index for updatelix
for(uword jj=0;jj<updateidx.n_elem;++jj){
if(any(updateidx[jj] == skiplist)){ // if element jj has already been updated
// do nothing
} else {
uword scope = updateidx[jj];
T target = *(Qpoint+permlist[scope]); // store the value of the target element
while(any(scope==skiplist)-1){ // while wareyou has not been updated
T local = *(Qpoint+scope); // store local value
*(Qpoint+scope) = target;
skiplist[rr]=scope;
++rr;
uvec wareyou = find(permlist==scope); // find where the local value will appear
scope = wareyou[0];
target = local;
}
}
}
cube_in.reshape(cube_in.n_rows,cube_in.n_slices,cube_in.n_cols);
return cub
e_in;
}
此代码作为我对 memcpy
黑客攻击的评论的补充。另外不要忘记尝试添加 const reference
以防止复制对象。
template <typename T>
static Cube<T> permute(const Cube<T> &cube){
const uword D1 = cube.n_rows;
const uword D2 = cube.n_cols;
const uword D3 = cube.n_slices;
const uword D1_mul_D3 = D1 * D3;
const Cube<T> output(D1, D3, D2);
const T * from = cube.memptr();
T *to = output.memptr();
for (uword s = 0; s < D3; ++s){
T *to_tmp = to + D1 * s;
for (uword c = 0; c < D2; ++c){
memcpy(to_tmp, from, D1 * sizeof(*from));
from += D1;
to_tmp += D1_mul_D3;
}
}
return output;
}