I am continuing to work my way through the Udacity Deep Reinforcement Learning Nanodegree. In this blog post I discuss and implement the Double DQN algorithm from Deep Reinforcement Learning with Double Q-Learning (Van Hasselt et al 2015). The Double DQN algorithm is a minor, but important, modification of the original DQN algorithm that I covered in a previous post.

The Van Hasselt et al 2015 paper makes several important contributions.

  1. Demonstration of how Q-learning can be overly optimistic in large-scale, even deterministic, problems due to the inherent estimation errors of learning.
  2. Demonstration that overestimations are more common and severe in practice than previously acknowledged.
  3. Implementation of Double Q-learning called Double DQN that extends, with minor modifications, the popular DQN algorithm and that can be used at scale to successfully reduce overestimations with the result being more stable and reliable learning.
  4. Demonstation that Double DQN finds better policies by obtaining new state-of-the-art results on the Atari 2600 dataset.

Q-learning overestimates Q-values

No matter what type of function approximation scheme used to approximate the action-value function $Q$ there will always be approximation error. The presence of the max operator in the Bellman equation used to compute the $Q$-values means that the approximate $Q$-values will almost always be strictly greater than the corresponding $Q$ values from the true action-value function (i.e., the approximation errors will almost always be positive). This potentially significant source of bias can impede learning and is often exacerbated by the use of flexible, non-linear function approximators such as neural networks.

Double Q-learning addresses these issues by explicitly separating action selection from action evaluation which allows each step to use a different function approximator resulting in a better overall approximation of the action-value function. Figure 2 (with caption) below, which is taken from Van Hasselt et al 2015, summarizes these ideas. See the paper for more details.

Implementing the Double DQN algorithm

The key idea behind Double Q-learning is to reduce overestimations of Q-values by separating the selection of actions from the evaluation of those actions so that a different Q-network can be used in each step. When applying Double Q-learning to extend the DQN algorithm one can use the online Q-network, $Q(S, a; \theta)$, to select the actions and then the target Q-network, $Q(S, a; \theta^{-})$, to evaluate the selected actions.

Before implement the Double DQN algorithm, I am going to re-implement the Q-learning update from the DQN algorithm in a way that explicitly separates action selection from action evaluation. Once I have implemented this new version of Q-learning, implementing the Double DQN algorithm will be much easier. Formally separating action selection from action evaluation involves re-writing the Q-learning Bellman equation as follows.

$$ Y_t^{DQN} = R_{t+1} + \gamma Q\big(S_{t+1}, \underset{a}{\mathrm{argmax}}\ Q(S_{t+1}, a; \theta_t); \theta_t\big) $$

In Python this can be implemented as three separate functions.

import torch
from torch import nn


def select_greedy_actions(states: torch.Tensor, q_network: nn.Module) -> torch.Tensor:
    """Select the greedy action for the current state given some Q-network."""
    _, actions = q_network(states).max(dim=1, keepdim=True)
    return actions


def evaluate_selected_actions(states: torch.Tensor,
                              actions: torch.Tensor,
                              rewards: torch.Tensor,
                              dones: torch.Tensor,
                              gamma: float,
                              q_network: nn.Module) -> torch.Tensor:
    """Compute the Q-values by evaluating the actions given the current states and Q-network."""
    next_q_values = q_network(states).gather(dim=1, index=actions)        
    q_values = rewards + (gamma * next_q_values * (1 - dones))
    return q_values


def q_learning_update(states: torch.Tensor,
                      rewards: torch.Tensor,
                      dones: torch.Tensor,
                      gamma: float,
                      q_network: nn.Module) -> torch.Tensor:
    """Q-Learning update with explicitly decoupled action selection and evaluation steps."""
    actions = select_greedy_actions(states, q_network)
    q_values = evaluate_selected_actions(states, actions, rewards, dones, gamma, q_network)
    return q_values

From here it is straight forward to implement the Double DQN algorithm. All I need is a second action-value function. The target network in the DQN architecture provides a natural candidate for the second action-value function. Hasselt et al 2015 suggest using the online Q-network to select the greedy policy actions before using the target Q-network to estimate the value of the selected actions. Once again here are the maths...

$$ Y_t^{DoubleDQN} = R_{t+1} + \gamma Q\big(S_{t+1}, \underset{a}{\mathrm{argmax}}\ Q(S_{t+1}, a; \theta_t), \theta_t^{-}\big) $$

