Main Content

Entrenar una red con características numéricas

En este ejemplo se muestra cómo crear y entrenar una red neuronal sencilla para la clasificación de datos de características mediante deep learning.

Si tiene un conjunto de datos de características numéricas (por ejemplo, una colección de datos numéricos sin dimensiones espaciales ni temporales), puede entrenar una red de deep learning utilizando una capa de entrada de características. Para ver un ejemplo de cómo entrenar una red para la clasificación de imágenes, consulte Crear una red neuronal de deep learning sencilla para clasificación.

En este ejemplo se muestra cómo entrenar una red para clasificar la condición de los dientes del engranaje de un sistema de transmisión con una mezcla de lecturas numéricas de sensores, estadísticas y etiquetas categóricas.

Cargar datos

Cargue el conjunto de datos de la caja de engranajes para el entrenamiento. Este conjunto de datos está formado por 208 lecturas sintéticas de un sistema de engranajes formado por 18 lecturas numéricas y 3 etiquetas categóricas:

  1. SigMean: media de la señal de vibración

  2. SigMedian: mediana de la señal de vibración

  3. SigRMS: RMS de la señal de vibración

  4. SigVar: varianza de la señal de vibración

  5. SigPeak: pico de la señal de vibración

  6. SigPeak2Peak: pico a pico de la señal de vibración

  7. SigSkewness: asimetría de la señal de vibración

  8. SigKurtosis: curtosis de la señal de vibración

  9. SigCrestFactor: factor de cresta de la señal de vibración

  10. SigMAD: MAD de la señal de vibración

  11. SigRangeCumSum: suma de intervalos de la señal de vibración

  12. SigCorrDimension: dimensión de correlación de la señal de vibración

  13. SigApproxEntropy: entropía aproximada de la señal de vibración

  14. SigLyapExponent: exponente de Lyapunov de la señal de vibración

  15. PeakFreq: frecuencia pico

  16. HighFreqPower: potencia de frecuencia alta

  17. EnvPower: potencia de entorno

  18. PeakSpecKurtosis: frecuencia pico de curtosis espectral

  19. SensorCondition: condición de sensor, especificada como "desvío de sensor" o "sin desvío de sensor"

  20. ShaftCondition: condición de eje, especificada como "desgaste de eje" o "sin desgaste de eje"

  21. GearToothCondition: condición de diente de engranaje, especificada como "diente con error" o "diente sin error"

Lea los datos de la caja de engranajes del archivo CSV "transmissionCasingData.csv".

filename = "transmissionCasingData.csv";
tbl = readtable(filename,TextType="String");

Convierta las etiquetas para la predicción en categóricas utilizando la función convertvars.

labelName = "GearToothCondition";
tbl = convertvars(tbl,labelName,"categorical");

Visualice las primeras filas de la tabla.

head(tbl)
    SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    SensorCondition    ShaftCondition     GearToothCondition
    ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    _______________    __________________

    -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13         "Sensor Drift"     "No Shaft Wear"      No Tooth Fault  
    -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12         "Sensor Drift"     "No Shaft Wear"      No Tooth Fault  
      1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  
      1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39         "Sensor Drift"     "Shaft Wear"         No Tooth Fault  

Para entrenar una red utilizando características categóricas, primero debe convertir las características categóricas en numéricas. Primero, convierta los predictores categóricos en numéricos con la función convertvars especificando un arreglo de cadena que contenga los nombres de todas las variables de entrada categórica. En este conjunto de datos, hay dos características categóricas con los nombres "SensorCondition" y "ShaftCondition".

categoricalInputNames = ["SensorCondition" "ShaftCondition"];
tbl = convertvars(tbl,categoricalInputNames,"categorical");

Forme un lazo con las variables de entrada categórica. Para cada variable:

  • Convierta los valores categóricos en vectores codificados one-hot usando la función onehotencode.

  • Añada los vectores one-hot a la tabla utilizando la función addvars. Especifique que los vectores se inserten después de la columna que contiene los datos categóricos correspondientes.

  • Elimine la columna correspondiente que contiene los datos categóricos.

