Skip to content

Commit

Permalink
Merge pull request #21 from tumaer/unit_tests
Browse files Browse the repository at this point in the history
Unit tests and most of batched rollout
  • Loading branch information
arturtoshev authored Jan 8, 2024
2 parents d213ba1 + e10922c commit af8d7cd
Show file tree
Hide file tree
Showing 31 changed files with 1,037 additions and 257 deletions.
33 changes: 33 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Publish to PyPI

on:
release:
types: [created]

jobs:
build-n-publish:
name: Build and publish to PyPI
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install Poetry
run: |
python -m pip install --upgrade pip
pip install poetry
poetry config virtualenvs.in-project true
- name: Install dependencies
run: |
poetry install
- name: Build and publish
env:
POETRY_PYPI_TOKEN_PYPI: ${{ secrets.POETRY_PYPI_TOKEN_PYPI }}
run: |
poetry version $(git describe --tags --abbrev=0)
poetry add $(cat requirements.txt)
poetry build
poetry publish
8 changes: 8 additions & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: Ruff
on: [push, pull_request]
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: chartboost/ruff-action@v1
43 changes: 43 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
name: Tests

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

permissions:
contents: read

jobs:
tests:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install Poetry
run: |
python -m pip install --upgrade pip
pip install poetry
poetry config virtualenvs.in-project true
- name: Install dependencies
run: |
poetry install
- name: Run pytest and generate coverage report
run: |
.venv/bin/pytest --cov-report=xml
- name: Upload coverage report to Codecov
uses: codecov/codecov-action@v3
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: ./coverage.xml
flags: unittests
verbose: true
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ venv*/
rollouts
profile
dist
.coverage