...and here is the the Python implementation.

def double_q_learning_update(states: torch.Tensor,
                             rewards: torch.Tensor,
                             dones: torch.Tensor,
                             gamma: float,
                             q_network_1: nn.Module,
                             q_network_2: nn.Module) -> torch.Tensor:
    """Double Q-Learning uses Q-network 1 to select actions and Q-network 2 to evaluate the selected actions."""
    actions = select_greedy_actions(states, q_network_1)
    q_values = evaluate_selected_actions(states, actions, rewards, dones, gamma, q_network_2)
    return q_values

Note that the function double_q_learning_update is almost identical to the q_learning_update function above: all that is needed is to introduce a second Q-network parameter, q_network_2, to the function. This second Q-network will be use to evaluate the actions chosen using the original Q-network parameter, now called q_network_1.

Experience Replay

Just like the DQN algorithm, the Double DQN algorithm uses an ExperienceReplayBuffer to stabilize the learning process.

import collections
import typing

import numpy as np


_field_names = [
    "state",
    "action",
    "reward",
    "next_state",
    "done"
]
Experience = collections.namedtuple("Experience", field_names=_field_names)


class ExperienceReplayBuffer:
    """Fixed-size buffer to store Experience tuples."""

    def __init__(self,
                 batch_size: int,
                 buffer_size: int = None,
                 random_state: np.random.RandomState = None) -> None:
        """
        Initialize an ExperienceReplayBuffer object.

        Parameters:
        -----------
        buffer_size (int): maximum size of buffer
        batch_size (int): size of each training batch
        random_state (np.random.RandomState): random number generator.
        
        """
        self._batch_size = batch_size
        self._buffer_size = buffer_size
        self._buffer = collections.deque(maxlen=buffer_size)
        self._random_state = np.random.RandomState() if random_state is None else random_state
        
    def __len__(self) -> int:
        return len(self._buffer)
    
    @property
    def batch_size(self) -> int:
        """Number of experience samples per training batch."""
        return self._batch_size
    
    @property
    def buffer_size(self) -> int:
        """Total number of experience samples stored in memory."""
        return self._buffer_size
    
    def append(self, experience: Experience) -> None:
        """Add a new experience to memory."""
        self._buffer.append(experience)
    
    def sample(self) -> typing.List[Experience]:
        """Randomly sample a batch of experiences from memory."""
        idxs = self._random_state.randint(len(self._buffer), size=self._batch_size)
        experiences = [self._buffer[idx] for idx in idxs]
        return experiences

Refactoring the DeepQAgent class

Now that I have an implementation of the Double Q-learning algorithm I can refactor the DeepQAgent class from my previous post to incorporate the functionality above. The functions defined above can be added to the DeepQAgent as either static methods or simply included as module level functions, depending. I tend to prefer module level functions instead of static methods as module level function can be imported independently of class definitions which makes them a bit more re-usable.

import typing

import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F


class Agent:
    
    def choose_action(self, state: np.array) -> int:
        """Rule for choosing an action given the current state of the environment."""
        raise NotImplementedError
        
    def learn(self, experiences: typing.List[Experience]) -> None:
        """Update the agent's state based on a collection of recent experiences."""
        raise NotImplementedError

    def save(self, filepath) -> None:
        """Save any important agent state to a file."""
        raise NotImplementedError
        
    def step(self,
             state: np.array,
             action: int,
             reward: float,
             next_state: np.array,
             done: bool) -> None:
        """Update agent's state after observing the effect of its action on the environment."""
        raise NotImplmentedError


