Add initial version of two-pass OCR processing
This commit is contained in:
@@ -26,6 +26,8 @@ parser.add_argument('-w', '--write_to', type=str, default=argparse.SUPPRESS,
|
|||||||
help='Where to save recognized texts to. Can be either "clipboard", "websocket", or a path to a text file.')
|
help='Where to save recognized texts to. Can be either "clipboard", "websocket", or a path to a text file.')
|
||||||
parser.add_argument('-e', '--engine', type=str, default=argparse.SUPPRESS,
|
parser.add_argument('-e', '--engine', type=str, default=argparse.SUPPRESS,
|
||||||
help='OCR engine to use. Available: "mangaocr", "glens", "glensweb", "bing", "gvision", "avision", "alivetext", "azure", "winrtocr", "oneocr", "easyocr", "rapidocr", "ocrspace".')
|
help='OCR engine to use. Available: "mangaocr", "glens", "glensweb", "bing", "gvision", "avision", "alivetext", "azure", "winrtocr", "oneocr", "easyocr", "rapidocr", "ocrspace".')
|
||||||
|
parser.add_argument('-es', '--engine_secondary', type=str, default=argparse.SUPPRESS,
|
||||||
|
help='OCR engine to use for two-pass processing.')
|
||||||
parser.add_argument('-p', '--pause_at_startup', type=str2bool, nargs='?', const=True, default=argparse.SUPPRESS,
|
parser.add_argument('-p', '--pause_at_startup', type=str2bool, nargs='?', const=True, default=argparse.SUPPRESS,
|
||||||
help='Pause at startup.')
|
help='Pause at startup.')
|
||||||
parser.add_argument('-i', '--ignore_flag', type=str2bool, nargs='?', const=True, default=argparse.SUPPRESS,
|
parser.add_argument('-i', '--ignore_flag', type=str2bool, nargs='?', const=True, default=argparse.SUPPRESS,
|
||||||
@@ -66,6 +68,7 @@ class Config:
|
|||||||
'read_from_secondary': '',
|
'read_from_secondary': '',
|
||||||
'write_to': 'clipboard',
|
'write_to': 'clipboard',
|
||||||
'engine': '',
|
'engine': '',
|
||||||
|
'engine_secondary': '',
|
||||||
'pause_at_startup': False,
|
'pause_at_startup': False,
|
||||||
'auto_pause' : 0,
|
'auto_pause' : 0,
|
||||||
'ignore_flag': False,
|
'ignore_flag': False,
|
||||||
|
|||||||
31
owocr/ocr.py
31
owocr/ocr.py
@@ -197,6 +197,7 @@ class MangaOcr:
|
|||||||
readable_name = 'Manga OCR'
|
readable_name = 'Manga OCR'
|
||||||
key = 'm'
|
key = 'm'
|
||||||
available = False
|
available = False
|
||||||
|
local = True
|
||||||
manual_language = False
|
manual_language = False
|
||||||
coordinate_support = False
|
coordinate_support = False
|
||||||
|
|
||||||
@@ -229,6 +230,7 @@ class GoogleVision:
|
|||||||
readable_name = 'Google Vision'
|
readable_name = 'Google Vision'
|
||||||
key = 'g'
|
key = 'g'
|
||||||
available = False
|
available = False
|
||||||
|
local = False
|
||||||
manual_language = False
|
manual_language = False
|
||||||
coordinate_support = False
|
coordinate_support = False
|
||||||
|
|
||||||
@@ -275,6 +277,7 @@ class GoogleLens:
|
|||||||
readable_name = 'Google Lens'
|
readable_name = 'Google Lens'
|
||||||
key = 'l'
|
key = 'l'
|
||||||
available = False
|
available = False
|
||||||
|
local = False
|
||||||
manual_language = False
|
manual_language = False
|
||||||
coordinate_support = True
|
coordinate_support = True
|
||||||
|
|
||||||
@@ -421,6 +424,7 @@ class GoogleLensWeb:
|
|||||||
readable_name = 'Google Lens (web)'
|
readable_name = 'Google Lens (web)'
|
||||||
key = 'k'
|
key = 'k'
|
||||||
available = False
|
available = False
|
||||||
|
local = False
|
||||||
manual_language = False
|
manual_language = False
|
||||||
coordinate_support = False
|
coordinate_support = False
|
||||||
|
|
||||||
@@ -518,6 +522,7 @@ class Bing:
|
|||||||
readable_name = 'Bing'
|
readable_name = 'Bing'
|
||||||
key = 'b'
|
key = 'b'
|
||||||
available = False
|
available = False
|
||||||
|
local = False
|
||||||
manual_language = False
|
manual_language = False
|
||||||
coordinate_support = True
|
coordinate_support = True
|
||||||
|
|
||||||
@@ -697,6 +702,7 @@ class AppleVision:
|
|||||||
readable_name = 'Apple Vision'
|
readable_name = 'Apple Vision'
|
||||||
key = 'a'
|
key = 'a'
|
||||||
available = False
|
available = False
|
||||||
|
local = True
|
||||||
manual_language = True
|
manual_language = True
|
||||||
coordinate_support = False
|
coordinate_support = False
|
||||||
|
|
||||||
@@ -748,6 +754,7 @@ class AppleLiveText:
|
|||||||
readable_name = 'Apple Live Text'
|
readable_name = 'Apple Live Text'
|
||||||
key = 'd'
|
key = 'd'
|
||||||
available = False
|
available = False
|
||||||
|
local = True
|
||||||
manual_language = True
|
manual_language = True
|
||||||
coordinate_support = True
|
coordinate_support = True
|
||||||
|
|
||||||
@@ -888,6 +895,7 @@ class WinRTOCR:
|
|||||||
readable_name = 'WinRT OCR'
|
readable_name = 'WinRT OCR'
|
||||||
key = 'w'
|
key = 'w'
|
||||||
available = False
|
available = False
|
||||||
|
local = True
|
||||||
manual_language = True
|
manual_language = True
|
||||||
coordinate_support = False
|
coordinate_support = False
|
||||||
|
|
||||||
@@ -945,6 +953,7 @@ class OneOCR:
|
|||||||
readable_name = 'OneOCR'
|
readable_name = 'OneOCR'
|
||||||
key = 'z'
|
key = 'z'
|
||||||
available = False
|
available = False
|
||||||
|
local = True
|
||||||
manual_language = False
|
manual_language = False
|
||||||
coordinate_support = True
|
coordinate_support = True
|
||||||
|
|
||||||
@@ -1068,6 +1077,7 @@ class AzureImageAnalysis:
|
|||||||
readable_name = 'Azure Image Analysis'
|
readable_name = 'Azure Image Analysis'
|
||||||
key = 'v'
|
key = 'v'
|
||||||
available = False
|
available = False
|
||||||
|
local = False
|
||||||
manual_language = False
|
manual_language = False
|
||||||
coordinate_support = False
|
coordinate_support = False
|
||||||
|
|
||||||
@@ -1123,6 +1133,7 @@ class EasyOCR:
|
|||||||
readable_name = 'EasyOCR'
|
readable_name = 'EasyOCR'
|
||||||
key = 'e'
|
key = 'e'
|
||||||
available = False
|
available = False
|
||||||
|
local = True
|
||||||
manual_language = True
|
manual_language = True
|
||||||
coordinate_support = False
|
coordinate_support = False
|
||||||
|
|
||||||
@@ -1160,6 +1171,7 @@ class RapidOCR:
|
|||||||
readable_name = 'RapidOCR'
|
readable_name = 'RapidOCR'
|
||||||
key = 'r'
|
key = 'r'
|
||||||
available = False
|
available = False
|
||||||
|
local = True
|
||||||
manual_language = True
|
manual_language = True
|
||||||
coordinate_support = False
|
coordinate_support = False
|
||||||
|
|
||||||
@@ -1168,10 +1180,10 @@ class RapidOCR:
|
|||||||
logger.warning('rapidocr not available, RapidOCR will not work!')
|
logger.warning('rapidocr not available, RapidOCR will not work!')
|
||||||
else:
|
else:
|
||||||
logger.info('Loading RapidOCR model')
|
logger.info('Loading RapidOCR model')
|
||||||
lang_det, lang_rec = self.language_to_model_language(language)
|
lang_rec = self.language_to_model_language(language)
|
||||||
self.model = ROCR(params={
|
self.model = ROCR(params={
|
||||||
'Det.engine_type': EngineType.ONNXRUNTIME,
|
'Det.engine_type': EngineType.ONNXRUNTIME,
|
||||||
'Det.lang_type': lang_det,
|
'Det.lang_type': LangDet.CH,
|
||||||
'Det.model_type': ModelType.SERVER if config['high_accuracy_detection'] else ModelType.MOBILE,
|
'Det.model_type': ModelType.SERVER if config['high_accuracy_detection'] else ModelType.MOBILE,
|
||||||
'Det.ocr_version': OCRVersion.PPOCRV5,
|
'Det.ocr_version': OCRVersion.PPOCRV5,
|
||||||
'Rec.engine_type': EngineType.ONNXRUNTIME,
|
'Rec.engine_type': EngineType.ONNXRUNTIME,
|
||||||
@@ -1185,19 +1197,19 @@ class RapidOCR:
|
|||||||
|
|
||||||
def language_to_model_language(self, language):
|
def language_to_model_language(self, language):
|
||||||
if language == 'ja':
|
if language == 'ja':
|
||||||
return LangDet.CH, LangRec.CH
|
return LangRec.CH
|
||||||
if language == 'zh':
|
if language == 'zh':
|
||||||
return LangDet.CH, LangRec.CH
|
return LangRec.CH
|
||||||
elif language == 'ko':
|
elif language == 'ko':
|
||||||
return LangDet.MULTI, LangRec.KOREAN
|
return LangRec.KOREAN
|
||||||
elif language == 'ru':
|
elif language == 'ru':
|
||||||
return LangDet.MULTI, LangRec.ESLAV
|
return LangRec.ESLAV
|
||||||
elif language == 'el':
|
elif language == 'el':
|
||||||
return LangDet.MULTI, LangRec.EL
|
return LangRec.EL
|
||||||
elif language == 'th':
|
elif language == 'th':
|
||||||
return LangDet.MULTI, LangRec.TH
|
return LangRec.TH
|
||||||
else:
|
else:
|
||||||
return LangDet.MULTI, LangRec.LATIN
|
return LangRec.LATIN
|
||||||
|
|
||||||
def __call__(self, img):
|
def __call__(self, img):
|
||||||
img, is_path = input_to_pil_image(img)
|
img, is_path = input_to_pil_image(img)
|
||||||
@@ -1224,6 +1236,7 @@ class OCRSpace:
|
|||||||
readable_name = 'OCRSpace'
|
readable_name = 'OCRSpace'
|
||||||
key = 'o'
|
key = 'o'
|
||||||
available = False
|
available = False
|
||||||
|
local = False
|
||||||
manual_language = True
|
manual_language = True
|
||||||
coordinate_support = False
|
coordinate_support = False
|
||||||
|
|
||||||
|
|||||||
233
owocr/run.py
233
owocr/run.py
@@ -10,6 +10,7 @@ import logging
|
|||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import copy
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -24,6 +25,7 @@ from PIL import Image, UnidentifiedImageError
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pynput import keyboard
|
from pynput import keyboard
|
||||||
from desktop_notifier import DesktopNotifierSync, Urgency
|
from desktop_notifier import DesktopNotifierSync, Urgency
|
||||||
|
from rapidfuzz import fuzz
|
||||||
|
|
||||||
from .ocr import *
|
from .ocr import *
|
||||||
from .config import config
|
from .config import config
|
||||||
@@ -300,38 +302,15 @@ class RequestHandler(socketserver.BaseRequestHandler):
|
|||||||
|
|
||||||
|
|
||||||
class TextFiltering:
|
class TextFiltering:
|
||||||
accurate_filtering = False
|
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
from pysbd import Segmenter
|
from pysbd import Segmenter
|
||||||
|
import langid
|
||||||
self.language = config.get_general('language')
|
self.language = config.get_general('language')
|
||||||
self.segmenter = Segmenter(language=self.language, clean=True)
|
self.segmenter = Segmenter(language=self.language, clean=True)
|
||||||
|
self.classify = langid.classify
|
||||||
self.regex = self.get_regex()
|
self.regex = self.get_regex()
|
||||||
self.last_result = ([], engine_index)
|
self.last_result = ([], engine_index)
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers import pipeline, AutoTokenizer
|
|
||||||
import torch
|
|
||||||
logging.getLogger('transformers').setLevel(logging.ERROR)
|
|
||||||
|
|
||||||
model_ckpt = 'papluca/xlm-roberta-base-language-detection'
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_ckpt,
|
|
||||||
use_fast = False
|
|
||||||
)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = 0
|
|
||||||
elif torch.backends.mps.is_available():
|
|
||||||
device = 'mps'
|
|
||||||
else:
|
|
||||||
device = -1
|
|
||||||
self.pipe = pipeline('text-classification', model=model_ckpt, tokenizer=tokenizer, device=device)
|
|
||||||
self.accurate_filtering = True
|
|
||||||
except:
|
|
||||||
import langid
|
|
||||||
self.classify = langid.classify
|
|
||||||
|
|
||||||
def get_regex(self):
|
def get_regex(self):
|
||||||
if self.language == 'ja':
|
if self.language == 'ja':
|
||||||
return re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]')
|
return re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]')
|
||||||
@@ -354,11 +333,26 @@ class TextFiltering:
|
|||||||
return re.compile(
|
return re.compile(
|
||||||
r'[a-zA-Z\u00C0-\u00FF\u0100-\u017F\u0180-\u024F\u0250-\u02AF\u1D00-\u1D7F\u1D80-\u1DBF\u1E00-\u1EFF\u2C60-\u2C7F\uA720-\uA7FF\uAB30-\uAB6F]')
|
r'[a-zA-Z\u00C0-\u00FF\u0100-\u017F\u0180-\u024F\u0250-\u02AF\u1D00-\u1D7F\u1D80-\u1DBF\u1E00-\u1EFF\u2C60-\u2C7F\uA720-\uA7FF\uAB30-\uAB6F]')
|
||||||
|
|
||||||
|
def convert_small_kana_to_big(self, text):
|
||||||
|
small_to_big = {
|
||||||
|
# Hiragana
|
||||||
|
'ぁ': 'あ', 'ぃ': 'い', 'ぅ': 'う', 'ぇ': 'え', 'ぉ': 'お',
|
||||||
|
'っ': 'つ', 'ゃ': 'や', 'ゅ': 'ゆ', 'ょ': 'よ', 'ゎ': 'わ',
|
||||||
|
# Katakana
|
||||||
|
'ァ': 'ア', 'ィ': 'イ', 'ゥ': 'ウ', 'ェ': 'エ', 'ォ': 'オ',
|
||||||
|
'ッ': 'ツ', 'ャ': 'ヤ', 'ュ': 'ユ', 'ョ': 'ヨ', 'ヮ': 'ワ'
|
||||||
|
}
|
||||||
|
|
||||||
|
converted_text = ''.join(small_to_big.get(char, char) for char in text)
|
||||||
|
return converted_text
|
||||||
|
|
||||||
def __call__(self, text):
|
def __call__(self, text):
|
||||||
orig_text = self.segmenter.segment(text)
|
orig_text = self.segmenter.segment(text)
|
||||||
orig_text_filtered = []
|
orig_text_filtered = []
|
||||||
for block in orig_text:
|
for block in orig_text:
|
||||||
block_filtered = self.regex.findall(block)
|
block_filtered = self.regex.findall(block)
|
||||||
|
if self.language == 'ja':
|
||||||
|
block_filtered = self.convert_small_kana_to_big(block_filtered)
|
||||||
|
|
||||||
if block_filtered:
|
if block_filtered:
|
||||||
orig_text_filtered.append(''.join(block_filtered))
|
orig_text_filtered.append(''.join(block_filtered))
|
||||||
@@ -376,18 +370,10 @@ class TextFiltering:
|
|||||||
new_blocks.append(block)
|
new_blocks.append(block)
|
||||||
|
|
||||||
final_blocks = []
|
final_blocks = []
|
||||||
if self.accurate_filtering:
|
for block in new_blocks:
|
||||||
detection_results = self.pipe(new_blocks, top_k=3, truncation=True)
|
# This only looks at language IF language is ja or zh, otherwise it keeps all text
|
||||||
for idx, block in enumerate(new_blocks):
|
if self.language not in ['ja', 'zh'] or self.classify(block)[0] in ['ja', 'zh'] or block == "\n":
|
||||||
for result in detection_results[idx]:
|
final_blocks.append(block)
|
||||||
if result['label'] == self.language:
|
|
||||||
final_blocks.append(block)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
for block in new_blocks:
|
|
||||||
# This only looks at language IF language is ja or zh, otherwise it keeps all text
|
|
||||||
if self.language not in ["ja", "zh"] or self.classify(block)[0] in ['ja', 'zh'] or block == "\n":
|
|
||||||
final_blocks.append(block)
|
|
||||||
|
|
||||||
text = '\n'.join(final_blocks)
|
text = '\n'.join(final_blocks)
|
||||||
|
|
||||||
@@ -675,7 +661,7 @@ class ScreenshotThread(threading.Thread):
|
|||||||
try:
|
try:
|
||||||
win32gui.ReleaseDC(self.window_handle, hwnd_dc)
|
win32gui.ReleaseDC(self.window_handle, hwnd_dc)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
sct_img = sct.grab(self.sct_params)
|
sct_img = sct.grab(self.sct_params)
|
||||||
img = Image.frombytes('RGB', sct_img.size, sct_img.bgra, 'raw', 'BGRX')
|
img = Image.frombytes('RGB', sct_img.size, sct_img.bgra, 'raw', 'BGRX')
|
||||||
@@ -724,64 +710,183 @@ class OutputResult:
|
|||||||
def __init__(self, init_filtering):
|
def __init__(self, init_filtering):
|
||||||
self.filtering = TextFiltering() if init_filtering else None
|
self.filtering = TextFiltering() if init_filtering else None
|
||||||
self.cj_regex = re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]')
|
self.cj_regex = re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]')
|
||||||
|
self.previous_result = None
|
||||||
|
|
||||||
def _coordinate_format_to_string(self, result_data):
|
def _coordinate_format_to_string(self, result_data):
|
||||||
full_text_parts = []
|
full_text_parts = []
|
||||||
for p in result_data.paragraphs:
|
for p in result_data.paragraphs:
|
||||||
for l in p.lines:
|
for l in p.lines:
|
||||||
if l.text != None:
|
full_text_parts.append(self._get_line_text(l))
|
||||||
full_text_parts.append(l.text)
|
|
||||||
else:
|
|
||||||
for w in l.words:
|
|
||||||
full_text_parts.append(w.text)
|
|
||||||
if w.separator != None:
|
|
||||||
full_text_parts.append(w.separator)
|
|
||||||
else:
|
|
||||||
full_text_parts.append(' ')
|
|
||||||
full_text_parts.append('\n')
|
full_text_parts.append('\n')
|
||||||
return "".join(full_text_parts)
|
return ''.join(full_text_parts)
|
||||||
|
|
||||||
def _post_process(self, text):
|
def _post_process(self, text, strip_spaces):
|
||||||
is_cj_text = self.cj_regex.search(text)
|
is_cj_text = self.cj_regex.search(text)
|
||||||
|
line_separator = '' if strip_spaces else ' '
|
||||||
if is_cj_text:
|
if is_cj_text:
|
||||||
text = ' '.join([''.join(i.split()) for i in text.splitlines()])
|
text = line_separator.join([''.join(i.split()) for i in text.splitlines()])
|
||||||
else:
|
else:
|
||||||
text = ' '.join([re.sub(r'\s+', ' ', i).strip() for i in text.splitlines()])
|
text = line_separator.join([re.sub(r'\s+', ' ', i).strip() for i in text.splitlines()])
|
||||||
text = text.replace('…', '...')
|
text = text.replace('…', '...')
|
||||||
text = re.sub('[・.]{2,}', lambda x: (x.end() - x.start()) * '.', text)
|
text = re.sub('[・.]{2,}', lambda x: (x.end() - x.start()) * '.', text)
|
||||||
if is_cj_text:
|
if is_cj_text:
|
||||||
text = jaconv.h2z(text, ascii=True, digit=True)
|
text = jaconv.h2z(text, ascii=True, digit=True)
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
def _get_line_text(self, line):
|
||||||
|
if line.text is not None:
|
||||||
|
return line.text
|
||||||
|
text_parts = []
|
||||||
|
for w in line.words:
|
||||||
|
text_parts.append(w.text)
|
||||||
|
if w.separator is not None:
|
||||||
|
text_parts.append(w.separator)
|
||||||
|
else:
|
||||||
|
text_parts.append(' ')
|
||||||
|
return ''.join(text_parts)
|
||||||
|
|
||||||
|
def _compare_text(self, current_text, prev_text, threshold=80):
|
||||||
|
if current_text in prev_text:
|
||||||
|
return True
|
||||||
|
if len(prev_text) > len(current_text):
|
||||||
|
return fuzz.partial_ratio(current_text, prev_text) >= threshold
|
||||||
|
return fuzz.ratio(current_text, prev_text) >= threshold
|
||||||
|
|
||||||
|
def _find_changed_lines(self, current_result, previous_result):
|
||||||
|
changed_lines = []
|
||||||
|
|
||||||
|
# If no previous result, all lines are considered changed
|
||||||
|
if previous_result is None:
|
||||||
|
for p in current_result.paragraphs:
|
||||||
|
changed_lines.extend(p.lines)
|
||||||
|
return changed_lines
|
||||||
|
|
||||||
|
# Check if image sizes are different - if so, treat all lines as changed
|
||||||
|
if (current_result.image_properties.width != previous_result.image_properties.width or
|
||||||
|
current_result.image_properties.height != previous_result.image_properties.height):
|
||||||
|
for p in current_result.paragraphs:
|
||||||
|
changed_lines.extend(p.lines)
|
||||||
|
return changed_lines
|
||||||
|
|
||||||
|
current_lines = []
|
||||||
|
previous_lines = []
|
||||||
|
|
||||||
|
for p in current_result.paragraphs:
|
||||||
|
current_lines.extend(p.lines)
|
||||||
|
for p in previous_result.paragraphs:
|
||||||
|
previous_lines.extend(p.lines)
|
||||||
|
|
||||||
|
all_previous_text = ''
|
||||||
|
for prev_line in previous_lines:
|
||||||
|
prev_text = self._get_line_text(prev_line)
|
||||||
|
prev_text = ''.join(self.filtering.regex.findall(prev_text))
|
||||||
|
if self.filtering.language == 'ja':
|
||||||
|
prev_text = self.filtering.convert_small_kana_to_big(prev_text)
|
||||||
|
all_previous_text += prev_text
|
||||||
|
|
||||||
|
for current_line in current_lines:
|
||||||
|
current_text = self._get_line_text(current_line)
|
||||||
|
current_text = ''.join(self.filtering.regex.findall(current_text))
|
||||||
|
if self.filtering.language == 'ja':
|
||||||
|
current_text = self.filtering.convert_small_kana_to_big(current_text)
|
||||||
|
|
||||||
|
text_similar = self._compare_text(current_text, all_previous_text)
|
||||||
|
|
||||||
|
if not text_similar:
|
||||||
|
changed_lines.append(current_line)
|
||||||
|
|
||||||
|
return changed_lines
|
||||||
|
|
||||||
|
def _create_changed_regions_image(self, pil_image, changed_lines, margin=5):
|
||||||
|
img_width, img_height = pil_image.size
|
||||||
|
|
||||||
|
# Convert normalized coordinates to pixel coordinates
|
||||||
|
regions = []
|
||||||
|
for line in changed_lines:
|
||||||
|
bbox = line.bounding_box
|
||||||
|
# Convert center-based bbox to corner-based
|
||||||
|
x1 = (bbox.center_x - bbox.width/2) * img_width - margin
|
||||||
|
y1 = (bbox.center_y - bbox.height/2) * img_height - margin
|
||||||
|
x2 = (bbox.center_x + bbox.width/2) * img_width + margin
|
||||||
|
y2 = (bbox.center_y + bbox.height/2) * img_height + margin
|
||||||
|
|
||||||
|
# Ensure coordinates are within image bounds
|
||||||
|
x1 = max(0, int(x1))
|
||||||
|
y1 = max(0, int(y1))
|
||||||
|
x2 = min(img_width, int(x2))
|
||||||
|
y2 = min(img_height, int(y2))
|
||||||
|
|
||||||
|
if x2 > x1 and y2 > y1: #Only add valid regions
|
||||||
|
regions.append((x1, y1, x2, y2))
|
||||||
|
|
||||||
|
if not regions:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Calculate the bounding box that contains all regions
|
||||||
|
overall_x1 = min(x1 for x1, y1, x2, y2 in regions)
|
||||||
|
overall_y1 = min(y1 for x1, y1, x2, y2 in regions)
|
||||||
|
overall_x2 = max(x2 for x1, y1, x2, y2 in regions)
|
||||||
|
overall_y2 = max(y2 for x1, y1, x2, y2 in regions)
|
||||||
|
|
||||||
|
# Crop the single rectangle containing all changed regions
|
||||||
|
result_image = pil_image.crop((overall_x1, overall_y1, overall_x2, overall_y2))
|
||||||
|
|
||||||
|
return result_image
|
||||||
|
|
||||||
def __call__(self, img_or_path, filter_text, notify):
|
def __call__(self, img_or_path, filter_text, notify):
|
||||||
if auto_pause_handler and not filter_text:
|
if auto_pause_handler and not filter_text:
|
||||||
auto_pause_handler.stop()
|
auto_pause_handler.stop()
|
||||||
|
|
||||||
|
output_format = config.get_general('output_format')
|
||||||
|
engine_color = config.get_general('engine_color')
|
||||||
engine_instance = engine_instances[engine_index]
|
engine_instance = engine_instances[engine_index]
|
||||||
|
|
||||||
|
if filter_text and engine_index_2 != -1 and engine_index_2 != engine_index:
|
||||||
|
engine_instance_2 = engine_instances[engine_index_2]
|
||||||
|
start_time = time.time()
|
||||||
|
res2, result_data_2 = engine_instance_2(img_or_path)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
if not res2:
|
||||||
|
logger.opt(ansi=True).warning(f'<{engine_color}>{engine_instance_2.readable_name}</{engine_color}> reported an error after {end_time - start_time:0.03f}s: {result_data_2}')
|
||||||
|
else:
|
||||||
|
changed_lines = self._find_changed_lines(result_data_2, self.previous_result)
|
||||||
|
|
||||||
|
self.previous_result = copy.deepcopy(result_data_2)
|
||||||
|
|
||||||
|
if len(changed_lines) > 0:
|
||||||
|
logger.opt(ansi=True).info(f"<{engine_color}>{engine_instance_2.readable_name}</{engine_color}> found {len(changed_lines)} changed line(s) in {end_time - start_time:0.03f}s, re-OCRing with <{engine_color}>{engine_instance.readable_name}</{engine_color}>")
|
||||||
|
|
||||||
|
if output_format != 'json':
|
||||||
|
changed_regions_image = self._create_changed_regions_image(img_or_path, changed_lines)
|
||||||
|
|
||||||
|
if changed_regions_image:
|
||||||
|
img_or_path = changed_regions_image
|
||||||
|
else:
|
||||||
|
logger.warning('Error occurred while creating the differential image.')
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
res, result_data = engine_instance(img_or_path)
|
res, result_data = engine_instance(img_or_path)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
orig_text = []
|
|
||||||
engine_color = config.get_general('engine_color')
|
|
||||||
if not res:
|
if not res:
|
||||||
logger.opt(ansi=True).info(f'<{engine_color}>{engine_instance.readable_name}</{engine_color}> reported an error after {end_time - start_time:0.03f}s: {result_data}')
|
logger.opt(ansi=True).warning(f'<{engine_color}>{engine_instance.readable_name}</{engine_color}> reported an error after {end_time - start_time:0.03f}s: {result_data}')
|
||||||
return orig_text
|
return
|
||||||
|
|
||||||
output_format = config.get_general('output_format')
|
|
||||||
verbosity = config.get_general('verbosity')
|
verbosity = config.get_general('verbosity')
|
||||||
output_string = ''
|
output_string = ''
|
||||||
log_message = ''
|
log_message = ''
|
||||||
result_data_text = None
|
result_data_text = None
|
||||||
|
|
||||||
# Check if the engine returned a structured OcrResult object
|
|
||||||
if isinstance(result_data, OcrResult):
|
if isinstance(result_data, OcrResult):
|
||||||
unprocessed_text = self._coordinate_format_to_string(result_data)
|
unprocessed_text = self._coordinate_format_to_string(result_data)
|
||||||
|
|
||||||
if output_format == 'json':
|
if output_format == 'json':
|
||||||
result_dict = asdict(result_data)
|
result_dict = asdict(result_data)
|
||||||
output_string = json.dumps(result_dict, ensure_ascii=False)
|
output_string = json.dumps(result_dict, ensure_ascii=False)
|
||||||
log_message = self._post_process(unprocessed_text)
|
log_message = self._post_process(unprocessed_text, False)
|
||||||
else:
|
else:
|
||||||
result_data_text = unprocessed_text
|
result_data_text = unprocessed_text
|
||||||
else:
|
else:
|
||||||
@@ -792,9 +897,9 @@ class OutputResult:
|
|||||||
logger.warning(f"Engine '{engine_instance.name}' does not support JSON output. Falling back to text.")
|
logger.warning(f"Engine '{engine_instance.name}' does not support JSON output. Falling back to text.")
|
||||||
if filter_text:
|
if filter_text:
|
||||||
text_to_process = self.filtering(result_data_text)
|
text_to_process = self.filtering(result_data_text)
|
||||||
output_string = self._post_process(text_to_process)
|
output_string = self._post_process(text_to_process, True)
|
||||||
else:
|
else:
|
||||||
output_string = self._post_process(result_data_text)
|
output_string = self._post_process(result_data_text, False)
|
||||||
log_message = output_string
|
log_message = output_string
|
||||||
|
|
||||||
if verbosity != 0:
|
if verbosity != 0:
|
||||||
@@ -810,7 +915,6 @@ class OutputResult:
|
|||||||
if notify and config.get_general('notifications'):
|
if notify and config.get_general('notifications'):
|
||||||
notifier.send(title='owocr', message='Text recognized: ' + log_message, urgency=get_notification_urgency())
|
notifier.send(title='owocr', message='Text recognized: ' + log_message, urgency=get_notification_urgency())
|
||||||
|
|
||||||
# Write the final formatted string to the destination
|
|
||||||
write_to = config.get_general('write_to')
|
write_to = config.get_general('write_to')
|
||||||
if write_to == 'websocket':
|
if write_to == 'websocket':
|
||||||
websocket_server_thread.send_text(output_string)
|
websocket_server_thread.send_text(output_string)
|
||||||
@@ -932,6 +1036,7 @@ def run():
|
|||||||
config_engines = []
|
config_engines = []
|
||||||
engine_keys = []
|
engine_keys = []
|
||||||
default_engine = ''
|
default_engine = ''
|
||||||
|
engine_secondary = ''
|
||||||
|
|
||||||
if len(config.get_general('engines')) > 0:
|
if len(config.get_general('engines')) > 0:
|
||||||
for config_engine in config.get_general('engines').split(','):
|
for config_engine in config.get_general('engines').split(','):
|
||||||
@@ -955,12 +1060,15 @@ def run():
|
|||||||
engine_keys.append(engine_class.key)
|
engine_keys.append(engine_class.key)
|
||||||
if config.get_general('engine') == engine_class.name:
|
if config.get_general('engine') == engine_class.name:
|
||||||
default_engine = engine_class.key
|
default_engine = engine_class.key
|
||||||
|
if config.get_general('engine_secondary') == engine_class.name and engine_class.local and engine_class.coordinate_support:
|
||||||
|
engine_secondary = engine_class.key
|
||||||
|
|
||||||
if len(engine_keys) == 0:
|
if len(engine_keys) == 0:
|
||||||
msg = 'No engines available!'
|
msg = 'No engines available!'
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
global engine_index
|
global engine_index
|
||||||
|
global engine_index_2
|
||||||
global terminated
|
global terminated
|
||||||
global paused
|
global paused
|
||||||
global notifier
|
global notifier
|
||||||
@@ -987,6 +1095,7 @@ def run():
|
|||||||
init_filtering = False
|
init_filtering = False
|
||||||
auto_pause_handler = None
|
auto_pause_handler = None
|
||||||
engine_index = engine_keys.index(default_engine) if default_engine != '' else 0
|
engine_index = engine_keys.index(default_engine) if default_engine != '' else 0
|
||||||
|
engine_index_2 = engine_keys.index(engine_secondary) if engine_secondary != '' else -1
|
||||||
engine_color = config.get_general('engine_color')
|
engine_color = config.get_general('engine_color')
|
||||||
combo_pause = config.get_general('combo_pause')
|
combo_pause = config.get_general('combo_pause')
|
||||||
combo_engine_switch = config.get_general('combo_engine_switch')
|
combo_engine_switch = config.get_general('combo_engine_switch')
|
||||||
|
|||||||
Reference in New Issue
Block a user