From 6b9b5c935181afbb3a808cd948ae0a95687e40f8 Mon Sep 17 00:00:00 2001 From: AuroraWright Date: Wed, 26 Jun 2024 07:15:03 +0200 Subject: [PATCH] Improve filtering --- owocr/run.py | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/owocr/run.py b/owocr/run.py index 872bd9a..db1b5c4 100644 --- a/owocr/run.py +++ b/owocr/run.py @@ -13,6 +13,7 @@ import websockets import socketserver import queue import io +import re from PIL import Image from PIL import UnidentifiedImageError @@ -254,6 +255,7 @@ class TextFiltering: def __init__(self): from pysbd import Segmenter self.segmenter = Segmenter(language='ja', clean=True) + self.kana_kanji_regex = re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]') try: from transformers import pipeline, AutoTokenizer model_ckpt = 'papluca/xlm-roberta-base-language-detection' @@ -267,12 +269,27 @@ class TextFiltering: import langid self.classify = langid.classify - def __call__(self, text, last_text): + def __call__(self, text, last_result): orig_text = self.segmenter.segment(text) - if last_text[1] != engine_index: - new_blocks = orig_text + + orig_text_filtered = [] + for block in orig_text: + block_filtered = self.kana_kanji_regex.findall(block) + if block_filtered: + orig_text_filtered.append(''.join(block_filtered)) + else: + orig_text_filtered.append(None) + + if last_result[1] == engine_index: + last_text = last_result[0] else: - new_blocks = [block for block in orig_text if block not in last_text[0]] + last_text = [] + + new_blocks = [] + for idx, block in enumerate(orig_text): + if orig_text_filtered[idx] and (orig_text_filtered[idx] not in last_text): + new_blocks.append(block) + final_blocks = [] if self.accurate_filtering: detection_results = self.pipe(new_blocks, top_k=2, truncation=True) @@ -286,7 +303,7 @@ class TextFiltering: final_blocks.append(block) text = '\n'.join(final_blocks) - return text, orig_text + return text, orig_text_filtered def pause_handler(is_combo=True): @@ -456,17 +473,17 @@ def are_images_identical(img1, img2): return (img1.shape == img2.shape) and (img1 == img2).all() -def process_and_write_results(img_or_path, write_to, notifications, last_text, filtering): +def process_and_write_results(img_or_path, write_to, notifications, last_result, filtering): engine_instance = engine_instances[engine_index] t0 = time.time() res, text = engine_instance(img_or_path) t1 = time.time() - orig_text = None + orig_text = [] engine_color = config.get_general('engine_color') if res: if filtering: - text, orig_text = filtering(text, last_text) + text, orig_text = filtering(text, last_result) text = post_process(text) logger.opt(ansi=True).info(f'Text recognized in {t1 - t0:0.03f}s using <{engine_color}>{engine_instance.readable_name}: {text}') if notifications: @@ -657,7 +674,7 @@ def run(read_from=None, screencapture_mode = None screencapture_window_active = True screencapture_window_visible = True - last_text = ([], engine_index) + last_result = ([], engine_index) if screen_capture_coords == '': screencapture_mode = 0 elif len(screen_capture_coords.split(',')) == 4: @@ -909,9 +926,9 @@ def run(read_from=None, else: sct_img = sct.grab(sct_params) img = Image.frombytes('RGB', sct_img.size, sct_img.bgra, 'raw', 'BGRX') - res = process_and_write_results(img, write_to, notifications, last_text, filtering) + res = process_and_write_results(img, write_to, notifications, last_result, filtering) if res: - last_text = (res, engine_index) + last_result = (res, engine_index) delay = screen_capture_delay_secs else: delay = delay_secs