Add initial version of two-pass OCR processing

This commit is contained in:
AuroraWright
2025-10-06 21:50:31 +02:00
parent 1921ecc849
commit ed9b05d2e0
3 changed files with 196 additions and 71 deletions

View File

@@ -26,6 +26,8 @@ parser.add_argument('-w', '--write_to', type=str, default=argparse.SUPPRESS,
help='Where to save recognized texts to. Can be either "clipboard", "websocket", or a path to a text file.')
parser.add_argument('-e', '--engine', type=str, default=argparse.SUPPRESS,
help='OCR engine to use. Available: "mangaocr", "glens", "glensweb", "bing", "gvision", "avision", "alivetext", "azure", "winrtocr", "oneocr", "easyocr", "rapidocr", "ocrspace".')
parser.add_argument('-es', '--engine_secondary', type=str, default=argparse.SUPPRESS,
help='OCR engine to use for two-pass processing.')
parser.add_argument('-p', '--pause_at_startup', type=str2bool, nargs='?', const=True, default=argparse.SUPPRESS,
help='Pause at startup.')
parser.add_argument('-i', '--ignore_flag', type=str2bool, nargs='?', const=True, default=argparse.SUPPRESS,
@@ -66,6 +68,7 @@ class Config:
'read_from_secondary': '',
'write_to': 'clipboard',
'engine': '',
'engine_secondary': '',
'pause_at_startup': False,
'auto_pause' : 0,
'ignore_flag': False,

View File

@@ -197,6 +197,7 @@ class MangaOcr:
readable_name = 'Manga OCR'
key = 'm'
available = False
local = True
manual_language = False
coordinate_support = False
@@ -229,6 +230,7 @@ class GoogleVision:
readable_name = 'Google Vision'
key = 'g'
available = False
local = False
manual_language = False
coordinate_support = False
@@ -275,6 +277,7 @@ class GoogleLens:
readable_name = 'Google Lens'
key = 'l'
available = False
local = False
manual_language = False
coordinate_support = True
@@ -421,6 +424,7 @@ class GoogleLensWeb:
readable_name = 'Google Lens (web)'
key = 'k'
available = False
local = False
manual_language = False
coordinate_support = False
@@ -518,6 +522,7 @@ class Bing:
readable_name = 'Bing'
key = 'b'
available = False
local = False
manual_language = False
coordinate_support = True
@@ -697,6 +702,7 @@ class AppleVision:
readable_name = 'Apple Vision'
key = 'a'
available = False
local = True
manual_language = True
coordinate_support = False
@@ -748,6 +754,7 @@ class AppleLiveText:
readable_name = 'Apple Live Text'
key = 'd'
available = False
local = True
manual_language = True
coordinate_support = True
@@ -888,6 +895,7 @@ class WinRTOCR:
readable_name = 'WinRT OCR'
key = 'w'
available = False
local = True
manual_language = True
coordinate_support = False
@@ -945,6 +953,7 @@ class OneOCR:
readable_name = 'OneOCR'
key = 'z'
available = False
local = True
manual_language = False
coordinate_support = True
@@ -1068,6 +1077,7 @@ class AzureImageAnalysis:
readable_name = 'Azure Image Analysis'
key = 'v'
available = False
local = False
manual_language = False
coordinate_support = False
@@ -1123,6 +1133,7 @@ class EasyOCR:
readable_name = 'EasyOCR'
key = 'e'
available = False
local = True
manual_language = True
coordinate_support = False
@@ -1160,6 +1171,7 @@ class RapidOCR:
readable_name = 'RapidOCR'
key = 'r'
available = False
local = True
manual_language = True
coordinate_support = False
@@ -1168,10 +1180,10 @@ class RapidOCR:
logger.warning('rapidocr not available, RapidOCR will not work!')
else:
logger.info('Loading RapidOCR model')
lang_det, lang_rec = self.language_to_model_language(language)
lang_rec = self.language_to_model_language(language)
self.model = ROCR(params={
'Det.engine_type': EngineType.ONNXRUNTIME,
'Det.lang_type': lang_det,
'Det.lang_type': LangDet.CH,
'Det.model_type': ModelType.SERVER if config['high_accuracy_detection'] else ModelType.MOBILE,
'Det.ocr_version': OCRVersion.PPOCRV5,
'Rec.engine_type': EngineType.ONNXRUNTIME,
@@ -1185,19 +1197,19 @@ class RapidOCR:
def language_to_model_language(self, language):
if language == 'ja':
return LangDet.CH, LangRec.CH
return LangRec.CH
if language == 'zh':
return LangDet.CH, LangRec.CH
return LangRec.CH
elif language == 'ko':
return LangDet.MULTI, LangRec.KOREAN
return LangRec.KOREAN
elif language == 'ru':
return LangDet.MULTI, LangRec.ESLAV
return LangRec.ESLAV
elif language == 'el':
return LangDet.MULTI, LangRec.EL
return LangRec.EL
elif language == 'th':
return LangDet.MULTI, LangRec.TH
return LangRec.TH
else:
return LangDet.MULTI, LangRec.LATIN
return LangRec.LATIN
def __call__(self, img):
img, is_path = input_to_pil_image(img)
@@ -1224,6 +1236,7 @@ class OCRSpace:
readable_name = 'OCRSpace'
key = 'o'
available = False
local = False
manual_language = True
coordinate_support = False

View File

@@ -10,6 +10,7 @@ import logging
import inspect
import os
import json
import copy
from dataclasses import asdict
import numpy as np
@@ -24,6 +25,7 @@ from PIL import Image, UnidentifiedImageError
from loguru import logger
from pynput import keyboard
from desktop_notifier import DesktopNotifierSync, Urgency
from rapidfuzz import fuzz
from .ocr import *
from .config import config
@@ -300,38 +302,15 @@ class RequestHandler(socketserver.BaseRequestHandler):
class TextFiltering:
accurate_filtering = False
def __init__(self):
from pysbd import Segmenter
import langid
self.language = config.get_general('language')
self.segmenter = Segmenter(language=self.language, clean=True)
self.classify = langid.classify
self.regex = self.get_regex()
self.last_result = ([], engine_index)
try:
from transformers import pipeline, AutoTokenizer
import torch
logging.getLogger('transformers').setLevel(logging.ERROR)
model_ckpt = 'papluca/xlm-roberta-base-language-detection'
tokenizer = AutoTokenizer.from_pretrained(
model_ckpt,
use_fast = False
)
if torch.cuda.is_available():
device = 0
elif torch.backends.mps.is_available():
device = 'mps'
else:
device = -1
self.pipe = pipeline('text-classification', model=model_ckpt, tokenizer=tokenizer, device=device)
self.accurate_filtering = True
except:
import langid
self.classify = langid.classify
def get_regex(self):
if self.language == 'ja':
return re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]')
@@ -354,11 +333,26 @@ class TextFiltering:
return re.compile(
r'[a-zA-Z\u00C0-\u00FF\u0100-\u017F\u0180-\u024F\u0250-\u02AF\u1D00-\u1D7F\u1D80-\u1DBF\u1E00-\u1EFF\u2C60-\u2C7F\uA720-\uA7FF\uAB30-\uAB6F]')
def convert_small_kana_to_big(self, text):
small_to_big = {
# Hiragana
'': '', '': '', '': '', '': '', '': '',
'': '', '': '', '': '', '': '', '': '',
# Katakana
'': '', '': '', '': '', '': '', '': '',
'': '', '': '', '': '', '': '', '': ''
}
converted_text = ''.join(small_to_big.get(char, char) for char in text)
return converted_text
def __call__(self, text):
orig_text = self.segmenter.segment(text)
orig_text_filtered = []
for block in orig_text:
block_filtered = self.regex.findall(block)
if self.language == 'ja':
block_filtered = self.convert_small_kana_to_big(block_filtered)
if block_filtered:
orig_text_filtered.append(''.join(block_filtered))
@@ -376,18 +370,10 @@ class TextFiltering:
new_blocks.append(block)
final_blocks = []
if self.accurate_filtering:
detection_results = self.pipe(new_blocks, top_k=3, truncation=True)
for idx, block in enumerate(new_blocks):
for result in detection_results[idx]:
if result['label'] == self.language:
final_blocks.append(block)
break
else:
for block in new_blocks:
# This only looks at language IF language is ja or zh, otherwise it keeps all text
if self.language not in ["ja", "zh"] or self.classify(block)[0] in ['ja', 'zh'] or block == "\n":
final_blocks.append(block)
for block in new_blocks:
# This only looks at language IF language is ja or zh, otherwise it keeps all text
if self.language not in ['ja', 'zh'] or self.classify(block)[0] in ['ja', 'zh'] or block == "\n":
final_blocks.append(block)
text = '\n'.join(final_blocks)
@@ -675,7 +661,7 @@ class ScreenshotThread(threading.Thread):
try:
win32gui.ReleaseDC(self.window_handle, hwnd_dc)
except:
pass
pass
else:
sct_img = sct.grab(self.sct_params)
img = Image.frombytes('RGB', sct_img.size, sct_img.bgra, 'raw', 'BGRX')
@@ -724,64 +710,183 @@ class OutputResult:
def __init__(self, init_filtering):
self.filtering = TextFiltering() if init_filtering else None
self.cj_regex = re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]')
self.previous_result = None
def _coordinate_format_to_string(self, result_data):
full_text_parts = []
for p in result_data.paragraphs:
for l in p.lines:
if l.text != None:
full_text_parts.append(l.text)
else:
for w in l.words:
full_text_parts.append(w.text)
if w.separator != None:
full_text_parts.append(w.separator)
else:
full_text_parts.append(' ')
full_text_parts.append(self._get_line_text(l))
full_text_parts.append('\n')
return "".join(full_text_parts)
return ''.join(full_text_parts)
def _post_process(self, text):
def _post_process(self, text, strip_spaces):
is_cj_text = self.cj_regex.search(text)
line_separator = '' if strip_spaces else ' '
if is_cj_text:
text = ' '.join([''.join(i.split()) for i in text.splitlines()])
text = line_separator.join([''.join(i.split()) for i in text.splitlines()])
else:
text = ' '.join([re.sub(r'\s+', ' ', i).strip() for i in text.splitlines()])
text = line_separator.join([re.sub(r'\s+', ' ', i).strip() for i in text.splitlines()])
text = text.replace('', '...')
text = re.sub('[・.]{2,}', lambda x: (x.end() - x.start()) * '.', text)
if is_cj_text:
text = jaconv.h2z(text, ascii=True, digit=True)
return text
def _get_line_text(self, line):
if line.text is not None:
return line.text
text_parts = []
for w in line.words:
text_parts.append(w.text)
if w.separator is not None:
text_parts.append(w.separator)
else:
text_parts.append(' ')
return ''.join(text_parts)
def _compare_text(self, current_text, prev_text, threshold=80):
if current_text in prev_text:
return True
if len(prev_text) > len(current_text):
return fuzz.partial_ratio(current_text, prev_text) >= threshold
return fuzz.ratio(current_text, prev_text) >= threshold
def _find_changed_lines(self, current_result, previous_result):
changed_lines = []
# If no previous result, all lines are considered changed
if previous_result is None:
for p in current_result.paragraphs:
changed_lines.extend(p.lines)
return changed_lines
# Check if image sizes are different - if so, treat all lines as changed
if (current_result.image_properties.width != previous_result.image_properties.width or
current_result.image_properties.height != previous_result.image_properties.height):
for p in current_result.paragraphs:
changed_lines.extend(p.lines)
return changed_lines
current_lines = []
previous_lines = []
for p in current_result.paragraphs:
current_lines.extend(p.lines)
for p in previous_result.paragraphs:
previous_lines.extend(p.lines)
all_previous_text = ''
for prev_line in previous_lines:
prev_text = self._get_line_text(prev_line)
prev_text = ''.join(self.filtering.regex.findall(prev_text))
if self.filtering.language == 'ja':
prev_text = self.filtering.convert_small_kana_to_big(prev_text)
all_previous_text += prev_text
for current_line in current_lines:
current_text = self._get_line_text(current_line)
current_text = ''.join(self.filtering.regex.findall(current_text))
if self.filtering.language == 'ja':
current_text = self.filtering.convert_small_kana_to_big(current_text)
text_similar = self._compare_text(current_text, all_previous_text)
if not text_similar:
changed_lines.append(current_line)
return changed_lines
def _create_changed_regions_image(self, pil_image, changed_lines, margin=5):
img_width, img_height = pil_image.size
# Convert normalized coordinates to pixel coordinates
regions = []
for line in changed_lines:
bbox = line.bounding_box
# Convert center-based bbox to corner-based
x1 = (bbox.center_x - bbox.width/2) * img_width - margin
y1 = (bbox.center_y - bbox.height/2) * img_height - margin
x2 = (bbox.center_x + bbox.width/2) * img_width + margin
y2 = (bbox.center_y + bbox.height/2) * img_height + margin
# Ensure coordinates are within image bounds
x1 = max(0, int(x1))
y1 = max(0, int(y1))
x2 = min(img_width, int(x2))
y2 = min(img_height, int(y2))
if x2 > x1 and y2 > y1: #Only add valid regions
regions.append((x1, y1, x2, y2))
if not regions:
return None
# Calculate the bounding box that contains all regions
overall_x1 = min(x1 for x1, y1, x2, y2 in regions)
overall_y1 = min(y1 for x1, y1, x2, y2 in regions)
overall_x2 = max(x2 for x1, y1, x2, y2 in regions)
overall_y2 = max(y2 for x1, y1, x2, y2 in regions)
# Crop the single rectangle containing all changed regions
result_image = pil_image.crop((overall_x1, overall_y1, overall_x2, overall_y2))
return result_image
def __call__(self, img_or_path, filter_text, notify):
if auto_pause_handler and not filter_text:
auto_pause_handler.stop()
output_format = config.get_general('output_format')
engine_color = config.get_general('engine_color')
engine_instance = engine_instances[engine_index]
if filter_text and engine_index_2 != -1 and engine_index_2 != engine_index:
engine_instance_2 = engine_instances[engine_index_2]
start_time = time.time()
res2, result_data_2 = engine_instance_2(img_or_path)
end_time = time.time()
if not res2:
logger.opt(ansi=True).warning(f'<{engine_color}>{engine_instance_2.readable_name}</{engine_color}> reported an error after {end_time - start_time:0.03f}s: {result_data_2}')
else:
changed_lines = self._find_changed_lines(result_data_2, self.previous_result)
self.previous_result = copy.deepcopy(result_data_2)
if len(changed_lines) > 0:
logger.opt(ansi=True).info(f"<{engine_color}>{engine_instance_2.readable_name}</{engine_color}> found {len(changed_lines)} changed line(s) in {end_time - start_time:0.03f}s, re-OCRing with <{engine_color}>{engine_instance.readable_name}</{engine_color}>")
if output_format != 'json':
changed_regions_image = self._create_changed_regions_image(img_or_path, changed_lines)
if changed_regions_image:
img_or_path = changed_regions_image
else:
logger.warning('Error occurred while creating the differential image.')
else:
return
start_time = time.time()
res, result_data = engine_instance(img_or_path)
end_time = time.time()
orig_text = []
engine_color = config.get_general('engine_color')
if not res:
logger.opt(ansi=True).info(f'<{engine_color}>{engine_instance.readable_name}</{engine_color}> reported an error after {end_time - start_time:0.03f}s: {result_data}')
return orig_text
logger.opt(ansi=True).warning(f'<{engine_color}>{engine_instance.readable_name}</{engine_color}> reported an error after {end_time - start_time:0.03f}s: {result_data}')
return
output_format = config.get_general('output_format')
verbosity = config.get_general('verbosity')
output_string = ''
log_message = ''
result_data_text = None
# Check if the engine returned a structured OcrResult object
if isinstance(result_data, OcrResult):
unprocessed_text = self._coordinate_format_to_string(result_data)
if output_format == 'json':
result_dict = asdict(result_data)
output_string = json.dumps(result_dict, ensure_ascii=False)
log_message = self._post_process(unprocessed_text)
log_message = self._post_process(unprocessed_text, False)
else:
result_data_text = unprocessed_text
else:
@@ -792,9 +897,9 @@ class OutputResult:
logger.warning(f"Engine '{engine_instance.name}' does not support JSON output. Falling back to text.")
if filter_text:
text_to_process = self.filtering(result_data_text)
output_string = self._post_process(text_to_process)
output_string = self._post_process(text_to_process, True)
else:
output_string = self._post_process(result_data_text)
output_string = self._post_process(result_data_text, False)
log_message = output_string
if verbosity != 0:
@@ -810,7 +915,6 @@ class OutputResult:
if notify and config.get_general('notifications'):
notifier.send(title='owocr', message='Text recognized: ' + log_message, urgency=get_notification_urgency())
# Write the final formatted string to the destination
write_to = config.get_general('write_to')
if write_to == 'websocket':
websocket_server_thread.send_text(output_string)
@@ -932,6 +1036,7 @@ def run():
config_engines = []
engine_keys = []
default_engine = ''
engine_secondary = ''
if len(config.get_general('engines')) > 0:
for config_engine in config.get_general('engines').split(','):
@@ -955,12 +1060,15 @@ def run():
engine_keys.append(engine_class.key)
if config.get_general('engine') == engine_class.name:
default_engine = engine_class.key
if config.get_general('engine_secondary') == engine_class.name and engine_class.local and engine_class.coordinate_support:
engine_secondary = engine_class.key
if len(engine_keys) == 0:
msg = 'No engines available!'
raise NotImplementedError(msg)
global engine_index
global engine_index_2
global terminated
global paused
global notifier
@@ -987,6 +1095,7 @@ def run():
init_filtering = False
auto_pause_handler = None
engine_index = engine_keys.index(default_engine) if default_engine != '' else 0
engine_index_2 = engine_keys.index(engine_secondary) if engine_secondary != '' else -1
engine_color = config.get_general('engine_color')
combo_pause = config.get_general('combo_pause')
combo_engine_switch = config.get_general('combo_engine_switch')