Skip to content

Commit

Permalink
doc: update the multidevice algorithm guide
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Nov 19, 2024
1 parent d064a55 commit c6c7de3
Showing 1 changed file with 21 additions and 96 deletions.
117 changes: 21 additions & 96 deletions docs/source/guide/experimental/multidevice_algorithm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,14 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Optional\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"from evox import Algorithm, dataclass, pytree_field, problems, workflows, monitors, use_state\n",
"from evox import dataclass, pytree_field, problems, workflows, monitors, algorithms, use_state\n",
"from evox.core.distributed import ShardingType\n",
"from evox.utils import *"
]
Expand All @@ -41,13 +39,20 @@
"When running in a distributed setup, we need to make decisions on how to place the data on these GPUs."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here, we use the vanilla PSO algorithm as an example. In PSO, each GPU can independently update the local information for its particles. On the other hand, updating the global information requires communication between GPUs, but this process can be handled rather efficiently using an all-reduce operation."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# The only changes:\n",
"# The only change:\n",
"# Add the sharding metadata\n",
"@dataclass\n",
"class SpecialPSOState:\n",
Expand All @@ -61,40 +66,19 @@
" key: jax.random.PRNGKey\n",
"\n",
"\n",
"# inherit from the base PSO algorithm\n",
"# and replace the State type with SpecialPSOState, which contains the sharding metadata\n",
"@dataclass\n",
"class PSO(Algorithm):\n",
" dim: jax.Array = pytree_field(static=True, init=False)\n",
" lb: jax.Array\n",
" ub: jax.Array\n",
" pop_size: jax.Array = pytree_field(static=True)\n",
" w: jax.Array = pytree_field(default=0.6)\n",
" phi_p: jax.Array = pytree_field(default=2.5)\n",
" phi_g: jax.Array = pytree_field(default=0.8)\n",
" mean: Optional[jax.Array] = pytree_field(default=None)\n",
" stdev: Optional[jax.Array] = pytree_field(default=None)\n",
" bound_method: str = pytree_field(static=True, default=\"clip\")\n",
"\n",
" def __post_init__(self):\n",
" self.set_frozen_attr(\"dim\", self.lb.shape[0])\n",
"\n",
"class PSO(algorithms.PSO):\n",
" def setup(self, key):\n",
" state_key, init_pop_key, init_v_key = jax.random.split(key, 3)\n",
" if self.mean is not None and self.stdev is not None:\n",
" population = self.stdev * jax.random.normal(\n",
" init_pop_key, shape=(self.pop_size, self.dim)\n",
" )\n",
" population = jnp.clip(population, self.lb, self.ub)\n",
" velocity = self.stdev * jax.random.normal(\n",
" init_v_key, shape=(self.pop_size, self.dim)\n",
" )\n",
" else:\n",
" length = self.ub - self.lb\n",
" population = jax.random.uniform(\n",
" init_pop_key, shape=(self.pop_size, self.dim)\n",
" )\n",
" population = population * length + self.lb\n",
" velocity = jax.random.uniform(init_v_key, shape=(self.pop_size, self.dim))\n",
" velocity = velocity * length * 2 - length\n",
" length = self.ub - self.lb\n",
" population = jax.random.uniform(\n",
" init_pop_key, shape=(self.pop_size, self.dim)\n",
" )\n",
" population = population * length + self.lb\n",
" velocity = jax.random.uniform(init_v_key, shape=(self.pop_size, self.dim))\n",
" velocity = velocity * length * 2 - length\n",
"\n",
" return SpecialPSOState(\n",
" population=population,\n",
Expand All @@ -105,66 +89,7 @@
" global_best_location=population[0],\n",
" global_best_fitness=jnp.array([jnp.inf]),\n",
" key=state_key,\n",
" )\n",
"\n",
" def ask(self, state):\n",
" return state.population, state\n",
"\n",
" def tell(self, state, fitness):\n",
" key, rg_key, rp_key = jax.random.split(state.key, 3)\n",
"\n",
" rg = jax.random.uniform(rg_key, shape=(self.pop_size, self.dim))\n",
" rp = jax.random.uniform(rp_key, shape=(self.pop_size, self.dim))\n",
"\n",
" compare = state.local_best_fitness > fitness\n",
" local_best_location = jnp.where(\n",
" compare[:, jnp.newaxis], state.population, state.local_best_location\n",
" )\n",
" local_best_fitness = jnp.minimum(state.local_best_fitness, fitness)\n",
"\n",
" global_best_location, global_best_fitness = min_by(\n",
" [state.global_best_location[jnp.newaxis, :], state.population],\n",
" [state.global_best_fitness, fitness],\n",
" )\n",
"\n",
" global_best_fitness = jnp.atleast_1d(global_best_fitness)\n",
"\n",
" velocity = (\n",
" self.w * state.velocity\n",
" + self.phi_p * rp * (local_best_location - state.population)\n",
" + self.phi_g * rg * (global_best_location - state.population)\n",
" )\n",
" population = state.population + velocity\n",
"\n",
" if self.bound_method == \"clip\":\n",
" population = jnp.clip(population, self.lb, self.ub)\n",
" velocity = jnp.clip(velocity, self.lb, self.ub)\n",
" elif self.bound_method == \"reflect\":\n",
" lower_bound_violation = population < self.lb\n",
" upper_bound_violation = population > self.ub\n",
"\n",
" population = jnp.where(\n",
" lower_bound_violation, 2 * self.lb - population, population\n",
" )\n",
" population = jnp.where(\n",
" upper_bound_violation, 2 * self.ub - population, population\n",
" )\n",
" velocity = jnp.where(\n",
" lower_bound_violation | upper_bound_violation, -velocity, velocity\n",
" )\n",
" # enforce the bounds in case the reflected particles are still out of bounds\n",
" population = jnp.clip(population, self.lb, self.ub)\n",
" velocity = jnp.clip(velocity, self.lb, self.ub)\n",
"\n",
" return state.replace(\n",
" population=population,\n",
" velocity=velocity,\n",
" local_best_location=local_best_location,\n",
" local_best_fitness=local_best_fitness,\n",
" global_best_location=global_best_location,\n",
" global_best_fitness=global_best_fitness,\n",
" key=key,\n",
" )\n"
" )"
]
},
{
Expand Down

0 comments on commit c6c7de3

Please sign in to comment.