Spaces:
Running
Running
| # interactive_plot_generator.py | |
| # Generate interactive air pollution maps for India with hover information | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| import geopandas as gpd | |
| from pathlib import Path | |
| from datetime import datetime | |
| from constants import INDIA_BOUNDS, COLOR_THEMES | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| class InteractiveIndiaMapPlotter: | |
| def __init__(self, plots_dir="plots", shapefile_path="shapefiles/India_State_Boundary.shp"): | |
| """ | |
| Initialize the interactive map plotter | |
| Parameters: | |
| plots_dir (str): Directory to save plots | |
| shapefile_path (str): Path to the India districts shapefile | |
| """ | |
| self.plots_dir = Path(plots_dir) | |
| self.plots_dir.mkdir(exist_ok=True) | |
| try: | |
| self.india_map = gpd.read_file(shapefile_path) | |
| # Ensure it's in lat/lon (WGS84) | |
| if self.india_map.crs is not None and self.india_map.crs.to_epsg() != 4326: | |
| self.india_map = self.india_map.to_crs(epsg=4326) | |
| except Exception as e: | |
| raise FileNotFoundError(f"Could not read the shapefile at '{shapefile_path}'. " | |
| f"Please ensure the file exists. Error: {e}") | |
| def create_india_map(self, data_values, metadata, color_theme=None, save_plot=True, custom_title=None): | |
| """ | |
| Create interactive air pollution map over India with hover information | |
| Parameters: | |
| data_values (np.ndarray): 2D array of pollution data | |
| metadata (dict): Metadata containing lats, lons, variable info, etc. | |
| color_theme (str): Color theme name from COLOR_THEMES | |
| save_plot (bool): Whether to save the plot as JPG | |
| custom_title (str): Custom title for the plot | |
| Returns: | |
| str: Path to saved plot file | |
| """ | |
| try: | |
| # Extract metadata | |
| lats = metadata['lats'] | |
| lons = metadata['lons'] | |
| var_name = metadata['variable_name'] | |
| display_name = metadata['display_name'] | |
| units = metadata['units'] | |
| pressure_level = metadata.get('pressure_level') | |
| time_stamp = metadata.get('timestamp_str') | |
| # Determine color theme | |
| if color_theme is None: | |
| from constants import AIR_POLLUTION_VARIABLES | |
| color_theme = AIR_POLLUTION_VARIABLES.get(var_name, {}).get('cmap', 'viridis') | |
| # Map matplotlib colormaps to Plotly colormaps | |
| colormap_mapping = { | |
| 'viridis': 'Viridis', | |
| 'plasma': 'Plasma', | |
| 'inferno': 'Inferno', | |
| 'magma': 'Magma', | |
| 'cividis': 'Cividis', | |
| 'YlOrRd': 'YlOrRd', | |
| 'RdYlGn_r': 'RdYlGn_r', | |
| 'coolwarm': 'RdBu_r', | |
| 'Spectral_r': 'Spectral_r', | |
| 'jet': 'Jet', | |
| 'turbo': 'Turbo' | |
| } | |
| plotly_colorscale = colormap_mapping.get(color_theme, 'Viridis') | |
| # Create mesh grid if needed | |
| if lons.ndim == 1 and lats.ndim == 1: | |
| lon_grid, lat_grid = np.meshgrid(lons, lats) | |
| else: | |
| lon_grid, lat_grid = lons, lats | |
| # Calculate statistics | |
| valid_data = data_values[~np.isnan(data_values)] | |
| if len(valid_data) == 0: | |
| raise ValueError("All data values are NaN - cannot create plot") | |
| from constants import AIR_POLLUTION_VARIABLES | |
| vmax_percentile = AIR_POLLUTION_VARIABLES.get(var_name, {}).get('vmax_percentile', 90) | |
| vmin = np.nanpercentile(valid_data, 5) | |
| vmax = np.nanpercentile(valid_data, vmax_percentile) | |
| if vmax <= vmin: | |
| vmax = vmin + 1.0 | |
| # Create hover text with detailed information | |
| hover_text = self._create_hover_text(lon_grid, lat_grid, data_values, display_name, units) | |
| # Create the figure | |
| fig = go.Figure() | |
| # Add pollution data as heatmap | |
| fig.add_trace(go.Heatmap( | |
| x=lons, | |
| y=lats, | |
| z=data_values, | |
| colorscale=plotly_colorscale, | |
| zmin=vmin, | |
| zmax=vmax, | |
| hovertext=hover_text, | |
| hoverinfo='text', | |
| colorbar=dict( | |
| title=dict( | |
| text=f"{display_name}" + (f"<br>({units})" if units else ""), | |
| side="right" | |
| ), | |
| thickness=20, | |
| len=0.6, | |
| x=1.02 | |
| ) | |
| )) | |
| # Add India state boundaries | |
| for _, row in self.india_map.iterrows(): | |
| if row.geometry.geom_type == 'Polygon': | |
| self._add_polygon_trace(fig, row.geometry) | |
| elif row.geometry.geom_type == 'MultiPolygon': | |
| for polygon in row.geometry.geoms: | |
| self._add_polygon_trace(fig, polygon) | |
| # Create title | |
| if custom_title: | |
| title = custom_title | |
| else: | |
| title = f'{display_name} Concentration over India' | |
| if pressure_level: | |
| title += f' at {pressure_level} hPa' | |
| title += f' on {time_stamp}' | |
| # Calculate stats for annotation | |
| stats_text = self._create_stats_text(valid_data, units) | |
| theme_name = COLOR_THEMES.get(color_theme, color_theme) | |
| # Auto-adjust bounds if needed | |
| xmin, ymin, xmax, ymax = self.india_map.total_bounds | |
| if not (INDIA_BOUNDS['lon_min'] <= xmin <= INDIA_BOUNDS['lon_max']): | |
| lon_range = [xmin, xmax] | |
| lat_range = [ymin, ymax] | |
| else: | |
| lon_range = [INDIA_BOUNDS['lon_min'], INDIA_BOUNDS['lon_max']] | |
| lat_range = [INDIA_BOUNDS['lat_min'], INDIA_BOUNDS['lat_max']] | |
| # Update layout | |
| fig.update_layout( | |
| title=dict( | |
| text=title, | |
| x=0.5, | |
| xanchor='center', | |
| font=dict(size=18, weight='bold') | |
| ), | |
| xaxis=dict( | |
| title='Longitude', | |
| range=lon_range, | |
| showgrid=True, | |
| gridcolor='rgba(128, 128, 128, 0.3)' | |
| ), | |
| yaxis=dict( | |
| title='Latitude', | |
| range=lat_range, | |
| showgrid=True, | |
| gridcolor='rgba(128, 128, 128, 0.3)' | |
| ), | |
| width=1400, | |
| height=1000, | |
| plot_bgcolor='white', | |
| annotations=[ | |
| # Statistics box | |
| dict( | |
| text=stats_text.replace('\n', '<br>'), | |
| xref='paper', yref='paper', | |
| x=0.02, y=0.98, | |
| xanchor='left', yanchor='top', | |
| showarrow=False, | |
| bgcolor='rgba(255, 255, 255, 0.9)', | |
| bordercolor='black', | |
| borderwidth=1, | |
| borderpad=10, | |
| font=dict(size=11) | |
| ), | |
| # Theme info box | |
| dict( | |
| text=f'Color Theme: {theme_name}', | |
| xref='paper', yref='paper', | |
| x=0.98, y=0.02, | |
| xanchor='right', yanchor='bottom', | |
| showarrow=False, | |
| bgcolor='rgba(211, 211, 211, 0.8)', | |
| bordercolor='gray', | |
| borderwidth=1, | |
| borderpad=8, | |
| font=dict(size=10) | |
| ) | |
| ] | |
| ) | |
| plot_path = None | |
| if save_plot: | |
| plot_path = self._save_plot(fig, var_name, display_name, pressure_level, color_theme, time_stamp) | |
| return plot_path | |
| except Exception as e: | |
| raise Exception(f"Error creating interactive map: {str(e)}") | |
| def _add_polygon_trace(self, fig, polygon): | |
| """Add a polygon boundary to the figure""" | |
| x, y = polygon.exterior.xy | |
| fig.add_trace(go.Scatter( | |
| x=list(x), | |
| y=list(y), | |
| mode='lines', | |
| line=dict(color='black', width=1), | |
| hoverinfo='skip', | |
| showlegend=False | |
| )) | |
| def _create_hover_text(self, lon_grid, lat_grid, data_values, display_name, units): | |
| """Create formatted hover text for each point""" | |
| hover_text = np.empty(data_values.shape, dtype=object) | |
| units_str = f" {units}" if units else "" | |
| for i in range(data_values.shape[0]): | |
| for j in range(data_values.shape[1]): | |
| lat = lat_grid[i, j] if lat_grid.ndim == 2 else lat_grid[i] | |
| lon = lon_grid[i, j] if lon_grid.ndim == 2 else lon_grid[j] | |
| value = data_values[i, j] | |
| if np.isnan(value): | |
| value_str = "N/A" | |
| elif abs(value) >= 1000: | |
| value_str = f"{value:.0f}{units_str}" | |
| elif abs(value) >= 10: | |
| value_str = f"{value:.1f}{units_str}" | |
| else: | |
| value_str = f"{value:.2f}{units_str}" | |
| hover_text[i, j] = ( | |
| f"<b>{display_name}</b>: {value_str}<br>" | |
| f"<b>Latitude</b>: {lat:.3f}°<br>" | |
| f"<b>Longitude</b>: {lon:.3f}°" | |
| ) | |
| return hover_text | |
| def _create_stats_text(self, data, units): | |
| """Create statistics text for annotation""" | |
| units_str = f" {units}" if units else "" | |
| stats = { | |
| 'Min': np.nanmin(data), | |
| 'Max': np.nanmax(data), | |
| 'Mean': np.nanmean(data), | |
| 'Median': np.nanmedian(data), | |
| 'Std': np.nanstd(data) | |
| } | |
| def format_number(val): | |
| if abs(val) >= 1000: | |
| return f"{val:.0f}" | |
| elif abs(val) >= 10: | |
| return f"{val:.1f}" | |
| else: | |
| return f"{val:.2f}" | |
| stats_lines = [f"{name}: {format_number(val)}{units_str}" for name, val in stats.items()] | |
| return "\n".join(stats_lines) | |
| def _save_plot(self, fig, var_name, display_name, pressure_level, color_theme, time_stamp): | |
| """Save the plot as JPG""" | |
| safe_display_name = display_name.replace('/', '_').replace(' ', '_').replace('₂', '2').replace('₃', '3').replace('.', '_') | |
| safe_time_stamp = time_stamp.replace('-', '').replace(':', '').replace(' ', '_') | |
| filename_parts = [f"{safe_display_name}_India_interactive"] | |
| if pressure_level: | |
| filename_parts.append(f"{int(pressure_level)}hPa") | |
| filename_parts.extend([color_theme, safe_time_stamp]) | |
| filename = "_".join(filename_parts) + ".jpg" | |
| plot_path = self.plots_dir / filename | |
| # Save as static JPG with high quality | |
| fig.write_image(str(plot_path), format='jpg', width=1400, height=1000, scale=2) | |
| print(f"Interactive plot saved as JPG: {plot_path}") | |
| return str(plot_path) | |
| def list_available_themes(self): | |
| """List available color themes""" | |
| return COLOR_THEMES | |
| def test_interactive_plot_generator(): | |
| """Test function for the interactive plot generator""" | |
| print("Testing interactive plot generator...") | |
| # Create test data | |
| lats = np.linspace(6, 38, 50) | |
| lons = np.linspace(68, 98, 60) | |
| lon_grid, lat_grid = np.meshgrid(lons, lats) | |
| data = np.sin(lat_grid * 0.1) * np.cos(lon_grid * 0.1) * 100 + 50 | |
| data += np.random.normal(0, 10, data.shape) | |
| metadata = { | |
| 'variable_name': 'pm25', | |
| 'display_name': 'PM2.5', | |
| 'units': 'µg/m³', | |
| 'lats': lats, | |
| 'lons': lons, | |
| 'pressure_level': None, | |
| 'timestamp_str': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), | |
| } | |
| shapefile_path = "shapefiles/India_State_Boundary.shp" | |
| if not Path(shapefile_path).exists(): | |
| print(f"❌ Test failed: Shapefile not found at '{shapefile_path}'.") | |
| print("Please make sure you have unzipped 'India_State_Boundary.zip' into a 'shapefiles' folder.") | |
| return False | |
| plotter = InteractiveIndiaMapPlotter(shapefile_path=shapefile_path) | |
| try: | |
| plot_path = plotter.create_india_map(data, metadata, color_theme='YlOrRd') | |
| print(f"✅ Test interactive plot created successfully: {plot_path}") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Test failed: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| if __name__ == "__main__": | |
| test_interactive_plot_generator() |