How to add Distance transformation Map in loss function at classification layer.

조회 수: 1 (최근 30일)
Raza Ali
Raza Ali 2020년 8월 27일
댓글: Raza Ali 2020년 9월 10일
Hi Everyone, I am trying to insert distance map information in loss fucntion. i am doing this in classification layer of CNN.
but when I calculate the disance map using "bwdist(Y)" commnad, during training process the MATLAB produce error
"Error using 'forwardLoss' in Layer ClassificationLayer. The function threw an error and could not be executed".
"Expected input image to be a 2-D real-valued, non-sparse gpuArray with underlying class uint8, uint16, uint32, int8, int16, int32, logical, single or double".
How can I add Distance transformation Map in loss fucntion. or how to resolve this issue?
  댓글 수: 3
Raza Ali
Raza Ali 2020년 9월 10일
The classification layer code:
%%%%%%%%%%%%%%%%%%
classdef CEDLossLayer < nnet.layer.ClassificationLayer
properties
% Row vector of weights corresponding to the classes in the
% training data.
Beta=0.7;
end
methods
function layer = CEDLossLayer(name)
% layer = CEDLossLayer(name) creates a
% Set layer name.
if nargin == 2
layer.Name = name;
end
% Set layer description
layer.Description = 'cross entropy';
end
function loss = forwardLoss(layer, Y, T)
% loss = forwardLoss(layer, Y, T) returns the cross entropy loss between the predictions Y and the training
% targets T.
N = size(Y, 4) * size(Y, 1) * size(Y, 2);
Y = squeeze(Y);
T = squeeze(T);
addpath ('D:\Ditsnave Transform')
W=WeightMap(T(:,:,1));
W = gather(W);
loss_i = ((layer.Beta).*W.*T .* log(nnet.internal.cnn.util.boundAwayFromZero(Y)))+((1-layer.Beta).*(1-W).*(1-T) .* log(1-(nnet.internal.cnn.util.boundAwayFromZero(Y))));
loss = -sum( sum( sum( sum(loss_i, 3).*(1./N), 1), 2));
end
function dLdY = backwardLoss(layer, Y, T)
% dLdX = backwardLoss(layer, Y, T) returns the derivatives of
% cross entropy loss with respect to the
% predictions Y.
N = size(Y, 4) * size(Y, 1) * size(Y, 2);
Y = squeeze(Y);
T = squeeze(T);
addpath ('D:\Ditsnave Transform')
W=WeightMap(T(:,:,1));
W = gather(W);
dLdY= (-(W.*T./nnet.internal.cnn.util.boundAwayFromZero(Y))).*(1./N);
% dLdY= -(1./N).*((((layer.Beta).*T)./nnet.internal.cnn.util.boundAwayFromZero(Y))-((1-layer.Beta).*(1-T))./(1-nnet.internal.cnn.util.boundAwayFromZero(Y)));
end
end
end
Raza Ali
Raza Ali 2020년 9월 10일
%% Weight Map Function
function [weight]=WeightMap(gt);
% class balance weights w_c(x)
uvals=unique(gt);
wmp=zeros(1,length(uvals));
for uv=1:length(uvals)
wmp(uv)=1/sum(gt(:)==uvals(uv));
end
% this normalization is important!
%background pixels must have weight 1
wmp=wmp/max(wmp);
% wc=double(gt);
% wc=uint8(gt);
wc=zeros(size(gt));
for uv=1:length(uvals)
wc(gt==uvals(uv))=wmp(uv);
end
% cells instances for distance computation
cells=bwlabel(gt==1, 4);
% cells distance map
bwgt=zeros(size(gt));
maps=zeros(size(gt,1),size(gt,2),max(cells(:)));
if max(cells(:))>=2
for ci=1:max(cells(:))
maps(:,:,ci)=bwdist(cells==ci);
end
maps=sort(maps,3);
d1=maps(:,:,1);
d2=maps(:,:,2);
bwgt=10*exp(-((d1+d2).^2)./(2*25) ).*(cells==0)
end
% unet weights
weight=wc + bwgt;
end

댓글을 달려면 로그인하십시오.

답변 (0개)

카테고리

Help CenterFile Exchange에서 Get Started with Statistics and Machine Learning Toolbox에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by