Skip to content

Commit

Permalink
feat[dpo]: Allows DPO Dataset to pad to a multiple
Browse files Browse the repository at this point in the history
- Needed for sequence parallel where the sequence length needs to be
  divisible

Signed-off-by: Terry Kong <[email protected]>

missing plumbing

Signed-off-by: Terry Kong <[email protected]>

fix: update other APIs required for MCore dist opt

Signed-off-by: Terry Kong <[email protected]>

feat: more informative error for DPO dataset tokenization failure

Signed-off-by: Terry Kong <[email protected]>

more functional tests

Signed-off-by: Terry Kong <[email protected]>

rename

Signed-off-by: Terry Kong <[email protected]>

convenience script to run all functional tests

Signed-off-by: Terry Kong <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: NeMo-Aligner CI <[email protected]>

more

Signed-off-by: Terry Kong <[email protected]>

fix closure, but need NVIDIA/NeMo#11189

Signed-off-by: Terry Kong <[email protected]>

eosid -> -100

Signed-off-by: Terry Kong <[email protected]>

revert black

Signed-off-by: Terry Kong <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: NeMo-Aligner CI <[email protected]>

global pad

Signed-off-by: Terry Kong <[email protected]>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: NeMo-Aligner CI <[email protected]>

test cleanup

Signed-off-by: Terry Kong <[email protected]>

make all scripts output the same

Signed-off-by: Terry Kong <[email protected]>

license

Signed-off-by: Terry Kong <[email protected]>

more comment

Signed-off-by: Terry Kong <[email protected]>
  • Loading branch information
terrykong committed Nov 7, 2024
1 parent e2c9695 commit 4fd8963
Show file tree
Hide file tree
Showing 31 changed files with 640 additions and 114 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ jobs:
- dpo-llama3
- sft-llama3
- rm-llama3
- dpo-mixtral-ep
- dpo-mixtral-sp
with:
RUNNER: self-hosted-azure
# Fairly aggresive timeout that all functional tests should try to adhere to
Expand Down
6 changes: 5 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,20 @@ git fetch -a
# 60e677423667c029dd05875da72bf0719774f844: [feat] Update get_model_parallel_src_rank to support tp-pp-dp ordering NeMo#10652
# 0deaf6716cb4f20766c995ce25d129795f1ae200: fix[export]: update API for disabling device reassignment in TRTLLM for Aligner NeMo#10863
# (superceded by 10863) 148543d6e9c66ff1f8562e84484448202249811d: feat: Migrate GPTSession refit path in Nemo export to ModelRunner for Aligner NeMo#10654
# ba8edbd2063f3349c40c9c73e5bae46abbe65f94: fix: regular torch optims (e.g., sgd) no longer error with closure spec NeMo#11189
# 35a7f718237cf011215db9e92273ed7236d0e8b1: Fix for crash with LoRA + tp_overlap_comm=false + sequence_parallel=true NeMo#10920
for pr_and_commit in \
"10651 0c92fe17df4642ffc33d5d8c0c83fda729e3910c" \
"10652 60e677423667c029dd05875da72bf0719774f844" \
"10863 0deaf6716cb4f20766c995ce25d129795f1ae200" \
"11189 ba8edbd2063f3349c40c9c73e5bae46abbe65f94" \
"10920 53cf6527571b29379188c8bb0dba8e507db3cca1" \
; do
pr=$(cut -f1 -d' ' <<<"$pr_and_commit")
head_pr_commit=$(cut -f2 -d' ' <<<"$pr_and_commit")
git fetch origin $head_pr_commit:PR-${pr}
# cherry-picks all commits between main and the top of the PR
git cherry-pick --allow-empty $(git merge-base origin/main PR-${pr})..PR-${pr}
git cherry-pick -m 1 --allow-empty $(git merge-base origin/main PR-${pr})..PR-${pr}
# Tag cherry-picks to help
git tag cherry-pick-PR-${pr}
done
Expand Down
4 changes: 4 additions & 0 deletions examples/nlp/gpt/conf/gpt_dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

