Skip to content

Level Sampler

The LevelSampler provides all of the functionality associated with a level buffer in a PLR/ACCEL-type method. In the standard Jax style, the level sampler class does not store any data itself, and accepts a sampler object for most operations.

Examples:

>>>
pholder_level       = ...
pholder_level_extra = ...
level_sampler       = LevelSampler(4000)
sampler             = level_sampler.initialize(pholder_level, pholder_level_extra)
should_replay       = level_sampler.sample_replay_decision(sampler, rng)
replay_levels       = level_sampler.sample_replay_levels(sampler, rng, 32) # 32 replay levels
scores              = ... # eval agent
sampler             = level_sampler.insert_batch(sampler, level, scores)

Parameters:

Name Type Description Default
capacity int

The maximum number of levels that can be stored in the buffer.

required
replay_prob float

The chance of performing on_replay vs on_new. Defaults to 0.95.

0.95
staleness_coeff float

The weighting factor for staleness. Defaults to 0.5.

0.5
minimum_fill_ratio float

The class will never sample a replay decision until the level buffer is at least as full as specified by this value. Defaults to 1.0.

1.0
prioritization_params dict

If prioritization="rank", this has a "temperature" field; for "topk" it has a "k" field. If not provided, by default this is initialized to a temperature of 1.0 and k=1. Defaults to None.

None
duplicate_check bool

If this is true, duplicate levels cannot be added to the buffer. This adds some computation to check for duplicates. Defaults to False.

False

find(sampler, level)

Returns the index of level in the level buffer. If level is not present, -1 is returned.

Parameters:

Name Type Description Default
sampler Sampler

The sampler object

required
level Level

The level to find

required

Returns:

Name Type Description
int int

index or -1 if not found.

flush(sampler)

Flushes this sampler, putting it back to its empty state. This does update it in place. TODO: Do we want that? Sam: When jitted it doesn't.

Parameters:

Name Type Description Default
sampler Sampler
required

Returns:

Name Type Description
Sampler Sampler

freshness_weights(sampler)

Returns freshness weights for each level.

Parameters:

Name Type Description Default
sampler Sampler
required

Returns:

Type Description
Array

chex.Array: shape (self.capacity)

get_levels(sampler, level_idx)

Returns the level at a particular index.

Parameters:

Name Type Description Default
sampler Sampler

The sampler object

required
level_idx int

The index to return

required

Returns:

Name Type Description
Level Level

get_levels_extra(sampler, level_idx)

Returns the level extras associated with a particular index

Parameters:

Name Type Description Default
sampler Sampler

The sampler object

required
level_idx int

The index to return

required

Returns:

Name Type Description
dict dict

initialize(pholder_level, pholder_level_extra=None)

Returns the sampler object as a dictionary.

Sampler Object Keys
  • "levels" (shape (self.capacity, ...)): the levels themselves
  • "scores" (shape (self.capacity)): the scores of the levels
  • "timestamps" (shape (self.capacity)): the timestamps of the levels
  • "size" (int): the number of levels currently in the buffer
  • "episode_count" (int): the number of episodes that have been played so far

Parameters:

Name Type Description Default
pholder_level Level

A placeholder level that will be used to initialize the level buffer.

required
pholder_level_extra dict

If given, this should be a dictionary with arbitrary keys that is kept track of alongside each level. An example is "max_return" for each level. Defaults to None.

None

Returns:

Name Type Description
Sampler Sampler

The initialized sampler object

insert(sampler, level, score, level_extra=None)

Attempt to insert level into the level buffer.

Insertion occurs when: - Corresponding score exceeds the score of the lowest weighted level currently in the buffer (in which case it will replace it). - Buffer is not yet at capacity.

Optionally, if the level to be inserted already exists in the level buffer, the corresponding buffer entry will be updated instead. (See, duplicate_check)

Parameters:

Name Type Description Default
sampler Sampler

The sampler object

