pass max_length to generate method

This commit is contained in:
Maciej Budyś
2022-11-03 21:33:10 +01:00
parent 6309289958
commit 694ff2e829

View File

@@ -39,7 +39,7 @@ class MangaOcr:
img = img.convert('L').convert('RGB') img = img.convert('L').convert('RGB')
x = self._preprocess(img) x = self._preprocess(img)
x = self.model.generate(x[None].to(self.model.device))[0].cpu() x = self.model.generate(x[None].to(self.model.device), max_length=300)[0].cpu()
x = self.tokenizer.decode(x, skip_special_tokens=True) x = self.tokenizer.decode(x, skip_special_tokens=True)
x = post_process(x) x = post_process(x)
return x return x