Skip to content

Commit

Permalink
feat: support cpu transcoding
Browse files Browse the repository at this point in the history
  • Loading branch information
bugale committed Apr 27, 2024
1 parent 37ec162 commit 72dcc36
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions buganime/transcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ def __init__(self, input_path: str, output_path: str, height_out: int, width_out
self.__width_out = width_out
model = Transcoder.Module(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4)
model.load_state_dict(torch.load(MODEL_PATH)['params'], strict=True)
self.__model = model.eval().cuda().half()
if torch.cuda.is_available():
self.__model = model.eval().cuda().half()
else:
self.__model = model.eval().half()
self.__gpu_lock: Optional[asyncio.Lock] = None
self.__frame_tasks_queue: Optional[asyncio.Queue[Optional[asyncio.Task[bytes]]]] = None

Expand Down Expand Up @@ -111,7 +114,9 @@ async def __write_output_frames(self, frames: AsyncIterator[bytes]) -> None:
@retry.retry(RuntimeError, tries=10, delay=1)
def __gpu_upscale(self, frame: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
frame_float = frame.cuda().permute(2, 0, 1).half() / 255
if torch.cuda.is_available():
frame = frame.cuda()
frame_float = frame.permute(2, 0, 1).half() / 255
frame_upscaled_float = self.__model(frame_float.unsqueeze(0)).data.squeeze().clamp_(0, 1)
return cast(torch.Tensor, (frame_upscaled_float * 255.0).round().byte().permute(1, 2, 0).cpu())

Expand Down

0 comments on commit 72dcc36

Please sign in to comment.