# dpo specific args
dpo:
Expand All @@ -17,6 +18,7 @@ trainer:

# how many GBS we loop over
limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# do not change these
Expand Down Expand Up @@ -57,6 +59,7 @@ model:
micro_batch_size: 1
global_batch_size: 64
megatron_amp_O2: True
variable_seq_lengths: ${not:model.data.pad_length_to_multiple_of}

dpo:
# This default value ensures there are no numeric differences beween trained and reference policies when computing log probs.
Expand Down Expand Up @@ -114,6 +117,7 @@ model:
data_impl: jsonl
splits_string: null
seq_length: ${model.encoder_seq_length}
pad_length_to_multiple_of: null # Use if sequence_parallel is enabled to ensure seq_length is divisible by the ...
skip_warmup: True
num_workers: 0
reset_position_ids: False # Reset position ids after end-of-document token
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_kto.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

# kto specific args
kto:
Expand All @@ -17,6 +18,7 @@ trainer:

# how many GBS we loop over
limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# do not change these
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_ppo_actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

ppo:
# How many steps we train warmup the critic for (without training the policy)
Expand All @@ -21,6 +22,7 @@ trainer:
max_steps: -1 # max PPO steps (-1 to go through the whole train set)
val_check_interval: 10
save_interval: ${.val_check_interval}
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# PPO args to generate the data for training
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_ppo_critic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

ppo:
port: 5556
Expand All @@ -15,6 +16,7 @@ trainer:

# used to set the learning rate scheduler
max_steps: 10000
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# a PyTriton parameter to specify
Expand Down
4 changes: 3 additions & 1 deletion examples/nlp/gpt/conf/gpt_rs_actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

rs:
max_epochs: 1
max_steps: -1 # max rs steps (-1 to go through the whole train set)
val_check_interval: 10
save_interval: ${.val_check_interval}
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# pick up from the model
Expand Down Expand Up @@ -177,4 +179,4 @@ model:
# define fields from the base model's config that should be ignored when merging with this config.
overwrite_base_config:
data:
data_prefix: True
data_prefix: True
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ trainer:
devices: 1
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

sft:
max_epochs: 1
Expand All @@ -15,6 +16,7 @@ trainer:
limit_train_batches: 1.0

limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# can be used to register any custom metrics that require token-by-token generation
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/gpt_spin.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16-mixed
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

# spin specific args
spin:
Expand All @@ -18,6 +19,7 @@ trainer:

# how many GBS we loop over
limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# do not change these
Expand Down
2 changes: 2 additions & 0 deletions examples/nlp/gpt/conf/training_rm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ trainer:
devices: 8
accelerator: gpu
precision: bf16
gradient_clip_val: 0.0 # No need to change. Megatron Core optimizer uses this value

# rm specific args
rm:
Expand All @@ -20,6 +21,7 @@ trainer:
# set to float for a percentage
# of the validation dataset
limit_val_batches: 1.0
# TODO: delete once Megatron Core optimizer becomes default
gradient_clip_val: 1.0

