Fast matrix multiplication with diagonal matrices

Let Wbe a large, sparse matrix. Let and be diagonal matrices of the same size. I would like to calculate . However, these matrices are large enough that matrix multiplication is very expensive. I would like to speed up the calculation of L.
I know that computing L can be sped up by utilizing the fact that and are diagonal. For example, I know that I can compute as follows.
diagD1 = diag(D1); % diagonal of the matrix D1.
D1W = W.*diadD1; % Equivalent to multiplying the ith row of W by D(i,i). Yields D1*W.
My question is whether there is a similar exploitation of the diagonality of that will allow me to avoid matrix multiplication to compute L.
Thank you.

댓글 수: 6

Hi Samuel,
are you sure about this? diagD1 is a column vector. So D1W = W*diagD1 is a column vector, not a matrix.
Samuel L. Polk
Samuel L. Polk 2021년 2월 24일
편집: Samuel L. Polk 2021년 2월 25일
@David Goodmanson, I agree that D1W = W*diagD1 is a vector, but I was looking at D1W = W.*diagD1. Note the dot multiplication. This notation multiplies the ith row of W by the ith element of diagD1.
Yes, I was mistaken. However, the transpose gets it done.
w = rand(5,5)
d = diag(rand(5,5))
A = w.*d
B = w.*d'
A./w % each row is multiplied by the same element of d
B./w % each col is multiplied by the same element of d
Thank you!
It doesn't appear to be beneficial to avoid matrix multiplication, though I am surprised the latter is so slow (in R2020b).
N=1e5;
W=sprand(N,N,100/N);
d=rand(N,1);
D=spdiags(d,0,N,N);
tic
L1=D*W*D;
toc
Elapsed time is 0.572682 seconds.
tic;
L2=(d.*W).*d.';
toc
Elapsed time is 13.663614 seconds.
(diag(D))'.*A works for right mulipltication - i.e.
A*D=(diag(D)).*A

댓글을 달려면 로그인하십시오.

 채택된 답변

James Tursa
James Tursa 2021년 2월 25일
편집: James Tursa 2021년 2월 25일
Here is a mex routine to do this calculation. It relies on inputting the diagonal matrices as full vectors of the diagonal elements. It does not check for underflow to 0 for the calculations. A robust production version of this code would check for this and clean the sparse result of 0 entries, but I did not include that code here. It also does not check for inf or NaN entries. This could be made faster with parallel code such as OpenMP, but I didn't do that either.
/* File: spdmd.c */
/* Compile: mex spdmd.c */
/* Syntax C = spdmd(D1,M,D2) */
/* Does C = D1 * M * D2 */
/* where M = double real sparse NxN matrix */
/* D1 = double real N element full vector representing diagonal NxN matrix */
/* D2 = double real N element full vector representing diagonal NxN matrix */
/* C = double real sparse NxN matrix */
/* Programmer: James Tursa */
/* Date: 2/24/2021 */
/* Includes ----------------------------------------------------------- */
#include "mex.h"
#include <string.h>
/* Gateway ------------------------------------------------------------ */
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
mwSize m, n, j, nrow;
double *Mpr, *D1pr, *D2pr, *Cpr;
mwIndex *Mir, *Mjc, *Cir, *Cjc;
/* Argument checks */
if( nlhs > 1 ) {
mexErrMsgTxt("Too many outputs");
}
if( nrhs != 3 ) {
mexErrMsgTxt("Need exactly three inputs");
}
if (!mxIsDouble(prhs[1]) || !mxIsSparse(prhs[1]) || mxIsComplex(prhs[1])) {
mexErrMsgTxt("2nd argument must be real double sparse matrix");
}
if( !mxIsDouble(prhs[0]) || mxIsSparse(prhs[0]) || mxIsComplex(prhs[0]) ||
mxGetNumberOfDimensions(prhs[0]) != 2 || (mxGetM(prhs[0]) != 1 && mxGetN(prhs[0]) != 1)) {
mexErrMsgTxt("1st argument must be real double full vector");
}
if (!mxIsDouble(prhs[2]) || mxIsSparse(prhs[2]) || mxIsComplex(prhs[2]) ||
mxGetNumberOfDimensions(prhs[2]) != 2 || (mxGetM(prhs[2]) != 1 && mxGetN(prhs[2]) != 1)) {
mexErrMsgTxt("3rd argument must be real double full vector");
}
m = mxGetM(prhs[1]);
n = mxGetN(prhs[1]);
if (m != n || mxGetNumberOfElements(prhs[0]) != n || mxGetNumberOfElements(prhs[2]) != n) {
mexErrMsgTxt("Matrix dimensions must agree.");
}
/* Sparse info */
Mir = mxGetIr(prhs[1]);
Mjc = mxGetJc(prhs[1]);
/* Create output */
plhs[0] = mxCreateSparse( m, n, Mjc[n], mxREAL);
/* Get data pointers */
Mpr = (double *) mxGetData(prhs[1]);
D1pr = (double *) mxGetData(prhs[0]);
D2pr = (double *) mxGetData(prhs[2]);
Cpr = (double *) mxGetData(plhs[0]);
Cir = mxGetIr(plhs[0]);
Cjc = mxGetJc(plhs[0]);
/* Fill in sparse indexing */
memcpy(Cjc, Mjc, (n+1) * sizeof(mwIndex));
memcpy(Cir, Mir, Cjc[n] * sizeof(mwIndex));
/* Calculate result */
for( j=0; j<n; j++ ) {
nrow = Mjc[j+1] - Mjc[j]; /* Number of row elements for this column */
while( nrow-- ) {
*Cpr++ = *Mpr++ * (D2pr[j] * D1pr[*Cir++]);
}
}
}

댓글 수: 3

This ran about 4 times faster for me on a large array.
FYI, I needed to add the line
#include "string.h"
for this to work.
James Tursa
James Tursa 2021년 2월 25일
편집: James Tursa 2021년 2월 25일
Fixed the include. Thanks. The speed gain, if any, will depend greatly on the actual sizes and sparsity involved.

댓글을 달려면 로그인하십시오.

추가 답변 (0개)

카테고리

도움말 센터File Exchange에서 Creating and Concatenating Matrices에 대해 자세히 알아보기

제품

질문:

2021년 2월 24일

댓글:

2022년 7월 28일

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by