Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- README.md +149 -0
- config.json +30 -0
- custom_generate/beam_search.py +501 -0
- custom_generate/generate.py +539 -0
- generation_config.json +13 -0
- merges.txt +0 -0
- model.safetensors +3 -0
- tokenizer.json +3 -0
- tokenizer_config.json +239 -0
- vocab.json +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
|
| 3 |
+
library_name: transformers
|
| 4 |
+
tags:
|
| 5 |
+
- custom_generate
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
## Description
|
| 9 |
+
|
| 10 |
+
[Diverse beam search](https://hf.co/papers/1610.02424) is a variant of beam search that produces more diverse output candidates to choose from. This strategy measures the dissimilarity of sequences and a penalty is applied if sequences are too similar. To avoid high computation costs, the number of beams is divided into groups.
|
| 11 |
+
|
| 12 |
+
Enable diverse beam search with the `num_beams`, `num_beam_groups` and `diversity_penalty` parameters (the `num_beams` parameter should be divisible by `num_beam_groups`).
|
| 13 |
+
|
| 14 |
+
```py
|
| 15 |
+
import torch
|
| 16 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, infer_device
|
| 17 |
+
|
| 18 |
+
device = infer_device()
|
| 19 |
+
|
| 20 |
+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
| 21 |
+
inputs = tokenizer("Hugging Face is an open-source company", return_tensors="pt").to(device)
|
| 22 |
+
|
| 23 |
+
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", dtype=torch.float16).to(device)
|
| 24 |
+
# explicitly set to 100 because Llama2 generation length is 4096
|
| 25 |
+
outputs = model.generate(**inputs, max_new_tokens=50, num_beams=6, num_beam_groups=3, diversity_penalty=1.0, do_sample=False)
|
| 26 |
+
tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
| 27 |
+
'Hugging Face is an open-source company 🤗\nWe are an open-source company. Our mission is to democratize AI and make it accessible to everyone. We believe that AI should be used for the benefit of humanity, not for the benefit of a'
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
DoLa works by **contrasting the logits** from the final layer with those from earlier layers of the model,
|
| 33 |
+
amplifying factual knowledge localized in specific layers and suppressing spurious information.
|
| 34 |
+
|
| 35 |
+
This can be useful for:
|
| 36 |
+
|
| 37 |
+
* **Short-answer tasks** (e.g., TruthfulQA) — using higher layers (`dola_layers="high"`)
|
| 38 |
+
* **Long-answer reasoning tasks** (e.g., GSM8K, StrategyQA, FACTOR, VicunaQA) — using lower layers (`dola_layers="low"`)
|
| 39 |
+
|
| 40 |
+
DoLa is **not recommended for smaller models** such as GPT-2, as the improvement may be negligible.
|
| 41 |
+
|
| 42 |
+
This implementation matches the `DoLa` functionality present in `transformers<4.53.0`.
|
| 43 |
+
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
## Base model
|
| 47 |
+
|
| 48 |
+
* [Qwen/Qwen3-0.6B](https://huggingface.co/Qwen/Qwen3-0.6B)
|
| 49 |
+
|
| 50 |
+
---
|
| 51 |
+
|
| 52 |
+
## Model compatibility
|
| 53 |
+
|
| 54 |
+
* Decoder-only transformer models
|
| 55 |
+
|
| 56 |
+
---
|
| 57 |
+
|
| 58 |
+
## Additional Arguments
|
| 59 |
+
|
| 60 |
+
* **`dola_layers`** (*str* or *List\[int]*, optional):
|
| 61 |
+
Which earlier layers to contrast with the final layer. Can be:
|
| 62 |
+
|
| 63 |
+
* `"low"` — lower half of layers (recommended for long answers)
|
| 64 |
+
* `"high"` — upper half of layers (recommended for short answers)
|
| 65 |
+
* List of integer indices (e.g., `[18, 20]`)
|
| 66 |
+
|
| 67 |
+
**Note:**
|
| 68 |
+
|
| 69 |
+
* Layer 0 is the word embedding; layer 1 is the first transformer block.
|
| 70 |
+
* If the model has tied word embeddings, layer 0 is skipped and counting starts at layer 2.
|
| 71 |
+
* Typical defaults:
|
| 72 |
+
|
| 73 |
+
| # Layers | `"low"` range | `"high"` range |
|
| 74 |
+
| -------- | ------------------- | ------------------- |
|
| 75 |
+
| > 40 | `(0, 20, 2)` | `(N - 20, N, 2)` |
|
| 76 |
+
| ≤ 40 | `range(0, N//2, 2)` | `range(N//2, N, 2)` |
|
| 77 |
+
|
| 78 |
+
* **`repetition_penalty`** (*float*, optional, defaults to `None`):
|
| 79 |
+
Helps reduce repetition. A value of `1.2` is recommended.
|
| 80 |
+
|
| 81 |
+
---
|
| 82 |
+
|
| 83 |
+
## Output Type changes
|
| 84 |
+
|
| 85 |
+
* The `generate` method output remains the same as default `transformers` generation,
|
| 86 |
+
but logits are post-processed using the DoLa contrastive scoring before token selection.
|
| 87 |
+
|
| 88 |
+
---
|
| 89 |
+
|
| 90 |
+
## Example usage
|
| 91 |
+
|
| 92 |
+
### Using higher layers (short-answer tasks)
|
| 93 |
+
|
| 94 |
+
```python
|
| 95 |
+
# requires `transformers>=4.56.0`, previously, it was part of the library
|
| 96 |
+
import torch
|
| 97 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, infer_device
|
| 98 |
+
|
| 99 |
+
device = infer_device()
|
| 100 |
+
|
| 101 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
| 102 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 103 |
+
"Qwen/Qwen3-0.6B", torch_dtype=torch.float16
|
| 104 |
+
).to(device)
|
| 105 |
+
|
| 106 |
+
inputs = tokenizer("What is the highest peak in the world?", return_tensors="pt").to(device)
|
| 107 |
+
|
| 108 |
+
outputs = model.generate(
|
| 109 |
+
**inputs,
|
| 110 |
+
max_new_tokens=50,
|
| 111 |
+
do_sample=False,
|
| 112 |
+
custom_generate="transformers-community/dola",
|
| 113 |
+
trust_remote_code=True,
|
| 114 |
+
dola_layers="high"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
print(tokenizer.batch_decode(outputs, skip_special_tokens=True))
|
| 118 |
+
```
|
| 119 |
+
|
| 120 |
+
---
|
| 121 |
+
|
| 122 |
+
### Contrasting specific layers
|
| 123 |
+
|
| 124 |
+
```python
|
| 125 |
+
import torch
|
| 126 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, infer_device
|
| 127 |
+
|
| 128 |
+
device = infer_device()
|
| 129 |
+
|
| 130 |
+
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
| 131 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 132 |
+
"Qwen/Qwen3-0.6B", torch_dtype=torch.float16
|
| 133 |
+
).to(device)
|
| 134 |
+
|
| 135 |
+
inputs = tokenizer("What is the highest peak in the world?", return_tensors="pt").to(device)
|
| 136 |
+
|
| 137 |
+
outputs = model.generate(
|
| 138 |
+
**inputs,
|
| 139 |
+
max_new_tokens=50,
|
| 140 |
+
do_sample=False,
|
| 141 |
+
repetition_penalty=1.2,
|
| 142 |
+
custom_generate="transformers-community/dola",
|
| 143 |
+
trust_remote_code=True,
|
| 144 |
+
dola_layers=[18, 20]
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Only decode the newly generated tokens
|
| 148 |
+
print(tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[-1]:], skip_special_tokens=True))
|
| 149 |
+
```
|
config.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"Qwen3ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_bias": false,
|
| 6 |
+
"attention_dropout": 0.0,
|
| 7 |
+
"bos_token_id": 151643,
|
| 8 |
+
"eos_token_id": 151645,
|
| 9 |
+
"head_dim": 128,
|
| 10 |
+
"hidden_act": "silu",
|
| 11 |
+
"hidden_size": 1024,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": 3072,
|
| 14 |
+
"max_position_embeddings": 40960,
|
| 15 |
+
"max_window_layers": 28,
|
| 16 |
+
"model_type": "qwen3",
|
| 17 |
+
"num_attention_heads": 16,
|
| 18 |
+
"num_hidden_layers": 28,
|
| 19 |
+
"num_key_value_heads": 8,
|
| 20 |
+
"rms_norm_eps": 1e-06,
|
| 21 |
+
"rope_scaling": null,
|
| 22 |
+
"rope_theta": 1000000,
|
| 23 |
+
"sliding_window": null,
|
| 24 |
+
"tie_word_embeddings": true,
|
| 25 |
+
"torch_dtype": "bfloat16",
|
| 26 |
+
"transformers_version": "4.56.0",
|
| 27 |
+
"use_cache": true,
|
| 28 |
+
"use_sliding_window": false,
|
| 29 |
+
"vocab_size": 151936
|
| 30 |
+
}
|
custom_generate/beam_search.py
ADDED
|
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2020 The HuggingFace Inc. team
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from abc import ABC, abstractmethod
|
| 18 |
+
from collections import UserDict
|
| 19 |
+
from typing import Optional, Union
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from transformers.utils import add_start_docstrings
|
| 24 |
+
|
| 25 |
+
PROCESS_INPUTS_DOCSTRING = r"""
|
| 26 |
+
Args:
|
| 27 |
+
input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
|
| 28 |
+
Indices of input sequence tokens in the vocabulary.
|
| 29 |
+
|
| 30 |
+
Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
|
| 31 |
+
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
|
| 32 |
+
|
| 33 |
+
[What are input IDs?](../glossary#input-ids)
|
| 34 |
+
next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
|
| 35 |
+
Current scores of the top `2 * num_beams` non-finished beam hypotheses.
|
| 36 |
+
next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
|
| 37 |
+
`input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
|
| 38 |
+
next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
|
| 39 |
+
Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
|
| 40 |
+
pad_token_id (`int`, *optional*):
|
| 41 |
+
The id of the *padding* token.
|
| 42 |
+
eos_token_id (`Union[int, list[int]]`, *optional*):
|
| 43 |
+
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
| 44 |
+
beam_indices (`torch.LongTensor`, *optional*):
|
| 45 |
+
Beam indices indicating to which beam hypothesis each token correspond.
|
| 46 |
+
group_index (`int`, *optional*):
|
| 47 |
+
The index of the group of beams. Used with [`~PreTrainedModel.group_beam_search`].
|
| 48 |
+
|
| 49 |
+
Return:
|
| 50 |
+
`UserDict`: A dictionary composed of the fields as defined above:
|
| 51 |
+
|
| 52 |
+
- **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of all
|
| 53 |
+
non-finished beams.
|
| 54 |
+
- **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be added
|
| 55 |
+
to the non-finished beam_hypotheses.
|
| 56 |
+
- **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
|
| 57 |
+
indicating to which beam the next tokens shall be added.
|
| 58 |
+
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
FINALIZE_INPUTS_DOCSTRING = r"""
|
| 62 |
+
Args:
|
| 63 |
+
input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
|
| 64 |
+
Indices of input sequence tokens in the vocabulary.
|
| 65 |
+
|
| 66 |
+
Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
|
| 67 |
+
[`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
|
| 68 |
+
|
| 69 |
+
[What are input IDs?](../glossary#input-ids)
|
| 70 |
+
final_beam_scores (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
|
| 71 |
+
The final scores of all non-finished beams.
|
| 72 |
+
final_beam_tokens (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
|
| 73 |
+
The last tokens to be added to the non-finished beam_hypotheses.
|
| 74 |
+
final_beam_indices (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
|
| 75 |
+
The beam indices indicating to which beam the `final_beam_tokens` shall be added.
|
| 76 |
+
pad_token_id (`int`, *optional*):
|
| 77 |
+
The id of the *padding* token.
|
| 78 |
+
eos_token_id (`Union[int, list[int]]`, *optional*):
|
| 79 |
+
The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
|
| 80 |
+
|
| 81 |
+
Return:
|
| 82 |
+
`torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences.
|
| 83 |
+
The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
|
| 84 |
+
due to the `eos_token_id`.
|
| 85 |
+
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class BeamHypotheses:
|
| 90 |
+
def __init__(self, num_beams: int, length_penalty: float, early_stopping: bool, max_length: Optional[int] = None):
|
| 91 |
+
"""
|
| 92 |
+
Initialize n-best list of hypotheses.
|
| 93 |
+
"""
|
| 94 |
+
self.length_penalty = length_penalty
|
| 95 |
+
self.early_stopping = early_stopping
|
| 96 |
+
self.max_length = max_length
|
| 97 |
+
self.num_beams = num_beams
|
| 98 |
+
self.beams = []
|
| 99 |
+
self.worst_score = 1e9
|
| 100 |
+
|
| 101 |
+
if not isinstance(self.early_stopping, bool) and self.max_length is None:
|
| 102 |
+
raise ValueError(
|
| 103 |
+
"When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
|
| 104 |
+
" BeamScorer class instance at initialization time."
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def __len__(self):
|
| 108 |
+
"""
|
| 109 |
+
Number of hypotheses in the list.
|
| 110 |
+
"""
|
| 111 |
+
return len(self.beams)
|
| 112 |
+
|
| 113 |
+
def add(
|
| 114 |
+
self,
|
| 115 |
+
hyp: torch.LongTensor,
|
| 116 |
+
sum_logprobs: float,
|
| 117 |
+
beam_indices: Optional[torch.LongTensor] = None,
|
| 118 |
+
generated_len: Optional[int] = None,
|
| 119 |
+
):
|
| 120 |
+
"""
|
| 121 |
+
Add a new hypothesis to the list.
|
| 122 |
+
"""
|
| 123 |
+
if generated_len is not None:
|
| 124 |
+
score = sum_logprobs / (generated_len**self.length_penalty)
|
| 125 |
+
# This 'else' case exists for retrocompatibility
|
| 126 |
+
else:
|
| 127 |
+
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
|
| 128 |
+
|
| 129 |
+
if len(self) < self.num_beams or score > self.worst_score:
|
| 130 |
+
self.beams.append((score, hyp, beam_indices))
|
| 131 |
+
if len(self) > self.num_beams:
|
| 132 |
+
sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
|
| 133 |
+
del self.beams[sorted_next_scores[0][1]]
|
| 134 |
+
self.worst_score = sorted_next_scores[1][0]
|
| 135 |
+
else:
|
| 136 |
+
self.worst_score = min(score, self.worst_score)
|
| 137 |
+
|
| 138 |
+
def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool:
|
| 139 |
+
"""
|
| 140 |
+
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
|
| 141 |
+
one in the heap, then we are done with this sentence.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
if len(self) < self.num_beams:
|
| 145 |
+
return False
|
| 146 |
+
|
| 147 |
+
# `True`: stop as soon as at least `num_beams` hypotheses are finished
|
| 148 |
+
if self.early_stopping is True:
|
| 149 |
+
return True
|
| 150 |
+
# `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate
|
| 151 |
+
# when `length_penalty` is positive. See the discussion below for more details.
|
| 152 |
+
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
|
| 153 |
+
elif self.early_stopping is False:
|
| 154 |
+
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
|
| 155 |
+
ret = self.worst_score >= highest_attainable_score
|
| 156 |
+
return ret
|
| 157 |
+
# `"never"`: compute the best possible score, depending on the signal of `length_penalty`
|
| 158 |
+
else:
|
| 159 |
+
# `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min
|
| 160 |
+
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
|
| 161 |
+
# its max this way
|
| 162 |
+
if self.length_penalty > 0.0:
|
| 163 |
+
if self.max_length <= decoder_prompt_len:
|
| 164 |
+
raise ValueError("max_length is not larger than decoder prompt length")
|
| 165 |
+
highest_attainable_score = (
|
| 166 |
+
best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
|
| 167 |
+
)
|
| 168 |
+
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
|
| 169 |
+
else:
|
| 170 |
+
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
|
| 171 |
+
ret = self.worst_score >= highest_attainable_score
|
| 172 |
+
return ret
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class BeamScorer(ABC):
|
| 177 |
+
"""
|
| 178 |
+
Abstract base class for all beam scorers that are used for [`~PreTrainedModel.beam_search`] and
|
| 179 |
+
[`~PreTrainedModel.beam_sample`].
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
@abstractmethod
|
| 183 |
+
@add_start_docstrings(PROCESS_INPUTS_DOCSTRING)
|
| 184 |
+
def process(
|
| 185 |
+
self,
|
| 186 |
+
input_ids: torch.LongTensor,
|
| 187 |
+
next_scores: torch.FloatTensor,
|
| 188 |
+
next_tokens: torch.LongTensor,
|
| 189 |
+
next_indices: torch.LongTensor,
|
| 190 |
+
**kwargs,
|
| 191 |
+
) -> tuple[torch.Tensor]:
|
| 192 |
+
raise NotImplementedError("This is an abstract method.")
|
| 193 |
+
|
| 194 |
+
@abstractmethod
|
| 195 |
+
@add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)
|
| 196 |
+
def finalize(
|
| 197 |
+
self,
|
| 198 |
+
input_ids: torch.LongTensor,
|
| 199 |
+
next_scores: torch.FloatTensor,
|
| 200 |
+
next_tokens: torch.LongTensor,
|
| 201 |
+
next_indices: torch.LongTensor,
|
| 202 |
+
max_length: int,
|
| 203 |
+
**kwargs,
|
| 204 |
+
) -> torch.LongTensor:
|
| 205 |
+
raise NotImplementedError("This is an abstract method.")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class BeamSearchScorer(BeamScorer):
|
| 209 |
+
r"""
|
| 210 |
+
[`BeamScorer`] implementing standard beam search decoding.
|
| 211 |
+
|
| 212 |
+
Adapted in part from [Facebook's XLM beam search
|
| 213 |
+
code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529).
|
| 214 |
+
|
| 215 |
+
Reference for the diverse beam search algorithm and implementation [Ashwin Kalyan's DBS
|
| 216 |
+
implementation](https://github.com/ashwinkalyan/dbs/blob/master/dbs/beam_utils.lua)
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
batch_size (`int`):
|
| 220 |
+
Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
|
| 221 |
+
num_beams (`int`):
|
| 222 |
+
Number of beams for beam search.
|
| 223 |
+
device (`torch.device`):
|
| 224 |
+
Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
|
| 225 |
+
allocated.
|
| 226 |
+
length_penalty (`float`, *optional*, defaults to 1.0):
|
| 227 |
+
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
|
| 228 |
+
the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
|
| 229 |
+
likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
|
| 230 |
+
`length_penalty` < 0.0 encourages shorter sequences.
|
| 231 |
+
do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
|
| 232 |
+
Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
|
| 233 |
+
`True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
|
| 234 |
+
heuristic is applied and the generation stops when is it very unlikely to find better candidates;
|
| 235 |
+
`"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
|
| 236 |
+
beam search algorithm).
|
| 237 |
+
num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
|
| 238 |
+
The number of beam hypotheses that shall be returned upon calling
|
| 239 |
+
[`~transformers.BeamSearchScorer.finalize`].
|
| 240 |
+
num_beam_groups (`int`, *optional*, defaults to 1):
|
| 241 |
+
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
|
| 242 |
+
See [this paper](https://huggingface.co/papers/1610.02424) for more details.
|
| 243 |
+
max_length (`int`, *optional*):
|
| 244 |
+
The maximum length of the sequence to be generated.
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
def __init__(
|
| 248 |
+
self,
|
| 249 |
+
batch_size: int,
|
| 250 |
+
num_beams: int,
|
| 251 |
+
device: torch.device,
|
| 252 |
+
length_penalty: Optional[float] = 1.0,
|
| 253 |
+
do_early_stopping: Optional[Union[bool, str]] = False,
|
| 254 |
+
num_beam_hyps_to_keep: Optional[int] = 1,
|
| 255 |
+
num_beam_groups: Optional[int] = 1,
|
| 256 |
+
max_length: Optional[int] = None,
|
| 257 |
+
):
|
| 258 |
+
self.num_beams = num_beams
|
| 259 |
+
self.device = device
|
| 260 |
+
self.length_penalty = length_penalty
|
| 261 |
+
self.do_early_stopping = do_early_stopping
|
| 262 |
+
self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
|
| 263 |
+
self.num_beam_groups = num_beam_groups
|
| 264 |
+
self.group_size = self.num_beams // self.num_beam_groups
|
| 265 |
+
|
| 266 |
+
self._is_init = False
|
| 267 |
+
# self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch.
|
| 268 |
+
# If group_beam_search is not used, the list consists of `batch_size` beam_hyps.
|
| 269 |
+
self._beam_hyps = [
|
| 270 |
+
BeamHypotheses(
|
| 271 |
+
num_beams=self.group_size,
|
| 272 |
+
length_penalty=self.length_penalty,
|
| 273 |
+
early_stopping=self.do_early_stopping,
|
| 274 |
+
max_length=max_length,
|
| 275 |
+
)
|
| 276 |
+
for _ in range(batch_size * self.num_beam_groups)
|
| 277 |
+
]
|
| 278 |
+
# self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group
|
| 279 |
+
# in the i-th mini-batch is complete.
|
| 280 |
+
self._done = torch.tensor(
|
| 281 |
+
[False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
if not isinstance(num_beams, int) or num_beams <= 1:
|
| 285 |
+
raise ValueError(
|
| 286 |
+
f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
|
| 287 |
+
" one should make use of `greedy_search` instead."
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
|
| 291 |
+
raise ValueError(
|
| 292 |
+
"`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
|
| 293 |
+
f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
@property
|
| 297 |
+
def is_done(self) -> bool:
|
| 298 |
+
return self._done.all()
|
| 299 |
+
|
| 300 |
+
def process(
|
| 301 |
+
self,
|
| 302 |
+
input_ids: torch.LongTensor,
|
| 303 |
+
next_scores: torch.FloatTensor,
|
| 304 |
+
next_tokens: torch.LongTensor,
|
| 305 |
+
next_indices: torch.LongTensor,
|
| 306 |
+
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
|
| 307 |
+
eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
|
| 308 |
+
beam_indices: Optional[torch.LongTensor] = None,
|
| 309 |
+
group_index: Optional[int] = 0,
|
| 310 |
+
decoder_prompt_len: Optional[int] = 0,
|
| 311 |
+
) -> dict[str, torch.Tensor]:
|
| 312 |
+
# add up to the length which the next_scores is calculated on (including decoder prompt)
|
| 313 |
+
cur_len = input_ids.shape[-1] + 1
|
| 314 |
+
batch_size = len(self._beam_hyps) // self.num_beam_groups
|
| 315 |
+
|
| 316 |
+
if batch_size != (input_ids.shape[0] // self.group_size):
|
| 317 |
+
if self.num_beam_groups > 1:
|
| 318 |
+
raise ValueError(
|
| 319 |
+
f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
|
| 320 |
+
f"size of {self.group_size} is expected by the beam scorer."
|
| 321 |
+
)
|
| 322 |
+
else:
|
| 323 |
+
raise ValueError(
|
| 324 |
+
f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
|
| 325 |
+
f"{self.group_size} is expected by the beam scorer."
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
device = input_ids.device
|
| 329 |
+
next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
|
| 330 |
+
next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
|
| 331 |
+
next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
|
| 332 |
+
|
| 333 |
+
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
|
| 334 |
+
if isinstance(eos_token_id, int):
|
| 335 |
+
eos_token_id = [eos_token_id]
|
| 336 |
+
eos_token_id = torch.tensor(eos_token_id)
|
| 337 |
+
|
| 338 |
+
for batch_idx in range(batch_size):
|
| 339 |
+
batch_group_idx = batch_idx * self.num_beam_groups + group_index
|
| 340 |
+
if self._done[batch_group_idx]:
|
| 341 |
+
if self.num_beams < len(self._beam_hyps[batch_group_idx]):
|
| 342 |
+
raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
|
| 343 |
+
if eos_token_id is None or pad_token_id is None:
|
| 344 |
+
raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
|
| 345 |
+
# pad the batch
|
| 346 |
+
next_beam_scores[batch_idx, :] = 0
|
| 347 |
+
next_beam_tokens[batch_idx, :] = pad_token_id
|
| 348 |
+
next_beam_indices[batch_idx, :] = 0
|
| 349 |
+
continue
|
| 350 |
+
|
| 351 |
+
# next tokens for this sentence
|
| 352 |
+
beam_idx = 0
|
| 353 |
+
for beam_token_rank, (next_token, next_score, next_index) in enumerate(
|
| 354 |
+
zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
|
| 355 |
+
):
|
| 356 |
+
batch_beam_idx = batch_idx * self.group_size + next_index
|
| 357 |
+
# add to generated hypotheses if end of sentence
|
| 358 |
+
if (eos_token_id is not None) and (next_token.item() in eos_token_id):
|
| 359 |
+
# if beam_token does not belong to top num_beams tokens, it should not be added
|
| 360 |
+
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
|
| 361 |
+
if is_beam_token_worse_than_top_num_beams:
|
| 362 |
+
continue
|
| 363 |
+
if beam_indices is not None:
|
| 364 |
+
beam_index = beam_indices[batch_beam_idx]
|
| 365 |
+
beam_index = beam_index + (batch_beam_idx,)
|
| 366 |
+
else:
|
| 367 |
+
beam_index = None
|
| 368 |
+
|
| 369 |
+
self._beam_hyps[batch_group_idx].add(
|
| 370 |
+
input_ids[batch_beam_idx].clone(),
|
| 371 |
+
next_score.item(),
|
| 372 |
+
beam_indices=beam_index,
|
| 373 |
+
generated_len=cur_len - decoder_prompt_len,
|
| 374 |
+
)
|
| 375 |
+
else:
|
| 376 |
+
# add next predicted token since it is not eos_token
|
| 377 |
+
next_beam_scores[batch_idx, beam_idx] = next_score
|
| 378 |
+
next_beam_tokens[batch_idx, beam_idx] = next_token
|
| 379 |
+
next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
|
| 380 |
+
beam_idx += 1
|
| 381 |
+
|
| 382 |
+
# once the beam for next step is full, don't add more tokens to it.
|
| 383 |
+
if beam_idx == self.group_size:
|
| 384 |
+
break
|
| 385 |
+
|
| 386 |
+
if beam_idx < self.group_size:
|
| 387 |
+
raise ValueError(
|
| 388 |
+
f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
|
| 389 |
+
f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
# Check if we are done so that we can save a pad step if all(done)
|
| 393 |
+
self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
|
| 394 |
+
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
return UserDict(
|
| 398 |
+
{
|
| 399 |
+
"next_beam_scores": next_beam_scores.view(-1),
|
| 400 |
+
"next_beam_tokens": next_beam_tokens.view(-1),
|
| 401 |
+
"next_beam_indices": next_beam_indices.view(-1),
|
| 402 |
+
}
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
def finalize(
|
| 406 |
+
self,
|
| 407 |
+
input_ids: torch.LongTensor,
|
| 408 |
+
final_beam_scores: torch.FloatTensor,
|
| 409 |
+
final_beam_tokens: torch.LongTensor,
|
| 410 |
+
final_beam_indices: torch.LongTensor,
|
| 411 |
+
max_length: int,
|
| 412 |
+
pad_token_id: Optional[Union[int, torch.Tensor]] = None,
|
| 413 |
+
eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
|
| 414 |
+
beam_indices: Optional[torch.LongTensor] = None,
|
| 415 |
+
decoder_prompt_len: Optional[int] = 0,
|
| 416 |
+
) -> tuple[torch.LongTensor]:
|
| 417 |
+
batch_size = len(self._beam_hyps) // self.num_beam_groups
|
| 418 |
+
|
| 419 |
+
if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
|
| 420 |
+
if isinstance(eos_token_id, int):
|
| 421 |
+
eos_token_id = [eos_token_id]
|
| 422 |
+
eos_token_id = torch.tensor(eos_token_id)
|
| 423 |
+
|
| 424 |
+
# finalize all open beam hypotheses and add to generated hypotheses
|
| 425 |
+
for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
|
| 426 |
+
if self._done[batch_group_idx]:
|
| 427 |
+
continue
|
| 428 |
+
|
| 429 |
+
# all open beam hypotheses are added to the beam hypothesis
|
| 430 |
+
# beam hypothesis class automatically keeps the best beams
|
| 431 |
+
for index_per_group in range(self.group_size):
|
| 432 |
+
batch_beam_idx = batch_group_idx * self.group_size + index_per_group
|
| 433 |
+
final_score = final_beam_scores[batch_beam_idx].item()
|
| 434 |
+
final_tokens = input_ids[batch_beam_idx]
|
| 435 |
+
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
|
| 436 |
+
generated_len = final_tokens.shape[-1] - decoder_prompt_len
|
| 437 |
+
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
|
| 438 |
+
|
| 439 |
+
# select the best hypotheses
|
| 440 |
+
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
|
| 441 |
+
best = []
|
| 442 |
+
best_indices = []
|
| 443 |
+
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
|
| 444 |
+
|
| 445 |
+
# retrieve best hypotheses
|
| 446 |
+
for i in range(batch_size):
|
| 447 |
+
beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
|
| 448 |
+
candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
|
| 449 |
+
sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
|
| 450 |
+
for j in range(self.num_beam_hyps_to_keep):
|
| 451 |
+
best_hyp_tuple = sorted_hyps.pop()
|
| 452 |
+
best_score = best_hyp_tuple[0]
|
| 453 |
+
best_hyp = best_hyp_tuple[1]
|
| 454 |
+
best_index = best_hyp_tuple[2]
|
| 455 |
+
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
|
| 456 |
+
|
| 457 |
+
# append hyp to lists
|
| 458 |
+
best.append(best_hyp)
|
| 459 |
+
|
| 460 |
+
# append indices to list
|
| 461 |
+
best_indices.append(best_index)
|
| 462 |
+
|
| 463 |
+
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
|
| 464 |
+
|
| 465 |
+
# prepare for adding eos
|
| 466 |
+
sent_lengths_max = sent_lengths.max().item() + 1
|
| 467 |
+
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
|
| 468 |
+
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
| 469 |
+
|
| 470 |
+
if len(best_indices) > 0 and best_indices[0] is not None:
|
| 471 |
+
indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
|
| 472 |
+
else:
|
| 473 |
+
indices = None
|
| 474 |
+
|
| 475 |
+
# shorter batches are padded if needed
|
| 476 |
+
if sent_lengths.min().item() != sent_lengths.max().item():
|
| 477 |
+
if pad_token_id is None:
|
| 478 |
+
raise ValueError("`pad_token_id` has to be defined")
|
| 479 |
+
decoded.fill_(pad_token_id)
|
| 480 |
+
|
| 481 |
+
if indices is not None:
|
| 482 |
+
indices.fill_(-1)
|
| 483 |
+
|
| 484 |
+
# fill with hypotheses and eos_token_id if the latter fits in
|
| 485 |
+
for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
|
| 486 |
+
decoded[i, : sent_lengths[i]] = hypo
|
| 487 |
+
|
| 488 |
+
if indices is not None:
|
| 489 |
+
indices[i, : len(best_idx)] = torch.tensor(best_idx)
|
| 490 |
+
|
| 491 |
+
if sent_lengths[i] < sent_max_len:
|
| 492 |
+
# inserting only the first eos_token_id
|
| 493 |
+
decoded[i, sent_lengths[i]] = eos_token_id[0]
|
| 494 |
+
|
| 495 |
+
return UserDict(
|
| 496 |
+
{
|
| 497 |
+
"sequences": decoded,
|
| 498 |
+
"sequence_scores": best_scores,
|
| 499 |
+
"beam_indices": indices,
|
| 500 |
+
}
|
| 501 |
+
)
|
custom_generate/generate.py
ADDED
|
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import LogitsProcessor, LogitsProcessorList, StoppingCriteriaList, GenerationConfig
|
| 3 |
+
from transformers.generation.utils import (
|
| 4 |
+
GenerationMixin,
|
| 5 |
+
GenerateBeamDecoderOnlyOutput,
|
| 6 |
+
GenerateBeamEncoderDecoderOutput,
|
| 7 |
+
)
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
from .beam_search import BeamSearchScorer
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
class HammingDiversityLogitsProcessor(LogitsProcessor):
|
| 16 |
+
r"""
|
| 17 |
+
[`LogitsProcessor`] that enforces diverse beam search.
|
| 18 |
+
|
| 19 |
+
Note that this logits processor is only effective for [`PreTrainedModel.group_beam_search`]. See [Diverse Beam
|
| 20 |
+
Search: Decoding Diverse Solutions from Neural Sequence Models](https://huggingface.co/papers/1610.02424) for more
|
| 21 |
+
details.
|
| 22 |
+
|
| 23 |
+
Traditional beam search often generates very similar sequences across different beams.
|
| 24 |
+
`HammingDiversityLogitsProcessor` addresses this by penalizing beams that generate tokens already chosen by other
|
| 25 |
+
beams in the same time step.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
diversity_penalty (`float`):
|
| 29 |
+
This value is subtracted from a beam's score if it generates a token same as any beam from other group at a
|
| 30 |
+
particular time. A higher `diversity_penalty` will enforce greater diversity among the beams. Adjusting
|
| 31 |
+
this value can help strike a balance between diversity and natural likelihood.
|
| 32 |
+
num_beams (`int`):
|
| 33 |
+
Number of beams for beam search. 1 means no beam search.
|
| 34 |
+
num_beam_groups (`int`):
|
| 35 |
+
Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
|
| 36 |
+
[this paper](https://huggingface.co/papers/1610.02424) for more details.
|
| 37 |
+
|
| 38 |
+
Examples:
|
| 39 |
+
|
| 40 |
+
```python
|
| 41 |
+
>>> from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 42 |
+
>>> import torch
|
| 43 |
+
|
| 44 |
+
>>> # Initialize the model and tokenizer
|
| 45 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-base")
|
| 46 |
+
>>> model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-base")
|
| 47 |
+
|
| 48 |
+
>>> # A long text about the solar system
|
| 49 |
+
>>> text = (
|
| 50 |
+
... "The Solar System is a gravitationally bound system comprising the Sun and the objects that orbit it, "
|
| 51 |
+
... "either directly or indirectly. Of the objects that orbit the Sun directly, the largest are the eight "
|
| 52 |
+
... "planets, with the remainder being smaller objects, such as the five dwarf planets and small Solar System "
|
| 53 |
+
... "bodies. The Solar System formed 4.6 billion years ago from the gravitational collapse of a giant "
|
| 54 |
+
... "interstellar molecular cloud."
|
| 55 |
+
... )
|
| 56 |
+
>>> inputs = tokenizer("summarize: " + text, return_tensors="pt")
|
| 57 |
+
|
| 58 |
+
>>> # Generate diverse summary
|
| 59 |
+
>>> outputs_diverse = model.generate(
|
| 60 |
+
... **inputs,
|
| 61 |
+
... num_beam_groups=2,
|
| 62 |
+
... diversity_penalty=10.0,
|
| 63 |
+
... max_length=100,
|
| 64 |
+
... num_beams=4,
|
| 65 |
+
... num_return_sequences=2,
|
| 66 |
+
... )
|
| 67 |
+
>>> summaries_diverse = tokenizer.batch_decode(outputs_diverse, skip_special_tokens=True)
|
| 68 |
+
|
| 69 |
+
>>> # Generate non-diverse summary
|
| 70 |
+
>>> outputs_non_diverse = model.generate(
|
| 71 |
+
... **inputs,
|
| 72 |
+
... max_length=100,
|
| 73 |
+
... num_beams=4,
|
| 74 |
+
... num_return_sequences=2,
|
| 75 |
+
... )
|
| 76 |
+
>>> summary_non_diverse = tokenizer.batch_decode(outputs_non_diverse, skip_special_tokens=True)
|
| 77 |
+
|
| 78 |
+
>>> # With `diversity_penalty`, the resulting beams are much more diverse
|
| 79 |
+
>>> print(summary_non_diverse)
|
| 80 |
+
['the solar system formed 4.6 billion years ago from the collapse of a giant interstellar molecular cloud. of the objects that orbit the Sun directly, the largest are the eight planets.',
|
| 81 |
+
'the Solar System formed 4.6 billion years ago from the collapse of a giant interstellar molecular cloud. of the objects that orbit the Sun directly, the largest are the eight planets.']
|
| 82 |
+
|
| 83 |
+
>>> print(summaries_diverse)
|
| 84 |
+
['the solar system formed 4.6 billion years ago from the collapse of a giant interstellar molecular cloud. of the objects that orbit the Sun directly, the largest are the eight planets.',
|
| 85 |
+
'the solar system formed 4.6 billion years ago from the collapse of a giant interstellar molecular cloud. of the objects that orbit the Sun directly, the largest are the eight planets. the rest of the objects are smaller objects, such as the five dwarf planets and small solar system bodies.']
|
| 86 |
+
```
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int):
|
| 90 |
+
if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0):
|
| 91 |
+
raise ValueError("`diversity_penalty` should be a float strictly larger than 0.")
|
| 92 |
+
self._diversity_penalty = diversity_penalty
|
| 93 |
+
if not isinstance(num_beams, int) or num_beams < 2:
|
| 94 |
+
raise ValueError("`num_beams` should be an integer strictly larger than 1.")
|
| 95 |
+
self._num_beams = num_beams
|
| 96 |
+
if not isinstance(num_beam_groups, int) or num_beam_groups < 2:
|
| 97 |
+
raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.")
|
| 98 |
+
if num_beam_groups > num_beams:
|
| 99 |
+
raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.")
|
| 100 |
+
self._num_sub_beams = num_beams // num_beam_groups
|
| 101 |
+
|
| 102 |
+
def __call__(
|
| 103 |
+
self,
|
| 104 |
+
input_ids: torch.LongTensor,
|
| 105 |
+
scores: torch.FloatTensor,
|
| 106 |
+
current_tokens: torch.LongTensor,
|
| 107 |
+
beam_group_idx: int,
|
| 108 |
+
) -> torch.FloatTensor:
|
| 109 |
+
r"""
|
| 110 |
+
Args:
|
| 111 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 112 |
+
Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
|
| 113 |
+
scores (`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`):
|
| 114 |
+
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
|
| 115 |
+
beam search or log softmax for each vocabulary token when using beam search
|
| 116 |
+
current_tokens (`torch.LongTensor` of shape `(batch_size)`):
|
| 117 |
+
Indices of input sequence tokens in the vocabulary, corresponding to the tokens selected by the other
|
| 118 |
+
beam groups in the current generation step.
|
| 119 |
+
beam_group_idx (`int`):
|
| 120 |
+
The index of the beam group currently being processed.
|
| 121 |
+
|
| 122 |
+
Return:
|
| 123 |
+
`torch.FloatTensor` of shape `(batch_size, config.vocab_size)`:
|
| 124 |
+
The processed prediction scores.
|
| 125 |
+
"""
|
| 126 |
+
# hamming diversity: penalise using same token in current group which was used in previous groups at
|
| 127 |
+
# the same time step
|
| 128 |
+
batch_size = current_tokens.shape[0] // self._num_beams
|
| 129 |
+
group_start_idx = beam_group_idx * self._num_sub_beams
|
| 130 |
+
group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams)
|
| 131 |
+
group_size = group_end_idx - group_start_idx
|
| 132 |
+
vocab_size = scores.shape[-1]
|
| 133 |
+
|
| 134 |
+
if group_start_idx == 0:
|
| 135 |
+
return scores
|
| 136 |
+
|
| 137 |
+
scores_processed = scores.clone()
|
| 138 |
+
for batch_idx in range(batch_size):
|
| 139 |
+
# predicted tokens of last time step of previous groups
|
| 140 |
+
previous_group_tokens = current_tokens[
|
| 141 |
+
batch_idx * self._num_beams : batch_idx * self._num_beams + group_start_idx
|
| 142 |
+
]
|
| 143 |
+
token_frequency = torch.bincount(previous_group_tokens, minlength=vocab_size).to(scores.device)
|
| 144 |
+
scores_processed[batch_idx * group_size : (batch_idx + 1) * group_size] -= (
|
| 145 |
+
self._diversity_penalty * token_frequency
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
return scores_processed
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _group_beam_search(
|
| 152 |
+
model,
|
| 153 |
+
input_ids: torch.LongTensor,
|
| 154 |
+
logits_processor: LogitsProcessorList,
|
| 155 |
+
stopping_criteria: StoppingCriteriaList,
|
| 156 |
+
generation_config: GenerationConfig,
|
| 157 |
+
synced_gpus: bool,
|
| 158 |
+
**model_kwargs,
|
| 159 |
+
):
|
| 160 |
+
r"""
|
| 161 |
+
Generates sequences of token ids for models with a language modeling head using **diverse beam search
|
| 162 |
+
decoding** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
| 163 |
+
|
| 164 |
+
Parameters:
|
| 165 |
+
input_ids (`torch.LongTensor` of shape `(batch_size*num_beams, sequence_length)`):
|
| 166 |
+
The sequence used as a prompt for the generation.
|
| 167 |
+
logits_processor (`LogitsProcessorList`):
|
| 168 |
+
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
| 169 |
+
used to modify the prediction scores of the language modeling head applied at each generation step.
|
| 170 |
+
stopping_criteria (`StoppingCriteriaList`):
|
| 171 |
+
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
| 172 |
+
used to tell if the generation loop should stop.
|
| 173 |
+
generation_config ([`~generation.GenerationConfig`]):
|
| 174 |
+
The generation configuration to be used as parametrization of the decoding method.
|
| 175 |
+
synced_gpus (`bool`):
|
| 176 |
+
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
|
| 177 |
+
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
|
| 178 |
+
model_kwargs:
|
| 179 |
+
Additional model specific kwargs that will be forwarded to the `forward` function of the model. If
|
| 180 |
+
model is an encoder-decoder model the kwargs should include `encoder_outputs`.
|
| 181 |
+
|
| 182 |
+
Return:
|
| 183 |
+
[`~generation.GenerateBeamDecoderOnlyOutput`], [`~generation.GenerateBeamEncoderDecoderOutput`] or
|
| 184 |
+
`torch.LongTensor`: A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
| 185 |
+
[`~generation.GenerateBeamDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
| 186 |
+
`return_dict_in_generate=True` or a [`~generation.GenerateBeamEncoderDecoderOutput`] if
|
| 187 |
+
`model.config.is_encoder_decoder=True`.
|
| 188 |
+
"""
|
| 189 |
+
# check parameters
|
| 190 |
+
assert (
|
| 191 |
+
generation_config.diversity_penalty != 0.0 and generation_config.num_beam_groups != 1
|
| 192 |
+
), "Group beam search requires diversity_penalty > 0.0 and num_beam_groups > 1"
|
| 193 |
+
if generation_config.do_sample is True:
|
| 194 |
+
raise ValueError("Group beam search requires `do_sample` to be set to `False`")
|
| 195 |
+
if generation_config.num_beams % generation_config.num_beam_groups != 0:
|
| 196 |
+
raise ValueError("Group beam search requires `num_beams` to be divisible by `num_beam_groups`")
|
| 197 |
+
if generation_config.diversity_penalty == 0.0:
|
| 198 |
+
raise ValueError("Group beam search requires `diversity_penalty` to be greater than `0.0`, otherwise your groups will be identical.")
|
| 199 |
+
|
| 200 |
+
if generation_config.diversity_penalty is not None and generation_config.diversity_penalty > 0.0:
|
| 201 |
+
logits_processor.append(
|
| 202 |
+
HammingDiversityLogitsProcessor(
|
| 203 |
+
diversity_penalty=generation_config.diversity_penalty,
|
| 204 |
+
num_beams=generation_config.num_beams,
|
| 205 |
+
num_beam_groups=generation_config.num_beam_groups,
|
| 206 |
+
)
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# define beam scorer
|
| 210 |
+
beam_scorer = BeamSearchScorer(
|
| 211 |
+
batch_size=input_ids.shape[0],
|
| 212 |
+
num_beams=generation_config.num_beams,
|
| 213 |
+
device=input_ids.device,
|
| 214 |
+
length_penalty=generation_config.length_penalty,
|
| 215 |
+
do_early_stopping=generation_config.early_stopping,
|
| 216 |
+
num_beam_hyps_to_keep=generation_config.num_return_sequences,
|
| 217 |
+
num_beam_groups=generation_config.num_beam_groups,
|
| 218 |
+
max_length=generation_config.max_length,
|
| 219 |
+
)
|
| 220 |
+
# init values
|
| 221 |
+
pad_token_id = generation_config._pad_token_tensor
|
| 222 |
+
eos_token_id = generation_config._eos_token_tensor
|
| 223 |
+
output_attentions = generation_config.output_attentions
|
| 224 |
+
output_hidden_states = generation_config.output_hidden_states
|
| 225 |
+
output_scores = generation_config.output_scores
|
| 226 |
+
output_logits = generation_config.output_logits
|
| 227 |
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
| 228 |
+
|
| 229 |
+
num_beams = beam_scorer.num_beams
|
| 230 |
+
num_beam_groups = beam_scorer.num_beam_groups
|
| 231 |
+
num_sub_beams = num_beams // num_beam_groups
|
| 232 |
+
batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
|
| 233 |
+
device = input_ids.device
|
| 234 |
+
|
| 235 |
+
batch_beam_size, cur_len = input_ids.shape
|
| 236 |
+
model_kwargs = model._get_initial_cache_position(
|
| 237 |
+
cur_len, input_ids.device, model_kwargs
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
if return_dict_in_generate and output_scores:
|
| 241 |
+
beam_indices = [
|
| 242 |
+
tuple(() for _ in range(num_sub_beams * batch_size))
|
| 243 |
+
for _ in range(num_beam_groups)
|
| 244 |
+
]
|
| 245 |
+
else:
|
| 246 |
+
beam_indices = None
|
| 247 |
+
|
| 248 |
+
if num_beams * batch_size != batch_beam_size:
|
| 249 |
+
raise ValueError(
|
| 250 |
+
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# init attention / hidden states / scores tuples
|
| 254 |
+
scores = () if (return_dict_in_generate and output_scores) else None
|
| 255 |
+
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
| 256 |
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
| 257 |
+
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
| 258 |
+
decoder_hidden_states = (
|
| 259 |
+
() if (return_dict_in_generate and output_hidden_states) else None
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
| 263 |
+
if return_dict_in_generate and model.config.is_encoder_decoder:
|
| 264 |
+
encoder_attentions = (
|
| 265 |
+
model_kwargs["encoder_outputs"].get("attentions")
|
| 266 |
+
if output_attentions
|
| 267 |
+
else None
|
| 268 |
+
)
|
| 269 |
+
encoder_hidden_states = (
|
| 270 |
+
model_kwargs["encoder_outputs"].get("hidden_states")
|
| 271 |
+
if output_hidden_states
|
| 272 |
+
else None
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in
|
| 276 |
+
# the same group don't produce same tokens every time.
|
| 277 |
+
beam_scores = torch.full(
|
| 278 |
+
(batch_size, num_beams), -1e9, dtype=torch.float, device=device
|
| 279 |
+
)
|
| 280 |
+
beam_scores[:, ::num_sub_beams] = 0
|
| 281 |
+
beam_scores = beam_scores.view((batch_size * num_beams,))
|
| 282 |
+
|
| 283 |
+
this_peer_finished = False
|
| 284 |
+
|
| 285 |
+
decoder_prompt_len = input_ids.shape[1] # record the prompt length of decoder
|
| 286 |
+
while model._has_unfinished_sequences(
|
| 287 |
+
this_peer_finished, synced_gpus, device=input_ids.device
|
| 288 |
+
):
|
| 289 |
+
# predicted tokens in cur_len step
|
| 290 |
+
current_tokens = torch.zeros(
|
| 291 |
+
batch_size * num_beams, dtype=input_ids.dtype, device=device
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# indices which will form the beams in the next time step
|
| 295 |
+
reordering_indices = torch.zeros(
|
| 296 |
+
batch_size * num_beams, dtype=torch.long, device=device
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# do one decoder step on all beams of all sentences in batch
|
| 300 |
+
model_inputs = model.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
| 301 |
+
|
| 302 |
+
# prepare variable output controls (note: some models won't accept all output controls)
|
| 303 |
+
model_inputs.update(
|
| 304 |
+
{"output_attentions": output_attentions} if output_attentions else {}
|
| 305 |
+
)
|
| 306 |
+
model_inputs.update(
|
| 307 |
+
{"output_hidden_states": output_hidden_states}
|
| 308 |
+
if output_hidden_states
|
| 309 |
+
else {}
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
outputs = model(**model_inputs, return_dict=True)
|
| 313 |
+
|
| 314 |
+
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
|
| 315 |
+
model_kwargs = model._update_model_kwargs_for_generation(
|
| 316 |
+
outputs,
|
| 317 |
+
model_kwargs,
|
| 318 |
+
is_encoder_decoder=model.config.is_encoder_decoder,
|
| 319 |
+
)
|
| 320 |
+
if synced_gpus and this_peer_finished:
|
| 321 |
+
cur_len = cur_len + 1
|
| 322 |
+
continue
|
| 323 |
+
|
| 324 |
+
if output_scores:
|
| 325 |
+
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
|
| 326 |
+
if output_logits:
|
| 327 |
+
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
| 328 |
+
# (the clone itself is always small)
|
| 329 |
+
raw_logit_score = outputs.logits[:, -1, :].to(
|
| 330 |
+
copy=True, device=input_ids.device
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
for beam_group_idx in range(num_beam_groups):
|
| 334 |
+
group_start_idx = beam_group_idx * num_sub_beams
|
| 335 |
+
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
|
| 336 |
+
group_size = group_end_idx - group_start_idx
|
| 337 |
+
|
| 338 |
+
# indices of beams of current group among all sentences in batch
|
| 339 |
+
batch_group_indices = []
|
| 340 |
+
|
| 341 |
+
for batch_idx in range(batch_size):
|
| 342 |
+
batch_group_indices.extend(
|
| 343 |
+
[
|
| 344 |
+
batch_idx * num_beams + idx
|
| 345 |
+
for idx in range(group_start_idx, group_end_idx)
|
| 346 |
+
]
|
| 347 |
+
)
|
| 348 |
+
group_input_ids = input_ids[batch_group_indices]
|
| 349 |
+
|
| 350 |
+
# select outputs of beams of current group only
|
| 351 |
+
# No need to clone() the logits here as they will not retain outputs.logits at the end of the loop
|
| 352 |
+
# .float() is needed to retain precision for later logits manipulations
|
| 353 |
+
next_token_logits = outputs.logits[batch_group_indices, -1, :].to(
|
| 354 |
+
dtype=torch.float32, device=input_ids.device
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
next_token_scores = nn.functional.log_softmax(
|
| 358 |
+
next_token_logits, dim=-1
|
| 359 |
+
) # (batch_size * group_size, vocab_size)
|
| 360 |
+
vocab_size = next_token_scores.shape[-1]
|
| 361 |
+
|
| 362 |
+
next_token_scores_processed = logits_processor(
|
| 363 |
+
group_input_ids,
|
| 364 |
+
next_token_scores,
|
| 365 |
+
current_tokens=current_tokens,
|
| 366 |
+
beam_group_idx=beam_group_idx,
|
| 367 |
+
)
|
| 368 |
+
next_token_scores = next_token_scores_processed + beam_scores[
|
| 369 |
+
batch_group_indices
|
| 370 |
+
].unsqueeze(-1)
|
| 371 |
+
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
|
| 372 |
+
|
| 373 |
+
if output_scores:
|
| 374 |
+
processed_score[batch_group_indices] = next_token_scores_processed
|
| 375 |
+
|
| 376 |
+
# reshape for beam search
|
| 377 |
+
next_token_scores = next_token_scores.view(
|
| 378 |
+
batch_size, group_size * vocab_size
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
# Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
|
| 382 |
+
n_eos_tokens = eos_token_id.shape[0] if eos_token_id is not None else 0
|
| 383 |
+
next_token_scores, next_tokens = torch.topk(
|
| 384 |
+
next_token_scores,
|
| 385 |
+
max(2, 1 + n_eos_tokens) * group_size,
|
| 386 |
+
dim=1,
|
| 387 |
+
largest=True,
|
| 388 |
+
sorted=True,
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
|
| 392 |
+
next_tokens = next_tokens % vocab_size
|
| 393 |
+
|
| 394 |
+
# stateless
|
| 395 |
+
process_beam_indices = (
|
| 396 |
+
sum(beam_indices, ()) if beam_indices is not None else None
|
| 397 |
+
)
|
| 398 |
+
beam_outputs = beam_scorer.process(
|
| 399 |
+
group_input_ids,
|
| 400 |
+
next_token_scores,
|
| 401 |
+
next_tokens,
|
| 402 |
+
next_indices,
|
| 403 |
+
pad_token_id=pad_token_id,
|
| 404 |
+
eos_token_id=eos_token_id,
|
| 405 |
+
beam_indices=process_beam_indices,
|
| 406 |
+
group_index=beam_group_idx,
|
| 407 |
+
decoder_prompt_len=decoder_prompt_len,
|
| 408 |
+
)
|
| 409 |
+
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
|
| 410 |
+
beam_next_tokens = beam_outputs["next_beam_tokens"]
|
| 411 |
+
beam_idx = beam_outputs["next_beam_indices"]
|
| 412 |
+
|
| 413 |
+
if return_dict_in_generate and output_scores:
|
| 414 |
+
beam_indices[beam_group_idx] = tuple(
|
| 415 |
+
beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],)
|
| 416 |
+
for i in range(len(beam_indices[0]))
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
input_ids[batch_group_indices] = group_input_ids[beam_idx]
|
| 420 |
+
group_input_ids = torch.cat(
|
| 421 |
+
[group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1
|
| 422 |
+
)
|
| 423 |
+
current_tokens[batch_group_indices] = group_input_ids[:, -1]
|
| 424 |
+
|
| 425 |
+
# (beam_idx // group_size) -> batch_idx
|
| 426 |
+
# (beam_idx % group_size) -> offset of idx inside the group
|
| 427 |
+
reordering_indices[batch_group_indices] = (
|
| 428 |
+
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor")
|
| 429 |
+
+ group_start_idx
|
| 430 |
+
+ (beam_idx % group_size)
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# Store scores, attentions and hidden_states when required
|
| 434 |
+
if return_dict_in_generate:
|
| 435 |
+
if output_scores:
|
| 436 |
+
scores += (processed_score,)
|
| 437 |
+
if output_logits:
|
| 438 |
+
raw_logits += (raw_logit_score,)
|
| 439 |
+
if output_attentions:
|
| 440 |
+
decoder_attentions += (
|
| 441 |
+
(outputs.decoder_attentions,)
|
| 442 |
+
if model.config.is_encoder_decoder
|
| 443 |
+
else (outputs.attentions,)
|
| 444 |
+
)
|
| 445 |
+
if model.config.is_encoder_decoder:
|
| 446 |
+
cross_attentions += (outputs.cross_attentions,)
|
| 447 |
+
|
| 448 |
+
if output_hidden_states:
|
| 449 |
+
decoder_hidden_states += (
|
| 450 |
+
(outputs.decoder_hidden_states,)
|
| 451 |
+
if model.config.is_encoder_decoder
|
| 452 |
+
else (outputs.hidden_states,)
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
|
| 456 |
+
|
| 457 |
+
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
| 458 |
+
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
| 459 |
+
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
|
| 460 |
+
# (that way the memory peak does not include outputs.logits)
|
| 461 |
+
del outputs
|
| 462 |
+
|
| 463 |
+
# NOTE: we need to check if `model._reorder_cache` exists for special models like RAG, RecurrentGemma etc.
|
| 464 |
+
if model_kwargs.get("past_key_values", None) is not None:
|
| 465 |
+
if hasattr(model, "_reorder_cache"):
|
| 466 |
+
model_kwargs["past_key_values"] = model._reorder_cache(
|
| 467 |
+
model_kwargs["past_key_values"], reordering_indices
|
| 468 |
+
)
|
| 469 |
+
else:
|
| 470 |
+
model_kwargs["past_key_values"].reorder_cache(reordering_indices)
|
| 471 |
+
|
| 472 |
+
# increase cur_len
|
| 473 |
+
cur_len = cur_len + 1
|
| 474 |
+
|
| 475 |
+
if beam_scorer.is_done or all(stopping_criteria(input_ids, scores)):
|
| 476 |
+
this_peer_finished = True
|
| 477 |
+
|
| 478 |
+
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
|
| 479 |
+
sequence_outputs = beam_scorer.finalize(
|
| 480 |
+
input_ids,
|
| 481 |
+
beam_scores,
|
| 482 |
+
next_tokens,
|
| 483 |
+
next_indices,
|
| 484 |
+
pad_token_id=pad_token_id,
|
| 485 |
+
eos_token_id=eos_token_id,
|
| 486 |
+
max_length=stopping_criteria.max_length,
|
| 487 |
+
beam_indices=final_beam_indices,
|
| 488 |
+
decoder_prompt_len=decoder_prompt_len,
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
if return_dict_in_generate:
|
| 492 |
+
if not output_scores:
|
| 493 |
+
sequence_outputs["sequence_scores"] = None
|
| 494 |
+
|
| 495 |
+
if model.config.is_encoder_decoder:
|
| 496 |
+
return GenerateBeamEncoderDecoderOutput(
|
| 497 |
+
sequences=sequence_outputs["sequences"],
|
| 498 |
+
sequences_scores=sequence_outputs["sequence_scores"],
|
| 499 |
+
scores=scores,
|
| 500 |
+
logits=raw_logits,
|
| 501 |
+
beam_indices=sequence_outputs["beam_indices"],
|
| 502 |
+
encoder_attentions=encoder_attentions,
|
| 503 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 504 |
+
decoder_attentions=decoder_attentions,
|
| 505 |
+
cross_attentions=cross_attentions,
|
| 506 |
+
decoder_hidden_states=decoder_hidden_states,
|
| 507 |
+
past_key_values=model_kwargs.get("past_key_values"),
|
| 508 |
+
)
|
| 509 |
+
else:
|
| 510 |
+
return GenerateBeamDecoderOnlyOutput(
|
| 511 |
+
sequences=sequence_outputs["sequences"],
|
| 512 |
+
sequences_scores=sequence_outputs["sequence_scores"],
|
| 513 |
+
scores=scores,
|
| 514 |
+
logits=raw_logits,
|
| 515 |
+
beam_indices=sequence_outputs["beam_indices"],
|
| 516 |
+
attentions=decoder_attentions,
|
| 517 |
+
hidden_states=decoder_hidden_states,
|
| 518 |
+
past_key_values=model_kwargs.get("past_key_values"),
|
| 519 |
+
)
|
| 520 |
+
else:
|
| 521 |
+
return sequence_outputs["sequences"]
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
def generate(model, *args, **kwargs):
|
| 525 |
+
"""Custom generate function for group beam search decoding.
|
| 526 |
+
Args:
|
| 527 |
+
model (`PreTrainedModel`):
|
| 528 |
+
The model to generate from.
|
| 529 |
+
num_beams (`int`): The number of beams to use for beam search.
|
| 530 |
+
num_beam_groups (`int`): The number of beam groups to use for beam search.
|
| 531 |
+
length_penalty (`float`): The length penalty to use for beam search.
|
| 532 |
+
early_stopping (`bool`): Whether to stop beam search when sufficient beams have finished.
|
| 533 |
+
num_return_sequences (`int`): The number of sequences to return.
|
| 534 |
+
max_length (`int`): The maximum length of the generated sequence.
|
| 535 |
+
"""
|
| 536 |
+
generation_outputs = GenerationMixin.generate(
|
| 537 |
+
model, *args, custom_generate=_group_beam_search, **kwargs
|
| 538 |
+
)
|
| 539 |
+
return generation_outputs
|
generation_config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token_id": 151643,
|
| 3 |
+
"do_sample": true,
|
| 4 |
+
"eos_token_id": [
|
| 5 |
+
151645,
|
| 6 |
+
151643
|
| 7 |
+
],
|
| 8 |
+
"pad_token_id": 151643,
|
| 9 |
+
"temperature": 0.6,
|
| 10 |
+
"top_k": 20,
|
| 11 |
+
"top_p": 0.95,
|
| 12 |
+
"transformers_version": "4.56.0"
|
| 13 |
+
}
|
merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f47f71177f32bcd101b7573ec9171e6a57f4f4d31148d38e382306f42996874b
|
| 3 |
+
size 1503300328
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
|
| 3 |
+
size 11422654
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": false,
|
| 3 |
+
"add_prefix_space": false,
|
| 4 |
+
"added_tokens_decoder": {
|
| 5 |
+
"151643": {
|
| 6 |
+
"content": "<|endoftext|>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false,
|
| 11 |
+
"special": true
|
| 12 |
+
},
|
| 13 |
+
"151644": {
|
| 14 |
+
"content": "<|im_start|>",
|
| 15 |
+
"lstrip": false,
|
| 16 |
+
"normalized": false,
|
| 17 |
+
"rstrip": false,
|
| 18 |
+
"single_word": false,
|
| 19 |
+
"special": true
|
| 20 |
+
},
|
| 21 |
+
"151645": {
|
| 22 |
+
"content": "<|im_end|>",
|
| 23 |
+
"lstrip": false,
|
| 24 |
+
"normalized": false,
|
| 25 |
+
"rstrip": false,
|
| 26 |
+
"single_word": false,
|
| 27 |
+
"special": true
|
| 28 |
+
},
|
| 29 |
+
"151646": {
|
| 30 |
+
"content": "<|object_ref_start|>",
|
| 31 |
+
"lstrip": false,
|
| 32 |
+
"normalized": false,
|
| 33 |
+
"rstrip": false,
|
| 34 |
+
"single_word": false,
|
| 35 |
+
"special": true
|
| 36 |
+
},
|
| 37 |
+
"151647": {
|
| 38 |
+
"content": "<|object_ref_end|>",
|
| 39 |
+
"lstrip": false,
|
| 40 |
+
"normalized": false,
|
| 41 |
+
"rstrip": false,
|
| 42 |
+
"single_word": false,
|
| 43 |
+
"special": true
|
| 44 |
+
},
|
| 45 |
+
"151648": {
|
| 46 |
+
"content": "<|box_start|>",
|
| 47 |
+
"lstrip": false,
|
| 48 |
+
"normalized": false,
|
| 49 |
+
"rstrip": false,
|
| 50 |
+
"single_word": false,
|
| 51 |
+
"special": true
|
| 52 |
+
},
|
| 53 |
+
"151649": {
|
| 54 |
+
"content": "<|box_end|>",
|
| 55 |
+
"lstrip": false,
|
| 56 |
+
"normalized": false,
|
| 57 |
+
"rstrip": false,
|
| 58 |
+
"single_word": false,
|
| 59 |
+
"special": true
|
| 60 |
+
},
|
| 61 |
+
"151650": {
|
| 62 |
+
"content": "<|quad_start|>",
|
| 63 |
+
"lstrip": false,
|
| 64 |
+
"normalized": false,
|
| 65 |
+
"rstrip": false,
|
| 66 |
+
"single_word": false,
|
| 67 |
+
"special": true
|
| 68 |
+
},
|
| 69 |
+
"151651": {
|
| 70 |
+
"content": "<|quad_end|>",
|
| 71 |
+
"lstrip": false,
|
| 72 |
+
"normalized": false,
|
| 73 |
+
"rstrip": false,
|
| 74 |
+
"single_word": false,
|
| 75 |
+
"special": true
|
| 76 |
+
},
|
| 77 |
+
"151652": {
|
| 78 |
+
"content": "<|vision_start|>",
|
| 79 |
+
"lstrip": false,
|
| 80 |
+
"normalized": false,
|
| 81 |
+
"rstrip": false,
|
| 82 |
+
"single_word": false,
|
| 83 |
+
"special": true
|
| 84 |
+
},
|
| 85 |
+
"151653": {
|
| 86 |
+
"content": "<|vision_end|>",
|
| 87 |
+
"lstrip": false,
|
| 88 |
+
"normalized": false,
|
| 89 |
+
"rstrip": false,
|
| 90 |
+
"single_word": false,
|
| 91 |
+
"special": true
|
| 92 |
+
},
|
| 93 |
+
"151654": {
|
| 94 |
+
"content": "<|vision_pad|>",
|
| 95 |
+
"lstrip": false,
|
| 96 |
+
"normalized": false,
|
| 97 |
+
"rstrip": false,
|
| 98 |
+
"single_word": false,
|
| 99 |
+
"special": true
|
| 100 |
+
},
|
| 101 |
+
"151655": {
|
| 102 |
+
"content": "<|image_pad|>",
|
| 103 |
+
"lstrip": false,
|
| 104 |
+
"normalized": false,
|
| 105 |
+
"rstrip": false,
|
| 106 |
+
"single_word": false,
|
| 107 |
+
"special": true
|
| 108 |
+
},
|
| 109 |
+
"151656": {
|
| 110 |
+
"content": "<|video_pad|>",
|
| 111 |
+
"lstrip": false,
|
| 112 |
+
"normalized": false,
|
| 113 |
+
"rstrip": false,
|
| 114 |
+
"single_word": false,
|
| 115 |
+
"special": true
|
| 116 |
+
},
|
| 117 |
+
"151657": {
|
| 118 |
+
"content": "<tool_call>",
|
| 119 |
+
"lstrip": false,
|
| 120 |
+
"normalized": false,
|
| 121 |
+
"rstrip": false,
|
| 122 |
+
"single_word": false,
|
| 123 |
+
"special": false
|
| 124 |
+
},
|
| 125 |
+
"151658": {
|
| 126 |
+
"content": "</tool_call>",
|
| 127 |
+
"lstrip": false,
|
| 128 |
+
"normalized": false,
|
| 129 |
+
"rstrip": false,
|
| 130 |
+
"single_word": false,
|
| 131 |
+
"special": false
|
| 132 |
+
},
|
| 133 |
+
"151659": {
|
| 134 |
+
"content": "<|fim_prefix|>",
|
| 135 |
+
"lstrip": false,
|
| 136 |
+
"normalized": false,
|
| 137 |
+
"rstrip": false,
|
| 138 |
+
"single_word": false,
|
| 139 |
+
"special": false
|
| 140 |
+
},
|
| 141 |
+
"151660": {
|
| 142 |
+
"content": "<|fim_middle|>",
|
| 143 |
+
"lstrip": false,
|
| 144 |
+
"normalized": false,
|
| 145 |
+
"rstrip": false,
|
| 146 |
+
"single_word": false,
|
| 147 |
+
"special": false
|
| 148 |
+
},
|
| 149 |
+
"151661": {
|
| 150 |
+
"content": "<|fim_suffix|>",
|
| 151 |
+
"lstrip": false,
|
| 152 |
+
"normalized": false,
|
| 153 |
+
"rstrip": false,
|
| 154 |
+
"single_word": false,
|
| 155 |
+
"special": false
|
| 156 |
+
},
|
| 157 |
+
"151662": {
|
| 158 |
+
"content": "<|fim_pad|>",
|
| 159 |
+
"lstrip": false,
|
| 160 |
+
"normalized": false,
|
| 161 |
+
"rstrip": false,
|
| 162 |
+
"single_word": false,
|
| 163 |
+
"special": false
|
| 164 |
+
},
|
| 165 |
+
"151663": {
|
| 166 |
+
"content": "<|repo_name|>",
|
| 167 |
+
"lstrip": false,
|
| 168 |
+
"normalized": false,
|
| 169 |
+
"rstrip": false,
|
| 170 |
+
"single_word": false,
|
| 171 |
+
"special": false
|
| 172 |
+
},
|
| 173 |
+
"151664": {
|
| 174 |
+
"content": "<|file_sep|>",
|
| 175 |
+
"lstrip": false,
|
| 176 |
+
"normalized": false,
|
| 177 |
+
"rstrip": false,
|
| 178 |
+
"single_word": false,
|
| 179 |
+
"special": false
|
| 180 |
+
},
|
| 181 |
+
"151665": {
|
| 182 |
+
"content": "<tool_response>",
|
| 183 |
+
"lstrip": false,
|
| 184 |
+
"normalized": false,
|
| 185 |
+
"rstrip": false,
|
| 186 |
+
"single_word": false,
|
| 187 |
+
"special": false
|
| 188 |
+
},
|
| 189 |
+
"151666": {
|
| 190 |
+
"content": "</tool_response>",
|
| 191 |
+
"lstrip": false,
|
| 192 |
+
"normalized": false,
|
| 193 |
+
"rstrip": false,
|
| 194 |
+
"single_word": false,
|
| 195 |
+
"special": false
|
| 196 |
+
},
|
| 197 |
+
"151667": {
|
| 198 |
+
"content": "<think>",
|
| 199 |
+
"lstrip": false,
|
| 200 |
+
"normalized": false,
|
| 201 |
+
"rstrip": false,
|
| 202 |
+
"single_word": false,
|
| 203 |
+
"special": false
|
| 204 |
+
},
|
| 205 |
+
"151668": {
|
| 206 |
+
"content": "</think>",
|
| 207 |
+
"lstrip": false,
|
| 208 |
+
"normalized": false,
|
| 209 |
+
"rstrip": false,
|
| 210 |
+
"single_word": false,
|
| 211 |
+
"special": false
|
| 212 |
+
}
|
| 213 |
+
},
|
| 214 |
+
"additional_special_tokens": [
|
| 215 |
+
"<|im_start|>",
|
| 216 |
+
"<|im_end|>",
|
| 217 |
+
"<|object_ref_start|>",
|
| 218 |
+
"<|object_ref_end|>",
|
| 219 |
+
"<|box_start|>",
|
| 220 |
+
"<|box_end|>",
|
| 221 |
+
"<|quad_start|>",
|
| 222 |
+
"<|quad_end|>",
|
| 223 |
+
"<|vision_start|>",
|
| 224 |
+
"<|vision_end|>",
|
| 225 |
+
"<|vision_pad|>",
|
| 226 |
+
"<|image_pad|>",
|
| 227 |
+
"<|video_pad|>"
|
| 228 |
+
],
|
| 229 |
+
"bos_token": null,
|
| 230 |
+
"chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
|
| 231 |
+
"clean_up_tokenization_spaces": false,
|
| 232 |
+
"eos_token": "<|im_end|>",
|
| 233 |
+
"errors": "replace",
|
| 234 |
+
"model_max_length": 131072,
|
| 235 |
+
"pad_token": "<|endoftext|>",
|
| 236 |
+
"split_special_tokens": false,
|
| 237 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 238 |
+
"unk_token": null
|
| 239 |
+
}
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|