class DeepQAgent(Agent):

    def __init__(self,
                 state_size: int,
                 action_size: int,
                 number_hidden_units: int,
                 optimizer_fn: typing.Callable[[typing.Iterable[nn.Parameter]], optim.Optimizer],
                 batch_size: int,
                 buffer_size: int,
                 epsilon_decay_schedule: typing.Callable[[int], float],
                 alpha: float,
                 gamma: float,
                 update_frequency: int,
                 double_dqn: bool = False,
                 seed: int = None) -> None:
        """
        Initialize a DeepQAgent.
        
        Parameters:
        -----------
        state_size (int): the size of the state space.
        action_size (int): the size of the action space.
        number_hidden_units (int): number of units in the hidden layers.
        optimizer_fn (callable): function that takes Q-network parameters and returns an optimizer.
        batch_size (int): number of experience tuples in each mini-batch.
        buffer_size (int): maximum number of experience tuples stored in the replay buffer.
        epsilon_decay_schdule (callable): function that takes episode number and returns epsilon.
        alpha (float): rate at which the target q-network parameters are updated.
        gamma (float): Controls how much that agent discounts future rewards (0 < gamma <= 1).
        update_frequency (int): frequency (measured in time steps) with which q-network parameters are updated.
        double_dqn (bool): whether to use vanilla DQN algorithm or use the Double DQN algorithm.
        seed (int): random seed
        
        """
        self._state_size = state_size
        self._action_size = action_size
        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # set seeds for reproducibility
        self._random_state = np.random.RandomState() if seed is None else np.random.RandomState(seed)
        if seed is not None:
            torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
        
        # initialize agent hyperparameters
        _replay_buffer_kwargs = {
            "batch_size": batch_size,
            "buffer_size": buffer_size,
            "random_state": self._random_state
        }
        self._memory = ExperienceReplayBuffer(**_replay_buffer_kwargs)
        self._epsilon_decay_schedule = epsilon_decay_schedule
        self._alpha = alpha
        self._gamma = gamma
        self._double_dqn = double_dqn
        
        # initialize Q-Networks
        self._update_frequency = update_frequency
        self._online_q_network = self._initialize_q_network(number_hidden_units)
        self._target_q_network = self._initialize_q_network(number_hidden_units)
        self._synchronize_q_networks(self._target_q_network, self._online_q_network)        
        self._online_q_network.to(self._device)
        self._target_q_network.to(self._device)
        
        # initialize the optimizer
        self._optimizer = optimizer_fn(self._online_q_network.parameters())

        # initialize some counters
        self._number_episodes = 0
        self._number_timesteps = 0
        
    def _initialize_q_network(self, number_hidden_units: int) -> nn.Module:
        """Create a neural network for approximating the action-value function."""
        q_network = nn.Sequential(
            nn.Linear(in_features=self._state_size, out_features=number_hidden_units),
            nn.ReLU(),
            nn.Linear(in_features=number_hidden_units, out_features=number_hidden_units),
            nn.ReLU(),
            nn.Linear(in_features=number_hidden_units, out_features=self._action_size)
        )
        return q_network
                 
    @staticmethod
    def _soft_update_q_network_parameters(q_network_1: nn.Module,
                                          q_network_2: nn.Module,
                                          alpha: float) -> None:
        """In-place, soft-update of q_network_1 parameters with parameters from q_network_2."""
        for p1, p2 in zip(q_network_1.parameters(), q_network_2.parameters()):
            p1.data.copy_(alpha * p2.data + (1 - alpha) * p1.data)
    
    @staticmethod
    def _synchronize_q_networks(q_network_1: nn.Module, q_network_2: nn.Module) -> None:
        """In place, synchronization of q_network_1 and q_network_2."""
        _ = q_network_1.load_state_dict(q_network_2.state_dict())
           
    def _uniform_random_policy(self, state: torch.Tensor) -> int:
        """Choose an action uniformly at random."""
        return self._random_state.randint(self._action_size)
        
    def _greedy_policy(self, state: torch.Tensor) -> int:
        """Choose an action that maximizes the action_values given the current state."""
        action = (self._online_q_network(state)
                      .argmax()
                      .cpu()  # action_values might reside on the GPU!
                      .item())
        return action
    
    def _epsilon_greedy_policy(self, state: torch.Tensor, epsilon: float) -> int:
        """With probability epsilon explore randomly; otherwise exploit knowledge optimally."""
        if self._random_state.random() < epsilon:
            action = self._uniform_random_policy(state)
        else:
            action = self._greedy_policy(state)
        return action

    def choose_action(self, state: np.array) -> int:
        """
        Return the action for given state as per current policy.
        
        Parameters:
        -----------
        state (np.array): current state of the environment.
        
        Return:
        --------
        action (int): an integer representing the chosen action.

        """
        # need to reshape state array and convert to tensor
        state_tensor = (torch.from_numpy(state)
                             .unsqueeze(dim=0)
                             .to(self._device))
            
        # choose uniform at random if agent has insufficient experience
        if not self.has_sufficient_experience():
            action = self._uniform_random_policy(state_tensor)
        else:
            epsilon = self._epsilon_decay_schedule(self._number_episodes)
            action = self._epsilon_greedy_policy(state_tensor, epsilon)
        return action
    
    def learn(self, experiences: typing.List[Experience]) -> None:
        """Update the agent's state based on a collection of recent experiences."""
        states, actions, rewards, next_states, dones = (torch.Tensor(vs).to(self._device) for vs in zip(*experiences))
        
        # need to add second dimension to some tensors
        actions = (actions.long()
                          .unsqueeze(dim=1))
        rewards = rewards.unsqueeze(dim=1)
        dones = dones.unsqueeze(dim=1)
        
        if self._double_dqn:
            target_q_values = double_q_learning_update(next_states,
                                                       rewards,
                                                       dones,
                                                       self._gamma,
                                                       self._online_q_network,
                                                       self._target_q_network)
        else:
            target_q_values = q_learning_update(next_states,
                                                rewards,
                                                dones,
                                                self._gamma,
                                                self._target_q_network)

        online_q_values = (self._online_q_network(states)
                               .gather(dim=1, index=actions))
        
        # compute the mean squared loss
        loss = F.mse_loss(online_q_values, target_q_values)
        
        # updates the parameters of the online network
        self._optimizer.zero_grad()
        loss.backward()
        self._optimizer.step()
        
        self._soft_update_q_network_parameters(self._target_q_network,
                                               self._online_q_network,
                                               self._alpha)
    
    def has_sufficient_experience(self) -> bool:
        """True if agent has enough experience to train on a batch of samples; False otherwise."""
        return len(self._memory) >= self._memory.batch_size
    
    def save(self, filepath: str) -> None:
        """
        Saves the state of the DeepQAgent.
        
        Parameters:
        -----------
        filepath (str): filepath where the serialized state should be saved.
        
        Notes:
        ------
        The method uses `torch.save` to serialize the state of the q-network, 
        the optimizer, as well as the dictionary of agent hyperparameters.
        
        """
        checkpoint = {
            "q-network-state": self._online_q_network.state_dict(),
            "optimizer-state": self._optimizer.state_dict(),
            "agent-hyperparameters": {
                "alpha": self._alpha,
                "batch_size": self._memory.batch_size,
                "buffer_size": self._memory.buffer_size,
                "gamma": self._gamma,
                "update_frequency": self._update_frequency
            }
        }
        torch.save(checkpoint, filepath)
        
    def step(self,
             state: np.array,
             action: int,
             reward: float,
             next_state: np.array,
             done: bool) -> None:
        """
        Updates the agent's state based on feedback received from the environment.
        
        Parameters:
        -----------
        state (np.array): the previous state of the environment.
        action (int): the action taken by the agent in the previous state.
        reward (float): the reward received from the environment.
        next_state (np.array): the resulting state of the environment following the action.
        done (bool): True is the training episode is finised; false otherwise.
        
        """
        experience = Experience(state, action, reward, next_state, done)
        self._memory.append(experience)
            
        if done:
            self._number_episodes += 1
        else:
            self._number_timesteps += 1
            
            # every so often the agent should learn from experiences
            if self._number_timesteps % self._update_frequency == 0 and self.has_sufficient_experience():
                experiences = self._memory.sample()
                self.learn(experiences)

