I am trying to implement curriculum learning in a matlab class
조회 수: 4 (최근 30일)
이전 댓글 표시
I am trying to implement curriculum learning in MATLAB.
So, I have varying levels of difficulty in my resetEnvironmentForLevel() function.
The issue is that the agent explores the level 1 and once it reaches the goal within the desired condition (threshold=0.6 and history size=3), it should ideally move to the next level.
However, I believe that mainly due to the resetImpl function the environmnet resets to level 1.
What I ideally want is that the agent moves to the next level only when it reaches the goal within the desired conditions.
P.S. I also tried using global but that didnt work as well.
classdef GridWorld < matlab.System
properties (Nontunable)
GridSize = [8, 8] % Size of the grid
Actions = {'up', 'down', 'left', 'right'} % Possible actions (no diagonals)
MaxSteps = 100 % Maximum steps per episode
LevelThreshold = 0.6 % Success rate threshold to advance to next level
HistorySize = 3 % Number of episodes to consider for level advancement
end
properties (Access = protected)
CurrentPosition % Agent's current position
StartPosition % Agent's starting position
GoalPosition % Agent's goal position
Obstacles % Obstacle positions
Explored % Explored cells
TotalReward % Accumulated reward
Steps % Counter for the number of steps
CurrentStep % Current step within an episode
PreviousActions % History of recent actions
StartLevel = 1 % Start at the easiest level
GoalReached % Flag to indicate if goal is reached
ShortestPathLength % Store the shortest path length
%
SuccessHistory % Array to store recent episode results
%
EpisodesPerLevel % Track episodes completed at each level
%
CurrLvl
TotalEpisodes = 0
end
methods (Access = protected)
function setupImpl(obj)
persistent currentLevel
if isempty(currentLevel)
currentLevel = 1;
end
obj.CurrLvl =1;
obj.Obstacles = obj.setObstacles();
obj.PreviousActions = zeros(1, 5);
obj.SuccessHistory = zeros(1, obj.HistorySize);
obj.TotalEpisodes = 0;
obj.resetEnvironmentForLevel(currentLevel);
obj.CurrentPosition = obj.StartPosition;
disp(['SetupImpl']);
end
function initializeEpisodeTracking(obj)
obj.SuccessHistory = zeros(1, obj.HistorySize);
obj.EpisodesPerLevel = zeros(1, 5); % Assuming 5 levels
end
% function globalvar = GridWorld
% Global Curr;
% globalvar.CurrLvl = Curr;
% end
%%
function resetImpl(obj)
% persistent currentLevel
% if isempty(currentLevel)
% currentLevel = obj.CurrLvl; % Default starting level
% end
currentLevel = obj.CurrLvl;
disp(obj.CurrLvl);
% Pass the current level to reset the environment correctly
obj.resetEnvironmentForLevel(currentLevel);
obj.setCurrentPosition(obj.StartPosition);
obj.Explored = false(obj.GridSize);
obj.Explored(obj.CurrentPosition(1), obj.CurrentPosition(2)) = true;
obj.TotalReward = 0;
obj.Steps = 0;
obj.CurrentStep = 0;
obj.PreviousActions = zeros(1, 5);
obj.GoalReached = false;
disp(['resetImpl - Current Level: ', num2str(currentLevel)]);
end
function [observation, reward, isDone] = stepImpl(obj, action)
persistent totalEpisodes successHistory currentLevel
if isempty(totalEpisodes)
totalEpisodes = 0;
successHistory = zeros(1, obj.HistorySize);
currentLevel = 1;
end
action = double(action);
obj.CurrentStep = obj.CurrentStep + 1;
[newPos, reward, isDone] = obj.takeAction(action, obj.Steps);
obj.setCurrentPosition(newPos);
obj.TotalReward = obj.TotalReward + reward;
obj.Steps = obj.Steps + 1;
obj.PreviousActions = [action, obj.PreviousActions(1:end-1)];
observation = zeros(obj.GridSize(1), obj.GridSize(2), 11, 'double');
observation(:,:,1) = double(obj.Obstacles);
observation(:,:,2) = double(obj.Explored);
observation(obj.CurrentPosition(1), obj.CurrentPosition(2), 3) = 1;
observation(obj.GoalPosition(1), obj.GoalPosition(2), 4) = 1;
relativeGoalPos = obj.GoalPosition - obj.CurrentPosition;
observation(:,:,5) = relativeGoalPos(1);
observation(:,:,6) = relativeGoalPos(2);
observation(:,:,7) = obj.distanceToNearestObstacle();
observation(:,:,8) = atan2(relativeGoalPos(2), relativeGoalPos(1));
observation(:,:,9:11) = repmat(reshape(obj.PreviousActions(1:3), 1, 1, []), obj.GridSize(1), obj.GridSize(2));
if obj.Steps >= obj.MaxSteps
isDone = true;
end
if isDone
totalEpisodes = totalEpisodes + 1;
successHistory = [isequal(newPos, obj.GoalPosition), successHistory(1:end-1)];
successRate = mean(successHistory);
disp(['Total Episodes: ', num2str(totalEpisodes)]);
disp(['Current Level: ', num2str(currentLevel)]);
disp(['Success Rate: ', num2str(successRate)]);
disp(['Success History: ', num2str(successHistory)]);
% Only advance if success history is [1 1 1] and success rate is >= 0.6
if totalEpisodes >= obj.HistorySize && ...
successRate >= obj.LevelThreshold && ...
all(successHistory == 1) && currentLevel < 5
currentLevel = currentLevel + 1;
obj.CurrLvl = currentLevel;
disp(['objCurrLvl in if:',num2str(obj.CurrLvl)]);
successHistory = zeros(1, obj.HistorySize);
disp(['Advanced to level ', num2str(currentLevel)]);
obj.resetEnvironmentForLevel(currentLevel);
else
% If level exceeds max (level 5), stay at max level
if currentLevel >= 5
currentLevel = 5;
end
disp('Did not advance level. Reasons:');
if totalEpisodes < obj.HistorySize
disp([' - Not enough episodes. Current: ', num2str(totalEpisodes), ', Required: ', num2str(obj.HistorySize)]);
end
if successRate < obj.LevelThreshold
disp([' - Success rate too low. Current: ', num2str(successRate), ', Required: ', num2str(obj.LevelThreshold)]);
end
end
obj.plot(obj.Steps)
end
end
methods (Access = public)
function setCurrentPosition(obj, newPosition)
obj.CurrentPosition = newPosition;
end
function resetEnvironmentForLevel(obj, currentLevel)
switch currentLevel;
% if currentLevel == 1
case 1
obj.Obstacles = false(obj.GridSize);
obj.Obstacles(3, 3) = true;
obj.StartPosition = [5, 5];
obj.GoalPosition = [1, 5];
% elseif currentLevel == 2
case 2
obj.Obstacles = false(obj.GridSize);
obj.Obstacles(3, 3) = true;
obj.StartPosition = [2, 5];
obj.GoalPosition = [8, 1];
% elseif currentLevel == 3
case 3
obj.Obstacles = false(obj.GridSize);
obj.Obstacles(3, 3:5) = true;
obj.StartPosition = [4, 5];
obj.GoalPosition = [7, 7];
% elseif currentLevel == 4
case 4
obj.Obstacles = obj.setObstacles(); % Use existing obstacle setup
obj.StartPosition = [4, 5];
obj.setRandomGoal();
% else
case 5
obj.Obstacles = obj.setObstacles();
obj.setRandomStart();
obj.setRandomGoal();
end
obj.CurrentPosition = obj.StartPosition;
obj.calculateShortestPath(); % Ensure shortest path is calculated for each level
end
댓글 수: 0
답변 (1개)
Anagha Mittal
2024년 10월 17일
The encountered issue is diue to the "resetImpl" function. By the implemented logic, environment is getting set to "level1" at each call and is not getting set to the correct level("CurrLvl").
I have made a few modifications to the code with the changes mentioned as comments:
classdef MLAns < matlab.System
properties (Nontunable)
GridSize = [8, 8] % Size of the grid
Actions = {'up', 'down', 'left', 'right'} % Possible actions
MaxSteps = 100 % Maximum steps per episode
LevelThreshold = 0.6 % Success rate threshold to advance to next level
HistorySize = 3 % Number of episodes to consider for level advancement
end
properties (Access = protected)
CurrentPosition % Agent's current position
StartPosition % Agent's starting position
GoalPosition % Agent's goal position
Obstacles % Obstacle positions
Explored % Explored cells
TotalReward % Accumulated reward
Steps % Counter for the number of steps
CurrentStep % Current step within an episode
PreviousActions % History of recent actions
CurrLvl = 1 % Start at the easiest level (default: 1)
GoalReached % Flag to indicate if goal is reached
ShortestPathLength % Store the shortest path length
SuccessHistory % Array to store recent episode results
TotalEpisodes = 0 % Track total episodes across all levels
end
methods (Access = protected)
function setupImpl(obj)
obj.initializeEnvironment(); % Initialize once during setup
obj.initializeEpisodeTracking();
disp('Environment setup complete.');
end
function resetImpl(obj)
obj.resetEnvironmentForLevel(obj.CurrLvl); % Ensure correct level is set
obj.setCurrentPosition(obj.StartPosition);
obj.Explored = false(obj.GridSize);
obj.Explored(obj.CurrentPosition(1), obj.CurrentPosition(2)) = true;
obj.TotalReward = 0;
obj.Steps = 0;
obj.CurrentStep = 0;
obj.PreviousActions = zeros(1, 5);
obj.GoalReached = false;
disp(['Environment reset to Level: ', num2str(obj.CurrLvl)]);
end
function [observation, reward, isDone] = stepImpl(obj, action)
action = double(action); % Convert action to double if necessary
obj.CurrentStep = obj.CurrentStep + 1;
[newPos, reward, isDone] = obj.takeAction(action, obj.Steps);
obj.setCurrentPosition(newPos);
obj.TotalReward = obj.TotalReward + reward;
obj.Steps = obj.Steps + 1;
% Observation generation (similar to your original code)
observation = obj.generateObservation();
% Check if episode is done
if isDone || obj.Steps >= obj.MaxSteps
obj.TotalEpisodes = obj.TotalEpisodes + 1;
obj.updateSuccessHistory(newPos); % Update success based on goal reaching
obj.handleLevelProgression(); % Check for level advancement
obj.plot(obj.Steps); % Plot environment after episode ends
end
end
function handleLevelProgression(obj)
% Check if agent should advance to the next level
successRate = mean(obj.SuccessHistory);
disp(['Success Rate: ', num2str(successRate)]);
if obj.TotalEpisodes >= obj.HistorySize && ...
successRate >= obj.LevelThreshold && ...
all(obj.SuccessHistory == 1) && obj.CurrLvl < 5
obj.CurrLvl = obj.CurrLvl + 1;
disp(['Advancing to Level ', num2str(obj.CurrLvl)]);
obj.resetEnvironmentForLevel(obj.CurrLvl);
obj.SuccessHistory = zeros(1, obj.HistorySize); % Reset history for new level
else
disp('Did not advance level.');
if obj.CurrLvl >= 5
disp('Max level reached.');
end
end
end
function initializeEnvironment(obj)
obj.CurrentPosition = [5, 5]; % Default start position
obj.resetEnvironmentForLevel(obj.CurrLvl);
end
function initializeEpisodeTracking(obj)
obj.SuccessHistory = zeros(1, obj.HistorySize);
end
function updateSuccessHistory(obj, newPos)
% Update success history after each episode
obj.SuccessHistory = [isequal(newPos, obj.GoalPosition), obj.SuccessHistory(1:end-1)];
disp(['Success History: ', num2str(obj.SuccessHistory)]);
end
function resetEnvironmentForLevel(obj, currentLevel)
switch currentLevel
case 1
obj.setSimpleLevel();
case 2
obj.setMediumLevel();
case 3
obj.setHardLevel();
case 4
obj.setVeryHardLevel();
case 5
obj.setExtremeLevel();
end
obj.calculateShortestPath(); % Calculate the optimal path for each level
end
function observation = generateObservation(obj)
% Create observation for the current state (use your original code here)
observation = zeros(obj.GridSize(1), obj.GridSize(2), 11, 'double');
observation(:,:,1) = double(obj.Obstacles);
observation(:,:,2) = double(obj.Explored);
observation(obj.CurrentPosition(1), obj.CurrentPosition(2), 3) = 1;
observation(obj.GoalPosition(1), obj.GoalPosition(2), 4) = 1;
% Additional state-related calculations can go here...
end
% Define your environment levels (for resetEnvironmentForLevel)
function setSimpleLevel(obj)
obj.Obstacles = false(obj.GridSize);
obj.Obstacles(3, 3) = true;
obj.StartPosition = [5, 5];
obj.GoalPosition = [1, 5];
end
function setMediumLevel(obj)
obj.Obstacles = false(obj.GridSize);
obj.Obstacles(3, 3) = true;
obj.StartPosition = [2, 5];
obj.GoalPosition = [8, 1];
end
% Define further levels similarly...
end
methods (Access = public)
function setCurrentPosition(obj, newPosition)
obj.CurrentPosition = newPosition;
end
end
end
Hope this helps!
댓글 수: 0
참고 항목
카테고리
Help Center 및 File Exchange에서 Training and Simulation에 대해 자세히 알아보기
Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!