Skip to content

Commit

Permalink
doc: update brax's document
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Nov 4, 2024
1 parent df01b11 commit 6338f77
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions src/evox/problems/neuroevolution/reinforcement_learning/brax.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,35 @@ def __init__(
reduce_fn: Callable[[jax.Array, int], jax.Array] = jnp.mean,
backend: str = "generalized",
):
"""Contruct a brax-based problem
"""Contruct a brax-based problem.
Firstly, you need to define a jit-able policy function. The policy function should have the following signature:
If you policy is not stateful:
:code:`fn(weights, obs) -> action`,
and if you policy is stateful:
:code:`fn(state, weights, obs) -> action, state`.
Then you need to set the `environment name <https://github.com/google/brax/tree/main/brax/envs>`_,
the maximum episode length, the number of episodes to evaluate for each individual.
For each individual,
it will run the policy with the environment for num_episodes times and use the reduce_fn to reduce the rewards (default to average).
Parameters
----------
policy
A callable if stateful: ``fn(state, weight, obs) -> action, state`` otherwise ``fn(weights, obs) -> action``
A callable if stateful: :code:`fn(state, weight, obs) -> action, state` otherwise :code:`fn(weights, obs) -> action`
env_name
The environment name.
batch_size
The number of brax environments to run in parallel.
Usually this should match the population size at the algorithm side.
max_episode_length
The maximum number of timesteps of an episode.
The maximum number of timesteps of each episode.
num_episodes
Evaluating the number of episodes for each individual.
The number of episodes to evaluate for each individual.
stateful_policy
Whether the policy is stateful (for example, RNN).
Default to False.
If False, the policy should be a pure function with signature (weights, obs) -> action.
If True, the policy should be a stateful function with signature (state, weights, obs) -> (action, state).
If False, the policy should be a pure function with signature :code:`fn(weights, obs) -> action`.
If True, the policy should be a stateful function with signature :code:`fn(state, weight, obs) -> action, state`.
initial_state
The initial state of the stateful policy.
Default to None.
Expand Down

0 comments on commit 6338f77

Please sign in to comment.