diff --git a/manga_ocr/__init__.py b/manga_ocr/__init__.py index bf78026..2cc257a 100644 --- a/manga_ocr/__init__.py +++ b/manga_ocr/__init__.py @@ -3,6 +3,7 @@ __version__ = '0.1.10' from manga_ocr.ocr import MangaOcr from manga_ocr.ocr import GoogleVision from manga_ocr.ocr import AppleVision +from manga_ocr.ocr import WinRTOCR from manga_ocr.ocr import AzureComputerVision from manga_ocr.ocr import EasyOCR from manga_ocr.ocr import PaddleOCR diff --git a/manga_ocr/ocr.py b/manga_ocr/ocr.py index 499064b..ccc0c26 100644 --- a/manga_ocr/ocr.py +++ b/manga_ocr/ocr.py @@ -3,7 +3,6 @@ import os import io from pathlib import Path import warnings -import configparser import time import sys import platform @@ -11,6 +10,7 @@ import platform import jaconv import torch import numpy as np +import json from PIL import Image from loguru import logger from transformers import ViTImageProcessor, AutoTokenizer, VisionEncoderDecoderModel @@ -44,8 +44,38 @@ try: except ImportError: pass +try: + import requests +except ImportError: + pass + +try: + import winocr +except ImportError: + pass + + +def post_process(text): + text = ''.join(text.split()) + text = text.replace('…', '...') + text = re.sub('[・.]{2,}', lambda x: (x.end() - x.start()) * '.', text) + text = jaconv.h2z(text, ascii=True, digit=True) + + return text + + class MangaOcr: - def __init__(self, pretrained_model_name_or_path='kha-white/manga-ocr-base', force_cpu=False): + name = "mangaocr" + readable_name = "Manga OCR" + key = "m" + available = True + + def __init__(self, config={'pretrained_model_name_or_path':'kha-white/manga-ocr-base','force_cpu':'False'}, pretrained_model_name_or_path='', force_cpu=False): + if pretrained_model_name_or_path == '': + pretrained_model_name_or_path = config['pretrained_model_name_or_path'] + if config['force_cpu'] == 'True': + force_cpu = True + logger.info(f'Loading Manga OCR model from {pretrained_model_name_or_path}') self.processor = ViTImageProcessor.from_pretrained(pretrained_model_name_or_path) self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) @@ -84,10 +114,14 @@ class MangaOcr: return pixel_values.squeeze() class GoogleVision: + name = "gvision" + readable_name = "Google Vision" + key = "g" + available = False + def __init__(self): if 'google.cloud' not in sys.modules: logger.warning('google-cloud-vision not available, Google Vision will not work!') - self.available = False else: logger.info(f'Parsing Google credentials') google_credentials_file = os.path.join(os.path.expanduser('~'),'.config','google_vision.json') @@ -98,12 +132,8 @@ class GoogleVision: logger.info('Google Vision ready') except: logger.warning('Error parsing Google credentials, Google Vision will not work!') - self.available = False def __call__(self, img_or_path): - if not self.available: - return "Engine not available!" - if isinstance(img_or_path, str) or isinstance(img_or_path, Path): img = Image.open(img_or_path) elif isinstance(img_or_path, Image.Image): @@ -124,25 +154,24 @@ class GoogleVision: return image_bytes.getvalue() class AppleVision: + name = "avision" + readable_name = "Apple Vision" + key = "a" + available = False + def __init__(self): if sys.platform != "darwin": logger.warning('Apple Vision is not supported on non-macOS platforms!') - self.available = False elif int(platform.mac_ver()[0].split('.')[0]) < 13: logger.warning('Apple Vision is not supported on macOS older than Ventura/13.0!') - self.available = False else: if 'objc' not in sys.modules: logger.warning('pyobjc not available, Apple Vision will not work!') - self.available = False else: self.available = True logger.info('Apple Vision ready') def __call__(self, img_or_path): - if not self.available: - return "Engine not available!" - if isinstance(img_or_path, str) or isinstance(img_or_path, Path): img = Image.open(img_or_path) elif isinstance(img_or_path, Image.Image): @@ -177,28 +206,78 @@ class AppleVision: img.save(image_bytes, format=img.format) return image_bytes.getvalue() +class WinRTOCR: + name = "winrtocr" + readable_name = "WinRT OCR" + key = "w" + available = False + + def __init__(self, config={}): + if os.name == 'nt': + if int(platform.release()) < 10: + logger.warning('WinRT OCR is not supported on Windows older than 10!') + elif 'winocr' not in sys.modules: + logger.warning('winocr not available, WinRT OCR will not work!') + else: + self.available = True + logger.info('WinRT OCR ready') + else: + if 'requests' not in sys.modules: + logger.warning('requests not available, WinRT OCR will not work!') + else: + try: + self.url = config['url'] + self.available = True + logger.info('WinRT OCR ready') + except: + logger.warning('Error reading URL from config, WinRT OCR will not work!') + + def __call__(self, img_or_path): + if isinstance(img_or_path, str) or isinstance(img_or_path, Path): + img = Image.open(img_or_path) + elif isinstance(img_or_path, Image.Image): + img = img_or_path + else: + raise ValueError(f'img_or_path must be a path or PIL.Image, instead got: {img_or_path}') + + if os.name == 'nt': + res = winocr.recognize_pil_sync(img, lang='ja')['text'] + else: + params = {'lang': 'ja'} + try: + res = requests.post(self.url, params=params, data=self._preprocess(img), timeout=3) + except requests.exceptions.Timeout: + return "Request timeout!" + + res = json.loads(res.text)['text'] + + x = post_process(res) + return x + + def _preprocess(self, img): + image_bytes = io.BytesIO() + img.save(image_bytes, format=img.format) + return image_bytes.getvalue() + class AzureComputerVision: - def __init__(self): + name = "azure" + readable_name = "Azure Computer Vision" + key = "v" + available = False + + def __init__(self, config={}): if 'azure.cognitiveservices.vision.computervision' not in sys.modules: logger.warning('azure-cognitiveservices-vision-computervision not available, Azure Computer Vision will not work!') - self.available = False else: logger.info(f'Parsing Azure credentials') - azure_credentials_file = os.path.join(os.path.expanduser('~'),'.config','azure_computer_vision.ini') try: - azure_credentials = configparser.ConfigParser() - azure_credentials.read(azure_credentials_file) - self.client = ComputerVisionClient(azure_credentials['config']['endpoint'], CognitiveServicesCredentials(azure_credentials['config']['api_key'])) + self.client = ComputerVisionClient(config['endpoint'], CognitiveServicesCredentials(config['api_key'])) self.available = True logger.info('Azure Computer Vision ready') except: logger.warning('Error parsing Azure credentials, Azure Computer Vision will not work!') - self.available = False def __call__(self, img_or_path): - if not self.available: - return "Engine not available!" - if isinstance(img_or_path, str) or isinstance(img_or_path, Path): img = Image.open(img_or_path) elif isinstance(img_or_path, Image.Image): @@ -234,10 +313,14 @@ class AzureComputerVision: return image_io class EasyOCR: + name = "easyocr" + readable_name = "EasyOCR" + key = "e" + available = False + def __init__(self): if 'easyocr' not in sys.modules: logger.warning('easyocr not available, EasyOCR will not work!') - self.available = False else: logger.info('Loading EasyOCR model') self.model = easyocr.Reader(['ja','en']) @@ -245,9 +328,6 @@ class EasyOCR: logger.info('EasyOCR ready') def __call__(self, img_or_path): - if not self.available: - return "Engine not available!" - if isinstance(img_or_path, str) or isinstance(img_or_path, Path): img = Image.open(img_or_path) elif isinstance(img_or_path, Image.Image): @@ -269,10 +349,14 @@ class EasyOCR: return image_bytes.getvalue() class PaddleOCR: + name = "paddleocr" + readable_name = "PaddleOCR" + key = "o" + available = False + def __init__(self): if 'paddleocr' not in sys.modules: - logger.warning('easyocr not available, PaddleOCR will not work!') - self.available = False + logger.warning('paddleocr not available, PaddleOCR will not work!') else: logger.info('Loading PaddleOCR model') self.model = POCR(use_angle_cls=True, show_log=False, lang='japan') @@ -280,9 +364,6 @@ class PaddleOCR: logger.info('PaddleOCR ready') def __call__(self, img_or_path): - if not self.available: - return "Engine not available!" - if isinstance(img_or_path, str) or isinstance(img_or_path, Path): img = Image.open(img_or_path) elif isinstance(img_or_path, Image.Image): @@ -302,12 +383,3 @@ class PaddleOCR: def _preprocess(self, img): return np.array(img.convert('RGB')) - - -def post_process(text): - text = ''.join(text.split()) - text = text.replace('…', '...') - text = re.sub('[・.]{2,}', lambda x: (x.end() - x.start()) * '.', text) - text = jaconv.h2z(text, ascii=True, digit=True) - - return text \ No newline at end of file diff --git a/manga_ocr/run.py b/manga_ocr/run.py index 7a4c467..542e28e 100644 --- a/manga_ocr/run.py +++ b/manga_ocr/run.py @@ -2,6 +2,7 @@ import sys import time import threading import os +import configparser from pathlib import Path import fire @@ -25,12 +26,12 @@ def are_images_identical(img1, img2): return (img1.shape == img2.shape) and (img1 == img2).all() -def process_and_write_results(engine_instance, engine_name, img_or_path, write_to): +def process_and_write_results(engine_instance, img_or_path, write_to): t0 = time.time() text = engine_instance(img_or_path) t1 = time.time() - logger.opt(ansi=True).info(f"Text recognized in {t1 - t0:0.03f}s using {engine_name}: {text}") + logger.opt(ansi=True).info(f"Text recognized in {t1 - t0:0.03f}s using {engine_instance.readable_name}: {text}") if write_to == 'clipboard': pyperclip.copy(text) @@ -49,10 +50,8 @@ def get_path_key(path): def run(read_from='clipboard', write_to='clipboard', - pretrained_model_name_or_path='kha-white/manga-ocr-base', - force_cpu=False, delay_secs=0.5, - engine='mangaocr', + engine='', start_paused=False, verbose=False ): @@ -62,10 +61,8 @@ def run(read_from='clipboard', :param read_from: Specifies where to read input images from. Can be either "clipboard", or a path to a directory. :param write_to: Specifies where to save recognized texts to. Can be either "clipboard", or a path to a text file. - :param pretrained_model_name_or_path: Path to a trained model, either local or from Transformers' model hub. - :param force_cpu: If True, OCR will use CPU even if GPU is available. :param delay_secs: How often to check for new images, in seconds. - :param engine: OCR engine to use. Available: "mangaocr", "gvision", "avision", "azure", "easyocr", "paddleocr". + :param engine: OCR engine to use. Available: "mangaocr", "gvision", "avision", "azure", "winrtocr", "easyocr", "paddleocr". :param start_paused: Pause at startup. :param verbose: If True, unhides all warnings. """ @@ -78,28 +75,47 @@ def run(read_from='clipboard', } logger.configure(**config) - avision = AppleVision() - gvision = GoogleVision() - azure = AzureComputerVision() - mangaocr = MangaOcr(pretrained_model_name_or_path, force_cpu) - easyocr = EasyOCR() - paddleocr = PaddleOCR() + engine_classes = [AppleVision, WinRTOCR, GoogleVision, AzureComputerVision, MangaOcr, EasyOCR, PaddleOCR] + engines = [] + engine_instances = [] + engine_keys = [] - engines = ['avision', 'gvision', 'azure', 'mangaocr', 'easyocr', 'paddleocr'] - engine_names = ['Apple Vision', 'Google Vision', 'Azure Computer Vision', 'Manga OCR', 'EasyOCR', 'PaddleOCR'] - engine_instances = [avision, gvision, azure, mangaocr, easyocr, paddleocr] - engine_keys = 'agvmeo' + logger.info(f'Parsing config file') + config_file = os.path.join(os.path.expanduser('~'),'.config','ocr_config.ini') + config = configparser.ConfigParser() + res = config.read(config_file) - def get_engine_name(engine): - return engine_names[engines.index(engine)] + if len(res) == 0: + logger.warning('No config file, defaults will be used') + else: + try: + for config_engine in config['common']['engines'].split(','): + engines.append(config_engine.strip()) + except KeyError: + pass - if engine not in engines: - msg = 'Unknown OCR engine!' + default_engine = '' + for engine_class in engine_classes: + if len(engines) == 0 or engine_class.name in engines: + try: + engine_instance = engine_class(config[engine_class.name]) + except KeyError: + engine_instance = engine_class() + + if engine_instance.available: + engine_instances.append(engine_instance) + engine_keys.append(engine_class.key) + if engine == engine_class.name: + default_engine = engine_class.key + + if len(engine_keys) == 0: + msg = 'No engines available!' raise NotImplementedError(msg) + engine_index = engine_keys.index(default_engine) if default_engine != '' else 0 + if sys.platform not in ('darwin', 'win32') and write_to == 'clipboard': # Check if the system is using Wayland - import os if os.environ.get('WAYLAND_DISPLAY'): # Check if the wl-clipboard package is installed if os.system("which wl-copy > /dev/null") == 0: @@ -119,7 +135,7 @@ def run(read_from='clipboard', tmp_paused = False img = None - logger.opt(ansi=True).info(f"Reading from clipboard using {get_engine_name(engine)}{' (paused)' if paused else ''}") + logger.opt(ansi=True).info(f"Reading from clipboard using {engine_instances[engine_index].readable_name}{' (paused)' if paused else ''}") def on_key_press(key): global tmp_paused @@ -142,7 +158,7 @@ def run(read_from='clipboard', if not read_from.is_dir(): raise ValueError('read_from must be either "clipboard" or a path to a directory') - logger.opt(ansi=True).info(f'Reading from directory {read_from} using {get_engine_name(engine)}') + logger.opt(ansi=True).info(f'Reading from directory {read_from} using {engine_instances[engine_index].readable_name}') old_paths = set() for path in read_from.iterdir(): @@ -150,7 +166,6 @@ def run(read_from='clipboard', def getchar_thread(): global user_input - import os if os.name == 'nt': # how it works on windows import msvcrt while True: @@ -184,6 +199,9 @@ def run(read_from='clipboard', user_input_thread.join() logger.info('Terminated!') break + + new_engine_index = engine_index + if read_from == 'clipboard' and user_input.lower() == 'p': if paused: logger.info('Unpaused!') @@ -192,17 +210,16 @@ def run(read_from='clipboard', logger.info('Paused!') paused = not paused elif user_input.lower() == 's': - if engine == engines[-1]: - engine = engines[0] + if engine_index == len(engine_keys) - 1: + new_engine_index = 0 else: - engine = engines[engines.index(engine) + 1] - - logger.opt(ansi=True).info(f"Switched to {get_engine_name(engine)}!") + new_engine_index = engine_index + 1 elif user_input.lower() in engine_keys: - new_engine = engines[engine_keys.find(user_input.lower())] - if engine != new_engine: - engine = new_engine - logger.opt(ansi=True).info(f"Switched to {get_engine_name(engine)}!") + new_engine_index = engine_keys.index(user_input.lower()) + + if engine_index != new_engine_index: + engine_index = new_engine_index + logger.opt(ansi=True).info(f"Switched to {engine_instances[engine_index].readable_name}!") user_input = '' @@ -223,7 +240,7 @@ def run(read_from='clipboard', logger.warning('Error while reading from clipboard ({})'.format(error)) else: if not just_unpaused and isinstance(img, Image.Image) and not are_images_identical(img, old_img): - process_and_write_results(engine_instances[engines.index(engine)], get_engine_name(engine), img, write_to) + process_and_write_results(engine_instances[engine_index], img, write_to) if just_unpaused: just_unpaused = False @@ -239,7 +256,7 @@ def run(read_from='clipboard', except (UnidentifiedImageError, OSError) as e: logger.warning(f'Error while reading file {path}: {e}') else: - process_and_write_results(engine_instances[engines.index(engine)], get_engine_name(engine), img, write_to) + process_and_write_results(engine_instances[engine_index], img, write_to) time.sleep(delay_secs) diff --git a/requirements.txt b/requirements.txt index 09a534e..7960522 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,6 @@ transformers>=4.25.0 unidic_lite google-cloud-vision azure-cognitiveservices-vision-computervision -pyobjc pynput easyocr paddleocr \ No newline at end of file