| import re | |
| from hat_splitter import HATSplitter as RustHATSplitter | |
| class HATSplitter: | |
| def __init__(self, special_token_dict: dict | None = None, max_word_size: int = 128): | |
| self.hat_splitter = RustHATSplitter() | |
| self.max_word_size = max_word_size | |
| self.special_token_dict = special_token_dict | |
| self.special_token_replace: dict[int, list[int]] = {token: list(text.encode("utf-8")) for text, token in self.special_token_dict.items()} | |
| self.special_token_pattern = re.compile(rf"({'|'.join(map(re.escape, special_token_dict.keys()))})") if special_token_dict else re.compile(r"(?!)") | |
| def encode(self, text: str) -> list[list[int]]: | |
| chunks = [] | |
| for str_chunk in self.special_token_pattern.split(text): | |
| if str_chunk: | |
| if str_chunk in self.special_token_dict: | |
| chunks.append([self.special_token_dict[str_chunk]]) | |
| else: | |
| chunks.extend(list(chunk) for chunk in self.hat_splitter.split_with_limit(str_chunk, self.max_word_size)) | |
| return chunks | |
| def decode(self, token_ids: list[int], errors: str = "replace", skip_special_tokens: bool = False) -> str: | |
| assert isinstance(token_ids, list), "token_ids must be a list" | |
| assert all(isinstance(token_id, int) for token_id in token_ids), "token_ids must be a list of integers" | |
| new_token_ids: list[int] | |
| if skip_special_tokens: | |
| new_token_ids = [token_id for token_id in token_ids if token_id not in self.special_token_replace] | |
| else: | |
| new_token_ids = [] | |
| for token in token_ids: | |
| if token in self.special_token_replace: | |
| new_token_ids.extend(self.special_token_replace[token]) | |
| else: | |
| new_token_ids.append(token) | |
| return bytes(new_token_ids).decode("utf-8", errors=errors) | |