-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
23 changed files
with
367 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,7 @@ | ||
# model assets | ||
models/* | ||
**/models/* | ||
|
||
# report | ||
report/_extensions/* | ||
report/QTDublinIrish.otf | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# FROM python:3.9 | ||
FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel | ||
|
||
WORKDIR /api/ | ||
|
||
COPY requirements.txt ./requirements.txt | ||
RUN pip install --user --upgrade pip | ||
RUN pip install -r requirements.txt | ||
|
||
EXPOSE 8000 | ||
|
||
CMD [ "python", "main.py", "--server.port=8501", "--server.address=0.0.0.0"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
import os | ||
|
||
|
||
class Settings(object): | ||
MONGO_DB_URI = os.getenv( | ||
"SIMCEL_MONGODB_URI", | ||
) | ||
API_HOST_PORT = int(os.getenv("API_HOST_PORT", 9771)) | ||
API_HOST_DOMAIN = os.getenv("API_HOST_DOMAIN", "0.0.0.0") | ||
RELOAD_CODE = os.getenv("RELOAD_CODE", False) | ||
NUMBER_OF_WORKER = int(os.getenv("NUMBER_OF_WORKER", 4)) | ||
|
||
settings = Settings() |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
def get_log_config(logger_name: str, log_level: str = "INFO") -> dict: | ||
log_config = { | ||
"version": 1, | ||
"disable_existing_loggers": False, | ||
"formatters": { | ||
"default": { | ||
"()": "uvicorn.logging.DefaultFormatter", | ||
"fmt": "%(levelprefix)s %(asctime)s :: %(message)s", | ||
"datefmt": "%Y-%m-%d %H:%M:%S", | ||
|
||
}, | ||
}, | ||
"handlers": { | ||
"default": { | ||
"formatter": "default", | ||
"class": "logging.StreamHandler", | ||
"stream": "ext://sys.stderr", | ||
}, | ||
}, | ||
"loggers": { | ||
logger_name: {"handlers": ["default"], "level": log_level}, | ||
}, | ||
} | ||
|
||
return log_config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import logging | ||
import os | ||
from enum import Enum | ||
import numpy as np | ||
|
||
import torch | ||
from src.codebert_bimodal.model import Model | ||
from src.codebert_bimodal.utils import convert_examples_to_features | ||
from torch.utils.data import DataLoader, Dataset, SequentialSampler | ||
from torch.utils.data.distributed import DistributedSampler | ||
from transformers import RobertaConfig, RobertaModel, RobertaTokenizer | ||
from typing import Sequence | ||
|
||
logger = logging.getLogger("js_detection") | ||
|
||
|
||
class ModelConfig: | ||
n_gpu = torch.cuda.device_count() | ||
per_gpu_eval_batch_size = 8 | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
model_dir = "models" | ||
model_path = "models/pytorch_model.bin" | ||
local_rank = -1 | ||
max_seq_length = 200 | ||
|
||
|
||
class TextDescription(str, Enum): | ||
MALICIOUS = "javascript perform malicious actions to trick users, steal data from users, \ | ||
or otherwise cause harm." | ||
BENIGN = "javascript perform normal, non-harmful actions" | ||
|
||
|
||
model_config = ModelConfig() | ||
|
||
|
||
class JavaScriptDataset(Dataset): | ||
def __init__(self, tokenizer, args, data, type=None): | ||
self.examples = [] | ||
self.type = type | ||
for js in data: | ||
if self.type == "test": | ||
js["label"] = 0 | ||
self.examples.append(convert_examples_to_features(js, tokenizer, args)) | ||
|
||
for idx, example in enumerate(self.examples[:3]): | ||
logger.debug("*** Example ***") | ||
logger.debug("idx: {}".format(idx)) | ||
logger.debug("code_tokens: {}".format([x.replace("\u0120", "_") for x in example.code_tokens])) | ||
logger.debug("code_ids: {}".format(" ".join(map(str, example.code_ids)))) | ||
logger.debug("nl_tokens: {}".format([x.replace("\u0120", "_") for x in example.nl_tokens])) | ||
logger.debug("nl_ids: {}".format(" ".join(map(str, example.nl_ids)))) | ||
|
||
def __len__(self): | ||
return len(self.examples) | ||
|
||
def __getitem__(self, i): | ||
"""return both tokenized code ids and nl ids and label""" | ||
return ( | ||
torch.tensor(self.examples[i].code_ids), | ||
torch.tensor(self.examples[i].nl_ids), | ||
torch.tensor(self.examples[i].label), | ||
) | ||
|
||
|
||
class JavaScriptClassifier: | ||
def __init__(self): | ||
self.device = model_config.device | ||
self.tokenizer = RobertaTokenizer.from_pretrained(model_config.model_dir) | ||
self.args = model_config | ||
config = RobertaConfig.from_pretrained("microsoft/codebert-base") | ||
config.num_labels = 2 | ||
model = RobertaModel.from_pretrained( | ||
"microsoft/codebert-base", | ||
from_tf=False, | ||
config=config, | ||
) | ||
self.model = Model(model, config, self.tokenizer, model_config) | ||
self.model.load_state_dict(torch.load(model_config.model_path)) | ||
self.model.to(self.args.device) | ||
|
||
def predict(self, input_data): | ||
eval_dataset = JavaScriptDataset(self.tokenizer, self.args, input_data, "test") | ||
|
||
eval_batch_size = self.args.per_gpu_eval_batch_size * max(1, self.args.n_gpu) | ||
# Note that DistributedSampler samples randomly | ||
eval_sampler = ( | ||
SequentialSampler(eval_dataset) if self.args.local_rank == -1 else DistributedSampler(eval_dataset) | ||
) | ||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=eval_batch_size) | ||
|
||
# multi-gpu evaluate | ||
# if self.args.n_gpu > 1: | ||
# model = torch.nn.DataParallel(self.model) | ||
model = self.model | ||
# model.eval() | ||
# Eval! | ||
logger.info("***** Running Test *****") | ||
logger.info(" Num examples = %d", len(eval_dataset)) | ||
logger.info(" Batch size = %d", eval_batch_size) | ||
|
||
nb_eval_steps = 0 | ||
all_predictions = [] | ||
for batch in eval_dataloader: | ||
code_inputs = batch[0].to(self.args.device) | ||
nl_inputs = batch[1].to(self.args.device) | ||
labels = batch[2].to(self.args.device) | ||
with torch.no_grad(): | ||
_, predictions = model(code_inputs, nl_inputs, labels) | ||
all_predictions.append(predictions.cpu()) | ||
nb_eval_steps += 1 | ||
all_predictions = torch.cat(all_predictions, 0).squeeze().numpy() | ||
|
||
# if isinstance(all_predictions, np.array): | ||
# all_predictions = np.array(all_predictions) | ||
logger.debug(all_predictions) | ||
results = [] | ||
|
||
for example, pred in zip(input_data, all_predictions.tolist()): | ||
if example["doc"] == TextDescription.MALICIOUS.value and pred == 1: | ||
results.append({"idx": example["idx"], "label": "malicious"}) | ||
elif example["doc"] == TextDescription.MALICIOUS.value and pred == 0: | ||
results.append({"idx": example["idx"], "label": "benign"}) | ||
elif example["doc"] == TextDescription.BENIGN.value and pred == 1: | ||
results.append({"idx": example["idx"], "label": "benign"}) | ||
elif example["doc"] == TextDescription.BENIGN.value and pred == 0: | ||
results.append({"idx": example["idx"], "label": "malicious"}) | ||
|
||
return results | ||
|
||
|
||
cls = JavaScriptClassifier() | ||
|
||
|
||
def get_cls(): | ||
return cls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import logging | ||
|
||
from fastapi import Depends, FastAPI | ||
from fastapi.routing import APIRouter | ||
|
||
from . import actions, schemas | ||
|
||
logger = logging.getLogger("js_detection") | ||
router = APIRouter() | ||
|
||
|
||
@router.post("/dummy_predict", response_model=schemas.JavaScriptResponse) | ||
def dummy_predict(request: schemas.JavaScriptRequest): | ||
examples = [ | ||
schemas.JavaScriptResponseItem(**{"idx": "0", "label": "malicious"}), | ||
schemas.JavaScriptResponseItem(**{"idx": "1", "label": "benign"}), | ||
] | ||
return schemas.JavaScriptResponse( | ||
**{"results": examples} | ||
) | ||
|
||
@router.post("/predict", response_model=schemas.JavaScriptResponse) | ||
def predict(request: schemas.JavaScriptRequest, classifier: actions.JavaScriptClassifier = Depends(actions.get_cls)): | ||
code_list = request.javascript | ||
|
||
cls_input = [] | ||
for item in code_list: | ||
cls_input.append( | ||
{ | ||
"idx": item.idx, | ||
"code": item.code, | ||
"doc": actions.TextDescription.MALICIOUS.value | ||
} | ||
) | ||
logger.debug(cls_input) | ||
results = classifier.predict(cls_input) | ||
results_by_schema = [schemas.JavaScriptResponseItem(**item) for item in results] | ||
return schemas.JavaScriptResponse( | ||
**{"results": results_by_schema} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from pydantic import BaseModel | ||
|
||
|
||
class JavaScriptRequestItem(BaseModel): | ||
idx: str | ||
code: str | ||
|
||
class JavaScriptRequest(BaseModel): | ||
javascript: list[JavaScriptRequestItem] | ||
|
||
|
||
class JavaScriptResponseItem(BaseModel): | ||
idx: str | ||
label: str | ||
|
||
class JavaScriptResponse(BaseModel): | ||
results: list[JavaScriptResponseItem] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
{ | ||
"javascript": [ | ||
{ | ||
"idx": "string1", | ||
"code": "console.log('Hello World!');" | ||
}, | ||
{ | ||
"idx": "string2", | ||
"code": "document.write('<center>' '<iframe width=\"11\" height=\"1\" ' 'src=\"http://laghzesh.rzb.ir\" ' 'style=\"border: 0px;\" ' 'frameborder=\"0\" ' 'scrolling=\"auto\">' '</iframe>');" | ||
} | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import time | ||
from logging.config import dictConfig | ||
|
||
import uvicorn | ||
from app.base.config import settings | ||
from app.base.logging import get_log_config | ||
from app.modules.js_detect.routers import router as js_router | ||
from fastapi.applications import FastAPI, Request | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from uvicorn.config import LOGGING_CONFIG | ||
|
||
dictConfig( | ||
get_log_config( | ||
logger_name="js_detection", | ||
log_level="DEBUG", | ||
) | ||
) | ||
|
||
app = FastAPI( | ||
title="Malicious JavaScript Detection API Demo", | ||
description="Malicious JavaScript Detection API Demo", | ||
version="0.0.1", | ||
docs_url="/documentation", | ||
redoc_url="/redoc", | ||
) | ||
|
||
def add_all_routers(app): | ||
app.include_router( | ||
js_router, prefix="/js-detection" | ||
) | ||
|
||
add_all_routers(app) | ||
|
||
@app.middleware("http") | ||
async def add_process_time_header(request: Request, call_next): | ||
start_time = time.time() | ||
response = await call_next(request) | ||
process_time = time.time() - start_time | ||
response.headers["X-Process-Time"] = str(process_time) | ||
return response | ||
|
||
ALLOWED_CORS_DOMAINS = ["*"] | ||
|
||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=ALLOWED_CORS_DOMAINS, | ||
allow_credentials=True, | ||
allow_methods=["*"], | ||
allow_headers=["*"] | ||
) | ||
|
||
@app.get("/") | ||
async def read_main(): | ||
return {"msg": "Streamlit Backend APIs"} | ||
|
||
|
||
if __name__ == "__main__": | ||
LOGGING_CONFIG["formatters"]["access"][ | ||
"fmt" | ||
] = '%(asctime)s %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s' | ||
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" | ||
|
||
uvicorn.run( | ||
"main:app", | ||
host=settings.API_HOST_DOMAIN, | ||
port=settings.API_HOST_PORT, | ||
reload=settings.RELOAD_CODE, | ||
workers=settings.NUMBER_OF_WORKER, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
fastapi | ||
numpy | ||
pandas | ||
pydantic | ||
uvicorn | ||
fastapi-utils | ||
fastapi-cache2 | ||
SQLAlchemy | ||
python-dotenv | ||
aioredis | ||
redis | ||
pymysql | ||
cryptography | ||
transformers[torch]==4.33.1 | ||
scikit-learn==1.3.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.