training and synthetic data generation code
This commit is contained in:
98
manga_ocr_dev/README.md
Normal file
98
manga_ocr_dev/README.md
Normal file
@@ -0,0 +1,98 @@
|
||||
# Project structure
|
||||
|
||||
```
|
||||
assets/ # assets (see description below)
|
||||
manga_ocr/ # release code (inference only)
|
||||
manga_ocr_dev/ # development code
|
||||
env.py # global constants
|
||||
data/ # data preprocessing
|
||||
synthetic_data_generator/ # generation of synthetic image-text pairs
|
||||
training/ # model training
|
||||
```
|
||||
|
||||
## assets
|
||||
|
||||
### fonts.csv
|
||||
csv with columns:
|
||||
- font_path: path to font file, relative to `FONTS_ROOT`
|
||||
- supported_chars: string of characters supported by this font
|
||||
- num_chars: number of supported characters
|
||||
- label: common/regular/special (used to sample regular fonts more often than special)
|
||||
|
||||
List of fonts with metadata used by synthetic data generator.
|
||||
Provided file is just an example, you have to generate similar file for your own set of fonts,
|
||||
using `manga_ocr_dev/synthetic_data_generator/scan_fonts.py` script.
|
||||
Note that `label` will be filled with `regular` by default. You have to label your special fonts manually.
|
||||
|
||||
### lines_example.csv
|
||||
csv with columns:
|
||||
- source: source of text
|
||||
- id: unique id of the line
|
||||
- line: line from language corpus
|
||||
|
||||
Example of csv used for synthetic data generation.
|
||||
|
||||
### len_to_p.csv
|
||||
csv with columns:
|
||||
- len: length of text
|
||||
- p: probability of text of this length occurring in manga
|
||||
|
||||
Used by synthetic data generator to more-or-less match the natural distribution of text lengths.
|
||||
Computed based on Manga109-s dataset.
|
||||
|
||||
### vocab.csv
|
||||
List of all characters supported by tokenizer.
|
||||
|
||||
# Training OCR
|
||||
|
||||
`env.py` contains global constants used across the repo. Set your paths to data etc. there.
|
||||
|
||||
1. Download [Manga109-s](http://www.manga109.org/en/download_s.html) dataset.
|
||||
2. Set `MANGA109_ROOT`, so that your directory structure looks like this:
|
||||
```
|
||||
<MANGA109_ROOT>/
|
||||
Manga109s_released_2021_02_28/
|
||||
annotations/
|
||||
annotations.v2018.05.31/
|
||||
images/
|
||||
books.txt
|
||||
readme.txt
|
||||
```
|
||||
3. Preprocess Manga109-s with `data/process_manga109s.py`
|
||||
4. Optionally generate synthetic data (see below)
|
||||
5. Train with `manga_ocr_dev/training/train.py`
|
||||
|
||||
# Synthetic data generation
|
||||
|
||||
Generated data is split into packages (named `0000`, `0001` etc.) for easier management of large dataset.
|
||||
Each package is assumed to have similar data distribution, so that a properly balanced dataset
|
||||
can be built from any subset of packages.
|
||||
|
||||
Data generation pipeline assumes following directory structure:
|
||||
|
||||
```
|
||||
<DATA_SYNTHETIC_ROOT>/
|
||||
img/ # generated images (output from generation pipeline)
|
||||
0000/
|
||||
0001/
|
||||
...
|
||||
lines/ # lines from corpus (input to generation pipeline)
|
||||
0000.csv
|
||||
0001.csv
|
||||
...
|
||||
meta/ # metadata (output from generation pipeline)
|
||||
0000.csv
|
||||
0001.csv
|
||||
...
|
||||
```
|
||||
|
||||
To use a language corpus for data generation, `lines/*.csv` files must be provided.
|
||||
For a small example of such file see `assets/lines_example.csv`.
|
||||
|
||||
To generate synthetic data:
|
||||
1. Generate backgrounds with `data/generate_backgrounds.py`.
|
||||
2. Put your fonts in `<FONTS_ROOT>`.
|
||||
3. Generate fonts metadata with `synthetic_data_generator/scan_fonts.py`.
|
||||
4. Optionally manually label your fonts with `common/regular/special` labels.
|
||||
5. Provide `<DATA_SYNTHETIC_ROOT>/lines/*.csv`.
|
||||
6. Run `synthetic_data_generator/run_generate.py` for each package.
|
||||
0
manga_ocr_dev/__init__.py
Normal file
0
manga_ocr_dev/__init__.py
Normal file
0
manga_ocr_dev/data/__init__.py
Normal file
0
manga_ocr_dev/data/__init__.py
Normal file
85
manga_ocr_dev/data/generate_backgrounds.py
Normal file
85
manga_ocr_dev/data/generate_backgrounds.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from manga_ocr_dev.env import MANGA109_ROOT, BACKGROUND_DIR
|
||||
|
||||
|
||||
def find_rectangle(mask, y, x, aspect_ratio_range=(0.33, 3.0)):
|
||||
ymin_ = ymax_ = y
|
||||
xmin_ = xmax_ = x
|
||||
|
||||
ymin = ymax = xmin = xmax = None
|
||||
|
||||
while True:
|
||||
if ymin is None:
|
||||
ymin_ -= 1
|
||||
if ymin_ == 0 or mask[ymin_, xmin_:xmax_].any():
|
||||
ymin = ymin_
|
||||
|
||||
if ymax is None:
|
||||
ymax_ += 1
|
||||
if ymax_ == mask.shape[0] - 1 or mask[ymax_, xmin_:xmax_].any():
|
||||
ymax = ymax_
|
||||
|
||||
if xmin is None:
|
||||
xmin_ -= 1
|
||||
if xmin_ == 0 or mask[ymin_:ymax_, xmin_].any():
|
||||
xmin = xmin_
|
||||
|
||||
if xmax is None:
|
||||
xmax_ += 1
|
||||
if xmax_ == mask.shape[1] - 1 or mask[ymin_:ymax_, xmax_].any():
|
||||
xmax = xmax_
|
||||
|
||||
h = ymax_ - ymin_
|
||||
w = xmax_ - xmin_
|
||||
if h > 1 and w > 1:
|
||||
ratio = w / h
|
||||
if ratio < aspect_ratio_range[0] or ratio > aspect_ratio_range[1]:
|
||||
return ymin_, ymax_, xmin_, xmax_
|
||||
|
||||
if None not in (ymin, ymax, xmin, xmax):
|
||||
return ymin, ymax, xmin, xmax
|
||||
|
||||
|
||||
def generate_backgrounds(crops_per_page=5, min_size=40):
|
||||
data = pd.read_csv(MANGA109_ROOT / 'data.csv')
|
||||
frames_df = pd.read_csv(MANGA109_ROOT / 'frames.csv')
|
||||
|
||||
BACKGROUND_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
page_paths = data.page_path.unique()
|
||||
for page_path in tqdm(page_paths):
|
||||
page = cv2.imread(str(MANGA109_ROOT / page_path))
|
||||
mask = np.zeros((page.shape[0], page.shape[1]), dtype=bool)
|
||||
for row in data[data.page_path == page_path].itertuples():
|
||||
mask[row.ymin:row.ymax, row.xmin:row.xmax] = True
|
||||
|
||||
frames_mask = np.zeros((page.shape[0], page.shape[1]), dtype=bool)
|
||||
for row in frames_df[frames_df.page_path == page_path].itertuples():
|
||||
frames_mask[row.ymin:row.ymax, row.xmin:row.xmax] = True
|
||||
|
||||
mask = mask | ~frames_mask
|
||||
|
||||
if mask.all():
|
||||
continue
|
||||
|
||||
unmasked_points = np.stack(np.where(~mask), axis=1)
|
||||
for i in range(crops_per_page):
|
||||
p = unmasked_points[np.random.randint(0, unmasked_points.shape[0])]
|
||||
y, x = p
|
||||
ymin, ymax, xmin, xmax = find_rectangle(mask, y, x)
|
||||
crop = page[ymin:ymax, xmin:xmax]
|
||||
|
||||
if crop.shape[0] >= min_size and crop.shape[1] >= min_size:
|
||||
out_filename = '_'.join(
|
||||
Path(page_path).with_suffix('').parts[-2:]) + f'_{ymin}_{ymax}_{xmin}_{xmax}.png'
|
||||
cv2.imwrite(str(BACKGROUND_DIR / out_filename), crop)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
generate_backgrounds()
|
||||
103
manga_ocr_dev/data/process_manga109s.py
Normal file
103
manga_ocr_dev/data/process_manga109s.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from manga_ocr_dev.env import MANGA109_ROOT
|
||||
|
||||
|
||||
def get_books():
|
||||
root = MANGA109_ROOT / 'Manga109s_released_2021_02_28'
|
||||
books = (root / 'books.txt').read_text().splitlines()
|
||||
books = pd.DataFrame({
|
||||
'book': books,
|
||||
'annotations': [str(root / 'annotations' / f'{book}.xml') for book in books],
|
||||
'images': [str(root / 'images' / book) for book in books],
|
||||
})
|
||||
|
||||
return books
|
||||
|
||||
|
||||
def export_frames():
|
||||
books = get_books()
|
||||
|
||||
data = []
|
||||
for book in tqdm(books.itertuples(), total=len(books)):
|
||||
tree = ET.parse(book.annotations)
|
||||
root = tree.getroot()
|
||||
for page in root.findall('./pages/page'):
|
||||
for frame in page.findall('./frame'):
|
||||
row = {}
|
||||
row['book'] = book.book
|
||||
row['page_index'] = int(page.attrib['index'])
|
||||
row['page_path'] = str(Path(book.images) / f'{row["page_index"]:03d}.jpg')
|
||||
row['page_width'] = int(page.attrib['width'])
|
||||
row['page_height'] = int(page.attrib['height'])
|
||||
row['id'] = frame.attrib['id']
|
||||
row['xmin'] = int(frame.attrib['xmin'])
|
||||
row['ymin'] = int(frame.attrib['ymin'])
|
||||
row['xmax'] = int(frame.attrib['xmax'])
|
||||
row['ymax'] = int(frame.attrib['ymax'])
|
||||
data.append(row)
|
||||
data = pd.DataFrame(data)
|
||||
|
||||
data.page_path = data.page_path.apply(lambda x: '/'.join(Path(x).parts[-4:]))
|
||||
data.to_csv(MANGA109_ROOT / 'frames.csv', index=False)
|
||||
|
||||
|
||||
def export_crops():
|
||||
crops_root = MANGA109_ROOT / 'crops'
|
||||
crops_root.mkdir(parents=True, exist_ok=True)
|
||||
margin = 10
|
||||
|
||||
books = get_books()
|
||||
|
||||
data = []
|
||||
for book in tqdm(books.itertuples(), total=len(books)):
|
||||
tree = ET.parse(book.annotations)
|
||||
root = tree.getroot()
|
||||
for page in root.findall('./pages/page'):
|
||||
for text in page.findall('./text'):
|
||||
row = {}
|
||||
row['book'] = book.book
|
||||
row['page_index'] = int(page.attrib['index'])
|
||||
row['page_path'] = str(Path(book.images) / f'{row["page_index"]:03d}.jpg')
|
||||
row['page_width'] = int(page.attrib['width'])
|
||||
row['page_height'] = int(page.attrib['height'])
|
||||
row['id'] = text.attrib['id']
|
||||
row['text'] = text.text
|
||||
row['xmin'] = int(text.attrib['xmin'])
|
||||
row['ymin'] = int(text.attrib['ymin'])
|
||||
row['xmax'] = int(text.attrib['xmax'])
|
||||
row['ymax'] = int(text.attrib['ymax'])
|
||||
data.append(row)
|
||||
data = pd.DataFrame(data)
|
||||
|
||||
n_test = int(0.1 * len(data))
|
||||
data['split'] = 'train'
|
||||
data.loc[data.sample(len(data)).iloc[:n_test].index, 'split'] = 'test'
|
||||
|
||||
data['crop_path'] = str(crops_root) + '\\' + data.id + '.png'
|
||||
|
||||
data.page_path = data.page_path.apply(lambda x: '/'.join(Path(x).parts[-4:]))
|
||||
data.crop_path = data.crop_path.apply(lambda x: '/'.join(Path(x).parts[-2:]))
|
||||
data.to_csv(MANGA109_ROOT / 'data.csv', index=False)
|
||||
|
||||
for page_path, boxes in tqdm(data.groupby('page_path'), total=data.page_path.nunique()):
|
||||
img = cv2.imread(str(MANGA109_ROOT / page_path))
|
||||
|
||||
for box in boxes.itertuples():
|
||||
xmin = max(box.xmin - margin, 0)
|
||||
xmax = min(box.xmax + margin, img.shape[1])
|
||||
ymin = max(box.ymin - margin, 0)
|
||||
ymax = min(box.ymax + margin, img.shape[0])
|
||||
crop = img[ymin:ymax, xmin:xmax]
|
||||
out_path = (crops_root / box.id).with_suffix('.png')
|
||||
cv2.imwrite(str(out_path), crop)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
export_frames()
|
||||
export_crops()
|
||||
9
manga_ocr_dev/env.py
Normal file
9
manga_ocr_dev/env.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from pathlib import Path
|
||||
|
||||
ASSETS_PATH = Path(__file__).parent.parent / 'assets'
|
||||
|
||||
FONTS_ROOT = Path('~/data/jp_fonts').expanduser()
|
||||
DATA_SYNTHETIC_ROOT = Path('~/data/manga/synthetic').expanduser()
|
||||
BACKGROUND_DIR = Path('~/data/manga/Manga109s/background').expanduser()
|
||||
MANGA109_ROOT = Path('~/data/manga/Manga109s').expanduser()
|
||||
TRAIN_ROOT = Path('~/data/manga/out').expanduser()
|
||||
24
manga_ocr_dev/requirements.txt
Normal file
24
manga_ocr_dev/requirements.txt
Normal file
@@ -0,0 +1,24 @@
|
||||
datasets
|
||||
jiwer
|
||||
torchinfo
|
||||
transformers>=4.12.5
|
||||
unidic-lite
|
||||
ipadic
|
||||
mecab-python3
|
||||
fugashi
|
||||
matplotlib
|
||||
numpy
|
||||
opencv-python
|
||||
pandas
|
||||
Pillow
|
||||
scikit-image
|
||||
scikit-learn
|
||||
scipy
|
||||
torch
|
||||
torchvision
|
||||
tqdm
|
||||
wandb
|
||||
fire
|
||||
budou
|
||||
albumentations>=1.1
|
||||
html2image
|
||||
38
manga_ocr_dev/synthetic_data_generator/README.md
Normal file
38
manga_ocr_dev/synthetic_data_generator/README.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# Synthetic data generator
|
||||
|
||||
Generation of synthetic image-text pairs imitating Japanese manga for the purpose of training OCR.
|
||||
|
||||
Features:
|
||||
- using either text from corpus or random text
|
||||
- text overlaid on background images
|
||||
- drawing text bubbles
|
||||
- various fonts and font styles
|
||||
- variety of text layouts:
|
||||
- vertical and horizontal text
|
||||
- multi-line text
|
||||
- [furigana](https://en.wikipedia.org/wiki/Furigana) (added randomly)
|
||||
- [tate chū yoko](https://www.w3.org/International/articles/vertical-text/#tcy)
|
||||
|
||||
|
||||
Text rendering is done with the usage of [html2image](https://github.com/vgalin/html2image),
|
||||
which is a wrapper around Chrome/Chromium browser's headless mode.
|
||||
It's not too elegant of a solution, and it is very slow, but it only needs to be run once,
|
||||
and when parallelized, processing time is manageable (~17 min per 10000 images on a 16-thread machine).
|
||||
|
||||
The upside of this approach is that a quite complex problem of typesetting and text rendering
|
||||
(especially when dealing with both horizontal and vertical text) is offloaded to
|
||||
the browser engine, keeping the codebase relatively simple and extendable.
|
||||
|
||||
High-level generation pipeline is as follows:
|
||||
1. Preprocess text (truncate and/or split into lines, add random furigana).
|
||||
2. Render text on a transparent background, using HTML engine.
|
||||
3. Select background image from backgrounds dataset.
|
||||
4. Overlay the text on the background, optionally drawing a bubble around the text.
|
||||
|
||||
# Examples
|
||||
|
||||
## Images generated with text from [CC-100 Japanese corpus](https://data.statmt.org/cc-100/)
|
||||

|
||||
|
||||
## Images generated with random text
|
||||

|
||||
0
manga_ocr_dev/synthetic_data_generator/__init__.py
Normal file
0
manga_ocr_dev/synthetic_data_generator/__init__.py
Normal file
198
manga_ocr_dev/synthetic_data_generator/generator.py
Normal file
198
manga_ocr_dev/synthetic_data_generator/generator.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import budou
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from manga_ocr_dev.env import ASSETS_PATH, FONTS_ROOT
|
||||
from manga_ocr_dev.synthetic_data_generator.renderer import Renderer
|
||||
from manga_ocr_dev.synthetic_data_generator.utils import get_font_meta, get_charsets, is_ascii, is_kanji
|
||||
|
||||
|
||||
class SyntheticDataGenerator:
|
||||
def __init__(self):
|
||||
self.vocab, self.hiragana, self.katakana = get_charsets()
|
||||
self.len_to_p = pd.read_csv(ASSETS_PATH / 'len_to_p.csv')
|
||||
self.parser = budou.get_parser('tinysegmenter')
|
||||
self.fonts_df, self.font_map = get_font_meta()
|
||||
self.font_labels, self.font_p = self.get_font_labels_prob()
|
||||
self.renderer = Renderer()
|
||||
|
||||
def process(self, text=None, override_css_params=None):
|
||||
"""
|
||||
Generate image, text pair. Use source text if provided, otherwise generate random text.
|
||||
"""
|
||||
|
||||
if override_css_params is None:
|
||||
override_css_params = {}
|
||||
|
||||
if text is None:
|
||||
# if using random text, choose font first,
|
||||
# and then generate text using only characters supported by that font
|
||||
if 'font_path' not in override_css_params:
|
||||
font_path = self.get_random_font()
|
||||
vocab = self.font_map[font_path]
|
||||
override_css_params['font_path'] = font_path
|
||||
else:
|
||||
font_path = override_css_params['font_path']
|
||||
vocab = self.font_map[font_path]
|
||||
|
||||
words = self.get_random_words(vocab)
|
||||
|
||||
else:
|
||||
text = text.replace(' ', ' ')
|
||||
text = text.replace('…', '...')
|
||||
words = self.split_into_words(text)
|
||||
|
||||
lines = self.words_to_lines(words)
|
||||
text_gt = '\n'.join(lines)
|
||||
|
||||
if 'font_path' not in override_css_params:
|
||||
override_css_params['font_path'] = self.get_random_font(text_gt)
|
||||
|
||||
font_path = override_css_params.get('font_path')
|
||||
if font_path:
|
||||
vocab = self.font_map.get(font_path)
|
||||
|
||||
# remove unsupported characters
|
||||
lines = [''.join([c for c in line if c in vocab]) for line in lines]
|
||||
text_gt = '\n'.join(lines)
|
||||
else:
|
||||
vocab = None
|
||||
|
||||
if np.random.random() < 0.5:
|
||||
word_prob = np.random.choice([0.33, 1.0], p=[0.3, 0.7])
|
||||
|
||||
lines = [self.add_random_furigana(line, word_prob, vocab) for line in lines]
|
||||
|
||||
img, params = self.renderer.render(lines, override_css_params)
|
||||
return img, text_gt, params
|
||||
|
||||
def get_random_words(self, vocab):
|
||||
vocab = list(vocab)
|
||||
max_text_len = np.random.choice(self.len_to_p.len, p=self.len_to_p.p)
|
||||
|
||||
words = []
|
||||
text_len = 0
|
||||
while True:
|
||||
word = ''.join(np.random.choice(vocab, np.random.randint(1, 4)))
|
||||
words.append(word)
|
||||
text_len += len(word)
|
||||
if text_len + len(word) >= max_text_len:
|
||||
break
|
||||
|
||||
return words
|
||||
|
||||
def split_into_words(self, text):
|
||||
max_text_len = np.random.choice(self.len_to_p.len, p=self.len_to_p.p)
|
||||
|
||||
words = []
|
||||
text_len = 0
|
||||
for chunk in self.parser.parse(text)['chunks']:
|
||||
words.append(chunk.word)
|
||||
text_len += len(chunk.word)
|
||||
if text_len + len(chunk.word) >= max_text_len:
|
||||
break
|
||||
|
||||
return words
|
||||
|
||||
def words_to_lines(self, words):
|
||||
text = ''.join(words)
|
||||
|
||||
max_num_lines = 10
|
||||
min_line_len = len(text) // max_num_lines
|
||||
max_line_len = 20
|
||||
max_line_len = np.clip(np.random.poisson(6), min_line_len, max_line_len)
|
||||
lines = []
|
||||
line = ''
|
||||
for word in words:
|
||||
line += word
|
||||
if len(line) >= max_line_len:
|
||||
lines.append(line)
|
||||
line = ''
|
||||
if line:
|
||||
lines.append(line)
|
||||
|
||||
return lines
|
||||
|
||||
def add_random_furigana(self, line, word_prob=1.0, vocab=None):
|
||||
if vocab is None:
|
||||
vocab = self.vocab
|
||||
else:
|
||||
vocab = list(vocab)
|
||||
|
||||
processed = ''
|
||||
kanji_group = ''
|
||||
ascii_group = ''
|
||||
for i, c in enumerate(line):
|
||||
|
||||
if is_kanji(c):
|
||||
c_type = 'kanji'
|
||||
kanji_group += c
|
||||
elif is_ascii(c):
|
||||
c_type = 'ascii'
|
||||
ascii_group += c
|
||||
else:
|
||||
c_type = 'other'
|
||||
|
||||
if c_type != 'kanji' or i == len(line) - 1:
|
||||
if kanji_group:
|
||||
if np.random.uniform() < word_prob:
|
||||
furigana_len = int(np.clip(np.random.normal(1.5, 0.5), 1, 4) * len(kanji_group))
|
||||
char_source = np.random.choice(['hiragana', 'katakana', 'all'], p=[0.8, 0.15, 0.05])
|
||||
char_source = {
|
||||
'hiragana': self.hiragana,
|
||||
'katakana': self.katakana,
|
||||
'all': vocab
|
||||
}[char_source]
|
||||
furigana = ''.join(np.random.choice(char_source, furigana_len))
|
||||
processed += f'<ruby>{kanji_group}<rt>{furigana}</rt></ruby>'
|
||||
else:
|
||||
processed += kanji_group
|
||||
kanji_group = ''
|
||||
|
||||
if c_type != 'ascii' or i == len(line) - 1:
|
||||
if ascii_group:
|
||||
if len(ascii_group) <= 3 and np.random.uniform() < 0.7:
|
||||
processed += f'<span style="text-combine-upright: all">{ascii_group}</span>'
|
||||
else:
|
||||
processed += ascii_group
|
||||
ascii_group = ''
|
||||
|
||||
if c_type == 'other':
|
||||
processed += c
|
||||
|
||||
return processed
|
||||
|
||||
def is_font_supporting_text(self, font_path, text):
|
||||
chars = self.font_map[font_path]
|
||||
for c in text:
|
||||
if c.isspace():
|
||||
continue
|
||||
if c not in chars:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_font_labels_prob(self):
|
||||
labels = {
|
||||
'common': 0.2,
|
||||
'regular': 0.75,
|
||||
'special': 0.05,
|
||||
}
|
||||
labels = {k: labels[k] for k in self.fonts_df.label.unique()}
|
||||
p = np.array(list(labels.values()))
|
||||
p = p / p.sum()
|
||||
labels = list(labels.keys())
|
||||
return labels, p
|
||||
|
||||
def get_random_font(self, text=None):
|
||||
label = np.random.choice(self.font_labels, p=self.font_p)
|
||||
df = self.fonts_df[self.fonts_df.label == label]
|
||||
|
||||
if text is None:
|
||||
return df.sample(1).iloc[0].font_path
|
||||
|
||||
valid_mask = df.font_path.apply(lambda x: self.is_font_supporting_text(x, text))
|
||||
if not valid_mask.any():
|
||||
# if text contains characters not supported by any font, just pick some of the more capable fonts
|
||||
valid_mask = (df.num_chars >= 4000)
|
||||
|
||||
return str(FONTS_ROOT / df[valid_mask].sample(1).iloc[0].font_path)
|
||||
265
manga_ocr_dev/synthetic_data_generator/renderer.py
Normal file
265
manga_ocr_dev/synthetic_data_generator/renderer.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import albumentations as A
|
||||
import cv2
|
||||
import numpy as np
|
||||
from html2image import Html2Image
|
||||
|
||||
from manga_ocr_dev.env import BACKGROUND_DIR
|
||||
from manga_ocr_dev.synthetic_data_generator.utils import get_background_df
|
||||
|
||||
|
||||
class Renderer:
|
||||
def __init__(self):
|
||||
self.hti = Html2Image()
|
||||
self.background_df = get_background_df(BACKGROUND_DIR)
|
||||
self.max_size = 600
|
||||
|
||||
def render(self, lines, override_css_params=None):
|
||||
img, params = self.render_text(lines, override_css_params)
|
||||
img = self.render_background(img)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
img = A.LongestMaxSize(self.max_size)(image=img)['image']
|
||||
return img, params
|
||||
|
||||
def render_text(self, lines, override_css_params=None):
|
||||
"""Render text on transparent background and return as BGRA image."""
|
||||
|
||||
params = self.get_random_css_params()
|
||||
if override_css_params:
|
||||
params.update(override_css_params)
|
||||
|
||||
css = get_css(**params)
|
||||
|
||||
# this is just a rough estimate, image is cropped later anyway
|
||||
size = (
|
||||
int(max(len(line) for line in lines) * params['font_size'] * 1.5),
|
||||
int(len(lines) * params['font_size'] * (3 + params['line_height'])),
|
||||
)
|
||||
if params['vertical']:
|
||||
size = size[::-1]
|
||||
html = self.lines_to_html(lines)
|
||||
|
||||
filename = str(uuid.uuid4()) + '.png'
|
||||
self.hti.screenshot(html_str=html, css_str=css, save_as=filename, size=size)
|
||||
img = cv2.imread(filename, cv2.IMREAD_UNCHANGED)
|
||||
os.remove(filename)
|
||||
return img, params
|
||||
|
||||
@staticmethod
|
||||
def get_random_css_params():
|
||||
params = {
|
||||
'font_size': 48,
|
||||
'vertical': True if np.random.rand() < 0.7 else False,
|
||||
'line_height': 0.5,
|
||||
'background_color': 'transparent',
|
||||
'text_color': 'black',
|
||||
}
|
||||
|
||||
if np.random.rand() < 0.7:
|
||||
params['text_orientation'] = 'upright'
|
||||
|
||||
stroke_variant = np.random.choice(['stroke', 'shadow', 'none'], p=[0.8, 0.15, 0.05])
|
||||
if stroke_variant == 'stroke':
|
||||
params['stroke_size'] = np.random.choice([1, 2, 3, 4, 8])
|
||||
params['stroke_color'] = 'white'
|
||||
elif stroke_variant == 'shadow':
|
||||
params['shadow_size'] = np.random.choice([2, 5, 10])
|
||||
params['shadow_color'] = 'white' if np.random.rand() < 0.8 else 'black',
|
||||
elif stroke_variant == 'none':
|
||||
pass
|
||||
|
||||
return params
|
||||
|
||||
def render_background(self, img):
|
||||
"""Add background and/or text bubble to a BGRA image, crop and return as BGR image."""
|
||||
draw_bubble = np.random.random() < 0.7
|
||||
|
||||
m0 = int(min(img.shape[:2]) * 0.3)
|
||||
img = crop_by_alpha(img, m0)
|
||||
|
||||
background_path = self.background_df.sample(1).iloc[0].path
|
||||
background = cv2.imread(background_path)
|
||||
|
||||
t = [
|
||||
A.HorizontalFlip(),
|
||||
A.RandomRotate90(),
|
||||
A.InvertImg(),
|
||||
A.RandomBrightnessContrast((-0.2, 0.4), (-0.8, -0.3), p=0.5 if draw_bubble else 1),
|
||||
A.Blur((3, 5), p=0.3),
|
||||
A.Resize(img.shape[0], img.shape[1]),
|
||||
]
|
||||
|
||||
background = A.Compose(t)(image=background)['image']
|
||||
|
||||
if not draw_bubble:
|
||||
if np.random.rand() < 0.5:
|
||||
img[:, :, :3] = 255 - img[:, :, :3]
|
||||
|
||||
else:
|
||||
radius = np.random.uniform(0.7, 1.)
|
||||
thickness = np.random.choice([1, 2, 3])
|
||||
alpha = np.random.randint(60, 100)
|
||||
sigma = np.random.randint(10, 15)
|
||||
|
||||
ymin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12))
|
||||
ymax = img.shape[0] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12))
|
||||
xmin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12))
|
||||
xmax = img.shape[1] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.07, 0.12))
|
||||
|
||||
bubble_fill_color = (255, 255, 255, 255)
|
||||
bubble_contour_color = (0, 0, 0, 255)
|
||||
bubble = np.zeros((img.shape[0], img.shape[1], 4), dtype=np.uint8)
|
||||
bubble = rounded_rectangle(bubble, (xmin, ymin), (xmax, ymax), radius=radius, color=bubble_fill_color,
|
||||
thickness=-1)
|
||||
bubble = rounded_rectangle(bubble, (xmin, ymin), (xmax, ymax), radius=radius, color=bubble_contour_color,
|
||||
thickness=thickness)
|
||||
|
||||
t = [
|
||||
A.ElasticTransform(alpha=alpha, sigma=sigma, alpha_affine=0, p=0.8),
|
||||
]
|
||||
bubble = A.Compose(t)(image=bubble)['image']
|
||||
|
||||
background = blend(bubble, background)
|
||||
|
||||
img = blend(img, background)
|
||||
|
||||
ymin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2))
|
||||
ymax = img.shape[0] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2))
|
||||
xmin = m0 - int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2))
|
||||
xmax = img.shape[1] - m0 + int(min(img.shape[:2]) * np.random.uniform(0.01, 0.2))
|
||||
img = img[ymin:ymax, xmin:xmax]
|
||||
return img
|
||||
|
||||
def lines_to_html(self, lines):
|
||||
lines_str = '\n'.join(['<p>' + line + '</p>' for line in lines])
|
||||
html = f"<html><body>\n{lines_str}\n</body></html>"
|
||||
return html
|
||||
|
||||
|
||||
def crop_by_alpha(img, margin):
|
||||
y, x = np.where(img[:, :, 3] > 0)
|
||||
ymin = y.min()
|
||||
ymax = y.max()
|
||||
xmin = x.min()
|
||||
xmax = x.max()
|
||||
img = img[ymin:ymax, xmin:xmax]
|
||||
img = np.pad(img, ((margin, margin), (margin, margin), (0, 0)))
|
||||
return img
|
||||
|
||||
|
||||
def blend(img, background):
|
||||
alpha = (img[:, :, 3] / 255)[:, :, np.newaxis]
|
||||
img = img[:, :, :3]
|
||||
img = (background * (1 - alpha) + img * alpha).astype(np.uint8)
|
||||
return img
|
||||
|
||||
|
||||
def rounded_rectangle(src, top_left, bottom_right, radius=1, color=255, thickness=1, line_type=cv2.LINE_AA):
|
||||
"""From https://stackoverflow.com/a/60210706"""
|
||||
|
||||
# corners:
|
||||
# p1 - p2
|
||||
# | |
|
||||
# p4 - p3
|
||||
|
||||
p1 = top_left
|
||||
p2 = (bottom_right[0], top_left[1])
|
||||
p3 = bottom_right
|
||||
p4 = (top_left[0], bottom_right[1])
|
||||
|
||||
height = abs(bottom_right[1] - top_left[1])
|
||||
width = abs(bottom_right[0] - top_left[0])
|
||||
|
||||
if radius > 1:
|
||||
radius = 1
|
||||
|
||||
corner_radius = int(radius * (min(height, width) / 2))
|
||||
|
||||
if thickness < 0:
|
||||
# big rect
|
||||
top_left_main_rect = (int(p1[0] + corner_radius), int(p1[1]))
|
||||
bottom_right_main_rect = (int(p3[0] - corner_radius), int(p3[1]))
|
||||
|
||||
top_left_rect_left = (p1[0], p1[1] + corner_radius)
|
||||
bottom_right_rect_left = (p4[0] + corner_radius, p4[1] - corner_radius)
|
||||
|
||||
top_left_rect_right = (p2[0] - corner_radius, p2[1] + corner_radius)
|
||||
bottom_right_rect_right = (p3[0], p3[1] - corner_radius)
|
||||
|
||||
all_rects = [
|
||||
[top_left_main_rect, bottom_right_main_rect],
|
||||
[top_left_rect_left, bottom_right_rect_left],
|
||||
[top_left_rect_right, bottom_right_rect_right]]
|
||||
|
||||
[cv2.rectangle(src, rect[0], rect[1], color, thickness) for rect in all_rects]
|
||||
|
||||
# draw straight lines
|
||||
cv2.line(src, (p1[0] + corner_radius, p1[1]), (p2[0] - corner_radius, p2[1]), color, abs(thickness), line_type)
|
||||
cv2.line(src, (p2[0], p2[1] + corner_radius), (p3[0], p3[1] - corner_radius), color, abs(thickness), line_type)
|
||||
cv2.line(src, (p3[0] - corner_radius, p4[1]), (p4[0] + corner_radius, p3[1]), color, abs(thickness), line_type)
|
||||
cv2.line(src, (p4[0], p4[1] - corner_radius), (p1[0], p1[1] + corner_radius), color, abs(thickness), line_type)
|
||||
|
||||
# draw arcs
|
||||
cv2.ellipse(src, (p1[0] + corner_radius, p1[1] + corner_radius), (corner_radius, corner_radius), 180.0, 0, 90,
|
||||
color, thickness, line_type)
|
||||
cv2.ellipse(src, (p2[0] - corner_radius, p2[1] + corner_radius), (corner_radius, corner_radius), 270.0, 0, 90,
|
||||
color, thickness, line_type)
|
||||
cv2.ellipse(src, (p3[0] - corner_radius, p3[1] - corner_radius), (corner_radius, corner_radius), 0.0, 0, 90, color,
|
||||
thickness, line_type)
|
||||
cv2.ellipse(src, (p4[0] + corner_radius, p4[1] - corner_radius), (corner_radius, corner_radius), 90.0, 0, 90, color,
|
||||
thickness, line_type)
|
||||
|
||||
return src
|
||||
|
||||
|
||||
def get_css(
|
||||
font_size,
|
||||
font_path,
|
||||
vertical=True,
|
||||
background_color='white',
|
||||
text_color='black',
|
||||
shadow_size=0,
|
||||
shadow_color='black',
|
||||
stroke_size=0,
|
||||
stroke_color='black',
|
||||
letter_spacing=None,
|
||||
line_height=0.5,
|
||||
text_orientation=None,
|
||||
):
|
||||
styles = [
|
||||
f"background-color: {background_color};",
|
||||
f"font-size: {font_size}px;",
|
||||
f"color: {text_color};",
|
||||
"font-family: custom;",
|
||||
f"line-height: {line_height};",
|
||||
"margin: 20px;",
|
||||
]
|
||||
|
||||
if text_orientation:
|
||||
styles.append(f"text-orientation: {text_orientation};")
|
||||
|
||||
if vertical:
|
||||
styles.append("writing-mode: vertical-rl;")
|
||||
|
||||
if shadow_size > 0:
|
||||
styles.append(f"text-shadow: 0 0 {shadow_size}px {shadow_color};")
|
||||
|
||||
if stroke_size > 0:
|
||||
# stroke is simulated by shadow overlaid multiple times
|
||||
styles.extend([
|
||||
f"text-shadow: " + ','.join([f"0 0 {stroke_size}px {stroke_color}"] * 10 * stroke_size) + ";",
|
||||
"-webkit-font-smoothing: antialiased;",
|
||||
])
|
||||
|
||||
if letter_spacing:
|
||||
styles.append(f"letter-spacing: {letter_spacing}em;")
|
||||
|
||||
font_path = font_path.replace('\\', '/')
|
||||
|
||||
styles_str = '\n'.join(styles)
|
||||
css = ""
|
||||
css += '\n@font-face {\nfont-family: custom;\nsrc: url("' + font_path + '");\n}\n'
|
||||
css += "body {\n" + styles_str + "\n}"
|
||||
return css
|
||||
64
manga_ocr_dev/synthetic_data_generator/run_generate.py
Normal file
64
manga_ocr_dev/synthetic_data_generator/run_generate.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import fire
|
||||
import pandas as pd
|
||||
from tqdm.contrib.concurrent import thread_map
|
||||
|
||||
from manga_ocr_dev.env import FONTS_ROOT, DATA_SYNTHETIC_ROOT
|
||||
from manga_ocr_dev.synthetic_data_generator.generator import SyntheticDataGenerator
|
||||
|
||||
generator = SyntheticDataGenerator()
|
||||
|
||||
|
||||
def f(args):
|
||||
try:
|
||||
i, source, id_, text = args
|
||||
filename = f'{id_}.jpg'
|
||||
img, text_gt, params = generator.process(text)
|
||||
|
||||
cv2.imwrite(str(OUT_DIR / filename), img)
|
||||
|
||||
font_path = Path(params['font_path']).relative_to(FONTS_ROOT)
|
||||
ret = source, id_, text_gt, params['vertical'], str(font_path)
|
||||
return ret
|
||||
|
||||
except Exception as e:
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
def run(package=0, n_random=1000, n_limit=None, max_workers=16):
|
||||
"""
|
||||
:param package: number of data package to generate
|
||||
:param n_random: how many samples with random text to generate
|
||||
:param n_limit: limit number of generated samples (for debugging)
|
||||
:param max_workers: max number of workers
|
||||
"""
|
||||
|
||||
package = f'{package:04d}'
|
||||
lines = pd.read_csv(DATA_SYNTHETIC_ROOT / f'lines/{package}.csv')
|
||||
random_lines = pd.DataFrame({
|
||||
'source': 'random',
|
||||
'id': [f'random_{package}_{i}' for i in range(n_random)],
|
||||
'line': None
|
||||
})
|
||||
lines = pd.concat([lines, random_lines], ignore_index=True)
|
||||
if n_limit:
|
||||
lines = lines.sample(n_limit)
|
||||
args = [(i, *values) for i, values in enumerate(lines.values)]
|
||||
|
||||
global OUT_DIR
|
||||
OUT_DIR = DATA_SYNTHETIC_ROOT / 'img' / package
|
||||
OUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
data = thread_map(f, args, max_workers=max_workers, desc=f'Processing package {package}')
|
||||
|
||||
data = pd.DataFrame(data, columns=['source', 'id', 'text', 'vertical', 'font_path'])
|
||||
meta_path = DATA_SYNTHETIC_ROOT / f'meta/{package}.csv'
|
||||
meta_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
data.to_csv(meta_path, index=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(run)
|
||||
72
manga_ocr_dev/synthetic_data_generator/scan_fonts.py
Normal file
72
manga_ocr_dev/synthetic_data_generator/scan_fonts.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import PIL
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from PIL import ImageDraw, ImageFont
|
||||
from fontTools.ttLib import TTFont
|
||||
from tqdm.contrib.concurrent import process_map
|
||||
|
||||
from manga_ocr_dev.env import ASSETS_PATH, FONTS_ROOT
|
||||
|
||||
vocab = pd.read_csv(ASSETS_PATH / 'vocab.csv').char.values
|
||||
|
||||
|
||||
def has_glyph(font, glyph):
|
||||
for table in font['cmap'].tables:
|
||||
if ord(glyph) in table.cmap.keys():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def process(font_path):
|
||||
"""
|
||||
Get supported characters list for a given font.
|
||||
Font metadata is not always reliable, so try to render each character and see if anything shows up.
|
||||
Still not perfect, because sometimes unsupported characters show up as rectangles.
|
||||
"""
|
||||
|
||||
try:
|
||||
font_path = str(font_path)
|
||||
ttfont = TTFont(font_path)
|
||||
pil_font = ImageFont.truetype(font_path, 24)
|
||||
|
||||
supported_chars = []
|
||||
|
||||
for char in vocab:
|
||||
if not has_glyph(ttfont, char):
|
||||
continue
|
||||
|
||||
image = PIL.Image.new('L', (40, 40), 255)
|
||||
draw = ImageDraw.Draw(image)
|
||||
draw.text((10, 0), char, 0, font=pil_font)
|
||||
if (np.array(image) != 255).sum() == 0:
|
||||
continue
|
||||
|
||||
supported_chars.append(char)
|
||||
|
||||
supported_chars = ''.join(supported_chars)
|
||||
except Exception as e:
|
||||
print(f'Error while processing {font_path}: {e}')
|
||||
supported_chars = ''
|
||||
|
||||
return supported_chars
|
||||
|
||||
|
||||
def main():
|
||||
path_in = FONTS_ROOT
|
||||
out_path = ASSETS_PATH / 'fonts.csv'
|
||||
|
||||
suffixes = {'.TTF', '.otf', '.ttc', '.ttf'}
|
||||
font_paths = [path for path in path_in.glob('**/*') if
|
||||
path.suffix in suffixes]
|
||||
|
||||
data = process_map(process, font_paths, max_workers=16)
|
||||
|
||||
font_paths = [str(path.relative_to(FONTS_ROOT)) for path in font_paths]
|
||||
data = pd.DataFrame({'font_path': font_paths, 'supported_chars': data})
|
||||
data['num_chars'] = data.supported_chars.str.len()
|
||||
data['label'] = 'regular'
|
||||
data.to_csv(out_path, index=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
54
manga_ocr_dev/synthetic_data_generator/utils.py
Normal file
54
manga_ocr_dev/synthetic_data_generator/utils.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import pandas as pd
|
||||
import unicodedata
|
||||
|
||||
from manga_ocr_dev.env import ASSETS_PATH, FONTS_ROOT
|
||||
|
||||
|
||||
def get_background_df(background_dir):
|
||||
background_df = []
|
||||
for path in background_dir.iterdir():
|
||||
ymin, ymax, xmin, xmax = [int(v) for v in path.stem.split('_')[-4:]]
|
||||
h = ymax - ymin
|
||||
w = xmax - xmin
|
||||
ratio = w / h
|
||||
|
||||
background_df.append({
|
||||
'path': str(path),
|
||||
'h': h,
|
||||
'w': w,
|
||||
'ratio': ratio,
|
||||
})
|
||||
background_df = pd.DataFrame(background_df)
|
||||
return background_df
|
||||
|
||||
|
||||
def is_kanji(ch):
|
||||
return 'CJK UNIFIED IDEOGRAPH' in unicodedata.name(ch)
|
||||
|
||||
|
||||
def is_hiragana(ch):
|
||||
return 'HIRAGANA' in unicodedata.name(ch)
|
||||
|
||||
|
||||
def is_katakana(ch):
|
||||
return 'KATAKANA' in unicodedata.name(ch)
|
||||
|
||||
|
||||
def is_ascii(ch):
|
||||
return ord(ch) < 128
|
||||
|
||||
|
||||
def get_charsets(vocab_path=None):
|
||||
if vocab_path is None:
|
||||
vocab_path = ASSETS_PATH / 'vocab.csv'
|
||||
vocab = pd.read_csv(vocab_path).char.values
|
||||
hiragana = vocab[[is_hiragana(c) for c in vocab]][:-6]
|
||||
katakana = vocab[[is_katakana(c) for c in vocab]][3:]
|
||||
return vocab, hiragana, katakana
|
||||
|
||||
|
||||
def get_font_meta():
|
||||
df = pd.read_csv(ASSETS_PATH / 'fonts.csv')
|
||||
df.font_path = df.font_path.apply(lambda x: str(FONTS_ROOT / x))
|
||||
font_map = {row.font_path: set(row.supported_chars) for row in df.dropna().itertuples()}
|
||||
return df, font_map
|
||||
0
manga_ocr_dev/training/__init__.py
Normal file
0
manga_ocr_dev/training/__init__.py
Normal file
165
manga_ocr_dev/training/dataset.py
Normal file
165
manga_ocr_dev/training/dataset.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import albumentations as A
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from manga_ocr_dev.env import MANGA109_ROOT, DATA_SYNTHETIC_ROOT
|
||||
|
||||
|
||||
class MangaDataset(Dataset):
|
||||
def __init__(self, processor, split, max_target_length, limit_size=None, augment=False, skip_packages=None):
|
||||
self.processor = processor
|
||||
self.max_target_length = max_target_length
|
||||
|
||||
data = []
|
||||
|
||||
print(f'Initializing dataset {split}...')
|
||||
|
||||
if skip_packages is None:
|
||||
skip_packages = set()
|
||||
else:
|
||||
skip_packages = {f'{x:04d}' for x in skip_packages}
|
||||
|
||||
for path in sorted((DATA_SYNTHETIC_ROOT / 'meta').glob('*.csv')):
|
||||
if path.stem in skip_packages:
|
||||
print(f'Skipping package {path}')
|
||||
continue
|
||||
if not (DATA_SYNTHETIC_ROOT / 'img' / path.stem).is_dir():
|
||||
print(f'Missing image data for package {path}, skipping')
|
||||
continue
|
||||
df = pd.read_csv(path)
|
||||
df = df.dropna()
|
||||
df['path'] = df.id.apply(lambda x: str(DATA_SYNTHETIC_ROOT / 'img' / path.stem / f'{x}.jpg'))
|
||||
df = df[['path', 'text']]
|
||||
df['synthetic'] = True
|
||||
data.append(df)
|
||||
|
||||
df = pd.read_csv(MANGA109_ROOT / 'data.csv')
|
||||
df = df[df.split == split].reset_index(drop=True)
|
||||
df['path'] = df.crop_path.apply(lambda x: str(MANGA109_ROOT / x))
|
||||
df = df[['path', 'text']]
|
||||
df['synthetic'] = False
|
||||
data.append(df)
|
||||
|
||||
data = pd.concat(data, ignore_index=True)
|
||||
|
||||
if limit_size:
|
||||
data = data.iloc[:limit_size]
|
||||
self.data = data
|
||||
|
||||
print(f'Dataset {split}: {len(self.data)}')
|
||||
|
||||
self.augment = augment
|
||||
self.transform_medium, self.transform_heavy = self.get_transforms()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.data.loc[idx]
|
||||
text = sample.text
|
||||
|
||||
if self.augment:
|
||||
medium_p = 0.8
|
||||
heavy_p = 0.02
|
||||
transform_variant = np.random.choice(['none', 'medium', 'heavy'],
|
||||
p=[1 - medium_p - heavy_p, medium_p, heavy_p])
|
||||
transform = {
|
||||
'none': None,
|
||||
'medium': self.transform_medium,
|
||||
'heavy': self.transform_heavy,
|
||||
}[transform_variant]
|
||||
else:
|
||||
transform = None
|
||||
|
||||
pixel_values = self.read_image(self.processor, sample.path, transform)
|
||||
labels = self.processor.tokenizer(text,
|
||||
padding="max_length",
|
||||
max_length=self.max_target_length,
|
||||
truncation=True).input_ids
|
||||
labels = np.array(labels)
|
||||
# important: make sure that PAD tokens are ignored by the loss function
|
||||
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
||||
|
||||
encoding = {
|
||||
"pixel_values": pixel_values,
|
||||
"labels": torch.tensor(labels),
|
||||
}
|
||||
return encoding
|
||||
|
||||
@staticmethod
|
||||
def read_image(processor, path, transform=None):
|
||||
img = cv2.imread(str(path))
|
||||
|
||||
if transform is None:
|
||||
transform = A.ToGray(always_apply=True)
|
||||
|
||||
img = transform(image=img)['image']
|
||||
|
||||
pixel_values = processor(img, return_tensors="pt").pixel_values
|
||||
return pixel_values.squeeze()
|
||||
|
||||
@staticmethod
|
||||
def get_transforms():
|
||||
t_medium = A.Compose([
|
||||
A.Rotate(5, border_mode=cv2.BORDER_REPLICATE, p=0.2),
|
||||
A.Perspective((0.01, 0.05), pad_mode=cv2.BORDER_REPLICATE, p=0.2),
|
||||
A.InvertImg(p=0.05),
|
||||
|
||||
A.OneOf([
|
||||
A.Downscale(0.25, 0.5, interpolation=cv2.INTER_LINEAR),
|
||||
A.Downscale(0.25, 0.5, interpolation=cv2.INTER_NEAREST),
|
||||
], p=0.1),
|
||||
A.Blur(p=0.2),
|
||||
A.Sharpen(p=0.2),
|
||||
A.RandomBrightnessContrast(p=0.5),
|
||||
A.GaussNoise((50, 200), p=0.3),
|
||||
A.ImageCompression(0, 30, p=0.1),
|
||||
A.ToGray(always_apply=True),
|
||||
])
|
||||
|
||||
t_heavy = A.Compose([
|
||||
A.Rotate(10, border_mode=cv2.BORDER_REPLICATE, p=0.2),
|
||||
A.Perspective((0.01, 0.05), pad_mode=cv2.BORDER_REPLICATE, p=0.2),
|
||||
A.InvertImg(p=0.05),
|
||||
|
||||
A.OneOf([
|
||||
A.Downscale(0.1, 0.2, interpolation=cv2.INTER_LINEAR),
|
||||
A.Downscale(0.1, 0.2, interpolation=cv2.INTER_NEAREST),
|
||||
], p=0.1),
|
||||
A.Blur((4, 9), p=0.5),
|
||||
A.Sharpen(p=0.5),
|
||||
A.RandomBrightnessContrast(0.8, 0.8, p=1),
|
||||
A.GaussNoise((1000, 10000), p=0.3),
|
||||
A.ImageCompression(0, 10, p=0.5),
|
||||
A.ToGray(always_apply=True),
|
||||
])
|
||||
|
||||
return t_medium, t_heavy
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from manga_ocr_dev.training.get_model import get_processor
|
||||
from manga_ocr_dev.training.utils import tensor_to_image
|
||||
|
||||
encoder_name = 'facebook/deit-tiny-patch16-224'
|
||||
decoder_name = 'cl-tohoku/bert-base-japanese-char-v2'
|
||||
|
||||
max_length = 300
|
||||
|
||||
processor = get_processor(encoder_name, decoder_name)
|
||||
ds = MangaDataset(processor, 'train', max_length, augment=True)
|
||||
|
||||
for i in range(20):
|
||||
sample = ds[0]
|
||||
img = tensor_to_image(sample['pixel_values'])
|
||||
tokens = sample['labels']
|
||||
tokens[tokens == -100] = processor.tokenizer.pad_token_id
|
||||
text = ''.join(processor.decode(tokens, skip_special_tokens=True).split())
|
||||
|
||||
print(f'{i}:\n{text}\n')
|
||||
plt.imshow(img)
|
||||
plt.show()
|
||||
63
manga_ocr_dev/training/get_model.py
Normal file
63
manga_ocr_dev/training/get_model.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoModel, TrOCRProcessor, VisionEncoderDecoderModel, \
|
||||
AutoFeatureExtractor, AutoTokenizer, VisionEncoderDecoderConfig
|
||||
|
||||
|
||||
class TrOCRProcessorCustom(TrOCRProcessor):
|
||||
"""The only point of this class is to bypass type checks of base class."""
|
||||
|
||||
def __init__(self, feature_extractor, tokenizer):
|
||||
self.feature_extractor = feature_extractor
|
||||
self.tokenizer = tokenizer
|
||||
self.current_processor = self.feature_extractor
|
||||
|
||||
|
||||
def get_processor(encoder_name, decoder_name):
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(decoder_name)
|
||||
processor = TrOCRProcessorCustom(feature_extractor, tokenizer)
|
||||
return processor
|
||||
|
||||
|
||||
def get_model(encoder_name, decoder_name, max_length, num_decoder_layers=None):
|
||||
encoder_config = AutoConfig.from_pretrained(encoder_name)
|
||||
encoder_config.is_decoder = False
|
||||
encoder_config.add_cross_attention = False
|
||||
encoder = AutoModel.from_config(encoder_config)
|
||||
|
||||
decoder_config = AutoConfig.from_pretrained(decoder_name)
|
||||
decoder_config.max_length = max_length
|
||||
decoder_config.is_decoder = True
|
||||
decoder_config.add_cross_attention = True
|
||||
decoder = AutoModelForCausalLM.from_config(decoder_config)
|
||||
|
||||
if num_decoder_layers is not None:
|
||||
if decoder_config.model_type == 'bert':
|
||||
decoder.bert.encoder.layer = decoder.bert.encoder.layer[-num_decoder_layers:]
|
||||
elif decoder_config.model_type in ('roberta', 'xlm-roberta'):
|
||||
decoder.roberta.encoder.layer = decoder.roberta.encoder.layer[-num_decoder_layers:]
|
||||
else:
|
||||
raise ValueError(f'Unsupported model_type: {decoder_config.model_type}')
|
||||
|
||||
decoder_config.num_hidden_layers = num_decoder_layers
|
||||
|
||||
config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
|
||||
config.tie_word_embeddings = False
|
||||
model = VisionEncoderDecoderModel(encoder=encoder, decoder=decoder, config=config)
|
||||
|
||||
processor = get_processor(encoder_name, decoder_name)
|
||||
|
||||
# set special tokens used for creating the decoder_input_ids from the labels
|
||||
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
|
||||
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
||||
# make sure vocab size is set correctly
|
||||
model.config.vocab_size = model.config.decoder.vocab_size
|
||||
|
||||
# set beam search parameters
|
||||
model.config.eos_token_id = processor.tokenizer.sep_token_id
|
||||
model.config.max_length = max_length
|
||||
model.config.early_stopping = True
|
||||
model.config.no_repeat_ngram_size = 3
|
||||
model.config.length_penalty = 2.0
|
||||
model.config.num_beams = 4
|
||||
|
||||
return model, processor
|
||||
32
manga_ocr_dev/training/metrics.py
Normal file
32
manga_ocr_dev/training/metrics.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import numpy as np
|
||||
from datasets import load_metric
|
||||
|
||||
|
||||
class Metrics:
|
||||
def __init__(self, processor):
|
||||
self.cer_metric = load_metric("cer")
|
||||
self.processor = processor
|
||||
|
||||
def compute_metrics(self, pred):
|
||||
label_ids = pred.label_ids
|
||||
pred_ids = pred.predictions
|
||||
print(label_ids.shape, pred_ids.shape)
|
||||
|
||||
pred_str = self.processor.batch_decode(pred_ids, skip_special_tokens=True)
|
||||
label_ids[label_ids == -100] = self.processor.tokenizer.pad_token_id
|
||||
label_str = self.processor.batch_decode(label_ids, skip_special_tokens=True)
|
||||
|
||||
pred_str = np.array([''.join(text.split()) for text in pred_str])
|
||||
label_str = np.array([''.join(text.split()) for text in label_str])
|
||||
|
||||
results = {}
|
||||
try:
|
||||
results['cer'] = self.cer_metric.compute(predictions=pred_str, references=label_str)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(pred_str)
|
||||
print(label_str)
|
||||
results['cer'] = 0
|
||||
results['accuracy'] = (pred_str == label_str).mean()
|
||||
|
||||
return results
|
||||
64
manga_ocr_dev/training/train.py
Normal file
64
manga_ocr_dev/training/train.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import fire
|
||||
import wandb
|
||||
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, default_data_collator
|
||||
|
||||
from manga_ocr_dev.env import TRAIN_ROOT
|
||||
from manga_ocr_dev.training.dataset import MangaDataset
|
||||
from manga_ocr_dev.training.get_model import get_model
|
||||
from manga_ocr_dev.training.metrics import Metrics
|
||||
|
||||
|
||||
def run(
|
||||
run_name='debug',
|
||||
encoder_name='facebook/deit-tiny-patch16-224',
|
||||
decoder_name='cl-tohoku/bert-base-japanese-char-v2',
|
||||
max_len=300,
|
||||
num_decoder_layers=2,
|
||||
batch_size=64,
|
||||
num_epochs=8,
|
||||
fp16=True,
|
||||
):
|
||||
wandb.login()
|
||||
|
||||
model, processor = get_model(encoder_name, decoder_name, max_len, num_decoder_layers)
|
||||
|
||||
# keep package 0 for validation
|
||||
train_dataset = MangaDataset(processor, 'train', max_len, augment=True, skip_packages=[0])
|
||||
eval_dataset = MangaDataset(processor, 'test', max_len, augment=False, skip_packages=range(1, 9999))
|
||||
|
||||
metrics = Metrics(processor)
|
||||
|
||||
training_args = Seq2SeqTrainingArguments(
|
||||
predict_with_generate=True,
|
||||
evaluation_strategy='steps',
|
||||
save_strategy='steps',
|
||||
per_device_train_batch_size=batch_size,
|
||||
per_device_eval_batch_size=batch_size,
|
||||
fp16=fp16,
|
||||
fp16_full_eval=fp16,
|
||||
dataloader_num_workers=16,
|
||||
output_dir=TRAIN_ROOT,
|
||||
logging_steps=10,
|
||||
save_steps=20000,
|
||||
eval_steps=20000,
|
||||
num_train_epochs=num_epochs,
|
||||
run_name=run_name
|
||||
)
|
||||
|
||||
# instantiate trainer
|
||||
trainer = Seq2SeqTrainer(
|
||||
model=model,
|
||||
tokenizer=processor.feature_extractor,
|
||||
args=training_args,
|
||||
compute_metrics=metrics.compute_metrics,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=default_data_collator,
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
wandb.finish()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
fire.Fire(run)
|
||||
27
manga_ocr_dev/training/utils.py
Normal file
27
manga_ocr_dev/training/utils.py
Normal file
@@ -0,0 +1,27 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchinfo import summary
|
||||
|
||||
|
||||
def encoder_summary(model, batch_size=4):
|
||||
img_size = model.config.encoder.image_size
|
||||
return summary(model.encoder, input_size=(batch_size, 3, img_size, img_size), depth=3,
|
||||
col_names=["output_size", "num_params", "mult_adds"], device='cpu')
|
||||
|
||||
|
||||
def decoder_summary(model, batch_size=4):
|
||||
img_size = model.config.encoder.image_size
|
||||
encoder_hidden_shape = (batch_size, (img_size // 16) ** 2 + 1, model.config.decoder.hidden_size)
|
||||
decoder_inputs = {
|
||||
'input_ids': torch.zeros(batch_size, 1, dtype=torch.int64),
|
||||
'attention_mask': torch.ones(batch_size, 1, dtype=torch.int64),
|
||||
'encoder_hidden_states': torch.rand(encoder_hidden_shape, dtype=torch.float32),
|
||||
'return_dict': False
|
||||
}
|
||||
return summary(model.decoder, input_data=decoder_inputs, depth=4,
|
||||
col_names=["output_size", "num_params", "mult_adds"],
|
||||
device='cpu')
|
||||
|
||||
|
||||
def tensor_to_image(img):
|
||||
return ((img.cpu().numpy() + 1) / 2 * 255).clip(0, 255).astype(np.uint8).transpose(1, 2, 0)
|
||||
Reference in New Issue
Block a user