training and synthetic data generation code

This commit is contained in:
Maciej Budyś
2022-02-09 20:39:01 +01:00
parent a9085393f4
commit 975dbf4d5e
42 changed files with 7089 additions and 15 deletions

98
manga_ocr_dev/README.md Normal file
View File

@@ -0,0 +1,98 @@
# Project structure
```
assets/ # assets (see description below)
manga_ocr/ # release code (inference only)
manga_ocr_dev/ # development code
env.py # global constants
data/ # data preprocessing
synthetic_data_generator/ # generation of synthetic image-text pairs
training/ # model training
```
## assets
### fonts.csv
csv with columns:
- font_path: path to font file, relative to `FONTS_ROOT`
- supported_chars: string of characters supported by this font
- num_chars: number of supported characters
- label: common/regular/special (used to sample regular fonts more often than special)
List of fonts with metadata used by synthetic data generator.
Provided file is just an example, you have to generate similar file for your own set of fonts,
using `manga_ocr_dev/synthetic_data_generator/scan_fonts.py` script.
Note that `label` will be filled with `regular` by default. You have to label your special fonts manually.
### lines_example.csv
csv with columns:
- source: source of text
- id: unique id of the line
- line: line from language corpus
Example of csv used for synthetic data generation.
### len_to_p.csv
csv with columns:
- len: length of text
- p: probability of text of this length occurring in manga
Used by synthetic data generator to more-or-less match the natural distribution of text lengths.
Computed based on Manga109-s dataset.
### vocab.csv
List of all characters supported by tokenizer.
# Training OCR
`env.py` contains global constants used across the repo. Set your paths to data etc. there.
1. Download [Manga109-s](http://www.manga109.org/en/download_s.html) dataset.
2. Set `MANGA109_ROOT`, so that your directory structure looks like this:
```
<MANGA109_ROOT>/
Manga109s_released_2021_02_28/
annotations/
annotations.v2018.05.31/
images/
books.txt
readme.txt
```
3. Preprocess Manga109-s with `data/process_manga109s.py`
4. Optionally generate synthetic data (see below)
5. Train with `manga_ocr_dev/training/train.py`
# Synthetic data generation
Generated data is split into packages (named `0000`, `0001` etc.) for easier management of large dataset.
Each package is assumed to have similar data distribution, so that a properly balanced dataset
can be built from any subset of packages.
Data generation pipeline assumes following directory structure:
```
<DATA_SYNTHETIC_ROOT>/
img/ # generated images (output from generation pipeline)
0000/
0001/
...
lines/ # lines from corpus (input to generation pipeline)
0000.csv
0001.csv
...
meta/ # metadata (output from generation pipeline)
0000.csv
0001.csv
...
```
To use a language corpus for data generation, `lines/*.csv` files must be provided.
For a small example of such file see `assets/lines_example.csv`.
To generate synthetic data:
1. Generate backgrounds with `data/generate_backgrounds.py`.
2. Put your fonts in `<FONTS_ROOT>`.
3. Generate fonts metadata with `synthetic_data_generator/scan_fonts.py`.
4. Optionally manually label your fonts with `common/regular/special` labels.
5. Provide `<DATA_SYNTHETIC_ROOT>/lines/*.csv`.
6. Run `synthetic_data_generator/run_generate.py` for each package.

View File

View File

View File

@@ -0,0 +1,85 @@
from pathlib import Path
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
from manga_ocr_dev.env import MANGA109_ROOT, BACKGROUND_DIR
def find_rectangle(mask, y, x, aspect_ratio_range=(0.33, 3.0)):
ymin_ = ymax_ = y
xmin_ = xmax_ = x
ymin = ymax = xmin = xmax = None
while True:
if ymin is None:
ymin_ -= 1
if ymin_ == 0 or mask[ymin_, xmin_:xmax_].any():
ymin = ymin_
if ymax is None:
ymax_ += 1
if ymax_ == mask.shape[0] - 1 or mask[ymax_, xmin_:xmax_].any():
ymax = ymax_
if xmin is None:
xmin_ -= 1
if xmin_ == 0 or mask[ymin_:ymax_, xmin_].any():
xmin = xmin_
if xmax is None:
xmax_ += 1
if xmax_ == mask.shape[1] - 1 or mask[ymin_:ymax_, xmax_].any():
xmax = xmax_
h = ymax_ - ymin_
w = xmax_ - xmin_
if h > 1 and w > 1:
ratio = w / h
if ratio < aspect_ratio_range[0] or ratio > aspect_ratio_range[1]:
return ymin_, ymax_, xmin_, xmax_
if None not in (ymin, ymax, xmin, xmax):
return ymin, ymax, xmin, xmax
def generate_backgrounds(crops_per_page=5, min_size=40):
data = pd.read_csv(MANGA109_ROOT / 'data.csv')
frames_df = pd.read_csv(MANGA109_ROOT / 'frames.csv')
BACKGROUND_DIR.mkdir(parents=True, exist_ok=True)
page_paths = data.page_path.unique()
for page_path in tqdm(page_paths):
page = cv2.imread(str(MANGA109_ROOT / page_path))
mask = np.zeros((page.shape[0], page.shape[1]), dtype=bool)
for row in data[data.page_path == page_path].itertuples():
mask[row.ymin:row.ymax, row.xmin:row.xmax] = True
frames_mask = np.zeros((page.shape[0], page.shape[1]), dtype=bool)
for row in frames_df[frames_df.page_path == page_path].itertuples():
frames_mask[row.ymin:row.ymax, row.xmin:row.xmax] = True
mask = mask | ~frames_mask
if mask.all():
continue
unmasked_points = np.stack(np.where(~mask), axis=1)
for i in range(crops_per_page):
p = unmasked_points[np.random.randint(0, unmasked_points.shape[0])]
y, x = p
ymin, ymax, xmin, xmax = find_rectangle(mask, y, x)
crop = page[ymin:ymax, xmin:xmax]
if crop.shape[0] >= min_size and crop.shape[1] >= min_size:
out_filename = '_'.join(
Path(page_path).with_suffix('').parts[-2:]) + f'_{ymin}_{ymax}_{xmin}_{xmax}.png'
cv2.imwrite(str(BACKGROUND_DIR / out_filename), crop)
if __name__ == '__main__':
generate_backgrounds()

View File

@@ -0,0 +1,103 @@
import xml.etree.ElementTree as ET
from pathlib import Path
import cv2
import pandas as pd
from tqdm import tqdm
from manga_ocr_dev.env import MANGA109_ROOT
def get_books():
root = MANGA109_ROOT / 'Manga109s_released_2021_02_28'
books = (root / 'books.txt').read_text().splitlines()
books = pd.DataFrame({
'book': books,
'annotations': [str(root / 'annotations' / f'{book}.xml') for book in books],
'images': [str(root / 'images' / book) for book in books],
})
return books
def export_frames():
books = get_books()
data = []
for book in tqdm(books.itertuples(), total=len(books)):
tree = ET.parse(book.annotations)
root = tree.getroot()
for page in root.findall('./pages/page'):
for frame in page.findall('./frame'):
row = {}
row['book'] = book.book
row['page_index'] = int(page.attrib['index'])
row['page_path'] = str(Path(book.images) / f'{row["page_index"]:03d}.jpg')
row['page_width'] = int(page.attrib['width'])
row['page_height'] = int(page.attrib['height'])
row['id'] = frame.attrib['id']
row['xmin'] = int(frame.attrib['xmin'])
row['ymin'] = int(frame.attrib['ymin'])
row['xmax'] = int(frame.attrib['xmax'])
row['ymax'] = int(frame.attrib['ymax'])
data.append(row)
data = pd.DataFrame(data)
data.page_path = data.page_path.apply(lambda x: '/'.join(Path(x).parts[-4:]))
data.to_csv(MANGA109_ROOT / 'frames.csv', index=False)
def export_crops():
crops_root = MANGA109_ROOT / 'crops'
crops_root.mkdir(parents=True, exist_ok=True)
margin = 10
books = get_books()
data = []
for book in tqdm(books.itertuples(), total=len(books)):
tree = ET.parse(book.annotations)
root = tree.getroot()
for page in root.findall('./pages/page'):
for text in page.findall('./text'):
row = {}
row['book'] = book.book
row['page_index'] = int(page.attrib['index'])
row['page_path'] = str(Path(book.images) / f'{row["page_index"]:03d}.jpg')
row['page_width'] = int(page.attrib['width'])
row['page_height'] = int(page.attrib['height'])
row['id'] = text.attrib['id']
row['text'] = text.text
row['xmin'] = int(text.attrib['xmin'])
row['ymin'] = int(text.attrib['ymin'])
row['xmax'] = int(text.attrib['xmax'])
row['ymax'] = int(text.attrib['ymax'])
data.append(row)
data = pd.DataFrame(data)
n_test = int(0.1 * len(data))
data['split'] = 'train'
data.loc[data.sample(len(data)).iloc[:n_test].index, 'split'] = 'test'
data['crop_path'] = str(crops_root) + '\\' + data.id + '.png'
data.page_path = data.page_path.apply(lambda x: '/'.join(Path(x).parts[-4:]))
data.crop_path = data.crop_path.apply(lambda x: '/'.join(Path(x).parts[-2:]))
data.to_csv(MANGA109_ROOT / 'data.csv', index=False)
for page_path, boxes in tqdm(data.groupby('page_path'), total=data.page_path.nunique()):
img = cv2.imread(str(MANGA109_ROOT / page_path))
for box in boxes.itertuples():
xmin = max(box.xmin - margin, 0)
xmax = min(box.xmax + margin, img.shape[1])
ymin = max(box.ymin - margin, 0)
ymax = min(box.ymax + margin, img.shape[0])
crop = img[ymin:ymax, xmin:xmax]
out_path = (crops_root / box.id).with_suffix('.png')
cv2.imwrite(str(out_path), crop)
if __name__ == '__main__':
export_frames()
export_crops()

9
manga_ocr_dev/env.py Normal file
View File

@@ -0,0 +1,9 @@
from pathlib import Path
ASSETS_PATH = Path(__file__).parent.parent / 'assets'
FONTS_ROOT = Path('~/data/jp_fonts').expanduser()
DATA_SYNTHETIC_ROOT = Path('~/data/manga/synthetic').expanduser()
BACKGROUND_DIR = Path('~/data/manga/Manga109s/background').expanduser()
MANGA109_ROOT = Path('~/data/manga/Manga109s').expanduser()
TRAIN_ROOT = Path('~/data/manga/out').expanduser()

View File

@@ -0,0 +1,24 @@
datasets
jiwer
torchinfo
transformers>=4.12.5
unidic-lite
ipadic
mecab-python3
fugashi
matplotlib
numpy
opencv-python
pandas
Pillow
scikit-image
scikit-learn
scipy
torch
torchvision
tqdm
wandb
fire
budou
albumentations>=1.1
html2image

View File

@@ -0,0 +1,38 @@
# Synthetic data generator
Generation of synthetic image-text pairs imitating Japanese manga for the purpose of training OCR.
Features:
- using either text from corpus or random text
- text overlaid on background images
- drawing text bubbles
- various fonts and font styles
- variety of text layouts:
- vertical and horizontal text
- multi-line text
- [furigana](https://en.wikipedia.org/wiki/Furigana) (added randomly)
- [tate chū yoko](https://www.w3.org/International/articles/vertical-text/#tcy)
Text rendering is done with the usage of [html2image](https://github.com/vgalin/html2image),
which is a wrapper around Chrome/Chromium browser's headless mode.
It's not too elegant of a solution, and it is very slow, but it only needs to be run once,
and when parallelized, processing time is manageable (~17 min per 10000 images on a 16-thread machine).
The upside of this approach is that a quite complex problem of typesetting and text rendering
(especially when dealing with both horizontal and vertical text) is offloaded to
the browser engine, keeping the codebase relatively simple and extendable.
High-level generation pipeline is as follows:
1. Preprocess text (truncate and/or split into lines, add random furigana).
2. Render text on a transparent background, using HTML engine.
3. Select background image from backgrounds dataset.
4. Overlay the text on the background, optionally drawing a bubble around the text.
# Examples
## Images generated with text from [CC-100 Japanese corpus](https://data.statmt.org/cc-100/)
![](../../assets/examples/cc-100.jpg)
## Images generated with random text
![](../../assets/examples/random.jpg)

View File

@@ -0,0 +1,198 @@
import budou
import numpy as np
import pandas as pd
from manga_ocr_dev.env import ASSETS_PATH, FONTS_ROOT
from manga_ocr_dev.synthetic_data_generator.renderer import Renderer
from manga_ocr_dev.synthetic_data_generator.utils import get_font_meta, get_charsets, is_ascii, is_kanji
class SyntheticDataGenerator:
def __init__(self):
self.vocab, self.hiragana, self.katakana = get_charsets()
self.len_to_p = pd.read_csv(ASSETS_PATH / 'len_to_p.csv')
self.parser = budou.get_parser('tinysegmenter')
self.fonts_df, self.font_map = get_font_meta()
self.font_labels, self.font_p = self.get_font_labels_prob()
self.renderer = Renderer()
def process(self, text=None, override_css_params=None):
"""
Generate image, text pair. Use source text if provided, otherwise generate random text.
"""
if override_css_params is None:
override_css_params = {}
if text is None:
# if using random text, choose font first,
# and then generate text using only characters supported by that font
if 'font_path' not in override_css_params:
font_path = self.get_random_font()
vocab = self.font_map[font_path]
override_css_params['font_path'] = font_path
else:
font_path = override_css_params['font_path']
vocab = self.font_map[font_path]
words = self.get_random_words(vocab)
else:
text = text.replace(' ', ' ')
text = text.replace('', '...')
words = self.split_into_words(text)
lines = self.words_to_lines(words)
text_gt = '\n'.join(lines)
if 'font_path' not in override_css_params:
override_css_params['font_path'] = self.get_random_font(text_gt)
font_path = override_css_params.get('font_path')
if font_path:
vocab = self.font_map.get(font_path)
# remove unsupported characters
lines = [''.join([c for c in line if c in vocab]) for line in lines]
text_gt = '\n'.join(lines)
else:
vocab = None
if np.random.random() < 0.5:
word_prob = np.random.choice([0.33, 1.0], p=[0.3, 0.7])
lines = [self.add_random_furigana(line, word_prob, vocab) for line in lines]
img, params = self.renderer.render(lines, override_css_params)
return img, text_gt, params
def get_random_words(self, vocab):
vocab = list(vocab)
max_text_len = np.random.choice(self.len_to_p.len, p=self.len_to_p.p)
words = []
text_len = 0
while True:
word = ''.join(np.random.choice(vocab, np.random.randint(1, 4)))
words.append(word)
text_len += len(word)
if text_len + len(word) >= max_text_len:
break
return words
def split_into_words(self, text):
max_text_len = np.random.choice(self.len_to_p.len, p=self.len_to_p.p)
words = []
text_len = 0
for chunk in self.parser.parse(text)['chunks']:
words.append(chunk.word)
text_len += len(chunk.word)
if text_len + len(chunk.word) >= max_text_len:
break
return words
def words_to_lines(self, words):
text = ''.join(words)
max_num_lines = 10
min_line_len = len(text) // max_num_lines
max_line_len = 20
max_line_len = np.clip(np.random.poisson(6), min_line_len, max_line_len)
lines = []
line = ''
for word in words:
line += word
if len(line) >= max_line_len:
lines.append(line)
line = ''
if line:
lines.append(line)
return lines
def add_random_furigana(self, line, word_prob=1.0, vocab=None):
if vocab is None:
vocab = self.vocab
else:
vocab = list(vocab)
processed = ''
kanji_group = ''
ascii_group = ''
for i, c in enumerate(line):
if is_kanji(c):
c_type = 'kanji'
kanji_group += c
elif is_ascii(c):
c_type = 'ascii'
ascii_group += c
else:
c_type = 'other'
if c_type != 'kanji' or i == len(line) - 1:
if kanji_group:
if np.random.uniform() < word_prob:
furigana_len = int(np.clip(np.random.normal(1.5, 0.5), 1, 4) * len(kanji_group))
char_source = np.random.choice(['hiragana', 'katakana', 'all'], p=[0.8, 0.15, 0.05])
char_source = {
'hiragana': self.hiragana,
'katakana': self.katakana,
'all': vocab
}[char_source]
furigana = ''.join(np.random.choice(char_source, furigana_len))
processed += f'<ruby>{kanji_group}<rt>{furigana}</rt></ruby>'
else:
processed += kanji_group
kanji_group = ''
if c_type != 'ascii' or i == len(line) - 1:
if ascii_group:
if len(ascii_group) <= 3 and np.random.uniform() < 0.7:
processed += f'<span style="text-combine-upright: all">{ascii_group}</span>'
else:
processed += ascii_group
ascii_group = ''
if c_type == 'other':
processed += c
return processed
def is_font_supporting_text(self, font_path, text):
chars = self.font_map[font_path]
for c in text:
if c.isspace():
continue
if c not in chars:
return False
return True
def get_font_labels_prob(self):
labels = {
'common': 0.2,
'regular': 0.75,
'special': 0.05,
}
labels = {k: labels[k] for k in self.fonts_df.label.unique()}
p = np.array(list(labels.values()))
p = p / p.sum()
labels = list(labels.keys())
return labels, p
def get_random_font(self, text=None):
label = np.random.choice(self.font_labels, p=self.font_p)
df = self.fonts_df[self.fonts_df.label == label]
if text is None:
return df.sample(1).iloc[0].font_path
valid_mask = df.font_path.apply(lambda x: self.is_font_supporting_text(x, text))
if not valid_mask.any():
# if text contains characters not supported by any font, just pick some of the more capable fonts
valid_mask = (df.num_chars >= 4000)
return str(FONTS_ROOT / df[valid_mask].sample(1).iloc[0].font_path)

View File

@@ -0,0 +1,265 @@
import os
import uuid
import albumentations as A
import cv2
import numpy as np
from html2image import Html2Image
from manga_ocr_dev.env import BACKGROUND_DIR
from manga_ocr_dev.synthetic_data_generator.utils import get_background_df
class Renderer:
def __init__(self):
self.hti = Html2Image()
self.background_df = get_background_df(BACKGROUND_DIR)
self.max_size = 600
def render(self, lines, override_css_params=None):
img, params = self.render_text(lines, override_css_params)
img = self.render_background(img)
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img = A.LongestMaxSize(self.max_size)(image=img)['image']
return img, params
def render_text(self, lines, override_css_params=None):
"""Render text on transparent background and return as BGRA image."""
params = self.get_random_css_params()
if override_css_params:
params.update(override_css_params)
css = get_css(**params)
# this is just a rough estimate, image is cropped later anyway
size = (
int(max(len(line) for line in lines) * params['font_size'] * 1.5),
int(len(lines) * params['font_size'] * (3 + params['line_height'])),
)
if params['vertical']:
size = size[::-1]
html = self.lines_to_html(lines)
filename = str(uuid.uuid4()) + '.png'
self.hti.screenshot(html_str=html, css_str=css, save_as=filename, size=size)
img = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
os.remove(filename)
return img, params
@staticmethod
def get_random_css_params():
params = {
'font_size': 48,
'vertical': True if np.random.rand() < 0.7 else False,
'line_height': 0.5,
'background_color': 'transparent',
'text_color': 'black',
}
if np.random.rand() < 0.7:
params['text_orientation'] = 'upright'
stroke_variant = np.random.choice(['stroke', 'shadow', 'none'], p=[0.8, 0.15, 0.05])
if stroke_variant == 'stroke':
params['stroke_size'] = np.random.choice([1, 2, 3, 4, 8])
params['stroke_color'] = 'white'
elif stroke_variant == 'shadow':
params['shadow_size'] = np.random.choice([2, 5, 10])
params['shadow_color'] = 'white' if np.random.rand() < 0.8 else 'black',
elif stroke_variant == 'none':
pass
return params
def render_background(self, img):
"""Add background and/or text bubble to a BGRA image, crop and return as BGR image."""
draw_bubble = np.random.random() < 0.7
m0 = int(min(img.shape[:2]) * 0.3)
img = crop_by_alpha(img, m0)
background_path = self.background_df.sample(1).iloc[0].path
background = cv2.imread(background_path)
t = [
A.HorizontalFlip(),
A.RandomRotate90(),
A.InvertImg(),
A.RandomBrightnessContrast((-0.2, 0.4), (-0.8, -0.3), p=0.5 if draw_bubble else 1),
A.Blur((3, 5), p=0.3),
A.Resize(img.shape[0], img.shape[1]),
]
background = A.Compose(t)(image=background)['image']
if not draw_bubble:
if np.random.rand() < 0.5:
img[:, :, :3] = 255 - img[:, :, :3]
else:
radius = np.random.uniform(0.7, 1.)
thickness = np.random.choice([1, 2, 3])
alpha = np.random.randint(60, 100)
sigma = np.random.randint(10, 15)
ymin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12))
ymax = img.shape[0] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12))
xmin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12))
xmax = img.shape[1] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12))
bubble_fill_color = (255, 255, 255, 255)
bubble_contour_color = (0, 0, 0, 255)
bubble = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8)
bubble = rounded_rectangle(bubble, (xmin, ymin), (xmax, ymax), radius=radius, color=bubble_fill_color,
thickness=-1)
bubble = rounded_rectangle(bubble, (xmin, ymin), (xmax, ymax), radius=radius, color=bubble_contour_color,
thickness=thickness)
t = [
A.ElasticTransform(alpha=alpha, sigma=sigma, alpha_affine=0, p=0.8),
]
bubble = A.Compose(t)(image=bubble)['image']
background = blend(bubble, background)
img = blend(img, background)
ymin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2))
ymax = img.shape[0] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2))
xmin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2))
xmax = img.shape[1] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2))
img = img[ymin:ymax, xmin:xmax]
return img
def lines_to_html(self, lines):
lines_str = '\n'.join(['<p>' + line + '</p>' for line in lines])
html = f"<html><body>\n{lines_str}\n</body></html>"
return html
def crop_by_alpha(img, margin):
y, x = np.where(img[:, :, 3] > 0)
ymin = y.min()
ymax = y.max()
xmin = x.min()
xmax = x.max()
img = img[ymin:ymax, xmin:xmax]
img = np.pad(img, ((margin, margin), (margin, margin), (0, 0)))
return img
def blend(img, background):
alpha = (img[:, :, 3] / 255)[:, :, np.newaxis]
img = img[:, :, :3]
img = (background * (1 - alpha) + img * alpha).astype(np.uint8)
return img
def rounded_rectangle(src, top_left, bottom_right, radius=1, color=255, thickness=1, line_type=cv2.LINE_AA):
"""From https://stackoverflow.com/a/60210706"""
# corners:
# p1 - p2
# | |
# p4 - p3
p1 = top_left
p2 = (bottom_right[0], top_left[1])
p3 = bottom_right
p4 = (top_left[0], bottom_right[1])
height = abs(bottom_right[1] - top_left[1])
width = abs(bottom_right[0] - top_left[0])
if radius > 1:
radius = 1
corner_radius = int(radius * (min(height, width) / 2))
if thickness < 0:
# big rect
top_left_main_rect = (int(p1[0] + corner_radius), int(p1[1]))
bottom_right_main_rect = (int(p3[0] - corner_radius), int(p3[1]))
top_left_rect_left = (p1[0], p1[1] + corner_radius)
bottom_right_rect_left = (p4[0] + corner_radius, p4[1] - corner_radius)
top_left_rect_right = (p2[0] - corner_radius, p2[1] + corner_radius)
bottom_right_rect_right = (p3[0], p3[1] - corner_radius)
all_rects = [
[top_left_main_rect, bottom_right_main_rect],
[top_left_rect_left, bottom_right_rect_left],
[top_left_rect_right, bottom_right_rect_right]]
[cv2.rectangle(src, rect[0], rect[1], color, thickness) for rect in all_rects]
# draw straight lines
cv2.line(src, (p1[0] + corner_radius, p1[1]), (p2[0] - corner_radius, p2[1]), color, abs(thickness), line_type)
cv2.line(src, (p2[0], p2[1] + corner_radius), (p3[0], p3[1] - corner_radius), color, abs(thickness), line_type)
cv2.line(src, (p3[0] - corner_radius, p4[1]), (p4[0] + corner_radius, p3[1]), color, abs(thickness), line_type)
cv2.line(src, (p4[0], p4[1] - corner_radius), (p1[0], p1[1] + corner_radius), color, abs(thickness), line_type)
# draw arcs
cv2.ellipse(src, (p1[0] + corner_radius, p1[1] + corner_radius), (corner_radius, corner_radius), 180.0, 0, 90,
color, thickness, line_type)
cv2.ellipse(src, (p2[0] - corner_radius, p2[1] + corner_radius), (corner_radius, corner_radius), 270.0, 0, 90,
color, thickness, line_type)
cv2.ellipse(src, (p3[0] - corner_radius, p3[1] - corner_radius), (corner_radius, corner_radius), 0.0, 0, 90, color,
thickness, line_type)
cv2.ellipse(src, (p4[0] + corner_radius, p4[1] - corner_radius), (corner_radius, corner_radius), 90.0, 0, 90, color,
thickness, line_type)
return src
def get_css(
font_size,
font_path,
vertical=True,
background_color='white',
text_color='black',
shadow_size=0,
shadow_color='black',
stroke_size=0,
stroke_color='black',
letter_spacing=None,
line_height=0.5,
text_orientation=None,
):
styles = [
f"background-color: {background_color};",
f"font-size: {font_size}px;",
f"color: {text_color};",
"font-family: custom;",
f"line-height: {line_height};",
"margin: 20px;",
]
if text_orientation:
styles.append(f"text-orientation: {text_orientation};")
if vertical:
styles.append("writing-mode: vertical-rl;")
if shadow_size > 0:
styles.append(f"text-shadow: 0 0 {shadow_size}px {shadow_color};")
if stroke_size > 0:
# stroke is simulated by shadow overlaid multiple times
styles.extend([
f"text-shadow: " + ','.join([f"0 0 {stroke_size}px {stroke_color}"] * 10 * stroke_size) + ";",
"-webkit-font-smoothing: antialiased;",
])
if letter_spacing:
styles.append(f"letter-spacing: {letter_spacing}em;")
font_path = font_path.replace('\\', '/')
styles_str = '\n'.join(styles)
css = ""
css += '\n@font-face {\nfont-family: custom;\nsrc: url("' + font_path + '");\n}\n'
css += "body {\n" + styles_str + "\n}"
return css

