Fast matrix multiplication with diagonal matrices
    조회 수: 15 (최근 30일)
  
       이전 댓글 표시
    
Let Wbe a large, sparse matrix. Let 
 and 
 be diagonal matrices of the same size. I would like to calculate 
. However, these matrices are large enough that matrix multiplication is very expensive. I would like to speed up the calculation of L. 
I know that computing L can be sped up by utilizing the fact that 
 and 
 are diagonal. For example, I know that I can compute 
 as follows. 
diagD1 = diag(D1);  % diagonal of the matrix D1.
D1W = W.*diadD1; % Equivalent to multiplying the ith row of W by D(i,i). Yields D1*W.
My question is whether there is a similar exploitation of the diagonality of 
 that will allow me to avoid matrix multiplication to compute L. 
Thank you.
채택된 답변
  James Tursa
      
      
 2021년 2월 25일
        
      편집: James Tursa
      
      
 2021년 2월 25일
  
      Here is a mex routine to do this calculation.  It relies on inputting the diagonal matrices as full vectors of the diagonal elements.  It does not check for underflow to 0 for the calculations.  A robust production version of this code would check for this and clean the sparse result of 0 entries, but I did not include that code here.  It also does not check for inf or NaN entries.  This could be made faster with parallel code such as OpenMP, but I didn't do that either.
/* File:  spdmd.c                                                                   */
/* Compile:  mex spdmd.c                                                            */
/* Syntax  C = spdmd(D1,M,D2)                                                       */
/* Does  C = D1 * M * D2                                                            */
/* where M  = double real sparse NxN matrix                                         */
/*       D1 = double real N element full vector representing diagonal NxN matrix    */
/*       D2 = double real N element full vector representing diagonal NxN matrix    */
/*       C  = double real sparse NxN matrix                                         */
/* Programmer:  James Tursa                                                         */
/* Date: 2/24/2021                                                                  */
/* Includes ----------------------------------------------------------- */
#include "mex.h"
#include <string.h>
/* Gateway ------------------------------------------------------------ */
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
    mwSize m, n, j, nrow;
    double *Mpr, *D1pr, *D2pr, *Cpr;
    mwIndex *Mir, *Mjc, *Cir, *Cjc;
/* Argument checks */
    if( nlhs > 1 ) {
        mexErrMsgTxt("Too many outputs");
    }
    if( nrhs != 3 ) {
        mexErrMsgTxt("Need exactly three inputs");
    }
    if (!mxIsDouble(prhs[1]) || !mxIsSparse(prhs[1]) || mxIsComplex(prhs[1])) {
        mexErrMsgTxt("2nd argument must be real double sparse matrix");
    }
    if( !mxIsDouble(prhs[0]) || mxIsSparse(prhs[0]) || mxIsComplex(prhs[0]) ||
	mxGetNumberOfDimensions(prhs[0]) != 2 || (mxGetM(prhs[0]) != 1 && mxGetN(prhs[0]) != 1)) {
        mexErrMsgTxt("1st argument must be real double full vector");
    }
    if (!mxIsDouble(prhs[2]) || mxIsSparse(prhs[2]) || mxIsComplex(prhs[2]) ||
	mxGetNumberOfDimensions(prhs[2]) != 2 || (mxGetM(prhs[2]) != 1 && mxGetN(prhs[2]) != 1)) {
	mexErrMsgTxt("3rd argument must be real double full vector");
    }
    m = mxGetM(prhs[1]);
    n = mxGetN(prhs[1]);
    if (m != n || mxGetNumberOfElements(prhs[0]) != n || mxGetNumberOfElements(prhs[2]) != n) {
        mexErrMsgTxt("Matrix dimensions must agree.");
    }
/* Sparse info */
    Mir = mxGetIr(prhs[1]);
    Mjc = mxGetJc(prhs[1]);
/* Create output */
    plhs[0] = mxCreateSparse( m, n, Mjc[n], mxREAL);
/* Get data pointers */
    Mpr = (double *) mxGetData(prhs[1]);
    D1pr = (double *) mxGetData(prhs[0]);
    D2pr = (double *) mxGetData(prhs[2]);
    Cpr = (double *) mxGetData(plhs[0]);
    Cir = mxGetIr(plhs[0]);
    Cjc = mxGetJc(plhs[0]);
/* Fill in sparse indexing */
    memcpy(Cjc, Mjc, (n+1) * sizeof(mwIndex));
    memcpy(Cir, Mir, Cjc[n] * sizeof(mwIndex));
/* Calculate result */
    for( j=0; j<n; j++ ) {
        nrow = Mjc[j+1] - Mjc[j];  /* Number of row elements for this column */
        while( nrow-- ) {
            *Cpr++ = *Mpr++ * (D2pr[j] * D1pr[*Cir++]);  
	}
    }
}
댓글 수: 3
  James Tursa
      
      
 2021년 2월 25일
				
      편집: James Tursa
      
      
 2021년 2월 25일
  
			Fixed the include.  Thanks.  The speed gain, if any, will depend greatly on the actual sizes and sparsity involved.
  z cy
 2022년 7월 28일
				Hi, I have a question, can you help me to solve it? Thanks!https://ww2.mathworks.cn/matlabcentral/answers/1769470-how-to-reduce-running-time-of-diagonal-matrix-multiplication-with-full-matrix-in-matlab
추가 답변 (0개)
참고 항목
카테고리
				Help Center 및 File Exchange에서 Creating and Concatenating Matrices에 대해 자세히 알아보기
			
	제품
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!