Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681093047
  • Loading branch information
gnecula authored and learned_optimization authors committed Oct 1, 2024
1 parent 4bcaeb0 commit 36fd2e1
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 12 deletions.
7 changes: 3 additions & 4 deletions docs/notebooks/summary_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -997,9 +997,9 @@
"id": "jNt9CNJf2HJN"
},
"source": [
"### jax.experimental.host_callback\n",
"### jax external callbacks\n",
"\n",
"Jax has some support to send data back from an accelerator back to the host while a ja program is running. This is exposed in jax.experimental.host_callback.\n",
"Jax has some support to send data back from an accelerator back to the host while a ja program is running. This is exposed in https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html.\n",
"\n",
"One can use this to print which is a quick way to get data out of a network."
]
Expand All @@ -1025,13 +1025,12 @@
}
],
"source": [
"from jax.experimental import host_callback as hcb\n",
"\n",
"\n",
"def loss(parameters):\n",
" loss = jnp.mean(parameters**2)\n",
" to_look_at = jnp.mean(123.)\n",
" hcb.id_print(to_look_at, name=\"to_look_at\")\n",
" jax.debug.print(\"to_look_at={}\", to_look_at)\n",
" return loss\n",
"\n",
"\n",
Expand Down
7 changes: 3 additions & 4 deletions docs/notebooks/summary_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -461,9 +461,9 @@ print(to_look_at)

+++ {"id": "jNt9CNJf2HJN"}

### jax.experimental.host_callback
### jax external callbacks

Jax has some support to send data back from an accelerator back to the host while a ja program is running. This is exposed in jax.experimental.host_callback.
Jax has some support to send data back from an accelerator back to the host while a ja program is running. This is exposed in https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html.

One can use this to print which is a quick way to get data out of a network.

Expand All @@ -474,13 +474,12 @@ colab:
id: 1Ih2LxP22MZD
outputId: 0dd0b8ec-2c9e-414d-eadf-843122b7b8ab
---
from jax.experimental import host_callback as hcb
def loss(parameters):
loss = jnp.mean(parameters**2)
to_look_at = jnp.mean(123.)
hcb.id_print(to_look_at, name="to_look_at")
jax.debug.print("to_look_at={}", to_look_at)
return loss
Expand Down
9 changes: 5 additions & 4 deletions docs/notebooks/summary_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,20 +348,21 @@ def loss(parameters):
print(to_look_at)

# + [markdown] id="jNt9CNJf2HJN"
# ### jax.experimental.host_callback
# ### jax external callbacks
#
# Jax has some support to send data back from an accelerator back to the host while a ja program is running. This is exposed in jax.experimental.host_callback.
# Jax has some support to send data back from an accelerator back to the host
# while a jax program is running. This is exposed in
# https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html.
#
# One can use this to print which is a quick way to get data out of a network.

# + colab={"base_uri": "https://localhost:8080/"} id="1Ih2LxP22MZD" outputId="0dd0b8ec-2c9e-414d-eadf-843122b7b8ab"
from jax.experimental import host_callback as hcb


def loss(parameters):
loss = jnp.mean(parameters**2)
to_look_at = jnp.mean(123.)
hcb.id_print(to_look_at, name="to_look_at")
jax.debug.print("to_look_at={}", to_look_at)
return loss


Expand Down

0 comments on commit 36fd2e1

Please sign in to comment.