Skip to content

Commit

Permalink
fix: support more aspect ratios
Browse files Browse the repository at this point in the history
  • Loading branch information
bugale committed Apr 27, 2024
1 parent c94e731 commit 37ec162
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 22 deletions.
16 changes: 8 additions & 8 deletions buganime/buganime.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,20 +141,20 @@ def process_file(input_path: str) -> None:
logging.info('ffprobe %s wrote %s, %s', str(proc.args), proc.stderr, proc.stdout)
video_info = parse_streams(json.loads(proc.stdout)['streams'])

try:
with lock_mutex(name=UPSCALE_MUTEX_NAME):
logging.info('Running Upscaler')
asyncio.run(transcode.Transcoder(input_path=input_path, output_path=output_path, height_out=2160, video_info=video_info).run())
logging.info('Upscaler for %s finished', input_path)
except Exception:
logging.exception('Failed to convert %s', input_path)
with lock_mutex(name=UPSCALE_MUTEX_NAME):
logging.info('Running Upscaler')
asyncio.run(transcode.Transcoder(input_path=input_path, output_path=output_path, height_out=2160, width_out=3840, video_info=video_info).run())
logging.info('Upscaler for %s finished', input_path)


def process_path(input_path: str) -> None:
if os.path.isdir(input_path):
for root, _, files in os.walk(input_path):
for file in files:
process_file(input_path=os.path.join(root, file))
try:
process_file(input_path=os.path.join(root, file))
except Exception:
logging.exception('Failed to convert %s', input_path)
else:
process_file(input_path=input_path)

Expand Down
19 changes: 14 additions & 5 deletions buganime/transcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tempfile
import asyncio
import logging
import warnings
from dataclasses import dataclass
from typing import AsyncIterator, cast, Optional

Expand Down Expand Up @@ -44,14 +45,14 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor:
tensor = body(tensor)
return cast(torch.Tensor, self.__upsampler(tensor) + base)

