Reinitialize a network with weights from previous training

조회 수: 10 (최근 30일)
Vasu Sharma
Vasu Sharma 2023년 10월 5일
답변: Udit06 2023년 10월 16일
Hi,
I have a custom neural network that I trained on an old data set. I now want to retrain the same model architecture but with new data. One way could be to append the old data with the new one and shuffle, but since the data is a bit different, I want to use some basic transfer learning.
For this purpose, I need to reinitialise my netowrk with the trained weights from the old data which I sold already in a mat file. I want to know how to use these trained weights with the new network., i.e. essentially create a new network (save structure as before), initialise this netowrk with weights from the previous trained network and retrain the network with new data.
Best Regards and thanks in advance. :)
Networklayer = [...
sequenceInputLayer(featureDimension)
fullyConnectedLayer(4*numHiddenUnits1)
reluLayer
fullyConnectedLayer(4*numHiddenUnits1)
reluLayer
fullyConnectedLayer(8*numHiddenUnits1)
reluLayer
gruLayer(LSTMStateNum,'OutputMode','sequence',InputWeightsInitializer='he',RecurrentWeightsInitializer='he')
fullyConnectedLayer(8*numHiddenUnits1)
reluLayer
fullyConnectedLayer(4*numHiddenUnits1)
reluLayer
fullyConnectedLayer(numResponses)
regressionLayer];

답변 (1개)

Udit06
Udit06 2023년 10월 16일
Hi Vasu,
I understand that you want to use the weights of the model trained on an old dataset to retrain the same model architecture with a new dataset. You can follow the following steps to achieve the same:
  1. Create a new network with the same architecture as the old network.
  2. Initialize the weights of the new network with the trained weights from the old network using the "setwb" function.
new_network = setwb(new_network, weights); % Set the weights of the new network
3. Train the new network with the new data using the "trainNetwork" function.
new_network = trainNetwork(new_data, new_network); % Train the new network with new data
You can refer to the following MathWorks documentations to know more about "setwb" and "trainNetwork" functions respectively:
  1. https://www.mathworks.com/help/deeplearning/ref/setwb.html
  2. https://www.mathworks.com/help/deeplearning/ref/trainnetwork.html
I hope this helps.

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

제품


릴리스

R2023a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by