import os import sys import json from pathlib import Path import gradio as gr import time # Make your repo importable (expecting a folder named causal-agent at repo root) sys.path.append(str(Path(__file__).parent / "causal-agent")) from auto_causal.agent import run_causal_analysis # uses env for provider/model # -------- LLM config (OpenAI only; key via HF Secrets) -------- os.environ.setdefault("LLM_PROVIDER", "openai") os.environ.setdefault("LLM_MODEL", "gpt-4o") # Lazy import to avoid import-time errors if key missing def _get_openai_client(): if os.getenv("LLM_PROVIDER", "openai") != "openai": raise RuntimeError("Only LLM_PROVIDER=openai is supported in this demo.") if not os.getenv("OPENAI_API_KEY"): raise RuntimeError("Missing OPENAI_API_KEY (set as a Space Secret).") try: # OpenAI SDK v1+ from openai import OpenAI return OpenAI() except Exception as e: raise RuntimeError(f"OpenAI SDK not available: {e}") # -------- System prompt you asked for (verbatim) -------- SYSTEM_PROMPT = """You are an expert in statistics and causal inference. You will be given: 1) The original research question. 2) The analysis method used. 3) The estimated effects, confidence intervals, standard errors, and p-values for each treatment group compared to the control group. 4) A brief dataset description. Your task is to produce a clear, concise, and non-technical summary that: - Directly answers the research question. - States whether the effect is statistically significant. - Quantifies the effect size and explains what it means in practical terms (e.g., percentage point change). - Mentions the method used in one sentence. - Optionally ranks the treatment effects from largest to smallest if multiple treatments exist. Formatting rules: - Use bullet points or short paragraphs. - Report effect sizes to two decimal places. - Clearly state the interpretation in plain English without technical jargon. Example Output Structure: - **Method:** [Name of method + 1-line rationale] - **Key Finding:** [Main answer to the research question] - **Details:** - [Treatment name]: +X.XX percentage points (95% CI: [L, U]), p < 0.001 — [Significance comment] - … - **Rank Order of Effects:** [Largest → Smallest] """ def _extract_minimal_payload(agent_result: dict) -> dict: """ Extract the minimal, LLM-friendly payload from run_causal_analysis output. Falls back gracefully if any fields are missing. """ # Try both top-level and nested (your JSON showed both patterns) res = agent_result or {} results = res.get("results", {}) if isinstance(res.get("results"), dict) else {} inner = results.get("results", {}) if isinstance(results.get("results"), dict) else {} vars_ = results.get("variables", {}) if isinstance(results.get("variables"), dict) else {} dataset_analysis = results.get("dataset_analysis", {}) if isinstance(results.get("dataset_analysis"), dict) else {} # Pull best-available fields question = ( results.get("original_query") or dataset_analysis.get("original_query") or res.get("query") or "N/A" ) method = ( inner.get("method_used") or res.get("method_used") or results.get("method_used") or "N/A" ) effect_estimate = ( inner.get("effect_estimate") or res.get("effect_estimate") or {} ) confidence_interval = ( inner.get("confidence_interval") or res.get("confidence_interval") or {} ) standard_error = ( inner.get("standard_error") or res.get("standard_error") or {} ) p_value = ( inner.get("p_value") or res.get("p_value") or {} ) dataset_desc = ( results.get("dataset_description") or res.get("dataset_description") or "N/A" ) return { "original_question": question, "method_used": method, "estimates": { "effect_estimate": effect_estimate, "confidence_interval": confidence_interval, "standard_error": standard_error, "p_value": p_value, }, "dataset_description": dataset_desc, } def _format_effects_md(effect_estimate: dict) -> str: """ Minimal human-readable view of effect estimates for display. """ if not effect_estimate or not isinstance(effect_estimate, dict): return "_No effect estimates found._" # Render as bullet list lines = [] for k, v in effect_estimate.items(): try: lines.append(f"- **{k}**: {float(v):+.4f}") except Exception: lines.append(f"- **{k}**: {v}") return "\n".join(lines) def _summarize_with_llm(payload: dict) -> str: """ Calls OpenAI with the provided SYSTEM_PROMPT and the JSON payload as the user message. Returns the model's text, or raises on error. """ client = _get_openai_client() model_name = os.getenv("LLM_MODEL", "gpt-4o-mini") user_content = ( "Summarize the following causal analysis results:\n\n" + json.dumps(payload, indent=2, ensure_ascii=False) ) # Use Chat Completions for broad compatibility resp = client.chat.completions.create( model=model_name, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_content}, ], temperature=0 ) text = resp.choices[0].message.content.strip() return text def run_agent(query: str, csv_path: str, dataset_description: str): """ Modified to use yield for progressive updates and immediate feedback """ # Immediate feedback - show processing has started processing_html = """
🔄 Analysis in Progress...
This may take 1-2 minutes depending on dataset size
""" yield ( processing_html, # method_out processing_html, # effects_out processing_html, # explanation_out {"status": "Processing started..."} # raw_results ) # Input validation if not os.getenv("OPENAI_API_KEY"): error_html = "
⚠️ Set a Space Secret named OPENAI_API_KEY
" yield (error_html, "", "", {}) return if not csv_path: error_html = "
Please upload a CSV dataset.
" yield (error_html, "", "", {}) return try: # Update status to show causal analysis is running analysis_html = """
📊 Running Causal Analysis...
Analyzing dataset and selecting optimal method
""" yield ( analysis_html, analysis_html, analysis_html, {"status": "Running causal analysis..."} ) result = run_causal_analysis( query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(), dataset_path=csv_path, dataset_description=(dataset_description or "").strip(), ) # Update to show LLM summarization step llm_html = """
🤖 Generating Summary...
Creating human-readable interpretation
""" yield ( llm_html, llm_html, llm_html, {"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}} ) except Exception as e: error_html = f"
❌ Error: {e}
" yield (error_html, "", "", {}) return try: payload = _extract_minimal_payload(result if isinstance(result, dict) else {}) method = payload.get("method_used", "N/A") # Format method output with simple styling method_html = f"""

