-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
94 lines (72 loc) · 3.62 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from sepcodec.dataloading.datamodules import SepDataModule
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.cli import SaveConfigCallback
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import yaml
import os
class LoggerSaveConfigCallback(SaveConfigCallback):
def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
if trainer.logger is not None:
experiment_name = trainer.logger.experiment.name
# Required for proper reproducibility
config = self.parser.dump(self.config, skip_none=False)
with open(self.config_filename, "r") as config_file:
config = yaml.load(config_file, Loader=yaml.FullLoader)
trainer.logger.experiment.config.update(config, allow_val_change=True)
with open(os.path.join(os.path.join(self.config['ckpt_path'], experiment_name), "config.yaml"), 'w') as outfile:
yaml.dump(config, outfile, default_flow_style=False)
#instanciate a ModelCheckpoint saving the model every epoch
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_argument("--log", default=False)
parser.add_argument("--log_model", default=False)
parser.add_argument("--ckpt_path", default="checkpoints")
parser.add_argument("--resume_id", default=None)
parser.add_argument("--resume_from_checkpoint", default=None)
if __name__ == "__main__":
cli = MyLightningCLI(model_class=None, datamodule_class=SepDataModule, seed_everything_default=123,
run=False, save_config_callback=LoggerSaveConfigCallback, save_config_kwargs={"overwrite": True},)
cli.instantiate_classes()
if cli.config.log:
logger = WandbLogger(project="HILSep", id=cli.config.resume_id)
experiment_name = logger.experiment.name
ckpt_path = cli.config.ckpt_path
else:
logger = None
cli.trainer.logger = logger
try:
if not os.path.exists(os.path.join(ckpt_path, experiment_name)):
os.makedirs(os.path.join(ckpt_path, experiment_name))
except:
pass
if logger is not None:
recent_callback = ModelCheckpoint(
dirpath=os.path.join(cli.config.ckpt_path, experiment_name),
filename='checkpoint-{step}', # This means all checkpoints are saved, not just the top k
every_n_epochs=200 # Replace with your desired value
)
best_callback = ModelCheckpoint(
monitor='train_loss_epoch',
dirpath=os.path.join(cli.config.ckpt_path, experiment_name),
filename='best-{step}',
save_top_k=1,
mode='min',
every_n_epochs=1
)
best_val_callback = ModelCheckpoint(
monitor='val_loss',
dirpath=os.path.join(cli.config.ckpt_path, experiment_name),
filename='best-val-{step}',
save_top_k=1,
mode='min',
every_n_epochs=1
)
# early_stopping_callback = EarlyStopping(
# monitor='val_loss',
# patience=5,
# mode='min'
# )
cli.trainer.callbacks = cli.trainer.callbacks[:-1]+[recent_callback, best_callback, best_val_callback]
# cli.trainer.fit(model=cli.model, datamodule=cli.datamodule, ckpt_path=cli.config.resume_from_checkpoint)