Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Union | |
| import torch | |
| from transformers import LogitsProcessor | |
| from seed_scheme_factory import SeedSchemeFactory | |
| from utils import bytes_to_base, base_to_bytes, get_values_per_byte | |
| class BaseProcessor(object): | |
| def __init__( | |
| self, | |
| msg_base: int, | |
| vocab: list[int], | |
| device: torch.device, | |
| seed_scheme: str, | |
| window_length: int = 1, | |
| salt_key: Union[int, None] = None, | |
| private_key: Union[int, None] = None, | |
| ): | |
| """ | |
| Args: | |
| msg_base: base of the message. | |
| vocab: vocabulary list. | |
| device: device to load processor. | |
| seed_scheme: scheme used to compute the seed. | |
| window_length: length of window to compute the seed. | |
| salt_key: salt to add to the seed. | |
| private_key: private key used to compute the seed. | |
| """ | |
| # Universal parameters | |
| self.msg_base = msg_base | |
| self.vocab = vocab | |
| self.vocab_size = len(vocab) | |
| self.device = device | |
| # Seed parameters | |
| seed_fn = SeedSchemeFactory.get_instance( | |
| seed_scheme, | |
| salt_key=salt_key, | |
| private_key=private_key, | |
| ) | |
| if seed_fn is None: | |
| raise ValueError(f'Seed scheme "{seed_scheme}" is invalid') | |
| else: | |
| self.seed_fn = seed_fn | |
| self.window_length = window_length | |
| # Initialize RNG, always use cpu generator | |
| self.rng = torch.Generator(device="cpu") | |
| # Compute the ranges of each value in base | |
| self.ranges = torch.zeros((self.msg_base + 1), dtype=torch.int64) | |
| chunk_size = self.vocab_size / self.msg_base | |
| r = self.vocab_size % self.msg_base | |
| self.ranges[1:] = chunk_size | |
| self.ranges[1 : r + 1] += 1 | |
| self.ranges = torch.cumsum(self.ranges, dim=0) | |
| def _seed_rng(self, input_ids: torch.Tensor): | |
| """ | |
| Set the seed for the rng based on the current sequences. | |
| Args: | |
| input_ids: id in the input sequence. | |
| """ | |
| seed = self.seed_fn(input_ids[-self.window_length :]) | |
| self.rng.manual_seed(seed) | |
| def _get_valid_list_ids(self, input_ids: torch.Tensor, value: int): | |
| """ | |
| Get ids of tokens in the valid list for the current sequences. | |
| """ | |
| self._seed_rng(input_ids) | |
| vocab_perm = torch.randperm( | |
| self.vocab_size, generator=self.rng, device="cpu" | |
| ).to(self.device) | |
| vocab_list = vocab_perm[self.ranges[value] : self.ranges[value + 1]] | |
| return vocab_list | |
| def _get_value(self, input_ids: torch.Tensor): | |
| """ | |
| Check whether the token is in the valid list. | |
| """ | |
| self._seed_rng(input_ids[:-1]) | |
| vocab_perm = torch.randperm( | |
| self.vocab_size, generator=self.rng, device="cpu" | |
| ).to(self.device) | |
| cur_token = input_ids[-1] | |
| cur_id = (vocab_perm == cur_token).nonzero(as_tuple=True)[0] | |
| value = (cur_id < self.ranges).type(torch.int).argmax().item() - 1 | |
| return value | |
| class EncryptorLogitsProcessor(LogitsProcessor, BaseProcessor): | |
| def __init__( | |
| self, | |
| prompt_ids: torch.Tensor, | |
| msg: bytes, | |
| gamma: float, | |
| start_pos: int = 0, | |
| *args, | |
| **kwargs, | |
| ): | |
| """ | |
| Args: | |
| msg: message to hide in the text. | |
| gamma: bias add to scores of token in valid list. | |
| """ | |
| super().__init__(*args, **kwargs) | |
| if prompt_ids.size(0) != 1: | |
| raise RuntimeError( | |
| "EncryptorLogitsProcessor does not support multiple prompts input." | |
| ) | |
| self.prompt_size = prompt_ids.size(1) | |
| self.start_pos = start_pos | |
| self.raw_msg = msg | |
| self.msg = bytes_to_base(msg, self.msg_base) | |
| self.gamma = gamma | |
| def __call__( | |
| self, input_ids_batch: torch.LongTensor, scores_batch: torch.FloatTensor | |
| ): | |
| # If the whole message is hidden already, then just return the raw scores. | |
| for i, input_ids in enumerate(input_ids_batch): | |
| cur_pos = input_ids.size(0) | |
| msg_ptr = cur_pos - (self.prompt_size + self.start_pos) | |
| if msg_ptr < 0 or msg_ptr >= len(self.msg): | |
| continue | |
| scores_batch[i] = self._add_bias_to_valid_list( | |
| input_ids, scores_batch[i], self.msg[msg_ptr] | |
| ) | |
| return scores_batch | |
| def _add_bias_to_valid_list( | |
| self, input_ids: torch.Tensor, scores: torch.Tensor, value: int | |
| ): | |
| """ | |
| Add the bias (gamma) to the valid list tokens | |
| """ | |
| ids = self._get_valid_list_ids(input_ids, value) | |
| scores[ids] = scores[ids] + self.gamma | |
| return scores | |
| def get_message_len(self): | |
| return len(self.msg) | |
| def validate(self, input_ids_batch: torch.Tensor): | |
| res = [] | |
| for input_ids in input_ids_batch: | |
| values = [] | |
| for i in range(self.start_pos, input_ids.size(0)): | |
| values.append(self._get_value(input_ids[: i + 1])) | |
| enc_msg = base_to_bytes(values, self.msg_base) | |
| cnt = 0 | |
| for i in range(len(self.raw_msg)): | |
| if self.raw_msg[i] == enc_msg[i]: | |
| cnt += 1 | |
| res.append(cnt / len(self.raw_msg)) | |
| return res | |
| class DecryptorProcessor(BaseProcessor): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def decrypt(self, input_ids_batch: torch.Tensor): | |
| """ | |
| Decrypt the text sequences. | |
| """ | |
| shift_msg = [] | |
| for shift in range(get_values_per_byte(self.msg_base)): | |
| msg = [] | |
| bytes_msg = [] | |
| for i, input_ids in enumerate(input_ids_batch): | |
| msg.append(list()) | |
| for j in range(self.window_length + shift, len(input_ids)): | |
| # TODO: this could be slow. Considering reimplement this. | |
| value = self._get_value(input_ids[: j + 1]) | |
| msg[i].append(value) | |
| bytes_msg.append(base_to_bytes(msg[i], self.msg_base)) | |
| shift_msg.append(bytes_msg) | |
| return shift_msg | |