Can you plot the gradient for CNNs using trainNetwork?

3 visualizaciones (últimos 30 días)
Arjun Desai
Arjun Desai el 27 de Mayo de 2018
Respondida: Snehal el 27 de Mzo. de 2025
I am using the trainNetwork command to train my network, but noticed that there is no way to plot the gradients over iterations. The trainInfo output does contain some information, but does not seem to contain any information about the gradient.

Respuestas (1)

Snehal
Snehal el 27 de Mzo. de 2025
I understand that you want to extract the gradient information while training a CNN and plot this over iterations. While ‘trainNetwork’ function in MATLAB does not directly expose gradients during the training process, there are two possible workarounds that you can follow:
  • Below is a sample code snippet on extracting gradients using ‘dlgradient’:
net = dlnetwork(layers); % Where ‘layers’ refers to a sequence of layers defined previously in the code.
% Assume 'net', 'XBatch', and 'YBatch' are already defined and 'XBatch' is a dlarray
% Forward pass
YPred = forward(net, XBatch);
% Computing loss
loss = crossentropy(YPred, YBatch);
% Compute gradients
gradients = dlgradient(loss, net.Learnables); % 'gradients' now contains the gradients of the loss with respect to the learnable parameters
  • To plot gradients when using ‘trainNetwork’, you can use a custom plot function instead. Information relating to rate of change of parameters like ‘TrainingLoss’and ‘ValidationLoss’ over iterations can be used to monitor and estimate the gradient-related patterns during training.
Below are some documentation links, you can refer to them for more information:
Hope this helps.

Categorías

Más información sobre Image Data Workflows en Help Center y File Exchange.

Productos


Versión

R2018a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by