How can I work with hybrid inputs (numerical + categorical variables) to create a neural network
조회 수: 30 (최근 30일)
이전 댓글 표시
I want to create a neural network model with two inputs:
Input 1: numerical values
Input 2: categorical dummy variables (calendar date: Sunday, Monday,..., January, february,...)
The first input is vector. The second input is matrix.
How to deal with this hybrid input data to create a NN model or any other machine learning technique.
Thanks
댓글 수: 0
채택된 답변
Tomaso Cetto
2021년 9월 30일
편집: Tomaso Cetto
2021년 9월 30일
Hi Ismael,
One thing you could do is create a very simple neural network using two featureInputLayers. Let's break down the workflow into steps.
Preparing the raw data
The standard data format for the featureInputLayer is numObservations x numFeatures. If I understand your data correctly, it seems like your first input has a single feature, so it will be a numObservations x 1 vector, and your second input will be a matrix of dummified categorical inputs (neural networks can't process categorical inputs so we need to turn them into numeric matrices by dummifying, or 'one-hot encoding', them). A very useful function you can use to transform your categorical array to a dummifed matrix is the onehotencode function. The idea is that you will have a numObservations x numFeatures matrix, where numFeatures is equal to the total number of categories of your categorical array.
Creating the datastores
The way to feed in multiple inputs to trainNetwork is using a combinedDatastore. A combinedDatstore objects hold multiple underlying datastores, and reads from them all at training time. Because your inputs are 2D arrays, you can have arrayDatastores as the underlying datastores. Attached below is some code where I create some dummy data and responses, store them in individual arrayDatastores, and then combine them into one combinedDatastore:
% Create data
XTrain1 = rand(10, 1); % numObservations (10) x numFeatures (1)
XTrain2 = rand(10, 7); % numObservations (10) x numFeatures (7)
YTrain = rand(10, 1); % numObservations (10) x numResponses(1)
% Create arrayDatastores. We transpose the arrays because the datastore needs to read out
% predictors in the format numFeatures x 1, and so the 'IterationDimension' becomes the
% second one.
dsX1 = arrayDatastore(XTrain1',"IterationDimension", 2,'OutputType','cell');
dsX2 = arrayDatastore(XTrain2',"IterationDimension", 2,'OutputType','cell');
dsY = arrayDatastore(YTrain',"IterationDimension", 2,'OutputType','cell');
% Create the combined datastore.
ds = combine(dsX1, dsX2, dsY);
% Read one observation from the datastore
ds.read
For more information regarding how to build datastores for feature input, see the appropriate section of the trainNetwork doc.
Creating the network
I've attached code below to build a simple network that fully-connects on each input branch, concatenates both branches, and then fully-connects one more time before output. Feel free to modify this network to suit your needs!
inputOne = [
featureInputLayer(1)
fullyConnectedLayer(20, 'Name', 'fc1')];
inputTwo = [
featureInputLayer(7)
fullyConnectedLayer(10, 'Name', 'fc2')];
concat = concatenationLayer(1,2,'Name','concat');
lgraph = layerGraph(inputOne);
lgraph = addLayers(lgraph, inputTwo);
lgraph = addLayers(lgraph, concat);
lgraph = connectLayers(lgraph, 'fc1', 'concat/in1');
lgraph = connectLayers(lgraph, 'fc2', 'concat/in2');
outputLayers = [
fullyConnectedLayer(numResponses,'Name','fc3')
regressionLayer('name', 'class')];
lgraph = addLayers(lgraph, outputLayers);
lgraph = connectLayers(lgraph, 'concat', 'fc3');
% Visualize the network architecture
analyzeNetwork(lgraph)
Training the network
All that's left to do is train the network!
options = trainingOptions('sgdm', ...
'MaxEpochs', 5, ...
'MiniBatchSize', 8, ...
'Verbose', true);
net = trainNetwork(ds, lgraph, options);
Let me know if this solves your problem and if there is anything more I can do to help!
댓글 수: 4
추가 답변 (0개)
참고 항목
카테고리
Help Center 및 File Exchange에서 Custom Training Loops에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!