Error using trainNetwork (line 191) Too many input arguments.

조회 수: 7 (최근 30일)
Ayat 2024년 7월 31일
댓글: Ayat 2024년 8월 16일
Hello, i am trying to code an automatic detection of alzheimer from EEG signals but my code has an error when using trainNetwork. It worked perfectely with a SVM but doesn't with a CNN. I tried looking online but nothing seems too work. I got this error :
Error using trainNetwork (line 191)
Too many input arguments.
Error in CNN (line 178)
net = trainNetwork(X_train, y_train, layers, options);
Caused by:
Error using gather
Too many input arguments.
Does anyone have an idea. Here is the part of my code that produce the CNN :
X = all_features{:, 1:end-1};
y = all_features.Label;
y = categorical(y);
disp(['Feature matrix dimensions: ', num2str(size(X))]);
disp(['Labels vector dimensions: ', num2str(size(y))]);
X = zscore(X);
numFeatures = size(X, 2);
numObservations = size(X, 1);
X = reshape(X, [numObservations, numFeatures, 1, 1]);
layers = [
imageInputLayer([numFeatures 1 1])
convolution2dLayer([3 1], 8, 'Padding', 'same')
maxPooling2dLayer([2 1], 'Stride', 2)
convolution2dLayer([3 1], 16, 'Padding', 'same')
options = trainingOptions('adam', ...
'MaxEpochs', 30, ...
'MiniBatchSize', 16, ...
'InitialLearnRate', 0.001, ...
'ValidationFrequency', 10, ...
'Verbose', false, ...
'Plots', 'training-progress');
cv = cvpartition(y, 'KFold', 5, 'Stratify', true);
accuracies = zeros(cv.NumTestSets, 1);
confusion_matrices = cell(cv.NumTestSets, 1);
for k = 1:cv.NumTestSets
train_idx = training(cv, k);
test_idx = test(cv, k);
X_train = X(train_idx, :, :, :);
y_train = y(train_idx);
X_test = X(test_idx, :, :, :);
y_test = y(test_idx);
net = trainNetwork(X_train, y_train, layers, options);
y_pred = classify(net, X_test);
confusion_matrices{k} = confusionmat(y_test, y_pred);
cm = confusion_matrices{k};
accuracies(k) = sum(diag(cm)) / sum(cm(:));
mean_accuracy = mean(accuracies);
fprintf('Mean Accuracy across 5 folds: %.2f%%\n', mean_accuracy * 100);
save('eeg_cnn_classifier_cv.mat', 'net');
disp('Confusion Matrices for each fold:');
for k = 1:cv.NumTestSets
disp(['Fold ', num2str(k), ':']);

채택된 답변

Shantanu Dixit
Shantanu Dixit 2024년 8월 6일
편집: Shantanu Dixit 2024년 8월 6일
Hi Ayat, it seems that you are facing issue while calling the trainNetwork using CNN. The issue lies in how the data transformation is done using reshape.
If you perform an operation (e.g., reshape) that creates trailing singleton (1) dimensions beyond the second dimension, MATLAB will automatically remove those dimensions from the resulting variable.
% Does not change the dimensions of X.
X = reshape(X,[numObservations, numFeatures,1, 1]);
Here reshaping can be done as follows:
X = reshape(X,[numFeatures, 1, 1, numObservations])
Subsequently the way X_train, X_test are accessed needs to be changed.
%% access X_train, X_test as follows
for k = 1:cv.NumTestSets
train_idx = training(cv, k);
test_idx = test(cv, k);
X_train = X(:, :, :, train_idx);
y_train = y(train_idx);
X_test = X(:, :, :, test_idx);
y_test = y(test_idx);
net = trainNetwork(X_train, y_train, layers, options);
y_pred = classify(net, X_test);
confusion_matrices{k} = confusionmat(y_test, y_pred);
cm = confusion_matrices{k};
accuracies(k) = sum(diag(cm)) / sum(cm(:));
Refer to the below links from the forum and MathWorks documentation for more information
  댓글 수: 1
Ayat 2024년 8월 16일
Sorry to answer this late, indeed the code now works with this method. I actually managed to find this alone by experiencing with my code but now i also know the reason why thanks to you. Thanks a lot.

댓글을 달려면 로그인하십시오.

추가 답변 (0개)




Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by