def __init__(self, input_path: str, output_path: str, height_out: int, video_info: VideoInfo) -> None:
def __init__(self, input_path: str, output_path: str, height_out: int, width_out: int, video_info: VideoInfo) -> None:
if not os.path.isfile(MODEL_PATH):
with open(MODEL_PATH, 'wb') as file:
file.write(requests.get(MODEL_URL, timeout=600).content)
self.__input_path, self.__output_path = input_path, output_path
self.__video_info = video_info
self.__height_out = height_out
self.__width_out = round(self.__video_info.width * self.__height_out / self.__video_info.height)
self.__width_out = width_out
model = Transcoder.Module(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4)
model.load_state_dict(torch.load(MODEL_PATH)['params'], strict=True)
self.__model = model.eval().cuda().half()
Expand Down Expand Up @@ -79,9 +80,16 @@ async def __read_input_frames(self) -> AsyncIterator[bytes]:
async def __write_output_frames(self, frames: AsyncIterator[bytes]) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
os.link(self.__input_path, os.path.join(temp_dir, 'input.mkv'))
args = ('-f', 'rawvideo', '-framerate', str(self.__video_info.fps), '-pix_fmt', 'rgb24', '-s', f'{self.__width_out}x{self.__height_out}',
width_out = self.__width_out
height_out = self.__height_out
if self.__video_info.width / self.__video_info.height > self.__width_out / self.__height_out:
height_out = round(self.__video_info.height * self.__width_out / self.__video_info.width)
else:
width_out = round(self.__video_info.width * self.__height_out / self.__video_info.height)
args = ('-f', 'rawvideo', '-framerate', str(self.__video_info.fps), '-pix_fmt', 'rgb24', '-s', f'{width_out}x{height_out}',
'-i', 'pipe:', '-i', 'input.mkv',
'-map', '0', '-map', f'1:{self.__video_info.audio_index}', '-vf', f'subtitles=input.mkv:si={self.__video_info.subtitle_index}',
'-map', '0', '-map', f'1:{self.__video_info.audio_index}',
'-vf', f'subtitles=input.mkv:si={self.__video_info.subtitle_index}, pad={self.__width_out}:{self.__height_out}:(ow-iw)/2:(oh-ih)/2:black',
*FFMPEG_OUTPUT_ARGS, self.__output_path,
'-loglevel', 'warning', '-y')
proc = await asyncio.subprocess.create_subprocess_exec('ffmpeg', *args, stdin=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE,
Expand Down Expand Up @@ -111,7 +119,8 @@ async def __upscale_frame(self, frame: bytes) -> bytes:
if self.__video_info.height == self.__height_out:
return frame
with torch.no_grad():
frame_arr = torch.frombuffer(frame, dtype=torch.uint8).reshape([self.__video_info.height, self.__video_info.width, 3])
with warnings.catch_warnings(action='ignore'):
frame_arr = torch.frombuffer(frame, dtype=torch.uint8).reshape([self.__video_info.height, self.__video_info.width, 3])
assert self.__gpu_lock
async with self.__gpu_lock:
frame_cpu = await asyncio.to_thread(self.__gpu_upscale, frame_arr)
Expand Down
Binary file added tests/data/1.mkv
Binary file not shown.
Binary file added tests/data/2.mkv
Binary file not shown.
67 changes: 58 additions & 9 deletions tests/test_buganime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import os
import tempfile
import json
import subprocess
import functools
import typing

import cv2
import numpy
import pytest

from buganime import buganime, transcode

NAME_CONVERSIONS = [
Expand Down Expand Up @@ -66,6 +74,11 @@
]


@pytest.mark.parametrize('path,result', NAME_CONVERSIONS)
def test_parse_filename(path: str, result: buganime.TVShow | buganime.Movie) -> None:
assert buganime.parse_filename(path) == result


STREAM_CONVERSIONS = [
('0.json', transcode.VideoInfo(audio_index=1, subtitle_index=1, width=1920, height=1080, fps='24000/1001', frames=34094)),
('1.json', transcode.VideoInfo(audio_index=1, subtitle_index=3, width=1920, height=1080, fps='24000/1001', frames=34095)),
Expand All @@ -78,18 +91,54 @@
]


def test_parse_filename() -> None:
for path, result in NAME_CONVERSIONS:
assert buganime.parse_filename(path) == result
@pytest.mark.parametrize('filename,result', STREAM_CONVERSIONS)
def test_parse_streams(filename: str, result: transcode.VideoInfo) -> None:
with open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb') as file:
assert buganime.parse_streams(json.loads(file.read())['streams']) == result


def _check_side_bars(frame: numpy.ndarray, bar_size: int) -> None:
assert max(cv2.mean(frame[0:, :bar_size])[:3]) < 1
assert max(cv2.mean(frame[0:, -bar_size:])[:3]) < 1
assert min(cv2.mean(frame[:1, bar_size:-bar_size])[:3]) > 254
assert min(cv2.mean(frame[-1:, bar_size:-bar_size])[:3]) > 254


def test_parse_streams() -> None:
for filename, result in STREAM_CONVERSIONS:
with open(os.path.join(os.path.dirname(__file__), 'data', filename), 'rb') as file:
assert buganime.parse_streams(json.loads(file.read())['streams']) == result
def _check_top_bottom_bars(frame: numpy.ndarray, bar_size: int) -> None:
assert max(cv2.mean(frame[:bar_size])[:3]) < 1
assert max(cv2.mean(frame[-bar_size:])[:3]) < 1
assert min(cv2.mean(frame[bar_size:-bar_size, :1])[:3]) > 254
assert min(cv2.mean(frame[bar_size:-bar_size, -1:])[:3]) > 254


VIDEO_TESTS = [
('0.mkv', '24000/1001', None),

# 1900x1080 -> 3840x2160, validate black bars on left/right
('1.mkv', '24000/1001', functools.partial(_check_side_bars, bar_size=20)),

# 1940x1080 -> 3840x2160, validate black bars on top/bottom
('2.mkv', '24000/1001', functools.partial(_check_top_bottom_bars, bar_size=11)),
]


def test_sanity() -> None:
@pytest.mark.parametrize('filename,fps,check_func', VIDEO_TESTS)
def test_transcode(filename: str, fps: str, check_func: typing.Callable[[numpy.ndarray], None] | None) -> None:
with tempfile.TemporaryDirectory() as tempdir:
buganime.OUTPUT_DIR = tempdir
buganime.process_file(os.path.join(os.path.dirname(__file__), 'data', '0.mkv'))
output_path = os.path.join(tempdir, 'Movies', filename)
buganime.process_file(os.path.join(os.path.dirname(__file__), 'data', filename))
assert os.path.isfile(output_path)
stream = json.loads(subprocess.run(['ffprobe', '-show_format', '-show_streams', '-of', 'json', output_path], text=True, capture_output=True,
check=True, encoding='utf-8').stdout)['streams'][0]
assert stream['codec_name'] == 'hevc'
assert stream['width'] == 3840
assert stream['height'] == 2160
assert stream['r_frame_rate'] == fps
if check_func is not None:
video = cv2.VideoCapture(output_path)
try:
frame = video.read()[1]
check_func(frame)
finally:
video.release()

0 comments on commit 37ec162

Please sign in to comment.