64 lines
2.7 KiB
Python
64 lines
2.7 KiB
Python
from transformers import AutoConfig, AutoModelForCausalLM, AutoModel, TrOCRProcessor, VisionEncoderDecoderModel, \
|
|
AutoFeatureExtractor, AutoTokenizer, VisionEncoderDecoderConfig
|
|
|
|
|
|
class TrOCRProcessorCustom(TrOCRProcessor):
|
|
"""The only point of this class is to bypass type checks of base class."""
|
|
|
|
def __init__(self, feature_extractor, tokenizer):
|
|
self.feature_extractor = feature_extractor
|
|
self.tokenizer = tokenizer
|
|
self.current_processor = self.feature_extractor
|
|
|
|
|
|
def get_processor(encoder_name, decoder_name):
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_name)
|
|
tokenizer = AutoTokenizer.from_pretrained(decoder_name)
|
|
processor = TrOCRProcessorCustom(feature_extractor, tokenizer)
|
|
return processor
|
|
|
|
|
|
def get_model(encoder_name, decoder_name, max_length, num_decoder_layers=None):
|
|
encoder_config = AutoConfig.from_pretrained(encoder_name)
|
|
encoder_config.is_decoder = False
|
|
encoder_config.add_cross_attention = False
|
|
encoder = AutoModel.from_config(encoder_config)
|
|
|
|
decoder_config = AutoConfig.from_pretrained(decoder_name)
|
|
decoder_config.max_length = max_length
|
|
decoder_config.is_decoder = True
|
|
decoder_config.add_cross_attention = True
|
|
decoder = AutoModelForCausalLM.from_config(decoder_config)
|
|
|
|
if num_decoder_layers is not None:
|
|
if decoder_config.model_type == 'bert':
|
|
decoder.bert.encoder.layer = decoder.bert.encoder.layer[-num_decoder_layers:]
|
|
elif decoder_config.model_type in ('roberta', 'xlm-roberta'):
|
|
decoder.roberta.encoder.layer = decoder.roberta.encoder.layer[-num_decoder_layers:]
|
|
else:
|
|
raise ValueError(f'Unsupported model_type: {decoder_config.model_type}')
|
|
|
|
decoder_config.num_hidden_layers = num_decoder_layers
|
|
|
|
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
|
|
config.tie_word_embeddings = False
|
|
model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder, config=config)
|
|
|
|
processor = get_processor(encoder_name, decoder_name)
|
|
|
|
# set special tokens used for creating the decoder_input_ids from the labels
|
|
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
|
|
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
|
# make sure vocab size is set correctly
|
|
model.config.vocab_size = model.config.decoder.vocab_size
|
|
|
|
# set beam search parameters
|
|
model.config.eos_token_id = processor.tokenizer.sep_token_id
|
|
model.config.max_length = max_length
|
|
model.config.early_stopping = True
|
|
model.config.no_repeat_ngram_size = 3
|
|
model.config.length_penalty = 2.0
|
|
model.config.num_beams = 4
|
|
|
|
return model, processor
|