how to use trainNetwork to train network with sequence and image in 2022a

조회 수: 6 (최근 30일)
zhuangxin Fang
zhuangxin Fang 2022년 3월 27일
답변: Milan Bansal 2023년 9월 29일
i have bulit a network with sequence input and image input,but when I use trainNetwork to train it, it will return the erros that my input format is incorrect.
%%%Train
dsX1Train = arrayDatastore(XTrain1,IterationDimension=4); %%XTrain1 is the image input
dsX2Train = arrayDatastore(XTrain2,IterationDimension=1); %%XTrain2 is the sequence input
dsYTrain = arrayDatastore(YTrain);
dsTrain = combine(dsX1Train,dsX2Train,dsYTrain); %%data combine
%%%Validation
dsX1Val = arrayDatastore(XVal1,IterationDimension=4);
dsX2Val = arrayDatastore(XVal2);
dsYVal = arrayDatastore(YVal);
dsVal = combine(dsX1Val,dsX2Val,dsYVal); %%data combine
[net,info]=trainNetwork(dsTrain,lgraph,options);

답변 (1개)

Milan Bansal
Milan Bansal 2023년 9월 29일
Hi,
As per my understanding you are trying to build a Neural Network with multiple data (image and sequence) as input and are facing issue while using the "trainNetwork" function.
In order to build and train a Neural Network with multiple inputs, it is required to create the Network Architecture using "layerGraph" object.
Define the main block of layers with "imageInputLayer" as input layer and other convolution layers. Add a "concatenationLayer" in the network with the dimensions such that "sequenceInputLayer" can be connected to this main architecture later for the sequence input. Convert this network into Layer Graph.
lgraph = layerGraph(layers);
Create a "sequenceInputLayer" for sequence input and connect it to the second input of the "concatenationLayer".
seqInput = sequenceInputLayer(inputSize);
lgraph = addLayers(lgraph,,'Name','seq');
lgraph = connectLayers(lgraph,"seq","cat/in2");
Train the network using "trainNetwork" function using "dsTrain" as "dataStore" and "lgraph" as the Network.
Refer to the example in the below MathWorks documentation link to learn more about Training Neural Network with Multiple Inputs.
Refer to the below MathWorks documentation link to learn more about "layerGraph".
Hope it helps!

카테고리

Help CenterFile Exchange에서 Deep Learning Toolbox에 대해 자세히 알아보기

제품


릴리스

R2022a

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by