Training a model using different shapes

조회 수: 9 (최근 30일)
Ahmed
Ahmed 2023년 1월 27일
답변: Conor Daly 2023년 3월 28일
I have training data of around 1000 shapes (different sizes and dimensions). This data is in a cell array, where each cell is a shape and within the cell there is an array of size n by 2. The n is the number of data points that draw the shape and 2 columns are for the x and y coordinates of these points. For the training data, these points are ordered so that if a straight line connects the points as they are ordered in the array it will draw out the desired shape accurately.
I would like to train a model to learn from those 1000 shapes so that if given a new shape and the points are not in order, the model is able to re order the points and draw the shape based on what it has learned from its training of all other shapes.
I am very new to the concept of training models, what I have used in matlab so far is giving the neural networks a set of inputs and an outputs and it learns what it can, but here I have different cases that should be learned from and I’m not sure added all those points to one long array of coordinates is the right thing to do because it defeats the purpose of the distinct shapes and the order of the points, any advice is appreciated
% Step 1: Prepare the data
% Load the x and y coordinates of your shapes
load('shapes.mat');
% Concatenate the x and y coordinates of each shape
data = [];
for i = 1:numel(shapes)
data = [data; shapes{i}];
end
% Step 2: Define the CNN architecture
layers = [
sequenceInputLayer([size(data, 1) 2])
lstmLayer(64,'OutputMode','sequence')
dropoutLayer(0.1)
lstmLayer(64,'OutputMode','sequence')
dropoutLayer(0.1)
fullyConnectedLayer(size(data, 1)*2)
regressionLayer];
% Step 3: Train the model on all shapes
% Split the data into training and test sets
[XTrain,XTest,YTrain,YTest] = split_data(data, 0.8);
options = trainingOptions('adam', ...
'InitialLearnRate', 0.01, ...
'MaxEpochs',4, ...
'Shuffle','every-epoch', ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(XTrain,YTrain,layers,options);
% Step 4: Use the trained model to make predictions on new shapes
predictedCoordinates = predict(net,XTest);
  댓글 수: 7
Ahmed
Ahmed 2023년 1월 27일
i want to first train the model on these shapes before testing on randomly ordered shapes
Ahmed
Ahmed 2023년 1월 27일
@KSSV is there any resource you could direct me to where I can learn how to train a model to order points based on different examples please

댓글을 달려면 로그인하십시오.

답변 (1개)

Conor Daly
Conor Daly 2023년 3월 28일
To train a model that can unscramble the order of the data, the model needs to be trained specifically for this. One way of doing this is to create a set of predictors which are scrambled, and use the unscrambled data as targets.
Here's an example to get you started. The model doesn't train very well, but it's just an example.
% Load the data.
load('shapes_2.mat');
% Transpose each shape to 2x(numPoints).
shapes = cellfun(@transpose, shapes, UniformOutput=false);
% Standardize data.
M = mean( cat(2, shapes{:}), 2 );
S = std( cat(2, shapes{:}), [], 2 );
shapes = cellfun(@(x)(x-M)./S, shapes, UniformOutput=false);
% Create training predictors/targets by scrambling the order of the
% predictors.
X = shapes;
T = shapes;
for n = 1:numel(X)
idx = randperm(size(X{n},2));
X{n} = X{n}(:, idx);
end
% Split into train/test sets.
XTrain = X(1:150);
TTrain = T(1:150);
XTest = X(151:end);
TTest = T(151:end);
% Define network architecture.
layers = [
sequenceInputLayer(2)
bilstmLayer(64)
dropoutLayer(0.1)
bilstmLayer(64)
dropoutLayer(0.1)
fullyConnectedLayer(2)
regressionLayer ];
% Train the network.
options = trainingOptions("adam", ...
MiniBatchSize=50, ...
MaxEpochs=300, ...
Shuffle="every-epoch", ...
ValidationData={XTest,TTest}, ...
Verbose=false, ...
OutputNetwork="best-validation-loss", ...
Plots="training-progress" );
net = trainNetwork(XTrain, TTrain, layers, options);
% Test the trained network.
YTest = predict(net, XTest);
meanAbsError = mean( cellfun(@(y,t)mean(abs(y - t),'all'), YTest, TTest ));

카테고리

Help CenterFile Exchange에서 Deep Learning Toolbox에 대해 자세히 알아보기

제품


릴리스

R2021b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by