From b330ed577c887334dd59643a65f346ff1e4cf329 Mon Sep 17 00:00:00 2001 From: gerkone Date: Thu, 15 Feb 2024 17:29:36 +0100 Subject: [PATCH 01/13] yacs.CfgNode instead of argparse.Namespace --- experiments/config.py | 220 ------------------------------ experiments/run.py | 168 ++++++++--------------- experiments/utils.py | 135 ++++++++++-------- lagrangebench/__init__.py | 2 - lagrangebench/case_setup/case.py | 25 ++-- lagrangebench/config.py | 190 ++++++++++++++++++++++++++ lagrangebench/defaults.py | 70 ---------- lagrangebench/evaluate/rollout.py | 35 ++--- lagrangebench/models/segnn.py | 15 +- lagrangebench/train/strats.py | 6 +- lagrangebench/train/trainer.py | 102 +++++--------- lagrangebench/utils.py | 30 +--- main.py | 80 ++++++++--- tests/rollout_test.py | 2 + 14 files changed, 464 insertions(+), 616 deletions(-) delete mode 100644 experiments/config.py create mode 100644 lagrangebench/config.py delete mode 100644 lagrangebench/defaults.py diff --git a/experiments/config.py b/experiments/config.py deleted file mode 100644 index c7d0438..0000000 --- a/experiments/config.py +++ /dev/null @@ -1,220 +0,0 @@ -import argparse -import os -from typing import Dict - -import yaml - - -def cli_arguments() -> Dict: - parser = argparse.ArgumentParser() - group = parser.add_mutually_exclusive_group(required=True) - - # config arguments - group.add_argument("-c", "--config", type=str, help="Path to the config yaml.") - group.add_argument("--model_dir", type=str, help="Path to the model checkpoint.") - - # run arguments - parser.add_argument( - "--mode", type=str, choices=["train", "infer", "all"], help="Train or evaluate." - ) - parser.add_argument("--batch_size", type=int, required=False, help="Batch size.") - parser.add_argument( - "--lr_start", type=float, required=False, help="Starting learning rate." - ) - parser.add_argument( - "--lr_final", type=float, required=False, help="Learning rate after decay." - ) - parser.add_argument( - "--lr_decay_rate", type=float, required=False, help="Learning rate decay." - ) - parser.add_argument( - "--lr_decay_steps", type=int, required=False, help="Learning rate decay steps." - ) - parser.add_argument( - "--noise_std", - type=float, - required=False, - help="Additive noise standard deviation.", - ) - parser.add_argument( - "--test", - action=argparse.BooleanOptionalAction, - help="Run test mode instead of validation.", - ) - parser.add_argument("--seed", type=int, required=False, help="Random seed.") - parser.add_argument( - "--data_dir", type=str, help="Absolute/relative path to the dataset." - ) - parser.add_argument("--ckp_dir", type=str, help="Path for checkpoints.") - - # model arguments - parser.add_argument( - "--model", - type=str, - help="Model name.", - ) - parser.add_argument( - "--input_seq_length", - type=int, - required=False, - help="Input position sequence length.", - ) - parser.add_argument( - "--num_mp_steps", - type=int, - required=False, - help="Number of message passing layers.", - ) - parser.add_argument( - "--num_mlp_layers", type=int, required=False, help="Number of MLP layers." - ) - parser.add_argument( - "--latent_dim", type=int, required=False, help="Hidden layer dimension." - ) - parser.add_argument( - "--magnitude_features", - action=argparse.BooleanOptionalAction, - help="Whether to include velocity magnitudes in node features.", - ) - parser.add_argument( - "--isotropic_norm", - action=argparse.BooleanOptionalAction, - help="Use isotropic normalization.", - ) - - # output arguments - parser.add_argument( - "--out_type", - type=str, - required=False, - choices=["vtk", "pkl", "none"], - help="Output type to store rollouts during validation.", - ) - parser.add_argument( - "--out_type_infer", - type=str, - required=False, - choices=["vtk", "pkl", "none"], - help="Output type to store rollouts during inference.", - ) - parser.add_argument( - "--rollout_dir", type=str, required=False, help="Directory to write rollouts." - ) - - # segnn-specific arguments - parser.add_argument( - "--lmax_attributes", - type=int, - required=False, - help="Maximum degree of attributes.", - ) - parser.add_argument( - "--lmax_hidden", - type=int, - required=False, - help="Maximum degree of hidden layers.", - ) - parser.add_argument( - "--segnn_norm", - type=str, - required=False, - choices=["instance", "batch", "none"], - help="Normalisation type.", - ) - parser.add_argument( - "--velocity_aggregate", - type=str, - required=False, - choices=["avg", "sum", "last", "all"], - help="Velocity aggregation function for node attributes.", - ) - parser.add_argument( - "--attribute_mode", - type=str, - required=False, - choices=["add", "concat", "velocity"], - help="How to combine node attributes.", - ) - # HAE-specific arguments - parser.add_argument( - "--right_attribute", - required=False, - action=argparse.BooleanOptionalAction, - help="Whether to use last velocity to steer the attribute embedding.", - ) - parser.add_argument( - "--attribute_embedding_blocks", - required=False, - type=int, - help="Number of embedding layers for the attributes.", - ) - - # misc arguments - parser.add_argument( - "--gpu", type=int, required=False, help="CUDA device ID to use." - ) - parser.add_argument( - "--f64", - required=False, - action=argparse.BooleanOptionalAction, - help="Whether to use double precision.", - ) - - parser.add_argument( - "--eval_n_trajs", - required=False, - type=int, - help="Number of trajectories to evaluate during validation.", - ) - parser.add_argument( - "--eval_n_trajs_infer", - required=False, - type=int, - help="Number of trajectories to evaluate during inference.", - ) - - parser.add_argument( - "--metrics", - required=False, - nargs="+", - help="Validation metrics to evaluate. Choose from: mse, mae, sinkhorn, e_kin.", - ) - parser.add_argument( - "--metrics_infer", - required=False, - nargs="+", - help="Inference metrics to evaluate during inference.", - ) - parser.add_argument( - "--metrics_stride", - required=False, - type=int, - help="Stride for Sinkhorn and e_kin during validation", - ) - parser.add_argument( - "--metrics_stride_infer", - required=False, - type=int, - help="Stride for Sinkhorn and e_kin during inference.", - ) - parser.add_argument( - "--n_rollout_steps", - required=False, - type=int, - help="Number of rollout steps during validation/testing.", - ) - # only keep passed arguments to avoid overwriting config - return {k: v for k, v in vars(parser.parse_args()).items() if v is not None} - - -class NestedLoader(yaml.SafeLoader): - """Load yaml files with nested configs.""" - - def get_single_data(self): - parent = {} - config = super().get_single_data() - if "extends" in config and (included := config["extends"]): - del config["extends"] - with open(os.path.join("configs", included), "r") as f: - parent = yaml.load(f, NestedLoader) - return {**parent, **config} diff --git a/experiments/run.py b/experiments/run.py index 33494ea..691efd9 100644 --- a/experiments/run.py +++ b/experiments/run.py @@ -1,56 +1,53 @@ -import copy import os import os.path as osp -from argparse import Namespace from datetime import datetime import haiku as hk -import jax.numpy as jnp import jmp import numpy as np import wandb -import yaml from experiments.utils import setup_data, setup_model from lagrangebench import Trainer, infer from lagrangebench.case_setup import case_builder +from lagrangebench.config import cfg_to_dict from lagrangebench.evaluate import averaged_metrics -from lagrangebench.utils import PushforwardConfig -def train_or_infer(args: Namespace): - data_train, data_valid, data_test, args = setup_data(args) +def train_or_infer(cfg): + mode = cfg.mode + old_model_dir = cfg.model.model_dir + is_test = cfg.eval.test + data_train, data_valid, data_test, dataset_name = setup_data(cfg) + + exp_info = {"dataset_name": dataset_name} + + metadata = data_train.metadata # neighbors search - bounds = np.array(data_train.metadata["bounds"]) - args.box = bounds[:, 1] - bounds[:, 0] + bounds = np.array(metadata["bounds"]) + box = bounds[:, 1] - bounds[:, 0] - args.info.len_train = len(data_train) - args.info.len_eval = len(data_valid) + exp_info["len_train"] = len(data_train) + exp_info["len_eval"] = len(data_valid) # setup core functions case = case_builder( - box=args.box, - metadata=data_train.metadata, - input_seq_length=args.config.input_seq_length, - isotropic_norm=args.config.isotropic_norm, - noise_std=args.config.noise_std, - magnitude_features=args.config.magnitude_features, + box=box, + metadata=metadata, external_force_fn=data_train.external_force_fn, - neighbor_list_backend=args.config.neighbor_list_backend, - neighbor_list_multiplier=args.config.neighbor_list_multiplier, - dtype=(jnp.float64 if args.config.f64 else jnp.float32), ) _, particle_type = data_train[0] - args.info.homogeneous_particles = particle_type.max() == particle_type.min() - args.metadata = data_train.metadata - args.normalization_stats = case.normalization_stats - args.config.has_external_force = data_train.external_force_fn is not None - # setup model from configs - model, MODEL = setup_model(args) + model, MODEL = setup_model( + cfg, + metadata=metadata, + homogeneous_particles=particle_type.max() == particle_type.min(), + has_external_force=data_train.external_force_fn is not None, + normalization_stats=case.normalization_stats, + ) model = hk.without_apply_rng(hk.transform_with_state(model)) # mixed precision training based on this reference: @@ -58,112 +55,67 @@ def train_or_infer(args: Namespace): policy = jmp.get_policy("params=float32,compute=float32,output=float32") hk.mixed_precision.set_policy(MODEL, policy) - if args.config.mode == "train" or args.config.mode == "all": + if mode == "train" or mode == "all": print("Start training...") # save config file - run_prefix = f"{args.config.model}_{data_train.name}" + run_prefix = f"{cfg.model.name}_{data_train.name}" data_and_time = datetime.today().strftime("%Y%m%d-%H%M%S") - args.info.run_name = f"{run_prefix}_{data_and_time}" + exp_info["run_name"] = f"{run_prefix}_{data_and_time}" - args.config.new_checkpoint = os.path.join( - args.config.ckp_dir, args.info.run_name - ) - os.makedirs(args.config.new_checkpoint, exist_ok=True) - os.makedirs(os.path.join(args.config.new_checkpoint, "best"), exist_ok=True) - with open(os.path.join(args.config.new_checkpoint, "config.yaml"), "w") as f: - yaml.dump(vars(args.config), f) - with open( - os.path.join(args.config.new_checkpoint, "best", "config.yaml"), "w" - ) as f: - yaml.dump(vars(args.config), f) - - if args.config.wandb: - # wandb doesn't like Namespace objects - args_dict = copy.copy(args) - args_dict.config = vars(args.config) - args_dict.info = vars(args.info) + cfg.model.model_dir = os.path.join(cfg.logging.ckp_dir, exp_info["run_name"]) + os.makedirs(cfg.model.model_dir, exist_ok=True) + os.makedirs(os.path.join(cfg.model.model_dir, "best"), exist_ok=True) + with open(os.path.join(cfg.model.model_dir, "config.yaml"), "w") as f: + cfg.dump(stream=f) + with open(os.path.join(cfg.model.model_dir, "best", "config.yaml"), "w") as f: + cfg.dump(stream=f) + + if cfg.logging.wandb: + cfg_dict = cfg_to_dict(cfg) + cfg_dict.update(exp_info) wandb_run = wandb.init( - project=args.config.wandb_project, - entity=args.config.wandb_entity, - name=args.info.run_name, - config=args_dict, + project=cfg.logging.wandb_project, + entity=cfg.logging.wandb_entity, + name=cfg.logging.run_name, + config=cfg_dict, save_code=True, ) else: wandb_run = None - pf_config = PushforwardConfig( - steps=args.config.pushforward["steps"], - unrolls=args.config.pushforward["unrolls"], - probs=args.config.pushforward["probs"], - ) - - trainer = Trainer( - model, - case, - data_train, - data_valid, - pushforward=pf_config, - metrics=args.config.metrics, - seed=args.config.seed, - batch_size=args.config.batch_size, - input_seq_length=args.config.input_seq_length, - noise_std=args.config.noise_std, - lr_start=args.config.lr_start, - lr_final=args.config.lr_final, - lr_decay_steps=args.config.lr_decay_steps, - lr_decay_rate=args.config.lr_decay_rate, - loss_weight=args.config.loss_weight, - n_rollout_steps=args.config.n_rollout_steps, - eval_n_trajs=args.config.eval_n_trajs, - rollout_dir=args.config.rollout_dir, - out_type=args.config.out_type, - log_steps=args.config.log_steps, - eval_steps=args.config.eval_steps, - metrics_stride=args.config.metrics_stride, - num_workers=args.config.num_workers, - batch_size_infer=args.config.batch_size_infer, - ) + trainer = Trainer(model, case, data_train, data_valid) _, _, _ = trainer( - step_max=args.config.step_max, - load_checkpoint=args.config.model_dir, - store_checkpoint=args.config.new_checkpoint, + step_max=cfg.train.step_max, + load_checkpoint=old_model_dir, + store_checkpoint=cfg.model.model_dir, wandb_run=wandb_run, ) - if args.config.wandb: + if cfg.logging.wandb: wandb.finish() - if args.config.mode == "infer" or args.config.mode == "all": + if mode == "infer" or mode == "all": print("Start inference...") - if args.config.mode == "all": - args.config.model_dir = os.path.join(args.config.new_checkpoint, "best") - assert osp.isfile(os.path.join(args.config.model_dir, "params_tree.pkl")) + best_model_dir = old_model_dir + if mode == "all": + best_model_dir = os.path.join(cfg.model.model_dir, "best") + assert osp.isfile(os.path.join(best_model_dir, "params_tree.pkl")) - args.config.rollout_dir = args.config.model_dir.replace("ckp", "rollout") - os.makedirs(args.config.rollout_dir, exist_ok=True) + cfg.eval.rollout_dir = best_model_dir.replace("ckp", "rollout") + os.makedirs(cfg.eval.rollout_dir, exist_ok=True) - if args.config.eval_n_trajs_infer is None: - args.config.eval_n_trajs_infer = args.config.eval_n_trajs + if cfg.eval.n_trajs_infer is None: + cfg.eval.n_trajs_infer = cfg.eval.n_trajs_train - assert args.config.model_dir, "model_dir must be specified for inference." + assert old_model_dir, "model_dir must be specified for inference." metrics = infer( model, case, - data_test if args.config.test else data_valid, - load_checkpoint=args.config.model_dir, - metrics=args.config.metrics_infer, - rollout_dir=args.config.rollout_dir, - eval_n_trajs=args.config.eval_n_trajs_infer, - n_rollout_steps=args.config.n_rollout_steps, - out_type=args.config.out_type_infer, - n_extrap_steps=args.config.n_extrap_steps, - seed=args.config.seed, - metrics_stride=args.config.metrics_stride_infer, - batch_size=args.config.batch_size_infer, + data_test if is_test else data_valid, + load_checkpoint=best_model_dir, ) - split = "test" if args.config.test else "valid" - print(f"Metrics of {args.config.model_dir} on {split} split:") + split = "test" if is_test else "valid" + print(f"Metrics of {best_model_dir} on {split} split:") print(averaged_metrics(metrics)) diff --git a/experiments/utils.py b/experiments/utils.py index 8168178..458cdc5 100644 --- a/experiments/utils.py +++ b/experiments/utils.py @@ -1,7 +1,7 @@ import os import os.path as osp from argparse import Namespace -from typing import Callable, Tuple, Type +from typing import Callable, Dict, Optional, Tuple, Type import jax import jax.numpy as jnp @@ -14,78 +14,97 @@ from lagrangebench.utils import NodeType -def setup_data(args: Namespace) -> Tuple[H5Dataset, H5Dataset, Namespace]: - if not osp.isabs(args.config.data_dir): - args.config.data_dir = osp.join(os.getcwd(), args.config.data_dir) +def setup_data(cfg) -> Tuple[H5Dataset, H5Dataset, Namespace]: + data_dir = cfg.data_dir + ckp_dir = cfg.logging.ckp_dir + rollout_dir = cfg.eval.rollout_dir + input_seq_length = cfg.model.input_seq_length + n_rollout_steps = cfg.eval.n_rollout_steps + neighbor_list_backend = cfg.neighbors.backend + if not osp.isabs(data_dir): + data_dir = osp.join(os.getcwd(), data_dir) - args.info.dataset_name = osp.basename(args.config.data_dir.split("/")[-1]) - if args.config.ckp_dir is not None: - os.makedirs(args.config.ckp_dir, exist_ok=True) - if args.config.rollout_dir is not None: - os.makedirs(args.config.rollout_dir, exist_ok=True) + dataset_name = osp.basename(data_dir.split("/")[-1]) + if ckp_dir is not None: + os.makedirs(ckp_dir, exist_ok=True) + if rollout_dir is not None: + os.makedirs(rollout_dir, exist_ok=True) # dataloader data_train = H5Dataset( "train", - dataset_path=args.config.data_dir, - input_seq_length=args.config.input_seq_length, - extra_seq_length=args.config.pushforward["unrolls"][-1], - nl_backend=args.config.neighbor_list_backend, + dataset_path=data_dir, + input_seq_length=input_seq_length, + extra_seq_length=cfg.optimizer.pushforward.unrolls[-1], + nl_backend=neighbor_list_backend, ) data_valid = H5Dataset( "valid", - dataset_path=args.config.data_dir, - input_seq_length=args.config.input_seq_length, - extra_seq_length=args.config.n_rollout_steps, - nl_backend=args.config.neighbor_list_backend, + dataset_path=data_dir, + input_seq_length=input_seq_length, + extra_seq_length=n_rollout_steps, + nl_backend=neighbor_list_backend, ) data_test = H5Dataset( "test", - dataset_path=args.config.data_dir, - input_seq_length=args.config.input_seq_length, - extra_seq_length=args.config.n_rollout_steps, - nl_backend=args.config.neighbor_list_backend, + dataset_path=data_dir, + input_seq_length=input_seq_length, + extra_seq_length=n_rollout_steps, + nl_backend=neighbor_list_backend, ) - if args.config.eval_n_trajs == -1: - args.config.eval_n_trajs = data_valid.num_samples - if args.config.eval_n_trajs_infer == -1: - args.config.eval_n_trajs_infer = data_valid.num_samples - assert data_valid.num_samples >= args.config.eval_n_trajs, ( + + # TODO find another way to set these + if cfg.eval.n_trajs_train == -1: + cfg.eval.n_trajs_train = data_valid.num_samples + if cfg.eval.n_trajs_infer == -1: + cfg.eval.n_trajs_infer = data_valid.num_samples + + assert data_valid.num_samples >= cfg.eval.n_trajs_train, ( f"Number of available evaluation trajectories ({data_valid.num_samples}) " - f"exceeds eval_n_trajs ({args.config.eval_n_trajs})" + f"exceeds eval_n_trajs ({cfg.eval.n_trajs_train})" ) - args.info.has_external_force = bool(data_train.external_force_fn is not None) + return data_train, data_valid, data_test, dataset_name + - return data_train, data_valid, data_test, args +def setup_model( + cfg, + metadata: Dict, + homogeneous_particles: bool = False, + has_external_force: bool = False, + normalization_stats: Optional[Dict] = None, +) -> Tuple[Callable, Type]: + """Setup model based on cfg.""" + model_name = cfg.model.name.lower() + latent_dim = cfg.model.latent_dim + num_mlp_layers = cfg.model.num_mlp_layers + num_mp_steps = cfg.model.num_mp_steps -def setup_model(args: Namespace) -> Tuple[Callable, Type]: - """Setup model based on args.""" - model_name = args.config.model.lower() - metadata = args.metadata + input_seq_length = cfg.model.input_seq_length + magnitude_features = cfg.train.magnitude_features if model_name == "gns": def model_fn(x): return models.GNS( particle_dimension=metadata["dim"], - latent_size=args.config.latent_dim, - blocks_per_step=args.config.num_mlp_layers, - num_mp_steps=args.config.num_mp_steps, + latent_size=latent_dim, + blocks_per_step=num_mlp_layers, + num_mp_steps=num_mp_steps, num_particle_types=NodeType.SIZE, particle_type_embedding_size=16, )(x) MODEL = models.GNS elif model_name == "segnn": + segnn_cfg = cfg.model.segnn # Hx1o vel, Hx0e vel, 2x1o boundary, 9x0e type node_feature_irreps = node_irreps( metadata, - args.config.input_seq_length, - args.config.has_external_force, - args.config.magnitude_features, - args.info.homogeneous_particles, + input_seq_length, + has_external_force, + homogeneous_particles, ) # 1o displacement, 0e distance edge_feature_irreps = Irreps("1x1o + 1x0e") @@ -94,21 +113,21 @@ def model_fn(x): return models.SEGNN( node_features_irreps=node_feature_irreps, edge_features_irreps=edge_feature_irreps, - scalar_units=args.config.latent_dim, - lmax_hidden=args.config.lmax_hidden, - lmax_attributes=args.config.lmax_attributes, + scalar_units=latent_dim, + lmax_hidden=segnn_cfg.lmax_hidden, + lmax_attributes=segnn_cfg.lmax_attributes, output_irreps=Irreps("1x1o"), - num_mp_steps=args.config.num_mp_steps, - n_vels=args.config.input_seq_length - 1, - velocity_aggregate=args.config.velocity_aggregate, - homogeneous_particles=args.info.homogeneous_particles, - blocks_per_step=args.config.num_mlp_layers, - norm=args.config.segnn_norm, + num_mp_steps=num_mp_steps, + n_vels=input_seq_length - 1, + velocity_aggregate=segnn_cfg.velocity_aggregate, + homogeneous_particles=cfg.train.homogeneous_particles, + blocks_per_step=num_mlp_layers, + norm=segnn_cfg.segnn_norm, )(x) MODEL = models.SEGNN elif model_name == "egnn": - box = args.box + box = cfg.box if jnp.array(metadata["periodic_boundary_conditions"]).any(): displacement_fn, shift_fn = space.periodic(jnp.array(box)) else: @@ -119,30 +138,30 @@ def model_fn(x): def model_fn(x): return models.EGNN( - hidden_size=args.config.latent_dim, + hidden_size=cfg.latent_dim, output_size=1, dt=metadata["dt"] * metadata["write_every"], displacement_fn=displacement_fn, shift_fn=shift_fn, - normalization_stats=args.normalization_stats, - num_mp_steps=args.config.num_mp_steps, - n_vels=args.config.input_seq_length - 1, + normalization_stats=normalization_stats, + num_mp_steps=num_mp_steps, + n_vels=input_seq_length - 1, residual=True, )(x) MODEL = models.EGNN elif model_name == "painn": - assert args.config.magnitude_features, "PaiNN requires magnitudes" + assert magnitude_features, "PaiNN requires magnitudes" radius = metadata["default_connectivity_radius"] * 1.5 def model_fn(x): return models.PaiNN( - hidden_size=args.config.latent_dim, + hidden_size=latent_dim, output_size=1, - n_vels=args.config.input_seq_length - 1, + n_vels=input_seq_length - 1, radial_basis_fn=models.painn.gaussian_rbf(20, radius, trainable=True), cutoff_fn=models.painn.cosine_cutoff(radius), - num_mp_steps=args.config.num_mp_steps, + num_mp_steps=num_mp_steps, )(x) MODEL = models.PaiNN diff --git a/lagrangebench/__init__.py b/lagrangebench/__init__.py index 39cf0eb..d157e76 100644 --- a/lagrangebench/__init__.py +++ b/lagrangebench/__init__.py @@ -3,7 +3,6 @@ from .evaluate import infer from .models import EGNN, GNS, SEGNN, PaiNN from .train.trainer import Trainer -from .utils import PushforwardConfig __all__ = [ "Trainer", @@ -21,7 +20,6 @@ "LDC2D", "LDC3D", "DAM2D", - "PushforwardConfig", ] __version__ = "0.0.1" diff --git a/lagrangebench/case_setup/case.py b/lagrangebench/case_setup/case.py index 21ee2ec..28b624b 100644 --- a/lagrangebench/case_setup/case.py +++ b/lagrangebench/case_setup/case.py @@ -9,8 +9,8 @@ from jax_md.dataclasses import dataclass, static_field from jax_md.partition import NeighborList, NeighborListFormat +from lagrangebench.config import cfg from lagrangebench.data.utils import get_dataset_stats -from lagrangebench.defaults import defaults from lagrangebench.train.strats import add_gns_noise from .features import FeatureDict, TargetDict, physical_feature_builder @@ -62,14 +62,7 @@ class CaseSetupFn: def case_builder( box: Tuple[float, float, float], metadata: Dict, - input_seq_length: int, - isotropic_norm: bool = defaults.isotropic_norm, - noise_std: float = defaults.noise_std, external_force_fn: Optional[Callable] = None, - magnitude_features: bool = defaults.magnitude_features, - neighbor_list_backend: str = defaults.neighbor_list_backend, - neighbor_list_multiplier: float = defaults.neighbor_list_multiplier, - dtype: jnp.dtype = defaults.dtype, ): """Set up a CaseSetupFn that contains every required function besides the model. @@ -83,15 +76,17 @@ def case_builder( Args: box: Box xyz sizes of the system. metadata: Dataset metadata dictionary. - input_seq_length: Length of the input sequence. - isotropic_norm: Whether to normalize dimensions equally. - noise_std: Noise standard deviation. external_force_fn: External force function. - magnitude_features: Whether to add velocity magnitudes in the features. - neighbor_list_backend: Backend of the neighbor list. - neighbor_list_multiplier: Capacity multiplier of the neighbor list. - dtype: Data type. """ + + input_seq_length = cfg.model.input_seq_length + isotropic_norm = cfg.train.isotropic_norm + noise_std = cfg.optimizer.noise_std + magnitude_features = cfg.train.magnitude_features + neighbor_list_backend = cfg.neighbors.backend + neighbor_list_multiplier = cfg.neighbors.multiplier + dtype = cfg.dtype + normalization_stats = get_dataset_stats(metadata, isotropic_norm, noise_std) # apply PBC in all directions or not at all diff --git a/lagrangebench/config.py b/lagrangebench/config.py new file mode 100644 index 0000000..4302d42 --- /dev/null +++ b/lagrangebench/config.py @@ -0,0 +1,190 @@ +from typing import Any, Dict, List, Optional + +import yaml +from yacs.config import CfgNode as CN + +# lagrangebench-wide config object +cfg = CN() + + +__custom_cfg_fn: Dict[str, Any] = {} + + +def custom_config(fn): + """ "Decorator to add custom config functions.""" + __custom_cfg_fn[fn.__name__] = fn + return fn + + +def defaults(cfg): + """Default lagrangebench values.""" + + if cfg is None: + raise ValueError("cfg should be a yacs CfgNode") + + # random seed + cfg.seed = 0 + # data type for preprocessing + cfg.dtype = "float64" + + # data directory + cfg.data_dir = None + # run, evaluation or both + cfg.mode = "all" + + # model + model = CN() + + model.name = None + # Length of the position input sequence + model.input_seq_length = 6 + # Number of message passing steps + model.num_mp_steps = 10 + # Number of MLP layers + model.num_mlp_layers = 2 + # Hidden dimension + model.latent_dim = 128 + # Load checkpointed model from this directory + model.model_dir = None + + cfg.model = model + + # training + train = CN() + + # batch size + train.batch_size = 1 + # max number of training steps + train.step_max = 500_000 + # whether to include velocity magnitude features + train.magnitude_features = False + # whether to normalize dimensions equally + train.isotropic_norm = False + # number of workers for data loading + train.num_workers = 4 + + cfg.train = train + + # optimizer + optimizer = CN() + + # initial learning rate + optimizer.lr_start = 1e-4 + # final learning rate (after exponential decay) + optimizer.lr_final = 1e-6 + # learning rate decay rate + optimizer.lr_decay_rate = 0.1 + # number of steps to decay learning rate + optimizer.lr_decay_steps = 1e5 + # standard deviation of the GNS-style noise + optimizer.noise_std = 3e-4 + + # optimizer: pushforward + pushforward = CN() + # At which training step to introduce next unroll stage + pushforward.steps = [-1, 200000, 300000, 400000] + # For how many steps to unroll + pushforward.unrolls = [0, 1, 2, 3] + # Which probability ratio to keep between the unrolls + pushforward.probs = [18, 2, 1, 1] + + # optimizer: loss weights + loss_weight = CN() + # weight for acceleration error + loss_weight.acc = 1.0 + # weight for velocity error + loss_weight.vel = 0.0 + # weight for position error + loss_weight.pos = 0.0 + + cfg.optimizer = optimizer + cfg.optimizer.loss_weight = loss_weight + cfg.optimizer.pushforward = pushforward + + # evaluation + eval = CN() + + # number of eval rollout steps. -1 is full rollout + eval.n_rollout_steps = 20 + # number of trajectories to evaluate during training + eval.n_trajs_train = 1 + # number of trajectories to evaluate during inference + eval.n_trajs_infer = 50 + # metrics for training + eval.metrics_train = ["mse"] + # stride for e_kin and sinkhorn + eval.metrics_stride_train = 10 + # metrics for inference + eval.metrics_infer = ["mse", "e_kin", "sinkhorn"] + # stride for e_kin and sinkhorn + eval.metrics_stride_infer = 1 + # number of extrapolation steps in inference + eval.n_extrap_steps = 0 + # batch size for validation/testing + eval.batch_size_infer = 2 + # loggingging directory + eval.out_type = None + # rollouts directory + eval.rollout_dir = None + # whether to use the test split + eval.test = False + + cfg.eval = eval + + # logging + logging = CN() + + # number of steps between loggings + logging.log_steps = 1000 + # number of steps between evaluations and checkpoints + logging.eval_steps = 10000 + # wandb enable + logging.wandb = False + # wandb project name + logging.wandb_project = None + # wandb entity name + logging.wandb_entity = "lagrangebench" + # checkpoint directory + logging.ckp_dir = "ckp" + + cfg.logging = logging + + # neighbor list + neighbors = CN() + + # backend for neighbor list computation + neighbors.backend = "jaxmd_vmap" + # multiplier for neighbor list capacity + neighbors.multiplier = 1.25 + + cfg.neighbors = neighbors + + # custom and user configs + for cfg_fn in __custom_cfg_fn.values(): + cfg_fn(cfg) + + +def check_cfg(cfg): + assert cfg.data_dir is not None, "cfg.data_dir must be specified." + assert ( + cfg.train.step_max is not None and cfg.train.step_max > 0 + ), "cfg.train.step_max must be specified and larger than 0." + + +def load_cfg(cfg: CN, config_path: str, extra_args: Optional[List] = None): + if cfg is None: + raise ValueError("cfg should be a yacs CfgNode") + if len(cfg) == 0: + defaults(cfg) + cfg.merge_from_file(config_path) + if extra_args is not None: + cfg.merge_from_list(extra_args) + check_cfg(cfg) + + +def cfg_to_dict(cfg: CN) -> Dict: + return yaml.safe_load(cfg.dump()) + + +# TODO find a better way +defaults(cfg) diff --git a/lagrangebench/defaults.py b/lagrangebench/defaults.py deleted file mode 100644 index 9cb3c22..0000000 --- a/lagrangebench/defaults.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Default lagrangebench values.""" - -from dataclasses import dataclass - -import jax.numpy as jnp - - -@dataclass(frozen=True) -class defaults: - """ - Default lagrangebench values. - - Attributes: - seed: random seed. Default 0. - batch_size: batch size. Default 1. - step_max: max number of training steps. Default ``1e7``. - dtype: data type. Default ``jnp.float32``. - magnitude_features: whether to include velocity magnitudes. Default False. - isotropic_norm: whether to normalize dimensions equally. Default False. - lr_start: initial learning rate. Default 1e-4. - lr_final: final learning rate (after exponential decay). Default 1e-6. - lr_decay_steps: number of steps to decay learning rate - lr_decay_rate: learning rate decay rate. Default 0.1. - noise_std: standard deviation of the GNS-style noise. Default 1e-4. - input_seq_length: number of input steps. Default 6. - n_rollout_steps: number of eval rollout steps. -1 is full rollout. Default -1. - eval_n_trajs: number of trajectories to evaluate. Default 1 trajectory. - rollout_dir: directory to save rollouts. Default None. - out_type: type of output. None means no rollout is stored. Default None. - n_extrap_steps: number of extrapolation steps. Default 0. - log_steps: number of steps between logs. Default 1000. - eval_steps: number of steps between evaluations and checkpoints. Default 5000. - neighbor_list_backend: neighbor list routine. Default "jaxmd_vmap". - neighbor_list_multiplier: multiplier for neighbor list capacity. Default 1.25. - """ - - # training - seed: int = 0 # random seed - batch_size: int = 1 # batch size - step_max: int = 5e5 # max number of training steps - dtype: jnp.dtype = jnp.float64 # data type for preprocessing - magnitude_features: bool = False # whether to include velocity magnitude features - isotropic_norm: bool = False # whether to normalize dimensions equally - num_workers: int = 4 # number of workers for data loading - - # learning rate - lr_start: float = 1e-4 # initial learning rate - lr_final: float = 1e-6 # final learning rate (after exponential decay) - lr_decay_steps: int = 1e5 # number of steps to decay learning rate - lr_decay_rate: float = 0.1 # learning rate decay rate - - noise_std: float = 3e-4 # standard deviation of the GNS-style noise - - # evaluation - input_seq_length: int = 6 # number of input steps - n_rollout_steps: int = -1 # number of eval rollout steps. -1 is full rollout - eval_n_trajs: int = 1 # number of trajectories to evaluate - rollout_dir: str = None # directory to save rollouts - out_type: str = "none" # type of output. None means no rollout is stored - n_extrap_steps: int = 0 # number of extrapolation steps - metrics_stride: int = 10 # stride for e_kin and sinkhorn - batch_size_infer: int = 2 # batch size for validation/testing - - # logging - log_steps: int = 1000 # number of steps between logs - eval_steps: int = 10000 # number of steps between evaluations and checkpoints - - # neighbor list - neighbor_list_backend: str = "jaxmd_vmap" # backend for neighbor list computation - neighbor_list_multiplier: float = 1.25 # multiplier for neighbor list capacity diff --git a/lagrangebench/evaluate/rollout.py b/lagrangebench/evaluate/rollout.py index dde7627..f112c02 100644 --- a/lagrangebench/evaluate/rollout.py +++ b/lagrangebench/evaluate/rollout.py @@ -4,7 +4,7 @@ import pickle import time from functools import partial -from typing import Callable, Iterable, List, Optional, Tuple +from typing import Callable, Iterable, Optional, Tuple import haiku as hk import jax @@ -13,9 +13,9 @@ from jax import jit, vmap from torch.utils.data import DataLoader +from lagrangebench.config import cfg from lagrangebench.data import H5Dataset from lagrangebench.data.utils import numpy_collate -from lagrangebench.defaults import defaults from lagrangebench.evaluate.metrics import MetricsComputer, MetricsDict from lagrangebench.evaluate.utils import write_vtk from lagrangebench.utils import ( @@ -74,7 +74,7 @@ def _forward_eval( return current_positions, state -def eval_batched_rollout( +def _eval_batched_rollout( forward_eval_vmap: Callable, preprocess_eval_vmap: Callable, case, @@ -237,7 +237,7 @@ def eval_rollout( # (pos_input_batch, particle_type_batch) = traj_batch_i # pos_input_batch.shape = (batch, num_particles, seq_length, dim) - example_rollout_batch, metrics_batch, neighbors = eval_batched_rollout( + example_rollout_batch, metrics_batch, neighbors = _eval_batched_rollout( forward_eval_vmap=forward_eval_vmap, preprocess_eval_vmap=preprocess_eval_vmap, case=case, @@ -314,15 +314,6 @@ def infer( params: Optional[hk.Params] = None, state: Optional[hk.State] = None, load_checkpoint: Optional[str] = None, - metrics: List = ["mse"], - rollout_dir: Optional[str] = None, - eval_n_trajs: int = defaults.eval_n_trajs, - n_rollout_steps: int = defaults.n_rollout_steps, - out_type: str = defaults.out_type, - n_extrap_steps: int = defaults.n_extrap_steps, - seed: int = defaults.seed, - metrics_stride: int = defaults.metrics_stride, - batch_size: int = defaults.batch_size_infer, ): """ Infer on a dataset, compute metrics and optionally save rollout in out_type format. @@ -357,21 +348,21 @@ def infer( else: params, state, _, _ = load_haiku(load_checkpoint) - key, seed_worker, generator = set_seed(seed) + key, seed_worker, generator = set_seed(cfg.seed) loader_test = DataLoader( dataset=data_test, - batch_size=batch_size, + batch_size=cfg.eval.batch_size_infer, collate_fn=numpy_collate, worker_init_fn=seed_worker, generator=generator, ) metrics_computer = MetricsComputer( - metrics, + cfg.eval.metrics_infer, dist_fn=case.displacement, metadata=data_test.metadata, input_seq_length=data_test.input_seq_length, - stride=metrics_stride, + stride=cfg.eval.metrics_stride_infer, ) # Precompile model model_apply = jit(model.apply) @@ -389,10 +380,10 @@ def infer( state=state, neighbors=neighbors, loader_eval=loader_test, - n_rollout_steps=n_rollout_steps, - n_trajs=eval_n_trajs, - rollout_dir=rollout_dir, - out_type=out_type, - n_extrap_steps=n_extrap_steps, + n_rollout_steps=cfg.eval.n_rollout_steps, + n_trajs=cfg.eval.n_trajs_infer, + rollout_dir=cfg.eval.rollout_dir, + out_type=cfg.eval.out_type, + n_extrap_steps=cfg.eval.n_extrap_steps, ) return eval_metrics diff --git a/lagrangebench/models/segnn.py b/lagrangebench/models/segnn.py index 5f3ee66..0241ec7 100644 --- a/lagrangebench/models/segnn.py +++ b/lagrangebench/models/segnn.py @@ -8,7 +8,6 @@ Standalone implementation + validation: https://github.com/gerkone/segnn-jax """ - import warnings from math import prod from typing import Any, Callable, Dict, Optional, Tuple, Union @@ -21,6 +20,7 @@ from e3nn_jax import Irreps, IrrepsArray from jax.tree_util import Partial, tree_map +from lagrangebench.config import custom_config from lagrangebench.utils import NodeType from .base import BaseModel @@ -608,3 +608,16 @@ def __call__( nodes = self._decoder(st_graph) out = self._postprocess(nodes, dim) return out + + +@custom_config +def segnn_config(cfg): + """SEGNN only parameters.""" + # Steerable attributes level + cfg.model.lmax_attributes = 1 + # Level of the hidden layer + cfg.model.lmax_hidden = 1 + # SEGNN normalization. instance, batch, none + cfg.model.segnn_norm = "none" + # SEGNN velocity aggregation. avg or last + cfg.model.velocity_aggregate = "avg" diff --git a/lagrangebench/train/strats.py b/lagrangebench/train/strats.py index da47056..a585983 100644 --- a/lagrangebench/train/strats.py +++ b/lagrangebench/train/strats.py @@ -95,7 +95,7 @@ def push_forward_sample_steps(key, step, pushforward): key, key_unroll = jax.random.split(key, 2) # steps needs to be an ordered list - steps = jnp.array(pushforward["steps"]) + steps = jnp.array(pushforward.steps) assert all(steps[i] <= steps[i + 1] for i in range(len(steps) - 1)) # until which index to sample from @@ -103,8 +103,8 @@ def push_forward_sample_steps(key, step, pushforward): unroll_steps = jax.random.choice( key_unroll, - a=jnp.array(pushforward["unrolls"][:idx]), - p=jnp.array(pushforward["probs"][:idx]), + a=jnp.array(pushforward.unrolls[:idx]), + p=jnp.array(pushforward.probs[:idx]), ) return key, unroll_steps diff --git a/lagrangebench/train/trainer.py b/lagrangebench/train/trainer.py index 322b6c5..b87894d 100644 --- a/lagrangebench/train/trainer.py +++ b/lagrangebench/train/trainer.py @@ -2,7 +2,7 @@ import os from functools import partial -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple import haiku as hk import jax @@ -13,13 +13,12 @@ from torch.utils.data import DataLoader from wandb.wandb_run import Run +from lagrangebench.config import cfg from lagrangebench.data import H5Dataset from lagrangebench.data.utils import numpy_collate -from lagrangebench.defaults import defaults from lagrangebench.evaluate import MetricsComputer, averaged_metrics, eval_rollout from lagrangebench.utils import ( LossConfig, - PushforwardConfig, broadcast_from_batch, broadcast_to_batch, get_kinematic_mask, @@ -40,18 +39,17 @@ def _mse( particle_type: jnp.ndarray, target: jnp.ndarray, model_fn: Callable, - loss_weight: LossConfig, + loss_weight: Dict[str, float], ): pred, state = model_fn(params, state, (features, particle_type)) # check active (non zero) output shapes - keys = list(set(loss_weight.nonzero) & set(pred.keys())) - assert all(target[k].shape == pred[k].shape for k in keys) + assert all(target[k].shape == pred[k].shape for k in pred) # particle mask non_kinematic_mask = jnp.logical_not(get_kinematic_mask(particle_type)) num_non_kinematic = non_kinematic_mask.sum() # loss components losses = [] - for t in keys: + for t in pred: losses.append((loss_weight[t] * (pred[t] - target[t]) ** 2).sum(axis=-1)) total_loss = jnp.array(losses).sum(0) total_loss = jnp.where(non_kinematic_mask, total_loss, 0) @@ -94,26 +92,6 @@ def Trainer( case, data_train: H5Dataset, data_valid: H5Dataset, - pushforward: Optional[PushforwardConfig] = None, - metrics: List = ["mse"], - seed: int = defaults.seed, - batch_size: int = defaults.batch_size, - input_seq_length: int = defaults.input_seq_length, - noise_std: float = defaults.noise_std, - lr_start: float = defaults.lr_start, - lr_final: float = defaults.lr_final, - lr_decay_steps: int = defaults.lr_decay_steps, - lr_decay_rate: float = defaults.lr_decay_rate, - loss_weight: Optional[LossConfig] = None, - n_rollout_steps: int = defaults.n_rollout_steps, - eval_n_trajs: int = defaults.eval_n_trajs, - rollout_dir: str = defaults.rollout_dir, - out_type: str = defaults.out_type, - log_steps: int = defaults.log_steps, - eval_steps: int = defaults.eval_steps, - metrics_stride: int = defaults.metrics_stride, - num_workers: int = defaults.num_workers, - batch_size_infer: int = defaults.batch_size_infer, ) -> Callable: """ Builds a function that automates model training and evaluation. @@ -130,26 +108,6 @@ def Trainer( case: Case setup class. data_train: Training dataset. data_valid: Validation dataset. - pushforward: Pushforward configuration. None for no pushforward. - metrics: Metrics to evaluate the model on. - seed: Random seed for model init, training tricks and dataloading. - batch_size: Training batch size. - input_seq_length: Input sequence length. Default is 6. - noise_std: Noise standard deviation for the GNS-style noise. - lr_start: Initial learning rate. - lr_final: Final learning rate. - lr_decay_steps: Number of steps to reach the final learning rate. - lr_decay_rate: Learning rate decay rate. - loss_weight: Loss weight object. - n_rollout_steps: Number of autoregressive rollout steps. - eval_n_trajs: Number of trajectories to evaluate. - rollout_dir: Rollout directory. - out_type: Output type. - log_steps: Wandb/screen logging frequency. - eval_steps: Evaluation and checkpointing frequency. - metrics_stride: stride for e_kin and sinkhorn. - num_workers: number of workers for data loading. - batch_size_infer: batch size for validation/testing. Returns: Configured training function. @@ -158,14 +116,23 @@ def Trainer( model, hk.TransformedWithState ), "Model must be passed as an Haiku transformed function." - base_key, seed_worker, generator = set_seed(seed) + input_seq_length = cfg.model.input_seq_length + noise_std = cfg.optimizer.noise_std + n_rollout_steps = cfg.eval.n_rollout_steps + eval_n_trajs = cfg.eval.n_trajs_train + # make immutable for jitting + # TODO look for simpler alternatives to LossConfig + loss_weight = LossConfig(**dict(cfg.optimizer.loss_weight)) + pushforward = cfg.optimizer.pushforward + + base_key, seed_worker, generator = set_seed(cfg.seed) # dataloaders loader_train = DataLoader( dataset=data_train, - batch_size=batch_size, + batch_size=cfg.train.batch_size, shuffle=True, - num_workers=num_workers, + num_workers=cfg.train.num_workers, collate_fn=numpy_collate, drop_last=True, worker_init_fn=seed_worker, @@ -173,7 +140,7 @@ def Trainer( ) loader_valid = DataLoader( dataset=data_valid, - batch_size=batch_size_infer, + batch_size=cfg.eval.batch_size_infer, collate_fn=numpy_collate, worker_init_fn=seed_worker, generator=generator, @@ -181,31 +148,25 @@ def Trainer( # learning rate decays from lr_start to lr_final over lr_decay_steps exponentially lr_scheduler = optax.exponential_decay( - init_value=lr_start, - transition_steps=lr_decay_steps, - decay_rate=lr_decay_rate, - end_value=lr_final, + init_value=cfg.optimizer.lr_start, + transition_steps=cfg.optimizer.lr_decay_steps, + decay_rate=cfg.optimizer.lr_decay_rate, + end_value=cfg.optimizer.lr_final, ) # optimizer opt_init, opt_update = optax.adamw(learning_rate=lr_scheduler, weight_decay=1e-8) - # loss config - loss_weight = LossConfig() if loss_weight is None else LossConfig(**loss_weight) - # pushforward config - if pushforward is None: - pushforward = PushforwardConfig() - # metrics computer config metrics_computer = MetricsComputer( - metrics, + cfg.eval.metrics_train, dist_fn=case.displacement, metadata=data_train.metadata, - input_seq_length=data_train.input_seq_length, - stride=metrics_stride, + input_seq_length=input_seq_length, + stride=cfg.eval.metrics_stride_train, ) def _train( - step_max: int = defaults.step_max, + step_max: Optional[int] = None, params: Optional[hk.Params] = None, state: Optional[hk.State] = None, opt_state: Optional[optax.OptState] = None, @@ -239,6 +200,9 @@ def _train( loader_valid ), "eval_n_trajs must be <= len(loader_valid)" + if step_max is None: + step_max = cfg.train.step_max + # Precompile model for evaluation model_apply = jax.jit(model.apply) @@ -340,7 +304,7 @@ def _train( opt_state=opt_state, ) - if step % log_steps == 0: + if step % cfg.logging.log_steps == 0: loss.block_until_ready() if wandb_run: wandb_run.log({"train/loss": loss.item()}, step) @@ -348,7 +312,7 @@ def _train( step_str = str(step).zfill(len(str(int(step_max)))) print(f"{step_str}, train/loss: {loss.item():.5f}.") - if step % eval_steps == 0 and step > 0: + if step % cfg.logging.eval_steps == 0 and step > 0: nbrs = broadcast_from_batch(neighbors_batch, index=0) eval_metrics = eval_rollout( case=case, @@ -360,8 +324,8 @@ def _train( loader_eval=loader_valid, n_rollout_steps=n_rollout_steps, n_trajs=eval_n_trajs, - rollout_dir=rollout_dir, - out_type=out_type, + rollout_dir=cfg.eval.rollout_dir, + out_type=cfg.eval.out_type_train, ) metrics = averaged_metrics(eval_metrics) diff --git a/lagrangebench/utils.py b/lagrangebench/utils.py index 9589e39..3cef392 100644 --- a/lagrangebench/utils.py +++ b/lagrangebench/utils.py @@ -5,8 +5,8 @@ import os import pickle import random -from dataclasses import dataclass, field -from typing import Callable, List, Tuple +from dataclasses import dataclass +from typing import Callable, Tuple import cloudpickle import jax @@ -171,27 +171,5 @@ class LossConfig: vel: float = 0.0 acc: float = 1.0 - def __getitem__(self, item): - return getattr(self, item) - - @property - def nonzero(self): - return [field for field in self.__annotations__ if self[field] != 0] - - -@dataclass(frozen=False) -class PushforwardConfig: - """Pushforward trick configuration. - - Attributes: - steps: When to introduce each unroll stage, e.g. [-1, 20000, 50000] - unrolls: For how many timesteps to unroll, e.g. [0, 1, 20] - probs: Probability (ratio) between the relative unrolls, e.g. [5, 4, 1] - """ - - steps: List[int] = field(default_factory=lambda: [-1]) - unrolls: List[int] = field(default_factory=lambda: [0]) - probs: List[float] = field(default_factory=lambda: [1.0]) - - def __getitem__(self, item): - return getattr(self, item) + def __getitem__(self, key): + return getattr(self, key) diff --git a/main.py b/main.py index bc25a09..f49b2f7 100644 --- a/main.py +++ b/main.py @@ -1,38 +1,74 @@ +import argparse import os -import pprint -from argparse import Namespace -import yaml -from experiments.config import NestedLoader, cli_arguments +def cli_arguments(): + parser = argparse.ArgumentParser() + group = parser.add_mutually_exclusive_group(required=True) + + # config arguments + group.add_argument("-c", "--config", type=str, help="Path to the config yaml.") + group.add_argument("--model_dir", type=str, help="Path to the model checkpoint.") + # misc arguments + parser.add_argument( + "--gpu", type=int, required=False, help="CUDA device ID to use." + ) + parser.add_argument( + "--f64", + required=False, + action=argparse.BooleanOptionalAction, + help="Whether to use double precision.", + ) + parser.add_argument( + "--xla_mem_fraction", + type=float, + required=False, + default=0.7, + help="Fraction of XLA memory to use.", + ) + # optional config overrides + parser.add_argument( + "extra", + default=None, + nargs=argparse.REMAINDER, + help="Extra config overrides as key value pairs.", + ) + + args = parser.parse_args() + if args.extra is None: + args.extra = [] + + return args + if __name__ == "__main__": cli_args = cli_arguments() - if "config" in cli_args: # to (re)start training - config_path = cli_args["config"] - elif "model_dir" in cli_args: # to run inference - config_path = os.path.join(cli_args["model_dir"], "config.yaml") - with open(config_path, "r") as f: - args = yaml.load(f, NestedLoader) - - # priority to command line arguments - args.update(cli_args) - args = Namespace(config=Namespace(**args), info=Namespace()) - print("#" * 79, "\nStarting a LagrangeBench run with the following configs:") - pprint.pprint(vars(args.config)) - print("#" * 79) + if cli_args.config is not None: # to (re)start training + config_path = cli_args.config.strip() + elif cli_args.model_dir is not None: # to run inference + config_path = os.path.join(cli_args.model_dir, "config.yaml") + cli_args.extra.extend(["model.model_dir", cli_args.model_dir]) # specify cuda device os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 from TensorFlow - os.environ["CUDA_VISIBLE_DEVICES"] = str(args.config.gpu) - os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(args.config.xla_mem_fraction) - - if args.config.f64: + os.environ["CUDA_VISIBLE_DEVICES"] = str(cli_args.gpu) + os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(cli_args.xla_mem_fraction) + if cli_args.f64: from jax import config config.update("jax_enable_x64", True) + else: + cli_args.extra.extend(["dtype", "float32"]) + + from lagrangebench.config import cfg, load_cfg + + load_cfg(cfg, config_path, cli_args.extra) + + print("#" * 79, "\nStarting a LagrangeBench run with the following configs:") + print(cfg.dump()) + print("#" * 79) from experiments.run import train_or_infer - train_or_infer(args) + train_or_infer(cfg) diff --git a/tests/rollout_test.py b/tests/rollout_test.py index a559c48..a48de3a 100644 --- a/tests/rollout_test.py +++ b/tests/rollout_test.py @@ -20,6 +20,8 @@ from lagrangebench.evaluate.rollout import _forward_eval, eval_batched_rollout from lagrangebench.utils import broadcast_from_batch +# TODO tests + class TestInferBuilder(unittest.TestCase): """Class for unit testing the evaluate_single_rollout function.""" From 1a35dd29371f42f74ef2cf463c1b502ce3b677b3 Mon Sep 17 00:00:00 2001 From: gerkone Date: Fri, 16 Feb 2024 09:19:46 +0100 Subject: [PATCH 02/13] yaml config cleanup --- configs/WaterDrop_2d/base.yaml | 6 -- configs/WaterDrop_2d/gns.yaml | 19 ++++-- configs/dam_2d/base.yaml | 7 -- configs/dam_2d/gns.yaml | 22 ++++-- configs/dam_2d/segnn.yaml | 23 +++++-- configs/defaults.yaml | 118 --------------------------------- configs/ldc_2d/base.yaml | 7 -- configs/ldc_2d/gns.yaml | 16 +++-- configs/ldc_2d/segnn.yaml | 19 ++++-- configs/ldc_3d/base.yaml | 6 -- configs/ldc_3d/gns.yaml | 16 +++-- configs/ldc_3d/segnn.yaml | 20 ++++-- configs/rpf_2d/base.yaml | 4 -- configs/rpf_2d/egnn.yaml | 33 +++++---- configs/rpf_2d/gns.yaml | 16 +++-- configs/rpf_2d/painn.yaml | 21 ++++-- configs/rpf_2d/segnn.yaml | 19 ++++-- configs/rpf_3d/base.yaml | 4 -- configs/rpf_3d/egnn.yaml | 33 +++++---- configs/rpf_3d/gns.yaml | 16 +++-- configs/rpf_3d/painn.yaml | 22 ++++-- configs/rpf_3d/segnn.yaml | 19 ++++-- configs/tgv_2d/base.yaml | 4 -- configs/tgv_2d/gns.yaml | 16 +++-- configs/tgv_2d/segnn.yaml | 20 ++++-- configs/tgv_3d/base.yaml | 4 -- configs/tgv_3d/gns.yaml | 16 +++-- configs/tgv_3d/segnn.yaml | 19 ++++-- 28 files changed, 263 insertions(+), 282 deletions(-) delete mode 100644 configs/WaterDrop_2d/base.yaml delete mode 100644 configs/dam_2d/base.yaml delete mode 100644 configs/defaults.yaml delete mode 100644 configs/ldc_2d/base.yaml delete mode 100644 configs/ldc_3d/base.yaml delete mode 100644 configs/rpf_2d/base.yaml delete mode 100644 configs/rpf_3d/base.yaml delete mode 100644 configs/tgv_2d/base.yaml delete mode 100644 configs/tgv_3d/base.yaml diff --git a/configs/WaterDrop_2d/base.yaml b/configs/WaterDrop_2d/base.yaml deleted file mode 100644 index be27172..0000000 --- a/configs/WaterDrop_2d/base.yaml +++ /dev/null @@ -1,6 +0,0 @@ -extends: defaults.yaml - -data_dir: /tmp/datasets/WaterDrop -wandb_project: waterdrop_2d - -neighbor_list_backend: matscipy diff --git a/configs/WaterDrop_2d/gns.yaml b/configs/WaterDrop_2d/gns.yaml index b89287a..64c2602 100644 --- a/configs/WaterDrop_2d/gns.yaml +++ b/configs/WaterDrop_2d/gns.yaml @@ -1,6 +1,15 @@ -extends: WaterDrop_2d/base.yaml +data_dir: /tmp/datasets/WaterDrop -model: gns -num_mp_steps: 10 -latent_dim: 128 -lr_start: 5.e-4 +model: + name: gns + num_mp_steps: 10 + latent_dim: 128 + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: waterdrop_2d + +neighbors: + backend: matscipy diff --git a/configs/dam_2d/base.yaml b/configs/dam_2d/base.yaml deleted file mode 100644 index be1d3bd..0000000 --- a/configs/dam_2d/base.yaml +++ /dev/null @@ -1,7 +0,0 @@ -extends: defaults.yaml - -data_dir: datasets/2D_DAM_5740_20kevery100 -wandb_project: dam_2d - -neighbor_list_multiplier: 2.0 -noise_std: 0.001 diff --git a/configs/dam_2d/gns.yaml b/configs/dam_2d/gns.yaml index 1b5891e..3644fc0 100644 --- a/configs/dam_2d/gns.yaml +++ b/configs/dam_2d/gns.yaml @@ -1,6 +1,18 @@ -extends: dam_2d/base.yaml +data_dir: datasets/2D_DAM_5740_20kevery100 + +model: + name: gns + num_mp_steps: 10 + latent_dim: 128 + +optimizer: + lr_start: 5.e-4 + noise_std: 0.001 + +logging: + wandb_project: dam_2d + +neighbors: + multiplier: 2.0 + -model: gns -num_mp_steps: 10 -latent_dim: 128 -lr_start: 5.e-4 diff --git a/configs/dam_2d/segnn.yaml b/configs/dam_2d/segnn.yaml index e7facf7..f8c8b74 100644 --- a/configs/dam_2d/segnn.yaml +++ b/configs/dam_2d/segnn.yaml @@ -1,8 +1,19 @@ -extends: dam_2d/base.yaml +data_dir: datasets/2D_DAM_5740_20kevery100 -model: segnn -num_mp_steps: 10 -latent_dim: 64 -lr_start: 5.e-4 +model: + name: segnn + num_mp_steps: 10 + latent_dim: 64 -isotropic_norm: True +train: + isotropic_norm: True + +optimizer: + lr_start: 5.e-4 + noise_std: 0.001 + +logging: + wandb_project: dam_2d + +neighbors: + multiplier: 2.0 diff --git a/configs/defaults.yaml b/configs/defaults.yaml deleted file mode 100644 index 0771f6a..0000000 --- a/configs/defaults.yaml +++ /dev/null @@ -1,118 +0,0 @@ -# Fallback parameters for the config file. These are overwritten by the config file. -extends: -# Model settings -# Model architecture name. gns, segnn, egnn -model: -# Length of the position input sequence -input_seq_length: 6 -# Number of message passing steps -num_mp_steps: 10 -# Number of MLP layers -num_mlp_layers: 2 -# Hidden dimension -latent_dim: 128 -# Load checkpointed model from this directory -model_dir: -# SEGNN only parameters -# Steerable attributes level -lmax_attributes: 1 -# Level of the hidden layer -lmax_hidden: 1 -# SEGNN normalization. instance, batch, none -segnn_norm: none -# SEGNN velocity aggregation. avg or last -velocity_aggregate: avg - -# Optimization settings -# Max steps -step_max: 500000 -# Batch size -batch_size: 1 -# Starting learning rate -lr_start: 1.e-4 -# Final learning rate after decay -lr_final: 1.e-6 -# Rate of learning rate decay -lr_decay_rate: 0.1 -# Number of steps for the learning rate to decay -lr_decay_steps: 1.e+5 -# Standard deviation for the additive noise -noise_std: 0.0003 -# Whether to use magnitudes or not -magnitude_features: False -# Whether to normalize inputs and outputs with the same value in x, y, ans z. -isotropic_norm: False -# Parameters related to the push-forward trick -pushforward: - # At which training step to introduce next unroll stage - steps: [-1, 200000, 300000, 400000] - # For how many steps to unroll - unrolls: [0, 1, 2, 3] - # Which probability ratio to keep between the unrolls - probs: [18, 2, 1, 1] - -# Loss settings -# Loss weight for position, acceleration, and velocity components -loss_weight: - acc: 1.0 - -# Run settings -# train, infer, all -mode: all -# Dataset directory -data_dir: -# Number of rollout steps. If "-1", then defaults to sequence_length - input_seq_len. -# n_rollout_steps must be <= ground truth len. For extrapolation use n_extrap_steps -n_rollout_steps: 20 -# Number of evaluation trajectories. "-1" for all available -eval_n_trajs: 50 -# Number of extrapolation steps -n_extrap_steps: 0 -# Whether to use test or validation split -test: False -# Seed -seed: 0 -# Cuda device. "-1" for cpu -gpu: 0 -# GPU memory allocation https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html -xla_mem_fraction: 0.75 -# Double precision everywhere other than the ML model -f64: True -# Neighbour list backend. jaxmd_vmap, jaxmd_scan, matscipy -neighbor_list_backend: jaxmd_vmap -# Neighbour list capacity multiplier -neighbor_list_multiplier: 1.25 -# number of workers for data loading -num_workers: 4 - -# Logging settings -# Use wandb for logging -wandb: False -wandb_project: False -# Change this with your own entity -wandb_entity: lagrangebench -# Number of steps between training logging -log_steps: 1000 -# Number of steps between evaluation -eval_steps: 10000 -# Checkpoint directory -ckp_dir: ckp -# Rollout/metrics directory -rollout_dir: -# Rollout storage format. vtk, pkl, none -out_type: none -# List of metrics. mse, mae, sinkhorn, e_kin -metrics: - - mse -metrics_stride: 10 - -# Inference params (valid/test) -metrics_infer: - - mse - - sinkhorn - - e_kin -metrics_stride_infer: 1 -out_type_infer: pkl -eval_n_trajs_infer: -1 -# batch size for validation/testing -batch_size_infer: 2 diff --git a/configs/ldc_2d/base.yaml b/configs/ldc_2d/base.yaml deleted file mode 100644 index d9fdc96..0000000 --- a/configs/ldc_2d/base.yaml +++ /dev/null @@ -1,7 +0,0 @@ -extends: defaults.yaml - -data_dir: datasets/2D_LDC_2708_10kevery100 -wandb_project: ldc_2d - -neighbor_list_multiplier: 2.0 -noise_std: 0.001 diff --git a/configs/ldc_2d/gns.yaml b/configs/ldc_2d/gns.yaml index fda8aea..309c650 100644 --- a/configs/ldc_2d/gns.yaml +++ b/configs/ldc_2d/gns.yaml @@ -1,6 +1,12 @@ -extends: ldc_2d/base.yaml +data_dir: datasets/2D_LDC_2708_10kevery100 -model: gns -num_mp_steps: 10 -latent_dim: 128 -lr_start: 5.e-4 +model: + name: gns + num_mp_steps: 10 + latent_dim: 128 + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: ldc_3d \ No newline at end of file diff --git a/configs/ldc_2d/segnn.yaml b/configs/ldc_2d/segnn.yaml index 1adece6..b32420e 100644 --- a/configs/ldc_2d/segnn.yaml +++ b/configs/ldc_2d/segnn.yaml @@ -1,8 +1,15 @@ -extends: ldc_2d/base.yaml +data_dir: datasets/2D_LDC_2708_10kevery100 -model: segnn -num_mp_steps: 10 -latent_dim: 64 -lr_start: 5.e-4 +model: + name: segnn + num_mp_steps: 10 + latent_dim: 64 -isotropic_norm: True +train: + isotropic_norm: True + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: ldc_3d \ No newline at end of file diff --git a/configs/ldc_3d/base.yaml b/configs/ldc_3d/base.yaml deleted file mode 100644 index 5dfb668..0000000 --- a/configs/ldc_3d/base.yaml +++ /dev/null @@ -1,6 +0,0 @@ -extends: defaults.yaml - -data_dir: datasets/3D_LDC_8160_10kevery100 -wandb_project: ldc_3d - -neighbor_list_multiplier: 2.0 diff --git a/configs/ldc_3d/gns.yaml b/configs/ldc_3d/gns.yaml index dbf14b4..9bd118a 100644 --- a/configs/ldc_3d/gns.yaml +++ b/configs/ldc_3d/gns.yaml @@ -1,6 +1,12 @@ -extends: ldc_3d/base.yaml +data_dir: datasets/3D_LDC_8160_10kevery100 -model: gns -num_mp_steps: 10 -latent_dim: 128 -lr_start: 5.e-4 +model: + name: gns + num_mp_steps: 10 + latent_dim: 128 + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: ldc_3d \ No newline at end of file diff --git a/configs/ldc_3d/segnn.yaml b/configs/ldc_3d/segnn.yaml index fa4844c..71c28a8 100644 --- a/configs/ldc_3d/segnn.yaml +++ b/configs/ldc_3d/segnn.yaml @@ -1,8 +1,16 @@ -extends: ldc_3d/base.yaml +data_dir: datasets/3D_LDC_8160_10kevery100 -model: segnn -num_mp_steps: 10 -latent_dim: 64 -lr_start: 5.e-4 +model: + name: segnn + num_mp_steps: 10 + latent_dim: 64 -isotropic_norm: True + +train: + isotropic_norm: True + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: ldc_3d \ No newline at end of file diff --git a/configs/rpf_2d/base.yaml b/configs/rpf_2d/base.yaml deleted file mode 100644 index 0916557..0000000 --- a/configs/rpf_2d/base.yaml +++ /dev/null @@ -1,4 +0,0 @@ -extends: defaults.yaml - -data_dir: datasets/2D_RPF_3200_20kevery100 -wandb_project: rpf_2d diff --git a/configs/rpf_2d/egnn.yaml b/configs/rpf_2d/egnn.yaml index 82ab3b3..404aabb 100644 --- a/configs/rpf_2d/egnn.yaml +++ b/configs/rpf_2d/egnn.yaml @@ -1,13 +1,20 @@ -extends: rpf_2d/base.yaml - -model: egnn -num_mp_steps: 5 -latent_dim: 128 -lr_start: 1.e-4 - -isotropic_norm: True -magnitude_features: True -loss_weight: - pos: 1.0 - vel: 0.0 - acc: 0.0 +data_dir: datasets/2D_RPF_3200_20kevery100 + +model: + name: egnn + num_mp_steps: 5 + latent_dim: 128 + +train: + isotropic_norm: True + magnitude_features: True + +optimizer: + lr_start: 5.e-4 + loss_weight: + pos: 1.0 + vel: 0.0 + acc: 0.0 + +logging: + wandb_project: rpf_3d diff --git a/configs/rpf_2d/gns.yaml b/configs/rpf_2d/gns.yaml index 87c2e81..c6740b0 100644 --- a/configs/rpf_2d/gns.yaml +++ b/configs/rpf_2d/gns.yaml @@ -1,6 +1,12 @@ -extends: rpf_2d/base.yaml +data_dir: datasets/2D_RPF_3200_20kevery100 -model: gns -num_mp_steps: 10 -latent_dim: 128 -lr_start: 5.e-4 +model: + name: gns + num_mp_steps: 10 + latent_dim: 128 + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: rpf_2d \ No newline at end of file diff --git a/configs/rpf_2d/painn.yaml b/configs/rpf_2d/painn.yaml index 95c4e91..c53e7cd 100644 --- a/configs/rpf_2d/painn.yaml +++ b/configs/rpf_2d/painn.yaml @@ -1,9 +1,16 @@ -extends: rpf_2d/base.yaml +data_dir: datasets/2D_RPF_3200_20kevery100 -model: painn -num_mp_steps: 5 -latent_dim: 128 -lr_start: 1.e-4 +model: + name: painn + num_mp_steps: 5 + latent_dim: 128 -isotropic_norm: True -magnitude_features: True +train: + isotropic_norm: True + magnitude_features: True + +optimizer: + lr_start: 1.e-4 + +logging: + wandb_project: rpf_3d diff --git a/configs/rpf_2d/segnn.yaml b/configs/rpf_2d/segnn.yaml index e65e2b4..1a80641 100644 --- a/configs/rpf_2d/segnn.yaml +++ b/configs/rpf_2d/segnn.yaml @@ -1,8 +1,15 @@ -extends: rpf_2d/base.yaml +data_dir: datasets/2D_RPF_3200_20kevery100 -model: segnn -num_mp_steps: 10 -latent_dim: 64 -lr_start: 1.e-3 +model: + name: segnn + num_mp_steps: 10 + latent_dim: 64 -isotropic_norm: True +train: + isotropic_norm: True + +optimizer: + lr_start: 1.e-3 + +logging: + wandb_project: rpf_3d diff --git a/configs/rpf_3d/base.yaml b/configs/rpf_3d/base.yaml deleted file mode 100644 index 7a20c34..0000000 --- a/configs/rpf_3d/base.yaml +++ /dev/null @@ -1,4 +0,0 @@ -extends: defaults.yaml - -data_dir: datasets/3D_RPF_8000_10kevery100 -wandb_project: rpf_3d diff --git a/configs/rpf_3d/egnn.yaml b/configs/rpf_3d/egnn.yaml index 1f793ff..1417075 100644 --- a/configs/rpf_3d/egnn.yaml +++ b/configs/rpf_3d/egnn.yaml @@ -1,13 +1,20 @@ -extends: rpf_3d/base.yaml - -model: egnn -num_mp_steps: 5 -latent_dim: 128 -lr_start: 1.e-4 - -isotropic_norm: True -magnitude_features: True -loss_weight: - pos: 1.0 - vel: 0.0 - acc: 0.0 +data_dir: datasets/3D_RPF_8000_10kevery100 + +model: + name: egnn + num_mp_steps: 5 + latent_dim: 128 + +train: + isotropic_norm: True + magnitude_features: True + +optimizer: + lr_start: 1.e-4 + loss_weight: + pos: 1.0 + vel: 0.0 + acc: 0.0 + +logging: + wandb_project: rpf_3d diff --git a/configs/rpf_3d/gns.yaml b/configs/rpf_3d/gns.yaml index 8bb2053..48c42f4 100644 --- a/configs/rpf_3d/gns.yaml +++ b/configs/rpf_3d/gns.yaml @@ -1,6 +1,12 @@ -extends: rpf_3d/base.yaml +data_dir: datasets/3D_RPF_8000_10kevery100 -model: gns -num_mp_steps: 10 -latent_dim: 128 -lr_start: 5.e-4 +model: + name: gns + num_mp_steps: 10 + latent_dim: 128 + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: rpf_3d \ No newline at end of file diff --git a/configs/rpf_3d/painn.yaml b/configs/rpf_3d/painn.yaml index cdd5b62..7bcc4eb 100644 --- a/configs/rpf_3d/painn.yaml +++ b/configs/rpf_3d/painn.yaml @@ -1,9 +1,17 @@ -extends: rpf_3d/base.yaml -model: painn -num_mp_steps: 5 -latent_dim: 128 -lr_start: 5.e-4 +data_dir: datasets/3D_RPF_8000_10kevery100 -isotropic_norm: True -magnitude_features: True +model: + name: painn + num_mp_steps: 5 + latent_dim: 128 + +train: + isotropic_norm: True + magnitude_features: True + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: rpf_3d diff --git a/configs/rpf_3d/segnn.yaml b/configs/rpf_3d/segnn.yaml index 0f6e6db..cd68478 100644 --- a/configs/rpf_3d/segnn.yaml +++ b/configs/rpf_3d/segnn.yaml @@ -1,8 +1,15 @@ -extends: rpf_3d/base.yaml +data_dir: datasets/3D_RPF_8000_10kevery100 -model: segnn -num_mp_steps: 10 -latent_dim: 64 -lr_start: 1.e-3 +model: + name: segnn + num_mp_steps: 10 + latent_dim: 64 -isotropic_norm: True +train: + isotropic_norm: True + +optimizer: + lr_start: 1.e-3 + +logging: + wandb_project: rpf_3d diff --git a/configs/tgv_2d/base.yaml b/configs/tgv_2d/base.yaml deleted file mode 100644 index f37268e..0000000 --- a/configs/tgv_2d/base.yaml +++ /dev/null @@ -1,4 +0,0 @@ -extends: defaults.yaml - -data_dir: datasets/2D_TGV_2500_10kevery100 -wandb_project: tgv_2d diff --git a/configs/tgv_2d/gns.yaml b/configs/tgv_2d/gns.yaml index 49c2330..edd4e71 100644 --- a/configs/tgv_2d/gns.yaml +++ b/configs/tgv_2d/gns.yaml @@ -1,6 +1,12 @@ -extends: tgv_2d/base.yaml +data_dir: datasets/2D_TGV_2500_10kevery100 -model: gns -num_mp_steps: 10 -latent_dim: 128 -lr_start: 5.e-4 +model: + name: gns + num_mp_steps: 10 + latent_dim: 128 + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: tgv_2d diff --git a/configs/tgv_2d/segnn.yaml b/configs/tgv_2d/segnn.yaml index 865fce3..6c9d553 100644 --- a/configs/tgv_2d/segnn.yaml +++ b/configs/tgv_2d/segnn.yaml @@ -1,8 +1,16 @@ -extends: tgv_2d/base.yaml +data_dir: datasets/2D_TGV_2500_10kevery100 -model: segnn -num_mp_steps: 10 -latent_dim: 64 -lr_start: 5.e-4 +model: + name: segnn + num_mp_steps: 10 + latent_dim: 64 -isotropic_norm: True +train: + isotropic_norm: True + + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: tgv_2d diff --git a/configs/tgv_3d/base.yaml b/configs/tgv_3d/base.yaml deleted file mode 100644 index 7c655e4..0000000 --- a/configs/tgv_3d/base.yaml +++ /dev/null @@ -1,4 +0,0 @@ -extends: defaults.yaml - -data_dir: datasets/3D_TGV_8000_10kevery100 -wandb_project: tgv_3d diff --git a/configs/tgv_3d/gns.yaml b/configs/tgv_3d/gns.yaml index cf0b741..99264a8 100644 --- a/configs/tgv_3d/gns.yaml +++ b/configs/tgv_3d/gns.yaml @@ -1,6 +1,12 @@ -extends: tgv_3d/base.yaml +data_dir: datasets/3D_TGV_8000_10kevery100 -model: gns -num_mp_steps: 10 -latent_dim: 128 -lr_start: 5.e-4 +model: + name: gns + num_mp_steps: 10 + latent_dim: 128 + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: tgv_3d diff --git a/configs/tgv_3d/segnn.yaml b/configs/tgv_3d/segnn.yaml index ebc81cc..72a6e09 100644 --- a/configs/tgv_3d/segnn.yaml +++ b/configs/tgv_3d/segnn.yaml @@ -1,8 +1,15 @@ -extends: tgv_3d/base.yaml +data_dir: datasets/3D_TGV_8000_10kevery100 -model: segnn -num_mp_steps: 10 -latent_dim: 64 -lr_start: 5.e-4 +model: + name: segnn + num_mp_steps: 10 + latent_dim: 64 -isotropic_norm: True +train: + isotropic_norm: True + +optimizer: + lr_start: 5.e-4 + +logging: + wandb_project: tgv_3d From f76c2180b73c9b45e1d06930be1cb4bef7396b56 Mon Sep 17 00:00:00 2001 From: gerkone Date: Fri, 16 Feb 2024 09:43:51 +0100 Subject: [PATCH 03/13] configured models --- experiments/utils.py | 22 ++-------------------- lagrangebench/__init__.py | 2 ++ lagrangebench/config.py | 13 +++---------- lagrangebench/models/egnn.py | 9 ++++----- lagrangebench/models/gns.py | 15 +++++---------- lagrangebench/models/painn.py | 23 +++++++++++------------ lagrangebench/models/segnn.py | 31 +++++++++---------------------- 7 files changed, 36 insertions(+), 79 deletions(-) diff --git a/experiments/utils.py b/experiments/utils.py index 458cdc5..c147a5a 100644 --- a/experiments/utils.py +++ b/experiments/utils.py @@ -77,10 +77,6 @@ def setup_model( """Setup model based on cfg.""" model_name = cfg.model.name.lower() - latent_dim = cfg.model.latent_dim - num_mlp_layers = cfg.model.num_mlp_layers - num_mp_steps = cfg.model.num_mp_steps - input_seq_length = cfg.model.input_seq_length magnitude_features = cfg.train.magnitude_features @@ -89,21 +85,18 @@ def setup_model( def model_fn(x): return models.GNS( particle_dimension=metadata["dim"], - latent_size=latent_dim, - blocks_per_step=num_mlp_layers, - num_mp_steps=num_mp_steps, num_particle_types=NodeType.SIZE, particle_type_embedding_size=16, )(x) MODEL = models.GNS elif model_name == "segnn": - segnn_cfg = cfg.model.segnn # Hx1o vel, Hx0e vel, 2x1o boundary, 9x0e type node_feature_irreps = node_irreps( metadata, input_seq_length, has_external_force, + magnitude_features, homogeneous_particles, ) # 1o displacement, 0e distance @@ -113,16 +106,9 @@ def model_fn(x): return models.SEGNN( node_features_irreps=node_feature_irreps, edge_features_irreps=edge_feature_irreps, - scalar_units=latent_dim, - lmax_hidden=segnn_cfg.lmax_hidden, - lmax_attributes=segnn_cfg.lmax_attributes, output_irreps=Irreps("1x1o"), - num_mp_steps=num_mp_steps, n_vels=input_seq_length - 1, - velocity_aggregate=segnn_cfg.velocity_aggregate, - homogeneous_particles=cfg.train.homogeneous_particles, - blocks_per_step=num_mlp_layers, - norm=segnn_cfg.segnn_norm, + homogeneous_particles=homogeneous_particles, )(x) MODEL = models.SEGNN @@ -138,13 +124,11 @@ def model_fn(x): def model_fn(x): return models.EGNN( - hidden_size=cfg.latent_dim, output_size=1, dt=metadata["dt"] * metadata["write_every"], displacement_fn=displacement_fn, shift_fn=shift_fn, normalization_stats=normalization_stats, - num_mp_steps=num_mp_steps, n_vels=input_seq_length - 1, residual=True, )(x) @@ -156,12 +140,10 @@ def model_fn(x): def model_fn(x): return models.PaiNN( - hidden_size=latent_dim, output_size=1, n_vels=input_seq_length - 1, radial_basis_fn=models.painn.gaussian_rbf(20, radius, trainable=True), cutoff_fn=models.painn.cosine_cutoff(radius), - num_mp_steps=num_mp_steps, )(x) MODEL = models.PaiNN diff --git a/lagrangebench/__init__.py b/lagrangebench/__init__.py index d157e76..ed803f2 100644 --- a/lagrangebench/__init__.py +++ b/lagrangebench/__init__.py @@ -1,4 +1,5 @@ from .case_setup.case import case_builder +from .config import cfg from .data import DAM2D, LDC2D, LDC3D, RPF2D, RPF3D, TGV2D, TGV3D, H5Dataset from .evaluate import infer from .models import EGNN, GNS, SEGNN, PaiNN @@ -20,6 +21,7 @@ "LDC2D", "LDC3D", "DAM2D", + "cfg", ] __version__ = "0.0.1" diff --git a/lagrangebench/config.py b/lagrangebench/config.py index 4302d42..c62cd2b 100644 --- a/lagrangebench/config.py +++ b/lagrangebench/config.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional import yaml from yacs.config import CfgNode as CN @@ -7,12 +7,9 @@ cfg = CN() -__custom_cfg_fn: Dict[str, Any] = {} - - def custom_config(fn): - """ "Decorator to add custom config functions.""" - __custom_cfg_fn[fn.__name__] = fn + """Decorator to add custom config functions.""" + fn(cfg) return fn @@ -159,10 +156,6 @@ def defaults(cfg): cfg.neighbors = neighbors - # custom and user configs - for cfg_fn in __custom_cfg_fn.values(): - cfg_fn(cfg) - def check_cfg(cfg): assert cfg.data_dir is not None, "cfg.data_dir must be specified." diff --git a/lagrangebench/models/egnn.py b/lagrangebench/models/egnn.py index b98ed7f..0f09854 100644 --- a/lagrangebench/models/egnn.py +++ b/lagrangebench/models/egnn.py @@ -16,6 +16,7 @@ from jax.tree_util import Partial from jax_md import space +from lagrangebench.config import cfg from lagrangebench.utils import NodeType from .base import BaseModel @@ -249,7 +250,6 @@ class EGNN(BaseModel): def __init__( self, - hidden_size: int, output_size: int, dt: float, n_vels: int, @@ -257,7 +257,6 @@ def __init__( shift_fn: space.ShiftFn, normalization_stats: Optional[Dict[str, jnp.ndarray]] = None, act_fn: Callable = jax.nn.silu, - num_mp_steps: int = 4, homogeneous_particles: bool = True, residual: bool = True, attention: bool = False, @@ -290,17 +289,17 @@ def __init__( """ super().__init__() # network - self._hidden_size = hidden_size + self._hidden_size = cfg.model.latent_dim self._output_size = output_size self._act_fn = act_fn - self._num_mp_steps = num_mp_steps + self._num_mp_steps = cfg.model.num_mp_steps self._residual = residual self._attention = attention self._normalize = normalize self._tanh = tanh # integrator - self._dt = dt / num_mp_steps + self._dt = dt / self._num_mp_steps self._displacement_fn = displacement_fn self._shift_fn = shift_fn if normalization_stats is None: diff --git a/lagrangebench/models/gns.py b/lagrangebench/models/gns.py index 9020231..5756305 100644 --- a/lagrangebench/models/gns.py +++ b/lagrangebench/models/gns.py @@ -9,6 +9,7 @@ import jax.numpy as jnp import jraph +from lagrangebench.config import cfg from lagrangebench.utils import NodeType from .base import BaseModel @@ -35,27 +36,21 @@ class GNS(BaseModel): def __init__( self, particle_dimension: int, - latent_size: int, - blocks_per_step: int, - num_mp_steps: int, - particle_type_embedding_size: int, + particle_type_embedding_size: int = 16, num_particle_types: int = NodeType.SIZE, ): """Initialize the model. Args: particle_dimension: Space dimensionality (e.g. 2 or 3). - latent_size: Size of the latent representations. - blocks_per_step: Number of MLP layers per block. - num_mp_steps: Number of message passing steps. particle_type_embedding_size: Size of the particle type embedding. num_particle_types: Max number of particle types. """ super().__init__() self._output_size = particle_dimension - self._latent_size = latent_size - self._blocks_per_step = blocks_per_step - self._mp_steps = num_mp_steps + self._latent_size = cfg.model.latent_dim + self._blocks_per_step = cfg.model.num_mlp_layers + self._mp_steps = cfg.model.num_mp_steps self._num_particle_types = num_particle_types self._embedding = hk.Embed( diff --git a/lagrangebench/models/painn.py b/lagrangebench/models/painn.py index 0447361..e394ba1 100644 --- a/lagrangebench/models/painn.py +++ b/lagrangebench/models/painn.py @@ -16,6 +16,7 @@ import jax.tree_util as tree import jraph +from lagrangebench.config import cfg from lagrangebench.utils import NodeType from .utils import LinearXav @@ -366,9 +367,7 @@ class PaiNN(hk.Module): def __init__( self, - hidden_size: int, output_size: int, - num_mp_steps: int, radial_basis_fn: Callable, cutoff_fn: Callable, n_vels: int, @@ -399,8 +398,8 @@ def __init__( self._n_vels = n_vels self._homogeneous_particles = homogeneous_particles - self._hidden_size = hidden_size - self._num_mp_steps = num_mp_steps + self._hidden_size = cfg.model.latent_dim + self._num_mp_steps = cfg.model.num_mp_steps self._eps = eps self._shared_filters = shared_filters self._shared_interactions = shared_interactions @@ -408,27 +407,27 @@ def __init__( self.radial_basis_fn = radial_basis_fn self.cutoff_fn = cutoff_fn - self.scalar_emb = LinearXav(hidden_size, name="scalar_embedding") + self.scalar_emb = LinearXav(self._hidden_size, name="scalar_embedding") # mix vector channels (only used if vector features are present in input) self.vector_emb = LinearXav( - hidden_size, with_bias=False, name="vector_embedding" + self._hidden_size, with_bias=False, name="vector_embedding" ) if shared_filters: - self.filter_net = LinearXav(3 * hidden_size, name="filter_net") + self.filter_net = LinearXav(3 * self._hidden_size, name="filter_net") else: self.filter_net = LinearXav( - num_mp_steps * 3 * hidden_size, name="filter_net" + self._num_mp_steps * 3 * self._hidden_size, name="filter_net" ) if self._shared_interactions: self.layers = [ - PaiNNLayer(hidden_size, 0, activation, eps=eps) - ] * num_mp_steps + PaiNNLayer(self._hidden_size, 0, activation, eps=eps) + ] * self._num_mp_steps else: self.layers = [ - PaiNNLayer(hidden_size, i, activation, eps=eps) - for i in range(num_mp_steps) + PaiNNLayer(self._hidden_size, i, activation, eps=eps) + for i in range(self._num_mp_steps) ] self._readout = PaiNNReadout(self._hidden_size, out_channels=output_size) diff --git a/lagrangebench/models/segnn.py b/lagrangebench/models/segnn.py index 0241ec7..c0cf25a 100644 --- a/lagrangebench/models/segnn.py +++ b/lagrangebench/models/segnn.py @@ -20,7 +20,7 @@ from e3nn_jax import Irreps, IrrepsArray from jax.tree_util import Partial, tree_map -from lagrangebench.config import custom_config +from lagrangebench.config import cfg, custom_config from lagrangebench.utils import NodeType from .base import BaseModel @@ -445,16 +445,9 @@ def __init__( self, node_features_irreps: Irreps, edge_features_irreps: Irreps, - scalar_units: int, - lmax_hidden: int, - lmax_attributes: int, output_irreps: Irreps, - num_mp_steps: int, n_vels: int, - velocity_aggregate: str = "avg", homogeneous_particles: bool = True, - norm: Optional[str] = None, - blocks_per_step: int = 2, embed_msg_features: bool = False, ): """ @@ -463,30 +456,23 @@ def __init__( Args: node_features_irreps: Irreps of the node features. edge_features_irreps: Irreps of the additional message passing features. - scalar_units: Hidden units (lower bound). Actual number depends on lmax. - lmax_hidden: Maximum L of the hidden layer representations. - lmax_attributes: Maximum L of the attributes. output_irreps: Output representation. - num_mp_steps: Number of message passing layers n_vels: Number of velocities in the history. - velocity_aggregate: Velocity sequence aggregation method. homogeneous_particles: If all particles are of homogeneous type. - norm: Normalization type. Either None, 'instance' or 'batch' - blocks_per_step: Number of tensor product blocks in each message passing embed_msg_features: Set to true to also embed edges/message passing features """ super().__init__() # network - self._attribute_irreps = Irreps.spherical_harmonics(lmax_attributes) + self._attribute_irreps = Irreps.spherical_harmonics(cfg.model.lmax_attributes) self._hidden_irreps = weight_balanced_irreps( - scalar_units, self._attribute_irreps, lmax_hidden + cfg.model.latent_dim, self._attribute_irreps, cfg.model.lmax_hidden ) self._output_irreps = output_irreps - self._num_mp_steps = num_mp_steps + self._num_mp_steps = cfg.model.num_mp_steps + self._norm = cfg.model.segnn_norm + self._blocks_per_step = cfg.model.num_mlp_layers self._embed_msg_features = embed_msg_features - self._norm = norm - self._blocks_per_step = blocks_per_step self._embedding = O3Embedding( self._hidden_irreps, @@ -500,13 +486,13 @@ def __init__( ) # transform - assert velocity_aggregate in [ + assert cfg.model.velocity_aggregate in [ "avg", "last", ], "Invalid velocity aggregate. Must be one of 'avg', 'sum' or 'last'." self._node_features_irreps = node_features_irreps self._edge_features_irreps = edge_features_irreps - self._velocity_aggregate = velocity_aggregate + self._velocity_aggregate = cfg.model.velocity_aggregate self._n_vels = n_vels self._homogeneous_particles = homogeneous_particles @@ -610,6 +596,7 @@ def __call__( return out +# TODO figure out why this is not working @custom_config def segnn_config(cfg): """SEGNN only parameters.""" From 2eb730f65e962e983876981deda16d8810085d71 Mon Sep 17 00:00:00 2001 From: gerkone Date: Fri, 16 Feb 2024 13:34:07 +0100 Subject: [PATCH 04/13] updated ldc config --- configs/ldc_2d/gns.yaml | 6 +++++- configs/ldc_2d/segnn.yaml | 6 +++++- configs/ldc_3d/gns.yaml | 5 ++++- configs/ldc_3d/segnn.yaml | 6 ++++-- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/configs/ldc_2d/gns.yaml b/configs/ldc_2d/gns.yaml index 309c650..333de25 100644 --- a/configs/ldc_2d/gns.yaml +++ b/configs/ldc_2d/gns.yaml @@ -7,6 +7,10 @@ model: optimizer: lr_start: 5.e-4 + noise_std: 0.001 logging: - wandb_project: ldc_3d \ No newline at end of file + wandb_project: ldc_3d + +neighbors: + multiplier: 2.0 \ No newline at end of file diff --git a/configs/ldc_2d/segnn.yaml b/configs/ldc_2d/segnn.yaml index b32420e..31effd3 100644 --- a/configs/ldc_2d/segnn.yaml +++ b/configs/ldc_2d/segnn.yaml @@ -10,6 +10,10 @@ train: optimizer: lr_start: 5.e-4 + noise_std: 0.001 logging: - wandb_project: ldc_3d \ No newline at end of file + wandb_project: ldc_3d + +neighbors: + multiplier: 2.0 \ No newline at end of file diff --git a/configs/ldc_3d/gns.yaml b/configs/ldc_3d/gns.yaml index 9bd118a..347bceb 100644 --- a/configs/ldc_3d/gns.yaml +++ b/configs/ldc_3d/gns.yaml @@ -9,4 +9,7 @@ optimizer: lr_start: 5.e-4 logging: - wandb_project: ldc_3d \ No newline at end of file + wandb_project: ldc_3d + +neighbors: + multiplier: 2.0 \ No newline at end of file diff --git a/configs/ldc_3d/segnn.yaml b/configs/ldc_3d/segnn.yaml index 71c28a8..0433553 100644 --- a/configs/ldc_3d/segnn.yaml +++ b/configs/ldc_3d/segnn.yaml @@ -5,7 +5,6 @@ model: num_mp_steps: 10 latent_dim: 64 - train: isotropic_norm: True @@ -13,4 +12,7 @@ optimizer: lr_start: 5.e-4 logging: - wandb_project: ldc_3d \ No newline at end of file + wandb_project: ldc_3d + +neighbors: + multiplier: 2.0 \ No newline at end of file From f5ca0a9196428e603cb24584572fd47b5beef54e Mon Sep 17 00:00:00 2001 From: gerkone Date: Fri, 16 Feb 2024 13:51:28 +0100 Subject: [PATCH 05/13] updated tests --- lagrangebench/data/data.py | 11 +++++----- tests/case_test.py | 19 ++++++++-------- tests/models_test.py | 18 ++++++++------- tests/pushforward_test.py | 15 ++++++++----- tests/rollout_test.py | 45 +++++++++++++++++--------------------- 5 files changed, 55 insertions(+), 53 deletions(-) diff --git a/lagrangebench/data/data.py b/lagrangebench/data/data.py index 0febebe..3ab018f 100644 --- a/lagrangebench/data/data.py +++ b/lagrangebench/data/data.py @@ -15,6 +15,7 @@ import wget from torch.utils.data import Dataset +from lagrangebench.config import cfg from lagrangebench.utils import NodeType URLS = { @@ -41,17 +42,16 @@ class H5Dataset(Dataset): def __init__( self, split: str, - dataset_path: str, + dataset_path: Optional[str] = None, name: Optional[str] = None, input_seq_length: int = 6, extra_seq_length: int = 0, - nl_backend: str = "jaxmd_vmap", ): """Initialize the dataset. If the dataset is not present, it is downloaded. Args: split: "train", "valid", or "test" - dataset_path: Path to the dataset + dataset_path: Path to the dataset. If none it reads from the config. name: Name of the dataset. If None, it is inferred from the path. input_seq_length: Length of the input sequence. The number of historic velocities is input_seq_length - 1. And during training, the returned @@ -60,8 +60,9 @@ def __init__( extra_seq_length: During training, this is the maximum number of pushforward unroll steps. During validation/testing, this specifies the largest N-step MSE loss we are interested in, e.g. for best model checkpointing. - nl_backend: Which backend to use for the neighbor list """ + if dataset_path is None: + dataset_path = cfg.data_dir if dataset_path.endswith("/"): # remove trailing slash in dataset path dataset_path = dataset_path[:-1] @@ -80,7 +81,7 @@ def __init__( self.dataset_path = dataset_path self.file_path = osp.join(dataset_path, split + ".h5") self.input_seq_length = input_seq_length - self.nl_backend = nl_backend + self.nl_backend = cfg.neighbors.backend force_fn_path = osp.join(dataset_path, "force.py") if osp.exists(force_fn_path): diff --git a/tests/case_test.py b/tests/case_test.py index 373eb28..d3cb33f 100644 --- a/tests/case_test.py +++ b/tests/case_test.py @@ -5,6 +5,14 @@ import numpy as np from lagrangebench.case_setup import case_builder +from lagrangebench.config import custom_config + + +@custom_config +def case_test_config(cfg): + cfg.model.input_seq_length = 3 # two past velocities + cfg.train.isotropic_norm = False + cfg.optimizer.noise_std = 0.0 class TestCaseBuilder(unittest.TestCase): @@ -25,14 +33,7 @@ def setUp(self): bounds = np.array(self.metadata["bounds"]) box = bounds[:, 1] - bounds[:, 0] - self.case = case_builder( - box, - self.metadata, - input_seq_length=3, # two past velocities - isotropic_norm=False, - noise_std=0.0, - external_force_fn=None, - ) + self.case = case_builder(box, self.metadata, external_force_fn=None) self.key = jax.random.PRNGKey(0) # position input shape (num_particles, sequence_len, dim) = (3, 5, 3) @@ -63,7 +64,7 @@ def setUp(self): ) self.particle_types = np.array([0, 0, 0]) - key, features, target_dict, neighbors = self.case.allocate( + _, _, _, neighbors = self.case.allocate( self.key, (self.position_data, self.particle_types) ) self.neighbors = neighbors diff --git a/tests/models_test.py b/tests/models_test.py index 702280c..ec6d20e 100644 --- a/tests/models_test.py +++ b/tests/models_test.py @@ -7,9 +7,19 @@ import numpy as np from lagrangebench import models +from lagrangebench.config import custom_config from lagrangebench.utils import NodeType +@custom_config +def model_test_config(cfg): + cfg.model.hidden_dim = 8 + cfg.model.output_size = 1 + cfg.model.num_mp_steps = 1 + cfg.model.lmax_attributes = 1 + cfg.model.lmax_hidden = 1 + + class ModelTest(unittest.TestCase): def dummy_sample(self, vel=None, pos=None): key = self.key() @@ -72,11 +82,7 @@ def segnn(x): return models.SEGNN( node_features_irreps="5x1o + 5x0e", edge_features_irreps="1x1o + 1x0e", - scalar_units=8, - lmax_hidden=1, - lmax_attributes=1, n_vels=5, - num_mp_steps=1, output_irreps="1x1o", )(x) @@ -89,9 +95,7 @@ def segnn(x): def test_egnn(self): def egnn(x): return models.EGNN( - hidden_size=8, output_size=1, - num_mp_steps=1, dt=0.01, n_vels=5, displacement_fn=lambda x, y: x - y, @@ -107,9 +111,7 @@ def egnn(x): def test_painn(self): def painn(x): return models.PaiNN( - hidden_size=8, output_size=1, - num_mp_steps=1, radial_basis_fn=models.painn.gaussian_rbf(20, 10, trainable=True), cutoff_fn=models.painn.cosine_cutoff(10), n_vels=5, diff --git a/tests/pushforward_test.py b/tests/pushforward_test.py index 06d77d8..ff379b1 100644 --- a/tests/pushforward_test.py +++ b/tests/pushforward_test.py @@ -3,19 +3,22 @@ import jax import numpy as np -from lagrangebench import PushforwardConfig +from lagrangebench.config import cfg, custom_config from lagrangebench.train.strats import push_forward_sample_steps +@custom_config +def pf_test_config(cfg): + cfg.optimizer.pushforward.steps = [-1, 20000, 50000, 100000] + cfg.optimizer.pushforward.unrolls = [0, 1, 3, 20] + cfg.optimizer.pushforward.probs = [4.05, 4.05, 1.0, 1.0] + + class TestPushForward(unittest.TestCase): """Class for unit testing the push-forward functions.""" def setUp(self): - self.pf = PushforwardConfig( - steps=[-1, 20000, 50000, 100000], - unrolls=[0, 1, 3, 20], - probs=[4.05, 4.05, 1.0, 1.0], - ) + self.pf = cfg.optimizer.pushforward self.key = jax.random.PRNGKey(42) diff --git a/tests/rollout_test.py b/tests/rollout_test.py index a48de3a..48ba729 100644 --- a/tests/rollout_test.py +++ b/tests/rollout_test.py @@ -1,5 +1,4 @@ import unittest -from argparse import Namespace from functools import partial import haiku as hk @@ -14,34 +13,35 @@ jax_config.update("jax_enable_x64", True) from lagrangebench.case_setup import case_builder +from lagrangebench.config import cfg, custom_config from lagrangebench.data import H5Dataset from lagrangebench.data.utils import get_dataset_stats, numpy_collate from lagrangebench.evaluate import MetricsComputer -from lagrangebench.evaluate.rollout import _forward_eval, eval_batched_rollout +from lagrangebench.evaluate.rollout import _eval_batched_rollout, _forward_eval from lagrangebench.utils import broadcast_from_batch -# TODO tests + +@custom_config +def eval_test_config(cfg): + # setup the configuration + cfg.data_dir = "tests/3D_LJ_3_1214every1" # Lennard-Jones dataset + cfg.model.input_seq_length = 3 + cfg.metrics = ["mse"] + cfg.eval.n_rollout_steps = 100 + cfg.train.isotropic_norm = False + cfg.optimizer.noise_std = 0.0 class TestInferBuilder(unittest.TestCase): """Class for unit testing the evaluate_single_rollout function.""" def setUp(self): - self.config = Namespace( - data_dir="tests/3D_LJ_3_1214every1", # Lennard-Jones dataset - input_seq_length=3, # two past velocities - metrics=["mse"], - n_rollout_steps=100, - isotropic_norm=False, - noise_std=0.0, - ) - data_valid = H5Dataset( split="valid", - dataset_path=self.config.data_dir, + dataset_path=cfg.data_dir, name="lj3d", - input_seq_length=self.config.input_seq_length, - extra_seq_length=self.config.n_rollout_steps, + input_seq_length=cfg.model.input_seq_length, + extra_seq_length=cfg.eval.n_rollout_steps, ) self.loader_valid = DataLoader( dataset=data_valid, batch_size=1, collate_fn=numpy_collate @@ -49,19 +49,14 @@ def setUp(self): self.metadata = data_valid.metadata self.normalization_stats = get_dataset_stats( - self.metadata, self.config.isotropic_norm, self.config.noise_std + self.metadata, cfg.train.isotropic_norm, cfg.optimizer.noise_std ) bounds = np.array(self.metadata["bounds"]) box = bounds[:, 1] - bounds[:, 0] self.displacement_fn, self.shift_fn = space.periodic(side=box) - self.case = case_builder( - box, - self.metadata, - self.config.input_seq_length, - noise_std=self.config.noise_std, - ) + self.case = case_builder(box, self.metadata) self.key = jax.random.PRNGKey(0) @@ -141,7 +136,7 @@ def model(x): for n_extrap_steps in [0, 5, 10]: with self.subTest(n_extrap_steps): - example_rollout_batch, metrics_batch, neighbors = eval_batched_rollout( + example_rollout_batch, metrics_batch, neighbors = _eval_batched_rollout( forward_eval_vmap=forward_eval_vmap, preprocess_eval_vmap=preprocess_eval_vmap, case=self.case, @@ -150,7 +145,7 @@ def model(x): traj_batch_i=traj_batch_i, neighbors=neighbors, metrics_computer_vmap=metrics_computer_vmap, - n_rollout_steps=self.config.n_rollout_steps, + n_rollout_steps=cfg.eval.n_rollout_steps, n_extrap_steps=n_extrap_steps, t_window=isl, ) @@ -185,7 +180,7 @@ def model(x): "Wrong rollout prediction", ) - total_steps = self.config.n_rollout_steps + n_extrap_steps + total_steps = cfg.eval.n_rollout_steps + n_extrap_steps assert example_rollout_batch.shape[1] == total_steps From 3f9a9de552e5242fd05957772d344cae4f2947f2 Mon Sep 17 00:00:00 2001 From: gerkone Date: Fri, 16 Feb 2024 13:56:33 +0100 Subject: [PATCH 06/13] yacs dependency --- poetry.lock | 32 +++++++++++++++++++++++++++++--- pyproject.toml | 1 + requirements_cuda.txt | 1 + 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/poetry.lock b/poetry.lock index fdc1a02..e521c76 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.0 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "absl-py" @@ -1745,8 +1745,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.23.3", markers = "python_version > \"3.10\""}, - {version = ">=1.21.2", markers = "python_version > \"3.9\" and python_version <= \"3.10\""}, {version = ">1.20", markers = "python_version <= \"3.9\""}, + {version = ">=1.21.2", markers = "python_version > \"3.9\" and python_version <= \"3.10\""}, ] [package.extras] @@ -2433,6 +2433,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2440,8 +2441,16 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2458,6 +2467,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2465,6 +2475,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -3395,6 +3406,21 @@ files = [ {file = "wget-3.2.zip", hash = "sha256:35e630eca2aa50ce998b9b1a127bb26b30dfee573702782aa982f875e3f16061"}, ] +[[package]] +name = "yacs" +version = "0.1.8" +description = "Yet Another Configuration System" +optional = false +python-versions = "*" +files = [ + {file = "yacs-0.1.8-py2-none-any.whl", hash = "sha256:d43d1854c1ffc4634c5b349d1c1120f86f05c3a294c9d141134f282961ab5d94"}, + {file = "yacs-0.1.8-py3-none-any.whl", hash = "sha256:99f893e30497a4b66842821bac316386f7bd5c4f47ad35c9073ef089aa33af32"}, + {file = "yacs-0.1.8.tar.gz", hash = "sha256:efc4c732942b3103bea904ee89af98bcd27d01f0ac12d8d4d369f1e7a2914384"}, +] + +[package.dependencies] +PyYAML = "*" + [[package]] name = "zipp" version = "3.17.0" @@ -3413,4 +3439,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<=3.11" -content-hash = "5fc2e88ec569a667ab5076bf43acf88c3bf3d7d359756359b31a9ccdd25148d7" +content-hash = "d49bfe19749c44cd94dd68d73b3658ea6144666668da557e922d4a95c118b0c0" diff --git a/pyproject.toml b/pyproject.toml index 8dabb99..ac4e8c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ ott-jax = "^0.4.2" matscipy = "^0.8.0" torch = {version = "2.1.0+cpu", source = "torchcpu"} wget = "^3.2" +yacs = "^0.1.8" [tool.poetry.group.dev.dependencies] # mypy = ">=1.8.0" - consider in the future diff --git a/requirements_cuda.txt b/requirements_cuda.txt index 0bc59df..5e90c0b 100644 --- a/requirements_cuda.txt +++ b/requirements_cuda.txt @@ -18,3 +18,4 @@ PyYAML torch==2.1.0+cpu wandb wget +yacs>=0.1.8 \ No newline at end of file From cd0c82f8114004609e667404c48e0e309377603f Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Mon, 19 Feb 2024 13:25:09 +0000 Subject: [PATCH 07/13] move experiments/* to lagrangebench/runner.py --- README.md | 2 +- configs/ldc_2d/gns.yaml | 2 +- configs/ldc_2d/segnn.yaml | 2 +- configs/rpf_2d/egnn.yaml | 2 +- configs/rpf_2d/painn.yaml | 2 +- configs/rpf_2d/segnn.yaml | 2 +- experiments/run.py | 121 ------------------ lagrangebench/config.py | 10 +- lagrangebench/evaluate/rollout.py | 4 +- .../utils.py => lagrangebench/runner.py | 102 ++++++++++++++- lagrangebench/train/trainer.py | 32 ++++- main.py | 8 +- 12 files changed, 139 insertions(+), 150 deletions(-) delete mode 100644 experiments/run.py rename experiments/utils.py => lagrangebench/runner.py (58%) diff --git a/README.md b/README.md index dd41baf..276619b 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ pip install --upgrade jax[cuda12_pip]==0.4.20 -f https://storage.googleapis.com/ ### MacOS Currently, only the CPU installation works. You will need to change a few small things to get it going: - Clone installation: in `pyproject.toml` change the torch version from `2.1.0+cpu` to `2.1.0`. Then, remove the `poetry.lock` file and run `poetry install --only main`. -- Configs: You will need to set `f64: False` and `num_workers: 0` in the `configs/` files. +- Configs: You will need to set `f32: True` and `num_workers: 0` in the `configs/` files. Although the current [`jax-metal==0.0.5` library](https://pypi.org/project/jax-metal/) supports jax in general, there seems to be a missing feature used by `jax-md` related to padding -> see [this issue](https://github.com/google/jax/issues/16366#issuecomment-1591085071). diff --git a/configs/ldc_2d/gns.yaml b/configs/ldc_2d/gns.yaml index 333de25..4574ef9 100644 --- a/configs/ldc_2d/gns.yaml +++ b/configs/ldc_2d/gns.yaml @@ -10,7 +10,7 @@ optimizer: noise_std: 0.001 logging: - wandb_project: ldc_3d + wandb_project: ldc_2d neighbors: multiplier: 2.0 \ No newline at end of file diff --git a/configs/ldc_2d/segnn.yaml b/configs/ldc_2d/segnn.yaml index 31effd3..fb26eb4 100644 --- a/configs/ldc_2d/segnn.yaml +++ b/configs/ldc_2d/segnn.yaml @@ -13,7 +13,7 @@ optimizer: noise_std: 0.001 logging: - wandb_project: ldc_3d + wandb_project: ldc_2d neighbors: multiplier: 2.0 \ No newline at end of file diff --git a/configs/rpf_2d/egnn.yaml b/configs/rpf_2d/egnn.yaml index 404aabb..3a634d1 100644 --- a/configs/rpf_2d/egnn.yaml +++ b/configs/rpf_2d/egnn.yaml @@ -17,4 +17,4 @@ optimizer: acc: 0.0 logging: - wandb_project: rpf_3d + wandb_project: rpf_2d diff --git a/configs/rpf_2d/painn.yaml b/configs/rpf_2d/painn.yaml index c53e7cd..b9e3256 100644 --- a/configs/rpf_2d/painn.yaml +++ b/configs/rpf_2d/painn.yaml @@ -13,4 +13,4 @@ optimizer: lr_start: 1.e-4 logging: - wandb_project: rpf_3d + wandb_project: rpf_2d diff --git a/configs/rpf_2d/segnn.yaml b/configs/rpf_2d/segnn.yaml index 1a80641..94370c5 100644 --- a/configs/rpf_2d/segnn.yaml +++ b/configs/rpf_2d/segnn.yaml @@ -12,4 +12,4 @@ optimizer: lr_start: 1.e-3 logging: - wandb_project: rpf_3d + wandb_project: rpf_2d diff --git a/experiments/run.py b/experiments/run.py deleted file mode 100644 index 691efd9..0000000 --- a/experiments/run.py +++ /dev/null @@ -1,121 +0,0 @@ -import os -import os.path as osp -from datetime import datetime - -import haiku as hk -import jmp -import numpy as np -import wandb - -from experiments.utils import setup_data, setup_model -from lagrangebench import Trainer, infer -from lagrangebench.case_setup import case_builder -from lagrangebench.config import cfg_to_dict -from lagrangebench.evaluate import averaged_metrics - - -def train_or_infer(cfg): - mode = cfg.mode - old_model_dir = cfg.model.model_dir - is_test = cfg.eval.test - - data_train, data_valid, data_test, dataset_name = setup_data(cfg) - - exp_info = {"dataset_name": dataset_name} - - metadata = data_train.metadata - # neighbors search - bounds = np.array(metadata["bounds"]) - box = bounds[:, 1] - bounds[:, 0] - - exp_info["len_train"] = len(data_train) - exp_info["len_eval"] = len(data_valid) - - # setup core functions - case = case_builder( - box=box, - metadata=metadata, - external_force_fn=data_train.external_force_fn, - ) - - _, particle_type = data_train[0] - - # setup model from configs - model, MODEL = setup_model( - cfg, - metadata=metadata, - homogeneous_particles=particle_type.max() == particle_type.min(), - has_external_force=data_train.external_force_fn is not None, - normalization_stats=case.normalization_stats, - ) - model = hk.without_apply_rng(hk.transform_with_state(model)) - - # mixed precision training based on this reference: - # https://github.com/deepmind/dm-haiku/blob/main/examples/imagenet/train.py - policy = jmp.get_policy("params=float32,compute=float32,output=float32") - hk.mixed_precision.set_policy(MODEL, policy) - - if mode == "train" or mode == "all": - print("Start training...") - # save config file - run_prefix = f"{cfg.model.name}_{data_train.name}" - data_and_time = datetime.today().strftime("%Y%m%d-%H%M%S") - exp_info["run_name"] = f"{run_prefix}_{data_and_time}" - - cfg.model.model_dir = os.path.join(cfg.logging.ckp_dir, exp_info["run_name"]) - os.makedirs(cfg.model.model_dir, exist_ok=True) - os.makedirs(os.path.join(cfg.model.model_dir, "best"), exist_ok=True) - with open(os.path.join(cfg.model.model_dir, "config.yaml"), "w") as f: - cfg.dump(stream=f) - with open(os.path.join(cfg.model.model_dir, "best", "config.yaml"), "w") as f: - cfg.dump(stream=f) - - if cfg.logging.wandb: - cfg_dict = cfg_to_dict(cfg) - cfg_dict.update(exp_info) - - wandb_run = wandb.init( - project=cfg.logging.wandb_project, - entity=cfg.logging.wandb_entity, - name=cfg.logging.run_name, - config=cfg_dict, - save_code=True, - ) - else: - wandb_run = None - - trainer = Trainer(model, case, data_train, data_valid) - _, _, _ = trainer( - step_max=cfg.train.step_max, - load_checkpoint=old_model_dir, - store_checkpoint=cfg.model.model_dir, - wandb_run=wandb_run, - ) - - if cfg.logging.wandb: - wandb.finish() - - if mode == "infer" or mode == "all": - print("Start inference...") - best_model_dir = old_model_dir - if mode == "all": - best_model_dir = os.path.join(cfg.model.model_dir, "best") - assert osp.isfile(os.path.join(best_model_dir, "params_tree.pkl")) - - cfg.eval.rollout_dir = best_model_dir.replace("ckp", "rollout") - os.makedirs(cfg.eval.rollout_dir, exist_ok=True) - - if cfg.eval.n_trajs_infer is None: - cfg.eval.n_trajs_infer = cfg.eval.n_trajs_train - - assert old_model_dir, "model_dir must be specified for inference." - metrics = infer( - model, - case, - data_test if is_test else data_valid, - load_checkpoint=best_model_dir, - ) - - split = "test" if is_test else "valid" - print(f"Metrics of {best_model_dir} on {split} split:") - print(averaged_metrics(metrics)) diff --git a/lagrangebench/config.py b/lagrangebench/config.py index c62cd2b..9b3cec0 100644 --- a/lagrangebench/config.py +++ b/lagrangebench/config.py @@ -79,7 +79,7 @@ def defaults(cfg): # optimizer: pushforward pushforward = CN() # At which training step to introduce next unroll stage - pushforward.steps = [-1, 200000, 300000, 400000] + pushforward.steps = [-1, 20000, 300000, 400000] # For how many steps to unroll pushforward.unrolls = [0, 1, 2, 3] # Which probability ratio to keep between the unrolls @@ -119,8 +119,10 @@ def defaults(cfg): eval.n_extrap_steps = 0 # batch size for validation/testing eval.batch_size_infer = 2 - # loggingging directory - eval.out_type = None + # write validation rollouts. One of "none", "vtk", or "pkl" + eval.out_type_train = "none" + # write inference rollouts. One of "none", "vtk", or "pkl" + eval.out_type_infer = "pkl" # rollouts directory eval.rollout_dir = None # whether to use the test split @@ -143,6 +145,8 @@ def defaults(cfg): logging.wandb_entity = "lagrangebench" # checkpoint directory logging.ckp_dir = "ckp" + # name of training run + logging.run_name = None cfg.logging = logging diff --git a/lagrangebench/evaluate/rollout.py b/lagrangebench/evaluate/rollout.py index f112c02..200bca9 100644 --- a/lagrangebench/evaluate/rollout.py +++ b/lagrangebench/evaluate/rollout.py @@ -289,7 +289,7 @@ def eval_rollout( "tag": example_rollout["particle_type"], } write_vtk(ref_state_vtk, f"{file_prefix}_ref_{k}.vtk") - if out_type == "pkl": + elif out_type == "pkl": filename = f"{file_prefix}.pkl" with open(filename, "wb") as f: @@ -383,7 +383,7 @@ def infer( n_rollout_steps=cfg.eval.n_rollout_steps, n_trajs=cfg.eval.n_trajs_infer, rollout_dir=cfg.eval.rollout_dir, - out_type=cfg.eval.out_type, + out_type=cfg.eval.out_type_infer, n_extrap_steps=cfg.eval.n_extrap_steps, ) return eval_metrics diff --git a/experiments/utils.py b/lagrangebench/runner.py similarity index 58% rename from experiments/utils.py rename to lagrangebench/runner.py index c147a5a..eb30921 100644 --- a/experiments/utils.py +++ b/lagrangebench/runner.py @@ -1,30 +1,121 @@ import os import os.path as osp from argparse import Namespace +from datetime import datetime from typing import Callable, Dict, Optional, Tuple, Type +import haiku as hk import jax import jax.numpy as jnp +import jmp +import numpy as np from e3nn_jax import Irreps from jax_md import space -from lagrangebench import models +from lagrangebench import Trainer, infer, models +from lagrangebench.case_setup import case_builder from lagrangebench.data import H5Dataset +from lagrangebench.evaluate import averaged_metrics from lagrangebench.models.utils import node_irreps from lagrangebench.utils import NodeType +def train_or_infer(cfg): + mode = cfg.mode + old_model_dir = cfg.model.model_dir + is_test = cfg.eval.test + + data_train, data_valid, data_test = setup_data(cfg) + + metadata = data_train.metadata + # neighbors search + bounds = np.array(metadata["bounds"]) + box = bounds[:, 1] - bounds[:, 0] + + # setup core functions + case = case_builder( + box=box, + metadata=metadata, + external_force_fn=data_train.external_force_fn, + ) + + _, particle_type = data_train[0] + + # setup model from configs + model, MODEL = setup_model( + cfg, + metadata=metadata, + homogeneous_particles=particle_type.max() == particle_type.min(), + has_external_force=data_train.external_force_fn is not None, + normalization_stats=case.normalization_stats, + ) + model = hk.without_apply_rng(hk.transform_with_state(model)) + + # mixed precision training based on this reference: + # https://github.com/deepmind/dm-haiku/blob/main/examples/imagenet/train.py + policy = jmp.get_policy("params=float32,compute=float32,output=float32") + hk.mixed_precision.set_policy(MODEL, policy) + + if mode == "train" or mode == "all": + print("Start training...") + + if cfg.logging.run_name is None: + run_prefix = f"{cfg.model.name}_{data_train.name}" + data_and_time = datetime.today().strftime("%Y%m%d-%H%M%S") + cfg.logging.run_name = f"{run_prefix}_{data_and_time}" + + cfg.model.model_dir = os.path.join(cfg.logging.ckp_dir, cfg.logging.run_name) + os.makedirs(cfg.model.model_dir, exist_ok=True) + os.makedirs(os.path.join(cfg.model.model_dir, "best"), exist_ok=True) + with open(os.path.join(cfg.model.model_dir, "config.yaml"), "w") as f: + cfg.dump(stream=f) + with open(os.path.join(cfg.model.model_dir, "best", "config.yaml"), "w") as f: + cfg.dump(stream=f) + + trainer = Trainer(model, case, data_train, data_valid) + _, _, _ = trainer( + step_max=cfg.train.step_max, + load_checkpoint=old_model_dir, + store_checkpoint=cfg.model.model_dir, + ) + + if mode == "infer" or mode == "all": + print("Start inference...") + + if mode == "infer": + model_dir = cfg.model.model_dir + if mode == "all": + model_dir = os.path.join(cfg.model.model_dir, "best") + assert osp.isfile(os.path.join(model_dir, "params_tree.pkl")) + + cfg.eval.rollout_dir = model_dir.replace("ckp", "rollout") + os.makedirs(cfg.eval.rollout_dir, exist_ok=True) + + if cfg.eval.n_trajs_infer is None: + cfg.eval.n_trajs_infer = cfg.eval.n_trajs_train + + assert model_dir, "model_dir must be specified for inference." + metrics = infer( + model, + case, + data_test if is_test else data_valid, + load_checkpoint=model_dir, + ) + + split = "test" if is_test else "valid" + print(f"Metrics of {model_dir} on {split} split:") + print(averaged_metrics(metrics)) + + def setup_data(cfg) -> Tuple[H5Dataset, H5Dataset, Namespace]: data_dir = cfg.data_dir ckp_dir = cfg.logging.ckp_dir rollout_dir = cfg.eval.rollout_dir input_seq_length = cfg.model.input_seq_length n_rollout_steps = cfg.eval.n_rollout_steps - neighbor_list_backend = cfg.neighbors.backend if not osp.isabs(data_dir): data_dir = osp.join(os.getcwd(), data_dir) - dataset_name = osp.basename(data_dir.split("/")[-1]) if ckp_dir is not None: os.makedirs(ckp_dir, exist_ok=True) if rollout_dir is not None: @@ -36,21 +127,18 @@ def setup_data(cfg) -> Tuple[H5Dataset, H5Dataset, Namespace]: dataset_path=data_dir, input_seq_length=input_seq_length, extra_seq_length=cfg.optimizer.pushforward.unrolls[-1], - nl_backend=neighbor_list_backend, ) data_valid = H5Dataset( "valid", dataset_path=data_dir, input_seq_length=input_seq_length, extra_seq_length=n_rollout_steps, - nl_backend=neighbor_list_backend, ) data_test = H5Dataset( "test", dataset_path=data_dir, input_seq_length=input_seq_length, extra_seq_length=n_rollout_steps, - nl_backend=neighbor_list_backend, ) # TODO find another way to set these @@ -64,7 +152,7 @@ def setup_data(cfg) -> Tuple[H5Dataset, H5Dataset, Namespace]: f"exceeds eval_n_trajs ({cfg.eval.n_trajs_train})" ) - return data_train, data_valid, data_test, dataset_name + return data_train, data_valid, data_test def setup_model( diff --git a/lagrangebench/train/trainer.py b/lagrangebench/train/trainer.py index b87894d..0204519 100644 --- a/lagrangebench/train/trainer.py +++ b/lagrangebench/train/trainer.py @@ -11,9 +11,9 @@ import optax from jax import vmap from torch.utils.data import DataLoader -from wandb.wandb_run import Run -from lagrangebench.config import cfg +import wandb +from lagrangebench.config import cfg, cfg_to_dict from lagrangebench.data import H5Dataset from lagrangebench.data.utils import numpy_collate from lagrangebench.evaluate import MetricsComputer, averaged_metrics, eval_rollout @@ -172,7 +172,6 @@ def _train( opt_state: Optional[optax.OptState] = None, store_checkpoint: Optional[str] = None, load_checkpoint: Optional[str] = None, - wandb_run: Optional[Run] = None, ) -> Tuple[hk.Params, hk.State, optax.OptState]: """ Training loop. @@ -187,7 +186,6 @@ def _train( opt_state: Optional optimizer state. store_checkpoint: Checkpoints destination. Without it params aren't saved. load_checkpoint: Initial checkpoint directory. If provided resumes training. - wandb_run: Wandb run. Returns: Tuple containing the final model parameters, state and optimizer state. @@ -228,9 +226,26 @@ def _train( key, subkey = jax.random.split(key, 2) params, state = model.init(subkey, (features, particle_type[0])) - if wandb_run is not None: - wandb_run.log({"info/num_params": get_num_params(params)}, 0) - wandb_run.log({"info/step_start": step}, 0) + # start logging + if cfg.logging.wandb: + cfg_dict = cfg_to_dict(cfg) + cfg_dict["info"] = { + "dataset_name": data_train.name, + "len_train": len(data_train), + "len_eval": len(data_valid), + "num_params": get_num_params(params).item(), + "step_start": step, + } + + wandb_run = wandb.init( + project=cfg.logging.wandb_project, + entity=cfg.logging.wandb_entity, + name=cfg.logging.run_name, + config=cfg_dict, + save_code=True, + ) + else: + wandb_run = None # initialize optimizer state if opt_state is None: @@ -347,6 +362,9 @@ def _train( if step == step_max + 1: break + if cfg.logging.wandb: + wandb.finish() + return params, state, opt_state return _train diff --git a/main.py b/main.py index f49b2f7..9d33c37 100644 --- a/main.py +++ b/main.py @@ -14,10 +14,10 @@ def cli_arguments(): "--gpu", type=int, required=False, help="CUDA device ID to use." ) parser.add_argument( - "--f64", + "--f32", required=False, action=argparse.BooleanOptionalAction, - help="Whether to use double precision.", + help="Whether to use single precision for pre-/postprocessing. Default is f64.", ) parser.add_argument( "--xla_mem_fraction", @@ -54,7 +54,7 @@ def cli_arguments(): os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 from TensorFlow os.environ["CUDA_VISIBLE_DEVICES"] = str(cli_args.gpu) os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(cli_args.xla_mem_fraction) - if cli_args.f64: + if not cli_args.f32: from jax import config config.update("jax_enable_x64", True) @@ -69,6 +69,6 @@ def cli_arguments(): print(cfg.dump()) print("#" * 79) - from experiments.run import train_or_infer + from lagrangebench.runner import train_or_infer train_or_infer(cfg) From c234b53905c66c62e822a632e50d9fefb1d11218 Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Mon, 19 Feb 2024 19:53:10 +0000 Subject: [PATCH 08/13] in docs and __init__ read the version from pyproject.toml --- README.md | 10 +++++----- docs/conf.py | 6 +++++- lagrangebench/__init__.py | 5 ++++- poetry.lock | 26 +++++++++++++------------- pyproject.toml | 1 + requirements_cuda.txt | 2 +- 6 files changed, 29 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 276619b..c508482 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,7 @@ Although the current [`jax-metal==0.0.5` library](https://pypi.org/project/jax-m A general tutorial is provided in the example notebook "Training GNS on the 2D Taylor Green Vortex" under `./notebooks/tutorial.ipynb` on the [LagrangeBench repository](https://github.com/tumaer/lagrangebench). The notebook covers the basics of LagrangeBench, such as loading a dataset, setting up a case, training a model from scratch and evaluating its performance. ### Running in a local clone (`main.py`) -Alternatively, experiments can also be set up with `main.py`, based on extensive YAML config files and cli arguments (check [`configs/`](configs/)). By default, the arguments have priority as: 1) passed cli arguments, 2) YAML config and 3) [`defaults.py`](lagrangebench/defaults.py) (`lagrangebench` defaults). +Alternatively, experiments can also be set up with `main.py`, based on extensive YAML config files and cli arguments (check [`configs/`](configs/)). By default, the arguments have priority as 1) passed cli arguments, 2) YAML config and 3) [`defaults.py`](lagrangebench/defaults.py) (`lagrangebench` defaults). When loading a saved model with `--model_dir` the config from the checkpoint is automatically loaded and training is restarted. For more details check the [`experiments/`](experiments/) directory and the [`run.py`](experiments/run.py) file. @@ -127,7 +127,7 @@ The datasets are hosted on Zenodo under the DOI: [10.5281/zenodo.10021925](https ### Notebooks We provide three notebooks that show LagrangeBench functionalities, namely: - [`tutorial.ipynb`](notebooks/tutorial.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/lagrangebench/blob/main/notebooks/tutorial.ipynb), with a general overview of LagrangeBench library, with training and evaluation of a simple GNS model, -- [`datasets.ipynb`](notebooks/datasets.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/lagrangebench/blob/main/notebooks/datasets.ipynb), with more details and visualizations on the datasets, and +- [`datasets.ipynb`](notebooks/datasets.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/lagrangebench/blob/main/notebooks/datasets.ipynb), with more details and visualizations of the datasets, and - [`gns_data.ipynb`](notebooks/gns_data.ipynb) [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/tumaer/lagrangebench/blob/main/notebooks/gns_data.ipynb), showing how to train models within LagrangeBench on the datasets from the paper [Learning to Simulate Complex Physics with Graph Networks](https://arxiv.org/abs/2002.09405). ## Directory structure @@ -165,9 +165,9 @@ Welcome! We highly appreciate [Github issues](https://github.com/tumaer/lagrange You can also chat with us on [**Discord**](https://discord.gg/Ds8jRZ78hU). ### Contributing Guideline -If you want to contribute to this repository, you will need the dev depencencies, i.e. +If you want to contribute to this repository, you will need the dev dependencies, i.e. install the environment with `poetry install` without the ` --only main` flag. -Then, we also recommend you to install the pre-commit hooks +Then, we also recommend you install the pre-commit hooks if you don't want to manually run `pre-commit run` before each commit. To sum up: ```bash @@ -220,6 +220,6 @@ The associated datasets can be cited as: ### Publications -The following further publcations are based on the LagrangeBench codebase: +The following further publications are based on the LagrangeBench codebase: 1. [Learning Lagrangian Fluid Mechanics with E(3)-Equivariant Graph Neural Networks (GSI 2023)](https://arxiv.org/abs/2305.15603), A. P. Toshev, G. Galletti, J. Brandstetter, S. Adami, N. A. Adams diff --git a/docs/conf.py b/docs/conf.py index 589f76c..e73044e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -10,7 +10,11 @@ copyright = "2023, Chair of Aerodynamics and Fluid Mechanics, TUM" author = "Artur Toshev, Gianluca Galletti" -version = "0.0.1" +# read the version from pyproject.toml +import toml + +pyproject = toml.load("../pyproject.toml") +version = pyproject["tool"]["poetry"]["version"] # -- Path setup -------------------------------------------------------------- diff --git a/lagrangebench/__init__.py b/lagrangebench/__init__.py index ed803f2..3f004d6 100644 --- a/lagrangebench/__init__.py +++ b/lagrangebench/__init__.py @@ -24,4 +24,7 @@ "cfg", ] -__version__ = "0.0.1" +import toml + +pyproject = toml.load("pyproject.toml") +__version__ = pyproject["tool"]["poetry"]["version"] diff --git a/poetry.lock b/poetry.lock index e521c76..43a7c40 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.0 and should not be changed by hand. [[package]] name = "absl-py" @@ -2433,7 +2433,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -2441,16 +2440,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -2467,7 +2458,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -2475,7 +2465,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -3150,6 +3139,17 @@ files = [ ml-dtypes = ">=0.3.1" numpy = ">=1.16.0" +[[package]] +name = "toml" +version = "0.10.2" +description = "Python Library for Tom's Obvious, Minimal Language" +optional = false +python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" +files = [ + {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, + {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, +] + [[package]] name = "tomli" version = "2.0.1" @@ -3439,4 +3439,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<=3.11" -content-hash = "d49bfe19749c44cd94dd68d73b3658ea6144666668da557e922d4a95c118b0c0" +content-hash = "9fe394e52f5db4b405b0ab8f8ba4d444ca4cacc7b87ee2839fc3025ee01ecb09" diff --git a/pyproject.toml b/pyproject.toml index ac4e8c7..b2ee1fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ matscipy = "^0.8.0" torch = {version = "2.1.0+cpu", source = "torchcpu"} wget = "^3.2" yacs = "^0.1.8" +toml = "^0.10.2" [tool.poetry.group.dev.dependencies] # mypy = ">=1.8.0" - consider in the future diff --git a/requirements_cuda.txt b/requirements_cuda.txt index 5e90c0b..3c969e3 100644 --- a/requirements_cuda.txt +++ b/requirements_cuda.txt @@ -18,4 +18,4 @@ PyYAML torch==2.1.0+cpu wandb wget -yacs>=0.1.8 \ No newline at end of file +yacs>=0.1.8 From a4e509fe887684b16b79a0547ce39169bf5de6ab Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Mon, 19 Feb 2024 20:05:35 +0000 Subject: [PATCH 09/13] fix ruff warning --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b2ee1fc..b832e1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,7 +75,6 @@ url = "https://download.pytorch.org/whl/cpu" priority = "explicit" [tool.ruff] -ignore = ["F811", "E402"] exclude = [ ".git", ".venv", @@ -87,6 +86,7 @@ show-fixes = true line-length = 88 [tool.ruff.lint] +ignore = ["F811", "E402"] select = [ "E", # pycodestyle "F", # Pyflakes From 3ced17b8fb0f832096915cad2193c356da4c204f Mon Sep 17 00:00:00 2001 From: gerkone Date: Mon, 19 Feb 2024 23:08:34 +0100 Subject: [PATCH 10/13] ruff --- lagrangebench/train/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lagrangebench/train/trainer.py b/lagrangebench/train/trainer.py index 0204519..c239672 100644 --- a/lagrangebench/train/trainer.py +++ b/lagrangebench/train/trainer.py @@ -9,10 +9,10 @@ import jax.numpy as jnp import jraph import optax +import wandb from jax import vmap from torch.utils.data import DataLoader -import wandb from lagrangebench.config import cfg, cfg_to_dict from lagrangebench.data import H5Dataset from lagrangebench.data.utils import numpy_collate From ca2d8101b76708d6974c7aa6ec34d422ed816f78 Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Tue, 20 Feb 2024 02:44:07 +0000 Subject: [PATCH 11/13] cfg.main and only --config directly to main.py --- .gitignore | 2 +- .pre-commit-config.yaml | 2 +- configs/WaterDrop_2d/gns.yaml | 3 +- configs/dam_2d/gns.yaml | 3 +- configs/dam_2d/segnn.yaml | 3 +- configs/ldc_2d/gns.yaml | 3 +- configs/ldc_2d/segnn.yaml | 3 +- configs/ldc_3d/gns.yaml | 3 +- configs/ldc_3d/segnn.yaml | 3 +- configs/rpf_2d/egnn.yaml | 3 +- configs/rpf_2d/gns.yaml | 3 +- configs/rpf_2d/painn.yaml | 3 +- configs/rpf_2d/segnn.yaml | 3 +- configs/rpf_3d/egnn.yaml | 3 +- configs/rpf_3d/gns.yaml | 3 +- configs/rpf_3d/painn.yaml | 4 +- configs/rpf_3d/segnn.yaml | 3 +- configs/tgv_2d/gns.yaml | 3 +- configs/tgv_2d/segnn.yaml | 3 +- configs/tgv_3d/gns.yaml | 3 +- configs/tgv_3d/segnn.yaml | 3 +- lagrangebench/case_setup/case.py | 2 +- lagrangebench/config.py | 37 ++++++++++--- lagrangebench/data/data.py | 2 +- lagrangebench/evaluate/rollout.py | 2 +- lagrangebench/models/segnn.py | 16 +----- lagrangebench/runner.py | 4 +- lagrangebench/train/trainer.py | 2 +- main.py | 87 +++++++++++++++++++------------ neighbors_search/scaling.py | 2 +- tests/rollout_test.py | 4 +- 31 files changed, 134 insertions(+), 86 deletions(-) diff --git a/.gitignore b/.gitignore index c28dee5..5580ed4 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,7 @@ ckp/ rollout/ rollouts/ -wandb +wandb/ *.out datasets baselines diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fc49dbe..db02477 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: - id: check-yaml - id: requirements-txt-fixer - repo: https://github.com/astral-sh/ruff-pre-commit - rev: 'v0.1.8' + rev: 'v0.2.2' hooks: - id: ruff args: [ --fix ] diff --git a/configs/WaterDrop_2d/gns.yaml b/configs/WaterDrop_2d/gns.yaml index 64c2602..cb181fc 100644 --- a/configs/WaterDrop_2d/gns.yaml +++ b/configs/WaterDrop_2d/gns.yaml @@ -1,4 +1,5 @@ -data_dir: /tmp/datasets/WaterDrop +main: + data_dir: /tmp/datasets/WaterDrop model: name: gns diff --git a/configs/dam_2d/gns.yaml b/configs/dam_2d/gns.yaml index 3644fc0..7d10992 100644 --- a/configs/dam_2d/gns.yaml +++ b/configs/dam_2d/gns.yaml @@ -1,4 +1,5 @@ -data_dir: datasets/2D_DAM_5740_20kevery100 +main: + data_dir: datasets/2D_DAM_5740_20kevery100 model: name: gns diff --git a/configs/dam_2d/segnn.yaml b/configs/dam_2d/segnn.yaml index f8c8b74..2bfb40b 100644 --- a/configs/dam_2d/segnn.yaml +++ b/configs/dam_2d/segnn.yaml @@ -1,4 +1,5 @@ -data_dir: datasets/2D_DAM_5740_20kevery100 +main: + data_dir: datasets/2D_DAM_5740_20kevery100 model: name: segnn diff --git a/configs/ldc_2d/gns.yaml b/configs/ldc_2d/gns.yaml index 4574ef9..65da6cb 100644 --- a/configs/ldc_2d/gns.yaml +++ b/configs/ldc_2d/gns.yaml @@ -1,4 +1,5 @@ -data_dir: datasets/2D_LDC_2708_10kevery100 +main: + data_dir: datasets/2D_LDC_2708_10kevery100 model: name: gns diff --git a/configs/ldc_2d/segnn.yaml b/configs/ldc_2d/segnn.yaml index fb26eb4..a15cc6b 100644 --- a/configs/ldc_2d/segnn.yaml +++ b/configs/ldc_2d/segnn.yaml @@ -1,4 +1,5 @@ -data_dir: datasets/2D_LDC_2708_10kevery100 +main: + data_dir: datasets/2D_LDC_2708_10kevery100 model: name: segnn diff --git a/configs/ldc_3d/gns.yaml b/configs/ldc_3d/gns.yaml index 347bceb..b757e2f 100644 --- a/configs/ldc_3d/gns.yaml +++ b/configs/ldc_3d/gns.yaml @@ -1,4 +1,5 @@ -data_dir: datasets/3D_LDC_8160_10kevery100 +main: + data_dir: datasets/3D_LDC_8160_10kevery100 model: name: gns diff --git a/configs/ldc_3d/segnn.yaml b/configs/ldc_3d/segnn.yaml index 0433553..4d5da64 100644 --- a/configs/ldc_3d/segnn.yaml +++ b/configs/ldc_3d/segnn.yaml @@ -1,4 +1,5 @@ -data_dir: datasets/3D_LDC_8160_10kevery100 +main: + data_dir: datasets/3D_LDC_8160_10kevery100 model: name: segnn diff --git a/configs/rpf_2d/egnn.yaml b/configs/rpf_2d/egnn.yaml index 3a634d1..790b708 100644 --- a/configs/rpf_2d/egnn.yaml +++ b/configs/rpf_2d/egnn.yaml @@ -1,4 +1,5 @@ -data_dir: datasets/2D_RPF_3200_20kevery100 +main: + data_dir: datasets/2D_RPF_3200_20kevery100 model: name: egnn diff --git a/configs/rpf_2d/gns.yaml b/configs/rpf_2d/gns.yaml index c6740b0..82383ec 100644 --- a/configs/rpf_2d/gns.yaml +++ b/configs/rpf_2d/gns.yaml @@ -1,4 +1,5 @@ -data_dir: datasets/2D_RPF_3200_20kevery100 +main: + data_dir: datasets/2D_RPF_3200_20kevery100 model: name: gns diff --git a/configs/rpf_2d/painn.yaml b/configs/rpf_2d/painn.yaml index b9e3256..b05f189 100644 --- a/configs/rpf_2d/painn.yaml +++ b/configs/rpf_2d/painn.yaml @@ -1,4 +1,5 @@ -data_dir: datasets/2D_RPF_3200_20kevery100 +main: + data_dir: datasets/2D_RPF_3200_20kevery100 model: name: painn diff --git a/configs/rpf_2d/segnn.yaml b/configs/rpf_2d/segnn.yaml index 94370c5..7c510eb 100644 --- a/configs/rpf_2d/segnn.yaml +++ b/configs/rpf_2d/segnn.yaml @@ -1,4 +1,5 @@ -data_dir: datasets/2D_RPF_3200_20kevery100 +main: + data_dir: datasets/2D_RPF_3200_20kevery100 model: name: segnn diff --git a/configs/rpf_3d/egnn.yaml b/configs/rpf_3d/egnn.yaml index 1417075..25d6d0f 100644 --- a/configs/rpf_3d/egnn.yaml +++ b/configs/rpf_3d/egnn.yaml @@ -1,4 +1,5 @@ -data_dir: datasets/3D_RPF_8000_10kevery100 +main: + data_dir: datasets/3D_RPF_8000_10kevery100 model: name: egnn diff --git a/configs/rpf_3d/gns.yaml b/configs/rpf_3d/gns.yaml index 48c42f4..416a993 100644 --- a/configs/rpf_3d/gns.yaml +++ b/configs/rpf_3d/gns.yaml @@ -1,4 +1,5 @@ -data_dir: datasets/3D_RPF_8000_10kevery100 +main: + data_dir: datasets/3D_RPF_8000_10kevery100 model: name: gns diff --git a/configs/rpf_3d/painn.yaml b/configs/rpf_3d/painn.yaml index 7bcc4eb..27f735c 100644 --- a/configs/rpf_3d/painn.yaml +++ b/configs/rpf_3d/painn.yaml @@ -1,5 +1,5 @@ - -data_dir: datasets/3D_RPF_8000_10kevery100 +main: + data_dir: datasets/3D_RPF_8000_10kevery100 model: name: painn diff --git a/configs/rpf_3d/segnn.yaml b/configs/rpf_3d/segnn.yaml index cd68478..64df420 100644 --- a/configs/rpf_3d/segnn.yaml +++ b/configs/rpf_3d/segnn.yaml @@ -1,4 +1,5 @@ -data_dir: datasets/3D_RPF_8000_10kevery100 +main: + data_dir: datasets/3D_RPF_8000_10kevery100 model: name: segnn diff --git a/configs/tgv_2d/gns.yaml b/configs/tgv_2d/gns.yaml index edd4e71..289e849 100644 --- a/configs/tgv_2d/gns.yaml +++ b/configs/tgv_2d/gns.yaml @@ -1,4 +1,5 @@ -data_dir: datasets/2D_TGV_2500_10kevery100 +main: + data_dir: datasets/2D_TGV_2500_10kevery100 model: name: gns diff --git a/configs/tgv_2d/segnn.yaml b/configs/tgv_2d/segnn.yaml index 6c9d553..092d59e 100644 --- a/configs/tgv_2d/segnn.yaml +++ b/configs/tgv_2d/segnn.yaml @@ -1,4 +1,5 @@ -data_dir: datasets/2D_TGV_2500_10kevery100 +main: + data_dir: datasets/2D_TGV_2500_10kevery100 model: name: segnn diff --git a/configs/tgv_3d/gns.yaml b/configs/tgv_3d/gns.yaml index 99264a8..b286b7c 100644 --- a/configs/tgv_3d/gns.yaml +++ b/configs/tgv_3d/gns.yaml @@ -1,4 +1,5 @@ -data_dir: datasets/3D_TGV_8000_10kevery100 +main: + data_dir: datasets/3D_TGV_8000_10kevery100 model: name: gns diff --git a/configs/tgv_3d/segnn.yaml b/configs/tgv_3d/segnn.yaml index 72a6e09..dddaeef 100644 --- a/configs/tgv_3d/segnn.yaml +++ b/configs/tgv_3d/segnn.yaml @@ -1,4 +1,5 @@ -data_dir: datasets/3D_TGV_8000_10kevery100 +main: + data_dir: datasets/3D_TGV_8000_10kevery100 model: name: segnn diff --git a/lagrangebench/case_setup/case.py b/lagrangebench/case_setup/case.py index 28b624b..cc3d92f 100644 --- a/lagrangebench/case_setup/case.py +++ b/lagrangebench/case_setup/case.py @@ -85,7 +85,7 @@ def case_builder( magnitude_features = cfg.train.magnitude_features neighbor_list_backend = cfg.neighbors.backend neighbor_list_multiplier = cfg.neighbors.multiplier - dtype = cfg.dtype + dtype = cfg.main.dtype normalization_stats = get_dataset_stats(metadata, isotropic_norm, noise_std) diff --git a/lagrangebench/config.py b/lagrangebench/config.py index 9b3cec0..c72e7eb 100644 --- a/lagrangebench/config.py +++ b/lagrangebench/config.py @@ -19,15 +19,23 @@ def defaults(cfg): if cfg is None: raise ValueError("cfg should be a yacs CfgNode") - # random seed - cfg.seed = 0 - # data type for preprocessing - cfg.dtype = "float64" + # global and hardware-related configs + main = CN() + # One of "train", "infer" or "all" (= both) + main.mode = "all" + # random seed + main.seed = 0 + # data type for preprocessing. One of "float32" or "float64" + main.dtype = "float64" + # gpu device + main.gpu = 0 + # XLA memory fraction to be preallocated + main.xla_mem_fraction = 0.7 # data directory - cfg.data_dir = None - # run, evaluation or both - cfg.mode = "all" + main.data_dir = None + + cfg.main = main # model model = CN() @@ -162,7 +170,7 @@ def defaults(cfg): def check_cfg(cfg): - assert cfg.data_dir is not None, "cfg.data_dir must be specified." + assert cfg.main.data_dir is not None, "cfg.main.data_dir must be specified." assert ( cfg.train.step_max is not None and cfg.train.step_max > 0 ), "cfg.train.step_max must be specified and larger than 0." @@ -185,3 +193,16 @@ def cfg_to_dict(cfg: CN) -> Dict: # TODO find a better way defaults(cfg) + + +@custom_config +def segnn_config(cfg): + """SEGNN only parameters.""" + # Steerable attributes level + cfg.model.lmax_attributes = 1 + # Level of the hidden layer + cfg.model.lmax_hidden = 1 + # SEGNN normalization. instance, batch, none + cfg.model.segnn_norm = "none" + # SEGNN velocity aggregation. avg or last + cfg.model.velocity_aggregate = "avg" diff --git a/lagrangebench/data/data.py b/lagrangebench/data/data.py index 3ab018f..99df047 100644 --- a/lagrangebench/data/data.py +++ b/lagrangebench/data/data.py @@ -62,7 +62,7 @@ def __init__( N-step MSE loss we are interested in, e.g. for best model checkpointing. """ if dataset_path is None: - dataset_path = cfg.data_dir + dataset_path = cfg.main.data_dir if dataset_path.endswith("/"): # remove trailing slash in dataset path dataset_path = dataset_path[:-1] diff --git a/lagrangebench/evaluate/rollout.py b/lagrangebench/evaluate/rollout.py index 200bca9..28e50cf 100644 --- a/lagrangebench/evaluate/rollout.py +++ b/lagrangebench/evaluate/rollout.py @@ -348,7 +348,7 @@ def infer( else: params, state, _, _ = load_haiku(load_checkpoint) - key, seed_worker, generator = set_seed(cfg.seed) + key, seed_worker, generator = set_seed(cfg.main.seed) loader_test = DataLoader( dataset=data_test, diff --git a/lagrangebench/models/segnn.py b/lagrangebench/models/segnn.py index c0cf25a..90964f2 100644 --- a/lagrangebench/models/segnn.py +++ b/lagrangebench/models/segnn.py @@ -20,7 +20,7 @@ from e3nn_jax import Irreps, IrrepsArray from jax.tree_util import Partial, tree_map -from lagrangebench.config import cfg, custom_config +from lagrangebench.config import cfg from lagrangebench.utils import NodeType from .base import BaseModel @@ -594,17 +594,3 @@ def __call__( nodes = self._decoder(st_graph) out = self._postprocess(nodes, dim) return out - - -# TODO figure out why this is not working -@custom_config -def segnn_config(cfg): - """SEGNN only parameters.""" - # Steerable attributes level - cfg.model.lmax_attributes = 1 - # Level of the hidden layer - cfg.model.lmax_hidden = 1 - # SEGNN normalization. instance, batch, none - cfg.model.segnn_norm = "none" - # SEGNN velocity aggregation. avg or last - cfg.model.velocity_aggregate = "avg" diff --git a/lagrangebench/runner.py b/lagrangebench/runner.py index eb30921..beacb6b 100644 --- a/lagrangebench/runner.py +++ b/lagrangebench/runner.py @@ -21,7 +21,7 @@ def train_or_infer(cfg): - mode = cfg.mode + mode = cfg.main.mode old_model_dir = cfg.model.model_dir is_test = cfg.eval.test @@ -108,7 +108,7 @@ def train_or_infer(cfg): def setup_data(cfg) -> Tuple[H5Dataset, H5Dataset, Namespace]: - data_dir = cfg.data_dir + data_dir = cfg.main.data_dir ckp_dir = cfg.logging.ckp_dir rollout_dir = cfg.eval.rollout_dir input_seq_length = cfg.model.input_seq_length diff --git a/lagrangebench/train/trainer.py b/lagrangebench/train/trainer.py index c239672..057dc5d 100644 --- a/lagrangebench/train/trainer.py +++ b/lagrangebench/train/trainer.py @@ -125,7 +125,7 @@ def Trainer( loss_weight = LossConfig(**dict(cfg.optimizer.loss_weight)) pushforward = cfg.optimizer.pushforward - base_key, seed_worker, generator = set_seed(cfg.seed) + base_key, seed_worker, generator = set_seed(cfg.main.seed) # dataloaders loader_train = DataLoader( diff --git a/main.py b/main.py index 9d33c37..6e52b1f 100644 --- a/main.py +++ b/main.py @@ -3,64 +3,87 @@ def cli_arguments(): + """Inspired by https://stackoverflow.com/a/51686813""" parser = argparse.ArgumentParser() - group = parser.add_mutually_exclusive_group(required=True) # config arguments - group.add_argument("-c", "--config", type=str, help="Path to the config yaml.") - group.add_argument("--model_dir", type=str, help="Path to the model checkpoint.") - # misc arguments - parser.add_argument( - "--gpu", type=int, required=False, help="CUDA device ID to use." - ) - parser.add_argument( - "--f32", - required=False, - action=argparse.BooleanOptionalAction, - help="Whether to use single precision for pre-/postprocessing. Default is f64.", - ) - parser.add_argument( - "--xla_mem_fraction", - type=float, - required=False, - default=0.7, - help="Fraction of XLA memory to use.", - ) - # optional config overrides + parser.add_argument("-c", "--config", type=str, help="Path to the config yaml.") parser.add_argument( "extra", default=None, - nargs=argparse.REMAINDER, - help="Extra config overrides as key value pairs.", + nargs="*", + help="Extra, optional config overrides. Need to be separated from '--config' " + "by the pseudo-argument '--'.", ) args = parser.parse_args() if args.extra is None: args.extra = [] + args.extra = preprocess_extras(args.extra) return args +def preprocess_extras(extras): + """Preprocess extras. + + args.extra can be in any of the following 6 formats: + {"--","-",""}key{" ","="}value + + Here we clean up {"--", "-", "="} and split into key value pairs + """ + + temp = [] + for arg in extras: + if arg.startswith("--"): # remove preceding "--" + arg = arg[2:] + elif arg.startswith("-"): # remove preceding "-" + arg = arg[1:] + temp += arg.split("=") # split key value pairs + + return temp + + +def import_cfg(config_path, extras): + """Import cfg without executing lagrangebench.__init__(). + + Based on: + https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly + """ + import importlib.util + + spec = importlib.util.spec_from_file_location("temp", "lagrangebench/config.py") + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + cfg = module.cfg + load_cfg = module.load_cfg + load_cfg(cfg, config_path, extras) + return cfg + + if __name__ == "__main__": cli_args = cli_arguments() - if cli_args.config is not None: # to (re)start training + if cli_args.config is not None: # start from config.yaml config_path = cli_args.config.strip() - elif cli_args.model_dir is not None: # to run inference - config_path = os.path.join(cli_args.model_dir, "config.yaml") - cli_args.extra.extend(["model.model_dir", cli_args.model_dir]) + elif "model.model_dir" in cli_args.extra: # start from a checkpoint + model_dir = cli_args.extra[cli_args.extra.index("model.model_dir") + 1] + config_path = os.path.join(model_dir, "config.yaml") + + # load cfg without executing lagrangebench.__init__() -> temporary cfg for cuda + cfg = import_cfg(config_path, cli_args.extra) # specify cuda device os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 from TensorFlow - os.environ["CUDA_VISIBLE_DEVICES"] = str(cli_args.gpu) - os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(cli_args.xla_mem_fraction) - if not cli_args.f32: + os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.main.gpu) + os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(cfg.main.xla_mem_fraction) + if cfg.main.dtype == "float64": from jax import config config.update("jax_enable_x64", True) - else: - cli_args.extra.extend(["dtype", "float32"]) + # load cfg once again, this time executing lagrangebench.__init__() -> global cfg from lagrangebench.config import cfg, load_cfg load_cfg(cfg, config_path, cli_args.extra) diff --git a/neighbors_search/scaling.py b/neighbors_search/scaling.py index cd47352..adb4434 100644 --- a/neighbors_search/scaling.py +++ b/neighbors_search/scaling.py @@ -28,7 +28,7 @@ def update_wrapper(neighbors_old, r_new): def compute_neighbors(args): Nx = args.Nx - mode = args.mode + mode = args.main.mode nl_backend = args.nl_backend num_partitions = args.num_partitions print(f"Start with Nx={Nx}, mode={mode}, backend={nl_backend}") diff --git a/tests/rollout_test.py b/tests/rollout_test.py index 48ba729..a557b56 100644 --- a/tests/rollout_test.py +++ b/tests/rollout_test.py @@ -24,7 +24,7 @@ @custom_config def eval_test_config(cfg): # setup the configuration - cfg.data_dir = "tests/3D_LJ_3_1214every1" # Lennard-Jones dataset + cfg.main.data_dir = "tests/3D_LJ_3_1214every1" # Lennard-Jones dataset cfg.model.input_seq_length = 3 cfg.metrics = ["mse"] cfg.eval.n_rollout_steps = 100 @@ -38,7 +38,7 @@ class TestInferBuilder(unittest.TestCase): def setUp(self): data_valid = H5Dataset( split="valid", - dataset_path=cfg.data_dir, + dataset_path=cfg.main.data_dir, name="lj3d", input_seq_length=cfg.model.input_seq_length, extra_seq_length=cfg.eval.n_rollout_steps, From ab0e7f1bc1eb927bfc17f5240d86b7a63b6b668e Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Tue, 20 Feb 2024 03:01:25 +0000 Subject: [PATCH 12/13] solve wandb isort issue with explicit third-party --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index b832e1f..42b1bf4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,9 @@ select = [ # "D", # pydocstyle - consider in the future ] +[tool.ruff.lint.isort] +known-third-party = ["wandb"] + [tool.pytest.ini_options] testpaths = "tests/" addopts = "--cov=lagrangebench --cov-fail-under=50" From 776397d475e25e52868f2edddab1e958487e8246 Mon Sep 17 00:00:00 2001 From: Artur Toshev Date: Tue, 20 Feb 2024 11:36:01 +0000 Subject: [PATCH 13/13] remove pseudo-argument -- --- main.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/main.py b/main.py index 6e52b1f..89fde31 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,6 @@ import argparse import os +from typing import List def cli_arguments(): @@ -8,37 +9,30 @@ def cli_arguments(): # config arguments parser.add_argument("-c", "--config", type=str, help="Path to the config yaml.") - parser.add_argument( - "extra", - default=None, - nargs="*", - help="Extra, optional config overrides. Need to be separated from '--config' " - "by the pseudo-argument '--'.", - ) - - args = parser.parse_args() - if args.extra is None: - args.extra = [] - args.extra = preprocess_extras(args.extra) + + args, extras = parser.parse_known_args() + if extras is None: + extras = [] + args.extra = preprocess_extras(extras) return args -def preprocess_extras(extras): +def preprocess_extras(extras: List[str]): """Preprocess extras. - args.extra can be in any of the following 6 formats: - {"--","-",""}key{" ","="}value + Args: + extras: key value pairs in any of the following formats: + `--key value`, `--key=value`, `key value`, `key=value` - Here we clean up {"--", "-", "="} and split into key value pairs + Return: + All key value pairs formatted as `key value` """ temp = [] for arg in extras: if arg.startswith("--"): # remove preceding "--" arg = arg[2:] - elif arg.startswith("-"): # remove preceding "-" - arg = arg[1:] temp += arg.split("=") # split key value pairs return temp