from __future__ import annotations import os import json from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple import numpy as np import pandas as pd from loguru import logger from sklearn.neighbors import NearestNeighbors from sentence_transformers import SentenceTransformer from app.utils.config import settings from app.utils.helpers import normalize_gender, clean_diagnosis @dataclass class SimilarCase: summary_id: str diagnosis: Optional[str] age: Optional[float] gender: Optional[str] summary_snippet: str similarity_score: float class CameroonMedicalData: """ Load, clean, analyze and search medical summaries specialized for the Cameroonian context. Designed for ~45k rows. Caches embeddings and lightweight stats. """ def __init__(self, csv_path: Optional[str] = None): self.csv_path = csv_path or settings.CAMEROON_DATA_CSV if not self.csv_path or not os.path.exists(self.csv_path): logger.warning("CameroonMedicalData: CSV path missing or not found. Set CAMEROON_DATA_CSV in .env") self.df = pd.DataFrame() else: self.df = self._load_csv(self.csv_path, settings.CAMEROON_MAX_ROWS) self._cleaned: bool = False self._model: Optional[SentenceTransformer] = None self._embeddings: Optional[np.ndarray] = None self._nn: Optional[NearestNeighbors] = None self._cache_dir = settings.CAMEROON_CACHE_DIR os.makedirs(self._cache_dir, exist_ok=True) # ----------------------- Data Loading & Cleaning ----------------------- def _load_csv(self, path: str, limit: Optional[int]) -> pd.DataFrame: df = pd.read_csv(path) if limit and limit > 0: df = df.head(limit) return df def clean(self) -> None: if self.df.empty: self._cleaned = True return df = self.df.copy() # Standardize column names expected_cols = [ "summary_id","patient_id","patient_age","patient_gender","diagnosis", "body_temp_c","blood_pressure_systolic","heart_rate","summary_text","date_recorded" ] missing = [c for c in expected_cols if c not in df.columns] if missing: raise ValueError(f"Missing required columns: {missing}") # Parse dates df["date_recorded"] = pd.to_datetime(df["date_recorded"], errors="coerce") # Handle missing values df["patient_gender"] = df["patient_gender"].fillna("") df["diagnosis"] = df["diagnosis"].fillna("") df["summary_text"] = df["summary_text"].fillna("") # Normalize gender and diagnosis df["patient_gender_norm"] = df["patient_gender"].apply(lambda v: normalize_gender(str(v))) df["diagnosis_norm"] = df["diagnosis"].apply(lambda v: clean_diagnosis(str(v))) # Coerce numeric vitals for col in ["patient_age","body_temp_c","blood_pressure_systolic","heart_rate"]: df[col] = pd.to_numeric(df[col], errors="coerce") # Drop rows with no summary text and no diagnosis df = df[~((df["summary_text"].str.len() == 0) & (df["diagnosis_norm"].isna()))] self.df = df.reset_index(drop=True) self._cleaned = True # ----------------------------- Statistics ----------------------------- def stats_overview(self) -> Dict[str, Any]: if not self._cleaned: self.clean() if self.df.empty: return {"total_rows": 0} df = self.df top_diagnoses = ( df["diagnosis_norm"].value_counts(dropna=True).head(20).dropna().to_dict() ) age_desc = df["patient_age"].describe().fillna(0).to_dict() return { "total_rows": int(len(df)), "top_diagnoses": top_diagnoses, "age_stats": age_desc, "gender_distribution": df["patient_gender_norm"].value_counts(dropna=True).to_dict(), } def stats_disease(self, disease_name: str) -> Dict[str, Any]: if not self._cleaned: self.clean() if self.df.empty: return {"disease": disease_name, "total_cases": 0} df = self.df mask = df["diagnosis_norm"] == disease_name.lower() subset = df[mask] total = int(len(subset)) # Age buckets bins = [-1, 18, 35, 60, 200] labels = ["0-18", "19-35", "36-60", "60+"] ages = pd.cut(subset["patient_age"], bins=bins, labels=labels) age_dist = ages.value_counts().reindex(labels, fill_value=0).to_dict() gender_dist = subset["patient_gender_norm"].value_counts().to_dict() # Common symptom terms (very simple proxy: frequent tokens in summary_text) common_symptoms = self._extract_common_terms(subset["summary_text"].tolist(), top_k=15) return { "disease": disease_name, "total_cases": total, "age_distribution": age_dist, "gender_distribution": gender_dist, "common_symptoms": common_symptoms, } def seasonal_patterns(self) -> Dict[str, int]: if not self._cleaned: self.clean() if self.df.empty: return {} df = self.df.dropna(subset=["date_recorded"]).copy() df["month"] = df["date_recorded"].dt.month counts = df["month"].value_counts().sort_index() # map month numbers to english lowercase names for consistency months = ["january","february","march","april","may","june","july","august","september","october","november","december"] return {months[i-1]: int(counts.get(i, 0)) for i in range(1, 13)} def age_gender_distribution(self) -> Dict[str, Any]: if not self._cleaned: self.clean() if self.df.empty: return {"age_buckets": {}, "gender_distribution": {}} df = self.df bins = [-1, 18, 35, 60, 200] labels = ["0-18", "19-35", "36-60", "60+"] ages = pd.cut(df["patient_age"], bins=bins, labels=labels) age_dist = ages.value_counts().reindex(labels, fill_value=0).to_dict() gender_dist = df["patient_gender_norm"].value_counts().to_dict() return {"age_buckets": age_dist, "gender_distribution": gender_dist} # --------------------------- Semantic Similarity --------------------------- def _ensure_embeddings(self) -> None: if self._embeddings is not None and self._nn is not None: return if not self._cleaned: self.clean() if self.df.empty: self._embeddings = np.zeros((0, 384), dtype=np.float32) self._nn = None return # Load model lazily if self._model is None: model_name = settings.CAMEROON_EMBEDDINGS_MODEL logger.info(f"Loading sentence-transformers model: {model_name}") self._model = SentenceTransformer(model_name) cache_file = os.path.join(self._cache_dir, "embeddings.npy") if os.path.exists(cache_file): try: self._embeddings = np.load(cache_file) except Exception: self._embeddings = None if self._embeddings is None or len(self._embeddings) != len(self.df): texts = self.df["summary_text"].astype(str).tolist() self._embeddings = self._model.encode(texts, batch_size=64, show_progress_bar=False, normalize_embeddings=True) np.save(cache_file, self._embeddings) # Build NN index self._nn = NearestNeighbors(n_neighbors=10, metric="cosine") self._nn.fit(self._embeddings) def search_similar_cases(self, query_text: str, top_k: int = 10) -> List[SimilarCase]: if not query_text or query_text.strip() == "": return [] self._ensure_embeddings() if self._model is None or self._nn is None or self._embeddings is None or self.df.empty: return [] q = self._model.encode([query_text], normalize_embeddings=True) distances, indices = self._nn.kneighbors(q, n_neighbors=min(top_k, len(self.df))) distances = distances[0] indices = indices[0] results: List[SimilarCase] = [] for dist, idx in zip(distances, indices): row = self.df.iloc[int(idx)] # similarity = 1 - cosine distance sim = float(1.0 - dist) snippet = str(row.get("summary_text", ""))[:140] + ("..." if len(str(row.get("summary_text", ""))) > 140 else "") results.append(SimilarCase( summary_id=str(row.get("summary_id", "")), diagnosis=row.get("diagnosis_norm"), age=float(row.get("patient_age")) if pd.notna(row.get("patient_age")) else None, gender=row.get("patient_gender_norm"), summary_snippet=snippet, similarity_score=sim, )) return results # ----------------------------- Utils ----------------------------- def _extract_common_terms(self, texts: List[str], top_k: int = 20) -> List[str]: # Very naive bag-of-words; in production consider medical entity extraction. from collections import Counter tokens: List[str] = [] for t in texts: for w in str(t).lower().replace(",", " ").replace(".", " ").split(): if len(w) >= 3 and w.isalpha(): tokens.append(w) return [w for w, _ in Counter(tokens).most_common(top_k)] # Singleton accessor _singleton: Optional[CameroonMedicalData] = None def get_cameroon_data() -> CameroonMedicalData: global _singleton if _singleton is None: _singleton = CameroonMedicalData() return _singleton