Files
owocr/owocr/run.py

552 lines
21 KiB
Python

import sys
import time
import threading
from pathlib import Path
import fire
import numpy as np
import pyperclipfix
import mss
import pywinctl
import asyncio
import websockets
import queue
import io
from PIL import Image
from PIL import UnidentifiedImageError
from loguru import logger
from pynput import keyboard
from notifypy import Notify
import inspect
from owocr.ocr import *
from owocr.config import Config
try:
import win32gui
import win32api
import win32con
import win32clipboard
import ctypes
except ImportError:
pass
config = None
class WindowsClipboardThread(threading.Thread):
def __init__(self):
super().__init__()
self.daemon = True
self.last_update = time.time()
def process_message(self, hwnd: int, msg: int, wparam: int, lparam: int):
WM_CLIPBOARDUPDATE = 0x031D
timestamp = time.time()
if msg == WM_CLIPBOARDUPDATE and timestamp - self.last_update > 1 and not (paused or tmp_paused):
if win32clipboard.IsClipboardFormatAvailable(win32con.CF_BITMAP):
clipboard_event.set()
self.last_update = timestamp
return 0
def create_window(self):
className = 'ClipboardHook'
wc = win32gui.WNDCLASS()
wc.lpfnWndProc = self.process_message
wc.lpszClassName = className
wc.hInstance = win32api.GetModuleHandle(None)
class_atom = win32gui.RegisterClass(wc)
return win32gui.CreateWindow(class_atom, className, 0, 0, 0, 0, 0, 0, 0, wc.hInstance, None)
def run(self):
hwnd = self.create_window()
self.thread_id = win32api.GetCurrentThreadId()
ctypes.windll.user32.AddClipboardFormatListener(hwnd)
win32gui.PumpMessages()
class WebsocketServerThread(threading.Thread):
def __init__(self, port, read):
super().__init__()
self.daemon = True
self.loop = asyncio.new_event_loop()
self.port = port
self.read = read
self.clients = set()
async def send_text_coroutine(self, text):
for client in self.clients:
await client.send(text)
async def server_handler(self, websocket):
self.clients.add(websocket)
try:
async for message in websocket:
if self.read and not (paused or tmp_paused):
websocket_queue.put(message)
try:
await websocket.send('True')
except websockets.exceptions.ConnectionClosedOK:
pass
else:
try:
await websocket.send('False')
except websockets.exceptions.ConnectionClosedOK:
pass
finally:
self.clients.remove(websocket)
def send_text(self, text):
return asyncio.run_coroutine_threadsafe(self.send_text_coroutine(text), self.loop)
def stop_server(self):
self.loop.call_soon_threadsafe(self.server.ws_server.close)
self.loop.call_soon_threadsafe(self.loop.stop)
def run(self):
asyncio.set_event_loop(self.loop)
start_server = websockets.serve(self.server_handler, '0.0.0.0', self.port, max_size=50000000)
self.server = start_server
self.loop.run_until_complete(start_server)
self.loop.run_forever()
pending = asyncio.all_tasks(loop=self.loop)
if len(pending) > 0:
self.loop.run_until_complete(asyncio.wait(pending))
self.loop.close()
def user_input_thread_run(engine_instances, engine_keys, engine_color):
global terminated
def _pause_handler(user_input):
global paused
global just_unpaused
if paused:
logger.info('Unpaused!')
just_unpaused = True
else:
logger.info('Paused!')
paused = not paused
def _engine_change_handler(user_input):
global engine_index
old_engine_index = engine_index
if user_input.lower() == 's':
if engine_index == len(engine_keys) - 1:
engine_index = 0
else:
engine_index += 1
elif user_input.lower() != '' and user_input.lower() in engine_keys:
engine_index = engine_keys.index(user_input.lower())
if engine_index != old_engine_index:
logger.opt(ansi=True).info(f'Switched to <{engine_color}>{engine_instances[engine_index].readable_name}</{engine_color}>!')
if sys.platform == 'win32':
import msvcrt
while True:
user_input_bytes = msvcrt.getch()
try:
user_input = user_input_bytes.decode()
if user_input.lower() in 'tq':
logger.info('Terminated!')
terminated = True
break
elif user_input.lower() == 'p':
_pause_handler(user_input)
else:
_engine_change_handler(user_input)
except UnicodeDecodeError:
pass
else:
import tty, termios
fd = sys.stdin.fileno()
old_settings = termios.tcgetattr(fd)
try:
tty.setcbreak(sys.stdin.fileno())
while True:
user_input = sys.stdin.read(1)
if user_input.lower() in 'tq':
logger.info('Terminated!')
terminated = True
break
if user_input.lower() == 'p':
_pause_handler(user_input)
else:
_engine_change_handler(user_input)
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
def on_key_press(key):
global tmp_paused
global first_pressed
if first_pressed == None and key in (keyboard.Key.cmd_l, keyboard.Key.cmd_r, keyboard.Key.ctrl_l, keyboard.Key.ctrl_r):
first_pressed = key
tmp_paused = True
def on_key_release(key):
global tmp_paused
global just_unpaused
global first_pressed
if key == first_pressed:
tmp_paused = False
just_unpaused = True
first_pressed = None
def on_window_activated(active):
global screencapture_window_active
screencapture_window_active = active
def on_window_resized(size):
global sct_params
sct_params['width'] = size[0]
sct_params['height'] = size[1]
def on_window_moved(pos):
global sct_params
sct_params['left'] = pos[0]
sct_params['top'] = pos[1]
def are_images_identical(img1, img2):
if None in (img1, img2):
return img1 == img2
img1 = np.array(img1)
img2 = np.array(img2)
return (img1.shape == img2.shape) and (img1 == img2).all()
def process_and_write_results(engine_instance, engine_color, img_or_path, write_to, notifications):
t0 = time.time()
text = engine_instance(img_or_path)
t1 = time.time()
logger.opt(ansi=True).info(f'Text recognized in {t1 - t0:0.03f}s using <{engine_color}>{engine_instance.readable_name}</{engine_color}>: {text}')
if notifications:
notification = Notify()
notification.application_name = 'owocr'
notification.title = 'Text recognized:'
notification.message = text
notification.send(block=False)
if write_to == 'websocket':
websocket_server_thread.send_text(text)
elif write_to == 'clipboard':
pyperclipfix.copy(text)
else:
write_to = Path(write_to)
if write_to.suffix.lower() != '.txt':
raise ValueError('write_to must be either "clipboard" or a path to a text file')
with write_to.open('a', encoding='utf-8') as f:
f.write(text + '\n')
def get_path_key(path):
return path, path.lstat().st_mtime
def init_config():
global config
config = Config()
def run(read_from='clipboard',
write_to='clipboard',
engine='',
pause_at_startup=False,
ignore_flag=False,
delete_images=False
):
"""
Japanese OCR client
Run OCR in the background, waiting for new images to appear either in system clipboard or a directory, or to be sent via a websocket.
Recognized texts can be either saved to system clipboard, appended to a text file or sent via a websocket.
:param read_from: Specifies where to read input images from. Can be either "clipboard", "websocket", "screencapture", or a path to a directory.
:param write_to: Specifies where to save recognized texts to. Can be either "clipboard", "websocket", or a path to a text file.
:param delay_secs: How often to check for new images, in seconds.
:param engine: OCR engine to use. Available: "mangaocr", "glens", "gvision", "avision", "azure", "winrtocr", "easyocr", "paddleocr".
:param pause_at_startup: Pause at startup.
:param ignore_flag: Process flagged clipboard images (images that are copied to the clipboard with the *ocr_ignore* string).
:param delete_images: Delete image files after processing when reading from a directory.
"""
engine_instances = []
config_engines = []
engine_keys = []
default_engine = ''
logger_format = '<green>{time:HH:mm:ss.SSS}</green> | <level>{message}</level>'
engine_color = 'cyan'
delay_secs = 0.5
websocket_port = 7331
notifications = False
screen_capture_monitor = 1
screen_capture_coords = ''
screen_capture_delay_secs = 3
if not config:
init_config()
if config.has_config:
if config.get_general('engines'):
for config_engine in config.get_general('engines').split(','):
config_engines.append(config_engine.strip().lower())
if config.get_general('logger_format'):
logger_format = config.get_general('logger_format')
if config.get_general('engine_color'):
engine_color = config.get_general('engine_color')
if config.get_general('delay_secs'):
delay_secs = config.get_general('delay_secs')
if config.get_general('websocket_port'):
websocket_port = config.get_general('websocket_port')
if config.get_general('notifications'):
notifications = config.get_general('notifications')
if config.get_general('screen_capture_monitor'):
screen_capture_monitor = config.get_general('screen_capture_monitor')
if config.get_general('screen_capture_delay_secs'):
screen_capture_delay_secs = config.get_general('screen_capture_delay_secs')
if config.get_general('screen_capture_coords'):
screen_capture_coords = config.get_general('screen_capture_coords')
logger.configure(handlers=[{'sink': sys.stderr, 'format': logger_format}])
if config.has_config:
logger.info('Parsed config file')
else:
logger.warning('No config file, defaults will be used')
for _,engine_class in sorted(inspect.getmembers(sys.modules[__name__], lambda x: hasattr(x, '__module__') and __package__ + '.ocr' in x.__module__ and inspect.isclass(x))):
if len(config_engines) == 0 or engine_class.name in config_engines:
if config.get_engine(engine_class.name) == None:
engine_instance = engine_class()
else:
engine_instance = engine_class(config.get_engine(engine_class.name))
if engine_instance.available:
engine_instances.append(engine_instance)
engine_keys.append(engine_class.key)
if engine == engine_class.name:
default_engine = engine_class.key
if len(engine_keys) == 0:
msg = 'No engines available!'
raise NotImplementedError(msg)
global engine_index
global terminated
global paused
global tmp_paused
global just_unpaused
global first_pressed
terminated = False
paused = pause_at_startup
just_unpaused = True
tmp_paused = False
first_pressed = None
engine_index = engine_keys.index(default_engine) if default_engine != '' else 0
user_input_thread = threading.Thread(target=user_input_thread_run, args=(engine_instances, engine_keys, engine_color), daemon=True)
user_input_thread.start()
tmp_paused_listener = keyboard.Listener(
on_press=on_key_press,
on_release=on_key_release)
tmp_paused_listener.start()
if read_from == 'websocket' or write_to == 'websocket':
global websocket_server_thread
websocket_server_thread = WebsocketServerThread(websocket_port, read_from == 'websocket')
websocket_server_thread.start()
if read_from == 'websocket':
global websocket_queue
websocket_queue = queue.Queue()
logger.opt(ansi=True).info(f"Reading from websocket using <{engine_color}>{engine_instances[engine_index].readable_name}</{engine_color}>{' (paused)' if paused else ''}")
elif read_from == 'clipboard':
from PIL import ImageGrab
mac_clipboard_polling = False
windows_clipboard_polling = False
generic_clipboard_polling = False
img = None
logger.opt(ansi=True).info(f"Reading from clipboard using <{engine_color}>{engine_instances[engine_index].readable_name}</{engine_color}>{' (paused)' if paused else ''}")
if sys.platform == 'darwin':
from AppKit import NSPasteboard, NSPasteboardTypePNG, NSPasteboardTypeTIFF
pasteboard = NSPasteboard.generalPasteboard()
count = pasteboard.changeCount()
mac_clipboard_polling = True
elif sys.platform == 'win32':
global clipboard_event
clipboard_event = threading.Event()
windows_clipboard_thread = WindowsClipboardThread()
windows_clipboard_thread.start()
windows_clipboard_polling = True
else:
generic_clipboard_polling = True
elif read_from == 'screencapture':
global screencapture_window_active
screencapture_window_mode = False
screencapture_window_active = True
with mss.mss() as sct:
mon = sct.monitors
if len(mon) <= screen_capture_monitor:
msg = '"screen_capture_monitor" has to be a valid monitor number!'
raise ValueError(msg)
if screen_capture_coords == '':
coord_left = mon[screen_capture_monitor]["left"]
coord_top = mon[screen_capture_monitor]["top"]
coord_width = mon[screen_capture_monitor]["width"]
coord_height = mon[screen_capture_monitor]["height"]
elif len(screen_capture_coords.split(',')) == 4:
x, y, coord_width, coord_height = [int(c.strip()) for c in screen_capture_coords.split(',')]
coord_left = mon[screen_capture_monitor]["left"] + x
coord_top = mon[screen_capture_monitor]["top"] + y
else:
window_titles = pywinctl.getAllTitles()
if screen_capture_coords in window_titles:
window_title = screen_capture_coords
else:
for window_title in window_titles:
if screen_capture_coords in window_title:
break
windows = pywinctl.getWindowsWithTitle(window_title)
if len(windows) == 0:
msg = '"screen_capture_coords" has to be empty (for the whole screen), a valid set of coordinates, or a valid window name!'
raise ValueError(msg)
screencapture_window_mode = True
target_window = windows[0]
coord_top = target_window.top
coord_left = target_window.left
coord_width = target_window.width
coord_height = target_window.height
screencapture_window_active = target_window.isActive
target_window.watchdog.start(isActiveCB=on_window_activated, resizedCB=on_window_resized, movedCB=on_window_moved)
target_window.watchdog.setTryToFind(True)
global sct_params
sct_params = {'top': coord_top, 'left': coord_left, 'width': coord_width, 'height': coord_height, 'mon': screen_capture_monitor}
logger.opt(ansi=True).info(f"Reading with screen capture using <{engine_color}>{engine_instances[engine_index].readable_name}</{engine_color}>{' (paused)' if paused else ''}")
else:
allowed_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.gif', '.webp')
read_from = Path(read_from)
if not read_from.is_dir():
raise ValueError('read_from must be either "clipboard" or a path to a directory')
logger.opt(ansi=True).info(f"Reading from directory {read_from} using <{engine_color}>{engine_instances[engine_index].readable_name}</{engine_color}>{' (paused)' if paused else ''}")
old_paths = set()
for path in read_from.iterdir():
if path.suffix.lower() in allowed_extensions:
old_paths.add(get_path_key(path))
while True:
if terminated:
if read_from == 'websocket' or write_to == 'websocket':
websocket_server_thread.stop_server()
websocket_server_thread.join()
if read_from == 'clipboard' and windows_clipboard_polling:
win32api.PostThreadMessage(windows_clipboard_thread.thread_id, win32con.WM_QUIT, 0, 0)
windows_clipboard_thread.join()
if read_from == 'screencapture' and screencapture_window_mode:
target_window.watchdog.stop()
user_input_thread.join()
tmp_paused_listener.stop()
break
if read_from == 'websocket':
while True:
try:
item = websocket_queue.get(timeout=delay_secs)
except queue.Empty:
break
else:
if not paused and not tmp_paused:
img = Image.open(io.BytesIO(item))
process_and_write_results(engine_instances[engine_index], engine_color, img, write_to, notifications)
elif read_from == 'clipboard':
if windows_clipboard_polling:
clipboard_changed = clipboard_event.wait(delay_secs)
if clipboard_changed:
clipboard_event.clear()
elif mac_clipboard_polling:
if not (paused or tmp_paused):
old_count = count
count = pasteboard.changeCount()
clipboard_changed = not just_unpaused and count != old_count and any(x in pasteboard.types() for x in [NSPasteboardTypePNG, NSPasteboardTypeTIFF])
else:
clipboard_changed = not (paused or tmp_paused)
if clipboard_changed:
old_img = img
try:
img = ImageGrab.grabclipboard()
except Exception:
pass
else:
if (windows_clipboard_polling or (not just_unpaused)) and \
isinstance(img, Image.Image) and \
(ignore_flag or pyperclipfix.paste() != '*ocr_ignore*') and \
((not generic_clipboard_polling) or (not are_images_identical(img, old_img))):
process_and_write_results(engine_instances[engine_index], engine_color, img, write_to, notifications)
just_unpaused = False
if not windows_clipboard_polling:
time.sleep(delay_secs)
elif read_from == 'screencapture':
if screencapture_window_active and not paused and not tmp_paused:
with mss.mss() as sct:
sct_img = sct.grab(sct_params)
img = Image.frombytes("RGB", sct_img.size, sct_img.bgra, "raw", "BGRX")
process_and_write_results(engine_instances[engine_index], engine_color, img, write_to, notifications)
time.sleep(screen_capture_delay_secs)
else:
time.sleep(delay_secs)
else:
for path in read_from.iterdir():
if path.suffix.lower() in allowed_extensions:
path_key = get_path_key(path)
if path_key not in old_paths:
old_paths.add(path_key)
if not paused and not tmp_paused:
try:
img = Image.open(path)
img.load()
except (UnidentifiedImageError, OSError) as e:
logger.warning(f'Error while reading file {path}: {e}')
else:
process_and_write_results(engine_instances[engine_index], engine_color, img, write_to, notifications)
img.close()
if delete_images:
Path.unlink(path)
time.sleep(delay_secs)
if __name__ == '__main__':
fire.Fire(run)