- Insufficient training data: Ensure that you have a diverse and representative training dataset with an adequate number of samples for each class. Make sure the dataset covers a wide range of object variations. You can explore data augmentation techniques to artificially increase the size of your training dataset, such as random scaling, rotation, cropping, or flipping.
- Inaccurate bounding box annotations: Verify that the bounding box annotations in your ground truth data are accurate and tightly enclose the objects of interest. Incorrect or imprecise annotations can lead to poor performance.
- Insufficient training or fine-tuning: Consider training the model for more epochs or adjusting the learning rate schedule to allow the model to converge better and improve performance. For example, you can try decreasing the learning rate over time, to improve training dynamics and performance
- Inappropriate network architecture or feature extraction layer: Experiment with different feature extraction networks or layers to capture more relevant features for object detection.
- Inadequate hyperparameter tuning: Tune the training options, such as the optimizer(e.g. Adam, RMSprop), mini-batch size, and overlap ranges, to find the optimal configuration for your specific dataset.
Low average precision Faster R-CNN
조회 수: 2 (최근 30일)
이전 댓글 표시
Hello, I am trying to train Faster R-CNN object detector ... after running this code, the training took 7 hours but average precision is so bad ... one of the classes has average precision zero.. I dont know what´s wrong ...
data = load('airportDatasetGroundTruth3.mat');
LabelData = data.gTruth.LabelData;
% Display first few rows of the data set.
LabelData(1:4,:)
rng(0)
shuffledIndices = randperm(height(LabelData));
idx = floor(0.6 * height(LabelData));
trainingIdx = 1:idx;
trainingDataTbl = LabelData(shuffledIndices(trainingIdx),:);
validationIdx = idx+1 : idx + 1 + floor(0.1 * length(shuffledIndices) );
validationDataTbl = LabelData(shuffledIndices(validationIdx),:);
testIdx = validationIdx(end)+1 : length(shuffledIndices);
testDataTbl = LabelData(shuffledIndices(testIdx),:);
imdsTrain = imageDatastore(trainingDataTbl{:,'imageFilename'});
bldsTrain = boxLabelDatastore(trainingDataTbl(:,2:6));
imdsValidation = imageDatastore(validationDataTbl{:,'imageFilename'});
bldsValidation = boxLabelDatastore(validationDataTbl(:,2:6));
imdsTest = imageDatastore(testDataTbl{:,'imageFilename'});
bldsTest = boxLabelDatastore(testDataTbl(:,2:6));
trainingData = combine(imdsTrain,bldsTrain);
validationData = combine(imdsValidation,bldsValidation);
testData = combine(imdsTest,bldsTest);
data = read(trainingData);
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,'rectangle',bbox);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)
inputSize = [224 224 3];
preprocessedTrainingData = transform(trainingData, @(data)preprocessData(data,inputSize));
numAnchors = 3;
anchorBoxes = estimateAnchorBoxes(preprocessedTrainingData,numAnchors)
featureExtractionNetwork = resnet50;
featureLayer = 'activation_40_relu';
numClasses = width(LabelData)-1;
lgraph = fasterRCNNLayers(inputSize,numClasses,anchorBoxes,featureExtractionNetwork,featureLayer);
augmentedTrainingData = transform(trainingData,@augmentData);
augmentedData = cell(4,1);
for k = 1:4
data = read(augmentedTrainingData);
augmentedData{k} = insertShape(data{1},'rectangle',data{2});
reset(augmentedTrainingData);
end
figure
montage(augmentedData,'BorderSize',10)
trainingData = transform(augmentedTrainingData,@(data)preprocessData(data,inputSize));
validationData = transform(validationData,@(data)preprocessData(data,inputSize));
data = read(trainingData);
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,'rectangle',bbox);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)
options = trainingOptions('sgdm',...
'MaxEpochs',40,...
'MiniBatchSize',2,...
'InitialLearnRate',1e-4,...
'CheckpointPath','E:\ADRIAN\BAKALARKA\DATASET\checkpoint_fasterrcnn',...
'ValidationData',validationData);
doTraining = true;
if doTraining
% Train the Faster R-CNN detector.
% * Adjust NegativeOverlapRange and PositiveOverlapRange to ensure
% that training samples tightly overlap with ground truth.
[detector, info] = trainFasterRCNNObjectDetector(trainingData,lgraph,options, ...
'NegativeOverlapRange',[0 0.3], ...
'PositiveOverlapRange',[0.6 1]);
else
% Load pretrained detector for the example.
pretrained = load('fasterRCNNResNet50EndToEndVehicleExample.mat');
detector = pretrained.detector;
end
I = imread(testDataTbl.imageFilename{1});
I = imresize(I,inputSize(1:2));
[bboxes,scores] = detect(detector,I);
I = insertObjectAnnotation(I,'rectangle',bboxes,scores);
figure
imshow(I)
testData = transform(testData,@(data)preprocessData(data,inputSize));
detectionResults = detect(detector,testData,'MinibatchSize',2);
[ap, recall, precision] = evaluateDetectionPrecision(detectionResults,testData);
recallv = cell2mat(recall);
precisionv = cell2mat(precision);
[r,index] = sort(recallv);
p = precisionv(index);
figure
plot(r,p)
xlabel("Recall")
ylabel("Precision")
grid on
title(sprintf("Average Precision = %.2f",mean(ap)))
function data = augmentData(data)
% Randomly flip images and bounding boxes horizontally.
tform = randomAffine2d('XReflection',true);
sz = size(data{1});
rout = affineOutputView(sz,tform);
data{1} = imwarp(data{1},tform,'OutputView',rout);
% Warp boxes.
data{2} = bboxwarp(data{2},tform,rout);
end
function data = preprocessData(data,targetSize)
% Resize image and bounding boxes to targetSize.
sz = size(data{1},[1 2]);
scale = targetSize(1:2)./sz;
data{1} = imresize(data{1},targetSize(1:2));
%
% Resize boxes.
data{2} = bboxresize(data{2},scale);
end
Average precision is only 33%.... (15% first class, 8% second class, 73% third class, 0% fourth class, 66% fifth class)
댓글 수: 0
답변 (1개)
Rohit
2023년 5월 22일
Hi Adrian,
I understand that you are trying to train Faster RCNN model for object detection but are getting low precision scores.
Here are the potential reasons for the low average precision of your Faster R-CNN object detector:
By addressing these aspects and making appropriate adjustments, you can work towards improving the average precision of your Faster R-CNN object detector.
댓글 수: 0
참고 항목
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!