View File

@@ -0,0 +1,64 @@
import traceback
from pathlib import Path
import cv2
import fire
import pandas as pd
from tqdm.contrib.concurrent import thread_map
from manga_ocr_dev.env import FONTS_ROOT, DATA_SYNTHETIC_ROOT
from manga_ocr_dev.synthetic_data_generator.generator import SyntheticDataGenerator
generator = SyntheticDataGenerator()
def f(args):
try:
i, source, id_, text = args
filename = f'{id_}.jpg'
img, text_gt, params = generator.process(text)
cv2.imwrite(str(OUT_DIR / filename), img)
font_path = Path(params['font_path']).relative_to(FONTS_ROOT)
ret = source, id_, text_gt, params['vertical'], str(font_path)
return ret
except Exception as e:
print(traceback.format_exc())
def run(package=0, n_random=1000, n_limit=None, max_workers=16):
"""
:param package: number of data package to generate
:param n_random: how many samples with random text to generate
:param n_limit: limit number of generated samples (for debugging)
:param max_workers: max number of workers
"""
package = f'{package:04d}'
lines = pd.read_csv(DATA_SYNTHETIC_ROOT / f'lines/{package}.csv')
random_lines = pd.DataFrame({
'source': 'random',
'id': [f'random_{package}_{i}' for i in range(n_random)],
'line': None
})
lines = pd.concat([lines, random_lines], ignore_index=True)
if n_limit:
lines = lines.sample(n_limit)
args = [(i, *values) for i, values in enumerate(lines.values)]
global OUT_DIR
OUT_DIR = DATA_SYNTHETIC_ROOT / 'img' / package
OUT_DIR.mkdir(parents=True, exist_ok=True)
data = thread_map(f, args, max_workers=max_workers, desc=f'Processing package {package}')
data = pd.DataFrame(data, columns=['source', 'id', 'text', 'vertical', 'font_path'])
meta_path = DATA_SYNTHETIC_ROOT / f'meta/{package}.csv'
meta_path.parent.mkdir(parents=True, exist_ok=True)
data.to_csv(meta_path, index=False)
if __name__ == '__main__':
fire.Fire(run)

