MATLAB Answers

What is the data "YTrain" in the Matlab CNN Regression example?

조회 수: 4(최근 30일)
Sho Wright
Sho Wright 2021년 2월 4일
댓글: Sho Wright 2021년 3월 28일
I've been following through this Matlab regression example where the network is trained to recognise the varying rotations in handwritten digits. I wanted to explore the results further and so found the root files in the program directory - inside, there are 10 subfolders (0-9) containing a total of 10,000 images and 2 excel sheets, "digitTest" and "digitTrain". These excel sheets are each 5000 rows of 3 columns for image file name, digit and rotated angle respectively. After running the code in the example myself and comparing the results, I can see that the response YTrain is the same as the excel file "digitTrain" and the response YValidation is the same as the excel file "digitTest". Later on in the post-processing of the data, a YPrediction and hence prediction error is calculated as so:
YPredicted = predict(net,XValidation);
predictionError = YValidation - YPredicted;
These 3 seperate responses have confused me and I'm looking for some clarification. From my understanding, validation data consists of the true values that are used to compare against the responses during training, in order to gain a rough estimation of how accurate a given network is. This makes sense as prediction error is the difference between the true and predicted values. I am not altogether sure what YTrain is; if this is meant to respresent the training responses, then why is there an excel sheet with pre-defined responses within the program directory already? What is YTrain representing, and if I were to train my own network would I need to generate a similar YTrain alongside my YValidation?

채택된 답변

Srivardhan Gadila
Srivardhan Gadila 2021년 2월 6일
As you have mentioned already, the images in the subfolders (0-9) contain the Input images of both Training (XTrain) and Validation (XValidation) images, similarly the Excel sheet contains the Response/Target data which the network is supposed to predict. The functions digitTrain4DArrayData and digitTest4DArrayData returns the images and their corresponding rotated angles as X and Y data respectively. Then XTrain and YTrain are used to train the network. During training phase the network performs forward pass on the XTrain data and the network produces some response data, then the loss is computed between the predicted response and YTrain followed by gradients, backward pass etc. Whereas XValidation and YValidation are used to validate the network while training and after the trianing is completed.
In order to train your own network with the same dataset then you can directly use the existing dataset by referring to the Load Data section of the example. If you have to train your own network on different images other than what is used in the example then, yes you have to generate the corresponding Y data (YTrain and YValidation) accordingly.
  댓글 수: 2
Sho Wright
Sho Wright 2021년 3월 28일
Hello again,
I have returned to this question after finding some difficulty in training my network, and was hoping you could guide me if possible. Since writing the original question, I have so far collected my Training (XTrain) and Validation (XValidation) images. Furthermore, I also collected my "true" values, which I formatted as my 'digitTest' file. I am unsure what response values I put in my 'digitTrain' file.
When asking my supervisor about this, they stated that there will not be any initial response values in my 'digitTrain', and that this column will be filled after early training epochs, based upon the validation data in 'digitTest'. This did not entirely make sense to me, so I tried it and got no luck - I can't place true values, as that is my validation data, and I can't put any predicted values, as none are collected so far. I also tried leaving my response column empty, and got the error "Responses must not contain NaNs".
For now, I have replaced the response column in 'digitTrain', randomized values in a similar range to my validation data. This works, and the RMSE drops to as low at 0.4, but I believe that this is not the correct way to train my network.
I think my question is, what kind of response data am I meant to use for 'digitTrain'?

댓글을 달려면 로그인하십시오.

추가 답변(0개)

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by