Improve filtering

This commit is contained in:
AuroraWright
2024-06-26 07:15:03 +02:00
parent 7a4b9a73b9
commit 6b9b5c9351

View File

@@ -13,6 +13,7 @@ import websockets
import socketserver
import queue
import io
import re
from PIL import Image
from PIL import UnidentifiedImageError
@@ -254,6 +255,7 @@ class TextFiltering:
def __init__(self):
from pysbd import Segmenter
self.segmenter = Segmenter(language='ja', clean=True)
self.kana_kanji_regex = re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]')
try:
from transformers import pipeline, AutoTokenizer
model_ckpt = 'papluca/xlm-roberta-base-language-detection'
@@ -267,12 +269,27 @@ class TextFiltering:
import langid
self.classify = langid.classify
def __call__(self, text, last_text):
def __call__(self, text, last_result):
orig_text = self.segmenter.segment(text)
if last_text[1] != engine_index:
new_blocks = orig_text
orig_text_filtered = []
for block in orig_text:
block_filtered = self.kana_kanji_regex.findall(block)
if block_filtered:
orig_text_filtered.append(''.join(block_filtered))
else:
orig_text_filtered.append(None)
if last_result[1] == engine_index:
last_text = last_result[0]
else:
new_blocks = [block for block in orig_text if block not in last_text[0]]
last_text = []
new_blocks = []
for idx, block in enumerate(orig_text):
if orig_text_filtered[idx] and (orig_text_filtered[idx] not in last_text):
new_blocks.append(block)
final_blocks = []
if self.accurate_filtering:
detection_results = self.pipe(new_blocks, top_k=2, truncation=True)
@@ -286,7 +303,7 @@ class TextFiltering:
final_blocks.append(block)
text = '\n'.join(final_blocks)
return text, orig_text
return text, orig_text_filtered
def pause_handler(is_combo=True):
@@ -456,17 +473,17 @@ def are_images_identical(img1, img2):
return (img1.shape == img2.shape) and (img1 == img2).all()
def process_and_write_results(img_or_path, write_to, notifications, last_text, filtering):
def process_and_write_results(img_or_path, write_to, notifications, last_result, filtering):
engine_instance = engine_instances[engine_index]
t0 = time.time()
res, text = engine_instance(img_or_path)
t1 = time.time()
orig_text = None
orig_text = []
engine_color = config.get_general('engine_color')
if res:
if filtering:
text, orig_text = filtering(text, last_text)
text, orig_text = filtering(text, last_result)
text = post_process(text)
logger.opt(ansi=True).info(f'Text recognized in {t1 - t0:0.03f}s using <{engine_color}>{engine_instance.readable_name}</{engine_color}>: {text}')
if notifications:
@@ -657,7 +674,7 @@ def run(read_from=None,
screencapture_mode = None
screencapture_window_active = True
screencapture_window_visible = True
last_text = ([], engine_index)
last_result = ([], engine_index)
if screen_capture_coords == '':
screencapture_mode = 0
elif len(screen_capture_coords.split(',')) == 4:
@@ -909,9 +926,9 @@ def run(read_from=None,
else:
sct_img = sct.grab(sct_params)
img = Image.frombytes('RGB', sct_img.size, sct_img.bgra, 'raw', 'BGRX')
res = process_and_write_results(img, write_to, notifications, last_text, filtering)
res = process_and_write_results(img, write_to, notifications, last_result, filtering)
if res:
last_text = (res, engine_index)
last_result = (res, engine_index)
delay = screen_capture_delay_secs
else:
delay = delay_secs