diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_knowledge_distillation.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_knowledge_distillation.py index b67d858ed..db93d29ec 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_knowledge_distillation.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_knowledge_distillation.py @@ -72,7 +72,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ required_keys.update(("tokens", "position_ids")) if parallel_state.is_pipeline_last_stage(): - required_keys.update(("labels", "loss_mask")) + required_keys.update(("labels", "loss_mask", "topk_logits", "topk_token_ids")) batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} @@ -83,7 +83,9 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ tokens = batch["tokens"] labels = batch["labels"] - loss_mask = batch["loss_mask"].clamp(min=0, max=1) + loss_mask = batch["loss_mask"] + if loss_mask is not None: + loss_mask = loss_mask.clamp(min=0, max=1) target_topk_logits = batch["topk_logits"] target_topk_token_ids = batch["topk_token_ids"] # Model forward pass diff --git a/tests/functional/kd.sh b/tests/functional/kd.sh index 83e472f52..fa7c2a4b9 100644 --- a/tests/functional/kd.sh +++ b/tests/functional/kd.sh @@ -83,7 +83,7 @@ torchrun --nproc-per-node 2 ${GPFS}/examples/nlp/gpt/train_gpt_knowledge_distill exp_manager.create_checkpoint_callback=False \ model.data.num_workers=2 \ ++model.tensor_model_parallel_size=1 \ - ++model.pipeline_model_parallel_size=1 \ + ++model.pipeline_model_parallel_size=2 \ exp_manager.explicit_log_dir=${RESULTS_DIR} \ ++model.activations_checkpoint_granularity=full \ ++model.activations_checkpoint_method=uniform \