How to calculate sum(A .* (B * C), 'all') [Ed. actually sum(A .* log(B * C), 'all')] efficiently when A is sparse and B*C is full and large?
조회 수: 4 (최근 30일)
이전 댓글 표시
I have three matrices, A of size [J,I], B of size [J,K], C of size [K,I]. A is a sparse matrix with more than 90% zeros, while B and C are positive double matrices. The typical values are J=1e5, I=1e4, K=50.
The problem is that B*C creates a full matrix of size [J,I], which leads to redundant memory usage because what I need is merely the elements (B * C)(find(A)). My current constraint is that I don't have enough memory for a full matrix of size [J,I]. I wonder if there's a smart way to avoid such unnecessary memory usage for calculating this specific expression?
I have tried coding B into a tall array, but error appears like "tall arrays are not allowed to contain sparse data" when .* is evaluted. I also tried coding A into a tall array using tall(full(A)), but that's not reasonable because I need to restore A in full matrix first, and A is in fact not "tall" at all. Another way I tried to reduce memory usage is to devide A and C into blocks and calculate this expression in part (using a for-loop). However, this is not efficient, and does not reach the goal of removing redundant memory usage.
Thanks in advance!
댓글 수: 0
채택된 답변
James Tursa
2021년 11월 1일
편집: James Tursa
2021년 11월 9일
Here is the straightforward mex code (i.e., no parallel sections) if you want to try it out. It computes the result directly in a loop without the need for large temporary memory allocations and data copying. You will need a supported C compiler installed. To compile it use the following at the command line:
mex sABC.c -R2018a
If you have an earlier version of MATLAB you can omit the -R2018a option.
To run it simply call as noted:
A = whatever
B = whatever
C = whatever
sABC(A,B,C)
The C source code:
/* File sABC.c
* sABC(A,B,C) returns sum(A.*log(B*C),'all')
*
* A = sparse real double MxN
* B = full real double MxK
* C = full real double KxN
*
* Programmer: James Tursa
* Date: 10/31/2021
*/
#include "mex.h"
#include <math.h>
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
double dot, result = 0.0;
mwSize M, K, N;
mwSize j, k, nrow;
double *A, *B, *C, *b, *c;
mwIndex *Air, *Ajc;
/* Argument checks */
if( nrhs != 3 ) {
mexErrMsgTxt("Need exactly three inputs.");
}
if( nlhs > 1 ) {
mexErrMsgTxt("Too many outputs.");
}
if( !mxIsDouble(prhs[0]) || !mxIsSparse(prhs[0]) || mxIsComplex(prhs[0]) ) {
mexErrMsgTxt("A must be real sparse double.");
}
if( !mxIsDouble(prhs[1]) || !mxIsDouble(prhs[2]) ||
mxIsSparse(prhs[1]) || mxIsSparse(prhs[2]) ||
mxIsComplex(prhs[1]) || mxIsComplex(prhs[2]) ) {
mexErrMsgTxt("B and C must be real full double matrices.");
}
if( mxGetNumberOfDimensions(prhs[1]) != 2 || mxGetNumberOfDimensions(prhs[2]) != 2 ) {
mexErrMsgTxt("B and C must be 2D.");
}
M = mxGetM(prhs[0]);
N = mxGetN(prhs[0]);
K = mxGetN(prhs[1]);
if( M != mxGetM(prhs[1]) ||
N != mxGetN(prhs[2]) ||
K != mxGetM(prhs[2]) ) {
mexErrMsgTxt("Dimensions not compatible.");
}
/* Calculate result, simple loop no parallel code */
Air = mxGetIr(prhs[0]);
Ajc = mxGetJc(prhs[0]);
A = (double *) mxGetData(prhs[0]);
B = (double *) mxGetData(prhs[1]);
C = (double *) mxGetData(prhs[2]);
for( j=0; j<N; j++ ) {
nrow = Ajc[j+1] - Ajc[j]; /* Number of row elements for this column */
while( nrow-- ) {
b = B + *Air++; /* B row pointer */
c = C + j*K; /* C column pointer */
dot = 0.0;
for( k=0; k<K; k++ ) { /* dot product of B row and C column */
dot += (*b) * (*c);
b += M;
c++;
}
result += *A++ * log(dot); /* Accumulate in result */
}
}
plhs[0] = mxCreateDoubleScalar(result);
}
추가 답변 (4개)
Matt J
2021년 10월 30일
편집: Matt J
2021년 10월 30일
Yes I'm still calculating sum(A .* (log(B*C)),'all').
I would probably just break C down into a small number of chunks and loop, e.g.,
Cr=reshape(C,K,I/10,10);
Acell=mat2cell(A,J,ones(1,10)*I/10);
mysum=0;
for n=1:10
mysum=mysum+sum( Acell{n}.*log(B*Cr(:,:,n)) ,'all');
end
Another way I tried to reduce memory usage is to devide A and C into blocks and calculate this expression in part (using a for-loop). However, this is not efficient, and does not reach the goal of removing redundant memory usage.
I'm not sure why you conclude this is not efficient, but regardless, I don't think you're going to be able to avoid it (in the case where you have the log operation in there) unless there is some particular structure to the sparsity pattern in A that you haven't told us about.
It's important to remember that there is a lot of parallel computation happening in a matrix multiplication. When parallel computation is involved, the number of computations isn't necessarily the thing that dominates performance.
James Tursa
2021년 10월 29일
편집: James Tursa
2021년 10월 29일
You could use this loop to avoid the memory usage, but it will run slowly because of the data copying going on in the background for the values, row, and column extractions from the variables. This extra data copying could be avoided in a mex routine if you really needed to recover that speed.
[row,col,v] = find(A);
mysum = 0;
for k=1:numel(v)
mysum = mysum + v(k)*(B(row(k),:)*C(:,col(k)));
end
댓글 수: 4
James Tursa
2021년 10월 31일
Do you have a supported C/C++ compiler installed? The code for this would be pretty straightforward.
참고 항목
카테고리
Help Center 및 File Exchange에서 Performance and Memory에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!