Merge hf-origin/main into main
Browse files- .gitattributes +1 -1
- app.py +8 -0
- climateqa/engine/talk_to_data/main.py +1 -1
- climateqa/engine/talk_to_data/myVanna.py +13 -0
- climateqa/engine/talk_to_data/plot.py +418 -0
- climateqa/engine/talk_to_data/sql_query.py +114 -0
- climateqa/engine/talk_to_data/talk_to_drias.py +317 -0
- climateqa/engine/talk_to_data/utils.py +281 -0
- climateqa/engine/talk_to_data/vanna_class.py +325 -0
- requirements.txt +1 -1
- style.css +10 -1
.gitattributes
CHANGED
|
@@ -45,4 +45,4 @@ documents/climate_gpt_v2.faiss filter=lfs diff=lfs merge=lfs -text
|
|
| 45 |
climateqa_v3.db filter=lfs diff=lfs merge=lfs -text
|
| 46 |
climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text
|
| 47 |
data/drias/drias.db filter=lfs diff=lfs merge=lfs -text
|
| 48 |
-
front/assets/*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 45 |
climateqa_v3.db filter=lfs diff=lfs merge=lfs -text
|
| 46 |
climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text
|
| 47 |
data/drias/drias.db filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
front/assets/*.png filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
|
@@ -16,7 +16,10 @@ from climateqa.chat import start_chat, chat_stream, finish_chat
|
|
| 16 |
from front.tabs import create_config_modal, cqa_tab, create_about_tab
|
| 17 |
from front.tabs import MainTabPanel, ConfigPanel
|
| 18 |
from front.tabs.tab_drias import create_drias_tab
|
|
|
|
| 19 |
from front.tabs.tab_ipcc import create_ipcc_tab
|
|
|
|
|
|
|
| 20 |
from front.utils import process_figures
|
| 21 |
from gradio_modal import Modal
|
| 22 |
|
|
@@ -533,8 +536,13 @@ def main_ui():
|
|
| 533 |
with gr.Tabs():
|
| 534 |
cqa_components = cqa_tab(tab_name="ClimateQ&A")
|
| 535 |
local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
|
|
|
|
| 536 |
drias_components = create_drias_tab(share_client=share_client, user_id=user_id)
|
| 537 |
ipcc_components = create_ipcc_tab(share_client=share_client, user_id=user_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
create_about_tab()
|
| 539 |
|
| 540 |
event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
|
|
|
|
| 16 |
from front.tabs import create_config_modal, cqa_tab, create_about_tab
|
| 17 |
from front.tabs import MainTabPanel, ConfigPanel
|
| 18 |
from front.tabs.tab_drias import create_drias_tab
|
| 19 |
+
<<<<<<< HEAD
|
| 20 |
from front.tabs.tab_ipcc import create_ipcc_tab
|
| 21 |
+
=======
|
| 22 |
+
>>>>>>> hf-origin/main
|
| 23 |
from front.utils import process_figures
|
| 24 |
from gradio_modal import Modal
|
| 25 |
|
|
|
|
| 536 |
with gr.Tabs():
|
| 537 |
cqa_components = cqa_tab(tab_name="ClimateQ&A")
|
| 538 |
local_cqa_components = cqa_tab(tab_name="France - Local Q&A")
|
| 539 |
+
<<<<<<< HEAD
|
| 540 |
drias_components = create_drias_tab(share_client=share_client, user_id=user_id)
|
| 541 |
ipcc_components = create_ipcc_tab(share_client=share_client, user_id=user_id)
|
| 542 |
+
=======
|
| 543 |
+
create_drias_tab(share_client=share_client, user_id=user_id)
|
| 544 |
+
|
| 545 |
+
>>>>>>> hf-origin/main
|
| 546 |
create_about_tab()
|
| 547 |
|
| 548 |
event_handling(cqa_components, config_components, tab_name="ClimateQ&A")
|
climateqa/engine/talk_to_data/main.py
CHANGED
|
@@ -121,4 +121,4 @@ async def ask_ipcc(query: str, index_state: int = 0, user_id: str | None = None)
|
|
| 121 |
|
| 122 |
log_drias_interaction_to_huggingface(query, sql_query, user_id)
|
| 123 |
|
| 124 |
-
return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
|
|
|
|
| 121 |
|
| 122 |
log_drias_interaction_to_huggingface(query, sql_query, user_id)
|
| 123 |
|
| 124 |
+
return sql_query, dataframe, figure, plot_information, sql_queries, result_dataframes, figures, plot_informations, index_state, plot_title_list, ""
|
climateqa/engine/talk_to_data/myVanna.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
from climateqa.engine.talk_to_data.vanna_class import MyCustomVectorDB
|
| 3 |
+
from vanna.openai import OpenAI_Chat
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
load_dotenv()
|
| 7 |
+
|
| 8 |
+
OPENAI_API_KEY = os.getenv('THEO_API_KEY')
|
| 9 |
+
|
| 10 |
+
class MyVanna(MyCustomVectorDB, OpenAI_Chat):
|
| 11 |
+
def __init__(self, config=None):
|
| 12 |
+
MyCustomVectorDB.__init__(self, config=config)
|
| 13 |
+
OpenAI_Chat.__init__(self, config=config)
|
climateqa/engine/talk_to_data/plot.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, TypedDict
|
| 2 |
+
from matplotlib.figure import figaspect
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from plotly.graph_objects import Figure
|
| 5 |
+
import plotly.graph_objects as go
|
| 6 |
+
import plotly.express as px
|
| 7 |
+
|
| 8 |
+
from climateqa.engine.talk_to_data.sql_query import (
|
| 9 |
+
indicator_for_given_year_query,
|
| 10 |
+
indicator_per_year_at_location_query,
|
| 11 |
+
)
|
| 12 |
+
from climateqa.engine.talk_to_data.config import INDICATOR_TO_UNIT
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Plot(TypedDict):
|
| 18 |
+
"""Represents a plot configuration in the DRIAS system.
|
| 19 |
+
|
| 20 |
+
This class defines the structure for configuring different types of plots
|
| 21 |
+
that can be generated from climate data.
|
| 22 |
+
|
| 23 |
+
Attributes:
|
| 24 |
+
name (str): The name of the plot type
|
| 25 |
+
description (str): A description of what the plot shows
|
| 26 |
+
params (list[str]): List of required parameters for the plot
|
| 27 |
+
plot_function (Callable[..., Callable[..., Figure]]): Function to generate the plot
|
| 28 |
+
sql_query (Callable[..., str]): Function to generate the SQL query for the plot
|
| 29 |
+
"""
|
| 30 |
+
name: str
|
| 31 |
+
description: str
|
| 32 |
+
params: list[str]
|
| 33 |
+
plot_function: Callable[..., Callable[..., Figure]]
|
| 34 |
+
sql_query: Callable[..., str]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def plot_indicator_evolution_at_location(params: dict) -> Callable[..., Figure]:
|
| 38 |
+
"""Generates a function to plot indicator evolution over time at a location.
|
| 39 |
+
|
| 40 |
+
This function creates a line plot showing how a climate indicator changes
|
| 41 |
+
over time at a specific location. It handles temperature, precipitation,
|
| 42 |
+
and other climate indicators.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
params (dict): Dictionary containing:
|
| 46 |
+
- indicator_column (str): The column name for the indicator
|
| 47 |
+
- location (str): The location to plot
|
| 48 |
+
- model (str): The climate model to use
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
| 52 |
+
|
| 53 |
+
Example:
|
| 54 |
+
>>> plot_func = plot_indicator_evolution_at_location({
|
| 55 |
+
... 'indicator_column': 'mean_temperature',
|
| 56 |
+
... 'location': 'Paris',
|
| 57 |
+
... 'model': 'ALL'
|
| 58 |
+
... })
|
| 59 |
+
>>> fig = plot_func(df)
|
| 60 |
+
"""
|
| 61 |
+
indicator = params["indicator_column"]
|
| 62 |
+
location = params["location"]
|
| 63 |
+
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
| 64 |
+
unit = INDICATOR_TO_UNIT.get(indicator, "")
|
| 65 |
+
|
| 66 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
| 67 |
+
"""Generates the actual plot from the data.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
df (pd.DataFrame): DataFrame containing the data to plot
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Figure: A plotly Figure object showing the indicator evolution
|
| 74 |
+
"""
|
| 75 |
+
fig = go.Figure()
|
| 76 |
+
if df['model'].nunique() != 1:
|
| 77 |
+
df_avg = df.groupby("year", as_index=False)[indicator].mean()
|
| 78 |
+
|
| 79 |
+
# Transform to list to avoid pandas encoding
|
| 80 |
+
indicators = df_avg[indicator].astype(float).tolist()
|
| 81 |
+
years = df_avg["year"].astype(int).tolist()
|
| 82 |
+
|
| 83 |
+
# Compute the 10-year rolling average
|
| 84 |
+
rolling_window = 10
|
| 85 |
+
sliding_averages = (
|
| 86 |
+
df_avg[indicator]
|
| 87 |
+
.rolling(window=rolling_window, min_periods=rolling_window)
|
| 88 |
+
.mean()
|
| 89 |
+
.astype(float)
|
| 90 |
+
.tolist()
|
| 91 |
+
)
|
| 92 |
+
model_label = "Model Average"
|
| 93 |
+
|
| 94 |
+
# Only add rolling average if we have enough data points
|
| 95 |
+
if len([x for x in sliding_averages if pd.notna(x)]) > 0:
|
| 96 |
+
# Sliding average dashed line
|
| 97 |
+
fig.add_scatter(
|
| 98 |
+
x=years,
|
| 99 |
+
y=sliding_averages,
|
| 100 |
+
mode="lines",
|
| 101 |
+
name="10 years rolling average",
|
| 102 |
+
line=dict(dash="dash"),
|
| 103 |
+
marker=dict(color="#d62728"),
|
| 104 |
+
hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
else:
|
| 108 |
+
df_model = df
|
| 109 |
+
|
| 110 |
+
# Transform to list to avoid pandas encoding
|
| 111 |
+
indicators = df_model[indicator].astype(float).tolist()
|
| 112 |
+
years = df_model["year"].astype(int).tolist()
|
| 113 |
+
|
| 114 |
+
# Compute the 10-year rolling average
|
| 115 |
+
rolling_window = 10
|
| 116 |
+
sliding_averages = (
|
| 117 |
+
df_model[indicator]
|
| 118 |
+
.rolling(window=rolling_window, min_periods=rolling_window)
|
| 119 |
+
.mean()
|
| 120 |
+
.astype(float)
|
| 121 |
+
.tolist()
|
| 122 |
+
)
|
| 123 |
+
model_label = f"Model : {df['model'].unique()[0]}"
|
| 124 |
+
|
| 125 |
+
# Only add rolling average if we have enough data points
|
| 126 |
+
if len([x for x in sliding_averages if pd.notna(x)]) > 0:
|
| 127 |
+
# Sliding average dashed line
|
| 128 |
+
fig.add_scatter(
|
| 129 |
+
x=years,
|
| 130 |
+
y=sliding_averages,
|
| 131 |
+
mode="lines",
|
| 132 |
+
name="10 years rolling average",
|
| 133 |
+
line=dict(dash="dash"),
|
| 134 |
+
marker=dict(color="#d62728"),
|
| 135 |
+
hovertemplate=f"10-year average: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Indicator per year plot
|
| 139 |
+
fig.add_scatter(
|
| 140 |
+
x=years,
|
| 141 |
+
y=indicators,
|
| 142 |
+
name=f"Yearly {indicator_label}",
|
| 143 |
+
mode="lines",
|
| 144 |
+
marker=dict(color="#1f77b4"),
|
| 145 |
+
hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
| 146 |
+
)
|
| 147 |
+
fig.update_layout(
|
| 148 |
+
title=f"Plot of {indicator_label} in {location} ({model_label})",
|
| 149 |
+
xaxis_title="Year",
|
| 150 |
+
yaxis_title=f"{indicator_label} ({unit})",
|
| 151 |
+
template="plotly_white",
|
| 152 |
+
)
|
| 153 |
+
return fig
|
| 154 |
+
|
| 155 |
+
return plot_data
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
indicator_evolution_at_location: Plot = {
|
| 159 |
+
"name": "Indicator evolution at location",
|
| 160 |
+
"description": "Plot an evolution of the indicator at a certain location",
|
| 161 |
+
"params": ["indicator_column", "location", "model"],
|
| 162 |
+
"plot_function": plot_indicator_evolution_at_location,
|
| 163 |
+
"sql_query": indicator_per_year_at_location_query,
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def plot_indicator_number_of_days_per_year_at_location(
|
| 168 |
+
params: dict,
|
| 169 |
+
) -> Callable[..., Figure]:
|
| 170 |
+
"""Generates a function to plot the number of days per year for an indicator.
|
| 171 |
+
|
| 172 |
+
This function creates a bar chart showing the frequency of certain climate
|
| 173 |
+
events (like days above a temperature threshold) per year at a specific location.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
params (dict): Dictionary containing:
|
| 177 |
+
- indicator_column (str): The column name for the indicator
|
| 178 |
+
- location (str): The location to plot
|
| 179 |
+
- model (str): The climate model to use
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
| 183 |
+
"""
|
| 184 |
+
indicator = params["indicator_column"]
|
| 185 |
+
location = params["location"]
|
| 186 |
+
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
| 187 |
+
unit = INDICATOR_TO_UNIT.get(indicator, "")
|
| 188 |
+
|
| 189 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
| 190 |
+
"""Generate the figure thanks to the dataframe
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
df (pd.DataFrame): pandas dataframe with the required data
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
Figure: Plotly figure
|
| 197 |
+
"""
|
| 198 |
+
fig = go.Figure()
|
| 199 |
+
if df['model'].nunique() != 1:
|
| 200 |
+
df_avg = df.groupby("year", as_index=False)[indicator].mean()
|
| 201 |
+
|
| 202 |
+
# Transform to list to avoid pandas encoding
|
| 203 |
+
indicators = df_avg[indicator].astype(float).tolist()
|
| 204 |
+
years = df_avg["year"].astype(int).tolist()
|
| 205 |
+
model_label = "Model Average"
|
| 206 |
+
|
| 207 |
+
else:
|
| 208 |
+
df_model = df
|
| 209 |
+
# Transform to list to avoid pandas encoding
|
| 210 |
+
indicators = df_model[indicator].astype(float).tolist()
|
| 211 |
+
years = df_model["year"].astype(int).tolist()
|
| 212 |
+
model_label = f"Model : {df['model'].unique()[0]}"
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# Bar plot
|
| 216 |
+
fig.add_trace(
|
| 217 |
+
go.Bar(
|
| 218 |
+
x=years,
|
| 219 |
+
y=indicators,
|
| 220 |
+
width=0.5,
|
| 221 |
+
marker=dict(color="#1f77b4"),
|
| 222 |
+
hovertemplate=f"{indicator_label}: %{{y:.2f}} {unit}<br>Year: %{{x}}<extra></extra>"
|
| 223 |
+
)
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
fig.update_layout(
|
| 227 |
+
title=f"{indicator_label} in {location} ({model_label})",
|
| 228 |
+
xaxis_title="Year",
|
| 229 |
+
yaxis_title=f"{indicator_label} ({unit})",
|
| 230 |
+
yaxis=dict(range=[0, max(indicators)]),
|
| 231 |
+
bargap=0.5,
|
| 232 |
+
template="plotly_white",
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
return fig
|
| 236 |
+
|
| 237 |
+
return plot_data
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
indicator_number_of_days_per_year_at_location: Plot = {
|
| 241 |
+
"name": "Indicator number of days per year at location",
|
| 242 |
+
"description": "Plot a barchart of the number of days per year of a certain indicator at a certain location. It is appropriate for frequency indicator.",
|
| 243 |
+
"params": ["indicator_column", "location", "model"],
|
| 244 |
+
"plot_function": plot_indicator_number_of_days_per_year_at_location,
|
| 245 |
+
"sql_query": indicator_per_year_at_location_query,
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def plot_distribution_of_indicator_for_given_year(
|
| 250 |
+
params: dict,
|
| 251 |
+
) -> Callable[..., Figure]:
|
| 252 |
+
"""Generates a function to plot the distribution of an indicator for a year.
|
| 253 |
+
|
| 254 |
+
This function creates a histogram showing the distribution of a climate
|
| 255 |
+
indicator across different locations for a specific year.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
params (dict): Dictionary containing:
|
| 259 |
+
- indicator_column (str): The column name for the indicator
|
| 260 |
+
- year (str): The year to plot
|
| 261 |
+
- model (str): The climate model to use
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
| 265 |
+
"""
|
| 266 |
+
indicator = params["indicator_column"]
|
| 267 |
+
year = params["year"]
|
| 268 |
+
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
| 269 |
+
unit = INDICATOR_TO_UNIT.get(indicator, "")
|
| 270 |
+
|
| 271 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
| 272 |
+
"""Generate the figure thanks to the dataframe
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
df (pd.DataFrame): pandas dataframe with the required data
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
Figure: Plotly figure
|
| 279 |
+
"""
|
| 280 |
+
fig = go.Figure()
|
| 281 |
+
if df['model'].nunique() != 1:
|
| 282 |
+
df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
|
| 283 |
+
indicator
|
| 284 |
+
].mean()
|
| 285 |
+
|
| 286 |
+
# Transform to list to avoid pandas encoding
|
| 287 |
+
indicators = df_avg[indicator].astype(float).tolist()
|
| 288 |
+
model_label = "Model Average"
|
| 289 |
+
|
| 290 |
+
else:
|
| 291 |
+
df_model = df
|
| 292 |
+
|
| 293 |
+
# Transform to list to avoid pandas encoding
|
| 294 |
+
indicators = df_model[indicator].astype(float).tolist()
|
| 295 |
+
model_label = f"Model : {df['model'].unique()[0]}"
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
fig.add_trace(
|
| 299 |
+
go.Histogram(
|
| 300 |
+
x=indicators,
|
| 301 |
+
opacity=0.8,
|
| 302 |
+
histnorm="percent",
|
| 303 |
+
marker=dict(color="#1f77b4"),
|
| 304 |
+
hovertemplate=f"{indicator_label}: %{{x:.2f}} {unit}<br>Frequency: %{{y:.2f}}%<extra></extra>"
|
| 305 |
+
)
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
fig.update_layout(
|
| 309 |
+
title=f"Distribution of {indicator_label} in {year} ({model_label})",
|
| 310 |
+
xaxis_title=f"{indicator_label} ({unit})",
|
| 311 |
+
yaxis_title="Frequency (%)",
|
| 312 |
+
plot_bgcolor="rgba(0, 0, 0, 0)",
|
| 313 |
+
showlegend=False,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
return fig
|
| 317 |
+
|
| 318 |
+
return plot_data
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
distribution_of_indicator_for_given_year: Plot = {
|
| 322 |
+
"name": "Distribution of an indicator for a given year",
|
| 323 |
+
"description": "Plot an histogram of the distribution for a given year of the values of an indicator",
|
| 324 |
+
"params": ["indicator_column", "model", "year"],
|
| 325 |
+
"plot_function": plot_distribution_of_indicator_for_given_year,
|
| 326 |
+
"sql_query": indicator_for_given_year_query,
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def plot_map_of_france_of_indicator_for_given_year(
|
| 331 |
+
params: dict,
|
| 332 |
+
) -> Callable[..., Figure]:
|
| 333 |
+
"""Generates a function to plot a map of France for an indicator.
|
| 334 |
+
|
| 335 |
+
This function creates a choropleth map of France showing the spatial
|
| 336 |
+
distribution of a climate indicator for a specific year.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
params (dict): Dictionary containing:
|
| 340 |
+
- indicator_column (str): The column name for the indicator
|
| 341 |
+
- year (str): The year to plot
|
| 342 |
+
- model (str): The climate model to use
|
| 343 |
+
|
| 344 |
+
Returns:
|
| 345 |
+
Callable[..., Figure]: A function that takes a DataFrame and returns a plotly Figure
|
| 346 |
+
"""
|
| 347 |
+
indicator = params["indicator_column"]
|
| 348 |
+
year = params["year"]
|
| 349 |
+
indicator_label = " ".join([word.capitalize() for word in indicator.split("_")])
|
| 350 |
+
unit = INDICATOR_TO_UNIT.get(indicator, "")
|
| 351 |
+
|
| 352 |
+
def plot_data(df: pd.DataFrame) -> Figure:
|
| 353 |
+
fig = go.Figure()
|
| 354 |
+
if df['model'].nunique() != 1:
|
| 355 |
+
df_avg = df.groupby(["latitude", "longitude"], as_index=False)[
|
| 356 |
+
indicator
|
| 357 |
+
].mean()
|
| 358 |
+
|
| 359 |
+
indicators = df_avg[indicator].astype(float).tolist()
|
| 360 |
+
latitudes = df_avg["latitude"].astype(float).tolist()
|
| 361 |
+
longitudes = df_avg["longitude"].astype(float).tolist()
|
| 362 |
+
model_label = "Model Average"
|
| 363 |
+
|
| 364 |
+
else:
|
| 365 |
+
df_model = df
|
| 366 |
+
|
| 367 |
+
# Transform to list to avoid pandas encoding
|
| 368 |
+
indicators = df_model[indicator].astype(float).tolist()
|
| 369 |
+
latitudes = df_model["latitude"].astype(float).tolist()
|
| 370 |
+
longitudes = df_model["longitude"].astype(float).tolist()
|
| 371 |
+
model_label = f"Model : {df['model'].unique()[0]}"
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
fig.add_trace(
|
| 375 |
+
go.Scattermapbox(
|
| 376 |
+
lat=latitudes,
|
| 377 |
+
lon=longitudes,
|
| 378 |
+
mode="markers",
|
| 379 |
+
marker=dict(
|
| 380 |
+
size=10,
|
| 381 |
+
color=indicators, # Color mapped to values
|
| 382 |
+
colorscale="Turbo", # Color scale (can be 'Plasma', 'Jet', etc.)
|
| 383 |
+
cmin=min(indicators), # Minimum color range
|
| 384 |
+
cmax=max(indicators), # Maximum color range
|
| 385 |
+
showscale=True, # Show colorbar
|
| 386 |
+
),
|
| 387 |
+
text=[f"{indicator_label}: {value:.2f} {unit}" for value in indicators], # Add hover text showing the indicator value
|
| 388 |
+
hoverinfo="text" # Only show the custom text on hover
|
| 389 |
+
)
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
fig.update_layout(
|
| 393 |
+
mapbox_style="open-street-map", # Use OpenStreetMap
|
| 394 |
+
mapbox_zoom=3,
|
| 395 |
+
mapbox_center={"lat": 46.6, "lon": 2.0},
|
| 396 |
+
coloraxis_colorbar=dict(title=f"{indicator_label} ({unit})"), # Add legend
|
| 397 |
+
title=f"{indicator_label} in {year} in France ({model_label}) " # Title
|
| 398 |
+
)
|
| 399 |
+
return fig
|
| 400 |
+
|
| 401 |
+
return plot_data
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
map_of_france_of_indicator_for_given_year: Plot = {
|
| 405 |
+
"name": "Map of France of an indicator for a given year",
|
| 406 |
+
"description": "Heatmap on the map of France of the values of an in indicator for a given year",
|
| 407 |
+
"params": ["indicator_column", "year", "model"],
|
| 408 |
+
"plot_function": plot_map_of_france_of_indicator_for_given_year,
|
| 409 |
+
"sql_query": indicator_for_given_year_query,
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
PLOTS = [
|
| 414 |
+
indicator_evolution_at_location,
|
| 415 |
+
indicator_number_of_days_per_year_at_location,
|
| 416 |
+
distribution_of_indicator_for_given_year,
|
| 417 |
+
map_of_france_of_indicator_for_given_year,
|
| 418 |
+
]
|
climateqa/engine/talk_to_data/sql_query.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 3 |
+
from typing import TypedDict
|
| 4 |
+
import duckdb
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
async def execute_sql_query(sql_query: str) -> pd.DataFrame:
|
| 8 |
+
"""Executes a SQL query on the DRIAS database and returns the results.
|
| 9 |
+
|
| 10 |
+
This function connects to the DuckDB database containing DRIAS climate data
|
| 11 |
+
and executes the provided SQL query. It handles the database connection and
|
| 12 |
+
returns the results as a pandas DataFrame.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
sql_query (str): The SQL query to execute
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
pd.DataFrame: A DataFrame containing the query results
|
| 19 |
+
|
| 20 |
+
Raises:
|
| 21 |
+
duckdb.Error: If there is an error executing the SQL query
|
| 22 |
+
"""
|
| 23 |
+
def _execute_query():
|
| 24 |
+
# Execute the query
|
| 25 |
+
con = duckdb.connect()
|
| 26 |
+
results = con.sql(sql_query).fetchdf()
|
| 27 |
+
# return fetched data
|
| 28 |
+
return results
|
| 29 |
+
|
| 30 |
+
# Run the query in a thread pool to avoid blocking
|
| 31 |
+
loop = asyncio.get_event_loop()
|
| 32 |
+
with ThreadPoolExecutor() as executor:
|
| 33 |
+
return await loop.run_in_executor(executor, _execute_query)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class IndicatorPerYearAtLocationQueryParams(TypedDict, total=False):
|
| 37 |
+
"""Parameters for querying an indicator's values over time at a location.
|
| 38 |
+
|
| 39 |
+
This class defines the parameters needed to query climate indicator data
|
| 40 |
+
for a specific location over multiple years.
|
| 41 |
+
|
| 42 |
+
Attributes:
|
| 43 |
+
indicator_column (str): The column name for the climate indicator
|
| 44 |
+
latitude (str): The latitude coordinate of the location
|
| 45 |
+
longitude (str): The longitude coordinate of the location
|
| 46 |
+
model (str): The climate model to use (optional)
|
| 47 |
+
"""
|
| 48 |
+
indicator_column: str
|
| 49 |
+
latitude: str
|
| 50 |
+
longitude: str
|
| 51 |
+
model: str
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def indicator_per_year_at_location_query(
|
| 55 |
+
table: str, params: IndicatorPerYearAtLocationQueryParams
|
| 56 |
+
) -> str:
|
| 57 |
+
"""SQL Query to get the evolution of an indicator per year at a certain location
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
table (str): sql table of the indicator
|
| 61 |
+
params (IndicatorPerYearAtLocationQueryParams) : dictionary with the required params for the query
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
str: the sql query
|
| 65 |
+
"""
|
| 66 |
+
indicator_column = params.get("indicator_column")
|
| 67 |
+
latitude = params.get("latitude")
|
| 68 |
+
longitude = params.get("longitude")
|
| 69 |
+
|
| 70 |
+
if indicator_column is None or latitude is None or longitude is None: # If one parameter is missing, returns an empty query
|
| 71 |
+
return ""
|
| 72 |
+
|
| 73 |
+
table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
|
| 74 |
+
|
| 75 |
+
sql_query = f"SELECT year, {indicator_column}, model\nFROM {table}\nWHERE latitude = {latitude} \nAnd longitude = {longitude} \nOrder by Year"
|
| 76 |
+
|
| 77 |
+
return sql_query
|
| 78 |
+
|
| 79 |
+
class IndicatorForGivenYearQueryParams(TypedDict, total=False):
|
| 80 |
+
"""Parameters for querying an indicator's values across locations for a year.
|
| 81 |
+
|
| 82 |
+
This class defines the parameters needed to query climate indicator data
|
| 83 |
+
across different locations for a specific year.
|
| 84 |
+
|
| 85 |
+
Attributes:
|
| 86 |
+
indicator_column (str): The column name for the climate indicator
|
| 87 |
+
year (str): The year to query
|
| 88 |
+
model (str): The climate model to use (optional)
|
| 89 |
+
"""
|
| 90 |
+
indicator_column: str
|
| 91 |
+
year: str
|
| 92 |
+
model: str
|
| 93 |
+
|
| 94 |
+
def indicator_for_given_year_query(
|
| 95 |
+
table:str, params: IndicatorForGivenYearQueryParams
|
| 96 |
+
) -> str:
|
| 97 |
+
"""SQL Query to get the values of an indicator with their latitudes, longitudes and models for a given year
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
table (str): sql table of the indicator
|
| 101 |
+
params (IndicatorForGivenYearQueryParams): dictionarry with the required params for the query
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
str: the sql query
|
| 105 |
+
"""
|
| 106 |
+
indicator_column = params.get("indicator_column")
|
| 107 |
+
year = params.get('year')
|
| 108 |
+
if year is None or indicator_column is None: # If one parameter is missing, returns an empty query
|
| 109 |
+
return ""
|
| 110 |
+
|
| 111 |
+
table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
|
| 112 |
+
|
| 113 |
+
sql_query = f"Select {indicator_column}, latitude, longitude, model\nFrom {table}\nWhere year = {year}"
|
| 114 |
+
return sql_query
|
climateqa/engine/talk_to_data/talk_to_drias.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from typing import Any, Callable, TypedDict, Optional
|
| 4 |
+
from numpy import sort
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import asyncio
|
| 7 |
+
from plotly.graph_objects import Figure
|
| 8 |
+
from climateqa.engine.llm import get_llm
|
| 9 |
+
from climateqa.engine.talk_to_data import sql_query
|
| 10 |
+
from climateqa.engine.talk_to_data.config import INDICATOR_COLUMNS_PER_TABLE
|
| 11 |
+
from climateqa.engine.talk_to_data.plot import PLOTS, Plot
|
| 12 |
+
from climateqa.engine.talk_to_data.sql_query import execute_sql_query
|
| 13 |
+
from climateqa.engine.talk_to_data.utils import (
|
| 14 |
+
detect_relevant_plots,
|
| 15 |
+
detect_year_with_openai,
|
| 16 |
+
loc2coords,
|
| 17 |
+
detect_location_with_openai,
|
| 18 |
+
nearestNeighbourSQL,
|
| 19 |
+
detect_relevant_tables,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
ROOT_PATH = os.path.dirname(os.path.dirname(os.getcwd()))
|
| 24 |
+
|
| 25 |
+
class TableState(TypedDict):
|
| 26 |
+
"""Represents the state of a table in the DRIAS workflow.
|
| 27 |
+
|
| 28 |
+
This class defines the structure for tracking the state of a table during the
|
| 29 |
+
data processing workflow, including its name, parameters, SQL query, and results.
|
| 30 |
+
|
| 31 |
+
Attributes:
|
| 32 |
+
table_name (str): The name of the table in the database
|
| 33 |
+
params (dict[str, Any]): Parameters used for querying the table
|
| 34 |
+
sql_query (str, optional): The SQL query used to fetch data
|
| 35 |
+
dataframe (pd.DataFrame | None, optional): The resulting data
|
| 36 |
+
figure (Callable[..., Figure], optional): Function to generate visualization
|
| 37 |
+
status (str): The current status of the table processing ('OK' or 'ERROR')
|
| 38 |
+
"""
|
| 39 |
+
table_name: str
|
| 40 |
+
params: dict[str, Any]
|
| 41 |
+
sql_query: Optional[str]
|
| 42 |
+
dataframe: Optional[pd.DataFrame | None]
|
| 43 |
+
figure: Optional[Callable[..., Figure]]
|
| 44 |
+
status: str
|
| 45 |
+
|
| 46 |
+
class PlotState(TypedDict):
|
| 47 |
+
"""Represents the state of a plot in the DRIAS workflow.
|
| 48 |
+
|
| 49 |
+
This class defines the structure for tracking the state of a plot during the
|
| 50 |
+
data processing workflow, including its name and associated tables.
|
| 51 |
+
|
| 52 |
+
Attributes:
|
| 53 |
+
plot_name (str): The name of the plot
|
| 54 |
+
tables (list[str]): List of tables used in the plot
|
| 55 |
+
table_states (dict[str, TableState]): States of the tables used in the plot
|
| 56 |
+
"""
|
| 57 |
+
plot_name: str
|
| 58 |
+
tables: list[str]
|
| 59 |
+
table_states: dict[str, TableState]
|
| 60 |
+
|
| 61 |
+
class State(TypedDict):
|
| 62 |
+
user_input: str
|
| 63 |
+
plots: list[str]
|
| 64 |
+
plot_states: dict[str, PlotState]
|
| 65 |
+
error: Optional[str]
|
| 66 |
+
|
| 67 |
+
async def find_relevant_plots(state: State, llm) -> list[str]:
|
| 68 |
+
print("---- Find relevant plots ----")
|
| 69 |
+
relevant_plots = await detect_relevant_plots(state['user_input'], llm)
|
| 70 |
+
return relevant_plots
|
| 71 |
+
|
| 72 |
+
async def find_relevant_tables_per_plot(state: State, plot: Plot, llm) -> list[str]:
|
| 73 |
+
print(f"---- Find relevant tables for {plot['name']} ----")
|
| 74 |
+
relevant_tables = await detect_relevant_tables(state['user_input'], plot, llm)
|
| 75 |
+
return relevant_tables
|
| 76 |
+
|
| 77 |
+
async def find_param(state: State, param_name:str, table: str) -> dict[str, Any] | None:
|
| 78 |
+
"""Perform the good method to retrieve the desired parameter
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
state (State): state of the workflow
|
| 82 |
+
param_name (str): name of the desired parameter
|
| 83 |
+
table (str): name of the table
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
dict[str, Any] | None:
|
| 87 |
+
"""
|
| 88 |
+
if param_name == 'location':
|
| 89 |
+
location = await find_location(state['user_input'], table)
|
| 90 |
+
return location
|
| 91 |
+
if param_name == 'year':
|
| 92 |
+
year = await find_year(state['user_input'])
|
| 93 |
+
return {'year': year}
|
| 94 |
+
return None
|
| 95 |
+
|
| 96 |
+
class Location(TypedDict):
|
| 97 |
+
location: str
|
| 98 |
+
latitude: Optional[str]
|
| 99 |
+
longitude: Optional[str]
|
| 100 |
+
|
| 101 |
+
async def find_location(user_input: str, table: str) -> Location:
|
| 102 |
+
print(f"---- Find location in table {table} ----")
|
| 103 |
+
location = await detect_location_with_openai(user_input)
|
| 104 |
+
output: Location = {'location' : location}
|
| 105 |
+
if location:
|
| 106 |
+
coords = loc2coords(location)
|
| 107 |
+
neighbour = nearestNeighbourSQL(coords, table)
|
| 108 |
+
output.update({
|
| 109 |
+
"latitude": neighbour[0],
|
| 110 |
+
"longitude": neighbour[1],
|
| 111 |
+
})
|
| 112 |
+
return output
|
| 113 |
+
|
| 114 |
+
async def find_year(user_input: str) -> str:
|
| 115 |
+
"""Extracts year information from user input using LLM.
|
| 116 |
+
|
| 117 |
+
This function uses an LLM to identify and extract year information from the
|
| 118 |
+
user's query, which is used to filter data in subsequent queries.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
user_input (str): The user's query text
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
str: The extracted year, or empty string if no year found
|
| 125 |
+
"""
|
| 126 |
+
print(f"---- Find year ---")
|
| 127 |
+
year = await detect_year_with_openai(user_input)
|
| 128 |
+
return year
|
| 129 |
+
|
| 130 |
+
def find_indicator_column(table: str) -> str:
|
| 131 |
+
"""Retrieves the name of the indicator column within a table.
|
| 132 |
+
|
| 133 |
+
This function maps table names to their corresponding indicator columns
|
| 134 |
+
using the predefined mapping in INDICATOR_COLUMNS_PER_TABLE.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
table (str): Name of the table in the database
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
str: Name of the indicator column for the specified table
|
| 141 |
+
|
| 142 |
+
Raises:
|
| 143 |
+
KeyError: If the table name is not found in the mapping
|
| 144 |
+
"""
|
| 145 |
+
print(f"---- Find indicator column in table {table} ----")
|
| 146 |
+
return INDICATOR_COLUMNS_PER_TABLE[table]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
async def process_table(
|
| 150 |
+
table: str,
|
| 151 |
+
params: dict[str, Any],
|
| 152 |
+
plot: Plot,
|
| 153 |
+
) -> TableState:
|
| 154 |
+
"""Processes a table to extract relevant data and generate visualizations.
|
| 155 |
+
|
| 156 |
+
This function retrieves the SQL query for the specified table, executes it,
|
| 157 |
+
and generates a visualization based on the results.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
table (str): The name of the table to process
|
| 161 |
+
params (dict[str, Any]): Parameters used for querying the table
|
| 162 |
+
plot (Plot): The plot object containing SQL query and visualization function
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
TableState: The state of the processed table
|
| 166 |
+
"""
|
| 167 |
+
table_state: TableState = {
|
| 168 |
+
'table_name': table,
|
| 169 |
+
'params': params.copy(),
|
| 170 |
+
'status': 'OK',
|
| 171 |
+
'dataframe': None,
|
| 172 |
+
'sql_query': None,
|
| 173 |
+
'figure': None
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
table_state['params']['indicator_column'] = find_indicator_column(table)
|
| 177 |
+
sql_query = plot['sql_query'](table, table_state['params'])
|
| 178 |
+
|
| 179 |
+
if sql_query == "":
|
| 180 |
+
table_state['status'] = 'ERROR'
|
| 181 |
+
return table_state
|
| 182 |
+
table_state['sql_query'] = sql_query
|
| 183 |
+
df = await execute_sql_query(sql_query)
|
| 184 |
+
|
| 185 |
+
table_state['dataframe'] = df
|
| 186 |
+
table_state['figure'] = plot['plot_function'](table_state['params'])
|
| 187 |
+
|
| 188 |
+
return table_state
|
| 189 |
+
|
| 190 |
+
async def drias_workflow(user_input: str) -> State:
|
| 191 |
+
"""Performs the complete workflow of Talk To Drias : from user input to sql queries, dataframes and figures generated
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
user_input (str): initial user input
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
State: Final state with all the results
|
| 198 |
+
"""
|
| 199 |
+
state: State = {
|
| 200 |
+
'user_input': user_input,
|
| 201 |
+
'plots': [],
|
| 202 |
+
'plot_states': {},
|
| 203 |
+
'error': ''
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
llm = get_llm(provider="openai")
|
| 207 |
+
|
| 208 |
+
plots = await find_relevant_plots(state, llm)
|
| 209 |
+
|
| 210 |
+
state['plots'] = plots
|
| 211 |
+
|
| 212 |
+
if len(state['plots']) < 1:
|
| 213 |
+
state['error'] = 'There is no plot to answer to the question'
|
| 214 |
+
return state
|
| 215 |
+
|
| 216 |
+
have_relevant_table = False
|
| 217 |
+
have_sql_query = False
|
| 218 |
+
have_dataframe = False
|
| 219 |
+
|
| 220 |
+
for plot_name in state['plots']:
|
| 221 |
+
|
| 222 |
+
plot = next((p for p in PLOTS if p['name'] == plot_name), None) # Find the associated plot object
|
| 223 |
+
if plot is None:
|
| 224 |
+
continue
|
| 225 |
+
|
| 226 |
+
plot_state: PlotState = {
|
| 227 |
+
'plot_name': plot_name,
|
| 228 |
+
'tables': [],
|
| 229 |
+
'table_states': {}
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
plot_state['plot_name'] = plot_name
|
| 233 |
+
|
| 234 |
+
relevant_tables = await find_relevant_tables_per_plot(state, plot, llm)
|
| 235 |
+
|
| 236 |
+
if len(relevant_tables) > 0 :
|
| 237 |
+
have_relevant_table = True
|
| 238 |
+
|
| 239 |
+
plot_state['tables'] = relevant_tables
|
| 240 |
+
|
| 241 |
+
params = {}
|
| 242 |
+
for param_name in plot['params']:
|
| 243 |
+
param = await find_param(state, param_name, relevant_tables[0])
|
| 244 |
+
if param:
|
| 245 |
+
params.update(param)
|
| 246 |
+
|
| 247 |
+
tasks = [process_table(table, params, plot) for table in plot_state['tables'][:3]]
|
| 248 |
+
results = await asyncio.gather(*tasks)
|
| 249 |
+
|
| 250 |
+
# Store results back in plot_state
|
| 251 |
+
have_dataframe = False
|
| 252 |
+
have_sql_query = False
|
| 253 |
+
for table_state in results:
|
| 254 |
+
if table_state['sql_query']:
|
| 255 |
+
have_sql_query = True
|
| 256 |
+
if table_state['dataframe'] is not None and len(table_state['dataframe']) > 0:
|
| 257 |
+
have_dataframe = True
|
| 258 |
+
plot_state['table_states'][table_state['table_name']] = table_state
|
| 259 |
+
|
| 260 |
+
state['plot_states'][plot_name] = plot_state
|
| 261 |
+
|
| 262 |
+
if not have_relevant_table:
|
| 263 |
+
state['error'] = "There is no relevant table in our database to answer your question"
|
| 264 |
+
elif not have_sql_query:
|
| 265 |
+
state['error'] = "There is no relevant sql query on our database that can help to answer your question"
|
| 266 |
+
elif not have_dataframe:
|
| 267 |
+
state['error'] = "There is no data in our table that can answer to your question"
|
| 268 |
+
|
| 269 |
+
return state
|
| 270 |
+
|
| 271 |
+
# def make_write_query_node():
|
| 272 |
+
|
| 273 |
+
# def write_query(state):
|
| 274 |
+
# print("---- Write query ----")
|
| 275 |
+
# for table in state["tables"]:
|
| 276 |
+
# sql_query = QUERIES[state[table]['query_type']](
|
| 277 |
+
# table=table,
|
| 278 |
+
# indicator_column=state[table]["columns"],
|
| 279 |
+
# longitude=state[table]["longitude"],
|
| 280 |
+
# latitude=state[table]["latitude"],
|
| 281 |
+
# )
|
| 282 |
+
# state[table].update({"sql_query": sql_query})
|
| 283 |
+
|
| 284 |
+
# return state
|
| 285 |
+
|
| 286 |
+
# return write_query
|
| 287 |
+
|
| 288 |
+
# def make_fetch_data_node(db_path):
|
| 289 |
+
|
| 290 |
+
# def fetch_data(state):
|
| 291 |
+
# print("---- Fetch data ----")
|
| 292 |
+
# for table in state["tables"]:
|
| 293 |
+
# results = execute_sql_query(db_path, state[table]['sql_query'])
|
| 294 |
+
# state[table].update(results)
|
| 295 |
+
|
| 296 |
+
# return state
|
| 297 |
+
|
| 298 |
+
# return fetch_data
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
## V2
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
# def make_fetch_data_node(db_path: str, llm):
|
| 306 |
+
# def fetch_data(state):
|
| 307 |
+
# print("---- Fetch data ----")
|
| 308 |
+
# db = SQLDatabase.from_uri(f"sqlite:///{db_path}")
|
| 309 |
+
# output = {}
|
| 310 |
+
# sql_query = write_sql_query(state["query"], db, state["tables"], llm)
|
| 311 |
+
# # TO DO : Add query checker
|
| 312 |
+
# print(f"SQL query : {sql_query}")
|
| 313 |
+
# output["sql_query"] = sql_query
|
| 314 |
+
# output.update(fetch_data_from_sql_query(db_path, sql_query))
|
| 315 |
+
# return output
|
| 316 |
+
|
| 317 |
+
# return fetch_data
|
climateqa/engine/talk_to_data/utils.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Annotated, TypedDict
|
| 3 |
+
import duckdb
|
| 4 |
+
from geopy.geocoders import Nominatim
|
| 5 |
+
import ast
|
| 6 |
+
from climateqa.engine.llm import get_llm
|
| 7 |
+
from climateqa.engine.talk_to_data.config import DRIAS_TABLES
|
| 8 |
+
from climateqa.engine.talk_to_data.plot import PLOTS, Plot
|
| 9 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
async def detect_location_with_openai(sentence):
|
| 13 |
+
"""
|
| 14 |
+
Detects locations in a sentence using OpenAI's API via LangChain.
|
| 15 |
+
"""
|
| 16 |
+
llm = get_llm()
|
| 17 |
+
|
| 18 |
+
prompt = f"""
|
| 19 |
+
Extract all locations (cities, countries, states, or geographical areas) mentioned in the following sentence.
|
| 20 |
+
Return the result as a Python list. If no locations are mentioned, return an empty list.
|
| 21 |
+
|
| 22 |
+
Sentence: "{sentence}"
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
response = await llm.ainvoke(prompt)
|
| 26 |
+
location_list = ast.literal_eval(response.content.strip("```python\n").strip())
|
| 27 |
+
if location_list:
|
| 28 |
+
return location_list[0]
|
| 29 |
+
else:
|
| 30 |
+
return ""
|
| 31 |
+
|
| 32 |
+
class ArrayOutput(TypedDict):
|
| 33 |
+
"""Represents the output of a function that returns an array.
|
| 34 |
+
|
| 35 |
+
This class is used to type-hint functions that return arrays,
|
| 36 |
+
ensuring consistent return types across the codebase.
|
| 37 |
+
|
| 38 |
+
Attributes:
|
| 39 |
+
array (str): A syntactically valid Python array string
|
| 40 |
+
"""
|
| 41 |
+
array: Annotated[str, "Syntactically valid python array."]
|
| 42 |
+
|
| 43 |
+
async def detect_year_with_openai(sentence: str) -> str:
|
| 44 |
+
"""
|
| 45 |
+
Detects years in a sentence using OpenAI's API via LangChain.
|
| 46 |
+
"""
|
| 47 |
+
llm = get_llm()
|
| 48 |
+
|
| 49 |
+
prompt = """
|
| 50 |
+
Extract all years mentioned in the following sentence.
|
| 51 |
+
Return the result as a Python list. If no year are mentioned, return an empty list.
|
| 52 |
+
|
| 53 |
+
Sentence: "{sentence}"
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
prompt = ChatPromptTemplate.from_template(prompt)
|
| 57 |
+
structured_llm = llm.with_structured_output(ArrayOutput)
|
| 58 |
+
chain = prompt | structured_llm
|
| 59 |
+
response: ArrayOutput = await chain.ainvoke({"sentence": sentence})
|
| 60 |
+
years_list = eval(response['array'])
|
| 61 |
+
if len(years_list) > 0:
|
| 62 |
+
return years_list[0]
|
| 63 |
+
else:
|
| 64 |
+
return ""
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def detectTable(sql_query: str) -> list[str]:
|
| 68 |
+
"""Extracts table names from a SQL query.
|
| 69 |
+
|
| 70 |
+
This function uses regular expressions to find all table names
|
| 71 |
+
referenced in a SQL query's FROM clause.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
sql_query (str): The SQL query to analyze
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
list[str]: A list of table names found in the query
|
| 78 |
+
|
| 79 |
+
Example:
|
| 80 |
+
>>> detectTable("SELECT * FROM temperature_data WHERE year > 2000")
|
| 81 |
+
['temperature_data']
|
| 82 |
+
"""
|
| 83 |
+
pattern = r'(?i)\bFROM\s+((?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+)(?:\.(?:`[^`]+`|"[^"]+"|\'[^\']+\'|\w+))*)'
|
| 84 |
+
matches = re.findall(pattern, sql_query)
|
| 85 |
+
return matches
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def loc2coords(location: str) -> tuple[float, float]:
|
| 89 |
+
"""Converts a location name to geographic coordinates.
|
| 90 |
+
|
| 91 |
+
This function uses the Nominatim geocoding service to convert
|
| 92 |
+
a location name (e.g., city name) to its latitude and longitude.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
location (str): The name of the location to geocode
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
tuple[float, float]: A tuple containing (latitude, longitude)
|
| 99 |
+
|
| 100 |
+
Raises:
|
| 101 |
+
AttributeError: If the location cannot be found
|
| 102 |
+
"""
|
| 103 |
+
geolocator = Nominatim(user_agent="city_to_latlong")
|
| 104 |
+
coords = geolocator.geocode(location)
|
| 105 |
+
return (coords.latitude, coords.longitude)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def coords2loc(coords: tuple[float, float]) -> str:
|
| 109 |
+
"""Converts geographic coordinates to a location name.
|
| 110 |
+
|
| 111 |
+
This function uses the Nominatim reverse geocoding service to convert
|
| 112 |
+
latitude and longitude coordinates to a human-readable location name.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
coords (tuple[float, float]): A tuple containing (latitude, longitude)
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
str: The address of the location, or "Unknown Location" if not found
|
| 119 |
+
|
| 120 |
+
Example:
|
| 121 |
+
>>> coords2loc((48.8566, 2.3522))
|
| 122 |
+
'Paris, France'
|
| 123 |
+
"""
|
| 124 |
+
geolocator = Nominatim(user_agent="coords_to_city")
|
| 125 |
+
try:
|
| 126 |
+
location = geolocator.reverse(coords)
|
| 127 |
+
return location.address
|
| 128 |
+
except Exception as e:
|
| 129 |
+
print(f"Error: {e}")
|
| 130 |
+
return "Unknown Location"
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def nearestNeighbourSQL(location: tuple, table: str) -> tuple[str, str]:
|
| 134 |
+
long = round(location[1], 3)
|
| 135 |
+
lat = round(location[0], 3)
|
| 136 |
+
|
| 137 |
+
table = f"'hf://datasets/timeki/drias_db/{table.lower()}.parquet'"
|
| 138 |
+
|
| 139 |
+
results = duckdb.sql(
|
| 140 |
+
f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}"
|
| 141 |
+
).fetchdf()
|
| 142 |
+
|
| 143 |
+
if len(results) == 0:
|
| 144 |
+
return "", ""
|
| 145 |
+
# cursor.execute(f"SELECT latitude, longitude FROM {table} WHERE latitude BETWEEN {lat - 0.3} AND {lat + 0.3} AND longitude BETWEEN {long - 0.3} AND {long + 0.3}")
|
| 146 |
+
return results['latitude'].iloc[0], results['longitude'].iloc[0]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
async def detect_relevant_tables(user_question: str, plot: Plot, llm) -> list[str]:
|
| 150 |
+
"""Identifies relevant tables for a plot based on user input.
|
| 151 |
+
|
| 152 |
+
This function uses an LLM to analyze the user's question and the plot
|
| 153 |
+
description to determine which tables in the DRIAS database would be
|
| 154 |
+
most relevant for generating the requested visualization.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
user_question (str): The user's question about climate data
|
| 158 |
+
plot (Plot): The plot configuration object
|
| 159 |
+
llm: The language model instance to use for analysis
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
list[str]: A list of table names that are relevant for the plot
|
| 163 |
+
|
| 164 |
+
Example:
|
| 165 |
+
>>> detect_relevant_tables(
|
| 166 |
+
... "What will the temperature be like in Paris?",
|
| 167 |
+
... indicator_evolution_at_location,
|
| 168 |
+
... llm
|
| 169 |
+
... )
|
| 170 |
+
['mean_annual_temperature', 'mean_summer_temperature']
|
| 171 |
+
"""
|
| 172 |
+
# Get all table names
|
| 173 |
+
table_names_list = DRIAS_TABLES
|
| 174 |
+
|
| 175 |
+
prompt = (
|
| 176 |
+
f"You are helping to build a plot following this description : {plot['description']}."
|
| 177 |
+
f"You are given a list of tables and a user question."
|
| 178 |
+
f"Based on the description of the plot, which table are appropriate for that kind of plot."
|
| 179 |
+
f"Write the 3 most relevant tables to use. Answer only a python list of table name."
|
| 180 |
+
f"### List of tables : {table_names_list}"
|
| 181 |
+
f"### User question : {user_question}"
|
| 182 |
+
f"### List of table name : "
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
table_names = ast.literal_eval(
|
| 186 |
+
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
|
| 187 |
+
)
|
| 188 |
+
return table_names
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def replace_coordonates(coords, query, coords_tables):
|
| 192 |
+
n = query.count(str(coords[0]))
|
| 193 |
+
|
| 194 |
+
for i in range(n):
|
| 195 |
+
query = query.replace(str(coords[0]), str(coords_tables[i][0]), 1)
|
| 196 |
+
query = query.replace(str(coords[1]), str(coords_tables[i][1]), 1)
|
| 197 |
+
return query
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
async def detect_relevant_plots(user_question: str, llm):
|
| 201 |
+
plots_description = ""
|
| 202 |
+
for plot in PLOTS:
|
| 203 |
+
plots_description += "Name: " + plot["name"]
|
| 204 |
+
plots_description += " - Description: " + plot["description"] + "\n"
|
| 205 |
+
|
| 206 |
+
prompt = (
|
| 207 |
+
f"You are helping to answer a quesiton with insightful visualizations."
|
| 208 |
+
f"You are given an user question and a list of plots with their name and description."
|
| 209 |
+
f"Based on the descriptions of the plots, which plot is appropriate to answer to this question."
|
| 210 |
+
f"Write the most relevant tables to use. Answer only a python list of plot name."
|
| 211 |
+
f"### Descriptions of the plots : {plots_description}"
|
| 212 |
+
f"### User question : {user_question}"
|
| 213 |
+
f"### Name of the plot : "
|
| 214 |
+
)
|
| 215 |
+
# prompt = (
|
| 216 |
+
# f"You are helping to answer a question with insightful visualizations. "
|
| 217 |
+
# f"Given a list of plots with their name and description: "
|
| 218 |
+
# f"{plots_description} "
|
| 219 |
+
# f"The user question is: {user_question}. "
|
| 220 |
+
# f"Choose the most relevant plots to answer the question. "
|
| 221 |
+
# f"The answer must be a Python list with the names of the relevant plots, and nothing else. "
|
| 222 |
+
# f"Ensure the response is in the exact format: ['PlotName1', 'PlotName2']."
|
| 223 |
+
# )
|
| 224 |
+
|
| 225 |
+
plot_names = ast.literal_eval(
|
| 226 |
+
(await llm.ainvoke(prompt)).content.strip("```python\n").strip()
|
| 227 |
+
)
|
| 228 |
+
return plot_names
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# Next Version
|
| 232 |
+
# class QueryOutput(TypedDict):
|
| 233 |
+
# """Generated SQL query."""
|
| 234 |
+
|
| 235 |
+
# query: Annotated[str, ..., "Syntactically valid SQL query."]
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# class PlotlyCodeOutput(TypedDict):
|
| 239 |
+
# """Generated Plotly code"""
|
| 240 |
+
|
| 241 |
+
# code: Annotated[str, ..., "Synatically valid Plotly python code."]
|
| 242 |
+
# def write_sql_query(user_input: str, db: SQLDatabase, relevant_tables: list[str], llm):
|
| 243 |
+
# """Generate SQL query to fetch information."""
|
| 244 |
+
# prompt_params = {
|
| 245 |
+
# "dialect": db.dialect,
|
| 246 |
+
# "table_info": db.get_table_info(),
|
| 247 |
+
# "input": user_input,
|
| 248 |
+
# "relevant_tables": relevant_tables,
|
| 249 |
+
# "model": "ALADIN63_CNRM-CM5",
|
| 250 |
+
# }
|
| 251 |
+
|
| 252 |
+
# prompt = ChatPromptTemplate.from_template(query_prompt_template)
|
| 253 |
+
# structured_llm = llm.with_structured_output(QueryOutput)
|
| 254 |
+
# chain = prompt | structured_llm
|
| 255 |
+
# result = chain.invoke(prompt_params)
|
| 256 |
+
|
| 257 |
+
# return result["query"]
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# def fetch_data_from_sql_query(db: str, sql_query: str):
|
| 261 |
+
# conn = sqlite3.connect(db)
|
| 262 |
+
# cursor = conn.cursor()
|
| 263 |
+
# cursor.execute(sql_query)
|
| 264 |
+
# column_names = [desc[0] for desc in cursor.description]
|
| 265 |
+
# values = cursor.fetchall()
|
| 266 |
+
# return {"column_names": column_names, "data": values}
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# def generate_chart_code(user_input: str, sql_query: list[str], llm):
|
| 270 |
+
# """ "Generate plotly python code for the chart based on the sql query and the user question"""
|
| 271 |
+
|
| 272 |
+
# class PlotlyCodeOutput(TypedDict):
|
| 273 |
+
# """Generated Plotly code"""
|
| 274 |
+
|
| 275 |
+
# code: Annotated[str, ..., "Synatically valid Plotly python code."]
|
| 276 |
+
|
| 277 |
+
# prompt = ChatPromptTemplate.from_template(plot_prompt_template)
|
| 278 |
+
# structured_llm = llm.with_structured_output(PlotlyCodeOutput)
|
| 279 |
+
# chain = prompt | structured_llm
|
| 280 |
+
# result = chain.invoke({"input": user_input, "sql_query": sql_query})
|
| 281 |
+
# return result["code"]
|
climateqa/engine/talk_to_data/vanna_class.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from vanna.base import VannaBase
|
| 2 |
+
from pinecone import Pinecone
|
| 3 |
+
from climateqa.engine.embeddings import get_embeddings_function
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import hashlib
|
| 6 |
+
|
| 7 |
+
class MyCustomVectorDB(VannaBase):
|
| 8 |
+
|
| 9 |
+
"""
|
| 10 |
+
VectorDB class for storing and retrieving vectors from Pinecone.
|
| 11 |
+
|
| 12 |
+
args :
|
| 13 |
+
config (dict) : Configuration dictionary containing the Pinecone API key and the index name :
|
| 14 |
+
- pc_api_key (str) : Pinecone API key
|
| 15 |
+
- index_name (str) : Pinecone index name
|
| 16 |
+
- top_k (int) : Number of top results to return (default = 2)
|
| 17 |
+
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self,config):
|
| 21 |
+
super().__init__(config = config)
|
| 22 |
+
try :
|
| 23 |
+
self.api_key = config.get('pc_api_key')
|
| 24 |
+
self.index_name = config.get('index_name')
|
| 25 |
+
except :
|
| 26 |
+
raise Exception("Please provide the Pinecone API key and the index name")
|
| 27 |
+
|
| 28 |
+
self.pc = Pinecone(api_key = self.api_key)
|
| 29 |
+
self.index = self.pc.Index(self.index_name)
|
| 30 |
+
self.top_k = config.get('top_k', 2)
|
| 31 |
+
self.embeddings = get_embeddings_function()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def check_embedding(self, id, namespace):
|
| 35 |
+
fetched = self.index.fetch(ids = [id], namespace = namespace)
|
| 36 |
+
if fetched['vectors'] == {}:
|
| 37 |
+
return False
|
| 38 |
+
return True
|
| 39 |
+
|
| 40 |
+
def generate_hash_id(self, data: str) -> str:
|
| 41 |
+
"""
|
| 42 |
+
Generate a unique hash ID for the given data.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
data (str): The input data to hash (e.g., a concatenated string of user attributes).
|
| 46 |
+
|
| 47 |
+
Returns:
|
| 48 |
+
str: A unique hash ID as a hexadecimal string.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
data_bytes = data.encode('utf-8')
|
| 52 |
+
hash_object = hashlib.sha256(data_bytes)
|
| 53 |
+
hash_id = hash_object.hexdigest()
|
| 54 |
+
|
| 55 |
+
return hash_id
|
| 56 |
+
|
| 57 |
+
def add_ddl(self, ddl: str, **kwargs) -> str:
|
| 58 |
+
id = self.generate_hash_id(ddl) + '_ddl'
|
| 59 |
+
|
| 60 |
+
if self.check_embedding(id, 'ddl'):
|
| 61 |
+
print(f"DDL having id {id} already exists")
|
| 62 |
+
return id
|
| 63 |
+
|
| 64 |
+
self.index.upsert(
|
| 65 |
+
vectors = [(id, self.embeddings.embed_query(ddl), {'ddl': ddl})],
|
| 66 |
+
namespace = 'ddl'
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
return id
|
| 70 |
+
|
| 71 |
+
def add_documentation(self, doc: str, **kwargs) -> str:
|
| 72 |
+
id = self.generate_hash_id(doc) + '_doc'
|
| 73 |
+
|
| 74 |
+
if self.check_embedding(id, 'documentation'):
|
| 75 |
+
print(f"Documentation having id {id} already exists")
|
| 76 |
+
return id
|
| 77 |
+
|
| 78 |
+
self.index.upsert(
|
| 79 |
+
vectors = [(id, self.embeddings.embed_query(doc), {'doc': doc})],
|
| 80 |
+
namespace = 'documentation'
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
return id
|
| 84 |
+
|
| 85 |
+
def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
|
| 86 |
+
id = self.generate_hash_id(question) + '_sql'
|
| 87 |
+
|
| 88 |
+
if self.check_embedding(id, 'question_sql'):
|
| 89 |
+
print(f"Question-SQL pair having id {id} already exists")
|
| 90 |
+
return id
|
| 91 |
+
|
| 92 |
+
self.index.upsert(
|
| 93 |
+
vectors = [(id, self.embeddings.embed_query(question + sql), {'question': question, 'sql': sql})],
|
| 94 |
+
namespace = 'question_sql'
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return id
|
| 98 |
+
|
| 99 |
+
def get_related_ddl(self, question: str, **kwargs) -> list:
|
| 100 |
+
res = self.index.query(
|
| 101 |
+
vector=self.embeddings.embed_query(question),
|
| 102 |
+
top_k=self.top_k,
|
| 103 |
+
namespace='ddl',
|
| 104 |
+
include_metadata=True
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
return [match['metadata']['ddl'] for match in res['matches']]
|
| 108 |
+
|
| 109 |
+
def get_related_documentation(self, question: str, **kwargs) -> list:
|
| 110 |
+
res = self.index.query(
|
| 111 |
+
vector=self.embeddings.embed_query(question),
|
| 112 |
+
top_k=self.top_k,
|
| 113 |
+
namespace='documentation',
|
| 114 |
+
include_metadata=True
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
return [match['metadata']['doc'] for match in res['matches']]
|
| 118 |
+
|
| 119 |
+
def get_similar_question_sql(self, question: str, **kwargs) -> list:
|
| 120 |
+
res = self.index.query(
|
| 121 |
+
vector=self.embeddings.embed_query(question),
|
| 122 |
+
top_k=self.top_k,
|
| 123 |
+
namespace='question_sql',
|
| 124 |
+
include_metadata=True
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
return [(match['metadata']['question'], match['metadata']['sql']) for match in res['matches']]
|
| 128 |
+
|
| 129 |
+
def get_training_data(self, **kwargs) -> pd.DataFrame:
|
| 130 |
+
|
| 131 |
+
list_of_data = []
|
| 132 |
+
|
| 133 |
+
namespaces = ['ddl', 'documentation', 'question_sql']
|
| 134 |
+
|
| 135 |
+
for namespace in namespaces:
|
| 136 |
+
|
| 137 |
+
data = self.index.query(
|
| 138 |
+
top_k=10000,
|
| 139 |
+
namespace=namespace,
|
| 140 |
+
include_metadata=True,
|
| 141 |
+
include_values=False
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
for match in data['matches']:
|
| 145 |
+
list_of_data.append(match['metadata'])
|
| 146 |
+
|
| 147 |
+
return pd.DataFrame(list_of_data)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def remove_training_data(self, id: str, **kwargs) -> bool:
|
| 152 |
+
if id.endswith("_ddl"):
|
| 153 |
+
self.Index.delete(ids=[id], namespace="_ddl")
|
| 154 |
+
return True
|
| 155 |
+
if id.endswith("_sql"):
|
| 156 |
+
self.index.delete(ids=[id], namespace="_sql")
|
| 157 |
+
return True
|
| 158 |
+
|
| 159 |
+
if id.endswith("_doc"):
|
| 160 |
+
self.Index.delete(ids=[id], namespace="_doc")
|
| 161 |
+
return True
|
| 162 |
+
|
| 163 |
+
return False
|
| 164 |
+
|
| 165 |
+
def generate_embedding(self, text, **kwargs):
|
| 166 |
+
# Implement the method here
|
| 167 |
+
pass
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def get_sql_prompt(
|
| 171 |
+
self,
|
| 172 |
+
initial_prompt : str,
|
| 173 |
+
question: str,
|
| 174 |
+
question_sql_list: list,
|
| 175 |
+
ddl_list: list,
|
| 176 |
+
doc_list: list,
|
| 177 |
+
**kwargs,
|
| 178 |
+
):
|
| 179 |
+
"""
|
| 180 |
+
Example:
|
| 181 |
+
```python
|
| 182 |
+
vn.get_sql_prompt(
|
| 183 |
+
question="What are the top 10 customers by sales?",
|
| 184 |
+
question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
|
| 185 |
+
ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
|
| 186 |
+
doc_list=["The customers table contains information about customers and their sales."],
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
This method is used to generate a prompt for the LLM to generate SQL.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
question (str): The question to generate SQL for.
|
| 195 |
+
question_sql_list (list): A list of questions and their corresponding SQL statements.
|
| 196 |
+
ddl_list (list): A list of DDL statements.
|
| 197 |
+
doc_list (list): A list of documentation.
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
any: The prompt for the LLM to generate SQL.
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
if initial_prompt is None:
|
| 204 |
+
initial_prompt = f"You are a {self.dialect} expert. " + \
|
| 205 |
+
"Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
|
| 206 |
+
|
| 207 |
+
initial_prompt = self.add_ddl_to_prompt(
|
| 208 |
+
initial_prompt, ddl_list, max_tokens=self.max_tokens
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if self.static_documentation != "":
|
| 212 |
+
doc_list.append(self.static_documentation)
|
| 213 |
+
|
| 214 |
+
initial_prompt = self.add_documentation_to_prompt(
|
| 215 |
+
initial_prompt, doc_list, max_tokens=self.max_tokens
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# initial_prompt = self.add_sql_to_prompt(
|
| 219 |
+
# initial_prompt, question_sql_list, max_tokens=self.max_tokens
|
| 220 |
+
# )
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
initial_prompt += (
|
| 224 |
+
"===Response Guidelines \n"
|
| 225 |
+
"1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
|
| 226 |
+
"2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
|
| 227 |
+
"3. If the provided context is insufficient, please give a sql query based on your knowledge and the context provided. \n"
|
| 228 |
+
"4. Please use the most relevant table(s). \n"
|
| 229 |
+
"5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
|
| 230 |
+
f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
|
| 231 |
+
f"7. Add a description of the table in the result of the sql query, if relevant. \n"
|
| 232 |
+
"8 Make sure to include the relevant KPI in the SQL query. The query should return impactfull data \n"
|
| 233 |
+
# f"8. If a set of latitude,longitude is provided, make a intermediate query to find the nearest value in the table and replace the coordinates in the sql query. \n"
|
| 234 |
+
# "7. Add a description of the table in the result of the sql query."
|
| 235 |
+
# "7. If the question is about a specific latitude, longitude, query an interval of 0.3 and keep only the first set of coordinate. \n"
|
| 236 |
+
# "7. Table names should be included in the result of the sql query. Use for example Mean_winter_temperature AS table_name in the query \n"
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
message_log = [self.system_message(initial_prompt)]
|
| 241 |
+
|
| 242 |
+
for example in question_sql_list:
|
| 243 |
+
if example is None:
|
| 244 |
+
print("example is None")
|
| 245 |
+
else:
|
| 246 |
+
if example is not None and "question" in example and "sql" in example:
|
| 247 |
+
message_log.append(self.user_message(example["question"]))
|
| 248 |
+
message_log.append(self.assistant_message(example["sql"]))
|
| 249 |
+
|
| 250 |
+
message_log.append(self.user_message(question))
|
| 251 |
+
|
| 252 |
+
return message_log
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# def get_sql_prompt(
|
| 256 |
+
# self,
|
| 257 |
+
# initial_prompt : str,
|
| 258 |
+
# question: str,
|
| 259 |
+
# question_sql_list: list,
|
| 260 |
+
# ddl_list: list,
|
| 261 |
+
# doc_list: list,
|
| 262 |
+
# **kwargs,
|
| 263 |
+
# ):
|
| 264 |
+
# """
|
| 265 |
+
# Example:
|
| 266 |
+
# ```python
|
| 267 |
+
# vn.get_sql_prompt(
|
| 268 |
+
# question="What are the top 10 customers by sales?",
|
| 269 |
+
# question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}],
|
| 270 |
+
# ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"],
|
| 271 |
+
# doc_list=["The customers table contains information about customers and their sales."],
|
| 272 |
+
# )
|
| 273 |
+
|
| 274 |
+
# ```
|
| 275 |
+
|
| 276 |
+
# This method is used to generate a prompt for the LLM to generate SQL.
|
| 277 |
+
|
| 278 |
+
# Args:
|
| 279 |
+
# question (str): The question to generate SQL for.
|
| 280 |
+
# question_sql_list (list): A list of questions and their corresponding SQL statements.
|
| 281 |
+
# ddl_list (list): A list of DDL statements.
|
| 282 |
+
# doc_list (list): A list of documentation.
|
| 283 |
+
|
| 284 |
+
# Returns:
|
| 285 |
+
# any: The prompt for the LLM to generate SQL.
|
| 286 |
+
# """
|
| 287 |
+
|
| 288 |
+
# if initial_prompt is None:
|
| 289 |
+
# initial_prompt = f"You are a {self.dialect} expert. " + \
|
| 290 |
+
# "Please help to generate a SQL query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions. "
|
| 291 |
+
|
| 292 |
+
# initial_prompt = self.add_ddl_to_prompt(
|
| 293 |
+
# initial_prompt, ddl_list, max_tokens=self.max_tokens
|
| 294 |
+
# )
|
| 295 |
+
|
| 296 |
+
# if self.static_documentation != "":
|
| 297 |
+
# doc_list.append(self.static_documentation)
|
| 298 |
+
|
| 299 |
+
# initial_prompt = self.add_documentation_to_prompt(
|
| 300 |
+
# initial_prompt, doc_list, max_tokens=self.max_tokens
|
| 301 |
+
# )
|
| 302 |
+
|
| 303 |
+
# initial_prompt += (
|
| 304 |
+
# "===Response Guidelines \n"
|
| 305 |
+
# "1. If the provided context is sufficient, please generate a valid SQL query without any explanations for the question. \n"
|
| 306 |
+
# "2. If the provided context is almost sufficient but requires knowledge of a specific string in a particular column, please generate an intermediate SQL query to find the distinct strings in that column. Prepend the query with a comment saying intermediate_sql \n"
|
| 307 |
+
# "3. If the provided context is insufficient, please explain why it can't be generated. \n"
|
| 308 |
+
# "4. Please use the most relevant table(s). \n"
|
| 309 |
+
# "5. If the question has been asked and answered before, please repeat the answer exactly as it was given before. \n"
|
| 310 |
+
# f"6. Ensure that the output SQL is {self.dialect}-compliant and executable, and free of syntax errors. \n"
|
| 311 |
+
# )
|
| 312 |
+
|
| 313 |
+
# message_log = [self.system_message(initial_prompt)]
|
| 314 |
+
|
| 315 |
+
# for example in question_sql_list:
|
| 316 |
+
# if example is None:
|
| 317 |
+
# print("example is None")
|
| 318 |
+
# else:
|
| 319 |
+
# if example is not None and "question" in example and "sql" in example:
|
| 320 |
+
# message_log.append(self.user_message(example["question"]))
|
| 321 |
+
# message_log.append(self.assistant_message(example["sql"]))
|
| 322 |
+
|
| 323 |
+
# message_log.append(self.user_message(question))
|
| 324 |
+
|
| 325 |
+
# return message_log
|
requirements.txt
CHANGED
|
@@ -26,4 +26,4 @@ duckdb==1.2.1
|
|
| 26 |
openai==1.61.1
|
| 27 |
pydantic==2.9.2
|
| 28 |
pydantic-settings==2.2.1
|
| 29 |
-
geojson==3.2.0
|
|
|
|
| 26 |
openai==1.61.1
|
| 27 |
pydantic==2.9.2
|
| 28 |
pydantic-settings==2.2.1
|
| 29 |
+
geojson==3.2.0
|
style.css
CHANGED
|
@@ -656,11 +656,20 @@ a {
|
|
| 656 |
/* overflow-y: scroll; */
|
| 657 |
}
|
| 658 |
#sql-query{
|
|
|
|
| 659 |
max-height: 100%;
|
| 660 |
}
|
| 661 |
|
| 662 |
#sql-query textarea{
|
| 663 |
min-height: 200px !important;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 664 |
}
|
| 665 |
|
| 666 |
#sql-query span{
|
|
@@ -741,4 +750,4 @@ div#tab-vanna{
|
|
| 741 |
#example-img-container {
|
| 742 |
flex-direction: column;
|
| 743 |
align-items: left;
|
| 744 |
-
}
|
|
|
|
| 656 |
/* overflow-y: scroll; */
|
| 657 |
}
|
| 658 |
#sql-query{
|
| 659 |
+
<<<<<<< HEAD
|
| 660 |
max-height: 100%;
|
| 661 |
}
|
| 662 |
|
| 663 |
#sql-query textarea{
|
| 664 |
min-height: 200px !important;
|
| 665 |
+
=======
|
| 666 |
+
max-height: 300px;
|
| 667 |
+
overflow-y:scroll;
|
| 668 |
+
}
|
| 669 |
+
|
| 670 |
+
#sql-query textarea{
|
| 671 |
+
min-height: 100px !important;
|
| 672 |
+
>>>>>>> hf-origin/main
|
| 673 |
}
|
| 674 |
|
| 675 |
#sql-query span{
|
|
|
|
| 750 |
#example-img-container {
|
| 751 |
flex-direction: column;
|
| 752 |
align-items: left;
|
| 753 |
+
}
|