mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-01-11 10:25:16 +00:00
Allow SPieceTokenizer to load model from a byte string.
This commit is contained in:
parent
334ba48cea
commit
88ed893034
@ -1,14 +1,18 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
class SPieceTokenizer:
|
class SPieceTokenizer:
|
||||||
|
add_eos = True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pretrained(path):
|
def from_pretrained(path):
|
||||||
return SPieceTokenizer(path)
|
return SPieceTokenizer(path)
|
||||||
|
|
||||||
def __init__(self, tokenizer_path):
|
def __init__(self, tokenizer_path):
|
||||||
import sentencepiece
|
import sentencepiece
|
||||||
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path)
|
if isinstance(tokenizer_path, bytes):
|
||||||
self.end = self.tokenizer.eos_id()
|
self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_eos=self.add_eos)
|
||||||
|
else:
|
||||||
|
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_eos=self.add_eos)
|
||||||
|
|
||||||
def get_vocab(self):
|
def get_vocab(self):
|
||||||
out = {}
|
out = {}
|
||||||
@ -18,5 +22,4 @@ class SPieceTokenizer:
|
|||||||
|
|
||||||
def __call__(self, string):
|
def __call__(self, string):
|
||||||
out = self.tokenizer.encode(string)
|
out = self.tokenizer.encode(string)
|
||||||
out += [self.end]
|
|
||||||
return {"input_ids": out}
|
return {"input_ids": out}
|
||||||
|
Loading…
Reference in New Issue
Block a user