Rework stuff, add Windows 10/11 OCR
This commit is contained in:
154
manga_ocr/ocr.py
154
manga_ocr/ocr.py
@@ -3,7 +3,6 @@ import os
|
||||
import io
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
import configparser
|
||||
import time
|
||||
import sys
|
||||
import platform
|
||||
@@ -11,6 +10,7 @@ import platform
|
||||
import jaconv
|
||||
import torch
|
||||
import numpy as np
|
||||
import json
|
||||
from PIL import Image
|
||||
from loguru import logger
|
||||
from transformers import ViTImageProcessor, AutoTokenizer, VisionEncoderDecoderModel
|
||||
@@ -44,8 +44,38 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import requests
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import winocr
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def post_process(text):
|
||||
text = ''.join(text.split())
|
||||
text = text.replace('…', '...')
|
||||
text = re.sub('[・.]{2,}', lambda x: (x.end() - x.start()) * '.', text)
|
||||
text = jaconv.h2z(text, ascii=True, digit=True)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
class MangaOcr:
|
||||
def __init__(self, pretrained_model_name_or_path='kha-white/manga-ocr-base', force_cpu=False):
|
||||
name = "mangaocr"
|
||||
readable_name = "Manga OCR"
|
||||
key = "m"
|
||||
available = True
|
||||
|
||||
def __init__(self, config={'pretrained_model_name_or_path':'kha-white/manga-ocr-base','force_cpu':'False'}, pretrained_model_name_or_path='', force_cpu=False):
|
||||
if pretrained_model_name_or_path == '':
|
||||
pretrained_model_name_or_path = config['pretrained_model_name_or_path']
|
||||
if config['force_cpu'] == 'True':
|
||||
force_cpu = True
|
||||
|
||||
logger.info(f'Loading Manga OCR model from {pretrained_model_name_or_path}')
|
||||
self.processor = ViTImageProcessor.from_pretrained(pretrained_model_name_or_path)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
|
||||
@@ -84,10 +114,14 @@ class MangaOcr:
|
||||
return pixel_values.squeeze()
|
||||
|
||||
class GoogleVision:
|
||||
name = "gvision"
|
||||
readable_name = "Google Vision"
|
||||
key = "g"
|
||||
available = False
|
||||
|
||||
def __init__(self):
|
||||
if 'google.cloud' not in sys.modules:
|
||||
logger.warning('google-cloud-vision not available, Google Vision will not work!')
|
||||
self.available = False
|
||||
else:
|
||||
logger.info(f'Parsing Google credentials')
|
||||
google_credentials_file = os.path.join(os.path.expanduser('~'),'.config','google_vision.json')
|
||||
@@ -98,12 +132,8 @@ class GoogleVision:
|
||||
logger.info('Google Vision ready')
|
||||
except:
|
||||
logger.warning('Error parsing Google credentials, Google Vision will not work!')
|
||||
self.available = False
|
||||
|
||||
def __call__(self, img_or_path):
|
||||
if not self.available:
|
||||
return "Engine not available!"
|
||||
|
||||
if isinstance(img_or_path, str) or isinstance(img_or_path, Path):
|
||||
img = Image.open(img_or_path)
|
||||
elif isinstance(img_or_path, Image.Image):
|
||||
@@ -124,25 +154,24 @@ class GoogleVision:
|
||||
return image_bytes.getvalue()
|
||||
|
||||
class AppleVision:
|
||||
name = "avision"
|
||||
readable_name = "Apple Vision"
|
||||
key = "a"
|
||||
available = False
|
||||
|
||||
def __init__(self):
|
||||
if sys.platform != "darwin":
|
||||
logger.warning('Apple Vision is not supported on non-macOS platforms!')
|
||||
self.available = False
|
||||
elif int(platform.mac_ver()[0].split('.')[0]) < 13:
|
||||
logger.warning('Apple Vision is not supported on macOS older than Ventura/13.0!')
|
||||
self.available = False
|
||||
else:
|
||||
if 'objc' not in sys.modules:
|
||||
logger.warning('pyobjc not available, Apple Vision will not work!')
|
||||
self.available = False
|
||||
else:
|
||||
self.available = True
|
||||
logger.info('Apple Vision ready')
|
||||
|
||||
def __call__(self, img_or_path):
|
||||
if not self.available:
|
||||
return "Engine not available!"
|
||||
|
||||
if isinstance(img_or_path, str) or isinstance(img_or_path, Path):
|
||||
img = Image.open(img_or_path)
|
||||
elif isinstance(img_or_path, Image.Image):
|
||||
@@ -177,28 +206,78 @@ class AppleVision:
|
||||
img.save(image_bytes, format=img.format)
|
||||
return image_bytes.getvalue()
|
||||
|
||||
class WinRTOCR:
|
||||
name = "winrtocr"
|
||||
readable_name = "WinRT OCR"
|
||||
key = "w"
|
||||
available = False
|
||||
|
||||
def __init__(self, config={}):
|
||||
if os.name == 'nt':
|
||||
if int(platform.release()) < 10:
|
||||
logger.warning('WinRT OCR is not supported on Windows older than 10!')
|
||||
elif 'winocr' not in sys.modules:
|
||||
logger.warning('winocr not available, WinRT OCR will not work!')
|
||||
else:
|
||||
self.available = True
|
||||
logger.info('WinRT OCR ready')
|
||||
else:
|
||||
if 'requests' not in sys.modules:
|
||||
logger.warning('requests not available, WinRT OCR will not work!')
|
||||
else:
|
||||
try:
|
||||
self.url = config['url']
|
||||
self.available = True
|
||||
logger.info('WinRT OCR ready')
|
||||
except:
|
||||
logger.warning('Error reading URL from config, WinRT OCR will not work!')
|
||||
|
||||
def __call__(self, img_or_path):
|
||||
if isinstance(img_or_path, str) or isinstance(img_or_path, Path):
|
||||
img = Image.open(img_or_path)
|
||||
elif isinstance(img_or_path, Image.Image):
|
||||
img = img_or_path
|
||||
else:
|
||||
raise ValueError(f'img_or_path must be a path or PIL.Image, instead got: {img_or_path}')
|
||||
|
||||
if os.name == 'nt':
|
||||
res = winocr.recognize_pil_sync(img, lang='ja')['text']
|
||||
else:
|
||||
params = {'lang': 'ja'}
|
||||
try:
|
||||
res = requests.post(self.url, params=params, data=self._preprocess(img), timeout=3)
|
||||
except requests.exceptions.Timeout:
|
||||
return "Request timeout!"
|
||||
|
||||
res = json.loads(res.text)['text']
|
||||
|
||||
x = post_process(res)
|
||||
return x
|
||||
|
||||
def _preprocess(self, img):
|
||||
image_bytes = io.BytesIO()
|
||||
img.save(image_bytes, format=img.format)
|
||||
return image_bytes.getvalue()
|
||||
|
||||
class AzureComputerVision:
|
||||
def __init__(self):
|
||||
name = "azure"
|
||||
readable_name = "Azure Computer Vision"
|
||||
key = "v"
|
||||
available = False
|
||||
|
||||
def __init__(self, config={}):
|
||||
if 'azure.cognitiveservices.vision.computervision' not in sys.modules:
|
||||
logger.warning('azure-cognitiveservices-vision-computervision not available, Azure Computer Vision will not work!')
|
||||
self.available = False
|
||||
else:
|
||||
logger.info(f'Parsing Azure credentials')
|
||||
azure_credentials_file = os.path.join(os.path.expanduser('~'),'.config','azure_computer_vision.ini')
|
||||
try:
|
||||
azure_credentials = configparser.ConfigParser()
|
||||
azure_credentials.read(azure_credentials_file)
|
||||
self.client = ComputerVisionClient(azure_credentials['config']['endpoint'], CognitiveServicesCredentials(azure_credentials['config']['api_key']))
|
||||
self.client = ComputerVisionClient(config['endpoint'], CognitiveServicesCredentials(config['api_key']))
|
||||
self.available = True
|
||||
logger.info('Azure Computer Vision ready')
|
||||
except:
|
||||
logger.warning('Error parsing Azure credentials, Azure Computer Vision will not work!')
|
||||
self.available = False
|
||||
|
||||
def __call__(self, img_or_path):
|
||||
if not self.available:
|
||||
return "Engine not available!"
|
||||
|
||||
if isinstance(img_or_path, str) or isinstance(img_or_path, Path):
|
||||
img = Image.open(img_or_path)
|
||||
elif isinstance(img_or_path, Image.Image):
|
||||
@@ -234,10 +313,14 @@ class AzureComputerVision:
|
||||
return image_io
|
||||
|
||||
class EasyOCR:
|
||||
name = "easyocr"
|
||||
readable_name = "EasyOCR"
|
||||
key = "e"
|
||||
available = False
|
||||
|
||||
def __init__(self):
|
||||
if 'easyocr' not in sys.modules:
|
||||
logger.warning('easyocr not available, EasyOCR will not work!')
|
||||
self.available = False
|
||||
else:
|
||||
logger.info('Loading EasyOCR model')
|
||||
self.model = easyocr.Reader(['ja','en'])
|
||||
@@ -245,9 +328,6 @@ class EasyOCR:
|
||||
logger.info('EasyOCR ready')
|
||||
|
||||
def __call__(self, img_or_path):
|
||||
if not self.available:
|
||||
return "Engine not available!"
|
||||
|
||||
if isinstance(img_or_path, str) or isinstance(img_or_path, Path):
|
||||
img = Image.open(img_or_path)
|
||||
elif isinstance(img_or_path, Image.Image):
|
||||
@@ -269,10 +349,14 @@ class EasyOCR:
|
||||
return image_bytes.getvalue()
|
||||
|
||||
class PaddleOCR:
|
||||
name = "paddleocr"
|
||||
readable_name = "PaddleOCR"
|
||||
key = "o"
|
||||
available = False
|
||||
|
||||
def __init__(self):
|
||||
if 'paddleocr' not in sys.modules:
|
||||
logger.warning('easyocr not available, PaddleOCR will not work!')
|
||||
self.available = False
|
||||
logger.warning('paddleocr not available, PaddleOCR will not work!')
|
||||
else:
|
||||
logger.info('Loading PaddleOCR model')
|
||||
self.model = POCR(use_angle_cls=True, show_log=False, lang='japan')
|
||||
@@ -280,9 +364,6 @@ class PaddleOCR:
|
||||
logger.info('PaddleOCR ready')
|
||||
|
||||
def __call__(self, img_or_path):
|
||||
if not self.available:
|
||||
return "Engine not available!"
|
||||
|
||||
if isinstance(img_or_path, str) or isinstance(img_or_path, Path):
|
||||
img = Image.open(img_or_path)
|
||||
elif isinstance(img_or_path, Image.Image):
|
||||
@@ -302,12 +383,3 @@ class PaddleOCR:
|
||||
|
||||
def _preprocess(self, img):
|
||||
return np.array(img.convert('RGB'))
|
||||
|
||||
|
||||
def post_process(text):
|
||||
text = ''.join(text.split())
|
||||
text = text.replace('…', '...')
|
||||
text = re.sub('[・.]{2,}', lambda x: (x.end() - x.start()) * '.', text)
|
||||
text = jaconv.h2z(text, ascii=True, digit=True)
|
||||
|
||||
return text
|
||||
Reference in New Issue
Block a user