Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
MXueguang committed Nov 30, 2023
1 parent e9e7ae2 commit 2e5d00e
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 19 deletions.
5 changes: 3 additions & 2 deletions examples/repllama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ deepspeed --include localhost:0,1,2,3 train.py \
--deepspeed ds_config.json \
--output_dir model_repllama \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--save_steps 20 \
--save_steps 200 \
--dataset_name Tevatron/msmarco-passage \
--bf16 \
--per_device_train_batch_size 8 \
Expand All @@ -86,7 +86,8 @@ deepspeed --include localhost:0,1,2,3 train.py \
--logging_steps 10 \
--overwrite_output_dir \
--dataset_proc_num 32 \
--negatives_x_device
--negatives_x_device \
--warmup_steps 100
```


Expand Down
2 changes: 1 addition & 1 deletion examples/repllama/repllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def encode_query(self, qry):


def compute_similarity(self, q_reps, p_reps):
return torch.matmul(q_reps, p_reps.transpose(0, 1))
return torch.matmul(q_reps, p_reps.transpose(0, 1)) / 0.01

def gradient_checkpointing_enable(self):
self.lm_q.base_model.gradient_checkpointing_enable()
Expand Down
16 changes: 0 additions & 16 deletions examples/repllama/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
class TevatronTrainer(Trainer):
def __init__(self, *args, **kwargs):
super(TevatronTrainer, self).__init__(*args, **kwargs)
self._dist_loss_scale_factor = dist.get_world_size() if self.args.negatives_x_device else 1

def _save(self, output_dir: Optional[str] = None, state_dict=None):
output_dir = output_dir if output_dir is not None else self.args.output_dir
Expand All @@ -35,21 +34,6 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
torch.save(lora_state_dict, os.path.join(output_dir, "adapter_model.bin"))
print(f"Save adapter model at {output_dir}")

def _prepare_inputs(
self,
inputs: Tuple[Dict[str, Union[torch.Tensor, Any]], ...]
) -> List[Dict[str, Union[torch.Tensor, Any]]]:
prepared = []
for x in inputs:
if isinstance(x, torch.Tensor):
prepared.append(x.to(self.args.device))
else:
prepared.append(super()._prepare_inputs(x))
return prepared

def compute_loss(self, model, inputs):
query, passage = inputs
return model(query=query, passage=passage).loss

def training_step(self, *args):
return super(TevatronTrainer, self).training_step(*args) / self._dist_loss_scale_factor

0 comments on commit 2e5d00e

Please sign in to comment.