forked from pytorch/serve
-
Notifications
You must be signed in to change notification settings - Fork 0
/
custom_handler.py
164 lines (141 loc) · 6.14 KB
/
custom_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import json
import logging
import os
import zipfile
from abc import ABC
import torch
import transformers
from transformers import BloomForCausalLM, BloomTokenizerFast
from ts.torch_handler.base_handler import BaseHandler
logger = logging.getLogger(__name__)
logger.info("Transformers version %s", transformers.__version__)
TORCH_DTYPES = {
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
}
class TransformersSeqClassifierHandler(BaseHandler, ABC):
"""
Transformers handler class for sequence, token classification and question answering.
"""
def __init__(self):
super(TransformersSeqClassifierHandler, self).__init__()
self.initialized = False
def initialize(self, ctx):
"""In this initialize function, the BERT model is loaded and
the Layer Integrated Gradients Algorithm for Captum Explanations
is initialized here.
Args:
ctx (context): It is a JSON Object containing information
pertaining to the model artifacts parameters.
"""
self.manifest = ctx.manifest
properties = ctx.system_properties
model_dir = properties.get("model_dir")
self.device = torch.device(
"cuda:" + str(properties.get("gpu_id"))
if torch.cuda.is_available() and properties.get("gpu_id") is not None
else "cpu"
)
# Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode
# further setup config can be added.
with zipfile.ZipFile(model_dir + "/model.zip", "r") as zip_ref:
zip_ref.extractall(model_dir + "/model")
# read configs for the mode, model_name, etc. from setup_config.json
setup_config_path = os.path.join(model_dir, "setup_config.json")
if os.path.isfile(setup_config_path):
with open(setup_config_path) as setup_config_file:
self.setup_config = json.load(setup_config_file)
else:
logger.warning("Missing the setup_config.json file.")
self.model = BloomForCausalLM.from_pretrained(
model_dir + "/model",
revision=self.setup_config["revision"],
max_memory={
int(key) if key.isnumeric() else key: value
for key, value in self.setup_config["max_memory"].items()
},
low_cpu_mem_usage=self.setup_config["low_cpu_mem_usage"],
device_map=self.setup_config["device_map"],
offload_folder=self.setup_config["offload_folder"],
offload_state_dict=self.setup_config["offload_state_dict"],
torch_dtype=TORCH_DTYPES[self.setup_config["torch_dtype"]],
)
self.tokenizer = BloomTokenizerFast.from_pretrained(
model_dir + "/model", return_tensors="pt"
)
self.model.eval()
logger.info("Transformer model from path %s loaded successfully", model_dir)
self.initialized = True
def preprocess(self, requests):
"""Basic text preprocessing, based on the user's chocie of application mode.
Args:
requests (str): The Input data in the form of text is passed on to the preprocess
function.
Returns:
list : The preprocess function returns a list of Tensor for the size of the word tokens.
"""
input_ids_batch = None
attention_mask_batch = None
for idx, data in enumerate(requests):
input_text = data.get("data")
if input_text is None:
input_text = data.get("body")
if isinstance(input_text, (bytes, bytearray)):
input_text = input_text.decode("utf-8")
max_length = self.setup_config["max_length"]
logger.info("Received text: '%s'", input_text)
inputs = self.tokenizer.encode_plus(
input_text,
max_length=int(max_length),
pad_to_max_length=True,
add_special_tokens=True,
return_tensors="pt",
)
input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
# making a batch out of the recieved requests
# attention masks are passed for cases where input tokens are padded.
if input_ids.shape is not None:
if input_ids_batch is None:
input_ids_batch = input_ids
attention_mask_batch = attention_mask
else:
input_ids_batch = torch.cat((input_ids_batch, input_ids), 0)
attention_mask_batch = torch.cat(
(attention_mask_batch, attention_mask), 0
)
return (input_ids_batch, attention_mask_batch)
def inference(self, input_batch):
"""Predict the class (or classes) of the received text using the
serialized transformers checkpoint.
Args:
input_batch (list): List of Text Tensors from the pre-process function is passed here
Returns:
list : It returns a list of the predicted value for the input text
"""
(input_ids_batch, _) = input_batch
inferences = []
input_ids_batch = input_ids_batch.to(self.device)
outputs = self.model.generate(
input_ids_batch,
do_sample=True,
max_new_tokens=int(self.setup_config["max_length"]),
top_p=0.95,
top_k=60,
)
for i, _ in enumerate(outputs):
inferences.append(
self.tokenizer.decode(outputs[i], skip_special_tokens=True)
)
logger.info("Generated text: '%s'", inferences)
print("Generated text", inferences)
return inferences
def postprocess(self, inference_output):
"""Post Process Function converts the predicted response into Torchserve readable format.
Args:
inference_output (list): It contains the predicted response of the input text.
Returns:
(list): Returns a list of the Predictions and Explanations.
"""
return inference_output