필터 지우기
필터 지우기

Design of a neural network with custom loss

조회 수: 8 (최근 30일)
Ramon Suarez
Ramon Suarez 2024년 2월 19일
답변: Ben 2024년 4월 9일
I would like to design a simple feedforward neural network with 1 input and 2 outputs.
The input is a parameter λ between a predefined range (for instance between -5 and 5) and the output is a vector of two components .
The loss function I would like to implement is given by this expression
where
As it can be seen from the loss definition, this network does not need any target outputs. The objective is to devise a network that predicts the vector that minimizes the loss.
Any help would be greatly appreciated!
------------
I have read the following threads that talk about customizing a loss function
I also read the response from https://uk.mathworks.com/matlabcentral/answers/1763785-how-to-customize-performance-functions-neural-network#answer_1043515 that indicates a new way of doing this with Custom Training Loops, but I have not been successful to implement this for my problem at hand.

답변 (1개)

Ben
Ben 2024년 4월 9일
The term is minimized if , which is a linear problem as you've stated, so you can actually use classic methods to solve this for x.
A = @(lambda) [lambda^2 + 1, lambda; lambda, 1];
f = @(lambda) [lambda; 1-lambda];
x = @(lambda) A(lambda)\f(lambda);
lambda = 0.123; % random choice
xlambda = x(lambda);
A(lambda)*xlambda - f(lambda) % returns [0;0], i.e. exact solution.
If you still want to model for a neural net N, you will have to use a custom training loop, since your loss is unsupervised, and trainNetwork / trainnet work for supervised training. You can write a custom training loop as follows, however note that I was unable to get this to train well, and certainly not as fast as computing the solution as above.
net = [featureInputLayer(1)
fullyConnectedLayer(100)
reluLayer
fullyConnectedLayer(2)];
net = dlnetwork(net);
lambda = dlarray(linspace(-5,5,10000),"CB");
maxIters = 10000;
vel = [];
lr = 1e-4;
lossFcn = dlaccelerate(@modelLoss);
for iter = 1:maxIters
[loss,grad] = dlfeval(lossFcn,net,lambda);
fprintf("Iter: %d, Loss: %.4f\n",iter,extractdata(loss));
[net,vel] = sgdmupdate(net,grad,vel,lr);
end
function [loss,grad] = modelLoss(net,lambda)
% Permute lambda to 1x1xBatchSize
x = forward(net,lambda);
x = stripdims(x);
x = permute(x,[1,3,2]);
lambda = stripdims(lambda);
lambda = permute(lambda,[1,3,2]);
A = [lambda.^2 + 1, lambda; lambda, ones(1,1,size(lambda,3),like=lambda)];
Ax = pagemtimes(A,x);
f = [lambda;1-lambda];
loss = l2loss(Ax,f,DataFormat="CUB");
grad = dlgradient(loss,net.Learnables);
end
If you can use a linear solve method as above, but need it to be autodiff compatible, you can use pinv which is supported by dlarray in R2024a:
A = @(lambda) [lambda^2 + 1, lambda; lambda, 1];
f = @(lambda) [lambda; 1-lambda];
x = @(lambda) pinv(A(lambda))*f(lambda);
This supports auto-diff with dlarray, so you can compute things like .
% x supports auto-diff, e.g. we can compute dx/dlambda
function dxidlambda(lambda,i)
A = @(lambda) [lambda^2 + 1, lambda; lambda, 1];
f = @(lambda) [lambda; 1-lambda];
x = @(lambda) pinv(A(lambda))*f(lambda);
xlambda = x(lambda);
xlambdai = xlambda(i);
dxidlambda = dlgradient(xlambdai,lambda);
end
lambda0 = dlarray(0.123);
dx1dlambda = dlfeval(@dxidlambda, lambda0, 1)

카테고리

Help CenterFile Exchange에서 Get Started with Deep Learning Toolbox에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by