add more options to gruLayer's GateActivationFunction

조회 수: 5 (최근 30일)
Hamza
Hamza 2023년 2월 28일
답변: Ben 2023년 3월 13일
Greetings,
I am trying to train a GRU RNN as a regression network to estimate a time series signal. I want to see the effect of changing the GateActivationFunction but I am limited with two options " sigmoid, and hard-sigmoid". I try to add more options "tanh, and radbasn" to the following files "GRULayer, gruForwardGeneral, and gruLayer", the modified files are attached below. when I am trying to run the code with " gruLayer(Hiddenlayers1,'Name','gru1','OutputMode','sequence','StateActivationFunction','tanh','GateActivationFunction','tanh'). I will see an error telling me that I am limited with the two options. My question is, Is there anyway to add more options?, if yes what are the other files that I most edit?.
Thanks,
Hamza Al Kouzbary
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%gruLayer%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function layer = gruLayer(varargin)
%gruLayer Gated recurrent unit layer
%
% layer = gruLayer(numHiddenUnits) creates a Gated Recurrent Unit layer.
% numHiddenUnits is the number of hidden units in the layer, specified as
% a positive integer.
%
% layer = gruLayer(numHiddenUnits, 'PARAM1', VAL1, 'PARAM2', VAL2, ...)
% specifies optional parameter name/value pairs for creating the layer:
%
% 'Name' - Name for the layer, specified
% as a character vector or a
% string. The default value is
% ''.
% 'InputWeights' - Input weights, specified by a
% 3*numHiddenUnits-by-D matrix or
% [], where D is the number of
% features of the input data. The
% default is [].
% 'RecurrentWeights' - Recurrent weights, specified as
% a 3*numHiddenUnits-by-
% numHiddenUnits
% matrix or []. The default is
% [].
% 'Bias' - Layer biases, specified as a
% 3*numHiddenUnits-by-1 vector, a
% 6*numHiddenUnits-by-1 vector,
% or []. The default is [].
% 'HiddenState' - Initial hidden state, specified
% as a numHiddenUnits-by-1 vector
% or []. The default is [].
% 'OutputMode' - The format of the output of the
% layer. Options are:
% - 'sequence', to output a
% full sequence.
% - 'last', to output the
% last element only.
% The default value is
% 'sequence'.
% 'StateActivationFunction' - Activation function to update
% the hidden state.
% Options are:
% - 'tanh'
% - 'softsign'
% The default value is 'tanh'.
% 'GateActivationFunction' - Activation function to apply to
% the gates. Options are:
% - 'sigmoid'
% - 'tanh'
% - 'radbasn'
% - 'hard-sigmoid'
% The default value is 'sigmoid'.
% 'InputWeightsLearnRateFactor' - Multiplier for the learning
% rate of the input weights,
% specified as a scalar or a
% three-element vector. The
% default value is 1.
% 'RecurrentWeightsLearnRateFactor' - Multiplier for the learning
% rate of the recurrent weights,
% specified as a scalar or a
% three-element vector. The
% default value is 1.
% 'BiasLearnRateFactor' - Multiplier for the learning
% rate of the bias, specified as
% a scalar or a three-element
% vector. The default value is 1.
% 'InputWeightsL2Factor' - Multiplier for the L2
% regularizer of the input
% weights, specified as a scalar
% or a three-element vector. The
% default value is 1.
% 'RecurrentWeightsL2Factor' - Multiplier for the L2
% regularizer of the recurrent
% weights, specified as a scalar
% or a three-element vector. The
% default value is 1.
% 'BiasL2Factor' - Multiplier for the L2
% regularizer of the bias,
% specified as a scalar or a
% three-element vector. The
% default value is 0.
% 'InputWeightsInitializer' - The function to initialize the
% input weights, specified as
% 'glorot', 'he', 'orthogonal',
% 'narrow-normal', 'zeros',
% 'ones' or a function handle.
% The default is 'glorot'.
% 'RecurrentWeightsInitializer' - The function to initialize the
% recurrent weights, specified as
% 'glorot', 'he', 'orthogonal',
% 'narrow-normal', 'zeros',
% 'ones' or a function handle.
% The default is 'orthogonal'.
% 'BiasInitializer' - The function to initialize the
% bias, specified as 'zeros',
% 'narrow-normal', 'ones' or a
% function handle. The default is
% 'zeros'.
% 'ResetGateMode' - Reset gate mode, specified as
% one of the following:
% - 'after-multiplication',
% apply reset gate after
% matrix multiplication. This
% option uses the cuDNN
% library when running on
% GPU.
% - 'before-multiplication',
% apply reset gate before
% matrix multiplication.
% - 'recurrent-bias-after-multiplication',
% apply reset gate after
% matrix multiplication and
% use recurrent bias.
% The default value is
% 'after-multiplication'.
%
% Example 1:
% % Create a GRU layer with 100 hidden units.
%
% layer = gruLayer(100);
%
% Example 2:
% % Create a GRU layer with 50 hidden units which returns the last
% % output element of the sequence. Manually initialize the recurrent
% % weights from a Gaussian distribution with standard deviation
% % 0.01.
%
% numHiddenUnits = 50;
% layer = gruLayer(numHiddenUnits, 'OutputMode', 'last', ...
% 'RecurrentWeights', randn([3*numHiddenUnits numHiddenUnits])*0.01);
%
% See also nnet.cnn.layer.GRULayer
%
% <a href="matlab:helpview('deeplearning','list_of_layers')">List of Deep Learning Layers</a>
% Copyright 2019-2020 The MathWorks, Inc.
% Parse the input arguments.
varargin = nnet.internal.cnn.layer.util.gatherParametersToCPU(varargin);
args = nnet.cnn.layer.GRULayer.parseInputArguments(varargin{:});
% Create an internal representation of the layer.
internalLayer = nnet.internal.cnn.layer.GRU(args.Name, ...
args.InputSize, ...
args.NumHiddenUnits, ...
true, ...
iGetReturnSequence(args.OutputMode), ...
args.StateActivationFunction, ...
args.GateActivationFunction, ...
args.ResetGateMode);
% Use the internal layer to construct a user visible layer.
layer = nnet.cnn.layer.GRULayer(internalLayer);
% Set learnable parameters, learn rate, L2 factors and initializers.
layer.InputWeights = args.InputWeights;
layer.InputWeightsL2Factor = args.InputWeightsL2Factor;
layer.InputWeightsLearnRateFactor = args.InputWeightsLearnRateFactor;
layer.InputWeightsInitializer = args.InputWeightsInitializer;
layer.RecurrentWeights = args.RecurrentWeights;
layer.RecurrentWeightsL2Factor = args.RecurrentWeightsL2Factor;
layer.RecurrentWeightsLearnRateFactor = args.RecurrentWeightsLearnRateFactor;
layer.RecurrentWeightsInitializer = args.RecurrentWeightsInitializer;
layer.Bias = args.Bias;
layer.BiasL2Factor = args.BiasL2Factor;
layer.BiasLearnRateFactor = args.BiasLearnRateFactor;
layer.BiasInitializer = args.BiasInitializer;
% Set hidden state state.
layer.HiddenState = args.HiddenState;
end
function tf = iGetReturnSequence( mode )
tf = true;
if strcmp( mode, 'last' )
tf = false;
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%gruForwardGeneral%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [h, H] = gruForwardGeneral(X,learnable,state,options)
% gruForwardGeneral implementation for gru forward call, see
% https://arxiv.org/abs/1406.1078v1
% Copyright 2019 The MathWorks, Inc.
W = learnable.W;
R = learnable.R;
b = learnable.b;
h0 = state.h0;
% Determine dimensions
numHidden = size(R,2);
% Determine dimensions
[~, N, T] = size(X);
% Indexing helpers
[rInd, zInd, hInd] = nnet.internal.cnn.util.gruGateIndices(numHidden);
% Input weights
Wrz = W([rInd,zInd], :);
Wh = W(hInd, :);
% Recurrent weights
Rrz = R([rInd,zInd], :);
Rh = R(hInd, :);
% Biases
brz = b([rInd,zInd], :);
bh = b(hInd, :);
% Pre-allocate hidden state
h = zeros(numHidden, N, T, 'like', X);
if isstring(options.StateActivationFunction) || ischar(options.StateActivationFunction)
stateActivationFunction = iGetStateActivation( options.StateActivationFunction );
elseif isa(options.StateActivationFunction,'function_handle')
stateActivationFunction = options.StateActivationFunction;
end
if isstring(options.GateActivationFunction) || ischar(options.GateActivationFunction)
gateActivationFunction = iGetGateActivation( options.GateActivationFunction );
elseif isa(options.GateActivationFunction,'function_handle')
gateActivationFunction = options.GateActivationFunction;
end
% First iteration of forward loop
% Update r and z gates
rz = gateActivationFunction( Wrz*X(:, :, 1) + Rrz*h0 + brz );
r = rz(rInd, :);
z = rz(zInd, :);
% Compute candidate state hs
hs = stateActivationFunction( Wh*X(:, :, 1) + r.*(Rh*h0) + bh );
% Update hidden state h
h(:, :, 1) = (1 - z).*hs + z.*h0;
% Main forward loop
for tt = 2:T
hIdx = h(:, :, tt-1);
% Update r and z gates
rz = gateActivationFunction( Wrz*X(:, :, tt) + Rrz*hIdx + brz );
r = rz(rInd, :);
z = rz(zInd, :);
% Compute candidate state hs
hs = stateActivationFunction( Wh*X(:, :, tt) + r.*(Rh*hIdx) + bh );
% Update hidden state h
h(:, :, tt) = (1 - z).*hs + z.*hIdx;
end
if options.ReturnLast
H = h;
h = h(:, :, end);
else
H = h(:, :, end);
end
end
%% Helper functions
function act = iGetStateActivation( activation )
switch activation
case 'tanh'
act = @nnet.internal.cnnhost.tanhForward;
case 'softsign'
act = @iSoftSign;
end
end
function act = iGetGateActivation( activation )
switch activation
case 'sigmoid'
act = @nnet.internal.cnnhost.sigmoidForward;
case 'hard-sigmoid'
act = @nnet.internal.cnnhost.hardSigmoidForward;
case 'tanh'
act = @nnet.internal.cnnhost.tanhForward;
case 'radbasn'
act = @nnet.internal.cnnhost.hardSigmoidForward;
end
end
%% Activation functions
function y = iSoftSign(x)
y = x./(1 + abs(x));
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%GRULayer%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
classdef GRULayer < nnet.cnn.layer.Layer & nnet.internal.cnn.layer.Externalizable
% GRULayer Gated Recurrent Unit (GRU) layer
%
% To create a GRU layer, use gruLayer.
%
% GRULayer properties:
% Name - Name of the layer
% InputSize - Input size of the layer
% NumHiddenUnits - Number of hidden units in the layer
% OutputMode - Output as sequence or last
% StateActivationFunction - Activation function to
% update the hidden state
% GateActivationFunction - Activation function to
% apply to the gates
% ResetGateMode - ResetGateMode - Reset gate
% mode. Apply reset gate
% before or after matrix
% multiplication, and with or
% without recurrent bias.
% NumInputs - The number of inputs for
% the layer.
% InputNames - The names of the inputs of
% the layer.
% NumOutputs - The number of outputs of
% the layer.
% OutputNames - The names of the outputs of
% the layer.
%
% Properties for learnable parameters:
% InputWeights - Input weights
% InputWeightsInitializer - The function for
% initializing the input
% weights.
% InputWeightsLearnRateFactor - Learning rate multiplier
% for the input weights
% InputWeightsL2Factor - L2 multiplier for the
% input weights
%
% RecurrentWeights - Recurrent weights
% RecurrentWeightsInitializer - The function for
% initializing the recurrent
% weights.
% RecurrentWeightsLearnRateFactor - Learning rate multiplier
% for the recurrent weights
% RecurrentWeightsL2Factor - L2 multiplier for the
% recurrent weights
%
% Bias - Bias vector
% BiasInitializer - The function for
% initializing the bias.
% BiasLearnRateFactor - Learning rate multiplier
% for the bias
% BiasL2Factor - L2 multiplier for the bias
%
% State parameters:
% HiddenState - Hidden state vector
%
% Example:
% Create a Gated Recurrent Unit layer.
%
% layer = gruLayer(10)
%
% See also gruLayer
% Copyright 2019-2020 The MathWorks, Inc.
properties(Dependent)
% Name A name for the layer
% The name for the layer. If this is set to '', then a name will
% be automatically set at training time.
Name
end
properties(SetAccess = private, Dependent)
% InputSize The input size for the layer. If this is set to
% 'auto', then the input size will be automatically set during
% training
InputSize
% NumHiddenUnits The number of hidden units in the layer
NumHiddenUnits
% OutputMode The output format of the layer. If 'sequence',
% output is a sequence. If 'last', the output is the last element
% in a sequence
OutputMode
% StateActivationFunction The activation function to update the
% hidden state. Valid options are 'tanh' or 'softsign'. The default
% is 'tanh'.
StateActivationFunction
% GateActivationFunction The activation function to apply to the
% gates. Valid options are 'sigmoid' or 'hard-sigmoid'. The default
% is 'sigmoid'.
GateActivationFunction
% ResetGateMode Reset gate mode, specified as one of the
% following:
% 'after-multiplication' - apply reset gate after matrix
% multiplication. With this option the bias has size
% 3*numHiddenUnits-by-1. This option is cuDNN compatible.
% 'before-multiplication' - apply reset gate before matrix
% multiplication. With this option the bias has size
% 3*numHiddenUnits-by-1.
% 'recurrent-bias-after-multiplication' - apply reset gate
% after matrix multiplication and use a recurrent bias. With
% this option the bias has size 6*numHiddenUnits-by-1. This
% option is cuDNN compatible.
ResetGateMode
end
properties(Dependent)
% InputWeights The input weights for the layer
% The input weight matrix for the GRU layer. The input weight
% matrix is a vertical concatenation of the three "gate" input
% weight matrices in the forward pass of a GRU. Those individual
% matrices are concatenated in the following order: update gate,
% reset gate, output gate. This matrix will have size
% 3*NumHiddenUnits-by-InputSize.
InputWeights
% InputWeightsInitializer The function for initializing the
% input weights.
InputWeightsInitializer
% InputWeightsLearnRateFactor The learning rate factor for the
% input weights
% The learning rate factor for the input weights. This factor is
% multiplied with the global learning rate to determine the
% learning rate for the input weights in this layer. For example,
% if it is set to 2, then the learning rate for the input weights
% in this layer will be twice the current global learning rate.
% To control the value of the learn rate for the three individual
% matrices in the InputWeights, a 1-by-3 vector can be assigned.
InputWeightsLearnRateFactor (1,:) {mustBeNumeric, iCheckFactorDimensions}
% InputWeightsL2Factor The L2 regularization factor for the input
% weights
% The L2 regularization factor for the input weights. This factor
% is multiplied with the global L2 regularization setting to
% determine the L2 regularization setting for the input weights
% in this layer. For example, if it is set to 2, then the L2
% regularization for the input weights in this layer will be
% twice the global L2 regularization setting. To control the
% value of the L2 factor for the three individual matrices in the
% InputWeights, a 1-by-3 vector can be assigned.
InputWeightsL2Factor (1,:) {mustBeNumeric, iCheckFactorDimensions}
% RecurrentWeights The recurrent weights for the layer
% The recurrent weight matrix for the GRU layer. The recurrent
% weight matrix is a vertical concatenation of the three "gate"
% recurrent weight matrices in the forward pass of a GRU. Those
% individual matrices are concatenated in the following order:
% update gate, reset gate, output gate. This matrix will have
% size 3*NumHiddenUnits-by-NumHiddenUnits.
RecurrentWeights
% RecurrentWeightsInitializer The function for initializing the
% recurrent weights.
RecurrentWeightsInitializer
% RecurrentWeightsLearnRateFactor The learning rate factor for
% the recurrent weights
% The learning rate factor for the recurrent weights. This factor
% is multiplied with the global learning rate to determine the
% learning rate for the recurrent weights in this layer. For
% example, if it is set to 2, then the learning rate for the
% recurrent weights in this layer will be twice the current
% global learning rate. To control the value of the learn rate
% for the three individual matrices in the RecurrentWeights, a
% 1-by-3 vector can be assigned.
RecurrentWeightsLearnRateFactor (1,:) {mustBeNumeric, iCheckFactorDimensions}
% RecurrentWeightsL2Factor The L2 regularization factor for the
% recurrent weights
% The L2 regularization factor for the recurrent weights. This
% factor is multiplied with the global L2 regularization setting
% to determine the L2 regularization setting for the recurrent
% weights in this layer. For example, if it is set to 2, then the
% L2 regularization for the recurrent weights in this layer will
% be twice the global L2 regularization setting. To control the
% value of the L2 factor for the three individual matrices in the
% RecurrentWeights, a 1-by-3 vector can be assigned.
RecurrentWeightsL2Factor (1,:) {mustBeNumeric, iCheckFactorDimensions}
% Bias The biases for the layer
% The bias vector for the GRU layer. The bias vector is a
% concatenation of the three "gate" bias vectors in the forward
% pass of a GRU. Those individual vectors are concatenated in
% the following order: update gate, reset gate, output gate. This
% vector will have size 3*NumHiddenUnits-by-1.
Bias
% BiasInitializer The function for initializing the bias
BiasInitializer
% BiasLearnRateFactor The learning rate factor for the biases
% The learning rate factor for the bias. This factor is
% multiplied with the global learning rate to determine the
% learning rate for the bias in this layer. For example, if it is
% set to 2, then the learning rate for the bias in this layer
% will be twice the current global learning rate. To control the
% value of the learn rate for the three individual vectors in the
% Bias, a 1-by-3 vector can be assigned.
BiasLearnRateFactor (1,:) {mustBeNumeric, iCheckFactorDimensions}
% BiasL2Factor The L2 regularization factor for the biases
% The L2 regularization factor for the biases. This factor is
% multiplied with the global L2 regularization setting to
% determine the L2 regularization setting for the biases in this
% layer. For example, if it is set to 2, then the L2
% regularization for the biases in this layer will be twice the
% global L2 regularization setting. To control the value of the
% L2 factor for the three individual vectors in the Bias, a
% 1-by-3 vector can be assigned.
BiasL2Factor (1,:) {mustBeNumeric, iCheckFactorDimensions}
end
properties(Dependent)
% HiddenState The initial value of the hidden state.
% The initial value of the hidden state. This vector will have
% size NumHiddenUnits-by-1. Setting this value sets the default
% value to which the hidden state is reset to when calling the
% resetState method of SeriesNetwork.
HiddenState
end
properties(SetAccess = private, Hidden, Dependent)
% OutputSize The number of hidden units in the layer. See
% NumHiddenUnits.
OutputSize
% OutputState The hidden state of the layer. See HiddenState.
OutputState
end
methods
function this = GRULayer(privateLayer)
this.PrivateLayer = privateLayer;
end
function val = get.Name(this)
val = this.PrivateLayer.Name;
end
function this = set.Name(this, val)
iAssertValidLayerName(val);
this.PrivateLayer.Name = char(val);
end
function val = get.InputSize(this)
val = this.PrivateLayer.InputSize;
if isempty(val)
val = 'auto';
end
end
function val = get.NumHiddenUnits(this)
val = this.PrivateLayer.HiddenSize;
end
function val = get.OutputMode(this)
val = iGetOutputMode( this.PrivateLayer.ReturnSequence );
end
function val = get.StateActivationFunction(this)
val = this.PrivateLayer.Activation;
end
function val = get.GateActivationFunction(this)
val = this.PrivateLayer.RecurrentActivation;
end
function val = get.InputWeights(this)
val = this.PrivateLayer.InputWeights.HostValue;
if isa(val, 'dlarray')
val = extractdata(val);
end
end
function val = get.ResetGateMode(this)
val = this.PrivateLayer.ResetGateMode;
end
function this = set.InputWeights(this, value)
if isequal(this.InputSize, 'auto')
expectedInputSize = NaN;
else
expectedInputSize = this.InputSize;
end
attributes = {'size', [3*this.NumHiddenUnits expectedInputSize],...
'real', 'nonsparse'};
value = iGatherAndValidateParameter(value, attributes);
if ~isempty(value)
this.PrivateLayer = this.PrivateLayer.configureForInputs( ...
{iMakeSizeOnlyArray([size(value,2) NaN NaN],'CBT')} );
end
this.PrivateLayer.InputWeights.Value = value;
end
function val = get.InputWeightsInitializer(this)
if iIsCustomInitializer(this.PrivateLayer.InputWeights.Initializer)
val = this.PrivateLayer.InputWeights.Initializer.Fcn;
else
val = this.PrivateLayer.InputWeights.Initializer.Name;
end
end
function this = set.InputWeightsInitializer(this, value)
value = iAssertValidWeightsInitializer(value, 'InputWeightsInitializer');
% Create the initializer with in and out indices of the weights
% size: 3*NumHiddenUnits-by-InputSize
this.PrivateLayer.InputWeights.Initializer = ...
iInitializerFactory(value, 2, 1);
end
function val = get.RecurrentWeights(this)
val = this.PrivateLayer.RecurrentWeights.HostValue;
if isa(val, 'dlarray')
val = extractdata(val);
end
end
function this = set.RecurrentWeights(this, value)
attributes = {'size', [3*this.NumHiddenUnits this.NumHiddenUnits],...
'real', 'nonsparse'};
value = iGatherAndValidateParameter(value, attributes);
this.PrivateLayer.RecurrentWeights.Value = value;
end
function val = get.RecurrentWeightsInitializer(this)
if iIsCustomInitializer(this.PrivateLayer.RecurrentWeights.Initializer)
val = this.PrivateLayer.RecurrentWeights.Initializer.Fcn;
else
val = this.PrivateLayer.RecurrentWeights.Initializer.Name;
end
end
function this = set.RecurrentWeightsInitializer(this, value)
value = iAssertValidWeightsInitializer(value, 'RecurrentWeightsInitializer');
% Create the initializer with in and out indices of the weights
% size: 3*NumHiddenUnits-by-NumHiddenUnits
this.PrivateLayer.RecurrentWeights.Initializer = ...
iInitializerFactory(value, 2, 1);
end
function val = get.Bias(this)
val = this.PrivateLayer.Bias.HostValue;
if isa(val, 'dlarray')
val = extractdata(val);
end
end
function this = set.Bias(this, value)
biasnrowfactor = 1 + double(isequal(this.ResetGateMode, ...
'recurrent-bias-after-multiplication'));
attributes = {'column', 'real', 'nonsparse'};
value = iGatherAndValidateParameter(value, attributes);
expectedSize = 3*biasnrowfactor*this.NumHiddenUnits;
% Valid input value is empty or has size either
% 3*NumHiddenUnits, if 'ResetGateMode' is
% 'after-multiplication' or 'before-multiplication', or
% 6*NumHiddenUnits, if 'ResetGateMode' is
% 'recurrent-bias-after-multiplication'.
if length(value)~=expectedSize && ~isequal(value,[])
error(message('nnet_cnn:layer:GRULayer:BiasSize',...
3*biasnrowfactor,this.ResetGateMode));
end
this.PrivateLayer.Bias.Value = value;
end
function val = get.BiasInitializer(this)
if iIsCustomInitializer(this.PrivateLayer.Bias.Initializer)
val = this.PrivateLayer.Bias.Initializer.Fcn;
else
val = this.PrivateLayer.Bias.Initializer.Name;
end
end
function this = set.BiasInitializer(this, value)
value = iAssertValidBiasInitializer(value);
% The Bias initializer needs to know which recurrent type
this.PrivateLayer.Bias.Initializer = iInitializerFactory(value,...
'GRU');
end
function val = get.HiddenState(this)
val = gather(this.PrivateLayer.HiddenState.Value);
end
function this = set.HiddenState(this, value)
value = iGatherAndValidateParameter(value, 'default', [this.NumHiddenUnits 1]);
this.PrivateLayer.InitialHiddenState = value;
this.PrivateLayer.HiddenState.Value = value;
end
function val = get.InputWeightsLearnRateFactor(this)
val = this.getFactor(this.PrivateLayer.InputWeights.LearnRateFactor);
end
function this = set.InputWeightsLearnRateFactor(this, val)
val = gather(val);
iAssertValidFactor(val)
this.PrivateLayer.InputWeights.LearnRateFactor = this.setFactor(val);
end
function val = get.InputWeightsL2Factor(this)
val = this.getFactor(this.PrivateLayer.InputWeights.L2Factor);
end
function this = set.InputWeightsL2Factor(this, val)
val = gather(val);
iAssertValidFactor(val)
this.PrivateLayer.InputWeights.L2Factor = this.setFactor(val);
end
function val = get.RecurrentWeightsLearnRateFactor(this)
val = this.getFactor(this.PrivateLayer.RecurrentWeights.LearnRateFactor);
end
function this = set.RecurrentWeightsLearnRateFactor(this, val)
val = gather(val);
iAssertValidFactor(val)
this.PrivateLayer.RecurrentWeights.LearnRateFactor = this.setFactor(val);
end
function val = get.RecurrentWeightsL2Factor(this)
val = this.getFactor(this.PrivateLayer.RecurrentWeights.L2Factor);
end
function this = set.RecurrentWeightsL2Factor(this, val)
val = gather(val);
iAssertValidFactor(val)
this.PrivateLayer.RecurrentWeights.L2Factor = this.setFactor(val);
end
function val = get.BiasLearnRateFactor(this)
val = this.getFactor(this.PrivateLayer.Bias.LearnRateFactor);
end
function this = set.BiasLearnRateFactor(this, val)
val = gather(val);
iAssertValidFactor(val)
this.PrivateLayer.Bias.LearnRateFactor = this.setFactor(val);
end
function val = get.BiasL2Factor(this)
val = this.getFactor(this.PrivateLayer.Bias.L2Factor);
end
function this = set.BiasL2Factor(this, val)
val = gather(val);
iAssertValidFactor(val)
this.PrivateLayer.Bias.L2Factor = this.setFactor(val);
end
function val = get.OutputSize(this)
val = this.NumHiddenUnits;
end
function val = get.OutputState(this)
val = this.HiddenState;
end
function out = saveobj(this)
privateLayer = this.PrivateLayer;
out.Version = 1.0;
out.Name = privateLayer.Name;
out.InputSize = privateLayer.InputSize;
out.NumHiddenUnits = privateLayer.HiddenSize;
out.ReturnSequence = privateLayer.ReturnSequence;
out.ResetGateMode = privateLayer.ResetGateMode;
out.StateActivationFunction = privateLayer.Activation;
out.GateActivationFunction = privateLayer.RecurrentActivation;
out.InputWeights = toStruct(privateLayer.InputWeights);
out.RecurrentWeights = toStruct(privateLayer.RecurrentWeights);
out.Bias = toStruct(privateLayer.Bias);
out.HiddenState = toStruct(privateLayer.HiddenState);
out.InitialHiddenState = gather(privateLayer.InitialHiddenState);
end
end
methods(Static)
function inputArguments = parseInputArguments(varargin)
parser = iCreateParser();
parser.parse(varargin{:});
inputArguments = iConvertToCanonicalForm(parser);
inputArguments.InputSize = [];
end
function this = loadobj(in)
internalLayer = nnet.internal.cnn.layer.GRU( in.Name, ...
in.InputSize, ...
in.NumHiddenUnits, ...
true, ...
in.ReturnSequence, ...
in.StateActivationFunction, ...
in.GateActivationFunction, ...
in.ResetGateMode );
internalLayer.InputWeights = nnet.internal.cnn.layer.learnable.PredictionLearnableParameter.fromStruct(in.InputWeights);
internalLayer.RecurrentWeights = nnet.internal.cnn.layer.learnable.PredictionLearnableParameter.fromStruct(in.RecurrentWeights);
internalLayer.Bias = nnet.internal.cnn.layer.learnable.PredictionLearnableParameter.fromStruct(in.Bias);
internalLayer.HiddenState = nnet.internal.cnn.layer.dynamic.TrainingDynamicParameter.fromStruct(in.HiddenState);
internalLayer.InitialHiddenState = in.InitialHiddenState;
this = nnet.cnn.layer.GRULayer(internalLayer);
end
end
methods(Hidden, Access = protected)
function [description, type] = getOneLineDisplay(obj)
description = iGetMessageString( ...
'nnet_cnn:layer:GRULayer:oneLineDisplay', ...
num2str(obj.NumHiddenUnits));
type = iGetMessageString( 'nnet_cnn:layer:GRULayer:Type' );
end
function groups = getPropertyGroups( this )
generalParameters = { 'Name' };
hyperParameters = { 'InputSize', ...
'NumHiddenUnits', ...
'OutputMode', ...
'StateActivationFunction', ...
'GateActivationFunction', ...
'ResetGateMode'};
learnableParameters = { 'InputWeights', ...
'RecurrentWeights', ...
'Bias' };
stateParameters = { 'HiddenState' };
groups = [
this.propertyGroupGeneral( generalParameters )
this.propertyGroupHyperparameters( hyperParameters )
this.propertyGroupLearnableParameters( learnableParameters )
this.propertyGroupDynamicParameters( stateParameters )
];
end
function footer = getFooter( this )
variableName = inputname(1);
footer = this.createShowAllPropertiesFooter( variableName );
end
function val = getFactor(this, val)
if isscalar(val)
% No operation needed
elseif numel(val) == (3*this.NumHiddenUnits)
val = val(1:this.NumHiddenUnits:end);
val = val(:)';
else
% Error - the factor has incorrect size
end
end
function val = setFactor(this, val)
if isscalar(val)
% No operation needed
elseif numel(val) == 3
% Expand a three-element vector into a 3*NumHiddenUnits-by-1
% column vector
expandedValues = repelem( val, this.NumHiddenUnits );
val = expandedValues(:);
else
% Error - the factor has incorrect size
end
end
end
end
function messageString = iGetMessageString( varargin )
messageString = getString( message( varargin{:} ) );
end
function p = iCreateParser()
p = inputParser;
defaultName = '';
defaultOutputMode = 'sequence';
defaultStateActivationFunction = 'tanh';
defaultGateActivationFunction = 'sigmoid';
defaultWeightLearnRateFactor = 1;
defaultBiasLearnRateFactor = 1;
defaultWeightL2Factor = 1;
defaultBiasL2Factor = 0;
defaultInputWeightsInitializer = 'glorot';
defaultRecurrentWeightsInitializer = 'orthogonal';
defaultBiasInitializer = 'zeros';
defaultLearnable = [];
defaultState = [];
defaultResetGateMode = 'after-multiplication';
p.addRequired('NumHiddenUnits', @(x)validateattributes(x, {'numeric'}, {'scalar', 'positive', 'integer'}));
p.addParameter('Name', defaultName, @nnet.internal.cnn.layer.paramvalidation.validateLayerName);
p.addParameter('OutputMode', defaultOutputMode, @(x)any(iAssertAndReturnValidOutputMode(x)));
p.addParameter('StateActivationFunction', defaultStateActivationFunction, @(x)any(iAssertAndReturnValidStateActivation(x)));
p.addParameter('GateActivationFunction', defaultGateActivationFunction, @(x)any(iAssertAndReturnValidGateActivation(x)));
p.addParameter('InputWeightsLearnRateFactor', defaultWeightLearnRateFactor, @(x)iAssertValidFactor(x));
p.addParameter('RecurrentWeightsLearnRateFactor', defaultWeightLearnRateFactor,@(x)iAssertValidFactor(x));
p.addParameter('BiasLearnRateFactor', defaultBiasLearnRateFactor,@(x)iAssertValidFactor(x));
p.addParameter('InputWeightsL2Factor', defaultWeightL2Factor, @(x)iAssertValidFactor(x));
p.addParameter('RecurrentWeightsL2Factor', defaultWeightL2Factor, @(x)iAssertValidFactor(x));
p.addParameter('BiasL2Factor', defaultBiasL2Factor, @(x)iAssertValidFactor(x));
p.addParameter('InputWeightsInitializer', defaultInputWeightsInitializer);
p.addParameter('RecurrentWeightsInitializer', defaultRecurrentWeightsInitializer);
p.addParameter('BiasInitializer', defaultBiasInitializer);
p.addParameter('InputWeights', defaultLearnable);
p.addParameter('RecurrentWeights', defaultLearnable);
p.addParameter('Bias', defaultLearnable);
p.addParameter('HiddenState', defaultState);
p.addParameter('ResetGateMode', defaultResetGateMode, @(x)any(iAssertAndReturnValidResetGateMode(x)));
end
function inputArguments = iConvertToCanonicalForm(parser)
results = parser.Results;
inputArguments = struct;
inputArguments.NumHiddenUnits = double( results.NumHiddenUnits );
inputArguments.Name = convertStringsToChars(results.Name);
inputArguments.OutputMode = iAssertAndReturnValidOutputMode(results.OutputMode);
inputArguments.StateActivationFunction = iAssertAndReturnValidStateActivation(convertStringsToChars(results.StateActivationFunction));
inputArguments.GateActivationFunction = iAssertAndReturnValidGateActivation(convertStringsToChars(results.GateActivationFunction));
inputArguments.InputWeightsLearnRateFactor = results.InputWeightsLearnRateFactor;
inputArguments.RecurrentWeightsLearnRateFactor = results.RecurrentWeightsLearnRateFactor;
inputArguments.BiasLearnRateFactor = results.BiasLearnRateFactor;
inputArguments.InputWeightsL2Factor = results.InputWeightsL2Factor;
inputArguments.RecurrentWeightsL2Factor = results.RecurrentWeightsL2Factor;
inputArguments.BiasL2Factor = results.BiasL2Factor;
inputArguments.InputWeightsInitializer = results.InputWeightsInitializer;
inputArguments.RecurrentWeightsInitializer = results.RecurrentWeightsInitializer;
inputArguments.BiasInitializer = results.BiasInitializer;
inputArguments.InputWeights = results.InputWeights;
inputArguments.RecurrentWeights = results.RecurrentWeights;
inputArguments.Bias = results.Bias;
inputArguments.HiddenState = results.HiddenState;
inputArguments.ResetGateMode = iAssertAndReturnValidResetGateMode(results.ResetGateMode);
end
function mode = iGetOutputMode( tf )
if tf
mode = 'sequence';
else
mode = 'last';
end
end
function iCheckFactorDimensions( value )
dim = numel( value );
if ~(dim == 1 || dim == 3)
exception = MException(message('nnet_cnn:layer:GRULayer:InvalidFactor'));
throwAsCaller(exception);
end
end
function validString = iAssertAndReturnValidOutputMode(value)
validString = validatestring(value, {'sequence', 'last'});
end
function validString = iAssertAndReturnValidStateActivation(value)
validString = validatestring(value, {'tanh', 'softsign'});
end
function validString = iAssertAndReturnValidGateActivation(value)
validString = validatestring(value, {'sigmoid','tanh', 'hard-sigmoid','radbasn'});
end
function iAssertValidFactor(value)
validateattributes(value, {'numeric'}, {'vector', 'real', 'nonnegative', 'finite'});
end
function value = iAssertValidWeightsInitializer(value, name)
validateattributes(value, {'function_handle','char','string'}, {});
if(ischar(value) || isstring(value))
value = validatestring(value, {'narrow-normal', ...
'glorot', ...
'he', ...
'orthogonal', ...
'zeros', ...
'ones'}, '', name);
end
end
function value = iAssertValidBiasInitializer(value)
validateattributes(value, {'function_handle','char','string'}, {});
if(ischar(value) || isstring(value))
value = validatestring(value, {'zeros', ...
'narrow-normal', ...
'ones'});
end
end
function initializer = iInitializerFactory(varargin)
initializer = nnet.internal.cnn.layer.learnable.initializer...
.initializerFactory(varargin{:});
end
function tf = iIsCustomInitializer(init)
tf = isa(init, 'nnet.internal.cnn.layer.learnable.initializer.Custom');
end
function iAssertValidLayerName(name)
iEvalAndThrow(@()...
nnet.internal.cnn.layer.paramvalidation.validateLayerName(name));
end
function iEvalAndThrow(func)
% Omit the stack containing internal functions by throwing as caller
try
func();
catch exception
throwAsCaller(exception)
end
end
function value = iGatherAndValidateParameter(varargin)
try
value = nnet.internal.cnn.layer.paramvalidation...
.gatherAndValidateNumericParameter(varargin{:});
catch exception
throwAsCaller(exception)
end
end
function value = iAssertAndReturnValidResetGateMode(value)
value = validatestring(value, {'after-multiplication', 'before-multiplication', 'recurrent-bias-after-multiplication'});
end
function dlX = iMakeSizeOnlyArray(varargin)
dlX = deep.internal.PlaceholderArray(varargin{:});
end

답변 (1개)

Ben
Ben 2023년 3월 13일
I would recommend implementing this extended GRU layer as a custom layer following this example:
You may be able to follow the code you have found in gruForwardGeneral to do this.
It is not recommended that you try to modify the source code directly.

카테고리

Help CenterFile Exchange에서 Pretrained Networks from External Platforms에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by