Level Sampler
The LevelSampler
provides all of the functionality associated with a level buffer in a PLR/ACCEL-type method. In the standard Jax style, the level sampler class does not store any data itself, and accepts a sampler
object for most operations.
Examples:
>>>
pholder_level = ...
pholder_level_extra = ...
level_sampler = LevelSampler(4000)
sampler = level_sampler.initialize(pholder_level, pholder_level_extra)
should_replay = level_sampler.sample_replay_decision(sampler, rng)
replay_levels = level_sampler.sample_replay_levels(sampler, rng, 32) # 32 replay levels
scores = ... # eval agent
sampler = level_sampler.insert_batch(sampler, level, scores)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
capacity |
int
|
The maximum number of levels that can be stored in the buffer. |
required |
replay_prob |
float
|
The chance of performing on_replay vs on_new. Defaults to 0.95. |
0.95
|
staleness_coeff |
float
|
The weighting factor for staleness. Defaults to 0.5. |
0.5
|
minimum_fill_ratio |
float
|
The class will never sample a replay decision until the level buffer is at least as full as specified by this value. Defaults to 1.0. |
1.0
|
prioritization_params |
dict
|
If prioritization="rank", this has a "temperature" field; for "topk" it has a "k" field. If not provided, by default this is initialized to a temperature of 1.0 and k=1. Defaults to None. |
None
|
duplicate_check |
bool
|
If this is true, duplicate levels cannot be added to the buffer. This adds some computation to check for duplicates. Defaults to False. |
False
|
find(sampler, level)
Returns the index of level in the level buffer. If level is not present, -1 is returned.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sampler |
Sampler
|
The sampler object |
required |
level |
Level
|
The level to find |
required |
Returns:
Name | Type | Description |
---|---|---|
int |
int
|
index or -1 if not found. |
flush(sampler)
Flushes this sampler, putting it back to its empty state. This does update it in place. TODO: Do we want that? Sam: When jitted it doesn't.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sampler |
Sampler
|
|
required |
Returns:
Name | Type | Description |
---|---|---|
Sampler |
Sampler
|
|
freshness_weights(sampler)
Returns freshness weights for each level.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sampler |
Sampler
|
|
required |
Returns:
Type | Description |
---|---|
Array
|
chex.Array: shape (self.capacity) |
get_levels(sampler, level_idx)
Returns the level at a particular index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sampler |
Sampler
|
The sampler object |
required |
level_idx |
int
|
The index to return |
required |
Returns:
Name | Type | Description |
---|---|---|
Level |
Level
|
|
get_levels_extra(sampler, level_idx)
Returns the level extras associated with a particular index
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sampler |
Sampler
|
The sampler object |
required |
level_idx |
int
|
The index to return |
required |
Returns:
Name | Type | Description |
---|---|---|
dict |
dict
|
|
initialize(pholder_level, pholder_level_extra=None)
Returns the sampler
object as a dictionary.
Sampler Object Keys
- "levels" (shape (self.capacity, ...)): the levels themselves
- "scores" (shape (self.capacity)): the scores of the levels
- "timestamps" (shape (self.capacity)): the timestamps of the levels
- "size" (int): the number of levels currently in the buffer
- "episode_count" (int): the number of episodes that have been played so far
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pholder_level |
Level
|
A placeholder level that will be used to initialize the level buffer. |
required |
pholder_level_extra |
dict
|
If given, this should be a dictionary with arbitrary keys that is kept track of alongside each level. An example is "max_return" for each level. Defaults to None. |
None
|
Returns:
Name | Type | Description |
---|---|---|
Sampler |
Sampler
|
The initialized sampler object |
insert(sampler, level, score, level_extra=None)
Attempt to insert level into the level buffer.
Insertion occurs when: - Corresponding score exceeds the score of the lowest weighted level currently in the buffer (in which case it will replace it). - Buffer is not yet at capacity.
Optionally, if the level to be inserted already exists in the level
buffer, the corresponding buffer entry will be updated instead.
(See, duplicate_check
)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sampler |
Sampler
|
The sampler object |
required |
level |
Level
|
Level to insert |
required |
score |
float
|
Its score |
required |
level_extra |
dict
|
If level extra was given in |
None
|
Returns:
Type | Description |
---|---|
tuple[Sampler, int]
|
tuple[Sampler, int]: The updated sampler, and the level's index in the buffer (-1 if it was not inserted) |
insert_batch(sampler, levels, scores, level_extras=None)
Inserts a batch of levels.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sampler |
Sampler
|
The sampler object |
required |
levels |
_type_
|
The levels to insert. This must be a |
required |
scores |
_type_
|
The scores of each level |
required |
level_extras |
dict
|
The optional level_extras. Defaults to None. |
None
|
level_weights(sampler, prioritization=None, prioritization_params=None)
Returns the weights for each level, taking into account both staleness and score.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sampler |
Sampler
|
The sampler |
required |
prioritization |
Prioritization
|
Possibly overrides self.prioritization. Defaults to None. |
None
|
prioritization_params |
dict
|
Possibly overrides self.prioritization_params. Defaults to None. |
None
|
Returns:
Type | Description |
---|---|
Array
|
chex.Array: Weights, shape (self.capacity) |
sample_replay_decision(sampler, rng)
Returns a single boolean indicating if a replay
or new
step should be taken. This is based on the proportion of the buffer that is filled and the replay_prob
parameter.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sampler |
Sampler
|
The sampler object |
required |
rng |
PRNGKey
|
|
required |
Returns:
Name | Type | Description |
---|---|---|
bool |
bool
|
|
sample_replay_level(sampler, rng)
Samples a replay level from the buffer. It does this by first computing the weights of each level (using level_weights
), and then sampling from the buffer using these weights. The sampler
object is updated to reflect the new episode count and the level that was sampled. The level itself is returned as well as the index of the level in the buffer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sampler |
Sampler
|
The sampler object |
required |
rng |
PRNGKey
|
|
required |
Returns:
Type | Description |
---|---|
tuple[Sampler, tuple[int, Level]]
|
tuple[Sampler, tuple[int, Level]]: The updated sampler object, the sampled level's index and the level itself. |
sample_replay_levels(sampler, rng, num)
Samples several levels by iteratively calling sample_replay_level
. The sampler
object is updated to reflect the new episode count and the levels that were sampled.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sampler |
Sampler
|
The sampler object |
required |
rng |
PRNGKey
|
|
required |
num |
int
|
How many levels to sample |
required |
Returns:
Type | Description |
---|---|
tuple[Sampler, tuple[Array, Level]]
|
tuple[Sampler, tuple[chex.Array, Level]]: The updated sampler, an array of indices, and multiple levels. |
score_weights(sampler, prioritization=None, prioritization_params=None)
Returns an array of shape (self.capacity) with the weights of each level (for sampling purposes).
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sampler |
Sampler
|
|
required |
prioritization |
Prioritization
|
Possibly overrides self.prioritization. Defaults to None. |
None
|
prioritization_params |
dict
|
Possibly overrides self.prioritization_params. Defaults to None. |
None
|
Returns:
Type | Description |
---|---|
Array
|
chex.Array: Score weights, shape (self.capacity) |
staleness_weights(sampler)
Returns staleness weights for each level.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sampler |
Sampler
|
|
required |
Returns:
Type | Description |
---|---|
Array
|
chex.Array: shape (self.capacity) |
update(sampler, idx, score, level_extra=None)
This updates the score and level_extras of a level
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sampler |
Sampler
|
The sampler object |
required |
idx |
int
|
The index of the level |
required |
score |
float
|
The score of the level |
required |
level_extra |
dict
|
The associated. Defaults to None. |
None
|
Returns:
Name | Type | Description |
---|---|---|
Sampler |
Sampler
|
Updated Sampler |
update_batch(sampler, level_inds, scores, level_extras=None)
Updates the scores and level_extras of a batch of levels.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sampler |
Sampler
|
The sampler object |
required |
level_inds |
Array
|
Level indices |
required |
scores |
Array
|
Scores |
required |
level_extras |
dict
|
. Defaults to None. |
None
|
Returns:
Name | Type | Description |
---|---|---|
Sampler |
Sampler
|
Updated Sampler |