FireShadow's picture
Initial clean commit
1721aea
raw
history blame
4.04 kB
import os, re, io, time, json, logging, contextlib, textwrap
from typing import Dict, Any
import pandas as pd
import argparse
from auto_causal.agent import run_causal_analysis
# Constants
RATE_LIMIT_SECONDS = 2
def run_caia(desc, question, df):
return run_causal_analysis(query=question, dataset_path=df, dataset_description=desc)
def parse_args():
parser = argparse.ArgumentParser(description="Run batch causal analysis.")
parser.add_argument("--csv_path", type=str, required=True, help="CSV file with queries, descriptions, and file names.")
parser.add_argument("--data_folder", type=str, required=True, help="Folder containing data CSVs.")
parser.add_argument("--data_category", type=str, required=True, help="Dataset category (e.g., real, qrdata, synthetic).")
parser.add_argument("--output_folder", type=str, required=True, help="Folder to save output.")
parser.add_argument("--llm_name", type=str, required=True, help="Name of the LLM used.")
return parser.parse_args()
def main():
args = parse_args()
csv_meta = args.csv_meta
data_dir = args.data_dir
output_json = args.output_json
os.environ["LLM_MODEL"] = args.llm_name
print("[main] Starting batch processing…")
if not os.path.exists(csv_meta):
logging.error(f"Meta file not found: {csv_meta}")
return
meta_df = pd.read_csv(csv_meta)
print(f"[main] Loaded metadata CSV with {len(meta_df)} rows.")
results: Dict[int, Dict[str, Any]] = {}
for idx, row in meta_df.iterrows():
data_path = os.path.join(data_dir, str(row["data_files"]))
print(f"\n[main] Row {idx+1}/{len(meta_df)} → Dataset: {data_path}")
try:
res = run_caia(
desc=row["data_description"],
question=row["natural_language_query"],
df=data_path,
)
# Format result according to specified structure
formatted_result = {
"query": row["natural_language_query"],
"method": row["method"],
"answer": row["answer"],
"dataset_description": row["data_description"],
"dataset_path": data_path,
"keywords": row.get("keywords", "Causality, Average treatment effect"),
"final_result": {
"method": res['results']['results'].get("method_used"),
"causal_effect": res['results']['results'].get("effect_estimate"),
"standard_deviation": res['results']['results'].get("standard_error"),
"treatment_variable": res['results']['variables'].get("treatment_variable", None),
"outcome_variable": res['results']['variables'].get("outcome_variable", None),
"covariates": res['results']['variables'].get("covariates", []),
"instrument_variable": res['results']['variables'].get("instrument_variable", None),
"running_variable": res['results']['variables'].get("running_variable", None),
"temporal_variable": res['results']['variables'].get("time_variable", None),
"statistical_test_results": res.get("summary", ""),
"explanation_for_model_choice": res.get("explanation", ""),
"regression_equation": res.get("regression_equation", "")
}
}
results[idx] = formatted_result
print(type(res))
print(res)
print(f"[main] Formatted result for row {idx+1}:", formatted_result)
except Exception as e:
logging.error(f"[{idx+1}] Error: {e}")
results[idx] = {"answer": str(e)}
time.sleep(RATE_LIMIT_SECONDS)
os.makedirs(os.path.dirname(output_json), exist_ok=True)
with open(output_json, "w") as f:
json.dump(results, f, indent=2)
print(f"[main] Done. Predictions saved to {output_json}")
if __name__ == "__main__":
main()