Commit
·
9a542c4
1
Parent(s):
1439daa
add chat function
Browse files- modeling_lingowhale.py +121 -0
modeling_lingowhale.py
CHANGED
|
@@ -19,6 +19,8 @@
|
|
| 19 |
|
| 20 |
import math
|
| 21 |
import os
|
|
|
|
|
|
|
| 22 |
from typing import List, Optional, Tuple, Union
|
| 23 |
|
| 24 |
import torch
|
|
@@ -28,6 +30,7 @@ from torch.nn import CrossEntropyLoss
|
|
| 28 |
from torch.nn import functional as F
|
| 29 |
from transformers import PretrainedConfig, PreTrainedModel
|
| 30 |
from transformers.activations import ACT2FN
|
|
|
|
| 31 |
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
| 32 |
CausalLMOutputWithPast)
|
| 33 |
from transformers.utils import logging
|
|
@@ -106,6 +109,44 @@ def _expand_mask(mask: torch.Tensor,
|
|
| 106 |
torch.finfo(dtype).min)
|
| 107 |
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
class LingoWhaleRMSNorm(torch.nn.Module):
|
| 110 |
|
| 111 |
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
|
@@ -931,6 +972,86 @@ class LingoWhaleForCausalLM(LingoWhalePreTrainedModel):
|
|
| 931 |
})
|
| 932 |
return model_inputs
|
| 933 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 934 |
@staticmethod
|
| 935 |
def _reorder_cache(past_key_values, beam_idx):
|
| 936 |
reordered_past = ()
|
|
|
|
| 19 |
|
| 20 |
import math
|
| 21 |
import os
|
| 22 |
+
from queue import Queue
|
| 23 |
+
from threading import Thread
|
| 24 |
from typing import List, Optional, Tuple, Union
|
| 25 |
|
| 26 |
import torch
|
|
|
|
| 30 |
from torch.nn import functional as F
|
| 31 |
from transformers import PretrainedConfig, PreTrainedModel
|
| 32 |
from transformers.activations import ACT2FN
|
| 33 |
+
from transformers.generation.utils import GenerationConfig
|
| 34 |
from transformers.modeling_outputs import (BaseModelOutputWithPast,
|
| 35 |
CausalLMOutputWithPast)
|
| 36 |
from transformers.utils import logging
|
|
|
|
| 109 |
torch.finfo(dtype).min)
|
| 110 |
|
| 111 |
|
| 112 |
+
class TextIterStreamer:
|
| 113 |
+
|
| 114 |
+
def __init__(self,
|
| 115 |
+
tokenizer,
|
| 116 |
+
skip_prompt=False,
|
| 117 |
+
skip_special_tokens=False):
|
| 118 |
+
self.tokenizer = tokenizer
|
| 119 |
+
self.skip_prompt = skip_prompt
|
| 120 |
+
self.skip_special_tokens = skip_special_tokens
|
| 121 |
+
self.tokens = []
|
| 122 |
+
self.text_queue = Queue()
|
| 123 |
+
self.next_tokens_are_prompt = True
|
| 124 |
+
|
| 125 |
+
def put(self, value):
|
| 126 |
+
if self.skip_prompt and self.next_tokens_are_prompt:
|
| 127 |
+
self.next_tokens_are_prompt = False
|
| 128 |
+
else:
|
| 129 |
+
if len(value.shape) > 1:
|
| 130 |
+
value = value[0]
|
| 131 |
+
self.tokens.extend(value.tolist())
|
| 132 |
+
self.text_queue.put(
|
| 133 |
+
self.tokenizer.decode(
|
| 134 |
+
self.tokens, skip_special_tokens=self.skip_special_tokens))
|
| 135 |
+
|
| 136 |
+
def end(self):
|
| 137 |
+
self.text_queue.put(None)
|
| 138 |
+
|
| 139 |
+
def __iter__(self):
|
| 140 |
+
return self
|
| 141 |
+
|
| 142 |
+
def __next__(self):
|
| 143 |
+
value = self.text_queue.get()
|
| 144 |
+
if value is None:
|
| 145 |
+
raise StopIteration()
|
| 146 |
+
else:
|
| 147 |
+
return value
|
| 148 |
+
|
| 149 |
+
|
| 150 |
class LingoWhaleRMSNorm(torch.nn.Module):
|
| 151 |
|
| 152 |
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
|
|
|
| 972 |
})
|
| 973 |
return model_inputs
|
| 974 |
|
| 975 |
+
def build_chat_input(self,
|
| 976 |
+
tokenizer,
|
| 977 |
+
messages: List[dict],
|
| 978 |
+
max_new_tokens: int = 0,
|
| 979 |
+
user_token_ids=[3],
|
| 980 |
+
assistant_tokens=[4]):
|
| 981 |
+
max_input_tokens = self.config.model_max_length - max_new_tokens
|
| 982 |
+
|
| 983 |
+
def _parse_messages(messages):
|
| 984 |
+
|
| 985 |
+
chat_rounds, chat_round = [], []
|
| 986 |
+
|
| 987 |
+
for message in messages:
|
| 988 |
+
if message['role'] == 'user' and len(chat_round) > 0:
|
| 989 |
+
chat_rounds.append(chat_round)
|
| 990 |
+
chat_round = []
|
| 991 |
+
chat_round.append(message)
|
| 992 |
+
|
| 993 |
+
if len(chat_round) > 0:
|
| 994 |
+
chat_rounds.append(chat_round)
|
| 995 |
+
|
| 996 |
+
return chat_rounds
|
| 997 |
+
|
| 998 |
+
chat_rounds = _parse_messages(messages)[::-1]
|
| 999 |
+
|
| 1000 |
+
def get_chat_tokens(tokenizer, chat_round, user_token_ids,
|
| 1001 |
+
assistant_tokens):
|
| 1002 |
+
tokens = []
|
| 1003 |
+
tokens += user_token_ids
|
| 1004 |
+
assert len(chat_round) < 3
|
| 1005 |
+
|
| 1006 |
+
if len(chat_round) == 1:
|
| 1007 |
+
tokens += tokenizer.encode(chat_round[0]['content'])
|
| 1008 |
+
tokens += assistant_tokens
|
| 1009 |
+
else:
|
| 1010 |
+
tokens += tokenizer.encode(chat_round[0]['content'])
|
| 1011 |
+
tokens += assistant_tokens
|
| 1012 |
+
tokens += tokenizer.encode(chat_round[1]['content'])
|
| 1013 |
+
|
| 1014 |
+
return tokens
|
| 1015 |
+
|
| 1016 |
+
input_tokens = []
|
| 1017 |
+
for chat_round in chat_rounds:
|
| 1018 |
+
chat_tokens = get_chat_tokens(tokenizer, chat_round,
|
| 1019 |
+
user_token_ids, assistant_tokens)
|
| 1020 |
+
if len(chat_tokens + input_tokens) > max_input_tokens:
|
| 1021 |
+
return input_tokens
|
| 1022 |
+
|
| 1023 |
+
input_tokens = chat_tokens + input_tokens
|
| 1024 |
+
return torch.LongTensor([input_tokens]).to(self.device)
|
| 1025 |
+
|
| 1026 |
+
def chat(self,
|
| 1027 |
+
tokenizer,
|
| 1028 |
+
messages: List[dict],
|
| 1029 |
+
stream=False,
|
| 1030 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 1031 |
+
max_new_tokens = 100):
|
| 1032 |
+
|
| 1033 |
+
|
| 1034 |
+
if generation_config is not None:
|
| 1035 |
+
max_new_tokens = generation_config.max_new_tokens
|
| 1036 |
+
|
| 1037 |
+
input_ids = self.build_chat_input(tokenizer, messages, max_new_tokens)
|
| 1038 |
+
if stream:
|
| 1039 |
+
streamer = TextIterStreamer(tokenizer,
|
| 1040 |
+
skip_prompt=True,
|
| 1041 |
+
skip_special_tokens=True)
|
| 1042 |
+
Thread(target=self.generate,
|
| 1043 |
+
kwargs=dict(inputs=input_ids,
|
| 1044 |
+
streamer=streamer,
|
| 1045 |
+
generation_config=generation_config)).start()
|
| 1046 |
+
|
| 1047 |
+
return streamer
|
| 1048 |
+
else:
|
| 1049 |
+
outputs = self.generate(input_ids,
|
| 1050 |
+
generation_config=generation_config)
|
| 1051 |
+
response = tokenizer.decode(outputs[0][len(input_ids[0]):],
|
| 1052 |
+
skip_special_tokens=True)
|
| 1053 |
+
return response
|
| 1054 |
+
|
| 1055 |
@staticmethod
|
| 1056 |
def _reorder_cache(past_key_values, beam_idx):
|
| 1057 |
reordered_past = ()
|