The adamupdate function in MATLAB R2024b incorrectly uses uint32 with sqrt and exhibits state corruption, causing errors even in minimal test cases."

조회 수: 3 (최근 30일)
% Test adamupdate function
clc;
clear;
% Define test parameters
learnable = dlarray(randn(5, 1)); % Example learnable parameter
gradient = dlarray(randn(5, 1)); % Example gradient
state = []; % Initial state (empty)
optimizer = trainingOptions('adam', 'InitialLearnRate', 0.01); % Example optimizer
timeStep = uint32(1); % Initial time step
try
% Perform a single adamupdate
updatedLearnable = adamupdate(learnable, gradient, state, optimizer, timeStep);
% Display results
disp('adamupdate test successful!');
disp('Updated Learnable:');
disp(updatedLearnable);
catch ME
% Display error message
disp('adamupdate test failed!');
disp(['Error: ', ME.message]);
%Display the stack trace to help Mathworks track the problem.
disp('Stack Trace:');
disp(ME.stack);
end
adamupdate test failed!
Error: Undefined function 'sqrt' for input arguments of type 'uint32'.
Stack Trace:
9x1 struct array with fields: file name line
% Perform a second adam update to test state persistence.
timeStep = uint32(2);
try
% Perform a single adamupdate
updatedLearnable = adamupdate(learnable, gradient, state, optimizer, double(timeStep));
% Display results
disp('adamupdate second test successful!');
disp('Updated Learnable:');
disp(updatedLearnable2);
catch ME
% Display error message
disp('adamupdate second test failed!');
disp(['Error: ', ME.message]);
%Display the stack trace to help Mathworks track the problem.
disp('Stack Trace:');
disp(ME.stack);
end
adamupdate second test failed!
Error: dlarray is supported only for full arrays of data type double, single, or logical, or for full gpuArrays of these data types.
Stack Trace:
12x1 struct array with fields: file name line
  댓글 수: 1
Chika
Chika 2025년 3월 18일
error message"
:
adamupdate test failed!
Error: Undefined function 'sqrt' for input arguments of type 'uint32'.
Stack Trace:
7×1 struct array with fields:
file
name
line
adamupdate second test failed!
Error: dlarray is supported only for full arrays of data type double, single, or logical, or for full gpuArrays of these data types.
Stack Trace:
10×1 struct array with fields:
file
name
line

댓글을 달려면 로그인하십시오.

채택된 답변

Joss Knight
Joss Knight 2025년 3월 22일
Well, I admit the error messages aren't very helpful but the basic problem is that passing a trainingOptions object in as an argument to adamupdate is not supported. See the documentation for the correct syntax.
  댓글 수: 1
Chika
Chika 2025년 3월 22일
I am extremely grateful to Joss Knight for pointing out the error and his advis for me to look at the documentation for adamupdate function.

댓글을 달려면 로그인하십시오.

추가 답변 (1개)

Chika
Chika 2025년 3월 22일
% corrected code following the documentation as advised by Joss Knight
% Test adamupdate function (Built-in)
clc;
clear;
% Define test parameters
learnable = dlarray(randn(5, 1)); % Example learnable parameter
gradient = dlarray(randn(5, 1)); % Example gradient
averageGrad = zeros(size(learnable)); % Initialize average gradient
averageSqGrad = zeros(size(learnable)); % Initialize average squared gradient
iteration = 1; % Initial iteration
try
% Perform a single adamupdate
[updatedLearnable, averageGrad, averageSqGrad] = adamupdate(learnable, gradient, averageGrad, averageSqGrad, iteration);
% Display results
disp('adamupdate test successful!');
disp('Updated Learnable:');
disp(updatedLearnable);
disp('Average Gradient:');
disp(averageGrad);
disp('Average Squared Gradient:');
disp(averageSqGrad);
catch ME
% Display error message
disp('adamupdate test failed!');
disp(['Error: ', ME.message]);
%Display the stack trace.
disp('Stack Trace:');
disp(ME.stack);
end
adamupdate test successful!
Updated Learnable:
5x1 dlarray -0.7648 -1.0165 -0.0125 -0.5996 0.4997
Average Gradient:
5x1 dlarray 0.1693 -0.0385 0.0958 -0.0383 0.0295
Average Squared Gradient:
5x1 dlarray 0.0029 0.0001 0.0009 0.0001 0.0001
% Perform a second adam update to test state persistence.
iteration = 2;
try
% Perform a second adam update, passing in the updated state
[updatedLearnable2, averageGrad2, averageSqGrad2] = adamupdate(learnable, gradient, averageGrad, averageSqGrad, iteration);
% Display results
disp('adamupdate second test successful!');
disp('Updated Learnable:');
disp(updatedLearnable2);
disp('Average Gradient:');
disp(averageGrad2);
disp('Average Squared Gradient:');
disp(averageSqGrad2);
catch ME
% Display error message
disp('adamupdate second test failed!');
disp(['Error: ', ME.message]);
%Display the stack trace.
disp('Stack Trace:');
disp(ME.stack);
end
adamupdate second test successful!
Updated Learnable:
5x1 dlarray -0.7648 -1.0165 -0.0125 -0.5996 0.4997
Average Gradient:
5x1 dlarray 0.3217 -0.0732 0.1819 -0.0727 0.0560
Average Squared Gradient:
5x1 dlarray 0.0057 0.0003 0.0018 0.0003 0.0002

카테고리

Help CenterFile Exchange에서 Image Data Workflows에 대해 자세히 알아보기

제품


릴리스

R2024b

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by