diff --git a/owocr/run.py b/owocr/run.py index 9a2dec4..231033b 100644 --- a/owocr/run.py +++ b/owocr/run.py @@ -59,8 +59,7 @@ config = None class WindowsClipboardThread(threading.Thread): def __init__(self): - super().__init__() - self.daemon = True + super().__init__(daemon=True) self.last_update = time.time() def process_message(self, hwnd: int, msg: int, wparam: int, lparam: int): @@ -90,11 +89,16 @@ class WindowsClipboardThread(threading.Thread): class WebsocketServerThread(threading.Thread): def __init__(self, read): - super().__init__() - self.daemon = True - self.loop = asyncio.new_event_loop() + super().__init__(daemon=True) + self._loop = None self.read = read self.clients = set() + self._event = threading.Event() + + @property + def loop(self): + self._event.wait() + return self._loop async def send_text_coroutine(self, text): for client in self.clients: @@ -124,19 +128,17 @@ class WebsocketServerThread(threading.Thread): return asyncio.run_coroutine_threadsafe(self.send_text_coroutine(text), self.loop) def stop_server(self): - self.loop.call_soon_threadsafe(self.server.ws_server.close) - self.loop.call_soon_threadsafe(self.loop.stop) + self.loop.call_soon_threadsafe(self._stop_event.set) def run(self): - asyncio.set_event_loop(self.loop) - start_server = websockets.serve(self.server_handler, '0.0.0.0', config.get_general('websocket_port'), max_size=1000000000) - self.server = start_server - self.loop.run_until_complete(start_server) - self.loop.run_forever() - pending = asyncio.all_tasks(loop=self.loop) - if len(pending) > 0: - self.loop.run_until_complete(asyncio.wait(pending)) - self.loop.close() + async def main(): + self._loop = asyncio.get_running_loop() + self._stop_event = stop_event = asyncio.Event() + self._event.set() + self.server = start_server = websockets.serve(self.server_handler, '0.0.0.0', config.get_general('websocket_port'), max_size=1000000000) + async with start_server: + await stop_event.wait() + asyncio.run(main()) class RequestHandler(socketserver.BaseRequestHandler): @@ -164,8 +166,7 @@ class RequestHandler(socketserver.BaseRequestHandler): class MacOSWindowTracker(threading.Thread): def __init__(self, window_id): - super().__init__() - self.daemon = True + super().__init__(daemon=True) self.stop = False self.window_id = window_id self.window_active = False @@ -200,8 +201,7 @@ class MacOSWindowTracker(threading.Thread): class WindowsWindowTracker(threading.Thread): def __init__(self, window_handle, only_active): - super().__init__() - self.daemon = True + super().__init__(daemon=True) self.stop = False self.window_handle = window_handle self.only_active = only_active @@ -258,12 +258,21 @@ class TextFiltering: self.kana_kanji_regex = re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]') try: from transformers import pipeline, AutoTokenizer + import torch + model_ckpt = 'papluca/xlm-roberta-base-language-detection' tokenizer = AutoTokenizer.from_pretrained( model_ckpt, use_fast = False ) - self.pipe = pipeline('text-classification', model=model_ckpt, tokenizer=tokenizer) + + if torch.cuda.is_available(): + device = 0 + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = -1 + self.pipe = pipeline('text-classification', model=model_ckpt, tokenizer=tokenizer, device=device) self.accurate_filtering = True except: import langid diff --git a/setup.py b/setup.py index c0a9863..784b09a 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ long_description = (Path(__file__).parent / "README.md").read_text('utf-8') setup( name="owocr", - version='1.8.0', + version='1.8.1', description="Japanese OCR", long_description=long_description, long_description_content_type="text/markdown",