Why Choose Model-Based Reinforcement Learning?
From the series: Reinforcement Learning
What is the difference between model-free and model-based reinforcement learning? Explore the differences and results as the learning models are applied to balancing a cart/pole system as an example. By the end, you will have a better understanding of situations where you may want to choose one method over the other.
Published: 15 Jul 2022
What is the difference between model-free and model-based reinforcement learning? Well, in this video, I'm going to answer that question. And as part of the answer, I'm going to show you the results of training an agent to balance a cart pole system with both model-free and model-based reinforcement learning. And hopefully, by the end, you're going to have a better understanding of the situations where you may want to choose one over the other. So I hope you stick around for it. I'm Brian, and welcome to a MATLAB Tech Talk.
With reinforcement learning, an agent learns how to interact with an environment in a way that maximizes a reward. And it does this by adjusting its policy over the course of many different attempts. The agent takes an action. And then if that's deemed to be good based on the reward that's received, then that action is reinforced. And there is a higher likelihood that it will perform the same action again the next time it's in that state.
What I've drawn here is model-free reinforcement learning, and with model-free RL, an agent learns the optimal policy solely through direct interactions with an environment. So all of the experience that the agent uses to learn is generated by the real environment.
Now, on the other hand, with model-based RL, the agent still uses some real experiences. But the difference is that within the agent, there's also a model of the environment, which it can use to create additional simulated experiences. It's sort of like having a mental model of the world that allows the agent to learn by thinking rather than just acting.
Now, I know that's kind of a confusing statement. So to understand it better, let's use sports as an analogy. There are two teams, A and B, and they've never played each other before. And therefore, Team A doesn't know how Team B plays or what to expect for their strategy.
And at the start of the match, Team A will try some tactics and decide whether those actions worked or not, and in real time, adjust their strategy to attempt to win the game. This is a model-free approach to learning because Team A had to interact with the environment, or Team B in this case, in order to learn how Team B plays and improve their chances of winning.
However, during this process, Team A could have developed a model in their head of how Team B plays. This model might not be perfectly accurate, but now, before the two teams play again, Team A can think about different strategies and assess whether they'll be successful or not based on their understanding of how Team B plays.
This is a model-based approach to learning because Team A is able to update their policy or learn a better playing strategy without directly interacting with Team B but just by thinking about it. They are supplementing their real experiences with simulated experiences. And because of this, they're able to learn and improve between episodes or between real physical matches.
In this same way, a software-based agent can have a mental model of the environment and use that to update its policy between physical interactions with the real-world environment. And this strategy can investigate more of the state space than just real experiences alone.
Conceptually, it works like this. Imagine this 2D rectangle represents the entire state space of the environment, and we want the agent to explore the entire thing and learn the optimal policy. If the starting state is here and the agent takes real actions, then over time, the agent may transition to this environment state. These are the real-world experiences. And with model-free reinforcement learning, the agent can only learn from these experiences, which as we can see doesn't cover much of the total state space area.
However, if the agent has a model of the environment, then at each time step, it can simulate what would happen if it had taken a different action. So from a single real trajectory, the agent could generate thousands of simulations each time step. And even if those simulations are only looking ahead one step in the future, the combination of real and simulated experiences covers a much wider swath of the state space. And so in this way, it takes fewer real interactions with the environment overall in order to cover the entire area.
And this is why model-based RL is more sample efficient than model-free RL. However, it does come at a cost, and that is computations. Running a simulation requires more compute power than just interacting with the real environment. So what we gain with fewer samples, we lose in more compute operations.
All right, now before we move on, I want to clear something up. When I say real environment, this simply means the environment that is external to the agent. I don't mean that the environment has to be a real physical entity. I mean, very often in reinforcement learning, we still train agents using a model of the environment. It's just that that model is external to the agent, and therefore, the only way for the agent to learn is by taking actions and receiving a response.
From the agent's perspective, that model is the environment, and it learns solely through direct interactions. It's only when the model exists within the agent and the agent can query that model internally without ever taking an action that it becomes model-based reinforcement learning.
All right, so now that we have a handle on the differences between model-free and model-based RL, I want to talk about why you may choose one versus the other. In comparison, model-free RL is a simple and straightforward approach. You let the agent take an action and then update the policy based on the observations and reward. You don't have to mess around with having a model or setting up any kind of simulations. And so one benefit of model-free is the simplicity of it.
Also, models are less accurate than the real world, and therefore, an agent that is updating its policy using possibly an imperfect mental model might learn some behaviors that don't work in the real environment. And so the agent would have to use the real environment to correct that policy.
And if the model continues to stay inaccurate, the agent would continually learn the wrong thing and then have to correct it and then learn the wrong thing again and so on in a really most inefficient manner. So model-free RL doesn't have this problem since the agent is always learning directly from the external environment.
Now, on the other hand, model-based RL requires fewer interactions with the external environment. And this is because, like we said earlier, the agent can do a lot of learning using the model in between real-world interactions. And this is particularly important if the agent is controlling a physical system like a robot or a vehicle because real actions have the possibility of damaging hardware, either by the agent requesting an action that causes that damage or just through the wearing down of parts over many consecutive movements.
There is also the benefit of being able to gain experience faster than real time by running models. And it's also easier to place the agent in states that would be difficult with the real environment. So those are some of the reasons why you may want to choose a model-based approach.
Now, if the environment is software based or it's a simulation already like a Simulink model of the real environment, then sample efficiency might not matter as much as total compute resources. It costs money and time to perform more calculations, and you might as well just interact directly with the environment with model-free URL. But in the case of physical systems, it can be beneficial to reduce the number of real-world training episodes with model-based RL even at the expense of more computations.
All right, so let's say that model-based RL is what you want to use to train an agent on a physical system. But what if you don't have a model of the environment ahead of time? How do you get one so that you can actually do model-based RL? Well, one way is that we could just create a model and then give it to the agent. And modeling is done all the time, right? So this doesn't seem like a bad idea.
However, often, creating a model requires data from the physical system so we need the agent to physically interact with the environment anyway to generate the data that we want to fit a model to, or we want to use that data to learn a model. So if we need to interact with the environment to develop a model and we want to train an agent using reinforcement learning, why not combine the two and set up a reinforcement learning algorithm that simultaneously learns the environment model and learns the optimal policy?
In this approach, when the agent gets real experiences from the environment, it uses that data to update the model and also to update its policy. And so now, before the next episode with the real environment, the agent can use the model to generate additional simulated experiences.
But a question might be, do we have to fully learn the model before the agent can start using it? I mean, we know it doesn't do us any good to make predictions with the model if the predictions are garbage. And early in the learning process, most of the model is garbage, right? Therefore, if we train the agent with that model, we'll end up with reinforcing behaviors that will perform badly in the real world. And that is true.
But when the bad actions are attempted in the real world and a lower reward is received, then that experience will correct the agent's policy. And it's going to correct the model for that given state and action. So now the agent can use that corrected model to explore the nearby states and actions more thoroughly. So it's self-correcting over time.
But there are ways of reducing the chance that the agent learns from areas of the model that aren't accurate in the first place. And one is by giving the agent multiple models or an ensemble of models. Here, the agent is maintaining three models of the environment, and each one is a neural network which is initialized with random weights and biases. And therefore, if they're given the same state action input, then it's likely that each will make different random predictions.
And then, after a training episode, these models will have a higher chance of making the same prediction for that trajectory that they were trained on while all of the other state action inputs would still produce random outputs. In this way, over time, if the predictions from all of the models agree, then it's likely that the agent is simulating a state that it's already learned. And if they don't agree, then it's safe to assume that the predictions are from random parts of the model.
All right, so one way to handle this is to just check the predictions and then only use the model experiences when they agree. And if they don't agree, just ignore them. However, a simpler approach is to just not worry about it and let the agent learn from all of the simulated experience whether it's correct or not.
If the models agree, then the agent is going to reinforce that behavior multiple times, once for each model. And if they don't agree, then the agent will reinforce several random behaviors but each probably in a different direction so they kind of offset each other. And so in this way, statistically, there is a strong reinforcement if the models agree and a weaker reinforcement if they don't.
All right, so let's wrap all of this up. With model-based reinforcement learning, an agent has a model of the environment that it can use to generate simulated experiences. Additionally, some algorithms allow the agent to learn that model as it interacts with the real environment. But whether it learns that model or not, just having a model as part of the agent is what makes it model-based RL.
And having multiple models provides a way for the agent to learn more quickly from simulated experiences that are more accurate and more slowly from simulated experiences that are less accurate. And all of this typically extends the amount of time and calculations it takes to train an agent, but it reduces the number of real-world samples that are needed.
Now, I told you that I would show you the results of both a model-free and model-based reinforcement learning approach to training an agent. And to set up this problem, I'm using the MATLAB example called Train MBPO Agent to Balance Cart-Pole System. This example uses a Model-Based Policy Optimization Agent, or MBPO, to learn how to move a cart such that it balances a freely spinning pole.
And this agent is learning three neural network models of the environment, which it also uses to generate simulated experiences that augment the real experiences that it gets from the physical environment. I've left a link to this example below, and I recommend that you take some time and play around with it so that you can get a better feel for what this example is actually doing and how model-based RL works.
But what I want to show you is the result from running this example five times with different random seeds. And then here, I'm comparing the number of episodes needed to exceed the training criteria for a model-based approach with the number of episodes needed for a model-free approach. And notice that in four of the five cases, the model-based agent reached the training criteria in fewer episodes. It took just about half of the number of episodes. Now, there is this one that took about the same number, but that's just kind of the random luck of the draw.
On average, the model-based agent is more sample efficient. Now, I don't have the number of calculations or the exact wall clock time that it took for each of these examples. But on average, the model-based approach took about twice as long to run. So that's the main trade-off that we have to consider between these two methods, calculation time versus number of real-world interactions.
All right, well, that's where I'm going to leave this video for now. If you don't want to miss any other future Tech Talk videos, don't forget to subscribe to this channel. And if you want to check out my channel, Control System Lectures, I cover more control theory topics there as well. Thanks for watching, and I'll see you next time.