Skip to content

Commit

Permalink
Merge pull request #24 from walledata/fix-resume
Browse files Browse the repository at this point in the history
Fix error on resume training
  • Loading branch information
Stardust-minus authored Sep 8, 2023
2 parents 67f4d08 + 82409e2 commit 1f3ee63
Showing 1 changed file with 17 additions and 21 deletions.
38 changes: 17 additions & 21 deletions train_ms.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,24 +172,20 @@ def run():
net_dur_disc = DDP(net_dur_disc, device_ids=[rank], find_unused_parameters=True)
try:
if net_dur_disc is not None:
_, _, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"),
net_dur_disc,
optim_dur_disc,
skip_optimizer=True,
)
_, optim_g, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"),
net_g,
optim_g,
skip_optimizer=True,
)
_, optim_d, _, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"),
net_d,
optim_d,
skip_optimizer=True,
)
_, _, dur_resume_lr, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(hps.model_dir, "DUR_*.pth"), net_dur_disc, optim_dur_disc,
skip_optimizer=True)
_, optim_g, g_resume_lr, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g,
optim_g, skip_optimizer=True)
_, optim_d, d_resume_lr, epoch_str = utils.load_checkpoint(
utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d,
optim_d, skip_optimizer=True)
if not optim_g.param_groups[0].get("initial_lr"):
optim_g.param_groups[0]["initial_lr"] = g_resume_lr
if not optim_d.param_groups[0].get("initial_lr"):
optim_d.param_groups[0]["initial_lr"] = d_resume_lr


epoch_str = max(epoch_str, 1)
global_step = (epoch_str - 1) * len(train_loader)
Expand All @@ -205,9 +201,9 @@ def run():
optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
)
if net_dur_disc is not None:
scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR(
optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
)
if not optim_dur_disc.param_groups[0].get("initial_lr"):
optim_dur_disc.param_groups[0]["initial_lr"] = dur_resume_lr
scheduler_dur_disc = torch.optim.lr_scheduler.ExponentialLR(optim_dur_disc, gamma=hps.train.lr_decay, last_epoch=epoch_str-2)
else:
scheduler_dur_disc = None
scaler = GradScaler(enabled=hps.train.fp16_run)
Expand Down

0 comments on commit 1f3ee63

Please sign in to comment.