CNNへの交差検定(​Cross-Vali​dation)の導入​の仕方

조회 수: 9 (최근 30일)
ssk
ssk 2019년 2월 7일
편집: ssk 2019년 2월 11일
プログラミング初心者です。
現在、チュートリアルのコードを微修正して動かしており、以下のコードに交差検定の追加を検討しております。
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.5,'randomize');
help crossvarで検索すると、以下のようにでてきました。
TESTVAL = FUN(XTRAIN,XTEST)
こちらを、TESTVAL = FUN(imdsTrain, imdsValidation)とすると交差検定を導入できるという認識で
コンパイルしたのですが動きませんでした。
Undefined function or variable 'FUN'.
というエラーが出てしまいます。
交差検定の正しいやり方につきましてご教示いただけますと幸いです。
どうぞよろしくお願いいたします。

채택된 답변

Tohru Kikawada
Tohru Kikawada 2019년 2월 9일
crossvalのドキュメントに記載のある下記は指定する関数の戻り値と引数の一例です。
TESTVAL = FUN(XTRAIN,XTEST)
ドキュメントにあるいくつかの例題は試してみましたでしょうか。crossvalは様々な機械学習のアルゴリズムで使えるように汎用性のある関数ハンドルの受け渡しで実行されます。CNNで交差検定を実行する場合も下記のようにCNNのクラス分類結果を返すような関数を関数ハンドルとして渡してあげる必要があります。
%% データセットの読み込み
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
'nndatasets','DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
%% ダミーのトレーニングインデックスを生成
X = (1:imds.numpartitions)';
y = imds.Labels;
%% 交差検定にCNNの予測ラベル関数のポインタを渡す
mcr = crossval('mcr',X,y,'Predfun',@(xtrain,ytrain,xtest)myCNNPredict(xtrain,ytrain,xtest,imds))
%% CNNを学習し、予測ラベルを出力する関数
function ypred = myCNNPredict(xtrain,ytrain,xtest,imds)
% 結果が一意になるように乱数シードをデフォルト値に設定
rng('default');
% ダミーの変数ベクトルを受けてimageDatastoreを学習用とテスト用に分割
imdsTrain = imageDatastore(imds.Files(xtrain));
imdsTrain.Labels = ytrain;
imdsValidation = imageDatastore(imds.Files(xtest));
% レイヤーの設定
layers = [
imageInputLayer([28 28 1])
convolution2dLayer(3,8,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
options = trainingOptions('sgdm', ...
'InitialLearnRate',0.01, ...
'MaxEpochs',4, ...
'Shuffle','every-epoch', ...
'Verbose',false);
net = trainNetwork(imdsTrain,layers,options);
ypred = classify(net,imdsValidation);
end
  댓글 수: 4
ssk
ssk 2019년 2월 11일
ご回答ありがとうございます。おかげさまでチュートリアルのコードを無事、コンパイルすることができました。ありがとうございます。
DICOMファイルでも交差検定が使えるかどうか試したところ、以下のようなエラーが出てしまいます。
The function '@(xtrain,ytrain,xtest)myCNNPredict(xtrain,ytrain,xtest,imds)' generated
the following error:
Input folders or files contain non-standard file extensions.
拡張子が違うのが原因かもしれません。
currentdirectory = pwd;
% set categories of subdirectory
categories = {'a', 'b', 'c','d'};
imds = imageDatastore(fullfile(currentdirectory, categories),'IncludeSubfolders',true,'FileExtensions','.dcm','LabelSource', 'foldernames','ReadFcn',@dicomread);
mcr = crossval('mcr',X,y,'Predfun',(xtrain,ytrain,xtest)myCNNPredict(xtrain,ytrain,xtest,imds))
作成したコードは上記の通りですが、DICOMファイルでの交差検定の仕方につきまして、ご教示頂けますと幸いです。
どうぞよろしくお願いいたします。
ssk
ssk 2019년 2월 11일
편집: ssk 2019년 2월 11일
五月雨式のコメント失礼いたします。
頂いた回答につきまして以下の質問がございます。
%% ダミーのトレーニングインデックスを生成
X = (1:imds.numpartitions)';
(1)なぜ、ダミーのトレーニングインデックスを生成しているのか、
(2)なぜ、numpartitions(おそらくnumber of partition)を使っているのか、
(3)(1:imds.numpatition)の意味につきましてもご教示いただけますと幸いです。
@(xtrain,ytrain,xtest)myCNNPredict(xtrain,ytrain,xtest,imds)
また、mcrの意味につきましては、 misclassification rateの略語という意味でお間違えないでしょうか。
どうぞよろしくお願いいたします。

댓글을 달려면 로그인하십시오.

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Deep Learning Toolbox에 대해 자세히 알아보기

태그

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!