diff --git a/manga_ocr/ocr.py b/manga_ocr/ocr.py index 4e028ef..b4cc825 100644 --- a/manga_ocr/ocr.py +++ b/manga_ocr/ocr.py @@ -18,6 +18,9 @@ class MangaOcr: if not force_cpu and torch.cuda.is_available(): logger.info('Using CUDA') self.model.cuda() + if not force_cpu and torch.backends.mps.is_available(): + logger.info('Using MPS') + self.model.to('mps') else: logger.info('Using CPU')