image input for trained neural network

조회 수: 2 (최근 30일)
Amir5121
Amir5121 2018년 6월 20일
답변: Parag 2025년 3월 7일
I'm very new to Neural networks in Matlab and i have this trained neural network
1 'imageinput' Image Input 16x16x1 images with 'zerocenter' normalization
2 'relu_1' ReLU ReLU
3 'fc_1' Fully Connected 16 fully connected layer
4 'relu_2' ReLU ReLU
5 'fc_2' Fully Connected 10 fully connected layer
6 'softmax' Softmax softmax
7 'classoutput' Classification Output crossentropyex with '0', '1', and 8 other classes
I have 20000 16*16 binary images for input and validation which goes fine but then i want to classify a random image with the trained network (the size of the input image is 16 * 16) but i get
>>> net(image)
Index exceeds matrix dimensions.
Using R2017a

답변 (1개)

Parag
Parag 2025년 3월 7일
Hi, the issue may occur because MATLAB's trainNetwork function expects the input image to be in a 4D format: (Height, Width, Channels, Batch Size). Since your trained network was designed with an image input layer (imageInputLayer([16 16 1])), it requires the input image to match this format.
Please refer to this code for implementation on dummy data
% Define Layers
layers = [
imageInputLayer([16 16 1])
convolution2dLayer(3, 8, 'Padding', 1) % Use numeric padding instead of 'same'
reluLayer
maxPooling2dLayer(2, 'Stride', 2)
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
% Define Training Options
options = trainingOptions("sgdm", ...
"MaxEpochs", 5, ...
"MiniBatchSize", 32, ...
"Verbose", false);
% Generate Random Binary Training Data (1000 samples)
numTrain = 1000;
trainImages = randi([0, 1], 16, 16, 1, numTrain); % 4D array (Height, Width, Channels, Batch)
trainLabels = categorical(randi([0, 9], numTrain, 1)); % Labels from 0 to 9
% Train Network
net = trainNetwork(trainImages, trainLabels, layers, options);
% Test with a Random 16x16 Image
testImage = randi([0, 1], 16, 16, 1); % Binary test image
testImage = reshape(testImage, [16 16 1 1]); % Ensure 4D input format
predLabel = classify(net, testImage);
disp("Predicted Label: " + string(predLabel))
Predicted Label: 6

카테고리

Help CenterFile Exchange에서 Deep Learning Toolbox에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by