Esta página aún no se ha traducido para esta versión. Puede ver la versión más reciente de esta página en inglés.

Clasificar series temporales mediante el análisis de ondas y el aprendizaje profundo

Este ejemplo muestra cómo clasificar las señales de electrocardiograma humano (ECG) utilizando la transformación de onda continua (CWT) y una red neuronal convolucional profunda (CNN).

Entrenar una CNN profunda desde cero es costosa desde cero y requiere una gran cantidad de datos de entrenamiento. En varias aplicaciones, no se dispone de una cantidad suficiente de datos de entrenamiento, y la síntesis de nuevos ejemplos de formación realistas no es factible. En estos casos, es deseable aprovechar las redes neuronales existentes que se han entrenado en grandes conjuntos de datos para tareas conceptualmente similares. Este aprovechamiento de las redes neuronales existentes se denomina aprendizaje de transferencia. En este ejemplo adaptamos dos CNN profundas, GoogLeNet y SqueezeNet, previamente entrenadas para el reconocimiento de imágenes para clasificar las formas de onda ECG en función de una representación de frecuencia de tiempo.

GoogLeNet y SqueezeNet son CNN profundas diseñadas originalmente para clasificar imágenes en 1000 categorías. Reutilizamos la arquitectura de red de LA CNN para clasificar las señales ECG basadas en imágenes del CWT de los datos de las series temporales. Los datos utilizados en este ejemplo están disponibles públicamente desde .PhysioNet

Descripción de los datos

En este ejemplo, se utilizan datos de ECG obtenidos de tres grupos de personas: personas con arritmia cardíaca (ARR), personas con insuficiencia cardíaca congestiva (CHF) y personas con ritmos sinuales normales (NSR). En total, se utilizan 162 grabaciones ECG de tres bases de datos PhysioNet: [3][7], [3] y [1][3].Base de datos de arritmias del MIT-BIHBase de datos de ritmo sinusal normal MIT-BIHLa base de datos de insuficiencia cardíaca congestiva BIDMC Más concretamente, 96 grabaciones de personas con arritmia, 30 grabaciones de personas con insuficiencia cardíaca congestiva y 36 grabaciones de personas con ritmos sinuarios normales. El objetivo es entrenar a un clasificador para distinguir entre ARR, CHF y NSR.

Descargar datos

El primer paso es descargar los datos del archivo .Repositorio GitHub Para descargar los datos del sitio web, haga clic y seleccione .Clone or downloadDownload ZIP Guarde el archivo en una carpeta donde tenga permiso de escritura.physionet_ECG_data-master.zip En las instrucciones de este ejemplo se supone que ha descargado el archivo en el directorio temporal, , en MATLAB.tempdir Modifique las instrucciones posteriores para descomprimir y cargar los datos si decide descargar los datos en una carpeta diferente de .tempdir Si está familiarizado con Git, puede descargar la última versión de las herramientas ( ) y obtener los datos de un símbolo del sistema utilizandoGitgit clone https://github.com/mathworks/physionet_ECG_data/.

Después de descargar los datos de GitHub, descomprima el archivo en el directorio temporal.

unzip(fullfile(tempdir,'physionet_ECG_data-master.zip'),tempdir)

Descomprimir crea la carpeta en el directorio temporal.physionet-ECG_data-master Esta carpeta contiene el archivo de texto y .README.mdECGData.zip El archivo contieneECGData.zip

  • ECGData.mat

  • Modified_physionet_data.txt

  • License.txt

contiene los datos utilizados en este ejemplo.ECGData.mat El archivo de texto, , es requerido por la política de copia de PhysioNet y proporciona las atribuciones de origen para los datos, así como una descripción de los pasos de preprocesamiento aplicados a cada registro ECG.Modified_physionet_data.txt

Descomprima .ECGData.zipphysionet-ECG_data-master Cargue el archivo de datos en el espacio de trabajo de MATLAB.

unzip(fullfile(tempdir,'physionet_ECG_data-master','ECGData.zip'),...     fullfile(tempdir,'physionet_ECG_data-master')) load(fullfile(tempdir,'physionet_ECG_data-master','ECGData.mat'))

es una matriz de estructura con dos campos: y .ECGDataDataLabels El campo es una matriz de 162 por 65536 donde cada fila es una grabación ECG muestreada a 128 hercios. es una matriz de celdas de 162 por 1 de etiquetas de diagnóstico, una para cada fila de .DataLabelsData Las tres categorías de diagnóstico son: , , y .'ARR''CHF''NSR'

