Spaces:
Running
Running
| import os | |
| import sys | |
| import json | |
| from pathlib import Path | |
| import gradio as gr | |
| import time | |
| # Make your repo importable (expecting a folder named causal-agent at repo root) | |
| sys.path.append(str(Path(__file__).parent / "causal-agent")) | |
| from auto_causal.agent import run_causal_analysis # uses env for provider/model | |
| # -------- LLM config (OpenAI only; key via HF Secrets) -------- | |
| os.environ.setdefault("LLM_PROVIDER", "openai") | |
| os.environ.setdefault("LLM_MODEL", "gpt-4o") | |
| # Lazy import to avoid import-time errors if key missing | |
| def _get_openai_client(): | |
| if os.getenv("LLM_PROVIDER", "openai") != "openai": | |
| raise RuntimeError("Only LLM_PROVIDER=openai is supported in this demo.") | |
| if not os.getenv("OPENAI_API_KEY"): | |
| raise RuntimeError("Missing OPENAI_API_KEY (set as a Space Secret).") | |
| try: | |
| # OpenAI SDK v1+ | |
| from openai import OpenAI | |
| return OpenAI() | |
| except Exception as e: | |
| raise RuntimeError(f"OpenAI SDK not available: {e}") | |
| # -------- System prompt you asked for (verbatim) -------- | |
| SYSTEM_PROMPT = """You are an expert in statistics and causal inference. | |
| You will be given: | |
| 1) The original research question. | |
| 2) The analysis method used. | |
| 3) The estimated effects, confidence intervals, standard errors, and p-values for each treatment group compared to the control group. | |
| 4) A brief dataset description. | |
| Your task is to produce a clear, concise, and non-technical summary that: | |
| - Directly answers the research question. | |
| - States whether the effect is statistically significant. | |
| - Quantifies the effect size and explains what it means in practical terms (e.g., percentage point change). | |
| - Mentions the method used in one sentence. | |
| - Optionally ranks the treatment effects from largest to smallest if multiple treatments exist. | |
| Formatting rules: | |
| - Use bullet points or short paragraphs. | |
| - Report effect sizes to two decimal places. | |
| - Clearly state the interpretation in plain English without technical jargon. | |
| Example Output Structure: | |
| - **Method:** [Name of method + 1-line rationale] | |
| - **Key Finding:** [Main answer to the research question] | |
| - **Details:** | |
| - [Treatment name]: +X.XX percentage points (95% CI: [L, U]), p < 0.001 β [Significance comment] | |
| - β¦ | |
| - **Rank Order of Effects:** [Largest β Smallest] | |
| """ | |
| def _extract_minimal_payload(agent_result: dict) -> dict: | |
| """ | |
| Extract the minimal, LLM-friendly payload from run_causal_analysis output. | |
| Falls back gracefully if any fields are missing. | |
| """ | |
| # Try both top-level and nested (your JSON showed both patterns) | |
| res = agent_result or {} | |
| results = res.get("results", {}) if isinstance(res.get("results"), dict) else {} | |
| inner = results.get("results", {}) if isinstance(results.get("results"), dict) else {} | |
| vars_ = results.get("variables", {}) if isinstance(results.get("variables"), dict) else {} | |
| dataset_analysis = results.get("dataset_analysis", {}) if isinstance(results.get("dataset_analysis"), dict) else {} | |
| # Pull best-available fields | |
| question = ( | |
| results.get("original_query") | |
| or dataset_analysis.get("original_query") | |
| or res.get("query") | |
| or "N/A" | |
| ) | |
| method = ( | |
| inner.get("method_used") | |
| or res.get("method_used") | |
| or results.get("method_used") | |
| or "N/A" | |
| ) | |
| effect_estimate = ( | |
| inner.get("effect_estimate") | |
| or res.get("effect_estimate") | |
| or {} | |
| ) | |
| confidence_interval = ( | |
| inner.get("confidence_interval") | |
| or res.get("confidence_interval") | |
| or {} | |
| ) | |
| standard_error = ( | |
| inner.get("standard_error") | |
| or res.get("standard_error") | |
| or {} | |
| ) | |
| p_value = ( | |
| inner.get("p_value") | |
| or res.get("p_value") | |
| or {} | |
| ) | |
| dataset_desc = ( | |
| results.get("dataset_description") | |
| or res.get("dataset_description") | |
| or "N/A" | |
| ) | |
| return { | |
| "original_question": question, | |
| "method_used": method, | |
| "estimates": { | |
| "effect_estimate": effect_estimate, | |
| "confidence_interval": confidence_interval, | |
| "standard_error": standard_error, | |
| "p_value": p_value, | |
| }, | |
| "dataset_description": dataset_desc, | |
| } | |
| def _format_effects_md(effect_estimate: dict) -> str: | |
| """ | |
| Minimal human-readable view of effect estimates for display. | |
| """ | |
| if not effect_estimate or not isinstance(effect_estimate, dict): | |
| return "_No effect estimates found._" | |
| # Render as bullet list | |
| lines = [] | |
| for k, v in effect_estimate.items(): | |
| try: | |
| lines.append(f"- **{k}**: {float(v):+.4f}") | |
| except Exception: | |
| lines.append(f"- **{k}**: {v}") | |
| return "\n".join(lines) | |
| def _summarize_with_llm(payload: dict) -> str: | |
| """ | |
| Calls OpenAI with the provided SYSTEM_PROMPT and the JSON payload as the user message. | |
| Returns the model's text, or raises on error. | |
| """ | |
| client = _get_openai_client() | |
| model_name = os.getenv("LLM_MODEL", "gpt-4o-mini") | |
| user_content = ( | |
| "Summarize the following causal analysis results:\n\n" | |
| + json.dumps(payload, indent=2, ensure_ascii=False) | |
| ) | |
| # Use Chat Completions for broad compatibility | |
| resp = client.chat.completions.create( | |
| model=model_name, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_content}, | |
| ], | |
| temperature=0 | |
| ) | |
| text = resp.choices[0].message.content.strip() | |
| return text | |
| def run_agent(query: str, csv_path: str, dataset_description: str): | |
| """ | |
| Modified to use yield for progressive updates and immediate feedback | |
| """ | |
| # Immediate feedback - show processing has started | |
| processing_html = """ | |
| <div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'> | |
| <div style='font-size: 16px; margin-bottom: 5px;'>π Analysis in Progress...</div> | |
| <div style='font-size: 14px; color: #666;'>This may take 1-2 minutes depending on dataset size</div> | |
| </div> | |
| """ | |
| yield ( | |
| processing_html, # method_out | |
| processing_html, # effects_out | |
| processing_html, # explanation_out | |
| {"status": "Processing started..."} # raw_results | |
| ) | |
| # Input validation | |
| if not os.getenv("OPENAI_API_KEY"): | |
| error_html = "<div style='padding: 10px; border: 1px solid #dc3545; border-radius: 5px; color: #dc3545; background-color: #333333;'>β οΈ Set a Space Secret named OPENAI_API_KEY</div>" | |
| yield (error_html, "", "", {}) | |
| return | |
| if not csv_path: | |
| error_html = "<div style='padding: 10px; border: 1px solid #ffc107; border-radius: 5px; color: #856404; background-color: #333333;'>Please upload a CSV dataset.</div>" | |
| yield (error_html, "", "", {}) | |
| return | |
| try: | |
| # Update status to show causal analysis is running | |
| analysis_html = """ | |
| <div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'> | |
| <div style='font-size: 16px; margin-bottom: 5px;'>π Running Causal Analysis...</div> | |
| <div style='font-size: 14px; color: #666;'>Analyzing dataset and selecting optimal method</div> | |
| </div> | |
| """ | |
| yield ( | |
| analysis_html, | |
| analysis_html, | |
| analysis_html, | |
| {"status": "Running causal analysis..."} | |
| ) | |
| result = run_causal_analysis( | |
| query=(query or "What is the effect of treatment T on outcome Y controlling for X?").strip(), | |
| dataset_path=csv_path, | |
| dataset_description=(dataset_description or "").strip(), | |
| ) | |
| # Update to show LLM summarization step | |
| llm_html = """ | |
| <div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'> | |
| <div style='font-size: 16px; margin-bottom: 5px;'>π€ Generating Summary...</div> | |
| <div style='font-size: 14px; color: #666;'>Creating human-readable interpretation</div> | |
| </div> | |
| """ | |
| yield ( | |
| llm_html, | |
| llm_html, | |
| llm_html, | |
| {"status": "Generating explanation...", "raw_analysis": result if isinstance(result, dict) else {}} | |
| ) | |
| except Exception as e: | |
| error_html = f"<div style='padding: 10px; border: 1px solid #dc3545; border-radius: 5px; color: #dc3545; background-color: #333333;'>β Error: {e}</div>" | |
| yield (error_html, "", "", {}) | |
| return | |
| try: | |
| payload = _extract_minimal_payload(result if isinstance(result, dict) else {}) | |
| method = payload.get("method_used", "N/A") | |
| # Format method output with simple styling | |
| method_html = f""" | |
| <div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'> | |
| <h3 style='margin: 0 0 10px 0; font-size: 18px;'>Selected Method</h3> | |
| <p style='margin: 0; font-size: 16px;'>{method}</p> | |
| </div> | |
| """ | |
| # Format effects with simple styling | |
| effect_estimate = payload.get("estimates", {}).get("effect_estimate", {}) | |
| if effect_estimate: | |
| effects_html = "<div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'>" | |
| effects_html += "<h3 style='margin: 0 0 10px 0; font-size: 18px;'>Effect Estimates</h3>" | |
| # for k, v in effect_estimate.items(): | |
| # try: | |
| # value = f"{float(v):+.4f}" | |
| # effects_html += f"<div style='margin: 8px 0; padding: 8px; border: 1px solid #eee; border-radius: 4px; background-color: #ffffff;'><strong>{k}:</strong> <span style='font-size: 16px;'>{value}</span></div>" | |
| # except: | |
| effects_html += f"<div style='margin: 8px 0; padding: 8px; border: 1px solid #eee; border-radius: 4px; background-color: #333333;'>{effect_estimate}</div>" | |
| effects_html += "</div>" | |
| else: | |
| effects_html = "<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; color: #666; font-style: italic; background-color: #333333;'>No effect estimates found</div>" | |
| # Generate explanation and format it | |
| try: | |
| explanation = _summarize_with_llm(payload) | |
| explanation_html = f""" | |
| <div style='padding: 15px; border: 1px solid #ddd; border-radius: 8px; margin: 5px 0; background-color: #333333;'> | |
| <h3 style='margin: 0 0 15px 0; font-size: 18px;'>Detailed Explanation</h3> | |
| <div style='line-height: 1.6; white-space: pre-wrap;'>{explanation}</div> | |
| </div> | |
| """ | |
| except Exception as e: | |
| explanation_html = f"<div style='padding: 10px; border: 1px solid #ffc107; border-radius: 5px; color: #856404; background-color: #333333;'>β οΈ LLM summary failed: {e}</div>" | |
| except Exception as e: | |
| error_html = f"<div style='padding: 10px; border: 1px solid #dc3545; border-radius: 5px; color: #dc3545; background-color: #333333;'>β Failed to parse results: {e}</div>" | |
| yield (error_html, "", "", {}) | |
| return | |
| # Final result | |
| yield (method_html, effects_html, explanation_html, result if isinstance(result, dict) else {}) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Causal Agent") | |
| gr.Markdown("Upload your dataset and ask causal questions in natural language. The system will automatically select the appropriate causal inference method and provide clear explanations.") | |
| with gr.Row(): | |
| query = gr.Textbox( | |
| label="Your causal question (natural language)", | |
| placeholder="e.g., What is the effect of attending the program (T) on income (Y), controlling for education and age?", | |
| lines=2, | |
| ) | |
| with gr.Row(): | |
| csv_file = gr.File( | |
| label="Dataset (CSV)", | |
| file_types=[".csv"], | |
| type="filepath" | |
| ) | |
| dataset_description = gr.Textbox( | |
| label="Dataset description (optional)", | |
| placeholder="Brief schema, how it was collected, time period, units, treatment/outcome variables, etc.", | |
| lines=4, | |
| ) | |
| run_btn = gr.Button("Run analysis", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| method_out = gr.HTML(label="Selected Method") | |
| with gr.Column(scale=1): | |
| effects_out = gr.HTML(label="Effect Estimates") | |
| with gr.Row(): | |
| explanation_out = gr.HTML(label="Detailed Explanation") | |
| # Add the collapsible raw results section | |
| with gr.Accordion("Raw Results (Advanced)", open=False): | |
| raw_results = gr.JSON(label="Complete Analysis Output", show_label=False) | |
| run_btn.click( | |
| fn=run_agent, | |
| inputs=[query, csv_file, dataset_description], | |
| outputs=[method_out, effects_out, explanation_out, raw_results], | |
| show_progress=True | |
| ) | |
| gr.Markdown( | |
| """ | |
| **Tips:** | |
| - Be specific about your treatment, outcome, and control variables | |
| - Include relevant context in the dataset description | |
| - The analysis may take 1-2 minutes for complex datasets | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch() |