Improve screen reading filtering again x2
This commit is contained in:
52
owocr/run.py
52
owocr/run.py
@@ -13,7 +13,6 @@ import asyncio
|
|||||||
import websockets
|
import websockets
|
||||||
import queue
|
import queue
|
||||||
import io
|
import io
|
||||||
import unicodedata
|
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from PIL import UnidentifiedImageError
|
from PIL import UnidentifiedImageError
|
||||||
@@ -122,6 +121,43 @@ class WebsocketServerThread(threading.Thread):
|
|||||||
self.loop.close()
|
self.loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TextFiltering:
|
||||||
|
accurate_filtering = False
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.segmenter = Segmenter(language='ja', clean=True)
|
||||||
|
try:
|
||||||
|
from transformers import pipeline, AutoTokenizer
|
||||||
|
model_ckpt = 'papluca/xlm-roberta-base-language-detection'
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_ckpt,
|
||||||
|
use_fast = False
|
||||||
|
)
|
||||||
|
self.pipe = pipeline('text-classification', model=model_ckpt, tokenizer=tokenizer)
|
||||||
|
self.accurate_filtering = True
|
||||||
|
except:
|
||||||
|
import langid
|
||||||
|
self.classify = langid.classify
|
||||||
|
|
||||||
|
def __call__(self, text, last_text):
|
||||||
|
orig_text = self.segmenter.segment(text)
|
||||||
|
new_blocks = [block for block in orig_text if block not in last_text]
|
||||||
|
final_blocks = []
|
||||||
|
if self.accurate_filtering:
|
||||||
|
detection_results = self.pipe(new_blocks, top_k=2, truncation=True)
|
||||||
|
for idx, block in enumerate(new_blocks):
|
||||||
|
if((detection_results[idx][0]['label'] == 'ja' and detection_results[idx][0]['score'] >= 0.85) or
|
||||||
|
(detection_results[idx][1]['label'] == 'ja' and detection_results[idx][1]['score'] >= 0.85)):
|
||||||
|
final_blocks.append(block)
|
||||||
|
else:
|
||||||
|
for block in new_blocks:
|
||||||
|
if self.classify(block)[0] == 'ja':
|
||||||
|
final_blocks.append(block)
|
||||||
|
|
||||||
|
text = '\n'.join(final_blocks)
|
||||||
|
return text, orig_text
|
||||||
|
|
||||||
|
|
||||||
def user_input_thread_run(engine_instances, engine_keys):
|
def user_input_thread_run(engine_instances, engine_keys):
|
||||||
def _terminate_handler(user_input):
|
def _terminate_handler(user_input):
|
||||||
global terminated
|
global terminated
|
||||||
@@ -254,12 +290,7 @@ def are_images_identical(img1, img2):
|
|||||||
return (img1.shape == img2.shape) and (img1 == img2).all()
|
return (img1.shape == img2.shape) and (img1 == img2).all()
|
||||||
|
|
||||||
|
|
||||||
def is_japanese(text):
|
def process_and_write_results(engine_instance, img_or_path, write_to, enable_filtering, last_text, filtering):
|
||||||
japanese_count = sum(1 for char in text if 'HIRAGANA' in unicodedata.name(char) or 'KATAKANA' in unicodedata.name(char) or 'CJK UNIFIED' in unicodedata.name(char))
|
|
||||||
return japanese_count / len(text) >= 0.7
|
|
||||||
|
|
||||||
|
|
||||||
def process_and_write_results(engine_instance, img_or_path, write_to, enable_filtering, last_text, segmenter):
|
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
res, text = engine_instance(img_or_path)
|
res, text = engine_instance(img_or_path)
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
@@ -268,8 +299,7 @@ def process_and_write_results(engine_instance, img_or_path, write_to, enable_fil
|
|||||||
engine_color = config.get_general('engine_color')
|
engine_color = config.get_general('engine_color')
|
||||||
if res:
|
if res:
|
||||||
if enable_filtering:
|
if enable_filtering:
|
||||||
orig_text = segmenter.segment(text)
|
text, orig_text = filtering(text, last_text)
|
||||||
text = '\n'.join([block for block in orig_text if (block not in last_text and is_japanese(block))])
|
|
||||||
text = post_process(text)
|
text = post_process(text)
|
||||||
logger.opt(ansi=True).info(f'Text recognized in {t1 - t0:0.03f}s using <{engine_color}>{engine_instance.readable_name}</{engine_color}>: {text}')
|
logger.opt(ansi=True).info(f'Text recognized in {t1 - t0:0.03f}s using <{engine_color}>{engine_instance.readable_name}</{engine_color}>: {text}')
|
||||||
if config.get_general('notifications'):
|
if config.get_general('notifications'):
|
||||||
@@ -482,8 +512,8 @@ def run(read_from=None,
|
|||||||
|
|
||||||
global sct_params
|
global sct_params
|
||||||
sct_params = {'top': coord_top, 'left': coord_left, 'width': coord_width, 'height': coord_height, 'mon': screen_capture_monitor}
|
sct_params = {'top': coord_top, 'left': coord_left, 'width': coord_width, 'height': coord_height, 'mon': screen_capture_monitor}
|
||||||
segmenter = Segmenter(language="ja", clean=True)
|
|
||||||
|
|
||||||
|
filtering = TextFiltering()
|
||||||
logger.opt(ansi=True).info(f"Reading with screen capture using <{engine_color}>{engine_instances[engine_index].readable_name}</{engine_color}>{' (paused)' if paused else ''}")
|
logger.opt(ansi=True).info(f"Reading with screen capture using <{engine_color}>{engine_instances[engine_index].readable_name}</{engine_color}>{' (paused)' if paused else ''}")
|
||||||
else:
|
else:
|
||||||
read_from = Path(read_from)
|
read_from = Path(read_from)
|
||||||
@@ -581,7 +611,7 @@ def run(read_from=None,
|
|||||||
if take_screenshot and screencapture_window_visible:
|
if take_screenshot and screencapture_window_visible:
|
||||||
sct_img = sct.grab(sct_params)
|
sct_img = sct.grab(sct_params)
|
||||||
img = Image.frombytes('RGB', sct_img.size, sct_img.bgra, 'raw', 'BGRX')
|
img = Image.frombytes('RGB', sct_img.size, sct_img.bgra, 'raw', 'BGRX')
|
||||||
res = process_and_write_results(engine_instances[engine_index], img, write_to, True, last_text, segmenter)
|
res = process_and_write_results(engine_instances[engine_index], img, write_to, True, last_text, filtering)
|
||||||
if res != '':
|
if res != '':
|
||||||
last_text = res
|
last_text = res
|
||||||
delay = screen_capture_delay_secs
|
delay = screen_capture_delay_secs
|
||||||
|
|||||||
Reference in New Issue
Block a user