필터 지우기
필터 지우기

How can I define a custom loss function using trainnet?

조회 수: 56 (최근 30일)
Matthew Murray
Matthew Murray 2024년 3월 29일
편집: Matt J 2024년 3월 29일
Hello,
I am trying to define a custom loss function using trainnet. The documentation says:
If the trainnet function does not provide the loss function that you need for your task, then you can specify a custom loss function to the trainnet as a function handle. The function must have the syntax loss = f(Y,T), where Y and T are the predictions and targets, respectively.
However, I am not sure how the predictions and targets are defined here. I am currently using trainnet as follows:
trainedNet = trainnet(dsTrain,layers,"mse",options);
dsTrain is a datastore containing the input and target images for the regression problem. But I would like change the loss to a custom function involving ssim. I would like something similar to the following, although, I know this isn't quite right:
trainedNet = trainnet(dsTrain,layers,@(Y,targets) 1-ssim(Y,targets),options);
I get the following errror message:
Error using trainnet
Value to differentiate is non-scalar. It must be a traced real dlarray scalar.
Thanks!

답변 (1개)

Matt J
Matt J 2024년 3월 29일
편집: Matt J 2024년 3월 29일
If you have multichannel output, the loss function will give you an SSIM per channel, e..g,
loss = @(Y,targets) 1-ssim(Y,targets);
[Y,T]=deal(dlarray(rand(5,4,8),'SSC'));
L=loss(Y,T);
whos L
Name Size Bytes Class Attributes L 1x1x8 70 dlarray
You need to decide how you want this reduced to a single value.

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by