Para almacenar los datos preprocesados de cada categoría, cree primero un directorio de datos ECG dentro de .dataDirtempdir A continuación, cree tres subdirectorios con el nombre de cada categoría de ECG.'data' La función auxiliar hace esto. acepta , el nombre de un directorio de datos ECG y el nombre de un directorio primario como argumentos de entrada.helperCreateECGDirectorieshelperCreateECGDirectoriesECGData Puede reemplazar con otro directorio en el que tenga permiso de escritura.tempdir Puede encontrar el código fuente de esta función auxiliar en la sección Funciones auxiliares al final de este ejemplo.

parentDir = tempdir; dataDir = 'data'; helperCreateECGDirectories(ECGData,parentDir,dataDir)

Trazar un representante de cada categoría de ECG. La función auxiliar hace esto. acepta como entrada.helperPlotRepshelperPlotRepsECGData Puede encontrar el código fuente de esta función auxiliar en la sección Funciones auxiliares al final de este ejemplo.

helperPlotReps(ECGData)

Crear representaciones de frecuencia de tiempo

Después de crear las carpetas, cree representaciones de frecuencia de tiempo de las señales ECG. Estas representaciones se denominan escalogramas. Un scalogram es el valor absoluto de los coeficientes CWT de una señal.

Para crear los escalogramas, precomputeunde un banco de filtros CWT. La precomputación del banco de filtros CWT es el método preferido al obtener el CWT de muchas señales utilizando los mismos parámetros.

Antes de generar los escalogramas, examine uno de ellos. Cree un banco de filtros CWT utilizando una señal con 1000 muestras.cwtfilterbank Utilice el banco de filtros para tomar el CWT de las primeras 1000 muestras de la señal y obtener el escalograma de los coeficientes.

Fs = 128; fb = cwtfilterbank('SignalLength',1000,...     'SamplingFrequency',Fs,...     'VoicesPerOctave',12); sig = ECGData.Data(1,1:1000); [cfs,frq] = wt(fb,sig); t = (0:999)/Fs;figure;pcolor(t,frq,abs(cfs)) set(gca,'yscale','log');shading interp;axis tight; title('Scalogram');xlabel('Time (s)');ylabel('Frequency (Hz)')

Utilice la función auxiliar para crear los escalogramas como imágenes RGB y escríbalos en el subdirectorio adecuado en .helperCreateRGBfromTFdataDir El código fuente de esta función auxiliar se encuentra en la sección Funciones auxiliares al final de este ejemplo. Para ser compatible con la arquitectura GoogLeNet, cada imagen RGB es una matriz de tamaño 224 por 224 por 3.

helperCreateRGBfromTF(ECGData,parentDir,dataDir)

Dividir en datos de capacitación y validación

Cargue las imágenes de scalogram como un almacén de datos de imágenes. La función etiqueta automáticamente las imágenes en función de los nombres de carpeta y almacena los datos como un ImageDatastore objeto.imageDatastore Un almacén de datos de imágenes le permite almacenar datos de imágenes de gran tamaño, incluidos datos que no caben en la memoria, y leer de forma eficiente lotes de imágenes durante el entrenamiento de una CNN.

allImages = imageDatastore(fullfile(parentDir,dataDir),...     'IncludeSubfolders',true,...     'LabelSource','foldernames');

Divida aleatoriamente las imágenes en dos grupos, uno para el entrenamiento y el otro para la validación. Utilice el 80% de las imágenes para el entrenamiento y el resto para la validación. A efectos de reproducibilidad, establecemos la semilla aleatoria en el valor predeterminado.

rng default [imgsTrain,imgsValidation] = splitEachLabel(allImages,0.8,'randomized'); disp(['Number of training images: ',num2str(numel(imgsTrain.Files))]);
Number of training images: 130 
disp(['Number of validation images: ',num2str(numel(imgsValidation.Files))]);
Number of validation images: 32 

GoogLeNet

Carga

Cargue la red neuronal GoogLeNet preentrenada. Si Deep Learning Toolbox™ paquete de soporte técnico de Modelo no está instalado, el software proporciona un vínculo al paquete de soporte necesario en el Explorador de complementos.for GoogLeNet Network Para instalar el paquete de soporte técnico, haga clic en el vínculo y, a continuación, haga clic en .Instalar

net = googlenet;

Extraiga y muestre el gráfico de capas de la red.

