Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Including tokenizer to onnx model / basic usage of the onnxruntime-extensions #798

Open
MLRadfys opened this issue Aug 28, 2024 · 6 comments

Comments

@MLRadfys
Copy link

MLRadfys commented Aug 28, 2024

Hi and thanks for this great library!

Iam very new to onnx and Iam trying to include the Roberta tokenizer into a Roberta onnx model.
As far as I have understood, one can get the onnx graph for the tokenizer using:

import onnxruntime as _ort
from transformers import RobertaTokenizer
from onnxruntime_extensions import OrtPyFunction, gen_processing_models

# Roberta tokenizer
tokenizer = AutoTokenizer.from_pretrained("roberta-base", model_max_length=512)
tokenizer_onnx = OrtPyFunction(gen_processing_models(spm_hf_tokenizer, pre_kwargs={})[0])

Now Iam wondering what the next step is? How can I combine the onnx tokenizer (or graph) with a model?

Thanks in advance for any help,

cheers,

M

@wenbingl
Copy link
Member

gen_processing_model returns two ONNX models, one for pre-processing, and other is for post-processing if post_kwargs presents in the kwargs.

If you want to combine any processing model into the ONNX model, please use this function https://onnx.ai/onnx/api/compose.html#onnx.compose.merge_models

@MLRadfys
Copy link
Author

Thank you so much for the quick reply!

I will give it a try! Are there any alternatives?
I just looked at the different test functions and saw that one can create a model using

node = [helper.make_node( 'RobertaTokenizer', ['string_input'], ['input_ids'], vocab=_get_file_content(vocab_file), merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length, domain='ai.onnx.contrib')]

graph = helper.make_graph(node, 'test0', [input1], [output1]) tokenizer_model = make_onnx_model(graph)

Would it then be possible to create a pipeline using:

full_model = pnp.SequentialProcessingModule(tokenizer_model, Roberta_model)

Or in other words, which one is the easiest and most straight-forward method? :-)

Really appreciate your help!

Cheers,

M

@wenbingl
Copy link
Member

wenbingl commented Aug 28, 2024

Thank you so much for the quick reply!

I will give it a try! Are there any alternatives? I just looked at the different test functions and saw that one can create a model using

node = [helper.make_node( 'RobertaTokenizer', ['string_input'], ['input_ids'], vocab=_get_file_content(vocab_file), merges=_get_file_content(merges_file), name='bpetok', padding_length=max_length, domain='ai.onnx.contrib')]

graph = helper.make_graph(node, 'test0', [input1], [output1]) tokenizer_model = make_onnx_model(graph)

Would it then be possible to create a pipeline using:

full_model = pnp.SequentialProcessingModule(tokenizer_model, Roberta_model)

Or in other words, which one is the easiest and most straight-forward method? :-)

Really appreciate your help!

Cheers,

M

this approach works in a lower level which requires onnx and tokenization data knowledge and is prone to errors. So, it is recommended to only use gen_processing_model API and users can get support if there is any problem.

@MLRadfys
Copy link
Author

MLRadfys commented Aug 29, 2024

Alright, I think I solved it using the gen_processing() and merge functions :-)

I attach my solution as a reference for others who encounter a similar problem:

import torch
from onnxruntime_extensions import gen_processing_models
from onnxruntime_extensions import get_library_path
import onnx
import onnxruntime as ort
import numpy as np
from transformers import RobertaForSequenceClassification, RobertaTokenizer

# Step 1: Load the Huggingface Roberta tokenizer and model
input_text = "A test text!"
model_type = "roberta-base"
model = RobertaForSequenceClassification.from_pretrained(model_type)
tokenizer =RobertaTokenizer.from_pretrained(model_type)

# Step 2: Export the tokenizer to ONNX using gen_processing_models
onnx_tokenizer_path = "tokenizer.onnx"

# Generate the tokenizer ONNX model
tokenizer_onnx_model = gen_processing_models(tokenizer, pre_kwargs={})[0]

# Save the tokenizer ONNX model
with open(onnx_tokenizer_path, "wb") as f:
    f.write(tokenizer_onnx_model.SerializeToString())

# Step 3: Export the Huggingface Roberta model to ONNX
onnx_model_path = "model.onnx"
dummy_input = tokenizer("This is a dummy input", return_tensors="pt")


# 5. Export the model to ONNX
torch.onnx.export(
    model,                                                              # model to be exported
    (dummy_input['input_ids'],dummy_input["attention_mask"]),           # model input (dummy input)
    onnx_model_path,                                                    # where to save the ONNX model
    input_names=["input_ids", "attention_mask_input"],                  # input tensor name
    output_names=["logits"],                                            # output tensor names
    dynamic_axes={"input_ids": {0: "batch_size", 1: "sequence_length"}, # dynamic axes
    "logits": {0: "batch_size"}
    }
)

# Step 4: Merge the tokenizer and model ONNX files into one
onnx_combined_model_path = "combined_model_tokenizer.onnx"

# Load the tokenizer and model ONNX files
tokenizer_onnx_model = onnx.load(onnx_tokenizer_path)
model_onnx_model = onnx.load(onnx_model_path)

# Inspect the ONNX models to find the correct input/output names
print("Tokenizer Model Inputs:", [node.name for node in tokenizer_onnx_model.graph.input])
print("Tokenizer Model Outputs:", [node.name for node in tokenizer_onnx_model.graph.output])
print("Model Inputs:", [node.name for node in model_onnx_model.graph.input])
print("Model Outputs:", [node.name for node in model_onnx_model.graph.output])

# Merge the tokenizer and model ONNX files
combined_model = onnx.compose.merge_models(
    tokenizer_onnx_model,
    model_onnx_model,
    io_map=[('input_ids', 'input_ids'), ('attention_mask', 'attention_mask_input')]
)

# Save the combined model
onnx.save(combined_model, onnx_combined_model_path)

# Step 5: Test the combined ONNX model using an Inference session with ONNX Runtime Extensions
# Initialize ONNX Runtime SessionOptions and load custom ops library
sess_options = ort.SessionOptions()
sess_options.register_custom_ops_library(get_library_path())

# Initialize ONNX Runtime Inference session with Extensions
session = ort.InferenceSession(onnx_combined_model_path, sess_options=sess_options, providers=['CPUExecutionProvider'])

# Prepare dummy input text
input_feed = {"input_text": np.asarray([input_text])}  # Assuming "input_text" is the input expected by the tokenizer

# Run the model
outputs = session.run(None, input_feed)

# Print the outputs
print("logits:", outputs[1][0])

Thanks for the help!

Cheers,

M

@r4ghu
Copy link

r4ghu commented Oct 17, 2024

Hi @MLRadfys ,
Thanks for this example. Its very useful to understand how to convert a HuggingFace Tokenizer to OnnxRuntime. I am able to successfully encode text to inputIds.

I am currently stuck in the process of decoding Ids back to text. Is there some example which I can use for reference.

@wenbingl
Copy link
Member

Hi @MLRadfys , Thanks for this example. Its very useful to understand how to convert a HuggingFace Tokenizer to OnnxRuntime. I am able to successfully encode text to inputIds.

I am currently stuck in the process of decoding Ids back to text. Is there some example which I can use for reference.

There is an example here:

m_tok, m_detok = gen_processing_models(
. The second return value of
gen_processing_models is the ONNX model which can decode the ids into text.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants