# interactive_plot_generator.py
# Generate interactive air pollution maps for India with hover information
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
import geopandas as gpd
from pathlib import Path
from datetime import datetime
from constants import INDIA_BOUNDS, COLOR_THEMES
import plotly.io as pio
import warnings
warnings.filterwarnings('ignore')
class InteractiveIndiaMapPlotter:
def __init__(self, plots_dir="plots", shapefile_path="shapefiles/India_State_Boundary.shp"):
"""
Initialize the interactive map plotter
Parameters:
plots_dir (str): Directory to save plots
shapefile_path (str): Path to the India districts shapefile
"""
self.plots_dir = Path(plots_dir)
self.plots_dir.mkdir(exist_ok=True)
try:
self.india_map = gpd.read_file(shapefile_path)
# Ensure it's in lat/lon (WGS84)
if self.india_map.crs is not None and self.india_map.crs.to_epsg() != 4326:
self.india_map = self.india_map.to_crs(epsg=4326)
except Exception as e:
raise FileNotFoundError(f"Could not read the shapefile at '{shapefile_path}'. "
f"Please ensure the file exists. Error: {e}")
def create_india_map(self, data_values, metadata, color_theme=None, save_plot=True, custom_title=None):
"""
Create interactive air pollution map over India with hover information
Parameters:
data_values (np.ndarray): 2D array of pollution data
metadata (dict): Metadata containing lats, lons, variable info, etc.
color_theme (str): Color theme name from COLOR_THEMES
save_plot (bool): Whether to save the plot as HTML and PNG
custom_title (str): Custom title for the plot
Returns:
dict: Dictionary containing paths to saved files and HTML content
- 'html_path': Path to interactive HTML file
- 'png_path': Path to static PNG file
- 'html_content': HTML content for embedding
"""
try:
# Extract metadata
lats = metadata['lats']
lons = metadata['lons']
var_name = metadata['variable_name']
display_name = metadata['display_name']
units = metadata['units']
pressure_level = metadata.get('pressure_level')
time_stamp = metadata.get('timestamp_str')
# Determine color theme
if color_theme is None:
from constants import AIR_POLLUTION_VARIABLES
color_theme = AIR_POLLUTION_VARIABLES.get(var_name, {}).get('cmap', 'viridis')
# Map matplotlib colormaps to Plotly colormaps
# This mapping ensures all COLOR_THEMES from constants.py are supported
colormap_mapping = {
# Sequential color schemes
'viridis': 'Viridis',
'plasma': 'Plasma',
'inferno': 'Inferno',
'magma': 'Magma',
'cividis': 'Cividis',
# Single-hue sequential schemes
'YlOrRd': 'YlOrRd',
'Oranges': 'Oranges',
'Reds': 'Reds',
'Purples': 'Purples',
'Blues': 'Blues',
'Greens': 'Greens',
# Diverging schemes
'coolwarm': 'RdBu_r',
'RdYlBu': 'RdYlBu',
'Spectral': 'Spectral',
'Spectral_r': 'Spectral_r',
'RdYlGn_r': 'RdYlGn_r',
# Other schemes
'jet': 'Jet',
'turbo': 'Turbo'
}
plotly_colorscale = colormap_mapping.get(color_theme, 'Viridis')
# Create mesh grid if needed
if lons.ndim == 1 and lats.ndim == 1:
lon_grid, lat_grid = np.meshgrid(lons, lats)
else:
lon_grid, lat_grid = lons, lats
# Calculate statistics
valid_data = data_values[~np.isnan(data_values)]
if len(valid_data) == 0:
raise ValueError("All data values are NaN - cannot create plot")
from constants import AIR_POLLUTION_VARIABLES
vmax_percentile = AIR_POLLUTION_VARIABLES.get(var_name, {}).get('vmax_percentile', 90)
vmin = np.nanpercentile(valid_data, 5)
vmax = np.nanpercentile(valid_data, vmax_percentile)
if vmax <= vmin:
vmax = vmin + 1.0
# Create hover text with detailed information
hover_text = self._create_hover_text(lon_grid, lat_grid, data_values, display_name, units)
# Create the figure
fig = go.Figure()
# Add pollution data as heatmap
fig.add_trace(go.Heatmap(
x=lons,
y=lats,
z=data_values,
colorscale=plotly_colorscale,
zmin=vmin,
zmax=vmax,
hovertext=hover_text,
hoverinfo='text',
colorbar=dict(
title=dict(
text=f"{display_name}" + (f"
({units})" if units else ""),
side="right"
),
thickness=20,
len=0.6,
x=1.02
)
))
# Add India state boundaries
for _, row in self.india_map.iterrows():
if row.geometry.geom_type == 'Polygon':
self._add_polygon_trace(fig, row.geometry)
elif row.geometry.geom_type == 'MultiPolygon':
for polygon in row.geometry.geoms:
self._add_polygon_trace(fig, polygon)
# Create title - include pressure level and plot type
if custom_title:
title = custom_title
else:
title = f'{display_name} Concentration over India (Interactive)'
if pressure_level:
title += f' at {pressure_level} hPa'
title += f' on {time_stamp}'
# Calculate stats for annotation
stats_text = self._create_stats_text(valid_data, units)
theme_name = COLOR_THEMES.get(color_theme, color_theme)
# Auto-adjust bounds if needed
xmin, ymin, xmax, ymax = self.india_map.total_bounds
if not (INDIA_BOUNDS['lon_min'] <= xmin <= INDIA_BOUNDS['lon_max']):
lon_range = [xmin, xmax]
lat_range = [ymin, ymax]
else:
lon_range = [INDIA_BOUNDS['lon_min'], INDIA_BOUNDS['lon_max']]
lat_range = [INDIA_BOUNDS['lat_min'], INDIA_BOUNDS['lat_max']]
# Update layout for better interactivity
fig.update_layout(
title=dict(
text=title,
x=0.5,
xanchor='center',
font=dict(size=18, weight='bold')
),
xaxis=dict(
title='Longitude',
range=lon_range,
showgrid=True,
gridcolor='rgba(128, 128, 128, 0.3)',
zeroline=False
),
yaxis=dict(
title='Latitude',
range=lat_range,
showgrid=True,
gridcolor='rgba(128, 128, 128, 0.3)',
zeroline=False,
scaleanchor="x",
scaleratio=1 # Simplified to match static plot aspect ratio
),
width=1400,
height=1000,
plot_bgcolor='white',
# Enable zoom, pan and other interactive features
dragmode='zoom',
showlegend=False,
hovermode='closest',
# Add modebar with download options
modebar=dict(
bgcolor='rgba(255, 255, 255, 0.8)',
activecolor='rgb(0, 123, 255)',
orientation='h'
),
annotations=[
# Statistics box
dict(
text=stats_text.replace('\n', '
'),
xref='paper', yref='paper',
x=0.02, y=0.98,
xanchor='left', yanchor='top',
showarrow=False,
bgcolor='rgba(255, 255, 255, 0.9)',
bordercolor='black',
borderwidth=1,
borderpad=10,
font=dict(size=11)
),
# Theme info box
dict(
text=f'Color Theme: {theme_name}',
xref='paper', yref='paper',
x=0.98, y=0.02,
xanchor='right', yanchor='bottom',
showarrow=False,
bgcolor='rgba(211, 211, 211, 0.8)',
bordercolor='gray',
borderwidth=1,
borderpad=8,
font=dict(size=10)
),
# Instructions
dict(
text='๐ Zoom: Mouse wheel or zoom tool | ๐ Hover: Show coordinates & values | ๐ฅ Download: Camera icon',
xref='paper', yref='paper',
x=0.5, y=0.02,
xanchor='center', yanchor='bottom',
showarrow=False,
bgcolor='rgba(173, 216, 230, 0.8)',
bordercolor='steelblue',
borderwidth=1,
borderpad=8,
font=dict(size=10, color='darkblue')
)
]
)
# Configure the figure for better interactivity and downloads
config = {
'displayModeBar': True,
'displaylogo': False,
'modeBarButtonsToAdd': [
'drawline',
'drawopenpath',
'drawclosedpath',
'drawcircle',
'drawrect',
'eraseshape'
],
'modeBarButtonsToRemove': ['lasso2d', 'select2d'],
'toImageButtonOptions': {
'format': 'png',
'filename': f'india_pollution_map_{datetime.now().strftime("%Y%m%d_%H%M%S")}',
'height': 1000,
'width': 1400,
'scale': 2
},
'responsive': True
}
# Save files if requested
result = {'html_content': None, 'html_path': None, 'png_path': None}
if save_plot:
# Generate HTML content for embedding
html_content = pio.to_html(
fig,
config=config,
include_plotlyjs='cdn',
div_id='interactive-plot',
full_html=False
)
result['html_content'] = html_content
# Save as HTML file
html_path = self._save_html_plot(fig, var_name, display_name, pressure_level, color_theme, time_stamp, config)
result['html_path'] = html_path
# Save as PNG for fallback (only if kaleido works)
png_path = self._save_png_plot(fig, var_name, display_name, pressure_level, color_theme, time_stamp)
result['png_path'] = png_path
else:
# Just return HTML content for display
html_content = pio.to_html(
fig,
config=config,
include_plotlyjs='cdn',
div_id='interactive-plot',
full_html=False
)
result['html_content'] = html_content
return result
except Exception as e:
raise Exception(f"Error creating interactive map: {str(e)}")
def _add_polygon_trace(self, fig, polygon):
"""Add a polygon boundary to the figure"""
x, y = polygon.exterior.xy
fig.add_trace(go.Scatter(
x=list(x),
y=list(y),
mode='lines',
line=dict(color='black', width=1),
hoverinfo='skip',
showlegend=False
))
def _create_hover_text(self, lon_grid, lat_grid, data_values, display_name, units):
"""Create formatted hover text for each point"""
hover_text = np.empty(data_values.shape, dtype=object)
units_str = f" {units}" if units else ""
for i in range(data_values.shape[0]):
for j in range(data_values.shape[1]):
lat = lat_grid[i, j] if lat_grid.ndim == 2 else lat_grid[i]
lon = lon_grid[i, j] if lon_grid.ndim == 2 else lon_grid[j]
value = data_values[i, j]
if np.isnan(value):
value_str = "N/A"
elif abs(value) >= 1000:
value_str = f"{value:.0f}{units_str}"
elif abs(value) >= 10:
value_str = f"{value:.1f}{units_str}"
else:
value_str = f"{value:.2f}{units_str}"
hover_text[i, j] = (
f"{display_name}: {value_str}
"
f"Latitude: {lat:.3f}ยฐ
"
f"Longitude: {lon:.3f}ยฐ"
)
return hover_text
def _create_stats_text(self, data, units):
"""Create statistics text for annotation"""
units_str = f" {units}" if units else ""
stats = {
'Min': np.nanmin(data),
'Max': np.nanmax(data),
'Mean': np.nanmean(data),
'Median': np.nanmedian(data),
'Std': np.nanstd(data)
}
def format_number(val):
if abs(val) >= 1000:
return f"{val:.0f}"
elif abs(val) >= 10:
return f"{val:.1f}"
else:
return f"{val:.2f}"
stats_lines = [f"{name}: {format_number(val)}{units_str}" for name, val in stats.items()]
return "\n".join(stats_lines)
def _save_html_plot(self, fig, var_name, display_name, pressure_level, color_theme, time_stamp, config):
"""Save the interactive plot as HTML"""
# Handle None values with fallbacks
display_name = display_name or var_name or 'Unknown'
time_stamp = time_stamp or 'Unknown_Time'
safe_display_name = display_name.replace('/', '_').replace(' ', '_').replace('โ', '2').replace('โ', '3').replace('.', '_')
safe_time_stamp = time_stamp.replace('-', '').replace(':', '').replace(' ', '_')
filename_parts = [f"{safe_display_name}_India_interactive"]
if pressure_level:
filename_parts.append(f"{int(pressure_level)}hPa")
filename_parts.extend([color_theme, safe_time_stamp])
filename = "_".join(filename_parts) + ".html"
plot_path = self.plots_dir / filename
# Save as interactive HTML
fig.write_html(str(plot_path), config=config, include_plotlyjs='cdn')
print(f"Interactive HTML plot saved: {plot_path}")
return str(plot_path)
def _save_png_plot(self, fig, var_name, display_name, pressure_level, color_theme, time_stamp):
"""Save the plot as PNG for download/fallback"""
safe_display_name = display_name.replace('/', '_').replace(' ', '_').replace('โ', '2').replace('โ', '3').replace('.', '_')
safe_time_stamp = time_stamp.replace('-', '').replace(':', '').replace(' ', '_')
filename_parts = [f"{safe_display_name}_India_static"]
if pressure_level:
filename_parts.append(f"{int(pressure_level)}hPa")
filename_parts.extend([color_theme, safe_time_stamp])
filename = "_".join(filename_parts) + ".png"
plot_path = self.plots_dir / filename
try:
# Save as static PNG with high quality
fig.write_image(str(plot_path), format='png', width=1400, height=1000, scale=2)
print(f"Static PNG plot saved: {plot_path}")
return str(plot_path)
except Exception as e:
print(f"Warning: Could not save PNG: {e}")
return None
def list_available_themes(self):
"""List available color themes"""
return COLOR_THEMES
def test_interactive_plot_generator():
"""Test function for the interactive plot generator"""
print("Testing interactive plot generator...")
# Create test data
lats = np.linspace(6, 38, 50)
lons = np.linspace(68, 98, 60)
lon_grid, lat_grid = np.meshgrid(lons, lats)
data = np.sin(lat_grid * 0.1) * np.cos(lon_grid * 0.1) * 100 + 50
data += np.random.normal(0, 10, data.shape)
metadata = {
'variable_name': 'pm25',
'display_name': 'PM2.5',
'units': 'ยตg/mยณ',
'lats': lats,
'lons': lons,
'pressure_level': None,
'timestamp_str': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
}
shapefile_path = "shapefiles/India_State_Boundary.shp"
if not Path(shapefile_path).exists():
print(f"โ Test failed: Shapefile not found at '{shapefile_path}'.")
print("Please make sure you have unzipped 'India_State_Boundary.zip' into a 'shapefiles' folder.")
return False
plotter = InteractiveIndiaMapPlotter(shapefile_path=shapefile_path)
try:
result = plotter.create_india_map(data, metadata, color_theme='YlOrRd')
if result.get('html_path'):
print(f"โ
Test interactive HTML plot created successfully: {result['html_path']}")
if result.get('png_path'):
print(f"โ
Test static PNG plot created successfully: {result['png_path']}")
return True
except Exception as e:
print(f"โ Test failed: {str(e)}")
import traceback
traceback.print_exc()
return False
def test_color_themes():
"""Test all available color themes for compatibility"""
from constants import COLOR_THEMES
# Create colormap mapping
colormap_mapping = {
# Sequential color schemes
'viridis': 'Viridis',
'plasma': 'Plasma',
'inferno': 'Inferno',
'magma': 'Magma',
'cividis': 'Cividis',
# Single-hue sequential schemes
'YlOrRd': 'YlOrRd',
'Oranges': 'Oranges',
'Reds': 'Reds',
'Purples': 'Purples',
'Blues': 'Blues',
'Greens': 'Greens',
# Diverging schemes
'coolwarm': 'RdBu_r',
'RdYlBu': 'RdYlBu',
'Spectral': 'Spectral',
'Spectral_r': 'Spectral_r',
'RdYlGn_r': 'RdYlGn_r',
# Other schemes
'jet': 'Jet',
'turbo': 'Turbo'
}
print("๐จ Testing color theme mappings:")
print(f"{'Color Theme':<15} {'Plotly Colorscale':<20} {'Status'}")
print("-" * 50)
for theme_key in COLOR_THEMES.keys():
if theme_key in colormap_mapping:
plotly_scale = colormap_mapping[theme_key]
status = "โ
Mapped"
else:
plotly_scale = "Viridis (default)"
status = "โ ๏ธ Missing"
print(f"{theme_key:<15} {plotly_scale:<20} {status}")
missing_themes = set(COLOR_THEMES.keys()) - set(colormap_mapping.keys())
if missing_themes:
print(f"\nโ Missing mappings for: {', '.join(missing_themes)}")
return False
else:
print(f"\nโ
All {len(COLOR_THEMES)} color themes are properly mapped!")
return True
if __name__ == "__main__":
test_interactive_plot_generator()