Skip to content

Commit

Permalink
update single-cell notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
kilianFatras committed Oct 17, 2023
1 parent dbeeb75 commit 98bf94a
Show file tree
Hide file tree
Showing 2 changed files with 375 additions and 94 deletions.
71 changes: 36 additions & 35 deletions examples/notebooks/model-comparison-plotting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -205,7 +205,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -222,7 +222,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 4,
"metadata": {
"tags": []
},
Expand Down Expand Up @@ -327,14 +327,14 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_8787/4293379450.py:6: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.\n",
"/tmp/ipykernel_10858/4293379450.py:6: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.\n",
" image = imageio.imread(filename)\n"
]
}
Expand All @@ -351,40 +351,25 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: 'models/gaussian-moons/cfm_v1.pt'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[17], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m models \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCFM\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mmodels/gaussian-moons/cfm_v1.pt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m,\n\u001b[1;32m 3\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mOT-CFM (ours)\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/otcfm_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 4\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSB-CFM (ours)\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/sbcfm_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 5\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mVP-CFM\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/stochastic_interpolant_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 6\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mFM\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/flow_matching_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 7\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mVP-SDE\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/vp_flow_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 8\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAction-Matching\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/action_matching_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 9\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mAction-Matching (Swish)\u001b[39m\u001b[38;5;124m\"\u001b[39m: torch\u001b[38;5;241m.\u001b[39mload(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodels/gaussian-moons/action_matching_swish_v1.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 10\u001b[0m }\n",
"File \u001b[0;32m~/anaconda3/envs/torchcfm2/lib/python3.10/site-packages/torch/serialization.py:791\u001b[0m, in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, weights_only, **pickle_load_args)\u001b[0m\n\u001b[1;32m 788\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoding\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m pickle_load_args\u001b[38;5;241m.\u001b[39mkeys():\n\u001b[1;32m 789\u001b[0m pickle_load_args[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mencoding\u001b[39m\u001b[38;5;124m'\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[0;32m--> 791\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43m_open_file_like\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mrb\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m opened_file:\n\u001b[1;32m 792\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_zipfile(opened_file):\n\u001b[1;32m 793\u001b[0m \u001b[38;5;66;03m# The zipfile reader is going to advance the current file position.\u001b[39;00m\n\u001b[1;32m 794\u001b[0m \u001b[38;5;66;03m# If we want to actually tail call to torch.jit.load, we need to\u001b[39;00m\n\u001b[1;32m 795\u001b[0m \u001b[38;5;66;03m# reset back to the original position.\u001b[39;00m\n\u001b[1;32m 796\u001b[0m orig_position \u001b[38;5;241m=\u001b[39m opened_file\u001b[38;5;241m.\u001b[39mtell()\n",
"File \u001b[0;32m~/anaconda3/envs/torchcfm2/lib/python3.10/site-packages/torch/serialization.py:271\u001b[0m, in \u001b[0;36m_open_file_like\u001b[0;34m(name_or_buffer, mode)\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_open_file_like\u001b[39m(name_or_buffer, mode):\n\u001b[1;32m 270\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m _is_path(name_or_buffer):\n\u001b[0;32m--> 271\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_open_file\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname_or_buffer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 273\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m mode:\n",
"File \u001b[0;32m~/anaconda3/envs/torchcfm2/lib/python3.10/site-packages/torch/serialization.py:252\u001b[0m, in \u001b[0;36m_open_file.__init__\u001b[0;34m(self, name, mode)\u001b[0m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, name, mode):\n\u001b[0;32m--> 252\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m)\u001b[49m)\n",
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'models/gaussian-moons/cfm_v1.pt'"
]
}
],
"outputs": [],
"source": [
"models = {\n",
" \"CFM\": torch.load(\"models/gaussian-moons/cfm_v1.pt\"),\n",
" \"OT-CFM (ours)\": torch.load(\"models/gaussian-moons/otcfm_v1.pt\"),\n",
" \"SB-CFM (ours)\": torch.load(\"models/gaussian-moons/sbcfm_v1.pt\"),\n",
" \"VP-CFM\": torch.load(\"models/gaussian-moons/stochastic_interpolant_v1.pt\"),\n",
" \"FM\": torch.load(\"models/gaussian-moons/flow_matching_v1.pt\"),\n",
" \"VP-SDE\": torch.load(\"models/gaussian-moons/vp_flow_v1.pt\"),\n",
" \"Action-Matching\": torch.load(\"models/gaussian-moons/action_matching_v1.pt\"),\n",
" \"Action-Matching (Swish)\": torch.load(\"models/gaussian-moons/action_matching_swish_v1.pt\"),\n",
" \"CFM\": torch.load(\"./models/8gaussian-moons/cfm_v1.pt\"),\n",
" \"OT-CFM (ours)\": torch.load(\"models/8gaussian-moons/otcfm_v1.pt\"),\n",
" \"SB-CFM (ours)\": torch.load(\"models/8gaussian-moons/sbcfm_v1.pt\"),\n",
" \"VP-CFM\": torch.load(\"models/8gaussian-moons/stochastic_interpolant_v1.pt\"),\n",
" #\"FM\": torch.load(\"models/8gaussian-moons/flow_matching_v1.pt\"),\n",
" #\"VP-SDE\": torch.load(\"models/8gaussian-moons/vp_flow_v1.pt\"),\n",
" \"Action-Matching\": torch.load(\"models/8gaussian-moons/action_matching_v1.pt\"),\n",
" \"Action-Matching (Swish)\": torch.load(\"models/8gaussian-moons/action_matching_swish_v1.pt\"),\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -411,8 +396,8 @@
" traj = nde.trajectory(sample.to(device), t_span=ts.to(device)).detach().cpu().numpy()\n",
" trajs[name] = traj\n",
"names = [\n",
" \"VP-SDE\",\n",
" \"FM\",\n",
" #\"VP-SDE\",\n",
" #\"FM\",\n",
" \"CFM\",\n",
" \"Action-Matching\",\n",
" \"Action-Matching (Swish)\",\n",
Expand Down Expand Up @@ -490,9 +475,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_10858/1452409276.py:7: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.\n",
" image = imageio.imread(filename)\n"
]
}
],
"source": [
"gif_name = \"gaussians-to-moons\"\n",
"ts = torch.linspace(0, 1, 101)\n",
Expand All @@ -503,6 +497,13 @@
" image = imageio.imread(filename)\n",
" writer.append_data(image)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
398 changes: 339 additions & 59 deletions examples/notebooks/single-cell_example.ipynb

Large diffs are not rendered by default.

0 comments on commit 98bf94a

Please sign in to comment.