diff --git a/docs/notebooks/summary_tutorial.ipynb b/docs/notebooks/summary_tutorial.ipynb index 41b3064..e97e7a0 100644 --- a/docs/notebooks/summary_tutorial.ipynb +++ b/docs/notebooks/summary_tutorial.ipynb @@ -353,7 +353,7 @@ "source": [ "Then we can transform the `loss` function with the function transformation: `summary.with_summary_output_reduced`.\n", "This transformation goes through the computation and extracts all the tagged values and returns them to us by name in a dictionary.\n", - "In implementation, all the hard work here is done by the wonderful `oryx` library (in particular [harvest](https://github.com/tensorflow/probability/blob/main/spinoffs/oryx/oryx/core/interpreters/harvest.py)).\n", + "In implementation, all the hard work here is done by the wonderful `oryx` library (in particular [harvest](https://github.com/jax-ml/oryx/tree/main/oryx/core/interpreters/harvest.py)).\n", "When we wrap a function this, we return a tuple containing the original result, and a dictionary with the desired metrics." ] }, @@ -893,7 +893,7 @@ "source": [ "def monitor(a):\n", " summary.summary(\"with_input\", a)\n", - " summary.summary(\"constant\", 2.0)\n", + " summary.summary(\"constant\", jnp.asarray(2.0))\n", " summary.summary(\"constant_with_inp\", 2.0 + (a * 0))\n", " return a\n", "\n", diff --git a/docs/notebooks/summary_tutorial.md b/docs/notebooks/summary_tutorial.md index 3c4c8a9..24d08d8 100644 --- a/docs/notebooks/summary_tutorial.md +++ b/docs/notebooks/summary_tutorial.md @@ -112,7 +112,7 @@ def loss(parameters): Then we can transform the `loss` function with the function transformation: `summary.with_summary_output_reduced`. This transformation goes through the computation and extracts all the tagged values and returns them to us by name in a dictionary. -In implementation, all the hard work here is done by the wonderful `oryx` library (in particular [harvest](https://github.com/tensorflow/probability/blob/main/spinoffs/oryx/oryx/core/interpreters/harvest.py)). +In implementation, all the hard work here is done by the wonderful `oryx` library (in particular [harvest](https://github.com/jax-ml/oryx/tree/main/oryx/core/interpreters/harvest.py)). When we wrap a function this, we return a tuple containing the original result, and a dictionary with the desired metrics. ```{code-cell} @@ -399,7 +399,7 @@ outputId: b1ccd1db-5615-45b9-b52c-0ae78ea9369f --- def monitor(a): summary.summary("with_input", a) - summary.summary("constant", 2.0) + summary.summary("constant", jnp.asarray(2.0)) summary.summary("constant_with_inp", 2.0 + (a * 0)) return a diff --git a/docs/notebooks/summary_tutorial.py b/docs/notebooks/summary_tutorial.py index 9a36604..ef2431c 100644 --- a/docs/notebooks/summary_tutorial.py +++ b/docs/notebooks/summary_tutorial.py @@ -25,6 +25,7 @@ # kernelspec: # display_name: Python 3 # name: python3 +# pylint: disable=line-too-long # --- # + [markdown] id="ryqPvTKI19zH" @@ -99,7 +100,7 @@ def loss(parameters): # + [markdown] id="AL9_xgfR4yPS" # Then we can transform the `loss` function with the function transformation: `summary.with_summary_output_reduced`. # This transformation goes through the computation and extracts all the tagged values and returns them to us by name in a dictionary. -# In implementation, all the hard work here is done by the wonderful `oryx` library (in particular [harvest](https://github.com/tensorflow/probability/blob/main/spinoffs/oryx/oryx/core/interpreters/harvest.py)). +# In implementation, all the hard work here is done by the wonderful `oryx` library (in particular [harvest](https://github.com/jax-ml/oryx/tree/main/oryx/core/interpreters/harvest.py)). # When we wrap a function this, we return a tuple containing the original result, and a dictionary with the desired metrics. # + colab={"base_uri": "https://localhost:8080/"} id="hZQkB6Um8PI5" outputId="984e4f64-7562-48ae-ca68-ff4014037553" @@ -299,7 +300,7 @@ def loss(parameters): # + colab={"base_uri": "https://localhost:8080/"} id="iIbrjrJ4HEd-" outputId="b1ccd1db-5615-45b9-b52c-0ae78ea9369f" def monitor(a): summary.summary("with_input", a) - summary.summary("constant", 2.0) + summary.summary("constant", jnp.asarray(2.0)) summary.summary("constant_with_inp", 2.0 + (a * 0)) return a