The code for the training loop remains unchanged from the previous post.

import collections
import typing

import gym


def _train_for_at_most(agent: Agent, env: gym.Env, max_timesteps: int) -> int:
    """Train agent for a maximum number of timesteps."""
    state = env.reset()
    score = 0
    for t in range(max_timesteps):
        action = agent.choose_action(state)
        next_state, reward, done, _ = env.step(action)
        agent.step(state, action, reward, next_state, done)
        state = next_state
        score += reward
        if done:
            break
    return score

                
def _train_until_done(agent: Agent, env: gym.Env) -> float:
    """Train the agent until the current episode is complete."""
    state = env.reset()
    score = 0
    done = False
    while not done:
        action = agent.choose_action(state)
        next_state, reward, done, _ = env.step(action)
        agent.step(state, action, reward, next_state, done)
        state = next_state
        score += reward
    return score


def train(agent: Agent,
          env: gym.Env,
          checkpoint_filepath: str,
          target_score: float,
          number_episodes: int,
          maximum_timesteps=None) -> typing.List[float]:
    """
    Reinforcement learning training loop.
    
    Parameters:
    -----------
    agent (Agent): an agent to train.
    env (gym.Env): an environment in which to train the agent.
    checkpoint_filepath (str): filepath used to save the state of the trained agent.
    number_episodes (int): maximum number of training episodes.
    maximum_timsteps (int): maximum number of timesteps per episode.
    
    Returns:
    --------
    scores (list): collection of episode scores from training.
    
    """
    scores = []
    most_recent_scores = collections.deque(maxlen=100)
    for i in range(number_episodes):
        if maximum_timesteps is None:
            score = _train_until_done(agent, env)
        else:
            score = _train_for_at_most(agent, env, maximum_timesteps)         
        scores.append(score)
        most_recent_scores.append(score)
        
        average_score = sum(most_recent_scores) / len(most_recent_scores)
        if average_score >= target_score:
            print(f"\nEnvironment solved in {i:d} episodes!\tAverage Score: {average_score:.2f}")
            agent.save(checkpoint_filepath)
            break
        if (i + 1) % 100 == 0:
            print(f"\rEpisode {i + 1}\tAverage Score: {average_score:.2f}")

    return scores

