Error occured while training an rnn model for audio denoising

조회 수: 2 (최근 30일)
Niteesh
Niteesh 2023년 4월 21일
답변: Krishna 2024년 7월 9일
clc;
clear;
close all;
% Load and preprocess the audio data
fs = 16000; % Sampling rate of the audio
[noisy_audio, ~] = audioread('noisy_piano.mp3');
[clean_audio_data, ~] = audioread('piano.mp3');
% Extract left and right channels
left_channel = clean_audio_data(:, 1);
right_channel = clean_audio_data(:, 2);
% Preprocess the audio data if needed
% Reshape the audio data to column vectors
left_channel = reshape(left_channel, [], 1);
right_channel = reshape(right_channel, [], 1);
% Concatenate the left and right channels as input data
clean_audio = [left_channel, right_channel];
% Perform any necessary preprocessing on the audio data, such as normalization, framing, etc.
% Split the data into training and validation sets
train_ratio = 0.8; % Percentage of data to use for training
train_idx = 1 : round(train_ratio * length(noisy_audio));
val_idx = (train_idx(end) + 1) : length(noisy_audio);
noisy_train = noisy_audio(train_idx); % Extract training data
clean_train = clean_audio(train_idx); % Extract training data
noisy_val = noisy_audio(val_idx); % Extract validation data
clean_val = clean_audio(val_idx); % Extract validation data
% Reshape the training data to have feature dimension of 1
noisy_train = reshape(noisy_train, [], 1);
clean_train = reshape(clean_train, [], 1);
noisy_val = reshape(noisy_val, [], 1);
clean_val = reshape(clean_val, [], 1);
% Define the LSTM denoising model
inputSize = 1; % Size of the input (1 for audio signals)
outputSize = 1; % Size of the output (1 for audio signals)
numHiddenUnits = 256; % Number of hidden units in the LSTM layer
layers = [
sequenceInputLayer(inputSize)
lstmLayer(numHiddenUnits, 'OutputMode', 'sequence')
fullyConnectedLayer(outputSize)
regressionLayer
];
% Define the training options
miniBatchSize = 128; % Mini-batch size for training
numEpochs = 50; % Number of epochs for training
options = trainingOptions('adam', ...
'MiniBatchSize', miniBatchSize, ...
'MaxEpochs', numEpochs, ...
'ValidationData', {noisy_val, clean_val}, ...
'ValidationFrequency', floor(length(noisy_train) / miniBatchSize), ...
'Plots', 'training-progress');
% Train the LSTM denoising model
denoising_net = trainNetwork(noisy_train, clean_train, layers, options);
% Use the trained LSTM denoising model to denoise the test data
denoised_audio = predict(denoising_net, noisy_audio);
% Perform any necessary post-processing on the denoised audio data
% Evaluate the denoising performance, e.g., by computing signal-to-noise ratio (SNR) or other relevant metrics
Error using trainNetwork
The training sequences are of feature dimension 9637171 but the input layer expects sequences of feature dimension 1.
Error in rnn_model (line 67)
denoising_net = trainNetwork(noisy_train, clean_train, layers, options);

답변 (1개)

Krishna
Krishna 2024년 7월 9일
Hello Niteesh,
The issue seems to lie in how you've structured your data for LSTM training. The input to the time series network should be organized in a cell array format of nx1 for both input and output, where n represents the total number of sequences we have.
Further, for one sequence if the input has 4 features, it should be formatted as 4x(sequence length), where the sequence length can vary. If the output is a single value, it should be in the format 1x (sequence length).
In your current case with only one sequence, n equals 1. Please review the following example to understand more about how to structure your data correctly for training.
Also, after reviewing the code, I noticed that you have only one sequence. Currently, you're using 80 percent of this sequence for training directly, while the remaining 20 percent is used for testing.
Passing the entire sequence at once may not be efficient because the entire dataset needs to be processed and predicted before the optimization and backpropagation algorithms can be applied. Instead, consider dividing the sequence into multiple smaller sequences. For example, if the length of one sequence is 10, try predicting the 11th value of the sequence. By structuring your data this way, you can potentially improve the results.
Additionally, consider incorporating more sequences into your dataset instead of relying solely on one sequence. Using the same sequence for both training and testing isn't efficient, as it can lead to overfitting. While your test results may appear satisfactory, they may not generalize well to real-world noise cancellation applications.
Hope this helps.

카테고리

Help CenterFile Exchange에서 AI for Signals에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by