More refactoring

This commit is contained in:
AuroraWright
2024-01-28 10:31:34 +01:00
parent 7c8cf92745
commit b7b87beaa0
3 changed files with 70 additions and 83 deletions

View File

@@ -6,21 +6,15 @@ def main():
init_config()
from owocr.run import config
fullargspec = inspect.getfullargspec(run)
old_defaults = fullargspec[0]
old_default_values = fullargspec[3]
new_defaults = []
cli_args = inspect.getfullargspec(run)[0]
defaults = []
if config.has_config:
index = 0
for argument in old_defaults:
if config.get_general(argument) == None:
new_defaults.append(old_default_values[index])
else:
new_defaults.append(config.get_general(argument))
index += 1
index = 0
for arg in cli_args:
defaults.append(config.get_general(arg))
index += 1
run.__defaults__ = tuple(new_defaults)
run.__defaults__ = tuple(defaults)
fire.Fire(run)

View File

@@ -3,10 +3,27 @@ import configparser
class Config:
has_config = False
general_config = {}
engine_config = {}
__general_config = {}
__engine_config = {}
__default_config = {
'read_from': 'clipboard',
'write_to': 'clipboard',
'engine': '',
'pause_at_startup': False,
'ignore_flag': False,
'delete_images': False,
'engines': [],
'logger_format': '<green>{time:HH:mm:ss.SSS}</green> | <level>{message}</level>',
'engine_color': 'cyan',
'delay_secs': 0.5,
'websocket_port': 7331,
'notifications': False,
'screen_capture_monitor': 1,
'screen_capture_coords': '',
'screen_capture_delay_secs': 3
}
def _parse(self, value):
def __parse(self, value):
value = value.strip()
if value.lower() == 'false':
return False
@@ -34,20 +51,23 @@ class Config:
for key in config:
if key == 'general':
for sub_key in config[key]:
self.general_config[sub_key.lower()] = self._parse(config[key][sub_key])
self.__general_config[sub_key.lower()] = self.__parse(config[key][sub_key])
elif key != 'DEFAULT':
self.engine_config[key.lower()] = {}
self.__engine_config[key.lower()] = {}
for sub_key in config[key]:
self.engine_config[key.lower()][sub_key.lower()] = self._parse(config[key][sub_key])
self.__engine_config[key.lower()][sub_key.lower()] = self.__parse(config[key][sub_key])
def get_general(self, value):
try:
return self.general_config[value]
return self.__general_config[value]
except KeyError:
return None
if value in self.__default_config:
return self.__default_config[value]
else:
return None
def get_engine(self, value):
try:
return self.engine_config[value]
return self.__engine_config[value]
except KeyError:
return None

View File

