pass max_length to generate method

This commit is contained in:
Maciej Budyś
2022-11-03 21:33:10 +01:00
parent 6309289958
commit 694ff2e829

View File

@@ -39,7 +39,7 @@ class MangaOcr:
img = img.convert('L').convert('RGB')
x = self._preprocess(img)
x = self.model.generate(x[None].to(self.model.device))[0].cpu()
x = self.model.generate(x[None].to(self.model.device), max_length=300)[0].cpu()
x = self.tokenizer.decode(x, skip_special_tokens=True)
x = post_process(x)
return x