Solving the LunarLander-v2 environment

In the rest of this blog post I will use the Double DQN algorithm to train an agent to solve the LunarLander-v2 environment from OpenAI and the compare it to the the results obtained using the vanilla DQN algorithm.

In this environment the landing pad is always at coordinates (0,0). The reward for moving the lander from the top of the screen to landing pad and arriving at zero speed is typically between 100 and 140 points. Firing the main engine is -0.3 points each frame (so the lander is incentivized to fire the engine as few times possible). If the lander moves away from landing pad it loses reward (so the lander is incentived to land in the designated landing area). The lander is also incentived to land "gracefully" (and not crash in the landing area!).

A training episode finishes if the lander crashes (-100 points) or comes to rest (+100 points). Each leg with ground contact receives and additional +10 points. The task is considered "solved" if the lander is able to achieve 200 points (I will actually be more stringent and define "solved" as achieving over 200 points on average in the most recent 100 training episodes).

Action Space

There are four discrete actions available:

  1. Do nothing.
  2. Fire the left orientation engine.
  3. Fire main engine.
  4. Fire the right orientation engine.

Colab specific environment setup

If you are playing around with this notebook on Google Colab, then you will need to run the following cell in order to install the required OpenAI dependencies into the environment.

!pip install gym[box2d]==0.17.*
import gym

env = gym.make('LunarLander-v2')
_ = env.seed(42)
/Users/pughdr/Research/stochastic-expatriate-descent/env/lib/python3.7/site-packages/gym/logger.py:30: UserWarning: WARN: Box bound precision lowered by casting to float32
  warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))

Creating a DeepQAgent

Before creating an instance of the DeepQAgent I need to define an $\epsilon$-decay schedule and choose an optimizer.

Epsilon decay schedule

As was the case with the DQN algorithm, when using the Double DQN algorithm the agent chooses its action using an $\epsilon$-greedy policy. When using an $\epsilon$-greedy policy, with probability $\epsilon$, the agent explores the state space by choosing an action uniformly at random from the set of feasible actions; with probability $1-\epsilon$, the agent exploits its current knowledge by choosing the optimal action given that current state.

As the agent learns and acquires additional knowledge about it environment it makes sense to decrease exploration and increase exploitation by decreasing $\epsilon$. In practice, it isn't a good idea to decrease $\epsilon$ to zero; instead one typically decreases $\epsilon$ over time according to some schedule until it reaches some minimum value.

def power_decay_schedule(episode_number: int,
                         decay_factor: float,
                         minimum_epsilon: float) -> float:
    """Power decay schedule found in other practical applications."""
    return max(decay_factor**episode_number, minimum_epsilon)

_epsilon_decay_schedule_kwargs = {
    "decay_factor": 0.99,
    "minimum_epsilon": 1e-2,
}
epsilon_decay_schedule = lambda n: power_decay_schedule(n, **_epsilon_decay_schedule_kwargs)

Choosing an optimizer

As is the case in training any neural network, the choice of optimizer and the tuning of its hyper-parameters (in particular the learning rate) is important. Here I am going to use the Adam optimizer. In my previous post on the DQN algorithm I used RMSProp. In my experiments I found that the Adam optimizer significantly improves the efficiency and stability of both the Double DQN and DQN algorithms compared with RMSProp (on this task at least!). In fact it seemed that the improvements in terms of efficiency and stability from using the Adam optimizer instead of RMSProp optimzer were more important than any gains from using Double DQN instead of DQN.

