필터 지우기
필터 지우기

Reducing overfitting in Neural networks

조회 수: 11 (최근 30일)
Daniel
Daniel 2017년 6월 8일
답변: Greg Heath 2017년 6월 9일
I am using the Matlab neural network toolbox in order to train an ANN. From past experience, implementing cross validation when working with ML algorithms can help reduce the problem of overfitting, as well as allowing use of your entire available dataset without adding bias.
My question is; is there any advantage to implementing k-fold cross validation when using the NN toolbox, or are overfitting and bias mitigated by the implementation already (e.g. in the 'trainbr' mode)

채택된 답변

Greg Heath
Greg Heath 2017년 6월 9일
K-FOLD CROSS-VALIDATION IS NOT A CURE FOR THE ILLS OF AN OVERFIT NET.
BACKGROUND:
1. OVERFITTING:
Nw > Ntrneq
where
Nw = number of unknown weights
Ntrneq = number of training equations
2. GENERALIZATION:
If y0 is a solution to the system of equations f(x0,y) = 0, then the system
generalizes well if
0 = f(x0 + dx , y) ==> y = y0 + dy
for NON-INFINITESMALLY SMALL dx and dy
3. NONUNIQENESS and INSTABILITY
Solutions to over-fit systems are, typically not unique. More importantly,
the non-uniqueness can lead to the instability and poor generalization of
iterative solutions.
Typically, there are an infinite number of solutions to an overfit system
of equations. However, many of the solutions do not generalize well. For example,
iterative solutions to an overfit system can lead to solutions that are
inappropriate. I call this problem
4. OVERTRAINING AN OVERFIT NET
There are several approaches to avoid overtraining an overfit net:
a. NONOVERFITTING: Do not overfit the net in the first place by using the rule
Ntrneq >= Nw
b. STOPPED TRAINING: Use train/val/test data division and STOP TRAINING when the
validation subset error increases, continually, for a prespecified
(MATLAB default is 6) number of epochs. This technique is used in the
LEVENBURG-MARQUARDT and CONJUGATE-GRADIENT training functions TRAINLM
and TRAINCG, respectively.
c. BAYESIAN REGULARIZATION: Constrain the size of the weights by adding to the
minimization function a penalty term proportional to the weights squared
Euclidean norm. Although this technique is the default in the training
function TRAINBR, it can be specified with other training functions.
5. Perhaps you confused the k-fold CROSS-VALIDATION with DATA DIVISION STOPPED TRAINING as a technique to avoid overtraining an overfit net.
It is not. See below.
6. K-FOLD CROSSVALIDATION
a. This widely known technique is not offered in the MATLAB NN TOOLBOX
b. Nontheless, my use of the CROSSVAL and CVPARTITION functions from
if true
% code
endother toolboxes
can be found in both the NEWSGROUP and ANSWERS by including "greg" as a
searchword with cross validation, cross-validation and crossvalidation
7. However, instead of using other toolboxes to implement k-fold crossvalidation, I compensate
by using m multiple designs (typically 10 <= m <= 30) that only differ by a random division of
training, validation and test subsets in addition to the default random selection of initial weights.
8. My technique is trivial to implement:
Given an I-H-O net topology
a. Initialize the random number generator so that designs can be duplicated.
b. Store the current state of the RNG at the beginning of the loop so that any
design can be recreated at a later date without regenerating the others.
c. Design a net and store the performance results (e.g., Normalized Mean
Square Error NMSE). Storing the net is not necessary since it is easily redesigned
given the stored state of the RNG.
Hope this helps.
Thank you for formally accepting my answer
Greg

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Deep Learning Toolbox에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by