Initial commit

This commit is contained in:
Maciej Budyś
2022-01-17 21:38:00 +01:00
parent 686bd5b33f
commit 5f925fde15
7 changed files with 301 additions and 0 deletions

55
manga_ocr/ocr.py Normal file
View File

@@ -0,0 +1,55 @@
import re
from pathlib import Path
import jaconv
import torch
from PIL import Image
from loguru import logger
from transformers import AutoFeatureExtractor, AutoTokenizer, VisionEncoderDecoderModel
class MangaOcr:
def __init__(self, pretrained_model_name_or_path='kha-white/manga-ocr-base', force_cpu=False):
logger.info(f'Loading OCR model from {pretrained_model_name_or_path}')
self.feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
self.model = VisionEncoderDecoderModel.from_pretrained(pretrained_model_name_or_path)
if not force_cpu and torch.cuda.is_available():
logger.info('Using CUDA')
self.model.cuda()
else:
logger.info('Using CPU')
self(Path(__file__).parent.parent / 'assets/crop.png')
logger.info('OCR ready')
def __call__(self, img_or_path):
if isinstance(img_or_path, str) or isinstance(img_or_path, Path):
img = Image.open(img_or_path)
elif isinstance(img_or_path, Image.Image):
img = img_or_path
else:
raise ValueError(f'Invalid value of img_or_path: {img_or_path}')
img = img.convert('L').convert('RGB')
x = self._preprocess(img)
x = self.model.generate(x[None].to(self.model.device))[0].cpu()
x = self.tokenizer.decode(x, skip_special_tokens=True)
x = post_process(x)
return x
def _preprocess(self, img):
pixel_values = self.feature_extractor(img, return_tensors="pt").pixel_values
return pixel_values.squeeze()
def post_process(text):
text = ''.join(text.split())
text = text.replace('', '...')
text = re.sub('[・.]{2,}', lambda x: (x.end() - x.start()) * '.', text)
text = jaconv.h2z(text, ascii=True, digit=True)
return text