Vectorize/optimize sum over permutation

조회 수: 2 (최근 30일)
cmangla
cmangla 2021년 8월 27일
댓글: cmangla 2021년 8월 31일
In the code below:
p = zeros(N, M, M);
for i=1:N
for j = 1:M
p(i, j, j:M) = Psi(i, Pi(i, j:M));
end
end
p = sum(p, 3);
Psi only contains positive real values. Pi contains permutation of indices of each row. N is large, approximately 2500, while M is small, M=4. This fragment is from an inner loop in my code.
It turns out that the assignment operation in the inner loop is the performance bottle-neck. I tried doing the summation in there rather than in the final line. My measurements were hasty but it seemed to be slower than doing the summation once in the final line.
How can I vectorize these loops? Or is there a different way to optimize this?
Note that since I am summing along the 3rd dimension in the final line, the order/index of values in p in the 3rd dimension index does not matter.
  댓글 수: 1
cmangla
cmangla 2021년 8월 27일
I've come up with a vectorisation, but it's clumsy:
p = zeros(N, M);
Psi_lin = Psi(:);
for j = 1:M
Pi_j = Pi(:, j:M);
Pi_width = M - j + 1;
Ri = repmat((1:N)', 1, Pi_width);
Psi_rows = Ri(:);
Psi_cols = Pi_j(:);
Psi_ind = sub2ind([N, M], Psi_rows, Psi_cols);
Psi_j = reshape(Psi_lin(Psi_ind), [N, Pi_width]);
p(:,j) = sum(Psi_j, 2);
end
p = rho + log(p);
I'd love to see suggestions of a cleaner approach.

댓글을 달려면 로그인하십시오.

답변 (1개)

Kumar Pallav
Kumar Pallav 2021년 8월 31일
I have tried a different way to approach this problem, where the complexity is reduced by performing the vectorization as shown in code below. Hope it helps!
p=zeros(N,M);
%permute the 'Psi' matrix according to the permute matrix 'pi'
for i=1:N
Psi(i,:)=Psi(i,Pi(i,:));
end
%perform the summation according to the logic
for j=1:M
p(:,j)=sum(Psi(:,j:M),2); %first col in p is sum of 1:3 col in Psi
end %second col in p is sum of 2:3 col in Psi and so on
  댓글 수: 1
cmangla
cmangla 2021년 8월 31일
I might try this out, but I'm worried the first loop will turn out to be a bottleneck, similar to how the loops in my original code are. Even in my original code, I'm not doing any computation in the loops, but the bottleneck is there, it seems due to the memory access pattern. I suspect it is happening due to the lack of memory locality in each iteration of the loop (1:N).
Your first loop can also be vectorised using the linear indexing trick I posted in my comment. In my code it seems to have a massive impact. It is more than 6x faster.

댓글을 달려면 로그인하십시오.

카테고리

Help CenterFile Exchange에서 Solver Outputs and Iterative Display에 대해 자세히 알아보기

제품


릴리스

R2020b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by