From d69a1df6ecfba5e6d1c570c9e400b18762bea4b5 Mon Sep 17 00:00:00 2001 From: AuroraWright Date: Sun, 4 Feb 2024 12:43:31 +0100 Subject: [PATCH] Improve screen reading filtering again x2 --- owocr/run.py | 52 +++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 41 insertions(+), 11 deletions(-) diff --git a/owocr/run.py b/owocr/run.py index 3748124..8564082 100644 --- a/owocr/run.py +++ b/owocr/run.py @@ -13,7 +13,6 @@ import asyncio import websockets import queue import io -import unicodedata from PIL import Image from PIL import UnidentifiedImageError @@ -122,6 +121,43 @@ class WebsocketServerThread(threading.Thread): self.loop.close() +class TextFiltering: + accurate_filtering = False + + def __init__(self): + self.segmenter = Segmenter(language='ja', clean=True) + try: + from transformers import pipeline, AutoTokenizer + model_ckpt = 'papluca/xlm-roberta-base-language-detection' + tokenizer = AutoTokenizer.from_pretrained( + model_ckpt, + use_fast = False + ) + self.pipe = pipeline('text-classification', model=model_ckpt, tokenizer=tokenizer) + self.accurate_filtering = True + except: + import langid + self.classify = langid.classify + + def __call__(self, text, last_text): + orig_text = self.segmenter.segment(text) + new_blocks = [block for block in orig_text if block not in last_text] + final_blocks = [] + if self.accurate_filtering: + detection_results = self.pipe(new_blocks, top_k=2, truncation=True) + for idx, block in enumerate(new_blocks): + if((detection_results[idx][0]['label'] == 'ja' and detection_results[idx][0]['score'] >= 0.85) or + (detection_results[idx][1]['label'] == 'ja' and detection_results[idx][1]['score'] >= 0.85)): + final_blocks.append(block) + else: + for block in new_blocks: + if self.classify(block)[0] == 'ja': + final_blocks.append(block) + + text = '\n'.join(final_blocks) + return text, orig_text + + def user_input_thread_run(engine_instances, engine_keys): def _terminate_handler(user_input): global terminated @@ -254,12 +290,7 @@ def are_images_identical(img1, img2): return (img1.shape == img2.shape) and (img1 == img2).all() -def is_japanese(text): - japanese_count = sum(1 for char in text if 'HIRAGANA' in unicodedata.name(char) or 'KATAKANA' in unicodedata.name(char) or 'CJK UNIFIED' in unicodedata.name(char)) - return japanese_count / len(text) >= 0.7 - - -def process_and_write_results(engine_instance, img_or_path, write_to, enable_filtering, last_text, segmenter): +def process_and_write_results(engine_instance, img_or_path, write_to, enable_filtering, last_text, filtering): t0 = time.time() res, text = engine_instance(img_or_path) t1 = time.time() @@ -268,8 +299,7 @@ def process_and_write_results(engine_instance, img_or_path, write_to, enable_fil engine_color = config.get_general('engine_color') if res: if enable_filtering: - orig_text = segmenter.segment(text) - text = '\n'.join([block for block in orig_text if (block not in last_text and is_japanese(block))]) + text, orig_text = filtering(text, last_text) text = post_process(text) logger.opt(ansi=True).info(f'Text recognized in {t1 - t0:0.03f}s using <{engine_color}>{engine_instance.readable_name}: {text}') if config.get_general('notifications'): @@ -482,8 +512,8 @@ def run(read_from=None, global sct_params sct_params = {'top': coord_top, 'left': coord_left, 'width': coord_width, 'height': coord_height, 'mon': screen_capture_monitor} - segmenter = Segmenter(language="ja", clean=True) + filtering = TextFiltering() logger.opt(ansi=True).info(f"Reading with screen capture using <{engine_color}>{engine_instances[engine_index].readable_name}{' (paused)' if paused else ''}") else: read_from = Path(read_from) @@ -581,7 +611,7 @@ def run(read_from=None, if take_screenshot and screencapture_window_visible: sct_img = sct.grab(sct_params) img = Image.frombytes('RGB', sct_img.size, sct_img.bgra, 'raw', 'BGRX') - res = process_and_write_results(engine_instances[engine_index], img, write_to, True, last_text, segmenter) + res = process_and_write_results(engine_instances[engine_index], img, write_to, True, last_text, filtering) if res != '': last_text = res delay = screen_capture_delay_secs