How could this code be vectorized?

조회 수: 3 (최근 30일)
Fredrik P
Fredrik P 2024년 1월 18일
답변: Voss 2024년 1월 18일
How can I vectorize this code? I suspect that it can be solve using ndgrid and sub2ind as my last question was. But I just don't get how it should be done. (I'm aware that the vectorized code might very well be slower than the loop-based code. RIght now, I'm just hoping to learn how it should be done.)
close; clear; clc;
n1 = 400;
n2 = 9;
n3 = 2;
A1 = rand(n1, n2, n3);
A2 = rand(n1, n2, n3);
A3 = rand(n1, n2, n3);
[~, B] = max(A3, [], 3);
[C11, C21, C31] = version1(A1, A2, A3, B, n1, n2, n3);
% [C12, C22, C32] = version2(A1, A2, A3, B, n1, n2, n3);
% disp([isequal(C11, C12), isequal(C21, C22), isequal(C31, C32)]);
disp(timeit(@() version1(A1, A2, A3, B, n1, n2, n3)));
4.4875e-04
% disp(timeit(@() version2(A1, A2, A3, B, n1, n2, n3)));
function [C1, C2, C3] = version1(A1, A2, A3, B, n1, n2, n3)
C1 = nan(n1, n2);
C2 = C1;
C3 = C1;
for j2 = 1:n2
for j1 = 1:n1
C1(j1, j2) = A1(j1, j2, B(j1, j2));
C2(j1, j2) = A2(j1, j2, B(j1, j2));
C3(j1, j2) = A3(j1, j2, B(j1, j2));
end
end
end

채택된 답변

Cris LaPierre
Cris LaPierre 2024년 1월 18일
편집: Cris LaPierre 2024년 1월 18일
I'd just have your max function return the linear index instead of the index. Then use that to index A1, A2, and A3
% your current solution
n1 = 400;
n2 = 9;
n3 = 2;
A1 = rand(n1, n2, n3);
A2 = rand(n1, n2, n3);
A3 = rand(n1, n2, n3);
[~, B] = max(A3, [], 3);
[C11, C21, C31] = version1(A1, A2, A3, B, n1, n2, n3);
% Function definition move to the bottom of script
Here is how you would do it using linear indexing
[~, B1] = max(A3, [], 3,'linear');
C111 = A1(B1);
C211 = A2(B1);
C311 = A3(B1);
Now compare the results to see if they are the same.
isequal(C11, C111)
ans = logical
1
isequal(C21, C211)
ans = logical
1
isequal(C31, C311)
ans = logical
1
As you can see, the results are equal.
% function moved to bottom of script
function [C1, C2, C3] = version1(A1, A2, A3, B, n1, n2, n3)
C1 = nan(n1, n2);
C2 = C1;
C3 = C1;
for j2 = 1:n2
for j1 = 1:n1
C1(j1, j2) = A1(j1, j2, B(j1, j2));
C2(j1, j2) = A2(j1, j2, B(j1, j2));
C3(j1, j2) = A3(j1, j2, B(j1, j2));
end
end
end

추가 답변 (1개)

Voss
Voss 2024년 1월 18일
n1 = 400;
n2 = 9;
n3 = 2;
A1 = rand(n1, n2, n3);
A2 = rand(n1, n2, n3);
A3 = rand(n1, n2, n3);
[~, B] = max(A3, [], 3);
[C11_v1, C21_v1, C31_v1] = version1(A1, A2, A3, B, n1, n2, n3);
[C11_v2, C21_v2, C31_v2] = version2(A1, A2, A3, B, n1, n2, n3);
% compare the results
isequal(C11_v1,C11_v2)
ans = logical
1
isequal(C21_v1,C21_v2)
ans = logical
1
isequal(C31_v1,C31_v2)
ans = logical
1
function [C1, C2, C3] = version2(A1, A2, A3, B, n1, n2, n3)
[RR,CC] = ndgrid(1:n1,1:n2);
idx = sub2ind([n1,n2,n3],RR,CC,B);
C1 = A1(idx);
C2 = A2(idx);
C3 = A3(idx);
end
function [C1, C2, C3] = version1(A1, A2, A3, B, n1, n2, n3)
C1 = nan(n1, n2);
C2 = C1;
C3 = C1;
for j2 = 1:n2
for j1 = 1:n1
C1(j1, j2) = A1(j1, j2, B(j1, j2));
C2(j1, j2) = A2(j1, j2, B(j1, j2));
C3(j1, j2) = A3(j1, j2, B(j1, j2));
end
end
end

카테고리

Help CenterFile Exchange에서 Linear Algebra에 대해 자세히 알아보기

태그

제품


릴리스

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by