Retraining YAMNet for audio classification returns channel mismatch error in "deep.internal.train.Trainer/train"
7 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
Ben
el 3 de Oct. de 2024
Comentada: Ben
el 4 de Oct. de 2024
I am retraining YAMNet for a binary classification task, operating on spectrograms of audio signals. My training audio has two classes, positive and negative. Audio is preprocessed & features extracted using yamnetPreprocess(). When training the network, trainnet() produces the following error:
Error using deep.internal.train.Trainer/train (line 74)
Number of channels in predictions (2) must match the number of
channels in the targets (3).
Error in deep.internal.train.ParallelTrainer>iTrainWithSplitCommunicator (line 227)
remoteNetwork = train(remoteTrainer, remoteNetwork, workerMbq);
Error in deep.internal.train.ParallelTrainer/computeTraining (line 127)
spmd
Error in deep.internal.train.Trainer/train (line 59)
net = computeTraining(trainer, net, mbq);
Error in deep.internal.train.trainnet (line 54)
net = train(trainer, net, mbq);
Error in trainnet (line 42)
[net,info] = deep.internal.train.trainnet(mbq, net, loss, options, ...
Error in train_DenseNet_detector_from_semi_synthetic_dataset (line 192)
[trained_network, train_info] = trainnet(trainFeatures, trainLabels', net, "crossentropy", options);
My understanding of this error is that it indicates a mismatch between the number of classes the network expects, and the number of classes in the dataset. I do not see how this can be possible, considering the number of classes in the network is explicitly set by the number of classes in the datastore:
classNames = unique(ads.Labels);
numClasses = numel(classNames);
net = audioPretrainedNetwork("yamnet", NumClasses=numClasses);
My script is based on this MATLAB tutorial: audioPretrainedNetwork and there are no functional differences in the way I'm building datastores or preprocessing the data. The training options and the call to trainnet() are configured as follows:
options = trainingOptions('adam', ...
InitialLearnRate = initial_learn_rate, ...
MaxEpochs = max_epochs, ...
MiniBatchSize = mini_batch_size, ...
Shuffle = "every-epoch", ...
Plots = "training-progress", ...
Metrics = "accuracy", ...
Verbose = 1, ...
ValidationData = {single(validationFeatures), validationLabels'}, ...
ValidationFrequency = validationFrequency,...
ExecutionEnvironment="parallel-auto");
[trained_network, train_info] = trainnet(trainFeatures, trainLabels', net, "crossentropy", options);
Relevant variable dimensions are as follows:
>> unique(ads.Labels)
ans =
2×1 categorical array
negative
positiveNoisy
>> size(trainLabels)
ans =
1 16240
>> size(trainFeatures)
ans =
96 64 1 16240
>> size(validationLabels)
ans =
1 6960
>> size(validationFeatures)
ans =
96 64 1 6960
The only real differences between my script and the MATLAB tutorial are that I'm using parallel execution in the training solver, and the datastore outputEnvironment is set to "gpu" . If I set ExecutionEnvironment = "auto" instead of "parallel-auto" and set ads.OutputEnvironment = 'cpu' the error stack is shorter, but the problem is the same:
Error using trainnet (line 46)
Number of channels in predictions (2) must match the number of channels in
the targets (3).
Error in train_DenseNet_detector_from_semi_synthetic_dataset (line 189)
[trained_network, train_info] = trainnet(trainFeatures, trainLabels', net, "crossentropy", options);
Please could someone give me some advice? The root cause of this is buried in the deep learning toolbox, and it's a little beyond me right now.
Thanks,
Ben
0 comentarios
Respuesta aceptada
Joss Knight
el 3 de Oct. de 2024
I think the issue will be that your label data is a categorical type with three categories. Run
categories(trainLabels)
to confirm. You might need to delete the unused category using removecats.
0 comentarios
Más respuestas (3)
Joss Knight
el 3 de Oct. de 2024
It looks like your network is returning output with three channels instead of two. Could you try running analyzeNetwork(net) to see what it is outputting?
0 comentarios
Ben
el 3 de Oct. de 2024
4 comentarios
Joss Knight
el 3 de Oct. de 2024
Editada: Joss Knight
el 3 de Oct. de 2024
Yes I see. It's not enough just to remove instances of one of the classes from the data, because that class is still one of the label categories. You are going to need to remove that category from your target using removecats, see my other Answer.
To simplify:
% Create label data with 3 classes
randomLabels = categorical(randi(3, 1, 100));
mycats = categories(randomLabels) % 3 categories, '1', '2' and '3'
% Remove all the '2's
randomLabels(randomLabels==mycats(2)) = [];
mycats = categories(randomLabels) % Still 3 categories!
% Remove the '2' category from the data
randomLabels = removecats(randomLabels, mycats(2));
mycats = categories(randomLabels) % Now there's only '1' and '3'
Ben
el 4 de Oct. de 2024
Editada: Ben
el 4 de Oct. de 2024
2 comentarios
Joss Knight
el 4 de Oct. de 2024
I'll pass on your comments.
I don't think this is quite as clearcut as you make out. You have asked your underlying datastore to use the folder names as the label source; this information is gathered on construction of the original audioDatastore. subset() shouldn't be making any assumptions about your choice of labels subsequently. You may have removed all the data from one class because you want to fine-tune your model to favour other classes, or for many other reasons. Or put it another way, if you had a model that accepted data from a datastore, it should also support data from a subset of that datastore; but if you pruned any missing classes, it wouldn't.
Nevertheless you raise some interesting points, in particular your point about using numel(categories(...)) instead of unique is a very good one.
Thanks.
Ver también
Categorías
Más información sobre Image Data Workflows en Help Center y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!