How can I define a custom loss function using trainnet?

16 visualizaciones (últimos 30 días)
Matthew Murray
Matthew Murray el 29 de Mzo. de 2024
Editada: Matt J el 29 de Mzo. de 2024
Hello,
I am trying to define a custom loss function using trainnet. The documentation says:
If the trainnet function does not provide the loss function that you need for your task, then you can specify a custom loss function to the trainnet as a function handle. The function must have the syntax loss = f(Y,T), where Y and T are the predictions and targets, respectively.
However, I am not sure how the predictions and targets are defined here. I am currently using trainnet as follows:
trainedNet = trainnet(dsTrain,layers,"mse",options);
dsTrain is a datastore containing the input and target images for the regression problem. But I would like change the loss to a custom function involving ssim. I would like something similar to the following, although, I know this isn't quite right:
trainedNet = trainnet(dsTrain,layers,@(Y,targets) 1-ssim(Y,targets),options);
I get the following errror message:
Error using trainnet
Value to differentiate is non-scalar. It must be a traced real dlarray scalar.
Thanks!

Respuestas (1)

Matt J
Matt J el 29 de Mzo. de 2024
Editada: Matt J el 29 de Mzo. de 2024
If you have multichannel output, the loss function will give you an SSIM per channel, e..g,
loss = @(Y,targets) 1-ssim(Y,targets);
[Y,T]=deal(dlarray(rand(5,4,8),'SSC'));
L=loss(Y,T);
whos L
Name Size Bytes Class Attributes L 1x1x8 70 dlarray
You need to decide how you want this reduced to a single value.

Categorías

Más información sobre Deep Learning Toolbox en Help Center y File Exchange.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by