@@ -66,11 +66,10 @@ class WindowsClipboardThread(threading.Thread):
class WebsocketServerThread(threading.Thread):
def __init__(self, port, read):
def __init__(self, read):
super().__init__()
self.daemon = True
self.loop = asyncio.new_event_loop()
self.port = port
self.read = read
self.clients = set()
@@ -105,7 +104,7 @@ class WebsocketServerThread(threading.Thread):
def run(self):
asyncio.set_event_loop(self.loop)
start_server = websockets.serve(self.server_handler, '0.0.0.0', self.port, max_size=50000000)
start_server = websockets.serve(self.server_handler, '0.0.0.0', config.get_general('websocket_port'), max_size=50000000)
self.server = start_server
self.loop.run_until_complete(start_server)
self.loop.run_forever()
@@ -115,7 +114,7 @@ class WebsocketServerThread(threading.Thread):
self.loop.close()
def user_input_thread_run(engine_instances, engine_keys, engine_color):
def user_input_thread_run(engine_instances, engine_keys):
def _terminate_handler(user_input):
global terminated
logger.info('Terminated!')
@@ -144,6 +143,7 @@ def user_input_thread_run(engine_instances, engine_keys, engine_color):
engine_index = engine_keys.index(user_input.lower())
if engine_index != old_engine_index:
engine_color = config.get_general('engine_color')
logger.opt(ansi=True).info(f'Switched to <{engine_color}>{engine_instances[engine_index].readable_name}</{engine_color}>!')
if sys.platform == 'win32':
@@ -223,13 +223,14 @@ def are_images_identical(img1, img2):
return (img1.shape == img2.shape) and (img1 == img2).all()
def process_and_write_results(engine_instance, engine_color, img_or_path, write_to, notifications):
def process_and_write_results(engine_instance, img_or_path, write_to):
t0 = time.time()
text = engine_instance(img_or_path)
t1 = time.time()
engine_color = config.get_general('engine_color')
logger.opt(ansi=True).info(f'Text recognized in {t1 - t0:0.03f}s using <{engine_color}>{engine_instance.readable_name}</{engine_color}>: {text}')
if notifications:
if config.get_general('notifications'):
notification = Notify()
notification.application_name = 'owocr'
notification.title = 'Text recognized:'
@@ -258,12 +259,12 @@ def init_config():
config = Config()
def run(read_from='clipboard',
write_to='clipboard',
engine='',
pause_at_startup=False,
ignore_flag=False,
delete_images=False
def run(read_from=None,
write_to=None,
engine=None,
pause_at_startup=None,
ignore_flag=None,
delete_images=None
):
"""
Japanese OCR client
@@ -280,58 +281,25 @@ def run(read_from='clipboard',
:param delete_images: Delete image files after processing when reading from a directory.
"""
engine_instances = []
config_engines = []
engine_keys = []
default_engine = ''
logger_format = '<green>{time:HH:mm:ss.SSS}</green> | <level>{message}</level>'
engine_color = 'cyan'
delay_secs = 0.5
websocket_port = 7331
notifications = False
screen_capture_monitor = 1
screen_capture_coords = ''
screen_capture_delay_secs = 3
if not config:
init_config()
if config.has_config:
if config.get_general('engines'):
for config_engine in config.get_general('engines').split(','):
config_engines.append(config_engine.strip().lower())
if config.get_general('logger_format'):
logger_format = config.get_general('logger_format')
if config.get_general('engine_color'):
engine_color = config.get_general('engine_color')
if config.get_general('delay_secs'):
delay_secs = config.get_general('delay_secs')
if config.get_general('websocket_port'):
websocket_port = config.get_general('websocket_port')
if config.get_general('notifications'):
notifications = config.get_general('notifications')
if config.get_general('screen_capture_monitor'):
screen_capture_monitor = config.get_general('screen_capture_monitor')
if config.get_general('screen_capture_delay_secs'):
screen_capture_delay_secs = config.get_general('screen_capture_delay_secs')
if config.get_general('screen_capture_coords'):
screen_capture_coords = config.get_general('screen_capture_coords')
logger.configure(handlers=[{'sink': sys.stderr, 'format': logger_format}])
logger.configure(handlers=[{'sink': sys.stderr, 'format': config.get_general('logger_format')}])
if config.has_config:
logger.info('Parsed config file')
else:
logger.warning('No config file, defaults will be used')
engine_instances = []
config_engines = []
engine_keys = []
default_engine = ''
if len(config.get_general('engines')) > 0:
for config_engine in config.get_general('engines').split(','):
config_engines.append(config_engine.strip().lower())
for _,engine_class in sorted(inspect.getmembers(sys.modules[__name__], lambda x: hasattr(x, '__module__') and __package__ + '.ocr' in x.__module__ and inspect.isclass(x))):
if len(config_engines) == 0 or engine_class.name in config_engines:
if config.get_engine(engine_class.name) == None:
@@ -361,8 +329,10 @@ def run(read_from='clipboard',
tmp_paused = False
first_pressed = None
engine_index = engine_keys.index(default_engine) if default_engine != '' else 0
engine_color = config.get_general('engine_color')
delay_secs = config.get_general('delay_secs')
user_input_thread = threading.Thread(target=user_input_thread_run, args=(engine_instances, engine_keys, engine_color), daemon=True)
user_input_thread = threading.Thread(target=user_input_thread_run, args=(engine_instances, engine_keys), daemon=True)
user_input_thread.start()
tmp_paused_listener = keyboard.Listener(
@@ -372,7 +342,7 @@ def run(read_from='clipboard',
if read_from == 'websocket' or write_to == 'websocket':
global websocket_server_thread
websocket_server_thread = WebsocketServerThread(websocket_port, read_from == 'websocket')
websocket_server_thread = WebsocketServerThread(read_from == 'websocket')
websocket_server_thread.start()
if read_from == 'websocket':
@@ -402,6 +372,9 @@ def run(read_from='clipboard',
else:
generic_clipboard_polling = True
elif read_from == 'screencapture':
screen_capture_monitor = config.get_general('screen_capture_monitor')
screen_capture_delay_secs = config.get_general('screen_capture_delay_secs')
screen_capture_coords = config.get_general('screen_capture_coords')
global screencapture_window_active
screencapture_window_mode = False
screencapture_window_active = True
@@ -471,7 +444,7 @@ def run(read_from='clipboard',
else:
if not paused and not tmp_paused:
img = Image.open(io.BytesIO(item))
process_and_write_results(engine_instances[engine_index], engine_color, img, write_to, notifications)
process_and_write_results(engine_instances[engine_index], img, write_to)
elif read_from == 'clipboard':
if windows_clipboard_polling:
clipboard_changed = clipboard_event.wait(delay_secs)
@@ -498,7 +471,7 @@ def run(read_from='clipboard',
isinstance(img, Image.Image) and \
(ignore_flag or pyperclipfix.paste() != '*ocr_ignore*') and \
((not generic_clipboard_polling) or (not are_images_identical(img, old_img))):
process_and_write_results(engine_instances[engine_index], engine_color, img, write_to, notifications)
process_and_write_results(engine_instances[engine_index], img, write_to)
just_unpaused = False
@@ -508,7 +481,7 @@ def run(read_from='clipboard',
if screencapture_window_active and not paused and not tmp_paused:
sct_img = sct.grab(sct_params)
img = Image.frombytes("RGB", sct_img.size, sct_img.bgra, "raw", "BGRX")
process_and_write_results(engine_instances[engine_index], engine_color, img, write_to, notifications)
process_and_write_results(engine_instances[engine_index], img, write_to)
time.sleep(screen_capture_delay_secs)
else:
time.sleep(delay_secs)
@@ -526,7 +499,7 @@ def run(read_from='clipboard',
except (UnidentifiedImageError, OSError) as e:
logger.warning(f'Error while reading file {path}: {e}')
else:
process_and_write_results(engine_instances[engine_index], engine_color, img, write_to, notifications)
process_and_write_results(engine_instances[engine_index], img, write_to)
img.close()
if delete_images:
Path.unlink(path)