lgraph = layerGraph(net); numberOfLayers = numel(lgraph.Layers); figure('Units','normalized','Position',[0.1 0.1 0.8 0.8]); plot(lgraph) title(['GoogLeNet Layer Graph: ',num2str(numberOfLayers),' Layers']);

Inspeccione el primer elemento de la propiedad Layers de red. Confirme que GoogLeNet requiere imágenes RGB de tamaño 224 por 224 por 3.

net.Layers(1)
ans =    ImageInputLayer with properties:                  Name: 'data'            InputSize: [224 224 3]     Hyperparameters     DataAugmentation: 'none'        Normalization: 'zerocenter'                 Mean: [224×224×3 single]  

Modificar parámetros de red de GoogLeNet

Cada capa de la arquitectura de red se puede considerar un filtro. Las capas anteriores identifican entidades más comunes de imágenes, como blobs, bordes y colores. Las capas posteriores se centran en entidades más específicas para diferenciar las categorías. GoogLeNet está preentrenado para clasificar imágenes en 1000 categorías de objetos. Debe volver a entrenar GoogLeNet para nuestro problema de clasificación ecleste.

Para evitar el sobreajuste, se utiliza una capa desplegable. Una capa desplegable establece aleatoriamente los elementos de entrada en cero con una probabilidad determinada. Consulte para obtener más información.dropoutLayer La probabilidad predeterminada es 0.5. Reemplace la capa de abandono final en la red, , con una capa de abandono de probabilidad 0.6.'pool5-drop_7x7_s1'

newDropoutLayer = dropoutLayer(0.6,'Name','new_Dropout'); lgraph = replaceLayer(lgraph,'pool5-drop_7x7_s1',newDropoutLayer);

Las capas convolucionales de la imagen de extracción de red se utilizan para clasificar la imagen de entrada. Estas dos capas, y en GoogLeNet, contienen información sobre cómo combinar las entidades que la red extrae en probabilidades de clase, un valor de pérdida y etiquetas predichas.'loss3-classifier''output' Para volver a entrenar GoogLeNet para clasificar las imágenes RGB, sustituya estas dos capas por nuevas capas adaptadas a los datos.

Reemplace la capa totalmente conectada por una nueva capa totalmente conectada con el número de filtros igual al número de clases.'loss3-classifier' Para aprender más rápido en las nuevas capas que en las capas transferidas, aumente los factores de tasa de aprendizaje de la capa totalmente conectada.

numClasses = numel(categories(imgsTrain.Labels)); newConnectedLayer = fullyConnectedLayer(numClasses,'Name','new_fc',...     'WeightLearnRateFactor',5,'BiasLearnRateFactor',5); lgraph = replaceLayer(lgraph,'loss3-classifier',newConnectedLayer);

La capa de clasificación especifica las clases de salida de la red. Reemplace la capa de clasificación por una nueva sin etiquetas de clase. establece automáticamente las clases de salida de la capa en el momento del entrenamiento.trainNetwork

newClassLayer = classificationLayer('Name','new_classoutput'); lgraph = replaceLayer(lgraph,'output',newClassLayer);

Establecer opciones de entrenamiento y entrenar GoogLeNet

El entrenamiento de una red neuronal es un proceso iterativo que implica minimizar una función de pérdida. Para minimizar la función de pérdida, se utiliza un algoritmo de descenso de degradado. En cada iteración, se evalúa el gradiente de la función de pérdida y se actualizan las ponderaciones del algoritmo de descenso.

El entrenamiento se puede ajustar estableciendo varias opciones. especifica el tamaño del paso inicial en la dirección del gradiente negativo de la función de pérdida. especifica el tamaño de un subconjunto del conjunto de entrenamiento que se usará en cada iteración.InitialLearnRateMiniBatchSize Una época es un pase completo del algoritmo de entrenamiento en todo el conjunto de entrenamiento. especifica el número máximo de épocas a utilizar para el entrenamiento.MaxEpochs Elegir el número correcto de épocas no es una tarea trivial. Disminuir el número de épocas tiene el efecto de subajustar el modelo, y aumentar el número de épocas resulta en sobreajuste.

Utilice la función para especificar las opciones de entrenamiento.trainingOptions Establézalo en 10, en 10 y en 0.0001.MiniBatchSizeMaxEpochsInitialLearnRate Visualice el progreso del entrenamiento estableciendo en .Plotstraining-progress Utilice el descenso de gradiente estocástico con el optimizador de impulso. De forma predeterminada, el entrenamiento se realiza en una GPU si hay una disponible (requiere Parallel Computing Toolbox™ y una GPU habilitada ® CUDA con capacidad de proceso 3.0 o superior). Para fines de reproducibilidad, establezca para que se utilizara la CPU.ExecutionEnvironmentcputrainNetwork Establezca el valor predeterminado. Los tiempos de ejecución serán más rápidos si puedes usar una GPU.

