Fast implementation of max-plus matrix multiplication

9 visualizaciones (últimos 30 días)
Davide Zorzenon
Davide Zorzenon el 13 de Abr. de 2021
Comentada: Seth Younger el 14 de Sept. de 2021
I am trying to implement a fast max-plus algebra multiplication of square matrices in MATLAB. The max-plus multiplication between two matrices returns the matrix such that .
A naive implementation of the max-plus multiplication in MATLAB is:
function C = mp_prod(A,B)
n = size(A,1);
C = -inf*ones(n);
for i = 1:n
for j = 1:n
for k = 1:n
C(i,j) = max(C(i,j), A(i,k) + B(k,j));
end
end
end
However, this is not particularly fast. The fastest implementation I could come up with is:
function C = fast_mp_prod(A,B)
n = size(A,1);
C = -inf*ones(n);
A = transpose(A);
for i = 1:n
C(i,:) = max(A(:,i) + B);
end
which takes about 0.18 seconds to multiply two 100×100 matrices 100 times on my computer. Since performing the same test using a "standard" multiplication takes about 0.004 seconds (45 times faster), I was wondering if there were a way to speed up the code and obtain more comparable timings using MATLAB.
If this were not possible, would it be worth to try to create a package similar to LAPACK for max-plus algebra to get faster performances? Or is this the maximum achievable speed for reasons like "the CPU is inherently slower to calculate the maximum than the product of two numbers"?
  4 comentarios
Jan
Jan el 13 de Abr. de 2021
The naive max() implementation contains a branch. This slows down the processing in general, because it impedes the pipelining in the CPU, which decides for one of the branches using heuristics. In case of max() about half of the predictions are false.
The matrix multiplication is performed in highly optimized library. Even the naive implementation in Matlab profits from Matlab's JIT acceleration, which does not handle max() with the same efficiency.
Do you have a C compiler installed?
Davide Zorzenon
Davide Zorzenon el 13 de Abr. de 2021
Thank you for the explanation! Yes, I have it.

Iniciar sesión para comentar.

Respuesta aceptada

Jan
Jan el 13 de Abr. de 2021
Editada: Jan el 15 de Abr. de 2021
A C-mex version:
[EDITED: Check of type double is added]
// mp_prod_mex.c
// C = mp_prod_mex(A, B)
// INPUT: A, B: Real double [m x k] and [k x n] matrices.
// OUTPUT: Real double [m x n] matrix.
//
// Equivalent Matlab code:
// [m, p] = size(A); n = size(B, 2);
// C = -inf(m, n);
// for i = 1:m
// for j = 1:n
// for k = 1:p
// C(i,j) = max(C(i,j), A(i,k) + B(k,j));
// end
// end
// end
//
// COMPILE:
// mex -O -R2018a mp_prod_mex.c
//
// Jan, Heidelberg, (C) 2021, License: CC BY-SA 3.0
// Handling of rectangular matrices: Bruno Luong
#include "mex.h"
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
{
double *A, *Ak, *B, *Bj, *Bk, *C, mInf = -mxGetInf();
register double s, t;
mwSize m, n, p, i, j, k;
// Get inputs:
A = mxGetDoubles(prhs[0]);
B = mxGetDoubles(prhs[1]);
m = mxGetM(prhs[0]); // size(A, 1)
n = mxGetN(prhs[1]); // size(B, 2)
p = mxGetN(prhs[0]); // size(A, 2) and size(B, 1)
if (!mxIsDouble(prhs[0]) || !mxIsDouble(prhs[1]) ||
mxIsComplex(prhs[0]) || mxIsComplex(prhs[1])) {
mexErrMsgIdAndTxt("Jan:mp_prod_mex:BadInput",
"mp_prod_mex: Inputs must be real double matrices.");
}
if (mxGetM(prhs[1]) != p) {
mexErrMsgIdAndTxt("Jan:mp_prod_mex:BadInput",
"mp_prod_mex: Size of A and B do not match.");
}
// Create output:
plhs[0] = mxCreateDoubleMatrix(m, n, mxREAL);
C = mxGetDoubles(plhs[0]);
// Calculate result:
for (j = 0; j < n; j++) { // Loop over B at 1st for linear access to C
Bj = B + j * n;
for (i = 0; i < m; i++) {
Ak = A + i;
Bk = Bj;
t = mInf;
for (k = 0; k < p; k++) {
s = *Ak + *Bk++;
Ak += m;
if (s > t) {
t = s;
}
}
*C++ = t;
}
}
return;
}
Timings, 100x100 input, 100 iterations:
% Matlab R2018b, i5 mobile:
Elapsed time is 0.365667 seconds. % 1st version from question
Elapsed time is 0.253912 seconds. % 2nd version from question
Elapsed time is 0.337319 seconds. % Bruno's version
Elapsed time is 0.077329 seconds. % C mex
Transposing A for a contiguous access was no significant advantage.
  12 comentarios
Davide Zorzenon
Davide Zorzenon el 14 de Abr. de 2021
Editada: Davide Zorzenon el 14 de Abr. de 2021
@Bruno Luong: thanks for pointing out the case with matrices of dimensions and . I had never thought about that but to be consistent the result of the max-plus multiplication shoud be -inf(m,n) as you wrote.
Seth Younger
Seth Younger el 14 de Sept. de 2021
@Jan I would love to connect with you and learn more about how this implementation works. Are you slicing the matrix up? I am confused by the (A,1) and (A,2).

Iniciar sesión para comentar.

Más respuestas (1)

Bruno Luong
Bruno Luong el 13 de Abr. de 2021
Editada: Bruno Luong el 13 de Abr. de 2021
function C = mp_prod(A,B)
m=size(A,1);
n=size(B,2);
AA=reshape(A,m,1,[]);
BB=reshape(B.',1,n,[]);
C=max(AA+BB,[],3);
tic/toc result
>> A=rand(100);
>> B=rand(100);
>> tic; C = mp_prod(A,B); toc
Elapsed time is 0.005206 seconds.
  8 comentarios
Jan
Jan el 14 de Abr. de 2021
Editada: Jan el 14 de Abr. de 2021
@Davide Zorzenon: Which machine and Matlab version are you using?
The different timings can mean, that Davide's setup is more efficient for the loop, or Bruno's setup is more efficient for the vectorized solution.
Davide Zorzenon
Davide Zorzenon el 14 de Abr. de 2021
@Jan: I use Matlab R2019a and my computer is a Windows 10 x64 intel core i7 (2.20GHz), with 16GB RAM and 6 cores.

Iniciar sesión para comentar.

Categorías

Más información sobre Parallel Computing Toolbox en Help Center y File Exchange.

Productos

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by