Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose flask endpoint #376

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .restore_face_v1 import restore_face_route_v1
128 changes: 128 additions & 0 deletions handlers/restore_face_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import base64

import cv2
import numpy as np
import torch
from flask import request, jsonify, Blueprint
from torchvision.transforms.functional import normalize

from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from basicsr.utils.misc import gpu_is_available, get_device
from basicsr.utils.registry import ARCH_REGISTRY
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from facelib.utils.misc import is_gray
from json_bodies.ImageRestoreFace import ImageRestoreFace
from utils.logs import log

pretrain_model_url = {
'restoration': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth',
}


def set_realesrgan(bg_tile):
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.utils.realesrgan_utils import RealESRGANer

use_half = False
if torch.cuda.is_available():
no_half_gpu_list = ['1650', '1660']
if not any(gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list):
use_half = True

model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
upsampler = RealESRGANer(
scale=2,
model_path="https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth",
model=model,
tile=bg_tile,
tile_pad=40,
pre_pad=0,
half=use_half
)

if not gpu_is_available():
import warnings
warnings.warn(
'Running on CPU now! Make sure your PyTorch version matches your CUDA. The unoptimized RealESRGAN is slow '
'on CPU.',
category=RuntimeWarning)
return upsampler


restore_face_route_v1 = Blueprint('restore_face_v1', __name__, url_prefix='/v1/restore_face')


@restore_face_route_v1.route("/image", methods=["POST"])
def restore_face():
r = ImageRestoreFace(**request.get_json())
device = get_device()

bg_upsampler = set_realesrgan(r.bg_tile) if r.bg_upsampler == 'realesrgan' else None
face_upsampler = bg_upsampler if r.face_upsampler else None

net = ARCH_REGISTRY.get('CodeFormer')(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9,
connect_list=['32', '64', '128', '256']).to(device)
ckpt_path = load_file_from_url(url=pretrain_model_url['restoration'], model_dir='models/CodeFormer', progress=True)
checkpoint = torch.load(ckpt_path)['params_ema']
net.load_state_dict(checkpoint)
net.eval()

face_helper = FaceRestoreHelper(
r.upscale,
face_size=512,
crop_ratio=(1, 1),
det_model=r.detection_model,
save_ext='png',
use_parse=True,
device=device
)

restored_images = []

for img_b64 in r.input_images:
face_helper.clean_all()
img_data = base64.b64decode(img_b64)
img_array = np.frombuffer(img_data, np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)

if r.has_aligned:
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
face_helper.is_gray = is_gray(img, threshold=10)
face_helper.cropped_faces = [img]
else:
face_helper.read_image(img)
num_det_faces = face_helper.get_face_landmarks_5(only_center_face=r.only_center_face, resize=640,
eye_dist_threshold=5)
log.info(f"Detected {num_det_faces} faces")
face_helper.align_warp_face()

for cropped_face in face_helper.cropped_faces:
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, [0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)

with torch.no_grad():
output = net(cropped_face_t, w=r.fidelity_weight, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
face_helper.add_restored_face(restored_face, cropped_face)

if not r.has_aligned:
bg_img = bg_upsampler.enhance(img, outscale=r.upscale)[0] if bg_upsampler else None
face_helper.get_inverse_affine(None)
restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=r.draw_box,
face_upsampler=face_upsampler) if face_upsampler else face_helper.paste_faces_to_input_image(
upsample_img=bg_img, draw_box=r.draw_box)
_, buffer = cv2.imencode('.png', restored_img)
img_b64 = base64.b64encode(buffer).decode('utf-8')
restored_images.append(img_b64)

return jsonify(restored_images)
23 changes: 23 additions & 0 deletions json_bodies/ImageRestoreFace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from enum import Enum
from typing import Optional, List

from pydantic import BaseModel


class DeviceEnum(str, Enum):
cuda = "cuda"
cpu = "cpu"


class ImageRestoreFace(BaseModel):
input_images: List[str]
fidelity_weight: Optional[float] = 0.5
upscale: Optional[int] = 2
has_aligned: Optional[bool] = False
only_center_face: Optional[bool] = False
draw_box: Optional[bool] = False
detection_model: Optional[str] = "retinaface_resnet50"
bg_upsampler: Optional[str] = None
face_upsampler: Optional[bool] = False
bg_tile: Optional[int] = 400
device: Optional[DeviceEnum] = DeviceEnum.cuda
Empty file added json_bodies/__init__.py
Empty file.
42 changes: 42 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import argparse

from flask import Flask, Blueprint, jsonify

from handlers import restore_face_route_v1
from utils.logs import log

# create a parser object
parser = argparse.ArgumentParser(description="A Flask app to frontend Codeformer")

# add arguments
parser.add_argument('--port', type=int, default=5000, help='The port to run the server on')
parser.add_argument('--prefix', type=str, default='codeformer', help='The route prefix for server to use')
parser.add_argument('--host', type=str, default='127.0.0.1', help='The host')

args = parser.parse_args()

# create a flask app with port, server as command line arguments
app = Flask(__name__)
prefix_route = Blueprint(args.prefix, __name__, url_prefix=f'/{args.prefix}')
prefix_route.register_blueprint(restore_face_route_v1)

app.register_blueprint(prefix_route)


@app.errorhandler(404)
def page_not_found(error):
log.error(error)
data = {"data": {"message": "Invalid route"}, "err": {}}
return jsonify(data), 404


@app.errorhandler(405)
def method_not_allowed(error):
log.error(error)
data = {"data": {"message": "Method not allowed"}, "err": {}}
return jsonify(data), 405


# run the app
if __name__ == '__main__':
app.run(port=args.port, host=args.host)
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@ torchvision
tqdm
yapf
lpips
gdown # supports downloading the large file from Google Drive
gdown # supports downloading the large file from Google Drive
flask
structlog
pydantic
Empty file added utils/__init__.py
Empty file.
26 changes: 26 additions & 0 deletions utils/logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import logging

import structlog


def init_logging():
structlog.configure(
processors=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.StackInfoRenderer(),
structlog.dev.set_exc_info,
structlog.processors.TimeStamper(fmt="iso"),
structlog.processors.JSONRenderer(),
],
wrapper_class=structlog.make_filtering_bound_logger(logging.NOTSET),
context_class=dict,
logger_factory=structlog.PrintLoggerFactory(),
cache_logger_on_first_use=False,
)
logger = structlog.get_logger()
logger.debug("Logging initialized")
return logger


log = init_logging()