Spaces:
Runtime error
Runtime error
Commit
·
30049a9
1
Parent(s):
d4dd3c5
fix: fixing mistral answering and prompt formatting
Browse files- backend/controller.py +14 -11
- explanation/interpret_captum.py +1 -1
- explanation/interpret_shap.py +23 -23
- model/mistral.py +23 -12
backend/controller.py
CHANGED
|
@@ -59,13 +59,15 @@ def interference(
|
|
| 59 |
raise RuntimeError("There was an error in the selected XAI approach.")
|
| 60 |
|
| 61 |
# call the explained chat function with the model instance
|
| 62 |
-
prompt_output, history_output,
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
)
|
| 70 |
# if no XAI approach is selected call the vanilla chat function
|
| 71 |
else:
|
|
@@ -78,16 +80,17 @@ def interference(
|
|
| 78 |
knowledge=knowledge,
|
| 79 |
)
|
| 80 |
# set XAI outputs to disclaimer html/none
|
| 81 |
-
|
| 82 |
"""
|
| 83 |
<div style="text-align: center"><h4>Without Selected XAI Approach,
|
| 84 |
no graphic will be displayed</h4></div>
|
| 85 |
""",
|
| 86 |
[("", "")],
|
|
|
|
| 87 |
)
|
| 88 |
|
| 89 |
# return the outputs
|
| 90 |
-
return prompt_output, history_output,
|
| 91 |
|
| 92 |
|
| 93 |
# simple chat function that calls the model
|
|
@@ -121,10 +124,10 @@ def explained_chat(
|
|
| 121 |
prompt = model.format_prompt(message, history, system_prompt, knowledge)
|
| 122 |
|
| 123 |
# generating an answer using the methods chat function
|
| 124 |
-
answer, xai_graphic, xai_markup = xai.chat_explained(model, prompt)
|
| 125 |
|
| 126 |
# updating the chat history with the new answer
|
| 127 |
history.append((message, answer))
|
| 128 |
|
| 129 |
# returning the updated history, xai graphic and xai plot elements
|
| 130 |
-
return "", history, xai_graphic, xai_markup
|
|
|
|
| 59 |
raise RuntimeError("There was an error in the selected XAI approach.")
|
| 60 |
|
| 61 |
# call the explained chat function with the model instance
|
| 62 |
+
prompt_output, history_output, xai_interactive, xai_markup, xai_plot = (
|
| 63 |
+
explained_chat(
|
| 64 |
+
model=model,
|
| 65 |
+
xai=xai,
|
| 66 |
+
message=prompt,
|
| 67 |
+
history=history,
|
| 68 |
+
system_prompt=system_prompt,
|
| 69 |
+
knowledge=knowledge,
|
| 70 |
+
)
|
| 71 |
)
|
| 72 |
# if no XAI approach is selected call the vanilla chat function
|
| 73 |
else:
|
|
|
|
| 80 |
knowledge=knowledge,
|
| 81 |
)
|
| 82 |
# set XAI outputs to disclaimer html/none
|
| 83 |
+
xai_interactive, xai_markup, xai_plot = (
|
| 84 |
"""
|
| 85 |
<div style="text-align: center"><h4>Without Selected XAI Approach,
|
| 86 |
no graphic will be displayed</h4></div>
|
| 87 |
""",
|
| 88 |
[("", "")],
|
| 89 |
+
None,
|
| 90 |
)
|
| 91 |
|
| 92 |
# return the outputs
|
| 93 |
+
return prompt_output, history_output, xai_interactive, xai_markup, xai_plot
|
| 94 |
|
| 95 |
|
| 96 |
# simple chat function that calls the model
|
|
|
|
| 124 |
prompt = model.format_prompt(message, history, system_prompt, knowledge)
|
| 125 |
|
| 126 |
# generating an answer using the methods chat function
|
| 127 |
+
answer, xai_graphic, xai_markup, xai_plot = xai.chat_explained(model, prompt)
|
| 128 |
|
| 129 |
# updating the chat history with the new answer
|
| 130 |
history.append((message, answer))
|
| 131 |
|
| 132 |
# returning the updated history, xai graphic and xai plot elements
|
| 133 |
+
return "", history, xai_graphic, xai_markup, xai_plot
|
explanation/interpret_captum.py
CHANGED
|
@@ -52,4 +52,4 @@ def chat_explained(model, prompt):
|
|
| 52 |
marked_text = markup_text(input_tokens, values, variant="captum")
|
| 53 |
|
| 54 |
# return response, graphic and marked_text array
|
| 55 |
-
return response_text, graphic, marked_text
|
|
|
|
| 52 |
marked_text = markup_text(input_tokens, values, variant="captum")
|
| 53 |
|
| 54 |
# return response, graphic and marked_text array
|
| 55 |
+
return response_text, graphic, marked_text, None
|
explanation/interpret_shap.py
CHANGED
|
@@ -23,29 +23,6 @@ def extract_seq_att(shap_values):
|
|
| 23 |
return list(zip(shap_values.data[0], values))
|
| 24 |
|
| 25 |
|
| 26 |
-
# main explain function that returns a chat with explanations
|
| 27 |
-
def chat_explained(model, prompt):
|
| 28 |
-
model.set_config({})
|
| 29 |
-
|
| 30 |
-
# create the shap explainer
|
| 31 |
-
shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
|
| 32 |
-
|
| 33 |
-
# get the shap values for the prompt
|
| 34 |
-
shap_values = shap_explainer([prompt])
|
| 35 |
-
|
| 36 |
-
# create the explanation graphic and marked text array
|
| 37 |
-
graphic = create_graphic(shap_values)
|
| 38 |
-
marked_text = markup_text(
|
| 39 |
-
shap_values.data[0], shap_values.values[0], variant="shap"
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
# create the response text
|
| 43 |
-
response_text = fmt.format_output_text(shap_values.output_names)
|
| 44 |
-
|
| 45 |
-
# return response, graphic and marked_text array
|
| 46 |
-
return response_text, graphic, marked_text
|
| 47 |
-
|
| 48 |
-
|
| 49 |
# function used to wrap the model with a shap model
|
| 50 |
def wrap_shap(model):
|
| 51 |
# calling global variants
|
|
@@ -80,3 +57,26 @@ def create_graphic(shap_values):
|
|
| 80 |
|
| 81 |
# return the html graphic as string to display in iFrame
|
| 82 |
return str(graphic_html)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
return list(zip(shap_values.data[0], values))
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
# function used to wrap the model with a shap model
|
| 27 |
def wrap_shap(model):
|
| 28 |
# calling global variants
|
|
|
|
| 57 |
|
| 58 |
# return the html graphic as string to display in iFrame
|
| 59 |
return str(graphic_html)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# main explain function that returns a chat with explanations
|
| 63 |
+
def chat_explained(model, prompt):
|
| 64 |
+
model.set_config({})
|
| 65 |
+
|
| 66 |
+
# create the shap explainer
|
| 67 |
+
shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
|
| 68 |
+
|
| 69 |
+
# get the shap values for the prompt
|
| 70 |
+
shap_values = shap_explainer([prompt])
|
| 71 |
+
|
| 72 |
+
# create the explanation graphic and marked text array
|
| 73 |
+
graphic = create_graphic(shap_values)
|
| 74 |
+
marked_text = markup_text(
|
| 75 |
+
shap_values.data[0], shap_values.values[0], variant="shap"
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# create the response text
|
| 79 |
+
response_text = fmt.format_output_text(shap_values.output_names)
|
| 80 |
+
|
| 81 |
+
# return response, graphic and marked_text array
|
| 82 |
+
return response_text, graphic, marked_text, None
|
model/mistral.py
CHANGED
|
@@ -58,8 +58,8 @@ def set_config(config_dict: dict):
|
|
| 58 |
|
| 59 |
|
| 60 |
# advanced formatting function that takes into a account a conversation history
|
| 61 |
-
# CREDIT:
|
| 62 |
-
|
| 63 |
def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
|
| 64 |
prompt = ""
|
| 65 |
|
|
@@ -83,8 +83,13 @@ def format_prompt(message: str, history: list, system_prompt: str, knowledge: st
|
|
| 83 |
# adds conversation history to the prompt
|
| 84 |
for conversation in history[1:]:
|
| 85 |
# takes all the following conversations and adds them as context
|
| 86 |
-
prompt += "".join(
|
|
|
|
|
|
|
| 87 |
|
|
|
|
|
|
|
|
|
|
| 88 |
return prompt
|
| 89 |
|
| 90 |
|
|
@@ -93,16 +98,22 @@ def format_answer(answer: str):
|
|
| 93 |
# empty answer string
|
| 94 |
formatted_answer = ""
|
| 95 |
|
| 96 |
-
#
|
| 97 |
-
|
| 98 |
-
if len(parts) >= 3:
|
| 99 |
-
# Return the text after the second occurrence of [/INST]
|
| 100 |
-
formatted_answer = parts[2].strip()
|
| 101 |
-
else:
|
| 102 |
-
# Return an empty string if there are fewer than two occurrences of [/INST]
|
| 103 |
-
formatted_answer = ""
|
| 104 |
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
return formatted_answer
|
| 107 |
|
| 108 |
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
# advanced formatting function that takes into a account a conversation history
|
| 61 |
+
# CREDIT: adapated from the Mistral AI Instruct chat template
|
| 62 |
+
# see https://github.com/chujiezheng/chat_templates/blob/main/chat_templates/mistral-instruct.jinja
|
| 63 |
def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
|
| 64 |
prompt = ""
|
| 65 |
|
|
|
|
| 83 |
# adds conversation history to the prompt
|
| 84 |
for conversation in history[1:]:
|
| 85 |
# takes all the following conversations and adds them as context
|
| 86 |
+
prompt += "".join(
|
| 87 |
+
f"\n[INST] {conversation[0]} [/INST] {conversation[1]}</s>"
|
| 88 |
+
)
|
| 89 |
|
| 90 |
+
prompt += """\n[INST] {message} [/INST]"""
|
| 91 |
+
|
| 92 |
+
# returns full prompt
|
| 93 |
return prompt
|
| 94 |
|
| 95 |
|
|
|
|
| 98 |
# empty answer string
|
| 99 |
formatted_answer = ""
|
| 100 |
|
| 101 |
+
# splitting answer by instruction tokens
|
| 102 |
+
segments = answer.split("[/INST]")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
+
# checking if proper history got returned
|
| 105 |
+
if len(segments) > 1:
|
| 106 |
+
# return text after the last ['/INST'] - reponse to last message
|
| 107 |
+
formatted_answer = segments[-1].strip()
|
| 108 |
+
else:
|
| 109 |
+
# return warning and full answer if not enough [/INST] tokens found
|
| 110 |
+
gr.Warning("""
|
| 111 |
+
There was an issue with answer formatting...\n
|
| 112 |
+
returning the full answer.
|
| 113 |
+
""")
|
| 114 |
+
formatted_answer = answer
|
| 115 |
+
|
| 116 |
+
print(f"CUT:\n {answer}\nINTO:\n{formatted_answer}")
|
| 117 |
return formatted_answer
|
| 118 |
|
| 119 |
|