training and synthetic data generation code

This commit is contained in:
Maciej Budyś
2022-02-09 20:39:01 +01:00
parent a9085393f4
commit 975dbf4d5e
42 changed files with 7089 additions and 15 deletions

View File

@@ -0,0 +1,38 @@
# Synthetic data generator
Generation of synthetic image-text pairs imitating Japanese manga for the purpose of training OCR.
Features:
- using either text from corpus or random text
- text overlaid on background images
- drawing text bubbles
- various fonts and font styles
- variety of text layouts:
- vertical and horizontal text
- multi-line text
- [furigana](https://en.wikipedia.org/wiki/Furigana) (added randomly)
- [tate chū yoko](https://www.w3.org/International/articles/vertical-text/#tcy)
Text rendering is done with the usage of [html2image](https://github.com/vgalin/html2image),
which is a wrapper around Chrome/Chromium browser's headless mode.
It's not too elegant of a solution, and it is very slow, but it only needs to be run once,
and when parallelized, processing time is manageable (~17 min per 10000 images on a 16-thread machine).
The upside of this approach is that a quite complex problem of typesetting and text rendering
(especially when dealing with both horizontal and vertical text) is offloaded to
the browser engine, keeping the codebase relatively simple and extendable.
High-level generation pipeline is as follows:
1. Preprocess text (truncate and/or split into lines, add random furigana).
2. Render text on a transparent background, using HTML engine.
3. Select background image from backgrounds dataset.
4. Overlay the text on the background, optionally drawing a bubble around the text.
# Examples
## Images generated with text from [CC-100 Japanese corpus](https://data.statmt.org/cc-100/)
![](../../assets/examples/cc-100.jpg)
## Images generated with random text
![](../../assets/examples/random.jpg)

View File

@@ -0,0 +1,198 @@
import budou
import numpy as np
import pandas as pd
from manga_ocr_dev.env import ASSETS_PATH, FONTS_ROOT
from manga_ocr_dev.synthetic_data_generator.renderer import Renderer
from manga_ocr_dev.synthetic_data_generator.utils import get_font_meta, get_charsets, is_ascii, is_kanji
class SyntheticDataGenerator:
def __init__(self):
self.vocab, self.hiragana, self.katakana = get_charsets()
self.len_to_p = pd.read_csv(ASSETS_PATH / 'len_to_p.csv')
self.parser = budou.get_parser('tinysegmenter')
self.fonts_df, self.font_map = get_font_meta()
self.font_labels, self.font_p = self.get_font_labels_prob()
self.renderer = Renderer()
def process(self, text=None, override_css_params=None):
"""
Generate image, text pair. Use source text if provided, otherwise generate random text.
"""
if override_css_params is None:
override_css_params = {}
if text is None:
# if using random text, choose font first,
# and then generate text using only characters supported by that font
if 'font_path' not in override_css_params:
font_path = self.get_random_font()
vocab = self.font_map[font_path]
override_css_params['font_path'] = font_path
else:
font_path = override_css_params['font_path']
vocab = self.font_map[font_path]
words = self.get_random_words(vocab)
else:
text = text.replace(' ', ' ')
text = text.replace('', '...')
words = self.split_into_words(text)
lines = self.words_to_lines(words)
text_gt = '\n'.join(lines)
if 'font_path' not in override_css_params:
override_css_params['font_path'] = self.get_random_font(text_gt)
font_path = override_css_params.get('font_path')
if font_path:
vocab = self.font_map.get(font_path)
# remove unsupported characters
lines = [''.join([c for c in line if c in vocab]) for line in lines]
text_gt = '\n'.join(lines)
else:
vocab = None
if np.random.random() < 0.5:
word_prob = np.random.choice([0.33, 1.0], p=[0.3, 0.7])
lines = [self.add_random_furigana(line, word_prob, vocab) for line in lines]
img, params = self.renderer.render(lines, override_css_params)
return img, text_gt, params
def get_random_words(self, vocab):
vocab = list(vocab)
max_text_len = np.random.choice(self.len_to_p.len, p=self.len_to_p.p)
words = []
text_len = 0
while True:
word = ''.join(np.random.choice(vocab, np.random.randint(1, 4)))
words.append(word)
text_len += len(word)
if text_len + len(word) >= max_text_len:
break
return words
def split_into_words(self, text):
max_text_len = np.random.choice(self.len_to_p.len, p=self.len_to_p.p)
words = []
text_len = 0
for chunk in self.parser.parse(text)['chunks']:
words.append(chunk.word)
text_len += len(chunk.word)
if text_len + len(chunk.word) >= max_text_len:
break
return words
def words_to_lines(self, words):
text = ''.join(words)
max_num_lines = 10
min_line_len = len(text) // max_num_lines
max_line_len = 20
max_line_len = np.clip(np.random.poisson(6), min_line_len, max_line_len)
lines = []
line = ''
for word in words:
line += word
if len(line) >= max_line_len:
lines.append(line)
line = ''
if line:
lines.append(line)
return lines
def add_random_furigana(self, line, word_prob=1.0, vocab=None):
if vocab is None:
vocab = self.vocab
else:
vocab = list(vocab)
processed = ''
kanji_group = ''
ascii_group = ''
for i, c in enumerate(line):
if is_kanji(c):
c_type = 'kanji'
kanji_group += c
elif is_ascii(c):
c_type = 'ascii'
ascii_group += c
else:
c_type = 'other'
if c_type != 'kanji' or i == len(line) - 1:
if kanji_group:
if np.random.uniform() < word_prob:
furigana_len = int(np.clip(np.random.normal(1.5, 0.5), 1, 4) * len(kanji_group))
char_source = np.random.choice(['hiragana', 'katakana', 'all'], p=[0.8, 0.15, 0.05])
char_source = {
'hiragana': self.hiragana,
'katakana': self.katakana,
'all': vocab
}[char_source]
furigana = ''.join(np.random.choice(char_source, furigana_len))
processed += f'<ruby>{kanji_group}<rt>{furigana}</rt></ruby>'
else:
processed += kanji_group
kanji_group = ''
if c_type != 'ascii' or i == len(line) - 1:
if ascii_group:
if len(ascii_group) <= 3 and np.random.uniform() < 0.7:
processed += f'<span style="text-combine-upright: all">{ascii_group}</span>'
else:
processed += ascii_group
ascii_group = ''
if c_type == 'other':
processed += c
return processed
def is_font_supporting_text(self, font_path, text):
chars = self.font_map[font_path]
for c in text:
if c.isspace():
continue
if c not in chars:
return False
return True
def get_font_labels_prob(self):
labels = {
'common': 0.2,
'regular': 0.75,
'special': 0.05,
}
labels = {k: labels[k] for k in self.fonts_df.label.unique()}
p = np.array(list(labels.values()))
p = p / p.sum()
labels = list(labels.keys())
return labels, p
def get_random_font(self, text=None):
label = np.random.choice(self.font_labels, p=self.font_p)
df = self.fonts_df[self.fonts_df.label == label]
if text is None:
return df.sample(1).iloc[0].font_path
valid_mask = df.font_path.apply(lambda x: self.is_font_supporting_text(x, text))
if not valid_mask.any():
# if text contains characters not supported by any font, just pick some of the more capable fonts
valid_mask = (df.num_chars >= 4000)
return str(FONTS_ROOT / df[valid_mask].sample(1).iloc[0].font_path)

View File

@@ -0,0 +1,265 @@
import os
import uuid
import albumentations as A
import cv2
import numpy as np
from html2image import Html2Image
from manga_ocr_dev.env import BACKGROUND_DIR
from manga_ocr_dev.synthetic_data_generator.utils import get_background_df
class Renderer:
def __init__(self):
self.hti = Html2Image()
self.background_df = get_background_df(BACKGROUND_DIR)
self.max_size = 600
def render(self, lines, override_css_params=None):
img, params = self.render_text(lines, override_css_params)
img = self.render_background(img)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img = A.LongestMaxSize(self.max_size)(image=img)['image']
return img, params
def render_text(self, lines, override_css_params=None):
"""Render text on transparent background and return as BGRA image."""
params = self.get_random_css_params()
if override_css_params:
params.update(override_css_params)
css = get_css(**params)
# this is just a rough estimate, image is cropped later anyway
size = (
int(max(len(line) for line in lines) * params['font_size'] * 1.5),
int(len(lines) * params['font_size'] * (3 + params['line_height'])),
)
if params['vertical']:
size = size[::-1]
html = self.lines_to_html(lines)
filename = str(uuid.uuid4()) + '.png'
self.hti.screenshot(html_str=html, css_str=css, save_as=filename, size=size)
img = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
os.remove(filename)
return img, params
@staticmethod
def get_random_css_params():
params = {
'font_size': 48,
'vertical': True if np.random.rand() < 0.7 else False,
'line_height': 0.5,
'background_color': 'transparent',
'text_color': 'black',
}
if np.random.rand() < 0.7:
params['text_orientation'] = 'upright'
stroke_variant = np.random.choice(['stroke', 'shadow', 'none'], p=[0.8, 0.15, 0.05])
if stroke_variant == 'stroke':
params['stroke_size'] = np.random.choice([1, 2, 3, 4, 8])
params['stroke_color'] = 'white'
elif stroke_variant == 'shadow':
params['shadow_size'] = np.random.choice([2, 5, 10])
params['shadow_color'] = 'white' if np.random.rand() < 0.8 else 'black',
elif stroke_variant == 'none':
pass
return params
def render_background(self, img):
"""Add background and/or text bubble to a BGRA image, crop and return as BGR image."""
draw_bubble = np.random.random() < 0.7
m0 = int(min(img.shape[:2]) * 0.3)
img = crop_by_alpha(img, m0)
background_path = self.background_df.sample(1).iloc[0].path
background = cv2.imread(background_path)
t = [
A.HorizontalFlip(),
A.RandomRotate90(),
A.InvertImg(),
A.RandomBrightnessContrast((-0.2, 0.4), (-0.8, -0.3), p=0.5 if draw_bubble else 1),
A.Blur((3, 5), p=0.3),
A.Resize(img.shape[0], img.shape[1]),
]
background = A.Compose(t)(image=background)['image']
if not draw_bubble:
if np.random.rand() < 0.5:
img[:, :, :3] = 255 - img[:, :, :3]
else:
radius = np.random.uniform(0.7, 1.)
thickness = np.random.choice([1, 2, 3])
alpha = np.random.randint(60, 100)
sigma = np.random.randint(10, 15)
ymin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12))
ymax = img.shape[0] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12))
xmin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12))
xmax = img.shape[1] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12))
bubble_fill_color = (255, 255, 255, 255)
bubble_contour_color = (0, 0, 0, 255)
bubble = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8)
bubble = rounded_rectangle(bubble, (xmin, ymin), (xmax, ymax), radius=radius, color=bubble_fill_color,
thickness=-1)
bubble = rounded_rectangle(bubble, (xmin, ymin), (xmax, ymax), radius=radius, color=bubble_contour_color,
thickness=thickness)
t = [
A.ElasticTransform(alpha=alpha, sigma=sigma, alpha_affine=0, p=0.8),
]
bubble = A.Compose(t)(image=bubble)['image']
background = blend(bubble, background)
img = blend(img, background)
ymin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2))
ymax = img.shape[0] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2))
xmin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2))
xmax = img.shape[1] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2))
img = img[ymin:ymax, xmin:xmax]
return img
def lines_to_html(self, lines):
lines_str = '\n'.join(['<p>' + line + '</p>' for line in lines])
html = f"<html><body>\n{lines_str}\n</body></html>"
return html
def crop_by_alpha(img, margin):
y, x = np.where(img[:, :, 3] > 0)
ymin = y.min()
ymax = y.max()
xmin = x.min()
xmax = x.max()
img = img[ymin:ymax, xmin:xmax]
img = np.pad(img, ((margin, margin), (margin, margin), (0, 0)))
return img
def blend(img, background):
alpha = (img[:, :, 3] / 255)[:, :, np.newaxis]
img = img[:, :, :3]
img = (background * (1 - alpha) + img * alpha).astype(np.uint8)
return img
def rounded_rectangle(src, top_left, bottom_right, radius=1, color=255, thickness=1, line_type=cv2.LINE_AA):
"""From https://stackoverflow.com/a/60210706"""
# corners:
# p1 - p2
# | |
# p4 - p3
p1 = top_left
p2 = (bottom_right[0], top_left[1])
p3 = bottom_right
p4 = (top_left[0], bottom_right[1])
height = abs(bottom_right[1] - top_left[1])
width = abs(bottom_right[0] - top_left[0])
if radius > 1:
radius = 1
corner_radius = int(radius * (min(height, width) / 2))
if thickness < 0:
# big rect
top_left_main_rect = (int(p1[0] + corner_radius), int(p1[1]))
bottom_right_main_rect = (int(p3[0] - corner_radius), int(p3[1]))
top_left_rect_left = (p1[0], p1[1] + corner_radius)
bottom_right_rect_left = (p4[0] + corner_radius, p4[1] - corner_radius)
top_left_rect_right = (p2[0] - corner_radius, p2[1] + corner_radius)
bottom_right_rect_right = (p3[0], p3[1] - corner_radius)
all_rects = [
[top_left_main_rect, bottom_right_main_rect],
[top_left_rect_left, bottom_right_rect_left],
[top_left_rect_right, bottom_right_rect_right]]
[cv2.rectangle(src, rect[0], rect[1], color, thickness) for rect in all_rects]
# draw straight lines
cv2.line(src, (p1[0] + corner_radius, p1[1]), (p2[0] - corner_radius, p2[1]), color, abs(thickness), line_type)
cv2.line(src, (p2[0], p2[1] + corner_radius), (p3[0], p3[1] - corner_radius), color, abs(thickness), line_type)
cv2.line(src, (p3[0] - corner_radius, p4[1]), (p4[0] + corner_radius, p3[1]), color, abs(thickness), line_type)
cv2.line(src, (p4[0], p4[1] - corner_radius), (p1[0], p1[1] + corner_radius), color, abs(thickness), line_type)
# draw arcs
cv2.ellipse(src, (p1[0] + corner_radius, p1[1] + corner_radius), (corner_radius, corner_radius), 180.0, 0, 90,
color, thickness, line_type)
cv2.ellipse(src, (p2[0] - corner_radius, p2[1] + corner_radius), (corner_radius, corner_radius), 270.0, 0, 90,
color, thickness, line_type)
cv2.ellipse(src, (p3[0] - corner_radius, p3[1] - corner_radius), (corner_radius, corner_radius), 0.0, 0, 90, color,
thickness, line_type)
cv2.ellipse(src, (p4[0] + corner_radius, p4[1] - corner_radius), (corner_radius, corner_radius), 90.0, 0, 90, color,
thickness, line_type)
return src
def get_css(
font_size,
font_path,
vertical=True,
background_color='white',
text_color='black',
shadow_size=0,
shadow_color='black',
stroke_size=0,
stroke_color='black',
letter_spacing=None,
line_height=0.5,
text_orientation=None,
):
styles = [
f"background-color: {background_color};",
f"font-size: {font_size}px;",
f"color: {text_color};",
"font-family: custom;",
f"line-height: {line_height};",
"margin: 20px;",
]
if text_orientation:
styles.append(f"text-orientation: {text_orientation};")
if vertical:
styles.append("writing-mode: vertical-rl;")
if shadow_size > 0:
styles.append(f"text-shadow: 0 0 {shadow_size}px {shadow_color};")
if stroke_size > 0:
# stroke is simulated by shadow overlaid multiple times
styles.extend([
f"text-shadow: " + ','.join([f"0 0 {stroke_size}px {stroke_color}"] * 10 * stroke_size) + ";",
"-webkit-font-smoothing: antialiased;",
])
if letter_spacing:
styles.append(f"letter-spacing: {letter_spacing}em;")
font_path = font_path.replace('\\', '/')
styles_str = '\n'.join(styles)
css = ""
css += '\n@font-face {\nfont-family: custom;\nsrc: url("' + font_path + '");\n}\n'
css += "body {\n" + styles_str + "\n}"
return css