for i = 1:numel(categoricalInputNames)
    name = categoricalInputNames(i);
    oh = onehotencode(tbl(:,name));
    tbl = addvars(tbl,oh,After=name);
    tbl(:,name) = [];
end

Divida los vectores en columnas independientes utilizando la función splitvars.

tbl = splitvars(tbl);

Visualice las primeras filas de la tabla. Observe que los predictores categóricos se han dividido en varias columnas con los valores categóricos como los nombres de las variables.

head(tbl)
    SigMean     SigMedian    SigRMS    SigVar     SigPeak    SigPeak2Peak    SigSkewness    SigKurtosis    SigCrestFactor    SigMAD     SigRangeCumSum    SigCorrDimension    SigApproxEntropy    SigLyapExponent    PeakFreq    HighFreqPower    EnvPower    PeakSpecKurtosis    No Sensor Drift    Sensor Drift    No Shaft Wear    Shaft Wear    GearToothCondition
    ________    _________    ______    _______    _______    ____________    ___________    ___________    ______________    _______    ______________    ________________    ________________    _______________    ________    _____________    ________    ________________    _______________    ____________    _____________    __________    __________________

    -0.94876     -0.9722     1.3726    0.98387    0.81571       3.6314        -0.041525       2.2666           2.0514         0.8081        28562              1.1429             0.031581            79.931            0          6.75e-06       3.23e-07         162.13                0                1                1              0           No Tooth Fault  
    -0.97537    -0.98958     1.3937    0.99105    0.81571       3.6314        -0.023777       2.2598           2.0203        0.81017        29418              1.1362             0.037835            70.325            0          5.08e-08       9.16e-08         226.12                0                1                1              0           No Tooth Fault  
      1.0502      1.0267     1.4449    0.98491     2.8157       3.6314         -0.04162       2.2658           1.9487        0.80853        31710              1.1479             0.031565            125.19            0          6.74e-06       2.85e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0227      1.0045     1.4288    0.99553     2.8157       3.6314        -0.016356       2.2483           1.9707        0.81324        30984              1.1472             0.032088             112.5            0          4.99e-06        2.4e-07         162.13                0                1                0              1           No Tooth Fault  
      1.0123      1.0024     1.4202    0.99233     2.8157       3.6314        -0.014701       2.2542           1.9826        0.81156        30661              1.1469              0.03287            108.86            0          3.62e-06       2.28e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0275      1.0102     1.4338     1.0001     2.8157       3.6314         -0.02659       2.2439           1.9638        0.81589        31102              1.0985             0.033427            64.576            0          2.55e-06       1.65e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0464      1.0275     1.4477     1.0011     2.8157       3.6314        -0.042849       2.2455           1.9449        0.81595        31665              1.1417             0.034159            98.838            0          1.73e-06       1.55e-07         230.39                0                1                0              1           No Tooth Fault  
      1.0459      1.0257     1.4402    0.98047     2.8157       3.6314        -0.035405       2.2757            1.955        0.80583        31554              1.1345               0.0353            44.223            0          1.11e-06       1.39e-07         230.39                0                1                0              1           No Tooth Fault  

Visualice los nombres de las clases del conjunto de datos.

classNames = categories(tbl{:,labelName})
classNames = 2x1 cell
    {'No Tooth Fault'}
    {'Tooth Fault'   }

Dividir un conjunto de datos en conjuntos de entrenamiento y de validación

Divida el conjunto de datos en particiones de entrenamiento, de validación y de prueba. Reserve el 15% de los datos para la validación y otro 15% para las pruebas.

Visualice el número de observaciones del conjunto de datos.

numObservations = size(tbl,1)
numObservations = 208

Determine el número de observaciones para cada partición.

numObservationsTrain = floor(0.7*numObservations)
numObservationsTrain = 145
numObservationsValidation = floor(0.15*numObservations)
numObservationsValidation = 31
numObservationsTest = numObservations - numObservationsTrain - numObservationsValidation
numObservationsTest = 32

