有没有办法 optimize/vectorize 这些循环遍历 3D 数组中的元素,而不需要明显更多的内存?

Is there a way to optimize/vectorize these loops over elements in a 3D array, without requiring significantly more memory?

我正在寻找一些帮助来加速一些 Matlab 代码。

下面显示了一个最小示例。该代码正在对 [x,y,z] 坐标上定义的 3D 矩阵进行一些计算。使用探查器时,ind 上的内部循环是耗时的部分,所以我想知道是否可以优化此循环,或者完全 removed/vectorized。

Nx = 8; % Number of grid points
Ny = 6;
Nz = 4;
Ntot = Nx*Ny*Nz;

xvals = rand(1,Nx); % Create grid vectors
yvals = rand(1,Ny);
zvals = rand(1,Nz);

input_vec = rand(Ny,Nx,Nz); % Generate a dummy 3D matrix ( meshgrid convention, [y,x,z] )
input_vec = reshape( permute(input_vec,[3,1,2]) , [Ntot 1]); % Unwrap to 1D, so z cycles fastest, then y, then x

C1 = 5; % Loop counters
C2 = 6;
C3 = 7;

output_vec = zeros(Ntot,1); % Preallocate
temp_vec = zeros(Ntot,1);

for cnt1 = 1:C1
    for cnt2 = 1:C2
        for cnt3 = 1:C3
            
            factor1 = xvals*cnt1; % Calculate some vectors which depend on cnt variables
            factor2 = yvals*cnt2;
            factor3 = zvals*cnt3;
            
            for ind = 1:Ntot % Loop over every grid point
                j1 = floor( floor((ind-1) / Nz) / Ny) + 1; % +1 and -1 's account for Matlab's [1] indexing
                j2 = mod( floor((ind-1)/Nz) , Ny ) + 1;
                j3 = mod( (ind-1), Nz ) + 1;
                temp_vec(ind) = input_vec(ind) * factor1(j1)*factor2(j2)*factor3(j3);
            end
            output_vec = output_vec + temp_vec;
        end
    end
end

在我的实际应用中,点数更像是 1024x1024x512,所以我尽量避免使用大量 meshgrid 格式的变量(其中包含大量重复信息)以保持内存需求下降——这就是上面代码中 3D 数组被解包为 1D 的原因。例如,一种解决方案可能是预先计算所有 j1,j2,j3 值,例如

j1 = 1:Nx;
j2 = 1:Ny;
j3 = 1:Nz;
[J1,J2,J3] = meshgrid(j1,j2,j3);
J1 = reshape( permute(J1,[3,1,2]) , [Ntot 1]); % Unwrap to 1D, so z cycles fastest, then y, then x 
J2 = reshape( permute(J2,[3,1,2]) , [Ntot 1]); 
J3 = reshape( permute(J3,[3,1,2]) , [Ntot 1]); 

但这比每次根据 ind 的值计算单个 j 值需要更多的 RAM。

任何人都可以帮助 better/faster(但仍然有效内存)方法来做到这一点吗?谢谢。

以下:

factor1 = xvals*cnt1; % Calculate some vectors which depend on cnt variables
factor2 = yvals*cnt2;
factor3 = zvals*cnt3;
for ind = 1:Ntot % Loop over every grid point
   j1 = floor( floor((ind-1) / Nz) / Ny) + 1; % +1 and -1 's account for Matlab's [1] indexing
   j2 = mod( floor((ind-1)/Nz) , Ny ) + 1;
   j3 = mod( (ind-1), Nz ) + 1;
   temp_vec(ind) = input_vec(ind) * factor1(j1)*factor2(j2)*factor3(j3);
end

可以写成(未测试):

factor1 = xvals*cnt1; % Calculate some vectors which depend on cnt variables
factor2 = yvals*cnt2;
factor3 = zvals*cnt3;
ind = 1:Ntot;
j1 = floor( floor((ind-1) / Nz) / Ny) + 1; % +1 and -1 's account for Matlab's [1] indexing
j2 = mod( floor((ind-1)/Nz) , Ny ) + 1;
j3 = mod( (ind-1), Nz ) + 1;
temp_vec = input_vec(:).' .* factor1(j1).*factor2(j2).*factor3(j3);

(尤其是不使用 ind 进行索引可能会有很大的不同,尽管我认为这是在最新版本的 MATLAB 中优化的特殊情况。)

但我们仍在创建大型中间数组。您应该能够简化(同样,未经测试,并且可能有错误):

factor1 = xvals*cnt1;                  % horizontal array
factor2 = (yvals*cnt2).';              % vertical array
factor3 = permute(zvals*cnt3,[1,3,2]); % array along 3rd dimension
temp_vec = input_vec .* factor1 .* factor2 .* factor3;