diff --git a/examples/notebooks/training-8gaussians-to-moons.ipynb b/examples/notebooks/training-8gaussians-to-moons.ipynb index 003e02a..418d35a 100644 --- a/examples/notebooks/training-8gaussians-to-moons.ipynb +++ b/examples/notebooks/training-8gaussians-to-moons.ipynb @@ -2,7 +2,6 @@ "cells": [ { "cell_type": "markdown", - "id": "a7cf12e8", "metadata": { "tags": [] }, @@ -27,7 +26,6 @@ { "cell_type": "code", "execution_count": 1, - "id": "977e54ff", "metadata": {}, "outputs": [], "source": [ @@ -53,7 +51,6 @@ }, { "cell_type": "markdown", - "id": "f86cc21f", "metadata": { "tags": [] }, @@ -75,7 +72,6 @@ { "cell_type": "code", "execution_count": 2, - "id": "f7eb28f3", "metadata": { "tags": [] }, @@ -199,7 +195,6 @@ }, { "cell_type": "markdown", - "id": "f195ebf0", "metadata": { "tags": [] }, @@ -221,7 +216,6 @@ { "cell_type": "code", "execution_count": 3, - "id": "b14304ba", "metadata": {}, "outputs": [ { @@ -343,7 +337,6 @@ }, { "cell_type": "markdown", - "id": "99ef2f00", "metadata": { "tags": [] }, @@ -369,7 +362,6 @@ { "cell_type": "code", "execution_count": 4, - "id": "0e4e0920", "metadata": {}, "outputs": [ { @@ -456,7 +448,8 @@ "batch_size = 256\n", "model = MLP(dim=dim, time_varying=True)\n", "optimizer = torch.optim.Adam(model.parameters())\n", - "FM = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma)\n", + "# For best performance, use ot_method=\"exact\". To follow the theory, use ot_method=\"sinkhorn\"\n", + "FM = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma, ot_method=\"exact\")\n", "\n", "start = time.time()\n", "for k in range(20000):\n", @@ -491,7 +484,6 @@ }, { "cell_type": "markdown", - "id": "f3597e13", "metadata": { "tags": [] }, @@ -512,7 +504,6 @@ { "cell_type": "code", "execution_count": 5, - "id": "c2272ec1", "metadata": {}, "outputs": [ { @@ -633,7 +624,6 @@ }, { "cell_type": "markdown", - "id": "dc89e19a", "metadata": { "tags": [] }, @@ -655,7 +645,6 @@ { "cell_type": "code", "execution_count": 6, - "id": "d047920d", "metadata": {}, "outputs": [ { @@ -781,7 +770,6 @@ { "cell_type": "code", "execution_count": 7, - "id": "74afe15e", "metadata": {}, "outputs": [ { @@ -917,7 +905,6 @@ { "cell_type": "code", "execution_count": null, - "id": "05cc909b", "metadata": {}, "outputs": [], "source": [] @@ -925,7 +912,6 @@ { "cell_type": "code", "execution_count": null, - "id": "ecd5bc7a", "metadata": {}, "outputs": [], "source": [] @@ -933,7 +919,7 @@ ], "metadata": { "kernelspec": { - "display_name": "myenv", + "display_name": "riemanian_flow_matching", "language": "python", "name": "myenv" }, @@ -947,7 +933,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.9.17" } }, "nbformat": 4,