-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
move to lit gpt llama impl #3
base: main
Are you sure you want to change the base?
Conversation
f2a515e
to
682d365
Compare
def get_model(config: Config) -> GPT: | ||
# Load model | ||
config_model = LlamaConfig.from_pretrained(config.path_model, attn_implementation=config.attn_implementation) | ||
return LlamaForCausalLM.from_pretrained(pretrained_model_name_or_path=config.path_model, config=config_model) | ||
if isinstance(config.llama_config, ModelConfig): | ||
llama_config = config.llama_config | ||
else: | ||
with open(config.llama_config) as f: | ||
llama_config = ModelConfig(**json.load(f)) | ||
|
||
llama_config.attention_impl = config.attention_impl | ||
return GPT(llama_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hrmm would the initialisation be the same across diloco workers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes indeed, is there a way to use hf hub to push the init ckpt like we did before ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems like GPT inherits from nn.Module instead of transformers PreTrainedModel so it doesnt have the from_pretrained
function. Maybe we can change it to allow loading from hub
input_ids = inputs_ids[:, :-1] | ||
target = inputs_ids[:, 1:] | ||
|
||
output = model(input_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing seqlens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch
fyi @Jackmin801 this pr was not meant to be merged for now, it was more to compare. Tho we have a 20% better mfu with it (even with torch compile) so we might want it. I will address your comments |
where are the MFU speedups coming from? if its from better layer implementations I think it might be better to just modify (copy over to our repo and modify) the HF Its also easier to have FP8 support this way (I have an implementation using transformer engine that uses this implementation method; https://github.com/PrimeIntellect-ai/OpenDiLoCo_internal/pull/71) |
add llama