Skip to content

Commit

Permalink
update SBCFM notebook details
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianFatras committed Sep 12, 2023
1 parent 1588442 commit 5cef086
Showing 1 changed file with 4 additions and 18 deletions.
22 changes: 4 additions & 18 deletions examples/notebooks/training-8gaussians-to-moons.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"cells": [
{
"cell_type": "markdown",
"id": "a7cf12e8",
"metadata": {
"tags": []
},
Expand All @@ -27,7 +26,6 @@
{
"cell_type": "code",
"execution_count": 1,
"id": "977e54ff",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -53,7 +51,6 @@
},
{
"cell_type": "markdown",
"id": "f86cc21f",
"metadata": {
"tags": []
},
Expand All @@ -75,7 +72,6 @@
{
"cell_type": "code",
"execution_count": 2,
"id": "f7eb28f3",
"metadata": {
"tags": []
},
Expand Down Expand Up @@ -199,7 +195,6 @@
},
{
"cell_type": "markdown",
"id": "f195ebf0",
"metadata": {
"tags": []
},
Expand All @@ -221,7 +216,6 @@
{
"cell_type": "code",
"execution_count": 3,
"id": "b14304ba",
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -343,7 +337,6 @@
},
{
"cell_type": "markdown",
"id": "99ef2f00",
"metadata": {
"tags": []
},
Expand All @@ -369,7 +362,6 @@
{
"cell_type": "code",
"execution_count": 4,
"id": "0e4e0920",
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -456,7 +448,8 @@
"batch_size = 256\n",
"model = MLP(dim=dim, time_varying=True)\n",
"optimizer = torch.optim.Adam(model.parameters())\n",
"FM = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma)\n",
"# For best performance, use ot_method=\"exact\". To follow the theory, use ot_method=\"sinkhorn\"\n",
"FM = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma, ot_method=\"exact\")\n",
"\n",
"start = time.time()\n",
"for k in range(20000):\n",
Expand Down Expand Up @@ -491,7 +484,6 @@
},
{
"cell_type": "markdown",
"id": "f3597e13",
"metadata": {
"tags": []
},
Expand All @@ -512,7 +504,6 @@
{
"cell_type": "code",
"execution_count": 5,
"id": "c2272ec1",
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -633,7 +624,6 @@
},
{
"cell_type": "markdown",
"id": "dc89e19a",
"metadata": {
"tags": []
},
Expand All @@ -655,7 +645,6 @@
{
"cell_type": "code",
"execution_count": 6,
"id": "d047920d",
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -781,7 +770,6 @@
{
"cell_type": "code",
"execution_count": 7,
"id": "74afe15e",
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -917,23 +905,21 @@
{
"cell_type": "code",
"execution_count": null,
"id": "05cc909b",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "ecd5bc7a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "myenv",
"display_name": "riemanian_flow_matching",
"language": "python",
"name": "myenv"
},
Expand All @@ -947,7 +933,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.9.17"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 5cef086

Please sign in to comment.