View File

@@ -0,0 +1,72 @@
import PIL
import numpy as np
import pandas as pd
from PIL import ImageDraw, ImageFont
from fontTools.ttLib import TTFont
from tqdm.contrib.concurrent import process_map
from manga_ocr_dev.env import ASSETS_PATH, FONTS_ROOT
vocab = pd.read_csv(ASSETS_PATH / 'vocab.csv').char.values
def has_glyph(font, glyph):
for table in font['cmap'].tables:
if ord(glyph) in table.cmap.keys():
return True
return False
def process(font_path):
"""
Get supported characters list for a given font.
Font metadata is not always reliable, so try to render each character and see if anything shows up.
Still not perfect, because sometimes unsupported characters show up as rectangles.
"""
try:
font_path = str(font_path)
ttfont = TTFont(font_path)
pil_font = ImageFont.truetype(font_path, 24)
supported_chars = []
for char in vocab:
if not has_glyph(ttfont, char):
continue
image = PIL.Image.new('L', (40, 40), 255)
draw = ImageDraw.Draw(image)
draw.text((10, 0), char, 0, font=pil_font)
if (np.array(image) != 255).sum() == 0:
continue
supported_chars.append(char)
supported_chars = ''.join(supported_chars)
except Exception as e:
print(f'Error while processing {font_path}: {e}')
supported_chars = ''
return supported_chars
def main():
path_in = FONTS_ROOT
out_path = ASSETS_PATH / 'fonts.csv'
suffixes = {'.TTF', '.otf', '.ttc', '.ttf'}
font_paths = [path for path in path_in.glob('**/*') if
path.suffix in suffixes]
data = process_map(process, font_paths, max_workers=16)
font_paths = [str(path.relative_to(FONTS_ROOT)) for path in font_paths]
data = pd.DataFrame({'font_path': font_paths, 'supported_chars': data})
data['num_chars'] = data.supported_chars.str.len()
data['label'] = 'regular'
data.to_csv(out_path, index=False)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,54 @@
import pandas as pd
import unicodedata
from manga_ocr_dev.env import ASSETS_PATH, FONTS_ROOT
def get_background_df(background_dir):
background_df = []
for path in background_dir.iterdir():
ymin, ymax, xmin, xmax = [int(v) for v in path.stem.split('_')[-4:]]
h = ymax - ymin
w = xmax - xmin
ratio = w / h
background_df.append({
'path': str(path),
'h': h,
'w': w,
'ratio': ratio,
})
background_df = pd.DataFrame(background_df)
return background_df
def is_kanji(ch):
return 'CJK UNIFIED IDEOGRAPH' in unicodedata.name(ch)
def is_hiragana(ch):
return 'HIRAGANA' in unicodedata.name(ch)
def is_katakana(ch):
return 'KATAKANA' in unicodedata.name(ch)
def is_ascii(ch):
return ord(ch) < 128
def get_charsets(vocab_path=None):
if vocab_path is None:
vocab_path = ASSETS_PATH / 'vocab.csv'
vocab = pd.read_csv(vocab_path).char.values
hiragana = vocab[[is_hiragana(c) for c in vocab]][:-6]
katakana = vocab[[is_katakana(c) for c in vocab]][3:]
return vocab, hiragana, katakana
def get_font_meta():
df = pd.read_csv(ASSETS_PATH / 'fonts.csv')
df.font_path = df.font_path.apply(lambda x: str(FONTS_ROOT / x))
font_map = {row.font_path: set(row.supported_chars) for row in df.dropna().itertuples()}
return df, font_map

