필터 지우기
필터 지우기

Effficient Computation of Matrix Gradient

조회 수: 2 (최근 30일)
Shreyas Bharadwaj
Shreyas Bharadwaj 2024년 4월 9일
댓글: Shreyas Bharadwaj 2024년 4월 9일
I am trying to compute the gradient of a matrix-valued function . I have computed the element-wise gradient as and have verified that it is correct numerically (for my purposes of gradient descent).
My MATLAB implementation of the above gradient is:
for p = 1:N
for q = 1:N
gradX(p,q) = sum(w .* (conj(A(:,p)) * conj(B(q,:))) .* (AXB), 'all');
which I have also verified is correct numerically.
However, my issue is that N = 750, so this computation is extremely slow and impractical for gradient descent: on my desktop with 32 GB RAM and an Intel Xeon 3.7 GHz processor, one iteration takes around 10-15 minutes. I expect to need several hundred iterations for convergence.
I was wondering if there is any obvious way I am missing to speed up or parallelize it. I have tried parfor but have not had any luck.
Thank you and I very much appreciate any suggestions.
  댓글 수: 2
Bruno Luong
Bruno Luong 2024년 4월 9일
편집: Bruno Luong 2024년 4월 9일
Whare is a typical size of w (or AXB)?
btw the first obvious optimization is pre multiply w with AXB.
Shreyas Bharadwaj
Shreyas Bharadwaj 2024년 4월 9일
편집: Shreyas Bharadwaj 2024년 4월 9일
Thank you, I will do that. All matrices, including w, are of size N x N i.e. 750 x 750.

댓글을 달려면 로그인하십시오.

채택된 답변

Bruno Luong
Bruno Luong 2024년 4월 9일
The best
N = 200; % 750
gradX_1 = zeros(N,N);
w = rand(N,N);
AXB = rand(N,N)+1i*rand(N);
A = rand(N,N)+1i*rand(N);
B = rand(N,N)+1i*rand(N);
for p = 1:N
for q = 1:N
gradX_1(p,q) = sum(w .* (conj(A(:,p)) * conj(B(q,:))) .* (AXB), 'all');
t1 = 15.1666
% Method 3
C = w .* AXB;
gradX = A' * C * B';
t2 = 0.0049
err = norm(gradX(:)-gradX_1(:),'inf') / norm(gradX_1(:))
err = 2.4063e-17
fprintf('New code version 3 is %g faster\n', t1/t2)
New code version 3 is 3088.92 faster
  댓글 수: 1
Shreyas Bharadwaj
Shreyas Bharadwaj 2024년 4월 9일
Thank you very much! This is exactly what I was looking for.

댓글을 달려면 로그인하십시오.

추가 답변 (1개)

Bruno Luong
Bruno Luong 2024년 4월 9일
I propose this, and time testing for N = 200;
N = 200; % 750
gradX_1 = zeros(N,N);
w = rand(N,N);
AXB = rand(N,N)+1i*rand(N);
A = rand(N,N)+1i*rand(N);
B = rand(N,N)+1i*rand(N);
for p = 1:N
for q = 1:N
gradX_1(p,q) = sum(w .* (conj(A(:,p)) * conj(B(q,:))) .* (AXB), 'all');
t1 = 6.6905
gradX = zeros(N,N);
C = w .* AXB;
C = reshape(C,1,[]);
for p = 1:N
Ap = A(:,p);
for q = 1:N
AB = Ap * B(q,:);
AB = reshape(AB,1,[]);
gradX(p,q) = C * AB';
t2 = 1.0750
fprintf('New code version 1 is %g faster\n', t1/t2)
New code version 1 is 6.22383 faster


Help CenterFile Exchange에서 Support Vector Machine Regression에 대해 자세히 알아보기




Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by