MATLAB Finding Output Weight Matrix of a Recurrent Neural Network (RNN) With Stochastic Gradient Descent (SGD)

조회 수: 11 (최근 30일)
I'm trying to find the output weight matrix of a recurrent neural network. I currently use the following linear regression formula:
Wout = pinv(r)*TD
where r is my RNN state matrix and is my training data set matrix. is the pseudoinverse operation. r is a by t matrix where D is the 1 dimensional size of my RNN and t is the number of time steps I am simulating. is a t by N matrix where N is the number of training data collections in my training data set.
My training data is too large and is producing a bunch of NaN's and zeros in . Rather than using linear regression, I would like to use stochastic gradient descent (SGD) to find . What is the best way to accomplish this in MATLAB?

채택된 답변

SOUMNATH PAUL
SOUMNATH PAUL 2023년 11월 29일
To my understanding you are trying to find the output weight matrix of a RNN using linear regression and it is showing undesired result like NAN's and zeroes, thus you seek to solve it using SGD.
Here are some steps that you can follow for implementing SGD for finding the output weight matrix 'Wout' of a RNN in MATLAB:
We will need to iteratively adjust 'Wout' by taking small steps in the direction that reduces the error between the RNN's predictions and the actual training data.
  1. Kindly initialize 'Wout', you can begin with a random or zero matrix for 'Wout'.
  2. Second step is to loop over batches, i.e. divide your training data into small batches.
  3. For each batch, calculate the RNN's predictions and the actual data(Forward Pass).
  4. After that you need to calculate errors, so measure the error between the RNN's predictions and the actual data.
  5. Then, compute the gradient of the error with respect to Wout(Backward Pass).
  6. Kindly adjust 'Wout' by a small step in the opposite of the gradient.
  7. The last step is to continue the above-mentioned process until the error is sufficiently low or for a fixed number of iterations.
Here is a basic code to illustrate SGD for updating 'Wout':
% Assuming 'r' is your state matrix (D by t) and 'TD' is your training data (t by N)
% Initialize parameters
learningRate = 0.01; % This is the step size in the gradient update
numEpochs = 100; % Number of times to go through the entire training data
batchSize = 50; % Size of each batch for training
Wout = randn(D, N); % Initialize Wout randomly
% Reshape 'TD' if it is a vector
if isvector(TD)
TD = TD(:); % Ensure TD is a column vector
end
% Perform SGD
for epoch = 1:numEpochs
for startIdx = 1:batchSize:size(r, 2)
endIdx = min(startIdx + batchSize - 1, size(r, 2));
% Extract the batch
rBatch = r(:, startIdx:endIdx);
TDBatch = TD(startIdx:endIdx, :);
% Forward pass: Calculate predictions
predictions = Wout' * rBatch;
% Calculate error for the batch
error = predictions - TDBatch;
% Backward pass: Compute gradient
gradWout = rBatch * error' / batchSize;
% Update Wout
Wout = Wout - learningRate * gradWout';
end
% Optional: Calculate and print total error after each epoch
totalError = norm(Wout' * r - TD, 'fro')^2;
fprintf('Epoch %d, Total Error: %f\n', epoch, totalError);
end
% 'Wout' is now trained using SGD
Additionaly, you can use the deep learning toolbox for training your model directly without creating your own optimization loops, Here is a documentation link which includes training options for SGD:
Hope it helps!
Regards,
Soumnath
  댓글 수: 5
SOUMNATH PAUL
SOUMNATH PAUL 2023년 12월 6일
The issue is arising because the 'gradWout' should be a [2501*1] matrix,matching with the dimensions of 'Wout', ideally the gradient computation should result in a [2501*1] matrix.
I believe the mismatch is happening due to the way the error and gradient are being calculated over the batch.
% Initialize parameters
learningRate = 0.01; % This is the step size in the gradient update
numEpochs = 100; % Number of times to go through the entire training data
batchSize = 50; % Size of each batch for training
Wout = randn(2501, 1); % Initialize Wout as a 2501 x 1 matrix
% Perform SGD
for epoch = 1:numEpochs
for startIdx = 1:batchSize:size(r, 2)
endIdx = min(startIdx + batchSize - 1, size(r, 2));
% Extract the batch
rBatch = r(:, startIdx:endIdx);
TDBatch = TD(startIdx:endIdx); % Assuming TD is t x 1
% Forward pass: Calculate predictions
predictions = Wout' * rBatch; % 1 x batchSize
% Calculate error for the batch
error = predictions - TDBatch'; % 1 x batchSize
% Backward pass: Compute gradient
gradWout = rBatch * error' / batchSize; % 2501 x batchSize * batchSize x 1 => 2501 x 1
% Update Wout
Wout = Wout - learningRate * gradWout; % 2501 x 1 - 2501 x 1 => 2501 x 1
end
% Optional: Calculate and print total error after each epoch
totalError = norm(Wout' * r - TD', 'fro')^2; % Assuming TD is t x 1
fprintf('Epoch %d, Total Error: %f\n', epoch, totalError);
end
% 'Wout' is now trained using SGD
Jonathan Frutschy
Jonathan Frutschy 2023년 12월 7일
@SOUMNATH PAUL This works for me using N = 1. I was able to get the original code you posted working for any abitratry N by making three changes:
#1: change error = predictions - TDBatch; to error = predictions' - TDBatch;
#2: change Wout = Wout - learningRate * gradWout'; to Wout = Wout - learningRate * gradWout;
#3: change totalError = norm(WoutSGD' * r - TD, 'fro')^2; to totalError = norm(WoutSGD' * r - TD', 'fro')^2;

댓글을 달려면 로그인하십시오.

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Dimensionality Reduction and Feature Extraction에 대해 자세히 알아보기

제품


릴리스

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by