65 lines
1.8 KiB
Python
65 lines
1.8 KiB
Python
import fire
|
|
import wandb
|
|
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
|
|
|
|
from manga_ocr_dev.env import TRAIN_ROOT
|
|
from manga_ocr_dev.training.dataset import MangaDataset
|
|
from manga_ocr_dev.training.get_model import get_model
|
|
from manga_ocr_dev.training.metrics import Metrics
|
|
|
|
|
|
def run(
|
|
run_name='debug',
|
|
encoder_name='facebook/deit-tiny-patch16-224',
|
|
decoder_name='cl-tohoku/bert-base-japanese-char-v2',
|
|
max_len=300,
|
|
num_decoder_layers=2,
|
|
batch_size=64,
|
|
num_epochs=8,
|
|
fp16=True,
|
|
):
|
|
wandb.login()
|
|
|
|
model, processor = get_model(encoder_name, decoder_name, max_len, num_decoder_layers)
|
|
|
|
# keep package 0 for validation
|
|
train_dataset = MangaDataset(processor, 'train', max_len, augment=True, skip_packages=[0])
|
|
eval_dataset = MangaDataset(processor, 'test', max_len, augment=False, skip_packages=range(1, 9999))
|
|
|
|
metrics = Metrics(processor)
|
|
|
|
training_args = Seq2SeqTrainingArguments(
|
|
predict_with_generate=True,
|
|
evaluation_strategy='steps',
|
|
save_strategy='steps',
|
|
per_device_train_batch_size=batch_size,
|
|
per_device_eval_batch_size=batch_size,
|
|
fp16=fp16,
|
|
fp16_full_eval=fp16,
|
|
dataloader_num_workers=16,
|
|
output_dir=TRAIN_ROOT,
|
|
logging_steps=10,
|
|
save_steps=20000,
|
|
eval_steps=20000,
|
|
num_train_epochs=num_epochs,
|
|
run_name=run_name
|
|
)
|
|
|
|
# instantiate trainer
|
|
trainer = Seq2SeqTrainer(
|
|
model=model,
|
|
tokenizer=processor.feature_extractor,
|
|
args=training_args,
|
|
compute_metrics=metrics.compute_metrics,
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
data_collator=default_data_collator,
|
|
)
|
|
trainer.train()
|
|
|
|
wandb.finish()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
fire.Fire(run)
|