View File

View File

@@ -0,0 +1,165 @@
import albumentations as A
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from manga_ocr_dev.env import MANGA109_ROOT, DATA_SYNTHETIC_ROOT
class MangaDataset(Dataset):
def __init__(self, processor, split, max_target_length, limit_size=None, augment=False, skip_packages=None):
self.processor = processor
self.max_target_length = max_target_length
data = []
print(f'Initializing dataset {split}...')
if skip_packages is None:
skip_packages = set()
else:
skip_packages = {f'{x:04d}' for x in skip_packages}
for path in sorted((DATA_SYNTHETIC_ROOT / 'meta').glob('*.csv')):
if path.stem in skip_packages:
print(f'Skipping package {path}')
continue
if not (DATA_SYNTHETIC_ROOT / 'img' / path.stem).is_dir():
print(f'Missing image data for package {path}, skipping')
continue
df = pd.read_csv(path)
df = df.dropna()
df['path'] = df.id.apply(lambda x: str(DATA_SYNTHETIC_ROOT / 'img' / path.stem / f'{x}.jpg'))
df = df[['path', 'text']]
df['synthetic'] = True
data.append(df)
df = pd.read_csv(MANGA109_ROOT / 'data.csv')
df = df[df.split == split].reset_index(drop=True)
df['path'] = df.crop_path.apply(lambda x: str(MANGA109_ROOT / x))
df = df[['path', 'text']]
df['synthetic'] = False
data.append(df)
data = pd.concat(data, ignore_index=True)
if limit_size:
data = data.iloc[:limit_size]
self.data = data
print(f'Dataset {split}: {len(self.data)}')
self.augment = augment
self.transform_medium, self.transform_heavy = self.get_transforms()
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data.loc[idx]
text = sample.text
if self.augment:
medium_p = 0.8
heavy_p = 0.02
transform_variant = np.random.choice(['none', 'medium', 'heavy'],
p=[1 - medium_p - heavy_p, medium_p, heavy_p])
transform = {
'none': None,
'medium': self.transform_medium,
'heavy': self.transform_heavy,
}[transform_variant]
else:
transform = None
pixel_values = self.read_image(self.processor, sample.path, transform)
labels = self.processor.tokenizer(text,
padding="max_length",
max_length=self.max_target_length,
truncation=True).input_ids
labels = np.array(labels)
# important: make sure that PAD tokens are ignored by the loss function
labels[labels == self.processor.tokenizer.pad_token_id] = -100
encoding = {
"pixel_values": pixel_values,
"labels": torch.tensor(labels),
}
return encoding
@staticmethod
def read_image(processor, path, transform=None):
img = cv2.imread(str(path))
if transform is None:
transform = A.ToGray(always_apply=True)
img = transform(image=img)['image']
pixel_values = processor(img, return_tensors="pt").pixel_values
return pixel_values.squeeze()
@staticmethod
def get_transforms():
t_medium = A.Compose([
A.Rotate(5, border_mode=cv2.BORDER_REPLICATE, p=0.2),
A.Perspective((0.01, 0.05), pad_mode=cv2.BORDER_REPLICATE, p=0.2),
A.InvertImg(p=0.05),
A.OneOf([
A.Downscale(0.25, 0.5, interpolation=cv2.INTER_LINEAR),
A.Downscale(0.25, 0.5, interpolation=cv2.INTER_NEAREST),
], p=0.1),
A.Blur(p=0.2),
A.Sharpen(p=0.2),
A.RandomBrightnessContrast(p=0.5),
A.GaussNoise((50, 200), p=0.3),
A.ImageCompression(0, 30, p=0.1),
A.ToGray(always_apply=True),
])
t_heavy = A.Compose([
A.Rotate(10, border_mode=cv2.BORDER_REPLICATE, p=0.2),
A.Perspective((0.01, 0.05), pad_mode=cv2.BORDER_REPLICATE, p=0.2),
A.InvertImg(p=0.05),
A.OneOf([
A.Downscale(0.1, 0.2, interpolation=cv2.INTER_LINEAR),
A.Downscale(0.1, 0.2, interpolation=cv2.INTER_NEAREST),
], p=0.1),
A.Blur((4, 9), p=0.5),
A.Sharpen(p=0.5),
A.RandomBrightnessContrast(0.8, 0.8, p=1),
A.GaussNoise((1000, 10000), p=0.3),
A.ImageCompression(0, 10, p=0.5),
A.ToGray(always_apply=True),
])
return t_medium, t_heavy
if __name__ == '__main__':
from manga_ocr_dev.training.get_model import get_processor
from manga_ocr_dev.training.utils import tensor_to_image
encoder_name = 'facebook/deit-tiny-patch16-224'
decoder_name = 'cl-tohoku/bert-base-japanese-char-v2'
max_length = 300
processor = get_processor(encoder_name, decoder_name)
ds = MangaDataset(processor, 'train', max_length, augment=True)
for i in range(20):
sample = ds[0]
img = tensor_to_image(sample['pixel_values'])
tokens = sample['labels']
tokens[tokens == -100] = processor.tokenizer.pad_token_id
text = ''.join(processor.decode(tokens, skip_special_tokens=True).split())
print(f'{i}:\n{text}\n')
plt.imshow(img)
plt.show()

