LeonceNsh's picture
Update app.py
b101d35 verified
import pandas as pd
import geopandas as gpd
import gradio as gr
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from scipy.cluster.hierarchy import linkage, leaves_list
# ========================
# Data Loading
# ========================
conus_data = pd.read_csv("conus27.csv")
county_geojson = gpd.read_file("county.geojson")
county_embeddings = pd.read_csv("county_embeddings.csv")
county_unemployment = pd.read_csv("county_unemployment.csv")
zcta_poverty = pd.read_csv("zcta_poverty.csv")
zcta_geojson = gpd.read_file("zcta.geojson")
# ========================
# Data Preparation
# ========================
county_unemployment_melted = county_unemployment.melt(
id_vars=['place'], var_name='date', value_name='unemployment_rate'
)
county_unemployment_melted['place'] = county_unemployment_melted['place'].astype(str)
county_geojson_unemployment = county_geojson.merge(
county_unemployment_melted, on='place', how='left'
)
zcta_poverty_melted = zcta_poverty.melt(
id_vars=['place'], var_name='year', value_name='poverty_rate'
)
zcta_poverty_melted['place'] = zcta_poverty_melted['place'].astype(str)
zcta_geojson['place'] = zcta_geojson['place'].astype(str)
zcta_geojson_poverty = zcta_geojson.merge(
zcta_poverty_melted, on='place', how='left'
)
health_metrics = [c for c in conus_data.columns if c.startswith('Percent_Person_')]
simplified_metrics = [c.replace('Percent_Person_', '') for c in health_metrics]
metric_mapping = dict(zip(simplified_metrics, health_metrics))
if 'place' in conus_data.columns:
merged_health = county_geojson.merge(conus_data, on='place', how='left')
elif 'GEOID' in county_geojson.columns and 'GEOID' in conus_data.columns:
merged_health = county_geojson.merge(conus_data, on='GEOID', how='left')
else:
raise ValueError("No matching key found to merge health data with geodata.")
# ========================
# Plotting Functions
# ========================
def plot_health_metric(metric):
metric_full = metric_mapping[metric]
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
merged_health.plot(
column=metric_full, cmap='viridis', legend=True,
ax=ax, alpha=0.7, edgecolor='black', linewidth=0.5,
missing_kwds={"color": "lightgrey", "label": "No Data"}
)
ax.set_title(f'{metric}: Geographic Distribution', fontsize=15)
ax.axis('off')
plt.tight_layout()
return fig
def plot_health_histogram(metric):
metric_full = metric_mapping[metric]
data = conus_data[metric_full].dropna()
fig, ax = plt.subplots(figsize=(8, 6))
sns.histplot(data, kde=True, color='teal', ax=ax)
ax.set_title(f'{metric}: Value Distribution', fontsize=15)
ax.set_xlabel(f'{metric} (%)')
ax.set_ylabel('Number of Counties')
plt.tight_layout()
return fig
def summarize_health_metrics(metric):
metric_full = metric_mapping[metric]
data = conus_data[metric_full].dropna()
desc = data.describe().to_frame().reset_index()
desc.columns = ['Statistic', 'Value']
median_val = data.median()
q1, q3 = data.quantile([0.25, 0.75])
iqr = q3 - q1
extra = pd.DataFrame({'Statistic': ['Median', 'IQR'], 'Value': [median_val, iqr]})
return pd.concat([desc, extra], ignore_index=True)
def plot_correlation_matrix(metrics):
selected = [metric_mapping[m] for m in metrics]
corr = conus_data[selected].corr()
linkage_matrix = linkage(1 - corr, method='average')
idx = leaves_list(linkage_matrix)
corr = corr.iloc[idx, :].iloc[:, idx]
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(
corr, annot=True, cmap='coolwarm', square=True, ax=ax,
xticklabels=corr.columns, yticklabels=corr.columns, cbar_kws={"shrink": .8}
)
ax.set_title('Health Metric Correlations', fontsize=15)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
return fig
def plot_unemployment_map(date):
data = county_geojson_unemployment[county_geojson_unemployment['date'] == str(date)]
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
data.plot(
column='unemployment_rate', cmap='YlGnBu', linewidth=0.5,
ax=ax, edgecolor='0.8', legend=True,
missing_kwds={"color": "lightgrey", "label": "Missing values"},
)
ax.set_title(f'Unemployment Rate by County: {date}', fontsize=15)
ax.axis('off')
plt.tight_layout()
return fig
def plot_poverty_map(year):
data = zcta_geojson_poverty[zcta_geojson_poverty['year'] == str(year)]
fig, ax = plt.subplots(1, 1, figsize=(12, 8))
data.plot(
column='poverty_rate', cmap='YlOrRd', linewidth=0.5,
ax=ax, edgecolor='0.8', legend=True,
missing_kwds={"color": "lightgrey", "label": "Missing values"},
)
ax.set_title(f'Poverty Rate by ZCTA: {year}', fontsize=15)
ax.axis('off')
plt.tight_layout()
return fig
# ========================
# Gradio Interface Logic
# ========================
def health_metric_interface(metric):
return plot_health_metric(metric), summarize_health_metrics(metric), plot_health_histogram(metric)
def correlation_interface(metrics):
if len(metrics) < 2:
return "Please select at least two metrics to see a correlation matrix."
return plot_correlation_matrix(metrics)
# ========================
# Gradio App Setup
# ========================
with gr.Blocks(title="US Population Health Dashboard") as demo:
gr.Markdown("# US Population Health Dashboard")
gr.Markdown("""
A comprehensive visualization platform for analyzing county-level and ZCTA-level health and socioeconomic indicators across the United States.
""")
with gr.Tab("Health Metrics"):
gr.Markdown("### Analyze Individual Health Metrics")
gr.Markdown("Select a health metric to view its geographic distribution, summary statistics, and value distribution.")
m = gr.Dropdown(label="Select a Health Metric", choices=simplified_metrics, value=simplified_metrics[0])
m_plot = gr.Plot()
m_summary = gr.Dataframe(headers=["Statistic", "Value"])
m_hist = gr.Plot()
m.change(health_metric_interface, m, [m_plot, m_summary, m_hist])
with gr.Tab("Metric Correlations"):
gr.Markdown("### Explore Relationships Between Metrics")
gr.Markdown("Select multiple health metrics to view their correlations. Metrics are reordered using hierarchical clustering.")
cm = gr.CheckboxGroup(choices=simplified_metrics, value=simplified_metrics[:5])
cm_plot = gr.Plot()
cm.change(correlation_interface, cm, cm_plot)
with gr.Tab("Unemployment Trends"):
gr.Markdown("### Track County Unemployment Over Time")
gr.Markdown("Select a date to see unemployment rate distribution across counties.")
d = sorted(county_unemployment_melted['date'].unique())
ud = gr.Dropdown(choices=d, value=d[0])
up = gr.Plot()
ud.change(lambda x: plot_unemployment_map(x), ud, up)
with gr.Tab("Poverty Trends"):
gr.Markdown("### Track ZCTA Poverty Over Time")
gr.Markdown("Select a year to see poverty rate distribution across ZIP Code Tabulation Areas.")
y = sorted(zcta_poverty_melted['year'].unique())
py = gr.Dropdown(choices=y, value=y[0])
pp = gr.Plot()
py.change(lambda x: plot_poverty_map(x), py, pp)
if __name__ == "__main__":
demo.launch()