Clasificar y actualizar el estado de una red en Simulink
En este ejemplo se muestra cómo clasificar datos de una red neuronal recurrente entrenada en Simulink® mediante el bloque Stateful Classify
. Este ejemplo usa una red de memoria de corto-largo plazo (LSTM) preentrenada.
Cargar una red preentrenada
Cargue JapaneseVowelsNet
, una red de LSTM preentrenada en el conjunto de datos de vocales japonesas, como se describe en [1] y [2]. Esta red se ha entrenado con las secuencias ordenadas por su longitud con un tamaño de minilote de 27.
load JapaneseVowelsNet
Visualice la arquitectura de red.
analyzeNetwork(net);
Cargar datos de prueba
Cargue los datos de prueba de las vocales japonesas. XTest
es un arreglo de celdas que contiene 370 secuencias de dimensión 12 con diferentes longitudes. TTest
es un vector categórico de las etiquetas "1","2",...,"9", que se corresponden con los nueve hablantes.
Cree un arreglo de horario simin
con filas con marcas de tiempo y copias repetidas de X
.
load JapaneseVowelsTestData; X = XTest{94}; numTimeSteps = size(X,2); simin = timetable(repmat(X,1,4)','TimeStep',seconds(0.2));
Modelo de Simulink para clasificar datos
El modelo de Simulink para clasificar datos contiene un bloque Stateful Classify
para predecir las etiquetas y un bloque From Workspace
para cargar la secuencia de datos de entrada sobre las unidades de tiempo.
Para restablecer la red neuronal recurrente a su estado inicial durante la simulación, coloque el bloque Stateful Classify
dentro de un Resettable Subsystem
y use la señal de control Reset
como activador.
open_system('StatefulClassifyExample');
Configurar modelo para la simulación
Ajuste los parámetros de configuración del modelo del bloque Stateful Classify
.
set_param('StatefulClassifyExample/Stateful Classify','NetworkFilePath','JapaneseVowelsNet.mat'); set_param('StatefulClassifyExample','SimulationMode','Normal');
Ejecutar la simulación
Para calcular las respuestas de la red JapaneseVowelsNet
, ejecute la simulación. Las etiquetas de la predicción se guardan en el área de trabajo de MATLAB®.
out = sim('StatefulClassifyExample');
Represente las etiquetas predichas en una gráfica escalonada. La gráfica muestra cómo cambian las predicciones entre unidades de tiempo.
labels = squeeze(out.YPred.Data(1:numTimeSteps,1)); figure stairs(labels, '-o') xlim([1 numTimeSteps]) xlabel("Time Step") ylabel("Predicted Class") title("Classification Over Time Steps")
Compare las predicciones con la etiqueta verdadera. Represente una línea horizontal que muestre la etiqueta verdadera de la observación.
trueLabel = double(TTest(94)); hold on line([1 numTimeSteps],[trueLabel trueLabel], ... 'Color','red', ... 'LineStyle','--') legend(["Prediction" "True Label"]) axis([1 numTimeSteps+1 0 9]);
Referencias
[1] M. Kudo, J. Toyama, and M. Shimbo. "Multidimensional Curve Classification Using Passing-Through Regions." Pattern Recognition Letters. Vol. 20, No. 11–13, pages 1103–1111.
[2] UCI Machine Learning Repository: Japanese Vowels Dataset. https://archive.ics.uci.edu/ml/datasets/Japanese+Vowels
Consulte también
Stateful Predict | Stateful Classify | Predict | Image Classifier