Computing mean by group (need help with speed!)

조회 수: 23 (최근 30일)
John
John 2015년 9월 8일
답변: Steven Lord 2022년 5월 14일
Hi everyone,
I need to demean the columns of matrix Xw (200000 x 24) by group id. To do this, I need to compute means of the columns of Xw for each group identified by the vector id (200000 x 1). I have written the following code, which is the fastest I could do in light of this very similar post:
%DEMEANING
[cid,indx_i,indx_j]=unique(id);
for i=1:size(Xw,2);
bb = accumarray(indx_j, Xw(:,i), [], @mean);
Xw(:,i) = Xw(:,i) - bb(indx_j);
end
This is faster than the code suggested at the post linked above. Importantly, the matrix Xw is rather sparse, and only contains zeros and ones (it is a matrix of dummy variables).
My question: Is there any way to speed up this process further? It is quite time consuming as it stands, and given that this itself is inside of a loop, it is slowing everything else down. Please help!! Creative solutions welcome.
Thanks!
  댓글 수: 1
Matthew Eicholtz
Matthew Eicholtz 2015년 9월 8일
"It is quite time consuming as it stands"...how much time are we talking? I ran your version on my machine and it took around 300 ms. I understand this becomes a problem if it is inside another loop, but I'm just curious what your benchmark is.

댓글을 달려면 로그인하십시오.

채택된 답변

Cedric
Cedric 2015년 9월 8일
편집: Cedric 2015년 9월 9일
In the line of Andrei/Matt's answers and based on my comment (under Andrei's answer):
[ii,jj] = ndgrid( id, 1:size( Xw, 2 )) ;
iijj = [ii(:), jj(:)] ;
sums = accumarray( iijj, Xw(:) ) ;
cnts = accumarray( iijj, ones( numel( Xw ), 1 )) ;
means = sums ./ cnts ;
Xw = Xw - means(id, :) ;
Kelly's solution is still the most efficient on my (old) laptop for small numbers of groups (here 20):
Time OP = 0.562232s
Time KK = 0.239906s
Time AB = 0.664438s
Time ME = 0.602663s
Time CW = 0.258647s
My variant of Andrei/Matt's solutions is slightly better with larger numbers of groups (here 1000):
Time OP = 1.318130s
Time KK = 0.342277s
Time AB = 1.426827s
Time ME = 1.355189s
Time CW = 0.279261s
Actually, here is the profile as a function of the number of groups, still on my rather old laptop:
We see that my solution is pretty flat, and crosses Kelly's in the range 100 < n groups < 300.
  댓글 수: 2
John
John 2015년 9월 9일
Cedric - thanks for all of the thought you put into this. Indeed, for my purposes (200000+ observations) your solution is the fastest. I think this turned out to be a really productive thread....
Cedric
Cedric 2015년 9월 9일
편집: Cedric 2015년 9월 10일
My pleasure - Yes I like these threads, there is always something to learn, from all posts (in fact, I usually print them as PDF).

댓글을 달려면 로그인하십시오.

추가 답변 (5개)

Kelly Kearney
Kelly Kearney 2015년 9월 8일
편집: Kelly Kearney 2015년 9월 8일
This example uses my aggregate function. It's basically a wrapper around accumarray, but allows you to apply the results to multiple columns without recalling accumarray each time.
% Some fake data
n = 200000;
Xw = rand(n,20);
id = ceil(rand(n,1) * 10);
% Group and subtract mean
[idg, Xwg] = aggregate(id, Xw, @(x) bsxfun(@minus, x, mean(x,1)));
% Put back in original order
[~, srt] = aggregate(id, (1:n)');
[~, isrt] = sort(cat(1, srt{:}));
Xwg = cat(1, Xwg{:});
Xwg = Xwg(isrt,:);
Perhaps I should add the index retrieval and resort part to the function itself... I'll add that to my todo list. When I timed this, the reordering part was the most time-consuming part, so that may be able to be sped up a bit more.
EDIT:
Returning the indices turned out to be a quick and easy change to the function. Adding that eliminates the second call to aggregate in my example. I've uploaded the change to GitHub, so you'll want to clone the code from there; MatlabCentral may not grab the updates until the end of the day. New example:
[idg, Xwg, idx] = aggregate(id, Xw, @(x) bsxfun(@minus, x, mean(x,1)));
[~, isrt] = sort(cat(1, idx{:}));
Xwg = cat(1, Xwg{:});
Xwg = Xwg(isrt,:);
This version is takes about 60% the time of your original code.
  댓글 수: 2
John
John 2015년 9월 8일
Thanks Kelly. I'm embedding this in my code at testing it out. I'll let you know how it works...
John
John 2015년 9월 8일
Kelly - so far, your function has worked the best. It has cut my runtime in half. Thank you SO much. You have no idea how helpful this is.
I'm going to wait another few hours, and if nobody has posted quicker code by then (I doubt they will), I'll accept your answer. Thanks again!

댓글을 달려면 로그인하십시오.


Guillaume
Guillaume 2015년 9월 8일
I have no idea if it would be faster but my suggestion would be to split your Xw into a submatrix for each ID. You can then calculate the mean and subtract it for all the columns all at once:
Xwid = arrayfun(@(cid) Xw(id == cid, :), unique(id), 'UniformOutput', false); %split Xw into submatrices
Xwid = cellfun(@(xw) xw - mean(xw), Xwsplit, 'UniformOutput', false); %remove row mean in each submatrix
  댓글 수: 1
John
John 2015년 9월 8일
Thanks for the suggestion! Unfortunately, this slows down the code quite a bit. Any other suggestions out there?

댓글을 달려면 로그인하십시오.


Andrei Bobrov
Andrei Bobrov 2015년 9월 8일
편집: Andrei Bobrov 2015년 9월 8일
[~,~,c]=unique(id);
[ii,jj] = ndgrid(c,1:size(Xw,2));
bb = accumarray([ii(:),jj(:)], Xw(:), [], @mean);
out = Xw - bb(c,:);
  댓글 수: 5
Matthew Eicholtz
Matthew Eicholtz 2015년 9월 8일
This is the fastest version for me. Although there is no need to call unique. Try this instead:
[ii,jj] = ndgrid(id,1:size(Xw,2));
bb = accumarray([ii(:),jj(:)], Xw(:), [], @mean);
out = Xw - bb(id,:);
John
John 2015년 9월 9일
Thanks guys. Huge help.

댓글을 달려면 로그인하십시오.


Vahidreza Jahanmard
Vahidreza Jahanmard 2022년 5월 14일
May the following command work better for your issue
[Groups,~] = findgroups(id);
bb = splitapply(@mean, Xw, Groups);
if you also want to deal with nan values:
bb = splitapply(@(x)mean(x,'omitnan'), Xw, Groups);

Steven Lord
Steven Lord 2022년 5월 14일
Since this question was asked (in release R2018b) we introduced the grouptransform function. I believe using the 'meancenter' method or specifying the method as a function handle involving detrend may make this a one-line operation.

카테고리

Help CenterFile Exchange에서 Data Type Conversion에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by