Why is my custom loss function extremly slow?

4 visualizaciones (últimos 30 días)
Clemens H.
Clemens H. el 10 de Nov. de 2023
Respondida: Clemens H. el 27 de Dic. de 2023
Hi,
I would like to train a physics-informed neural network (PINN). I have used the following example as a basis: https://de.mathworks.com/help/deeplearning/ug/solve-partial-differential-equations-with-lbfgs-method-and-deep-learning.html
I use a fairly simple feedforward network as the neural network. I have created the following custom loss function (here n_DOF=3, nMiniBatch=n_B=500):
function [loss,gradients] = modelLossModalMDOF(net,Freq,LAMBDA,PHI,M,k,logicParam)
% === Input
% Freq - (n_DOF,n_B) Natural frequencies (dlarray)
% LAMBDA - (n_DOF,n_DOF,n_B) Natural frequencies Matrix (numerical array)
% PHI - (n_DOF,n_DOF,n_B) mode shapes (numerical array)
% M - (n_DOF,n_DOF) massmatrix (numerical array)
% k - (1,n_DOF) real Stiffnesses (numerical array)
% logicParam - (1,n_DOF) Vector with DOFs to be estimated (numerical array)
% === Output
% loss - (1,1) PINN-Loss
% gradients - (...) gradients
% Initialization
nMiniBatch = size(Freq,2); % Minibatchsize
dof = size(M,1); % DOFs
k_mod = dlarray(nan(nMiniBatch,dof),"BC");
tempLoss = dlarray(nan(dof,dof));
f = dlarray(nan(dof,nMiniBatch*dof));
% Prediction
kPred = forward(net,Freq);
% Loop over all models in the batch
for j=1:nMiniBatch
counter = 1;
for j2=1:dof
if logicParam(j2)==1
k_mod(j2,j) = kPred(counter,j);
counter=counter+1;
else
k_mod(j2,j) = k(j2);
end
end
% global stiffness matrix
K_mod = dlgenK_MDOF(k_mod(:,j));
% eigenvalue problem (with correct stiffnesses = 0)
for j2=1:dof
f(:,j2+dof*(j-1)) = (K_mod-LAMBDA(j2,j2,j)*M)*PHI(:,j2,j);
end
end
% Set data format again
f = dlarray(f,"CB");
% eigenvalue problem-Loss
zeroTarget = zeros(size(f),"like",f);
loss = l2loss(f,zeroTarget);
gradients = dlgradient(loss,net.Learnables);
end
I have noticed that the loss function is extremely slowed down, especially when calculating the gradient. A few iterations take minutes. The loss does not really decrease (order of magnitude 10e+6). During training, more and more RAM is used, so that after some time I get 90% utilization even with 32 GB RAM.
I have already tried ADAM and L-BFGS. Is there a way to speed up the training significantly?
Thank you in advance!
  3 comentarios
Clemens H.
Clemens H. el 10 de Nov. de 2023
Editada: Clemens H. el 11 de Nov. de 2023
Hi Stefan,
thank you for the quick reply!
I have added the required input files to the attachment. There are three different scenarios (zip-folder):
1.) A normal sized training set with 5000 samples,
2) a medium set with 500 samples and
3) a small set with 100 samples.
All values originate from solving the eigenvalue problem from a three-mass oscillator (modal analysis), for which stiffness and mass parameters were varied to generate the training data. The PINN should correctly determine the "unknown" stiffness parameters (mass is known) using given natural frequencies (squared) and scaled modeshapes in the loss function.
Many thanks for your support,
Clemens
Venu
Venu el 20 de Nov. de 2023
Editada: Venu el 20 de Nov. de 2023
If you have implemented your loss function with ADAM optimizer try adjusting your learning rate for the Adam optimizer.
Check the parameters if you have used L-BFGS-B optimizer. Adjust the maximum number of iterations and the convergence threshold to see if it affects the training speed.
If the issue still persists, feel free to provide the code that you have tried implementing ADAM and LBFGS optimizers in your loss function.

Iniciar sesión para comentar.

Respuesta aceptada

Clemens H.
Clemens H. el 27 de Dic. de 2023
Hi @Venu,
Thank you for your reply. I was able to find the "error" after some time. The problem lies in the formulation of the loss function. It is not recommended to use if-queries and for-loops in it, because during the training the loss function is called countless times.
After I had removed the if queries and replaced the for loops with vectorized code, the loss function could now also be executed.
Maybe this answer will help someone in the future.
Best regards Clemens

Más respuestas (0)

Categorías

Más información sobre Sequence and Numeric Feature Data Workflows en Help Center y File Exchange.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by