from torch import optim


_optimizer_kwargs = {
    "lr": 1e-3,
    "betas": (0.9, 0.999),
    "eps": 1e-08,
    "weight_decay": 0,
    "amsgrad": False,
}
optimizer_fn = lambda parameters: optim.Adam(parameters, **_optimizer_kwargs)

Training the DeepQAgent using Double DQN

Now I am finally ready to train the deep_q_agent. The target score for the LunarLander-v2 environment is 200 points on average for at least 100 consecutive episodes. If the deep_q_agent is able to "solve" the environment, then training will terminate early.

_agent_kwargs = {
    "state_size": env.observation_space.shape[0],
    "action_size": env.action_space.n, 
    "number_hidden_units": 64,
    "optimizer_fn": optimizer_fn,
    "epsilon_decay_schedule": epsilon_decay_schedule,
    "batch_size": 64,
    "buffer_size": 100000,
    "alpha": 1e-3,
    "gamma": 0.99,
    "update_frequency": 4,
    "double_dqn": True,  # True uses Double DQN; False uses DQN 
    "seed": None,
}
double_dqn_agent = DeepQAgent(**_agent_kwargs)

double_dqn_scores = train(double_dqn_agent,
                          env,
                          "double-dqn-checkpoint.pth",
                          number_episodes=2000,
                          target_score=200)
Episode 100	Average Score: -170.99
Episode 200	Average Score: -96.48
Episode 300	Average Score: -58.06
Episode 400	Average Score: -37.71
Episode 500	Average Score: -36.62
Episode 600	Average Score: -9.45
Episode 700	Average Score: 85.52
Episode 800	Average Score: 113.89
Episode 900	Average Score: 169.01
Episode 1000	Average Score: 193.31
Episode 1100	Average Score: 151.43
Episode 1200	Average Score: 198.34
Episode 1300	Average Score: 88.76

Environment solved in 1398 episodes!	Average Score: 205.27

Training the DeepQAgent using DQN

Next I will create another DeepQAgent and train it using the original DQN algorithm for comparison.

_agent_kwargs = {
    "state_size": env.observation_space.shape[0],
    "action_size": env.action_space.n, 
    "number_hidden_units": 64,
    "optimizer_fn": optimizer_fn,
    "epsilon_decay_schedule": epsilon_decay_schedule,
    "batch_size": 64,
    "buffer_size": 100000,
    "alpha": 1e-3,
    "gamma": 0.99,
    "update_frequency": 4,
    "double_dqn": False,  # True uses Double DQN; False uses DQN 
    "seed": None,
}
dqn_agent = DeepQAgent(**_agent_kwargs)

dqn_scores = train(dqn_agent,
                   env,
                   "dqn-checkpoint.pth",
                   number_episodes=2000,
                   target_score=200)
Episode 100	Average Score: -154.20
Episode 200	Average Score: -52.41
Episode 300	Average Score: 42.07
Episode 400	Average Score: 15.85
Episode 500	Average Score: 149.99

Environment solved in 581 episodes!	Average Score: 201.22

Comparing DQN and Double DQN

To make it a bit easier to compare the overall performance of the two algorithms I will now re-train both agents for the same number of episodes (rather than training for the minimum number of episodes required to achieve a target score).

_agent_kwargs = {
    "state_size": env.observation_space.shape[0],
    "action_size": env.action_space.n, 
    "number_hidden_units": 64,
    "optimizer_fn": optimizer_fn,
    "epsilon_decay_schedule": epsilon_decay_schedule,
    "batch_size": 64,
    "buffer_size": 100000,
    "alpha": 1e-3,
    "gamma": 0.99,
    "update_frequency": 4,
    "double_dqn": True,
    "seed": None,
}
double_dqn_agent = DeepQAgent(**_agent_kwargs)

double_dqn_scores = train(double_dqn_agent,
                          env,
                          "double-dqn-checkpoint.pth",
                          number_episodes=2000,
                          target_score=float("inf"), # hack to insure that training never terminates early
                         )
