Fix websockets > 14 (thanks @graingert), use GPU for text filtering if available (silences warnings)
This commit is contained in:
51
owocr/run.py
51
owocr/run.py
@@ -59,8 +59,7 @@ config = None
|
|||||||
|
|
||||||
class WindowsClipboardThread(threading.Thread):
|
class WindowsClipboardThread(threading.Thread):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__(daemon=True)
|
||||||
self.daemon = True
|
|
||||||
self.last_update = time.time()
|
self.last_update = time.time()
|
||||||
|
|
||||||
def process_message(self, hwnd: int, msg: int, wparam: int, lparam: int):
|
def process_message(self, hwnd: int, msg: int, wparam: int, lparam: int):
|
||||||
@@ -90,11 +89,16 @@ class WindowsClipboardThread(threading.Thread):
|
|||||||
|
|
||||||
class WebsocketServerThread(threading.Thread):
|
class WebsocketServerThread(threading.Thread):
|
||||||
def __init__(self, read):
|
def __init__(self, read):
|
||||||
super().__init__()
|
super().__init__(daemon=True)
|
||||||
self.daemon = True
|
self._loop = None
|
||||||
self.loop = asyncio.new_event_loop()
|
|
||||||
self.read = read
|
self.read = read
|
||||||
self.clients = set()
|
self.clients = set()
|
||||||
|
self._event = threading.Event()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loop(self):
|
||||||
|
self._event.wait()
|
||||||
|
return self._loop
|
||||||
|
|
||||||
async def send_text_coroutine(self, text):
|
async def send_text_coroutine(self, text):
|
||||||
for client in self.clients:
|
for client in self.clients:
|
||||||
@@ -124,19 +128,17 @@ class WebsocketServerThread(threading.Thread):
|
|||||||
return asyncio.run_coroutine_threadsafe(self.send_text_coroutine(text), self.loop)
|
return asyncio.run_coroutine_threadsafe(self.send_text_coroutine(text), self.loop)
|
||||||
|
|
||||||
def stop_server(self):
|
def stop_server(self):
|
||||||
self.loop.call_soon_threadsafe(self.server.ws_server.close)
|
self.loop.call_soon_threadsafe(self._stop_event.set)
|
||||||
self.loop.call_soon_threadsafe(self.loop.stop)
|
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
asyncio.set_event_loop(self.loop)
|
async def main():
|
||||||
start_server = websockets.serve(self.server_handler, '0.0.0.0', config.get_general('websocket_port'), max_size=1000000000)
|
self._loop = asyncio.get_running_loop()
|
||||||
self.server = start_server
|
self._stop_event = stop_event = asyncio.Event()
|
||||||
self.loop.run_until_complete(start_server)
|
self._event.set()
|
||||||
self.loop.run_forever()
|
self.server = start_server = websockets.serve(self.server_handler, '0.0.0.0', config.get_general('websocket_port'), max_size=1000000000)
|
||||||
pending = asyncio.all_tasks(loop=self.loop)
|
async with start_server:
|
||||||
if len(pending) > 0:
|
await stop_event.wait()
|
||||||
self.loop.run_until_complete(asyncio.wait(pending))
|
asyncio.run(main())
|
||||||
self.loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
class RequestHandler(socketserver.BaseRequestHandler):
|
class RequestHandler(socketserver.BaseRequestHandler):
|
||||||
@@ -164,8 +166,7 @@ class RequestHandler(socketserver.BaseRequestHandler):
|
|||||||
|
|
||||||
class MacOSWindowTracker(threading.Thread):
|
class MacOSWindowTracker(threading.Thread):
|
||||||
def __init__(self, window_id):
|
def __init__(self, window_id):
|
||||||
super().__init__()
|
super().__init__(daemon=True)
|
||||||
self.daemon = True
|
|
||||||
self.stop = False
|
self.stop = False
|
||||||
self.window_id = window_id
|
self.window_id = window_id
|
||||||
self.window_active = False
|
self.window_active = False
|
||||||
@@ -200,8 +201,7 @@ class MacOSWindowTracker(threading.Thread):
|
|||||||
|
|
||||||
class WindowsWindowTracker(threading.Thread):
|
class WindowsWindowTracker(threading.Thread):
|
||||||
def __init__(self, window_handle, only_active):
|
def __init__(self, window_handle, only_active):
|
||||||
super().__init__()
|
super().__init__(daemon=True)
|
||||||
self.daemon = True
|
|
||||||
self.stop = False
|
self.stop = False
|
||||||
self.window_handle = window_handle
|
self.window_handle = window_handle
|
||||||
self.only_active = only_active
|
self.only_active = only_active
|
||||||
@@ -258,12 +258,21 @@ class TextFiltering:
|
|||||||
self.kana_kanji_regex = re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]')
|
self.kana_kanji_regex = re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]')
|
||||||
try:
|
try:
|
||||||
from transformers import pipeline, AutoTokenizer
|
from transformers import pipeline, AutoTokenizer
|
||||||
|
import torch
|
||||||
|
|
||||||
model_ckpt = 'papluca/xlm-roberta-base-language-detection'
|
model_ckpt = 'papluca/xlm-roberta-base-language-detection'
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_ckpt,
|
model_ckpt,
|
||||||
use_fast = False
|
use_fast = False
|
||||||
)
|
)
|
||||||
self.pipe = pipeline('text-classification', model=model_ckpt, tokenizer=tokenizer)
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = 0
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
device = "mps"
|
||||||
|
else:
|
||||||
|
device = -1
|
||||||
|
self.pipe = pipeline('text-classification', model=model_ckpt, tokenizer=tokenizer, device=device)
|
||||||
self.accurate_filtering = True
|
self.accurate_filtering = True
|
||||||
except:
|
except:
|
||||||
import langid
|
import langid
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -5,7 +5,7 @@ long_description = (Path(__file__).parent / "README.md").read_text('utf-8')
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="owocr",
|
name="owocr",
|
||||||
version='1.8.0',
|
version='1.8.1',
|
||||||
description="Japanese OCR",
|
description="Japanese OCR",
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
|||||||
Reference in New Issue
Block a user