training and synthetic data generation code
This commit is contained in:
64
manga_ocr_dev/training/train.py
Normal file
64
manga_ocr_dev/training/train.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import fire
|
||||
import wandb
|
||||
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
|
||||
|
||||
from manga_ocr_dev.env import TRAIN_ROOT
|
||||
from manga_ocr_dev.training.dataset import MangaDataset
|
||||
from manga_ocr_dev.training.get_model import get_model
|
||||
from manga_ocr_dev.training.metrics import Metrics
|
||||
|
||||
|
||||
def run(
|
||||
run_name='debug',
|
||||
encoder_name='facebook/deit-tiny-patch16-224',
|
||||
decoder_name='cl-tohoku/bert-base-japanese-char-v2',
|
||||
max_len=300,
|
||||
num_decoder_layers=2,
|
||||
batch_size=64,
|
||||
num_epochs=8,
|
||||
fp16=True,
|
||||
):
|
||||
wandb.login()
|
||||
|
||||
model, processor = get_model(encoder_name, decoder_name, max_len, num_decoder_layers)
|
||||
|
||||
# keep package 0 for validation
|
||||
train_dataset = MangaDataset(processor, 'train', max_len, augment=True, skip_packages=[0])
|
||||
eval_dataset = MangaDataset(processor, 'test', max_len, augment=False, skip_packages=range(1, 9999))
|
||||
|
||||
metrics = Metrics(processor)
|
||||
|
||||
training_args = Seq2SeqTrainingArguments(
|
||||
predict_with_generate=True,
|
||||
evaluation_strategy='steps',
|
||||
save_strategy='steps',
|
||||
per_device_train_batch_size=batch_size,
|
||||
per_device_eval_batch_size=batch_size,
|
||||
fp16=fp16,
|
||||
fp16_full_eval=fp16,
|
||||
dataloader_num_workers=16,
|
||||
output_dir=TRAIN_ROOT,
|
||||
logging_steps=10,
|
||||
save_steps=20000,
|
||||
eval_steps=20000,
|
||||
num_train_epochs=num_epochs,
|
||||
run_name=run_name
|
||||
)
|
||||
|
||||
# instantiate trainer
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
tokenizer=processor.feature_extractor,
|
||||
args=training_args,
|
||||
compute_metrics=metrics.compute_metrics,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=default_data_collator,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
wandb.finish()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(run)
|
||||
Reference in New Issue
Block a user