From e3c9c8870fb2381d522d2254c05a5da7b8359241 Mon Sep 17 00:00:00 2001 From: Terry Kong Date: Thu, 7 Nov 2024 20:05:26 +0000 Subject: [PATCH] dpo pad fix if none Signed-off-by: Terry Kong --- nemo_aligner/algorithms/dpo.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo_aligner/algorithms/dpo.py b/nemo_aligner/algorithms/dpo.py index 428ae7bee..75b773106 100644 --- a/nemo_aligner/algorithms/dpo.py +++ b/nemo_aligner/algorithms/dpo.py @@ -53,7 +53,8 @@ def dpo_custom_collate( This collate happens outside of the torch data loader and is not compatible with the multiprocessing logic due to requiring communication collectives. """ - assert pad_length_to_multiple_of >= 0, f"{pad_length_to_multiple_of=} must be >=0" + if pad_length_to_multiple_of is not None and pad_length_to_multiple_of < 0: + raise ValueError(f"{pad_length_to_multiple_of=} must be >= 0") chosen_tokens = [item["chosen"] for item in batch] rejected_tokens = [item["rejected"] for item in batch] chosen_lengths = torch.LongTensor([item["chosen_length"] for item in batch])