Skip to content

Commit

Permalink
Fixed Mask-RCNN model logging and NUM_CLASSES (#201)
Browse files Browse the repository at this point in the history
  • Loading branch information
luxedo authored and jsbroks committed Apr 30, 2019
1 parent 538ed1a commit d757cc5
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions backend/webserver/util/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from keras.preprocessing.image import img_to_array
from mrcnn.config import Config
import mrcnn.model as modellib
import logging
logger = logging.getLogger('gunicorn.error')


MODEL_DIR = "/workspace/models"
Expand All @@ -18,14 +20,14 @@ class CocoConfig(Config):
NAME = "coco"
GPU_COUNT = 1
IMAGES_PER_GPU = 1
NUM_CLASSES = 1 + 80
NUM_CLASSES = len(CLASS_NAMES)


class MaskRCNN():

def __init__(self):
self.config = CocoConfig()

self.config = CocoConfig()
self.model = modellib.MaskRCNN(
mode="inference",
model_dir=MODEL_DIR,
Expand All @@ -34,26 +36,27 @@ def __init__(self):
try:
self.model.load_weights(COCO_MODEL_PATH, by_name=True)
self.model.keras_model._make_predict_function()
logger.info(f"Loaded MaskRCNN model: {COCO_MODEL_PATH}")
except:
logger.error(f"Could not load MaskRCNN model (place 'mask_rcnn_coco.h5' in the models directory)")
self.model = None


def detect(self, image):

if self.model is None:
return {}

image = image.convert('RGB')
width, height = image.size
image.thumbnail((1024, 1024))

image = img_to_array(image)
result = self.model.detect([image])[0]

masks = result.get('masks')
class_ids = result.get('class_ids')

coco_image = im.Image(width=width, height=height)

for i in range(masks.shape[-1]):
Expand All @@ -68,4 +71,3 @@ def detect(self, image):


model = MaskRCNN()

0 comments on commit d757cc5

Please sign in to comment.