通过减少for循环的数量来加快速度



我目前正在使用Matlab进行编码。众所周知,for循环很慢,而Matlab做矩阵乘法非常有效。不幸的是,我的代码中充满了for循环。

我的代码类似于:

function FInt = ComputeF(A, B, C, D, E, F, G, H)
%A is a real column vector of size [Na, 1].
%B is a real column vector of size [Nb, 1].
%C is a real column vector of size [Nc, 1].
%D is a real column vector of size [Nd, 1].
%E, F, G are real column vectors of the same size as A.
%H is a real column vector of the same size as A.
%This function evaluates FInt, a tensor of size [Na, Nb, Nc, Nd]. 
%Recording the correct dimensions and initializing FInt
Na = size(A, 1);
Nb = size(B, 1);
Nc = size(C, 1);
Nd = size(D, 1);
FInt = zeros(Na, Nb, Nc, Nd);
%Computing the tensor FInt
for na = 1:Na
for nc=1:Nc
for nd=1:Nd
%Calculating intermediate values
S1 = -((B(:) - C(nc) + E(na)) ./ (2 * sin(D(nd) ./ 2))).^2;
S2 = (B(:) + C(nc) + F(na)) ./ (2 .* cos(D(nd) ./ 2));
S3 = (B(:) + C(nc) + G(na)) ./ (2 .* cos(D(nd) ./ 2));
S4 = H(na) ./ cos(D(nd) ./ 2);
%Calculating the integrand FInt
FInt(na, nc, :, nd) = exp(S1) .* (sinh(S2 + 1i * S4) + conj(sinh(S3 + 1i * S4)));
end
end
end
end

正如您所看到的,我已经尝试通过将:用于向量B来对过程进行矢量化,从而至少提高了一点计算速度。(为什么是B?通常它是最长的矢量(。

我的问题是,数量依赖于太多的索引,以至于我不知道如何正确地将其矢量化。

在numpy中,有一个概念被正式称为广播。MATLAB在R2016b中引入了这一概念。它被称为";矢量化""膨胀";,有时;广播";在MATLAB社区中。这个想法是,如果你把一堆数组的维度排成一行,你就可以扩展单元维度来匹配完整的维度。这里有一个关于这个主题的好资源:https://blogs.mathworks.com/loren/2016/10/24/matlab-arithmetic-expands-in-r2016b/.

如果你想让结果的大小为[Na, Nb, Nc, Nd],你可以让所有数组的大小都合适,用数组填充缺失的维度:

A = reshape(A, Na, 1, 1, 1);
B = reshape(B, 1, Nb, 1, 1);
C = reshape(C, 1, 1, Nc, 1);
D = reshape(D, 1, 1, 1, Nd);
E = reshape(E, Na, 1, 1, 1);
F = reshape(F, Na, 1, 1, 1);
G = reshape(G, Na, 1, 1, 1);
H = reshape(H, Na, 1, 1, 1);

现在,您可以直接对这些阵列执行矢量化操作,而不会产生歧义:

S1 = -((B - C + E) ./ (2 * sin(D ./ 2))).^2;
S2 = (B + C + F) ./ (2 .* cos(D ./ 2));
S3 = (B + C + G) ./ (2 .* cos(D ./ 2));
S4 = H ./ cos(D ./ 2);
%Calculating the integrand F
FInt = exp(S1) .* (sinh(S2 + 1i * S4) + conj(sinh(S3 + 1i * S4)));

请注意,这里删除了所有显式循环。中间阵列的大小取决于其输入的大小:

size(S1) == [Na, Nb, Nc, Nd]
size(S2) == [Na, Nb, Nc, Nd]
size(S3) == [Na, Nb, Nc, Nd]
size(S4) == [Na, 1, 1, Nd]

您不需要预先分配输出,因为它是由输入的大小自动产生的。

相关内容

  • 没有找到相关文章

最新更新