options = trainingOptions('sgdm',...     'MiniBatchSize',15,...     'MaxEpochs',20,...     'InitialLearnRate',1e-4,...     'ValidationData',imgsValidation,...     'ValidationFrequency',10,...     'Verbose',1,...     'ExecutionEnvironment','cpu',...     'Plots','training-progress'); rng default

Entrena la red. El proceso de entrenamiento suele tardar de 1 a 5 minutos en una CPU de escritorio. La ventana de comandos muestra información de entrenamiento durante la ejecución. Los resultados incluyen el número de época, el número de iteración, el tiempo transcurrido, la precisión del minilote, la precisión de validación y el valor de la función de pérdida para los datos de validación.

trainedGN = trainNetwork(imgsTrain,lgraph,options);

Initializing input data normalization. |======================================================================================================================| |  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  | |         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       | |======================================================================================================================| |       1 |           1 |       00:00:03 |        6.67% |       18.75% |       4.9207 |       2.4141 |      1.0000e-04 | |       2 |          10 |       00:00:23 |       66.67% |       62.50% |       0.9589 |       1.3191 |      1.0000e-04 | |       3 |          20 |       00:00:43 |       46.67% |       75.00% |       1.2973 |       0.5928 |      1.0000e-04 | |       4 |          30 |       00:01:04 |       60.00% |       78.13% |       0.7219 |       0.4576 |      1.0000e-04 | |       5 |          40 |       00:01:25 |       73.33% |       84.38% |       0.4750 |       0.3367 |      1.0000e-04 | |       7 |          50 |       00:01:46 |       93.33% |       84.38% |       0.2714 |       0.2892 |      1.0000e-04 | |       8 |          60 |       00:02:07 |       80.00% |       87.50% |       0.3617 |       0.2433 |      1.0000e-04 | |       9 |          70 |       00:02:29 |       86.67% |       87.50% |       0.3246 |       0.2526 |      1.0000e-04 | |      10 |          80 |       00:02:50 |      100.00% |       96.88% |       0.0701 |       0.1876 |      1.0000e-04 | |      12 |          90 |       00:03:11 |       86.67% |      100.00% |       0.2836 |       0.1681 |      1.0000e-04 | |      13 |         100 |       00:03:32 |       86.67% |       96.88% |       0.4160 |       0.1607 |      1.0000e-04 | |      14 |         110 |       00:03:53 |       86.67% |       96.88% |       0.3237 |       0.1565 |      1.0000e-04 | |      15 |         120 |       00:04:14 |       93.33% |       96.88% |       0.1646 |       0.1476 |      1.0000e-04 | |      17 |         130 |       00:04:35 |      100.00% |       96.88% |       0.0551 |       0.1330 |      1.0000e-04 | |      18 |         140 |       00:04:57 |       93.33% |       96.88% |       0.0927 |       0.1347 |      1.0000e-04 | |      19 |         150 |       00:05:18 |       93.33% |       93.75% |       0.1666 |       0.1325 |      1.0000e-04 | |      20 |         160 |       00:05:39 |       93.33% |       96.88% |       0.0873 |       0.1164 |      1.0000e-04 | |======================================================================================================================| 

Inspeccione la última capa de la red entrenada. Confirme que la capa Classification Output incluye las tres clases.

trainedGN.Layers(end)
ans =    ClassificationOutputLayer with properties:              Name: 'new_classoutput'          Classes: [ARR    CHF    NSR]       OutputSize: 3     Hyperparameters     LossFunction: 'crossentropyex'  

Evaluar la precisión de GoogLeNet

Evalúe la red utilizando los datos de validación.

[YPred,probs] = classify(trainedGN,imgsValidation); accuracy = mean(YPred==imgsValidation.Labels); disp(['GoogLeNet Accuracy: ',num2str(100*accuracy),'%'])
GoogLeNet Accuracy: 96.875% 

