필터 지우기
필터 지우기

Is it possible to share common weights and bias among different LSTM layers?

조회 수: 11 (최근 30일)
Abel
Abel 2020년 2월 5일
답변: Conor Daly 2023년 2월 17일
I am building a network looks like the figure below.
There are three LSTM layers, namely LSTM_common_1, LSTM_common_2 and LSTM_common_3.
Can I retrict their weights and bias so that all of the LSTM_common_x shares the same set of weights and bias?
2020-02-05 17_50_15-Clipboard.png

답변 (1개)

Conor Daly
Conor Daly 2023년 2월 17일
One way to share weights like this is to use nested layers -- layers which have learnable parameters defined by neural networks. The general idea is to create a layer which uses the shared sub-network (which in this case is just a single LSTM layer) as appropriate.
Here's an example for the case above:
classdef commonLSTMLayer < nnet.layer.Layer ...
& nnet.layer.Formattable ...
& nnet.layer.Acceleratable
properties (Learnable)
Network
end
methods
function this = commonLSTMLayer(numHiddenUnits, numOutputs, args)
arguments
numHiddenUnits (1,1) {mustBePositive, mustBeInteger}
numOutputs (1,1) {mustBePositive, mustBeInteger}
args.OutputMode {mustBeMember(args.OutputMode, ["last","sequence"])}= "sequence"
args.Name {mustBeTextScalar}
end
this.Name = args.Name;
layer = lstmLayer(numHiddenUnits, OutputMode=args.OutputMode);
this.Network = dlnetwork(layer, Initialize=false);
this.NumOutputs = numOutputs;
this.OutputNames = "out" + (1:numOutputs);
end
function varargout = predict(this, X)
varargout = cell(1,this.NumOutputs);
for n = 1:this.NumOutputs
varargout{n} = predict(this.Network, X(n,:,:));
end
end
end
end
Using this layer we can construct the network as follows:
numInputChannels = 3;
numHiddenUnits = 64;
layers = [ sequenceInputLayer(numInputChannels)
commonLSTMLayer(numHiddenUnits, numInputChannels, OutputMode="last", Name="lstm")
fullyConnectedLayer(2, Name="fc1")
concatenationLayer(1, 3, Name="cat")
regressionLayer() ];
lg = layerGraph(layers);
lg = addLayers(lg, fullyConnectedLayer(2, Name="fc2"));
lg = addLayers(lg, fullyConnectedLayer(2, Name="fc3"));
lg = connectLayers(lg, "lstm/out2", "fc2");
lg = connectLayers(lg, "lstm/out3", "fc3");
lg = connectLayers(lg, "fc2", "cat/in2");
lg = connectLayers(lg, "fc3", "cat/in3");
analyzeNetwork(lg)

카테고리

Help CenterFile Exchange에서 Quantization, Projection, and Pruning에 대해 자세히 알아보기

제품


릴리스

R2019a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by