import budou import numpy as np import pandas as pd from manga_ocr_dev.env import ASSETS_PATH, FONTS_ROOT from manga_ocr_dev.synthetic_data_generator.renderer import Renderer from manga_ocr_dev.synthetic_data_generator.utils import get_font_meta, get_charsets, is_ascii, is_kanji class SyntheticDataGenerator: def __init__(self): self.vocab, self.hiragana, self.katakana = get_charsets() self.len_to_p = pd.read_csv(ASSETS_PATH / 'len_to_p.csv') self.parser = budou.get_parser('tinysegmenter') self.fonts_df, self.font_map = get_font_meta() self.font_labels, self.font_p = self.get_font_labels_prob() self.renderer = Renderer() def process(self, text=None, override_css_params=None): """ Generate image, text pair. Use source text if provided, otherwise generate random text. """ if override_css_params is None: override_css_params = {} if text is None: # if using random text, choose font first, # and then generate text using only characters supported by that font if 'font_path' not in override_css_params: font_path = self.get_random_font() vocab = self.font_map[font_path] override_css_params['font_path'] = font_path else: font_path = override_css_params['font_path'] vocab = self.font_map[font_path] words = self.get_random_words(vocab) else: text = text.replace(' ', ' ') text = text.replace('…', '...') words = self.split_into_words(text) lines = self.words_to_lines(words) text_gt = '\n'.join(lines) if 'font_path' not in override_css_params: override_css_params['font_path'] = self.get_random_font(text_gt) font_path = override_css_params.get('font_path') if font_path: vocab = self.font_map.get(font_path) # remove unsupported characters lines = [''.join([c for c in line if c in vocab]) for line in lines] text_gt = '\n'.join(lines) else: vocab = None if np.random.random() < 0.5: word_prob = np.random.choice([0.33, 1.0], p=[0.3, 0.7]) lines = [self.add_random_furigana(line, word_prob, vocab) for line in lines] img, params = self.renderer.render(lines, override_css_params) return img, text_gt, params def get_random_words(self, vocab): vocab = list(vocab) max_text_len = np.random.choice(self.len_to_p.len, p=self.len_to_p.p) words = [] text_len = 0 while True: word = ''.join(np.random.choice(vocab, np.random.randint(1, 4))) words.append(word) text_len += len(word) if text_len + len(word) >= max_text_len: break return words def split_into_words(self, text): max_text_len = np.random.choice(self.len_to_p.len, p=self.len_to_p.p) words = [] text_len = 0 for chunk in self.parser.parse(text)['chunks']: words.append(chunk.word) text_len += len(chunk.word) if text_len + len(chunk.word) >= max_text_len: break return words def words_to_lines(self, words): text = ''.join(words) max_num_lines = 10 min_line_len = len(text) // max_num_lines max_line_len = 20 max_line_len = np.clip(np.random.poisson(6), min_line_len, max_line_len) lines = [] line = '' for word in words: line += word if len(line) >= max_line_len: lines.append(line) line = '' if line: lines.append(line) return lines def add_random_furigana(self, line, word_prob=1.0, vocab=None): if vocab is None: vocab = self.vocab else: vocab = list(vocab) processed = '' kanji_group = '' ascii_group = '' for i, c in enumerate(line): if is_kanji(c): c_type = 'kanji' kanji_group += c elif is_ascii(c): c_type = 'ascii' ascii_group += c else: c_type = 'other' if c_type != 'kanji' or i == len(line) - 1: if kanji_group: if np.random.uniform() < word_prob: furigana_len = int(np.clip(np.random.normal(1.5, 0.5), 1, 4) * len(kanji_group)) char_source = np.random.choice(['hiragana', 'katakana', 'all'], p=[0.8, 0.15, 0.05]) char_source = { 'hiragana': self.hiragana, 'katakana': self.katakana, 'all': vocab }[char_source] furigana = ''.join(np.random.choice(char_source, furigana_len)) processed += f'{kanji_group}{furigana}' else: processed += kanji_group kanji_group = '' if c_type != 'ascii' or i == len(line) - 1: if ascii_group: if len(ascii_group) <= 3 and np.random.uniform() < 0.7: processed += f'{ascii_group}' else: processed += ascii_group ascii_group = '' if c_type == 'other': processed += c return processed def is_font_supporting_text(self, font_path, text): chars = self.font_map[font_path] for c in text: if c.isspace(): continue if c not in chars: return False return True def get_font_labels_prob(self): labels = { 'common': 0.2, 'regular': 0.75, 'special': 0.05, } labels = {k: labels[k] for k in self.fonts_df.label.unique()} p = np.array(list(labels.values())) p = p / p.sum() labels = list(labels.keys()) return labels, p def get_random_font(self, text=None): label = np.random.choice(self.font_labels, p=self.font_p) df = self.fonts_df[self.fonts_df.label == label] if text is None: return df.sample(1).iloc[0].font_path valid_mask = df.font_path.apply(lambda x: self.is_font_supporting_text(x, text)) if not valid_mask.any(): # if text contains characters not supported by any font, just pick some of the more capable fonts valid_mask = (df.num_chars >= 4000) return str(FONTS_ROOT / df[valid_mask].sample(1).iloc[0].font_path)