training and synthetic data generation code

This commit is contained in:
Maciej Budyś
2022-02-09 20:39:01 +01:00
parent a9085393f4
commit 975dbf4d5e
42 changed files with 7089 additions and 15 deletions

View File

View File

@@ -0,0 +1,165 @@
import albumentations as A
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from manga_ocr_dev.env import MANGA109_ROOT, DATA_SYNTHETIC_ROOT
class MangaDataset(Dataset):
def __init__(self, processor, split, max_target_length, limit_size=None, augment=False, skip_packages=None):
self.processor = processor
self.max_target_length = max_target_length
data = []
print(f'Initializing dataset {split}...')
if skip_packages is None:
skip_packages = set()
else:
skip_packages = {f'{x:04d}' for x in skip_packages}
for path in sorted((DATA_SYNTHETIC_ROOT / 'meta').glob('*.csv')):
if path.stem in skip_packages:
print(f'Skipping package {path}')
continue
if not (DATA_SYNTHETIC_ROOT / 'img' / path.stem).is_dir():
print(f'Missing image data for package {path}, skipping')
continue
df = pd.read_csv(path)
df = df.dropna()
df['path'] = df.id.apply(lambda x: str(DATA_SYNTHETIC_ROOT / 'img' / path.stem / f'{x}.jpg'))
df = df[['path', 'text']]
df['synthetic'] = True
data.append(df)
df = pd.read_csv(MANGA109_ROOT / 'data.csv')
df = df[df.split == split].reset_index(drop=True)
df['path'] = df.crop_path.apply(lambda x: str(MANGA109_ROOT / x))
df = df[['path', 'text']]
df['synthetic'] = False
data.append(df)
data = pd.concat(data, ignore_index=True)
if limit_size:
data = data.iloc[:limit_size]
self.data = data
print(f'Dataset {split}: {len(self.data)}')
self.augment = augment
self.transform_medium, self.transform_heavy = self.get_transforms()
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data.loc[idx]
text = sample.text
if self.augment:
medium_p = 0.8
heavy_p = 0.02
transform_variant = np.random.choice(['none', 'medium', 'heavy'],
p=[1 - medium_p - heavy_p, medium_p, heavy_p])
transform = {
'none': None,
'medium': self.transform_medium,
'heavy': self.transform_heavy,
}[transform_variant]
else:
transform = None
pixel_values = self.read_image(self.processor, sample.path, transform)
labels = self.processor.tokenizer(text,
padding="max_length",
max_length=self.max_target_length,
truncation=True).input_ids
labels = np.array(labels)
# important: make sure that PAD tokens are ignored by the loss function
labels[labels == self.processor.tokenizer.pad_token_id] = -100
encoding = {
"pixel_values": pixel_values,
"labels": torch.tensor(labels),
}
return encoding
@staticmethod
def read_image(processor, path, transform=None):
img = cv2.imread(str(path))
if transform is None:
transform = A.ToGray(always_apply=True)
img = transform(image=img)['image']
pixel_values = processor(img, return_tensors="pt").pixel_values
return pixel_values.squeeze()
@staticmethod
def get_transforms():
t_medium = A.Compose([
A.Rotate(5, border_mode=cv2.BORDER_REPLICATE, p=0.2),
A.Perspective((0.01, 0.05), pad_mode=cv2.BORDER_REPLICATE, p=0.2),
A.InvertImg(p=0.05),
A.OneOf([
A.Downscale(0.25, 0.5, interpolation=cv2.INTER_LINEAR),
A.Downscale(0.25, 0.5, interpolation=cv2.INTER_NEAREST),
], p=0.1),
A.Blur(p=0.2),
A.Sharpen(p=0.2),
A.RandomBrightnessContrast(p=0.5),
A.GaussNoise((50, 200), p=0.3),
A.ImageCompression(0, 30, p=0.1),
A.ToGray(always_apply=True),
])
t_heavy = A.Compose([
A.Rotate(10, border_mode=cv2.BORDER_REPLICATE, p=0.2),
A.Perspective((0.01, 0.05), pad_mode=cv2.BORDER_REPLICATE, p=0.2),
A.InvertImg(p=0.05),
A.OneOf([
A.Downscale(0.1, 0.2, interpolation=cv2.INTER_LINEAR),
A.Downscale(0.1, 0.2, interpolation=cv2.INTER_NEAREST),
], p=0.1),
A.Blur((4, 9), p=0.5),
A.Sharpen(p=0.5),
A.RandomBrightnessContrast(0.8, 0.8, p=1),
A.GaussNoise((1000, 10000), p=0.3),
A.ImageCompression(0, 10, p=0.5),
A.ToGray(always_apply=True),
])
return t_medium, t_heavy
if __name__ == '__main__':
from manga_ocr_dev.training.get_model import get_processor
from manga_ocr_dev.training.utils import tensor_to_image
encoder_name = 'facebook/deit-tiny-patch16-224'
decoder_name = 'cl-tohoku/bert-base-japanese-char-v2'
max_length = 300
processor = get_processor(encoder_name, decoder_name)
ds = MangaDataset(processor, 'train', max_length, augment=True)
for i in range(20):
sample = ds[0]
img = tensor_to_image(sample['pixel_values'])
tokens = sample['labels']
tokens[tokens == -100] = processor.tokenizer.pad_token_id
text = ''.join(processor.decode(tokens, skip_special_tokens=True).split())
print(f'{i}:\n{text}\n')
plt.imshow(img)
plt.show()

View File