View File

@@ -0,0 +1,63 @@
from transformers import AutoConfig, AutoModelForCausalLM, AutoModel, TrOCRProcessor, VisionEncoderDecoderModel, \
AutoFeatureExtractor, AutoTokenizer, VisionEncoderDecoderConfig
class TrOCRProcessorCustom(TrOCRProcessor):
"""The only point of this class is to bypass type checks of base class."""
def __init__(self, feature_extractor, tokenizer):
self.feature_extractor = feature_extractor
self.tokenizer = tokenizer
self.current_processor = self.feature_extractor
def get_processor(encoder_name, decoder_name):
feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_name)
tokenizer = AutoTokenizer.from_pretrained(decoder_name)
processor = TrOCRProcessorCustom(feature_extractor, tokenizer)
return processor
def get_model(encoder_name, decoder_name, max_length, num_decoder_layers=None):
encoder_config = AutoConfig.from_pretrained(encoder_name)
encoder_config.is_decoder = False
encoder_config.add_cross_attention = False
encoder = AutoModel.from_config(encoder_config)
decoder_config = AutoConfig.from_pretrained(decoder_name)
decoder_config.max_length = max_length
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
decoder = AutoModelForCausalLM.from_config(decoder_config)
if num_decoder_layers is not None:
if decoder_config.model_type == 'bert':
decoder.bert.encoder.layer = decoder.bert.encoder.layer[-num_decoder_layers:]
elif decoder_config.model_type in ('roberta', 'xlm-roberta'):
decoder.roberta.encoder.layer = decoder.roberta.encoder.layer[-num_decoder_layers:]
else:
raise ValueError(f'Unsupported model_type: {decoder_config.model_type}')
decoder_config.num_hidden_layers = num_decoder_layers
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
config.tie_word_embeddings = False
model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder, config=config)
processor = get_processor(encoder_name, decoder_name)
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size
# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = max_length
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
return model, processor

