ECG Classification Using Machine Learning | AI Techniques for ECG Classification, Part 2
From the series: AI Techniques for ECG Classification
Learn about feature extraction techniques and feature selection techniques, such as minimum redundancy maximum relevance (mRMR), and leverage the built-in apps such as Classification Learner to easily build machine learning models for ECG classification.
Published: 28 Jan 2021
Hello. Welcome to the part two of this video series on developing AI model for biomedical ECG signals. In this video, we will talk about how to train machine learning algorithms on these ECG signals to build the classification model for classifying these three classes of ECG signals, which were arrhythmia, congestive heart failure, and normal sinus rhythm. And as we discussed in the video one, we saw that the signal length of each of these samples are 65,000 samples long and we have 162 records of ECG data, which we will be working with.
Now a typical, easy way to train the machine learning models would be that we can directly feed our raw ECG signals in a machine learning classification algorithm and it would give us the outputs like the arrythmia, congestive heart failure. But unfortunately, this approach does not work and this is probably because the signals are really, really long and then the features of the signals are changing very rapidly with time, which the machine learning model does not find a good way to interpret.
So the way it works is we add an additional step of feature engineering to extract the features from the ECG signals and use those features to train our machine learning algorithms, which helps build this model with a higher accuracy. So now the feature engineering step has numerous advantages and numerous aspects to it. The first and foremost mean that it does greatly reduce the dimensionality of the input data.
So in our case, our signals are 65,000 samples long. We'd no longer would be worried about a big data size. It would be reduced into a small set of features, which would be used to train that model. And then also, on top of that, the feature selection also is a part of feature engineering, which allows you to select the features which are more prominent or have a more effect on training of the machine learning models and we can eliminate the other features, which are probably just causing noise to the input data set.
And some of the ideas for doing feature extraction can be like you can extract statistical features like the principal component analysis, NCA, or because we are working with signals we can take power spectral density, use wavelet feature extractors, and a lot of different options are available to us.
And when we work especially with signal, MATLAB does offer a lot of different tools which you can use for feature extractions, both in the time domain as well as frequency domain. For instance, we can extract all the R-peaks and maybe use that as an input. Maybe find certain patterns in the signal of PQRS wave. We could use that. Or maybe see the spectral spread of the signals and use that as a feature set as input to the machine learning algorithms.
But the challenge with this manual feature extraction is that it is quite iterative because we do not know, when we are trying to create a new model, which features will actually have an impact in improving the accuracy of the model and which features are more just adding noise to the input data set. So that's why it becomes quite iterative, and there's a lot of hit and trial before we can actually eliminate the non useful features and just focus on one or two features which have the most impact on the training.
But the good news is with MATLAB, you also have an automatic feature extraction routine using your favorite scattering filter banks. Now what this automatic feature extraction does is we've got a series of filter banks. It's very similar to a convolutional neural network, but with the difference being that the weights or the filter coefficients are already known well in advance. And just like convolutional neural network, we do have a wavelength convolution with the filter bank. Then we have a nonlinearity operation with a modulus. And then there's some averaging happening. And after each layer, it spits out a set of scattering coefficients which are used as scattering features for training the classifiers. So we can have a number of layers, but usually two layers or three layers gives us sufficient enough features to train the classifier.
And it can relieve a lot of requirements on amount of data and model complexity. And what that means is this large signal will get reduced in a small set of features, but at the same time even in scenarios when we do not have enough data, still the feature extraction method by this technique extracts the very prominent pieces from the signal which influence the training. So we even can get away with a little amount of data and train our models.
If you would like to learn more about this wavelet scattering approach, feel free to take a look at our documentation page on mathworks.com/wavelet-scattering, and you can see a little bit more about the theory of wavelength scattering as well as some relevant examples of how to use it. But anyways, we will be using the wavelength scattering feature to train some machine learning models using the Classification Learner app in our MATLAB interface. So without much delay let's jump to MATLAB.
So this is the MATLAB script we have. And the first step is we will load our ECG data. And as I'd mentioned before, our ECG data is 162 records. Each record is 65,000 samples long, and you can see the corresponding labels NSR, CHF and ARR. So let's go ahead and see what this wavelet scattering feature looks like. So in this section, we are extracting this wavelet scattering, reconstructing this wavelength scattering filter bank, giving the length of the signal, and then we are extracting the features using the function feature matrix.
And when we do that if you see or hear, the sample signal, which is of 65,000 samples long, gets reduced down to a feature set of 499 by 8. And that is a 95% reduction in the size of features. And this feature set is extracted automatically in just these two lines of code. So how powerful is that? We can visualize these features, say if we want to visualize the level one scattering coefficients. It will not make much sense graphically just by inspecting them, but if you see that for CHF the feature sets of scattering coefficients will look something like this. And this 499 by 8 features is what we'll use to train our models.
So let's go ahead, and actually I already have some partition data. So what I did was I separated out 113 samples for training, and 49 samples I keep aside of ECG samples for testing and randomly selected. So let's go ahead and compute the feature matrix of all of this data for both the test as well as training samples. So I'll go ahead and run this feature matrix script over here. I split it into training data and testing data. So now if I see this training data and testing data, I have 499 features arranged in the form of column vectors. And the last column is my 500th column with my associated labels. One thing to note here is that the 499 by 8 features, I spread it out saying that now it's no longer 499 by 8, but for one class I have 8 times 499 features. So it means that I'm expanding this 499 by 8 by 13 to result in a part of 904 by 500 for training and 392 by 500 for testing, and just repeating those 8 labels for each class.
So that being said, my data is ready to be ingested in my Classification Learner app, and that I can open from my apps tab in the machine learning, deep learning like this. Or even I can invoke it from the command line by just running this Classification Learner command. And now what I can do is I can feed in this training data and build classifiers out of it. So this is my Classification Learner app. And what I'll do is I'll start a new session. In this case, it says ask me to select an input, and I can select my training data as an input. Note that it automatically extracts us feature sets, which is 499 columns over here. That is my predictors tab. And then the labels of this table is this unique categorical area with three unique classes. So let's go ahead and start a new session in this Classification Learner.
And now you can plot this scatter plot for all these different predictors. There's 499 by 8. So you can see how they are correlated with each other for these three classes, ARR, CHF and NSR. But then we would want to train the algorithm. And this is where we can choose which algorithm we'd like to train. Now you can train nearest neighbor classifiers like cubic CNN, cosine KNN, medium KNN, and then there's SVM's, and there's this whole library of discriminant analysis, decision trees, and whatnot. Now since we do not know which algorithm works the best, what I like to do is I just want to train all of them. So let's go ahead and train all of these classifiers together.
So note that it's starting up parallel pool. What it means is that I'm going to use the parallel course for doing the training because I have selected that parallel option so that would speed up this training process. So I see that most of my algorithms are trained at this point. And we can see that two of these algorithms, ensemble subspace discriminant and subspace KNN, get the highest accuracy of 100%. So let's go ahead and select these algorithms, and what I can do is I've got this option generate function. So MATLAB would automatically regenerate the script for training this algorithm again. Or I can just export the trained model to my workspace and test out my algorithm. So that's what I'm going to do. I exported to a variable named trainedModel.
All right, so now let's go ahead and evaluate this trained model on the test data. So I have this testing data set, and I'm just going to run this. Accuracy of the predicted class is going to be compared with my true class accuracy. And I see that it's all the way in the training in the Classification Learner app. I was getting 100% accuracy, but for the test data it's not doing as good of a job. It's close to 90% accuracy, which is again not bad, but we do see that some of the ARR samples are misclassified, and so is the CHF.
So what do we do now? We did start with the 499 by 8 for each signal data set, but like I said in the beginning in my slides, some of these features might be causing just more noise to the data set and not helping me train them more. So how do we go about that problem? Then what we can do is we can use some automatic feature selectors in MATLAB to select the prominent features and remove the unwanted features. Now if you'd like to see one of the feature selectors I'm using is fscmrmr, and if you open the help documentation on it. So this is an automatic feature selection algorithm which ranks features for classification using the Minimum Redundancy Maximum Relevance algorithm. So this documentation page kind of walks you how to use it, but if you'd like to see the algorithm, what this Minimum Redundancy Maximum Relevance is, then we also have the information here at the bottom of the documentation page. And not to forget, we also have the references from where we've taken this algorithm.
So this is an automated ranking feature selection which I'll use. So let's go ahead and run this. And what I'll do is it will calculate the scores for all the training data sets and evaluate them on what features are more prominent and have a greater impact and which features are having less impact. So I see that if I plot the first 100 features, this is how this plot would look like. And if I see this matrix MRMR features. So let's open this up. So you see that the feature number one has got a very high score, then the second highest score is feature number 141. Third highest, 395, 73, and you can see so on. So probably what I'll do is in my next training, I'll just use the top 20 feature sets, not using the entire 499 feature set. And then I'll call that on the Classification Learner app. So let's go ahead and do that.
Again I'll start a new session. And in this time, my feature is not the training data, but now it's training select data. And now you see there are only 20 features which I'm going to use to train. So let's go ahead and start the session. And again, in this case, we will train probably all the machine learning algorithms. So let's go ahead and click train. Give it a minute. And this time the feature sets are less, so the training would be really quick. And we see that all these algorithms are trained and probably the cubic SVM and the quadratic SVM have got the maximum accuracy.
So let's go ahead, and we can see this confusion chart even from this. On the test data, only three misclassifications. On the other hand, the cubic SVM has three misclassifications again, but then they are one for ARR and two for CHF. So let's stick with this cubic SVM and try to export this model. And I'm going to override the previous model, the trained model, in the workspace. Do you want to override it? Yes. Going back to the script now, let's evaluate our new trained model with the selected features. And now we see that all of a sudden our accuracy went from 89% to 95.6% on the test data set.
And if you're still not satisfied, let's even export this quadratic SVM to the workspace. And it says, do you want to override it? Yes. And I re-evaluate the test data on it, and before I know it, it's 97% accuracy on the test data, which this algorithm has never seen before. So you can see that normally we quickly train the machine learning algorithms using the Classification Learner app, we automatically extracted the features using the favorite scattering networks, but also we selected the relevant features using the MRMR algorithm to get a better accuracy for trained models.
So this concludes our part two video of the series. In the part three videos, we'll see how we can train deep learning networks on the same data set. Thank you.