How to calculate sum(A .* (B * C), 'all') [Ed. actually sum(A .* log(B * C), 'all')] efficiently when A is sparse and B*C is full and large?

조회 수: 4 (최근 30일)
I have three matrices, A of size [J,I], B of size [J,K], C of size [K,I]. A is a sparse matrix with more than 90% zeros, while B and C are positive double matrices. The typical values are J=1e5, I=1e4, K=50.
The problem is that B*C creates a full matrix of size [J,I], which leads to redundant memory usage because what I need is merely the elements (B * C)(find(A)). My current constraint is that I don't have enough memory for a full matrix of size [J,I]. I wonder if there's a smart way to avoid such unnecessary memory usage for calculating this specific expression?
I have tried coding B into a tall array, but error appears like "tall arrays are not allowed to contain sparse data" when .* is evaluted. I also tried coding A into a tall array using tall(full(A)), but that's not reasonable because I need to restore A in full matrix first, and A is in fact not "tall" at all. Another way I tried to reduce memory usage is to devide A and C into blocks and calculate this expression in part (using a for-loop). However, this is not efficient, and does not reach the goal of removing redundant memory usage.
Thanks in advance!

채택된 답변

James Tursa
James Tursa 2021년 11월 1일
편집: James Tursa 2021년 11월 9일
Here is the straightforward mex code (i.e., no parallel sections) if you want to try it out. It computes the result directly in a loop without the need for large temporary memory allocations and data copying. You will need a supported C compiler installed. To compile it use the following at the command line:
mex sABC.c -R2018a
If you have an earlier version of MATLAB you can omit the -R2018a option.
To run it simply call as noted:
A = whatever
B = whatever
C = whatever
sABC(A,B,C)
The C source code:
/* File sABC.c
* sABC(A,B,C) returns sum(A.*log(B*C),'all')
*
* A = sparse real double MxN
* B = full real double MxK
* C = full real double KxN
*
* Programmer: James Tursa
* Date: 10/31/2021
*/
#include "mex.h"
#include <math.h>
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
double dot, result = 0.0;
mwSize M, K, N;
mwSize j, k, nrow;
double *A, *B, *C, *b, *c;
mwIndex *Air, *Ajc;
/* Argument checks */
if( nrhs != 3 ) {
mexErrMsgTxt("Need exactly three inputs.");
}
if( nlhs > 1 ) {
mexErrMsgTxt("Too many outputs.");
}
if( !mxIsDouble(prhs[0]) || !mxIsSparse(prhs[0]) || mxIsComplex(prhs[0]) ) {
mexErrMsgTxt("A must be real sparse double.");
}
if( !mxIsDouble(prhs[1]) || !mxIsDouble(prhs[2]) ||
mxIsSparse(prhs[1]) || mxIsSparse(prhs[2]) ||
mxIsComplex(prhs[1]) || mxIsComplex(prhs[2]) ) {
mexErrMsgTxt("B and C must be real full double matrices.");
}
if( mxGetNumberOfDimensions(prhs[1]) != 2 || mxGetNumberOfDimensions(prhs[2]) != 2 ) {
mexErrMsgTxt("B and C must be 2D.");
}
M = mxGetM(prhs[0]);
N = mxGetN(prhs[0]);
K = mxGetN(prhs[1]);
if( M != mxGetM(prhs[1]) ||
N != mxGetN(prhs[2]) ||
K != mxGetM(prhs[2]) ) {
mexErrMsgTxt("Dimensions not compatible.");
}
/* Calculate result, simple loop no parallel code */
Air = mxGetIr(prhs[0]);
Ajc = mxGetJc(prhs[0]);
A = (double *) mxGetData(prhs[0]);
B = (double *) mxGetData(prhs[1]);
C = (double *) mxGetData(prhs[2]);
for( j=0; j<N; j++ ) {
nrow = Ajc[j+1] - Ajc[j]; /* Number of row elements for this column */
while( nrow-- ) {
b = B + *Air++; /* B row pointer */
c = C + j*K; /* C column pointer */
dot = 0.0;
for( k=0; k<K; k++ ) { /* dot product of B row and C column */
dot += (*b) * (*c);
b += M;
c++;
}
result += *A++ * log(dot); /* Accumulate in result */
}
}
plhs[0] = mxCreateDoubleScalar(result);
}
  댓글 수: 1
