Improve filtering

This commit is contained in:
AuroraWright
2024-06-26 07:15:03 +02:00
parent 7a4b9a73b9
commit 6b9b5c9351

View File

@@ -13,6 +13,7 @@ import websockets
import socketserver import socketserver
import queue import queue
import io import io
import re
from PIL import Image from PIL import Image
from PIL import UnidentifiedImageError from PIL import UnidentifiedImageError
@@ -254,6 +255,7 @@ class TextFiltering:
def __init__(self): def __init__(self):
from pysbd import Segmenter from pysbd import Segmenter
self.segmenter = Segmenter(language='ja', clean=True) self.segmenter = Segmenter(language='ja', clean=True)
self.kana_kanji_regex = re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]')
try: try:
from transformers import pipeline, AutoTokenizer from transformers import pipeline, AutoTokenizer
model_ckpt = 'papluca/xlm-roberta-base-language-detection' model_ckpt = 'papluca/xlm-roberta-base-language-detection'
@@ -267,12 +269,27 @@ class TextFiltering:
import langid import langid
self.classify = langid.classify self.classify = langid.classify
def __call__(self, text, last_text): def __call__(self, text, last_result):
orig_text = self.segmenter.segment(text) orig_text = self.segmenter.segment(text)
if last_text[1] != engine_index:
new_blocks = orig_text orig_text_filtered = []
for block in orig_text:
block_filtered = self.kana_kanji_regex.findall(block)
if block_filtered:
orig_text_filtered.append(''.join(block_filtered))
else:
orig_text_filtered.append(None)
if last_result[1] == engine_index:
last_text = last_result[0]
else: else:
new_blocks = [block for block in orig_text if block not in last_text[0]] last_text = []
new_blocks = []
for idx, block in enumerate(orig_text):
if orig_text_filtered[idx] and (orig_text_filtered[idx] not in last_text):
new_blocks.append(block)
final_blocks = [] final_blocks = []
if self.accurate_filtering: if self.accurate_filtering:
detection_results = self.pipe(new_blocks, top_k=2, truncation=True) detection_results = self.pipe(new_blocks, top_k=2, truncation=True)
@@ -286,7 +303,7 @@ class TextFiltering:
final_blocks.append(block) final_blocks.append(block)
text = '\n'.join(final_blocks) text = '\n'.join(final_blocks)
return text, orig_text return text, orig_text_filtered
def pause_handler(is_combo=True): def pause_handler(is_combo=True):
@@ -456,17 +473,17 @@ def are_images_identical(img1, img2):
return (img1.shape == img2.shape) and (img1 == img2).all() return (img1.shape == img2.shape) and (img1 == img2).all()
def process_and_write_results(img_or_path, write_to, notifications, last_text, filtering): def process_and_write_results(img_or_path, write_to, notifications, last_result, filtering):
engine_instance = engine_instances[engine_index] engine_instance = engine_instances[engine_index]
t0 = time.time() t0 = time.time()
res, text = engine_instance(img_or_path) res, text = engine_instance(img_or_path)
t1 = time.time() t1 = time.time()
orig_text = None orig_text = []
engine_color = config.get_general('engine_color') engine_color = config.get_general('engine_color')
if res: if res:
if filtering: if filtering:
text, orig_text = filtering(text, last_text) text, orig_text = filtering(text, last_result)
text = post_process(text) text = post_process(text)
logger.opt(ansi=True).info(f'Text recognized in {t1 - t0:0.03f}s using <{engine_color}>{engine_instance.readable_name}</{engine_color}>: {text}') logger.opt(ansi=True).info(f'Text recognized in {t1 - t0:0.03f}s using <{engine_color}>{engine_instance.readable_name}</{engine_color}>: {text}')
if notifications: if notifications:
@@ -657,7 +674,7 @@ def run(read_from=None,
screencapture_mode = None screencapture_mode = None
screencapture_window_active = True screencapture_window_active = True
screencapture_window_visible = True screencapture_window_visible = True
last_text = ([], engine_index) last_result = ([], engine_index)
if screen_capture_coords == '': if screen_capture_coords == '':
screencapture_mode = 0 screencapture_mode = 0
elif len(screen_capture_coords.split(',')) == 4: elif len(screen_capture_coords.split(',')) == 4:
@@ -909,9 +926,9 @@ def run(read_from=None,
else: else:
sct_img = sct.grab(sct_params) sct_img = sct.grab(sct_params)
img = Image.frombytes('RGB', sct_img.size, sct_img.bgra, 'raw', 'BGRX') img = Image.frombytes('RGB', sct_img.size, sct_img.bgra, 'raw', 'BGRX')
res = process_and_write_results(img, write_to, notifications, last_text, filtering) res = process_and_write_results(img, write_to, notifications, last_result, filtering)
if res: if res:
last_text = (res, engine_index) last_result = (res, engine_index)
delay = screen_capture_delay_secs delay = screen_capture_delay_secs
else: else:
delay = delay_secs delay = delay_secs