How dllarray works in Matlab

If I have a data with dimension = 1024x1. When I predict the result using a trained network net which takes input 1(C)x1(B)x2048(T) and after that when I check the mse from two different method why do they generate two different answers?
A = dlarray(data,'TCB')
B = predict(net,A);
loss = mse(A,B)
loss = 28.6
loss = mse(squeeze(extractdata(B)),squeeze(extractdata(A)));
loss = 0.026
I need to work on dlarray for autodifferntiation. Please someone guide me

댓글 수: 4

Matt J
Matt J 2024년 9월 7일
편집: Matt J 2024년 9월 7일
If it's only 1024x1, attach A and B in a .mat file so we can examine it.
SYED
SYED 2024년 9월 7일
I am attaching data file and a trained model
SYED
SYED 2024년 9월 7일
I uploaded the required file
Matt J
Matt J 2024년 9월 7일
편집: Matt J 2024년 9월 7일
I can't load the net variable I'm afraid. And the file does not contain a 1024x1 variable. Please attach a file just with the A and B variables mentioned in your original post.

댓글을 달려면 로그인하십시오.

답변 (1개)

Sahas
Sahas 2024년 9월 9일

0 개 추천

Hi @SYED,
In the first method mentioned, MATLAB’s “mse” function takes two formatted “dlarray” objects as input and computes the MSE. But in the second method, the “mse” function takes two standard datatypes, which are the underlying datatypes of the “dlarrray” objects and computes the MSE.
The MathWorks documentation for “mse” function states that the first input argument, “prediction”, must be a formatted or unformatted “dlarray” object. When “extractdata” function is used in method 2, the output datatype is the underlying datatype of the input “dlarray” object which results in an incorrect MSE.
Refer to the following MathWorks documentation for more information on input arguments of the “mse” function: https://www.mathworks.com/help/deeplearning/ref/dlarray.mse.html
I suggest using the first method to calculate MSE. Refer to the following code snippet for reference:
clc
A = dlarray(data, 'TCB')
B = predict(net,A)
loss = mse(A,B) %Correct usage
% A2 = squeeze(extractdata(A))
% B2 = squeeze(extractdata(B))
Atemp = extractdata(A)
Btemp = extractdata(B)
losstemp = mse(Btemp, Atemp) %Incorrect usage
% losstemp = mse(B, Atemp) %Correct usage
A2 = squeeze(A)
B2 = squeeze(B)
loss2 = mse(A2, B2) %Alternate correct use
For more information on the usage of “extractdata” function and data formats of “dlarray” objects, refer to the following MathWorks documentation links:
Hope this is beneficial!

카테고리

도움말 센터File Exchange에서 Deep Learning Toolbox에 대해 자세히 알아보기

제품

릴리스

R2024a

질문:

2024년 9월 7일

답변:

2024년 9월 9일

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by