From 716e5033a462f7f0b48f1ffd0b2852e276645aa5 Mon Sep 17 00:00:00 2001 From: Alexander Bukharin <59148829+abukharin3@users.noreply.github.com> Date: Fri, 22 Nov 2024 17:07:55 -0500 Subject: [PATCH 1/5] feat: adds REINFORCE algorithm (#357) Signed-off-by: Terry Kong Signed-off-by: NeMo-Aligner CI Signed-off-by: abukharin Co-authored-by: abukharin Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Terry Kong --- .github/workflows/cicd-main.yml | 1 + CHANGELOG.md | 1 + README.md | 2 + docs/user-guide/reinforce.rst | 256 ++++++++ .../nlp/gpt/conf/gpt_reinforce_actor.yaml | 216 +++++++ examples/nlp/gpt/train_gpt_reinforce_actor.py | 198 ++++++ nemo_aligner/algorithms/reinforce.py | 599 ++++++++++++++++++ .../nlp/gpt/megatron_gpt_reinforce_actor.py | 393 ++++++++++++ nemo_aligner/utils/ppo_utils.py | 28 + tests/functional/reinforce.sh | 178 ++++++ .../test_cases/reinforce-llama3-pp2-reshard | 28 + tests/test_ppo_utils.py | 22 +- 12 files changed, 1921 insertions(+), 1 deletion(-) create mode 100644 docs/user-guide/reinforce.rst create mode 100644 examples/nlp/gpt/conf/gpt_reinforce_actor.yaml create mode 100644 examples/nlp/gpt/train_gpt_reinforce_actor.py create mode 100644 nemo_aligner/algorithms/reinforce.py create mode 100644 nemo_aligner/models/nlp/gpt/megatron_gpt_reinforce_actor.py create mode 100755 tests/functional/reinforce.sh create mode 100755 tests/functional/test_cases/reinforce-llama3-pp2-reshard diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index a2784d592..d2d27e95a 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -90,6 +90,7 @@ jobs: matrix: test_case: - ppo-llama3-pp2-reshard + - reinforce-llama3-pp2-reshard - dpo-llama3 - kd-llama3 - sft-llama3 diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c2f34819..63cd9ba5c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) durations = timer.consume_durations() ``` - Add code and instructions for replicating Reward Modeling training in HelpSteer2 and HelpSteer2-Preference +- Implement REINFORCE algorithm. ### Breaking Changes - Upgrade TRTLLM dependency from v0.10.0 to v0.12.0 and migrate from `GPTSession` cpp runtime to `ModelRunner` python runtime. Please use the latest Dockerfile. diff --git a/README.md b/README.md index 66029e4da..a2a500e26 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,8 @@ The toolkit is currently in it's early stages. We are committed to improving the * **Reward Model Training** * **Reinforcement Learning from Human Feedback using the [PPO](https://arxiv.org/pdf/1707.06347.pdf) Algorithm** * [Llama3-70B-PPO-Chat](https://huggingface.co/nvidia/Llama3-70B-PPO-Chat) aligned with NeMo-Aligner using TRT-LLM. +* **Reinforcement Learning from Human Feedback using the REINFORCE Algorithm** + * [Llama-3.1-Nemotron-70B-Instruct](https://huggingface.co/nvidia/Llama-3.1-Nemotron-70B-Instruct) aligned with NeMo-Aligner using TRT-LLM. * **Direct Preference Optimization** as described in [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/pdf/2305.18290) * [Llama3-70B-DPO-Chat](https://huggingface.co/nvidia/Llama3-70B-DPO-Chat) aligned with NeMo Aligner. * **Self-Play Fine-Tuning (SPIN)** as described in [Self-Play Fine-Tuning Converts Weak Language Models to Strong Language Models](https://arxiv.org/pdf/2401.01335) diff --git a/docs/user-guide/reinforce.rst b/docs/user-guide/reinforce.rst new file mode 100644 index 000000000..cc3005db1 --- /dev/null +++ b/docs/user-guide/reinforce.rst @@ -0,0 +1,256 @@ +.. include:: /content/nemo.rsts + +.. _model-aligner-reinforce: + +Model Alignment by REINFORCE +@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ + +In this tutorial, we will guide you through the process of aligning a NeMo Framework model using REINFORCE. This method can be applied to various models, including LLaMa2 and Mistral, with our scripts functioning consistently across different models. + +REINFORCE is usually preceded by a Supervised Fine-Tuning (SFT). We should first follow the :ref:`Prerequisite guide ` and the :ref:`SFT guide `. After obtaining the SFT model, we will also need to train a reward model as in :ref:`PPO guide `. We will use the REINFORCE algorithm on the `Anthropic-HH-RLHF `__ dataset. + +REINFORCE Training +############ + +After you have fine-tuned a GPT model using Supervised Fine-Tuning (SFT), and trained a reward model as explained in the preceding section, you can start aligning the policy using REINFORCE. + +During REINFORCE training, we have three models interacting with each other, which Aligner runs in two separate jobs: + +#. The Policy Network: This is the model we are training and it should start from an SFT model. +#. The Reward Model (RM): This model accepts a prompt combined with a response as input and produces a single scalar value, known as the reward. The REINFORCE algorithm aims to maximize this reward. +#. The Initial Policy Network (also known as the Reference Model): We use this model to compute a KL Divergence penalty term that ensures that the PPO Actor does not diverge too much from the Initial Policy. This way, we prevent the REINFORCE Actor from overfitting to the rewards given by the RM, and ensure it does not forget the knowledge it acquired during pretraining and SFT. This model should be the one used to initialize the REINFORCE Actor Network. + +The next section discusses how to launch each of these two jobs. + +Launching the Reward Model and Critic Server +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +To launch the server: + +.. code-block:: bash + + #!/bin/bash + RM_NEMO_FILE="/path/to/trained_rm.nemo" + GPFS="/path/to/nemo-aligner-repo" + + RESULTS_DIR="critic_results_dir" + + cd ${GPFS} + export PYTHONPATH="${GPFS}:${PYTHONPATH}" \ + && export HYDRA_FULL_ERROR=1 \ + && python -u examples/nlp/gpt/serve_reward_model.py \ + trainer.num_nodes=1 \ + trainer.devices=8 \ + ++model.tensor_model_parallel_size=4 \ + rm_model_file=${RM_NEMO_FILE} + + +The above example launches the reward model server on eight GPUs and one node. Make sure to change trainer.devices, trainer.num_nodes depending on your model size and scale. Aligner will work on any scale. Also, make sure to tune the trainer.reinforce.inference_micro_batch_size argument. This argument sets the size of the batch the REINFORCE actor is allowed to send to the reward per DP rank. + +Launch the Initial Policy and REINFORCE Actor Training +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +The REINFORCE Actor training job contains the master controller that makes the HTTP calls to all servers when needed. To launch the REINFORCE Actor and Initial Policy server: + +.. code-block:: bash + + GPFS="/path/to/nemo-aligner-repo" + TRAIN_DATA_PATH="/path/to/train_prompts.jsonl" + VALID_DATA_PATH="/path/to/test_prompts.jsonl" + + PRETRAINED_ACTOR_NEMO_FILE="/path/to/sft_checkpoint.nemo" + RESULTS_DIR="/path/to/actor_results_dir" + + USE_FLASK=False + ACTOR_LR=1e-6 + KL=0.01 + NUM_ROLLOUTS=32 + ACTOR_GBS=32 + REWARD_PORT=5555 + # Change this to the hostname of server hosting the reward model + host_reward="localhost" + + cd ${GPFS} + export PYTHONPATH="${GPFS}:${PYTHONPATH}" \ + && export HYDRA_FULL_ERROR=1 \ + && python -u examples/nlp/gpt/train_gpt_reinforce_actor.py \ + "model.data.data_prefix={train: [${TRAIN_DATA_PATH}], validation: [${VALID_DATA_PATH}], test: [${VALID_DATA_PATH}]}" \ + pretrained_checkpoint.restore_from_path=\"${ACTOR_NEMO_FILE}\" \ + exp_manager.checkpoint_callback_params.save_top_k=1 \ + exp_manager.explicit_log_dir=\"${RESULTS_DIR}\" \ + trainer.reinforce.max_epochs=1 \ + trainer.reinforce.max_steps=313 \ + trainer.reinforce.val_check_interval=4 \ + trainer.num_nodes=1 \ + trainer.devices=8 \ + trainer.reinforce.trt_llm.enable=True \ + trainer.reinforce.trt_llm.reshard=True \ + trainer.reinforce.trt_llm.unload_engine_train=False \ + ++model.tensor_model_parallel_size=4 \ + ++model.reinforce.num_rollout_samples=${NUM_ROLLOUTS} \ + model.global_batch_size=${ACTOR_GBS} \ + model.micro_batch_size=1 \ + model.optim.lr=\"\\\$\{multiply:${ACTOR_LR},1.001\}\" \ + model.optim.sched.warmup_steps=0 \ + model.optim.sched.constant_steps=312 \ + model.optim.sched.min_lr=${ACTOR_LR} \ + model.optim.weight_decay=0.01 \ + model.reinforce.rollout_micro_batch_size=16 \ + model.reinforce.forward_micro_batch_size=16 \ + model.reinforce.val_rollout_micro_batch_size=8 \ + model.data.data_impl=jsonl \ + remote_rm.reward_model.ip=${host_reward} \ + remote_rm.reward_model.port=${REWARD_PORT} \ + ++model.reinforce.length_params.max_length=2048 \ + trainer.reinforce.initial_policy_kl_penalty="${KL}" \ + ++model.optim.bucket_cap_mb=200 \ + ++model.dist_ckpt_format=zarr \ + ++model.optim.overlap_grad_sync=False \ + ++model.optim.contiguous_grad_buffer=True \ + ++model.enable_nge=True \ + trainer.reinforce.batch_iterator.use_flask=${USE_FLASK} \ + trainer.reinforce.rollout_batch_seq_length=4096 + +The above command launches the initial and actor server on one node with eight GPUs. + +Launching Both Servers for REINFORCE training +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +You can use slurm to launch the two jobs and get them to coordinate together in a full REINFORCE job through the following: + +.. code-block:: bash + + #!/bin/bash + #SBATCH -N 1 --ntasks-per-node 8 -A <> -p <> --job-name <> -t 4:00:00 --exclusive + #SBATCH hetjob + #SBATCH -N 1 --ntasks-per-node 8 -A <> -p <> --job-name <> -t 4:00:00 --exclusive + + NAME="reinforce" + + # PARAMETERS + RM_NEMO_FILE="/path/to/trained_rm.nemo" + + ACTOR_NEMO_FILE="/path/to/sft_model.nemo" + + TRAIN_DATA_PATH="/path/to/train_prompts.jsonl" + VALID_DATA_PATH="/path/to/test_prompts.jsonl" + + RESULTS_DIR="/path/to/results_dir" + mkdir -p $RESULTS_DIR + + GPFS="/path/to/nemo-aligner-repo" + MOUNTS="--container-mounts=MOUNTS" # mounts + + CONTAINER=<<>> # use the latest NeMo Training container, Aligner will work there + + PROJECT=reinforce_run + + CRITIC_LOG_DIR="${RESULTS_DIR}/critic_results" + CRITIC_OUTFILE="${CRITIC_LOG_DIR}/critic_output_%j_%t.log" + CRITIC_ERRFILE="${CRITIC_LOG_DIR}/critic_error_%j_%t.err" + REWARD_PORT=5567 + CRITIC_CONFIG_PATH="${GPFS}/examples/nlp/gpt/conf" + CRITIC_CONFIG_NAME="inference_rm" + + CONF_DIR="${GPFS}/examples/nlp/gpt/conf" + CONFIG_NAME="gpt_reinforce_actor" + + mkdir -p $CRITIC_LOG_DIR + + CRITIC_NAME="${NAME}_critic" + + read -r -d '' cmd_critic_inference <`__ script from the NeMo codebase to run more rigorous evaluation of your trained model. \ No newline at end of file diff --git a/examples/nlp/gpt/conf/gpt_reinforce_actor.yaml b/examples/nlp/gpt/conf/gpt_reinforce_actor.yaml new file mode 100644 index 000000000..8efe26bb5 --- /dev/null +++ b/examples/nlp/gpt/conf/gpt_reinforce_actor.yaml @@ -0,0 +1,216 @@ +defaults: + - optional tp_overlap@model.ub_tp_comm_overlap_cfg: + +trainer: + # these args are respected + num_nodes: 8 + devices: 8 + accelerator: gpu + precision: bf16 + + reinforce: + + max_epochs: 1 + max_steps: -1 # max REINFORCE steps (-1 to go through the whole train set) + val_check_interval: 10 + save_interval: ${.val_check_interval} + gradient_clip_val: 1.0 + + # REINFORCE args to generate the data for training + initial_policy_kl_penalty: 0.01 + use_absolute_kl: True + num_rollouts_per_prompt: 4 + + + # the sequence length to pad the rollout batch for training to + # reduce fragmentation at the cost of using more + # memory, set to null if we don't want to pad it + # to a constant size + # if actual seq length is higher than this a warning will be raised + # but will not crash and training will still proceed on the larger + # sequence length + rollout_batch_seq_length: null + + # Speed-up training by accelerating inference stage using TRTLLM + trt_llm: + enable: True + reshard: False # if True then reshard the model into TP only for inference + + # TRTLLM preallocates activation memory according to the number of input tokens + # By default, assume the max input length is the difference between the model sequence length and the max number of tokens to generate + max_input_len: ${subtract:${model.encoder_seq_length}, ${model.reinforce.length_params.max_length}} + + # the seed to use for trt-llm generation + seed: ${model.seed} + + # for supported values see: https://github.com/NVIDIA/NeMo/blob/db6244857af3b012f645c7f4672522978bb608b1/nemo/export/trt_llm/converter/utils.py#L26 + model_type: llama # can be gptj, gptnext, llama, gemma, falcon + + # Save GPU memory by unloading and reloading the TRTLLM engine before and after the training stage + # Reloading the engine incurs a constant time overhead + unload_engine_train: False + + batch_iterator: + # When use_flask is True, we will spawn a flask server on rank 0 to balance the work of policy rollouts. + # This option is useful in cases where the generation length varies greatly across DP ranks since + # the flask server will allow DP ranks with shorter responses to process more samples and DP ranks + # with longer responses to process less samples. Thereby lowering the DP wait time. + use_flask: False + port: 5557 + + # pick up from the model + # *do not change this* + model_gbs: ${model.global_batch_size} + model_mbs: ${model.micro_batch_size} + + # no need to change these + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_time: null + max_epochs: ${.reinforce.max_epochs} + max_steps: ${.reinforce.max_steps} + +remote_rm: + # what to batch the inputs to + # set to None if no batching when sending inference to the reward model + pad_to_length: ${model.encoder_seq_length} + + # reward model server + reward_model: + name: reward_model + ip: localhost + port: 5555 + + +exp_manager: + explicit_log_dir: /results + exp_dir: null + name: megatron_gpt_reinforce_actor + create_wandb_logger: False + wandb_logger_kwargs: + project: nemo_aligner_reinforce + name: gpt3_reinforce_2b + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_rewards + save_top_k: 1 + mode: max + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits + filename: 'megatron_gpt-{step}-{consumed_samples}-{reinforce_optimization_step}-{epoch}-{val_rewards:.3f}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + +pretrained_checkpoint: + restore_from_path: null + +model: + + reinforce: + # training generation mbs + rollout_micro_batch_size: 8 + num_rollout_samples: 512 + + # mbs to do log prob inference, can be set to + # lower than rollout_micro_batch_size to reduce + # memory usage + forward_micro_batch_size: ${.rollout_micro_batch_size} + + # val generation mbs + val_rollout_micro_batch_size: ${.rollout_micro_batch_size} + num_val_samples: ${.num_rollout_samples} + + # to offload during generation or not + offload_adam_states: True + + # params for generation + sampling_params: + use_greedy: False + temperature: 1.0 + top_k: 0 + top_p: 1.0 + repetition_penalty: 1.0 + add_BOS: False + all_probs: False + compute_logprob: False + # will be used in NeMo version > 1.20.0 + # keeping it for now + end_strings: ["<|endoftext|>", ""] + + # length argument for autoregressive sampling + # max length means max amount of tokens to generate + length_params: + max_length: ${int_div:${model.encoder_seq_length}, 2} + min_length: 1 + + trt_llm: ${trainer.reinforce.trt_llm} + + peft: + peft_scheme: "none" # ["lora", "none"] + restore_from_path: null + restore_from_ckpt: + checkpoint_dir: null + checkpoint_name: null + + lora_tuning: + target_modules: ['attention_qkv'] # this can either be 'attention_qkv','attention_dense','mlp_fc1','mlp_fc2', 'attention' (qkv & dense), 'mlp' (fc1 & fc2), 'all' + adapter_dim: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + mcore_gpt: True + # these control the mbs/gbs during REINFORCE training + micro_batch_size: 1 + global_batch_size: 64 + megatron_amp_O2: True + + encoder_seq_length: 4096 + max_position_embeddings: ${model.encoder_seq_length} + + ## Sequence Parallelism + sequence_parallel: False + + # miscellaneous + seed: 1234 + + optim: + name: distributed_fused_adam + bucket_cap_mb: 200 + overlap_grad_sync: False + contiguous_grad_buffer: True + lr: 9e-7 + weight_decay: 0.1 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 10 + constant_steps: 1000 + min_lr: 9e-8 + + precision: ${trainer.precision} + + data: + data_impl: jsonl + splits_string: null + seq_length: ${model.encoder_seq_length} + skip_warmup: True + num_workers: 0 + reset_position_ids: False # Reset position ids after end-of-document token + reset_attention_mask: False # Reset attention mask after end-of-document token + eod_mask_loss: False # Mask loss for the end of document tokens + index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix + data_prefix: null + + # define fields from the base model's config that should be ignored when merging with this config. + overwrite_base_config: + data: + data_prefix: True \ No newline at end of file diff --git a/examples/nlp/gpt/train_gpt_reinforce_actor.py b/examples/nlp/gpt/train_gpt_reinforce_actor.py new file mode 100644 index 000000000..0aa238fc4 --- /dev/null +++ b/examples/nlp/gpt/train_gpt_reinforce_actor.py @@ -0,0 +1,198 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import torch +import torch.multiprocessing as mp +from megatron.core.utils import divide +from omegaconf.omegaconf import OmegaConf + +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager +from nemo_aligner.algorithms.reinforce import ReinforceTrainer +from nemo_aligner.data.nlp.builders import ( + build_dataloader, + build_train_valid_test_rlhf_datasets, + collate_with_pad_to_max_batch, +) +from nemo_aligner.models.nlp.gpt.megatron_gpt_reinforce_actor import MegatronGPTReinforceActorModel +from nemo_aligner.models.nlp.gpt.reward_critic_clients import RemoteGPTRMClient +from nemo_aligner.utils import parallel_state +from nemo_aligner.utils.batch_iterators import get_batch_iterator_cls +from nemo_aligner.utils.distributed import Timer +from nemo_aligner.utils.train_script_utils import ( + CustomLoggerWrapper, + add_custom_checkpoint_callback, + extract_optimizer_scheduler_from_ptl_model, + init_distributed, + init_peft, + init_using_ptl, + resolve_and_create_trainer, + retrieve_custom_trainer_state_dict, +) +from nemo_aligner.utils.utils import load_and_override_model_config, load_from_nemo, retrieve_model_state_dict_in_cpu + +"""Script to start REINFORCE training""" + +OmegaConf.register_new_resolver("multiply", lambda x, y: x * y, replace=True) +OmegaConf.register_new_resolver("int_div", lambda x, y: x // y, replace=True) +OmegaConf.register_new_resolver("subtract", lambda x, y: x - y, replace=True) + +mp.set_start_method("spawn", force=True) + + +@hydra_runner(config_path="conf", config_name="gpt_reinforce_actor") +def main(cfg) -> None: + cfg.model = load_and_override_model_config(cfg.pretrained_checkpoint.restore_from_path, cfg.model) + + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + + trainer = resolve_and_create_trainer(cfg, "reinforce") + + exp_manager(trainer, cfg.exp_manager) + + logger = CustomLoggerWrapper(trainer.loggers) + + ptl_model = load_from_nemo( + MegatronGPTReinforceActorModel, + cfg.model, + trainer, + strict=True, + restore_path=cfg.pretrained_checkpoint.restore_from_path, + ) + + init_peft(ptl_model, cfg.model) + + init_policy_state_dict = None + + # only need this if we are running with inital kl penalty & full-parameter tuning + if cfg.trainer.reinforce.initial_policy_kl_penalty > 0 and cfg.model.peft.peft_scheme == "none": + init_policy_state_dict = retrieve_model_state_dict_in_cpu( + ptl_model, megatron_amp_O2=cfg.model.get("megatron_amp_O2", False) + ) + + ptl_model.init_policy_state_dict = init_policy_state_dict + + # pull values from checkpoint + trainer_restore_path = trainer.ckpt_path + + # TODO: log this restore path + if trainer_restore_path is not None: + custom_trainer_state_dict = retrieve_custom_trainer_state_dict(trainer) + else: + custom_trainer_state_dict = None + + init_distributed(trainer, ptl_model, cfg.model.get("transformer_engine", False)) + + # use the entire dataset + train_valid_test_num_samples = [-1, -1, -1] + train_ds, validation_ds, _ = build_train_valid_test_rlhf_datasets( + cfg=cfg.model, + data_prefix=cfg.model.data.data_prefix, + data_impl=cfg.model.data.data_impl, + splits_string=cfg.model.data.splits_string, + train_valid_test_num_samples=train_valid_test_num_samples, + seq_length=cfg.model.data.seq_length, + seed=cfg.model.seed, + tokenizer=ptl_model.tokenizer, + ) + + max_seqlen = cfg.model.reinforce.length_params.max_length + eos_id = ptl_model.tokenizer.eos_id + + # collate fn to pad to the max seq length in the batch + collate_fn = collate_with_pad_to_max_batch(max_seqlen, eos_id, cfg, generate_masks_and_position_ids=False) + + train_dataloader_builder = partial( + build_dataloader, + cfg=cfg, + dataset=train_ds, + mbs=cfg.model.reinforce.rollout_micro_batch_size, + gbs=cfg.model.reinforce.num_rollout_samples, + collate_fn=collate_fn, + load_gbs=False, + ) + + val_dataloader_builder = partial( + build_dataloader, + cfg=cfg, + dataset=validation_ds, + mbs=cfg.model.reinforce.val_rollout_micro_batch_size, + gbs=cfg.model.reinforce.num_val_samples, + collate_fn=collate_fn, + load_gbs=False, + use_random_sampler=False, + ) + + # nemo uses the train dataloader to figure out + # max steps to take when max_steps = -1 + # but our train dataloader is for the prompts + # so we instaniate a dummy dataloader + # to get the proper max *optimization* steps + # nemo treats batch size of normal dataloader as GBS/DP + # so we need to offset it by DP + dummy_train_dataloader = torch.utils.data.DataLoader( + dataset=train_ds, batch_size=divide(cfg.model.global_batch_size, parallel_state.get_data_parallel_world_size()) + ) + + init_using_ptl(trainer, ptl_model, dummy_train_dataloader, train_ds) + # make sure the dummy train dataloader is never used + del ptl_model._train_dl + del dummy_train_dataloader + + optimizer, scheduler = extract_optimizer_scheduler_from_ptl_model(ptl_model) + ckpt_callback = add_custom_checkpoint_callback(trainer, ptl_model) + + logger.log_hyperparams(OmegaConf.to_container(cfg)) + + rm = RemoteGPTRMClient(cfg.remote_rm) + timer = Timer(cfg.exp_manager.get("max_time_per_run") if cfg.exp_manager else None) + + batch_iterator_cfg = cfg.trainer.reinforce.get("batch_iterator", {}) + batch_iterator_cls = get_batch_iterator_cls(batch_iterator_cfg) + + reinforce_trainer = ReinforceTrainer( + cfg=cfg.trainer.reinforce, + model=ptl_model, + optimizer=optimizer, + scheduler=scheduler, + train_dataloader_builder=train_dataloader_builder, + val_dataloader_builder=val_dataloader_builder, + collate_fn=collate_fn, + rm=rm, + batch_iterator_cls=batch_iterator_cls, + logger=logger, + ckpt_callback=ckpt_callback, + run_timer=timer, + ) + + if custom_trainer_state_dict is not None: + reinforce_trainer.load_state_dict(custom_trainer_state_dict) + + reinforce_trainer.fit() + + # Note: The main loop creates multiple HTTPCommunicators which own a + # pytriton.client.FuturesModelClient. At the end of the loop, we manually + # close all FuturesModelClients since we do not use the context manager + # syntax. This guarantees all dangling threads are no longer blocking. + # `atexit` does not suffice since the registered cleanup function can be + # queued behind another blocking atexit registered function. + # TODO: utilize context managers to avoid manual cleanup + rm.communicator.close() + + +if __name__ == "__main__": + main() diff --git a/nemo_aligner/algorithms/reinforce.py b/nemo_aligner/algorithms/reinforce.py new file mode 100644 index 000000000..3bb127cee --- /dev/null +++ b/nemo_aligner/algorithms/reinforce.py @@ -0,0 +1,599 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +from collections import UserDict +from contextlib import nullcontext +from typing import Dict, List, Optional, Union + +import pandas as pd +import torch +from megatron.core import parallel_state as mcore_parallel_state +from megatron.core.utils import divide +from omegaconf.dictconfig import DictConfig +from tqdm import tqdm +from typing_extensions import Self + +from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import MegatronPretrainingRandomSampler +from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split +from nemo.utils import logging +from nemo_aligner.models.nlp.gpt.megatron_gpt_reinforce_actor import MegatronGPTReinforceActorModel +from nemo_aligner.utils import parallel_state +from nemo_aligner.utils.distributed import ( + ScopedTimer, + all_reduce_dict, + masked_global_mean_var, + normalize_tensor, + rebalance_nd_tensor, +) +from nemo_aligner.utils.parallel_state import is_trt_llm_reshard, trt_llm_reshard_region +from nemo_aligner.utils.ppo_utils import calculate_kl_penalty, calculate_rloo_baseline, create_mask +from nemo_aligner.utils.server_utils import FutureResult +from nemo_aligner.utils.train_utils import clip_gradients +from nemo_aligner.utils.trainer_utils import check_progress, compute_num_steps_per_epoch +from nemo_aligner.utils.utils import clear_memory, cpu_dict, masked_mean + + +class ReinforceRolloutBatch(UserDict): + @classmethod + def from_rollout_batches( + cls: Self, rollout_batches: List[Dict], eos_id: int, rollout_batch_seq_length: Optional[int] + ) -> Self: + """Given a list of rollout batches, stack the tensors within and put them in a single dictionary + """ + stacked_dict = cls() + + for k in sorted(rollout_batches[0]): + + list_of_tensors = [item[k] for item in rollout_batches] + + if all(x.ndim == 1 for x in list_of_tensors): + tensor = torch.cat(list_of_tensors) + else: + pad_value = eos_id if k == "response_tokens" else 0 + + list_of_tensors = [row.flatten() for tensor in list_of_tensors for row in tensor] + # TODO: can we avoid padding locally then padding globally? + tensor = torch.nn.utils.rnn.pad_sequence(list_of_tensors, batch_first=True, padding_value=pad_value) + + # find the max sequence length globally + max_seqlen = torch.tensor([tensor.size(-1)], dtype=torch.long, device=torch.cuda.current_device()) + torch.distributed.all_reduce(max_seqlen, op=torch.distributed.ReduceOp.MAX) + + if rollout_batch_seq_length is None or max_seqlen >= rollout_batch_seq_length: + pad_seq_len = max_seqlen.item() + else: + # response tokens must be B x S because computing log probs requires us to offset by 1 + pad_seq_len = rollout_batch_seq_length if k == "response_tokens" else rollout_batch_seq_length - 1 + + tensor = torch.nn.functional.pad(tensor, (0, pad_seq_len - tensor.size(-1)), value=pad_value) + + stacked_dict[k] = tensor + + return stacked_dict + + def gather_and_balance_globally(self): + global_rollout_batch = type(self)() + + for k, tensor in self.data.items(): + # with reshard enabled, PP groups turn into DP groups. So need to balance them first and then + # balance by all the original DP groups + # NOTE: this logic needs to use the pure parallel state, that is one without sharding but needs + # to ping the is_trt_llm_reshard variable + if is_trt_llm_reshard(): + tensor = rebalance_nd_tensor(tensor, group=mcore_parallel_state.get_pipeline_model_parallel_group()) + + tensor = rebalance_nd_tensor(tensor, group=mcore_parallel_state.get_data_parallel_group()) + global_rollout_batch[k] = tensor + + return global_rollout_batch + + def chunk(self, rank, split_size, seed): + chunked_rollout_batch = type(self)() + + batch_set = set(tensor.size(0) for tensor in self.data.values()) + assert len(batch_set) == 1, "batch sizes are not the same across the rollout batch" + B = batch_set.pop() + + g_cpu = torch.Generator() + g_cpu.manual_seed(seed) + indices = torch.arange(B) + + for k in self.data: + chunked_rollout_batch[k] = self.data[k][indices].clone() + + return chunked_rollout_batch + + +def compute_num_rollout_microbatches(dataloader): + return divide( + divide(dataloader.batch_sampler.global_batch_size, dataloader.batch_sampler.micro_batch_size), + parallel_state.get_data_parallel_world_size(), + ) + + +class ReinforceTrainer: + """Trainer to coordinate REINFORCE training + """ + + def __init__( + self, + cfg: DictConfig, + model: MegatronGPTReinforceActorModel, + optimizer, + scheduler, + train_dataloader_builder, + val_dataloader_builder, + collate_fn, + rm, + batch_iterator_cls, + logger, + ckpt_callback, + run_timer, + ): + self.cfg = cfg + self.model = model + self.optimizer = optimizer + self.scheduler = scheduler + self.train_dataloader_builder = train_dataloader_builder + self.val_dataloader_builder = val_dataloader_builder + self.collate_fn = collate_fn + self.rm = rm + self.batch_iterator_cls = batch_iterator_cls + self.logger = logger + self.ckpt_callback = ckpt_callback + + # this timer checks if we should stop training + self.run_timer = run_timer + + self.trtllm_reshard = "trt_llm" in cfg and cfg.trt_llm.enable and cfg.trt_llm.reshard + + self.consumed_samples = 0 + # the step here is REINFORCE step + self.step = 0 + # keep track of how many times we optimized the actor + self.reinforce_optimization_step = 0 + + # compute `max_steps` + train_dataloader = self.train_dataloader_builder(consumed_samples=0) + if (not isinstance(train_dataloader.batch_sampler, MegatronPretrainingRandomSampler)) and ( + self.cfg.max_epochs is not None and self.cfg.max_epochs > 1 + ): + # if you use MegatronPretrainingBatchSampler as the batch_sampler passed to your train dataloader (in builders.py) + # then each epoch will repeat all your samples in the same order as the previous epoch, there is no shuffling + # to fix this, you should use MegatronPretrainingRandomSampler instead, which alleviates this issue and allows + # random shuffling for each epoch. + raise ValueError( + "max_epochs > 1 is not supported unless using `MegatronPretrainingRandomSampler` as the batch_sampler for your train dataloader" + ) + + self.num_steps_per_epoch = compute_num_steps_per_epoch(train_dataloader.batch_sampler) + self.set_max_steps() + + self.compute_init_policy_kl = self.cfg.initial_policy_kl_penalty > 0 + # size to pad our rollout batch to + self.rollout_batch_seq_length = self.cfg.rollout_batch_seq_length + + # for wandb table + self.train_df = pd.DataFrame(columns=["step", "prompt", "response", "reward"]) + self.val_df = pd.DataFrame(columns=["step", "prompt", "response", "reward"]) + + self.timer = ScopedTimer(reduction="mean", sync_cuda=True, buffer_size=1) + + def generate_reinforce_data(self, rollout_batch): + """generate reinforce specific data for training + """ + reinforce_rollout_data = {} + reinforce_rollout_metrics = {} + + prompt_lengths = rollout_batch["prompt_lengths"] + response_lengths = rollout_batch["response_lengths"] + prompt_tokens = rollout_batch["prompt_tokens"] + response_tokens = rollout_batch["response_tokens"] + rewards = rollout_batch["rewards"] + logprobs = rollout_batch["logprobs"] + is_end = rollout_batch["is_end"] + + if self.compute_init_policy_kl: + init_policy_kl = calculate_kl_penalty( + log_probs_a=rollout_batch["logprobs"], + log_probs_b=rollout_batch["init_logprobs"], + use_absolute_kl=self.cfg.use_absolute_kl, + ) + else: + init_policy_kl = torch.tensor(0, dtype=logprobs.dtype, device=logprobs.device) + + mask = create_mask(values=logprobs, prompt_lengths=prompt_lengths, response_lengths=response_lengths) + + init_policy_kl = masked_mean(init_policy_kl, mask, dim=-1) + rewards_with_kl = rewards - self.cfg.initial_policy_kl_penalty * init_policy_kl + + baseline = calculate_rloo_baseline(prompts=prompt_tokens, reward=rewards_with_kl, mask=is_end.float()) + + # collect everything we need to train REINFORCE + reinforce_rollout_data["mask"] = mask + reinforce_rollout_data["rewards_with_kl"] = rewards_with_kl + reinforce_rollout_data["baseline"] = baseline + reinforce_rollout_data["response_tokens"] = response_tokens + reinforce_rollout_data["is_end"] = is_end + + # compute metrics + # these are not global yet + reinforce_rollout_metrics["init_policy_kl"] = init_policy_kl.sum().item() if self.compute_init_policy_kl else 0 + reinforce_rollout_metrics["rewards_with_kl"] = rewards_with_kl.sum().item() + reinforce_rollout_metrics["num_samples"] = prompt_lengths.size(0) + + # now the metrics are global + reinforce_rollout_metrics = all_reduce_dict( + reinforce_rollout_metrics, + group=parallel_state.get_data_parallel_group(), + op=torch.distributed.ReduceOp.SUM, + ) + num_samples = reinforce_rollout_metrics.pop("num_samples") + reinforce_rollout_metrics = {k: v / num_samples for k, v in reinforce_rollout_metrics.items()} + + return reinforce_rollout_data, cpu_dict(reinforce_rollout_metrics) + + def _run_inference(self, dataloader_builder, consumed_samples, is_validation): + """this function is run per DP so the metrics need to be computed globally + assumes that the dataloader is built with the proper consumed samples value + """ + reshard_context = trt_llm_reshard_region if self.trtllm_reshard else nullcontext + + rollout_batches, futures = [], [] + + with reshard_context(): + # dataloader must be built within the reshard context because it uses DP rank and size + dataloader = dataloader_builder(consumed_samples=consumed_samples) + sampler_iter = iter(dataloader.batch_sampler) + + # must compute the number of microbatches in the reshard context + # so the DP groups are correct + num_microbatches = compute_num_rollout_microbatches(dataloader) + + with self.timer("batch_iterator_init"): + batch_iterator = self.batch_iterator_cls( + sampler_iter, num_microbatches, dataloader.dataset, self.collate_fn + ) + + with self.timer("generate"): + for batch in batch_iterator: + if not is_validation: + for _ in range(self.cfg.num_rollouts_per_prompt): + rollout_batch = self.model.infer(batch) + rollout_batch["prompt_tokens"] = batch["text"] + rollout_batches.append(rollout_batch) + futures.append(self.rm.infer_rm(rollout_batch)) + else: + rollout_batch = self.model.infer(batch) + rollout_batches.append(rollout_batch) + futures.append(self.rm.infer_rm(rollout_batch)) + + unbalanced_local_batch = ReinforceRolloutBatch.from_rollout_batches( + rollout_batches, + eos_id=self.model.tokenizer.eos_id, + rollout_batch_seq_length=self.cfg.rollout_batch_seq_length, + ) + global_rollout_batch = unbalanced_local_batch.gather_and_balance_globally() + + padded_rollout_sequence_length = global_rollout_batch["response_tokens"].size(-1) + + # the chunking must be outside of the TRT-LLM context because we do logprob calculation in nemo + balanced_local_batch = global_rollout_batch.chunk( + rank=parallel_state.get_data_parallel_rank(), + split_size=parallel_state.get_data_parallel_world_size(), + seed=self.step, + ) + # since we compute the logprobs in nemo we need to disable the resharding + batched_response_tokens = balanced_local_batch["response_tokens"] + + with self.timer("logprobs"): + rollout_logprobs = self.model.get_inference_log_probs(batched_response_tokens) + balanced_local_batch["logprobs"] = rollout_logprobs + + compute_init_policy_kl = not is_validation and self.compute_init_policy_kl + if compute_init_policy_kl: + with self.timer("init_logprobs"): + rollout_init_logprobs = self.model.get_init_policy_logprobs(batched_response_tokens) + balanced_local_batch["init_logprobs"] = rollout_init_logprobs + + # we send the request in sharded context, so we need to keep this sharding and then undo it + with reshard_context(): + with self.timer("critic_wait"): + rm_rollout_batches = [] + for future in futures: + rewards = future.result().squeeze(1) + rm_rollout_batches.append({"rewards": rewards}) + + unbalanced_rm_batch = ReinforceRolloutBatch.from_rollout_batches( + rm_rollout_batches, + eos_id=self.model.tokenizer.eos_id, + rollout_batch_seq_length=padded_rollout_sequence_length, + ) + global_rm_batch = unbalanced_rm_batch.gather_and_balance_globally() + + # chunking needs to be outside of reshard region + # NOTE: the seed here must be the same as the chunk before since we need to shuffle + # these values the same way as the other values + balanced_rm_batch = global_rm_batch.chunk( + rank=parallel_state.get_data_parallel_rank(), + split_size=parallel_state.get_data_parallel_world_size(), + seed=self.step, + ) + balanced_local_batch.update(balanced_rm_batch) + + global_rollout_batch.update(global_rm_batch) + + return balanced_local_batch, cpu_dict(self.compute_rollout_metrics(global_rollout_batch)) + + def compute_rollout_metrics(self, rollout_batch): + table = {} + + prompt_lengths = rollout_batch["prompt_lengths"] + response_lengths = rollout_batch["response_lengths"] + response_tokens = rollout_batch["response_tokens"] + rewards = rollout_batch["rewards"] + is_end = rollout_batch["is_end"] + + # take the first sample for logging + reward = rewards[0] + prompt_length = prompt_lengths[0] + response_length = response_lengths[0] + response_token = response_tokens[0] + + table["reward"] = reward.item() + table["prompt"] = self.model.tokenizer.ids_to_text(response_token[:prompt_length].tolist()) + table["response"] = self.model.tokenizer.ids_to_text(response_token[prompt_length:response_length].tolist()) + + metrics = { + "table": table, + "rollout_size": prompt_lengths.size(0), + "response_lengths": response_lengths.float().mean().item(), + "prompt_lengths": prompt_lengths.float().mean().item(), + "generation_length": (response_lengths - prompt_lengths).float().mean().item(), + "rewards": rewards.mean().item(), + "fraction_of_samples_properly_ended": is_end.float().mean().item(), + } + + return metrics + + @torch.no_grad() + def run_validation(self): + self.model.prepare_for_inference() + + _, rollout_metrics = self._run_inference(self.val_dataloader_builder, consumed_samples=0, is_validation=True) + + self.model.finish_inference() + return rollout_metrics + + @torch.no_grad() + def generate_rollouts(self): + with self.timer("prepare_for_inference"): + # Timing includes build if first step and refit if step > 1 + self.model.prepare_for_inference() + + rollout_batch, rollout_metrics = self._run_inference( + self.train_dataloader_builder, consumed_samples=self.consumed_samples, is_validation=False + ) + + self.consumed_samples += rollout_metrics["rollout_size"] + + reinforce_rollout_data, reinforce_rollout_metrics = self.generate_reinforce_data(rollout_batch) + + with self.timer("finish_inference"): + # Timing includes engine unloading if enabled + self.model.finish_inference() + + return ( + reinforce_rollout_data, + rollout_metrics | reinforce_rollout_metrics | {"consumed_samples": self.consumed_samples}, + self.timer.consume_durations(), + ) + + def run_training(self, dataloader_iter): + self.model.prepare_for_training() + + for batch in dataloader_iter: + self.optimizer.zero_grad() + + self.model.prepare_for_training_step() + loss_mean, metrics = self.model.get_loss_and_metrics(batch=batch, forward_only=False) + self.model.finish_training_step() + + grad_norm = clip_gradients(self.model, self.cfg.gradient_clip_val) + grad_norm = grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm + lr = self.optimizer.param_groups[0]["lr"] + + self.optimizer.step() + self.scheduler.step() + + if grad_norm is not None: + metrics["grad_norm"] = grad_norm + if lr is not None: + # Some optimizers like adafactor do not require a LR in their initializer + metrics["lr"] = lr + + metrics.update({"loss": loss_mean, "optim_step": self.reinforce_optimization_step}) + + self.logger.log_metrics( + metrics, step=self.step, prefix="train_optim/", + ) + + self.reinforce_optimization_step += 1 + + self.model.finish_training() + + # zero grad again incase it frees up grad mem + self.optimizer.zero_grad() + return loss_mean, metrics + + def fit(self): + epoch_iter = range(self.epoch, self.cfg.max_epochs) + if len(epoch_iter) <= 0: + # epoch done + return + + for _ in epoch_iter: + num_steps_in_epoch = min( + self.max_steps - self.step, self.num_steps_per_epoch - self.step % self.num_steps_per_epoch + ) + loop_iter = range(num_steps_in_epoch) + + if not loop_iter: + return # training ended + + global_pbar = tqdm( + loop_iter, initial=self.step, total=self.max_steps, leave=True, desc="REINFORCE Global Step" + ) + + dp_size = parallel_state.get_data_parallel_world_size() + + num_to_load_on_each_dp = divide(self.cfg.model_gbs, dp_size) + + self.run_timer.start_time() + for _ in global_pbar: + step_metrics = {} + timing_metrics = {} + + clear_memory() + with self.timer("rollout_time"): + reinforce_rollout_data, metrics, rollout_timer_metrics = self.generate_rollouts() + # Consume rollout_time + timing_metrics.update(self.timer.consume_durations()) + + rollout_timer_metrics = all_reduce_dict(rollout_timer_metrics, op=torch.distributed.ReduceOp.MAX) + timing_metrics.update(rollout_timer_metrics) + + # logging + table_metrics = metrics.pop("table") + self.train_df.loc[len(self.train_df)] = [ + self.step, + table_metrics["prompt"], + table_metrics["response"], + table_metrics["reward"], + ] + metrics["epoch"] = self.epoch + 1 + self.logger.log_metrics( + metrics, step=self.step, prefix="train_rollouts/", + ) + self.logger.log_table( + key="table/train_rollouts", dataframe=self.train_df, step=self.step, + ) + + rollout_size = reinforce_rollout_data["response_tokens"].size(0) + rollout_dataloader_iter = get_iterator_k_split( + reinforce_rollout_data, divide(rollout_size, num_to_load_on_each_dp) + ) + # start training + clear_memory() + with self.timer("train_time"): + self.run_training(rollout_dataloader_iter) + + self.logger.log_metrics( + timing_metrics | self.timer.consume_durations(), step=self.step, prefix="timers/" + ) + + self.step += 1 + + run_time_exceeded = self.run_timer.is_finished() + run_val, save_model, is_train_end = check_progress( + self.step, + self.max_steps, + self.cfg.val_check_interval, + self.cfg.save_interval, + 1.0, # TODO:(geshen): allow for limit val batches + run_time_exceeded=run_time_exceeded, + ) + + if run_val: + with self.timer("validation_time"): + val_metrics = self.run_validation() + # Note: validation_time is logged one step behind (val step 5 means we've completed step 4) + timing_metrics.update(self.timer.consume_durations()) + + val_table_metrics = val_metrics.pop("table") + + self.val_df.loc[len(self.val_df)] = [ + self.step, + val_table_metrics["prompt"], + val_table_metrics["response"], + val_table_metrics["reward"], + ] + self.logger.log_metrics(val_metrics, step=self.step, prefix="val_rollouts/") + self.logger.log_table("table/val_rollouts", dataframe=self.val_df, step=self.step) + + step_metrics.update({f"val_{k}": v for k, v in val_metrics.items()}) + + step_metrics.update(timing_metrics) + step_metrics.update({f"train_{k}": v for k, v in metrics.items()}) + global_pbar.set_postfix(step_metrics) + + if save_model: + step_metrics = {k: torch.as_tensor(v) for k, v in step_metrics.items()} + self.save(step_metrics, is_train_end=is_train_end) + + if run_time_exceeded: + logging.info(f"Time limit given by run_timer={self.run_timer} reached. Stopping run") + return + + self.logger.finalize() + + def state_dict(self): + return { + "step": self.step, + "consumed_samples": self.consumed_samples, + "epoch": self.epoch, + "reinforce_optimization_step": self.reinforce_optimization_step, + } + + def load_state_dict(self, state_dict): + self.step = state_dict["step"] + self.consumed_samples = state_dict["consumed_samples"] + self.reinforce_optimization_step = state_dict["reinforce_optimization_step"] + + loaded_values = [self.step, self.consumed_samples, self.reinforce_optimization_step] + + # make sure everyone loaded the same checkpoint as rank 0 + to_broadcast = torch.tensor(loaded_values, dtype=torch.float32, device=torch.cuda.current_device()) + torch.distributed.broadcast(to_broadcast, 0) + + assert loaded_values == to_broadcast.tolist() + # restore max steps we need to run for + self.set_max_steps() + + def save(self, extra_candidates=None, is_train_end=False): + self.model.prepare_for_training() + # load back in the adam states if needed + torch.cuda.synchronize() + torch.distributed.barrier() + + if extra_candidates is None: + extra_candidates = {} + + monitor_candidates = {k: torch.tensor(v, dtype=torch.int32) for k, v in self.state_dict().items()} + monitor_candidates.update(extra_candidates) + + self.ckpt_callback.custom_save(monitor_candidates=monitor_candidates, is_train_end=is_train_end) + + self.model.finish_training() + + def set_max_steps(self): + self.max_steps = self.num_steps_per_epoch * self.cfg.max_epochs + + if (max_steps := self.cfg.get("max_steps", -1)) >= 0: + self.max_steps = min(self.max_steps, max_steps) + + @property + def epoch(self): + return self.step // self.num_steps_per_epoch diff --git a/nemo_aligner/models/nlp/gpt/megatron_gpt_reinforce_actor.py b/nemo_aligner/models/nlp/gpt/megatron_gpt_reinforce_actor.py new file mode 100644 index 000000000..a98180fe8 --- /dev/null +++ b/nemo_aligner/models/nlp/gpt/megatron_gpt_reinforce_actor.py @@ -0,0 +1,393 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import nullcontext + +import torch +import torch.distributed +from megatron.core.num_microbatches_calculator import get_num_microbatches +from megatron.core.pipeline_parallel.schedules import get_forward_backward_func +from megatron.core.utils import divide +from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.megatron.utils import ( + average_losses_across_data_parallel_group, + get_iterator_k_split, + get_ltor_masks_and_position_ids, +) +from nemo.collections.nlp.parts.mixins.nlp_adapter_mixins import NLPAdapterModelMixin +from nemo.collections.nlp.parts.utils_funcs import get_last_rank +from nemo.utils import logging +from nemo_aligner.models.alignable_interface import AlignableGenerativeInterface +from nemo_aligner.utils import parallel_state +from nemo_aligner.utils.distributed import ( + broadcast_2d_tensor_within_pp, + calculate_distributed_entropy, + from_parallel_logits_to_logprobs, +) +from nemo_aligner.utils.text_generation_utils import ( + TrackLengthGPTModelTextGenerationStrategy, + verify_is_valid_and_clamp_range_, +) +from nemo_aligner.utils.train_utils import ( + grad_reductions, + prepare_for_training_step, + set_eval, + set_sync_funcs, + set_train, +) +from nemo_aligner.utils.trt_llm import GPTGenerateTRTLLM +from nemo_aligner.utils.utils import ( + adapter_control, + clear_memory, + configure_batch_sizes, + cpu_weight_swap, + masked_mean, + offload_distributed_adam, +) + + +class MegatronGPTReinforceActorModel(NLPAdapterModelMixin, MegatronGPTModel, AlignableGenerativeInterface): + def __init__(self, cfg: DictConfig, trainer: Trainer): + super().__init__(cfg, trainer=trainer) + self.automatic_optimization = False + + self.init_policy_state_dict = None + self.distributed_adam_offload_manager = None + + # length parameters for generation + self._length_params = OmegaConf.to_container(self.cfg.reinforce.length_params, resolve=True) + # sampling parameters for generation + self._sampling_params = OmegaConf.to_container(self.cfg.reinforce.sampling_params, resolve=True) + + self.to_offload_adam_states = self.cfg.reinforce.offload_adam_states and self.with_distributed_adam + self.forward_micro_batch_size = self.cfg.reinforce.forward_micro_batch_size + + self.use_trtllm_generation = "trt_llm" in self.cfg.reinforce and self.cfg.reinforce.trt_llm.enable + if self.use_trtllm_generation: + self.trtllm_generate = GPTGenerateTRTLLM( + model_cfg=self.cfg, + max_generation_length=self.cfg.reinforce.length_params.get("max_length", 1024), + max_input_len=self.cfg.reinforce.trt_llm.get("max_input_len", 1024), + generation_batch_size=self.cfg.reinforce.get("rollout_micro_batch_size", 4), + unload_engine_train=self.cfg.reinforce.trt_llm.get("unload_engine_train", False), + trt_model_type=self.cfg.reinforce.trt_llm.get("model_type", "llama"), + end_strings=self.cfg.reinforce.sampling_params["end_strings"], + reshard_model=self.cfg.reinforce.trt_llm.get("reshard", False), + sample_temperature=self.cfg.reinforce.sampling_params["temperature"], + sample_top_k=self.cfg.reinforce.sampling_params["top_k"], + sample_top_p=self.cfg.reinforce.sampling_params["top_p"], + repetition_penalty=self.cfg.reinforce.sampling_params["repetition_penalty"], + use_greedy=self.cfg.reinforce.sampling_params.get("use_greedy", False), + tokenizer=self.tokenizer, + seed=self.cfg.reinforce.trt_llm.get("seed", self.cfg.seed), + ) + + # training calls + def get_actor_forward_output_and_loss_func(self): + def fwd_output_and_loss_func(data_iterator, model): + batch = next(data_iterator) + required_keys = set() + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + required_keys.update(batch.keys()) + else: + required_keys.add("attention_mask") + + if parallel_state.is_pipeline_first_stage(): + required_keys.update(("response_tokens", "position_ids")) + + if parallel_state.is_pipeline_last_stage(): + required_keys.update(("response_tokens", "baseline", "mask", "rewards_with_kl", "is_end")) + + batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} + + parallel_logits = model( + batch["response_tokens"], batch["position_ids"], batch["attention_mask"], labels=None, + ) + + def loss_func(parallel_logits): + mask = batch["mask"] + rewards_with_kl = batch["rewards_with_kl"] + baseline = batch["baseline"] + tokens = batch["response_tokens"] + is_end = batch["is_end"] + + is_end_mask = mask * is_end.view(-1, 1) + + curr_log_probs = from_parallel_logits_to_logprobs( + vocab_parallel_logits=parallel_logits, target=tokens, higher_stability=True + ) + + reinforce_loss = -1 * curr_log_probs * (rewards_with_kl - baseline) + + if is_end_mask.sum() > 0: + loss = masked_mean(reinforce_loss, mask) + else: + # hack to disable this update since there are no valid tokens + loss = reinforce_loss.view(-1)[0] * 0 + + reduced_actor_loss = average_losses_across_data_parallel_group([loss]) + return ( + loss, + {"loss": reduced_actor_loss,}, + ) + + return parallel_logits, loss_func + + return fwd_output_and_loss_func + + def prepare_for_training(self): + configure_batch_sizes( + mbs=self.cfg.micro_batch_size, + gbs=self.cfg.global_batch_size, + dp=parallel_state.get_data_parallel_world_size(), + ) + self.onload_adam_states() + + def prepare_for_training_step(self): + # custom trainers will always zero grad for us + prepare_for_training_step(self, zero_grad=False) + + def get_loss_and_metrics(self, batch, forward_only): + sequence_length = batch["response_tokens"].size(1) + + attention_mask, _, position_ids = self.get_ltor_masks_and_position_ids(tokens=batch["response_tokens"]) + batch["attention_mask"] = attention_mask + batch["position_ids"] = position_ids + + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + set_sync_funcs(self, forward_only) + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_actor_forward_output_and_loss_func(), + data_iterator=self._make_data_iterator_list(data_iter), + model=self.model, + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=sequence_length, + micro_batch_size=self.cfg.micro_batch_size, + ) + + metrics = {} + + for key in ["loss"]: + if losses_reduced_per_micro_batch: + metric_mean = torch.stack( + [loss_reduced[key] for loss_reduced in losses_reduced_per_micro_batch] + ).mean() + else: + metric_mean = torch.tensor(0.0, device=torch.cuda.current_device()) + + torch.distributed.broadcast(metric_mean, get_last_rank()) + + metrics[key] = metric_mean.cpu().item() + + return metrics["loss"], metrics + + def finish_training_step(self): + grad_reductions(self) + + def finish_training(self): + """no need to offload adam states here + """ + + # inference calls + def get_logprob_output_only_func(self, inference_only=True): + fwd_output_only_func = self.get_forward_output_only_func() + + def log_prob_output_only_func(dataloader_iter, model): + batch = next(dataloader_iter) + + output_tensor, _ = fwd_output_only_func(iter([batch,]), model) + + def id_func(output_tensor, non_loss_data=True): + logprobs = from_parallel_logits_to_logprobs( + vocab_parallel_logits=output_tensor, + target=batch[0], + inference_only=inference_only, + higher_stability=True, + ) + return logprobs + + return output_tensor, id_func + + return log_prob_output_only_func + + @torch.no_grad() + def get_inference_log_probs(self, response_tokens, forward_micro_batch_size=None): + if forward_micro_batch_size is None: + forward_micro_batch_size = self.forward_micro_batch_size + + set_sync_funcs(self, forward_only=True) + + mbs, seq_length = response_tokens.size() + num_microbatches = divide(mbs, forward_micro_batch_size) + attention_mask, _, position_ids = self.get_ltor_masks_and_position_ids(response_tokens) + + batch_iter = get_iterator_k_split([response_tokens, attention_mask, position_ids], num_microbatches) + + fwd_bwd_function = get_forward_backward_func() + logprobs_list = fwd_bwd_function( + forward_step_func=self.get_logprob_output_only_func(inference_only=True), + data_iterator=self._make_data_iterator_list(batch_iter), + model=self.model, + num_microbatches=num_microbatches, + forward_only=True, + seq_length=seq_length, + micro_batch_size=forward_micro_batch_size, + collect_non_loss_data=True, + ) + + logprobs = torch.cat(logprobs_list) if len(logprobs_list) > 0 else None + + # Broadcast it from last PP stage to everything else. + logprobs = broadcast_2d_tensor_within_pp(logprobs) + + return logprobs + + def prepare_for_inference(self): + """normally we would configure the micro batch calculator here + but the nemo generation already does the configuration""" + self._reset_activation_checkpointing_args() + self._reset_sequence_parallelism_args() + set_eval(self) + self.offload_adam_states() + + if self.use_trtllm_generation: + # TODO this might be optimized to avoid calling `refit()` twice in a row after a validation step + self.trtllm_generate.refit(self.model) + clear_memory() + + @torch.no_grad() + def infer(self, inference_batch): + prompt_tokens = inference_batch["text"].cuda(non_blocking=True) + prompt_lengths = inference_batch["length"].cuda(non_blocking=True) + inputs = (prompt_tokens, prompt_lengths) + + strategy = TrackLengthGPTModelTextGenerationStrategy( + model=self, context_lengths=prompt_lengths, max_length=self._length_params["max_length"] + ) + + if self.use_trtllm_generation: + actor_output = self.trtllm_generate.generate(inputs) + response_tokens = actor_output["response_tokens"] + response_lengths = actor_output["response_lengths"] + else: + actor_output = self.generate( + inputs=inputs, + length_params=self._length_params, + sampling_params=self._sampling_params, + strategy=strategy, + ) + response_tokens = torch.cuda.LongTensor(actor_output["token_ids"]) if actor_output else None + response_tokens = broadcast_2d_tensor_within_pp(response_tokens, dtype=torch.long) + response_lengths = strategy.get_lengths() + + max_response_length = response_lengths.max().item() + + # Sanity check to validate response length. + if max_response_length != response_tokens.size(1): + # This may actually happen because NeMo does not always stop generation after `max_length` in batch mode + # => `response_tokens` may contain up to `max_length + max_context_length` tokens. + # TODO once NeMo fixes this issue we should be able to always raise an exception when the check above fails, + # and remove the `if` below. + if ( + max_response_length >= response_tokens.size(1) + or response_tokens.size(1) != prompt_lengths.max().item() + self._length_params["max_length"] + ): + raise AssertionError( + f"max response length ({max_response_length}) does not match the size of " + f"`response_tokens` ({response_tokens.size(1)})" + ) + + # sometimes backends like TRT-LLM will generate invalid tokens + # so we need to also inplace mutate the response_tokens to be within the tokenizer range + is_valid = verify_is_valid_and_clamp_range_( + response_tokens, + response_lengths, + strategy, + self.tokenizer, + self.cfg.reinforce.sampling_params["end_strings"], + ) + + rollout_batch = { + "response_tokens": response_tokens, + "response_lengths": response_lengths, + "prompt_lengths": prompt_lengths, + "is_end": is_valid, + } + + # return in GPU, trainer needs to move to cpu + + return rollout_batch + + def get_init_policy_logprobs(self, response_tokens): + use_peft_init_policy = self.use_peft and self.init_policy_state_dict is None + + context_mgr = ( + adapter_control(self) + if use_peft_init_policy + else cpu_weight_swap(self, self.init_policy_state_dict, megatron_amp_O2=self.megatron_amp_O2) + ) + + with context_mgr: + return self.get_inference_log_probs(response_tokens) + + def finish_inference(self): + # training will onload the adam states, no need to onload it here + self._restore_activation_checkpointing_args() + self._restore_sequence_parallelism_args() + + if self.use_trtllm_generation: + self.trtllm_generate.free() + + set_train(self) + + def offload_adam_states(self): + if self.distributed_adam_offload_manager is None: + + self.distributed_adam_offload_manager = ( + offload_distributed_adam( + self._optimizer.state_dict(state_dict_format=1, gather_on_root=False), force_clear_memory=True + ) + if self.to_offload_adam_states + else nullcontext() + ) + + # offload onto cpu + self.distributed_adam_offload_manager.__enter__() + + def onload_adam_states(self): + if self.distributed_adam_offload_manager is not None: + # load back onto GPU + self.distributed_adam_offload_manager.__exit__(None, None, None) + + self.distributed_adam_offload_manager = None + + def get_ltor_masks_and_position_ids(self, tokens): + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + data=tokens, + eod_token=self.tokenizer.eos_id, + reset_position_ids=self.cfg.data.get("reset_position_ids", False), + reset_attention_mask=self.cfg.data.get("reset_attention_mask", False), + eod_mask_loss=False, # since we ignore the loss mask here + ) + attention_mask = attention_mask.expand(tokens.size(0), -1, -1, -1) + position_ids = position_ids.expand(tokens.size(0), -1) + + return attention_mask, loss_mask, position_ids diff --git a/nemo_aligner/utils/ppo_utils.py b/nemo_aligner/utils/ppo_utils.py index 1d1f5cf67..0a69e3b9a 100644 --- a/nemo_aligner/utils/ppo_utils.py +++ b/nemo_aligner/utils/ppo_utils.py @@ -112,3 +112,31 @@ def select_topk(batch, num_select=1): selected_batch = {k: batch[k][selected_idx] for k in batch.keys()} return selected_batch + + +def calculate_rloo_baseline(prompts, reward, mask): + """ + Function to select the RLOO baseline for each (prompt, response) pair in the batch. + The same baseline is calculated for each prompt. Masked samples are not included + in the baseline calculation. + """ + unique_prompts = torch.unique(prompts, dim=0) + + baseline = torch.zeros_like(reward) + reward_device = reward.get_device() + if reward_device == -1: + reward_device = "cpu" + + for i in range(len(unique_prompts)): + is_matching_prompt = (prompts == unique_prompts[i]).all(1) + prompt_idx = torch.arange(len(prompts), device=reward_device)[is_matching_prompt] + rloo_mat = (1 - torch.eye(len(prompt_idx))).to(reward_device) + + if mask[prompt_idx].sum() <= 1: + # Ignore sample: set baseline equal to reward + baseline[prompt_idx] = reward[prompt_idx] + else: + rloo = torch.matmul(rloo_mat, reward[prompt_idx] * mask[prompt_idx]) / (mask[prompt_idx].sum() - 1) + baseline[prompt_idx] = rloo + + return baseline diff --git a/tests/functional/reinforce.sh b/tests/functional/reinforce.sh new file mode 100755 index 000000000..3529acdbb --- /dev/null +++ b/tests/functional/reinforce.sh @@ -0,0 +1,178 @@ +#!/bin/bash + +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR +set -eoux pipefail + +export NCCL_ALGO=Tree +export NVTE_APPLY_QK_LAYER_SCALING=1 + +KL=${KL:-0.03} +LR=${LR:-9e-7} +RUN_ONLY=${RUN_ONLY:-} +GBS=${GBS:-2} +TP_SIZE=${TP_SIZE:-1} +PP_SIZE=${PP_SIZE:-2} +RESHARD=${RESHARD:-True} +RM_NEMO_FILE=${RM_NEMO_FILE} +ACTOR_NEMO_FILE=${ACTOR_NEMO_FILE} + + +MIN_LR=$(awk -v var="$LR" 'BEGIN {print var - 1e-11}') + +TRAIN_DATA_PATH=$SCRIPT_DIR/test_data/synthetic-123.jsonl +VALID_DATA_PATH=$SCRIPT_DIR/test_data/synthetic-123.jsonl + +NAME="reinforce_test" + +# PARAMETERS +RESULTS_DIR="/tmp/${NAME}" +mkdir -p $RESULTS_DIR + +GPFS=$(git rev-parse --show-toplevel) + +# W&B Logging +PROJECT=reinforce_test + +REWARD_LOG_DIR="${RESULTS_DIR}/reward_results" +REWARD_PORT=5555 + +mkdir -p $REWARD_LOG_DIR + +REWARD_NAME="${NAME}_reward" + +reward() { +export CUDA_VISIBLE_DEVICES=0 +export PYTHONPATH="${GPFS}:${PYTHONPATH:-}" +export HYDRA_FULL_ERROR=1 +python -u ${GPFS}/examples/nlp/gpt/serve_reward_model.py \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + inference.port=${REWARD_PORT} \ + ++model.tensor_model_parallel_size=1 \ + ++model.pipeline_model_parallel_size=1 \ + ++model.dist_ckpt_load_strictness=log_all \ + rm_model_file=${RM_NEMO_FILE} +} +reward_log_file=$(mktemp /tmp/reward-reinforce-log-XXXXXX) +if [[ $RUN_ONLY =~ actor* ]]; then + echo SKIPPING REWARD +elif [[ $RUN_ONLY == reward ]]; then + reward 2>&1 | stdbuf -o0 sed 's/^/[REWARD_SERVER]: /' | tee $reward_log_file + exit $? +else + reward 2>&1 | stdbuf -o0 sed 's/^/[REWARD_SERVER]: /' | tee $reward_log_file & +fi + +if [[ -z "${FAST:-}" ]]; then + sleep 15 +fi +######################################################################################### + +ACTOR_LOG_DIR="${RESULTS_DIR}/actor_results" +mkdir -p $ACTOR_LOG_DIR + +ACTOR_NAME="${NAME}_actor" +host_reward=localhost + +actor() { +export CUDA_VISIBLE_DEVICES=0,1 +export PYTHONPATH="${GPFS}:${PYTHONPATH:-}" +export HYDRA_FULL_ERROR=1 +mpirun -np 2 --allow-run-as-root python -u ${GPFS}/examples/nlp/gpt/train_gpt_reinforce_actor.py \ + "++model.data.data_prefix={train: [${TRAIN_DATA_PATH}], validation: [${VALID_DATA_PATH}], test: [${VALID_DATA_PATH}]}" \ + pretrained_checkpoint.restore_from_path=${ACTOR_NEMO_FILE} \ + exp_manager.explicit_log_dir=${ACTOR_LOG_DIR} \ + exp_manager.create_wandb_logger=True \ + exp_manager.wandb_logger_kwargs.name=${ACTOR_NAME} \ + exp_manager.wandb_logger_kwargs.project=${PROJECT} \ + exp_manager.create_checkpoint_callback=True \ + trainer.num_nodes=1 \ + trainer.devices=2 \ + trainer.reinforce.trt_llm.enable=True \ + ++model.offload_adam_states=False \ + trainer.reinforce.trt_llm.reshard=${RESHARD} \ + trainer.reinforce.val_check_interval=2 \ + ++trainer.reinforce.save_interval=2 \ + ++model.micro_batch_size=1 \ + ++model.global_batch_size=${GBS} \ + ++model.tensor_model_parallel_size=${TP_SIZE} \ + ++model.pipeline_model_parallel_size=${PP_SIZE} \ + ++model.reinforce.entropy_bonus=0.0 \ + ++model.reinforce.ratio_eps=0.2 \ + ++model.encoder_seq_length=64 \ + ++exp_manager.checkpoint_callback_params.save_top_k=10 \ + ++model.reinforce.num_rollout_samples=${GBS} \ + ++model.reinforce.rollout_micro_batch_size=1 \ + ++model.reinforce.length_params.max_length=32 \ + ++model.reinforce.forward_micro_batch_size=1 \ + trainer.reinforce.initial_policy_kl_penalty="${KL}" \ + trainer.reinforce.rollout_batch_seq_length=32 \ + ++trainer.reinforce.flask_server.enable=True \ + ++model.optim.lr=${LR} \ + ++model.optim.sched.min_lr=${MIN_LR} \ + ++model.activations_checkpoint_granularity=full \ + ++model.activations_checkpoint_method=uniform \ + ++model.activations_checkpoint_num_layers=1 \ + ++model.optim.bucket_cap_mb=200 \ + ++model.optim.overlap_grad_sync=False \ + ++model.optim.contiguous_grad_buffer=True \ + ++model.enable_nge=True \ + remote_rm.reward_model.ip=${host_reward} \ + remote_rm.reward_model.port=${REWARD_PORT} \ + \ + +model.overwrite_base_config.optim=True \ + '~model.optim' \ + '++model.optim={name:sgd}' \ + model.reinforce.sampling_params.use_greedy=True \ + trainer.reinforce.save_interval=0 \ + trainer.reinforce.max_steps=3 \ + trainer.reinforce.trt_llm.model_type=llama \ + ++exp_manager=null \ + \ + ++model.dist_ckpt_load_strictness=log_all \ + $@ +} + +actor_log_file=$(mktemp /tmp/actor-reinforce-log-XXXXXX) +if [[ -z "$RUN_ONLY" || "$RUN_ONLY" == actor_trt || "$RUN_ONLY" == trt ]]; then + actor 2>&1 | stdbuf -o0 sed 's/^/[ACTOR_TRT]: /' +elif [[ "$RUN_ONLY" == actor_nemo || "$RUN_ONLY" == nemo ]]; then + actor trainer.reinforce.trt_llm.enable=False 2>&1 | stdbuf -o0 sed 's/^/[ACTOR_NEMO]: /' +else + echo "Only accepts RUN_ONLY=actor_nemo or actor_trt" + exit 1 +fi | tee $actor_log_file || true + +REWARD_ID=$(grep -oP "kill -SIGINT \K\d+" $reward_log_file) +if [[ $REWARD_ID =~ ^[0-9]+$ ]]; then + echo "Valid integer: $REWARD_ID" + kill -SIGINT $REWARD_ID +else + echo "Invalid REWARD_ID=$REWARD_ID detected!" + exit 1 +fi + +if ! fgrep 'Cleaning up communicator' $actor_log_file &>/dev/null; then + echo "[ERROR] Did not find 'Cleaning up communicator' in the actor logs ($actor_log_file) indicating the actor reached the end" + exit 1 +fi + +echo "Waiting for backgrounded processes to finish..." +wait +set +x +echo "[Finished] $0" diff --git a/tests/functional/test_cases/reinforce-llama3-pp2-reshard b/tests/functional/test_cases/reinforce-llama3-pp2-reshard new file mode 100755 index 000000000..7ee8342d7 --- /dev/null +++ b/tests/functional/test_cases/reinforce-llama3-pp2-reshard @@ -0,0 +1,28 @@ +#!/bin/bash + +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +cd $SCRIPT_DIR + +set -eoux pipefail + +GBS=2 \ + TP_SIZE=1 \ + PP_SIZE=2 \ + RESHARD=True \ + RM_NEMO_FILE=${ALIGNER_CI_DIR}/checkpoints/llama3--nlayers4-hidden64-ffn224-dummy_rm-megatron_gpt.nemo \ + ACTOR_NEMO_FILE=${ALIGNER_CI_DIR}/checkpoints/tiny-llama3-results-nlayers2-hidden128-ffn448-nhead4-qgroup2-megatron_gpt.nemo \ + bash ../reinforce.sh 2>&1 | tee $(basename $0).log diff --git a/tests/test_ppo_utils.py b/tests/test_ppo_utils.py index a12db274b..0959e9add 100644 --- a/tests/test_ppo_utils.py +++ b/tests/test_ppo_utils.py @@ -17,7 +17,12 @@ import torch import torch.nn.functional as F -from nemo_aligner.utils.ppo_utils import calculate_advantages_and_returns, calculate_entropy, calculate_ppo_rewards +from nemo_aligner.utils.ppo_utils import ( + calculate_advantages_and_returns, + calculate_entropy, + calculate_ppo_rewards, + calculate_rloo_baseline, +) class TestCalculateEntropy: @@ -120,3 +125,18 @@ def test_calculate_advantage_and_returns_small_example(self): assert torch.allclose(advantages, gt_advantages), "computed advantage is not the same as hand example" assert torch.allclose(returns, gt_advantages + values), "computed returns is not the same as hand example" + + +class TestCalculateRLOOBaseline: + def test_calculate_rloo_baseline_small_example(self): + + prompts = torch.Tensor([[1, 0], [1, 0], [0, 1], [1, 0], [1, 0], [0, 1], [0, 1], [0, 1],]) + + rewards = torch.Tensor([1, 0, 2, -3, 5, 7, -1, 0]) + mask = torch.Tensor([1, 1, 1, 1, 1, 1, 1, 0]) + + baseline = calculate_rloo_baseline(prompts, rewards, mask) + + gt_baseline = torch.Tensor([2 / 3, 1.0, 3.0, 2.0, -2 / 3, 1 / 2, 9 / 2, 8 / 2]) + + assert torch.allclose(baseline, gt_baseline), "computed baseline is not the same as hand example" From 4247dc5b5e4ac803220fd1a5533264b86d430d81 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Wed, 27 Nov 2024 01:04:12 -0800 Subject: [PATCH 2/5] docs: Apply other feedback from 24.09 VDR (#411) Signed-off-by: Terry Kong --- docs/user-guide/aligner-algo-header.rst | 4 + docs/user-guide/cai.rst | 78 +++++----- docs/user-guide/dpo.rst | 41 ++++-- docs/user-guide/draftp.rst | 9 +- docs/user-guide/index.rst | 134 ++++++++++++++++-- docs/user-guide/knowledge-distillation.rst | 13 +- .../{modelalignment.rsts => nemoaligner.rsts} | 37 ++++- docs/user-guide/rlhf.rst | 11 +- docs/user-guide/rs.rst | 9 +- docs/user-guide/sft.rst | 11 +- docs/user-guide/spin.rst | 7 +- docs/user-guide/steerlm.rst | 11 +- docs/user-guide/steerlm2.rst | 24 ++-- 13 files changed, 262 insertions(+), 127 deletions(-) create mode 100644 docs/user-guide/aligner-algo-header.rst rename docs/user-guide/{modelalignment.rsts => nemoaligner.rsts} (67%) diff --git a/docs/user-guide/aligner-algo-header.rst b/docs/user-guide/aligner-algo-header.rst new file mode 100644 index 000000000..15114dc02 --- /dev/null +++ b/docs/user-guide/aligner-algo-header.rst @@ -0,0 +1,4 @@ +.. important:: + Before starting this tutorial, be sure to review the :ref:`introduction ` for tips on setting up your NeMo-Aligner environment. + + If you run into any problems, refer to NeMo's `Known Issues page `__. The page enumerates known issues and provides suggested workarounds where appropriate. \ No newline at end of file diff --git a/docs/user-guide/cai.rst b/docs/user-guide/cai.rst index 2f24d1000..49253685a 100644 --- a/docs/user-guide/cai.rst +++ b/docs/user-guide/cai.rst @@ -1,6 +1,8 @@ .. include:: /content/nemo.rsts -.. _model-aligner-cai: +.. include:: aligner-algo-header.rst + +.. _nemo-aligner-cai: Constitutional AI: Harmlessness from AI Feedback @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@ -14,12 +16,12 @@ CAI allows training a harmless, but non-evasive AI assistant that engages with h .. _Constitutional AI (CAI): https://arxiv.org/abs/2212.08073 CAI -############### -The basic steps of CAI are described in this section and illustrated in the figure below (`Figure 1 `_). +### +The basic steps of CAI are described in this section and illustrated in the `figure below `_. (Supervised Stage) Critique → Revision → Supervised Learning: The AI generates responses to harmfulness prompts using a helpful-only AI assistant, then critiques and revises its own responses according to a principle in the constitution, and then fine-tunes the original model on the revised responses. -(RL Stage) AI Comparison Evaluations → Reward Model → Reinforcement Learning: The AI generates pairs of responses to harmfulness prompts using the finetuned model, then evaluates which response is better according to a principle in the constitution, and then trains a reward model based on this dataset of AI preferences and a human helpfulness preferences. The AI then trains with RL using the learned reward model. +(RL Stage) AI Comparison Evaluations → Reward Model → Reinforcement Learning: The AI generates pairs of responses to harmfulness prompts using the fine-tuned model, then evaluates which response is better according to a principle in the constitution, and then trains a reward model based on this dataset of AI preferences and a human helpfulness preferences. The AI then trains with RL using the learned reward model. .. image:: ../assets/cai_diagram.png :alt: basic steps of the CAI process @@ -29,25 +31,22 @@ The basic steps of CAI are described in this section and illustrated in the figu Critiques, revisions, and AI harmlessness feedback are steered by a small set of principles drawn from a ‘constitution’. The supervised stage significantly improves the initial model. It gives some control over the initial behavior at the start of the RL phase, while addressing potential exploration problems. The RL stage significantly improves performance and reliability. Motivation -############### +########## Constitutional AI motivation refers to designing AI systems in such a way that their objectives and behaviors are guided by a set of predefined rules or principles. It includes the following: -Scaling supervision: using AI to help humans supervise other AIs more efficiently and effectively, especially for tasks where AI capabilities may exceed human ones. - -A harmless but non-evasive assistant: reducing the tension between helpfulness and harmlessness, and avoiding evasive responses that reduce transparency and helpfulness. - -Simplicity and transparency: encoding the training goals in a simple list of natural language instructions or principles, and using chain-of-thought reasoning to make AI decision making explicit and understandable. +- Scaling Supervision: Use AI to assist humans in supervising other AIs more efficiently and effectively, particularly for tasks where AI capabilities may surpass human ones. +- A Harmless but Non-Evasive Assistant: Minimize the tension between helpfulness and harmlessness, and avoid evasive responses that reduce transparency and helpfulness. +- Simplicity and Transparency: Encode training goals in a straightforward list of natural language instructions or principles, and employ chain-of-thought reasoning to make AI decision-making explicit and understandable. +- Reducing Iteration Time: Eliminate the need to collect new human feedback labels when modifying objectives or testing different behaviors. -Reducing iteration time: obviating the need to collect new human feedback labels when altering the objective or testing different behaviors. - -Train a CAI model -##################### +Train a CAI Model +################# This section is a step-by-step tutorial that walks you through how to run a full CAI pipeline with a ``Mistral-7B`` LLM model. It includes the following: -1. Data download and preprocessing. +1. Download the models and datasets. -2. Generate responses to harmfulness prompts using a helpful-only AI assistant. Ask the model to critique its response according to a principle in the constitution, and then revise the original response in light of the critique. +2. Generate and revise responses to harmful prompts creating the SL-CAI dataset. Ask the model to critique its response according to a principle in the constitution, and then revise the original response in light of the critique. 3. Fine-tune ``Mistral-7B`` with SFT on the revised responses to create a ``Mistral-7B-SL-CAI`` model. @@ -56,24 +55,22 @@ This section is a step-by-step tutorial that walks you through how to run a full b. Formulate each prompt and pair into a multiple choice question, where we ask ``Mixtral-8x7B`` which response is best according to the constitution. c. Blend the AI feedback preference dataset (prompts and pairs) with human feedback helpfulness dataset. -5. Train a Reward Model (RM). +5. Train the Reward Model (RM). 6. Fine-tune the ``Mistral-7B-SL-CAI`` with Proximal Policy Optimization (PPO) and the RM to train a ``Mistral-7B-RL-CAI`` model. 7. Run inference. -.. note:: - Before starting this tutorial, be sure to review the :ref:`introduction ` for tips on setting up your NeMo-Aligner environment. - - If you run into any problems, refer to NeMo's `Known Issues page `__. The page enumerates known issues and provides suggested workarounds where appropriate. +.. _nemo-aligner-cai-flow-diagram: .. image:: ../assets/cai_flow.png -Step 1: Download models and datasets -############################################################################# -1. Download ``Mistral-7B-Instruct`` and ``Mistral-7B`` LLM models from https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 and https://huggingface.co/mistralai/Mistral-7B-v0.1 into the models folder. +Step 1: Download the models and datasets +######################################## + +1. Download the ``Mistral-7B-Instruct`` and ``Mistral-7B`` LLM models from https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 and https://huggingface.co/mistralai/Mistral-7B-v0.1 into the models folder. - Then, convert into .nemo format: + Then, convert them into .nemo format: .. code-block:: bash @@ -92,7 +89,7 @@ Step 1: Download models and datasets This command will download the dataset to ``/path/to/anthropic_red_team_attempts_train.json`` -3. Download SFT helpfulness dataset: +3. Download the SFT helpfulness dataset: .. code-block:: bash @@ -101,7 +98,7 @@ Step 1: Download models and datasets This command will download the dataset to ``/path/to/nvidia_sft_datablend_v1_train.json`` -4. Download and process preference helpfulness dataset: +4. Download and process the preference helpfulness dataset: .. code-block:: bash @@ -112,7 +109,7 @@ Step 1: Download models and datasets Step 2: Generate and revise responses to harmful prompts creating the SL-CAI dataset -################################################################################################### +#################################################################################### Run an inference server in the background using the following command: @@ -158,16 +155,16 @@ Please wait for the server to be ready before proceeding. --apply_chat_template False \ --response_extract_pattern "[/INST]" -This will generate an SL-CAI dataset of prompts and revised responses as ``cai_revisions_aligner_chat_template.json`` +This will generate an SL-CAI dataset of prompts and revised responses as ``cai_revisions_aligner_chat_template.json``. -The few-shot samples should be provided following the template in ``few_shot_samples_example.json`` (filling in the `content` tags, and choosing how many samples to use), and should include a red teaming prompt, a response from the helpful model (e.g. ``Mistral-7B`` in this tutorial), critique and revision requests and responses. An example is shown in the `Anthropic repo `_. +The few-shot samples should be provided following the template in ``few_shot_samples_example.json``. Fill in the `content` tags and choose how many samples to use. The samples should include a red teaming prompt, a response from the helpful model (e.g., ``Mistral-7B`` in this tutorial), critique and revision requests, and responses. An example is shown in the `Anthropic repo `_. -*NOTE: The tokenizer file can be found by extracting the .nemo checkpoint using `tar -xf /models/mistral/mistral-7b-Instruct.nemo`. -There are 2 tokenizer files that end with `.model` in the model checkpoint and they are the same, so you can use either one for data processing.* +.. note:: + The tokenizer file can be found by extracting the .nemo checkpoint using `tar -xf /models/mistral/mistral-7b-Instruct.nemo`. There are two tokenizer files that end with `.model` in the model checkpoint, and they are identical. You can use either one for data processing. Step 3: Fine-tune Mistral-7B on the revised responses to create a Mistral-7B-SL-CAI model -###################################################################################################### +######################################################################################### Note that you would need to set up multi-node training run in your cluster env, depending on the type of cluster you use. For details, please refer to https://lightning.ai/docs/pytorch/stable/clouds/cluster.html . @@ -199,10 +196,9 @@ Note that you would need to set up multi-node training run in your cluster env, Step 4: Generate the RL-CAI (preference) dataset for RM and PPO training -############################################################################################################## +######################################################################## -The following section runs an inference server with the SL-CAI model that we've previously trained, and queries it with red teaming prompts asking for several responses per prompt. -The responses will then be ranked by a judge LLM being run from NVIDIA's NGC. An NGC API key can be acquired `here`_. +The following section runs an inference server with the SL-CAI model that we've previously trained. It queries the server with red teaming prompts, requesting several responses per prompt. These responses will then be ranked by a judge LLM running from NVIDIA's NGC. You can acquire an NGC API key `here`_. The following command will run the inference server: @@ -257,8 +253,8 @@ Using a different terminal, run the following command to start the RL-CAI datase This command will create the ``rl-cai`` dataset files in the defined output folder with the given output filename prefix. -Step 5: Train the RM -##################### +Step 5: Train the Reward Model (RM) +################################### Run the following command to train the RM: @@ -285,7 +281,7 @@ Run the following command to train the RM: The trained RM checkpoint will be saved to output dir given by ``exp_manager.explicit_log_dir``. -Step 6: Fine-tune Mistral-7B-SL-CAI with PPO and the RM to train a Mistral-7B-RL-CAI model +Step 6: Fine-tune the Mistral-7B-SL-CAI with PPO and the RM to train a Mistral-7B-RL-CAI model ############################################################################################## Run the following command in the background to launch a RM and PPO critic training server: @@ -329,8 +325,8 @@ Run the following command to launch actor training and a reference policy server The trained LLM policy checkpoint will be saved to the output dir given by ``exp_manager.explicit_log_dir``. -Step 7: Inference -################## +Step 7: Run inference +##################### To start inference, run an inference server in the background using the following command: .. code-block:: bash diff --git a/docs/user-guide/dpo.rst b/docs/user-guide/dpo.rst index 70e4b0d9e..901ceee37 100644 --- a/docs/user-guide/dpo.rst +++ b/docs/user-guide/dpo.rst @@ -1,15 +1,12 @@ .. include:: /content/nemo.rsts -.. _model-aligner-dpo: +.. include:: aligner-algo-header.rst + +.. _nemo-aligner-dpo: Model Alignment by DPO, RPO, and IPO @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -.. note:: - Before starting this tutorial, be sure to review the :ref:`introduction ` for tips on setting up your NeMo-Aligner environment. - - If you run into any problems, refer to NeMo's `Known Issues page `__. The page enumerates known issues and provides suggested workarounds where appropriate. - The NeMo Framework supports efficient model alignment via the NeMo-Aligner codebase. All algorithms in NeMo-Aligner will work with any GPT-based model that is from Megatron Core (in the config it has ``mcore_gpt=True``). For the purposes of this tutorial, we will go through the entire Direct Preference Optimization (DPO) pipeline using the newly released `2B GPT model with 4096 sequence length `__. The same tutorial also works for GPT models (such as LLaMa3) of any size. @@ -22,7 +19,7 @@ In full-parameter DPO, there exists an actor and a reference model. The actor is For LoRA-based DPO, the actor is initialized by the reference model plus LoRA weights, where only the LoRA weights are trainable. Therefore, it allows us to switch between the actor/reference models by simply enabling or disabling LoRA. In addition, there is no need to store two sets of LLM weights. RPO and IPO Variations -####################### +###################### Besides the vanilla DPO algorithm, we support other variants of DPO algorithms, including Identity Preference Optimization (IPO) and Reward-aware Preference Optimization (RPO). @@ -31,7 +28,7 @@ The algorithm is identified with the ``dpo.preference_loss`` config variable. We To use the RPO algorithm, each dataset example should have ``chosen_reward`` and ``rejected_reward``, which might come from human labelers or reward models. If ``chosen_reward`` and ``rejected_reward`` are not existent in the data, ``dpo.default_chosen_reward`` and ``dpo.default_rejected_reward`` are used. Obtain a Pretrained Model -############################ +######################### To start, we must first get a pretrained model to align. There are two models we recommend to get started. The rest of the tutorial will work with either model, but for demonstration purposes, we will use the smaller 2B model. .. tab-set:: @@ -81,7 +78,7 @@ Instruction Following Taught by Supervised Fine-Tuning (SFT) For best DPO training performance, it is recommended that you start with a SFT model, rather than the base model. For a full guide on how to perform SFT on a Megatron GPT model, please refer to the :ref:`SFT guide `. DPO Model Training -##################### +################## Before running the core DPO training, you must prepare your training and validation data to the format required for DPO training. DPO expects ``.jsonl`` files where each line is a JSON dict corresponding to a single, complete sample, as shown below:: @@ -100,6 +97,25 @@ Your JSONL file must contain at least as many samples as the Global Batch Size ( Once your data is processed into the correct format, you are ready to begin DPO training. You must start with a pretrained or SFT trained model. For this section, we will use the SFT model trained in the previous step to train the DPO model. For the purposes of the following sections, we assume that your training ``.jsonl`` file is located in ``/path/to/train_dpo_format.jsonl`` and your validation ``.jsonl`` file is located in ``/path/to/valid_dpo_format.jsonl``. +.. tip:: + + If you don't have a DPO dataset readily available, you can generate a toy one to get started. Here's + an example to generate ``NUM_EXAMPLES_TO_GENERATE`` examples. Ensure this value is larger than the + global_batch_size. + + .. code-block:: bash + + # Generates a dummy dataset in /path/to/train_dpo_format.jsonl /path/to/valid_dpo_format.jsonl + + NUM_EXAMPLES_TO_GENERATE=200 + + mkdir -p /path/to + for i in $(seq 1 $NUM_EXAMPLES_TO_GENERATE); do + cat <System\n\nUser\n${i}*10=?\nAssistant\n", "chosen_response": "$((i * 10))\n", "rejected_response": "I refuse to answer this question.\n"} + EOF + done | tee /path/to/train_dpo_format.jsonl /path/to/valid_dpo_format.jsonl >/dev/null + For the following parameters, the ``model.dpo.ref_policy_kl_penalty`` corresponds to the beta parameter in the DPO paper. .. tab-set:: @@ -111,7 +127,7 @@ For the following parameters, the ``model.dpo.ref_policy_kl_penalty`` correspond .. code-block:: bash - export GPFS="/path/to/nemo-aligner-repo" + export GPFS="/opt/NeMo-Aligner" export TRAIN_DATA_PATH="/path/to/train_dpo_format.jsonl" export VALID_DATA_PATH="/path/to/valid_dpo_format.jsonl" @@ -147,7 +163,7 @@ For the following parameters, the ``model.dpo.ref_policy_kl_penalty`` correspond #SBATCH --exclusive #SBATCH --overcommit - GPFS="/path/to/nemo-aligner-repo" + export GPFS="/opt/NeMo-Aligner" PRETRAINED_CHECKPOINT_NEMO_FILE="/path/to/megatron_gpt_sft.nemo" TRAIN_DATA_PATH="/path/to/train_comparisons.jsonl" @@ -187,7 +203,6 @@ For the following parameters, the ``model.dpo.ref_policy_kl_penalty`` correspond EOF srun --no-container-mount-home -o $OUTFILE -e $ERRFILE --container-image=$CONTAINER $MOUNTS bash -c "${cmd}" - set +x The default DPO training tunes all parameters. To use LoRA, we can set ``model.peft.peft_scheme=lora`` and use different parameters in ``model.peft.lora_tuning``. Please check the parameters in `the config file `__. @@ -204,4 +219,4 @@ However, the following list is a brief overview of which hyperparameters we have * global_batch_size: Generally, we have found that, all other parameters held equal, lower GBS performs worse. GBS of 256 or 512 seems to be the sweet spot for most models we trained. * epochs: Highly sensitive to training data size. We recommend you start with 1 epoch and then add on from there. We did not see any improvements beyond 3 epochs. * learning rate: We tested cosine annealing with a warmup of 10 steps, followed by a slow decay to a constant rate. That constant rate should be fairly low. We saw the best performance with 9e-7 and 5-e7. -* ref_policy_kl_penalty: We generally saw better performance with lower values of 0.1, 0.2, 0.5, and 1.0. Occasionally, values as high as 5.0 worked too. \ No newline at end of file +* ref_policy_kl_penalty: We generally saw better performance with lower values of 0.1, 0.2, 0.5, and 1.0. Occasionally, values as high as 5.0 worked too. diff --git a/docs/user-guide/draftp.rst b/docs/user-guide/draftp.rst index d4e504e32..e2227f944 100644 --- a/docs/user-guide/draftp.rst +++ b/docs/user-guide/draftp.rst @@ -1,15 +1,12 @@ .. include:: /content/nemo.rsts -.. _model-aligner-draftp: +.. include:: aligner-algo-header.rst + +.. _nemo-aligner-draftp: Fine-Tuning Stable Diffusion with DRaFT+ @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -.. note:: - Before starting this tutorial, be sure to review the :ref:`introduction ` for tips on setting up your NeMo-Aligner environment. - - If you run into any problems, refer to NeMo's `Known Issues page `__. The page enumerates known issues and provides suggested workarounds where appropriate. - In this tutorial, we will go through the step-by-step guide for fine-tuning a Stable Diffusion model using DRaFT+ algorithm by NVIDIA. DRaFT+ enhances the DRaFT `DRaFT `__ algorithm by mitigating mode collapse and improving diversity through regularization. For more technical details on the DRaFT+ algorithm, check out our technical blog. diff --git a/docs/user-guide/index.rst b/docs/user-guide/index.rst index 650d67a6e..bf80fb618 100644 --- a/docs/user-guide/index.rst +++ b/docs/user-guide/index.rst @@ -1,16 +1,17 @@ .. include:: /content/nemo.rsts -.. include:: modelalignment.rsts +.. include:: nemoaligner.rsts .. toctree:: - :maxdepth: 4 - :titlesonly: + :maxdepth: 2 sft.rst + knowledge-distillation.rst + dpo.rst rlhf.rst steerlm.rst steerlm2.rst - dpo.rst + rs.rst spin.rst draftp.rst cai.rst @@ -18,26 +19,131 @@ :ref:`Prerequisite Obtaining a Pre-Trained Model ` This section provides instructions on how to download pre-trained LLMs in .nemo format. The following section will use these base LLMs for further fine-tuning and alignment. -:ref:`Model Alignment by Supervised Fine-Tuning (SFT) ` +:ref:`Model Alignment by Supervised Fine-Tuning (SFT) ` In this section, we walk you through the most straightforward alignment method. We use a supervised dataset in the prompt-response pairs format to fine-tune the base model according to the desired behavior. -:ref:`Model Alignment by RLHF ` +:ref:`Supervised Fine-Tuning (SFT) with Knowledge Distillation ` + In this section, we walk through a variation of SFT using Knowledge Distillation where we train a smaller "student" model using a larger "teacher" model. + +:ref:`Model Alignment by DPO, RPO and IPO ` + DPO, RPO, and IPO are simpler alignment methods compared to RLHF. DPO introduces a novel parameterization of the reward model in RLHF, which allows us to extract the corresponding optimal policy. Similarly, RPO and IPO provide alternative parameterizations or optimization strategies, each contributing unique approaches to refining model alignment. + +:ref:`Model Alignment by RLHF ` RLHF is the next step up in alignment and is still responsible for most state-of-the-art chat models. In this section, we walk you through the process of RLHF alignment, including training a reward model and RLHF training with the PPO algorithm. -:ref:`Model Alignment by SteerLM Method ` +:ref:`Model Alignment by SteerLM Method ` SteerLM is a novel approach developed by NVIDIA. SteerLM simplifies alignment compared to RLHF. It is based on SFT, but allows user-steerable AI by enabling you to adjust attributes at inference time. -:ref:`Model Alignment by SteerLM 2.0 Method ` +:ref:`Model Alignment by SteerLM 2.0 Method ` SteerLM 2.0 is an extension to SteerLM method that introduces an iterative training procedure to explicitly enforce the generated responses to follow the desired attribute distribution. -:ref:`Model Alignment by DPO, RPO and IPO ` - DPO, RPO, and IPO are simpler alignment methods compared to RLHF. DPO introduces a novel parameterization of the reward model in RLHF, which allows us to extract the corresponding optimal policy. Similarly, RPO and IPO provide alternative parameterizations or optimization strategies, each contributing unique approaches to refining model alignment. - -:ref:`Model Alignment by Rejection Sampling (RS) ` +:ref:`Model Alignment by Rejection Sampling (RS) ` RS is a simple online alignment algorithm. In RS, the policy model generates several responses. These responses are assigned a score by the reward model, and the highest scoring responses are used for SFT. -:ref:`Fine-tuning Stable Diffusion with DRaFT+ ` +:ref:`Fine-tuning Stable Diffusion with DRaFT+ ` DRaFT+ is an algorithm for fine-tuning text-to-image generative diffusion models. It achieves this by directly backpropagating through a reward model. This approach addresses the mode collapse issues from the original DRaFT algorithm and improves diversity through regularization. -:ref:`Constitutional AI: Harmlessness from AI Feedback ` +:ref:`Constitutional AI: Harmlessness from AI Feedback ` CAI, an alignment method developed by Anthropic, enables the incorporation of AI feedback for aligning LLMs. This feedback is grounded in a small set of principles (referred to as the ‘Constitution’) that guide the model toward desired behaviors, emphasizing helpfulness, honesty, and harmlessness. + +.. list-table:: Algorithm vs. (NLP) Models + :widths: auto + :header-rows: 1 + :stub-columns: 1 + + * - Algorithm + - TRTLLM Accelerated + - GPT 2B + - LLaMA2 + - LLaMA3 + - Mistral + - Nemotron-4 + - Mixtral + * - :ref:`SFT ` + - + - Yes (✓) + - Yes + - Yes + - Yes + - Yes (✓) + - + * - :ref:`SFT with Knowledge Distillation ` + - + - Yes (✓) + - Yes + - Yes + - Yes + - Yes + - + * - :ref:`DPO ` + - + - Yes (✓) + - Yes + - Yes + - Yes + - Yes (✓) + - In active development + * - :ref:`RLHF ` + - Yes + - Yes + - Yes + - Yes (✓) + - Yes + - Yes (✓) + - + * - :ref:`SteerLM ` + - + - Yes + - Yes (✓) + - Yes + - Yes + - Yes + - + * - :ref:`SteerLM 2.0 ` + - + - Yes + - Yes + - Yes + - Yes + - Yes + - + * - :ref:`Rejection Sampling ` + - + - Yes + - Yes + - Yes + - Yes + - Yes + - + * - :ref:`CAI ` + - + - Yes + - Yes + - Yes + - Yes (✓) + - Yes + - + +.. list-table:: Algorithm vs. (Multimodal) Models + :widths: auto + :header-rows: 1 + :stub-columns: 1 + + * - Algorithm + - Stable Diffusion + * - :ref:`Draft+ ` + - Yes (✓) + +.. note:: + + * (✓): Indicates the model is verified to work with the algorithm. Models without this demarcation are expected to work but have not been formally verified yet. + +Hardware Requirements +##################### + +NeMo-Aligner is powered by other NVIDIA libraries that support several NVIDIA GPUs. +NeMo-Aligner is tested on H100 but also works on A100. Several tutorials assume 80GB VRAM, +so if you are following along with GPUs with 40GB, adjust your config accordingly. + +Examples of config adjustments are increasing node count, introducing more tensor/pipeline +parallelism, lowering batch size, and increasing gradient accumulation. diff --git a/docs/user-guide/knowledge-distillation.rst b/docs/user-guide/knowledge-distillation.rst index db793a882..51387c162 100644 --- a/docs/user-guide/knowledge-distillation.rst +++ b/docs/user-guide/knowledge-distillation.rst @@ -1,5 +1,9 @@ .. include:: /content/nemo.rsts +.. include:: aligner-algo-header.rst + +.. _nemo-aligner-knowledge-distillation: + Supervised Fine-Tuning (SFT) with Knowledge Distillation @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@ -10,14 +14,15 @@ There are two primary benefits of knowledge distillation compared to standard su There are many variants of knowledge distillation. NeMo Aligner supports training the student model to match the top-K logits of the teacher model. In this tutorial, we will go through fine-tuning a 2B student using a fine-tuned Nemotron 8B chat model. .. note:: - Before starting this tutorial, be sure to review the :ref:`introduction ` for tips on setting up your NeMo-Aligner environment. + Before starting this tutorial, be sure to review the :ref:`introduction ` for tips on setting up your NeMo-Aligner environment. If you run into any problems, refer to NeMo's `Known Issues page `__. The page enumerates known issues and provides suggested workarounds where appropriate. -Obtain the fine-tuned teacher and pre-trained student models +Obtain the Fine-Tuned Teacher and Pre-Trained Student Models ############################################################ -To start, we must first download both the pre-trained student and fine-tuned teacher models + +To start, we must first download both the pre-trained student and fine-tuned teacher models. .. tab-set:: @@ -42,7 +47,7 @@ To start, we must first download both the pre-trained student and fine-tuned tea .. code-block:: bash huggingface-cli download nvidia/nemotron-3-8b-chat-4k-sft --local-dir teacher_checkpoint -After these steps you should have files ``2b_student.nemo`` and ``teacher_checkpoint/Nemotron-3-8B-Chat-4k-SFT.nemo`` to use in NeMo-Aligner. +After these steps, you should have files ``2b_student.nemo`` and ``teacher_checkpoint/Nemotron-3-8B-Chat-4k-SFT.nemo`` to use in NeMo-Aligner. .. note:: Megatron Core models use TransformerEngine as a backend, which attempts to find efficient kernels. However, depending on your GPU, it may not always succeed. If you encounter errors related to kernel finding, set these variables at the top of your script. diff --git a/docs/user-guide/modelalignment.rsts b/docs/user-guide/nemoaligner.rsts similarity index 67% rename from docs/user-guide/modelalignment.rsts rename to docs/user-guide/nemoaligner.rsts index 099a5bbe4..1e54ed676 100644 --- a/docs/user-guide/modelalignment.rsts +++ b/docs/user-guide/nemoaligner.rsts @@ -1,8 +1,8 @@ -.. _model-aligner-intro: +.. _nemo-aligner-intro: -Model Alignment -!!!!!!!!!!!!!!! +NeMo-Aligner +!!!!!!!!!!!! Introduction ############ @@ -13,6 +13,8 @@ The NeMo-Aligner toolkit is built using the `NeMo Toolkit `__. Once you have logged in, you can get the container here: `NVIDIA NGC NeMo Framework `__. -To use a pre-built container, run the following code: +To run interactively using a pre-built container, run the following code: .. code-block:: bash - docker run -it --gpus=all --shm-size=8g --workdir /opt/NeMo-Aligner nvcr.io/nvidia/nemo:24.09 + docker run --rm -it \ + --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --shm-size=8g \ + --workdir /opt/NeMo-Aligner \ + nvcr.io/nvidia/nemo:24.09 Please use the latest tag in the form yy.mm.(patch). -.. note:: +.. important:: - Some of the subsequent tutorials require accessing gated Hugging Face models. For details on how to access these models, refer to `this document `__. - If you run into any problems, refer to NeMo's `Known Issues page `__. The page enumerates known issues and provides suggested workarounds where appropriate. +Build a NeMo-Aligner Dockerfile +############################### + +NeMo-Aligner also provides its own `dockerfile `__ if you want to customize the environment. +Run the following to build the image: + + .. code-block:: bash + git clone https://github.com/NVIDIA/NeMo-Aligner.git + cd NeMo-Aligner + + # Replace with branch to build from + ALIGNER_COMMIT=main + TARGET_IMAGE=aligner-custom + + docker buildx build \ + -t $TARGET_IMAGE \ + --build-arg=ALIGNER_COMMIT=$ALIGNER_COMMIT \ + . + + # Run the image using the above command in "Get Started" and swap out "nvcr.io/nvidia/nemo:24.09" with "aligner-custom". diff --git a/docs/user-guide/rlhf.rst b/docs/user-guide/rlhf.rst index 1caebf1a5..5c68edb60 100644 --- a/docs/user-guide/rlhf.rst +++ b/docs/user-guide/rlhf.rst @@ -1,21 +1,18 @@ .. include:: /content/nemo.rsts -.. _model-aligner-rlhf: +.. include:: aligner-algo-header.rst + +.. _nemo-aligner-rlhf: Model Alignment by RLHF @@@@@@@@@@@@@@@@@@@@@@@ -.. note:: - Before starting this tutorial, be sure to review the :ref:`introduction ` for tips on setting up your NeMo-Aligner environment. - - If you run into any problems, refer to NeMo's `Known Issues page `__. The page enumerates known issues and provides suggested workarounds where appropriate. - For the purposes of this tutorial, we will go through the entire Reinforcement Learning from Human Feedback (RLHF) pipeline using models from the NeMo Framework. These models can include LLaMa or Mistral, and our scripts will function consistently across them. RLHF is usually preceded by a Supervised Fine-Tuning (SFT). We should first follow the :ref:`Prerequisite guide ` and the :ref:`SFT guide `. After obtaining the SFT model, we will use this to start the RLHF process. We will use the `PPO `__ algorithm for reinforcement learning on the `Anthropic-HH-RLHF `__ dataset. Data Processing for RLHF -######################### +######################## We have a script ready to use for processing the Anthropic-HH dataset into a JSONL format. Run the following command on the `download_and_process.py `__ script for anthropic HH. diff --git a/docs/user-guide/rs.rst b/docs/user-guide/rs.rst index be3b0546d..526284785 100644 --- a/docs/user-guide/rs.rst +++ b/docs/user-guide/rs.rst @@ -1,15 +1,12 @@ .. include:: /content/nemo.rsts -.. _model-aligner-rs: +.. include:: aligner-algo-header.rst + +.. _nemo-aligner-rs: Model Alignment by Rejection Sampling @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -.. note:: - Before starting this tutorial, be sure to review the :ref:`introduction ` for tips on setting up your NeMo-Aligner environment. - - If you run into any problems, refer to NeMo's `Known Issues page `__. The page enumerates known issues and provides suggested workarounds where appropriate. - In this tutorial, we will guide you through the process of aligning a NeMo Framework model using rejection sampling. This method can be applied to various models, including LLaMa and Mistral, with our scripts functioning consistently across different models. Rejection Sampling is usually preceded by a Supervised Fine-Tuning (SFT). We should first follow the :ref:`Prerequisite guide ` and the :ref:`SFT guide `. After obtaining the SFT model, we will also need to train a reward model as in :ref:`PPO guide `. We will use the rejection sampling algorithm on the `Anthropic-HH-RLHF `__ dataset. diff --git a/docs/user-guide/sft.rst b/docs/user-guide/sft.rst index 9d2c86d99..3dde7c32e 100644 --- a/docs/user-guide/sft.rst +++ b/docs/user-guide/sft.rst @@ -1,5 +1,7 @@ .. include:: /content/nemo.rsts +.. include:: aligner-algo-header.rst + .. _prerequisite: Obtain a Pretrained Model @@ -58,7 +60,7 @@ After these steps, you will have a file called ``mcore_gpt.nemo`` to use in NeMo export NVTE_FLASH_ATTN=0 export NVTE_FUSED_ATTN=0 -.. _model-aligner-sft: +.. _nemo-aligner-sft: Model Alignment by Supervised Fine-Tuning (SFT) ############################################### @@ -69,13 +71,8 @@ Model Alignment by Supervised Fine-Tuning (SFT) 2. **Chat**. In the *Chat* format, each example contains a multi-turn conversation between different roles (e.g., *User* and *Assistant*). Fine-tuning the base model on a chat format dataset is useful to align a chatbot. -.. note:: - Before starting this tutorial, be sure to review the :ref:`introduction ` for tips on setting up your NeMo-Aligner environment. - - If you run into any problems, refer to NeMo's `Known Issues page `__. The page enumerates known issues and provides suggested workarounds where appropriate. - Fine-Tune with a Prompt-Response Dataset -%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% Step 1: Format the data. ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/user-guide/spin.rst b/docs/user-guide/spin.rst index 77fb2f6e9..a43432961 100644 --- a/docs/user-guide/spin.rst +++ b/docs/user-guide/spin.rst @@ -1,5 +1,7 @@ .. include:: /content/nemo.rsts +.. include:: aligner-algo-header.rst + Model Alignment by Self-Play Fine-Tuning (SPIN) @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@ -9,11 +11,6 @@ All algorithms in NeMo-Aligner will work with any GPT-based model that is from M For details on the SPIN algorithm, refer to the paper: `https://arxiv.org/abs/2401.01335 `__. -.. note:: - Before starting this tutorial, be sure to review the :ref:`introduction ` for tips on setting up your NeMo-Aligner environment. - - If you run into any problems, refer to NeMo's `Known Issues page `__. The page enumerates known issues and provides suggested workarounds where appropriate. - Obtain a Pretrained Model ######################### To start, we must first get a pretrained model to align. There are two models we recommend to get started. The rest of the tutorial will work with either model, but for demonstration purposes, we will use the smaller 2B model. diff --git a/docs/user-guide/steerlm.rst b/docs/user-guide/steerlm.rst index 983850277..1c2eea2b1 100644 --- a/docs/user-guide/steerlm.rst +++ b/docs/user-guide/steerlm.rst @@ -1,6 +1,8 @@ .. include:: /content/nemo.rsts -.. _model-aligner-steerlm: +.. include:: aligner-algo-header.rst + +.. _nemo-aligner-steerlm: Model Alignment by SteerLM Method @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ @@ -12,7 +14,7 @@ The current approach for LLM improvement combines Supervised Fine-Tuning (SFT) a SteerLM addresses these challenges and represents a significant advancement in the field, making it easier to tailor LLMs to specific needs and preferences. This document delves into how SteerLM operates and offers guidance on training a SteerLM model. SteerLM -############### +####### SteerLM leverages a SFT method that empowers you to control responses during inference. It overcomes the limitations of prior alignment techniques, and consists of four key steps: 1. Train an attribute prediction model on human-annotated datasets to evaluate response quality on any number of attributes like helpfulness, humor, and creativity. @@ -45,11 +47,6 @@ Train a SteerLM Model This section is a step-by-step tutorial that walks you through how to run a full SteerLM pipeline with a Llama2 70B LLM model. -.. note:: - Before starting this tutorial, be sure to review the :ref:`introduction ` for tips on setting up your NeMo-Aligner environment. - - If you run into any problems, refer to NeMo's `Known Issues page `__. The page enumerates known issues and provides suggested workarounds where appropriate. - Download the Llama 2 LLM Model ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/user-guide/steerlm2.rst b/docs/user-guide/steerlm2.rst index 366e0be03..10ca277f2 100644 --- a/docs/user-guide/steerlm2.rst +++ b/docs/user-guide/steerlm2.rst @@ -1,12 +1,14 @@ -.. .. include:: /content/nemo.rsts +.. include:: /content/nemo.rsts -.. _model-aligner-steerlm2: +.. include:: aligner-algo-header.rst + +.. _nemo-aligner-steerlm2: SteerLM 2.0: Iterative Training for Attribute-Conditioned Language Model Alignment -@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ +@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ -**SteerLM 2.0** is a novel approach for aligning large language models (LLMs) to generate responses with desired attribute values, building upon the original `SteerLM `_ method [1]_ . While SteerLM conducts attribute-conditioned Supervised Fine-Tuning (SFT) to steer LLM outputs, SteerLM 2.0 introduces an iterative training procedure to explicitly enforce the generated responses to follow the desired attribute distribution. +**SteerLM 2.0** is a novel approach for aligning large language models (LLMs) to generate responses with desired attribute values, building upon the original `SteerLM `_ method [1]_. While SteerLM conducts attribute-conditioned Supervised Fine-Tuning (SFT) to steer LLM outputs, SteerLM 2.0 introduces an iterative training procedure to explicitly enforce the generated responses to follow the desired attribute distribution. Overview ######## @@ -21,7 +23,7 @@ SteerLM 2.0 accomplishes this by minimizing the Kullback-Leibler (KL) divergence This KL divergence loss can be optimized using samples from an initial SteerLM model :math:`Q'(y|a, x)`, leading to an efficient gradient estimation procedure (see [2]_ for derivations). Method Details -############### +############## **Construct the optimal conditional distribution** :math:`P(y|a, x)`: Using Bayes' rule and the attribute prediction model :math:`P(a|x, y)`, we can derive the optimal conditional distribution as: @@ -44,7 +46,7 @@ where :math:`w'_i` and :math:`b'_i` are normalized importance weights targeting By iteratively training on this loss, SteerLM 2.0 can learn to generate responses :math:`y` that better conform to specified attribute values :math:`a` for a given prompt :math:`x`. Train a SteerLM 2.0 Model -########################### +######################### Prepare the Training Dataset ---------------------------- @@ -103,7 +105,7 @@ For a given attribute string a and prompt x (constructed from prompt turns and t These values are provided as log(P(a|x,y)), log(P(y|x)), and log(Q(y|a,x)), respectively, for each sampled response :math:`y_i`. Training Example ------------------- +---------------- By organizing the data in this format, the SteerLM 2.0 model can be effectively trained to generate responses that conform to the desired attribute values while approximating the optimal conditional distribution :math:`P(y|a, x)`. The following is an example of launching the training of SteerLM 2.0: @@ -158,16 +160,16 @@ By organizing the data in this format, the SteerLM 2.0 model can be effectively exp_manager.explicit_log_dir=/results/acsft_70b \ exp_manager.checkpoint_callback_params.save_nemo_on_train_end=True -``/path/to/steerlm1/model`` is the path to the initial SteerLM model. For details on training the initial SteerLM model, refer to the :ref:`SteerLM documentation `. +``/path/to/steerlm1/model`` is the path to the initial SteerLM model. For details on training the initial SteerLM model, refer to the :ref:`SteerLM documentation `. Inference ------------------- +--------- -Since the SteerLM 2.0 Model is an extension of the original SteerLM model, the inference process is similar. Please refer to the `SteerLM `_ documentation for more details. +Since the SteerLM 2.0 Model is an extension of the original SteerLM model, the inference process is similar. Refer to the `SteerLM `_ documentation for more details. References ---------- .. [1] Dong, Y., Delalleau, O., Zeng, J., Shen, G., Zhang, J.J., Sreedhar, M.N., Kuchaiev, O. (2023). SteerLM: Attribute Conditioned SFT as an (User-Steerable) Alternative to RLHF. -.. [2] Wang, Z., Dong, Y., Delalleau, O., Zeng, J., Shen, G., Zhang, J.J., Sreedhar, M.N., Kuchaiev, O. (2024). HelpSteer2: Open-source dataset for training top-performing reward models. \ No newline at end of file +.. [2] Wang, Z., Dong, Y., Delalleau, O., Zeng, J., Shen, G., Zhang, J.J., Sreedhar, M.N., Kuchaiev, O. (2024). HelpSteer2: Open-source dataset for training top-performing reward models. From 35d0c59a8b0479c93a5ae7b580a6a5c2a7782420 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Thu, 28 Nov 2024 01:20:34 +0100 Subject: [PATCH 3/5] ci: Allow dry-run of release (#421) Signed-off-by: Oliver Koenig --- .github/workflows/release.yaml | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 12548461e..6991a5cfb 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -20,10 +20,15 @@ on: description: Ref (SHA or branch name) to release required: true type: string + dry-run: + description: Do not publish a wheel and GitHub release. + required: true + default: true + type: boolean jobs: release: - uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_release_library.yml@v0.12.3 + uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_release_library.yml@v0.15.0 with: release-ref: ${{ inputs.release-ref }} image-name: nemo_aligner_container @@ -36,8 +41,10 @@ jobs: python-package: nemo_aligner container-workdir: /opt/NeMo-Aligner library-name: NeMo-Aligner + dry-run: ${{ inputs.dry-run }} secrets: TWINE_USERNAME: ${{ secrets.TWINE_USERNAME }} TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }} SLACK_RELEASE_ENDPOINT: ${{ secrets.SLACK_RELEASE_ENDPOINT }} PAT: ${{ secrets.PAT }} + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} From 70e4f31dea1fa0dcf7e425af46344fa1db45f44f Mon Sep 17 00:00:00 2001 From: Alexander Bukharin <59148829+abukharin3@users.noreply.github.com> Date: Mon, 2 Dec 2024 21:26:28 -0500 Subject: [PATCH 4/5] fix: correct REINFORCE to resume training (#427) Signed-off-by: abukharin Signed-off-by: NeMo-Aligner CI Co-authored-by: abukharin Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Terry Kong --- nemo_aligner/utils/train_script_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/nemo_aligner/utils/train_script_utils.py b/nemo_aligner/utils/train_script_utils.py index c6f6f8089..eeed1a538 100644 --- a/nemo_aligner/utils/train_script_utils.py +++ b/nemo_aligner/utils/train_script_utils.py @@ -50,12 +50,18 @@ def retrieve_custom_trainer_state_dict(ptl_trainer): consumed_samples = extract_value_from_ckpt(key="consumed_samples", ckpt_path=trainer_restore_path) step = extract_value_from_ckpt(key="step", ckpt_path=trainer_restore_path) epoch = extract_value_from_ckpt(key="epoch", ckpt_path=trainer_restore_path) + + # TODO: unify alignment step key to avoid adding one for each algo ppo_optimization_step = extract_value_from_ckpt(key="ppo_optimization_step", ckpt_path=trainer_restore_path) + reinforce_optimization_step = extract_value_from_ckpt( + key="reinforce_optimization_step", ckpt_path=trainer_restore_path + ) trainer_state_dict = { "step": step, "consumed_samples": consumed_samples, "epoch": epoch, "ppo_optimization_step": ppo_optimization_step, + "reinforce_optimization_step": reinforce_optimization_step, } return trainer_state_dict From 1d732adfd817f6d1c779040d4100e1145cfd3a08 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Dec 2024 23:20:20 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: NeMo-Aligner CI --- tests/test_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 334ec05fe..79fb5e77d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -169,7 +169,7 @@ def test_dpo_dataset_conversion(): # (@adithyare) bonus test! convert oai style messages back into a string using Jinja # Attempt to import jinja2 via importorskip jinja2 = pytest.importorskip("jinja2", reason="jinja2 library is not installed") - + # Now it's safe to use jinja2 from jinja2 import Template