Speeding up numerical gradient of tensor with FFTs

조회 수: 2 (최근 30일)
Shreyas Bharadwaj
Shreyas Bharadwaj 2024년 7월 18일
답변: Sahas 2024년 7월 19일
I have a gradient problem that involves the following function: adding a phase to each column in a matrix and computing the FFT of each column, and aggregating all FFTs into a new matrix.
My brute-force numerical gradient is adding a small phase to each column, iteratively, and then computing the loss by comparing it with some known matrix. currX is the current guess of phases, Ts2pHH is the matrix to whose columns these phases are added, and the function is computing by taking a sum along 2 dimensions of the intensity and then adding it up.
I want to know if this can be done more efficiently, because for my matrix size (~1000x1000) this takes around 2 minutes, which is very slow.
My code is shown below:
for k = 1:length(currX)
currX_perturbed = currX;
currX_perturbed(k) = currX_perturbed(k) + epsilon;
phases_perturbed = exp(1i * [0, currX_perturbed]);
Tcorr_perturbed = Ts2pHH .* phases_perturbed;
TcorrFFT_perturbed = fftshift(fft(fft(fft(fft(reshape(Tcorr_perturbed, [Npx, Npx, Nin, Nin]), [], 3), [], 4), [], 1), [], 2));
inputFreq_perturbed = squeeze(sum(sum(abs(TcorrFFT_perturbed).^2, 1), 2));
gradient(k) = gradient(k) + (-sum(inputFreq_perturbed .* support, 'all') - loss) / epsilon;
end

답변 (1개)

Sahas
Sahas 2024년 7월 19일
As per my understanding, you would like to optimize the code provided so that it executes efficiently for bigger sized data inputs.
Since I am not sure of the entire algorithm, providing an algorithmic optimization is challenging.
However, MATLAB provides a number of code optimization techniques and code writing strategies such as “vectorization”, “parallelization”, “pre-allocation”. The execution time can be reduced by incorporating combination of these techniques. More information about such methods can be found in the below links:
https://www.mathworks.com/help/coder/ug/optimize-generated-code.html -- Optimization strategies for various scenarios when writing code
https://www.mathworks.com/help/matlab/matlab_prog/vectorization.html --- Basics of “vectorization” method of coding in MATLAB
Please find the attached code for reference on how to use the “pre-allocation” technique.
% Pre-allocation Technique
currX_perturbed = currX;
phases_perturbed = zeros(numElements + 1, numElements);
Tcorr_perturbed = zeros(Npx, Npx, Nin, Nin, numElements);
TcorrFFT_perturbed = zeros(Npx, Npx, Nin, Nin, numElements);
inputFreq_perturbed = zeros(Nin, Nin, numElements);
Below is a sample implementation on how to use the “vectorization” method. Please note that “vectorization” is a memory-intensive method and comes with a tradeoff, it might go out-of-memory for larger inputs.
% Vectorization Method
function [gradient, elapsedTime] = compute_gradient(currX, Ts2pHH, support, epsilon)
% Start timing
tic;
% Compute the original phases and Tcorr
phases = exp(1i * currX);
phases = reshape(phases, [1, 1, 1, length(currX)]);
Tcorr = Ts2pHH .* phases;
TcorrFFT = fftshift(fft(fft(fft(fft(Tcorr, [], 3), [], 4), [], 1), [], 2));
inputFreq = squeeze(sum(sum(abs(TcorrFFT).^2, 1), 2));
loss = sum(inputFreq .* support, 'all');
% Create a matrix of perturbed phases
perturbed_phases = exp(1i * (currX + epsilon * eye(length(currX))));
perturbed_phases = reshape(perturbed_phases, [1, 1, length(currX), length(currX)]);
% Apply perturbed phases to Ts2pHH
Tcorr_perturbed = Ts2pHH .* permute(perturbed_phases, [1, 2, 4, 3]);
% Compute FFT for all perturbed matrices
TcorrFFT_perturbed = fftshift(fft(fft(fft(fft(Tcorr_perturbed, [], 3), [], 4), [], 1), [], 2));
% Compute the input frequencies for all perturbed matrices
inputFreq_perturbed = squeeze(sum(sum(abs(TcorrFFT_perturbed).^2, 1), 2));
% Compute the gradient
gradient = (-sum(inputFreq_perturbed .* support, 1) - loss) / epsilon;
% Stop timing
elapsedTime = toc;
% Display the elapsed time
fprintf('Elapsed time: %.2f seconds\n', elapsedTime);
end
% TESTBENCH
n = 170;
currX = rand(1, n); % Example current phases
Ts2pHH = rand(n, n, n, n); % Example matrix
support = rand(n, n); % Example support matrix
epsilon = 1e-6;
[gradient, elapsedTime] = compute_gradient(currX, Ts2pHH, support, epsilon);
disp(gradient);
Hope this is beneficial!

카테고리

Help CenterFile Exchange에서 Performance and Memory에 대해 자세히 알아보기

제품


릴리스

R2023b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by