Episode 100	Average Score: -167.21
Episode 200	Average Score: -126.37
Episode 300	Average Score: -40.27
Episode 400	Average Score: 62.17
Episode 500	Average Score: 198.44
Episode 600	Average Score: 220.72
Episode 700	Average Score: 235.76
Episode 800	Average Score: 239.27
Episode 900	Average Score: 227.32
Episode 1000	Average Score: 238.04
Episode 1100	Average Score: 230.81
Episode 1200	Average Score: 241.14
Episode 1300	Average Score: 241.82
Episode 1400	Average Score: 240.67
Episode 1500	Average Score: 248.40
Episode 1600	Average Score: 255.64
Episode 1700	Average Score: 257.87
Episode 1800	Average Score: 262.54
Episode 1900	Average Score: 252.53
Episode 2000	Average Score: 251.59
_agent_kwargs = {
    "state_size": env.observation_space.shape[0],
    "action_size": env.action_space.n, 
    "number_hidden_units": 64,
    "optimizer_fn": optimizer_fn,
    "epsilon_decay_schedule": epsilon_decay_schedule,
    "batch_size": 64,
    "buffer_size": 100000,
    "alpha": 1e-3,
    "gamma": 0.99,
    "update_frequency": 4,
    "double_dqn": False,
    "seed": None,
}
dqn_agent = DeepQAgent(**_agent_kwargs)

dqn_scores = train(dqn_agent,
                   env,
                   "dqn-checkpoint.pth",
                   number_episodes=2000,
                   target_score=float("inf"))
Episode 100	Average Score: -168.37
Episode 200	Average Score: -87.11
Episode 300	Average Score: -19.44
Episode 400	Average Score: 20.67
Episode 500	Average Score: 128.34
Episode 600	Average Score: 173.04
Episode 700	Average Score: 146.70
Episode 800	Average Score: 144.91
Episode 900	Average Score: 217.99
Episode 1000	Average Score: 209.41
Episode 1100	Average Score: 214.15
Episode 1200	Average Score: 219.80
Episode 1300	Average Score: 217.97
Episode 1400	Average Score: 242.82
Episode 1500	Average Score: 239.33
Episode 1600	Average Score: 234.26
Episode 1700	Average Score: 221.27
Episode 1800	Average Score: 239.80
Episode 1900	Average Score: 256.01
Episode 2000	Average Score: 266.62

Plotting the time series of scores

I can use Pandas to quickly plot the time series of scores along with a 100 episode moving average. Note that training stops as soon as the rolling average crosses the target score.

import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
dqn_scores = pd.Series(dqn_scores, name="scores")
double_dqn_scores = pd.Series(double_dqn_scores, name="scores")
fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True, sharey=True)
_ = dqn_scores.plot(ax=axes[0], label="DQN Scores")
_ = (dqn_scores.rolling(window=100)
               .mean()
               .rename("Rolling Average")
               .plot(ax=axes[0]))
_ = axes[0].legend()
_ = axes[0].set_ylabel("Score")

_ = double_dqn_scores.plot(ax=axes[1], label="Double DQN Scores")
_ = (double_dqn_scores.rolling(window=100)
                      .mean()
                      .rename("Rolling Average")
                      .plot(ax=axes[1]))
_ = axes[1].legend()
_ = axes[1].set_ylabel("Score")
_ = axes[1].set_xlabel("Episode Number")

Kernel density plot of the scores

In general, the kernel density plot will be bimodal with one mode less than -100 and a second mode greater than 200. The negative mode corresponds to those training episodes where the agent crash landed and thus scored at most -100; the positive mode corresponds to those training episodes where the agent "solved" the task. The kernel density or scores typically exhibits negative skewness (i.e., a fat left tail): there are lots of ways in which landing the lander can go horribly wrong (resulting in the agent getting a very low score) and only relatively few paths to a gentle landing (and a high score).

Depending, you may see that the distribution of scores for Double DQN has a significantly higher positive mode and lower negative mode when compared to the distribution for DQN which indicates that the agent trained with Double DQN solved the task more frequently and crashed and burned less frequently than the agent trained with DQN.

fig, ax = plt.subplots(1,1)
_ = dqn_scores.plot(kind="kde", ax=ax, label="DQN")
_ = double_dqn_scores.plot(kind="kde", ax=ax, label="Double DQN")
_ = ax.set_xlabel("Score")
_ = ax.legend()

Where to go from here?

In a future post I plan to cover Prioritized Experience Replay which improves the sampling scheme used by the ExperienceReplayBuffer so as to replay important transitions more frequently which should lead to more efficient learning.