Spaces:
Sleeping
Sleeping
Add parameters to RestAPI
Browse files- api.py +3 -1
- schemes.py +5 -1
- stegno.py +3 -2
api.py
CHANGED
|
@@ -45,14 +45,16 @@ async def encrypt_api(
|
|
| 45 |
seed_scheme=body.seed_scheme,
|
| 46 |
window_length=body.window_length,
|
| 47 |
private_key=body.private_key,
|
|
|
|
| 48 |
max_new_tokens_ratio=body.max_new_tokens_ratio,
|
|
|
|
| 49 |
num_beams=body.num_beams,
|
| 50 |
repetition_penalty=body.repetition_penalty,
|
| 51 |
)
|
| 52 |
return {
|
| 53 |
"texts": texts,
|
| 54 |
"msgs_rates": msgs_rates,
|
| 55 |
-
"
|
| 56 |
}
|
| 57 |
|
| 58 |
|
|
|
|
| 45 |
seed_scheme=body.seed_scheme,
|
| 46 |
window_length=body.window_length,
|
| 47 |
private_key=body.private_key,
|
| 48 |
+
min_new_tokens_ratio=body.min_new_tokens_ratio,
|
| 49 |
max_new_tokens_ratio=body.max_new_tokens_ratio,
|
| 50 |
+
do_sample=body.do_sample,
|
| 51 |
num_beams=body.num_beams,
|
| 52 |
repetition_penalty=body.repetition_penalty,
|
| 53 |
)
|
| 54 |
return {
|
| 55 |
"texts": texts,
|
| 56 |
"msgs_rates": msgs_rates,
|
| 57 |
+
"tokens_infos": tokens_infos,
|
| 58 |
}
|
| 59 |
|
| 60 |
|
schemes.py
CHANGED
|
@@ -49,7 +49,7 @@ class EncryptionBody(BaseModel):
|
|
| 49 |
title="Private key used to compute the seed for PRF",
|
| 50 |
ge=0,
|
| 51 |
)
|
| 52 |
-
|
| 53 |
default=GlobalConfig.get("encrypt.default", "min_new_tokens_ratio"),
|
| 54 |
title="Min length of generated text compared to the minimum length required to hide the message",
|
| 55 |
ge=1,
|
|
@@ -64,6 +64,10 @@ class EncryptionBody(BaseModel):
|
|
| 64 |
title="Number of beams used in beam search",
|
| 65 |
ge=1,
|
| 66 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
repetition_penalty: float = Field(
|
| 69 |
default=GlobalConfig.get("encrypt.default", "repetition_penalty"),
|
|
|
|
| 49 |
title="Private key used to compute the seed for PRF",
|
| 50 |
ge=0,
|
| 51 |
)
|
| 52 |
+
min_new_tokens_ratio: float = Field(
|
| 53 |
default=GlobalConfig.get("encrypt.default", "min_new_tokens_ratio"),
|
| 54 |
title="Min length of generated text compared to the minimum length required to hide the message",
|
| 55 |
ge=1,
|
|
|
|
| 64 |
title="Number of beams used in beam search",
|
| 65 |
ge=1,
|
| 66 |
)
|
| 67 |
+
do_sample: bool = Field(
|
| 68 |
+
default=GlobalConfig.get("encrypt.default", "do_sample"),
|
| 69 |
+
title="Whether to use greedy or sampling generating"
|
| 70 |
+
)
|
| 71 |
|
| 72 |
repetition_penalty: float = Field(
|
| 73 |
default=GlobalConfig.get("encrypt.default", "repetition_penalty"),
|
stegno.py
CHANGED
|
@@ -78,18 +78,19 @@ def generate(
|
|
| 78 |
salt_key=salt_key,
|
| 79 |
private_key=private_key,
|
| 80 |
)
|
| 81 |
-
min_length = (
|
| 82 |
prompt_size
|
| 83 |
+ start_pos
|
| 84 |
+ logits_processor.get_message_len() * min_new_tokens_ratio
|
| 85 |
)
|
| 86 |
-
max_length = (
|
| 87 |
prompt_size
|
| 88 |
+ start_pos
|
| 89 |
+ logits_processor.get_message_len() * max_new_tokens_ratio
|
| 90 |
)
|
| 91 |
max_length = min(max_length, tokenizer.model_max_length)
|
| 92 |
min_length = min(min_length, max_length)
|
|
|
|
| 93 |
output_tokens = model.generate(
|
| 94 |
**tokenized_input,
|
| 95 |
logits_processor=transformers.LogitsProcessorList([logits_processor]),
|
|
|
|
| 78 |
salt_key=salt_key,
|
| 79 |
private_key=private_key,
|
| 80 |
)
|
| 81 |
+
min_length = int(
|
| 82 |
prompt_size
|
| 83 |
+ start_pos
|
| 84 |
+ logits_processor.get_message_len() * min_new_tokens_ratio
|
| 85 |
)
|
| 86 |
+
max_length = int(
|
| 87 |
prompt_size
|
| 88 |
+ start_pos
|
| 89 |
+ logits_processor.get_message_len() * max_new_tokens_ratio
|
| 90 |
)
|
| 91 |
max_length = min(max_length, tokenizer.model_max_length)
|
| 92 |
min_length = min(min_length, max_length)
|
| 93 |
+
|
| 94 |
output_tokens = model.generate(
|
| 95 |
**tokenized_input,
|
| 96 |
logits_processor=transformers.LogitsProcessorList([logits_processor]),
|