From ed9b05d2e09649cdde618d1dec0f3a3c44158afc Mon Sep 17 00:00:00 2001 From: AuroraWright Date: Mon, 6 Oct 2025 21:50:31 +0200 Subject: [PATCH] Add initial version of two-pass OCR processing --- owocr/config.py | 3 + owocr/ocr.py | 31 +++++-- owocr/run.py | 233 +++++++++++++++++++++++++++++++++++------------- 3 files changed, 196 insertions(+), 71 deletions(-) diff --git a/owocr/config.py b/owocr/config.py index 9acb604..fdbea45 100644 --- a/owocr/config.py +++ b/owocr/config.py @@ -26,6 +26,8 @@ parser.add_argument('-w', '--write_to', type=str, default=argparse.SUPPRESS, help='Where to save recognized texts to. Can be either "clipboard", "websocket", or a path to a text file.') parser.add_argument('-e', '--engine', type=str, default=argparse.SUPPRESS, help='OCR engine to use. Available: "mangaocr", "glens", "glensweb", "bing", "gvision", "avision", "alivetext", "azure", "winrtocr", "oneocr", "easyocr", "rapidocr", "ocrspace".') +parser.add_argument('-es', '--engine_secondary', type=str, default=argparse.SUPPRESS, + help='OCR engine to use for two-pass processing.') parser.add_argument('-p', '--pause_at_startup', type=str2bool, nargs='?', const=True, default=argparse.SUPPRESS, help='Pause at startup.') parser.add_argument('-i', '--ignore_flag', type=str2bool, nargs='?', const=True, default=argparse.SUPPRESS, @@ -66,6 +68,7 @@ class Config: 'read_from_secondary': '', 'write_to': 'clipboard', 'engine': '', + 'engine_secondary': '', 'pause_at_startup': False, 'auto_pause' : 0, 'ignore_flag': False, diff --git a/owocr/ocr.py b/owocr/ocr.py index 43a1235..eef3afb 100644 --- a/owocr/ocr.py +++ b/owocr/ocr.py @@ -197,6 +197,7 @@ class MangaOcr: readable_name = 'Manga OCR' key = 'm' available = False + local = True manual_language = False coordinate_support = False @@ -229,6 +230,7 @@ class GoogleVision: readable_name = 'Google Vision' key = 'g' available = False + local = False manual_language = False coordinate_support = False @@ -275,6 +277,7 @@ class GoogleLens: readable_name = 'Google Lens' key = 'l' available = False + local = False manual_language = False coordinate_support = True @@ -421,6 +424,7 @@ class GoogleLensWeb: readable_name = 'Google Lens (web)' key = 'k' available = False + local = False manual_language = False coordinate_support = False @@ -518,6 +522,7 @@ class Bing: readable_name = 'Bing' key = 'b' available = False + local = False manual_language = False coordinate_support = True @@ -697,6 +702,7 @@ class AppleVision: readable_name = 'Apple Vision' key = 'a' available = False + local = True manual_language = True coordinate_support = False @@ -748,6 +754,7 @@ class AppleLiveText: readable_name = 'Apple Live Text' key = 'd' available = False + local = True manual_language = True coordinate_support = True @@ -888,6 +895,7 @@ class WinRTOCR: readable_name = 'WinRT OCR' key = 'w' available = False + local = True manual_language = True coordinate_support = False @@ -945,6 +953,7 @@ class OneOCR: readable_name = 'OneOCR' key = 'z' available = False + local = True manual_language = False coordinate_support = True @@ -1068,6 +1077,7 @@ class AzureImageAnalysis: readable_name = 'Azure Image Analysis' key = 'v' available = False + local = False manual_language = False coordinate_support = False @@ -1123,6 +1133,7 @@ class EasyOCR: readable_name = 'EasyOCR' key = 'e' available = False + local = True manual_language = True coordinate_support = False @@ -1160,6 +1171,7 @@ class RapidOCR: readable_name = 'RapidOCR' key = 'r' available = False + local = True manual_language = True coordinate_support = False @@ -1168,10 +1180,10 @@ class RapidOCR: logger.warning('rapidocr not available, RapidOCR will not work!') else: logger.info('Loading RapidOCR model') - lang_det, lang_rec = self.language_to_model_language(language) + lang_rec = self.language_to_model_language(language) self.model = ROCR(params={ 'Det.engine_type': EngineType.ONNXRUNTIME, - 'Det.lang_type': lang_det, + 'Det.lang_type': LangDet.CH, 'Det.model_type': ModelType.SERVER if config['high_accuracy_detection'] else ModelType.MOBILE, 'Det.ocr_version': OCRVersion.PPOCRV5, 'Rec.engine_type': EngineType.ONNXRUNTIME, @@ -1185,19 +1197,19 @@ class RapidOCR: def language_to_model_language(self, language): if language == 'ja': - return LangDet.CH, LangRec.CH + return LangRec.CH if language == 'zh': - return LangDet.CH, LangRec.CH + return LangRec.CH elif language == 'ko': - return LangDet.MULTI, LangRec.KOREAN + return LangRec.KOREAN elif language == 'ru': - return LangDet.MULTI, LangRec.ESLAV + return LangRec.ESLAV elif language == 'el': - return LangDet.MULTI, LangRec.EL + return LangRec.EL elif language == 'th': - return LangDet.MULTI, LangRec.TH + return LangRec.TH else: - return LangDet.MULTI, LangRec.LATIN + return LangRec.LATIN def __call__(self, img): img, is_path = input_to_pil_image(img) @@ -1224,6 +1236,7 @@ class OCRSpace: readable_name = 'OCRSpace' key = 'o' available = False + local = False manual_language = True coordinate_support = False diff --git a/owocr/run.py b/owocr/run.py index ca11c47..a39ebde 100644 --- a/owocr/run.py +++ b/owocr/run.py @@ -10,6 +10,7 @@ import logging import inspect import os import json +import copy from dataclasses import asdict import numpy as np @@ -24,6 +25,7 @@ from PIL import Image, UnidentifiedImageError from loguru import logger from pynput import keyboard from desktop_notifier import DesktopNotifierSync, Urgency +from rapidfuzz import fuzz from .ocr import * from .config import config @@ -300,38 +302,15 @@ class RequestHandler(socketserver.BaseRequestHandler): class TextFiltering: - accurate_filtering = False - def __init__(self): from pysbd import Segmenter + import langid self.language = config.get_general('language') self.segmenter = Segmenter(language=self.language, clean=True) + self.classify = langid.classify self.regex = self.get_regex() self.last_result = ([], engine_index) - try: - from transformers import pipeline, AutoTokenizer - import torch - logging.getLogger('transformers').setLevel(logging.ERROR) - - model_ckpt = 'papluca/xlm-roberta-base-language-detection' - tokenizer = AutoTokenizer.from_pretrained( - model_ckpt, - use_fast = False - ) - - if torch.cuda.is_available(): - device = 0 - elif torch.backends.mps.is_available(): - device = 'mps' - else: - device = -1 - self.pipe = pipeline('text-classification', model=model_ckpt, tokenizer=tokenizer, device=device) - self.accurate_filtering = True - except: - import langid - self.classify = langid.classify - def get_regex(self): if self.language == 'ja': return re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]') @@ -354,11 +333,26 @@ class TextFiltering: return re.compile( r'[a-zA-Z\u00C0-\u00FF\u0100-\u017F\u0180-\u024F\u0250-\u02AF\u1D00-\u1D7F\u1D80-\u1DBF\u1E00-\u1EFF\u2C60-\u2C7F\uA720-\uA7FF\uAB30-\uAB6F]') + def convert_small_kana_to_big(self, text): + small_to_big = { + # Hiragana + 'ぁ': 'あ', 'ぃ': 'い', 'ぅ': 'う', 'ぇ': 'え', 'ぉ': 'お', + 'っ': 'つ', 'ゃ': 'や', 'ゅ': 'ゆ', 'ょ': 'よ', 'ゎ': 'わ', + # Katakana + 'ァ': 'ア', 'ィ': 'イ', 'ゥ': 'ウ', 'ェ': 'エ', 'ォ': 'オ', + 'ッ': 'ツ', 'ャ': 'ヤ', 'ュ': 'ユ', 'ョ': 'ヨ', 'ヮ': 'ワ' + } + + converted_text = ''.join(small_to_big.get(char, char) for char in text) + return converted_text + def __call__(self, text): orig_text = self.segmenter.segment(text) orig_text_filtered = [] for block in orig_text: block_filtered = self.regex.findall(block) + if self.language == 'ja': + block_filtered = self.convert_small_kana_to_big(block_filtered) if block_filtered: orig_text_filtered.append(''.join(block_filtered)) @@ -376,18 +370,10 @@ class TextFiltering: new_blocks.append(block) final_blocks = [] - if self.accurate_filtering: - detection_results = self.pipe(new_blocks, top_k=3, truncation=True) - for idx, block in enumerate(new_blocks): - for result in detection_results[idx]: - if result['label'] == self.language: - final_blocks.append(block) - break - else: - for block in new_blocks: - # This only looks at language IF language is ja or zh, otherwise it keeps all text - if self.language not in ["ja", "zh"] or self.classify(block)[0] in ['ja', 'zh'] or block == "\n": - final_blocks.append(block) + for block in new_blocks: + # This only looks at language IF language is ja or zh, otherwise it keeps all text + if self.language not in ['ja', 'zh'] or self.classify(block)[0] in ['ja', 'zh'] or block == "\n": + final_blocks.append(block) text = '\n'.join(final_blocks) @@ -675,7 +661,7 @@ class ScreenshotThread(threading.Thread): try: win32gui.ReleaseDC(self.window_handle, hwnd_dc) except: - pass + pass else: sct_img = sct.grab(self.sct_params) img = Image.frombytes('RGB', sct_img.size, sct_img.bgra, 'raw', 'BGRX') @@ -724,64 +710,183 @@ class OutputResult: def __init__(self, init_filtering): self.filtering = TextFiltering() if init_filtering else None self.cj_regex = re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]') + self.previous_result = None def _coordinate_format_to_string(self, result_data): full_text_parts = [] for p in result_data.paragraphs: for l in p.lines: - if l.text != None: - full_text_parts.append(l.text) - else: - for w in l.words: - full_text_parts.append(w.text) - if w.separator != None: - full_text_parts.append(w.separator) - else: - full_text_parts.append(' ') + full_text_parts.append(self._get_line_text(l)) full_text_parts.append('\n') - return "".join(full_text_parts) + return ''.join(full_text_parts) - def _post_process(self, text): + def _post_process(self, text, strip_spaces): is_cj_text = self.cj_regex.search(text) + line_separator = '' if strip_spaces else ' ' if is_cj_text: - text = ' '.join([''.join(i.split()) for i in text.splitlines()]) + text = line_separator.join([''.join(i.split()) for i in text.splitlines()]) else: - text = ' '.join([re.sub(r'\s+', ' ', i).strip() for i in text.splitlines()]) + text = line_separator.join([re.sub(r'\s+', ' ', i).strip() for i in text.splitlines()]) text = text.replace('…', '...') text = re.sub('[・.]{2,}', lambda x: (x.end() - x.start()) * '.', text) if is_cj_text: text = jaconv.h2z(text, ascii=True, digit=True) return text + def _get_line_text(self, line): + if line.text is not None: + return line.text + text_parts = [] + for w in line.words: + text_parts.append(w.text) + if w.separator is not None: + text_parts.append(w.separator) + else: + text_parts.append(' ') + return ''.join(text_parts) + + def _compare_text(self, current_text, prev_text, threshold=80): + if current_text in prev_text: + return True + if len(prev_text) > len(current_text): + return fuzz.partial_ratio(current_text, prev_text) >= threshold + return fuzz.ratio(current_text, prev_text) >= threshold + + def _find_changed_lines(self, current_result, previous_result): + changed_lines = [] + + # If no previous result, all lines are considered changed + if previous_result is None: + for p in current_result.paragraphs: + changed_lines.extend(p.lines) + return changed_lines + + # Check if image sizes are different - if so, treat all lines as changed + if (current_result.image_properties.width != previous_result.image_properties.width or + current_result.image_properties.height != previous_result.image_properties.height): + for p in current_result.paragraphs: + changed_lines.extend(p.lines) + return changed_lines + + current_lines = [] + previous_lines = [] + + for p in current_result.paragraphs: + current_lines.extend(p.lines) + for p in previous_result.paragraphs: + previous_lines.extend(p.lines) + + all_previous_text = '' + for prev_line in previous_lines: + prev_text = self._get_line_text(prev_line) + prev_text = ''.join(self.filtering.regex.findall(prev_text)) + if self.filtering.language == 'ja': + prev_text = self.filtering.convert_small_kana_to_big(prev_text) + all_previous_text += prev_text + + for current_line in current_lines: + current_text = self._get_line_text(current_line) + current_text = ''.join(self.filtering.regex.findall(current_text)) + if self.filtering.language == 'ja': + current_text = self.filtering.convert_small_kana_to_big(current_text) + + text_similar = self._compare_text(current_text, all_previous_text) + + if not text_similar: + changed_lines.append(current_line) + + return changed_lines + + def _create_changed_regions_image(self, pil_image, changed_lines, margin=5): + img_width, img_height = pil_image.size + + # Convert normalized coordinates to pixel coordinates + regions = [] + for line in changed_lines: + bbox = line.bounding_box + # Convert center-based bbox to corner-based + x1 = (bbox.center_x - bbox.width/2) * img_width - margin + y1 = (bbox.center_y - bbox.height/2) * img_height - margin + x2 = (bbox.center_x + bbox.width/2) * img_width + margin + y2 = (bbox.center_y + bbox.height/2) * img_height + margin + + # Ensure coordinates are within image bounds + x1 = max(0, int(x1)) + y1 = max(0, int(y1)) + x2 = min(img_width, int(x2)) + y2 = min(img_height, int(y2)) + + if x2 > x1 and y2 > y1: #Only add valid regions + regions.append((x1, y1, x2, y2)) + + if not regions: + return None + + # Calculate the bounding box that contains all regions + overall_x1 = min(x1 for x1, y1, x2, y2 in regions) + overall_y1 = min(y1 for x1, y1, x2, y2 in regions) + overall_x2 = max(x2 for x1, y1, x2, y2 in regions) + overall_y2 = max(y2 for x1, y1, x2, y2 in regions) + + # Crop the single rectangle containing all changed regions + result_image = pil_image.crop((overall_x1, overall_y1, overall_x2, overall_y2)) + + return result_image + def __call__(self, img_or_path, filter_text, notify): if auto_pause_handler and not filter_text: auto_pause_handler.stop() + output_format = config.get_general('output_format') + engine_color = config.get_general('engine_color') engine_instance = engine_instances[engine_index] + + if filter_text and engine_index_2 != -1 and engine_index_2 != engine_index: + engine_instance_2 = engine_instances[engine_index_2] + start_time = time.time() + res2, result_data_2 = engine_instance_2(img_or_path) + end_time = time.time() + + if not res2: + logger.opt(ansi=True).warning(f'<{engine_color}>{engine_instance_2.readable_name} reported an error after {end_time - start_time:0.03f}s: {result_data_2}') + else: + changed_lines = self._find_changed_lines(result_data_2, self.previous_result) + + self.previous_result = copy.deepcopy(result_data_2) + + if len(changed_lines) > 0: + logger.opt(ansi=True).info(f"<{engine_color}>{engine_instance_2.readable_name} found {len(changed_lines)} changed line(s) in {end_time - start_time:0.03f}s, re-OCRing with <{engine_color}>{engine_instance.readable_name}") + + if output_format != 'json': + changed_regions_image = self._create_changed_regions_image(img_or_path, changed_lines) + + if changed_regions_image: + img_or_path = changed_regions_image + else: + logger.warning('Error occurred while creating the differential image.') + else: + return + start_time = time.time() res, result_data = engine_instance(img_or_path) end_time = time.time() - orig_text = [] - engine_color = config.get_general('engine_color') if not res: - logger.opt(ansi=True).info(f'<{engine_color}>{engine_instance.readable_name} reported an error after {end_time - start_time:0.03f}s: {result_data}') - return orig_text + logger.opt(ansi=True).warning(f'<{engine_color}>{engine_instance.readable_name} reported an error after {end_time - start_time:0.03f}s: {result_data}') + return - output_format = config.get_general('output_format') verbosity = config.get_general('verbosity') output_string = '' log_message = '' result_data_text = None - - # Check if the engine returned a structured OcrResult object + if isinstance(result_data, OcrResult): unprocessed_text = self._coordinate_format_to_string(result_data) if output_format == 'json': result_dict = asdict(result_data) output_string = json.dumps(result_dict, ensure_ascii=False) - log_message = self._post_process(unprocessed_text) + log_message = self._post_process(unprocessed_text, False) else: result_data_text = unprocessed_text else: @@ -792,9 +897,9 @@ class OutputResult: logger.warning(f"Engine '{engine_instance.name}' does not support JSON output. Falling back to text.") if filter_text: text_to_process = self.filtering(result_data_text) - output_string = self._post_process(text_to_process) + output_string = self._post_process(text_to_process, True) else: - output_string = self._post_process(result_data_text) + output_string = self._post_process(result_data_text, False) log_message = output_string if verbosity != 0: @@ -810,7 +915,6 @@ class OutputResult: if notify and config.get_general('notifications'): notifier.send(title='owocr', message='Text recognized: ' + log_message, urgency=get_notification_urgency()) - # Write the final formatted string to the destination write_to = config.get_general('write_to') if write_to == 'websocket': websocket_server_thread.send_text(output_string) @@ -932,6 +1036,7 @@ def run(): config_engines = [] engine_keys = [] default_engine = '' + engine_secondary = '' if len(config.get_general('engines')) > 0: for config_engine in config.get_general('engines').split(','): @@ -955,12 +1060,15 @@ def run(): engine_keys.append(engine_class.key) if config.get_general('engine') == engine_class.name: default_engine = engine_class.key + if config.get_general('engine_secondary') == engine_class.name and engine_class.local and engine_class.coordinate_support: + engine_secondary = engine_class.key if len(engine_keys) == 0: msg = 'No engines available!' raise NotImplementedError(msg) global engine_index + global engine_index_2 global terminated global paused global notifier @@ -987,6 +1095,7 @@ def run(): init_filtering = False auto_pause_handler = None engine_index = engine_keys.index(default_engine) if default_engine != '' else 0 + engine_index_2 = engine_keys.index(engine_secondary) if engine_secondary != '' else -1 engine_color = config.get_general('engine_color') combo_pause = config.get_general('combo_pause') combo_engine_switch = config.get_general('combo_engine_switch')