MATLAB Answers

How to improve speed of calculating trace in a script?

조회 수: 3(최근 30일)
Hi all,
In my project I have to calculate the trace of some matrix products, I have the following script to demonstrate the purpose:
clear; clc;
% number of total tests.
nTest = 500;
% part 1. generate nTest*2 random matrices.
nd = 1000;
nt = 100;
dis = cell(nTest, 2);
dis = cellfun(@(v) rand(nd, nt), dis, 'un', 0);
% part 2. perform truncated-SVD on each matrix,
% only leave nRem singular vectors and values.
nRem = 50;
disSVD = cell(nTest, 2);
for isvd = 1:nTest
for jsvd = 1:2
[u, s, v] = svd(dis{isvd, jsvd}, 0);
disSVD{isvd, jsvd} = {u(:, 1:nRem), s(1:nRem, 1:nRem), v(:, 1:nRem)};
end
end
% part 3. for each SVD result, perform trace to obtain disTrans. disTrans is
% non-symmetric, thus jtr needs to start from 1.
disTrans = zeros(nTest);
for itr = 1:nTest
u1 = disSVD{itr, 1};
for jtr = 1:nTest
u2 = disSVD{jtr, 2};
disTrans(itr, jtr) = ...
trace(u1{3} * u1{2}' * u1{1}' * u2{1} * u2{2} * u2{3}');
end
end
I ran profile to find out which part is the slowest, it turns out it's calculating the trace in part 3 due to the large number. Unfortunately in my project the number of calculating trace is also very large. So any idea of how to improve the speed of calculating the trace? The profile is shown here:
Many thanks!

  댓글 수: 8

표시 이전 댓글 수: 5
Xiaohan Du
Xiaohan Du 3 Apr 2018
Sorry I'm still confused with the profiler, my question is:
The self-time of 'testuiTujSortImprove', 'trace', 'rand(nd, nt)' are 61.062s, 2.533s, 1.001s, respectively, and the rest functions all cost less than 0.013s. I think what cost time in the main function are calculations of trace and generation of random matrices. If these 2 only cost 2.533s and 1.001s then why does the whole script cost 61.062s which is a lot more than 2.533 + 1.001?
Steven Lord
Steven Lord 3 Apr 2018
Note that built-in functions like svd, cellfun, rand, etc. don't show up in your Profiler report but they do take time. If you open the entry for testuiTujSortImprove I believe you'll see a line in the Children table where the time spent in built-in functions will be reported, and given how many times you called svd in particular I think that will account for the "missing" time.
Xiaohan Du
Xiaohan Du 4 Apr 2018
I got it, opening the entry of testuiTujSortImprove shows:
It seems that the operations inside trace, i.e.
u1{3} * u1{2}' * u1{1}' * u2{1} * u2{2} * u2{3}'
cost the majority of time.

로그인 to comment.

채택된 답변

Christine Tobler
Christine Tobler 3 Apr 2018
You can make the trace operator work faster as follows: Currently, the input is two truncated SVDs, A1 = U1 * S1 * V1' and A2 = U2 * S2 * V2', and you are computing
trace(A1'*A2)
correct? After inserting A1 and A2, you can use the property of trace that trace(A*B) == trace(B*A) (note that trace(A*B*C) ~= trace(A*C*B), see wikipedia).
So this means that you can rearrange
trace(V1*S1'*U1'*U2*S2*V2') == trace( (V2'*V1) * S1 * (U1'*U2) * S2)
Make sure that the parentheses are set like this, and all other operations are on nRem-by-nRem matrices.
By the way, you can also rewrite trace(A'*B) as sum(sum(A.*B)), but I'm not sure if this will give you a speedup for this case.

  댓글 수: 3

Xiaohan Du
Xiaohan Du 4 Apr 2018
Hi Chris,
Thanks for the reply. I do have a further speedup now with your rearrangement:
I wonder what caused the speedup? My understanding is that calculating trace is very cheap, 250000 operations only cost 2.402s. What cost time is the operations inside the trace, i.e.
V1*S1'*U1'*U2*S2*V2'
So in order to improve speed, this needs to be optimised. Your suggestion results in nRem-by-nRem operations, is this why it's faster? I thought this would improve the speed of calculating trace, not what's inside trace?
Christine Tobler
Christine Tobler 4 Apr 2018
Trace is a very quick operation, compared with the matrix multiplications, so the larger part of the speed-up is probably about those multiplications. The trace call itself should also become a bit faster, because it now acts on nRem-by-nRem matrices, instead of nTest-by-nTest matrices.
You can try taking the function calls apart:
M3 = u2{3}' * u1{3};
M1 = u1{1}' * u2{1};
M = M3 * u1{2} * M1 * u2{2};
trace(M);
This way, the profiler will tell you how much time is spent in each line, and you can see how much time the trace call takes.

로그인 to comment.

More Answers (0)

이 질문에 답변하려면 로그인을(를) 수행하십시오.


Translated by