La precisión es idéntica a la precisión de validación notificada en la figura de visualización del entrenamiento. Los escalogramas se dividieron en colecciones de entrenamiento y validación. Ambas colecciones se utilizaron para entrenar a GoogLeNet. La forma ideal de evaluar el resultado del entrenamiento es que la red clasifique los datos que no ha visto. Dado que hay una cantidad insuficiente de datos para dividir en entrenamiento, validación y pruebas, tratamos la precisión de validación calculada como la precisión de la red.

Explore las activaciones de GoogLeNet

Cada capa de una CNN produce una respuesta, o activación, a una imagen de entrada. Sin embargo, solo hay unas pocas capas dentro de una CNN que son adecuadas para la extracción de entidades de imagen. Las capas al principio de la red capturan entidades de imagen básicas, como bordes y blobs. Para ver esto, visualice los pesos del filtro de red de la primera capa convolucional. Hay 64 conjuntos individuales de pesos en la primera capa.

wghts = trainedGN.Layers(2).Weights; wghts = rescale(wghts); wghts = imresize(wghts,5); figure montage(wghts) title('First Convolutional Layer Weights')

Puede examinar las activaciones y descubrir qué características aprende GoogLeNet comparando áreas de activación con la imagen original. Para obtener más información, consulte y .

Examine qué áreas de las capas convolucionales se activan en una imagen de la clase.ARR Compare con las áreas correspondientes de la imagen original. Cada capa de una red neuronal convolucional consta de muchas matrices 2D llamadas .Canales Pase la imagen a través de la red y examine las activaciones de salida de la primera capa convolucional, .'conv1-7x7_s2'

convLayer = 'conv1-7x7_s2';  imgClass = 'ARR'; imgName = 'ARR_10.jpg'; imarr = imread(fullfile(parentDir,dataDir,imgClass,imgName));  trainingFeaturesARR = activations(trainedGN,imarr,convLayer); sz = size(trainingFeaturesARR); trainingFeaturesARR = reshape(trainingFeaturesARR,[sz(1) sz(2) 1 sz(3)]); figure montage(rescale(trainingFeaturesARR),'Size',[8 8]) title([imgClass,' Activations'])

Encuentra el canal más fuerte para esta imagen. Compare el canal más fuerte con la imagen original.

imgSize = size(imarr); imgSize = imgSize(1:2); [~,maxValueIndex] = max(max(max(trainingFeaturesARR))); arrMax = trainingFeaturesARR(:,:,:,maxValueIndex); arrMax = rescale(arrMax); arrMax = imresize(arrMax,imgSize); figure; imshowpair(imarr,arrMax,'montage') title(['Strongest ',imgClass,' Channel: ',num2str(maxValueIndex)])

SqueezeNet

SqueezeNet es una CNN profunda cuya arquitectura soporta imágenes de tamaño 227 por 227 por 3. Aunque las dimensiones de la imagen son diferentes para GoogLeNet, no es necesario generar nuevas imágenes RGB en las dimensiones de SqueezeNet. Puede utilizar las imágenes RGB originales.

Carga

Cargue la red neuronal SqueezeNet previamente entrenada. Si Deep Learning Toolbox™ paquete de soporte técnico de Modelo no está instalado, el software proporciona un vínculo al paquete de soporte necesario en el Explorador de complementos.for SqueezeNet Network Para instalar el paquete de soporte técnico, haga clic en el vínculo y, a continuación, haga clic en .Instalar

sqz = squeezenet;

Extraiga el gráfico de capas de la red. Confirme que SqueezeNet tiene menos capas que GoogLeNet. Confirme también que SqueezeNet está configurado para imágenes de tamaño 227 por 227 por 3

lgraphSqz = layerGraph(sqz); disp(['Number of Layers: ',num2str(numel(lgraphSqz.Layers))])
Number of Layers: 68 
disp(lgraphSqz.Layers(1).InputSize)
   227   227     3 

Modificar los parámetros de red de SqueezeNet

Para volver a entrenar SqueezeNet para clasificar nuevas imágenes, realice cambios similares a los realizados para GoogLeNet.

Inspeccione las últimas seis capas de red.

lgraphSqz.Layers(end-5:end)
ans =    6x1 Layer array with layers:       1   'drop9'                             Dropout                 50% dropout      2   'conv10'                            Convolution             1000 1x1x512 convolutions with stride [1  1] and padding [0  0  0  0]      3   'relu_conv10'                       ReLU                    ReLU      4   'pool10'                            Average Pooling         14x14 average pooling with stride [1  1] and padding [0  0  0  0]      5   'prob'                              Softmax                 softmax      6   'ClassificationLayer_predictions'   Classification Output   crossentropyex with 'tench' and 999 other classes 