required
level Level

Level to insert

required
score float

Its score

required
level_extra dict

If level extra was given in initialize, then it must be given here too. Defaults to None.

None

Returns:

Type Description
tuple[Sampler, int]

tuple[Sampler, int]: The updated sampler, and the level's index in the buffer (-1 if it was not inserted)

insert_batch(sampler, levels, scores, level_extras=None)

Inserts a batch of levels.

Parameters:

Name Type Description Default
sampler Sampler

The sampler object

required
levels _type_

The levels to insert. This must be a batched level, in that it has an extra dimension at the front.

required
scores _type_

The scores of each level

required
level_extras dict

The optional level_extras. Defaults to None.

None

level_weights(sampler, prioritization=None, prioritization_params=None)

Returns the weights for each level, taking into account both staleness and score.

Parameters:

Name Type Description Default
sampler Sampler

The sampler

required
prioritization Prioritization

Possibly overrides self.prioritization. Defaults to None.

None
prioritization_params dict

Possibly overrides self.prioritization_params. Defaults to None.

None

Returns:

Type Description
Array

chex.Array: Weights, shape (self.capacity)

sample_replay_decision(sampler, rng)

Returns a single boolean indicating if a replay or new step should be taken. This is based on the proportion of the buffer that is filled and the replay_prob parameter.

Parameters:

Name Type Description Default
sampler Sampler

The sampler object

required
rng PRNGKey
required

Returns:

Name Type Description
bool bool

sample_replay_level(sampler, rng)

Samples a replay level from the buffer. It does this by first computing the weights of each level (using level_weights), and then sampling from the buffer using these weights. The sampler object is updated to reflect the new episode count and the level that was sampled. The level itself is returned as well as the index of the level in the buffer.

Parameters:

Name Type Description Default
sampler Sampler

The sampler object

required
rng PRNGKey
required

Returns:

Type Description
tuple[Sampler, tuple[int, Level]]

tuple[Sampler, tuple[int, Level]]: The updated sampler object, the sampled level's index and the level itself.

sample_replay_levels(sampler, rng, num)

Samples several levels by iteratively calling sample_replay_level. The sampler object is updated to reflect the new episode count and the levels that were sampled.

Parameters:

Name Type Description Default
sampler Sampler

The sampler object

required
rng PRNGKey
required
num int

How many levels to sample

required

Returns:

Type Description
tuple[Sampler, tuple[Array, Level]]

tuple[Sampler, tuple[chex.Array, Level]]: The updated sampler, an array of indices, and multiple levels.

score_weights(sampler, prioritization=None, prioritization_params=None)

Returns an array of shape (self.capacity) with the weights of each level (for sampling purposes).

Parameters:

Name Type Description Default
sampler Sampler
required
prioritization Prioritization

Possibly overrides self.prioritization. Defaults to None.

None
prioritization_params dict

Possibly overrides self.prioritization_params. Defaults to None.

None

Returns:

Type Description
Array

chex.Array: Score weights, shape (self.capacity)

staleness_weights(sampler)

Returns staleness weights for each level.

Parameters:

Name Type Description Default
sampler Sampler
required

Returns:

Type Description
Array

chex.Array: shape (self.capacity)

update(sampler, idx, score, level_extra=None)

This updates the score and level_extras of a level

Parameters:

Name Type Description Default
sampler Sampler

The sampler object

required
idx int

The index of the level

required
score float

The score of the level

required
level_extra dict

The associated. Defaults to None.

None

Returns:

Name Type Description
Sampler Sampler

Updated Sampler

update_batch(sampler, level_inds, scores, level_extras=None)

Updates the scores and level_extras of a batch of levels.

Parameters:

Name Type Description Default
sampler Sampler

The sampler object

required
level_inds Array

Level indices

required
scores Array

Scores

required
level_extras dict

. Defaults to None.

None

Returns:

Name Type Description
Sampler Sampler

Updated Sampler