필터 지우기
필터 지우기

mnist classification using batch method

조회 수: 2 (최근 30일)
nadia
nadia 2015년 11월 29일
편집: Greg Heath 2015년 12월 5일
Hi. I want to train a neural network with mnist database using batch method. I use below code but my accuracy is very low. but I think the code is correct. can any one help me please?
function [hiddenWeights, outputWeights, error] = train_network_batch(numberOfHiddenUnits, input, target, epochs, batchSize, learningRate,lambda)
% The number of training vectors.
trainingSetSize = size(input, 2);
% Input vector has 784 dimensions.
inputDimensions = size(input, 1);
% We have to distinguish 10 digits.
outputDimensions = size(target, 1);
% Initialize the weights for the hidden layer and the output layer.
% hiddenWeights = randn(NHiddenUnit, inputDimensions)*1/sqrt(size(input, 1));
% outputWeights = randn(outputDimensions, NHiddenUnit)*1/sqrt(size(input, 1));
hiddenWeights = rand(numberOfHiddenUnits, inputDimensions);
outputWeights = rand(outputDimensions, numberOfHiddenUnits);
hiddenWeights = hiddenWeights./size(hiddenWeights, 2);
outputWeights = outputWeights./size(outputWeights, 2);
hiddenWeights_store = hiddenWeights;
outputWeights_store = outputWeights;
n = zeros(batchSize,1);
validation_count=0;
validation_accuracy=0;
figure; hold on;
%batch method
for t = 1: epochs
for k = 1: batchSize
% Select which input vector to train on.
n(k) = floor(rand(1)*trainingSetSize + 1);
% n(k) =k;
% Propagate the input vector through the network.
inputVector = input(:, n(k));
hiddenActualInput = hiddenWeights*inputVector;
hiddenOutputVector = linear_func(hiddenActualInput);
outputActualInput = outputWeights*hiddenOutputVector;
outputVector = linear_func(outputActualInput);
targetVector = target(:, n(k));
% Backpropagate the errors.
outputDelta = dlinear_func(outputActualInput).*(outputVector - targetVector);
hiddenDelta = dlinear_func(hiddenActualInput).*(outputWeights'*outputDelta);
% outputWeights_store = outputWeights_store -(learningRate*lambda/batchSize).*outputWeights- learningRate.*outputDelta*hiddenOutputVector'; hiddenWeights_store = hiddenWeights_store -(learningRate*lambda/batchSize).*hiddenWeights-learningRate.*hiddenDelta*inputVector';
% outputWeights =(1-(learningRate*lambda/batchSize)).*outputWeights - learningRate.*outputDelta*hiddenOutputVector'; % hiddenWeights = (1-(learningRate*lambda/batchSize)).*hiddenWeights - learningRate.*hiddenDelta*inputVector';
end;
outputWeights=outputWeights+(outputWeights_store./batchSize);
hiddenWeights=hiddenWeights+(hiddenWeights_store./batchSize);
outputWeights_store=0;
hiddenWeights_store=0;
% %*********************************end of batch method*************** % Calculate the error for plotting. error = 0; for k = 1: batchSize inputVector = input(:, n(k)); targetVector = target(:, n(k));
error = error + norm(linear_func(outputWeights*linear_func(hiddenWeights*inputVector)) - targetVector, 2);
end;
error = error/batchSize;
plot(t, error,'*');
title(['MSE_ batch','NH= ',num2str(numberOfHiddenUnits),',',' alfa=',num2str(learningRate),' ,epoch=',num2str(epochs)]);
xlabel('epoch');
ylabel('cost');
inputValues=load('validation.mat');
inputValues=inputValues.v;
labels=load('label.mat');
labels=labels.l;
[correctlyClassified, classificationErrors]=validation_network(hiddenWeights,outputWeights,inputValues',labels);
correctlyClassified=correctlyClassified/10000;
if correctlyClassified<= validation_accuracy
validation_count=validation_count+1;
else
validation_count=0;
end
if validation_count>7
break;
end
validation_accuracy=correctlyClassified;
end;
end

채택된 답변

Greg Heath
Greg Heath 2015년 12월 5일
편집: Greg Heath 2015년 12월 5일
1. I don't think that anyone wants to wade through all of that code when you can just use MATLAB classification functions
help PATTERNNET
doc PATTERNNET
2. If none of your hidden or output functions is nonlinear, then all you have is a complicated linear classifier which can be implemented with BACKSLASH.
Hope this helps.
Thank you for formally accepting my answer
Greg

추가 답변 (0개)

태그

아직 태그를 입력하지 않았습니다.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by