How to improve speed of calculating trace in a script?
조회 수: 3 (최근 30일)
이전 댓글 표시
Hi all,
In my project I have to calculate the trace of some matrix products, I have the following script to demonstrate the purpose:
clear; clc;
% number of total tests.
nTest = 500;
% part 1. generate nTest*2 random matrices.
nd = 1000;
nt = 100;
dis = cell(nTest, 2);
dis = cellfun(@(v) rand(nd, nt), dis, 'un', 0);
% part 2. perform truncated-SVD on each matrix,
% only leave nRem singular vectors and values.
nRem = 50;
disSVD = cell(nTest, 2);
for isvd = 1:nTest
for jsvd = 1:2
[u, s, v] = svd(dis{isvd, jsvd}, 0);
disSVD{isvd, jsvd} = {u(:, 1:nRem), s(1:nRem, 1:nRem), v(:, 1:nRem)};
end
end
% part 3. for each SVD result, perform trace to obtain disTrans. disTrans is
% non-symmetric, thus jtr needs to start from 1.
disTrans = zeros(nTest);
for itr = 1:nTest
u1 = disSVD{itr, 1};
for jtr = 1:nTest
u2 = disSVD{jtr, 2};
disTrans(itr, jtr) = ...
trace(u1{3} * u1{2}' * u1{1}' * u2{1} * u2{2} * u2{3}');
end
end
I ran profile to find out which part is the slowest, it turns out it's calculating the trace in part 3 due to the large number. Unfortunately in my project the number of calculating trace is also very large. So any idea of how to improve the speed of calculating the trace? The profile is shown here:
Many thanks!
댓글 수: 8
Steven Lord
2018년 4월 3일
Note that built-in functions like svd, cellfun, rand, etc. don't show up in your Profiler report but they do take time. If you open the entry for testuiTujSortImprove I believe you'll see a line in the Children table where the time spent in built-in functions will be reported, and given how many times you called svd in particular I think that will account for the "missing" time.
채택된 답변
Christine Tobler
2018년 4월 3일
You can make the trace operator work faster as follows: Currently, the input is two truncated SVDs, A1 = U1 * S1 * V1' and A2 = U2 * S2 * V2', and you are computing
trace(A1'*A2)
correct? After inserting A1 and A2, you can use the property of trace that trace(A*B) == trace(B*A) (note that trace(A*B*C) ~= trace(A*C*B), see wikipedia).
So this means that you can rearrange
trace(V1*S1'*U1'*U2*S2*V2') == trace( (V2'*V1) * S1 * (U1'*U2) * S2)
Make sure that the parentheses are set like this, and all other operations are on nRem-by-nRem matrices.
By the way, you can also rewrite trace(A'*B) as sum(sum(A.*B)), but I'm not sure if this will give you a speedup for this case.
댓글 수: 3
Christine Tobler
2018년 4월 4일
Trace is a very quick operation, compared with the matrix multiplications, so the larger part of the speed-up is probably about those multiplications. The trace call itself should also become a bit faster, because it now acts on nRem-by-nRem matrices, instead of nTest-by-nTest matrices.
You can try taking the function calls apart:
M3 = u2{3}' * u1{3};
M1 = u1{1}' * u2{1};
M = M3 * u1{2} * M1 * u2{2};
trace(M);
This way, the profiler will tell you how much time is spent in each line, and you can see how much time the trace call takes.
추가 답변 (0개)
참고 항목
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!