customing a layer and error

조회 수: 1 (최근 30일)
JaeHyeon Lee
JaeHyeon Lee 2021년 1월 28일
편집: JaeHyeon Lee 2021년 2월 10일
After customizing the layer, I used the checkLayer function to check the validity of the layer. If the input is (H, W, C) and debugging, the input is expanded to (H, W, C, N). Can you tell me why this is so?
  댓글 수: 1
Shubham Khatri
Shubham Khatri 2021년 2월 3일
It should error out on entering the third parameter as you have not specified its value in the fourth variable. Can you share the reporduction steps?

댓글을 달려면 로그인하십시오.

답변 (1개)

Shubham Khatri
Shubham Khatri 2021년 2월 3일
Hello,
To my understanding the checklayer funtion accepts input in 2 ways.
2 inputs-
checkLayer(layer,validInputSize)
Checks the validity of a custom layer using generated data of the sizes in validInputSize. Variable validInputSize defines the number of inputs to the layer.
4 inputs-
checkLayer(layer,validInputSize,Name,Value)
The third input and fourth input are linked with each other. The third input defines the additional parameter for check and the fourth input defines the value of the third parameter to compare to.
Please refer to the documentation link for checklayer here
Hope it helps.
  댓글 수: 1
JaeHyeon Lee
JaeHyeon Lee 2021년 2월 10일
편집: JaeHyeon Lee 2021년 2월 10일
I made the source code of the layer like this, and when debugging, I could see that the shape of X1, X2, X3 is changed to 4D-dlarray. Can I know specifically why?
The validity of the layer was checkedLayer(layer, {[28 28 128], [28 28 128], [28 28 128]},'ObservationDimension', 4).
classdef nonLocalBlock < nnet.layer.Layer
properties (Learnable)
% Layer learnable parameters
% Scaling coefficients
end
methods
function layer = nonLocalBlock(numInputs,name)
layer.NumInputs = numInputs;
layer.Name = name;
layer.Description = "Non-local Block" ;
end
function Z = predict(~, X1, X2, X3)
X1_size = size(X1);
X2_size = size(X2);
X3_size = size(X3);
[H, W, C, D] = size(X1);
X1 = reshape(X1 ,H*W, C, 1, []);
X2 = reshape(X2, C, H*W, 1, []);
X3 = reshape(X3 ,H*W, C, 1, []);
if D > 1
f = pagemtimes(X1, X2);
else
f = mtimes(X1, X2);
end
f_div_C = softmax(f,'DataFormat','SSCB');
if D > 1
y = pagemtimes(f_div_C, X3);
else
y = mtimes(f_div_C, X3);
end
y = reshape(y, H, W, C, []);
Z = y;
end
end
end

댓글을 달려면 로그인하십시오.

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by