@@ -0,0 +1,63 @@
from transformers import AutoConfig, AutoModelForCausalLM, AutoModel, TrOCRProcessor, VisionEncoderDecoderModel, \
AutoFeatureExtractor, AutoTokenizer, VisionEncoderDecoderConfig
class TrOCRProcessorCustom(TrOCRProcessor):
"""The only point of this class is to bypass type checks of base class."""
def __init__(self, feature_extractor, tokenizer):
self.feature_extractor = feature_extractor
self.tokenizer = tokenizer
self.current_processor = self.feature_extractor
def get_processor(encoder_name, decoder_name):
feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_name)
tokenizer = AutoTokenizer.from_pretrained(decoder_name)
processor = TrOCRProcessorCustom(feature_extractor, tokenizer)
return processor
def get_model(encoder_name, decoder_name, max_length, num_decoder_layers=None):
encoder_config = AutoConfig.from_pretrained(encoder_name)
encoder_config.is_decoder = False
encoder_config.add_cross_attention = False
encoder = AutoModel.from_config(encoder_config)
decoder_config = AutoConfig.from_pretrained(decoder_name)
decoder_config.max_length = max_length
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
decoder = AutoModelForCausalLM.from_config(decoder_config)
if num_decoder_layers is not None:
if decoder_config.model_type == 'bert':
decoder.bert.encoder.layer = decoder.bert.encoder.layer[-num_decoder_layers:]
elif decoder_config.model_type in ('roberta', 'xlm-roberta'):
decoder.roberta.encoder.layer = decoder.roberta.encoder.layer[-num_decoder_layers:]
else:
raise ValueError(f'Unsupported model_type: {decoder_config.model_type}')
decoder_config.num_hidden_layers = num_decoder_layers
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
config.tie_word_embeddings = False
model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder, config=config)
processor = get_processor(encoder_name, decoder_name)
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size
# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = max_length
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
return model, processor

View File

@@ -0,0 +1,32 @@
import numpy as np
from datasets import load_metric
class Metrics:
def __init__(self, processor):
self.cer_metric = load_metric("cer")
self.processor = processor
def compute_metrics(self, pred):
label_ids = pred.label_ids
pred_ids = pred.predictions
print(label_ids.shape, pred_ids.shape)
pred_str = self.processor.batch_decode(pred_ids, skip_special_tokens=True)
label_ids[label_ids == -100] = self.processor.tokenizer.pad_token_id
label_str = self.processor.batch_decode(label_ids, skip_special_tokens=True)
pred_str = np.array([''.join(text.split()) for text in pred_str])
label_str = np.array([''.join(text.split()) for text in label_str])
results = {}
try:
results['cer'] = self.cer_metric.compute(predictions=pred_str, references=label_str)
except Exception as e:
print(e)
print(pred_str)
print(label_str)
results['cer'] = 0
results['accuracy'] = (pred_str == label_str).mean()
return results

View File

@@ -0,0 +1,64 @@
import fire
import wandb
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
from manga_ocr_dev.env import TRAIN_ROOT
from manga_ocr_dev.training.dataset import MangaDataset
from manga_ocr_dev.training.get_model import get_model
from manga_ocr_dev.training.metrics import Metrics
def run(
run_name='debug',
encoder_name='facebook/deit-tiny-patch16-224',
decoder_name='cl-tohoku/bert-base-japanese-char-v2',
max_len=300,
num_decoder_layers=2,
batch_size=64,
num_epochs=8,
fp16=True,
):
wandb.login()
model, processor = get_model(encoder_name, decoder_name, max_len, num_decoder_layers)
# keep package 0 for validation
train_dataset = MangaDataset(processor, 'train', max_len, augment=True, skip_packages=[0])
eval_dataset = MangaDataset(processor, 'test', max_len, augment=False, skip_packages=range(1, 9999))
metrics = Metrics(processor)
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
evaluation_strategy='steps',
save_strategy='steps',
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
fp16=fp16,
fp16_full_eval=fp16,
dataloader_num_workers=16,
output_dir=TRAIN_ROOT,
logging_steps=10,
save_steps=20000,
eval_steps=20000,
num_train_epochs=num_epochs,
run_name=run_name
)
# instantiate trainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=processor.feature_extractor,
args=training_args,
compute_metrics=metrics.compute_metrics,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=default_data_collator,
)
trainer.train()
wandb.finish()
if __name__ == '__main__':
fire.Fire(run)

View File

@@ -0,0 +1,27 @@
import numpy as np
import torch
from torchinfo import summary
def encoder_summary(model, batch_size=4):
img_size = model.config.encoder.image_size
return summary(model.encoder, input_size=(batch_size, 3, img_size, img_size), depth=3,
col_names=["output_size", "num_params", "mult_adds"], device='cpu')
def decoder_summary(model, batch_size=4):
img_size = model.config.encoder.image_size
encoder_hidden_shape = (batch_size, (img_size // 16) ** 2 + 1, model.config.decoder.hidden_size)
decoder_inputs = {
'input_ids': torch.zeros(batch_size, 1, dtype=torch.int64),
'attention_mask': torch.ones(batch_size, 1, dtype=torch.int64),
'encoder_hidden_states': torch.rand(encoder_hidden_shape, dtype=torch.float32),
'return_dict': False
}
return summary(model.decoder, input_data=decoder_inputs, depth=4,
col_names=["output_size", "num_params", "mult_adds"],
device='cpu')
def tensor_to_image(img):
return ((img.cpu().numpy() + 1) / 2 * 255).clip(0, 255).astype(np.uint8).transpose(1, 2, 0)