Files
owocr/manga_ocr_dev/training/metrics.py
2022-02-09 20:39:37 +01:00

33 lines
1.1 KiB
Python

import numpy as np
from datasets import load_metric
class Metrics:
def __init__(self, processor):
self.cer_metric = load_metric("cer")
self.processor = processor
def compute_metrics(self, pred):
label_ids = pred.label_ids
pred_ids = pred.predictions
print(label_ids.shape, pred_ids.shape)
pred_str = self.processor.batch_decode(pred_ids, skip_special_tokens=True)
label_ids[label_ids == -100] = self.processor.tokenizer.pad_token_id
label_str = self.processor.batch_decode(label_ids, skip_special_tokens=True)
pred_str = np.array([''.join(text.split()) for text in pred_str])
label_str = np.array([''.join(text.split()) for text in label_str])
results = {}
try:
results['cer'] = self.cer_metric.compute(predictions=pred_str, references=label_str)
except Exception as e:
print(e)
print(pred_str)
print(label_str)
results['cer'] = 0
results['accuracy'] = (pred_str == label_str).mean()
return results