Reemplace la capa, la última capa desplegable de la red, por una capa de deserción de probabilidad 0,6.'drop9'

tmpLayer = lgraphSqz.Layers(end-5); newDropoutLayer = dropoutLayer(0.6,'Name','new_dropout'); lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newDropoutLayer);

A diferencia de GoogLeNet, la última capa reconocible de SqueezeNet es una capa convolucional 1 por 1, y no una capa totalmente conectada.'conv10' Reemplace la capa por una nueva capa convolucional por el número de filtros igual al número de clases.'conv10' Como se hizo con GoogLeNet, aumente los factores de tasa de aprendizaje de la nueva capa.

numClasses = numel(categories(imgsTrain.Labels)); tmpLayer = lgraphSqz.Layers(end-4); newLearnableLayer = convolution2dLayer(1,numClasses, ...         'Name','new_conv', ...         'WeightLearnRateFactor',10, ...         'BiasLearnRateFactor',10); lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newLearnableLayer);

Reemplace la capa de clasificación por una nueva sin etiquetas de clase.

tmpLayer = lgraphSqz.Layers(end); newClassLayer = classificationLayer('Name','new_classoutput'); lgraphSqz = replaceLayer(lgraphSqz,tmpLayer.Name,newClassLayer);

Inspeccione las últimas seis capas de la red. Confirme que se han cambiado las capas de abandono, convolución y salida.

lgraphSqz.Layers(63:68)
ans =    6x1 Layer array with layers:       1   'new_dropout'       Dropout                 60% dropout      2   'new_conv'          Convolution             3 1x1 convolutions with stride [1  1] and padding [0  0  0  0]      3   'relu_conv10'       ReLU                    ReLU      4   'pool10'            Average Pooling         14x14 average pooling with stride [1  1] and padding [0  0  0  0]      5   'prob'              Softmax                 softmax      6   'new_classoutput'   Classification Output   crossentropyex 

Preparar datos RGB para SqueezeNet

Las imágenes RGB tienen dimensiones adecuadas para la arquitectura GoogLeNet. Cree almacenes de datos de imágenes aumentadas que redimensionen automáticamente las imágenes RGB existentes para la arquitectura SqueezeNet. Para obtener más información, consulte .augmentedImageDatastore

augimgsTrain = augmentedImageDatastore([227 227],imgsTrain); augimgsValidation = augmentedImageDatastore([227 227],imgsValidation);

Establecer opciones de entrenamiento y Train SqueezeNet

Cree un nuevo conjunto de opciones de entrenamiento para usar con SqueezeNet. Establezca la semilla aleatoria en el valor predeterminado y entrene la red. El proceso de entrenamiento suele tardar de 1 a 5 minutos en una CPU de escritorio.

ilr = 3e-4; miniBatchSize = 10; maxEpochs = 15; valFreq = floor(numel(augimgsTrain.Files)/miniBatchSize); opts = trainingOptions('sgdm',...     'MiniBatchSize',miniBatchSize,...     'MaxEpochs',maxEpochs,...     'InitialLearnRate',ilr,...     'ValidationData',augimgsValidation,...     'ValidationFrequency',valFreq,...     'Verbose',1,...     'ExecutionEnvironment','cpu',...     'Plots','training-progress');  rng default trainedSN = trainNetwork(augimgsTrain,lgraphSqz,opts);

