Main Content

Clasificar y actualizar el estado de una red en Simulink

En este ejemplo se muestra cómo clasificar datos de una red neuronal recurrente entrenada en Simulink® mediante el bloque Stateful Classify. Este ejemplo usa una red de memoria de corto-largo plazo (LSTM) preentrenada.

Cargar una red preentrenada

Cargue JapaneseVowelsNet, una red de LSTM preentrenada en el conjunto de datos de vocales japonesas, como se describe en [1] y [2]. Esta red se ha entrenado con las secuencias ordenadas por su longitud con un tamaño de minilote de 27.

load JapaneseVowelsNet

Visualice la arquitectura de red.

analyzeNetwork(net);

Cargar datos de prueba

Cargue los datos de prueba de las vocales japonesas. XTest es un arreglo de celdas que contiene 370 secuencias de dimensión 12 con diferentes longitudes. TTest es un vector categórico de las etiquetas "1","2",...,"9", que se corresponden con los nueve hablantes.

Cree un arreglo de horario simin con filas con marcas de tiempo y copias repetidas de X.

load JapaneseVowelsTestData;
X = XTest{94};
numTimeSteps = size(X,2);
simin = timetable(repmat(X,1,4)','TimeStep',seconds(0.2));

Modelo de Simulink para clasificar datos

El modelo de Simulink para clasificar datos contiene un bloque Stateful Classify para predecir las etiquetas y un bloque From Workspace para cargar la secuencia de datos de entrada sobre las unidades de tiempo.

Para restablecer la red neuronal recurrente a su estado inicial durante la simulación, coloque el bloque Stateful Classify dentro de un Resettable Subsystem y use la señal de control Reset como activador.

open_system('StatefulClassifyExample');

Configurar modelo para la simulación

Ajuste los parámetros de configuración del modelo del bloque Stateful Classify.

set_param('StatefulClassifyExample/Stateful Classify','NetworkFilePath','JapaneseVowelsNet.mat');
set_param('StatefulClassifyExample','SimulationMode','Normal');

Ejecutar la simulación

Para calcular las respuestas de la red JapaneseVowelsNet, ejecute la simulación. Las etiquetas de la predicción se guardan en el área de trabajo de MATLAB®.

out = sim('StatefulClassifyExample');

Represente las etiquetas predichas en una gráfica escalonada. La gráfica muestra cómo cambian las predicciones entre unidades de tiempo.

labels = squeeze(out.YPred.Data(1:numTimeSteps,1));

figure
stairs(labels, '-o')
xlim([1 numTimeSteps])
xlabel("Time Step")
ylabel("Predicted Class")
title("Classification Over Time Steps")

Compare las predicciones con la etiqueta verdadera. Represente una línea horizontal que muestre la etiqueta verdadera de la observación.

trueLabel = double(TTest(94));
hold on
line([1 numTimeSteps],[trueLabel trueLabel], ...
    'Color','red', ...
    'LineStyle','--')
legend(["Prediction" "True Label"])
axis([1 numTimeSteps+1 0 9]);

Referencias

[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.

[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels

Consulte también

| | |

Temas relacionados