View File

@@ -0,0 +1,64 @@
import traceback
from pathlib import Path
import cv2
import fire
import pandas as pd
from tqdm.contrib.concurrent import thread_map
from manga_ocr_dev.env import FONTS_ROOT, DATA_SYNTHETIC_ROOT
from manga_ocr_dev.synthetic_data_generator.generator import SyntheticDataGenerator
generator = SyntheticDataGenerator()
def f(args):
try:
i, source, id_, text = args
filename = f'{id_}.jpg'
img, text_gt, params = generator.process(text)
cv2.imwrite(str(OUT_DIR / filename), img)
font_path = Path(params['font_path']).relative_to(FONTS_ROOT)
ret = source, id_, text_gt, params['vertical'], str(font_path)
return ret
except Exception as e:
print(traceback.format_exc())
def run(package=0, n_random=1000, n_limit=None, max_workers=16):
"""
:param package: number of data package to generate
:param n_random: how many samples with random text to generate
:param n_limit: limit number of generated samples (for debugging)
:param max_workers: max number of workers
"""
package = f'{package:04d}'
lines = pd.read_csv(DATA_SYNTHETIC_ROOT / f'lines/{package}.csv')
random_lines = pd.DataFrame({
'source': 'random',
'id': [f'random_{package}_{i}' for i in range(n_random)],
'line': None
})
lines = pd.concat([lines, random_lines], ignore_index=True)
if n_limit:
lines = lines.sample(n_limit)
args = [(i, *values) for i, values in enumerate(lines.values)]
global OUT_DIR
OUT_DIR = DATA_SYNTHETIC_ROOT / 'img' / package
OUT_DIR.mkdir(parents=True, exist_ok=True)
data = thread_map(f, args, max_workers=max_workers, desc=f'Processing package {package}')
data = pd.DataFrame(data, columns=['source', 'id', 'text', 'vertical', 'font_path'])
meta_path = DATA_SYNTHETIC_ROOT / f'meta/{package}.csv'
meta_path.parent.mkdir(parents=True, exist_ok=True)
data.to_csv(meta_path, index=False)
if __name__ == '__main__':
fire.Fire(run)

