三重加权总和
Triple weighted sum
我试图对某个加权和进行矢量化,但不知道该怎么做。我在下面创建了一个简单的最小工作示例。我想解决方案涉及 bsxfun 或 reshape 和 kronecker 产品,但我仍然没有设法让它工作。
rng(1);
N = 200;
T1 = 5;
T2 = 7;
T3 = 10;
A = rand(N,T1,T2,T3);
w1 = rand(T1,1);
w2 = rand(T2,1);
w3 = rand(T3,1);
B = zeros(N,1);
for i = 1:N
for j1=1:T1
for j2=1:T2
for j3=1:T3
B(i) = B(i) + w1(j1) * w2(j2) * w3(j3) * A(i,j1,j2,j3);
end
end
end
end
A = B;
对于二维情况,有一个聪明的答案。
您可以使用额外的乘法来修改之前答案的 w1 * w2'
网格,然后再乘以 w3
。然后您可以再次使用矩阵乘法与 "flattened" 版本的 A
.
相乘
W = reshape(w1 * w2.', [], 1) * w3.';
B = reshape(A, size(A, 1), []) * W(:);
您可以将权重的创建包装到它自己的函数中,并使其可推广到 N
权重。由于这使用递归,N
限于您当前的递归限制(默认为 500)。
function W = createWeights(W, varargin)
if numel(varargin) > 0
W = createWeights(W(:) * varargin{1}(:).', varargin{2:end});
end
end
并将其用于:
W = createWeights(w1, w2, w3);
B = reshape(A, size(A, 1), []) * W(:);
更新
使用@CKT 的部分非常好的建议kron
,我们可以稍微修改createWeights
。
function W = createWeights(W, varargin)
if numel(varargin) > 0
W = createWeights(kron(varargin{1}, W), varargin{2:end});
end
end
这是其背后的逻辑:
ww1 = repmat (permute (w1, [4, 1, 2, 3]), [N, 1, T2, T3]);
ww2 = repmat (permute (w2, [3, 4, 1, 2]), [N, T1, 1, T3]);
ww3 = repmat (permute (w3, [2, 3, 4, 1]), [N, T1, T2, 1 ]);
B = ww1 .* ww2 .* ww3 .* A;
B = sum (B(:,:), 2)
您可以通过首先在适当的维度中创建 w1
、w2
和 w3
来避免 permute
。您也可以使用 bsxfun
而不是 repmat
来获得额外的性能,我只是在这里展示逻辑,repmat
更容易理解。
编辑: 任意输入维度的通用版本:
Dims = {N, T1, T2, T3}; % add T4, T5, T6, etc as appropriate
Params = cell (1, length (Dims));
Params{1} = rand (Dims{:});
for n = 2 : length (Dims)
DimSubscripts = ones (1, length (Dims)); DimSubscripts(n) = Dims{n};
RepSubscripts = [Dims{:}]; RepSubscripts(n) = 1;
Params{n} = repmat (rand (DimSubscripts), RepSubscripts);
end
B = times (Params{:});
B = sum (B(:,:), 2)
同样,除非您创建了一些函数来构造 Kronecker 乘积向量,否则您无法将其很好地推广到 N-D,但是
A = reshape(A, N, []) * kron(w3, kron(w2, w1));
如果我们无论如何都要走函数路线,并且比 elegance/brevity 更看重性能,那么考虑一下:
function B = weightReduce(A, varargin)
B = A;
for i = length(varargin):-1:1
N = length(varargin{i});
B = reshape(B, [], N) * varargin{i};
end
end
这是我看到的性能对比:
tic;
for i = 1:10000
W = createWeights(w1,w2,w3);
B = reshape(A, size(A,1), [])*W(:);
end
toc
Elapsed time is 0.920821 seconds.
tic;
for i = 1:10000
B2 = weightReduce(A, w1, w2, w3);
end
toc
Elapsed time is 0.484470 seconds.
我试图对某个加权和进行矢量化,但不知道该怎么做。我在下面创建了一个简单的最小工作示例。我想解决方案涉及 bsxfun 或 reshape 和 kronecker 产品,但我仍然没有设法让它工作。
rng(1);
N = 200;
T1 = 5;
T2 = 7;
T3 = 10;
A = rand(N,T1,T2,T3);
w1 = rand(T1,1);
w2 = rand(T2,1);
w3 = rand(T3,1);
B = zeros(N,1);
for i = 1:N
for j1=1:T1
for j2=1:T2
for j3=1:T3
B(i) = B(i) + w1(j1) * w2(j2) * w3(j3) * A(i,j1,j2,j3);
end
end
end
end
A = B;
对于二维情况,有一个聪明的答案
您可以使用额外的乘法来修改之前答案的 w1 * w2'
网格,然后再乘以 w3
。然后您可以再次使用矩阵乘法与 "flattened" 版本的 A
.
W = reshape(w1 * w2.', [], 1) * w3.';
B = reshape(A, size(A, 1), []) * W(:);
您可以将权重的创建包装到它自己的函数中,并使其可推广到 N
权重。由于这使用递归,N
限于您当前的递归限制(默认为 500)。
function W = createWeights(W, varargin)
if numel(varargin) > 0
W = createWeights(W(:) * varargin{1}(:).', varargin{2:end});
end
end
并将其用于:
W = createWeights(w1, w2, w3);
B = reshape(A, size(A, 1), []) * W(:);
更新
使用@CKT 的部分非常好的建议kron
,我们可以稍微修改createWeights
。
function W = createWeights(W, varargin)
if numel(varargin) > 0
W = createWeights(kron(varargin{1}, W), varargin{2:end});
end
end
这是其背后的逻辑:
ww1 = repmat (permute (w1, [4, 1, 2, 3]), [N, 1, T2, T3]);
ww2 = repmat (permute (w2, [3, 4, 1, 2]), [N, T1, 1, T3]);
ww3 = repmat (permute (w3, [2, 3, 4, 1]), [N, T1, T2, 1 ]);
B = ww1 .* ww2 .* ww3 .* A;
B = sum (B(:,:), 2)
您可以通过首先在适当的维度中创建 w1
、w2
和 w3
来避免 permute
。您也可以使用 bsxfun
而不是 repmat
来获得额外的性能,我只是在这里展示逻辑,repmat
更容易理解。
编辑: 任意输入维度的通用版本:
Dims = {N, T1, T2, T3}; % add T4, T5, T6, etc as appropriate
Params = cell (1, length (Dims));
Params{1} = rand (Dims{:});
for n = 2 : length (Dims)
DimSubscripts = ones (1, length (Dims)); DimSubscripts(n) = Dims{n};
RepSubscripts = [Dims{:}]; RepSubscripts(n) = 1;
Params{n} = repmat (rand (DimSubscripts), RepSubscripts);
end
B = times (Params{:});
B = sum (B(:,:), 2)
同样,除非您创建了一些函数来构造 Kronecker 乘积向量,否则您无法将其很好地推广到 N-D,但是
A = reshape(A, N, []) * kron(w3, kron(w2, w1));
如果我们无论如何都要走函数路线,并且比 elegance/brevity 更看重性能,那么考虑一下:
function B = weightReduce(A, varargin)
B = A;
for i = length(varargin):-1:1
N = length(varargin{i});
B = reshape(B, [], N) * varargin{i};
end
end
这是我看到的性能对比:
tic;
for i = 1:10000
W = createWeights(w1,w2,w3);
B = reshape(A, size(A,1), [])*W(:);
end
toc
Elapsed time is 0.920821 seconds.
tic;
for i = 1:10000
B2 = weightReduce(A, w1, w2, w3);
end
toc
Elapsed time is 0.484470 seconds.