Add support for fpng-py, update Azure APIs

This commit is contained in:
AuroraWright
2024-02-02 15:27:55 +01:00
parent a6581dd3ff
commit c2feb52233
2 changed files with 48 additions and 49 deletions

View File

@@ -33,10 +33,10 @@ except ImportError:
pass
try:
from azure.cognitiveservices.vision.computervision import ComputerVisionClient
from azure.cognitiveservices.vision.computervision.models import OperationStatusCodes
from msrest.authentication import CognitiveServicesCredentials
from msrest.exceptions import ClientRequestError
from azure.ai.vision.imageanalysis import ImageAnalysisClient
from azure.ai.vision.imageanalysis.models import VisualFeatures
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import ServiceRequestError
except ImportError:
pass
@@ -66,6 +66,12 @@ try:
except ImportError:
pass
try:
import fpng_py
optimized_png_encode = True
except:
optimized_png_encode = False
def empty_post_process(text):
return text
@@ -79,6 +85,22 @@ def post_process(text):
return text
def pil_image_to_bytes(img, img_format='png', png_compression=6):
if img_format == 'png' and optimized_png_encode:
raw_data = img.convert('RGBA').tobytes()
width, height = img.size
image_bytes = fpng_py.fpng_encode_image_to_memory(raw_data, width, height)
else:
image_bytes = io.BytesIO()
img.save(image_bytes, format=img_format, compress_level=png_compression)
image_bytes = image_bytes.getvalue()
return image_bytes
def pil_image_to_numpy_array(img):
return np.array(img.convert('RGBA'))
class MangaOcr:
name = 'mangaocr'
readable_name = 'Manga OCR'
@@ -149,9 +171,7 @@ class GoogleVision:
return x
def _preprocess(self, img):
image_bytes = io.BytesIO()
img.save(image_bytes, format='png')
return image_bytes.getvalue()
return pil_image_to_bytes(img)
class GoogleLens:
name = 'glens'
@@ -216,9 +236,7 @@ class GoogleLens:
new_h = int(new_w / aspect_ratio)
img = img.resize((new_w, new_h), Image.LANCZOS)
image_bytes = io.BytesIO()
img.save(image_bytes, format='png')
return image_bytes.getvalue()
return pil_image_to_bytes(img)
class AppleVision:
name = 'avision'
@@ -268,9 +286,7 @@ class AppleVision:
return x
def _preprocess(self, img):
image_bytes = io.BytesIO()
img.save(image_bytes, format='tiff')
return image_bytes.getvalue()
return pil_image_to_bytes(img, 'tiff')
class WinRTOCR:
name = 'winrtocr'
@@ -326,27 +342,25 @@ class WinRTOCR:
return x
def _preprocess(self, img):
image_bytes = io.BytesIO()
img.save(image_bytes, format='png', compress_level=1)
return image_bytes.getvalue()
return pil_image_to_bytes(img, png_compression=1)
class AzureComputerVision:
class AzureImageAnalysis:
name = 'azure'
readable_name = 'Azure Computer Vision'
readable_name = 'Azure Image Analysis'
key = 'v'
available = False
def __init__(self, config={}):
if 'azure.cognitiveservices.vision.computervision' not in sys.modules:
logger.warning('azure-cognitiveservices-vision-computervision not available, Azure Computer Vision will not work!')
if 'azure.ai.vision.imageanalysis' not in sys.modules:
logger.warning('azure-ai-vision-imageanalysis not available, Azure Image Analysis will not work!')
else:
logger.info(f'Parsing Azure credentials')
try:
self.client = ComputerVisionClient(config['endpoint'], CognitiveServicesCredentials(config['api_key']))
self.client = ImageAnalysisClient(config['endpoint'], AzureKeyCredential(config['api_key']))
self.available = True
logger.info('Azure Computer Vision ready')
logger.info('Azure Image Analysis ready')
except:
logger.warning('Error parsing Azure credentials, Azure Computer Vision will not work!')
logger.warning('Error parsing Azure credentials, Azure Image Analysis will not work!')
def __call__(self, img_or_path):
if isinstance(img_or_path, str) or isinstance(img_or_path, Path):
@@ -356,30 +370,17 @@ class AzureComputerVision:
else:
raise ValueError(f'img_or_path must be a path or PIL.Image, instead got: {img_or_path}')
image_io = self._preprocess(img)
logging.getLogger('urllib3.connectionpool').disabled = True
try:
read_response = self.client.read_in_stream(image_io, raw=True)
read_operation_location = read_response.headers['Operation-Location']
operation_id = read_operation_location.split('/')[-1]
while True:
read_result = self.client.get_read_result(operation_id)
if read_result.status.lower() not in [OperationStatusCodes.not_started, OperationStatusCodes.running]:
break
time.sleep(0.3)
except ClientRequestError:
read_result = self.client.analyze(image_data=self._preprocess(img), visual_features=[VisualFeatures.READ])
except ServiceRequestError:
return (False, 'Connection error!')
except:
return (False, 'Unknown error!')
res = ''
if read_result.status == OperationStatusCodes.succeeded:
for text_result in read_result.analyze_result.read_results:
for line in text_result.lines:
res += line.text + ' '
if read_result.read:
for line in read_result.read.blocks[0].lines:
res += line.text + ' '
else:
return (False, 'Unknown error!')
@@ -387,10 +388,7 @@ class AzureComputerVision:
return x
def _preprocess(self, img):
image_io = io.BytesIO()
img.save(image_io, format='png')
image_io.seek(0)
return image_io
return pil_image_to_bytes(img)
class EasyOCR:
name = 'easyocr'
@@ -424,7 +422,7 @@ class EasyOCR:
return x
def _preprocess(self, img):
return np.array(img.convert('RGB'))
return pil_image_to_numpy_array(img)
class RapidOCR:
name = 'rapidocr'
@@ -450,6 +448,7 @@ class RapidOCR:
logger.info('Loading RapidOCR model')
self.model = ROCR(rec_model_path=rapidocr_model_file)
logging.getLogger().disabled = True
self.available = True
logger.info('RapidOCR ready')
@@ -461,7 +460,6 @@ class RapidOCR:
else:
raise ValueError(f'img_or_path must be a path or PIL.Image, instead got: {img_or_path}')
logging.getLogger().disabled = True
res = ''
read_results, elapsed = self.model(self._preprocess(img))
if read_results:
@@ -472,4 +470,4 @@ class RapidOCR:
return x
def _preprocess(self, img):
return np.array(img.convert('RGB'))
return pil_image_to_numpy_array(img)