Fix websockets > 14 (thanks @graingert), use GPU for text filtering if available (silences warnings)
This commit is contained in:
51
owocr/run.py
51
owocr/run.py
@@ -59,8 +59,7 @@ config = None
|
||||
|
||||
class WindowsClipboardThread(threading.Thread):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.daemon = True
|
||||
super().__init__(daemon=True)
|
||||
self.last_update = time.time()
|
||||
|
||||
def process_message(self, hwnd: int, msg: int, wparam: int, lparam: int):
|
||||
@@ -90,11 +89,16 @@ class WindowsClipboardThread(threading.Thread):
|
||||
|
||||
class WebsocketServerThread(threading.Thread):
|
||||
def __init__(self, read):
|
||||
super().__init__()
|
||||
self.daemon = True
|
||||
self.loop = asyncio.new_event_loop()
|
||||
super().__init__(daemon=True)
|
||||
self._loop = None
|
||||
self.read = read
|
||||
self.clients = set()
|
||||
self._event = threading.Event()
|
||||
|
||||
@property
|
||||
def loop(self):
|
||||
self._event.wait()
|
||||
return self._loop
|
||||
|
||||
async def send_text_coroutine(self, text):
|
||||
for client in self.clients:
|
||||
@@ -124,19 +128,17 @@ class WebsocketServerThread(threading.Thread):
|
||||
return asyncio.run_coroutine_threadsafe(self.send_text_coroutine(text), self.loop)
|
||||
|
||||
def stop_server(self):
|
||||
self.loop.call_soon_threadsafe(self.server.ws_server.close)
|
||||
self.loop.call_soon_threadsafe(self.loop.stop)
|
||||
self.loop.call_soon_threadsafe(self._stop_event.set)
|
||||
|
||||
def run(self):
|
||||
asyncio.set_event_loop(self.loop)
|
||||
start_server = websockets.serve(self.server_handler, '0.0.0.0', config.get_general('websocket_port'), max_size=1000000000)
|
||||
self.server = start_server
|
||||
self.loop.run_until_complete(start_server)
|
||||
self.loop.run_forever()
|
||||
pending = asyncio.all_tasks(loop=self.loop)
|
||||
if len(pending) > 0:
|
||||
self.loop.run_until_complete(asyncio.wait(pending))
|
||||
self.loop.close()
|
||||
async def main():
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._stop_event = stop_event = asyncio.Event()
|
||||
self._event.set()
|
||||
self.server = start_server = websockets.serve(self.server_handler, '0.0.0.0', config.get_general('websocket_port'), max_size=1000000000)
|
||||
async with start_server:
|
||||
await stop_event.wait()
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
class RequestHandler(socketserver.BaseRequestHandler):
|
||||
@@ -164,8 +166,7 @@ class RequestHandler(socketserver.BaseRequestHandler):
|
||||
|
||||
class MacOSWindowTracker(threading.Thread):
|
||||
def __init__(self, window_id):
|
||||
super().__init__()
|
||||
self.daemon = True
|
||||
super().__init__(daemon=True)
|
||||
self.stop = False
|
||||
self.window_id = window_id
|
||||
self.window_active = False
|
||||
@@ -200,8 +201,7 @@ class MacOSWindowTracker(threading.Thread):
|
||||
|
||||
class WindowsWindowTracker(threading.Thread):
|
||||
def __init__(self, window_handle, only_active):
|
||||
super().__init__()
|
||||
self.daemon = True
|
||||
super().__init__(daemon=True)
|
||||
self.stop = False
|
||||
self.window_handle = window_handle
|
||||
self.only_active = only_active
|
||||
@@ -258,12 +258,21 @@ class TextFiltering:
|
||||
self.kana_kanji_regex = re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]')
|
||||
try:
|
||||
from transformers import pipeline, AutoTokenizer
|
||||
import torch
|
||||
|
||||
model_ckpt = 'papluca/xlm-roberta-base-language-detection'
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_ckpt,
|
||||
use_fast = False
|
||||
)
|
||||
self.pipe = pipeline('text-classification', model=model_ckpt, tokenizer=tokenizer)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = 0
|
||||
elif torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
else:
|
||||
device = -1
|
||||
self.pipe = pipeline('text-classification', model=model_ckpt, tokenizer=tokenizer, device=device)
|
||||
self.accurate_filtering = True
|
||||
except:
|
||||
import langid
|
||||
|
||||
Reference in New Issue
Block a user