Skip to content

Commit

Permalink
add new log
Browse files Browse the repository at this point in the history
  • Loading branch information
truonghm committed Sep 22, 2023
1 parent cdaf8ba commit c79699e
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 11 deletions.
20 changes: 10 additions & 10 deletions api/app/modules/js_detect/actions.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import logging
import os
from enum import Enum
import numpy as np
from typing import Sequence

import numpy as np
import torch
from src.codebert_bimodal.model import Model
from src.codebert_bimodal.utils import convert_examples_to_features
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from transformers import RobertaConfig, RobertaModel, RobertaTokenizer
from typing import Sequence

logger = logging.getLogger("js_detection")

Expand Down Expand Up @@ -42,13 +42,13 @@ def __init__(self, tokenizer, args, data, type=None):
js["label"] = 0
self.examples.append(convert_examples_to_features(js, tokenizer, args))

for idx, example in enumerate(self.examples[:3]):
logger.debug("*** Example ***")
logger.debug("idx: {}".format(idx))
logger.debug("code_tokens: {}".format([x.replace("\u0120", "_") for x in example.code_tokens]))
logger.debug("code_ids: {}".format(" ".join(map(str, example.code_ids))))
logger.debug("nl_tokens: {}".format([x.replace("\u0120", "_") for x in example.nl_tokens]))
logger.debug("nl_ids: {}".format(" ".join(map(str, example.nl_ids))))
# for idx, example in enumerate(self.examples[:3]):
# logger.debug("*** Example ***")
# logger.debug("idx: {}".format(idx))
# logger.debug("code_tokens: {}".format([x.replace("\u0120", "_") for x in example.code_tokens]))
# logger.debug("code_ids: {}".format(" ".join(map(str, example.code_ids))))
# logger.debug("nl_tokens: {}".format([x.replace("\u0120", "_") for x in example.nl_tokens]))
# logger.debug("nl_ids: {}".format(" ".join(map(str, example.nl_ids))))

def __len__(self):
return len(self.examples)
Expand All @@ -75,7 +75,7 @@ def __init__(self):
config=config,
)
self.model = Model(model, config, self.tokenizer, model_config)
self.model.load_state_dict(torch.load(model_config.model_path))
self.model.load_state_dict(torch.load(model_config.model_path, map_location=self.device))
self.model.to(self.args.device)

def predict(self, input_data):
Expand Down
File renamed without changes.
Loading

0 comments on commit c79699e

Please sign in to comment.