add more options to gruLayer's GateActivationFunction
6 visualizaciones (últimos 30 días)
Mostrar comentarios más antiguos
Greetings,
I am trying to train a GRU RNN as a regression network to estimate a time series signal. I want to see the effect of changing the GateActivationFunction but I am limited with two options " sigmoid, and hard-sigmoid". I try to add more options "tanh, and radbasn" to the following files "GRULayer, gruForwardGeneral, and gruLayer", the modified files are attached below. when I am trying to run the code with " gruLayer(Hiddenlayers1,'Name','gru1','OutputMode','sequence','StateActivationFunction','tanh','GateActivationFunction','tanh'). I will see an error telling me that I am limited with the two options. My question is, Is there anyway to add more options?, if yes what are the other files that I most edit?.
Thanks,
Hamza Al Kouzbary
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%gruLayer%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function layer = gruLayer(varargin)
%gruLayer Gated recurrent unit layer
%
% layer = gruLayer(numHiddenUnits) creates a Gated Recurrent Unit layer.
% numHiddenUnits is the number of hidden units in the layer, specified as
% a positive integer.
%
% layer = gruLayer(numHiddenUnits, 'PARAM1', VAL1, 'PARAM2', VAL2, ...)
% specifies optional parameter name/value pairs for creating the layer:
%
% 'Name' - Name for the layer, specified
% as a character vector or a
% string. The default value is
% ''.
% 'InputWeights' - Input weights, specified by a
% 3*numHiddenUnits-by-D matrix or
% [], where D is the number of
% features of the input data. The
% default is [].
% 'RecurrentWeights' - Recurrent weights, specified as
% a 3*numHiddenUnits-by-
% numHiddenUnits
% matrix or []. The default is
% [].
% 'Bias' - Layer biases, specified as a
% 3*numHiddenUnits-by-1 vector, a
% 6*numHiddenUnits-by-1 vector,
% or []. The default is [].
% 'HiddenState' - Initial hidden state, specified
% as a numHiddenUnits-by-1 vector
% or []. The default is [].
% 'OutputMode' - The format of the output of the
% layer. Options are:
% - 'sequence', to output a
% full sequence.
% - 'last', to output the
% last element only.
% The default value is
% 'sequence'.
% 'StateActivationFunction' - Activation function to update
% the hidden state.
% Options are:
% - 'tanh'
% - 'softsign'
% The default value is 'tanh'.
% 'GateActivationFunction' - Activation function to apply to
% the gates. Options are:
% - 'sigmoid'
% - 'tanh'
% - 'radbasn'
% - 'hard-sigmoid'
% The default value is 'sigmoid'.
% 'InputWeightsLearnRateFactor' - Multiplier for the learning
% rate of the input weights,
% specified as a scalar or a
% three-element vector. The
% default value is 1.
% 'RecurrentWeightsLearnRateFactor' - Multiplier for the learning
% rate of the recurrent weights,
% specified as a scalar or a
% three-element vector. The
% default value is 1.
% 'BiasLearnRateFactor' - Multiplier for the learning
% rate of the bias, specified as
% a scalar or a three-element
% vector. The default value is 1.
% 'InputWeightsL2Factor' - Multiplier for the L2
% regularizer of the input
% weights, specified as a scalar
% or a three-element vector. The
% default value is 1.
% 'RecurrentWeightsL2Factor' - Multiplier for the L2
% regularizer of the recurrent
% weights, specified as a scalar
% or a three-element vector. The
% default value is 1.
% 'BiasL2Factor' - Multiplier for the L2
% regularizer of the bias,
% specified as a scalar or a
% three-element vector. The
% default value is 0.
% 'InputWeightsInitializer' - The function to initialize the
% input weights, specified as
% 'glorot', 'he', 'orthogonal',
% 'narrow-normal', 'zeros',
% 'ones' or a function handle.
% The default is 'glorot'.
% 'RecurrentWeightsInitializer' - The function to initialize the
% recurrent weights, specified as
% 'glorot', 'he', 'orthogonal',
% 'narrow-normal', 'zeros',
% 'ones' or a function handle.
% The default is 'orthogonal'.
% 'BiasInitializer' - The function to initialize the
% bias, specified as 'zeros',
% 'narrow-normal', 'ones' or a
% function handle. The default is
% 'zeros'.
% 'ResetGateMode' - Reset gate mode, specified as
% one of the following:
% - 'after-multiplication',
% apply reset gate after
% matrix multiplication. This
% option uses the cuDNN
% library when running on
% GPU.
% - 'before-multiplication',
% apply reset gate before
% matrix multiplication.
% - 'recurrent-bias-after-multiplication',
% apply reset gate after
% matrix multiplication and
% use recurrent bias.
% The default value is
% 'after-multiplication'.
%
% Example 1:
% % Create a GRU layer with 100 hidden units.
%
% layer = gruLayer(100);
%
% Example 2:
% % Create a GRU layer with 50 hidden units which returns the last
% % output element of the sequence. Manually initialize the recurrent
% % weights from a Gaussian distribution with standard deviation
% % 0.01.
%
% numHiddenUnits = 50;
% layer = gruLayer(numHiddenUnits, 'OutputMode', 'last', ...
% 'RecurrentWeights', randn([3*numHiddenUnits numHiddenUnits])*0.01);
%
% See also nnet.cnn.layer.GRULayer
%
% <a href="matlab:helpview('deeplearning','list_of_layers')">List of Deep Learning Layers</a>
% Copyright 2019-2020 The MathWorks, Inc.
% Parse the input arguments.
varargin = nnet.internal.cnn.layer.util.gatherParametersToCPU(varargin);
args = nnet.cnn.layer.GRULayer.parseInputArguments(varargin{:});
% Create an internal representation of the layer.
internalLayer = nnet.internal.cnn.layer.GRU(args.Name, ...
args.InputSize, ...
args.NumHiddenUnits, ...
true, ...
iGetReturnSequence(args.OutputMode), ...
args.StateActivationFunction, ...
args.GateActivationFunction, ...
args.ResetGateMode);
% Use the internal layer to construct a user visible layer.
layer = nnet.cnn.layer.GRULayer(internalLayer);
% Set learnable parameters, learn rate, L2 factors and initializers.
layer.InputWeights = args.InputWeights;
layer.InputWeightsL2Factor = args.InputWeightsL2Factor;
layer.InputWeightsLearnRateFactor = args.InputWeightsLearnRateFactor;
layer.InputWeightsInitializer = args.InputWeightsInitializer;
layer.RecurrentWeights = args.RecurrentWeights;
layer.RecurrentWeightsL2Factor = args.RecurrentWeightsL2Factor;
layer.RecurrentWeightsLearnRateFactor = args.RecurrentWeightsLearnRateFactor;
layer.RecurrentWeightsInitializer = args.RecurrentWeightsInitializer;
layer.Bias = args.Bias;
layer.BiasL2Factor = args.BiasL2Factor;
layer.BiasLearnRateFactor = args.BiasLearnRateFactor;
layer.BiasInitializer = args.BiasInitializer;
% Set hidden state state.
layer.HiddenState = args.HiddenState;
end
function tf = iGetReturnSequence( mode )
tf = true;
if strcmp( mode, 'last' )
tf = false;
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%gruForwardGeneral%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [h, H] = gruForwardGeneral(X,learnable,state,options)
% gruForwardGeneral implementation for gru forward call, see
% https://arxiv.org/abs/1406.1078v1
% Copyright 2019 The MathWorks, Inc.
W = learnable.W;
R = learnable.R;
b = learnable.b;
h0 = state.h0;
% Determine dimensions
numHidden = size(R,2);
% Determine dimensions
[~, N, T] = size(X);
% Indexing helpers
[rInd, zInd, hInd] = nnet.internal.cnn.util.gruGateIndices(numHidden);
% Input weights
Wrz = W([rInd,zInd], :);
Wh = W(hInd, :);
% Recurrent weights
Rrz = R([rInd,zInd], :);
Rh = R(hInd, :);
% Biases
brz = b([rInd,zInd], :);
bh = b(hInd, :);
% Pre-allocate hidden state
h = zeros(numHidden, N, T, 'like', X);
if isstring(options.StateActivationFunction) || ischar(options.StateActivationFunction)
stateActivationFunction = iGetStateActivation( options.StateActivationFunction );
elseif isa(options.StateActivationFunction,'function_handle')
stateActivationFunction = options.StateActivationFunction;
end
if isstring(options.GateActivationFunction) || ischar(options.GateActivationFunction)
gateActivationFunction = iGetGateActivation( options.GateActivationFunction );
elseif isa(options.GateActivationFunction,'function_handle')
gateActivationFunction = options.GateActivationFunction;
end
% First iteration of forward loop
% Update r and z gates
rz = gateActivationFunction( Wrz*X(:, :, 1) + Rrz*h0 + brz );
r = rz(rInd, :);
z = rz(zInd, :);
% Compute candidate state hs
hs = stateActivationFunction( Wh*X(:, :, 1) + r.*(Rh*h0) + bh );
% Update hidden state h
h(:, :, 1) = (1 - z).*hs + z.*h0;
% Main forward loop
for tt = 2:T
hIdx = h(:, :, tt-1);
% Update r and z gates
rz = gateActivationFunction( Wrz*X(:, :, tt) + Rrz*hIdx + brz );
r = rz(rInd, :);
z = rz(zInd, :);
% Compute candidate state hs
hs = stateActivationFunction( Wh*X(:, :, tt) + r.*(Rh*hIdx) + bh );
% Update hidden state h
h(:, :, tt) = (1 - z).*hs + z.*hIdx;
end
if options.ReturnLast
H = h;
h = h(:, :, end);
else
H = h(:, :, end);
end
end
%% Helper functions
function act = iGetStateActivation( activation )
switch activation
case 'tanh'
act = @nnet.internal.cnnhost.tanhForward;
case 'softsign'
act = @iSoftSign;
end
end
function act = iGetGateActivation( activation )
switch activation
case 'sigmoid'
act = @nnet.internal.cnnhost.sigmoidForward;
case 'hard-sigmoid'
act = @nnet.internal.cnnhost.hardSigmoidForward;
case 'tanh'
act = @nnet.internal.cnnhost.tanhForward;
case 'radbasn'
act = @nnet.internal.cnnhost.hardSigmoidForward;
end
end
%% Activation functions
function y = iSoftSign(x)
y = x./(1 + abs(x));
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%GRULayer%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
classdef GRULayer < nnet.cnn.layer.Layer & nnet.internal.cnn.layer.Externalizable
% GRULayer Gated Recurrent Unit (GRU) layer
%
% To create a GRU layer, use gruLayer.
%
% GRULayer properties:
% Name - Name of the layer
% InputSize - Input size of the layer
% NumHiddenUnits - Number of hidden units in the layer
% OutputMode - Output as sequence or last
% StateActivationFunction - Activation function to
% update the hidden state
% GateActivationFunction - Activation function to
% apply to the gates
% ResetGateMode - ResetGateMode - Reset gate
% mode. Apply reset gate
% before or after matrix
% multiplication, and with or
% without recurrent bias.
% NumInputs - The number of inputs for
% the layer.
% InputNames - The names of the inputs of
% the layer.
% NumOutputs - The number of outputs of
% the layer.
% OutputNames - The names of the outputs of
% the layer.
%
% Properties for learnable parameters:
% InputWeights - Input weights
% InputWeightsInitializer - The function for
% initializing the input
% weights.
% InputWeightsLearnRateFactor - Learning rate multiplier
% for the input weights
% InputWeightsL2Factor - L2 multiplier for the
% input weights
%
% RecurrentWeights - Recurrent weights
% RecurrentWeightsInitializer - The function for
% initializing the recurrent
% weights.
% RecurrentWeightsLearnRateFactor - Learning rate multiplier
% for the recurrent weights
% RecurrentWeightsL2Factor - L2 multiplier for the
% recurrent weights
%
% Bias - Bias vector
% BiasInitializer - The function for
% initializing the bias.
% BiasLearnRateFactor - Learning rate multiplier
% for the bias
% BiasL2Factor - L2 multiplier for the bias
%
% State parameters:
% HiddenState - Hidden state vector
%
% Example:
% Create a Gated Recurrent Unit layer.
%
% layer = gruLayer(10)
%
% See also gruLayer
% Copyright 2019-2020 The MathWorks, Inc.
properties(Dependent)
% Name A name for the layer
% The name for the layer. If this is set to '', then a name will
% be automatically set at training time.
Name
end
properties(SetAccess = private, Dependent)
% InputSize The input size for the layer. If this is set to
% 'auto', then the input size will be automatically set during
% training
InputSize
% NumHiddenUnits The number of hidden units in the layer
NumHiddenUnits
% OutputMode The output format of the layer. If 'sequence',
% output is a sequence. If 'last', the output is the last element
% in a sequence
OutputMode
% StateActivationFunction The activation function to update the
% hidden state. Valid options are 'tanh' or 'softsign'. The default
% is 'tanh'.
StateActivationFunction
% GateActivationFunction The activation function to apply to the
% gates. Valid options are 'sigmoid' or 'hard-sigmoid'. The default
% is 'sigmoid'.
GateActivationFunction
% ResetGateMode Reset gate mode, specified as one of the
% following:
% 'after-multiplication' - apply reset gate after matrix
% multiplication. With this option the bias has size
% 3*numHiddenUnits-by-1. This option is cuDNN compatible.
% 'before-multiplication' - apply reset gate before matrix
% multiplication. With this option the bias has size
% 3*numHiddenUnits-by-1.
% 'recurrent-bias-after-multiplication' - apply reset gate
% after matrix multiplication and use a recurrent bias. With
% this option the bias has size 6*numHiddenUnits-by-1. This
% option is cuDNN compatible.
ResetGateMode
end
properties(Dependent)
% InputWeights The input weights for the layer
% The input weight matrix for the GRU layer. The input weight
% matrix is a vertical concatenation of the three "gate" input
% weight matrices in the forward pass of a GRU. Those individual
% matrices are concatenated in the following order: update gate,
% reset gate, output gate. This matrix will have size
% 3*NumHiddenUnits-by-InputSize.
InputWeights
% InputWeightsInitializer The function for initializing the
% input weights.
InputWeightsInitializer
% InputWeightsLearnRateFactor The learning rate factor for the
% input weights
% The learning rate factor for the input weights. This factor is
% multiplied with the global learning rate to determine the
% learning rate for the input weights in this layer. For example,
% if it is set to 2, then the learning rate for the input weights
% in this layer will be twice the current global learning rate.
% To control the value of the learn rate for the three individual
% matrices in the InputWeights, a 1-by-3 vector can be assigned.
InputWeightsLearnRateFactor (1,:) {mustBeNumeric, iCheckFactorDimensions}
% InputWeightsL2Factor The L2 regularization factor for the input
% weights
% The L2 regularization factor for the input weights. This factor
% is multiplied with the global L2 regularization setting to
% determine the L2 regularization setting for the input weights
% in this layer. For example, if it is set to 2, then the L2
% regularization for the input weights in this layer will be
% twice the global L2 regularization setting. To control the
% value of the L2 factor for the three individual matrices in the
% InputWeights, a 1-by-3 vector can be assigned.
InputWeightsL2Factor (1,:) {mustBeNumeric, iCheckFactorDimensions}
% RecurrentWeights The recurrent weights for the layer
% The recurrent weight matrix for the GRU layer. The recurrent
% weight matrix is a vertical concatenation of the three "gate"
% recurrent weight matrices in the forward pass of a GRU. Those
% individual matrices are concatenated in the following order:
% update gate, reset gate, output gate. This matrix will have
% size 3*NumHiddenUnits-by-NumHiddenUnits.
RecurrentWeights
% RecurrentWeightsInitializer The function for initializing the
% recurrent weights.
RecurrentWeightsInitializer
% RecurrentWeightsLearnRateFactor The learning rate factor for
% the recurrent weights
% The learning rate factor for the recurrent weights. This factor
% is multiplied with the global learning rate to determine the
% learning rate for the recurrent weights in this layer. For
% example, if it is set to 2, then the learning rate for the
% recurrent weights in this layer will be twice the current
% global learning rate. To control the value of the learn rate
% for the three individual matrices in the RecurrentWeights, a
% 1-by-3 vector can be assigned.
RecurrentWeightsLearnRateFactor (1,:) {mustBeNumeric, iCheckFactorDimensions}
% RecurrentWeightsL2Factor The L2 regularization factor for the
% recurrent weights
% The L2 regularization factor for the recurrent weights. This
% factor is multiplied with the global L2 regularization setting
% to determine the L2 regularization setting for the recurrent
% weights in this layer. For example, if it is set to 2, then the
% L2 regularization for the recurrent weights in this layer will
% be twice the global L2 regularization setting. To control the
% value of the L2 factor for the three individual matrices in the
% RecurrentWeights, a 1-by-3 vector can be assigned.
RecurrentWeightsL2Factor (1,:) {mustBeNumeric, iCheckFactorDimensions}
% Bias The biases for the layer
% The bias vector for the GRU layer. The bias vector is a
% concatenation of the three "gate" bias vectors in the forward
% pass of a GRU. Those individual vectors are concatenated in
% the following order: update gate, reset gate, output gate. This
% vector will have size 3*NumHiddenUnits-by-1.
Bias
% BiasInitializer The function for initializing the bias
BiasInitializer
% BiasLearnRateFactor The learning rate factor for the biases
% The learning rate factor for the bias. This factor is
% multiplied with the global learning rate to determine the
% learning rate for the bias in this layer. For example, if it is
% set to 2, then the learning rate for the bias in this layer
% will be twice the current global learning rate. To control the
% value of the learn rate for the three individual vectors in the
% Bias, a 1-by-3 vector can be assigned.
BiasLearnRateFactor (1,:) {mustBeNumeric, iCheckFactorDimensions}
% BiasL2Factor The L2 regularization factor for the biases
% The L2 regularization factor for the biases. This factor is
% multiplied with the global L2 regularization setting to
% determine the L2 regularization setting for the biases in this
% layer. For example, if it is set to 2, then the L2
% regularization for the biases in this layer will be twice the
% global L2 regularization setting. To control the value of the
% L2 factor for the three individual vectors in the Bias, a
% 1-by-3 vector can be assigned.
BiasL2Factor (1,:) {mustBeNumeric, iCheckFactorDimensions}
end
properties(Dependent)
% HiddenState The initial value of the hidden state.
% The initial value of the hidden state. This vector will have
% size NumHiddenUnits-by-1. Setting this value sets the default
% value to which the hidden state is reset to when calling the
% resetState method of SeriesNetwork.
HiddenState
end
properties(SetAccess = private, Hidden, Dependent)
% OutputSize The number of hidden units in the layer. See
% NumHiddenUnits.
OutputSize
% OutputState The hidden state of the layer. See HiddenState.
OutputState
end
methods
function this = GRULayer(privateLayer)
this.PrivateLayer = privateLayer;
end
function val = get.Name(this)
val = this.PrivateLayer.Name;
end
function this = set.Name(this, val)
iAssertValidLayerName(val);
this.PrivateLayer.Name = char(val);
end
function val = get.InputSize(this)
val = this.PrivateLayer.InputSize;
if isempty(val)
val = 'auto';
end
end
function val = get.NumHiddenUnits(this)
val = this.PrivateLayer.HiddenSize;
end
function val = get.OutputMode(this)
val = iGetOutputMode( this.PrivateLayer.ReturnSequence );
end
function val = get.StateActivationFunction(this)
val = this.PrivateLayer.Activation;
end
function val = get.GateActivationFunction(this)
val = this.PrivateLayer.RecurrentActivation;
end
function val = get.InputWeights(this)
val = this.PrivateLayer.InputWeights.HostValue;
if isa(val, 'dlarray')
val = extractdata(val);
end
end
function val = get.ResetGateMode(this)
val = this.PrivateLayer.ResetGateMode;
end
function this = set.InputWeights(this, value)
if isequal(this.InputSize, 'auto')
expectedInputSize = NaN;
else
expectedInputSize = this.InputSize;
end
attributes = {'size', [3*this.NumHiddenUnits expectedInputSize],...
'real', 'nonsparse'};
value = iGatherAndValidateParameter(value, attributes);
if ~isempty(value)
this.PrivateLayer = this.PrivateLayer.configureForInputs( ...
{iMakeSizeOnlyArray([size(value,2) NaN NaN],'CBT')} );
end
this.PrivateLayer.InputWeights.Value = value;
end
function val = get.InputWeightsInitializer(this)
if iIsCustomInitializer(this.PrivateLayer.InputWeights.Initializer)
val = this.PrivateLayer.InputWeights.Initializer.Fcn;
else
val = this.PrivateLayer.InputWeights.Initializer.Name;
end
end
function this = set.InputWeightsInitializer(this, value)
value = iAssertValidWeightsInitializer(value, 'InputWeightsInitializer');
% Create the initializer with in and out indices of the weights
% size: 3*NumHiddenUnits-by-InputSize
this.PrivateLayer.InputWeights.Initializer = ...
iInitializerFactory(value, 2, 1);
end
function val = get.RecurrentWeights(this)
val = this.PrivateLayer.RecurrentWeights.HostValue;
if isa(val, 'dlarray')
val = extractdata(val);
end
end
function this = set.RecurrentWeights(this, value)
attributes = {'size', [3*this.NumHiddenUnits this.NumHiddenUnits],...
'real', 'nonsparse'};
value = iGatherAndValidateParameter(value, attributes);
this.PrivateLayer.RecurrentWeights.Value = value;
end
function val = get.RecurrentWeightsInitializer(this)
if iIsCustomInitializer(this.PrivateLayer.RecurrentWeights.Initializer)
val = this.PrivateLayer.RecurrentWeights.Initializer.Fcn;
else
val = this.PrivateLayer.RecurrentWeights.Initializer.Name;
end
end
function this = set.RecurrentWeightsInitializer(this, value)
value = iAssertValidWeightsInitializer(value, 'RecurrentWeightsInitializer');
% Create the initializer with in and out indices of the weights
% size: 3*NumHiddenUnits-by-NumHiddenUnits
this.PrivateLayer.RecurrentWeights.Initializer = ...
iInitializerFactory(value, 2, 1);
end
function val = get.Bias(this)
val = this.PrivateLayer.Bias.HostValue;
if isa(val, 'dlarray')
val = extractdata(val);
end
end
function this = set.Bias(this, value)
biasnrowfactor = 1 + double(isequal(this.ResetGateMode, ...
'recurrent-bias-after-multiplication'));
attributes = {'column', 'real', 'nonsparse'};
value = iGatherAndValidateParameter(value, attributes);
expectedSize = 3*biasnrowfactor*this.NumHiddenUnits;
% Valid input value is empty or has size either
% 3*NumHiddenUnits, if 'ResetGateMode' is
% 'after-multiplication' or 'before-multiplication', or
% 6*NumHiddenUnits, if 'ResetGateMode' is
% 'recurrent-bias-after-multiplication'.
if length(value)~=expectedSize && ~isequal(value,[])
error(message('nnet_cnn:layer:GRULayer:BiasSize',...
3*biasnrowfactor,this.ResetGateMode));
end
this.PrivateLayer.Bias.Value = value;
end
function val = get.BiasInitializer(this)
if iIsCustomInitializer(this.PrivateLayer.Bias.Initializer)
val = this.PrivateLayer.Bias.Initializer.Fcn;
else
val = this.PrivateLayer.Bias.Initializer.Name;
end
end
function this = set.BiasInitializer(this, value)
value = iAssertValidBiasInitializer(value);
% The Bias initializer needs to know which recurrent type
this.PrivateLayer.Bias.Initializer = iInitializerFactory(value,...
'GRU');
end
function val = get.HiddenState(this)
val = gather(this.PrivateLayer.HiddenState.Value);
end
function this = set.HiddenState(this, value)
value = iGatherAndValidateParameter(value, 'default', [this.NumHiddenUnits 1]);
this.PrivateLayer.InitialHiddenState = value;
this.PrivateLayer.HiddenState.Value = value;
end
function val = get.InputWeightsLearnRateFactor(this)
val = this.getFactor(this.PrivateLayer.InputWeights.LearnRateFactor);
end
function this = set.InputWeightsLearnRateFactor(this, val)
val = gather(val);
iAssertValidFactor(val)
this.PrivateLayer.InputWeights.LearnRateFactor = this.setFactor(val);
end
function val = get.InputWeightsL2Factor(this)
val = this.getFactor(this.PrivateLayer.InputWeights.L2Factor);
end
function this = set.InputWeightsL2Factor(this, val)
val = gather(val);
iAssertValidFactor(val)
this.PrivateLayer.InputWeights.L2Factor = this.setFactor(val);
end
function val = get.RecurrentWeightsLearnRateFactor(this)
val = this.getFactor(this.PrivateLayer.RecurrentWeights.LearnRateFactor);
end
function this = set.RecurrentWeightsLearnRateFactor(this, val)
val = gather(val);
iAssertValidFactor(val)
this.PrivateLayer.RecurrentWeights.LearnRateFactor = this.setFactor(val);
end
function val = get.RecurrentWeightsL2Factor(this)
val = this.getFactor(this.PrivateLayer.RecurrentWeights.L2Factor);
end
function this = set.RecurrentWeightsL2Factor(this, val)
val = gather(val);
iAssertValidFactor(val)
this.PrivateLayer.RecurrentWeights.L2Factor = this.setFactor(val);
end
function val = get.BiasLearnRateFactor(this)
val = this.getFactor(this.PrivateLayer.Bias.LearnRateFactor);
end
function this = set.BiasLearnRateFactor(this, val)
val = gather(val);
iAssertValidFactor(val)
this.PrivateLayer.Bias.LearnRateFactor = this.setFactor(val);
end
function val = get.BiasL2Factor(this)
val = this.getFactor(this.PrivateLayer.Bias.L2Factor);
end
function this = set.BiasL2Factor(this, val)
val = gather(val);
iAssertValidFactor(val)
this.PrivateLayer.Bias.L2Factor = this.setFactor(val);
end
function val = get.OutputSize(this)
val = this.NumHiddenUnits;
end
function val = get.OutputState(this)
val = this.HiddenState;
end
function out = saveobj(this)
privateLayer = this.PrivateLayer;
out.Version = 1.0;
out.Name = privateLayer.Name;
out.InputSize = privateLayer.InputSize;
out.NumHiddenUnits = privateLayer.HiddenSize;
out.ReturnSequence = privateLayer.ReturnSequence;
out.ResetGateMode = privateLayer.ResetGateMode;
out.StateActivationFunction = privateLayer.Activation;
out.GateActivationFunction = privateLayer.RecurrentActivation;
out.InputWeights = toStruct(privateLayer.InputWeights);
out.RecurrentWeights = toStruct(privateLayer.RecurrentWeights);
out.Bias = toStruct(privateLayer.Bias);
out.HiddenState = toStruct(privateLayer.HiddenState);
out.InitialHiddenState = gather(privateLayer.InitialHiddenState);
end
end
methods(Static)
function inputArguments = parseInputArguments(varargin)
parser = iCreateParser();
parser.parse(varargin{:});
inputArguments = iConvertToCanonicalForm(parser);
inputArguments.InputSize = [];
end
function this = loadobj(in)
internalLayer = nnet.internal.cnn.layer.GRU( in.Name, ...
in.InputSize, ...
in.NumHiddenUnits, ...
true, ...
in.ReturnSequence, ...
in.StateActivationFunction, ...
in.GateActivationFunction, ...
in.ResetGateMode );
internalLayer.InputWeights = nnet.internal.cnn.layer.learnable.PredictionLearnableParameter.fromStruct(in.InputWeights);
internalLayer.RecurrentWeights = nnet.internal.cnn.layer.learnable.PredictionLearnableParameter.fromStruct(in.RecurrentWeights);
internalLayer.Bias = nnet.internal.cnn.layer.learnable.PredictionLearnableParameter.fromStruct(in.Bias);
internalLayer.HiddenState = nnet.internal.cnn.layer.dynamic.TrainingDynamicParameter.fromStruct(in.HiddenState);
internalLayer.InitialHiddenState = in.InitialHiddenState;
this = nnet.cnn.layer.GRULayer(internalLayer);
end
end
methods(Hidden, Access = protected)
function [description, type] = getOneLineDisplay(obj)
description = iGetMessageString( ...
'nnet_cnn:layer:GRULayer:oneLineDisplay', ...
num2str(obj.NumHiddenUnits));
type = iGetMessageString( 'nnet_cnn:layer:GRULayer:Type' );
end
function groups = getPropertyGroups( this )
generalParameters = { 'Name' };
hyperParameters = { 'InputSize', ...
'NumHiddenUnits', ...
'OutputMode', ...
'StateActivationFunction', ...
'GateActivationFunction', ...
'ResetGateMode'};
learnableParameters = { 'InputWeights', ...
'RecurrentWeights', ...
'Bias' };
stateParameters = { 'HiddenState' };
groups = [
this.propertyGroupGeneral( generalParameters )
this.propertyGroupHyperparameters( hyperParameters )
this.propertyGroupLearnableParameters( learnableParameters )
this.propertyGroupDynamicParameters( stateParameters )
];
end
function footer = getFooter( this )
variableName = inputname(1);
footer = this.createShowAllPropertiesFooter( variableName );
end
function val = getFactor(this, val)
if isscalar(val)
% No operation needed
elseif numel(val) == (3*this.NumHiddenUnits)
val = val(1:this.NumHiddenUnits:end);
val = val(:)';
else
% Error - the factor has incorrect size
end
end
function val = setFactor(this, val)
if isscalar(val)
% No operation needed
elseif numel(val) == 3
% Expand a three-element vector into a 3*NumHiddenUnits-by-1
% column vector
expandedValues = repelem( val, this.NumHiddenUnits );
val = expandedValues(:);
else
% Error - the factor has incorrect size
end
end
end
end
function messageString = iGetMessageString( varargin )
messageString = getString( message( varargin{:} ) );
end
function p = iCreateParser()
p = inputParser;
defaultName = '';
defaultOutputMode = 'sequence';
defaultStateActivationFunction = 'tanh';
defaultGateActivationFunction = 'sigmoid';
defaultWeightLearnRateFactor = 1;
defaultBiasLearnRateFactor = 1;
defaultWeightL2Factor = 1;
defaultBiasL2Factor = 0;
defaultInputWeightsInitializer = 'glorot';
defaultRecurrentWeightsInitializer = 'orthogonal';
defaultBiasInitializer = 'zeros';
defaultLearnable = [];
defaultState = [];
defaultResetGateMode = 'after-multiplication';
p.addRequired('NumHiddenUnits', @(x)validateattributes(x, {'numeric'}, {'scalar', 'positive', 'integer'}));
p.addParameter('Name', defaultName, @nnet.internal.cnn.layer.paramvalidation.validateLayerName);
p.addParameter('OutputMode', defaultOutputMode, @(x)any(iAssertAndReturnValidOutputMode(x)));
p.addParameter('StateActivationFunction', defaultStateActivationFunction, @(x)any(iAssertAndReturnValidStateActivation(x)));
p.addParameter('GateActivationFunction', defaultGateActivationFunction, @(x)any(iAssertAndReturnValidGateActivation(x)));
p.addParameter('InputWeightsLearnRateFactor', defaultWeightLearnRateFactor, @(x)iAssertValidFactor(x));
p.addParameter('RecurrentWeightsLearnRateFactor', defaultWeightLearnRateFactor,@(x)iAssertValidFactor(x));
p.addParameter('BiasLearnRateFactor', defaultBiasLearnRateFactor,@(x)iAssertValidFactor(x));
p.addParameter('InputWeightsL2Factor', defaultWeightL2Factor, @(x)iAssertValidFactor(x));
p.addParameter('RecurrentWeightsL2Factor', defaultWeightL2Factor, @(x)iAssertValidFactor(x));
p.addParameter('BiasL2Factor', defaultBiasL2Factor, @(x)iAssertValidFactor(x));
p.addParameter('InputWeightsInitializer', defaultInputWeightsInitializer);
p.addParameter('RecurrentWeightsInitializer', defaultRecurrentWeightsInitializer);
p.addParameter('BiasInitializer', defaultBiasInitializer);
p.addParameter('InputWeights', defaultLearnable);
p.addParameter('RecurrentWeights', defaultLearnable);
p.addParameter('Bias', defaultLearnable);
p.addParameter('HiddenState', defaultState);
p.addParameter('ResetGateMode', defaultResetGateMode, @(x)any(iAssertAndReturnValidResetGateMode(x)));
end
function inputArguments = iConvertToCanonicalForm(parser)
results = parser.Results;
inputArguments = struct;
inputArguments.NumHiddenUnits = double( results.NumHiddenUnits );
inputArguments.Name = convertStringsToChars(results.Name);
inputArguments.OutputMode = iAssertAndReturnValidOutputMode(results.OutputMode);
inputArguments.StateActivationFunction = iAssertAndReturnValidStateActivation(convertStringsToChars(results.StateActivationFunction));
inputArguments.GateActivationFunction = iAssertAndReturnValidGateActivation(convertStringsToChars(results.GateActivationFunction));
inputArguments.InputWeightsLearnRateFactor = results.InputWeightsLearnRateFactor;
inputArguments.RecurrentWeightsLearnRateFactor = results.RecurrentWeightsLearnRateFactor;
inputArguments.BiasLearnRateFactor = results.BiasLearnRateFactor;
inputArguments.InputWeightsL2Factor = results.InputWeightsL2Factor;
inputArguments.RecurrentWeightsL2Factor = results.RecurrentWeightsL2Factor;
inputArguments.BiasL2Factor = results.BiasL2Factor;
inputArguments.InputWeightsInitializer = results.InputWeightsInitializer;
inputArguments.RecurrentWeightsInitializer = results.RecurrentWeightsInitializer;
inputArguments.BiasInitializer = results.BiasInitializer;
inputArguments.InputWeights = results.InputWeights;
inputArguments.RecurrentWeights = results.RecurrentWeights;
inputArguments.Bias = results.Bias;
inputArguments.HiddenState = results.HiddenState;
inputArguments.ResetGateMode = iAssertAndReturnValidResetGateMode(results.ResetGateMode);
end
function mode = iGetOutputMode( tf )
if tf
mode = 'sequence';
else
mode = 'last';
end
end
function iCheckFactorDimensions( value )
dim = numel( value );
if ~(dim == 1 || dim == 3)
exception = MException(message('nnet_cnn:layer:GRULayer:InvalidFactor'));
throwAsCaller(exception);
end
end
function validString = iAssertAndReturnValidOutputMode(value)
validString = validatestring(value, {'sequence', 'last'});
end
function validString = iAssertAndReturnValidStateActivation(value)
validString = validatestring(value, {'tanh', 'softsign'});
end
function validString = iAssertAndReturnValidGateActivation(value)
validString = validatestring(value, {'sigmoid','tanh', 'hard-sigmoid','radbasn'});
end
function iAssertValidFactor(value)
validateattributes(value, {'numeric'}, {'vector', 'real', 'nonnegative', 'finite'});
end
function value = iAssertValidWeightsInitializer(value, name)
validateattributes(value, {'function_handle','char','string'}, {});
if(ischar(value) || isstring(value))
value = validatestring(value, {'narrow-normal', ...
'glorot', ...
'he', ...
'orthogonal', ...
'zeros', ...
'ones'}, '', name);
end
end
function value = iAssertValidBiasInitializer(value)
validateattributes(value, {'function_handle','char','string'}, {});
if(ischar(value) || isstring(value))
value = validatestring(value, {'zeros', ...
'narrow-normal', ...
'ones'});
end
end
function initializer = iInitializerFactory(varargin)
initializer = nnet.internal.cnn.layer.learnable.initializer...
.initializerFactory(varargin{:});
end
function tf = iIsCustomInitializer(init)
tf = isa(init, 'nnet.internal.cnn.layer.learnable.initializer.Custom');
end
function iAssertValidLayerName(name)
iEvalAndThrow(@()...
nnet.internal.cnn.layer.paramvalidation.validateLayerName(name));
end
function iEvalAndThrow(func)
% Omit the stack containing internal functions by throwing as caller
try
func();
catch exception
throwAsCaller(exception)
end
end
function value = iGatherAndValidateParameter(varargin)
try
value = nnet.internal.cnn.layer.paramvalidation...
.gatherAndValidateNumericParameter(varargin{:});
catch exception
throwAsCaller(exception)
end
end
function value = iAssertAndReturnValidResetGateMode(value)
value = validatestring(value, {'after-multiplication', 'before-multiplication', 'recurrent-bias-after-multiplication'});
end
function dlX = iMakeSizeOnlyArray(varargin)
dlX = deep.internal.PlaceholderArray(varargin{:});
end
0 comentarios
Respuestas (1)
Ben
el 13 de Mzo. de 2023
I would recommend implementing this extended GRU layer as a custom layer following this example:
You may be able to follow the code you have found in gruForwardGeneral to do this.
It is not recommended that you try to modify the source code directly.
0 comentarios
Ver también
Categorías
Más información sobre Build Deep Neural Networks en Help Center y File Exchange.
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!