Tranformer-XL (PPO-TrXL)
Overview
Real-world tasks may expose imperfect information (e.g. partial observability). Such tasks require an agent to leverage memory capabilities. One way to do this is to use recurrent neural networks (e.g. LSTM) as seen in ppo_atari_lstm.py
, docs. Here, Transformer-XL is used as episodic memory in Proximal Policy Optimization (PPO).
Original Paper and Implementation
- Memory Gym: Towards Endless Tasks to Benchmark Memory Capabilities of Agents
- neroRL, Episodic Transformer Memory PPO
- Interactive Visualizations of Trained Agents
Related Publications and Repositories
- Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context
- Stabilizing Transformers for Reinforcement Learning
- Towards mental time travel: a hierarchical memory for reinforcement learning agents
- Grounded Language Learning Fast and Slow
- transformerXL_PPO_JAX
Implemented Variants
Variants Implemented | Description |
---|---|
ppo_trxl.py , docs |
For training on tasks like Endless-MortarMayhem-v0 . |
Below is our single-file implementation of PPO-TrXL:
ppo_trxl.py
ppo_trxl.py
has the following features:
- Works with Memory Gym's environments (84x84 RGB image observation).
- Works with Minigrid Memory (84x84 RGB image observation).
- Works also with environments exposing only game state vector observations (e.g. Proof of Memory Environment).
- Works with just single or multi-discrete action spaces.
Usage
cd cleanrl/ppo_trxl
poetry install
poetry run python ppo_trxl.py --help
poetry run python ppo_trxl.py --env-id Endless-MortarMayhem-v0
Explanation of the logged metrics
episode/r_mean
: mean of the episodic return of the gameepisode/l_mean
: mean of the episode length of the game in stepsepisode/t_mean
: mean of the episode duration of the game in secondsepisode/advantage_mean
: mean of all computed advantagesepisode/value_mean
: mean of all approximated valuescharts/SPS
: number of steps per secondcharts/learning_rate
: the current learning ratecharts/entropy_coefficient
: the current entropy coefficientlosses/value_loss
: the mean value loss across all data pointslosses/policy_loss
: the mean policy loss across all data pointslosses/entropy
: the mean entropy value across all data pointslosses/reconstruction_loss
: the mean observation reconstruction loss value across all data pointslosses/loss
: the mean of all summed losses across all data pointslosses/old_approx_kl
: the approximate Kullback–Leibler divergence, measured by(-logratio).mean()
, which corresponds to the k1 estimator in John Schulman’s blog post on approximating KLlosses/approx_kl
: better alternative toolad_approx_kl
measured by(logratio.exp() - 1) - logratio
, which corresponds to the k3 estimator in approximating KLlosses/clipfrac
: the fraction of the training data that triggered the clipped objectivelosses/explained_variance
: the explained variance for the value function
Implementation details
Most details are derived from ppo.py. These are additional or differing details:
- The policy and value function share parameters.
- Multi-head attention is implemented so that all heads share parameters.
- Abolute positional encoding is used as default. Learned positional encodings are supported.
- Previously computed hidden states of the TrXL layers are cached and re-used for up to
trxl_memory_length
. Only 1 hidden state is computed anew. - TrXL layers adhere to pre-layer normalization.
- Support for multi-discrete action spaces.
- Support for an auxiliary observation reconstruction loss, which reconstructs TrXL's output to the fed visual observation.
- The learning rate and the entropy bonus coefficient linearly decay until reaching a lower threshold.
Experiment results
Note: When training on potentially endless episodes, the cached hidden states demand a large GPU memory. To reproduce the following experiments a minimum of 40GB is required.
PPO-TrXL | |
---|---|
MortarMayhem-Grid-v0 | 0.99 ± 0.00 |
MortarMayhem-v0 | 0.99 ± 0.00 |
Endless-MortarMayhem-v0 | 1.50 ± 0.02 |
MysteryPath-Grid-v0 | 0.97 ± 0.01 |
MysteryPath-v0 | 1.67 ± 0.02 |
Endless-MysteryPath-v0 | 1.84 ± 0.06 |
SearingSpotlights-v0 | 1.11 ± 0.08 |
Endless-SearingSpotlights-v0 | 1.60 ± 0.03 |
Learning curves:
Tracked experiments:
Enjoy pre-trained models
Use cleanrl/ppo_trxl/enjoy.py to watch pre-trained agents. You can retrieve pre-trained models from huggingface.
Run models from the hub:
python cleanrl/ppo_trxl/enjoy.py --hub --name Endless-MortarMayhem-v0_12.nn
python cleanrl/ppo_trxl/enjoy.py --hub --name Endless-MysterPath-v0_11.nn
python cleanrl/ppo_trxl/enjoy.py --hub --name Endless-SearingSpotlights-v0_30.nn
Run local models (or download them from the hub manually):
python cleanrl/ppo_trxl/enjoy.py --name Your.cleanrl_model