Allow SPieceTokenizer to load model from a byte string.

This commit is contained in:
comfyanonymous 2024-07-23 14:17:42 -04:00
parent 334ba48cea
commit 88ed893034

View File

@ -1,14 +1,18 @@
import os
class SPieceTokenizer:
add_eos = True
@staticmethod
def from_pretrained(path):
return SPieceTokenizer(path)
def __init__(self, tokenizer_path):
import sentencepiece
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path)
self.end = self.tokenizer.eos_id()
if isinstance(tokenizer_path, bytes):
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_eos=self.add_eos)
else:
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_eos=self.add_eos)
def get_vocab(self):
out = {}
@ -18,5 +22,4 @@ class SPieceTokenizer:
def __call__(self, string):
out = self.tokenizer.encode(string)
out += [self.end]
return {"input_ids": out}