Cree un arreglo de índices aleatorios que se corresponda con las observaciones y divídalo utilizando los tamaños de partición.

idx = randperm(numObservations);
idxTrain = idx(1:numObservationsTrain);
idxValidation = idx(numObservationsTrain+1:numObservationsTrain+numObservationsValidation);
idxTest = idx(numObservationsTrain+numObservationsValidation+1:end);

Divida la tabla de datos en particiones de entrenamiento, de validación y de prueba utilizando los índices.

tblTrain = tbl(idxTrain,:);
tblValidation = tbl(idxValidation,:);
tblTest = tbl(idxTest,:);

Definir la arquitectura de red

Defina la red para la clasificación.

Defina una red con una capa de entrada de características y especifique el número de características. Configure también la capa de entrada para normalizar los datos utilizando la normalización de puntuación Z. A continuación, incluya una capa completamente conectada con un tamaño de salida de 50, seguida de una capa de normalización de lotes y una capa ReLU. Para la clasificación, especifique otra capa totalmente conectada con un tamaño de salida que se corresponda con el número de clases, seguida de una capa softmax.

numFeatures = size(tbl,2) - 1;
numClasses = numel(classNames);
 
layers = [
    featureInputLayer(numFeatures,Normalization="zscore")
    fullyConnectedLayer(50)
    batchNormalizationLayer
    reluLayer
    fullyConnectedLayer(numClasses)
    softmaxLayer];

Especificar las opciones de entrenamiento

Especifique las opciones de entrenamiento.

  • Entrene la red con Adam.

  • Realice el entrenamiento empleando minilotes de un tamaño de 16.

  • Cambie el orden de los datos en cada época.

  • Monitorice la precisión de la red durante el entrenamiento especificando datos de validación.

  • Muestre el progreso del entrenamiento en una gráfica y omita la salida de la ventana de comandos detallada.

El software entrena la red según los datos de entrenamiento y calcula la precisión de los datos de validación en intervalos regulares durante el entrenamiento. Los datos de validación no se utilizan para actualizar los pesos de la red.

miniBatchSize = 16;

options = trainingOptions("adam", ...
    MiniBatchSize=miniBatchSize, ...
    Shuffle="every-epoch", ...
    ValidationData=tblValidation, ...
    Plots="training-progress", ...
    Metrics="accuracy", ...
    Verbose=false);

Entrenar la red

Entrene la red con la arquitectura definida por layers, los datos de entrenamiento y las opciones de entrenamiento. De forma predeterminada, trainnet usa una GPU en caso de que esté disponible. De lo contrario, usa una CPU. Entrenar en una GPU requiere Parallel Computing Toolbox™ y un dispositivo GPU compatible. Para obtener información sobre los dispositivos compatibles, consulte GPU Computing Requirements (Parallel Computing Toolbox). También puede especificar el entorno de ejecución con el argumento nombre-valor ExecutionEnvironment de trainingOptions.

La gráfica de progreso del entrenamiento muestra la pérdida y la precisión de minilotes y la pérdida y la precisión de validación. Para obtener más información sobre la gráfica de progreso del entrenamiento, consulte Monitorizar el progreso del entrenamiento de deep learning.

net = trainnet(tblTrain,layers,"crossentropy",options);

Probar la red

Prediga las etiquetas de los datos de prueba con la red entrenada y calcule la precisión. Especifique el mismo tamaño de minilote utilizado para el entrenamiento.

scores = minibatchpredict(net,tblTest(:,1:end-1),MiniBatchSize=miniBatchSize);
YPred = scores2label(scores,classNames);

Calcule la precisión de clasificación. La precisión es la proporción de etiquetas que la red predice correctamente.

YTest = tblTest{:,labelName};
accuracy = sum(YPred == YTest)/numel(YTest)
accuracy = 0.9375

Visualice los resultados en una matriz de confusión.

figure
confusionchart(YTest,YPred)

Figure contains an object of type ConfusionMatrixChart.

Consulte también

| | | | |

Ejemplos relacionados

Más acerca de