# do not change these
Expand Down
29 changes: 13 additions & 16 deletions examples/nlp/gpt/train_gpt_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
from nemo_aligner.algorithms.dpo import DPOTrainer, dpo_custom_collate
from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets
from nemo_aligner.data.nlp.builders import build_dataloader, build_train_valid_test_dpo_datasets, identity_collate
from nemo_aligner.models.nlp.gpt.megatron_gpt_dpo_model import MegatronGPTDPOModel
from nemo_aligner.utils.distributed import Timer
from nemo_aligner.utils.train_script_utils import (
Expand All @@ -37,6 +37,7 @@

OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True)
OmegaConf.register_new_resolver("int_div", lambda x, y: x // y, replace=True)
OmegaConf.register_new_resolver("not", lambda x: not x, replace=True)

mp.set_start_method("spawn", force=True)

Expand Down Expand Up @@ -85,7 +86,7 @@ def main(cfg) -> None:
# use the entire dataset
train_valid_test_num_samples = [-1 * cfg.model.global_batch_size] * 3

train_ds, validation_ds, test_ds = build_train_valid_test_dpo_datasets(
train_ds, validation_ds, _ = build_train_valid_test_dpo_datasets(
cfg=cfg.model,
data_prefix=cfg.model.data.data_prefix,
data_impl=cfg.model.data.data_impl,
Expand All @@ -104,13 +105,7 @@ def main(cfg) -> None:
gbs=cfg.model.global_batch_size,
load_gbs=True,
pad_samples_to_global_batch_size=False,
collate_fn=partial(
dpo_custom_collate,
eos_id=ptl_model.tokenizer.eos_id,
reset_position_ids=cfg.model.data.get("reset_position_ids", False),
reset_attention_mask=cfg.model.data.get("reset_attention_mask", False),
eod_mask_loss=cfg.model.data.get("eod_mask_loss", False),
),
collate_fn=identity_collate,
)

val_dataloader = build_dataloader(
Expand All @@ -121,13 +116,7 @@ def main(cfg) -> None:
gbs=cfg.model.global_batch_size,
load_gbs=True,
pad_samples_to_global_batch_size=False,
collate_fn=partial(
dpo_custom_collate,
eos_id=ptl_model.tokenizer.eos_id,
reset_position_ids=cfg.model.data.get("reset_position_ids", False),
reset_attention_mask=cfg.model.data.get("reset_attention_mask", False),
eod_mask_loss=cfg.model.data.get("eod_mask_loss", False),
),
collate_fn=identity_collate,
use_random_sampler=False,
)

Expand All @@ -147,6 +136,14 @@ def main(cfg) -> None:
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
test_dataloader=None,
collate_fn=partial(
dpo_custom_collate,
eos_id=ptl_model.tokenizer.eos_id,
reset_position_ids=cfg.model.data.get("reset_position_ids", False),
reset_attention_mask=cfg.model.data.get("reset_attention_mask", False),
eod_mask_loss=cfg.model.data.get("eod_mask_loss", False),
pad_length_to_multiple_of=cfg.model.data.get("pad_length_to_multiple_of", None),
),
logger=logger,
ckpt_callback=ckpt_callback,
run_timer=timer,
Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/algorithms/critic_server_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def run_training(self, tokens=None, returns=None, prev_values=None, mask=None):
grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.optimizer.step(closure=None)
self.scheduler.step()

if grad_norm is not None:
Expand Down
57 changes: 52 additions & 5 deletions nemo_aligner/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from collections import defaultdict
from statistics import mean
from typing import Any, Protocol

import torch
import torch.distributed
from omegaconf.dictconfig import DictConfig
from tqdm import tqdm

Expand All @@ -24,13 +27,32 @@
)
from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids
from nemo.utils import logging
from nemo_aligner.utils import parallel_state
from nemo_aligner.utils.distributed import SyncTimer
from nemo_aligner.utils.train_utils import clip_gradients
from nemo_aligner.utils.trainer_utils import check_progress, compute_limit_batches, compute_num_steps_per_epoch
from nemo_aligner.utils.utils import clear_memory


def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False):
class DistributedCollateFunction(Protocol):
def __call__(self, batch: list[dict], **kwargs: Any) -> dict[str, torch.Tensor]:
...


