matlabのディープラーニングでは、なぜテストデータを使わずにバリデーションデータを使うのか
이 질문을 팔로우합니다.
- 팔로우하는 게시물 피드에서 업데이트를 확인할 수 있습니다.
- 정보 수신 기본 설정에 따라 이메일을 받을 수 있습니다.
오류 발생
페이지가 변경되었기 때문에 동작을 완료할 수 없습니다. 업데이트된 상태를 보려면 페이지를 다시 불러오십시오.
이전 댓글 표시
0 개 추천
プログラミング初心者です。
下記リンクにつきまして、
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7);
という一文がありますが、なぜ、テストデータを使わずにバリデーションデータを使うのでしょうか。
imdsValidationではなく、imdsTestだと納得できるのですが不思議です。
もしバリデーションデータを使うのであれば、テストデータは使わなくてもいいかご教示頂けますと幸いです。
채택된 답변
Kenta
2019년 3월 12일
3 개 추천
単に、ここではバリデーションデータをテストデータと読み替えて問題ないと思います。また、以下のように、
[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.7,0.2);
などとして、画像を訓練、バリデーション、テストデータに分けると良いかもしれません。
リンクの学習曲線のところでは、バリデーションデータを使います。
そして、最後のところで
[YPred,probs] = classify(net,imdsTest);
accuracy = mean(YPred == imdsTest.Labels)
とすると、テストデータで正答率を計算できます。ここで、optionsのところに
'ValidationPatience', 3
を追加すれば学習の早期終了ができます。「'ValidationPatience' の値は、ネットワークの学習が停止するまでに、検証セットでの損失が前の最小損失以上になることが許容される回数です。」
とあります。学習がある程度のところで限界が来たらそこで学習がストップするので学習時間を短縮できたり、過学習が抑えられる可能性があります。
댓글 수: 11
ssk
2019년 3월 12일
itakuraさま、いつもご回答いただきまして誠にありがとうございます。本例につきまして、バリデーションデータとテストデータを置き換えることができる旨、ご教示いただきありがとうございます。imdstestも使った例でも試してみます。また、'ValidationPatience', も利用してみたいと思います。
あわせて、下記リンクにつきましてitakura様宛に追加でご質問させていただきましたのでご覧頂けますと幸いです。
https://jp.mathworks.com/matlabcentral/answers/447586-cross-validation
ssk
2019년 3월 13일
追加でのご質問失礼いたします。
[imdsTrain,imdsTest,imdsValidation] = splitEachLabel(imds,0.7,0.2);と仮にテストデータとバリデーションデータを分けた場合、クロスバリデーションは以前ご教示いただいた方法と同じものを使っても差し支えないでしょうか。それとも、imdstest向けのコードの変更が必要でしょうか。
Kenta
2019년 3월 13일
3つに分けたかったら少し改変する必要があります。
具体的には、9つ分のイメージデータストアを合わせたものをもう一度、imdsTrain, imdsValidationにわける必要があります。
ただ、今回のような結果だといままでどおりの交差検証のコードでひとまずやってみたらよいと思います。枚数を増やしてうまくいかなかったらまた対策を考える方向がよいかと思います。
ssk
2019년 3월 13일
ご教示いただきありがとうございます。まずはもともとの交差検証のコードで進めたいと思います。念のため、9つ分のイメージデータストアを合わせたものをもう一度、imdsTrain, imdsValidationにわけるコードにつき、以前ご教示いただいたコードを基に自身で作成してみます。
ssk
2019년 3월 13일
トレーニング、テスト、バリデーションの3つに分けたコードを試しに作成してみたのですが、以下のコードでご趣旨を反映できておりますでしょうか。
%% cross validation
[imds01,imds02,imds03,imds04,imds05,imds06,imds07,imds08,imds09,imds010]...
= splitEachLabel(imds,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,'randomize');
imdsTrain1 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files));
imdsTrain1.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels);
imdsTrain2 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds010.Files));
imdsTrain2.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds010.Labels);
imdsTrain3 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds09.Files,imds010.Files));
imdsTrain3.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds09.Labels,imds010.Labels);
imdsTrain4 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain4.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain5 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain5.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain6 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain6.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain7 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds06.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain7.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain8 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain8.Labels = cat(1,imds01.Labels,imds02.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain9 = imageDatastore(cat(1,imds01.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain9.Labels = cat(1,imds01.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain10 = imageDatastore(cat(1,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain10.Labels = cat(1,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
%% training
accuracy=zeros(1,10);
for i=1:10
stname1=sprintf('imdsTrain%d',i);
eval(['trainimds' ,'=', stname1,';'])
%trainimds.ReadFcn = @(filename)resize(filename);
i2=10-i+1;
stname2=sprintf('imds0%d',i2);
eval(['imdsValidation' ,'=', stname2,';'])
imdsValidation.ReadFcn = @(filename)resize(filename);
[imds11,imds12,imds13,imds14,imds15]...
= splitEachLabel(imds,0.2,0.2,0.2,0.2,'randomize');
imdsTest11 = imageDatastore(cat(1,imds11.Files,imds12.Files,imds13.Files,imds14.Files));
imdsTest11.Labels = cat(1,imds11.Labels,imds12.Labels,imds13.Labels,imds14.Labels);
imdsTest12 = imageDatastore(cat(1,imds11.Files,imds12.Files,imds13.Files,imds15.Files));
imdsTest12.Labels = cat(1,imds11.Labels,imds12.Labels,imds13.Labels,imds15.Labels);
imdsTest13 = imageDatastore(cat(1,imds11.Files,imds12.Files,imds14.Files,imds15.Files));
imdsTest13.Labels = cat(1,imds11.Labels,imds12.Labels,imds14.Labels,imds15.Labels);
imdsTest14 = imageDatastore(cat(1,imds11.Files,imds13.Files,imds14.Files,imds15.Files));
imdsTest14.Labels = cat(1,imds11.Labels,imds13.Labels,imds14.Labels,imds15.Labels);
imdsTest15 = imageDatastore(cat(1,imds12.Files,imds13.Files,imds14.Files,imds15.Files));
imdsTest15.Labels = cat(1,imds11.Labels,imds13.Labels,imds14.Labels,imds15.Labels);
%% training for test data
accuracy=zeros(11,15);
for i3=11:15
stname3=sprintf('imdsTest%d',i3);
eval(['imdsTest' ,'=', stname3,';'])
%imdsTest.ReadFcn = @(filename)resize(filename);
i4=15-i+1;
stname4=sprintf('imds0%d',i4);
eval(['imdsValidation' ,'=', stname4,';'])
imdsValidation.ReadFcn = @(filename)resize(filename);
%%train network(中略)
[YPred,probs] = classify(net,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)
[YPred,probs] = classify(net,imdsTest);
accuracy = mean(YPred == imdsTest.Labels)
Kenta
2019년 3월 14일
i番目のループのなかで、トレーニングデータ(仮)をトレーニングデータとバリデーションデータに分けたらいいと思います。そして、バリデーションデータをテストデータ(ただ名前を変えるだけ)としてテストしたらいいです。
ある程度までロスが下がり切ったりしたら計算時間が冗長になるし、訓練データに過適合するのを防げます。ただ、たくさんの枚数をこなしたときに必ずしももこの操作が必要かどうかは不明です。1クラス100枚くらいで交差検証なしでやってみてはどうでしょうか。CPUで計算してもそこまで計算時間はかからないと思います。
%% cross validation
[imds01,imds02,imds03,imds04,imds05,imds06,imds07,imds08,imds09,imds010]...
= splitEachLabel(imds,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,'randomize');
imdsTrain1 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files));
imdsTrain1.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels);
imdsTrain2 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds010.Files));
imdsTrain2.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds010.Labels);
imdsTrain3 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds09.Files,imds010.Files));
imdsTrain3.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds09.Labels,imds010.Labels);
imdsTrain4 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain4.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain5 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain5.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain6 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds04.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain6.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds04.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain7 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds03.Files,imds06.Files,imds05.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain7.Labels = cat(1,imds01.Labels,imds02.Labels,imds03.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain8 = imageDatastore(cat(1,imds01.Files,imds02.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain8.Labels = cat(1,imds01.Labels,imds02.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain9 = imageDatastore(cat(1,imds01.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain9.Labels = cat(1,imds01.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
imdsTrain10 = imageDatastore(cat(1,imds02.Files,imds03.Files,imds04.Files,imds05.Files,imds06.Files,imds07.Files,imds08.Files,imds09.Files,imds010.Files));
imdsTrain10.Labels = cat(1,imds02.Labels,imds03.Labels,imds04.Labels,imds05.Labels,imds06.Labels,imds07.Labels,imds08.Labels,imds09.Labels,imds010.Labels);
%% training
accuracy=zeros(1,10);
for i=1:10
stname1=sprintf('imdsTrain%d',i);
eval(['trainimds' ,'=', stname1,';'])
%trainimds.ReadFcn = @(filename)resize(filename);
[imdstrain,imdsvalidation]=splitEachLabel(trainimds,0.8);
i2=10-i+1;
stname2=sprintf('imds0%d',i2);
eval(['imdsTest' ,'=', stname2,';'])
imdsTest.ReadFcn = @(filename)resize(filename);
%% training for test data
%imdstrainで訓練
%imdsvalidationをoptionsのなかのvalidationに指定
%imdstestでテスト
ありがとうございます!コードを試したところ無事に動きました。本コードにおけるクロスバリデーションのニュアンスの確認をしたいのですが、はじめに全ての画像をtrainingとして均等に10分割し、さらに10分割した画像をそれぞれtraining:validation = 8:2で分ける。このとき、testはvalidationと同視できるので、training:test = 8:2である。(つまり、本データの8割をtraining、2割をtest(validation)として使う。その後、組み合わせをかえてそれぞれの画像のaccuracyを調べて平均を取る。上記の認識でよろしいでしょうか?
以前あった例ですと、
[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.7,0.2,0.1); で合計が100%ですが、今回の場合は、[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.8,0.2,0.2);で合計120%のような気もするのですが、例えば[imdsTrain,imdsValidation, imdsTest] = splitEachLabel(imds,0.6,0.2,0.2);のような形で修正する必要はないのでしょうか?
また、なぜテストデータとバリデーションデータを同視できるか理由をご存知でしたらご教示いただけますと幸いです。
Kenta
2019년 3월 15일
今回の場合は、1000枚あったとすると、900, 100に分けて、その900をさらに800と100に分けたイメージです。
もちろんおっしゃるように、1000枚を一気に8:1:1に分けても等価です。そのサイクルを10回して、それを平均すればいいです。
同視できるというのは、バリデーションデータと名付けられたデータでテストをしているので、その場合に限り、バリデーションデータをテストデータと読み替えて問題ないのでは?ということです。
本来、バリデーションデータとテストデータは異なったニュアンスを持っているものと思います。
ssk
2019년 3월 15일
ニュアンスをご教示いただきありがとうございました!
覚えが悪く大変申し訳なのでもう一度確認しますが、(1,000枚の画像がある場合、)まず900(training), 100(test)に分けて、その後で900を更に800(training)、100(varidation)に分けるということですね。以上から、800(training)、100(test)、100(validation)になるということでしょうか。
上記コードですと、for loop内では以下のようになっておりますので
[imdstrain,imdsvalidation]=splitEachLabel(trainimds,0.8);
imstrain1の場合、
training:900(枚)*0.8= 720(枚)
validation:900(枚)*0.2= 180(枚)
test: 100(枚)・・・imds10
これを順に10回続けていって・・・
imstrain10の場合、
training:900(枚)*0.8= 720(枚)
validation:900(枚)*0.2= 180(枚)
test: 100(枚)・・・imds01
以上の平均を求めるという認識でよろしいでしょうか?
Kenta
2019년 3월 17일
はい、それで正しいと思います。
ssk
2019년 3월 17일
ありがとうございます!
추가 답변 (0개)
카테고리
도움말 센터 및 File Exchange에서 Deep Learning Toolbox에 대해 자세히 알아보기
참고 항목
웹사이트 선택
번역된 콘텐츠를 보고 지역별 이벤트와 혜택을 살펴보려면 웹사이트를 선택하십시오. 현재 계신 지역에 따라 다음 웹사이트를 권장합니다:
또한 다음 목록에서 웹사이트를 선택하실 수도 있습니다.
사이트 성능 최적화 방법
최고의 사이트 성능을 위해 중국 사이트(중국어 또는 영어)를 선택하십시오. 현재 계신 지역에서는 다른 국가의 MathWorks 사이트 방문이 최적화되지 않았습니다.
미주
- América Latina (Español)
- Canada (English)
- United States (English)
유럽
- Belgium (English)
- Denmark (English)
- Deutschland (Deutsch)
- España (Español)
- Finland (English)
- France (Français)
- Ireland (English)
- Italia (Italiano)
- Luxembourg (English)
- Netherlands (English)
- Norway (English)
- Österreich (Deutsch)
- Portugal (English)
- Sweden (English)
- Switzerland
- United Kingdom (English)