View File

@@ -0,0 +1,32 @@
import numpy as np
from datasets import load_metric
class Metrics:
def __init__(self, processor):
self.cer_metric = load_metric("cer")
self.processor = processor
def compute_metrics(self, pred):
label_ids = pred.label_ids
pred_ids = pred.predictions
print(label_ids.shape, pred_ids.shape)
pred_str = self.processor.batch_decode(pred_ids, skip_special_tokens=True)
label_ids[label_ids == -100] = self.processor.tokenizer.pad_token_id
label_str = self.processor.batch_decode(label_ids, skip_special_tokens=True)
pred_str = np.array([''.join(text.split()) for text in pred_str])
label_str = np.array([''.join(text.split()) for text in label_str])
results = {}
try:
results['cer'] = self.cer_metric.compute(predictions=pred_str, references=label_str)
except Exception as e:
print(e)
print(pred_str)
print(label_str)
results['cer'] = 0
results['accuracy'] = (pred_str == label_str).mean()
return results

View File

@@ -0,0 +1,64 @@
import fire
import wandb
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
from manga_ocr_dev.env import TRAIN_ROOT
from manga_ocr_dev.training.dataset import MangaDataset
from manga_ocr_dev.training.get_model import get_model
from manga_ocr_dev.training.metrics import Metrics
def run(
run_name='debug',
encoder_name='facebook/deit-tiny-patch16-224',
decoder_name='cl-tohoku/bert-base-japanese-char-v2',
max_len=300,
num_decoder_layers=2,
batch_size=64,
num_epochs=8,
fp16=True,
):
wandb.login()
model, processor = get_model(encoder_name, decoder_name, max_len, num_decoder_layers)
# keep package 0 for validation
train_dataset = MangaDataset(processor, 'train', max_len, augment=True, skip_packages=[0])
eval_dataset = MangaDataset(processor, 'test', max_len, augment=False, skip_packages=range(1, 9999))
metrics = Metrics(processor)
training_args = Seq2SeqTrainingArguments(
predict_with_generate=True,
evaluation_strategy='steps',
save_strategy='steps',
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
fp16=fp16,
fp16_full_eval=fp16,
dataloader_num_workers=16,
output_dir=TRAIN_ROOT,
logging_steps=10,
save_steps=20000,
eval_steps=20000,
num_train_epochs=num_epochs,
run_name=run_name
)
# instantiate trainer
trainer = Seq2SeqTrainer(
model=model,
tokenizer=processor.feature_extractor,
args=training_args,
compute_metrics=metrics.compute_metrics,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=default_data_collator,
)
trainer.train()
wandb.finish()
if __name__ == '__main__':
fire.Fire(run)

View File

@@ -0,0 +1,27 @@
import numpy as np
import torch
from torchinfo import summary
def encoder_summary(model, batch_size=4):
img_size = model.config.encoder.image_size
return summary(model.encoder, input_size=(batch_size, 3, img_size, img_size), depth=3,
col_names=["output_size", "num_params", "mult_adds"], device='cpu')
def decoder_summary(model, batch_size=4):
img_size = model.config.encoder.image_size
encoder_hidden_shape = (batch_size, (img_size // 16) ** 2 + 1, model.config.decoder.hidden_size)
decoder_inputs = {
'input_ids': torch.zeros(batch_size, 1, dtype=torch.int64),
'attention_mask': torch.ones(batch_size, 1, dtype=torch.int64),
'encoder_hidden_states': torch.rand(encoder_hidden_shape, dtype=torch.float32),
'return_dict': False
}
return summary(model.decoder, input_data=decoder_inputs, depth=4,
col_names=["output_size", "num_params", "mult_adds"],
device='cpu')
def tensor_to_image(img):
return ((img.cpu().numpy() + 1) / 2 * 255).clip(0, 255).astype(np.uint8).transpose(1, 2, 0)