Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
bugale committed Apr 27, 2024
1 parent a119c3f commit 3fc5604
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/check-code.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,6 @@ jobs:
- name: Install test dependencies
run: python -m pip install -r dev-requirements.txt
- name: Test
run: pytest tests --full-trace
run: |
$Env:PYTHONASYNCIODEBUG = 1
pytest --full-trace -o log_cli_level=DEBUG tests
22 changes: 21 additions & 1 deletion buganime/transcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,25 @@ async def __read_input_frames(self) -> AsyncIterator[bytes]:
args = ('-i', self.__input_path,
'-f', 'rawvideo', '-pix_fmt', 'rgb24', 'pipe:',
'-loglevel', 'warning')
proc = await asyncio.subprocess.create_subprocess_exec('ffmpeg', *args, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
logging.warning('c0')
proc = await asyncio.subprocess.create_subprocess_exec('ffmpeg', *args)
logging.warning('c-1')
assert proc.stdout
assert proc.stderr
try:
frame_length = self.__video_info.width * self.__video_info.height * 3
with contextlib.suppress(asyncio.IncompleteReadError):
while True:
logging.warning('c1')
yield await proc.stdout.readexactly(frame_length)
logging.warning('c2')
finally:
logging.warning('c3')
with contextlib.suppress(ProcessLookupError):
proc.terminate()
logging.info('ffmpeg input: %s', str(await proc.stderr.read()))
await proc.wait()
logging.warning('c4')

async def __write_output_frames(self, frames: AsyncIterator[bytes]) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
Expand Down Expand Up @@ -115,29 +121,43 @@ async def __write_output_frames(self, frames: AsyncIterator[bytes]) -> None:
@retry.retry(RuntimeError, tries=10, delay=1)
def __gpu_upscale(self, frame: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
logging.warning('x')
if torch.cuda.is_available():
frame = frame.cuda()
logging.warning('a')
frame_float = frame.permute(2, 0, 1).half() / 255
logging.warning('b')
frame_upscaled_float = self.__model(frame_float.unsqueeze(0)).data.squeeze().clamp_(0, 1)
logging.warning('c')
return cast(torch.Tensor, (frame_upscaled_float * 255.0).round().byte().permute(1, 2, 0).cpu())

async def __upscale_frame(self, frame: bytes) -> bytes:
logging.warning('a1')
if self.__video_info.height == self.__height_out:
return frame
logging.warning('a2')
with torch.no_grad():
with warnings.catch_warnings(action='ignore'):
frame_arr = torch.frombuffer(frame, dtype=torch.uint8).reshape([self.__video_info.height, self.__video_info.width, 3])
logging.warning('a3')
assert self.__gpu_lock
async with self.__gpu_lock:
logging.warning('a4')
frame_cpu = await asyncio.to_thread(self.__gpu_upscale, frame_arr)
logging.warning('a5')
logging.warning('a6')
return cast(bytes, await asyncio.to_thread(
lambda: cv2.resize(frame_cpu.numpy(), (self.__width_out, self.__height_out), interpolation=cv2.INTER_LANCZOS4).tobytes()))

async def __generate_upscaling_tasks(self) -> None:
assert self.__frame_tasks_queue
logging.warning('b2')
async for frame in self.__read_input_frames():
logging.warning('b3')
await self.__frame_tasks_queue.put(asyncio.create_task(self.__upscale_frame(frame)))
logging.warning('b4')
await self.__frame_tasks_queue.put(None)
logging.warning('b5')

async def __get_output_frames(self) -> AsyncIterator[bytes]:
assert self.__frame_tasks_queue
Expand Down

0 comments on commit 3fc5604

Please sign in to comment.