# Sphinx documentation
docs/_build/
16 changes: 4 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,9 @@ repos:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: requirements-txt-fixer
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
args: [ --profile, black ]
- repo: https://github.com/ambv/black
rev: 23.3.0
hooks:
- id: black
args: ['--config=./pyproject.toml']
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.265'
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.1.8'
hooks:
- id: ruff
args: [ --fix ]
- id: ruff-format
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,18 @@ pip install --upgrade jax[cuda12_pip]==0.4.20 -f https://storage.googleapis.com/
```

### MacOS
Currently, only the CPU installation works. You will need to change a few small things to get it going:
Currently, only the CPU installation works. You will need to change a few small things to get it going:
- Clone installation: in `pyproject.toml` change the torch version from `2.1.0+cpu` to `2.1.0`. Then, remove the `poetry.lock` file and run `poetry install --only main`.
- Configs: You will need to set `f64: False` and `num_workers: 0` in the `configs/` files.

Although the current [`jax-metal==0.0.5` library](https://pypi.org/project/jax-metal/) supports jax in general, there seems to be a missing feature used by `jax-md` related to padding -> see [this issue](https://github.com/google/jax/issues/16366#issuecomment-1591085071).

## Usage
### Standalone benchmark library
A general tutorial is provided in the example notebook "Training GNS on the 2D Taylor Green Vortex" under `./notebooks/tutorial.ipynb` on the [LagrangeBench repository](https://github.com/tumaer/lagrangebench). The notebook covers the basics of LagrangeBench, such as loading a dataset, setting up a case, training a model from scratch and evaluating it's performance.
A general tutorial is provided in the example notebook "Training GNS on the 2D Taylor Green Vortex" under `./notebooks/tutorial.ipynb` on the [LagrangeBench repository](https://github.com/tumaer/lagrangebench). The notebook covers the basics of LagrangeBench, such as loading a dataset, setting up a case, training a model from scratch and evaluating its performance.

### Running in a local clone (`main.py`)
Alternatively, experiments can also be set up with `main.py`, based around extensive YAML config files and cli arguments (check [`configs/`](configs/)). By default, the arguments have priority as: 1) passed cli arguments, 2) YAML config and 3) [`defaults.py`](lagrangebench/defaults.py) (`lagrangebench` defaults).
Alternatively, experiments can also be set up with `main.py`, based on extensive YAML config files and cli arguments (check [`configs/`](configs/)). By default, the arguments have priority as: 1) passed cli arguments, 2) YAML config and 3) [`defaults.py`](lagrangebench/defaults.py) (`lagrangebench` defaults).

When loading a saved model with `--model_dir` the config from the checkpoint is automatically loaded and training is restarted. For more details check the [`experiments/`](experiments/) directory and the [`run.py`](experiments/run.py) file.

Expand Down Expand Up @@ -94,8 +94,8 @@ The datasets are hosted on Zenodo under the DOI: [10.5281/zenodo.10021925](https


### Notebooks
Whe provide three notebooks that show LagrangeBench functionalities, namely:
- [`tutorial.ipynb`](notebooks/tutorial.ipynb) with a general overview of LagrangeBench library, with trainin and evaluation of a simple GNS model,
We provide three notebooks that show LagrangeBench functionalities, namely:
- [`tutorial.ipynb`](notebooks/tutorial.ipynb) with a general overview of LagrangeBench library, with training and evaluation of a simple GNS model,
- [`datasets.ipynb`](notebooks/datasets.ipynb) with more details and visualizations on the datasets, and
- [`gns_data.ipynb`](notebooks/gns_data.ipynb) showing how to train models within LagrangeBench on the datasets from the paper [Learning to Simulate Complex Physics with Graph Networks](https://arxiv.org/abs/2002.09405).

Expand Down
2 changes: 2 additions & 0 deletions configs/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,5 @@ metrics_infer:
metrics_stride_infer: 1
out_type_infer: pkl
eval_n_trajs_infer: -1
# batch size for validation/testing
batch_size_infer: 2
4 changes: 2 additions & 2 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

cloudpickle
dm_haiku>=0.0.10
e3nn_jax>=0.20.0
e3nn_jax==0.20.3
h5py
jax[cpu]==0.4.20
jax_md>=0.2.8
Expand All @@ -15,6 +15,6 @@ pyvista
PyYAML
sphinx==7.2.6
sphinx-rtd-theme==1.3.0
torch>=2.1.0+cpu
torch==2.1.0+cpu
wandb
wget
6 changes: 4 additions & 2 deletions experiments/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import jax.numpy as jnp
import jmp
import numpy as np
import wandb
import yaml

import wandb
from experiments.utils import setup_data, setup_model
from lagrangebench import Trainer, infer
from lagrangebench.case_setup import case_builder
Expand Down Expand Up @@ -123,6 +123,7 @@ def train_or_infer(args: Namespace):
eval_steps=args.config.eval_steps,
metrics_stride=args.config.metrics_stride,
num_workers=args.config.num_workers,
batch_size_infer=args.config.batch_size_infer,
)
_, _, _ = trainer(
step_max=args.config.step_max,
Expand Down Expand Up @@ -150,7 +151,7 @@ def train_or_infer(args: Namespace):
metrics = infer(
model,
case,
data_test,
data_test if args.config.test else data_valid,
load_checkpoint=args.config.model_dir,
metrics=args.config.metrics_infer,
rollout_dir=args.config.rollout_dir,
Expand All @@ -160,6 +161,7 @@ def train_or_infer(args: Namespace):
n_extrap_steps=args.config.n_extrap_steps,
seed=args.config.seed,
metrics_stride=args.config.metrics_stride_infer,
batch_size=args.config.batch_size_infer,
)

split = "test" if args.config.test else "valid"
Expand Down
10 changes: 4 additions & 6 deletions lagrangebench/case_setup/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Callable, Dict, Optional, Tuple, Union

import jax.numpy as jnp
from jax import jit, lax, random, vmap
from jax import Array, jit, lax, vmap
from jax_md import space
from jax_md.dataclasses import dataclass, static_field
from jax_md.partition import NeighborList, NeighborListFormat
Expand All @@ -15,16 +15,14 @@
from .features import FeatureDict, TargetDict, physical_feature_builder
from .partition import neighbor_list

TrainCaseOut = Tuple[random.KeyArray, FeatureDict, TargetDict, NeighborList]
TrainCaseOut = Tuple[Array, FeatureDict, TargetDict, NeighborList]
EvalCaseOut = Tuple[FeatureDict, NeighborList]
SampleIn = Tuple[jnp.ndarray, jnp.ndarray]

AllocateFn = Callable[[random.KeyArray, SampleIn, float, int], TrainCaseOut]
AllocateFn = Callable[[Array, SampleIn, float, int], TrainCaseOut]
AllocateEvalFn = Callable[[SampleIn], EvalCaseOut]

PreprocessFn = Callable[
[random.KeyArray, SampleIn, float, NeighborList, int], TrainCaseOut
]
PreprocessFn = Callable[[Array, SampleIn, float, NeighborList, int], TrainCaseOut]
PreprocessEvalFn = Callable[[SampleIn, NeighborList], EvalCaseOut]

IntegrateFn = Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]
Expand Down
2 changes: 1 addition & 1 deletion lagrangebench/case_setup/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def scan_body(carry, input):
if not is_sparse(format):
capacity_limit = N - 1 if mask_self else N
elif format is NeighborListFormat.Sparse:
capacity_limit = N * (N - 1) if mask_self else N ** 2
capacity_limit = N * (N - 1) if mask_self else N**2
else:
capacity_limit = N * (N - 1) // 2
if max_occupancy > capacity_limit:
Expand Down
2 changes: 1 addition & 1 deletion lagrangebench/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def get_window(self, idx: int):
def __getitem__(self, idx: int):
"""
Get a sequence of positions (of size windows) from the dataset at index idx.
Returns:
Array of shape (num_particles_max, input_seq_length + 1, dim). Along axis=1
the position sequence (length input_seq_length) and the last position to
Expand Down
1 change: 1 addition & 0 deletions lagrangebench/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class defaults:
out_type: str = "none" # type of output. None means no rollout is stored
n_extrap_steps: int = 0 # number of extrapolation steps
metrics_stride: int = 10 # stride for e_kin and sinkhorn
batch_size_infer: int = 2 # batch size for validation/testing

# logging
log_steps: int = 1000 # number of steps between logs
Expand Down
2 changes: 1 addition & 1 deletion lagrangebench/evaluate/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
metadata: Metadata of the dataset.
loss_ranges: List of horizon lengths to compute the loss for.
input_seq_length: Length of the input sequence.
stride: Rollout subsample frequency for Sinkhorn.
stride: Rollout subsample frequency for e_kin and sinkhorn.
ot_backend: Backend for sinkhorn computation. "ott" or "pot".
"""
if active_metrics is None:
Expand Down
Loading

0 comments on commit af8d7cd

Please sign in to comment.