training and synthetic data generation code
This commit is contained in:
0
manga_ocr_dev/training/__init__.py
Normal file
0
manga_ocr_dev/training/__init__.py
Normal file
165
manga_ocr_dev/training/dataset.py
Normal file
165
manga_ocr_dev/training/dataset.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import albumentations as A
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from manga_ocr_dev.env import MANGA109_ROOT, DATA_SYNTHETIC_ROOT
|
||||
|
||||
|
||||
class MangaDataset(Dataset):
|
||||
def __init__(self, processor, split, max_target_length, limit_size=None, augment=False, skip_packages=None):
|
||||
self.processor = processor
|
||||
self.max_target_length = max_target_length
|
||||
|
||||
data = []
|
||||
|
||||
print(f'Initializing dataset {split}...')
|
||||
|
||||
if skip_packages is None:
|
||||
skip_packages = set()
|
||||
else:
|
||||
skip_packages = {f'{x:04d}' for x in skip_packages}
|
||||
|
||||
for path in sorted((DATA_SYNTHETIC_ROOT / 'meta').glob('*.csv')):
|
||||
if path.stem in skip_packages:
|
||||
print(f'Skipping package {path}')
|
||||
continue
|
||||
if not (DATA_SYNTHETIC_ROOT / 'img' / path.stem).is_dir():
|
||||
print(f'Missing image data for package {path}, skipping')
|
||||
continue
|
||||
df = pd.read_csv(path)
|
||||
df = df.dropna()
|
||||
df['path'] = df.id.apply(lambda x: str(DATA_SYNTHETIC_ROOT / 'img' / path.stem / f'{x}.jpg'))
|
||||
df = df[['path', 'text']]
|
||||
df['synthetic'] = True
|
||||
data.append(df)
|
||||
|
||||
df = pd.read_csv(MANGA109_ROOT / 'data.csv')
|
||||
df = df[df.split == split].reset_index(drop=True)
|
||||
df['path'] = df.crop_path.apply(lambda x: str(MANGA109_ROOT / x))
|
||||
df = df[['path', 'text']]
|
||||
df['synthetic'] = False
|
||||
data.append(df)
|
||||
|
||||
data = pd.concat(data, ignore_index=True)
|
||||
|
||||
if limit_size:
|
||||
data = data.iloc[:limit_size]
|
||||
self.data = data
|
||||
|
||||
print(f'Dataset {split}: {len(self.data)}')
|
||||
|
||||
self.augment = augment
|
||||
self.transform_medium, self.transform_heavy = self.get_transforms()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.data.loc[idx]
|
||||
text = sample.text
|
||||
|
||||
if self.augment:
|
||||
medium_p = 0.8
|
||||
heavy_p = 0.02
|
||||
transform_variant = np.random.choice(['none', 'medium', 'heavy'],
|
||||
p=[1 - medium_p - heavy_p, medium_p, heavy_p])
|
||||
transform = {
|
||||
'none': None,
|
||||
'medium': self.transform_medium,
|
||||
'heavy': self.transform_heavy,
|
||||
}[transform_variant]
|
||||
else:
|
||||
transform = None
|
||||
|
||||
pixel_values = self.read_image(self.processor, sample.path, transform)
|
||||
labels = self.processor.tokenizer(text,
|
||||
padding="max_length",
|
||||
max_length=self.max_target_length,
|
||||
truncation=True).input_ids
|
||||
labels = np.array(labels)
|
||||
# important: make sure that PAD tokens are ignored by the loss function
|
||||
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
||||
|
||||
encoding = {
|
||||
"pixel_values": pixel_values,
|
||||
"labels": torch.tensor(labels),
|
||||
}
|
||||
return encoding
|
||||
|
||||
@staticmethod
|
||||
def read_image(processor, path, transform=None):
|
||||
img = cv2.imread(str(path))
|
||||
|
||||
if transform is None:
|
||||
transform = A.ToGray(always_apply=True)
|
||||
|
||||
img = transform(image=img)['image']
|
||||
|
||||
pixel_values = processor(img, return_tensors="pt").pixel_values
|
||||
return pixel_values.squeeze()
|
||||
|
||||
@staticmethod
|
||||
def get_transforms():
|
||||
t_medium = A.Compose([
|
||||
A.Rotate(5, border_mode=cv2.BORDER_REPLICATE, p=0.2),
|
||||
A.Perspective((0.01, 0.05), pad_mode=cv2.BORDER_REPLICATE, p=0.2),
|
||||
A.InvertImg(p=0.05),
|
||||
|
||||
A.OneOf([
|
||||
A.Downscale(0.25, 0.5, interpolation=cv2.INTER_LINEAR),
|
||||
A.Downscale(0.25, 0.5, interpolation=cv2.INTER_NEAREST),
|
||||
], p=0.1),
|
||||
A.Blur(p=0.2),
|
||||
A.Sharpen(p=0.2),
|
||||
A.RandomBrightnessContrast(p=0.5),
|
||||
A.GaussNoise((50, 200), p=0.3),
|
||||
A.ImageCompression(0, 30, p=0.1),
|
||||
A.ToGray(always_apply=True),
|
||||
])
|
||||
|
||||
t_heavy = A.Compose([
|
||||
A.Rotate(10, border_mode=cv2.BORDER_REPLICATE, p=0.2),
|
||||
A.Perspective((0.01, 0.05), pad_mode=cv2.BORDER_REPLICATE, p=0.2),
|
||||
A.InvertImg(p=0.05),
|
||||
|
||||
A.OneOf([
|
||||
A.Downscale(0.1, 0.2, interpolation=cv2.INTER_LINEAR),
|
||||
A.Downscale(0.1, 0.2, interpolation=cv2.INTER_NEAREST),
|
||||
], p=0.1),
|
||||
A.Blur((4, 9), p=0.5),
|
||||
A.Sharpen(p=0.5),
|
||||
A.RandomBrightnessContrast(0.8, 0.8, p=1),
|
||||
A.GaussNoise((1000, 10000), p=0.3),
|
||||
A.ImageCompression(0, 10, p=0.5),
|
||||
A.ToGray(always_apply=True),
|
||||
])
|
||||
|
||||
return t_medium, t_heavy
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from manga_ocr_dev.training.get_model import get_processor
|
||||
from manga_ocr_dev.training.utils import tensor_to_image
|
||||
|
||||
encoder_name = 'facebook/deit-tiny-patch16-224'
|
||||
decoder_name = 'cl-tohoku/bert-base-japanese-char-v2'
|
||||
|
||||
max_length = 300
|
||||
|
||||
processor = get_processor(encoder_name, decoder_name)
|
||||
ds = MangaDataset(processor, 'train', max_length, augment=True)
|
||||
|
||||
for i in range(20):
|
||||
sample = ds[0]
|
||||
img = tensor_to_image(sample['pixel_values'])
|
||||
tokens = sample['labels']
|
||||
tokens[tokens == -100] = processor.tokenizer.pad_token_id
|
||||
text = ''.join(processor.decode(tokens, skip_special_tokens=True).split())
|
||||
|
||||
print(f'{i}:\n{text}\n')
|
||||
plt.imshow(img)
|
||||
plt.show()
|
||||
63
manga_ocr_dev/training/get_model.py
Normal file
63
manga_ocr_dev/training/get_model.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoModel, TrOCRProcessor, VisionEncoderDecoderModel, \
|
||||
AutoFeatureExtractor, AutoTokenizer, VisionEncoderDecoderConfig
|
||||
|
||||
|
||||
class TrOCRProcessorCustom(TrOCRProcessor):
|
||||
"""The only point of this class is to bypass type checks of base class."""
|
||||
|
||||
def __init__(self, feature_extractor, tokenizer):
|
||||
self.feature_extractor = feature_extractor
|
||||
self.tokenizer = tokenizer
|
||||
self.current_processor = self.feature_extractor
|
||||
|
||||
|
||||
def get_processor(encoder_name, decoder_name):
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(decoder_name)
|
||||
processor = TrOCRProcessorCustom(feature_extractor, tokenizer)
|
||||
return processor
|
||||
|
||||
|
||||
def get_model(encoder_name, decoder_name, max_length, num_decoder_layers=None):
|
||||
encoder_config = AutoConfig.from_pretrained(encoder_name)
|
||||
encoder_config.is_decoder = False
|
||||
encoder_config.add_cross_attention = False
|
||||
encoder = AutoModel.from_config(encoder_config)
|
||||
|
||||
decoder_config = AutoConfig.from_pretrained(decoder_name)
|
||||
decoder_config.max_length = max_length
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.add_cross_attention = True
|
||||
decoder = AutoModelForCausalLM.from_config(decoder_config)
|
||||
|
||||
if num_decoder_layers is not None:
|
||||
if decoder_config.model_type == 'bert':
|
||||
decoder.bert.encoder.layer = decoder.bert.encoder.layer[-num_decoder_layers:]
|
||||
elif decoder_config.model_type in ('roberta', 'xlm-roberta'):
|
||||
decoder.roberta.encoder.layer = decoder.roberta.encoder.layer[-num_decoder_layers:]
|
||||
else:
|
||||
raise ValueError(f'Unsupported model_type: {decoder_config.model_type}')
|
||||
|
||||
decoder_config.num_hidden_layers = num_decoder_layers
|
||||
|
||||
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
|
||||
config.tie_word_embeddings = False
|
||||
model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder, config=config)
|
||||
|
||||
processor = get_processor(encoder_name, decoder_name)
|
||||
|
||||
# set special tokens used for creating the decoder_input_ids from the labels
|
||||
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
|
||||
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
||||
# make sure vocab size is set correctly
|
||||
model.config.vocab_size = model.config.decoder.vocab_size
|
||||
|
||||
# set beam search parameters
|
||||
model.config.eos_token_id = processor.tokenizer.sep_token_id
|
||||
model.config.max_length = max_length
|
||||
model.config.early_stopping = True
|
||||
model.config.no_repeat_ngram_size = 3
|
||||
model.config.length_penalty = 2.0
|
||||
model.config.num_beams = 4
|
||||
|
||||
return model, processor
|
||||
32
manga_ocr_dev/training/metrics.py
Normal file
32
manga_ocr_dev/training/metrics.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import numpy as np
|
||||
from datasets import load_metric
|
||||
|
||||
|
||||
class Metrics:
|
||||
def __init__(self, processor):
|
||||
self.cer_metric = load_metric("cer")
|
||||
self.processor = processor
|
||||
|
||||
def compute_metrics(self, pred):
|
||||
label_ids = pred.label_ids
|
||||
pred_ids = pred.predictions
|
||||
print(label_ids.shape, pred_ids.shape)
|
||||
|
||||
pred_str = self.processor.batch_decode(pred_ids, skip_special_tokens=True)
|
||||
label_ids[label_ids == -100] = self.processor.tokenizer.pad_token_id
|
||||
label_str = self.processor.batch_decode(label_ids, skip_special_tokens=True)
|
||||
|
||||
pred_str = np.array([''.join(text.split()) for text in pred_str])
|
||||
label_str = np.array([''.join(text.split()) for text in label_str])
|
||||
|
||||
results = {}
|
||||
try:
|
||||
results['cer'] = self.cer_metric.compute(predictions=pred_str, references=label_str)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(pred_str)
|
||||
print(label_str)
|
||||
results['cer'] = 0
|
||||
results['accuracy'] = (pred_str == label_str).mean()
|
||||
|
||||
return results
|
||||
64
manga_ocr_dev/training/train.py
Normal file
64
manga_ocr_dev/training/train.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import fire
|
||||
import wandb
|
||||
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
|
||||
|
||||
from manga_ocr_dev.env import TRAIN_ROOT
|
||||
from manga_ocr_dev.training.dataset import MangaDataset
|
||||
from manga_ocr_dev.training.get_model import get_model
|
||||
from manga_ocr_dev.training.metrics import Metrics
|
||||
|
||||
|
||||
def run(
|
||||
run_name='debug',
|
||||
encoder_name='facebook/deit-tiny-patch16-224',
|
||||
decoder_name='cl-tohoku/bert-base-japanese-char-v2',
|
||||
max_len=300,
|
||||
num_decoder_layers=2,
|
||||
batch_size=64,
|
||||
num_epochs=8,
|
||||
fp16=True,
|
||||
):
|
||||
wandb.login()
|
||||
|
||||
model, processor = get_model(encoder_name, decoder_name, max_len, num_decoder_layers)
|
||||
|
||||
# keep package 0 for validation
|
||||
train_dataset = MangaDataset(processor, 'train', max_len, augment=True, skip_packages=[0])
|
||||
eval_dataset = MangaDataset(processor, 'test', max_len, augment=False, skip_packages=range(1, 9999))
|
||||
|
||||
metrics = Metrics(processor)
|
||||
|
||||
training_args = Seq2SeqTrainingArguments(
|
||||
predict_with_generate=True,
|
||||
evaluation_strategy='steps',
|
||||
save_strategy='steps',
|
||||
per_device_train_batch_size=batch_size,
|
||||
per_device_eval_batch_size=batch_size,
|
||||
fp16=fp16,
|
||||
fp16_full_eval=fp16,
|
||||
dataloader_num_workers=16,
|
||||
output_dir=TRAIN_ROOT,
|
||||
logging_steps=10,
|
||||
save_steps=20000,
|
||||
eval_steps=20000,
|
||||
num_train_epochs=num_epochs,
|
||||
run_name=run_name
|
||||
)
|
||||
|
||||
# instantiate trainer
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
tokenizer=processor.feature_extractor,
|
||||
args=training_args,
|
||||
compute_metrics=metrics.compute_metrics,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=default_data_collator,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
wandb.finish()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(run)
|
||||
27
manga_ocr_dev/training/utils.py
Normal file
27
manga_ocr_dev/training/utils.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchinfo import summary
|
||||
|
||||
|
||||
def encoder_summary(model, batch_size=4):
|
||||
img_size = model.config.encoder.image_size
|
||||
return summary(model.encoder, input_size=(batch_size, 3, img_size, img_size), depth=3,
|
||||
col_names=["output_size", "num_params", "mult_adds"], device='cpu')
|
||||
|
||||
|
||||
def decoder_summary(model, batch_size=4):
|
||||
img_size = model.config.encoder.image_size
|
||||
encoder_hidden_shape = (batch_size, (img_size // 16) ** 2 + 1, model.config.decoder.hidden_size)
|
||||
decoder_inputs = {
|
||||
'input_ids': torch.zeros(batch_size, 1, dtype=torch.int64),
|
||||
'attention_mask': torch.ones(batch_size, 1, dtype=torch.int64),
|
||||
'encoder_hidden_states': torch.rand(encoder_hidden_shape, dtype=torch.float32),
|
||||
'return_dict': False
|
||||
}
|
||||
return summary(model.decoder, input_data=decoder_inputs, depth=4,
|
||||
col_names=["output_size", "num_params", "mult_adds"],
|
||||
device='cpu')
|
||||
|
||||
|
||||
def tensor_to_image(img):
|
||||
return ((img.cpu().numpy() + 1) / 2 * 255).clip(0, 255).astype(np.uint8).transpose(1, 2, 0)
|
||||
Reference in New Issue
Block a user