Spaces:
Sleeping
Sleeping
| import os | |
| from datetime import datetime | |
| from copy import deepcopy | |
| import json | |
| import base64 | |
| from argparse import ArgumentParser | |
| from tqdm import tqdm | |
| import numpy as np | |
| from matplotlib import pyplot as plt | |
| import torch | |
| from datasets import load_dataset | |
| from model_factory import ModelFactory | |
| from stegno import generate | |
| rng = torch.Generator(device="cpu") | |
| rng.manual_seed(0) | |
| def load_msgs(msg_lens: list[int]): | |
| msgs = [] | |
| c4_en = load_dataset("allenai/c4", "en", split="validation", streaming=True) | |
| iterator = iter(c4_en) | |
| for length in tqdm(msg_lens, desc="Loading messages"): | |
| random_msg = torch.randint(256, (length,), generator=rng) | |
| msgs.append(["random", bytes(random_msg.tolist())]) | |
| while True: | |
| readable_msg = next(iterator)["text"] | |
| try: | |
| msgs.append(["readable", readable_msg[:length].encode("ascii")]) | |
| break | |
| except Exception as e: | |
| continue | |
| return msgs | |
| def load_prompts(tokenizer, n: int, prompt_size: int): | |
| prompts = [] | |
| c4_en = load_dataset("allenai/c4", "en", split="train", streaming=True) | |
| iterator = iter(c4_en) | |
| with tqdm(total=n, desc="Loading prompts") as pbar: | |
| while len(prompts) < n: | |
| text = next(iterator)["text"] | |
| input_ids = tokenizer.encode(text, return_tensors="pt") | |
| if input_ids.size(1) < prompt_size: | |
| continue | |
| truncated_text = tokenizer.batch_decode(input_ids[:, :prompt_size])[ | |
| 0 | |
| ] | |
| prompts.append(truncated_text) | |
| pbar.update() | |
| return prompts | |
| class AnalyseProcessor(object): | |
| params_names = [ | |
| "msgs", | |
| "bases", | |
| "deltas", | |
| ] | |
| def __init__( | |
| self, | |
| save_file: str, | |
| save_freq: int | None = None, | |
| gen_model: str | None = None, | |
| judge_model: str | None = None, | |
| msgs: list[bytes] | None = None, | |
| bases: list[int] | None = None, | |
| deltas: list[float] | None = None, | |
| prompts: list[str] | None = None, | |
| repeat: int = 1, | |
| gen_params: dict | None = None, | |
| batch_size: int = 1, | |
| ): | |
| self.save_file = save_file | |
| self.save_freq = save_freq | |
| self.data = { | |
| "params": { | |
| "gen_model": gen_model, | |
| "judge_model": judge_model, | |
| "ptrs": { | |
| "msgs": 0, | |
| "bases": 0, | |
| "deltas": 0, | |
| }, | |
| "values": { | |
| "msgs": msgs, | |
| "bases": bases, | |
| "deltas": deltas, | |
| }, | |
| "prompts": prompts, | |
| "batch_size": batch_size, | |
| "repeat": repeat, | |
| "gen": gen_params, | |
| }, | |
| "results": [], | |
| } | |
| self.__pbar = None | |
| self.last_saved = None | |
| self.skip_first = False | |
| def run(self, depth=0): | |
| if self.__pbar is None: | |
| total = 1 | |
| for v in self.data["params"]["values"].keys(): | |
| if v is None: | |
| raise RuntimeError(f"values must not be None when running") | |
| initial = 0 | |
| for param_name in self.params_names[::-1]: | |
| initial += total * self.data["params"]["ptrs"][param_name] | |
| total *= len(self.data["params"]["values"][param_name]) | |
| if self.skip_first: | |
| initial += 1 | |
| self.__pbar = tqdm( | |
| desc="Generating", | |
| total=total, | |
| initial=initial, | |
| ) | |
| if depth < len(self.params_names): | |
| param_name = self.params_names[depth] | |
| while self.data["params"]["ptrs"][param_name] < len( | |
| self.data["params"]["values"][param_name] | |
| ): | |
| self.run(depth + 1) | |
| self.data["params"]["ptrs"][param_name] = ( | |
| self.data["params"]["ptrs"][param_name] + 1 | |
| ) | |
| self.data["params"]["ptrs"][param_name] = 0 | |
| if depth == 0: | |
| self.save_data(self.save_file) | |
| else: | |
| if self.skip_first: | |
| self.skip_first = False | |
| return | |
| prompts = self.data["params"]["prompts"] | |
| msg_ptr = self.data["params"]["ptrs"]["msgs"] | |
| msg_type, msg = self.data["params"]["values"]["msgs"][msg_ptr] | |
| base_ptr = self.data["params"]["ptrs"]["bases"] | |
| base = self.data["params"]["values"]["bases"][base_ptr] | |
| delta_ptr = self.data["params"]["ptrs"]["deltas"] | |
| delta = self.data["params"]["values"]["deltas"][delta_ptr] | |
| model, tokenizer = ModelFactory.load_model( | |
| self.data["params"]["gen_model"] | |
| ) | |
| l = 0 | |
| while l < len(prompts): | |
| start = datetime.now() | |
| r = l + self.data["params"]["batch_size"] | |
| r = min(r, len(prompts)) | |
| texts, msgs_rates, _ = generate( | |
| model=model, | |
| tokenizer=tokenizer, | |
| prompt=prompts[l:r], | |
| msg=msg, | |
| msg_base=base, | |
| delta=delta, | |
| **self.data["params"]["gen"], | |
| ) | |
| end = datetime.now() | |
| for i in range(len(texts)): | |
| prompt_ptr = l + i | |
| text = texts[i] | |
| msg_rate = msgs_rates[i] | |
| self.data["results"].append( | |
| { | |
| "ptrs": { | |
| "prompts": prompt_ptr, | |
| "msgs": msg_ptr, | |
| "bases": base_ptr, | |
| "deltas": delta_ptr, | |
| }, | |
| "perplexity": ModelFactory.compute_perplexity( | |
| self.data["params"]["judge_model"], text | |
| ), | |
| "text": text, | |
| "msg_rate": msg_rate, | |
| "run_time (ms)": (end - start).microseconds | |
| / len(texts), | |
| } | |
| ) | |
| l += self.data["params"]["batch_size"] | |
| postfix = { | |
| "base": base, | |
| "msg_len": len(msg), | |
| "delta": delta, | |
| } | |
| self.__pbar.refresh() | |
| if self.save_freq and (self.__pbar.n + 1) % self.save_freq == 0: | |
| self.save_data(self.save_file) | |
| if self.last_saved is not None: | |
| seconds = (datetime.now() - self.last_saved).seconds | |
| minutes = seconds // 60 | |
| hours = minutes // 60 | |
| minutes %= 60 | |
| seconds %= 60 | |
| postfix["last_saved"] = f"{hours}:{minutes}:{seconds} ago" | |
| self.__pbar.set_postfix(postfix) | |
| self.__pbar.update() | |
| def __get_mean(self, ptrs: dict, value_name: str): | |
| s = 0 | |
| cnt = 0 | |
| for r in self.data["results"]: | |
| msg_type, msg = self.data["params"]["values"]["msgs"][ | |
| r["ptrs"]["msgs"] | |
| ] | |
| valid = True | |
| for k in ptrs: | |
| if ( | |
| (k in r["ptrs"] and r["ptrs"][k] != ptrs[k]) | |
| or (k == "msg_len" and len(msg) != ptrs[k]) | |
| or (k == "msg_type" and msg_type != ptrs[k]) | |
| ): | |
| valid = False | |
| break | |
| if valid: | |
| s += r[value_name] | |
| cnt += 1 | |
| if cnt == 0: | |
| cnt = 1 | |
| return s / cnt | |
| def plot(self, figs_dir: str): | |
| os.makedirs(figs_dir, exist_ok=True) | |
| msg_set = set() | |
| for msg_type, msg in self.data["params"]["values"]["msgs"]: | |
| msg_set.add((msg_type, len(msg))) | |
| msg_set = sorted(msg_set) | |
| # Delta effect | |
| os.makedirs(os.path.join(figs_dir, "delta_effect"), exist_ok=True) | |
| for value_name in ["perplexity", "msg_rate"]: | |
| fig = plt.figure(dpi=300) | |
| for base_ptr, base in enumerate( | |
| self.data["params"]["values"]["bases"] | |
| ): | |
| for msg_type, msg_len in msg_set: | |
| x = [] | |
| y = [] | |
| for delta_ptr, delta in enumerate( | |
| self.data["params"]["values"]["deltas"] | |
| ): | |
| x.append(delta) | |
| y.append( | |
| self.__get_mean( | |
| ptrs={ | |
| "bases": base_ptr, | |
| "msg_type": msg_type, | |
| "msg_len": msg_len, | |
| "deltas": delta_ptr, | |
| }, | |
| value_name=value_name, | |
| ) | |
| ) | |
| plt.plot( | |
| x, | |
| y, | |
| label=f"B={base}, msg_type={msg_type}, msg_len={msg_len}", | |
| ) | |
| plt.ylim(ymin=0) | |
| plt.legend() | |
| plt.savefig( | |
| os.path.join(figs_dir, "delta_effect", f"{value_name}.pdf"), | |
| bbox_inches="tight", | |
| ) | |
| plt.close(fig) | |
| # Message length effect | |
| os.makedirs(os.path.join(figs_dir, "msg_len_effect"), exist_ok=True) | |
| for value_name in ["perplexity", "msg_rate"]: | |
| fig = plt.figure(dpi=300) | |
| for base_ptr, base in enumerate( | |
| self.data["params"]["values"]["bases"] | |
| ): | |
| for delta_ptr, delta in enumerate( | |
| self.data["params"]["values"]["deltas"] | |
| ): | |
| x = {} | |
| y = {} | |
| for msg_type, msg_len in msg_set: | |
| if msg_type not in x: | |
| x[msg_type] = [] | |
| if msg_type not in y: | |
| y[msg_type] = [] | |
| x[msg_type].append(msg_len) | |
| y[msg_type].append( | |
| self.__get_mean( | |
| ptrs={ | |
| "bases": base_ptr, | |
| "msg_type": msg_type, | |
| "msg_len": msg_len, | |
| "deltas": delta_ptr, | |
| }, | |
| value_name=value_name, | |
| ) | |
| ) | |
| for msg_type in x: | |
| plt.plot( | |
| x[msg_type], | |
| y[msg_type], | |
| label=f"B={base}, msg_type={msg_type}, delta={delta}", | |
| ) | |
| plt.ylim(ymin=0) | |
| plt.legend() | |
| plt.savefig( | |
| os.path.join(figs_dir, "msg_len_effect", f"{value_name}.pdf"), | |
| bbox_inches="tight", | |
| ) | |
| plt.close(fig) | |
| print(f"Saved figures to {figs_dir}") | |
| def save_data(self, file_name: str): | |
| if file_name is None: | |
| return | |
| os.makedirs(os.path.dirname(file_name), exist_ok=True) | |
| data = deepcopy(self.data) | |
| for i in range(len(data["params"]["values"]["msgs"])): | |
| msg_type, msg = data["params"]["values"]["msgs"][i] | |
| if msg_type == "random": | |
| str_msg = base64.b64encode(msg).decode("ascii") | |
| else: | |
| str_msg = msg.decode("ascii") | |
| data["params"]["values"]["msgs"][i] = [msg_type, str_msg] | |
| with open(file_name, "w") as f: | |
| json.dump(data, f, indent=2) | |
| if self.__pbar is None: | |
| print(f"Saved AnalyseProcessor data to {file_name}") | |
| else: | |
| self.last_saved = datetime.now() | |
| def load_data(self, file_name: str): | |
| with open(file_name, "r") as f: | |
| self.data = json.load(f) | |
| for i in range(len(self.data["params"]["values"]["msgs"])): | |
| msg_type, str_msg = self.data["params"]["values"]["msgs"][i] | |
| if msg_type == "random": | |
| msg = base64.b64decode(str_msg) | |
| else: | |
| msg = str_msg.encode("ascii") | |
| self.data["params"]["values"]["msgs"][i] = [msg_type, msg] | |
| self.skip_first = len(self.data["results"]) > 0 | |
| self.__pbar = None | |
| def create_args(): | |
| parser = ArgumentParser() | |
| # messages | |
| parser.add_argument( | |
| "--msgs-file", type=str, default=None, help="Where messages are stored" | |
| ) | |
| parser.add_argument( | |
| "--msgs-lengths", | |
| nargs=3, | |
| type=int, | |
| help="Range of messages' lengths. This is parsed in form: <start> <end> <num>", | |
| ) | |
| parser.add_argument( | |
| "--msgs-per-length", | |
| type=int, | |
| default=5, | |
| help="Number of messages per length", | |
| ) | |
| # prompts | |
| parser.add_argument( | |
| "--prompts-file", | |
| type=str, | |
| default=None, | |
| help="Where prompts are stored", | |
| ) | |
| parser.add_argument( | |
| "--num-prompts", | |
| type=int, | |
| default=10, | |
| help="Number of prompts", | |
| ) | |
| parser.add_argument( | |
| "--prompt-size", | |
| type=int, | |
| default=50, | |
| help="Size of prompts (in tokens)", | |
| ) | |
| # Others | |
| parser.add_argument( | |
| "--overwrite", | |
| action="store_true", | |
| help="Whether to overwrite prompts and messages files", | |
| ) | |
| # Hyperparameters | |
| parser.add_argument( | |
| "--gen-model", | |
| type=str, | |
| default="gpt2", | |
| help="Model used to generate", | |
| ) | |
| parser.add_argument( | |
| "--judge-model", | |
| type=str, | |
| default="gpt2", | |
| help="Model used to compute score perplexity of generated text", | |
| ) | |
| parser.add_argument( | |
| "--deltas", | |
| nargs=3, | |
| type=float, | |
| help="Range of delta. This is parsed in form: <start> <end> <num>", | |
| ) | |
| parser.add_argument( | |
| "--bases", | |
| nargs="+", | |
| type=int, | |
| help="Bases used in base encoding", | |
| ) | |
| # Generate parameters | |
| parser.add_argument( | |
| "--do-sample", | |
| action="store_true", | |
| help="Whether to use sample or greedy search", | |
| ) | |
| parser.add_argument( | |
| "--num-beams", type=int, default=1, help="How many beams to use" | |
| ) | |
| parser.add_argument( | |
| "--batch-size", | |
| type=int, | |
| default=1, | |
| help="Batch size used for generating", | |
| ) | |
| # Results | |
| parser.add_argument( | |
| "--repeat", | |
| type=int, | |
| default=1, | |
| help="How many times to repeat for each set of parameters, prompts and messages", | |
| ) | |
| parser.add_argument( | |
| "--load-file", | |
| type=str, | |
| default=None, | |
| help="Where to load data for AnalyseProcessor", | |
| ) | |
| parser.add_argument( | |
| "--save-file", | |
| type=str, | |
| default=None, | |
| help="Where to save data for AnalyseProcessor", | |
| ) | |
| parser.add_argument( | |
| "--save-freq", type=int, default=100, help="Save frequency" | |
| ) | |
| parser.add_argument( | |
| "--figs-dir", | |
| type=str, | |
| default=None, | |
| help="Where to save figures", | |
| ) | |
| return parser.parse_args() | |
| def main(args): | |
| if not args.load_file: | |
| model, tokenizer = ModelFactory.load_model(args.gen_model) | |
| prompts = load_prompts(tokenizer, args.num_prompts, args.prompt_size) | |
| msgs_lens = [] | |
| for i in np.linspace( | |
| args.msgs_lengths[0], | |
| args.msgs_lengths[1], | |
| int(args.msgs_lengths[2]), | |
| dtype=np.int64, | |
| ): | |
| for _ in range(args.msgs_per_length): | |
| msgs_lens.append(i) | |
| msgs = load_msgs(msgs_lens) | |
| processor = AnalyseProcessor( | |
| save_file=args.save_file, | |
| save_freq=args.save_freq, | |
| gen_model=args.gen_model, | |
| judge_model=args.judge_model, | |
| msgs=msgs, | |
| bases=args.bases, | |
| deltas=np.linspace( | |
| args.deltas[0], args.deltas[1], int(args.deltas[2]) | |
| ).tolist(), | |
| prompts=prompts, | |
| batch_size=args.batch_size, | |
| gen_params=dict( | |
| start_pos_p=[0], | |
| seed_scheme="dummy_hash", | |
| window_length=1, | |
| min_new_tokens_ratio=1, | |
| max_new_tokens_ratio=1, | |
| do_sample=args.do_sample, | |
| num_beams=args.num_beams, | |
| repetition_penalty=1.0, | |
| ), | |
| ) | |
| processor.save_data(args.save_file) | |
| else: | |
| processor = AnalyseProcessor( | |
| save_file=args.save_file, | |
| save_freq=args.save_freq, | |
| ) | |
| processor.load_data(args.load_file) | |
| processor.run() | |
| processor.plot(args.figs_dir) | |
| if __name__ == "__main__": | |
| args = create_args() | |
| main(args) | |