Main Content

Convertir una red de clasificación en una red de regresión

En este ejemplo se muestra cómo convertir una red de clasificación entrenada en una red de regresión.

Las redes de clasificación de imágenes preentrenadas se han entrenado con más de un millón de imágenes y pueden clasificarlas en 1000 categorías de objetos, como teclado, taza de café, lápiz y muchos animales. Las redes han aprendido representaciones ricas en características para una amplia gama de imágenes. La red toma una imagen como entrada y, a continuación, emite una etiqueta para el objeto de la imagen junto con las probabilidades de cada una de las categorías de objetos.

La transferencia del aprendizaje se suele usar en aplicaciones de deep learning. Se puede usar una red preentrenada como punto de partida para aprender una nueva tarea. En este ejemplo se muestra cómo tomar una red de clasificación preentrenada y volver a entrenarla para tareas de regresión.

En el ejemplo se carga una arquitectura de red neuronal convolucional preentrenada para clasificación, se reemplazan las capas para clasificación y se vuelve a entrenar la red para predecir ángulos de dígitos manuscritos rotados.

Cargar una red preentrenada

Cargue la red preentrenada desde el archivo de soporte digitsClassificationConvolutionNet.mat. Este archivo contiene una red de clasificación que clasifica dígitos manuscritos.

load digitsClassificationConvolutionNet
layers = net.Layers
layers = 
  13x1 Layer array with layers:

     1   'imageinput'    Image Input                  28x28x1 images
     2   'conv_1'        2-D Convolution              10 3x3x1 convolutions with stride [2  2] and padding [0  0  0  0]
     3   'batchnorm_1'   Batch Normalization          Batch normalization with 10 channels
     4   'relu_1'        ReLU                         ReLU
     5   'conv_2'        2-D Convolution              20 3x3x10 convolutions with stride [2  2] and padding [0  0  0  0]
     6   'batchnorm_2'   Batch Normalization          Batch normalization with 20 channels
     7   'relu_2'        ReLU                         ReLU
     8   'conv_3'        2-D Convolution              40 3x3x20 convolutions with stride [2  2] and padding [0  0  0  0]
     9   'batchnorm_3'   Batch Normalization          Batch normalization with 40 channels
    10   'relu_3'        ReLU                         ReLU
    11   'gap'           2-D Global Average Pooling   2-D global average pooling
    12   'fc'            Fully Connected              10 fully connected layer
    13   'softmax'       Softmax                      softmax

Cargar datos

El conjunto de datos contiene imágenes sintéticas de dígitos manuscritos junto con los ángulos correspondientes (en grados) que se rota cada imagen.

Cargue las imágenes de entrenamiento y de prueba como arreglos 4D desde los archivos de soporte DigitsDataTrain.mat y DigitsDataTest.mat. Las variables anglesTrain y anglesTest son los ángulos de rotación en grados. Cada conjunto de datos de entrenamiento y de prueba contiene 5000 imágenes.

load DigitsDataTrain
load DigitsDataTest

Muestre 20 imágenes aleatorias de entrenamiento mediante imshow.

numTrainImages = numel(anglesTrain);
figure
idx = randperm(numTrainImages,20);
for i = 1:numel(idx)
    subplot(4,5,i)    
    imshow(XTrain(:,:,:,idx(i)))
end

Figure contains 20 axes objects. Axes object 1 contains an object of type image. Axes object 2 contains an object of type image. Axes object 3 contains an object of type image. Axes object 4 contains an object of type image. Axes object 5 contains an object of type image. Axes object 6 contains an object of type image. Axes object 7 contains an object of type image. Axes object 8 contains an object of type image. Axes object 9 contains an object of type image. Axes object 10 contains an object of type image. Axes object 11 contains an object of type image. Axes object 12 contains an object of type image. Axes object 13 contains an object of type image. Axes object 14 contains an object of type image. Axes object 15 contains an object of type image. Axes object 16 contains an object of type image. Axes object 17 contains an object of type image. Axes object 18 contains an object of type image. Axes object 19 contains an object of type image. Axes object 20 contains an object of type image.

Sustituir capas finales

Las capas convolucionales de la red extraen características de la imagen que la última capa de aprendizaje usa para clasificar la imagen de entrada. La capa 'fc' contiene la información sobre cómo combinar las características que la red extrae en probabilidades de clase. Para volver a entrenar una red preentrenada para regresión, sustituya esta capa y la siguiente capa softmax por una nueva adaptada a la tarea.

Reemplace la capa completamente conectada final con una capa completamente conectada de tamaño 1 (el número de respuestas).

numResponses = 1;
layer = fullyConnectedLayer(numResponses,Name="fc");

net = replaceLayer(net,"fc",layer)
net = 
  dlnetwork with properties:

         Layers: [13x1 nnet.cnn.layer.Layer]
    Connections: [12x2 table]
     Learnables: [14x3 table]
          State: [6x3 table]
     InputNames: {'imageinput'}
    OutputNames: {'softmax'}
    Initialized: 0

  View summary with summary.

Elimine la capa softmax.

net = removeLayers(net,"softmax");

Ajustar los factores de tasa de aprendizaje de las capas

La red ahora está lista para volver a ser entrenada con los nuevos datos. Si lo prefiere, puede ralentizar el entrenamiento de los pesos de las capas anteriores de la red aumentando la tasa de aprendizaje de la nueva capa totalmente conectada y reduciendo la tasa de aprendizaje global cuando especifique las opciones de entrenamiento.

Aumente las tasas de aprendizaje de los parámetros de la capa totalmente conectada por un factor concreto usando la función setLearnRateFactor.

net = setLearnRateFactor(net,"fc","Weights",10);
net = setLearnRateFactor(net,"fc","Bias",10);

Especificar las opciones de entrenamiento

Especifique las opciones de entrenamiento. Para escoger entre las opciones se requiere un análisis empírico. Para explorar diferentes configuraciones de opciones de entrenamiento mediante la ejecución de experimentos, puede utilizar la app Experiment Manager.

  • Especifique una tasa de aprendizaje reducida de 0,0001.

  • Muestre el progreso del entrenamiento en una gráfica.

  • Deshabilite la salida detallada.

options = trainingOptions("sgdm",...
    InitialLearnRate=0.001, ...
    Plots="training-progress",...
    Verbose=false);

Entrenar redes neuronales

Entrene la red neuronal con la función trainnet. Para la regresión, utilice la pérdida de error cuadrático medio. De forma predeterminada, la función trainnet usa una GPU en caso de que esté disponible. Para utilizar una GPU se requiere una licencia de Parallel Computing Toolbox™ y un dispositivo GPU compatible. Para obtener información sobre los dispositivos compatibles, consulte GPU Computing Requirements (Parallel Computing Toolbox). De lo contrario, la función usa la CPU. Para especificar el entorno de ejecución, utilice la opción de entrenamiento ExecutionEnvironment.

net = trainnet(XTrain,anglesTrain,net,"mse",options);

Probar la red

Pruebe el rendimiento de la red evaluando la precisión de los datos de prueba.

Utilice predict para predecir los ángulos de rotación de las imágenes de validación.

YTest = predict(net,XTest);

Visualice las predicciones en una gráfica de dispersión. Represente los valores predichos frente a los valores reales.

figure
scatter(YTest,anglesTest,"+")
xlabel("Predicted Value")
ylabel("True Value")

hold on
plot([-60 60], [-60 60],"r--")

Figure contains an axes object. The axes object with xlabel Predicted Value, ylabel True Value contains 2 objects of type scatter, line.

Consulte también

| |

Temas relacionados