Wenyu Zhang
Wenyu Zhang 2021년 11월 1일
편집: Wenyu Zhang 2021년 11월 1일
Wow, it's amazing that this mex code even without parallel computing is faster than my other trials! Thank you for providing such a sample!

댓글을 달려면 로그인하십시오.

추가 답변 (4개)

Matt J
Matt J 2021년 11월 1일
편집: Matt J 2021년 11월 1일
One more way,
[kj,ki,a]=find(A);
C=C.';
accum=0;
for n=1:K
accum=accum+ B(kj,n).*C(ki,n);
end
result=log(accum).'*a;

Matt J
Matt J 2021년 10월 29일
편집: Matt J 2021년 10월 30일
Use the equivalent expression,
sum((B.'*A).*C,'all')
  댓글 수: 8
Matt J
Matt J 2021년 11월 1일
편집: Matt J 2021년 11월 2일
May I ask if there's any equivalent expression of sum(v .* (B * C), 'all') when v is a 1xI dense vector?
sum(B,1)*C*v.'
Wenyu Zhang
Wenyu Zhang 2021년 11월 2일
편집: Wenyu Zhang 2021년 11월 2일
Thank you very much! It's not hard to prove that sum(B,1)*C*v' is equivalent to sum(v.*(B*C),'all'). Your expression is not only more elegant but also faster than mine.

댓글을 달려면 로그인하십시오.


Matt J
Matt J 2021년 10월 30일
편집: Matt J 2021년 10월 30일
Yes I'm still calculating sum(A .* (log(B*C)),'all').
I would probably just break C down into a small number of chunks and loop, e.g.,
Cr=reshape(C,K,I/10,10);
Acell=mat2cell(A,J,ones(1,10)*I/10);
mysum=0;
for n=1:10
mysum=mysum+sum( Acell{n}.*log(B*Cr(:,:,n)) ,'all');
end
Another way I tried to reduce memory usage is to devide A and C into blocks and calculate this expression in part (using a for-loop). However, this is not efficient, and does not reach the goal of removing redundant memory usage.
I'm not sure why you conclude this is not efficient, but regardless, I don't think you're going to be able to avoid it (in the case where you have the log operation in there) unless there is some particular structure to the sparsity pattern in A that you haven't told us about.
It's important to remember that there is a lot of parallel computation happening in a matrix multiplication. When parallel computation is involved, the number of computations isn't necessarily the thing that dominates performance.
  댓글 수: 5
Wenyu Zhang
Wenyu Zhang 2021년 11월 1일
In general, my I is not exactly 1e4. May I ask if there exists an elegant way to reshape the matrix when J/I is not an integer? The solution I could think of is to make Cr a cell like Acell.
Matt J
Matt J 2021년 11월 1일
편집: Matt J 2021년 11월 2일
The solution I could think of is to make Cr a cell like Acell.
Yes, that would be the way. You can use mat2tiles in the File Exchange
Acell=mat2tiles(A,[inf,1e3]);
Ccell=mat2tiles(C,[inf,1e3]);

댓글을 달려면 로그인하십시오.


James Tursa
James Tursa 2021년 10월 29일
편집: James Tursa 2021년 10월 29일
You could use this loop to avoid the memory usage, but it will run slowly because of the data copying going on in the background for the values, row, and column extractions from the variables. This extra data copying could be avoided in a mex routine if you really needed to recover that speed.
[row,col,v] = find(A);
mysum = 0;
for k=1:numel(v)
mysum = mysum + v(k)*(B(row(k),:)*C(:,col(k)));
end
  댓글 수: 4
James Tursa
James Tursa 2021년 10월 31일
Do you have a supported C/C++ compiler installed? The code for this would be pretty straightforward.
Wenyu Zhang
Wenyu Zhang 2021년 10월 31일
Yes I have a supported C++ compiler. But I do not have enough basic knowledge about the mex routine. It may take some time for me to get started.

댓글을 달려면 로그인하십시오.

카테고리

Help CenterFile Exchange에서 Performance and Memory에 대해 자세히 알아보기

태그

제품


릴리스

R2019a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by