You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello! I am trying to create a custom op for a tokenizer. I would like to add the tokenizer to the model cardiffnlp/twitter-roberta-base-hate-latest for internal pre-processing. My idea is to use the onnxruntime_extensions library, specifically the PrePostProcessor pipeline, so I created a new Step to add to the pre-processing and it correctly saves the model, but then when I create an ort session it throws the following error:
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from ./src/models/twitter_roberta_base/roberta_model.with_pre_post_processing.onnx failed:Fatal error: com.microsoft.extensions:HfRobertaTokenizer(-1) is not a registered function/op
I am not sure if this is the way to accomplish what I want, but if anyone could provide me with some guidance it would be much appreciated. My goal is to quantize the model and run it in an Android app for SequenceClassification (HateSpeech specifically).
Here is the pipeline code:
@onnx_op(op_type="HfRobertaTokenizer",inputs=[PyCustomOpDef.dt_string],outputs=[PyCustomOpDef.dt_int64, PyCustomOpDef.dt_int64],attrs={"padding_length": PyCustomOpDef.dt_int64})
def roberta_tokenizer(input_text, **kwargs):
tokenizer = RobertaTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-hate-latest")
padding_length = kwargs.get("padding_length", 512) # Set a default padding length
tokens = tokenizer(input_text[0], padding='max_length', max_length=padding_length, return_tensors='np')
input_ids = tokens['input_ids']
attention_mask = tokens['attention_mask']
return input_ids, attention_mask
class RobertaTokenizerStep(Step):
def __init__(self, tokenizer_param: TokenizerParam, need_token_type_ids_output: bool = False, name: Optional[str] = None):
"""
Brief: This step is used to convert the input text into the input_ids and attention_mask.
It supports an input of a single string for classification models.
Args:
tokenizer_param: some essential infos to build a tokenizer,
You can create a TokenizerParam like this:
tokenizer_param = TokenizerParam(vocab=tokenizer.vocab, # vocab is dict or file_path
merges_file = tokenizer.merges_file, # merges file for BPE
add_prefix_space = True or False (Optional),
)
name: Optional name of step. Defaults to 'RobertaTokenizerStep'
need_token_type_ids_output: whether to include token_type_ids output.
"""
outputs = ["input_ids", "attention_mask"]
if need_token_type_ids_output:
outputs.append("token_type_ids")
super().__init__(["input_text"], outputs, name)
self._tokenizer_param = tokenizer_param
def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
input_type_str0, input_shape_str0 = self._get_input_type_and_shape_strs(graph, 0)
input_shape_0 = input_shape_str0.split(",")
prefix_ = f'step_{self.step_num}'
batch_dim = input_shape_0[0] if len(input_shape_0) > 1 else "1"
output_shape_str = f"{batch_dim}, _{prefix_}__num_ids"
assert input_type_str0 == "string"
onnx_tokenizer_impl = "HfRobertaTokenizer"
def build_output_declare():
return ",".join([f"int64[{output_shape_str}] {out}" for out in self.output_names])
def get_tokenizer_ret():
return ",".join(self.output_names)
def build_output_imp():
return ""
def build_input_declare():
return f"{input_type_str0}[{input_shape_str0}] {self.input_names[0]}"
def build_unsqueeze():
if len(input_shape_0) == 1:
return f"""
input_with_batch = Unsqueeze({self.input_names[0]}, i64_0)
"""
else:
return f"""
input_with_batch = Identity({self.input_names[0]})
"""
converter_graph = onnx.parser.parse_graph(
f"""\
{onnx_tokenizer_impl} ({build_input_declare()})
=> ({build_output_declare()})
{{
i64_0 = Constant <value = int64[1] {{0}}> ()
{build_unsqueeze()}
{get_tokenizer_ret()} = com.microsoft.extensions.{onnx_tokenizer_impl} (input_with_batch)
{build_output_imp()}
}}
"""
)
roberta_tokenizer_param = self._tokenizer_param
token_model_attr = []
attrs = {
"vocab": roberta_tokenizer_param.vocab_or_file,
}
for attr in attrs:
token_model_attr.append(onnx.helper.make_attribute(attr, attrs[attr]))
node_idx = next(i for i, v in enumerate(converter_graph.node) if v.op_type == onnx_tokenizer_impl)
converter_graph.node[node_idx].attribute.extend(token_model_attr)
return converter_graph
onnx_opset = 16
roberta_hf_tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-hate-latest", use_fast=True)
# Add pre and post processing to the model
inputs = [create_named_value("input_text", onnx.TensorProto.STRING, [1, "num_sentences"])]
pipeline = PrePostProcessor(inputs, onnx_opset)
pipeline.add_pre_processing(
[
RobertaTokenizerStep(
TokenizerParam(
vocab_or_file=roberta_hf_tokenizer.vocab
)
)
]
)
pipeline.add_post_processing(
[
ArgMax(),
]
)
onnx_model = onnx.load("./src/models/twitter_roberta_base/roberta_model.onnx")
new_model = pipeline.run(onnx_model)
onnx.save_model(new_model, "./src/models/twitter_roberta_base/roberta_model.with_pre_post_processing.onnx")
Then I load the model:
session_options = onnxruntime.SessionOptions()
session_options.register_custom_ops_library(_lib_path())
# Load the ONNX model
onnx_model_path = './src/models/twitter_roberta_base/roberta_model.with_pre_post_processing.onnx'
onnx_model = onnx.load(onnx_model_path)
# Alternatively, iterate through the nodes to check for the custom operator
for node in onnx_model.graph.node:
if node.op_type == "HfRobertaTokenizer":
print(f"Custom operator {node.op_type} found in the model.")
ort_session = onnxruntime.InferenceSession(onnx_model_path, sess_options=session_options, providers=["CPUExecutionProvider"])
The text was updated successfully, but these errors were encountered:
You'll have to make sure you're registering the custom op in the same app/Python session that is loading the model. If model quantization is happening in a separate Python codebase from Python training you could export the onnx_op-decorated method in a package and import it wherever you're loading the model.
If you're trying to use the Java onnx-extensions package to run a model with your custom layer I believe the only ways to do so would be to 1) package your Python code in the Java app, register your custom operator, and point to the CustomOps lib in your Java app, or 2) write your custom operator in C++ and create a library with your customer operator.
Hey @JTunis! thanks for the response. I ended up creating a Rust crate with the huggingface tokenizers library and calling the tokenizer through Android NDK. Nevertheless, I think it would be nice to explore the second option you mentioned, do you have a sample to work on top of it?
I don't yet unfortunately. I've also been interested in implementing custom ops in Rust rather than C++ and just came across this. Haven't looked at it in detail yet, but maybe you could look at wrapping Huggingface's library in what they have going on there.
Hello! I am trying to create a custom op for a tokenizer. I would like to add the tokenizer to the model cardiffnlp/twitter-roberta-base-hate-latest for internal pre-processing. My idea is to use the onnxruntime_extensions library, specifically the PrePostProcessor pipeline, so I created a new Step to add to the pre-processing and it correctly saves the model, but then when I create an ort session it throws the following error:
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from ./src/models/twitter_roberta_base/roberta_model.with_pre_post_processing.onnx failed:Fatal error: com.microsoft.extensions:HfRobertaTokenizer(-1) is not a registered function/op
I am not sure if this is the way to accomplish what I want, but if anyone could provide me with some guidance it would be much appreciated. My goal is to quantize the model and run it in an Android app for SequenceClassification (HateSpeech specifically).
Here is the pipeline code:
Then I load the model:
The text was updated successfully, but these errors were encountered: