Spaces:
Sleeping
Sleeping
| import base64 | |
| import json | |
| import torch | |
| from fastapi import FastAPI | |
| from fastapi.openapi.utils import get_openapi | |
| import uvicorn | |
| from stegno import generate, decrypt | |
| from utils import load_model | |
| from seed_scheme_factory import SeedSchemeFactory | |
| from model_factory import ModelFactory | |
| from global_config import GlobalConfig | |
| from schemes import DecryptionBody, EncryptionBody | |
| app = FastAPI() | |
| with open("resources/examples.json", "r") as f: | |
| examples = json.load(f) | |
| async def encrypt_api( | |
| body: EncryptionBody, | |
| ): | |
| byte_msg = base64.b64decode(body.msg) | |
| model, tokenizer = ModelFactory.load_model(body.gen_model) | |
| texts, msgs_rates, tokens_infos = generate( | |
| tokenizer=tokenizer, | |
| model=model, | |
| prompt=body.prompt, | |
| msg=byte_msg, | |
| start_pos_p=[body.start_pos], | |
| delta=body.delta, | |
| msg_base=body.msg_base, | |
| seed_scheme=body.seed_scheme, | |
| window_length=body.window_length, | |
| private_key=body.private_key, | |
| min_new_tokens_ratio=body.min_new_tokens_ratio, | |
| max_new_tokens_ratio=body.max_new_tokens_ratio, | |
| do_sample=body.do_sample, | |
| num_beams=body.num_beams, | |
| repetition_penalty=body.repetition_penalty, | |
| ) | |
| return { | |
| "texts": texts, | |
| "msgs_rates": msgs_rates, | |
| "tokens_infos": tokens_infos, | |
| } | |
| async def decrypt_api(body: DecryptionBody): | |
| model, tokenizer = ModelFactory.load_model(body.gen_model) | |
| msgs = decrypt( | |
| tokenizer=tokenizer, | |
| device=model.device, | |
| text=body.text, | |
| msg_base=body.msg_base, | |
| seed_scheme=body.seed_scheme, | |
| window_length=body.window_length, | |
| private_key=body.private_key, | |
| ) | |
| msg_b64 = {} | |
| for i, s_msg in enumerate(msgs): | |
| msg_b64[i] = [] | |
| for msg in s_msg: | |
| msg_b64[i].append(base64.b64encode(msg)) | |
| return msg_b64 | |
| async def default_config(): | |
| configs = { | |
| "default": { | |
| "encrypt": { | |
| "gen_model": GlobalConfig.get("encrypt.default", "gen_model"), | |
| "start_pos": GlobalConfig.get("encrypt.default", "start_pos"), | |
| "delta": GlobalConfig.get("encrypt.default", "delta"), | |
| "msg_base": GlobalConfig.get("encrypt.default", "msg_base"), | |
| "seed_scheme": GlobalConfig.get( | |
| "encrypt.default", "seed_scheme" | |
| ), | |
| "window_length": GlobalConfig.get( | |
| "encrypt.default", "window_length" | |
| ), | |
| "private_key": GlobalConfig.get( | |
| "encrypt.default", "private_key" | |
| ), | |
| "min_new_tokens_ratio": GlobalConfig.get( | |
| "encrypt.default", "min_new_tokens_ratio" | |
| ), | |
| "max_new_tokens_ratio": GlobalConfig.get( | |
| "encrypt.default", "max_new_tokens_ratio" | |
| ), | |
| "do_sample": GlobalConfig.get("encrypt.default", "do_sample"), | |
| "num_beams": GlobalConfig.get("encrypt.default", "num_beams"), | |
| "repetition_penalty": GlobalConfig.get( | |
| "encrypt.default", "repetition_penalty" | |
| ), | |
| }, | |
| "decrypt": { | |
| "gen_model": GlobalConfig.get("encrypt.default", "gen_model"), | |
| "msg_base": GlobalConfig.get("encrypt.default", "msg_base"), | |
| "seed_scheme": GlobalConfig.get( | |
| "encrypt.default", "seed_scheme" | |
| ), | |
| "window_length": GlobalConfig.get( | |
| "encrypt.default", "window_length" | |
| ), | |
| "private_key": GlobalConfig.get( | |
| "encrypt.default", "private_key" | |
| ), | |
| }, | |
| }, | |
| "seed_schemes": SeedSchemeFactory.get_schemes_name(), | |
| "models": ModelFactory.get_models_names(), | |
| } | |
| return configs | |
| if __name__ == "__main__": | |
| # The following are mainly used to satisfy the linter | |
| host = GlobalConfig.get("server", "host") | |
| host = str(host) if host is not None else "0.0.0.0" | |
| port = GlobalConfig.get("server", "port") | |
| port = int(port) if port is not None else 8000 | |
| workers = GlobalConfig.get("server", "workers") | |
| workers = int(workers) if workers is not None else 1 | |
| uvicorn.run("api:app", host=host, port=port, workers=workers) | |