Add support for fpng-py, update Azure APIs
This commit is contained in:
94
owocr/ocr.py
94
owocr/ocr.py
@@ -33,10 +33,10 @@ except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from azure.cognitiveservices.vision.computervision import ComputerVisionClient
|
||||
from azure.cognitiveservices.vision.computervision.models import OperationStatusCodes
|
||||
from msrest.authentication import CognitiveServicesCredentials
|
||||
from msrest.exceptions import ClientRequestError
|
||||
from azure.ai.vision.imageanalysis import ImageAnalysisClient
|
||||
from azure.ai.vision.imageanalysis.models import VisualFeatures
|
||||
from azure.core.credentials import AzureKeyCredential
|
||||
from azure.core.exceptions import ServiceRequestError
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -66,6 +66,12 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import fpng_py
|
||||
optimized_png_encode = True
|
||||
except:
|
||||
optimized_png_encode = False
|
||||
|
||||
|
||||
def empty_post_process(text):
|
||||
return text
|
||||
@@ -79,6 +85,22 @@ def post_process(text):
|
||||
return text
|
||||
|
||||
|
||||
def pil_image_to_bytes(img, img_format='png', png_compression=6):
|
||||
if img_format == 'png' and optimized_png_encode:
|
||||
raw_data = img.convert('RGBA').tobytes()
|
||||
width, height = img.size
|
||||
image_bytes = fpng_py.fpng_encode_image_to_memory(raw_data, width, height)
|
||||
else:
|
||||
image_bytes = io.BytesIO()
|
||||
img.save(image_bytes, format=img_format, compress_level=png_compression)
|
||||
image_bytes = image_bytes.getvalue()
|
||||
return image_bytes
|
||||
|
||||
|
||||
def pil_image_to_numpy_array(img):
|
||||
return np.array(img.convert('RGBA'))
|
||||
|
||||
|
||||
class MangaOcr:
|
||||
name = 'mangaocr'
|
||||
readable_name = 'Manga OCR'
|
||||
@@ -149,9 +171,7 @@ class GoogleVision:
|
||||
return x
|
||||
|
||||
def _preprocess(self, img):
|
||||
image_bytes = io.BytesIO()
|
||||
img.save(image_bytes, format='png')
|
||||
return image_bytes.getvalue()
|
||||
return pil_image_to_bytes(img)
|
||||
|
||||
class GoogleLens:
|
||||
name = 'glens'
|
||||
@@ -216,9 +236,7 @@ class GoogleLens:
|
||||
new_h = int(new_w / aspect_ratio)
|
||||
img = img.resize((new_w, new_h), Image.LANCZOS)
|
||||
|
||||
image_bytes = io.BytesIO()
|
||||
img.save(image_bytes, format='png')
|
||||
return image_bytes.getvalue()
|
||||
return pil_image_to_bytes(img)
|
||||
|
||||
class AppleVision:
|
||||
name = 'avision'
|
||||
@@ -268,9 +286,7 @@ class AppleVision:
|
||||
return x
|
||||
|
||||
def _preprocess(self, img):
|
||||
image_bytes = io.BytesIO()
|
||||
img.save(image_bytes, format='tiff')
|
||||
return image_bytes.getvalue()
|
||||
return pil_image_to_bytes(img, 'tiff')
|
||||
|
||||
class WinRTOCR:
|
||||
name = 'winrtocr'
|
||||
@@ -326,27 +342,25 @@ class WinRTOCR:
|
||||
return x
|
||||
|
||||
def _preprocess(self, img):
|
||||
image_bytes = io.BytesIO()
|
||||
img.save(image_bytes, format='png', compress_level=1)
|
||||
return image_bytes.getvalue()
|
||||
return pil_image_to_bytes(img, png_compression=1)
|
||||
|
||||
class AzureComputerVision:
|
||||
class AzureImageAnalysis:
|
||||
name = 'azure'
|
||||
readable_name = 'Azure Computer Vision'
|
||||
readable_name = 'Azure Image Analysis'
|
||||
key = 'v'
|
||||
available = False
|
||||
|
||||
def __init__(self, config={}):
|
||||
if 'azure.cognitiveservices.vision.computervision' not in sys.modules:
|
||||
logger.warning('azure-cognitiveservices-vision-computervision not available, Azure Computer Vision will not work!')
|
||||
if 'azure.ai.vision.imageanalysis' not in sys.modules:
|
||||
logger.warning('azure-ai-vision-imageanalysis not available, Azure Image Analysis will not work!')
|
||||
else:
|
||||
logger.info(f'Parsing Azure credentials')
|
||||
try:
|
||||
self.client = ComputerVisionClient(config['endpoint'], CognitiveServicesCredentials(config['api_key']))
|
||||
self.client = ImageAnalysisClient(config['endpoint'], AzureKeyCredential(config['api_key']))
|
||||
self.available = True
|
||||
logger.info('Azure Computer Vision ready')
|
||||
logger.info('Azure Image Analysis ready')
|
||||
except:
|
||||
logger.warning('Error parsing Azure credentials, Azure Computer Vision will not work!')
|
||||
logger.warning('Error parsing Azure credentials, Azure Image Analysis will not work!')
|
||||
|
||||
def __call__(self, img_or_path):
|
||||
if isinstance(img_or_path, str) or isinstance(img_or_path, Path):
|
||||
@@ -356,30 +370,17 @@ class AzureComputerVision:
|
||||
else:
|
||||
raise ValueError(f'img_or_path must be a path or PIL.Image, instead got: {img_or_path}')
|
||||
|
||||
image_io = self._preprocess(img)
|
||||
logging.getLogger('urllib3.connectionpool').disabled = True
|
||||
|
||||
try:
|
||||
read_response = self.client.read_in_stream(image_io, raw=True)
|
||||
|
||||
read_operation_location = read_response.headers['Operation-Location']
|
||||
operation_id = read_operation_location.split('/')[-1]
|
||||
|
||||
while True:
|
||||
read_result = self.client.get_read_result(operation_id)
|
||||
if read_result.status.lower() not in [OperationStatusCodes.not_started, OperationStatusCodes.running]:
|
||||
break
|
||||
time.sleep(0.3)
|
||||
except ClientRequestError:
|
||||
read_result = self.client.analyze(image_data=self._preprocess(img), visual_features=[VisualFeatures.READ])
|
||||
except ServiceRequestError:
|
||||
return (False, 'Connection error!')
|
||||
except:
|
||||
return (False, 'Unknown error!')
|
||||
|
||||
res = ''
|
||||
if read_result.status == OperationStatusCodes.succeeded:
|
||||
for text_result in read_result.analyze_result.read_results:
|
||||
for line in text_result.lines:
|
||||
res += line.text + ' '
|
||||
if read_result.read:
|
||||
for line in read_result.read.blocks[0].lines:
|
||||
res += line.text + ' '
|
||||
else:
|
||||
return (False, 'Unknown error!')
|
||||
|
||||
@@ -387,10 +388,7 @@ class AzureComputerVision:
|
||||
return x
|
||||
|
||||
def _preprocess(self, img):
|
||||
image_io = io.BytesIO()
|
||||
img.save(image_io, format='png')
|
||||
image_io.seek(0)
|
||||
return image_io
|
||||
return pil_image_to_bytes(img)
|
||||
|
||||
class EasyOCR:
|
||||
name = 'easyocr'
|
||||
@@ -424,7 +422,7 @@ class EasyOCR:
|
||||
return x
|
||||
|
||||
def _preprocess(self, img):
|
||||
return np.array(img.convert('RGB'))
|
||||
return pil_image_to_numpy_array(img)
|
||||
|
||||
class RapidOCR:
|
||||
name = 'rapidocr'
|
||||
@@ -450,6 +448,7 @@ class RapidOCR:
|
||||
|
||||
logger.info('Loading RapidOCR model')
|
||||
self.model = ROCR(rec_model_path=rapidocr_model_file)
|
||||
logging.getLogger().disabled = True
|
||||
self.available = True
|
||||
logger.info('RapidOCR ready')
|
||||
|
||||
@@ -461,7 +460,6 @@ class RapidOCR:
|
||||
else:
|
||||
raise ValueError(f'img_or_path must be a path or PIL.Image, instead got: {img_or_path}')
|
||||
|
||||
logging.getLogger().disabled = True
|
||||
res = ''
|
||||
read_results, elapsed = self.model(self._preprocess(img))
|
||||
if read_results:
|
||||
@@ -472,4 +470,4 @@ class RapidOCR:
|
||||
return x
|
||||
|
||||
def _preprocess(self, img):
|
||||
return np.array(img.convert('RGB'))
|
||||
return pil_image_to_numpy_array(img)
|
||||
|
||||
Reference in New Issue
Block a user