Deep Learning ANN Classification Model
이 질문을 팔로우합니다.
- 팔로우하는 게시물 피드에서 업데이트를 확인할 수 있습니다.
- 정보 수신 기본 설정에 따라 이메일을 받을 수 있습니다.
오류 발생
페이지가 변경되었기 때문에 동작을 완료할 수 없습니다. 업데이트된 상태를 보려면 페이지를 다시 불러오십시오.
이전 댓글 표시
0 개 추천
Hi,
I am trying to develop a pattern recognition classification ANN model for 10 different classes using 11 inputs. The model runs but the performance is poor (less than 20%). I want to try to a deep learning technique. However, all the examples are for image-based classification. My problem is numerical-based (no image). Is there a way to do a deep learning pattern-based classification model in MATLAB? Is there an example on how to do that?
채택된 답변
Cris LaPierre
2023년 12월 5일
편집: Cris LaPierre
2023년 12월 5일
Absolutely. Here is a page showing multiple examples, none of which are images: https://www.mathworks.com/help/deeplearning/gs/pattern-recognition-with-a-shallow-neural-network.html
These are all shallow networks. You can turn a shallow network into a deep network by adding more layers.
I would suggest using the Neural Network Pattern Recognition App to create a network, and then export the code. You can then manually expand that. Here's an example I built up based on the Iris data set example in the app (4 inputs, 3 outputs)
% Solve a Pattern Recognition Problem with a Neural Network
% Script generated by Neural Pattern Recognition app
% Created 05-Dec-2023 10:55:42
%
% This script assumes these variables are defined:
%
% irisInputs - input data.
% irisTargets - target data.
x = irisInputs;
t = irisTargets;
% Choose a Training Function
% For a list of all training functions type: help nntrain
% 'trainlm' is usually fastest.
% 'trainbr' takes longer but may be better for challenging problems.
% 'trainscg' uses less memory. Suitable in low memory situations.
trainFcn = 'trainscg'; % Scaled conjugate gradient backpropagation.
This is the section of code that creates the layers. It is a shallow network because there is only 1 hidden layer
% Create a Pattern Recognition Network
hiddenLayerSize = 10;
net = patternnet(hiddenLayerSize, trainFcn);
You can turn this into a deep learning network by adding more hidden layers. For example, this code would create a 3-layer network.
% Three hidden layer NN
hiddenLayerSize1 = 10;
hiddenLayerSize2 = 20;
hiddenLayerSize3 = 15;
net = patternnet([hiddenLayerSize1 hiddenLayerSize2 hiddenLayerSize3], trainFcn);
% Setup Division of Data for Training, Validation, Testing
net.divideParam.trainRatio = 70/100;
net.divideParam.valRatio = 15/100;
net.divideParam.testRatio = 15/100;
% Train the Network
[net,tr] = train(net,x,t);
% Test the Network
y = net(x);
e = gsubtract(t,y);
performance = perform(net,t,y)
tind = vec2ind(t);
yind = vec2ind(y);
percentErrors = sum(tind ~= yind)/numel(tind);
% View the Network
view(net)
% Plots
% Uncomment these lines to enable various plots.
%figure, plotperform(tr)
%figure, plottrainstate(tr)
%figure, ploterrhist(e)
%figure, plotconfusion(t,y)
%figure, plotroc(t,y)
댓글 수: 7
You'll need to load the Iris data set to run this code.
% Load the Iris dataset
[IrisInput,IrisOutput] = iris_dataset;
IrisInput contains 150 observations with four input features corresponding to four measurements:
- Sepal length in cm
- Sepal width in cm
- Petal length in cm
- Petal width in cm
IrisOutput is a 3-by-150 matrix that indicates the classification of each observation by placing a 1 in the row corresponding to the correct species: first fow for setosa, second row for versicolor, or third for virginica.
For more information on this dataset, run the following help command
help iris_dataset
Mostafa
2023년 12월 6일
Thanks for your answer. It was of great help. Unfortunately, while the accuracy did improve a little bit, it is still low (27% from the confusion matrix). Any ohter suggestions on how to improve the accuracy? I had some inputs that were 0 and 1 (yes or no) and I tried to remove them thinking that they may be some prediction challenges. It did help but nothing that significant.
Mostafa
2023년 12월 6일
I actually found that MATLAB does not define my output classes correctly. I have 34,610 observations divided into 10 output classes. However, when I look at the confusion matrix, some of the output classes have the term NaN with zero observation. To define the output classes, I used binary codes. For example, the first output class is:
0 0 0 0 0 0 0 0 0 1
The second is:
0 0 0 0 0 0 0 0 1 0
Any ideas what may be causing this?
Mostafa
2023년 12월 6일
As an update, I reran the file with only 219 observations and I am facing the same problems. Attached is the confusion matrix.
Consider attaching your data, and share your code. Without that, we can only talk generally.
Here is the code and small data set. I imported the data and transposed them.
Cris LaPierre
2023년 12월 6일
편집: Cris LaPierre
2023년 12월 6일
Not knowing anything about your data, I think you may need to look into feature engineering. Three of your inputs are highly correlated, meaning they aren't adding anything new to the model. However, even after eliminating them, I get similar results. To me, this means there is not enough difference in your inputs to generate an accurate model.
It = readmatrix("InputsMethod1.xlsx");
Ot = readmatrix("Outputs1.xlsx");
xnames = "input" + (1:4);
x = It';
t = Ot';
% turn t into vector of 'class labels'
f = max((0.1:0.1:1)' .* t);
% Normalize data
[x,ps]=mapminmax(x,0,1);
figure
plot([x;f])

% Inputs 1, 2 and 3 are highly correlated
figure
gplotmatrix(It,[],f,[],[],[],[],[],xnames)

% Detemine which features have the highest predictive power
[idx,scores] = fscmrmr(It,f')
idx = 1×4
1 4 2 3
scores = 1×4
0.0397 0.0034 0.0031 0.0205
bar(scores(idx))
xlabel('Predictor rank')
ylabel('Predictor importance score')
xticklabels(xnames(idx));

Here are the results I get with a shallow network for the 2 top features.

추가 답변 (0개)
카테고리
도움말 센터 및 File Exchange에서 Pattern Recognition에 대해 자세히 알아보기
참고 항목
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!웹사이트 선택
번역된 콘텐츠를 보고 지역별 이벤트와 혜택을 살펴보려면 웹사이트를 선택하십시오. 현재 계신 지역에 따라 다음 웹사이트를 권장합니다:
또한 다음 목록에서 웹사이트를 선택하실 수도 있습니다.
사이트 성능 최적화 방법
최고의 사이트 성능을 위해 중국 사이트(중국어 또는 영어)를 선택하십시오. 현재 계신 지역에서는 다른 국가의 MathWorks 사이트 방문이 최적화되지 않았습니다.
미주
- América Latina (Español)
- Canada (English)
- United States (English)
유럽
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)
