diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py index d471ae367..8888456ce 100644 --- a/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py @@ -111,6 +111,8 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ required_keys.add("attention_mask") if "cu_seqlens" in batch: required_keys.add("cu_seqlens") + required_keys.add("max_seqlen") + required_keys.add("cu_seqlens_argmin") if parallel_state.is_pipeline_first_stage(): if packed: