Add TextFiltering Support for more languages

This commit is contained in:
Beangate
2025-07-28 12:07:15 -04:00
parent 01f63cd402
commit 1c2e844d7a
3 changed files with 39 additions and 8 deletions

View File

@@ -299,10 +299,11 @@ class RequestHandler(socketserver.BaseRequestHandler):
class TextFiltering:
accurate_filtering = False
def __init__(self):
def __init__(self, lang='ja'):
from pysbd import Segmenter
self.segmenter = Segmenter(language='ja', clean=True)
self.kana_kanji_regex = re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]')
self.segmenter = Segmenter(language=lang, clean=True)
self.regex = self.get_regex(lang)
try:
from transformers import pipeline, AutoTokenizer
import torch
@@ -325,13 +326,35 @@ class TextFiltering:
except:
import langid
self.classify = langid.classify
def get_regex(self, lang):
if lang == 'ja':
return re.compile(r'[\u3041-\u3096\u30A1-\u30FA\u4E00-\u9FFF]')
elif lang == 'zh':
return re.compile(r'[\u4E00-\u9FFF]')
elif lang == 'ko':
return re.compile(r'[\uAC00-\uD7AF]')
elif lang == 'ar':
return re.compile(r'[\u0600-\u06FF\u0750-\u077F\u08A0-\u08FF\uFB50-\uFDFF\uFE70-\uFEFF]')
elif lang == 'ru':
return re.compile(r'[\u0400-\u04FF\u0500-\u052F\u2DE0-\u2DFF\uA640-\uA69F\u1C80-\u1C8F]')
elif lang == 'el':
return re.compile(r'[\u0370-\u03FF\u1F00-\u1FFF]')
elif lang == 'he':
return re.compile(r'[\u0590-\u05FF\uFB1D-\uFB4F]')
elif lang == 'th':
return re.compile(r'[\u0E00-\u0E7F]')
else:
# Latin Extended regex for many European languages/English
return re.compile(
r'[a-zA-Z\u00C0-\u00FF\u0100-\u017F\u0180-\u024F\u0250-\u02AF\u1D00-\u1D7F\u1D80-\u1DBF\u1E00-\u1EFF\u2C60-\u2C7F\uA720-\uA7FF\uAB30-\uAB6F]')
def __call__(self, text, last_result):
orig_text = self.segmenter.segment(text)
orig_text_filtered = []
for block in orig_text:
block_filtered = self.kana_kanji_regex.findall(block)
block_filtered = self.regex.findall(block)
if block_filtered:
orig_text_filtered.append(''.join(block_filtered))
else:
@@ -352,12 +375,13 @@ class TextFiltering:
detection_results = self.pipe(new_blocks, top_k=3, truncation=True)
for idx, block in enumerate(new_blocks):
for result in detection_results[idx]:
if result['label'] == 'ja':
if result['label'] == self.language:
final_blocks.append(block)
break
else:
for block in new_blocks:
if self.classify(block)[0] in ('ja', 'zh'):
# This only looks at language IF language is ja or zh, otherwise it keeps all text
if self.language not in ["ja", "zh"] or self.classify(block)[0] in ['ja', 'zh'] or block == "\n":
final_blocks.append(block)
text = '\n'.join(final_blocks)