ECG Classification Using Transfer Learning | AI Techniques for ECG Classification, Part 4
From the series: AI Techniques for ECG Classification
Learn how you can easily generate sharp time-frequency representations with continuous wavelet transforms, and then use those representations to train pretrained convolutional networks like AlexNet and SqueezeNet to build ECG classifiers.
Published: 28 Jan 2021
Hello, everyone. Thank you very much for tuning in to this video. My name is Kirthi Devleker. And today's topic is ECG Classification using Transfer Learning. This video is part of the video series named Developing AI models For biomedical ECG signals. So let's begin.
So the goal of this video is to develop a classifier that can classify ECG signals into three distinct categories. The three categories here are ARR, CHF, and NSR. So we have 96, 30, and 36 signals in each class.
Our approach today is to use transfer learning techniques, which involves the use of convolutional neural networks to build a model that can help classify these signals quickly and easily. Now, my signals that I have here, each signal is roughly around 65,000 samples.
So what is transfer learning? So transfer learning is a part of-- is a type of a deep learning model where you can basically take a pre-trained network like-- and there are many networks that are available. And you can fine tune some parameters of these networks to actually fit to your task.
This is in stark contrast with training actually, a deep neural network from scratch, which requires you to assemble all the layers and figure out the right parameters for the layers. The reason why we are using transfer learning-- transfer learning is because we want to be able to quickly take these layers that have been already created and train these networks on our problems, so that we don't spend a lot of time creating new networks or troubleshooting the networks.
Because if you think about it, sometimes when you're trying to develop a model, you won't know if-- let's say if you don't get a good result, you wouldn't necessarily know if the reason for not getting the good result if it's because of the problems you have with the network or if maybe it's got something to do with your data.
So that's the reason why we are using transfer learning. And if you notice here, we have many different networks that you can leverage like AlexNet, VGG net, GoogelLeNet, SqueezeNet. So we will be using AlexNet just as an example. And the techniques that I'm going to cover can be re-applied using any network that you can see here. And MATLAB actually-- MATLAB actually provides you with a lot of different networks that you can use. So the list here, that's continuously growing. And you can, you know, take a look at what networks we have.
OK, so let's move on. So what do we need to actually build these networks? So one thing you need is these convolutional neural networks actually work very well on images. So that's why you need something called as the time-frequency representation that you can see on the left.
Now, the idea of this time-frequency of a signal can show you how the signal content or the frequency content evolves as a function of time. And this can be taught as a good pattern.
These time-frequency representations can be saved. And you can actually use a lot of these CNNs, or convolutional neural networks, like transfer learning networks, that can actually take in images and work on these images. So if you really think about it, these signals can have some distinct patterns depending on-- you know, it doesn't matter what the pattern is. But as long as you throw in enough data to this network, it can learn those patterns and identify which class a signal belongs to.
So the next question people may ask is, OK, what time-frequency representations can we choose from? And the answer is it really depends on your application. But one specific time-frequency representation that I'm going to use here is called a wavelet scalogram technique.
And the reason why we're using this technique is because we don't really have to specify any parameters and the network-- I'm sorry, the scalogram technique automatically gives you a short time-frequency representation without you having to worry about figuring out the right parameters, which is the case with some of these other transfer learning-- some of these other time-frequency representations.
So just to give you some idea, so I have an ECG signal here on the left. And if you notice, the spectrogram, for instance here, you know, it kind of smears your time-frequency representation a little bit. And it's hard to tell, you know, where your features are in your signal.
Compare it with the time-frequency representation you get from a scalogram, the same ECG, now you can see these beats, which are distinct with clearly separated-- and you can actually see that the frequency range of these beats and a lot of good other information. So in some sense, you're continuously wavelet transform is actually preserving all the information that's presenting your ECG signal. And the benefit is without really-- you know, you don't really need to make any assumptions about what parameter to use, what window to use, which is the case with spectrogram, to get a nice time-frequency view of your signals. OK?
So this is what we will be covering in our video today. Now, one other thing I just want to quickly address is although we have a lot of pre-trained networks and convolutional neural networks that are available, and MATLAB also works with-- if you don't find the network you're working with, you can use these or import models from Keras, or TensorFlow, or Caffe. And you can import and export these models to the ONNX format.
So ONNX stands for open neural network exchange. And it is becoming widely accepted and standard in the artificial intelligence community. So this allows you to interface with frameworks like PyTorch, Caffe tools, so on and so forth, OK?
So now let's look at what is the overall workflow for classification of these signals with CNNs that transfer learning? So the idea is we have some signals and we will generate time-frequency representations of these signals. Once we generate these time-frequency representations, we will save these representations as images.
And then we will take a deep network or convolutional neural network, basically following the transfer learning approach. We will make some modifications like a couple of small modifications to make that network adapt to our problem. And then we will train the model.
Now, once the model is trained, then we will evaluate how good or bad the model is by using some new signals and then following the same process, which is take the time to generate time-frequency representation and make the model predict. So this is the workflow we are going to follow. And with that, let's now jump into MATLAB to actually see this in action.
So here is MATLAB. So here is my script that I have, that we are going to walk through quickly now, OK? So just to show you the signals here-- so I have 162 signals here in my-- for today. And these signals basically belong to one of these three categories, ARR, CHF, and NSR. So the idea is we want to be able to build a model that can automatically classify a signal into these three distinct categories, OK?
So to visualize the signals, you can use an app called Signal Analyzer, which will help you analyze the signals in time, frequency, and time-frequency domain. You can open up the apps gallery and find out where the Signal Analyzer app is. But we will not cover this in this video. So you can always take a look and try it yourself. So here's the app, Signal Analyzer. And you can use that app to visualize this.
But coming to the main point of your video, let's now first look at the short time Fourier transform approach. So if you are familiar with short time Fourier transform-- sometimes people also call it a spectrogram-- the spectrogram typically requires your signal as input, which we certainly need it, and then some set of parameters. Now, these parameters are the type of the window, how many window segments you want to chunk your signal into, and then the number of FFT points, so on and so forth.
Now, in MATLAB, if you don't specify values for these parameters, you'll get default behavior, meaning the function actually chooses some values for you. And when you execute this, if you want to take a look here, what you would get is so here's your signal. And on the left is class I. And here's the spectrogram of your normal signal.
So if you notice, you kind of-- the time-frequency representation here, you're missing out on some features, like for instance, this relatively higher frequency QRS waves, you don't see the separation here in time-frequency view, OK?
So this is one of the reasons. So I can also show you for the same thing for other class as well. Like here, if you look at spectrogram for class II, you'll probably see the same thing. And again, the most important point here is you have a signal, which has features that are occurring at different scales. But in the time-frequency view, you kind of see a smooshed picture, right?
So whenever you have situations like this and you're working with real signals, I highly recommend you actually use the continuous wavelet transform. And if you notice, the continuous wavelet transform can actually help you get a very short time-frequency view of your signal. And I personally find it very useful when I'm trying to look at signals that I don't know, or if I don't know what kind of frequency components are present in the signal and what time, the continuous wavelet transform does a terrific job of identifying those components.
So at a minimum, the signal-- the function just takes in the signal. The sampling frequency is optional parameter, it just helps you to annotate the axes. And if you take a look at this, now the same signal, you notice that you have these QRS waves that are kind of nicely separated. So you see-- whenever you see a peak in the time domain view, these peaks are kind of nicely separated here. This is your QRS of the ECG signal.
So now once you have this, this kind of-- if you notice, this kind of looks like a nice pattern. And this belongs to a normal class. You'll see that the peaks are very well, or nicely separated here.
So if you want to do the same thing for class II, now you'll notice that you'll get a very similar picture. But now you'll start seeing some differences when you compare it with class I, right? So for instance, you'll see that the peaks-- the distance between these peaks is irregular and that's why probably that could be arrhythmia.
But again, the main point is you have a function that can give you a very sharp pattern. And now you can leverage those patterns to train a network. And deep networks typically are very good at identifying and building-- identifying patterns and building a model out of it, as long as your input representations are pretty sharp, right?
So that's basically the main idea of continuous wavelet transform. And now next step is if you want to generate time-frequency representations for a lot of signals, then there is one recommendation here, which is you can actually use something called as the filter bank, which is just you create a filter bank, this one-time operation, you just create a filter bank. And then you apply the filter bank, which is here, subsequently, in a loop across all your signals. So this will tremendously speed up your generation of your time-frequency representation. So you'll get all the coefficients. And these coefficients, you can take these coefficients and save it as a JPEG file.
So what I've done is I've actually created-- this is the function, a small script I've written, which kind of takes in-- looks through all the signals and generates time-frequency representations, meaning it generates all the coefficients. And I have a function here that will take those coefficients and save it as a JPEG file. So in some sense, if you want to look at-- if you want to look at the process, so this is how the process looks. So when it starts, loads in one signal, generates the time-frequency representation, and it starts saving the time frequency representations to disk, right? So this is typically-- this is the main workflow here.
Once you're done with this, then the next idea is you want to be able to build a model. But before you build a model, you need to take these time-frequency representations, or images, which are here-- I can show you on your left here. So you have the three folders, or three directories, which contain three different time-frequency representations for three different signals.
And by the use of data store, what you can do is you can take-- you can use the data store to manage your data here very well. Like for instance, the data it could be used to split each label. So I'm using an 80-20 split, which means I'm using 80% of my images for training, 20% for test. But again, it really depends on what your application is.
But the good thing about data store is it can help you work with-- this is just one small example. But let's say if you have lots and lots of data, then image data store kind of helps you to kind of just isolate your logic from your dealing with this large data sets, so figuring out writing extra code to load that data into memory, figure out what operations you want to perform, and so on and so forth, right? So image data store kind of gives you a nice way of working with that without really having to write a lot of code.
So now once you have your training and test images, what we will do is we will take our training images set and then we will build a model on this. And then the test images set will be used to evaluate the model. So how do we go about building the model? So as I said, for transfer learning, we will take a pre-trained network, so in this case AlexNet.
So AlexNet is, like, one of the most popular deep networks. Originally, AlexNet was created or designed to identify objects in 1,000 different categories. So you can see here, this 1000 here, is for identifying from 1,000 different classes. But what we will do is in our case, we'll just make a small modification. So we'll just take this 23rd layer here, which is fully connected. And instead of 1,000, we'll just modify to three because we just have three classes to worry about. And the last layer would be the classification layer.
So if I just make these changes, now you will notice that I have my fully connected layer. So this is three fully connected layers. So instead sort of 1,000, I just have three. And my last layer is classification layer.
So once I have made these changes, then the next step is just to set up some training options. And again, there is no magic way of finding these out. I mean, you can try for trial and error. Or you can use techniques like Bayesian optimization to actually figure out the right parameters.
So once you do this, you can train your network. Now, the good thing is this process actually, if you GPUs, this training process can be actually sped up on GPUs. And let me just see if I have training. Yep, I do have a training video here. So it could take a little bit time. So that's why I have created a video here. So let me just get the video in.
So here's the video. So play it from the beginning. So you'll notice that once you hit Train, obviously the video is sped up, you'll see that the models take some time to train. And eventually, it just settles at 100%. And there it also kind of goes to almost zero, which means that your model is now fully trained. Once your model is trained, then we will take the model results in this variable called myNet. And then what we will do is we will take the model here. And we will evaluate the model or test images.
So in this case, when I evaluate it on test images, it looks like my model has accuracy of one, which is 100%, which means the model has done a very good job of classifying new images that it has never seen. These are new time-frequency images. And you can see here. So the new-- so you have pretty much-- it's kind of done a pretty good job of classifying all the signals here.
So this is the main idea of this video. I mean, I hope you enjoyed this video. Just to summarize, transfer learning can be used to quickly develop a model without you having to write a lot of code. And I think as you can see here, we will make this code available for you. So you can just take this code, make some changes, and apply to your own problems, and see how it works.
OK, so thank you very much. That's all in this video. And I'll see you in the next one. Bye bye.