필터 지우기
필터 지우기

How to plot averaged ROC curve?

조회 수: 5 (최근 30일)
Ishfaque Ahmed
Ishfaque Ahmed 2022년 4월 20일
댓글: Adam Danz 2022년 4월 25일
I am trying to plot ROC curve for my model for multiple iterations. The curve are not at same locations so I want to plot one averaged ROC from all 10 ROC curves. Please suggest me the solution.
  댓글 수: 1
Chunru
Chunru 2022년 4월 21일
You can interpolate each curve on the same grid and then perform average.

댓글을 달려면 로그인하십시오.

채택된 답변

Chunru
Chunru 2022년 4월 21일
% Create sample data
numPoints = 50;
nCurves = 10;
x = sort(rand(numPoints, nCurves));
y = (sort(rand(numPoints, nCurves))).^(1/4);
plot(x, y);
grid on;
hold on;
% same grid
x0 = linspace(0, 1, 100);
% interpolation
yinterp = zeros(length(x0), nCurves);
for i=1:nCurves
yinterp(:, i) = interp1(x(:,i), y(:,i), x0, 'linear', 'extrap');
end
% Now average together
meany = mean(yinterp, 2);
% Now plot
hold on;
plot(x0, meany, 'LineWidth', 2);

추가 답변 (1개)

Image Analyst
Image Analyst 2022년 4월 21일
편집: Image Analyst 2022년 4월 21일
Try this:
% Create sample data because the original poster didn't upload theirs.
numPoints = 30;
x1 = sort(rand(1, numPoints));
x2 = sort(rand(1, numPoints));
x3 = sort(rand(1, numPoints));
x4 = sort(rand(1, numPoints));
x5 = sort(rand(1, numPoints));
x6 = sort(rand(1, numPoints));
x7 = sort(rand(1, numPoints));
x8 = sort(rand(1, numPoints));
x9 = sort(rand(1, numPoints));
x10 = sort(rand(1, numPoints));
y1 = sort(rand(1, numPoints));
y2 = sort(rand(1, numPoints));
y3 = sort(rand(1, numPoints));
y4 = sort(rand(1, numPoints));
y5 = sort(rand(1, numPoints));
y6 = sort(rand(1, numPoints));
y7 = sort(rand(1, numPoints));
y8 = sort(rand(1, numPoints));
y9 = sort(rand(1, numPoints));
y10 = sort(rand(1, numPoints));
plot(x1, y1, '-');
hold on;
plot(x2, y2, '-');
plot(x3, y3, '-');
plot(x4, y4, '-');
plot(x5, y5, '-');
plot(x6, y6, '-');
plot(x7, y7, '-');
plot(x8, y8, '-');
plot(x9, y9, '-');
plot(x10, y10, '-');
grid on;
hold on;
%========================================================================
% Since you have your own data you'd start here
% and NOT create the sample data above.
allx = sort([x1,x2,x3,x4,x5,x6,x7,x8,x9,x10], 'ascend');
% Then interpolate all the other curves so they're on a common x axis.
y1a = interp1(x1, y1, allx);
y2a = interp1(x2, y2, allx);
y3a = interp1(x3, y3, allx);
y4a = interp1(x4, y4, allx);
y5a = interp1(x5, y5, allx);
y6a = interp1(x6, y6, allx);
y7a = interp1(x7, y7, allx);
y8a = interp1(x8, y8, allx);
y9a = interp1(x9, y9, allx);
y10a = interp1(x10, y10, allx);
% Get all y together in one matrix.
allY = [y1a;y2a;y3a;y4a;y5a;y6a;y7a;y8a;y9a;y10a];
% Find out how many curves have valid, non-nan values at each x location.
counts = sum(~isnan(allY), 1);
% Now set nans to zero so we can sum the values and not get a nan if one of the curves is nan for some x value.
allY(isnan(allY)) = 0;
% Since some y are nan (which happens outside the x range where they were originally defined)
% we can't use mean(ally, 1) to get the mean value because we'd be averaging in zeros.
% So we need to sum the ally array vertically to get the sum of the non-nan values,
% and then sum the counts array vertically to find out
% how many signals were not nan for those x values.
% Then we can divide the sum by the counts to get the true mean.
meany = sum(allY, 1) ./ sum(counts, 1);
% Now plot the mean as a thick black curve.
hold on;
plot(allx, meany, 'k-', 'LineWidth', 4);
title('Thick black line is the mean of all curves')
Note how the plot gets a little wiggly near the ends as the number of valid curves (non-nan values) gets fewer and so the mean gets closer to the valid remaining curves. For example let's say the after x = 0.9 there are only 5 curves with non-nan values, not the full 10. So there you'd want to average only 5 curves, not all 10. So in the picture above, see close to 1, only the yellow curve has valid x values out that far, so the mean will equal the yellow curve's y value there. It's for this reason that you can't just simply use the mean() function and you have to divide the sum by the count (because the count changes). Does that make sense?
  댓글 수: 3
Image Analyst
Image Analyst 2022년 4월 21일
Well I guess you could compute the standard deviation at every x location and then get two curves
  1. the average curve plus the locally varying standard deviation
  2. the average curve minus the locally varying standard deviation.
Then plot those curves. One will be above the mean curve and one will be below it. Where you have only one curve (at the outside ends) the standard deviation will be zero there of course.
Adam Danz
Adam Danz 2022년 4월 25일
I wonder if curve fitting would useful. Then you could get error estimates of the fit parameters and plot the smooth fit and the range of error.

댓글을 달려면 로그인하십시오.

Community Treasure Hunt

Find the treasures in MATLAB Central and discover how the community can help you!

Start Hunting!

Translated by