- Use the current network (info.TrainedNetwork) instead of loading checkpoints.
- Compute MSE on the same mini-batch or validation set for consistency:
trainnet difference between TrainingLoss and manually computed MSE loss
17 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
SYt
el 17 de Nov. de 2025 a las 19:01
Respondida: Leepakshi
el 20 de Nov. de 2025 a las 9:33
I train a dlNetwork using trainnet and a custom OutputFcn that loads the network at a given frequency using Checkpoint.
This is the how the TrainingOptions are defined:
options = trainingOptions(algo, ...
'MaxEpochs', epochs, ...
'MiniBatchSize', 1024, ...
'InitialLearnRate', learnrate,...
'Verbose',false,...
'CheckpointPath',checkdir,...
'CheckpointFrequencyUnit', 'iteration', ...
'CheckpointFrequency', check_freq,...
OutputFcn=@(info)updatePlotAndStopTraining(info,lines, checkdir, check_freq, XTest, YTrain, XTrain));
This is the custom OutputFcn where I also manually calculate the mse:
function stop = updatePlotAndStopTraining(info,lines, directory__, checkFreq, XTest, YTrain, XTrain)
global msee
iteration = info.Iteration;
trainingLoss = info.TrainingLoss;
if (~isempty(iteration)) && (mod(iteration,checkFreq)==0) && (iteration ~= 0)
d = dir(fullfile(directory__, '*.mat'));
dates = {d.date};
files = {d.name};
[~, idx] = sort(datenum(dates));
latest_file_name = files{idx(end)};
checknet = load(fullfile(directory__, latest_file_name));
msee = (mse(predict(checknet.net, XTrain.'), YTrain));
end
if iteration<checkFreq
if isvalid(lines.distanceToBase)
addpoints(lines.trainingLossLine,iteration,1.0)
addpoints(lines.mse,iteration,0.0)
end
elseif ~isempty(trainingLoss)
if isvalid(lines.distanceToBase)
addpoints(lines.trainingLossLine,iteration,trainingLoss)
addpoints(lines.mse,iteration,msee)
end
end
stop = false;
end
This is the training call:
[finalnet,info] = trainnet(XTrain.', YTrain.', resetNet,'mse', options);
I would expect mse and training loss to be very close only differing due to the TrainingLoss being normalized, but they are going opposite direction. While my manually computed mse suggested the model is not converging, TrainingLoss shows some convergence...

0 comentarios
Respuestas (1)
Leepakshi
el 20 de Nov. de 2025 a las 9:33
Hi,
You computed MSE using a checkpointed network, which may lag behind the current training state. predict(checknet.net, XTrain.') uses entire training set, while TrainingLoss is per mini-batch and normalized. Data orientation and timing mismatch cause misleading trends.
These approaches can be used to sort it:
preds = predict(info.TrainedNetwork, XTrain.');
msee = mse(preds, YTrain);
This aligns your metric with training progress.
Hope it helps!
0 comentarios
Ver también
Categorías
Más información sobre Sequence and Numeric Feature Data Workflows en Help Center y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!