Applying vectorization techniques to speedup the performance of dividing a 3D matrix by a 2D matrix

조회 수: 4 (최근 30일)
I'm working on removing a for loop in my Matlab code to improve performance. My original code has one for loop (from j=1:Nx) that is harmful to performance (in my production code, this for loop is processed over 20 million times if I test large simulations). I am curious if I can remove this for loop through vectorization, repmat, or a similar technique. My original Matlab implementation is given below.
clc; clear all;
% Test Data
% I'm trying to remove the for loop for j in the code below
N = 10;
M = 10;
Nx = 32; % Ny=Nx=Nz
Nz = 32;
Ny = 32;
Fnmhat = rand(Nx,Nz+1);
Jnmhat = rand(Nx,1);
xi_n_m_hat = rand(Nx,N+1,M+1);
Uhat = zeros(Nx,Nz+1);
Uhat_2 = zeros(Nx,Nz+1);
identy = eye(Ny+1,Ny+1);
p = rand(Nx,1);
gammap = rand(Nx,1);
D = rand(Nx+1,Ny+1);
D2 = rand(Nx+1,Ny+1);
D_start = D(1,:);
D_end = D(end,:);
gamma = 1.5;
alpha = 0; % this could be non-zero
ntests = 100;
% Original Code (Partially vectorized)
tic
for n=0:N
for m=0:M
b = Fnmhat.';
alphaalpha = 1.0;
betabeta = 0.0; % this could be non-zero
gammagamma = gamma*gamma - p.^2 - 2*alpha.*p; % size (Nx,1)
d_min = 1.0;
n_min = 0.0; % this could be non-zero
r_min = xi_n_m_hat(:,n+1,m+1);
d_max = -1i.*gammap;
n_max = 1.0;
r_max = Jnmhat;
A = alphaalpha*D2 + betabeta*D + permute(gammagamma,[3,2,1]).*identy;
A(end,:,:) = repmat(n_min*D_end,[1,1,Nx]);
b(end,:) = r_min;
A(end,end,:) = A(end,end,:) + d_min;
A(1,:,:) = repmat(n_max*D_start,[1,1,Nx]);
A(1,1,:) = A(1,1,:) + permute(d_max,[2,3,1]);
b(1,:) = r_max;
% Non-vectorized code - can this part be vectorized?
for j=1:Nx
utilde = linsolve(A(:,:,j),b(:,j)); % A\b
Uhat(j,:) = utilde.';
end
end
end
toc
Here is my attempt at vectorizing the code (and removing the for loop for j).
% Same test data as original code
% New Code (completely vectorized but incorrect)
tic
for n=0:N
for m=0:M
b = Fnmhat.';
alphaalpha = 1.0;
betabeta = 0.0; % this could be non-zero
gammagamma = gamma*gamma - p.^2 - 2*alpha.*p; % size (Nx,1)
d_min = 1.0;
n_min = 0.0; % this could be non-zero
r_min = xi_n_m_hat(:,n+1,m+1);
d_max = -1i.*gammap;
n_max = 1.0;
r_max = Jnmhat;
A2 = alphaalpha*D2 + betabeta*D + permute(gammagamma,[3,2,1]).*identy;
A2(end,:,:) = repmat(n_min*D_end,[1,1,Nx]);
b(end,:) = r_min;
A2(end,end,:) = A2(end,end,:) + d_min;
A2(1,:,:) = repmat(n_max*D_start,[1,1,Nx]);
A2(1,1,:) = A2(1,1,:) + permute(d_max,[2,3,1]);
b(1,:) = r_max;
% Non-vectorized code - can this part be vectorized?
%for j=1:Nx
% utilde_2 = linsolve(A2(:,:,j),b(:,j)); % A2\b
% Uhat_2(j,:) = utilde_2.';
%end
% My attempt - this doesn't work since I don't loop through the index j
% in repmat
utilde_2 = squeeze(repmat(linsolve(A2(:,:,Nx),b(:,Nx)),[1,1,Nx]));
utilde_2 = utilde_2(:,1);
Uhat_2 = squeeze(repmat(utilde_2',[1,1,Nx]));
Uhat_2 = Uhat_2';
end
end
toc
diff = norm(Uhat - Uhat_2,inf); % is 0 if correct
I'm curious if repmat (or a different builtin Matlab function) can speed up this part of the code:
for j=1:Nx
utilde = linsolve(A(:,:,j),b(:,j)); % A\b
Uhat(j,:) = utilde.';
end
Is the for loop for j absolutely necessary or can it be removed?

채택된 답변

Bruno Luong
Bruno Luong 2021년 7월 29일
If you have C compilers the fatest methods are perhaps mmx and MultipleQR avaikable on FEX
  댓글 수: 9
Bruno Luong
Bruno Luong 2021년 7월 30일
MMX and MultipleQRSolve make parallel loop on page, MATLAB for-loop make parallize on the algorithm of linsolve.
That's why the matrix size matters, and possibly the number of the physical processor cores.
MultipleQRSolve is less efficient for larg matrix because my implementation of QR is less efficient than the Lapack function called by MMX and MATLAB native for-loop.
Matthew Kehoe
Matthew Kehoe 2021년 7월 31일
I think that MMX can be optimized if both A and B are not complex doubles. In my data, B is not a complex double so it may be possible to speed up the MMX calculation. Here is how I would implement the three different methods in my real Matlab code.
% These parameters mimic the real data in my code
m=33;
n=33;
p=1;
q=32;
ntests = 10000;
% My code calculates Ac and Br before going into the for loop
Ac = rand(m,n,q)+1i*rand(m,n,q); % A is a complex double of size (33,33,32)
Br = rand(m,q); % B is a (real) double of size (33,32)
% Before I decide to use a for loop/mmx/MultipleQRSolve my code
% "understands" that A is a complex double of size (33,33,32) and B is a
% (real) double of size (33,32). I don't need to calculate what A or B are inside
% the for loop. I only reshape B inside MMX and MultipleQRSolve because I
% have to for the divides operation.
% Here is how I would write the three functions below in my "real" code.
% for-loop
tic
for ii=1:ntests
z1 = zeros(q,m);
for j=1:q
% This is how my code currently computes A\b
utilde = linsolve(Ac(:,:,j),Br(:,j)); % A\b
z1(j,:) = utilde.';
end
end
toc % Elapsed time is 14.231135 seconds.
% mmx
tic
for ii=1:ntests
Bnew = reshape(Br,m,1,q); % Make Br size(33,1,32) to apply MMX
Ar = real(Ac);
Ai = imag(Ac);
Br = real(Bnew);
Bi = imag(Bnew); % is zero as b is a real double
% z_1 = Ar+Ai*i
% z_2 = Br+Bi*i
% z_1/z_2 = [(Ar*Br + Ai*Bi) + 1i*(Ai*Br - Ar*Bi)]/(Br^2 + Bi^2);
% Since Bi == 0, this is simplified to
% z_1/z_2 = [(Ar*Br) + 1i*(Ai*Br)]/(Br^2);
% I think that this makes the code below
%AA = [Ar,-Ai;Ai,Ar];
%BB = [Br;Bi];
%zz = mmx('backslash', AA, BB);
%z2=zz(1:n,:,:)+1i*zz(n+1:end,:,:);
% Into the faster version
Num = mmx('mult', Ar, Br);
Num = Num + 1i*mmx('mult', Ai, Br);
Den = Br.^2;
z2 = mmx('backslash',Num,Den);
z2 = permute(z2,[3 1 2]);
end
toc % Elapsed time is 2.441799 seconds.
% MultipleQRSolve
tic
for ii=1:ntests
Bnew_2 = reshape(Br,m,1,q); % Make Br size(33,1,32) to apply MultipleQRSolve
z3 = MultipleQRSolve(Ac,Bnew_2);
z3 = permute(z3,[3 1 2]);
end
toc % Elapsed time is 25.991396 seconds.
diff = norm(z1-z2,inf); % Not zero since my code for z_1/z_2 isn't correct.
diff2 = norm(z1-z3,inf);
If the code for
AA = [Ar,-Ai;Ai,Ar];
BB = [Br;Bi];
zz = mmx('backslash', AA, BB);
z2=zz(1:n,:,:)+1i*zz(n+1:end,:,:);
isn't needed (as B is not a complex double) then MMX would "beat" the for loop. Thanks for all of your help with this question (and for writing the MultipleQRSolve).

댓글을 달려면 로그인하십시오.

추가 답변 (2개)

Matt J
Matt J 2021년 7월 29일
편집: Matt J 2021년 7월 29일
Another idea.
clc; clear all;
% Test Data
% I'm trying to remove the for loop for j in the code below
N = 10;
M = 10;
Nx = 32; % Ny=Nx=Nz
Nz = 32;
Ny = 32;
AA=kron(speye(Nx),ones(Nx+1));
map=logical(AA);
% Original Code (Partially vectorized)
tic
for n=0:N
for m=0:M
....
%Vectorized code
AA(map)=A(:);
Uhat=reshape(AA\b(:),Nx+1,Nx).';
end
end
toc
  댓글 수: 5
Bruno Luong
Bruno Luong 2021년 7월 29일
AA(map)=A2(:)
Well, I know someone who is wonder why (:) can be much slower than reshape. ;-)
Matt J
Matt J 2021년 7월 29일
Yeah, I didn't see that A was complex-valued. So,
AA(map)=rehape(A,[],1);
Uhat=reshape(AA\b(:),Nx+1,Nx).';

댓글을 달려면 로그인하십시오.


Matt J
Matt J 2021년 7월 29일
편집: Matt J 2021년 7월 29일
On the GPU (i.e. if A and b are gpuArrays), the for-loop can be removed:
Uhat = permute( pagefun(@mldivide,A,reshape(b,[],1,Nx)) ,[2,1,3]);
  댓글 수: 1
Matthew Kehoe
Matthew Kehoe 2021년 7월 29일
This approach requires the Parallel Computing Toolbox. I will investigate getting this toolbox. Is there another approach that doesn't require a separate toolbox?

댓글을 달려면 로그인하십시오.

카테고리

Help CenterFile Exchange에서 Logical에 대해 자세히 알아보기

제품


릴리스

R2020a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by