Improve screen reading filtering again x2

This commit is contained in:
AuroraWright
2024-02-04 12:43:31 +01:00
parent f324dfc5a8
commit d69a1df6ec

View File

@@ -13,7 +13,6 @@ import asyncio
import websockets import websockets
import queue import queue
import io import io
import unicodedata
from PIL import Image from PIL import Image
from PIL import UnidentifiedImageError from PIL import UnidentifiedImageError
@@ -122,6 +121,43 @@ class WebsocketServerThread(threading.Thread):
self.loop.close() self.loop.close()
class TextFiltering:
accurate_filtering = False
def __init__(self):
self.segmenter = Segmenter(language='ja', clean=True)
try:
from transformers import pipeline, AutoTokenizer
model_ckpt = 'papluca/xlm-roberta-base-language-detection'
tokenizer = AutoTokenizer.from_pretrained(
model_ckpt,
use_fast = False
)
self.pipe = pipeline('text-classification', model=model_ckpt, tokenizer=tokenizer)
self.accurate_filtering = True
except:
import langid
self.classify = langid.classify
def __call__(self, text, last_text):
orig_text = self.segmenter.segment(text)
new_blocks = [block for block in orig_text if block not in last_text]
final_blocks = []
if self.accurate_filtering:
detection_results = self.pipe(new_blocks, top_k=2, truncation=True)
for idx, block in enumerate(new_blocks):
if((detection_results[idx][0]['label'] == 'ja' and detection_results[idx][0]['score'] >= 0.85) or
(detection_results[idx][1]['label'] == 'ja' and detection_results[idx][1]['score'] >= 0.85)):
final_blocks.append(block)
else:
for block in new_blocks:
if self.classify(block)[0] == 'ja':
final_blocks.append(block)
text = '\n'.join(final_blocks)
return text, orig_text
def user_input_thread_run(engine_instances, engine_keys): def user_input_thread_run(engine_instances, engine_keys):
def _terminate_handler(user_input): def _terminate_handler(user_input):
global terminated global terminated
@@ -254,12 +290,7 @@ def are_images_identical(img1, img2):
return (img1.shape == img2.shape) and (img1 == img2).all() return (img1.shape == img2.shape) and (img1 == img2).all()
def is_japanese(text): def process_and_write_results(engine_instance, img_or_path, write_to, enable_filtering, last_text, filtering):
japanese_count = sum(1 for char in text if 'HIRAGANA' in unicodedata.name(char) or 'KATAKANA' in unicodedata.name(char) or 'CJK UNIFIED' in unicodedata.name(char))
return japanese_count / len(text) >= 0.7
def process_and_write_results(engine_instance, img_or_path, write_to, enable_filtering, last_text, segmenter):
t0 = time.time() t0 = time.time()
res, text = engine_instance(img_or_path) res, text = engine_instance(img_or_path)
t1 = time.time() t1 = time.time()
@@ -268,8 +299,7 @@ def process_and_write_results(engine_instance, img_or_path, write_to, enable_fil
engine_color = config.get_general('engine_color') engine_color = config.get_general('engine_color')
if res: if res:
if enable_filtering: if enable_filtering:
orig_text = segmenter.segment(text) text, orig_text = filtering(text, last_text)
text = '\n'.join([block for block in orig_text if (block not in last_text and is_japanese(block))])
text = post_process(text) text = post_process(text)
logger.opt(ansi=True).info(f'Text recognized in {t1 - t0:0.03f}s using <{engine_color}>{engine_instance.readable_name}</{engine_color}>: {text}') logger.opt(ansi=True).info(f'Text recognized in {t1 - t0:0.03f}s using <{engine_color}>{engine_instance.readable_name}</{engine_color}>: {text}')
if config.get_general('notifications'): if config.get_general('notifications'):
@@ -482,8 +512,8 @@ def run(read_from=None,
global sct_params global sct_params
sct_params = {'top': coord_top, 'left': coord_left, 'width': coord_width, 'height': coord_height, 'mon': screen_capture_monitor} sct_params = {'top': coord_top, 'left': coord_left, 'width': coord_width, 'height': coord_height, 'mon': screen_capture_monitor}
segmenter = Segmenter(language="ja", clean=True)
filtering = TextFiltering()
logger.opt(ansi=True).info(f"Reading with screen capture using <{engine_color}>{engine_instances[engine_index].readable_name}</{engine_color}>{' (paused)' if paused else ''}") logger.opt(ansi=True).info(f"Reading with screen capture using <{engine_color}>{engine_instances[engine_index].readable_name}</{engine_color}>{' (paused)' if paused else ''}")
else: else:
read_from = Path(read_from) read_from = Path(read_from)
@@ -581,7 +611,7 @@ def run(read_from=None,
if take_screenshot and screencapture_window_visible: if take_screenshot and screencapture_window_visible:
sct_img = sct.grab(sct_params) sct_img = sct.grab(sct_params)
img = Image.frombytes('RGB', sct_img.size, sct_img.bgra, 'raw', 'BGRX') img = Image.frombytes('RGB', sct_img.size, sct_img.bgra, 'raw', 'BGRX')
res = process_and_write_results(engine_instances[engine_index], img, write_to, True, last_text, segmenter) res = process_and_write_results(engine_instances[engine_index], img, write_to, True, last_text, filtering)
if res != '': if res != '':
last_text = res last_text = res
delay = screen_capture_delay_secs delay = screen_capture_delay_secs