Skip to content

Maze DR

This example implements domain randomisation (DR). It uses the AutoReset wrapper to automatically reset the environment to a new random level upon episode termination. It follows the PureJaxRL-style training loop.

Outputs

This code saves checkpoints to ./checkpoints/<run_name>/<seed>/models/<update_step>, and if mode=eval, it saves results in ./results in a .npz format.

Usage

python examples/maze_dr.py <args>

Examples

Run training:

python examples/maze_dr.py --run_name my_dr_test --project my_wandb_project --seed 0 --num_updates 10000

After that has finished, run evaluation on the final checkpoint.

python examples/maze_dr.py --mode eval --checkpoint_directory checkpoints/my_dr_test/0 --checkpoint_to_eval=-1

Arguments

Name Description Default
--project Wandb project JAXUED_TEST
--run_name This controls where the checkpoints are stored None
--seed Random seed 0
--mode "train" or "eval" train
--checkpoint_directory Only valid if mode==eval where to load checkpoint from None
--checkpoint_to_eval Only valid if mode==eval. This is the timestep to load from the above checkpoint directory -1
--checkpoint_save_interval How often to save checkpoints 0
--max_number_of_checkpoints How many checkpoints to save in total 60
--eval_freq How often to evaluate the agent and log 250
--eval_num_attempts How many attempts (episodes) per level to run for evaluation 10
--eval_levels The eval levels to use "SixteenRooms", "SixteenRooms2", "Labyrinth", "LabyrinthFlipped", "Labyrinth2", "StandardMaze", "StandardMaze2", "StandardMaze3"
--lr The agent's learning rate 1e-4
--max_grad_norm The agent's max PPO grad norm 0.5
--num_updates Number of updates. Mutually exclusive with num_env_steps. Generally, num_env_steps = num_updates * num_steps * num_train_envs 30000
--num_env_steps Number of env steps. Mutually exclusive with `num_updates`` None
--num_steps Number of PPO rollout steps 256
--num_train_envs Number of training environments 32
--num_minibatches Number of PPO minibatches 1
--gamma Discount factor 0.995
--epoch_ppo Number of PPO epochs 5
--clip_eps PPO Epsilon Clip 0.2
--gae_lambda PPO Lambda 0.98
--entropy_coeff PPO entropy coefficient 1e-3
--critic_coeff Critic coefficient 0.5
--agent_view_size The number of tiles the agent can see in front of it 5
--n_walls Number of walls to generate 25