-
Notifications
You must be signed in to change notification settings - Fork 27
/
app.py
172 lines (145 loc) · 5.41 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
from typing import Optional, List
from logging import getLogger
from fastapi import FastAPI, Depends, Response, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from typing import Union
from config import TRUST_REMOTE_CODE, get_allowed_tokens
from vectorizer import Vectorizer, VectorInput
from meta import Meta
logger = getLogger("uvicorn")
vec: Vectorizer
meta_config: Meta
get_bearer_token = HTTPBearer(auto_error=False)
allowed_tokens: List[str] = None
def is_authorized(auth: Optional[HTTPAuthorizationCredentials]) -> bool:
if allowed_tokens is not None and (
auth is None or auth.credentials not in allowed_tokens
):
return False
return True
async def lifespan(app: FastAPI):
global vec
global meta_config
global allowed_tokens
allowed_tokens = get_allowed_tokens()
cuda_env = os.getenv("ENABLE_CUDA")
cuda_per_process_memory_fraction = 1.0
if "CUDA_PER_PROCESS_MEMORY_FRACTION" in os.environ:
try:
cuda_per_process_memory_fraction = float(
os.getenv("CUDA_PER_PROCESS_MEMORY_FRACTION")
)
except ValueError:
logger.error(
f"Invalid CUDA_PER_PROCESS_MEMORY_FRACTION (should be between 0.0-1.0)"
)
if 0.0 <= cuda_per_process_memory_fraction <= 1.0:
logger.info(
f"CUDA_PER_PROCESS_MEMORY_FRACTION set to {cuda_per_process_memory_fraction}"
)
cuda_support = False
cuda_core = ""
if cuda_env is not None and cuda_env == "true" or cuda_env == "1":
cuda_support = True
cuda_core = os.getenv("CUDA_CORE")
if cuda_core is None or cuda_core == "":
cuda_core = "cuda:0"
logger.info(f"CUDA_CORE set to {cuda_core}")
else:
logger.info("Running on CPU")
# Batch text tokenization enabled by default
direct_tokenize = False
transformers_direct_tokenize = os.getenv("T2V_TRANSFORMERS_DIRECT_TOKENIZE")
if (
transformers_direct_tokenize is not None
and transformers_direct_tokenize == "true"
or transformers_direct_tokenize == "1"
):
direct_tokenize = True
model_dir = "./models/model"
def get_model_name() -> Union[str, bool]:
if os.path.exists(f"{model_dir}/model_name"):
with open(f"{model_dir}/model_name", "r") as f:
model_name = f.read()
return model_name, True
# Default model directory is ./models/model
return model_dir, False
def get_onnx_runtime() -> bool:
if os.path.exists(f"{model_dir}/onnx_runtime"):
with open(f"{model_dir}/onnx_runtime", "r") as f:
onnx_runtime = f.read()
return onnx_runtime == "true"
return False
def get_trust_remote_code() -> bool:
if os.path.exists(f"{model_dir}/trust_remote_code"):
with open(f"{model_dir}/trust_remote_code", "r") as f:
trust_remote_code = f.read()
return trust_remote_code == "true"
return TRUST_REMOTE_CODE
def log_info_about_onnx(onnx_runtime: bool):
if onnx_runtime:
onnx_quantization_info = "missing"
if os.path.exists(f"{model_dir}/onnx_quantization_info"):
with open(f"{model_dir}/onnx_quantization_info", "r") as f:
onnx_quantization_info = f.read()
logger.info(
f"Running ONNX vectorizer with quantized model for {onnx_quantization_info}"
)
model_name, use_sentence_transformer_vectorizer = get_model_name()
onnx_runtime = get_onnx_runtime()
trust_remote_code = get_trust_remote_code()
log_info_about_onnx(onnx_runtime)
meta_config = Meta(
model_dir,
model_name,
use_sentence_transformer_vectorizer,
trust_remote_code,
)
vec = Vectorizer(
model_dir,
cuda_support,
cuda_core,
cuda_per_process_memory_fraction,
meta_config.get_model_type(),
meta_config.get_architecture(),
direct_tokenize,
onnx_runtime,
use_sentence_transformer_vectorizer,
model_name,
trust_remote_code,
)
yield
app = FastAPI(lifespan=lifespan)
@app.get("/.well-known/live", response_class=Response)
@app.get("/.well-known/ready", response_class=Response)
async def live_and_ready(response: Response):
response.status_code = status.HTTP_204_NO_CONTENT
@app.get("/meta")
def meta(
response: Response,
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
):
if is_authorized(auth):
return meta_config.get()
else:
response.status_code = status.HTTP_401_UNAUTHORIZED
return {"error": "Unauthorized"}
@app.post("/vectors")
@app.post("/vectors/")
async def vectorize(
item: VectorInput,
response: Response,
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
):
if is_authorized(auth):
try:
vector = await vec.vectorize(item.text, item.config)
return {"text": item.text, "vector": vector.tolist(), "dim": len(vector)}
except Exception as e:
logger.exception("Something went wrong while vectorizing data.")
response.status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
return {"error": str(e)}
else:
response.status_code = status.HTTP_401_UNAUTHORIZED
return {"error": "Unauthorized"}