Skip to content

Commit

Permalink
Geshen/upgrade to 24 05 (#187)
Browse files Browse the repository at this point in the history
* upgrade to latest te and mcore

Signed-off-by: Gerald Shen <[email protected]>

* add changelog and bump dockerfile

Signed-off-by: Gerald Shen <[email protected]>

* update package info

Signed-off-by: Gerald Shen <[email protected]>

* update changelog

Signed-off-by: Gerald Shen <[email protected]>

---------

Signed-off-by: Gerald Shen <[email protected]>
  • Loading branch information
gshennvm authored Jun 3, 2024
1 parent 5b898b6 commit 0baba99
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 41 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Next Version]

### New features and optimizations

### Breaking changes

### Bug Fixes

## [0.3.1] - 2024-05
- SPIN: added `rollout_micro_batch_size` parameter which allows users to set the batch size for doing generation during SPIN training.
previously the generation batch size was automatically set to the data parallel size (DP) of the model

Expand All @@ -25,6 +33,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Fixed issue when `model.encoder_seq_length` is mismatched with `model.data.train_ds.max_seq_length` in SFT and SPIN.
- Delete MegatronPretrainingRandomSampler from Aligner since it has been upstreamed into NeMo

## [0.3.0] - 2024-05

### New features and optimizations
- Special TRT-LLM release. See [Accelerated-RLHF](https://github.com/NVIDIA/NeMo-Aligner/blob/v0.3.0.trtllm/Accelerated-RLHF.md) and [Accelerated-RLHF-Release](https://github.com/NVIDIA/NeMo-Aligner/releases/tag/v0.3.0.trtllm) for more details.

## [0.2.0] - 2024-02
### New features and optimizations
- Added public-facing official Dockerfile for NeMo-Aligner.
Expand Down
12 changes: 6 additions & 6 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# CUDA 12.3
FROM nvcr.io/nvidia/pytorch:24.01-py3
FROM nvcr.io/nvidia/pytorch:24.02-py3

### config tags
ARG APEX_TAG=master
ARG TE_TAG=release_v1.4
ARG MLM_TAG=43792028f003ed25a3ee8c5a0d4cad82317d81b5
ARG NEMO_TAG=9d86acd5ebf3cec020f84dfe7e25c109506803b1
ARG PYTRITON_VERSION=0.4.1
ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
ARG TE_TAG=a51ff542dcb1f605aa54f9b0e1aaadb132acd53d
ARG MLM_TAG=core_r0.7.0
ARG NEMO_TAG=r2.0.0rc0
ARG PYTRITON_VERSION=0.5.5
ARG PROTOBUF_VERSION=4.24.4
ARG ALIGNER_COMMIT=main

Expand Down
20 changes: 2 additions & 18 deletions nemo_aligner/models/nlp/gpt/gpt_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,23 +233,7 @@ def sharded_state_dict(self, prefix=""):
# from the parent
sharded_state_dict = super().sharded_state_dict(prefix=prefix)

if self.post_process and self.return_rm_head_in_state_dict:
rm_head_prefix = f"{prefix}rm_head."
rm_head_state_dict = self.rm_head.state_dict(prefix=rm_head_prefix, keep_vars=True)

# weights are sharded row wise
weight_key = f"{rm_head_prefix}weight"

sharded_state_dict[weight_key] = make_tp_sharded_tensor_for_checkpoint(
tensor=rm_head_state_dict[weight_key],
key=weight_key,
replica_id=parallel_state.get_data_parallel_rank(),
allow_shape_mismatch=False,
tp_axis=1,
)

# biases are not sharded
bias_key = f"{rm_head_prefix}bias"
sharded_state_dict[bias_key] = make_sharded_tensor_for_checkpoint(rm_head_state_dict[bias_key], bias_key)
if not self.return_rm_head_in_state_dict:
sharded_state_dict = {k: v for k, v in sharded_state_dict.items() if "rm_head" not in k}

return sharded_state_dict
34 changes: 18 additions & 16 deletions nemo_aligner/models/nlp/gpt/megatron_gpt_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,22 +317,24 @@ def on_load_checkpoint(self, checkpoint) -> None:
"""NOTE: Have to set strict to False because we have a rm head
"""
# mcore uses distributed checkpointing
if "state_dict" in checkpoint and checkpoint["state_dict"]:
for index, module in enumerate(self.get_model_module_list()):
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
checkpoint_state_dict = checkpoint["state_dict"][f"model_{index}"]
else:
checkpoint_state_dict = checkpoint["state_dict"]
# checkpoint_state_dict has "model." but module does not so we need to remove it when loading
checkpoint_state_dict = {
key.replace("model.", ""): checkpoint_state_dict.pop(key)
for key in list(checkpoint_state_dict.keys())
}
module.load_state_dict(checkpoint_state_dict, strict=False)
else:
# when restoring a distributed checkpoint from a ptl checkpoint we need to defer loading the state_dict
# see NLPModel.on_load_checkpoint
checkpoint["state_dict"] = {}
# FSDP supports the lagecy checkpointing or torch-FSDP-native sharded checkpointing
if not self.use_fsdp:
if "state_dict" in checkpoint and checkpoint["state_dict"]:
for index, module in enumerate(self.get_model_module_list()):
if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None:
checkpoint_state_dict = checkpoint["state_dict"][f"model_{index}"]
else:
checkpoint_state_dict = checkpoint["state_dict"]
# checkpoint_state_dict has "model." but module does not so we need to remove it when loading
checkpoint_state_dict = {
key.replace("model.", ""): checkpoint_state_dict.pop(key)
for key in list(checkpoint_state_dict.keys())
}
module.load_state_dict(checkpoint_state_dict, strict=False)
else:
# when restoring a distributed checkpoint from a ptl checkpoint we need to defer loading the state_dict
# see NLPModel.on_load_checkpoint
checkpoint["state_dict"] = {}

def prepare_for_training_step(self):
# custom trainers will always zero grad for us
Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/package_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

MAJOR = 0
MINOR = 3
PATCH = 0
PATCH = 1
PRE_RELEASE = "dev"

# Use the following formatting: (major, minor, patch, pre-release)
Expand Down

0 comments on commit 0baba99

Please sign in to comment.