Rework stuff, add Windows 10/11 OCR

This commit is contained in:
AuroraWright
2023-12-13 10:33:47 +01:00
parent 3eb8db3a48
commit 25d101f806
4 changed files with 168 additions and 79 deletions

View File

@@ -3,6 +3,7 @@ __version__ = '0.1.10'
from manga_ocr.ocr import MangaOcr from manga_ocr.ocr import MangaOcr
from manga_ocr.ocr import GoogleVision from manga_ocr.ocr import GoogleVision
from manga_ocr.ocr import AppleVision from manga_ocr.ocr import AppleVision
from manga_ocr.ocr import WinRTOCR
from manga_ocr.ocr import AzureComputerVision from manga_ocr.ocr import AzureComputerVision
from manga_ocr.ocr import EasyOCR from manga_ocr.ocr import EasyOCR
from manga_ocr.ocr import PaddleOCR from manga_ocr.ocr import PaddleOCR

View File

@@ -3,7 +3,6 @@ import os
import io import io
from pathlib import Path from pathlib import Path
import warnings import warnings
import configparser
import time import time
import sys import sys
import platform import platform
@@ -11,6 +10,7 @@ import platform
import jaconv import jaconv
import torch import torch
import numpy as np import numpy as np
import json
from PIL import Image from PIL import Image
from loguru import logger from loguru import logger
from transformers import ViTImageProcessor, AutoTokenizer, VisionEncoderDecoderModel from transformers import ViTImageProcessor, AutoTokenizer, VisionEncoderDecoderModel
@@ -44,8 +44,38 @@ try:
except ImportError: except ImportError:
pass pass
try:
import requests
except ImportError:
pass
try:
import winocr
except ImportError:
pass
def post_process(text):
text = ''.join(text.split())
text = text.replace('', '...')
text = re.sub('[・.]{2,}', lambda x: (x.end() - x.start()) * '.', text)
text = jaconv.h2z(text, ascii=True, digit=True)
return text
class MangaOcr: class MangaOcr:
def __init__(self, pretrained_model_name_or_path='kha-white/manga-ocr-base', force_cpu=False): name = "mangaocr"
readable_name = "Manga OCR"
key = "m"
available = True
def __init__(self, config={'pretrained_model_name_or_path':'kha-white/manga-ocr-base','force_cpu':'False'}, pretrained_model_name_or_path='', force_cpu=False):
if pretrained_model_name_or_path == '':
pretrained_model_name_or_path = config['pretrained_model_name_or_path']
if config['force_cpu'] == 'True':
force_cpu = True
logger.info(f'Loading Manga OCR model from {pretrained_model_name_or_path}') logger.info(f'Loading Manga OCR model from {pretrained_model_name_or_path}')
self.processor = ViTImageProcessor.from_pretrained(pretrained_model_name_or_path) self.processor = ViTImageProcessor.from_pretrained(pretrained_model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
@@ -84,10 +114,14 @@ class MangaOcr:
return pixel_values.squeeze() return pixel_values.squeeze()
class GoogleVision: class GoogleVision:
name = "gvision"
readable_name = "Google Vision"
key = "g"
available = False
def __init__(self): def __init__(self):
if 'google.cloud' not in sys.modules: if 'google.cloud' not in sys.modules:
logger.warning('google-cloud-vision not available, Google Vision will not work!') logger.warning('google-cloud-vision not available, Google Vision will not work!')
self.available = False
else: else:
logger.info(f'Parsing Google credentials') logger.info(f'Parsing Google credentials')
google_credentials_file = os.path.join(os.path.expanduser('~'),'.config','google_vision.json') google_credentials_file = os.path.join(os.path.expanduser('~'),'.config','google_vision.json')
@@ -98,12 +132,8 @@ class GoogleVision:
logger.info('Google Vision ready') logger.info('Google Vision ready')
except: except:
logger.warning('Error parsing Google credentials, Google Vision will not work!') logger.warning('Error parsing Google credentials, Google Vision will not work!')
self.available = False
def __call__(self, img_or_path): def __call__(self, img_or_path):
if not self.available:
return "Engine not available!"
if isinstance(img_or_path, str) or isinstance(img_or_path, Path): if isinstance(img_or_path, str) or isinstance(img_or_path, Path):
img = Image.open(img_or_path) img = Image.open(img_or_path)
elif isinstance(img_or_path, Image.Image): elif isinstance(img_or_path, Image.Image):
@@ -124,25 +154,24 @@ class GoogleVision:
return image_bytes.getvalue() return image_bytes.getvalue()
class AppleVision: class AppleVision:
name = "avision"
readable_name = "Apple Vision"
key = "a"
available = False
def __init__(self): def __init__(self):
if sys.platform != "darwin": if sys.platform != "darwin":
logger.warning('Apple Vision is not supported on non-macOS platforms!') logger.warning('Apple Vision is not supported on non-macOS platforms!')
self.available = False
elif int(platform.mac_ver()[0].split('.')[0]) < 13: elif int(platform.mac_ver()[0].split('.')[0]) < 13:
logger.warning('Apple Vision is not supported on macOS older than Ventura/13.0!') logger.warning('Apple Vision is not supported on macOS older than Ventura/13.0!')
self.available = False
else: else:
if 'objc' not in sys.modules: if 'objc' not in sys.modules:
logger.warning('pyobjc not available, Apple Vision will not work!') logger.warning('pyobjc not available, Apple Vision will not work!')
self.available = False
else: else:
self.available = True self.available = True
logger.info('Apple Vision ready') logger.info('Apple Vision ready')
def __call__(self, img_or_path): def __call__(self, img_or_path):
if not self.available:
return "Engine not available!"
if isinstance(img_or_path, str) or isinstance(img_or_path, Path): if isinstance(img_or_path, str) or isinstance(img_or_path, Path):
img = Image.open(img_or_path) img = Image.open(img_or_path)
elif isinstance(img_or_path, Image.Image): elif isinstance(img_or_path, Image.Image):
@@ -177,28 +206,78 @@ class AppleVision:
img.save(image_bytes, format=img.format) img.save(image_bytes, format=img.format)
return image_bytes.getvalue() return image_bytes.getvalue()
class WinRTOCR:
name = "winrtocr"
readable_name = "WinRT OCR"
key = "w"
available = False
def __init__(self, config={}):
if os.name == 'nt':
if int(platform.release()) < 10:
logger.warning('WinRT OCR is not supported on Windows older than 10!')
elif 'winocr' not in sys.modules:
logger.warning('winocr not available, WinRT OCR will not work!')
else:
self.available = True
logger.info('WinRT OCR ready')
else:
if 'requests' not in sys.modules:
logger.warning('requests not available, WinRT OCR will not work!')
else:
try:
self.url = config['url']
self.available = True
logger.info('WinRT OCR ready')
except:
logger.warning('Error reading URL from config, WinRT OCR will not work!')
def __call__(self, img_or_path):
if isinstance(img_or_path, str) or isinstance(img_or_path, Path):
img = Image.open(img_or_path)
elif isinstance(img_or_path, Image.Image):
img = img_or_path
else:
raise ValueError(f'img_or_path must be a path or PIL.Image, instead got: {img_or_path}')
if os.name == 'nt':
res = winocr.recognize_pil_sync(img, lang='ja')['text']
else:
params = {'lang': 'ja'}
try:
res = requests.post(self.url, params=params, data=self._preprocess(img), timeout=3)
except requests.exceptions.Timeout:
return "Request timeout!"
res = json.loads(res.text)['text']
x = post_process(res)
return x
def _preprocess(self, img):
image_bytes = io.BytesIO()
img.save(image_bytes, format=img.format)
return image_bytes.getvalue()
class AzureComputerVision: class AzureComputerVision:
def __init__(self): name = "azure"
readable_name = "Azure Computer Vision"
key = "v"
available = False
def __init__(self, config={}):
if 'azure.cognitiveservices.vision.computervision' not in sys.modules: if 'azure.cognitiveservices.vision.computervision' not in sys.modules:
logger.warning('azure-cognitiveservices-vision-computervision not available, Azure Computer Vision will not work!') logger.warning('azure-cognitiveservices-vision-computervision not available, Azure Computer Vision will not work!')
self.available = False
else: else:
logger.info(f'Parsing Azure credentials') logger.info(f'Parsing Azure credentials')
azure_credentials_file = os.path.join(os.path.expanduser('~'),'.config','azure_computer_vision.ini')
try: try:
azure_credentials = configparser.ConfigParser() self.client = ComputerVisionClient(config['endpoint'], CognitiveServicesCredentials(config['api_key']))
azure_credentials.read(azure_credentials_file)
self.client = ComputerVisionClient(azure_credentials['config']['endpoint'], CognitiveServicesCredentials(azure_credentials['config']['api_key']))
self.available = True self.available = True
logger.info('Azure Computer Vision ready') logger.info('Azure Computer Vision ready')
except: except:
logger.warning('Error parsing Azure credentials, Azure Computer Vision will not work!') logger.warning('Error parsing Azure credentials, Azure Computer Vision will not work!')
self.available = False
def __call__(self, img_or_path): def __call__(self, img_or_path):
if not self.available:
return "Engine not available!"
if isinstance(img_or_path, str) or isinstance(img_or_path, Path): if isinstance(img_or_path, str) or isinstance(img_or_path, Path):
img = Image.open(img_or_path) img = Image.open(img_or_path)
elif isinstance(img_or_path, Image.Image): elif isinstance(img_or_path, Image.Image):
@@ -234,10 +313,14 @@ class AzureComputerVision:
return image_io return image_io
class EasyOCR: class EasyOCR:
name = "easyocr"
readable_name = "EasyOCR"
key = "e"
available = False
def __init__(self): def __init__(self):
if 'easyocr' not in sys.modules: if 'easyocr' not in sys.modules:
logger.warning('easyocr not available, EasyOCR will not work!') logger.warning('easyocr not available, EasyOCR will not work!')
self.available = False
else: else:
logger.info('Loading EasyOCR model') logger.info('Loading EasyOCR model')
self.model = easyocr.Reader(['ja','en']) self.model = easyocr.Reader(['ja','en'])
@@ -245,9 +328,6 @@ class EasyOCR:
logger.info('EasyOCR ready') logger.info('EasyOCR ready')
def __call__(self, img_or_path): def __call__(self, img_or_path):
if not self.available:
return "Engine not available!"
if isinstance(img_or_path, str) or isinstance(img_or_path, Path): if isinstance(img_or_path, str) or isinstance(img_or_path, Path):
img = Image.open(img_or_path) img = Image.open(img_or_path)
elif isinstance(img_or_path, Image.Image): elif isinstance(img_or_path, Image.Image):
@@ -269,10 +349,14 @@ class EasyOCR:
return image_bytes.getvalue() return image_bytes.getvalue()
class PaddleOCR: class PaddleOCR:
name = "paddleocr"
readable_name = "PaddleOCR"
key = "o"
available = False
def __init__(self): def __init__(self):
if 'paddleocr' not in sys.modules: if 'paddleocr' not in sys.modules:
logger.warning('easyocr not available, PaddleOCR will not work!') logger.warning('paddleocr not available, PaddleOCR will not work!')
self.available = False
else: else:
logger.info('Loading PaddleOCR model') logger.info('Loading PaddleOCR model')
self.model = POCR(use_angle_cls=True, show_log=False, lang='japan') self.model = POCR(use_angle_cls=True, show_log=False, lang='japan')
@@ -280,9 +364,6 @@ class PaddleOCR:
logger.info('PaddleOCR ready') logger.info('PaddleOCR ready')
def __call__(self, img_or_path): def __call__(self, img_or_path):
if not self.available:
return "Engine not available!"
if isinstance(img_or_path, str) or isinstance(img_or_path, Path): if isinstance(img_or_path, str) or isinstance(img_or_path, Path):
img = Image.open(img_or_path) img = Image.open(img_or_path)
elif isinstance(img_or_path, Image.Image): elif isinstance(img_or_path, Image.Image):
@@ -302,12 +383,3 @@ class PaddleOCR:
def _preprocess(self, img): def _preprocess(self, img):
return np.array(img.convert('RGB')) return np.array(img.convert('RGB'))
def post_process(text):
text = ''.join(text.split())
text = text.replace('', '...')
text = re.sub('[・.]{2,}', lambda x: (x.end() - x.start()) * '.', text)
text = jaconv.h2z(text, ascii=True, digit=True)
return text

View File

@@ -2,6 +2,7 @@ import sys
import time import time
import threading import threading
import os import os
import configparser
from pathlib import Path from pathlib import Path
import fire import fire
@@ -25,12 +26,12 @@ def are_images_identical(img1, img2):
return (img1.shape == img2.shape) and (img1 == img2).all() return (img1.shape == img2.shape) and (img1 == img2).all()
def process_and_write_results(engine_instance, engine_name, img_or_path, write_to): def process_and_write_results(engine_instance, img_or_path, write_to):
t0 = time.time() t0 = time.time()
text = engine_instance(img_or_path) text = engine_instance(img_or_path)
t1 = time.time() t1 = time.time()
logger.opt(ansi=True).info(f"Text recognized in {t1 - t0:0.03f}s using <cyan>{engine_name}</cyan>: {text}") logger.opt(ansi=True).info(f"Text recognized in {t1 - t0:0.03f}s using <cyan>{engine_instance.readable_name}</cyan>: {text}")
if write_to == 'clipboard': if write_to == 'clipboard':
pyperclip.copy(text) pyperclip.copy(text)
@@ -49,10 +50,8 @@ def get_path_key(path):
def run(read_from='clipboard', def run(read_from='clipboard',
write_to='clipboard', write_to='clipboard',
pretrained_model_name_or_path='kha-white/manga-ocr-base',
force_cpu=False,
delay_secs=0.5, delay_secs=0.5,
engine='mangaocr', engine='',
start_paused=False, start_paused=False,
verbose=False verbose=False
): ):
@@ -62,10 +61,8 @@ def run(read_from='clipboard',
:param read_from: Specifies where to read input images from. Can be either "clipboard", or a path to a directory. :param read_from: Specifies where to read input images from. Can be either "clipboard", or a path to a directory.
:param write_to: Specifies where to save recognized texts to. Can be either "clipboard", or a path to a text file. :param write_to: Specifies where to save recognized texts to. Can be either "clipboard", or a path to a text file.
:param pretrained_model_name_or_path: Path to a trained model, either local or from Transformers' model hub.
:param force_cpu: If True, OCR will use CPU even if GPU is available.
:param delay_secs: How often to check for new images, in seconds. :param delay_secs: How often to check for new images, in seconds.
:param engine: OCR engine to use. Available: "mangaocr", "gvision", "avision", "azure", "easyocr", "paddleocr". :param engine: OCR engine to use. Available: "mangaocr", "gvision", "avision", "azure", "winrtocr", "easyocr", "paddleocr".
:param start_paused: Pause at startup. :param start_paused: Pause at startup.
:param verbose: If True, unhides all warnings. :param verbose: If True, unhides all warnings.
""" """
@@ -78,28 +75,47 @@ def run(read_from='clipboard',
} }
logger.configure(**config) logger.configure(**config)
avision = AppleVision() engine_classes = [AppleVision, WinRTOCR, GoogleVision, AzureComputerVision, MangaOcr, EasyOCR, PaddleOCR]
gvision = GoogleVision() engines = []
azure = AzureComputerVision() engine_instances = []
mangaocr = MangaOcr(pretrained_model_name_or_path, force_cpu) engine_keys = []
easyocr = EasyOCR()
paddleocr = PaddleOCR()
engines = ['avision', 'gvision', 'azure', 'mangaocr', 'easyocr', 'paddleocr'] logger.info(f'Parsing config file')
engine_names = ['Apple Vision', 'Google Vision', 'Azure Computer Vision', 'Manga OCR', 'EasyOCR', 'PaddleOCR'] config_file = os.path.join(os.path.expanduser('~'),'.config','ocr_config.ini')
engine_instances = [avision, gvision, azure, mangaocr, easyocr, paddleocr] config = configparser.ConfigParser()
engine_keys = 'agvmeo' res = config.read(config_file)
def get_engine_name(engine): if len(res) == 0:
return engine_names[engines.index(engine)] logger.warning('No config file, defaults will be used')
else:
try:
for config_engine in config['common']['engines'].split(','):
engines.append(config_engine.strip())
except KeyError:
pass
if engine not in engines: default_engine = ''
msg = 'Unknown OCR engine!' for engine_class in engine_classes:
if len(engines) == 0 or engine_class.name in engines:
try:
engine_instance = engine_class(config[engine_class.name])
except KeyError:
engine_instance = engine_class()
if engine_instance.available:
engine_instances.append(engine_instance)
engine_keys.append(engine_class.key)
if engine == engine_class.name:
default_engine = engine_class.key
if len(engine_keys) == 0:
msg = 'No engines available!'
raise NotImplementedError(msg) raise NotImplementedError(msg)
engine_index = engine_keys.index(default_engine) if default_engine != '' else 0
if sys.platform not in ('darwin', 'win32') and write_to == 'clipboard': if sys.platform not in ('darwin', 'win32') and write_to == 'clipboard':
# Check if the system is using Wayland # Check if the system is using Wayland
import os
if os.environ.get('WAYLAND_DISPLAY'): if os.environ.get('WAYLAND_DISPLAY'):
# Check if the wl-clipboard package is installed # Check if the wl-clipboard package is installed
if os.system("which wl-copy > /dev/null") == 0: if os.system("which wl-copy > /dev/null") == 0:
@@ -119,7 +135,7 @@ def run(read_from='clipboard',
tmp_paused = False tmp_paused = False
img = None img = None
logger.opt(ansi=True).info(f"Reading from clipboard using <cyan>{get_engine_name(engine)}</cyan>{' (paused)' if paused else ''}") logger.opt(ansi=True).info(f"Reading from clipboard using <cyan>{engine_instances[engine_index].readable_name}</cyan>{' (paused)' if paused else ''}")
def on_key_press(key): def on_key_press(key):
global tmp_paused global tmp_paused
@@ -142,7 +158,7 @@ def run(read_from='clipboard',
if not read_from.is_dir(): if not read_from.is_dir():
raise ValueError('read_from must be either "clipboard" or a path to a directory') raise ValueError('read_from must be either "clipboard" or a path to a directory')
logger.opt(ansi=True).info(f'Reading from directory {read_from} using <cyan>{get_engine_name(engine)}</cyan>') logger.opt(ansi=True).info(f'Reading from directory {read_from} using <cyan>{engine_instances[engine_index].readable_name}</cyan>')
old_paths = set() old_paths = set()
for path in read_from.iterdir(): for path in read_from.iterdir():
@@ -150,7 +166,6 @@ def run(read_from='clipboard',
def getchar_thread(): def getchar_thread():
global user_input global user_input
import os
if os.name == 'nt': # how it works on windows if os.name == 'nt': # how it works on windows
import msvcrt import msvcrt
while True: while True:
@@ -184,6 +199,9 @@ def run(read_from='clipboard',
user_input_thread.join() user_input_thread.join()
logger.info('Terminated!') logger.info('Terminated!')
break break
new_engine_index = engine_index
if read_from == 'clipboard' and user_input.lower() == 'p': if read_from == 'clipboard' and user_input.lower() == 'p':
if paused: if paused:
logger.info('Unpaused!') logger.info('Unpaused!')
@@ -192,17 +210,16 @@ def run(read_from='clipboard',
logger.info('Paused!') logger.info('Paused!')
paused = not paused paused = not paused
elif user_input.lower() == 's': elif user_input.lower() == 's':
if engine == engines[-1]: if engine_index == len(engine_keys) - 1:
engine = engines[0] new_engine_index = 0
else: else:
engine = engines[engines.index(engine) + 1] new_engine_index = engine_index + 1
logger.opt(ansi=True).info(f"Switched to <cyan>{get_engine_name(engine)}</cyan>!")
elif user_input.lower() in engine_keys: elif user_input.lower() in engine_keys:
new_engine = engines[engine_keys.find(user_input.lower())] new_engine_index = engine_keys.index(user_input.lower())
if engine != new_engine:
engine = new_engine if engine_index != new_engine_index:
logger.opt(ansi=True).info(f"Switched to <cyan>{get_engine_name(engine)}</cyan>!") engine_index = new_engine_index
logger.opt(ansi=True).info(f"Switched to <cyan>{engine_instances[engine_index].readable_name}</cyan>!")
user_input = '' user_input = ''
@@ -223,7 +240,7 @@ def run(read_from='clipboard',
logger.warning('Error while reading from clipboard ({})'.format(error)) logger.warning('Error while reading from clipboard ({})'.format(error))
else: else:
if not just_unpaused and isinstance(img, Image.Image) and not are_images_identical(img, old_img): if not just_unpaused and isinstance(img, Image.Image) and not are_images_identical(img, old_img):
process_and_write_results(engine_instances[engines.index(engine)], get_engine_name(engine), img, write_to) process_and_write_results(engine_instances[engine_index], img, write_to)
if just_unpaused: if just_unpaused:
just_unpaused = False just_unpaused = False
@@ -239,7 +256,7 @@ def run(read_from='clipboard',
except (UnidentifiedImageError, OSError) as e: except (UnidentifiedImageError, OSError) as e:
logger.warning(f'Error while reading file {path}: {e}') logger.warning(f'Error while reading file {path}: {e}')
else: else:
process_and_write_results(engine_instances[engines.index(engine)], get_engine_name(engine), img, write_to) process_and_write_results(engine_instances[engine_index], img, write_to)
time.sleep(delay_secs) time.sleep(delay_secs)

View File

@@ -10,7 +10,6 @@ transformers>=4.25.0
unidic_lite unidic_lite
google-cloud-vision google-cloud-vision
azure-cognitiveservices-vision-computervision azure-cognitiveservices-vision-computervision
pyobjc
pynput pynput
easyocr easyocr
paddleocr paddleocr