Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat templates #293

Open
wants to merge 2 commits into
base: chat-templates
Choose a base branch
from

Conversation

SamGalanakis
Copy link

Moved state to the PromptState and made further changes to align with the expected templating of huggingface.

@SamGalanakis
Copy link
Author

SamGalanakis commented Dec 10, 2023

import lmql
from transformers import (
    AutoTokenizer,
)

tokenizer_string = "HuggingFaceH4/zephyr-7b-beta"

lmql_model = lmql.model(
    f"local:gpt2",
    tokenizer=tokenizer_string,cuda=True
)

tokenizer = AutoTokenizer.from_pretrained(tokenizer_string)


dialogue = [
    {"role": "system", "content": "You are a bot"},
    {"role": "user", "content": "Hey bot"},
    {"role": "assistant", "content": "hey there pal"},
    {"role": "assistant", "content": "more assistant continuation"},
    {"role": "user", "content": "Nice"}
]


@lmql.query(model=lmql_model, name="lmql_chat",chat_template=tokenizer.chat_template)
def lmql_chat():
    '''argmax 
        "{:system} You are a bot"
        "{:user} Hey bot"
        "{:assistant} [ANSWER]" where ANSWER in set(["hey there"]) 
        "pal"
        "{:user} Nice"
    '''


out  = lmql_chat()
prompt_from_lmql = out.prompt
prompt_from_huggingface  = tokenizer.apply_chat_template(dialogue,tokenize=False)
assert prompt_from_huggingface == prompt_from_huggingface, "Prompt from lmql does not match prompt from huggingface"

Tested with this example @lbeurerkellner

@lbeurerkellner
Copy link
Collaborator

lbeurerkellner commented Dec 17, 2023

State handling looks good now thanks. I just wanted to test with "meta-llama/Llama-2-7b-chat-hf", however, it seems the start/end tag extraction does not work very well there.

First of all, some templates seem to expect a raise_exception callback for error. E.g. the Llama-2 template, raises an exception if we don't alternate user/system/user/system.

I ignored exceptions using the following implementation of get_start_end, but this leads to invalid rendered templates

        eos_token = self.model.adapter._tokenizer.tokenizer_impl.tokenizer.eos_token
        bos_token = self.model.adapter._tokenizer.tokenizer_impl.tokenizer.bos_token
        split_template_text = "split_template_text" # Dummy text to split the template - just needs to not be present in the template
        
        def raise_exception(message): 
            warnings.warn(message)
        role_start, role_end = Template(chat_template, trim_blocks=True, lstrip_blocks=True)\
            .render(messages=[
                {'role':role,'content':split_template_text}
            ],
            bos_token=bos_token,
            eos_token=eos_token, 
            raise_exception=raise_exception
        ).split(split_template_text)
        
        return role_start, role_end

I am testing with the following query:

argmax
    "{:user} Say 'this is a test'" 
    "{:assistant} [RESPONSE]"
from
    "meta-llama/Llama-2-7b-chat-hf"
where 
    len(TOKENS(RESPONSE)) < 10

With this, we get the following.

<s> [INST] Say 'this is a test' [/INST]None of us are sure what you're asking

"None " seems to be an artefact of a failing call to get_start_end, where we ignore exceptions. So it seems, we need to find a way to get start+end, without violating these kind of constraints like alternating user/assistant.

@SamGalanakis
Copy link
Author

I see, we could track the previous role as well and pass it in with another dumy text but parsing the start+end from the resulting string would be very tricky. Also for the same meta-llama/Llama-2-7b-chat-hf template returns an empty string when passing a system role without any messages after so that would also break it. Not sure I see any reasonable way to deal with these kind of templates. Would only supporting "simple" templates make sense? Or maybe make this feature more general where the user can pass a function that takes in a role, message and spits out start, end so people can deal with more complex requirements on a case by case basis?

@lbeurerkellner
Copy link
Collaborator

lbeurerkellner commented Jan 7, 2024

The problem with the Llama template seems to be that it does not allow all tag interleavings as expressible with our current {:role} syntax. If users specify e.g. two system messages, it will raise an error.

I think it would be fine to limit support to simpler templates for now. For Llama, we should probably still add a custom template so that it works out of the box, as it is a common prompt template. But this can be a separate PR.

For this, it should be fine to merge if we can make sure that template errors are reported properly and that users know how to fix errors by providing a simpler template. With the current system, what is the concrete requirement a template has to satisfy to be considered "simple"? Maybe that could be specified in the error message?

@AJHoeh
Copy link

AJHoeh commented Mar 7, 2024

import lmql
from transformers import (
    AutoTokenizer,
)

tokenizer_string = "HuggingFaceH4/zephyr-7b-beta"

lmql_model = lmql.model(
    f"local:gpt2",
    tokenizer=tokenizer_string,cuda=True
)

tokenizer = AutoTokenizer.from_pretrained(tokenizer_string)


dialogue = [
    {"role": "system", "content": "You are a bot"},
    {"role": "user", "content": "Hey bot"},
    {"role": "assistant", "content": "hey there pal"},
    {"role": "assistant", "content": "more assistant continuation"},
    {"role": "user", "content": "Nice"}
]


@lmql.query(model=lmql_model, name="lmql_chat",chat_template=tokenizer.chat_template)
def lmql_chat():
    '''argmax 
        "{:system} You are a bot"
        "{:user} Hey bot"
        "{:assistant} [ANSWER]" where ANSWER in set(["hey there"]) 
        "pal"
        "{:user} Nice"
    '''


out  = lmql_chat()
prompt_from_lmql = out.prompt
prompt_from_huggingface  = tokenizer.apply_chat_template(dialogue,tokenize=False)
assert prompt_from_huggingface == prompt_from_huggingface, "Prompt from lmql does not match prompt from huggingface"

Tested with this example @lbeurerkellner

Don't know if this is (still) relevant, but in the last line of code the hf prompt is compared to itself instead of the lmql one. Just wanted to let you know in case you didn't catch that yet :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants