How to speed up calculations between adjacent values in an array?

조회 수: 2 (최근 30일)
Charles Roux
Charles Roux 2022년 10월 3일
댓글: J. Alex Lee 2022년 10월 6일
I'm implementing the velocity-Verlet algorithm (see below) for masses in a chain. The equation that I need to solve iteratively and that is giving me pain goes as follows:
The problem is that the code I wrote takes way too long to run for the desired number of time steps. I was hoping someone could show me a more efficient way to do this. Here is the relevant part of the code.
% arrays of positions, velocities, and accelerations
x = zeros(N + 2, 1);
v = zeros(N + 2, 1);
a = zeros(N + 2, 1);
v(2) = Tl;
v(N+1) = Tr;
xil0 = 2*lambda.*kB*Tl;
xir0 = 2*lambda*kB*Tr;
%array of masses
m = 1 - deltaM + (2*deltaM)*rand(N + 2, 1);
m(1) = inf;
m(N+2) = inf;
newa = zeros(N + 2, 1);
xl = zeros(N+2,1);
xr = zeros(N+2,1);
fl = zeros(N+2, 1);
%fstep = zeros(N+2, 1);
for step = 1:maxStep
x = x + v.*tstep + 0.5.*a.*tstep.^2;
xl(2:N+1) = x(1:N);
xr(2:N+1) = x(3:N+2);
%%%% the most time expensive part
fl = k.*xl + beta.*(xl - x).^3;
fstep = fl + k.*(- 2*x + xr) + beta.*((xr - x).^3);
%%%%
xiL = xil0*randn;
xiR = xir0*randn;
fstep(2) = fstep(2) + xiL - lambda.*v(2);
fstep(N+1) = fstep(N+1) + xiR - lambda.*v(N+1);
newa = fstep./m;
% calculating v(t + delta_t)
v = v + 0.5.*(a + newa).*tstep;
fl = fl.*v;
a = newa;
end
I was able to make the code run a bit faster by removing an inner for-loop that went like this
for i = 2:N+1
% calculating a(t + delta_t)
fl(i) = k*(newx(i-1) - 2*newx(i) + newx(i+1))...
+ beta.*((newx(i-1) - newx(i))^3 + (newx(i+1) - newx(i)).^3);
end
and instead "shifting" the x array and performing element-wise operations with those,i.e.
xl(2:N+1) = x(1:N);
xr(2:N+1) = x(3:N+2);
he overall performance is still far too slow (~6 seconds for 10^5 iterations) for me to run it for the desired 10^9ish iterations. If this problem is caused by poor pre-allocation, like I've read on other threads, I'm unsure how to manage it beyond what I've already posted.
Thank you
Edit: here is the profiler confirming my suspicions
  댓글 수: 3
David Goodmanson
David Goodmanson 2022년 10월 4일
편집: David Goodmanson 2022년 10월 4일
Hi Charles,
dpb's code will help. In addition, in the line
fl = k.*xl + beta.*(xl - x).^3;
you have already computed all the cubes of differences of nearest neighbors. The line
fstep = fl + k.*(- 2*x + xr) + beta.*((xr - x).^3);
computes them all over again. With a shift of indices and a minus sign, (xr - x).^3 can be gotten from (xl - x).^3. Then there are some end effects to take care of. But the index shift might improve things by a factor of somewhat less than 2. Overall if these comments were to gain a speed increase of 5, for 1e9 iterations that's still going to take around three and a half hours.
J. Alex Lee
J. Alex Lee 2022년 10월 6일
I was surprised at how dominant the effect of avoiding ^ is! You can try with different combinations of size of x and number of iterations, because it might scale slightly differently, but looks like using dpb's * instead of ^ gives you ~20x speed boost, with the indexing not as dominant, so David's re-use estimate is spot on with <2x
rng(0)
M = 1000;
N = 100000;
x = rand(N+2,1);
xl = zeros(N+2,1);
xr = zeros(N+2,1);
% original
tic
for k = 1:M
xl(2:N+1) = x(1:N);
xr(2:N+1) = x(3:N+2);
d1 = (xl - x).^3;
d2 = (xr - x).^3;
end
toc
Elapsed time is 8.885521 seconds.
% dpb's avoid ^
tic
for k = 1:M
xl(2:N+1) = x(1:N);
xr(2:N+1) = x(3:N+2);
d1 = (xl - x).*(xl - x).*(xl - x);
d2 = (xr - x).*(xr - x).*(xr - x);
end
toc
Elapsed time is 0.384202 seconds.
% David's index shifting (but keep ^)
tic
for k = 1:M
d0 = diff(x).^3;
d1B = -d0(1:N);
d2B = d0(2:N+1);
end
toc
Elapsed time is 4.481809 seconds.
% both improvements
tic
for k = 1:M
dx = diff(x);
d0 = dx.*dx.*dx;
d1B = -d0(1:N);
d2B = d0(2:N+1);
end
toc
Elapsed time is 0.269772 seconds.
isequal(d1(2:N+1),d1B)
ans = logical
1
isequal(d2(2:N+1),d2B)
ans = logical
1
Also, I didn't scrutinize the actual algorithm, but at first glance it seems even in the original, David's comment about the end effects isn't taken care of...you get junk on the first and last elements of fl, don't you? Did you mean:
fstep(1) = fstep(2) + xiL - lambda.*v(2);
fstep(N+2) = fstep(N+1) + xiR - lambda.*v(N+1);

댓글을 달려면 로그인하십시오.

답변 (0개)

카테고리

Help CenterFile Exchange에서 Logical에 대해 자세히 알아보기

제품

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by