Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add context parallel support for SFT #420

Closed
wants to merge 17 commits into from
Closed

Conversation

ashors1
Copy link
Collaborator

@ashors1 ashors1 commented Nov 27, 2024

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Changelog

  • Please update the CHANGELOG.md under next version with high level changes in this PR.

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this 

Before your PR is "Ready for review"

Pre checks:

Checklist when contributing a new algorithm

  • Does the trainer resume and restore model state all states?
  • Does the trainer support all parallelism techniques(PP, TP, DP)?
  • Does the trainer support max_steps=-1 and validation?
  • Does the trainer only call APIs defined in alignable_interface.py?
  • Does the trainer have proper logging?

Additional Information

  • Related to # (issue)

@ashors1 ashors1 marked this pull request as draft November 27, 2024 06:25
@ashors1 ashors1 changed the title Add context parallel support for SFT feat: add context parallel support for SFT Nov 27, 2024
ashors1 and others added 3 commits November 29, 2024 21:59
Signed-off-by: ashors1 <[email protected]>
for more information, see https://pre-commit.ci

Signed-off-by: NeMo-Aligner CI <[email protected]>
Signed-off-by: Anna Shors <[email protected]>
@ashors1 ashors1 marked this pull request as ready for review December 2, 2024 20:40
@ashors1 ashors1 added the Run CICD Set + un-set to retrigger label Dec 2, 2024
@ashors1 ashors1 requested a review from terrykong December 2, 2024 20:40
@ashors1
Copy link
Collaborator Author

ashors1 commented Dec 2, 2024

Note that CP support for SFT was recently added to NeMo. We need a NeMo commit at least as recent as 8c921dc19a905d8b5a0f90f6e2a34607c2e0660d

@ashors1 ashors1 added Run CICD Set + un-set to retrigger and removed Run CICD Set + un-set to retrigger labels Dec 2, 2024
@ashors1 ashors1 added Run CICD Set + un-set to retrigger and removed Run CICD Set + un-set to retrigger labels Dec 2, 2024
@ashors1 ashors1 changed the base branch from main to dev December 2, 2024 23:29
@ashors1 ashors1 changed the base branch from dev to main December 2, 2024 23:29
@ashors1 ashors1 added Run CICD Set + un-set to retrigger and removed Run CICD Set + un-set to retrigger labels Dec 3, 2024
@ashors1 ashors1 added Run CICD Set + un-set to retrigger and removed Run CICD Set + un-set to retrigger labels Dec 3, 2024
Signed-off-by: ashors1 <[email protected]>
@ashors1 ashors1 added Run CICD Set + un-set to retrigger and removed Run CICD Set + un-set to retrigger labels Dec 4, 2024
Copy link
Collaborator

@terrykong terrykong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO:

restore_from_path: ??? # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training.
sync_batch_comm: False
megatron_amp_O2: False
encoder_seq_length: 4096 # the sequence length of the encoder model, it will be overwriten by loaded GPT model
transformer_engine: True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my education, why do we need to specify this now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't actually think it's necessary since TE is enabled by default. I just wanted to make explicit the fact that we were using TE. But I will remove this

Comment on lines +391 to +396
pad_seq_length_to_mult = 16
if model_cfg is not None:
pad_seq_length_to_mult = (
8 * model_cfg.get("tensor_model_parallel_size", 1) if model_cfg.get("sequence_parallel", False) else 16
)
pad_seq_length_to_mult *= model_cfg.get("context_parallel_size", 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to that fp8 comment above, should this be:

Suggested change
pad_seq_length_to_mult = 16
if model_cfg is not None:
pad_seq_length_to_mult = (
8 * model_cfg.get("tensor_model_parallel_size", 1) if model_cfg.get("sequence_parallel", False) else 16
)
pad_seq_length_to_mult *= model_cfg.get("context_parallel_size", 1)
pad_seq_length_to_mult = 16
if model_cfg is not None:
if model_cfg.get("sequence_parallel", False):
pad_seq_length_to_mult = math.lcm(pad_seq_length_to_mult, model_cfg.get("tensor_model_parallel_size", 1))
pad_seq_length_to_mult *= model_cfg.get("context_parallel_size", 1)

? From the comment it sounds like if someone is doing fp8 SFT with TP=1 and set sequence_parallel, then the padding would be too small

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That chunk was taken directly from here: https://github.com/NVIDIA/NeMo/blob/b847bf75c371931e4f17ea426741c1d023afa0c0/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py#L262-L268, but the code does seem to contradict the comment. I'll follow up with the TE team

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the TE team: "when SP=True, the dimensions are flipped so the sequence dimension is first. So we only need to make sure it's divisible by 8 after the TP split to comply with TE's expectations."

@@ -88,7 +88,7 @@ def get_loss_and_metrics(self, batch, forward_only):
set_sync_funcs(self, forward_only)

fwd_bwd_function = get_forward_backward_func()
fwd_loss_fn = self.get_forward_output_and_loss_func(forward_only)
fwd_loss_fn = self.get_forward_output_and_loss_func(forward_only, tuning=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does tuning do?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It controls the keys that are returned in the batch: https://github.com/NVIDIA/NeMo/blob/b847bf75c371931e4f17ea426741c1d023afa0c0/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L1211-L1222. If tuning=False, we don't return the keys that are necessary for sequence packing (and thus CP) in TE.

Note also that tuning is set to True in NeMo's SFT: https://github.com/NVIDIA/NeMo/blob/b847bf75c371931e4f17ea426741c1d023afa0c0/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py#L407

@terrykong terrykong changed the base branch from main to dev December 4, 2024 00:58
@ashors1 ashors1 changed the base branch from dev to main December 4, 2024 06:43
@ashors1 ashors1 removed the Run CICD Set + un-set to retrigger label Dec 4, 2024
@ashors1
Copy link
Collaborator Author

ashors1 commented Dec 4, 2024

closing in favor of #430. @terrykong I'll address your comments there

@ashors1 ashors1 closed this Dec 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants