Fast matrix multiplication with diagonal matrices
22 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
Samuel L. Polk
el 24 de Feb. de 2021
Comentada: z cy
el 28 de Jul. de 2022
Let Wbe a large, sparse matrix. Let
and
be diagonal matrices of the same size. I would like to calculate
. However, these matrices are large enough that matrix multiplication is very expensive. I would like to speed up the calculation of L.
![](https://www.mathworks.com/matlabcentral/answers/uploaded_files/530579/image.png)
![](https://www.mathworks.com/matlabcentral/answers/uploaded_files/530584/image.png)
![](https://www.mathworks.com/matlabcentral/answers/uploaded_files/530589/image.png)
I know that computing L can be sped up by utilizing the fact that
and
are diagonal. For example, I know that I can compute
as follows.
![](https://www.mathworks.com/matlabcentral/answers/uploaded_files/530594/image.png)
![](https://www.mathworks.com/matlabcentral/answers/uploaded_files/530599/image.png)
![](https://www.mathworks.com/matlabcentral/answers/uploaded_files/530604/image.png)
diagD1 = diag(D1); % diagonal of the matrix D1.
D1W = W.*diadD1; % Equivalent to multiplying the ith row of W by D(i,i). Yields D1*W.
My question is whether there is a similar exploitation of the diagonality of
that will allow me to avoid matrix multiplication to compute L.
![](https://www.mathworks.com/matlabcentral/answers/uploaded_files/530609/image.png)
Thank you.
6 comentarios
saskia leary
el 20 de Mzo. de 2022
(diag(D))'.*A works for right mulipltication - i.e.
A*D=(diag(D)).*A
Respuesta aceptada
James Tursa
el 25 de Feb. de 2021
Editada: James Tursa
el 25 de Feb. de 2021
Here is a mex routine to do this calculation. It relies on inputting the diagonal matrices as full vectors of the diagonal elements. It does not check for underflow to 0 for the calculations. A robust production version of this code would check for this and clean the sparse result of 0 entries, but I did not include that code here. It also does not check for inf or NaN entries. This could be made faster with parallel code such as OpenMP, but I didn't do that either.
/* File: spdmd.c */
/* Compile: mex spdmd.c */
/* Syntax C = spdmd(D1,M,D2) */
/* Does C = D1 * M * D2 */
/* where M = double real sparse NxN matrix */
/* D1 = double real N element full vector representing diagonal NxN matrix */
/* D2 = double real N element full vector representing diagonal NxN matrix */
/* C = double real sparse NxN matrix */
/* Programmer: James Tursa */
/* Date: 2/24/2021 */
/* Includes ----------------------------------------------------------- */
#include "mex.h"
#include <string.h>
/* Gateway ------------------------------------------------------------ */
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
mwSize m, n, j, nrow;
double *Mpr, *D1pr, *D2pr, *Cpr;
mwIndex *Mir, *Mjc, *Cir, *Cjc;
/* Argument checks */
if( nlhs > 1 ) {
mexErrMsgTxt("Too many outputs");
}
if( nrhs != 3 ) {
mexErrMsgTxt("Need exactly three inputs");
}
if (!mxIsDouble(prhs[1]) || !mxIsSparse(prhs[1]) || mxIsComplex(prhs[1])) {
mexErrMsgTxt("2nd argument must be real double sparse matrix");
}
if( !mxIsDouble(prhs[0]) || mxIsSparse(prhs[0]) || mxIsComplex(prhs[0]) ||
mxGetNumberOfDimensions(prhs[0]) != 2 || (mxGetM(prhs[0]) != 1 && mxGetN(prhs[0]) != 1)) {
mexErrMsgTxt("1st argument must be real double full vector");
}
if (!mxIsDouble(prhs[2]) || mxIsSparse(prhs[2]) || mxIsComplex(prhs[2]) ||
mxGetNumberOfDimensions(prhs[2]) != 2 || (mxGetM(prhs[2]) != 1 && mxGetN(prhs[2]) != 1)) {
mexErrMsgTxt("3rd argument must be real double full vector");
}
m = mxGetM(prhs[1]);
n = mxGetN(prhs[1]);
if (m != n || mxGetNumberOfElements(prhs[0]) != n || mxGetNumberOfElements(prhs[2]) != n) {
mexErrMsgTxt("Matrix dimensions must agree.");
}
/* Sparse info */
Mir = mxGetIr(prhs[1]);
Mjc = mxGetJc(prhs[1]);
/* Create output */
plhs[0] = mxCreateSparse( m, n, Mjc[n], mxREAL);
/* Get data pointers */
Mpr = (double *) mxGetData(prhs[1]);
D1pr = (double *) mxGetData(prhs[0]);
D2pr = (double *) mxGetData(prhs[2]);
Cpr = (double *) mxGetData(plhs[0]);
Cir = mxGetIr(plhs[0]);
Cjc = mxGetJc(plhs[0]);
/* Fill in sparse indexing */
memcpy(Cjc, Mjc, (n+1) * sizeof(mwIndex));
memcpy(Cir, Mir, Cjc[n] * sizeof(mwIndex));
/* Calculate result */
for( j=0; j<n; j++ ) {
nrow = Mjc[j+1] - Mjc[j]; /* Number of row elements for this column */
while( nrow-- ) {
*Cpr++ = *Mpr++ * (D2pr[j] * D1pr[*Cir++]);
}
}
}
3 comentarios
James Tursa
el 25 de Feb. de 2021
Editada: James Tursa
el 25 de Feb. de 2021
Fixed the include. Thanks. The speed gain, if any, will depend greatly on the actual sizes and sparsity involved.
z cy
el 28 de Jul. de 2022
Hi, I have a question, can you help me to solve it? Thanks!https://ww2.mathworks.cn/matlabcentral/answers/1769470-how-to-reduce-running-time-of-diagonal-matrix-multiplication-with-full-matrix-in-matlab
Más respuestas (0)
Ver también
Categorías
Más información sobre Matrix Indexing en Help Center y File Exchange.
Productos
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!