Add initial version of two-pass OCR processing

This commit is contained in:
AuroraWright
2025-10-06 21:50:31 +02:00
parent 1921ecc849
commit ed9b05d2e0
3 changed files with 196 additions and 71 deletions

View File

@@ -26,6 +26,8 @@ parser.add_argument('-w', '--write_to', type=str, default=argparse.SUPPRESS,
help='Where to save recognized texts to. Can be either "clipboard", "websocket", or a path to a text file.') help='Where to save recognized texts to. Can be either "clipboard", "websocket", or a path to a text file.')
parser.add_argument('-e', '--engine', type=str, default=argparse.SUPPRESS, parser.add_argument('-e', '--engine', type=str, default=argparse.SUPPRESS,
help='OCR engine to use. Available: "mangaocr", "glens", "glensweb", "bing", "gvision", "avision", "alivetext", "azure", "winrtocr", "oneocr", "easyocr", "rapidocr", "ocrspace".') help='OCR engine to use. Available: "mangaocr", "glens", "glensweb", "bing", "gvision", "avision", "alivetext", "azure", "winrtocr", "oneocr", "easyocr", "rapidocr", "ocrspace".')
parser.add_argument('-es', '--engine_secondary', type=str, default=argparse.SUPPRESS,
help='OCR engine to use for two-pass processing.')
parser.add_argument('-p', '--pause_at_startup', type=str2bool, nargs='?', const=True, default=argparse.SUPPRESS, parser.add_argument('-p', '--pause_at_startup', type=str2bool, nargs='?', const=True, default=argparse.SUPPRESS,
help='Pause at startup.') help='Pause at startup.')
parser.add_argument('-i', '--ignore_flag', type=str2bool, nargs='?', const=True, default=argparse.SUPPRESS, parser.add_argument('-i', '--ignore_flag', type=str2bool, nargs='?', const=True, default=argparse.SUPPRESS,
@@ -66,6 +68,7 @@ class Config:
'read_from_secondary': '', 'read_from_secondary': '',
'write_to': 'clipboard', 'write_to': 'clipboard',
'engine': '', 'engine': '',
'engine_secondary': '',
'pause_at_startup': False, 'pause_at_startup': False,
'auto_pause' : 0, 'auto_pause' : 0,
'ignore_flag': False, 'ignore_flag': False,

View File

@@ -197,6 +197,7 @@ class MangaOcr:
readable_name = 'Manga OCR' readable_name = 'Manga OCR'
key = 'm' key = 'm'
available = False available = False
local = True
manual_language = False manual_language = False
coordinate_support = False coordinate_support = False
@@ -229,6 +230,7 @@ class GoogleVision:
readable_name = 'Google Vision' readable_name = 'Google Vision'
key = 'g' key = 'g'
available = False available = False
local = False
manual_language = False manual_language = False
coordinate_support = False coordinate_support = False
@@ -275,6 +277,7 @@ class GoogleLens:
readable_name = 'Google Lens' readable_name = 'Google Lens'
key = 'l' key = 'l'
available = False available = False
local = False
manual_language = False manual_language = False
coordinate_support = True coordinate_support = True
@@ -421,6 +424,7 @@ class GoogleLensWeb:
readable_name = 'Google Lens (web)' readable_name = 'Google Lens (web)'
key = 'k' key = 'k'
available = False available = False
local = False
manual_language = False manual_language = False
coordinate_support = False coordinate_support = False
@@ -518,6 +522,7 @@ class Bing:
readable_name = 'Bing' readable_name = 'Bing'
key = 'b' key = 'b'
available = False available = False
local = False
manual_language = False manual_language = False
coordinate_support = True coordinate_support = True
@@ -697,6 +702,7 @@ class AppleVision:
readable_name = 'Apple Vision' readable_name = 'Apple Vision'
key = 'a' key = 'a'
available = False available = False
local = True
manual_language = True manual_language = True
coordinate_support = False coordinate_support = False
@@ -748,6 +754,7 @@ class AppleLiveText:
readable_name = 'Apple Live Text' readable_name = 'Apple Live Text'
key = 'd' key = 'd'
available = False available = False
local = True
manual_language = True manual_language = True
coordinate_support = True coordinate_support = True
@@ -888,6 +895,7 @@ class WinRTOCR:
readable_name = 'WinRT OCR' readable_name = 'WinRT OCR'
key = 'w' key = 'w'
available = False available = False
local = True
manual_language = True manual_language = True
coordinate_support = False coordinate_support = False
@@ -945,6 +953,7 @@ class OneOCR:
readable_name = 'OneOCR' readable_name = 'OneOCR'
key = 'z' key = 'z'
available = False available = False
local = True
manual_language = False manual_language = False
coordinate_support = True coordinate_support = True
@@ -1068,6 +1077,7 @@ class AzureImageAnalysis:
readable_name = 'Azure Image Analysis' readable_name = 'Azure Image Analysis'
key = 'v' key = 'v'
available = False available = False
local = False
manual_language = False manual_language = False
coordinate_support = False coordinate_support = False
@@ -1123,6 +1133,7 @@ class EasyOCR:
readable_name = 'EasyOCR' readable_name = 'EasyOCR'
key = 'e' key = 'e'
available = False available = False
local = True
manual_language = True manual_language = True
coordinate_support = False coordinate_support = False
@@ -1160,6 +1171,7 @@ class RapidOCR:
readable_name = 'RapidOCR' readable_name = 'RapidOCR'
key = 'r' key = 'r'
available = False available = False
local = True
manual_language = True manual_language = True
coordinate_support = False coordinate_support = False
@@ -1168,10 +1180,10 @@ class RapidOCR:
logger.warning('rapidocr not available, RapidOCR will not work!') logger.warning('rapidocr not available, RapidOCR will not work!')
else: else:
logger.info('Loading RapidOCR model') logger.info('Loading RapidOCR model')
lang_det, lang_rec = self.language_to_model_language(language) lang_rec = self.language_to_model_language(language)
self.model = ROCR(params={ self.model = ROCR(params={
'Det.engine_type': EngineType.ONNXRUNTIME, 'Det.engine_type': EngineType.ONNXRUNTIME,
'Det.lang_type': lang_det, 'Det.lang_type': LangDet.CH,
'Det.model_type': ModelType.SERVER if config['high_accuracy_detection'] else ModelType.MOBILE, 'Det.model_type': ModelType.SERVER if config['high_accuracy_detection'] else ModelType.MOBILE,
'Det.ocr_version': OCRVersion.PPOCRV5, 'Det.ocr_version': OCRVersion.PPOCRV5,
'Rec.engine_type': EngineType.ONNXRUNTIME, 'Rec.engine_type': EngineType.ONNXRUNTIME,
@@ -1185,19 +1197,19 @@ class RapidOCR:
def language_to_model_language(self, language): def language_to_model_language(self, language):
if language == 'ja': if language == 'ja':
return LangDet.CH, LangRec.CH return LangRec.CH
if language == 'zh': if language == 'zh':
return LangDet.CH, LangRec.CH return LangRec.CH
elif language == 'ko': elif language == 'ko':
return LangDet.MULTI, LangRec.KOREAN return LangRec.KOREAN
elif language == 'ru': elif language == 'ru':
return LangDet.MULTI, LangRec.ESLAV return LangRec.ESLAV
elif language == 'el': elif language == 'el':
return LangDet.MULTI, LangRec.EL return LangRec.EL
elif language == 'th': elif language == 'th':
return LangDet.MULTI, LangRec.TH return LangRec.TH
else: else:
return LangDet.MULTI, LangRec.LATIN return LangRec.LATIN
def __call__(self, img): def __call__(self, img):
img, is_path = input_to_pil_image(img) img, is_path = input_to_pil_image(img)
@@ -1224,6 +1236,7 @@ class OCRSpace:
readable_name = 'OCRSpace' readable_name = 'OCRSpace'
key = 'o' key = 'o'
available = False available = False
local = False
manual_language = True manual_language = True
coordinate_support = False coordinate_support = False

View File

@@ -10,6 +10,7 @@ import logging
import inspect import inspect
import os import os
import json import json
import copy
from dataclasses import asdict from dataclasses import asdict
import numpy as np import numpy as np
@@ -24,6 +25,7 @@ from PIL import Image, UnidentifiedImageError
from loguru import logger from loguru import logger
from pynput import keyboard from pynput import keyboard
from desktop_notifier import DesktopNotifierSync, Urgency from desktop_notifier import DesktopNotifierSync, Urgency
from rapidfuzz import fuzz
from .ocr import * from .ocr import *
from .config import config from .config import config
@@ -300,38 +302,15 @@ class RequestHandler(socketserver.BaseRequestHandler):
class TextFiltering: class TextFiltering:
accurate_filtering = False
def __init__(self): def __init__(self):
from pysbd import Segmenter from pysbd import Segmenter
import langid
self.language = config.get_general('language') self.language = config.get_general('language')
self.segmenter = Segmenter(language=self.language, clean=True) self.segmenter = Segmenter(language=self.language, clean=True)
self.classify = langid.classify
self.regex = self.get_regex() self.regex = self.get_regex()
self.last_result = ([], engine_index) self.last_result = ([], engine_index)
try:
from transformers import pipeline, AutoTokenizer
import torch
logging.getLogger('transformers').setLevel(logging.ERROR)
model_ckpt = 'papluca/xlm-roberta-base-language-detection'
tokenizer = AutoTokenizer.from_pretrained(
model_ckpt,
use_fast = False
)
if torch.cuda.is_available():
device = 0
elif torch.backends.mps.is_available():
device = 'mps'
else:
device = -1
self.pipe = pipeline('text-classification', model=model_ckpt, tokenizer=tokenizer, device=device)
self.accurate_filtering = True
except:
import langid
self.classify = langid.classify
def get_regex(self): def get_regex(self):
if self.language == 'ja': if self.language == 'ja':
return re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]') return re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]')
@@ -354,11 +333,26 @@ class TextFiltering:
return re.compile( return re.compile(
r'[a-zA-Z\u00C0-\u00FF\u0100-\u017F\u0180-\u024F\u0250-\u02AF\u1D00-\u1D7F\u1D80-\u1DBF\u1E00-\u1EFF\u2C60-\u2C7F\uA720-\uA7FF\uAB30-\uAB6F]') r'[a-zA-Z\u00C0-\u00FF\u0100-\u017F\u0180-\u024F\u0250-\u02AF\u1D00-\u1D7F\u1D80-\u1DBF\u1E00-\u1EFF\u2C60-\u2C7F\uA720-\uA7FF\uAB30-\uAB6F]')
def convert_small_kana_to_big(self, text):
small_to_big = {
# Hiragana
'': '', '': '', '': '', '': '', '': '',
'': '', '': '', '': '', '': '', '': '',
# Katakana
'': '', '': '', '': '', '': '', '': '',
'': '', '': '', '': '', '': '', '': ''
}
converted_text = ''.join(small_to_big.get(char, char) for char in text)
return converted_text
def __call__(self, text): def __call__(self, text):
orig_text = self.segmenter.segment(text) orig_text = self.segmenter.segment(text)
orig_text_filtered = [] orig_text_filtered = []
for block in orig_text: for block in orig_text:
block_filtered = self.regex.findall(block) block_filtered = self.regex.findall(block)
if self.language == 'ja':
block_filtered = self.convert_small_kana_to_big(block_filtered)
if block_filtered: if block_filtered:
orig_text_filtered.append(''.join(block_filtered)) orig_text_filtered.append(''.join(block_filtered))
@@ -376,18 +370,10 @@ class TextFiltering:
new_blocks.append(block) new_blocks.append(block)
final_blocks = [] final_blocks = []
if self.accurate_filtering: for block in new_blocks:
detection_results = self.pipe(new_blocks, top_k=3, truncation=True) # This only looks at language IF language is ja or zh, otherwise it keeps all text
for idx, block in enumerate(new_blocks): if self.language not in ['ja', 'zh'] or self.classify(block)[0] in ['ja', 'zh'] or block == "\n":
for result in detection_results[idx]: final_blocks.append(block)
if result['label'] == self.language:
final_blocks.append(block)
break
else:
for block in new_blocks:
# This only looks at language IF language is ja or zh, otherwise it keeps all text
if self.language not in ["ja", "zh"] or self.classify(block)[0] in ['ja', 'zh'] or block == "\n":
final_blocks.append(block)
text = '\n'.join(final_blocks) text = '\n'.join(final_blocks)
@@ -724,64 +710,183 @@ class OutputResult:
def __init__(self, init_filtering): def __init__(self, init_filtering):
self.filtering = TextFiltering() if init_filtering else None self.filtering = TextFiltering() if init_filtering else None
self.cj_regex = re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]') self.cj_regex = re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]')
self.previous_result = None
def _coordinate_format_to_string(self, result_data): def _coordinate_format_to_string(self, result_data):
full_text_parts = [] full_text_parts = []
for p in result_data.paragraphs: for p in result_data.paragraphs:
for l in p.lines: for l in p.lines:
if l.text != None: full_text_parts.append(self._get_line_text(l))
full_text_parts.append(l.text)
else:
for w in l.words:
full_text_parts.append(w.text)
if w.separator != None:
full_text_parts.append(w.separator)
else:
full_text_parts.append(' ')
full_text_parts.append('\n') full_text_parts.append('\n')
return "".join(full_text_parts) return ''.join(full_text_parts)
def _post_process(self, text): def _post_process(self, text, strip_spaces):
is_cj_text = self.cj_regex.search(text) is_cj_text = self.cj_regex.search(text)
line_separator = '' if strip_spaces else ' '
if is_cj_text: if is_cj_text:
text = ' '.join([''.join(i.split()) for i in text.splitlines()]) text = line_separator.join([''.join(i.split()) for i in text.splitlines()])
else: else:
text = ' '.join([re.sub(r'\s+', ' ', i).strip() for i in text.splitlines()]) text = line_separator.join([re.sub(r'\s+', ' ', i).strip() for i in text.splitlines()])
text = text.replace('', '...') text = text.replace('', '...')
text = re.sub('[・.]{2,}', lambda x: (x.end() - x.start()) * '.', text) text = re.sub('[・.]{2,}', lambda x: (x.end() - x.start()) * '.', text)
if is_cj_text: if is_cj_text:
text = jaconv.h2z(text, ascii=True, digit=True) text = jaconv.h2z(text, ascii=True, digit=True)
return text return text
def _get_line_text(self, line):
if line.text is not None:
return line.text
text_parts = []
for w in line.words:
text_parts.append(w.text)
if w.separator is not None:
text_parts.append(w.separator)
else:
text_parts.append(' ')
return ''.join(text_parts)
def _compare_text(self, current_text, prev_text, threshold=80):
if current_text in prev_text:
return True
if len(prev_text) > len(current_text):
return fuzz.partial_ratio(current_text, prev_text) >= threshold
return fuzz.ratio(current_text, prev_text) >= threshold
def _find_changed_lines(self, current_result, previous_result):
changed_lines = []
# If no previous result, all lines are considered changed
if previous_result is None:
for p in current_result.paragraphs:
changed_lines.extend(p.lines)
return changed_lines
# Check if image sizes are different - if so, treat all lines as changed
if (current_result.image_properties.width != previous_result.image_properties.width or
current_result.image_properties.height != previous_result.image_properties.height):
for p in current_result.paragraphs:
changed_lines.extend(p.lines)
return changed_lines
current_lines = []
previous_lines = []
for p in current_result.paragraphs:
current_lines.extend(p.lines)
for p in previous_result.paragraphs:
previous_lines.extend(p.lines)
all_previous_text = ''
for prev_line in previous_lines:
prev_text = self._get_line_text(prev_line)
prev_text = ''.join(self.filtering.regex.findall(prev_text))
if self.filtering.language == 'ja':
prev_text = self.filtering.convert_small_kana_to_big(prev_text)
all_previous_text += prev_text
for current_line in current_lines:
current_text = self._get_line_text(current_line)
current_text = ''.join(self.filtering.regex.findall(current_text))
if self.filtering.language == 'ja':
current_text = self.filtering.convert_small_kana_to_big(current_text)
text_similar = self._compare_text(current_text, all_previous_text)
if not text_similar:
changed_lines.append(current_line)
return changed_lines
def _create_changed_regions_image(self, pil_image, changed_lines, margin=5):
img_width, img_height = pil_image.size
# Convert normalized coordinates to pixel coordinates
regions = []
for line in changed_lines:
bbox = line.bounding_box
# Convert center-based bbox to corner-based
x1 = (bbox.center_x - bbox.width/2) * img_width - margin
y1 = (bbox.center_y - bbox.height/2) * img_height - margin
x2 = (bbox.center_x + bbox.width/2) * img_width + margin
y2 = (bbox.center_y + bbox.height/2) * img_height + margin
# Ensure coordinates are within image bounds
x1 = max(0, int(x1))
y1 = max(0, int(y1))
x2 = min(img_width, int(x2))
y2 = min(img_height, int(y2))
if x2 > x1 and y2 > y1: #Only add valid regions
regions.append((x1, y1, x2, y2))
if not regions:
return None
# Calculate the bounding box that contains all regions
overall_x1 = min(x1 for x1, y1, x2, y2 in regions)
overall_y1 = min(y1 for x1, y1, x2, y2 in regions)
overall_x2 = max(x2 for x1, y1, x2, y2 in regions)
overall_y2 = max(y2 for x1, y1, x2, y2 in regions)
# Crop the single rectangle containing all changed regions
result_image = pil_image.crop((overall_x1, overall_y1, overall_x2, overall_y2))
return result_image
def __call__(self, img_or_path, filter_text, notify): def __call__(self, img_or_path, filter_text, notify):
if auto_pause_handler and not filter_text: if auto_pause_handler and not filter_text:
auto_pause_handler.stop() auto_pause_handler.stop()
output_format = config.get_general('output_format')
engine_color = config.get_general('engine_color')
engine_instance = engine_instances[engine_index] engine_instance = engine_instances[engine_index]
if filter_text and engine_index_2 != -1 and engine_index_2 != engine_index:
engine_instance_2 = engine_instances[engine_index_2]
start_time = time.time()
res2, result_data_2 = engine_instance_2(img_or_path)
end_time = time.time()
if not res2:
logger.opt(ansi=True).warning(f'<{engine_color}>{engine_instance_2.readable_name}</{engine_color}> reported an error after {end_time - start_time:0.03f}s: {result_data_2}')
else:
changed_lines = self._find_changed_lines(result_data_2, self.previous_result)
self.previous_result = copy.deepcopy(result_data_2)
if len(changed_lines) > 0:
logger.opt(ansi=True).info(f"<{engine_color}>{engine_instance_2.readable_name}</{engine_color}> found {len(changed_lines)} changed line(s) in {end_time - start_time:0.03f}s, re-OCRing with <{engine_color}>{engine_instance.readable_name}</{engine_color}>")
if output_format != 'json':
changed_regions_image = self._create_changed_regions_image(img_or_path, changed_lines)
if changed_regions_image:
img_or_path = changed_regions_image
else:
logger.warning('Error occurred while creating the differential image.')
else:
return
start_time = time.time() start_time = time.time()
res, result_data = engine_instance(img_or_path) res, result_data = engine_instance(img_or_path)
end_time = time.time() end_time = time.time()
orig_text = []
engine_color = config.get_general('engine_color')
if not res: if not res:
logger.opt(ansi=True).info(f'<{engine_color}>{engine_instance.readable_name}</{engine_color}> reported an error after {end_time - start_time:0.03f}s: {result_data}') logger.opt(ansi=True).warning(f'<{engine_color}>{engine_instance.readable_name}</{engine_color}> reported an error after {end_time - start_time:0.03f}s: {result_data}')
return orig_text return
output_format = config.get_general('output_format')
verbosity = config.get_general('verbosity') verbosity = config.get_general('verbosity')
output_string = '' output_string = ''
log_message = '' log_message = ''
result_data_text = None result_data_text = None
# Check if the engine returned a structured OcrResult object
if isinstance(result_data, OcrResult): if isinstance(result_data, OcrResult):
unprocessed_text = self._coordinate_format_to_string(result_data) unprocessed_text = self._coordinate_format_to_string(result_data)
if output_format == 'json': if output_format == 'json':
result_dict = asdict(result_data) result_dict = asdict(result_data)
output_string = json.dumps(result_dict, ensure_ascii=False) output_string = json.dumps(result_dict, ensure_ascii=False)
log_message = self._post_process(unprocessed_text) log_message = self._post_process(unprocessed_text, False)
else: else:
result_data_text = unprocessed_text result_data_text = unprocessed_text
else: else:
@@ -792,9 +897,9 @@ class OutputResult:
logger.warning(f"Engine '{engine_instance.name}' does not support JSON output. Falling back to text.") logger.warning(f"Engine '{engine_instance.name}' does not support JSON output. Falling back to text.")
if filter_text: if filter_text:
text_to_process = self.filtering(result_data_text) text_to_process = self.filtering(result_data_text)
output_string = self._post_process(text_to_process) output_string = self._post_process(text_to_process, True)
else: else:
output_string = self._post_process(result_data_text) output_string = self._post_process(result_data_text, False)
log_message = output_string log_message = output_string
if verbosity != 0: if verbosity != 0:
@@ -810,7 +915,6 @@ class OutputResult:
if notify and config.get_general('notifications'): if notify and config.get_general('notifications'):
notifier.send(title='owocr', message='Text recognized: ' + log_message, urgency=get_notification_urgency()) notifier.send(title='owocr', message='Text recognized: ' + log_message, urgency=get_notification_urgency())
# Write the final formatted string to the destination
write_to = config.get_general('write_to') write_to = config.get_general('write_to')
if write_to == 'websocket': if write_to == 'websocket':
websocket_server_thread.send_text(output_string) websocket_server_thread.send_text(output_string)
@@ -932,6 +1036,7 @@ def run():
config_engines = [] config_engines = []
engine_keys = [] engine_keys = []
default_engine = '' default_engine = ''
engine_secondary = ''
if len(config.get_general('engines')) > 0: if len(config.get_general('engines')) > 0:
for config_engine in config.get_general('engines').split(','): for config_engine in config.get_general('engines').split(','):
@@ -955,12 +1060,15 @@ def run():
engine_keys.append(engine_class.key) engine_keys.append(engine_class.key)
if config.get_general('engine') == engine_class.name: if config.get_general('engine') == engine_class.name:
default_engine = engine_class.key default_engine = engine_class.key
if config.get_general('engine_secondary') == engine_class.name and engine_class.local and engine_class.coordinate_support:
engine_secondary = engine_class.key
if len(engine_keys) == 0: if len(engine_keys) == 0:
msg = 'No engines available!' msg = 'No engines available!'
raise NotImplementedError(msg) raise NotImplementedError(msg)
global engine_index global engine_index
global engine_index_2
global terminated global terminated
global paused global paused
global notifier global notifier
@@ -987,6 +1095,7 @@ def run():
init_filtering = False init_filtering = False
auto_pause_handler = None auto_pause_handler = None
engine_index = engine_keys.index(default_engine) if default_engine != '' else 0 engine_index = engine_keys.index(default_engine) if default_engine != '' else 0
engine_index_2 = engine_keys.index(engine_secondary) if engine_secondary != '' else -1
engine_color = config.get_general('engine_color') engine_color = config.get_general('engine_color')
combo_pause = config.get_general('combo_pause') combo_pause = config.get_general('combo_pause')
combo_engine_switch = config.get_general('combo_engine_switch') combo_engine_switch = config.get_general('combo_engine_switch')