training and synthetic data generation code

This commit is contained in:
Maciej Budyś
2022-02-09 20:39:01 +01:00
parent a9085393f4
commit 975dbf4d5e
42 changed files with 7089 additions and 15 deletions

View File

@@ -0,0 +1,64 @@
import fire
import wandb
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
from manga_ocr_dev.env import TRAIN_ROOT
from manga_ocr_dev.training.dataset import MangaDataset
from manga_ocr_dev.training.get_model import get_model
from manga_ocr_dev.training.metrics import Metrics
def run(
run_name='debug',
encoder_name='facebook/deit-tiny-patch16-224',
decoder_name='cl-tohoku/bert-base-japanese-char-v2',
max_len=300,
num_decoder_layers=2,
batch_size=64,
num_epochs=8,
fp16=True,
):
wandb.login()
model, processor = get_model(encoder_name, decoder_name, max_len, num_decoder_layers)
# keep package 0 for validation
train_dataset = MangaDataset(processor, 'train', max_len, augment=True, skip_packages=[0])
eval_dataset = MangaDataset(processor, 'test', max_len, augment=False, skip_packages=range(1, 9999))
metrics = Metrics(processor)
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
evaluation_strategy='steps',
save_strategy='steps',
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
fp16=fp16,
fp16_full_eval=fp16,
dataloader_num_workers=16,
output_dir=TRAIN_ROOT,
logging_steps=10,
save_steps=20000,
eval_steps=20000,
num_train_epochs=num_epochs,
run_name=run_name
)
# instantiate trainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=processor.feature_extractor,
args=training_args,
compute_metrics=metrics.compute_metrics,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=default_data_collator,
)
trainer.train()
wandb.finish()
if __name__ == '__main__':
fire.Fire(run)