p-christ commited on
Commit
fa9baae
·
1 Parent(s): 0eb34d8

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +140 -10
pipeline.py CHANGED
@@ -1,4 +1,6 @@
 
1
  from typing import Dict, List, Any
 
2
 
3
  class PreTrainedPipeline():
4
  def __init__(self, path=""):
@@ -6,15 +8,143 @@ class PreTrainedPipeline():
6
  # Preload all the elements you are going to need at inference.
7
  # For instance your model, processors, tokenizer that might be needed.
8
  # This function is only called once, so do all the heavy processing I/O here"""
9
- pass
 
 
 
10
 
11
  def __call__(self, inputs: str):
12
- """
13
- Args:
14
- inputs (:obj:`str`):
15
- a string to get the features from.
16
- Return:
17
- A :obj:`list` of floats: The features computed by the model.
18
- """
19
- # IMPLEMENT_THIS
20
- return ["hello"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
2
  from typing import Dict, List, Any
3
+ import itertools
4
 
5
  class PreTrainedPipeline():
6
  def __init__(self, path=""):
 
8
  # Preload all the elements you are going to need at inference.
9
  # For instance your model, processors, tokenizer that might be needed.
10
  # This function is only called once, so do all the heavy processing I/O here"""
11
+ self.model_name="p-christ/QuestionGenerationModelOnly"
12
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
13
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
14
+
15
 
16
  def __call__(self, inputs: str):
17
+ if len(inputs) == 0: return []
18
+ inputs = " ".join(inputs.split())
19
+ sents, answers = self._extract_answers(inputs)
20
+ flat_answers = list(itertools.chain(*answers))
21
+
22
+ if len(flat_answers) == 0:
23
+ return []
24
+
25
+ qg_examples = self._prepare_inputs_for_qg_from_answers_hl(sents, answers)
26
+
27
+ qg_inputs = [example['source_text'] for example in qg_examples]
28
+ questions = self._generate_questions(qg_inputs)
29
+ output = [{'answer': example['answer'], 'question': que} for example, que in zip(qg_examples, questions)]
30
+ output = self.clean_generated_QAs(output)
31
+ return output
32
+
33
+ def _extract_answers(self, context):
34
+ print("_extract_answers")
35
+ sents, inputs = self._prepare_inputs_for_ans_extraction(context)
36
+ inputs = self._tokenize(inputs, padding=True, truncation=True)
37
+
38
+ outs = self.model.generate(
39
+ input_ids=inputs['input_ids'].to(self.device),
40
+ attention_mask=inputs['attention_mask'].to(self.device),
41
+ max_length=32,
42
+ )
43
+
44
+ dec = [self.tokenizer.decode(ids, skip_special_tokens=False) for ids in outs]
45
+ answers = [item.split('<sep>') for item in dec]
46
+ answers = [i[:-1] for i in answers]
47
+
48
+ return sents, answers
49
+
50
+
51
+ def _prepare_inputs_for_ans_extraction(self, text):
52
+ print("_prepare_inputs_for_ans_extraction")
53
+ sents = sent_tokenize(text)
54
+
55
+ inputs = []
56
+ for i in range(len(sents)):
57
+ source_text = "extract answers:"
58
+ for j, sent in enumerate(sents):
59
+ if i == j:
60
+ sent = "<hl> %s <hl>" % sent
61
+ source_text = "%s %s" % (source_text, sent)
62
+ source_text = source_text.strip()
63
+
64
+ if self.model_type == "t5":
65
+ source_text = source_text + " </s>"
66
+ inputs.append(source_text)
67
+
68
+ return sents, inputs
69
+
70
+ def _tokenize(self,
71
+ inputs,
72
+ padding=True,
73
+ truncation=True,
74
+ add_special_tokens=True,
75
+ max_length=512
76
+ ):
77
+ inputs = self.tokenizer.batch_encode_plus(
78
+ inputs,
79
+ max_length=max_length,
80
+ add_special_tokens=add_special_tokens,
81
+ truncation=truncation,
82
+ padding="max_length" if padding else False,
83
+ pad_to_max_length=padding,
84
+ return_tensors="pt"
85
+ )
86
+ return inputs
87
+
88
+ def _generate_questions(self, inputs):
89
+ print("_generate_questions")
90
+ inputs = self._tokenize(inputs, padding=True, truncation=True)
91
+
92
+ outs = self.model.generate(
93
+ input_ids=inputs['input_ids'].to(self.device),
94
+ attention_mask=inputs['attention_mask'].to(self.device),
95
+ max_length=32,
96
+ num_beams=4,
97
+ )
98
+
99
+ questions = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in outs]
100
+ return questions
101
+
102
+
103
+
104
+ def _prepare_inputs_for_qg_from_answers_hl(self, sents, answers):
105
+ print("_prepare_inputs_for_qg_from_answers_hl")
106
+ inputs = []
107
+ for i, answer in enumerate(answers):
108
+ if len(answer) == 0: continue
109
+ for answer_text in answer:
110
+ sent = sents[i]
111
+ sents_copy = sents[:]
112
+ answer_text = self.remove_pad(answer_text)
113
+ answer_text = answer_text.strip()
114
+ print("Answer", answer)
115
+ print("Answer text", answer_text)
116
+
117
+ try:
118
+ ans_start_idx = sent.lower().index(answer_text.lower())
119
+ except ValueError:
120
+ # Means the answer is not in the sentence so we skip this one
121
+ continue
122
+
123
+ sent = f"{sent[:ans_start_idx]} <hl> {answer_text} <hl> {sent[ans_start_idx + len(answer_text): ]}"
124
+ sents_copy[i] = sent
125
+
126
+ source_text = " ".join(sents_copy)
127
+ source_text = f"generate question: {source_text}"
128
+ if self.model_type == "t5":
129
+ source_text = source_text + " </s>"
130
+
131
+ inputs.append({"answer": answer_text, "source_text": source_text})
132
+
133
+ return inputs
134
+
135
+ def clean_generated_QAs(self, generated_QAs):
136
+ clean_QAs = []
137
+ answers_used = set()
138
+ # Only allow 1 question per answer, take the first case of it
139
+ for qa in generated_QAs:
140
+ if qa['answer'] in answers_used:
141
+ break
142
+ answers_used.add(qa['answer'])
143
+ clean_QAs.append(qa)
144
+ return clean_QAs
145
+
146
+ def remove_pad(self, str):
147
+ if "<pad>" in str:
148
+ return str.replace("<pad>", "")
149
+ return str
150
+