diff --git a/examples/convert_to_tflite.py b/examples/convert_to_tflite.py index a2b76df..691476a 100644 --- a/examples/convert_to_tflite.py +++ b/examples/convert_to_tflite.py @@ -6,10 +6,12 @@ from dialognlu import TransformerNLU from dialognlu.utils.tf_utils import convert_to_tflite_model -model_path = "../saved_models/joint_distilbert_model" +# model_path = "../saved_models/joint_distilbert_model" +model_path = "../saved_models/joint_trans_bert_model" print("Loading model ...") nlu = TransformerNLU.load(model_path) -save_file_path = "../saved_models/joint_distilbert_model/model.tflite" -convert_to_tflite_model(nlu.model.model, save_file_path, conversion_mode="fp16_quantization") \ No newline at end of file +save_file_path = model_path + "/model.tflite" +convert_to_tflite_model(nlu.model.model, save_file_path, conversion_mode="hybrid_quantization") +print("Done") \ No newline at end of file diff --git a/examples/evaluate_tflite_transformer_nlu.py b/examples/evaluate_tflite_transformer_nlu.py new file mode 100644 index 0000000..b5b64a2 --- /dev/null +++ b/examples/evaluate_tflite_transformer_nlu.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +""" +@author: mwahdan +""" + +# diasable the GPU +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + + +from dialognlu import TransformerNLU +from dialognlu.readers.goo_format_reader import Reader +import time + + +num_process = 2 + + +model_path = "../saved_models/joint_distilbert_model" +# model_path = "../saved_models/joint_trans_bert_model" +# model_path = "../saved_models/joint_trans_albert_model" +# model_path = "../saved_models/joint_trans_roberta_model" + +print("Loading model ...") +nlu = TransformerNLU.load(model_path, quantized=True, num_process=num_process) + +print("Loading dataset ...") +test_path = "../data/snips/test" +test_dataset = Reader.read(test_path) + +print("Evaluating model ...") +t1 = time.time() +token_f1_score, tag_f1_score, report, acc = nlu.evaluate(test_dataset) +t2 = time.time() + +print('Slot Classification Report:', report) +print('Slot token f1_score = %f' % token_f1_score) +print('Slot tag f1_score = %f' % tag_f1_score) +print('Intent accuracy = %f' % acc) + +print("Using %d processes took %f seconds" % (num_process, t2 - t1)) \ No newline at end of file diff --git a/examples/predict_tflite_nlu.py b/examples/predict_tflite_nlu.py new file mode 100644 index 0000000..c2ac4bc --- /dev/null +++ b/examples/predict_tflite_nlu.py @@ -0,0 +1,26 @@ +# -*- coding: utf-8 -*- +""" +@author: mwahdan +""" + +# diasable the GPU +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + + +from dialognlu import TransformerNLU + + +# model_path = "../saved_models/joint_distilbert_model" +# model_path = "../saved_models/joint_trans_bert_model" +# model_path = "../saved_models/joint_trans_albert_model" +model_path = "../saved_models/joint_trans_roberta_model" + +print("Loading model ...") +nlu = TransformerNLU.load(model_path, quantized=True, num_process=1) + +print("Prediction ...") +utterance = "add sabrina salerno to the grime instrumentals playlist" +print ("utterance: {}".format(utterance)) +result = nlu.predict(utterance) +print ("result: {}".format(result)) \ No newline at end of file diff --git a/examples/train_transformer_nlu.py b/examples/train_transformer_nlu.py index 57079a6..260fcba 100644 --- a/examples/train_transformer_nlu.py +++ b/examples/train_transformer_nlu.py @@ -29,7 +29,7 @@ pretrained_model_name_or_path = "distilbert-base-uncased" save_path = "../saved_models/joint_distilbert_model" -epochs = 1 #3 +epochs = 3 batch_size = 64 @@ -40,6 +40,8 @@ "num_bert_fine_tune_layers": 10, "intent_loss_weight": 1.0,#0.2, "slots_loss_weight": 3.0,#2.0, + + "max_length": 64, # You can set max_length (recommended) or leave it and it will be computed automatically based on longest training example } @@ -47,5 +49,5 @@ nlu.train(train_dataset, val_dataset, epochs, batch_size) print("Saving ...") -nlu.save(save_path) +nlu.save(save_path, save_tflite=True, conversion_mode="hybrid_quantization") print("Done") \ No newline at end of file diff --git a/setup.py b/setup.py index 4a571f3..31d1331 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ setup( name="dialognlu", - version="0.1.0", + version="0.2.0", author="Mahmoud Wahdan", author_email="mahmoud.a.wahdan@gmail.com", description="State-of-the-art Dialog NLU (Natural Language Understanding) Library with TensorFlow 2.x and keras", diff --git a/src/dialognlu/__init__.py b/src/dialognlu/__init__.py index dd1ef7c..ac9d1dc 100644 --- a/src/dialognlu/__init__.py +++ b/src/dialognlu/__init__.py @@ -1,5 +1,5 @@ -__version__ = "0.1.0" +__version__ = "0.2.0" from .nlu_components import TransformerNLU, BertNLU from .auto_nlu import AutoNLU \ No newline at end of file diff --git a/src/dialognlu/models/base_joint_trans.py b/src/dialognlu/models/base_joint_trans.py index 7d8d509..829c74c 100644 --- a/src/dialognlu/models/base_joint_trans.py +++ b/src/dialognlu/models/base_joint_trans.py @@ -22,6 +22,7 @@ def __init__(self, config, trans_model=None, is_load=False): self.num_bert_fine_tune_layers = config.get('num_bert_fine_tune_layers', 10) self.intent_loss_weight = config.get('intent_loss_weight', 1.0) self.slots_loss_weight = config.get('slots_loss_weight', 3.0) + self.max_length = config.get('max_length') self.model_params = config @@ -115,4 +116,65 @@ def load_model_by_class(klazz, load_folder_path, trans_model_name): new_model = klazz(model_params, trans_model=None, is_load=True) new_model.model = tf.keras.models.load_model(os.path.join(load_folder_path, trans_model_name)) new_model.compile_model() - return new_model \ No newline at end of file + return new_model + + +class TfliteBaseJointTransformerModel: + + def __init__(self, config): + self.config = config + self.slots_num = config['slots_num'] + self.interpreter = None + + def predict_slots_intent(self, x, slots_vectorizer, intent_vectorizer, remove_start_end=True, + include_intent_prob=False): + # x = {k:v[0] for k,v in x.items()} + valid_positions = x["valid_positions"] + x["valid_positions"] = self.prepare_valid_positions(valid_positions) + y_slots, y_intent = self.predict(x) + slots = slots_vectorizer.inverse_transform(y_slots, valid_positions) + if remove_start_end: + slots = [x[1:-1] for x in slots] + + if not include_intent_prob: + intents = np.array([intent_vectorizer.inverse_transform([np.argmax(i)])[0] for i in y_intent]) + else: + intents = np.array([(intent_vectorizer.inverse_transform([np.argmax(i)])[0], round(float(np.max(i)), 4)) for i in y_intent]) + return slots[0], intents[0] + + def prepare_valid_positions(self, in_valid_positions): + in_valid_positions = np.expand_dims(in_valid_positions, axis=2) + in_valid_positions = np.tile(in_valid_positions, (1, 1, self.slots_num)) + return in_valid_positions + + def predict(self, inputs): + raise NotImplementedError() + + @staticmethod + def load_model_by_class(clazz, path): + with open(os.path.join(path, 'params.json'), 'r') as json_file: + model_params = json.load(json_file) + + new_model = clazz(model_params) + quant_model_file = os.path.join(path, 'model.tflite') + new_model.interpreter = tf.lite.Interpreter(model_path=str(quant_model_file), num_threads=1) + new_model.interpreter.allocate_tensors() + return new_model + + +class TfliteBaseJointTransformer4inputsModel(TfliteBaseJointTransformerModel): + + def __init__(self, config): + super(TfliteBaseJointTransformer4inputsModel, self).__init__(config) + + def predict(self, inputs): + self.interpreter.set_tensor(self.interpreter.get_input_details()[0]["index"], inputs.get("input_word_ids").astype(np.int32)) + self.interpreter.set_tensor(self.interpreter.get_input_details()[1]["index"], inputs.get("input_mask").astype(np.int32)) + self.interpreter.set_tensor(self.interpreter.get_input_details()[2]["index"], inputs.get("input_type_ids").astype(np.int32)) + self.interpreter.set_tensor(self.interpreter.get_input_details()[3]["index"], inputs.get("valid_positions").astype(np.float32)) + output_index_0 = self.interpreter.get_output_details()[0]["index"] + output_index_1 = self.interpreter.get_output_details()[1]["index"] + self.interpreter.invoke() + intent = self.interpreter.get_tensor(output_index_0) + slots = self.interpreter.get_tensor(output_index_1) + return slots, intent \ No newline at end of file diff --git a/src/dialognlu/models/joint_trans_albert.py b/src/dialognlu/models/joint_trans_albert.py index a160022..655dc89 100644 --- a/src/dialognlu/models/joint_trans_albert.py +++ b/src/dialognlu/models/joint_trans_albert.py @@ -6,7 +6,7 @@ import tensorflow as tf from tensorflow.keras.models import Model from tensorflow.keras.layers import Input, Dense, Multiply, TimeDistributed -from .base_joint_trans import BaseJointTransformerModel +from .base_joint_trans import BaseJointTransformerModel, TfliteBaseJointTransformerModel, TfliteBaseJointTransformer4inputsModel class JointTransAlbertModel(BaseJointTransformerModel): @@ -16,10 +16,10 @@ def __init__(self, config, trans_model=None, is_load=False): def build_model(self): - in_id = Input(shape=(None,), name='input_word_ids', dtype=tf.int32) - in_mask = Input(shape=(None,), name='input_mask', dtype=tf.int32) - in_segment = Input(shape=(None,), name='input_type_ids', dtype=tf.int32) - in_valid_positions = Input(shape=(None, self.slots_num), name='valid_positions') + in_id = Input(shape=(self.max_length), name='input_word_ids', dtype=tf.int32) + in_mask = Input(shape=(self.max_length), name='input_mask', dtype=tf.int32) + in_segment = Input(shape=(self.max_length), name='input_type_ids', dtype=tf.int32) + in_valid_positions = Input(shape=(self.max_length, self.slots_num), name='valid_positions') bert_inputs = [in_id, in_mask, in_segment] inputs = bert_inputs + [in_valid_positions] @@ -41,3 +41,14 @@ def save(self, model_path): @staticmethod def load(load_folder_path): return BaseJointTransformerModel.load_model_by_class(JointTransAlbertModel, load_folder_path, 'joint_albert_model.h5') + + + +class TfliteJointTransAlbertModel(TfliteBaseJointTransformer4inputsModel): + + def __init__(self, config): + super(TfliteJointTransAlbertModel, self).__init__(config) + + @staticmethod + def load(path): + return TfliteBaseJointTransformerModel.load_model_by_class(TfliteJointTransAlbertModel, path) \ No newline at end of file diff --git a/src/dialognlu/models/joint_trans_bert.py b/src/dialognlu/models/joint_trans_bert.py index 36baeb1..845f499 100644 --- a/src/dialognlu/models/joint_trans_bert.py +++ b/src/dialognlu/models/joint_trans_bert.py @@ -6,7 +6,8 @@ import tensorflow as tf from tensorflow.keras.models import Model from tensorflow.keras.layers import Input, Dense, Multiply, TimeDistributed -from .base_joint_trans import BaseJointTransformerModel +from .base_joint_trans import BaseJointTransformerModel, TfliteBaseJointTransformerModel, TfliteBaseJointTransformer4inputsModel +import numpy as np class JointTransBertModel(BaseJointTransformerModel): @@ -16,10 +17,10 @@ def __init__(self, config, trans_model=None, is_load=False): def build_model(self): - in_id = Input(shape=(None,), name='input_word_ids', dtype=tf.int32) - in_mask = Input(shape=(None,), name='input_mask', dtype=tf.int32) - in_segment = Input(shape=(None,), name='input_type_ids', dtype=tf.int32) - in_valid_positions = Input(shape=(None, self.slots_num), name='valid_positions') + in_id = Input(shape=(self.max_length), name='input_word_ids', dtype=tf.int32) + in_mask = Input(shape=(self.max_length), name='input_mask', dtype=tf.int32) + in_segment = Input(shape=(self.max_length), name='input_type_ids', dtype=tf.int32) + in_valid_positions = Input(shape=(self.max_length, self.slots_num), name='valid_positions') bert_inputs = [in_id, in_mask, in_segment] inputs = bert_inputs + [in_valid_positions] @@ -41,3 +42,14 @@ def save(self, model_path): @staticmethod def load(load_folder_path): return BaseJointTransformerModel.load_model_by_class(JointTransBertModel, load_folder_path, 'joint_bert_model.h5') + + + +class TfliteJointTransBertModel(TfliteBaseJointTransformer4inputsModel): + + def __init__(self, config): + super(TfliteJointTransBertModel, self).__init__(config) + + @staticmethod + def load(path): + return TfliteBaseJointTransformerModel.load_model_by_class(TfliteJointTransBertModel, path) \ No newline at end of file diff --git a/src/dialognlu/models/joint_trans_distilbert.py b/src/dialognlu/models/joint_trans_distilbert.py index bbc674a..8c75e24 100644 --- a/src/dialognlu/models/joint_trans_distilbert.py +++ b/src/dialognlu/models/joint_trans_distilbert.py @@ -3,7 +3,7 @@ @author: mwahdan """ -from .base_joint_trans import BaseJointTransformerModel +from .base_joint_trans import BaseJointTransformerModel, TfliteBaseJointTransformerModel from .callbacks import F1Metrics import tensorflow as tf from tensorflow.keras.models import Model @@ -18,9 +18,9 @@ def __init__(self, config, trans_model=None, is_load=False): def build_model(self): - in_id = Input(shape=(None,), name='input_word_ids', dtype=tf.int32) - in_mask = Input(shape=(None,), name='input_mask', dtype=tf.int32) - in_valid_positions = Input(shape=(None, self.slots_num), name='valid_positions') + in_id = Input(shape=(self.max_length), name='input_word_ids', dtype=tf.int32) + in_mask = Input(shape=(self.max_length), name='input_mask', dtype=tf.int32) + in_valid_positions = Input(shape=(self.max_length, self.slots_num), name='valid_positions') bert_inputs = [in_id, in_mask] inputs = bert_inputs + [in_valid_positions] @@ -78,4 +78,26 @@ def save(self, model_path): @staticmethod def load(load_folder_path): - return BaseJointTransformerModel.load_model_by_class(JointTransDistilBertModel, load_folder_path, 'joint_distilbert_model.h5') \ No newline at end of file + return BaseJointTransformerModel.load_model_by_class(JointTransDistilBertModel, load_folder_path, 'joint_distilbert_model.h5') + + + +class TfliteJointTransDistilBertModel(TfliteBaseJointTransformerModel): + + def __init__(self, config): + super(TfliteJointTransDistilBertModel, self).__init__(config) + + def predict(self, inputs): + self.interpreter.set_tensor(self.interpreter.get_input_details()[0]["index"], inputs.get("input_word_ids").astype(np.int32)) + self.interpreter.set_tensor(self.interpreter.get_input_details()[1]["index"], inputs.get("input_mask").astype(np.int32)) + self.interpreter.set_tensor(self.interpreter.get_input_details()[2]["index"], inputs.get("valid_positions").astype(np.float32)) + output_index_0 = self.interpreter.get_output_details()[0]["index"] + output_index_1 = self.interpreter.get_output_details()[1]["index"] + self.interpreter.invoke() + intent = self.interpreter.get_tensor(output_index_0) + slots = self.interpreter.get_tensor(output_index_1) + return slots, intent + + @staticmethod + def load(path): + return TfliteBaseJointTransformerModel.load_model_by_class(TfliteJointTransDistilBertModel, path) \ No newline at end of file diff --git a/src/dialognlu/models/joint_trans_roberta.py b/src/dialognlu/models/joint_trans_roberta.py index b45b729..f24c901 100644 --- a/src/dialognlu/models/joint_trans_roberta.py +++ b/src/dialognlu/models/joint_trans_roberta.py @@ -3,10 +3,13 @@ @author: mwahdan """ -from .base_joint_trans import BaseJointTransformerModel import tensorflow as tf +from tensorflow.keras.layers import Dense, Input, Multiply, TimeDistributed from tensorflow.keras.models import Model -from tensorflow.keras.layers import Input, Dense, Multiply, TimeDistributed + +from .base_joint_trans import (BaseJointTransformerModel, + TfliteBaseJointTransformer4inputsModel, + TfliteBaseJointTransformerModel) class JointTransRobertaModel(BaseJointTransformerModel): @@ -16,10 +19,10 @@ def __init__(self, config, trans_model=None, is_load=False): def build_model(self): - in_id = Input(shape=(None,), name='input_word_ids', dtype=tf.int32) - in_mask = Input(shape=(None,), name='input_mask', dtype=tf.int32) - in_segment = Input(shape=(None,), name='input_type_ids', dtype=tf.int32) - in_valid_positions = Input(shape=(None, self.slots_num), name='valid_positions') + in_id = Input(shape=(self.max_length), name='input_word_ids', dtype=tf.int32) + in_mask = Input(shape=(self.max_length), name='input_mask', dtype=tf.int32) + in_segment = Input(shape=(self.max_length), name='input_type_ids', dtype=tf.int32) + in_valid_positions = Input(shape=(self.max_length, self.slots_num), name='valid_positions') bert_inputs = [in_id, in_mask, in_segment] inputs = bert_inputs + [in_valid_positions] @@ -41,3 +44,13 @@ def save(self, model_path): @staticmethod def load(load_folder_path): return BaseJointTransformerModel.load_model_by_class(JointTransRobertaModel, load_folder_path, 'joint_roberta_model.h5') + + +class TfliteJointTransRobertaModel(TfliteBaseJointTransformer4inputsModel): + + def __init__(self, config): + super(TfliteJointTransRobertaModel, self).__init__(config) + + @staticmethod + def load(path): + return TfliteBaseJointTransformerModel.load_model_by_class(TfliteJointTransRobertaModel, path) diff --git a/src/dialognlu/models/model_pool.py b/src/dialognlu/models/model_pool.py new file mode 100644 index 0000000..fd0aea9 --- /dev/null +++ b/src/dialognlu/models/model_pool.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +""" +@author: mwahdan +""" + +from .joint_trans_bert import TfliteJointTransBertModel +from .joint_trans_distilbert import TfliteJointTransDistilBertModel +from .joint_trans_albert import TfliteJointTransAlbertModel +from .joint_trans_roberta import TfliteJointTransRobertaModel +import os +import multiprocessing + + +LOAD_TFLITE_CLASS_NAME_2_MODEL = { + 'JointTransDistilBertModel': TfliteJointTransDistilBertModel, + 'JointTransBertModel': TfliteJointTransBertModel, + 'JointTransAlbertModel': TfliteJointTransAlbertModel, + 'JointTransRobertaModel': TfliteJointTransRobertaModel +} + + +""" +To solve the problem that Pool initializer don't give us access to the variables created in initializer, we used global variable +The idea of using global variable is: + Each worker is in a separate process. Thus, you can use an ordinary global variable. +Source: https://stackoverflow.com/questions/10117073/how-to-use-initializer-to-set-up-my-multiprocess-pool/10118250#10118250 + +I though about another idea that may seem be the same but a little better than gloabl variable. +The idea is to use a class with static methods. And because the process will have only one class, +then it is safe. +""" +class WorkerProcessor: + + @staticmethod + def load_model(clazz_name, load_folder_path): + if clazz_name not in LOAD_TFLITE_CLASS_NAME_2_MODEL: + raise Exception('%s has no supported tflite model') + model = LOAD_TFLITE_CLASS_NAME_2_MODEL[clazz_name].load(load_folder_path) + WorkerProcessor.model = model + print("Model Loaded, process id: %d" % os.getpid()) + + @staticmethod + def predict_slots_intent(x): + return WorkerProcessor.model.predict_slots_intent(x[0], x[1], x[2], x[3], x[4]) + + +class NluModelPool: + + def __init__(self, clzz, path, num_process=2): + self.pool = multiprocessing.Pool(initializer=WorkerProcessor.load_model, initargs=(clzz, path,), processes=num_process) + + def predict_slots_intent(self, X, slots_vectorizer, intent_vectorizer, remove_start_end=True, + include_intent_prob=False): + parameters = [] + for i in range(len(X["valid_positions"])): + parameters.append(({k:v[[i]] for k,v in X.items()}, slots_vectorizer, intent_vectorizer, remove_start_end, include_intent_prob,)) + output = self.pool.map(WorkerProcessor.predict_slots_intent, parameters) + slots = [] + intents = [] + for i in output: + slots.append(i[0]) + intents.append(i[1]) + return slots, intents \ No newline at end of file diff --git a/src/dialognlu/models/trans_auto_model.py b/src/dialognlu/models/trans_auto_model.py index ae54299..44cad2a 100644 --- a/src/dialognlu/models/trans_auto_model.py +++ b/src/dialognlu/models/trans_auto_model.py @@ -8,6 +8,7 @@ from .joint_trans_albert import JointTransAlbertModel #from .joint_trans_xlnet import JointTransXlnetModel from .joint_trans_roberta import JointTransRobertaModel +from .model_pool import NluModelPool from ..compression.commons import from_pretrained from transformers import TFAutoModel import json @@ -67,11 +68,15 @@ def create_joint_trans_model(config): return joint_model -def load_joint_trans_model(load_folder_path): +def load_joint_trans_model(load_folder_path, quantized=False, num_process=4): with open(os.path.join(load_folder_path, 'params.json'), 'r') as json_file: model_params = json.load(json_file) clazz = model_params['class'] - if clazz not in LOAD_CLASS_NAME_2_MODEL: - raise Exception('%s not supported') - model = LOAD_CLASS_NAME_2_MODEL[clazz].load(load_folder_path) + if quantized: + print("Loading quantized model in %d processes" % num_process) + model = NluModelPool(clazz, load_folder_path, num_process) + else: + if clazz not in LOAD_CLASS_NAME_2_MODEL: + raise Exception('%s not supported') + model = LOAD_CLASS_NAME_2_MODEL[clazz].load(load_folder_path) return model \ No newline at end of file diff --git a/src/dialognlu/nlu_components.py b/src/dialognlu/nlu_components.py index 8a502b4..eef4ad8 100644 --- a/src/dialognlu/nlu_components.py +++ b/src/dialognlu/nlu_components.py @@ -11,6 +11,7 @@ from .vectorizers.tags_vectorizer import TagsVectorizer from .readers.dataset import NluDataset from .utils.data_utils import flatten, convert_to_slots +from .utils.tf_utils import convert_to_tflite_model from sklearn.preprocessing import LabelEncoder from sklearn import metrics from seqeval.metrics import classification_report, f1_score @@ -79,6 +80,11 @@ def train(self, train_dataset: NluDataset, val_dataset: NluDataset=None, epochs= print('Vectorizing training text ...') # train_input_ids, train_input_mask, train_segment_ids, train_valid_positions, train_sequence_lengths = self.text_vectorizer.transform(train_dataset.text) train_data = self.text_vectorizer.transform(train_dataset.text) + # get max if not exist + max_length = self.config.get("max_length", None) + if max_length is None: + max_length = self.text_vectorizer.max_length + self.config["max_length"] = max_length train_valid_positions = train_data["valid_positions"] if self.tags_vectorizer is None: print('Fitting tags encoder ...') @@ -167,18 +173,22 @@ def from_config(config: dict): def init_text_vectorizer(self): pretrained_model_name_or_path = self.config["pretrained_model_name_or_path"] cache_dir = self.config["cache_dir"] - max_length = None # TODO: support max_length + max_length = self.config.get("max_length", None) # get max_length or None. If None, it will be computed internally self.text_vectorizer = TransVectorizer(pretrained_model_name_or_path, max_length, cache_dir) def init_model(self): self.model = create_joint_trans_model(self.config) @staticmethod - def load(path): + def load(path, quantized=False, num_process=4): new_instance = JointNLU.load_pickles(path, TransformerNLU) - new_instance.model = load_joint_trans_model(path) + new_instance.model = load_joint_trans_model(path, quantized, num_process) return new_instance + def save(self, path, save_tflite=False, conversion_mode="hybrid_quantization"): + super(TransformerNLU, self).save(path) + if save_tflite: + convert_to_tflite_model(self.model.model, os.path.join(path, "model.tflite"), conversion_mode=conversion_mode) class BertNLU(JointNLU): diff --git a/src/dialognlu/vectorizers/trans_vectorizer.py b/src/dialognlu/vectorizers/trans_vectorizer.py index caffdd0..dc0d353 100644 --- a/src/dialognlu/vectorizers/trans_vectorizer.py +++ b/src/dialognlu/vectorizers/trans_vectorizer.py @@ -10,8 +10,8 @@ class TransVectorizer: - def __init__(self, pretrained_model_name_or_path, max_length, cache_dir=None): - self.max_length =max_length + def __init__(self, pretrained_model_name_or_path, max_length=None, cache_dir=None): + self.max_length = max_length self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, cache_dir=cache_dir) self.tokenizer_type = self.tokenizer.__class__.__name__ self.valid_start = None @@ -76,12 +76,15 @@ def transform(self, text_arr): valid_positions.append(valid_pos) sequence_lengths = np.array([len(i) for i in input_ids]) - input_ids = tf.keras.preprocessing.sequence.pad_sequences(input_ids, padding='post') - input_mask = tf.keras.preprocessing.sequence.pad_sequences(input_mask, padding='post') - segment_ids = tf.keras.preprocessing.sequence.pad_sequences(segment_ids, padding='post') - valid_positions = tf.keras.preprocessing.sequence.pad_sequences(valid_positions, padding='post') + input_ids = tf.keras.preprocessing.sequence.pad_sequences(input_ids, padding='post', maxlen=self.max_length) + input_mask = tf.keras.preprocessing.sequence.pad_sequences(input_mask, padding='post', maxlen=self.max_length) + segment_ids = tf.keras.preprocessing.sequence.pad_sequences(segment_ids, padding='post', maxlen=self.max_length) + valid_positions = tf.keras.preprocessing.sequence.pad_sequences(valid_positions, padding='post', maxlen=self.max_length) result = {"input_word_ids": input_ids, "input_mask": input_mask, "input_type_ids": segment_ids, "valid_positions": valid_positions, "sequence_lengths": sequence_lengths} + # set new max_length if None + if self.max_length is None: + self.max_length = input_ids.shape[1] return result