diff --git a/buganime/transcode.py b/buganime/transcode.py index 53b343c..5f0d95d 100644 --- a/buganime/transcode.py +++ b/buganime/transcode.py @@ -55,7 +55,10 @@ def __init__(self, input_path: str, output_path: str, height_out: int, width_out self.__width_out = width_out model = Transcoder.Module(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4) model.load_state_dict(torch.load(MODEL_PATH)['params'], strict=True) - self.__model = model.eval().cuda().half() + if torch.cuda.is_available(): + self.__model = model.eval().cuda().half() + else: + self.__model = model.eval().half() self.__gpu_lock: Optional[asyncio.Lock] = None self.__frame_tasks_queue: Optional[asyncio.Queue[Optional[asyncio.Task[bytes]]]] = None @@ -111,7 +114,9 @@ async def __write_output_frames(self, frames: AsyncIterator[bytes]) -> None: @retry.retry(RuntimeError, tries=10, delay=1) def __gpu_upscale(self, frame: torch.Tensor) -> torch.Tensor: with torch.no_grad(): - frame_float = frame.cuda().permute(2, 0, 1).half() / 255 + if torch.cuda.is_available(): + frame = frame.cuda() + frame_float = frame.permute(2, 0, 1).half() / 255 frame_upscaled_float = self.__model(frame_float.unsqueeze(0)).data.squeeze().clamp_(0, 1) return cast(torch.Tensor, (frame_upscaled_float * 255.0).round().byte().permute(1, 2, 0).cpu())