Fix websockets > 14 (thanks @graingert), use GPU for text filtering if available (silences warnings)

This commit is contained in:
AuroraWright
2024-12-12 21:29:11 +01:00
parent b932f1f767
commit 1897324de2
2 changed files with 31 additions and 22 deletions

View File

@@ -59,8 +59,7 @@ config = None
class WindowsClipboardThread(threading.Thread):
def __init__(self):
super().__init__()
self.daemon = True
super().__init__(daemon=True)
self.last_update = time.time()
def process_message(self, hwnd: int, msg: int, wparam: int, lparam: int):
@@ -90,11 +89,16 @@ class WindowsClipboardThread(threading.Thread):
class WebsocketServerThread(threading.Thread):
def __init__(self, read):
super().__init__()
self.daemon = True
self.loop = asyncio.new_event_loop()
super().__init__(daemon=True)
self._loop = None
self.read = read
self.clients = set()
self._event = threading.Event()
@property
def loop(self):
self._event.wait()
return self._loop
async def send_text_coroutine(self, text):
for client in self.clients:
@@ -124,19 +128,17 @@ class WebsocketServerThread(threading.Thread):
return asyncio.run_coroutine_threadsafe(self.send_text_coroutine(text), self.loop)
def stop_server(self):
self.loop.call_soon_threadsafe(self.server.ws_server.close)
self.loop.call_soon_threadsafe(self.loop.stop)
self.loop.call_soon_threadsafe(self._stop_event.set)
def run(self):
asyncio.set_event_loop(self.loop)
start_server = websockets.serve(self.server_handler, '0.0.0.0', config.get_general('websocket_port'), max_size=1000000000)
self.server = start_server
self.loop.run_until_complete(start_server)
self.loop.run_forever()
pending = asyncio.all_tasks(loop=self.loop)
if len(pending) > 0:
self.loop.run_until_complete(asyncio.wait(pending))
self.loop.close()
async def main():
self._loop = asyncio.get_running_loop()
self._stop_event = stop_event = asyncio.Event()
self._event.set()
self.server = start_server = websockets.serve(self.server_handler, '0.0.0.0', config.get_general('websocket_port'), max_size=1000000000)
async with start_server:
await stop_event.wait()
asyncio.run(main())
class RequestHandler(socketserver.BaseRequestHandler):
@@ -164,8 +166,7 @@ class RequestHandler(socketserver.BaseRequestHandler):
class MacOSWindowTracker(threading.Thread):
def __init__(self, window_id):
super().__init__()
self.daemon = True
super().__init__(daemon=True)
self.stop = False
self.window_id = window_id
self.window_active = False
@@ -200,8 +201,7 @@ class MacOSWindowTracker(threading.Thread):
class WindowsWindowTracker(threading.Thread):
def __init__(self, window_handle, only_active):
super().__init__()
self.daemon = True
super().__init__(daemon=True)
self.stop = False
self.window_handle = window_handle
self.only_active = only_active
@@ -258,12 +258,21 @@ class TextFiltering:
self.kana_kanji_regex = re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]')
try:
from transformers import pipeline, AutoTokenizer
import torch
model_ckpt = 'papluca/xlm-roberta-base-language-detection'
tokenizer = AutoTokenizer.from_pretrained(
model_ckpt,
use_fast = False
)
self.pipe = pipeline('text-classification', model=model_ckpt, tokenizer=tokenizer)
if torch.cuda.is_available():
device = 0
elif torch.backends.mps.is_available():
device = "mps"
else:
device = -1
self.pipe = pipeline('text-classification', model=model_ckpt, tokenizer=tokenizer, device=device)
self.accurate_filtering = True
except:
import langid