Create plot with gradient descend vectors for function of 2 variables

조회 수: 37 (최근 30일)
Askic V
Askic V 2024년 2월 1일
답변: Austin M. Weber 2024년 2월 1일
Hello,
I would like to extend this code to a function of two variables z = x.^2+y.^2.
I have the following code for a function y = f(x). Gradient is a vector that shows the direction of the fastet rise of the value, while gradient descend vector shows the direction of the steepest fall.
% Define the function
f = @(x) 2*x.^2 - 4*x - 2;
grad_f = @(x) 4*x - 4;
% Plot the function
x = linspace(-5, 5, 100);
y = f(x);
figure;
plot(x, y, 'LineWidth', 1.5);
hold on;
% Set the starting point for gradient descent
start_x = -2;
start_y = f(start_x);
% Plot the starting point
scatter(start_x, start_y, 100, 'r', 'filled');
% Points to show gradient vectors
points_to_show = -2;
% Gradient descent parameters
alpha = 0.5;
deltax = 0.5;
for i = 1:numel(points_to_show)
x = points_to_show(i);
% Plot the point
scatter(x, f(x), 100, 'g', 'filled');
% Calculate the gradient at the current point using the gradient function
gradient_at_x = grad_f(x);
% Plot the gradient vector at the specified point
quiver(x, f(x), -alpha*(x+deltax), -alpha*gradient_at_x*(x+deltax), 'Color', 'r', 'LineWidth', 1.5);
end
hold off;
% Add labels and title
xlabel('X-axis');
ylabel('Y-axis');
title('Gradient Descent Vectors');
% Adjust axis limits if needed
grid on;
This is what I have done so far, but I'm not really able to figure out the way to plot the vector. Can you give advice or hints? What I noticed by searching the web is that gradient descent vectors are actually plotted on contours and not on the surface itself.
% Define the function
z = @(x, y) x.^2 + y.^2;
% Point of interest
point_x = -1;
point_y = -1;
point_z = z(point_x, point_y);
deltax = 0.5;
deltay = 0.5;
% Calculate the gradient analytically
gradient_x = 2*point_x;
gradient_y = 2*point_y;
gradient_z = 0;
% Plot the surface
[x_vals, y_vals] = meshgrid(linspace(-2, 2, 50), linspace(-2, 2, 50));
z_vals = z(x_vals, y_vals);
surf(x_vals, y_vals, z_vals);
hold on;
% Plot the point of interest
scatter3(point_x, point_y, point_z, 100, 'r', 'filled');
% Plot the gradient vector at the specified point
quiver3(point_x, point_y, point_z, gradient_x, gradient_y, gradient_z, 'Color', 'r', 'LineWidth', 2);
hold off;
% Add labels and title
xlabel('X-axis');
ylabel('Y-axis');
zlabel('Z-axis');
Thank you.

답변 (1개)

Austin M. Weber
Austin M. Weber 2024년 2월 1일
My apologies for the double-response. I accidentally answered your question as a comment rather than as an answer.
If I am understanding correctly, I think I have found a simple way of doing what you want.
First, let me copy-paste your code without the quiver3 function:
% Define the function
z = @(x, y) x.^2 + y.^2;
% Point of interest
point_x = -1;
point_y = -1;
point_z = z(point_x, point_y);
deltax = 0.5;
deltay = 0.5;
% Calculate the gradient analytically
gradient_x = 2*point_x;
gradient_y = 2*point_y;
gradient_z = 0;
% Plot the surface
[x_vals, y_vals] = meshgrid(linspace(-2, 2, 50), linspace(-2, 2, 50));
z_vals = z(x_vals, y_vals);
surfc(x_vals, y_vals, z_vals,'EdgeColor','none','FaceColor','interp','FaceAlpha',0.7);
hold on;
% Plot the point of interest
scatter3(point_x, point_y, point_z, 30, 'r', 'filled');
% Add labels and title
xlabel('X-axis');
ylabel('Y-axis');
zlabel('Z-axis');
view(-22,6)
I swapped the surf function for surfc which plots a contour map underneath the surface plot. I also got rid of the edge lines and interpolated the colors to make the visalization less busy. You can revert to your original surf plot if you prefer. Moreover, I used the view function to change the azimuth angle to get a different perspective of the axes.
Now, to add the vector arrow, I am going to calculate an infinitesimally small change for each x, y, and z coordinate of your point of interest. The change in slope at this point is what I am going to use to define the vector arrow.
% Infinitesimally small change
infchange = 0.0000000001;
point_xinf = point_x + infchange;
point_yinf = point_y + infchange;
point_zinf = z(point_xinf, point_yinf);
% Calculate the difference relative to the original point
dx = point_xinf - point_x;
dy = point_yinf - point_y;
dz = point_zinf - point_z;
% Add vector arrow to map
quiver3(point_x,point_y,point_z,...
dx*20e8,dy*20e8,dz*20e8,'Color','r',...
'LineWidth',1,...
'ShowArrowHead','on',...
'MaxHeadSize',0.6)

카테고리

Help CenterFile Exchange에서 Surface and Mesh Plots에 대해 자세히 알아보기

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by