Custom deep learning network - gradient function using dlfeval
이전 댓글 표시
I want to create a custom deep learning training function, the output of which is an array Y. I have two inputs, the arrays X1 and X2. I want to find the gradient of Y with respect to X1 and X2.
This is my network:
layers1 = [
sequenceInputLayer(sizeInput,"Name","XTrain1")
fullyConnectedLayer(numHiddenDimension,"Name","fc_1")
softplusLayer('Name','s_1')];
layers2 = [
sequenceInputLayer(sizeInput,"Name","XTrain2")
fullyConnectedLayer(numHiddenDimension,"Name","fc_2")
softplusLayer('Name','s_2')];
lgraph = layerGraph(layers1);
lgraph = addLayers(lgraph,layers2); % connect layers -> 2 in, 1 out
add = additionLayer(2,'Name','add');
lgraph = addLayers(lgraph,add);
lgraph = connectLayers(lgraph,'s_1','add/in1');
lgraph = connectLayers(lgraph,'s_2','add/in2');
fc = fullyConnectedLayer(sizeInput,"Name","fc_3");
lgraph = addLayers(lgraph,fc);
lgraph = connectLayers(lgraph,'add','fc_3');
dlnet = dlnetwork(lgraph);
My
should become my output. Then every iteration, I do:
dlX1 = dlarray(X1,'CTB');
dlX2 = dlarray(X2,'CTB');% to differentiate: dlarray/dlgradient
for i = 1:sizeInput
[gradx1(i), gradx2(i), dlY] = dlfeval(@modelGradientsX,dlnet,dlX1(i),dlX2(i)); % here is where I get my error
end
and I call my function
, which is supposed to get the derivative of my output with respect to my inputs.
, which is supposed to get the derivative of my output with respect to my inputs.function [gradx1, gradx2, dlY] = modelGradientsX(dlnet,dlX1,dlX2)
dlY = forward(dlnet,dlX1,dlX2);
[gradx1, gradx2] = dlgradient(dlY,dlX1,dlX2);
end
And the error I get is: "Input data must be formatted dlarray objects". I have seen similar approaches in other examples (like this one: https://www.mathworks.com/matlabcentral/fileexchange/74760-image-classification-using-cnn-with-multi-input-cnn) so I don't understand - why is
not the correct type of data?
채택된 답변
추가 답변 (1개)
카테고리
도움말 센터 및 File Exchange에서 Operations에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!