Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: arendu <[email protected]>
  • Loading branch information
arendu committed Oct 30, 2024
1 parent bd590d6 commit 3ed1cb1
Show file tree
Hide file tree
Showing 4 changed files with 428 additions and 4 deletions.
1 change: 1 addition & 0 deletions examples/nlp/gpt/conf/gpt_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ model:
micro_batch_size: 1
global_batch_size: 64
megatron_amp_O2: True
mamba_hybrid: False

dpo:
# This default value ensures there are no numeric differences beween trained and reference policies when computing log probs.
Expand Down
4 changes: 2 additions & 2 deletions examples/nlp/gpt/train_gpt_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from nemo.utils.exp_manager import exp_manager
from nemo_aligner.algorithms.dpo import DPOTrainer, dpo_custom_collate
from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets
from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel
from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel, MegatronMambaDPOModel
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
CustomLoggerWrapper,
Expand Down Expand Up @@ -53,7 +53,7 @@ def main(cfg) -> None:
logger = CustomLoggerWrapper(trainer.loggers)

ptl_model = load_from_nemo(
MegatronGPTDPOModel,
MegatronMambaDPOModel if cfg.model.mamba_hybrid else MegatronGPTDPOModel,
cfg.model,
trainer,
strict=True,
Expand Down
Loading

0 comments on commit 3ed1cb1

Please sign in to comment.