def dpo_custom_collate(
batch: list[dict],
eos_id: int,
reset_position_ids: bool = False,
reset_attention_mask: bool = False,
eod_mask_loss: bool = False,
pad_length_to_multiple_of: int | None = None,
) -> dict[str, torch.Tensor]:
"""
Transposes minibatch from list[dict] -> dict[Tensor] and also pads
This collate happens outside of the torch data loader and is not compatible with the multiprocessing
logic due to requiring communication collectives.
"""
chosen_tokens = [item["chosen"] for item in batch]
rejected_tokens = [item["rejected"] for item in batch]
chosen_lengths = torch.LongTensor([item["chosen_length"] for item in batch])
Expand All @@ -44,9 +66,32 @@ def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_
rejected_tokens = torch.nn.utils.rnn.pad_sequence(rejected_tokens, batch_first=True, padding_value=eos_id)
chosen_labels = torch.nn.utils.rnn.pad_sequence(chosen_labels, batch_first=True, padding_value=-100)
rejected_labels = torch.nn.utils.rnn.pad_sequence(rejected_labels, batch_first=True, padding_value=-100)
assert chosen_tokens.shape == rejected_tokens.shape
assert chosen_labels.shape == rejected_labels.shape

if pad_length_to_multiple_of:
# Assumes both chosen and rejected match
max_seq_len = torch.tensor(chosen_tokens.shape[1], device=torch.cuda.current_device())
torch.distributed.all_reduce(
max_seq_len, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_data_parallel_group()
)

padded_max_len = math.ceil(max_seq_len / pad_length_to_multiple_of) * pad_length_to_multiple_of
chosen_tokens = torch.nn.functional.pad(
chosen_tokens, (0, padded_max_len - chosen_tokens.shape[1]), mode="constant", value=eos_id
)
rejected_tokens = torch.nn.functional.pad(
rejected_tokens, (0, padded_max_len - rejected_tokens.shape[1]), mode="constant", value=eos_id
)
chosen_labels = torch.nn.functional.pad(
chosen_labels, (0, padded_max_len - chosen_labels.shape[1]), mode="constant", value=-100
)
rejected_labels = torch.nn.functional.pad(
rejected_labels, (0, padded_max_len - rejected_labels.shape[1]), mode="constant", value=-100
)

attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
chosen_tokens, eos_id, reset_position_ids, reset_attention_mask, eod_mask_loss,
chosen_tokens.cuda(), eos_id, reset_position_ids, reset_attention_mask, eod_mask_loss,
)
assert attention_mask.ndim == 4, "attention_mask is incorrect shape for dpo_custom_collate"
if attention_mask.shape[0] == 1:
Expand All @@ -70,8 +115,7 @@ def dpo_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_


class DPOTrainer:
"""Trainer to coordinate DPO training
"""
"""Trainer to coordinate DPO training"""

def __init__(
self,
Expand All @@ -82,6 +126,7 @@ def __init__(
train_dataloader,
val_dataloader,
test_dataloader,
collate_fn: DistributedCollateFunction,
logger,
ckpt_callback,
run_timer,
Expand All @@ -90,6 +135,7 @@ def __init__(
self.train_dataloader = train_dataloader
self.val_dataloader = val_dataloader
self.test_dataloader = test_dataloader
self.collate_fn = collate_fn
self.logger = logger
self.cfg = cfg
self.optimizer = optimizer
Expand Down Expand Up @@ -172,7 +218,7 @@ def train_single_step(self, global_batch):
grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.optimizer.step(closure=None)
self.scheduler.step()

trainer_metrics = {}
Expand Down Expand Up @@ -317,6 +363,7 @@ def augment_dataloader(self, dataloader):
while True:
try:
batch = next(iter_dataloader)
batch = self.collate_fn(batch)
logprobs = self.model.get_ref_policy_logprobs(batch).cpu()
chosen_logps, reject_logps = torch.split(logprobs, len(logprobs) // 2, dim=0)
batch["ref_policy_log_probs_chosen"] = chosen_logps
Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/algorithms/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def run_training(self, dataloader_iter):
grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
lr = self.optimizer.param_groups[0]["lr"]

self.optimizer.step()
self.optimizer.step(closure=None)
self.scheduler.step()

if grad_norm is not None:
Expand Down
Loading

0 comments on commit 4fd8963

Please sign in to comment.