필터 지우기
필터 지우기

Classification Experiment Design - stopping training

조회 수: 2 (최근 30일)
Dink
Dink 2013년 4월 19일
Hi,
I have a question (or several) regarding experiment design using matlab to conduct binary classification. I have a "design" set (for training and validation) and a separate test set to evaluate generalisation. My problem is when to stop training and apply the resulting net to my test set. From the nnet faq "Statisticians tend to be skeptical of stopped training because it appears to be statistically inefficient due to the use of the split-sample technique". So I use cross-validation by the following general method:
  • for a given node value (H) I create a 100 random starting weight sets (S)
  • for each (S) I randomly divide the design set into k equally sized, mutually exclusive subsets and train K nets using K(i) as the validation set and K-K(i) as the training set.
  • Each net is trained to stop at mse_goal = 1e-6
  • I evaluate validation error for each K(i) and retrain the relevant net to the number of epochs where validation error was lowest ?
  • (do I need to do this or can I somehow select/return the net with the weights trained to this epoch from the [net tr] output) ?????
  • I apply this net to my test set to evaluate generalisation
  • the net from the set of S*K gives with the lowest generalisation error gives me the best trained net for my given H using my available data
Does this make sense ??
  댓글 수: 1
Greg Heath
Greg Heath 2013년 4월 20일
You failed to give 4 important values
1. N-size of data set
2. I-input dimensionality
3. O-output dimensionality
4. MSE00-mean target variance mean(var(target'))

댓글을 달려면 로그인하십시오.

채택된 답변

Greg Heath
Greg Heath 2013년 4월 20일
1. Initialize the RNG before the H loop and record the current RNG seed at the beginning of each inner loop. You can retrain any individual net by knowing the corresponding values of H, ntrial and SEED.
2. Train with k-2 nets Using MSEtrngoal = max(0,0.01*Ndof*MSE00a/Ntrneq).
3. Record MSEtrn and MSEtrna for the combined training set and BOTH MSEs of the nontraining subsets at the MSE minima of EACH nontraining set:
a. MSEk-1 at the minimum of MSEk is an unbiased estimate of generalization
error.
b. So is MSEk at the minimum of MSEk-1.
c. So is their average MSEtst = ( MSEk + MSEk-1)/2
4. Calculate the summary stats (min,median,mean,stdv,max)of R2trn, R2trna, R2tst and plot as a function of H.
Hope this helps.
Thank you for formally accepting my answer
Greg

추가 답변 (0개)

카테고리

Help CenterFile Exchange에서 Sequence and Numeric Feature Data Workflows에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by