Selected Method

{method}

""" # Format effects with simple styling effect_estimate = payload.get("estimates", {}).get("effect_estimate", {}) if effect_estimate: effects_html = "
" effects_html += "

Effect Estimates

" # for k, v in effect_estimate.items(): # try: # value = f"{float(v):+.4f}" # effects_html += f"
{k}: {value}
" # except: effects_html += f"
{effect_estimate}
" effects_html += "
" else: effects_html = "
No effect estimates found
" # Generate explanation and format it try: explanation = _summarize_with_llm(payload) explanation_html = f"""

Detailed Explanation

{explanation}
""" except Exception as e: explanation_html = f"
⚠️ LLM summary failed: {e}
" except Exception as e: error_html = f"
❌ Failed to parse results: {e}
" yield (error_html, "", "", {}) return # Final result yield (method_html, effects_html, explanation_html, result if isinstance(result, dict) else {}) with gr.Blocks() as demo: gr.Markdown("# Causal Agent") gr.Markdown("Upload your dataset and ask causal questions in natural language. The system will automatically select the appropriate causal inference method and provide clear explanations.") with gr.Row(): query = gr.Textbox( label="Your causal question (natural language)", placeholder="e.g., What is the effect of attending the program (T) on income (Y), controlling for education and age?", lines=2, ) with gr.Row(): csv_file = gr.File( label="Dataset (CSV)", file_types=[".csv"], type="filepath" ) dataset_description = gr.Textbox( label="Dataset description (optional)", placeholder="Brief schema, how it was collected, time period, units, treatment/outcome variables, etc.", lines=4, ) run_btn = gr.Button("Run analysis", variant="primary") with gr.Row(): with gr.Column(scale=1): method_out = gr.HTML(label="Selected Method") with gr.Column(scale=1): effects_out = gr.HTML(label="Effect Estimates") with gr.Row(): explanation_out = gr.HTML(label="Detailed Explanation") # Add the collapsible raw results section with gr.Accordion("Raw Results (Advanced)", open=False): raw_results = gr.JSON(label="Complete Analysis Output", show_label=False) run_btn.click( fn=run_agent, inputs=[query, csv_file, dataset_description], outputs=[method_out, effects_out, explanation_out, raw_results], show_progress=True ) gr.Markdown( """ **Tips:** - Be specific about your treatment, outcome, and control variables - Include relevant context in the dataset description - The analysis may take 1-2 minutes for complex datasets """ ) if __name__ == "__main__": demo.queue().launch()