Initializing input data normalization. |======================================================================================================================| |  Epoch  |  Iteration  |  Time Elapsed  |  Mini-batch  |  Validation  |  Mini-batch  |  Validation  |  Base Learning  | |         |             |   (hh:mm:ss)   |   Accuracy   |   Accuracy   |     Loss     |     Loss     |      Rate       | |======================================================================================================================| |       1 |           1 |       00:00:01 |       20.00% |       43.75% |       5.2508 |       1.2540 |          0.0003 | |       1 |          13 |       00:00:11 |       60.00% |       50.00% |       0.9912 |       1.0519 |          0.0003 | |       2 |          26 |       00:00:20 |       60.00% |       59.38% |       0.8554 |       0.8497 |          0.0003 | |       3 |          39 |       00:00:30 |       60.00% |       59.38% |       0.8120 |       0.8328 |          0.0003 | |       4 |          50 |       00:00:38 |       50.00% |              |       0.7885 |              |          0.0003 | |       4 |          52 |       00:00:40 |       60.00% |       65.63% |       0.7091 |       0.7314 |          0.0003 | |       5 |          65 |       00:00:49 |       90.00% |       87.50% |       0.4639 |       0.5893 |          0.0003 | |       6 |          78 |       00:00:59 |       70.00% |       87.50% |       0.6021 |       0.4355 |          0.0003 | |       7 |          91 |       00:01:08 |       90.00% |       90.63% |       0.2307 |       0.2945 |          0.0003 | |       8 |         100 |       00:01:15 |       90.00% |              |       0.1827 |              |          0.0003 | |       8 |         104 |       00:01:18 |       90.00% |       93.75% |       0.2139 |       0.2153 |          0.0003 | |       9 |         117 |       00:01:28 |      100.00% |       90.63% |       0.0521 |       0.1964 |          0.0003 | |      10 |         130 |       00:01:38 |       90.00% |       90.63% |       0.1134 |       0.2214 |          0.0003 | |      11 |         143 |       00:01:47 |      100.00% |       90.63% |       0.0855 |       0.2095 |          0.0003 | |      12 |         150 |       00:01:52 |       90.00% |              |       0.2394 |              |          0.0003 | |      12 |         156 |       00:01:57 |      100.00% |       90.63% |       0.0606 |       0.1849 |          0.0003 | |      13 |         169 |       00:02:06 |      100.00% |       90.63% |       0.0090 |       0.2071 |          0.0003 | |      14 |         182 |       00:02:16 |      100.00% |       93.75% |       0.0127 |       0.3597 |          0.0003 | |      15 |         195 |       00:02:25 |      100.00% |       93.75% |       0.0016 |       0.3414 |          0.0003 | |======================================================================================================================| 

Inspeccione la última capa de la red. Confirme que la capa Classification Output incluye las tres clases.

trainedSN.Layers(end)
ans =    ClassificationOutputLayer with properties:              Name: 'new_classoutput'          Classes: [ARR    CHF    NSR]       OutputSize: 3     Hyperparameters     LossFunction: 'crossentropyex'  

Evaluar la precisión de SqueezeNet

Evalúe la red utilizando los datos de validación.

[YPred,probs] = classify(trainedSN,augimgsValidation); accuracy = mean(YPred==imgsValidation.Labels); disp(['SqueezeNet Accuracy: ',num2str(100*accuracy),'%'])
SqueezeNet Accuracy: 93.75% 

Conclusión

En este ejemplo se muestra cómo utilizar el aprendizaje de transferencia y el análisis de wavelet continuo para clasificar tres clases de señales ECG aprovechando las CNN entrenadas previamente GoogLeNet y SqueezeNet. Las representaciones de frecuencia de tiempo basadas en ondas de señales ECG se utilizan para crear escalogramas. Se generan imágenes RGB de los escalogramas. Las imágenes se utilizan para afinar ambos CNN profundos. También se exploraron las activaciones de diferentes capas de red.

Este ejemplo ilustra un flujo de trabajo posible que puede usar para clasificar señales mediante modelos CNN entrenados previamente. Otros flujos de trabajo son posibles. GoogLeNet y SqueezeNet son modelos previamente entrenados en un subconjunto de la base de datos ImageNet [10], que se utiliza en el ImageNet Large-Scale Visual Recognition Challenge (ILSVRC) [8]. La colección ImageNet contiene imágenes de objetos del mundo real, como peces, aves, electrodomésticos y hongos. Los escalogramas quedan fuera de la clase de objetos del mundo real. Con el fin de encajar en la arquitectura GoogLeNet y SqueezeNet, los escalogramas también se sometieron a una reducción de datos. En lugar de afinar CNN preentrenados para distinguir diferentes clases de escalogramas, entrenar una CNN desde cero en las dimensiones originales del escalograma es una opción.

