diff --git a/.gitignore b/.gitignore index c28dee5..5580ed4 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ ckp/ rollout/ rollouts/ -wandb +wandb/ *.out datasets baselines diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fc49dbe..db02477 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: - id: check-yaml - id: requirements-txt-fixer - repo: https://github.com/astral-sh/ruff-pre-commit - rev: 'v0.1.8' + rev: 'v0.2.2' hooks: - id: ruff args: [ --fix ] diff --git a/README.md b/README.md index dd41baf..c508482 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ pip install --upgrade jax[cuda12_pip]==0.4.20 -f https://storage.googleapis.com/ ### MacOS Currently, only the CPU installation works. You will need to change a few small things to get it going: - Clone installation: in `pyproject.toml` change the torch version from `2.1.0+cpu` to `2.1.0`. Then, remove the `poetry.lock` file and run `poetry install --only main`. -- Configs: You will need to set `f64: False` and `num_workers: 0` in the `configs/` files. +- Configs: You will need to set `f32: True` and `num_workers: 0` in the `configs/` files. Although the current [`jax-metal==0.0.5` library](https://pypi.org/project/jax-metal/) supports jax in general, there seems to be a missing feature used by `jax-md` related to padding -> see [this issue](https://github.com/google/jax/issues/16366#issuecomment-1591085071). @@ -81,7 +81,7 @@ Although the current [`jax-metal==0.0.5` library](https://pypi.org/project/jax-m A general tutorial is provided in the example notebook "Training GNS on the 2D Taylor Green Vortex" under `./notebooks/tutorial.ipynb` on the [LagrangeBench repository](https://github.com/tumaer/lagrangebench). The notebook covers the basics of LagrangeBench, such as loading a dataset, setting up a case, training a model from scratch and evaluating its performance. ### Running in a local clone (`main.py`) -Alternatively, experiments can also be set up with `main.py`, based on extensive YAML config files and cli arguments (check [`configs/`](configs/)). By default, the arguments have priority as: 1) passed cli arguments, 2) YAML config and 3) [`defaults.py`](lagrangebench/defaults.py) (`lagrangebench` defaults). +Alternatively, experiments can also be set up with `main.py`, based on extensive YAML config files and cli arguments (check [`configs/`](configs/)). By default, the arguments have priority as 1) passed cli arguments, 2) YAML config and 3) [`defaults.py`](lagrangebench/defaults.py) (`lagrangebench` defaults). When loading a saved model with `--model_dir` the config from the checkpoint is automatically loaded and training is restarted. For more details check the [`experiments/`](experiments/) directory and the [`run.py`](experiments/run.py) file. @@ -127,7 +127,7 @@ The datasets are hosted on Zenodo under the DOI: [10.5281/zenodo.10021925](https ### Notebooks We provide three notebooks that show LagrangeBench functionalities, namely: - [`tutorial.ipynb`](notebooks/tutorial.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/lagrangebench/blob/main/notebooks/tutorial.ipynb), with a general overview of LagrangeBench library, with training and evaluation of a simple GNS model, -- [`datasets.ipynb`](notebooks/datasets.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/lagrangebench/blob/main/notebooks/datasets.ipynb), with more details and visualizations on the datasets, and +- [`datasets.ipynb`](notebooks/datasets.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/lagrangebench/blob/main/notebooks/datasets.ipynb), with more details and visualizations of the datasets, and - [`gns_data.ipynb`](notebooks/gns_data.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/lagrangebench/blob/main/notebooks/gns_data.ipynb), showing how to train models within LagrangeBench on the datasets from the paper [Learning to Simulate Complex Physics with Graph Networks](https://arxiv.org/abs/2002.09405). ## Directory structure @@ -165,9 +165,9 @@ Welcome! We highly appreciate [Github issues](https://github.com/tumaer/lagrange You can also chat with us on [**Discord**](https://discord.gg/Ds8jRZ78hU). ### Contributing Guideline -If you want to contribute to this repository, you will need the dev depencencies, i.e. +If you want to contribute to this repository, you will need the dev dependencies, i.e. install the environment with `poetry install` without the ` --only main` flag. -Then, we also recommend you to install the pre-commit hooks +Then, we also recommend you install the pre-commit hooks if you don't want to manually run `pre-commit run` before each commit. To sum up: ```bash @@ -220,6 +220,6 @@ The associated datasets can be cited as: ### Publications -The following further publcations are based on the LagrangeBench codebase: +The following further publications are based on the LagrangeBench codebase: 1. [Learning Lagrangian Fluid Mechanics with E(3)-Equivariant Graph Neural Networks (GSI 2023)](https://arxiv.org/abs/2305.15603), A. P. Toshev, G. Galletti, J. Brandstetter, S. Adami, N. A. Adams diff --git a/configs/WaterDrop_2d/base.yaml b/configs/WaterDrop_2d/base.yaml deleted file mode 100644 index be27172..0000000 --- a/configs/WaterDrop_2d/base.yaml +++ /dev/null @@ -1,6 +0,0 @@ -extends: defaults.yaml - -data_dir: /tmp/datasets/WaterDrop -wandb_project: waterdrop_2d - -neighbor_list_backend: matscipy diff --git a/configs/WaterDrop_2d/gns.yaml b/configs/WaterDrop_2d/gns.yaml index b89287a..cb181fc 100644 --- a/configs/WaterDrop_2d/gns.yaml +++ b/configs/WaterDrop_2d/gns.yaml @@ -1,6 +1,16 @@ -extends: WaterDrop_2d/base.yaml +main: + data_dir: /tmp/datasets/WaterDrop -model: gns -num_mp_steps: 10 -latent_dim: 128 -lr_start: 5.e-4 +model: + name: gns + num_mp_steps: 10 + latent_dim: 128 + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: waterdrop_2d + +neighbors: + backend: matscipy diff --git a/configs/dam_2d/base.yaml b/configs/dam_2d/base.yaml deleted file mode 100644 index be1d3bd..0000000 --- a/configs/dam_2d/base.yaml +++ /dev/null @@ -1,7 +0,0 @@ -extends: defaults.yaml - -data_dir: datasets/2D_DAM_5740_20kevery100 -wandb_project: dam_2d - -neighbor_list_multiplier: 2.0 -noise_std: 0.001 diff --git a/configs/dam_2d/gns.yaml b/configs/dam_2d/gns.yaml index 1b5891e..7d10992 100644 --- a/configs/dam_2d/gns.yaml +++ b/configs/dam_2d/gns.yaml @@ -1,6 +1,19 @@ -extends: dam_2d/base.yaml +main: + data_dir: datasets/2D_DAM_5740_20kevery100 + +model: + name: gns + num_mp_steps: 10 + latent_dim: 128 + +optimizer: + lr_start: 5.e-4 + noise_std: 0.001 + +logging: + wandb_project: dam_2d + +neighbors: + multiplier: 2.0 + -model: gns -num_mp_steps: 10 -latent_dim: 128 -lr_start: 5.e-4 diff --git a/configs/dam_2d/segnn.yaml b/configs/dam_2d/segnn.yaml index e7facf7..2bfb40b 100644 --- a/configs/dam_2d/segnn.yaml +++ b/configs/dam_2d/segnn.yaml @@ -1,8 +1,20 @@ -extends: dam_2d/base.yaml +main: + data_dir: datasets/2D_DAM_5740_20kevery100 -model: segnn -num_mp_steps: 10 -latent_dim: 64 -lr_start: 5.e-4 +model: + name: segnn + num_mp_steps: 10 + latent_dim: 64 -isotropic_norm: True +train: + isotropic_norm: True + +optimizer: + lr_start: 5.e-4 + noise_std: 0.001 + +logging: + wandb_project: dam_2d + +neighbors: + multiplier: 2.0 diff --git a/configs/defaults.yaml b/configs/defaults.yaml deleted file mode 100644 index 0771f6a..0000000 --- a/configs/defaults.yaml +++ /dev/null @@ -1,118 +0,0 @@ -# Fallback parameters for the config file. These are overwritten by the config file. -extends: -# Model settings -# Model architecture name. gns, segnn, egnn -model: -# Length of the position input sequence -input_seq_length: 6 -# Number of message passing steps -num_mp_steps: 10 -# Number of MLP layers -num_mlp_layers: 2 -# Hidden dimension -latent_dim: 128 -# Load checkpointed model from this directory -model_dir: -# SEGNN only parameters -# Steerable attributes level -lmax_attributes: 1 -# Level of the hidden layer -lmax_hidden: 1 -# SEGNN normalization. instance, batch, none -segnn_norm: none -# SEGNN velocity aggregation. avg or last -velocity_aggregate: avg - -# Optimization settings -# Max steps -step_max: 500000 -# Batch size -batch_size: 1 -# Starting learning rate -lr_start: 1.e-4 -# Final learning rate after decay -lr_final: 1.e-6 -# Rate of learning rate decay -lr_decay_rate: 0.1 -# Number of steps for the learning rate to decay -lr_decay_steps: 1.e+5 -# Standard deviation for the additive noise -noise_std: 0.0003 -# Whether to use magnitudes or not -magnitude_features: False -# Whether to normalize inputs and outputs with the same value in x, y, ans z. -isotropic_norm: False -# Parameters related to the push-forward trick -pushforward: - # At which training step to introduce next unroll stage - steps: [-1, 200000, 300000, 400000] - # For how many steps to unroll - unrolls: [0, 1, 2, 3] - # Which probability ratio to keep between the unrolls - probs: [18, 2, 1, 1] - -# Loss settings -# Loss weight for position, acceleration, and velocity components -loss_weight: - acc: 1.0 - -# Run settings -# train, infer, all -mode: all -# Dataset directory -data_dir: -# Number of rollout steps. If "-1", then defaults to sequence_length - input_seq_len. -# n_rollout_steps must be <= ground truth len. For extrapolation use n_extrap_steps -n_rollout_steps: 20 -# Number of evaluation trajectories. "-1" for all available -eval_n_trajs: 50 -# Number of extrapolation steps -n_extrap_steps: 0 -# Whether to use test or validation split -test: False -# Seed -seed: 0 -# Cuda device. "-1" for cpu -gpu: 0 -# GPU memory allocation https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html -xla_mem_fraction: 0.75 -# Double precision everywhere other than the ML model -f64: True -# Neighbour list backend. jaxmd_vmap, jaxmd_scan, matscipy -neighbor_list_backend: jaxmd_vmap -# Neighbour list capacity multiplier -neighbor_list_multiplier: 1.25 -# number of workers for data loading -num_workers: 4 - -# Logging settings -# Use wandb for logging -wandb: False -wandb_project: False -# Change this with your own entity -wandb_entity: lagrangebench -# Number of steps between training logging -log_steps: 1000 -# Number of steps between evaluation -eval_steps: 10000 -# Checkpoint directory -ckp_dir: ckp -# Rollout/metrics directory -rollout_dir: -# Rollout storage format. vtk, pkl, none -out_type: none -# List of metrics. mse, mae, sinkhorn, e_kin -metrics: - - mse -metrics_stride: 10 - -# Inference params (valid/test) -metrics_infer: - - mse - - sinkhorn - - e_kin -metrics_stride_infer: 1 -out_type_infer: pkl -eval_n_trajs_infer: -1 -# batch size for validation/testing -batch_size_infer: 2 diff --git a/configs/ldc_2d/base.yaml b/configs/ldc_2d/base.yaml deleted file mode 100644 index d9fdc96..0000000 --- a/configs/ldc_2d/base.yaml +++ /dev/null @@ -1,7 +0,0 @@ -extends: defaults.yaml - -data_dir: datasets/2D_LDC_2708_10kevery100 -wandb_project: ldc_2d - -neighbor_list_multiplier: 2.0 -noise_std: 0.001 diff --git a/configs/ldc_2d/gns.yaml b/configs/ldc_2d/gns.yaml index fda8aea..65da6cb 100644 --- a/configs/ldc_2d/gns.yaml +++ b/configs/ldc_2d/gns.yaml @@ -1,6 +1,17 @@ -extends: ldc_2d/base.yaml +main: + data_dir: datasets/2D_LDC_2708_10kevery100 -model: gns -num_mp_steps: 10 -latent_dim: 128 -lr_start: 5.e-4 +model: + name: gns + num_mp_steps: 10 + latent_dim: 128 + +optimizer: + lr_start: 5.e-4 + noise_std: 0.001 + +logging: + wandb_project: ldc_2d + +neighbors: + multiplier: 2.0 \ No newline at end of file diff --git a/configs/ldc_2d/segnn.yaml b/configs/ldc_2d/segnn.yaml index 1adece6..a15cc6b 100644 --- a/configs/ldc_2d/segnn.yaml +++ b/configs/ldc_2d/segnn.yaml @@ -1,8 +1,20 @@ -extends: ldc_2d/base.yaml +main: + data_dir: datasets/2D_LDC_2708_10kevery100 -model: segnn -num_mp_steps: 10 -latent_dim: 64 -lr_start: 5.e-4 +model: + name: segnn + num_mp_steps: 10 + latent_dim: 64 -isotropic_norm: True +train: + isotropic_norm: True + +optimizer: + lr_start: 5.e-4 + noise_std: 0.001 + +logging: + wandb_project: ldc_2d + +neighbors: + multiplier: 2.0 \ No newline at end of file diff --git a/configs/ldc_3d/base.yaml b/configs/ldc_3d/base.yaml deleted file mode 100644 index 5dfb668..0000000 --- a/configs/ldc_3d/base.yaml +++ /dev/null @@ -1,6 +0,0 @@ -extends: defaults.yaml - -data_dir: datasets/3D_LDC_8160_10kevery100 -wandb_project: ldc_3d - -neighbor_list_multiplier: 2.0 diff --git a/configs/ldc_3d/gns.yaml b/configs/ldc_3d/gns.yaml index dbf14b4..b757e2f 100644 --- a/configs/ldc_3d/gns.yaml +++ b/configs/ldc_3d/gns.yaml @@ -1,6 +1,16 @@ -extends: ldc_3d/base.yaml +main: + data_dir: datasets/3D_LDC_8160_10kevery100 -model: gns -num_mp_steps: 10 -latent_dim: 128 -lr_start: 5.e-4 +model: + name: gns + num_mp_steps: 10 + latent_dim: 128 + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: ldc_3d + +neighbors: + multiplier: 2.0 \ No newline at end of file diff --git a/configs/ldc_3d/segnn.yaml b/configs/ldc_3d/segnn.yaml index fa4844c..4d5da64 100644 --- a/configs/ldc_3d/segnn.yaml +++ b/configs/ldc_3d/segnn.yaml @@ -1,8 +1,19 @@ -extends: ldc_3d/base.yaml +main: + data_dir: datasets/3D_LDC_8160_10kevery100 -model: segnn -num_mp_steps: 10 -latent_dim: 64 -lr_start: 5.e-4 +model: + name: segnn + num_mp_steps: 10 + latent_dim: 64 -isotropic_norm: True +train: + isotropic_norm: True + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: ldc_3d + +neighbors: + multiplier: 2.0 \ No newline at end of file diff --git a/configs/rpf_2d/base.yaml b/configs/rpf_2d/base.yaml deleted file mode 100644 index 0916557..0000000 --- a/configs/rpf_2d/base.yaml +++ /dev/null @@ -1,4 +0,0 @@ -extends: defaults.yaml - -data_dir: datasets/2D_RPF_3200_20kevery100 -wandb_project: rpf_2d diff --git a/configs/rpf_2d/egnn.yaml b/configs/rpf_2d/egnn.yaml index 82ab3b3..790b708 100644 --- a/configs/rpf_2d/egnn.yaml +++ b/configs/rpf_2d/egnn.yaml @@ -1,13 +1,21 @@ -extends: rpf_2d/base.yaml - -model: egnn -num_mp_steps: 5 -latent_dim: 128 -lr_start: 1.e-4 - -isotropic_norm: True -magnitude_features: True -loss_weight: - pos: 1.0 - vel: 0.0 - acc: 0.0 +main: + data_dir: datasets/2D_RPF_3200_20kevery100 + +model: + name: egnn + num_mp_steps: 5 + latent_dim: 128 + +train: + isotropic_norm: True + magnitude_features: True + +optimizer: + lr_start: 5.e-4 + loss_weight: + pos: 1.0 + vel: 0.0 + acc: 0.0 + +logging: + wandb_project: rpf_2d diff --git a/configs/rpf_2d/gns.yaml b/configs/rpf_2d/gns.yaml index 87c2e81..82383ec 100644 --- a/configs/rpf_2d/gns.yaml +++ b/configs/rpf_2d/gns.yaml @@ -1,6 +1,13 @@ -extends: rpf_2d/base.yaml +main: + data_dir: datasets/2D_RPF_3200_20kevery100 -model: gns -num_mp_steps: 10 -latent_dim: 128 -lr_start: 5.e-4 +model: + name: gns + num_mp_steps: 10 + latent_dim: 128 + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: rpf_2d \ No newline at end of file diff --git a/configs/rpf_2d/painn.yaml b/configs/rpf_2d/painn.yaml index 95c4e91..b05f189 100644 --- a/configs/rpf_2d/painn.yaml +++ b/configs/rpf_2d/painn.yaml @@ -1,9 +1,17 @@ -extends: rpf_2d/base.yaml +main: + data_dir: datasets/2D_RPF_3200_20kevery100 -model: painn -num_mp_steps: 5 -latent_dim: 128 -lr_start: 1.e-4 +model: + name: painn + num_mp_steps: 5 + latent_dim: 128 -isotropic_norm: True -magnitude_features: True +train: + isotropic_norm: True + magnitude_features: True + +optimizer: + lr_start: 1.e-4 + +logging: + wandb_project: rpf_2d diff --git a/configs/rpf_2d/segnn.yaml b/configs/rpf_2d/segnn.yaml index e65e2b4..7c510eb 100644 --- a/configs/rpf_2d/segnn.yaml +++ b/configs/rpf_2d/segnn.yaml @@ -1,8 +1,16 @@ -extends: rpf_2d/base.yaml +main: + data_dir: datasets/2D_RPF_3200_20kevery100 -model: segnn -num_mp_steps: 10 -latent_dim: 64 -lr_start: 1.e-3 +model: + name: segnn + num_mp_steps: 10 + latent_dim: 64 -isotropic_norm: True +train: + isotropic_norm: True + +optimizer: + lr_start: 1.e-3 + +logging: + wandb_project: rpf_2d diff --git a/configs/rpf_3d/base.yaml b/configs/rpf_3d/base.yaml deleted file mode 100644 index 7a20c34..0000000 --- a/configs/rpf_3d/base.yaml +++ /dev/null @@ -1,4 +0,0 @@ -extends: defaults.yaml - -data_dir: datasets/3D_RPF_8000_10kevery100 -wandb_project: rpf_3d diff --git a/configs/rpf_3d/egnn.yaml b/configs/rpf_3d/egnn.yaml index 1f793ff..25d6d0f 100644 --- a/configs/rpf_3d/egnn.yaml +++ b/configs/rpf_3d/egnn.yaml @@ -1,13 +1,21 @@ -extends: rpf_3d/base.yaml - -model: egnn -num_mp_steps: 5 -latent_dim: 128 -lr_start: 1.e-4 - -isotropic_norm: True -magnitude_features: True -loss_weight: - pos: 1.0 - vel: 0.0 - acc: 0.0 +main: + data_dir: datasets/3D_RPF_8000_10kevery100 + +model: + name: egnn + num_mp_steps: 5 + latent_dim: 128 + +train: + isotropic_norm: True + magnitude_features: True + +optimizer: + lr_start: 1.e-4 + loss_weight: + pos: 1.0 + vel: 0.0 + acc: 0.0 + +logging: + wandb_project: rpf_3d diff --git a/configs/rpf_3d/gns.yaml b/configs/rpf_3d/gns.yaml index 8bb2053..416a993 100644 --- a/configs/rpf_3d/gns.yaml +++ b/configs/rpf_3d/gns.yaml @@ -1,6 +1,13 @@ -extends: rpf_3d/base.yaml +main: + data_dir: datasets/3D_RPF_8000_10kevery100 -model: gns -num_mp_steps: 10 -latent_dim: 128 -lr_start: 5.e-4 +model: + name: gns + num_mp_steps: 10 + latent_dim: 128 + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: rpf_3d \ No newline at end of file diff --git a/configs/rpf_3d/painn.yaml b/configs/rpf_3d/painn.yaml index cdd5b62..27f735c 100644 --- a/configs/rpf_3d/painn.yaml +++ b/configs/rpf_3d/painn.yaml @@ -1,9 +1,17 @@ -extends: rpf_3d/base.yaml +main: + data_dir: datasets/3D_RPF_8000_10kevery100 -model: painn -num_mp_steps: 5 -latent_dim: 128 -lr_start: 5.e-4 +model: + name: painn + num_mp_steps: 5 + latent_dim: 128 -isotropic_norm: True -magnitude_features: True +train: + isotropic_norm: True + magnitude_features: True + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: rpf_3d diff --git a/configs/rpf_3d/segnn.yaml b/configs/rpf_3d/segnn.yaml index 0f6e6db..64df420 100644 --- a/configs/rpf_3d/segnn.yaml +++ b/configs/rpf_3d/segnn.yaml @@ -1,8 +1,16 @@ -extends: rpf_3d/base.yaml +main: + data_dir: datasets/3D_RPF_8000_10kevery100 -model: segnn -num_mp_steps: 10 -latent_dim: 64 -lr_start: 1.e-3 +model: + name: segnn + num_mp_steps: 10 + latent_dim: 64 -isotropic_norm: True +train: + isotropic_norm: True + +optimizer: + lr_start: 1.e-3 + +logging: + wandb_project: rpf_3d diff --git a/configs/tgv_2d/base.yaml b/configs/tgv_2d/base.yaml deleted file mode 100644 index f37268e..0000000 --- a/configs/tgv_2d/base.yaml +++ /dev/null @@ -1,4 +0,0 @@ -extends: defaults.yaml - -data_dir: datasets/2D_TGV_2500_10kevery100 -wandb_project: tgv_2d diff --git a/configs/tgv_2d/gns.yaml b/configs/tgv_2d/gns.yaml index 49c2330..289e849 100644 --- a/configs/tgv_2d/gns.yaml +++ b/configs/tgv_2d/gns.yaml @@ -1,6 +1,13 @@ -extends: tgv_2d/base.yaml +main: + data_dir: datasets/2D_TGV_2500_10kevery100 -model: gns -num_mp_steps: 10 -latent_dim: 128 -lr_start: 5.e-4 +model: + name: gns + num_mp_steps: 10 + latent_dim: 128 + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: tgv_2d diff --git a/configs/tgv_2d/segnn.yaml b/configs/tgv_2d/segnn.yaml index 865fce3..092d59e 100644 --- a/configs/tgv_2d/segnn.yaml +++ b/configs/tgv_2d/segnn.yaml @@ -1,8 +1,17 @@ -extends: tgv_2d/base.yaml +main: + data_dir: datasets/2D_TGV_2500_10kevery100 -model: segnn -num_mp_steps: 10 -latent_dim: 64 -lr_start: 5.e-4 +model: + name: segnn + num_mp_steps: 10 + latent_dim: 64 -isotropic_norm: True +train: + isotropic_norm: True + + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: tgv_2d diff --git a/configs/tgv_3d/base.yaml b/configs/tgv_3d/base.yaml deleted file mode 100644 index 7c655e4..0000000 --- a/configs/tgv_3d/base.yaml +++ /dev/null @@ -1,4 +0,0 @@ -extends: defaults.yaml - -data_dir: datasets/3D_TGV_8000_10kevery100 -wandb_project: tgv_3d diff --git a/configs/tgv_3d/gns.yaml b/configs/tgv_3d/gns.yaml index cf0b741..b286b7c 100644 --- a/configs/tgv_3d/gns.yaml +++ b/configs/tgv_3d/gns.yaml @@ -1,6 +1,13 @@ -extends: tgv_3d/base.yaml +main: + data_dir: datasets/3D_TGV_8000_10kevery100 -model: gns -num_mp_steps: 10 -latent_dim: 128 -lr_start: 5.e-4 +model: + name: gns + num_mp_steps: 10 + latent_dim: 128 + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: tgv_3d diff --git a/configs/tgv_3d/segnn.yaml b/configs/tgv_3d/segnn.yaml index ebc81cc..dddaeef 100644 --- a/configs/tgv_3d/segnn.yaml +++ b/configs/tgv_3d/segnn.yaml @@ -1,8 +1,16 @@ -extends: tgv_3d/base.yaml +main: + data_dir: datasets/3D_TGV_8000_10kevery100 -model: segnn -num_mp_steps: 10 -latent_dim: 64 -lr_start: 5.e-4 +model: + name: segnn + num_mp_steps: 10 + latent_dim: 64 -isotropic_norm: True +train: + isotropic_norm: True + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: tgv_3d diff --git a/docs/conf.py b/docs/conf.py index 589f76c..e73044e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -10,7 +10,11 @@ copyright = "2023, Chair of Aerodynamics and Fluid Mechanics, TUM" author = "Artur Toshev, Gianluca Galletti" -version = "0.0.1" +# read the version from pyproject.toml +import toml + +pyproject = toml.load("../pyproject.toml") +version = pyproject["tool"]["poetry"]["version"] # -- Path setup -------------------------------------------------------------- diff --git a/experiments/config.py b/experiments/config.py deleted file mode 100644 index c7d0438..0000000 --- a/experiments/config.py +++ /dev/null @@ -1,220 +0,0 @@ -import argparse -import os -from typing import Dict - -import yaml - - -def cli_arguments() -> Dict: - parser = argparse.ArgumentParser() - group = parser.add_mutually_exclusive_group(required=True) - - # config arguments - group.add_argument("-c", "--config", type=str, help="Path to the config yaml.") - group.add_argument("--model_dir", type=str, help="Path to the model checkpoint.") - - # run arguments - parser.add_argument( - "--mode", type=str, choices=["train", "infer", "all"], help="Train or evaluate." - ) - parser.add_argument("--batch_size", type=int, required=False, help="Batch size.") - parser.add_argument( - "--lr_start", type=float, required=False, help="Starting learning rate." - ) - parser.add_argument( - "--lr_final", type=float, required=False, help="Learning rate after decay." - ) - parser.add_argument( - "--lr_decay_rate", type=float, required=False, help="Learning rate decay." - ) - parser.add_argument( - "--lr_decay_steps", type=int, required=False, help="Learning rate decay steps." - ) - parser.add_argument( - "--noise_std", - type=float, - required=False, - help="Additive noise standard deviation.", - ) - parser.add_argument( - "--test", - action=argparse.BooleanOptionalAction, - help="Run test mode instead of validation.", - ) - parser.add_argument("--seed", type=int, required=False, help="Random seed.") - parser.add_argument( - "--data_dir", type=str, help="Absolute/relative path to the dataset." - ) - parser.add_argument("--ckp_dir", type=str, help="Path for checkpoints.") - - # model arguments - parser.add_argument( - "--model", - type=str, - help="Model name.", - ) - parser.add_argument( - "--input_seq_length", - type=int, - required=False, - help="Input position sequence length.", - ) - parser.add_argument( - "--num_mp_steps", - type=int, - required=False, - help="Number of message passing layers.", - ) - parser.add_argument( - "--num_mlp_layers", type=int, required=False, help="Number of MLP layers." - ) - parser.add_argument( - "--latent_dim", type=int, required=False, help="Hidden layer dimension." - ) - parser.add_argument( - "--magnitude_features", - action=argparse.BooleanOptionalAction, - help="Whether to include velocity magnitudes in node features.", - ) - parser.add_argument( - "--isotropic_norm", - action=argparse.BooleanOptionalAction, - help="Use isotropic normalization.", - ) - - # output arguments - parser.add_argument( - "--out_type", - type=str, - required=False, - choices=["vtk", "pkl", "none"], - help="Output type to store rollouts during validation.", - ) - parser.add_argument( - "--out_type_infer", - type=str, - required=False, - choices=["vtk", "pkl", "none"], - help="Output type to store rollouts during inference.", - ) - parser.add_argument( - "--rollout_dir", type=str, required=False, help="Directory to write rollouts." - ) - - # segnn-specific arguments - parser.add_argument( - "--lmax_attributes", - type=int, - required=False, - help="Maximum degree of attributes.", - ) - parser.add_argument( - "--lmax_hidden", - type=int, - required=False, - help="Maximum degree of hidden layers.", - ) - parser.add_argument( - "--segnn_norm", - type=str, - required=False, - choices=["instance", "batch", "none"], - help="Normalisation type.", - ) - parser.add_argument( - "--velocity_aggregate", - type=str, - required=False, - choices=["avg", "sum", "last", "all"], - help="Velocity aggregation function for node attributes.", - ) - parser.add_argument( - "--attribute_mode", - type=str, - required=False, - choices=["add", "concat", "velocity"], - help="How to combine node attributes.", - ) - # HAE-specific arguments - parser.add_argument( - "--right_attribute", - required=False, - action=argparse.BooleanOptionalAction, - help="Whether to use last velocity to steer the attribute embedding.", - ) - parser.add_argument( - "--attribute_embedding_blocks", - required=False, - type=int, - help="Number of embedding layers for the attributes.", - ) - - # misc arguments - parser.add_argument( - "--gpu", type=int, required=False, help="CUDA device ID to use." - ) - parser.add_argument( - "--f64", - required=False, - action=argparse.BooleanOptionalAction, - help="Whether to use double precision.", - ) - - parser.add_argument( - "--eval_n_trajs", - required=False, - type=int, - help="Number of trajectories to evaluate during validation.", - ) - parser.add_argument( - "--eval_n_trajs_infer", - required=False, - type=int, - help="Number of trajectories to evaluate during inference.", - ) - - parser.add_argument( - "--metrics", - required=False, - nargs="+", - help="Validation metrics to evaluate. Choose from: mse, mae, sinkhorn, e_kin.", - ) - parser.add_argument( - "--metrics_infer", - required=False, - nargs="+", - help="Inference metrics to evaluate during inference.", - ) - parser.add_argument( - "--metrics_stride", - required=False, - type=int, - help="Stride for Sinkhorn and e_kin during validation", - ) - parser.add_argument( - "--metrics_stride_infer", - required=False, - type=int, - help="Stride for Sinkhorn and e_kin during inference.", - ) - parser.add_argument( - "--n_rollout_steps", - required=False, - type=int, - help="Number of rollout steps during validation/testing.", - ) - # only keep passed arguments to avoid overwriting config - return {k: v for k, v in vars(parser.parse_args()).items() if v is not None} - - -class NestedLoader(yaml.SafeLoader): - """Load yaml files with nested configs.""" - - def get_single_data(self): - parent = {} - config = super().get_single_data() - if "extends" in config and (included := config["extends"]): - del config["extends"] - with open(os.path.join("configs", included), "r") as f: - parent = yaml.load(f, NestedLoader) - return {**parent, **config} diff --git a/experiments/run.py b/experiments/run.py deleted file mode 100644 index 33494ea..0000000 --- a/experiments/run.py +++ /dev/null @@ -1,169 +0,0 @@ -import copy -import os -import os.path as osp -from argparse import Namespace -from datetime import datetime - -import haiku as hk -import jax.numpy as jnp -import jmp -import numpy as np -import wandb -import yaml - -from experiments.utils import setup_data, setup_model -from lagrangebench import Trainer, infer -from lagrangebench.case_setup import case_builder -from lagrangebench.evaluate import averaged_metrics -from lagrangebench.utils import PushforwardConfig - - -def train_or_infer(args: Namespace): - data_train, data_valid, data_test, args = setup_data(args) - - # neighbors search - bounds = np.array(data_train.metadata["bounds"]) - args.box = bounds[:, 1] - bounds[:, 0] - - args.info.len_train = len(data_train) - args.info.len_eval = len(data_valid) - - # setup core functions - case = case_builder( - box=args.box, - metadata=data_train.metadata, - input_seq_length=args.config.input_seq_length, - isotropic_norm=args.config.isotropic_norm, - noise_std=args.config.noise_std, - magnitude_features=args.config.magnitude_features, - external_force_fn=data_train.external_force_fn, - neighbor_list_backend=args.config.neighbor_list_backend, - neighbor_list_multiplier=args.config.neighbor_list_multiplier, - dtype=(jnp.float64 if args.config.f64 else jnp.float32), - ) - - _, particle_type = data_train[0] - - args.info.homogeneous_particles = particle_type.max() == particle_type.min() - args.metadata = data_train.metadata - args.normalization_stats = case.normalization_stats - args.config.has_external_force = data_train.external_force_fn is not None - - # setup model from configs - model, MODEL = setup_model(args) - model = hk.without_apply_rng(hk.transform_with_state(model)) - - # mixed precision training based on this reference: - # https://github.com/deepmind/dm-haiku/blob/main/examples/imagenet/train.py - policy = jmp.get_policy("params=float32,compute=float32,output=float32") - hk.mixed_precision.set_policy(MODEL, policy) - - if args.config.mode == "train" or args.config.mode == "all": - print("Start training...") - # save config file - run_prefix = f"{args.config.model}_{data_train.name}" - data_and_time = datetime.today().strftime("%Y%m%d-%H%M%S") - args.info.run_name = f"{run_prefix}_{data_and_time}" - - args.config.new_checkpoint = os.path.join( - args.config.ckp_dir, args.info.run_name - ) - os.makedirs(args.config.new_checkpoint, exist_ok=True) - os.makedirs(os.path.join(args.config.new_checkpoint, "best"), exist_ok=True) - with open(os.path.join(args.config.new_checkpoint, "config.yaml"), "w") as f: - yaml.dump(vars(args.config), f) - with open( - os.path.join(args.config.new_checkpoint, "best", "config.yaml"), "w" - ) as f: - yaml.dump(vars(args.config), f) - - if args.config.wandb: - # wandb doesn't like Namespace objects - args_dict = copy.copy(args) - args_dict.config = vars(args.config) - args_dict.info = vars(args.info) - - wandb_run = wandb.init( - project=args.config.wandb_project, - entity=args.config.wandb_entity, - name=args.info.run_name, - config=args_dict, - save_code=True, - ) - else: - wandb_run = None - - pf_config = PushforwardConfig( - steps=args.config.pushforward["steps"], - unrolls=args.config.pushforward["unrolls"], - probs=args.config.pushforward["probs"], - ) - - trainer = Trainer( - model, - case, - data_train, - data_valid, - pushforward=pf_config, - metrics=args.config.metrics, - seed=args.config.seed, - batch_size=args.config.batch_size, - input_seq_length=args.config.input_seq_length, - noise_std=args.config.noise_std, - lr_start=args.config.lr_start, - lr_final=args.config.lr_final, - lr_decay_steps=args.config.lr_decay_steps, - lr_decay_rate=args.config.lr_decay_rate, - loss_weight=args.config.loss_weight, - n_rollout_steps=args.config.n_rollout_steps, - eval_n_trajs=args.config.eval_n_trajs, - rollout_dir=args.config.rollout_dir, - out_type=args.config.out_type, - log_steps=args.config.log_steps, - eval_steps=args.config.eval_steps, - metrics_stride=args.config.metrics_stride, - num_workers=args.config.num_workers, - batch_size_infer=args.config.batch_size_infer, - ) - _, _, _ = trainer( - step_max=args.config.step_max, - load_checkpoint=args.config.model_dir, - store_checkpoint=args.config.new_checkpoint, - wandb_run=wandb_run, - ) - - if args.config.wandb: - wandb.finish() - - if args.config.mode == "infer" or args.config.mode == "all": - print("Start inference...") - if args.config.mode == "all": - args.config.model_dir = os.path.join(args.config.new_checkpoint, "best") - assert osp.isfile(os.path.join(args.config.model_dir, "params_tree.pkl")) - - args.config.rollout_dir = args.config.model_dir.replace("ckp", "rollout") - os.makedirs(args.config.rollout_dir, exist_ok=True) - - if args.config.eval_n_trajs_infer is None: - args.config.eval_n_trajs_infer = args.config.eval_n_trajs - - assert args.config.model_dir, "model_dir must be specified for inference." - metrics = infer( - model, - case, - data_test if args.config.test else data_valid, - load_checkpoint=args.config.model_dir, - metrics=args.config.metrics_infer, - rollout_dir=args.config.rollout_dir, - eval_n_trajs=args.config.eval_n_trajs_infer, - n_rollout_steps=args.config.n_rollout_steps, - out_type=args.config.out_type_infer, - n_extrap_steps=args.config.n_extrap_steps, - seed=args.config.seed, - metrics_stride=args.config.metrics_stride_infer, - batch_size=args.config.batch_size_infer, - ) - - split = "test" if args.config.test else "valid" - print(f"Metrics of {args.config.model_dir} on {split} split:") - print(averaged_metrics(metrics)) diff --git a/experiments/utils.py b/experiments/utils.py deleted file mode 100644 index 8168178..0000000 --- a/experiments/utils.py +++ /dev/null @@ -1,156 +0,0 @@ -import os -import os.path as osp -from argparse import Namespace -from typing import Callable, Tuple, Type - -import jax -import jax.numpy as jnp -from e3nn_jax import Irreps -from jax_md import space - -from lagrangebench import models -from lagrangebench.data import H5Dataset -from lagrangebench.models.utils import node_irreps -from lagrangebench.utils import NodeType - - -def setup_data(args: Namespace) -> Tuple[H5Dataset, H5Dataset, Namespace]: - if not osp.isabs(args.config.data_dir): - args.config.data_dir = osp.join(os.getcwd(), args.config.data_dir) - - args.info.dataset_name = osp.basename(args.config.data_dir.split("/")[-1]) - if args.config.ckp_dir is not None: - os.makedirs(args.config.ckp_dir, exist_ok=True) - if args.config.rollout_dir is not None: - os.makedirs(args.config.rollout_dir, exist_ok=True) - - # dataloader - data_train = H5Dataset( - "train", - dataset_path=args.config.data_dir, - input_seq_length=args.config.input_seq_length, - extra_seq_length=args.config.pushforward["unrolls"][-1], - nl_backend=args.config.neighbor_list_backend, - ) - data_valid = H5Dataset( - "valid", - dataset_path=args.config.data_dir, - input_seq_length=args.config.input_seq_length, - extra_seq_length=args.config.n_rollout_steps, - nl_backend=args.config.neighbor_list_backend, - ) - data_test = H5Dataset( - "test", - dataset_path=args.config.data_dir, - input_seq_length=args.config.input_seq_length, - extra_seq_length=args.config.n_rollout_steps, - nl_backend=args.config.neighbor_list_backend, - ) - if args.config.eval_n_trajs == -1: - args.config.eval_n_trajs = data_valid.num_samples - if args.config.eval_n_trajs_infer == -1: - args.config.eval_n_trajs_infer = data_valid.num_samples - assert data_valid.num_samples >= args.config.eval_n_trajs, ( - f"Number of available evaluation trajectories ({data_valid.num_samples}) " - f"exceeds eval_n_trajs ({args.config.eval_n_trajs})" - ) - - args.info.has_external_force = bool(data_train.external_force_fn is not None) - - return data_train, data_valid, data_test, args - - -def setup_model(args: Namespace) -> Tuple[Callable, Type]: - """Setup model based on args.""" - model_name = args.config.model.lower() - metadata = args.metadata - - if model_name == "gns": - - def model_fn(x): - return models.GNS( - particle_dimension=metadata["dim"], - latent_size=args.config.latent_dim, - blocks_per_step=args.config.num_mlp_layers, - num_mp_steps=args.config.num_mp_steps, - num_particle_types=NodeType.SIZE, - particle_type_embedding_size=16, - )(x) - - MODEL = models.GNS - elif model_name == "segnn": - # Hx1o vel, Hx0e vel, 2x1o boundary, 9x0e type - node_feature_irreps = node_irreps( - metadata, - args.config.input_seq_length, - args.config.has_external_force, - args.config.magnitude_features, - args.info.homogeneous_particles, - ) - # 1o displacement, 0e distance - edge_feature_irreps = Irreps("1x1o + 1x0e") - - def model_fn(x): - return models.SEGNN( - node_features_irreps=node_feature_irreps, - edge_features_irreps=edge_feature_irreps, - scalar_units=args.config.latent_dim, - lmax_hidden=args.config.lmax_hidden, - lmax_attributes=args.config.lmax_attributes, - output_irreps=Irreps("1x1o"), - num_mp_steps=args.config.num_mp_steps, - n_vels=args.config.input_seq_length - 1, - velocity_aggregate=args.config.velocity_aggregate, - homogeneous_particles=args.info.homogeneous_particles, - blocks_per_step=args.config.num_mlp_layers, - norm=args.config.segnn_norm, - )(x) - - MODEL = models.SEGNN - elif model_name == "egnn": - box = args.box - if jnp.array(metadata["periodic_boundary_conditions"]).any(): - displacement_fn, shift_fn = space.periodic(jnp.array(box)) - else: - displacement_fn, shift_fn = space.free() - - displacement_fn = jax.vmap(displacement_fn, in_axes=(0, 0)) - shift_fn = jax.vmap(shift_fn, in_axes=(0, 0)) - - def model_fn(x): - return models.EGNN( - hidden_size=args.config.latent_dim, - output_size=1, - dt=metadata["dt"] * metadata["write_every"], - displacement_fn=displacement_fn, - shift_fn=shift_fn, - normalization_stats=args.normalization_stats, - num_mp_steps=args.config.num_mp_steps, - n_vels=args.config.input_seq_length - 1, - residual=True, - )(x) - - MODEL = models.EGNN - elif model_name == "painn": - assert args.config.magnitude_features, "PaiNN requires magnitudes" - radius = metadata["default_connectivity_radius"] * 1.5 - - def model_fn(x): - return models.PaiNN( - hidden_size=args.config.latent_dim, - output_size=1, - n_vels=args.config.input_seq_length - 1, - radial_basis_fn=models.painn.gaussian_rbf(20, radius, trainable=True), - cutoff_fn=models.painn.cosine_cutoff(radius), - num_mp_steps=args.config.num_mp_steps, - )(x) - - MODEL = models.PaiNN - elif model_name == "linear": - - def model_fn(x): - return models.Linear(dim_out=metadata["dim"])(x) - - MODEL = models.Linear - - return model_fn, MODEL diff --git a/lagrangebench/__init__.py b/lagrangebench/__init__.py index 39cf0eb..3f004d6 100644 --- a/lagrangebench/__init__.py +++ b/lagrangebench/__init__.py @@ -1,9 +1,9 @@ from .case_setup.case import case_builder +from .config import cfg from .data import DAM2D, LDC2D, LDC3D, RPF2D, RPF3D, TGV2D, TGV3D, H5Dataset from .evaluate import infer from .models import EGNN, GNS, SEGNN, PaiNN from .train.trainer import Trainer -from .utils import PushforwardConfig __all__ = [ "Trainer", @@ -21,7 +21,10 @@ "LDC2D", "LDC3D", "DAM2D", - "PushforwardConfig", + "cfg", ] -__version__ = "0.0.1" +import toml + +pyproject = toml.load("pyproject.toml") +__version__ = pyproject["tool"]["poetry"]["version"] diff --git a/lagrangebench/case_setup/case.py b/lagrangebench/case_setup/case.py index 21ee2ec..cc3d92f 100644 --- a/lagrangebench/case_setup/case.py +++ b/lagrangebench/case_setup/case.py @@ -9,8 +9,8 @@ from jax_md.dataclasses import dataclass, static_field from jax_md.partition import NeighborList, NeighborListFormat +from lagrangebench.config import cfg from lagrangebench.data.utils import get_dataset_stats -from lagrangebench.defaults import defaults from lagrangebench.train.strats import add_gns_noise from .features import FeatureDict, TargetDict, physical_feature_builder @@ -62,14 +62,7 @@ class CaseSetupFn: def case_builder( box: Tuple[float, float, float], metadata: Dict, - input_seq_length: int, - isotropic_norm: bool = defaults.isotropic_norm, - noise_std: float = defaults.noise_std, external_force_fn: Optional[Callable] = None, - magnitude_features: bool = defaults.magnitude_features, - neighbor_list_backend: str = defaults.neighbor_list_backend, - neighbor_list_multiplier: float = defaults.neighbor_list_multiplier, - dtype: jnp.dtype = defaults.dtype, ): """Set up a CaseSetupFn that contains every required function besides the model. @@ -83,15 +76,17 @@ def case_builder( Args: box: Box xyz sizes of the system. metadata: Dataset metadata dictionary. - input_seq_length: Length of the input sequence. - isotropic_norm: Whether to normalize dimensions equally. - noise_std: Noise standard deviation. external_force_fn: External force function. - magnitude_features: Whether to add velocity magnitudes in the features. - neighbor_list_backend: Backend of the neighbor list. - neighbor_list_multiplier: Capacity multiplier of the neighbor list. - dtype: Data type. """ + + input_seq_length = cfg.model.input_seq_length + isotropic_norm = cfg.train.isotropic_norm + noise_std = cfg.optimizer.noise_std + magnitude_features = cfg.train.magnitude_features + neighbor_list_backend = cfg.neighbors.backend + neighbor_list_multiplier = cfg.neighbors.multiplier + dtype = cfg.main.dtype + normalization_stats = get_dataset_stats(metadata, isotropic_norm, noise_std) # apply PBC in all directions or not at all diff --git a/lagrangebench/config.py b/lagrangebench/config.py new file mode 100644 index 0000000..c72e7eb --- /dev/null +++ b/lagrangebench/config.py @@ -0,0 +1,208 @@ +from typing import Dict, List, Optional + +import yaml +from yacs.config import CfgNode as CN + +# lagrangebench-wide config object +cfg = CN() + + +def custom_config(fn): + """Decorator to add custom config functions.""" + fn(cfg) + return fn + + +def defaults(cfg): + """Default lagrangebench values.""" + + if cfg is None: + raise ValueError("cfg should be a yacs CfgNode") + + # global and hardware-related configs + main = CN() + + # One of "train", "infer" or "all" (= both) + main.mode = "all" + # random seed + main.seed = 0 + # data type for preprocessing. One of "float32" or "float64" + main.dtype = "float64" + # gpu device + main.gpu = 0 + # XLA memory fraction to be preallocated + main.xla_mem_fraction = 0.7 + # data directory + main.data_dir = None + + cfg.main = main + + # model + model = CN() + + model.name = None + # Length of the position input sequence + model.input_seq_length = 6 + # Number of message passing steps + model.num_mp_steps = 10 + # Number of MLP layers + model.num_mlp_layers = 2 + # Hidden dimension + model.latent_dim = 128 + # Load checkpointed model from this directory + model.model_dir = None + + cfg.model = model + + # training + train = CN() + + # batch size + train.batch_size = 1 + # max number of training steps + train.step_max = 500_000 + # whether to include velocity magnitude features + train.magnitude_features = False + # whether to normalize dimensions equally + train.isotropic_norm = False + # number of workers for data loading + train.num_workers = 4 + + cfg.train = train + + # optimizer + optimizer = CN() + + # initial learning rate + optimizer.lr_start = 1e-4 + # final learning rate (after exponential decay) + optimizer.lr_final = 1e-6 + # learning rate decay rate + optimizer.lr_decay_rate = 0.1 + # number of steps to decay learning rate + optimizer.lr_decay_steps = 1e5 + # standard deviation of the GNS-style noise + optimizer.noise_std = 3e-4 + + # optimizer: pushforward + pushforward = CN() + # At which training step to introduce next unroll stage + pushforward.steps = [-1, 20000, 300000, 400000] + # For how many steps to unroll + pushforward.unrolls = [0, 1, 2, 3] + # Which probability ratio to keep between the unrolls + pushforward.probs = [18, 2, 1, 1] + + # optimizer: loss weights + loss_weight = CN() + # weight for acceleration error + loss_weight.acc = 1.0 + # weight for velocity error + loss_weight.vel = 0.0 + # weight for position error + loss_weight.pos = 0.0 + + cfg.optimizer = optimizer + cfg.optimizer.loss_weight = loss_weight + cfg.optimizer.pushforward = pushforward + + # evaluation + eval = CN() + + # number of eval rollout steps. -1 is full rollout + eval.n_rollout_steps = 20 + # number of trajectories to evaluate during training + eval.n_trajs_train = 1 + # number of trajectories to evaluate during inference + eval.n_trajs_infer = 50 + # metrics for training + eval.metrics_train = ["mse"] + # stride for e_kin and sinkhorn + eval.metrics_stride_train = 10 + # metrics for inference + eval.metrics_infer = ["mse", "e_kin", "sinkhorn"] + # stride for e_kin and sinkhorn + eval.metrics_stride_infer = 1 + # number of extrapolation steps in inference + eval.n_extrap_steps = 0 + # batch size for validation/testing + eval.batch_size_infer = 2 + # write validation rollouts. One of "none", "vtk", or "pkl" + eval.out_type_train = "none" + # write inference rollouts. One of "none", "vtk", or "pkl" + eval.out_type_infer = "pkl" + # rollouts directory + eval.rollout_dir = None + # whether to use the test split + eval.test = False + + cfg.eval = eval + + # logging + logging = CN() + + # number of steps between loggings + logging.log_steps = 1000 + # number of steps between evaluations and checkpoints + logging.eval_steps = 10000 + # wandb enable + logging.wandb = False + # wandb project name + logging.wandb_project = None + # wandb entity name + logging.wandb_entity = "lagrangebench" + # checkpoint directory + logging.ckp_dir = "ckp" + # name of training run + logging.run_name = None + + cfg.logging = logging + + # neighbor list + neighbors = CN() + + # backend for neighbor list computation + neighbors.backend = "jaxmd_vmap" + # multiplier for neighbor list capacity + neighbors.multiplier = 1.25 + + cfg.neighbors = neighbors + + +def check_cfg(cfg): + assert cfg.main.data_dir is not None, "cfg.main.data_dir must be specified." + assert ( + cfg.train.step_max is not None and cfg.train.step_max > 0 + ), "cfg.train.step_max must be specified and larger than 0." + + +def load_cfg(cfg: CN, config_path: str, extra_args: Optional[List] = None): + if cfg is None: + raise ValueError("cfg should be a yacs CfgNode") + if len(cfg) == 0: + defaults(cfg) + cfg.merge_from_file(config_path) + if extra_args is not None: + cfg.merge_from_list(extra_args) + check_cfg(cfg) + + +def cfg_to_dict(cfg: CN) -> Dict: + return yaml.safe_load(cfg.dump()) + + +# TODO find a better way +defaults(cfg) + + +@custom_config +def segnn_config(cfg): + """SEGNN only parameters.""" + # Steerable attributes level + cfg.model.lmax_attributes = 1 + # Level of the hidden layer + cfg.model.lmax_hidden = 1 + # SEGNN normalization. instance, batch, none + cfg.model.segnn_norm = "none" + # SEGNN velocity aggregation. avg or last + cfg.model.velocity_aggregate = "avg" diff --git a/lagrangebench/data/data.py b/lagrangebench/data/data.py index 0febebe..99df047 100644 --- a/lagrangebench/data/data.py +++ b/lagrangebench/data/data.py @@ -15,6 +15,7 @@ import wget from torch.utils.data import Dataset +from lagrangebench.config import cfg from lagrangebench.utils import NodeType URLS = { @@ -41,17 +42,16 @@ class H5Dataset(Dataset): def __init__( self, split: str, - dataset_path: str, + dataset_path: Optional[str] = None, name: Optional[str] = None, input_seq_length: int = 6, extra_seq_length: int = 0, - nl_backend: str = "jaxmd_vmap", ): """Initialize the dataset. If the dataset is not present, it is downloaded. Args: split: "train", "valid", or "test" - dataset_path: Path to the dataset + dataset_path: Path to the dataset. If none it reads from the config. name: Name of the dataset. If None, it is inferred from the path. input_seq_length: Length of the input sequence. The number of historic velocities is input_seq_length - 1. And during training, the returned @@ -60,8 +60,9 @@ def __init__( extra_seq_length: During training, this is the maximum number of pushforward unroll steps. During validation/testing, this specifies the largest N-step MSE loss we are interested in, e.g. for best model checkpointing. - nl_backend: Which backend to use for the neighbor list """ + if dataset_path is None: + dataset_path = cfg.main.data_dir if dataset_path.endswith("/"): # remove trailing slash in dataset path dataset_path = dataset_path[:-1] @@ -80,7 +81,7 @@ def __init__( self.dataset_path = dataset_path self.file_path = osp.join(dataset_path, split + ".h5") self.input_seq_length = input_seq_length - self.nl_backend = nl_backend + self.nl_backend = cfg.neighbors.backend force_fn_path = osp.join(dataset_path, "force.py") if osp.exists(force_fn_path): diff --git a/lagrangebench/defaults.py b/lagrangebench/defaults.py deleted file mode 100644 index 9cb3c22..0000000 --- a/lagrangebench/defaults.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Default lagrangebench values.""" - -from dataclasses import dataclass - -import jax.numpy as jnp - - -@dataclass(frozen=True) -class defaults: - """ - Default lagrangebench values. - - Attributes: - seed: random seed. Default 0. - batch_size: batch size. Default 1. - step_max: max number of training steps. Default ``1e7``. - dtype: data type. Default ``jnp.float32``. - magnitude_features: whether to include velocity magnitudes. Default False. - isotropic_norm: whether to normalize dimensions equally. Default False. - lr_start: initial learning rate. Default 1e-4. - lr_final: final learning rate (after exponential decay). Default 1e-6. - lr_decay_steps: number of steps to decay learning rate - lr_decay_rate: learning rate decay rate. Default 0.1. - noise_std: standard deviation of the GNS-style noise. Default 1e-4. - input_seq_length: number of input steps. Default 6. - n_rollout_steps: number of eval rollout steps. -1 is full rollout. Default -1. - eval_n_trajs: number of trajectories to evaluate. Default 1 trajectory. - rollout_dir: directory to save rollouts. Default None. - out_type: type of output. None means no rollout is stored. Default None. - n_extrap_steps: number of extrapolation steps. Default 0. - log_steps: number of steps between logs. Default 1000. - eval_steps: number of steps between evaluations and checkpoints. Default 5000. - neighbor_list_backend: neighbor list routine. Default "jaxmd_vmap". - neighbor_list_multiplier: multiplier for neighbor list capacity. Default 1.25. - """ - - # training - seed: int = 0 # random seed - batch_size: int = 1 # batch size - step_max: int = 5e5 # max number of training steps - dtype: jnp.dtype = jnp.float64 # data type for preprocessing - magnitude_features: bool = False # whether to include velocity magnitude features - isotropic_norm: bool = False # whether to normalize dimensions equally - num_workers: int = 4 # number of workers for data loading - - # learning rate - lr_start: float = 1e-4 # initial learning rate - lr_final: float = 1e-6 # final learning rate (after exponential decay) - lr_decay_steps: int = 1e5 # number of steps to decay learning rate - lr_decay_rate: float = 0.1 # learning rate decay rate - - noise_std: float = 3e-4 # standard deviation of the GNS-style noise - - # evaluation - input_seq_length: int = 6 # number of input steps - n_rollout_steps: int = -1 # number of eval rollout steps. -1 is full rollout - eval_n_trajs: int = 1 # number of trajectories to evaluate - rollout_dir: str = None # directory to save rollouts - out_type: str = "none" # type of output. None means no rollout is stored - n_extrap_steps: int = 0 # number of extrapolation steps - metrics_stride: int = 10 # stride for e_kin and sinkhorn - batch_size_infer: int = 2 # batch size for validation/testing - - # logging - log_steps: int = 1000 # number of steps between logs - eval_steps: int = 10000 # number of steps between evaluations and checkpoints - - # neighbor list - neighbor_list_backend: str = "jaxmd_vmap" # backend for neighbor list computation - neighbor_list_multiplier: float = 1.25 # multiplier for neighbor list capacity diff --git a/lagrangebench/evaluate/rollout.py b/lagrangebench/evaluate/rollout.py index dde7627..28e50cf 100644 --- a/lagrangebench/evaluate/rollout.py +++ b/lagrangebench/evaluate/rollout.py @@ -4,7 +4,7 @@ import pickle import time from functools import partial -from typing import Callable, Iterable, List, Optional, Tuple +from typing import Callable, Iterable, Optional, Tuple import haiku as hk import jax @@ -13,9 +13,9 @@ from jax import jit, vmap from torch.utils.data import DataLoader +from lagrangebench.config import cfg from lagrangebench.data import H5Dataset from lagrangebench.data.utils import numpy_collate -from lagrangebench.defaults import defaults from lagrangebench.evaluate.metrics import MetricsComputer, MetricsDict from lagrangebench.evaluate.utils import write_vtk from lagrangebench.utils import ( @@ -74,7 +74,7 @@ def _forward_eval( return current_positions, state -def eval_batched_rollout( +def _eval_batched_rollout( forward_eval_vmap: Callable, preprocess_eval_vmap: Callable, case, @@ -237,7 +237,7 @@ def eval_rollout( # (pos_input_batch, particle_type_batch) = traj_batch_i # pos_input_batch.shape = (batch, num_particles, seq_length, dim) - example_rollout_batch, metrics_batch, neighbors = eval_batched_rollout( + example_rollout_batch, metrics_batch, neighbors = _eval_batched_rollout( forward_eval_vmap=forward_eval_vmap, preprocess_eval_vmap=preprocess_eval_vmap, case=case, @@ -289,7 +289,7 @@ def eval_rollout( "tag": example_rollout["particle_type"], } write_vtk(ref_state_vtk, f"{file_prefix}_ref_{k}.vtk") - if out_type == "pkl": + elif out_type == "pkl": filename = f"{file_prefix}.pkl" with open(filename, "wb") as f: @@ -314,15 +314,6 @@ def infer( params: Optional[hk.Params] = None, state: Optional[hk.State] = None, load_checkpoint: Optional[str] = None, - metrics: List = ["mse"], - rollout_dir: Optional[str] = None, - eval_n_trajs: int = defaults.eval_n_trajs, - n_rollout_steps: int = defaults.n_rollout_steps, - out_type: str = defaults.out_type, - n_extrap_steps: int = defaults.n_extrap_steps, - seed: int = defaults.seed, - metrics_stride: int = defaults.metrics_stride, - batch_size: int = defaults.batch_size_infer, ): """ Infer on a dataset, compute metrics and optionally save rollout in out_type format. @@ -357,21 +348,21 @@ def infer( else: params, state, _, _ = load_haiku(load_checkpoint) - key, seed_worker, generator = set_seed(seed) + key, seed_worker, generator = set_seed(cfg.main.seed) loader_test = DataLoader( dataset=data_test, - batch_size=batch_size, + batch_size=cfg.eval.batch_size_infer, collate_fn=numpy_collate, worker_init_fn=seed_worker, generator=generator, ) metrics_computer = MetricsComputer( - metrics, + cfg.eval.metrics_infer, dist_fn=case.displacement, metadata=data_test.metadata, input_seq_length=data_test.input_seq_length, - stride=metrics_stride, + stride=cfg.eval.metrics_stride_infer, ) # Precompile model model_apply = jit(model.apply) @@ -389,10 +380,10 @@ def infer( state=state, neighbors=neighbors, loader_eval=loader_test, - n_rollout_steps=n_rollout_steps, - n_trajs=eval_n_trajs, - rollout_dir=rollout_dir, - out_type=out_type, - n_extrap_steps=n_extrap_steps, + n_rollout_steps=cfg.eval.n_rollout_steps, + n_trajs=cfg.eval.n_trajs_infer, + rollout_dir=cfg.eval.rollout_dir, + out_type=cfg.eval.out_type_infer, + n_extrap_steps=cfg.eval.n_extrap_steps, ) return eval_metrics diff --git a/lagrangebench/models/egnn.py b/lagrangebench/models/egnn.py index b98ed7f..0f09854 100644 --- a/lagrangebench/models/egnn.py +++ b/lagrangebench/models/egnn.py @@ -16,6 +16,7 @@ from jax.tree_util import Partial from jax_md import space +from lagrangebench.config import cfg from lagrangebench.utils import NodeType from .base import BaseModel @@ -249,7 +250,6 @@ class EGNN(BaseModel): def __init__( self, - hidden_size: int, output_size: int, dt: float, n_vels: int, @@ -257,7 +257,6 @@ def __init__( shift_fn: space.ShiftFn, normalization_stats: Optional[Dict[str, jnp.ndarray]] = None, act_fn: Callable = jax.nn.silu, - num_mp_steps: int = 4, homogeneous_particles: bool = True, residual: bool = True, attention: bool = False, @@ -290,17 +289,17 @@ def __init__( """ super().__init__() # network - self._hidden_size = hidden_size + self._hidden_size = cfg.model.latent_dim self._output_size = output_size self._act_fn = act_fn - self._num_mp_steps = num_mp_steps + self._num_mp_steps = cfg.model.num_mp_steps self._residual = residual self._attention = attention self._normalize = normalize self._tanh = tanh # integrator - self._dt = dt / num_mp_steps + self._dt = dt / self._num_mp_steps self._displacement_fn = displacement_fn self._shift_fn = shift_fn if normalization_stats is None: diff --git a/lagrangebench/models/gns.py b/lagrangebench/models/gns.py index 9020231..5756305 100644 --- a/lagrangebench/models/gns.py +++ b/lagrangebench/models/gns.py @@ -9,6 +9,7 @@ import jax.numpy as jnp import jraph +from lagrangebench.config import cfg from lagrangebench.utils import NodeType from .base import BaseModel @@ -35,27 +36,21 @@ class GNS(BaseModel): def __init__( self, particle_dimension: int, - latent_size: int, - blocks_per_step: int, - num_mp_steps: int, - particle_type_embedding_size: int, + particle_type_embedding_size: int = 16, num_particle_types: int = NodeType.SIZE, ): """Initialize the model. Args: particle_dimension: Space dimensionality (e.g. 2 or 3). - latent_size: Size of the latent representations. - blocks_per_step: Number of MLP layers per block. - num_mp_steps: Number of message passing steps. particle_type_embedding_size: Size of the particle type embedding. num_particle_types: Max number of particle types. """ super().__init__() self._output_size = particle_dimension - self._latent_size = latent_size - self._blocks_per_step = blocks_per_step - self._mp_steps = num_mp_steps + self._latent_size = cfg.model.latent_dim + self._blocks_per_step = cfg.model.num_mlp_layers + self._mp_steps = cfg.model.num_mp_steps self._num_particle_types = num_particle_types self._embedding = hk.Embed( diff --git a/lagrangebench/models/painn.py b/lagrangebench/models/painn.py index 0447361..e394ba1 100644 --- a/lagrangebench/models/painn.py +++ b/lagrangebench/models/painn.py @@ -16,6 +16,7 @@ import jax.tree_util as tree import jraph +from lagrangebench.config import cfg from lagrangebench.utils import NodeType from .utils import LinearXav @@ -366,9 +367,7 @@ class PaiNN(hk.Module): def __init__( self, - hidden_size: int, output_size: int, - num_mp_steps: int, radial_basis_fn: Callable, cutoff_fn: Callable, n_vels: int, @@ -399,8 +398,8 @@ def __init__( self._n_vels = n_vels self._homogeneous_particles = homogeneous_particles - self._hidden_size = hidden_size - self._num_mp_steps = num_mp_steps + self._hidden_size = cfg.model.latent_dim + self._num_mp_steps = cfg.model.num_mp_steps self._eps = eps self._shared_filters = shared_filters self._shared_interactions = shared_interactions @@ -408,27 +407,27 @@ def __init__( self.radial_basis_fn = radial_basis_fn self.cutoff_fn = cutoff_fn - self.scalar_emb = LinearXav(hidden_size, name="scalar_embedding") + self.scalar_emb = LinearXav(self._hidden_size, name="scalar_embedding") # mix vector channels (only used if vector features are present in input) self.vector_emb = LinearXav( - hidden_size, with_bias=False, name="vector_embedding" + self._hidden_size, with_bias=False, name="vector_embedding" ) if shared_filters: - self.filter_net = LinearXav(3 * hidden_size, name="filter_net") + self.filter_net = LinearXav(3 * self._hidden_size, name="filter_net") else: self.filter_net = LinearXav( - num_mp_steps * 3 * hidden_size, name="filter_net" + self._num_mp_steps * 3 * self._hidden_size, name="filter_net" ) if self._shared_interactions: self.layers = [ - PaiNNLayer(hidden_size, 0, activation, eps=eps) - ] * num_mp_steps + PaiNNLayer(self._hidden_size, 0, activation, eps=eps) + ] * self._num_mp_steps else: self.layers = [ - PaiNNLayer(hidden_size, i, activation, eps=eps) - for i in range(num_mp_steps) + PaiNNLayer(self._hidden_size, i, activation, eps=eps) + for i in range(self._num_mp_steps) ] self._readout = PaiNNReadout(self._hidden_size, out_channels=output_size) diff --git a/lagrangebench/models/segnn.py b/lagrangebench/models/segnn.py index 5f3ee66..90964f2 100644 --- a/lagrangebench/models/segnn.py +++ b/lagrangebench/models/segnn.py @@ -8,7 +8,6 @@ Standalone implementation + validation: https://github.com/gerkone/segnn-jax """ - import warnings from math import prod from typing import Any, Callable, Dict, Optional, Tuple, Union @@ -21,6 +20,7 @@ from e3nn_jax import Irreps, IrrepsArray from jax.tree_util import Partial, tree_map +from lagrangebench.config import cfg from lagrangebench.utils import NodeType from .base import BaseModel @@ -445,16 +445,9 @@ def __init__( self, node_features_irreps: Irreps, edge_features_irreps: Irreps, - scalar_units: int, - lmax_hidden: int, - lmax_attributes: int, output_irreps: Irreps, - num_mp_steps: int, n_vels: int, - velocity_aggregate: str = "avg", homogeneous_particles: bool = True, - norm: Optional[str] = None, - blocks_per_step: int = 2, embed_msg_features: bool = False, ): """ @@ -463,30 +456,23 @@ def __init__( Args: node_features_irreps: Irreps of the node features. edge_features_irreps: Irreps of the additional message passing features. - scalar_units: Hidden units (lower bound). Actual number depends on lmax. - lmax_hidden: Maximum L of the hidden layer representations. - lmax_attributes: Maximum L of the attributes. output_irreps: Output representation. - num_mp_steps: Number of message passing layers n_vels: Number of velocities in the history. - velocity_aggregate: Velocity sequence aggregation method. homogeneous_particles: If all particles are of homogeneous type. - norm: Normalization type. Either None, 'instance' or 'batch' - blocks_per_step: Number of tensor product blocks in each message passing embed_msg_features: Set to true to also embed edges/message passing features """ super().__init__() # network - self._attribute_irreps = Irreps.spherical_harmonics(lmax_attributes) + self._attribute_irreps = Irreps.spherical_harmonics(cfg.model.lmax_attributes) self._hidden_irreps = weight_balanced_irreps( - scalar_units, self._attribute_irreps, lmax_hidden + cfg.model.latent_dim, self._attribute_irreps, cfg.model.lmax_hidden ) self._output_irreps = output_irreps - self._num_mp_steps = num_mp_steps + self._num_mp_steps = cfg.model.num_mp_steps + self._norm = cfg.model.segnn_norm + self._blocks_per_step = cfg.model.num_mlp_layers self._embed_msg_features = embed_msg_features - self._norm = norm - self._blocks_per_step = blocks_per_step self._embedding = O3Embedding( self._hidden_irreps, @@ -500,13 +486,13 @@ def __init__( ) # transform - assert velocity_aggregate in [ + assert cfg.model.velocity_aggregate in [ "avg", "last", ], "Invalid velocity aggregate. Must be one of 'avg', 'sum' or 'last'." self._node_features_irreps = node_features_irreps self._edge_features_irreps = edge_features_irreps - self._velocity_aggregate = velocity_aggregate + self._velocity_aggregate = cfg.model.velocity_aggregate self._n_vels = n_vels self._homogeneous_particles = homogeneous_particles diff --git a/lagrangebench/runner.py b/lagrangebench/runner.py new file mode 100644 index 0000000..beacb6b --- /dev/null +++ b/lagrangebench/runner.py @@ -0,0 +1,245 @@ +import os +import os.path as osp +from argparse import Namespace +from datetime import datetime +from typing import Callable, Dict, Optional, Tuple, Type + +import haiku as hk +import jax +import jax.numpy as jnp +import jmp +import numpy as np +from e3nn_jax import Irreps +from jax_md import space + +from lagrangebench import Trainer, infer, models +from lagrangebench.case_setup import case_builder +from lagrangebench.data import H5Dataset +from lagrangebench.evaluate import averaged_metrics +from lagrangebench.models.utils import node_irreps +from lagrangebench.utils import NodeType + + +def train_or_infer(cfg): + mode = cfg.main.mode + old_model_dir = cfg.model.model_dir + is_test = cfg.eval.test + + data_train, data_valid, data_test = setup_data(cfg) + + metadata = data_train.metadata + # neighbors search + bounds = np.array(metadata["bounds"]) + box = bounds[:, 1] - bounds[:, 0] + + # setup core functions + case = case_builder( + box=box, + metadata=metadata, + external_force_fn=data_train.external_force_fn, + ) + + _, particle_type = data_train[0] + + # setup model from configs + model, MODEL = setup_model( + cfg, + metadata=metadata, + homogeneous_particles=particle_type.max() == particle_type.min(), + has_external_force=data_train.external_force_fn is not None, + normalization_stats=case.normalization_stats, + ) + model = hk.without_apply_rng(hk.transform_with_state(model)) + + # mixed precision training based on this reference: + # https://github.com/deepmind/dm-haiku/blob/main/examples/imagenet/train.py + policy = jmp.get_policy("params=float32,compute=float32,output=float32") + hk.mixed_precision.set_policy(MODEL, policy) + + if mode == "train" or mode == "all": + print("Start training...") + + if cfg.logging.run_name is None: + run_prefix = f"{cfg.model.name}_{data_train.name}" + data_and_time = datetime.today().strftime("%Y%m%d-%H%M%S") + cfg.logging.run_name = f"{run_prefix}_{data_and_time}" + + cfg.model.model_dir = os.path.join(cfg.logging.ckp_dir, cfg.logging.run_name) + os.makedirs(cfg.model.model_dir, exist_ok=True) + os.makedirs(os.path.join(cfg.model.model_dir, "best"), exist_ok=True) + with open(os.path.join(cfg.model.model_dir, "config.yaml"), "w") as f: + cfg.dump(stream=f) + with open(os.path.join(cfg.model.model_dir, "best", "config.yaml"), "w") as f: + cfg.dump(stream=f) + + trainer = Trainer(model, case, data_train, data_valid) + _, _, _ = trainer( + step_max=cfg.train.step_max, + load_checkpoint=old_model_dir, + store_checkpoint=cfg.model.model_dir, + ) + + if mode == "infer" or mode == "all": + print("Start inference...") + + if mode == "infer": + model_dir = cfg.model.model_dir + if mode == "all": + model_dir = os.path.join(cfg.model.model_dir, "best") + assert osp.isfile(os.path.join(model_dir, "params_tree.pkl")) + + cfg.eval.rollout_dir = model_dir.replace("ckp", "rollout") + os.makedirs(cfg.eval.rollout_dir, exist_ok=True) + + if cfg.eval.n_trajs_infer is None: + cfg.eval.n_trajs_infer = cfg.eval.n_trajs_train + + assert model_dir, "model_dir must be specified for inference." + metrics = infer( + model, + case, + data_test if is_test else data_valid, + load_checkpoint=model_dir, + ) + + split = "test" if is_test else "valid" + print(f"Metrics of {model_dir} on {split} split:") + print(averaged_metrics(metrics)) + + +def setup_data(cfg) -> Tuple[H5Dataset, H5Dataset, Namespace]: + data_dir = cfg.main.data_dir + ckp_dir = cfg.logging.ckp_dir + rollout_dir = cfg.eval.rollout_dir + input_seq_length = cfg.model.input_seq_length + n_rollout_steps = cfg.eval.n_rollout_steps + if not osp.isabs(data_dir): + data_dir = osp.join(os.getcwd(), data_dir) + + if ckp_dir is not None: + os.makedirs(ckp_dir, exist_ok=True) + if rollout_dir is not None: + os.makedirs(rollout_dir, exist_ok=True) + + # dataloader + data_train = H5Dataset( + "train", + dataset_path=data_dir, + input_seq_length=input_seq_length, + extra_seq_length=cfg.optimizer.pushforward.unrolls[-1], + ) + data_valid = H5Dataset( + "valid", + dataset_path=data_dir, + input_seq_length=input_seq_length, + extra_seq_length=n_rollout_steps, + ) + data_test = H5Dataset( + "test", + dataset_path=data_dir, + input_seq_length=input_seq_length, + extra_seq_length=n_rollout_steps, + ) + + # TODO find another way to set these + if cfg.eval.n_trajs_train == -1: + cfg.eval.n_trajs_train = data_valid.num_samples + if cfg.eval.n_trajs_infer == -1: + cfg.eval.n_trajs_infer = data_valid.num_samples + + assert data_valid.num_samples >= cfg.eval.n_trajs_train, ( + f"Number of available evaluation trajectories ({data_valid.num_samples}) " + f"exceeds eval_n_trajs ({cfg.eval.n_trajs_train})" + ) + + return data_train, data_valid, data_test + + +def setup_model( + cfg, + metadata: Dict, + homogeneous_particles: bool = False, + has_external_force: bool = False, + normalization_stats: Optional[Dict] = None, +) -> Tuple[Callable, Type]: + """Setup model based on cfg.""" + model_name = cfg.model.name.lower() + + input_seq_length = cfg.model.input_seq_length + magnitude_features = cfg.train.magnitude_features + + if model_name == "gns": + + def model_fn(x): + return models.GNS( + particle_dimension=metadata["dim"], + num_particle_types=NodeType.SIZE, + particle_type_embedding_size=16, + )(x) + + MODEL = models.GNS + elif model_name == "segnn": + # Hx1o vel, Hx0e vel, 2x1o boundary, 9x0e type + node_feature_irreps = node_irreps( + metadata, + input_seq_length, + has_external_force, + magnitude_features, + homogeneous_particles, + ) + # 1o displacement, 0e distance + edge_feature_irreps = Irreps("1x1o + 1x0e") + + def model_fn(x): + return models.SEGNN( + node_features_irreps=node_feature_irreps, + edge_features_irreps=edge_feature_irreps, + output_irreps=Irreps("1x1o"), + n_vels=input_seq_length - 1, + homogeneous_particles=homogeneous_particles, + )(x) + + MODEL = models.SEGNN + elif model_name == "egnn": + box = cfg.box + if jnp.array(metadata["periodic_boundary_conditions"]).any(): + displacement_fn, shift_fn = space.periodic(jnp.array(box)) + else: + displacement_fn, shift_fn = space.free() + + displacement_fn = jax.vmap(displacement_fn, in_axes=(0, 0)) + shift_fn = jax.vmap(shift_fn, in_axes=(0, 0)) + + def model_fn(x): + return models.EGNN( + output_size=1, + dt=metadata["dt"] * metadata["write_every"], + displacement_fn=displacement_fn, + shift_fn=shift_fn, + normalization_stats=normalization_stats, + n_vels=input_seq_length - 1, + residual=True, + )(x) + + MODEL = models.EGNN + elif model_name == "painn": + assert magnitude_features, "PaiNN requires magnitudes" + radius = metadata["default_connectivity_radius"] * 1.5 + + def model_fn(x): + return models.PaiNN( + output_size=1, + n_vels=input_seq_length - 1, + radial_basis_fn=models.painn.gaussian_rbf(20, radius, trainable=True), + cutoff_fn=models.painn.cosine_cutoff(radius), + )(x) + + MODEL = models.PaiNN + elif model_name == "linear": + + def model_fn(x): + return models.Linear(dim_out=metadata["dim"])(x) + + MODEL = models.Linear + + return model_fn, MODEL diff --git a/lagrangebench/train/strats.py b/lagrangebench/train/strats.py index da47056..a585983 100644 --- a/lagrangebench/train/strats.py +++ b/lagrangebench/train/strats.py @@ -95,7 +95,7 @@ def push_forward_sample_steps(key, step, pushforward): key, key_unroll = jax.random.split(key, 2) # steps needs to be an ordered list - steps = jnp.array(pushforward["steps"]) + steps = jnp.array(pushforward.steps) assert all(steps[i] <= steps[i + 1] for i in range(len(steps) - 1)) # until which index to sample from @@ -103,8 +103,8 @@ def push_forward_sample_steps(key, step, pushforward): unroll_steps = jax.random.choice( key_unroll, - a=jnp.array(pushforward["unrolls"][:idx]), - p=jnp.array(pushforward["probs"][:idx]), + a=jnp.array(pushforward.unrolls[:idx]), + p=jnp.array(pushforward.probs[:idx]), ) return key, unroll_steps diff --git a/lagrangebench/train/trainer.py b/lagrangebench/train/trainer.py index 322b6c5..057dc5d 100644 --- a/lagrangebench/train/trainer.py +++ b/lagrangebench/train/trainer.py @@ -2,24 +2,23 @@ import os from functools import partial -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple import haiku as hk import jax import jax.numpy as jnp import jraph import optax +import wandb from jax import vmap from torch.utils.data import DataLoader -from wandb.wandb_run import Run +from lagrangebench.config import cfg, cfg_to_dict from lagrangebench.data import H5Dataset from lagrangebench.data.utils import numpy_collate -from lagrangebench.defaults import defaults from lagrangebench.evaluate import MetricsComputer, averaged_metrics, eval_rollout from lagrangebench.utils import ( LossConfig, - PushforwardConfig, broadcast_from_batch, broadcast_to_batch, get_kinematic_mask, @@ -40,18 +39,17 @@ def _mse( particle_type: jnp.ndarray, target: jnp.ndarray, model_fn: Callable, - loss_weight: LossConfig, + loss_weight: Dict[str, float], ): pred, state = model_fn(params, state, (features, particle_type)) # check active (non zero) output shapes - keys = list(set(loss_weight.nonzero) & set(pred.keys())) - assert all(target[k].shape == pred[k].shape for k in keys) + assert all(target[k].shape == pred[k].shape for k in pred) # particle mask non_kinematic_mask = jnp.logical_not(get_kinematic_mask(particle_type)) num_non_kinematic = non_kinematic_mask.sum() # loss components losses = [] - for t in keys: + for t in pred: losses.append((loss_weight[t] * (pred[t] - target[t]) ** 2).sum(axis=-1)) total_loss = jnp.array(losses).sum(0) total_loss = jnp.where(non_kinematic_mask, total_loss, 0) @@ -94,26 +92,6 @@ def Trainer( case, data_train: H5Dataset, data_valid: H5Dataset, - pushforward: Optional[PushforwardConfig] = None, - metrics: List = ["mse"], - seed: int = defaults.seed, - batch_size: int = defaults.batch_size, - input_seq_length: int = defaults.input_seq_length, - noise_std: float = defaults.noise_std, - lr_start: float = defaults.lr_start, - lr_final: float = defaults.lr_final, - lr_decay_steps: int = defaults.lr_decay_steps, - lr_decay_rate: float = defaults.lr_decay_rate, - loss_weight: Optional[LossConfig] = None, - n_rollout_steps: int = defaults.n_rollout_steps, - eval_n_trajs: int = defaults.eval_n_trajs, - rollout_dir: str = defaults.rollout_dir, - out_type: str = defaults.out_type, - log_steps: int = defaults.log_steps, - eval_steps: int = defaults.eval_steps, - metrics_stride: int = defaults.metrics_stride, - num_workers: int = defaults.num_workers, - batch_size_infer: int = defaults.batch_size_infer, ) -> Callable: """ Builds a function that automates model training and evaluation. @@ -130,26 +108,6 @@ def Trainer( case: Case setup class. data_train: Training dataset. data_valid: Validation dataset. - pushforward: Pushforward configuration. None for no pushforward. - metrics: Metrics to evaluate the model on. - seed: Random seed for model init, training tricks and dataloading. - batch_size: Training batch size. - input_seq_length: Input sequence length. Default is 6. - noise_std: Noise standard deviation for the GNS-style noise. - lr_start: Initial learning rate. - lr_final: Final learning rate. - lr_decay_steps: Number of steps to reach the final learning rate. - lr_decay_rate: Learning rate decay rate. - loss_weight: Loss weight object. - n_rollout_steps: Number of autoregressive rollout steps. - eval_n_trajs: Number of trajectories to evaluate. - rollout_dir: Rollout directory. - out_type: Output type. - log_steps: Wandb/screen logging frequency. - eval_steps: Evaluation and checkpointing frequency. - metrics_stride: stride for e_kin and sinkhorn. - num_workers: number of workers for data loading. - batch_size_infer: batch size for validation/testing. Returns: Configured training function. @@ -158,14 +116,23 @@ def Trainer( model, hk.TransformedWithState ), "Model must be passed as an Haiku transformed function." - base_key, seed_worker, generator = set_seed(seed) + input_seq_length = cfg.model.input_seq_length + noise_std = cfg.optimizer.noise_std + n_rollout_steps = cfg.eval.n_rollout_steps + eval_n_trajs = cfg.eval.n_trajs_train + # make immutable for jitting + # TODO look for simpler alternatives to LossConfig + loss_weight = LossConfig(**dict(cfg.optimizer.loss_weight)) + pushforward = cfg.optimizer.pushforward + + base_key, seed_worker, generator = set_seed(cfg.main.seed) # dataloaders loader_train = DataLoader( dataset=data_train, - batch_size=batch_size, + batch_size=cfg.train.batch_size, shuffle=True, - num_workers=num_workers, + num_workers=cfg.train.num_workers, collate_fn=numpy_collate, drop_last=True, worker_init_fn=seed_worker, @@ -173,7 +140,7 @@ def Trainer( ) loader_valid = DataLoader( dataset=data_valid, - batch_size=batch_size_infer, + batch_size=cfg.eval.batch_size_infer, collate_fn=numpy_collate, worker_init_fn=seed_worker, generator=generator, @@ -181,37 +148,30 @@ def Trainer( # learning rate decays from lr_start to lr_final over lr_decay_steps exponentially lr_scheduler = optax.exponential_decay( - init_value=lr_start, - transition_steps=lr_decay_steps, - decay_rate=lr_decay_rate, - end_value=lr_final, + init_value=cfg.optimizer.lr_start, + transition_steps=cfg.optimizer.lr_decay_steps, + decay_rate=cfg.optimizer.lr_decay_rate, + end_value=cfg.optimizer.lr_final, ) # optimizer opt_init, opt_update = optax.adamw(learning_rate=lr_scheduler, weight_decay=1e-8) - # loss config - loss_weight = LossConfig() if loss_weight is None else LossConfig(**loss_weight) - # pushforward config - if pushforward is None: - pushforward = PushforwardConfig() - # metrics computer config metrics_computer = MetricsComputer( - metrics, + cfg.eval.metrics_train, dist_fn=case.displacement, metadata=data_train.metadata, - input_seq_length=data_train.input_seq_length, - stride=metrics_stride, + input_seq_length=input_seq_length, + stride=cfg.eval.metrics_stride_train, ) def _train( - step_max: int = defaults.step_max, + step_max: Optional[int] = None, params: Optional[hk.Params] = None, state: Optional[hk.State] = None, opt_state: Optional[optax.OptState] = None, store_checkpoint: Optional[str] = None, load_checkpoint: Optional[str] = None, - wandb_run: Optional[Run] = None, ) -> Tuple[hk.Params, hk.State, optax.OptState]: """ Training loop. @@ -226,7 +186,6 @@ def _train( opt_state: Optional optimizer state. store_checkpoint: Checkpoints destination. Without it params aren't saved. load_checkpoint: Initial checkpoint directory. If provided resumes training. - wandb_run: Wandb run. Returns: Tuple containing the final model parameters, state and optimizer state. @@ -239,6 +198,9 @@ def _train( loader_valid ), "eval_n_trajs must be <= len(loader_valid)" + if step_max is None: + step_max = cfg.train.step_max + # Precompile model for evaluation model_apply = jax.jit(model.apply) @@ -264,9 +226,26 @@ def _train( key, subkey = jax.random.split(key, 2) params, state = model.init(subkey, (features, particle_type[0])) - if wandb_run is not None: - wandb_run.log({"info/num_params": get_num_params(params)}, 0) - wandb_run.log({"info/step_start": step}, 0) + # start logging + if cfg.logging.wandb: + cfg_dict = cfg_to_dict(cfg) + cfg_dict["info"] = { + "dataset_name": data_train.name, + "len_train": len(data_train), + "len_eval": len(data_valid), + "num_params": get_num_params(params).item(), + "step_start": step, + } + + wandb_run = wandb.init( + project=cfg.logging.wandb_project, + entity=cfg.logging.wandb_entity, + name=cfg.logging.run_name, + config=cfg_dict, + save_code=True, + ) + else: + wandb_run = None # initialize optimizer state if opt_state is None: @@ -340,7 +319,7 @@ def _train( opt_state=opt_state, ) - if step % log_steps == 0: + if step % cfg.logging.log_steps == 0: loss.block_until_ready() if wandb_run: wandb_run.log({"train/loss": loss.item()}, step) @@ -348,7 +327,7 @@ def _train( step_str = str(step).zfill(len(str(int(step_max)))) print(f"{step_str}, train/loss: {loss.item():.5f}.") - if step % eval_steps == 0 and step > 0: + if step % cfg.logging.eval_steps == 0 and step > 0: nbrs = broadcast_from_batch(neighbors_batch, index=0) eval_metrics = eval_rollout( case=case, @@ -360,8 +339,8 @@ def _train( loader_eval=loader_valid, n_rollout_steps=n_rollout_steps, n_trajs=eval_n_trajs, - rollout_dir=rollout_dir, - out_type=out_type, + rollout_dir=cfg.eval.rollout_dir, + out_type=cfg.eval.out_type_train, ) metrics = averaged_metrics(eval_metrics) @@ -383,6 +362,9 @@ def _train( if step == step_max + 1: break + if cfg.logging.wandb: + wandb.finish() + return params, state, opt_state return _train diff --git a/lagrangebench/utils.py b/lagrangebench/utils.py index 9589e39..3cef392 100644 --- a/lagrangebench/utils.py +++ b/lagrangebench/utils.py @@ -5,8 +5,8 @@ import os import pickle import random -from dataclasses import dataclass, field -from typing import Callable, List, Tuple +from dataclasses import dataclass +from typing import Callable, Tuple import cloudpickle import jax @@ -171,27 +171,5 @@ class LossConfig: vel: float = 0.0 acc: float = 1.0 - def __getitem__(self, item): - return getattr(self, item) - - @property - def nonzero(self): - return [field for field in self.__annotations__ if self[field] != 0] - - -@dataclass(frozen=False) -class PushforwardConfig: - """Pushforward trick configuration. - - Attributes: - steps: When to introduce each unroll stage, e.g. [-1, 20000, 50000] - unrolls: For how many timesteps to unroll, e.g. [0, 1, 20] - probs: Probability (ratio) between the relative unrolls, e.g. [5, 4, 1] - """ - - steps: List[int] = field(default_factory=lambda: [-1]) - unrolls: List[int] = field(default_factory=lambda: [0]) - probs: List[float] = field(default_factory=lambda: [1.0]) - - def __getitem__(self, item): - return getattr(self, item) + def __getitem__(self, key): + return getattr(self, key) diff --git a/main.py b/main.py index bc25a09..89fde31 100644 --- a/main.py +++ b/main.py @@ -1,38 +1,91 @@ +import argparse import os -import pprint -from argparse import Namespace +from typing import List -import yaml -from experiments.config import NestedLoader, cli_arguments +def cli_arguments(): + """Inspired by https://stackoverflow.com/a/51686813""" + parser = argparse.ArgumentParser() + + # config arguments + parser.add_argument("-c", "--config", type=str, help="Path to the config yaml.") + + args, extras = parser.parse_known_args() + if extras is None: + extras = [] + args.extra = preprocess_extras(extras) + + return args + + +def preprocess_extras(extras: List[str]): + """Preprocess extras. + + Args: + extras: key value pairs in any of the following formats: + `--key value`, `--key=value`, `key value`, `key=value` + + Return: + All key value pairs formatted as `key value` + """ + + temp = [] + for arg in extras: + if arg.startswith("--"): # remove preceding "--" + arg = arg[2:] + temp += arg.split("=") # split key value pairs + + return temp + + +def import_cfg(config_path, extras): + """Import cfg without executing lagrangebench.__init__(). + + Based on: + https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly + """ + import importlib.util + + spec = importlib.util.spec_from_file_location("temp", "lagrangebench/config.py") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + cfg = module.cfg + load_cfg = module.load_cfg + load_cfg(cfg, config_path, extras) + return cfg + if __name__ == "__main__": cli_args = cli_arguments() - if "config" in cli_args: # to (re)start training - config_path = cli_args["config"] - elif "model_dir" in cli_args: # to run inference - config_path = os.path.join(cli_args["model_dir"], "config.yaml") - with open(config_path, "r") as f: - args = yaml.load(f, NestedLoader) + if cli_args.config is not None: # start from config.yaml + config_path = cli_args.config.strip() + elif "model.model_dir" in cli_args.extra: # start from a checkpoint + model_dir = cli_args.extra[cli_args.extra.index("model.model_dir") + 1] + config_path = os.path.join(model_dir, "config.yaml") - # priority to command line arguments - args.update(cli_args) - args = Namespace(config=Namespace(**args), info=Namespace()) - print("#" * 79, "\nStarting a LagrangeBench run with the following configs:") - pprint.pprint(vars(args.config)) - print("#" * 79) + # load cfg without executing lagrangebench.__init__() -> temporary cfg for cuda + cfg = import_cfg(config_path, cli_args.extra) # specify cuda device os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 from TensorFlow - os.environ["CUDA_VISIBLE_DEVICES"] = str(args.config.gpu) - os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(args.config.xla_mem_fraction) - - if args.config.f64: + os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.main.gpu) + os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(cfg.main.xla_mem_fraction) + if cfg.main.dtype == "float64": from jax import config config.update("jax_enable_x64", True) - from experiments.run import train_or_infer + # load cfg once again, this time executing lagrangebench.__init__() -> global cfg + from lagrangebench.config import cfg, load_cfg + + load_cfg(cfg, config_path, cli_args.extra) + + print("#" * 79, "\nStarting a LagrangeBench run with the following configs:") + print(cfg.dump()) + print("#" * 79) + + from lagrangebench.runner import train_or_infer - train_or_infer(args) + train_or_infer(cfg) diff --git a/neighbors_search/scaling.py b/neighbors_search/scaling.py index cd47352..adb4434 100644 --- a/neighbors_search/scaling.py +++ b/neighbors_search/scaling.py @@ -28,7 +28,7 @@ def update_wrapper(neighbors_old, r_new): def compute_neighbors(args): Nx = args.Nx - mode = args.mode + mode = args.main.mode nl_backend = args.nl_backend num_partitions = args.num_partitions print(f"Start with Nx={Nx}, mode={mode}, backend={nl_backend}") diff --git a/poetry.lock b/poetry.lock index fdc1a02..43a7c40 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1745,8 +1745,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.23.3", markers = "python_version > \"3.10\""}, - {version = ">=1.21.2", markers = "python_version > \"3.9\" and python_version <= \"3.10\""}, {version = ">1.20", markers = "python_version <= \"3.9\""}, + {version = ">=1.21.2", markers = "python_version > \"3.9\" and python_version <= \"3.10\""}, ] [package.extras] @@ -3139,6 +3139,17 @@ files = [ ml-dtypes = ">=0.3.1" numpy = ">=1.16.0" +[[package]] +name = "toml" +version = "0.10.2" +description = "Python Library for Tom's Obvious, Minimal Language" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, + {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, +] + [[package]] name = "tomli" version = "2.0.1" @@ -3395,6 +3406,21 @@ files = [ {file = "wget-3.2.zip", hash = "sha256:35e630eca2aa50ce998b9b1a127bb26b30dfee573702782aa982f875e3f16061"}, ] +[[package]] +name = "yacs" +version = "0.1.8" +description = "Yet Another Configuration System" +optional = false +python-versions = "*" +files = [ + {file = "yacs-0.1.8-py2-none-any.whl", hash = "sha256:d43d1854c1ffc4634c5b349d1c1120f86f05c3a294c9d141134f282961ab5d94"}, + {file = "yacs-0.1.8-py3-none-any.whl", hash = "sha256:99f893e30497a4b66842821bac316386f7bd5c4f47ad35c9073ef089aa33af32"}, + {file = "yacs-0.1.8.tar.gz", hash = "sha256:efc4c732942b3103bea904ee89af98bcd27d01f0ac12d8d4d369f1e7a2914384"}, +] + +[package.dependencies] +PyYAML = "*" + [[package]] name = "zipp" version = "3.17.0" @@ -3413,4 +3439,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<=3.11" -content-hash = "5fc2e88ec569a667ab5076bf43acf88c3bf3d7d359756359b31a9ccdd25148d7" +content-hash = "9fe394e52f5db4b405b0ab8f8ba4d444ca4cacc7b87ee2839fc3025ee01ecb09" diff --git a/pyproject.toml b/pyproject.toml index 8dabb99..42b1bf4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,8 @@ ott-jax = "^0.4.2" matscipy = "^0.8.0" torch = {version = "2.1.0+cpu", source = "torchcpu"} wget = "^3.2" +yacs = "^0.1.8" +toml = "^0.10.2" [tool.poetry.group.dev.dependencies] # mypy = ">=1.8.0" - consider in the future @@ -73,7 +75,6 @@ url = "https://download.pytorch.org/whl/cpu" priority = "explicit" [tool.ruff] -ignore = ["F811", "E402"] exclude = [ ".git", ".venv", @@ -85,6 +86,7 @@ show-fixes = true line-length = 88 [tool.ruff.lint] +ignore = ["F811", "E402"] select = [ "E", # pycodestyle "F", # Pyflakes @@ -93,6 +95,9 @@ select = [ # "D", # pydocstyle - consider in the future ] +[tool.ruff.lint.isort] +known-third-party = ["wandb"] + [tool.pytest.ini_options] testpaths = "tests/" addopts = "--cov=lagrangebench --cov-fail-under=50" diff --git a/requirements_cuda.txt b/requirements_cuda.txt index 0bc59df..3c969e3 100644 --- a/requirements_cuda.txt +++ b/requirements_cuda.txt @@ -18,3 +18,4 @@ PyYAML torch==2.1.0+cpu wandb wget +yacs>=0.1.8 diff --git a/tests/case_test.py b/tests/case_test.py index 373eb28..d3cb33f 100644 --- a/tests/case_test.py +++ b/tests/case_test.py @@ -5,6 +5,14 @@ import numpy as np from lagrangebench.case_setup import case_builder +from lagrangebench.config import custom_config + + +@custom_config +def case_test_config(cfg): + cfg.model.input_seq_length = 3 # two past velocities + cfg.train.isotropic_norm = False + cfg.optimizer.noise_std = 0.0 class TestCaseBuilder(unittest.TestCase): @@ -25,14 +33,7 @@ def setUp(self): bounds = np.array(self.metadata["bounds"]) box = bounds[:, 1] - bounds[:, 0] - self.case = case_builder( - box, - self.metadata, - input_seq_length=3, # two past velocities - isotropic_norm=False, - noise_std=0.0, - external_force_fn=None, - ) + self.case = case_builder(box, self.metadata, external_force_fn=None) self.key = jax.random.PRNGKey(0) # position input shape (num_particles, sequence_len, dim) = (3, 5, 3) @@ -63,7 +64,7 @@ def setUp(self): ) self.particle_types = np.array([0, 0, 0]) - key, features, target_dict, neighbors = self.case.allocate( + _, _, _, neighbors = self.case.allocate( self.key, (self.position_data, self.particle_types) ) self.neighbors = neighbors diff --git a/tests/models_test.py b/tests/models_test.py index 702280c..ec6d20e 100644 --- a/tests/models_test.py +++ b/tests/models_test.py @@ -7,9 +7,19 @@ import numpy as np from lagrangebench import models +from lagrangebench.config import custom_config from lagrangebench.utils import NodeType +@custom_config +def model_test_config(cfg): + cfg.model.hidden_dim = 8 + cfg.model.output_size = 1 + cfg.model.num_mp_steps = 1 + cfg.model.lmax_attributes = 1 + cfg.model.lmax_hidden = 1 + + class ModelTest(unittest.TestCase): def dummy_sample(self, vel=None, pos=None): key = self.key() @@ -72,11 +82,7 @@ def segnn(x): return models.SEGNN( node_features_irreps="5x1o + 5x0e", edge_features_irreps="1x1o + 1x0e", - scalar_units=8, - lmax_hidden=1, - lmax_attributes=1, n_vels=5, - num_mp_steps=1, output_irreps="1x1o", )(x) @@ -89,9 +95,7 @@ def segnn(x): def test_egnn(self): def egnn(x): return models.EGNN( - hidden_size=8, output_size=1, - num_mp_steps=1, dt=0.01, n_vels=5, displacement_fn=lambda x, y: x - y, @@ -107,9 +111,7 @@ def egnn(x): def test_painn(self): def painn(x): return models.PaiNN( - hidden_size=8, output_size=1, - num_mp_steps=1, radial_basis_fn=models.painn.gaussian_rbf(20, 10, trainable=True), cutoff_fn=models.painn.cosine_cutoff(10), n_vels=5, diff --git a/tests/pushforward_test.py b/tests/pushforward_test.py index 06d77d8..ff379b1 100644 --- a/tests/pushforward_test.py +++ b/tests/pushforward_test.py @@ -3,19 +3,22 @@ import jax import numpy as np -from lagrangebench import PushforwardConfig +from lagrangebench.config import cfg, custom_config from lagrangebench.train.strats import push_forward_sample_steps +@custom_config +def pf_test_config(cfg): + cfg.optimizer.pushforward.steps = [-1, 20000, 50000, 100000] + cfg.optimizer.pushforward.unrolls = [0, 1, 3, 20] + cfg.optimizer.pushforward.probs = [4.05, 4.05, 1.0, 1.0] + + class TestPushForward(unittest.TestCase): """Class for unit testing the push-forward functions.""" def setUp(self): - self.pf = PushforwardConfig( - steps=[-1, 20000, 50000, 100000], - unrolls=[0, 1, 3, 20], - probs=[4.05, 4.05, 1.0, 1.0], - ) + self.pf = cfg.optimizer.pushforward self.key = jax.random.PRNGKey(42) diff --git a/tests/rollout_test.py b/tests/rollout_test.py index a559c48..a557b56 100644 --- a/tests/rollout_test.py +++ b/tests/rollout_test.py @@ -1,5 +1,4 @@ import unittest -from argparse import Namespace from functools import partial import haiku as hk @@ -14,32 +13,35 @@ jax_config.update("jax_enable_x64", True) from lagrangebench.case_setup import case_builder +from lagrangebench.config import cfg, custom_config from lagrangebench.data import H5Dataset from lagrangebench.data.utils import get_dataset_stats, numpy_collate from lagrangebench.evaluate import MetricsComputer -from lagrangebench.evaluate.rollout import _forward_eval, eval_batched_rollout +from lagrangebench.evaluate.rollout import _eval_batched_rollout, _forward_eval from lagrangebench.utils import broadcast_from_batch +@custom_config +def eval_test_config(cfg): + # setup the configuration + cfg.main.data_dir = "tests/3D_LJ_3_1214every1" # Lennard-Jones dataset + cfg.model.input_seq_length = 3 + cfg.metrics = ["mse"] + cfg.eval.n_rollout_steps = 100 + cfg.train.isotropic_norm = False + cfg.optimizer.noise_std = 0.0 + + class TestInferBuilder(unittest.TestCase): """Class for unit testing the evaluate_single_rollout function.""" def setUp(self): - self.config = Namespace( - data_dir="tests/3D_LJ_3_1214every1", # Lennard-Jones dataset - input_seq_length=3, # two past velocities - metrics=["mse"], - n_rollout_steps=100, - isotropic_norm=False, - noise_std=0.0, - ) - data_valid = H5Dataset( split="valid", - dataset_path=self.config.data_dir, + dataset_path=cfg.main.data_dir, name="lj3d", - input_seq_length=self.config.input_seq_length, - extra_seq_length=self.config.n_rollout_steps, + input_seq_length=cfg.model.input_seq_length, + extra_seq_length=cfg.eval.n_rollout_steps, ) self.loader_valid = DataLoader( dataset=data_valid, batch_size=1, collate_fn=numpy_collate @@ -47,19 +49,14 @@ def setUp(self): self.metadata = data_valid.metadata self.normalization_stats = get_dataset_stats( - self.metadata, self.config.isotropic_norm, self.config.noise_std + self.metadata, cfg.train.isotropic_norm, cfg.optimizer.noise_std ) bounds = np.array(self.metadata["bounds"]) box = bounds[:, 1] - bounds[:, 0] self.displacement_fn, self.shift_fn = space.periodic(side=box) - self.case = case_builder( - box, - self.metadata, - self.config.input_seq_length, - noise_std=self.config.noise_std, - ) + self.case = case_builder(box, self.metadata) self.key = jax.random.PRNGKey(0) @@ -139,7 +136,7 @@ def model(x): for n_extrap_steps in [0, 5, 10]: with self.subTest(n_extrap_steps): - example_rollout_batch, metrics_batch, neighbors = eval_batched_rollout( + example_rollout_batch, metrics_batch, neighbors = _eval_batched_rollout( forward_eval_vmap=forward_eval_vmap, preprocess_eval_vmap=preprocess_eval_vmap, case=self.case, @@ -148,7 +145,7 @@ def model(x): traj_batch_i=traj_batch_i, neighbors=neighbors, metrics_computer_vmap=metrics_computer_vmap, - n_rollout_steps=self.config.n_rollout_steps, + n_rollout_steps=cfg.eval.n_rollout_steps, n_extrap_steps=n_extrap_steps, t_window=isl, ) @@ -183,7 +180,7 @@ def model(x): "Wrong rollout prediction", ) - total_steps = self.config.n_rollout_steps + n_extrap_steps + total_steps = cfg.eval.n_rollout_steps + n_extrap_steps assert example_rollout_batch.shape[1] == total_steps