Add EasyOCR/PaddleOCR, remove unneeded stuff
This commit is contained in:
@@ -4,3 +4,5 @@ from manga_ocr.ocr import MangaOcr
|
||||
from manga_ocr.ocr import GoogleVision
|
||||
from manga_ocr.ocr import AppleVision
|
||||
from manga_ocr.ocr import AzureComputerVision
|
||||
from manga_ocr.ocr import EasyOCR
|
||||
from manga_ocr.ocr import PaddleOCR
|
||||
|
||||
@@ -10,6 +10,7 @@ import platform
|
||||
|
||||
import jaconv
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from loguru import logger
|
||||
from transformers import ViTImageProcessor, AutoTokenizer, VisionEncoderDecoderModel
|
||||
@@ -33,9 +34,19 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import easyocr
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from paddleocr import PaddleOCR as POCR
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
class MangaOcr:
|
||||
def __init__(self, pretrained_model_name_or_path='kha-white/manga-ocr-base', force_cpu=False):
|
||||
logger.info(f'Loading OCR model from {pretrained_model_name_or_path}')
|
||||
logger.info(f'Loading Manga OCR model from {pretrained_model_name_or_path}')
|
||||
self.processor = ViTImageProcessor.from_pretrained(pretrained_model_name_or_path)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
|
||||
self.model = VisionEncoderDecoderModel.from_pretrained(pretrained_model_name_or_path)
|
||||
@@ -222,6 +233,76 @@ class AzureComputerVision:
|
||||
image_io.seek(0)
|
||||
return image_io
|
||||
|
||||
class EasyOCR:
|
||||
def __init__(self):
|
||||
if 'easyocr' not in sys.modules:
|
||||
logger.warning('easyocr not available, EasyOCR will not work!')
|
||||
self.available = False
|
||||
else:
|
||||
logger.info('Loading EasyOCR model')
|
||||
self.model = easyocr.Reader(['ja','en'])
|
||||
self.available = True
|
||||
logger.info('EasyOCR ready')
|
||||
|
||||
def __call__(self, img_or_path):
|
||||
if not self.available:
|
||||
return "Engine not available!"
|
||||
|
||||
if isinstance(img_or_path, str) or isinstance(img_or_path, Path):
|
||||
img = Image.open(img_or_path)
|
||||
elif isinstance(img_or_path, Image.Image):
|
||||
img = img_or_path
|
||||
else:
|
||||
raise ValueError(f'img_or_path must be a path or PIL.Image, instead got: {img_or_path}')
|
||||
|
||||
res = ''
|
||||
read_result = self.model.readtext(self._preprocess(img), detail=0)
|
||||
for text in read_result:
|
||||
res += text + ' '
|
||||
|
||||
x = post_process(res)
|
||||
return x
|
||||
|
||||
def _preprocess(self, img):
|
||||
image_bytes = io.BytesIO()
|
||||
img.save(image_bytes, format=img.format)
|
||||
return image_bytes.getvalue()
|
||||
|
||||
class PaddleOCR:
|
||||
def __init__(self):
|
||||
if 'paddleocr' not in sys.modules:
|
||||
logger.warning('easyocr not available, PaddleOCR will not work!')
|
||||
self.available = False
|
||||
else:
|
||||
logger.info('Loading PaddleOCR model')
|
||||
self.model = POCR(use_angle_cls=True, show_log=False, lang='japan')
|
||||
self.available = True
|
||||
logger.info('PaddleOCR ready')
|
||||
|
||||
def __call__(self, img_or_path):
|
||||
if not self.available:
|
||||
return "Engine not available!"
|
||||
|
||||
if isinstance(img_or_path, str) or isinstance(img_or_path, Path):
|
||||
img = Image.open(img_or_path)
|
||||
elif isinstance(img_or_path, Image.Image):
|
||||
img = img_or_path
|
||||
else:
|
||||
raise ValueError(f'img_or_path must be a path or PIL.Image, instead got: {img_or_path}')
|
||||
|
||||
res = ''
|
||||
read_results = self.model.ocr(self._preprocess(img), cls=True)
|
||||
for read_result in read_results:
|
||||
if read_result:
|
||||
for text in read_result:
|
||||
res += text[1][0] + ' '
|
||||
|
||||
x = post_process(res)
|
||||
return x
|
||||
|
||||
def _preprocess(self, img):
|
||||
return np.array(img.convert('RGB'))
|
||||
|
||||
|
||||
def post_process(text):
|
||||
text = ''.join(text.split())
|
||||
@@ -229,4 +310,4 @@ def post_process(text):
|
||||
text = re.sub('[・.]{2,}', lambda x: (x.end() - x.start()) * '.', text)
|
||||
text = jaconv.h2z(text, ascii=True, digit=True)
|
||||
|
||||
return text
|
||||
return text
|
||||
@@ -12,17 +12,7 @@ from PIL import UnidentifiedImageError
|
||||
from loguru import logger
|
||||
from pynput import keyboard
|
||||
|
||||
from manga_ocr import MangaOcr
|
||||
from manga_ocr import GoogleVision
|
||||
from manga_ocr import AppleVision
|
||||
from manga_ocr import AzureComputerVision
|
||||
|
||||
engines = ['avision', 'gvision', 'azure', 'mangaocr']
|
||||
|
||||
|
||||
def get_engine_name(engine):
|
||||
engine_names = ['Apple Vision', 'Google Vision', 'Azure Computer Vision', 'Manga OCR']
|
||||
return engine_names[engines.index(engine)]
|
||||
from manga_ocr import *
|
||||
|
||||
|
||||
def are_images_identical(img1, img2):
|
||||
@@ -35,19 +25,12 @@ def are_images_identical(img1, img2):
|
||||
return (img1.shape == img2.shape) and (img1 == img2).all()
|
||||
|
||||
|
||||
def process_and_write_results(mocr, avision, gvision, azure, img_or_path, write_to, engine):
|
||||
def process_and_write_results(engine_instance, engine_name, img_or_path, write_to):
|
||||
t0 = time.time()
|
||||
if engine == 'gvision':
|
||||
text = gvision(img_or_path)
|
||||
elif engine == 'avision':
|
||||
text = avision(img_or_path)
|
||||
elif engine == 'azure':
|
||||
text = azure(img_or_path)
|
||||
else:
|
||||
text = mocr(img_or_path)
|
||||
text = engine_instance(img_or_path)
|
||||
t1 = time.time()
|
||||
|
||||
logger.opt(ansi=True).info(f"Text recognized in {t1 - t0:0.03f}s using <cyan>{get_engine_name(engine)}</cyan>: {text}")
|
||||
logger.opt(ansi=True).info(f"Text recognized in {t1 - t0:0.03f}s using <cyan>{engine_name}</cyan>: {text}")
|
||||
|
||||
if write_to == 'clipboard':
|
||||
pyperclip.copy(text)
|
||||
@@ -81,7 +64,7 @@ def run(read_from='clipboard',
|
||||
:param pretrained_model_name_or_path: Path to a trained model, either local or from Transformers' model hub.
|
||||
:param force_cpu: If True, OCR will use CPU even if GPU is available.
|
||||
:param delay_secs: How often to check for new images, in seconds.
|
||||
:param engine: OCR engine to use. Available: "mangaocr", "gvision", "avision", "azure".
|
||||
:param engine: OCR engine to use. Available: "mangaocr", "gvision", "avision", "azure", "easyocr", "paddleocr".
|
||||
:param verbose: If True, unhides all warnings.
|
||||
"""
|
||||
|
||||
@@ -93,10 +76,20 @@ def run(read_from='clipboard',
|
||||
}
|
||||
logger.configure(**config)
|
||||
|
||||
mocr = MangaOcr(pretrained_model_name_or_path, force_cpu)
|
||||
avision = AppleVision()
|
||||
gvision = GoogleVision()
|
||||
azure = AzureComputerVision()
|
||||
avision = AppleVision()
|
||||
mangaocr = MangaOcr(pretrained_model_name_or_path, force_cpu)
|
||||
easyocr = EasyOCR()
|
||||
paddleocr = PaddleOCR()
|
||||
|
||||
engines = ['avision', 'gvision', 'azure', 'mangaocr', 'easyocr', 'paddleocr']
|
||||
engine_names = ['Apple Vision', 'Google Vision', 'Azure Computer Vision', 'Manga OCR', 'EasyOCR', 'PaddleOCR']
|
||||
engine_instances = [avision, gvision, azure, mangaocr, easyocr, paddleocr]
|
||||
engine_keys = 'agvmeo'
|
||||
|
||||
def get_engine_name(engine):
|
||||
return engine_names[engines.index(engine)]
|
||||
|
||||
if engine not in engines:
|
||||
msg = 'Unknown OCR engine!'
|
||||
@@ -203,8 +196,8 @@ def run(read_from='clipboard',
|
||||
engine = engines[engines.index(engine) + 1]
|
||||
|
||||
logger.opt(ansi=True).info(f"Switched to <cyan>{get_engine_name(engine)}</cyan>!")
|
||||
elif user_input.lower() in 'agvm':
|
||||
new_engine = engines['agvm'.find(user_input.lower())]
|
||||
elif user_input.lower() in engine_keys:
|
||||
new_engine = engines[engine_keys.find(user_input.lower())]
|
||||
if engine != new_engine:
|
||||
engine = new_engine
|
||||
logger.opt(ansi=True).info(f"Switched to <cyan>{get_engine_name(engine)}</cyan>!")
|
||||
@@ -228,7 +221,7 @@ def run(read_from='clipboard',
|
||||
logger.warning('Error while reading from clipboard ({})'.format(error))
|
||||
else:
|
||||
if not just_unpaused and isinstance(img, Image.Image) and not are_images_identical(img, old_img):
|
||||
process_and_write_results(mocr, avision, gvision, azure, img, write_to, engine)
|
||||
process_and_write_results(engine_instances[engines.index(engine)], get_engine_name(engine), img, write_to)
|
||||
|
||||
if just_unpaused:
|
||||
just_unpaused = False
|
||||
@@ -244,7 +237,7 @@ def run(read_from='clipboard',
|
||||
except (UnidentifiedImageError, OSError) as e:
|
||||
logger.warning(f'Error while reading file {path}: {e}')
|
||||
else:
|
||||
process_and_write_results(mocr, avision, gvision, azure, img, write_to, engine)
|
||||
process_and_write_results(engine_instances[engines.index(engine)], get_engine_name(engine), img, write_to)
|
||||
|
||||
time.sleep(delay_secs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user