-
Notifications
You must be signed in to change notification settings - Fork 314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Hybrid Data Pipeline #495
base: main
Are you sure you want to change the base?
Conversation
我做了一个 one-shot ,学着写一个,python解释器的样例,佬看一下对不对
|
|
||
|
||
model = dict( | ||
type=HybridFinetune, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个名字有点奇怪,要不叫做 HybridFinetuneModel,不过还有一个疑问,如果直接写了 finetune,用户会不会以为只能 finetune model 而不能 pretrain model?
chat_template=chat_template, | ||
max_length=max_length, | ||
pack_to_max_length=True, | ||
num_workers = dataloader_num_workers, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方也有 dataloader_num_workers?
type=HybridDataset, | ||
data_dir=data_dir, | ||
data_files=data_files, | ||
data_cached='cached_llava', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
支持自动 cached 功能,即用户指定 data_cached 路径后,如果不存在则自动缓存,如果存在则直接读取并告诉用户
"role": "user", | ||
"content": [ | ||
{ | ||
"type": "image_url", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方感觉无法做到通用,因为可能会插入一些图片区分的 token,大部分情况下可能都会要重写 tokenizer 逻辑
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同样的问题,是否有办法兼容以下这几种处理方式?
<image>
Picture X: <image>
<IMG><image></IMG>
self.dataset = dataset | ||
|
||
self._ori_img_urls = dataset['image_urls'] | ||
self._ori_img_rngs = dataset['image_ranges'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要加点注释,否则不知道这个字段是啥意思
xtuner/dataset/hybrid/collate.py
Outdated
'pixel_values': pixel_values, | ||
'cumulative_len': cumulative_len, | ||
'image_ranges': image_ranges, | ||
'image_belong': image_belong |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉要说明下,有些字段只有在特定模式下才需要吧,如果没有点注释,自定义会很难
from xtuner.types import HybridChatTemplate | ||
from xtuner.utils import build_tokenizer | ||
|
||
os.environ['TOKENIZERS_PARALLELISM'] = 'true' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种环境变量有没有 false 的可能,如果有,则最好可以通过让用户环境变量设置,默认值为 true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
考虑加前缀?XTUNER_XXXXXXX
added_keys=dict(tokens=int), | ||
) | ||
def _register_tokens(data, tokenizer=None, chat_template=None): | ||
data['tokens'] = len(data['input_ids']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感觉 tokens 这个名字难以理解,最好应该是 token_len 清晰很多
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接叫length?对齐transformers的一些默认行为,同时方便 LengthGroupedSampler
https://huggingface.co/docs/transformers/v4.39.1/en/main_classes/trainer#transformers.TrainingArguments.length_column_name
added_keys=dict(position_ids=list), | ||
) | ||
def _register_position_ids(data, tokenizer=None, chat_template=None): | ||
data['position_ids'] = [i for i in range(len(data['input_ids']))] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data['position_ids'] = [i for i in range(len(data['input_ids']))] | |
data['position_ids'] = list(range(len(data['input_ids']))) |
input_keys=dict(input_ids=list), | ||
added_keys=dict(cumulative_len=list), | ||
) | ||
def _register_cumulative_len(data, tokenizer=None, chat_template=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
过于简单的函数,可以考虑不要这个封装,否则看起来有点复杂,过度设计
if not isinstance(data[key], _type): | ||
breakpoint() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not isinstance(data[key], _type): | |
breakpoint() | |
assert isinstance(data[key], _type) |
|
||
for url, ind in zip(image_urls, img_token_inds): | ||
# image = self.load_image(url) | ||
h, w = 336 // 14, 336 // 14 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如何兼容其他分辨率和patch size,通过 config 传入这个参数?
|
||
img_ranges = [] | ||
for i, _ in enumerate(zip(input_ids, labels)): | ||
if isinstance(input_ids[i], list): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if isinstance(input_ids[i], list): | |
if isinstance(input_ids[i], list): # image pad tokens |
new_ids.extend(input_ids[i]) | ||
new_labels.extend(labels[i]) | ||
|
||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else: | |
else: # text token |
) | ||
def llava_to_openai(data, tokenizer=None, chat_template=None): | ||
|
||
image_token = '<image>' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
projector_config = ProjectorConfig( | ||
visual_hidden_size=self.visual_encoder.config.hidden_size, | ||
llm_hidden_size=self.llm.config.hidden_size, | ||
depth=projector_depth) | ||
self.projector = ProjectorModel(projector_config).to( | ||
self.visual_encoder.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方的初始化是不是在纯LLM时会出问题?
elif self.role == 'user': | ||
if len(self.files) > 0: | ||
stop_word = chat_template.stop_words[0] | ||
text += f'\n{stop_word}\n{chat_template.decorate_files(self.files)}' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为何会引入 stop_word?同时为啥只取第0个?
elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer': | ||
eos_token_ids = tokenizer.eos_token_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif tokenizer.__class__.__name__ == 'ChatGLMTokenizer': | |
eos_token_ids = tokenizer.eos_token_id |
No description provided.