Improve speed of linear interpolation in nested loops
조회 수: 9 (최근 30일)
이전 댓글 표시
I have to do 1-dimensional linear interpolation many times within 4 nested loops. My X-grid is sorted so I can use interp1q but the code is still slow for my purposes. I managed to do a simple vectorization that eliminates the innermost loop (so I have only 3 loops instead of 4) and it's much faster, but unfortunately still not fast enough for my problem. Any suggestions on how to improve speed? Thanks
I report below a MWE (please, bear in mind that in my real problem the grids are larger)
clear
clc
close all
nx = 40; % grid size for x
nb = 45; % grid size for b
nk = 55;
b_min = -100;
b_max = 300;
% Generate fake data
rng('default')
pol_debt = b_min+(b_max-b_min)*rand(nk,nb,nx); % in [b_min,b_max]
pol_kp_ind = randi([1,nk],nk,nb,nx); % integers in {1,2,..,nk}
pol_exitp = rand(nb,nx,nk); % in [0,1]
b_gridp = zeros(nb,nk);
for k_c =1:nk
% in general, the columns of b_gridp are *not* equal to each other
%b_gridp(:,k_c) = linspace(b_min,b_max,nb)'; %EDITED HERE
b_gridp(:,k_c) = linspace(b_min+rand,b_max-rand,nb)';
end
%% Slow, not vectorized code
tic
stay_arr = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital
for xp_c = 1:nx
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,xp_c,knext_ind); % dim: (nb,1)
dexit_inter = interp1q(b_gridp(:,knext_ind),pol_exit_bx,bnext); % scalar
dexit = min(max(dexit_inter,0),1); % scalar
stay_arr(xp_c,k_c,b_c,x_c) = 1-dexit; % scalar
end
end %k_c
end %b_c
end %x_c
toc
%% This is a faster but not fast enough!
tic
stay_arr2 = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,:,knext_ind); % dim: (nb,nx)
dexit_inter = interp1q(b_gridp(:,knext_ind),pol_exit_bx,bnext); % dim is (1,nx')
dexit = min(max(dexit_inter,0),1); % dim is (1,nx')
stay_arr2(:,k_c,b_c,x_c) = 1-dexit;
end %k_c
end %b_c
end %x_c
toc
err = max(abs(stay_arr-stay_arr2),[],'all')
댓글 수: 5
채택된 답변
Bruno Luong
2022년 7월 25일
편집: Bruno Luong
2022년 7월 25일
This seems to work
clear
clc
close all
nx = 40; % grid size for x
nb = 45; % grid size for b
nk = 55;
b_min = -100;
b_max = 300;
% Generate fake data
rng('default')
pol_debt = b_min+(b_max-b_min)*rand(nk,nb,nx); % in [b_min,b_max]
pol_kp_ind = randi([1,nk],nk,nb,nx); % integers in {1,2,..,nk}
pol_exitp = rand(nb,nx,nk); % in [0,1]
b_gridp = zeros(nb,nk);
for k_c =1:nk
% in general, the columns of b_gridp are *not* equal to each other
b_gridp(:,k_c) = linspace(b_min,b_max,nb)';
end
disp('start code')
tic
stay_arr = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital
for xp_c = 1:nx
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,xp_c,knext_ind); % dim: (nb,1)
dexit_inter = interp1q(b_gridp(:,knext_ind),pol_exit_bx,bnext); % scalar
dexit = min(max(dexit_inter,0),1); % scalar
stay_arr(xp_c,k_c,b_c,x_c) = 1-dexit; % scalar
end
end %k_c
end %b_c
end %x_c
toc
%% Full vectorized code
tic
bgridcommon = b_gridp(:,1);
Y = interp1(bgridcommon,(1:nb)',pol_debt); % nk x nb x nx
Yt = max(min(Y,nb-1),1); % no need if there is no overflowed in the data
I = floor(Yt); % nk x nb x nx
W = Y-I;
[I,J]=ndgrid(I,1:nx); % (nk x nb x nx) x nx
K = repmat(pol_kp_ind,[1 1 1 nx]);
K = reshape(K,size(I));
rhsilin = sub2ind(size(pol_exitp),I,J,K); % (nk x nb x nx) x nx;
rhsilin = reshape(rhsilin, [nk,nb,nx,nx]);
dexit_inter = (1-W).*pol_exitp(rhsilin) + W.*pol_exitp(rhsilin+1);
dexit_inter = permute(dexit_inter, [4 1 2 3]); % [nx,nk,nb,nx]
dexit = min(max(dexit_inter,0),1);
stay_arr2 = 1-dexit;
toc
err = norm(stay_arr2(:)-stay_arr(:),Inf)
댓글 수: 4
Bruno Luong
2022년 7월 26일
편집: Bruno Luong
2022년 7월 26일
Sorry forget my comment above about loop. The bin interval is not the first index. Here is the code corrected that works for variable bin vectors.
nx = 40; % grid size for x
nb = 45; % grid size for b
nk = 55;
b_min = -100;
b_max = 300;
% Generate fake data
rng('default')
pol_debt = b_min+(b_max-b_min)*rand(nk,nb,nx); % in [b_min,b_max]
pol_kp_ind = randi([1,nk],nk,nb,nx); % integers in {1,2,..,nk}
pol_exitp = rand(nb,nx,nk); % in [0,1]
b_gridp = zeros(nb,nk);
for k_c =1:nk
% in general, the columns of b_gridp are *not* equal to each other
b_gridp(:,k_c) = linspace(b_min-rand(),b_max+rand(),nb)';
end
disp('start code')
tic
stay_arr = zeros(nx,nk,nb,nx);
for x_c = 1:nx % current x
for b_c = 1:nb % current debt
for k_c = 1:nk % current capital
for xp_c = 1:nx
bnext = pol_debt(k_c,b_c,x_c);
knext_ind = pol_kp_ind(k_c,b_c,x_c);
pol_exit_bx = pol_exitp(:,xp_c,knext_ind); % dim: (nb,1)
dexit_inter = interp1q(b_gridp(:,knext_ind),pol_exit_bx,bnext); % scalar
dexit = min(max(dexit_inter,0),1); % scalar
stay_arr(xp_c,k_c,b_c,x_c) = 1-dexit; % scalar
end
end %k_c
end %b_c
end %x_c
toc
%% Full vectorized code
tic
K = pol_kp_ind;
bminK = reshape(b_gridp(1,K),size(K));
bmaxK = reshape(b_gridp(nb,K),size(K));
Y = 1 + (nb-1) * (pol_debt - bminK) ./ (bmaxK-bminK);
Yt = max(min(Y,nb-1),1); % no need if there is no overflowed in the data
I = floor(Yt); % nk x nb x nx
W = Y-I;
[I,J]=ndgrid(I,1:nx); % (nk x nb x nx) x nx
K = reshape(repmat(K,[1 1 1 nx]),size(I));
rhsilin = sub2ind(size(pol_exitp),I,J,K); % (nk x nb x nx) x nx;
rhsilin = reshape(rhsilin, [nk,nb,nx,nx]);
dexit_inter = (1-W).*pol_exitp(rhsilin) + W.*pol_exitp(rhsilin+1);
dexit_inter = permute(dexit_inter, [4 1 2 3]); % [nx,nk,nb,nx]
dexit = min(max(dexit_inter,0),1);
stay_arr2 = 1-dexit;
toc
err = norm(stay_arr2(:)-stay_arr(:),Inf)
추가 답변 (0개)
참고 항목
카테고리
Help Center 및 File Exchange에서 Performance and Memory에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!