Skip to content

Commit

Permalink
add fastapi demo
Browse files Browse the repository at this point in the history
  • Loading branch information
truonghm committed Sep 22, 2023
1 parent fe89b78 commit 34ff3a9
Show file tree
Hide file tree
Showing 23 changed files with 367 additions and 16 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# model assets
models/*
**/models/*

# report
report/_extensions/*
report/QTDublinIrish.otf
Expand Down
12 changes: 12 additions & 0 deletions api/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# FROM python:3.9
FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel

WORKDIR /api/

COPY requirements.txt ./requirements.txt
RUN pip install --user --upgrade pip
RUN pip install -r requirements.txt

EXPOSE 8000

CMD [ "python", "main.py", "--server.port=8501", "--server.address=0.0.0.0"]
13 changes: 13 additions & 0 deletions api/app/base/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os


class Settings(object):
MONGO_DB_URI = os.getenv(
"SIMCEL_MONGODB_URI",
)
API_HOST_PORT = int(os.getenv("API_HOST_PORT", 9771))
API_HOST_DOMAIN = os.getenv("API_HOST_DOMAIN", "0.0.0.0")
RELOAD_CODE = os.getenv("RELOAD_CODE", False)
NUMBER_OF_WORKER = int(os.getenv("NUMBER_OF_WORKER", 4))

settings = Settings()
Empty file added api/app/base/db.py
Empty file.
Empty file added api/app/base/errors.py
Empty file.
25 changes: 25 additions & 0 deletions api/app/base/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
def get_log_config(logger_name: str, log_level: str = "INFO") -> dict:
log_config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"()": "uvicorn.logging.DefaultFormatter",
"fmt": "%(levelprefix)s %(asctime)s :: %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S",

},
},
"handlers": {
"default": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
},
"loggers": {
logger_name: {"handlers": ["default"], "level": log_level},
},
}

return log_config
135 changes: 135 additions & 0 deletions api/app/modules/js_detect/actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import logging
import os
from enum import Enum
import numpy as np

import torch
from src.codebert_bimodal.model import Model
from src.codebert_bimodal.utils import convert_examples_to_features
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from transformers import RobertaConfig, RobertaModel, RobertaTokenizer
from typing import Sequence

logger = logging.getLogger("js_detection")


class ModelConfig:
n_gpu = torch.cuda.device_count()
per_gpu_eval_batch_size = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_dir = "models"
model_path = "models/pytorch_model.bin"
local_rank = -1
max_seq_length = 200


class TextDescription(str, Enum):
MALICIOUS = "javascript perform malicious actions to trick users, steal data from users, \
or otherwise cause harm."
BENIGN = "javascript perform normal, non-harmful actions"


model_config = ModelConfig()


class JavaScriptDataset(Dataset):
def __init__(self, tokenizer, args, data, type=None):
self.examples = []
self.type = type
for js in data:
if self.type == "test":
js["label"] = 0
self.examples.append(convert_examples_to_features(js, tokenizer, args))

for idx, example in enumerate(self.examples[:3]):
logger.debug("*** Example ***")
logger.debug("idx: {}".format(idx))
logger.debug("code_tokens: {}".format([x.replace("\u0120", "_") for x in example.code_tokens]))
logger.debug("code_ids: {}".format(" ".join(map(str, example.code_ids))))
logger.debug("nl_tokens: {}".format([x.replace("\u0120", "_") for x in example.nl_tokens]))
logger.debug("nl_ids: {}".format(" ".join(map(str, example.nl_ids))))

def __len__(self):
return len(self.examples)

def __getitem__(self, i):
"""return both tokenized code ids and nl ids and label"""
return (
torch.tensor(self.examples[i].code_ids),
torch.tensor(self.examples[i].nl_ids),
torch.tensor(self.examples[i].label),
)


class JavaScriptClassifier:
def __init__(self):
self.device = model_config.device
self.tokenizer = RobertaTokenizer.from_pretrained(model_config.model_dir)
self.args = model_config
config = RobertaConfig.from_pretrained("microsoft/codebert-base")
config.num_labels = 2
model = RobertaModel.from_pretrained(
"microsoft/codebert-base",
from_tf=False,
config=config,
)
self.model = Model(model, config, self.tokenizer, model_config)
self.model.load_state_dict(torch.load(model_config.model_path))
self.model.to(self.args.device)

def predict(self, input_data):
eval_dataset = JavaScriptDataset(self.tokenizer, self.args, input_data, "test")

eval_batch_size = self.args.per_gpu_eval_batch_size * max(1, self.args.n_gpu)
# Note that DistributedSampler samples randomly
eval_sampler = (
SequentialSampler(eval_dataset) if self.args.local_rank == -1 else DistributedSampler(eval_dataset)
)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=eval_batch_size)

# multi-gpu evaluate
# if self.args.n_gpu > 1:
# model = torch.nn.DataParallel(self.model)
model = self.model
# model.eval()
# Eval!
logger.info("***** Running Test *****")
logger.info(" Num examples = %d", len(eval_dataset))
logger.info(" Batch size = %d", eval_batch_size)

nb_eval_steps = 0
all_predictions = []
for batch in eval_dataloader:
code_inputs = batch[0].to(self.args.device)
nl_inputs = batch[1].to(self.args.device)
labels = batch[2].to(self.args.device)
with torch.no_grad():
_, predictions = model(code_inputs, nl_inputs, labels)
all_predictions.append(predictions.cpu())
nb_eval_steps += 1
all_predictions = torch.cat(all_predictions, 0).squeeze().numpy()

# if isinstance(all_predictions, np.array):
# all_predictions = np.array(all_predictions)
logger.debug(all_predictions)
results = []

for example, pred in zip(input_data, all_predictions.tolist()):
if example["doc"] == TextDescription.MALICIOUS.value and pred == 1:
results.append({"idx": example["idx"], "label": "malicious"})
elif example["doc"] == TextDescription.MALICIOUS.value and pred == 0:
results.append({"idx": example["idx"], "label": "benign"})
elif example["doc"] == TextDescription.BENIGN.value and pred == 1:
results.append({"idx": example["idx"], "label": "benign"})
elif example["doc"] == TextDescription.BENIGN.value and pred == 0:
results.append({"idx": example["idx"], "label": "malicious"})

return results


cls = JavaScriptClassifier()


def get_cls():
return cls
40 changes: 40 additions & 0 deletions api/app/modules/js_detect/routers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import logging

from fastapi import Depends, FastAPI
from fastapi.routing import APIRouter

from . import actions, schemas

logger = logging.getLogger("js_detection")
router = APIRouter()


@router.post("/dummy_predict", response_model=schemas.JavaScriptResponse)
def dummy_predict(request: schemas.JavaScriptRequest):
examples = [
schemas.JavaScriptResponseItem(**{"idx": "0", "label": "malicious"}),
schemas.JavaScriptResponseItem(**{"idx": "1", "label": "benign"}),
]
return schemas.JavaScriptResponse(
**{"results": examples}
)

@router.post("/predict", response_model=schemas.JavaScriptResponse)
def predict(request: schemas.JavaScriptRequest, classifier: actions.JavaScriptClassifier = Depends(actions.get_cls)):
code_list = request.javascript

cls_input = []
for item in code_list:
cls_input.append(
{
"idx": item.idx,
"code": item.code,
"doc": actions.TextDescription.MALICIOUS.value
}
)
logger.debug(cls_input)
results = classifier.predict(cls_input)
results_by_schema = [schemas.JavaScriptResponseItem(**item) for item in results]
return schemas.JavaScriptResponse(
**{"results": results_by_schema}
)
17 changes: 17 additions & 0 deletions api/app/modules/js_detect/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from pydantic import BaseModel


class JavaScriptRequestItem(BaseModel):
idx: str
code: str

class JavaScriptRequest(BaseModel):
javascript: list[JavaScriptRequestItem]


class JavaScriptResponseItem(BaseModel):
idx: str
label: str

class JavaScriptResponse(BaseModel):
results: list[JavaScriptResponseItem]
12 changes: 12 additions & 0 deletions api/example.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"javascript": [
{
"idx": "string1",
"code": "console.log('Hello World!');"
},
{
"idx": "string2",
"code": "document.write('<center>' '<iframe width=\"11\" height=\"1\" ' 'src=\"http://laghzesh.rzb.ir\" ' 'style=\"border: 0px;\" ' 'frameborder=\"0\" ' 'scrolling=\"auto\">' '</iframe>');"
}
]
}
69 changes: 69 additions & 0 deletions api/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import time
from logging.config import dictConfig

import uvicorn
from app.base.config import settings
from app.base.logging import get_log_config
from app.modules.js_detect.routers import router as js_router
from fastapi.applications import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from uvicorn.config import LOGGING_CONFIG

dictConfig(
get_log_config(
logger_name="js_detection",
log_level="DEBUG",
)
)

app = FastAPI(
title="Malicious JavaScript Detection API Demo",
description="Malicious JavaScript Detection API Demo",
version="0.0.1",
docs_url="/documentation",
redoc_url="/redoc",
)

def add_all_routers(app):
app.include_router(
js_router, prefix="/js-detection"
)

add_all_routers(app)

@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response

ALLOWED_CORS_DOMAINS = ["*"]

app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_CORS_DOMAINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"]
)

@app.get("/")
async def read_main():
return {"msg": "Streamlit Backend APIs"}


if __name__ == "__main__":
LOGGING_CONFIG["formatters"]["access"][
"fmt"
] = '%(asctime)s %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"

uvicorn.run(
"main:app",
host=settings.API_HOST_DOMAIN,
port=settings.API_HOST_PORT,
reload=settings.RELOAD_CODE,
workers=settings.NUMBER_OF_WORKER,
)
15 changes: 15 additions & 0 deletions api/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
fastapi
numpy
pandas
pydantic
uvicorn
fastapi-utils
fastapi-cache2
SQLAlchemy
python-dotenv
aioredis
redis
pymysql
cryptography
transformers[torch]==4.33.1
scikit-learn==1.3.0
33 changes: 21 additions & 12 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,26 @@ version: '3.9'

services:

mongodb:
image: mongo:4.2.1
container_name: js-code
api:
build: ./api
restart: unless-stopped
container_name: api
env_file:
- ./api/.env
ports:
- 27017:27017
volumes: [ mongodbM1Data:/data/db, mongodbM1Config:/data/configdb ]
environment:
MONGO_INITDB_ROOT_USERNAME: jscode
MONGO_INITDB_ROOT_PASSWORD: jscode

volumes:
mongodbM1Data: null
mongodbM1Config: null
- 9771:9771
healthcheck:
test: ["CMD", "curl", "-f", "http://0.0.0.0:9771"]
timeout: 30s
retries: 10
volumes:
- ./api:/api
- ./src:/api/src
- ./models/codebert-bimodal/checkpoint-best-aver:/api/models
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
Loading

0 comments on commit 34ff3a9

Please sign in to comment.