필터 지우기
필터 지우기

ニューラルネットワー​クの学習をdoubl​e型で行うことはでき​ますか?

조회 수: 1 (최근 30일)
Fumiya Watanabe
Fumiya Watanabe 2018년 6월 26일
댓글: Fumiya Watanabe 2018년 7월 5일
ニューラルネットワークの学習をdouble型で行うことはできますか?
現在、ある実数値ベクトルを入力とする回帰問題をNeural Network Toolboxを用いて実現しようとしています。 このベクトル入力を画像入力として扱うことで実現を考えています。しかしながら、trainNetworkを実行するとsingle型として扱われてしまう問題が生じており、解決法がわからず困っております。
例えば、次の自作の回帰層を考えます。
classdef testLayer < nnet.layer.RegressionLayer
methods
function layer = testLayer()
end
function loss = forwardLoss(layer, Y, T)
loss = gpuArray(0);
end
function dLdX = backwardLoss(layer, Y, T)
dLdX = gpuArray(zeros(size(Y)));
end
end
end
この自作回帰層を用いて、次のように学習を実行します。
%%学習データ
x_in = rand(10, 1, 1, 6);
y_tr = rand(6, 5);
%%層構造とオプションの定義
layers = [
imageInputLayer([10 1 1], 'Normalization', 'none', 'Name', 'Input')
fullyConnectedLayer(2, 'Name', 'Layer1')
reluLayer('Name', 'ReLU1')
fullyConnectedLayer(5, 'Name', 'Output')
testLayer
];
layers(end).Name = 'Regression';
options = trainingOptions(...
'sgdm',...
'InitialLearnRate', 0.001, ...
'MiniBatchSize', 3, ...
'MaxEpochs', 1);
%%学習開始
net = trainNetwork(x_in, y_tr, layers, options);
すると、次のエラーが発生します。
エラー: trainNetwork (line 154)
Incorrect type of dLdX for 'backwardLoss' in the output layer. Expected gpuArray of underlying type 'single', but instead has
underlying type 'double'.
上記の自作回帰層で、gpuArrayの内部をsingleにキャストすることで実行することが可能となるのですが、実際に使っている自作回帰層ではdouble型でないと計算できない関数を利用しているため、
function loss = forwardLoss(layer, Y, T)
loss = gpuArray(single(myfun(double(Y), double(T))));
end
のようなキャストをしていく必要が生じてしまいます。これを避けるために学習をdouble型で実行したいのですが、解決法はありますでしょうか。

채택된 답변

Naoya
Naoya 2018년 6월 29일
Neural Network Toolbox で提供される 畳み込みニューラルネットワークですが、trainNetwork 側で与えるデータ型は single, double 両方を受け付けます。
しかしながら、基本的にGPU上では単精度演算として扱われますので、GPU へ渡すゲートウェイとなるデータ型は single型となってしまいます。
  댓글 수: 3
Naoya
Naoya 2018년 7월 3일
ご連絡ありがとうございます。 cpuモードの場合でも backwardLoss 関数のゲートウェイは single型にする必要があります。
Fumiya Watanabe
Fumiya Watanabe 2018년 7월 5일
ご回答ありがとうございます。
入力としてはdouble型を受け付けるが、計算内部はGPU・CPUどちらの場合でもsingle型で実行される形になっており、自作の層を扱う場合はsingle型でほかの層とのやり取りが必要であると理解いたしました。 ありがとうございました。

댓글을 달려면 로그인하십시오.

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 深層学習データの前処理에 대해 자세히 알아보기

제품


릴리스

R2018a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!