# Effficient Computation of Matrix Gradient

조회 수: 2 (최근 30일)
댓글: Shreyas Bharadwaj 2024년 4월 9일
Hi,
I am trying to compute the gradient of a matrix-valued function . I have computed the element-wise gradient as and have verified that it is correct numerically (for my purposes of gradient descent).
My MATLAB implementation of the above gradient is:
for p = 1:N
for q = 1:N
gradX(p,q) = sum(w .* (conj(A(:,p)) * conj(B(q,:))) .* (AXB), 'all');
end
end
which I have also verified is correct numerically.
However, my issue is that N = 750, so this computation is extremely slow and impractical for gradient descent: on my desktop with 32 GB RAM and an Intel Xeon 3.7 GHz processor, one iteration takes around 10-15 minutes. I expect to need several hundred iterations for convergence.
I was wondering if there is any obvious way I am missing to speed up or parallelize it. I have tried parfor but have not had any luck.
Thank you and I very much appreciate any suggestions.
##### 댓글 수: 2없음 표시없음 숨기기
Bruno Luong 2024년 4월 9일
편집: Bruno Luong 2024년 4월 9일
Whare is a typical size of w (or AXB)?
btw the first obvious optimization is pre multiply w with AXB.
편집: Shreyas Bharadwaj 2024년 4월 9일
Thank you, I will do that. All matrices, including w, are of size N x N i.e. 750 x 750.

댓글을 달려면 로그인하십시오.

### 채택된 답변

Bruno Luong 2024년 4월 9일
The best
N = 200; % 750
w = rand(N,N);
AXB = rand(N,N)+1i*rand(N);
A = rand(N,N)+1i*rand(N);
B = rand(N,N)+1i*rand(N);
tic
for p = 1:N
for q = 1:N
gradX_1(p,q) = sum(w .* (conj(A(:,p)) * conj(B(q,:))) .* (AXB), 'all');
end
end
t1=toc
t1 = 15.1666
% Method 3
tic
C = w .* AXB;
gradX = A' * C * B';
t2=toc
t2 = 0.0049
err = 2.4063e-17
fprintf('New code version 3 is %g faster\n', t1/t2)
New code version 3 is 3088.92 faster
##### 댓글 수: 1이전 댓글 -1개 표시이전 댓글 -1개 숨기기
Thank you very much! This is exactly what I was looking for.

댓글을 달려면 로그인하십시오.

### 추가 답변 (1개)

Bruno Luong 2024년 4월 9일
I propose this, and time testing for N = 200;
N = 200; % 750
w = rand(N,N);
AXB = rand(N,N)+1i*rand(N);
A = rand(N,N)+1i*rand(N);
B = rand(N,N)+1i*rand(N);
tic
for p = 1:N
for q = 1:N
gradX_1(p,q) = sum(w .* (conj(A(:,p)) * conj(B(q,:))) .* (AXB), 'all');
end
end
t1=toc
t1 = 6.6905
tic
C = w .* AXB;
C = reshape(C,1,[]);
for p = 1:N
Ap = A(:,p);
for q = 1:N
AB = Ap * B(q,:);
AB = reshape(AB,1,[]);
end
end
t2=toc
t2 = 1.0750
fprintf('New code version 1 is %g faster\n', t1/t2)
New code version 1 is 6.22383 faster
Thank you!

댓글을 달려면 로그인하십시오.

### 카테고리

Help CenterFile Exchange에서 Support Vector Machine Regression에 대해 자세히 알아보기

R2023b

### Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by