-
Notifications
You must be signed in to change notification settings - Fork 4
/
trellis_mfm.yaml
70 lines (60 loc) · 1.2 KB
/
trellis_mfm.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# @package _global_
defaults:
- override /model: trellis_gnn_mlp_mfm.yaml
- override /datamodule: trellis_dataloader.yaml
- override /logger:
- csv
- wandb
- override /trainer: gpu
hydra:
launcher:
name: "mfm_trellis"
seed: 0
datamodule:
batch_size: 1
ivp_batch_size: 1024
split: patients
num_components: null
plot_pca: False
use_small_exp_num: False
seed: 0
model:
name: mfm_trellis
flow_lr: 1e-4
gnn_lr: 1e-4
dim: 43
num_hidden: 512
num_layers_decoder: 7
num_hidden_gnn: 128
num_layers_gnn: 2
knn_k: 100
num_treat_conditions: 11
num_cell_conditions: 2
base: source
pca_space_eval: True # only used of num_components is not null
run_validation: True
integrate_time_steps: 500
seed: 0
trainer:
max_epochs: 1500
min_epochs: 1500
check_val_every_n_epoch: 1500
accelerator: gpu
devices: 1
#log_every_n_steps: 5
checkpoint:
filename: "ckpt"
slurm:
requeue: True
# NOTE: early stopping is NOT being used
callbacks:
model_checkpoint:
monitor: "val/2-Wasserstein-PDO"
early_stopping:
monitor: "val/2-Wasserstein-PDO"
mode: "min"
patience: 100
min_delta: 0
logger:
wandb:
tags: ["trellis", "mfm"]