Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OmegaConf #28

Merged
merged 38 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b330ed5
yacs.CfgNode instead of argparse.Namespace
gerkone Feb 15, 2024
1a35dd2
yaml config cleanup
gerkone Feb 16, 2024
f76c218
configured models
gerkone Feb 16, 2024
2eb730f
updated ldc config
gerkone Feb 16, 2024
f5ca0a9
updated tests
gerkone Feb 16, 2024
3f9a9de
yacs dependency
gerkone Feb 16, 2024
cd0c82f
move experiments/* to lagrangebench/runner.py
arturtoshev Feb 19, 2024
c234b53
in docs and __init__ read the version from pyproject.toml
arturtoshev Feb 19, 2024
a4e509f
fix ruff warning
arturtoshev Feb 19, 2024
3ced17b
ruff
gerkone Feb 19, 2024
ca2d810
cfg.main and only --config directly to main.py
arturtoshev Feb 20, 2024
ab0e7f1
solve wandb isort issue with explicit third-party
arturtoshev Feb 20, 2024
776397d
remove pseudo-argument --
arturtoshev Feb 20, 2024
16ae58d
omegaconf
arturtoshev Feb 22, 2024
b139ad1
class Trainer
arturtoshev Feb 22, 2024
716c2fe
defaults.yaml -> defaults.py
arturtoshev Feb 22, 2024
ddc1609
fix defaults.yaml to defaults.py
arturtoshev Feb 22, 2024
0a16248
add runner_test
arturtoshev Feb 22, 2024
5e5233f
add cfg sanity checks
arturtoshev Feb 22, 2024
fa2d604
merge zenode fix
arturtoshev Feb 23, 2024
ecbe8fc
removed loss config
gerkone Feb 23, 2024
434f27f
wandb update configs
gerkone Feb 23, 2024
1cd5fa5
add check_cfg asserts
arturtoshev Feb 23, 2024
9481e40
rename data_dir to dataset_path
arturtoshev Feb 24, 2024
eefe8e9
use LAGRANGEBENCH_DEFAULTS as a special config extends value
arturtoshev Feb 24, 2024
8d440e6
remove .main. as an additional cfg hierarchy
arturtoshev Feb 24, 2024
1427316
reintroduce base.yaml per dataset and fix .main.
arturtoshev Feb 24, 2024
56806d2
move wandb_run to Trainer.train()
arturtoshev Feb 24, 2024
7d4829a
model_dir to load_ckp (and store_ckp)
arturtoshev Feb 24, 2024
9daf64a
pass cfg_model directly to case_builder
arturtoshev Feb 24, 2024
f1747f6
fix tests
arturtoshev Feb 24, 2024
e00711d
update readme and requirements
arturtoshev Feb 24, 2024
908407f
tutorial notebook ; add merge default configs
gerkone Feb 24, 2024
cacade4
automate version bumping
arturtoshev Feb 26, 2024
01ce3dc
update docs
arturtoshev Feb 26, 2024
05162fa
fix trainer docs
arturtoshev Feb 26, 2024
bfefdfe
defaults in the docs
gerkone Feb 26, 2024
c4d698a
doc fixes
gerkone Feb 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
ckp/
rollout/
rollouts/
wandb
wandb/
*.out
datasets
baselines
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ repos:
- id: check-yaml
- id: requirements-txt-fixer
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.1.8'
rev: 'v0.2.2'
hooks:
- id: ruff
args: [ --fix ]
Expand Down
42 changes: 25 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pip install --upgrade jax[cuda12_pip]==0.4.20 -f https://storage.googleapis.com/
### MacOS
Currently, only the CPU installation works. You will need to change a few small things to get it going:
- Clone installation: in `pyproject.toml` change the torch version from `2.1.0+cpu` to `2.1.0`. Then, remove the `poetry.lock` file and run `poetry install --only main`.
- Configs: You will need to set `f64: False` and `num_workers: 0` in the `configs/` files.
- Configs: You will need to set `dtype=float32` and `train.num_workers=0`.

Although the current [`jax-metal==0.0.5` library](https://pypi.org/project/jax-metal/) supports jax in general, there seems to be a missing feature used by `jax-md` related to padding -> see [this issue](https://github.com/google/jax/issues/16366#issuecomment-1591085071).

Expand All @@ -83,39 +83,39 @@ Although the current [`jax-metal==0.0.5` library](https://pypi.org/project/jax-m
A general tutorial is provided in the example notebook "Training GNS on the 2D Taylor Green Vortex" under `./notebooks/tutorial.ipynb` on the [LagrangeBench repository](https://github.com/tumaer/lagrangebench). The notebook covers the basics of LagrangeBench, such as loading a dataset, setting up a case, training a model from scratch and evaluating its performance.

### Running in a local clone (`main.py`)
Alternatively, experiments can also be set up with `main.py`, based on extensive YAML config files and cli arguments (check [`configs/`](configs/)). By default, the arguments have priority as: 1) passed cli arguments, 2) YAML config and 3) [`defaults.py`](lagrangebench/defaults.py) (`lagrangebench` defaults).
Alternatively, experiments can also be set up with `main.py`, based on extensive YAML config files and cli arguments (check [`configs/`](configs/)). By default, the arguments have priority as 1) passed cli arguments, 2) YAML config and 3) [`defaults.py`](lagrangebench/defaults.py) (`lagrangebench` defaults).

When loading a saved model with `--model_dir` the config from the checkpoint is automatically loaded and training is restarted. For more details check the [`experiments/`](experiments/) directory and the [`run.py`](experiments/run.py) file.
When loading a saved model with `load_ckp` the config from the checkpoint is automatically loaded and training is restarted. For more details check the [`runner.py`](lagrangebench/runner.py) file.

**Train**

For example, to start a _GNS_ run from scratch on the RPF 2D dataset use
```
python main.py --config configs/rpf_2d/gns.yaml
python main.py config=configs/rpf_2d/gns.yaml
```
Some model presets can be found in `./configs/`.

If `--mode=all`, then training (`--mode=train`) and subsequent inference (`--mode=infer`) on the test split will be run in one go.
If `mode=all` is provided, then training (`mode=train`) and subsequent inference (`mode=infer`) on the test split will be run in one go.


**Restart training**

To restart training from the last checkpoint in `--model_dir` use
To restart training from the last checkpoint in `load_ckp` use
```
python main.py --model_dir ckp/gns_rpf2d_yyyymmdd-hhmmss
python main.py load_ckp=ckp/gns_rpf2d_yyyymmdd-hhmmss
```

**Inference**

To evaluate a trained model from `--model_dir` on the test split (`--test`) use
To evaluate a trained model from `load_ckp` on the test split (`test=True`) use
```
python main.py --model_dir ckp/gns_rpf2d_yyyymmdd-hhmmss/best --rollout_dir rollout/gns_rpf2d_yyyymmdd-hhmmss/best --mode infer --test
python main.py load_ckp=ckp/gns_rpf2d_yyyymmdd-hhmmss/best rollout_dir=rollout/gns_rpf2d_yyyymmdd-hhmmss/best mode=infer test=True
```

If the default `--out_type_infer=pkl` is active, then the generated trajectories and a `metricsYYYY_MM_DD_HH_MM_SS.pkl` file will be written to the `--rollout_dir`. The metrics file contains all `--metrics_infer` properties for each generated rollout.
If the default `eval.infer.out_type=pkl` is active, then the generated trajectories and a `metricsYYYY_MM_DD_HH_MM_SS.pkl` file will be written to `eval.rollout_dir`. The metrics file contains all `eval.infer.metrics` properties for each generated rollout.

## Datasets
The datasets are hosted on Zenodo under the DOI: [10.5281/zenodo.10021925](https://zenodo.org/doi/10.5281/zenodo.10021925). When creating a new dataset instance, the data is automatically downloaded. Alternatively, to manually download them use the `download_data.sh` shell script, either with a specific dataset name or "all". Namely
The datasets are hosted on Zenodo under the DOI: [10.5281/zenodo.10021925](https://zenodo.org/doi/10.5281/zenodo.10021925). If a dataset is not found in `dataset_path`, the data is automatically downloaded. Alternatively, to manually download the datasets use the `download_data.sh` shell script, either with a specific dataset name or "all". Namely
- __Taylor Green Vortex 2D__: `bash download_data.sh tgv_2d datasets/`
- __Reverse Poiseuille Flow 2D__: `bash download_data.sh rpf_2d datasets/`
- __Lid Driven Cavity 2D__: `bash download_data.sh ldc_2d datasets/`
Expand All @@ -129,7 +129,7 @@ The datasets are hosted on Zenodo under the DOI: [10.5281/zenodo.10021925](https
### Notebooks
We provide three notebooks that show LagrangeBench functionalities, namely:
- [`tutorial.ipynb`](notebooks/tutorial.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/lagrangebench/blob/main/notebooks/tutorial.ipynb), with a general overview of LagrangeBench library, with training and evaluation of a simple GNS model,
- [`datasets.ipynb`](notebooks/datasets.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/lagrangebench/blob/main/notebooks/datasets.ipynb), with more details and visualizations on the datasets, and
- [`datasets.ipynb`](notebooks/datasets.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/lagrangebench/blob/main/notebooks/datasets.ipynb), with more details and visualizations of the datasets, and
- [`gns_data.ipynb`](notebooks/gns_data.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/lagrangebench/blob/main/notebooks/gns_data.ipynb), showing how to train models within LagrangeBench on the datasets from the paper [Learning to Simulate Complex Physics with Graph Networks](https://arxiv.org/abs/2002.09405).

## Directory structure
Expand All @@ -144,7 +144,8 @@ We provide three notebooks that show LagrangeBench functionalities, namely:
┃ ┗ 📜utils.py
┣ 📂evaluate # Evaluation and rollout generation tools
┃ ┣ 📜metrics.py
┃ ┗ 📜rollout.py
┃ ┣ 📜rollout.py
┃ ┗ 📜utils.py
┣ 📂models # Baseline models
┃ ┣ 📜base.py # BaseModel class
┃ ┣ 📜egnn.py
Expand All @@ -157,6 +158,7 @@ We provide three notebooks that show LagrangeBench functionalities, namely:
┃ ┣ 📜strats.py # Training tricks
┃ ┗ 📜trainer.py # Trainer method
┣ 📜defaults.py # Default values
┣ 📜runner.py # Runner wrapping training and inference
┗ 📜utils.py
```

Expand All @@ -167,9 +169,9 @@ Welcome! We highly appreciate [Github issues](https://github.com/tumaer/lagrange
You can also chat with us on [**Discord**](https://discord.gg/Ds8jRZ78hU).

### Contributing Guideline
If you want to contribute to this repository, you will need the dev depencencies, i.e.
If you want to contribute to this repository, you will need the dev dependencies, i.e.
install the environment with `poetry install` without the ` --only main` flag.
Then, we also recommend you to install the pre-commit hooks
Then, we also recommend you install the pre-commit hooks
if you don't want to manually run `pre-commit run` before each commit. To sum up:

```bash
Expand All @@ -181,6 +183,10 @@ source $PATH_TO_LAGRANGEBENCH_VENV/bin/activate
# install pre-commit hooks defined in .pre-commit-config.yaml
# ruff is configured in pyproject.toml
pre-commit install

# if you want to bump the version in both pyproject.toml and __init__.py, do
poetry self add poetry-bumpversion
poetry version patch # or minor/major
```

After you have run `git add <FILE>` and try to `git commit`, the pre-commit hook will
Expand All @@ -195,10 +201,11 @@ pytest

### Clone vs Library
LagrangeBench can be installed by cloning the repository or as a standalone library. This offers more flexibility, but it also comes with its disadvantages: the necessity to implement some things twice. If you change any of the following things, make sure to update its counterpart as well:
- General setup in `experiments/` and `notebooks/tutorial.ipynb`
- General setup in `lagrangebench/runner.py` and `notebooks/tutorial.ipynb`
- Configs in `configs/` and `lagrangebench/defaults.py`
- Zenodo URLs in `download_data.sh` and `lagrangebench/data/data.py`
- Dependencies in `pyproject.toml`, `requirements_cuda.txt`, and `docs/requirements.txt`
- Library version in `pyproject.toml` and `lagrangebench/__init__.py`


## Citation
Expand Down Expand Up @@ -229,6 +236,7 @@ The associated datasets can be cited as:


### Publications
The following further publcations are based on the LagrangeBench codebase:
The following further publications are based on the LagrangeBench codebase:

1. [Learning Lagrangian Fluid Mechanics with E(3)-Equivariant Graph Neural Networks (GSI 2023)](https://arxiv.org/abs/2305.15603), A. P. Toshev, G. Galletti, J. Brandstetter, S. Adami, N. A. Adams
2. [Neural SPH: Improved Neural Modeling of Lagrangian Fluid Dynamics](https://arxiv.org/abs/2402.06275), A. P. Toshev, J. A. Erbesdobler, N. A. Adams, J. Brandstetter
6 changes: 0 additions & 6 deletions configs/WaterDrop_2d/base.yaml

This file was deleted.

23 changes: 18 additions & 5 deletions configs/WaterDrop_2d/gns.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
extends: WaterDrop_2d/base.yaml
extends: LAGRANGEBENCH_DEFAULTS

model: gns
num_mp_steps: 10
latent_dim: 128
lr_start: 5.e-4
main:
dataset_path: /tmp/datasets/WaterDrop

model:
name: gns
num_mp_steps: 10
latent_dim: 128

train:
optimizer:
lr_start: 5.e-4

logging:
wandb_project: waterdrop_2d

neighbors:
backend: matscipy
12 changes: 7 additions & 5 deletions configs/dam_2d/base.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
extends: defaults.yaml
extends: LAGRANGEBENCH_DEFAULTS

data_dir: datasets/2D_DAM_5740_20kevery100
wandb_project: dam_2d
dataset_path: datasets/2D_DAM_5740_20kevery100

neighbor_list_multiplier: 2.0
noise_std: 0.001
logging:
wandb_project: dam_2d

neighbors:
multiplier: 2.0
15 changes: 10 additions & 5 deletions configs/dam_2d/gns.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
extends: dam_2d/base.yaml
extends: configs/dam_2d/base.yaml

model: gns
num_mp_steps: 10
latent_dim: 128
lr_start: 5.e-4
model:
name: gns
num_mp_steps: 10
latent_dim: 128

train:
noise_std: 0.001
optimizer:
lr_start: 5.e-4
16 changes: 10 additions & 6 deletions configs/dam_2d/segnn.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
extends: dam_2d/base.yaml
extends: configs/dam_2d/base.yaml

model: segnn
num_mp_steps: 10
latent_dim: 64
lr_start: 5.e-4
model:
name: segnn
num_mp_steps: 10
latent_dim: 64
isotropic_norm: True

isotropic_norm: True
train:
noise_std: 0.001
optimizer:
lr_start: 5.e-4
118 changes: 0 additions & 118 deletions configs/defaults.yaml

This file was deleted.

12 changes: 7 additions & 5 deletions configs/ldc_2d/base.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
extends: defaults.yaml
extends: LAGRANGEBENCH_DEFAULTS

data_dir: datasets/2D_LDC_2708_10kevery100
wandb_project: ldc_2d
dataset_path: datasets/2D_LDC_2708_10kevery100

neighbor_list_multiplier: 2.0
noise_std: 0.001
logging:
wandb_project: ldc_2d

neighbors:
multiplier: 2.0
15 changes: 10 additions & 5 deletions configs/ldc_2d/gns.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
extends: ldc_2d/base.yaml
extends: configs/ldc_2d/base.yaml

model: gns
num_mp_steps: 10
latent_dim: 128
lr_start: 5.e-4
model:
name: gns
num_mp_steps: 10
latent_dim: 128

train:
noise_std: 0.001
optimizer:
lr_start: 5.e-4
Loading
Loading