Referencias

  1. Baim, D. S., W. S. Colucci, E. S. Monrad, H. S. Smith, R. F. Wright, A. Lanoue, D. F. Gauthier, B. J. Ransil, W. Grossman y E. Braunwald. "Supervivencia de pacientes con insuficiencia cardíaca congestiva grave tratados con milrinona oral." .Journal of the American College of Cardiology Vol. 7, Número 3, 1986, págs. 661–670.

  2. Engin, M. "EcG beat classification using neuro-fuzzy network." .Pattern Recognition Letters Vol. 25, Número 15, 2004, págs. 1715–1722.

  3. Goldberger A. L., L. A. N. Amaral, L. Glass, J. M. Hausdorff, P. Ch. Ivanov, R. G. Mark, J. E. Mietus, G. B. Moody, C.-K. Peng, y H. E. Stanley. "PhysioBank, PhysioToolkit y PhysioNet: Componentes de un nuevo recurso de investigación para señales fisiológicas complejas." Circulación. Vol. 101, Número 23: e215–e220. [Páginas Electrónicas de Circulación; ]; 2000 (13 de junio). Doi:http://circ.ahajournals.org/content/101/23/e215.full 10.1161/01.CIR.101.23.e215.

  4. Leonarduzzi, R. F., G. Schlotthauer y M. E. Torres. "El análisis multifractal basado en el líder de wavelet de la variabilidad de la frecuencia cardíaca durante la isquemia miocárdica." En , , 110–113.Engineering in Medicine and Biology Society (EMBC)Annual International Conference of the IEEE Buenos Aires, Argentina: IEEE, 2010.

  5. Li, T., y M. Zhou. "Clasificación ECG usando entropía de paquetes de wavelet y bosques aleatorios." .Entropy Vol. 18, Número 8, 2016, p.285.

  6. Maharaj, E. A., y A. M. Alonso. "Análisis discriminatorio de series temporales multivariadas: Aplicación al diagnóstico basado en señales ECG." .Computational Statistics and Data Analysis Vol. 70, 2014, págs. 67–87.

  7. Moody, G.B., y R. G. Mark. "El impacto de la base de datos de arritmias del MIT-BIH." .IEEE Engineering in Medicine and Biology Magazine Vol. 20. Número 3, mayo-junio de 2001, págs. 45–50. (PMID: 11446209)

  8. Russakovsky, O., J. Deng, y H. Su et al. "Desafío de reconocimiento visual a gran escala de ImageNet." .International Journal of Computer Vision Vol. 115, Número 3, 2015, págs. 211–252.

  9. Zhao, Q., y L. Zhang. "Extracción y clasificación de características ECG mediante la transformación de wavelet y máquinas vectoriales de soporte." En , 1089-1092.IEEE International Conference on Neural Networks and Brain Beijing, China: IEEE, 2005.

  10. .ImageNethttp://www.image-net.org

Funciones de apoyo

crea un directorio de datos dentro de un directorio primario y, a continuación, crea tres subdirectorios dentro del directorio de datos.helperCreateECGDataDirectories Los subdirectorios reciben el nombre de cada clase de señal ECG que se encuentra en .ECGData

function helperCreateECGDirectories(ECGData,parentFolder,dataFolder) % This function is only intended to support the ECGAndDeepLearningExample. % It may change or be removed in a future release.  rootFolder = parentFolder; localFolder = dataFolder; mkdir(fullfile(rootFolder,localFolder))  folderLabels = unique(ECGData.Labels); for i = 1:numel(folderLabels)     mkdir(fullfile(rootFolder,localFolder,char(folderLabels(i)))); end end

traza las primeras mil muestras de un representante de cada clase de señal ECG encontrada en .helperPlotRepsECGData

function helperPlotReps(ECGData) % This function is only intended to support the ECGAndDeepLearningExample. % It may change or be removed in a future release.  folderLabels = unique(ECGData.Labels);  for k=1:3     ecgType = folderLabels{k};     ind = find(ismember(ECGData.Labels,ecgType));     subplot(3,1,k)     plot(ECGData.Data(ind(1),1:1000));     grid on     title(ecgType) end end

utiliza para obtener la transformación continua de las señales ECG y genera los escalogramas a partir de los coeficientes de wavelet.helperCreateRGBfromTFcwtfilterbank La función auxiliar cambia el tamaño de los escalogramas y los escribe en el disco como imágenes jpeg.

function helperCreateRGBfromTF(ECGData,parentFolder,childFolder) % This function is only intended to support the ECGAndDeepLearningExample. % It may change or be removed in a future release.  imageRoot = fullfile(parentFolder,childFolder);  data = ECGData.Data; labels = ECGData.Labels;  [~,signalLength] = size(data);  fb = cwtfilterbank('SignalLength',signalLength,'VoicesPerOctave',12); r = size(data,1);  for ii = 1:r     cfs = abs(fb.wt(data(ii,:)));     im = ind2rgb(im2uint8(rescale(cfs)),jet(128));          imgLoc = fullfile(imageRoot,char(labels(ii)));     imFileName = strcat(char(labels(ii)),'_',num2str(ii),'.jpg');     imwrite(imresize(im,[224 224]),fullfile(imgLoc,imFileName)); end end