View File

@@ -0,0 +1,72 @@
import PIL
import numpy as np
import pandas as pd
from PIL import ImageDraw, ImageFont
from fontTools.ttLib import TTFont
from tqdm.contrib.concurrent import process_map
from manga_ocr_dev.env import ASSETS_PATH, FONTS_ROOT
vocab = pd.read_csv(ASSETS_PATH / 'vocab.csv').char.values
def has_glyph(font, glyph):
for table in font['cmap'].tables:
if ord(glyph) in table.cmap.keys():
return True
return False
def process(font_path):
"""
Get supported characters list for a given font.
Font metadata is not always reliable, so try to render each character and see if anything shows up.
Still not perfect, because sometimes unsupported characters show up as rectangles.
"""
try:
font_path = str(font_path)
ttfont = TTFont(font_path)
pil_font = ImageFont.truetype(font_path, 24)
supported_chars = []
for char in vocab:
if not has_glyph(ttfont, char):
continue
image = PIL.Image.new('L', (40, 40), 255)
draw = ImageDraw.Draw(image)
draw.text((10, 0), char, 0, font=pil_font)
if (np.array(image) != 255).sum() == 0:
continue
supported_chars.append(char)
supported_chars = ''.join(supported_chars)
except Exception as e:
print(f'Error while processing {font_path}: {e}')
supported_chars = ''
return supported_chars
def main():
path_in = FONTS_ROOT
out_path = ASSETS_PATH / 'fonts.csv'
suffixes = {'.TTF', '.otf', '.ttc', '.ttf'}
font_paths = [path for path in path_in.glob('**/*') if
path.suffix in suffixes]
data = process_map(process, font_paths, max_workers=16)
font_paths = [str(path.relative_to(FONTS_ROOT)) for path in font_paths]
data = pd.DataFrame({'font_path': font_paths, 'supported_chars': data})
data['num_chars'] = data.supported_chars.str.len()
data['label'] = 'regular'
data.to_csv(out_path, index=False)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,54 @@
import pandas as pd
import unicodedata
from manga_ocr_dev.env import ASSETS_PATH, FONTS_ROOT
def get_background_df(background_dir):
background_df = []
for path in background_dir.iterdir():
ymin, ymax, xmin, xmax = [int(v) for v in path.stem.split('_')[-4:]]
h = ymax - ymin
w = xmax - xmin
ratio = w / h
background_df.append({
'path': str(path),
'h': h,
'w': w,
'ratio': ratio,
})
background_df = pd.DataFrame(background_df)
return background_df
def is_kanji(ch):
return 'CJK UNIFIED IDEOGRAPH' in unicodedata.name(ch)
def is_hiragana(ch):
return 'HIRAGANA' in unicodedata.name(ch)
def is_katakana(ch):
return 'KATAKANA' in unicodedata.name(ch)
def is_ascii(ch):
return ord(ch) < 128
def get_charsets(vocab_path=None):
if vocab_path is None:
vocab_path = ASSETS_PATH / 'vocab.csv'
vocab = pd.read_csv(vocab_path).char.values
hiragana = vocab[[is_hiragana(c) for c in vocab]][:-6]
katakana = vocab[[is_katakana(c) for c in vocab]][3:]
return vocab, hiragana, katakana
def get_font_meta():
df = pd.read_csv(ASSETS_PATH / 'fonts.csv')
df.font_path = df.font_path.apply(lambda x: str(FONTS_ROOT / x))
font_map = {row.font_path: set(row.supported_chars) for row in df.dropna().itertuples()}
return df, font_map