이 질문을 팔로우합니다.
- 팔로우하는 게시물 피드에서 업데이트를 확인할 수 있습니다.
- 정보 수신 기본 설정에 따라 이메일을 받을 수 있습니다.
matlabのディープラーニングでは、なぜテストデータを使わずにバリデーションデータを使うのか
조회 수: 6 (최근 30일)
이전 댓글 표시
채택된 답변
Kenta
2019년 3월 12일
単に、ここではバリデーションデータをテストデータと読み替えて問題ないと思います。また、以下のように、
[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.7,0.2);
などとして、画像を訓練、バリデーション、テストデータに分けると良いかもしれません。
リンクの学習曲線のところでは、バリデーションデータを使います。
そして、最後のところで
[YPred,probs] = classify(net,imdsTest);
accuracy = mean(YPred == imdsTest.Labels)
とすると、テストデータで正答率を計算できます。ここで、optionsのところに
'ValidationPatience', 3
を追加すれば学習の早期終了ができます。「'ValidationPatience' の値は、ネットワークの学習が停止するまでに、検証セットでの損失が前の最小損失以上になることが許容される回数です。」
とあります。学習がある程度のところで限界が来たらそこで学習がストップするので学習時間を短縮できたり、過学習が抑えられる可能性があります。
댓글 수: 11
ssk
2019년 3월 12일
itakuraさま、いつもご回答いただきまして誠にありがとうございます。本例につきまして、バリデーションデータとテストデータを置き換えることができる旨、ご教示いただきありがとうございます。imdstestも使った例でも試してみます。また、'ValidationPatience', も利用してみたいと思います。
あわせて、下記リンクにつきましてitakura様宛に追加でご質問させていただきましたのでご覧頂けますと幸いです。
https://jp.mathworks.com/matlabcentral/answers/447586-cross-validation
ssk
2019년 3월 13일
トレーニング、テスト、バリデーションの3つに分けたコードを試しに作成してみたのですが、以下のコードでご趣旨を反映できておりますでしょうか。
%% cross validation
[imds01,imds02,imds03,imds04,imds05,imds06,imds07,imds08,imds09,imds010]...
= splitEachLabel(imds,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,'randomize');
imdsTrain1 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files));
imdsTrain1.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels);
imdsTrain2 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds010.Files));
imdsTrain2.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds010.Labels);
imdsTrain3 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds09.Files,imds010.Files));
imdsTrain3.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds09.Labels,imds010.Labels);
imdsTrain4 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain4.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain5 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain5.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain6 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain6.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain7 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds06.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain7.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain8 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain8.Labels = cat(1,imds01.Labels,imds02.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain9 = imageDatastore(cat(1,imds01.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain9.Labels = cat(1,imds01.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain10 = imageDatastore(cat(1,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain10.Labels = cat(1,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
%% training
accuracy=zeros(1,10);
for i=1:10
stname1=sprintf('imdsTrain%d',i);
eval(['trainimds' ,'=', stname1,';'])
%trainimds.ReadFcn = @(filename)resize(filename);
i2=10-i+1;
stname2=sprintf('imds0%d',i2);
eval(['imdsValidation' ,'=', stname2,';'])
imdsValidation.ReadFcn = @(filename)resize(filename);
[imds11,imds12,imds13,imds14,imds15]...
= splitEachLabel(imds,0.2,0.2,0.2,0.2,'randomize');
imdsTest11 = imageDatastore(cat(1,imds11.Files,imds12.Files,imds13.Files,imds14.Files));
imdsTest11.Labels = cat(1,imds11.Labels,imds12.Labels,imds13.Labels,imds14.Labels);
imdsTest12 = imageDatastore(cat(1,imds11.Files,imds12.Files,imds13.Files,imds15.Files));
imdsTest12.Labels = cat(1,imds11.Labels,imds12.Labels,imds13.Labels,imds15.Labels);
imdsTest13 = imageDatastore(cat(1,imds11.Files,imds12.Files,imds14.Files,imds15.Files));
imdsTest13.Labels = cat(1,imds11.Labels,imds12.Labels,imds14.Labels,imds15.Labels);
imdsTest14 = imageDatastore(cat(1,imds11.Files,imds13.Files,imds14.Files,imds15.Files));
imdsTest14.Labels = cat(1,imds11.Labels,imds13.Labels,imds14.Labels,imds15.Labels);
imdsTest15 = imageDatastore(cat(1,imds12.Files,imds13.Files,imds14.Files,imds15.Files));
imdsTest15.Labels = cat(1,imds11.Labels,imds13.Labels,imds14.Labels,imds15.Labels);
%% training for test data
accuracy=zeros(11,15);
for i3=11:15
stname3=sprintf('imdsTest%d',i3);
eval(['imdsTest' ,'=', stname3,';'])
%imdsTest.ReadFcn = @(filename)resize(filename);
i4=15-i+1;
stname4=sprintf('imds0%d',i4);
eval(['imdsValidation' ,'=', stname4,';'])
imdsValidation.ReadFcn = @(filename)resize(filename);
%%train network(中略)
[YPred,probs] = classify(net,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)
[YPred,probs] = classify(net,imdsTest);
accuracy = mean(YPred == imdsTest.Labels)
Kenta
2019년 3월 14일
i番目のループのなかで、トレーニングデータ(仮)をトレーニングデータとバリデーションデータに分けたらいいと思います。そして、バリデーションデータをテストデータ(ただ名前を変えるだけ)としてテストしたらいいです。
ある程度までロスが下がり切ったりしたら計算時間が冗長になるし、訓練データに過適合するのを防げます。ただ、たくさんの枚数をこなしたときに必ずしももこの操作が必要かどうかは不明です。1クラス100枚くらいで交差検証なしでやってみてはどうでしょうか。CPUで計算してもそこまで計算時間はかからないと思います。
%% cross validation
[imds01,imds02,imds03,imds04,imds05,imds06,imds07,imds08,imds09,imds010]...
= splitEachLabel(imds,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,'randomize');
imdsTrain1 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files));
imdsTrain1.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels);
imdsTrain2 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds010.Files));
imdsTrain2.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds010.Labels);
imdsTrain3 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds09.Files,imds010.Files));
imdsTrain3.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds09.Labels,imds010.Labels);
imdsTrain4 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain4.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain5 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain5.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain6 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain6.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain7 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds06.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain7.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain8 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain8.Labels = cat(1,imds01.Labels,imds02.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain9 = imageDatastore(cat(1,imds01.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain9.Labels = cat(1,imds01.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain10 = imageDatastore(cat(1,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain10.Labels = cat(1,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
%% training
accuracy=zeros(1,10);
for i=1:10
stname1=sprintf('imdsTrain%d',i);
eval(['trainimds' ,'=', stname1,';'])
%trainimds.ReadFcn = @(filename)resize(filename);
[imdstrain,imdsvalidation]=splitEachLabel(trainimds,0.8);
i2=10-i+1;
stname2=sprintf('imds0%d',i2);
eval(['imdsTest' ,'=', stname2,';'])
imdsTest.ReadFcn = @(filename)resize(filename);
%% training for test data
%imdstrainで訓練
%imdsvalidationをoptionsのなかのvalidationに指定
%imdstestでテスト
ssk
2019년 3월 14일
편집: ssk
2019년 3월 14일
ありがとうございます!コードを試したところ無事に動きました。本コードにおけるクロスバリデーションのニュアンスの確認をしたいのですが、はじめに全ての画像をtrainingとして均等に10分割し、さらに10分割した画像をそれぞれtraining:validation = 8:2で分ける。このとき、testはvalidationと同視できるので、training:test = 8:2である。(つまり、本データの8割をtraining、2割をtest(validation)として使う。その後、組み合わせをかえてそれぞれの画像のaccuracyを調べて平均を取る。上記の認識でよろしいでしょうか?
以前あった例ですと、
[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.7,0.2,0.1); で合計が100%ですが、今回の場合は、[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.8,0.2,0.2);で合計120%のような気もするのですが、例えば[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.6,0.2,0.2);のような形で修正する必要はないのでしょうか?
また、なぜテストデータとバリデーションデータを同視できるか理由をご存知でしたらご教示いただけますと幸いです。
ssk
2019년 3월 15일
ニュアンスをご教示いただきありがとうございました!
覚えが悪く大変申し訳なのでもう一度確認しますが、(1,000枚の画像がある場合、)まず900(training), 100(test)に分けて、その後で900を更に800(training)、100(varidation)に分けるということですね。以上から、800(training)、100(test)、100(validation)になるということでしょうか。
上記コードですと、for loop内では以下のようになっておりますので
[imdstrain,imdsvalidation]=splitEachLabel(trainimds,0.8);
imstrain1の場合、
training:900(枚)*0.8= 720(枚)
validation:900(枚)*0.2= 180(枚)
test: 100(枚)・・・imds10
これを順に10回続けていって・・・
imstrain10の場合、
training:900(枚)*0.8= 720(枚)
validation:900(枚)*0.2= 180(枚)
test: 100(枚)・・・imds01
以上の平均を求めるという認識でよろしいでしょうか?
추가 답변 (0개)
참고 항목
카테고리
Help Center 및 File Exchange에서 Deep Learning Toolbox에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!오류 발생
페이지가 변경되었기 때문에 동작을 완료할 수 없습니다. 업데이트된 상태를 보려면 페이지를 다시 불러오십시오.
웹사이트 선택
번역된 콘텐츠를 보고 지역별 이벤트와 혜택을 살펴보려면 웹사이트를 선택하십시오. 현재 계신 지역에 따라 다음 웹사이트를 권장합니다:
또한 다음 목록에서 웹사이트를 선택하실 수도 있습니다.
사이트 성능 최적화 방법
최고의 사이트 성능을 위해 중국 사이트(중국어 또는 영어)를 선택하십시오. 현재 계신 지역에서는 다른 국가의 MathWorks 사이트 방문이 최적화되지 않았습니다.
미주
- América Latina (Español)
- Canada (English)
- United States (English)
유럽
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom(English)
아시아 태평양
- Australia (English)
- India (English)
- New Zealand (English)
- 中国
- 日本Japanese (日本語)
- 한국Korean (한국어)