Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

微调LLaVA报错 #6

Open
njucckevin opened this issue Jun 12, 2024 · 4 comments
Open

微调LLaVA报错 #6

njucckevin opened this issue Jun 12, 2024 · 4 comments

Comments

@njucckevin
Copy link

[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/nfs04/chengkz/VL-RLHF/src/vlrlhf/dpo.py", line 146, in <module>
[rank1]:     dpo_trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
[rank1]:   File "/home/data_91_d/anaconda3/envs/chengkz_lvlm/lib/python3.10/site-packages/transformers/trainer.py", line 1885, in train
[rank1]:     return inner_training_loop(
[rank1]:   File "/home/data_91_d/anaconda3/envs/chengkz_lvlm/lib/python3.10/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
[rank1]:     tr_loss_step = self.training_step(model, inputs)
[rank1]:   File "/home/nfs04/chengkz/VL-RLHF/src/vlrlhf/base/trainer.py", line 305, in training_step
[rank1]:     loss_step = super().training_step(model, inputs)
[rank1]:   File "/home/data_91_d/anaconda3/envs/chengkz_lvlm/lib/python3.10/site-packages/transformers/trainer.py", line 3238, in training_step
[rank1]:     loss = self.compute_loss(model, inputs)
[rank1]:   File "/home/data_91_d/anaconda3/envs/chengkz_lvlm/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1081, in compute_loss
[rank1]:     loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
[rank1]:   File "/home/data_91_d/anaconda3/envs/chengkz_lvlm/lib/python3.10/site-packages/trl/trainer/dpo_trainer.py", line 1022, in get_batch_loss_metrics
[rank1]:     ) = self.concatenated_forward(model, batch)
[rank1]:   File "/home/nfs04/chengkz/VL-RLHF/src/vlrlhf/models/Llava/__init__.py", line 502, in concatenated_forward
[rank1]:     pixel_values=concatenated_batch["pixel_values"],
[rank1]: KeyError: 'pixel_values'

请问这个报错该如何解决?

@TideDra
Copy link
Owner

TideDra commented Jun 12, 2024

已修复,感谢您的反馈。

@njucckevin
Copy link
Author

作者您好,除了在benchmark上测试之外,或许有推理代码可供参考吗?即对于预训练的ckpt或微调之后的ckpt,进行简单的单样本推理。或者我应该参考哪个仓库/模型的代码?
很棒的工作,感谢~

@TideDra
Copy link
Owner

TideDra commented Jun 15, 2024

作者您好,除了在benchmark上测试之外,或许有推理代码可供参考吗?即对于预训练的ckpt或微调之后的ckpt,进行简单的单样本推理。或者我应该参考哪个仓库/模型的代码? 很棒的工作,感谢~

您可以使用src/vlrlhf/eval.utils.py中提供的相关接口:

from vlrlhf.eval.utils import load_model_and_processor
model,processor,generation_kwargs = load_model_and_processor(YourModelPath,None)
image_path = 'a.jpg'
prompt = 'Describe this image'
prompt = processor.format_multimodal_prompt(prompt,image_path)
inputs = processor(texts=[prompt], images_path=[image_path], check_format=False)
inputs.pop('label',None)
outputs = model.generate(**inputs, use_cache=True, **generation_kwargs)

@njucckevin
Copy link
Author

想问下现在的代码仓库支持KTO吗,我看scripts里面有kto相关的脚本,例如kto_qwenvl?如果还不支持的话,后续有计划吗
感谢~

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants