Spaces:
Build error
Build error
| # import streamlit as st | |
| # x = st.slider('Select a value') | |
| # st.write(x, 'squared is', x * x) | |
| import streamlit as st | |
| import random | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| import tarfile | |
| import os | |
| import sys | |
| import yaml | |
| st.title("PrithviWxC Model Inference") | |
| st.write("Setting up environment...") | |
| # Set up torch backends and seeds | |
| torch.jit.enable_onednn_fusion(True) | |
| if torch.cuda.is_available(): | |
| st.write(f"Using device: {torch.cuda.get_device_name()}") | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cudnn.deterministic = True | |
| random.seed(42) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed(42) | |
| torch.manual_seed(42) | |
| np.random.seed(42) | |
| # Set device | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| st.write(f"Using device: {device}") | |
| # Download and extract PrithviWxC module | |
| st.write("Downloading and setting up PrithviWxC module...") | |
| module_tar_path = hf_hub_download( | |
| repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
| filename="PrithviWxC.tar.gz", | |
| local_dir=".", | |
| force_download=True | |
| ) | |
| with tarfile.open(module_tar_path, "r:gz") as tar: | |
| tar.extractall(path=".") | |
| # Add the module path to sys.path | |
| sys.path.append(os.path.abspath("./PrithviWxC")) | |
| st.write("PrithviWxC module imported successfully.") | |
| # Now import the module | |
| from PrithviWxC.dataloaders.merra2 import Merra2Dataset, input_scalers, output_scalers, static_input_scalers, preproc | |
| from PrithviWxC.model import PrithviWxC | |
| # Variables and times | |
| surface_vars = [ | |
| "EFLUX", | |
| "GWETROOT", | |
| "HFLUX", | |
| "LAI", | |
| "LWGAB", | |
| "LWGEM", | |
| "LWTUP", | |
| "PS", | |
| "QV2M", | |
| "SLP", | |
| "SWGNT", | |
| "SWTNT", | |
| "T2M", | |
| "TQI", | |
| "TQL", | |
| "TQV", | |
| "TS", | |
| "U10M", | |
| "V10M", | |
| "Z0M", | |
| ] | |
| static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"] | |
| vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"] | |
| levels = [ | |
| 34.0, | |
| 39.0, | |
| 41.0, | |
| 43.0, | |
| 44.0, | |
| 45.0, | |
| 48.0, | |
| 51.0, | |
| 53.0, | |
| 56.0, | |
| 63.0, | |
| 68.0, | |
| 71.0, | |
| 72.0, | |
| ] | |
| padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]} | |
| st.write("Setting up dataset parameters...") | |
| # User inputs for lead times and input times | |
| lead_time = st.number_input("Lead Time (hours)", min_value=1, max_value=24, value=6) | |
| input_time = st.number_input("Input Time Difference (hours)", min_value=-24, max_value=0, value=-6) | |
| lead_times = [lead_time] # This variable can be changed to change the task | |
| input_times = [input_time] # This variable can be changed to change the task | |
| # Data file | |
| time_range = ("2020-01-01T00:00:00", "2020-01-01T23:59:59") | |
| st.write("Downloading data files...") | |
| surf_dir = Path("./merra-2") | |
| snapshot_download( | |
| repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
| allow_patterns="merra-2/MERRA2_sfc_2020010[1].nc", | |
| local_dir=".", | |
| force_download=True, | |
| ) | |
| vert_dir = Path("./merra-2") | |
| snapshot_download( | |
| repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
| allow_patterns="merra-2/MERRA_pres_2020010[1].nc", | |
| local_dir=".", | |
| force_download=True, | |
| ) | |
| # Climatology | |
| surf_clim_dir = Path("./climatology") | |
| snapshot_download( | |
| repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
| allow_patterns="climatology/climate_surface_doy00[1]*.nc", | |
| local_dir=".", | |
| force_download=True, | |
| ) | |
| vert_clim_dir = Path("./climatology") | |
| snapshot_download( | |
| repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
| allow_patterns="climatology/climate_vertical_doy00[1]*.nc", | |
| local_dir=".", | |
| force_download=True, | |
| ) | |
| st.write("Setting positional encoding...") | |
| positional_encoding = "fourier" | |
| st.write("Initializing dataset...") | |
| dataset = Merra2Dataset( | |
| time_range=time_range, | |
| lead_times=lead_times, | |
| input_times=input_times, | |
| data_path_surface=surf_dir, | |
| data_path_vertical=vert_dir, | |
| climatology_path_surface=surf_clim_dir, | |
| climatology_path_vertical=vert_clim_dir, | |
| surface_vars=surface_vars, | |
| static_surface_vars=static_surface_vars, | |
| vertical_vars=vertical_vars, | |
| levels=levels, | |
| positional_encoding=positional_encoding, | |
| ) | |
| assert len(dataset) > 0, "There doesn't seem to be any valid data." | |
| st.write("Loading scalers...") | |
| surf_in_scal_path = Path("./climatology/musigma_surface.nc") | |
| hf_hub_download( | |
| repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
| filename=f"climatology/{surf_in_scal_path.name}", | |
| local_dir=".", | |
| force_download=True, | |
| ) | |
| vert_in_scal_path = Path("./climatology/musigma_vertical.nc") | |
| hf_hub_download( | |
| repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
| filename=f"climatology/{vert_in_scal_path.name}", | |
| local_dir=".", | |
| force_download=True, | |
| ) | |
| surf_out_scal_path = Path("./climatology/anomaly_variance_surface.nc") | |
| hf_hub_download( | |
| repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
| filename=f"climatology/{surf_out_scal_path.name}", | |
| local_dir=".", | |
| force_download=True, | |
| ) | |
| vert_out_scal_path = Path("./climatology/anomaly_variance_vertical.nc") | |
| hf_hub_download( | |
| repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
| filename=f"climatology/{vert_out_scal_path.name}", | |
| local_dir=".", | |
| force_download=True, | |
| ) | |
| in_mu, in_sig = input_scalers( | |
| surface_vars, | |
| vertical_vars, | |
| levels, | |
| surf_in_scal_path, | |
| vert_in_scal_path, | |
| ) | |
| output_sig = output_scalers( | |
| surface_vars, | |
| vertical_vars, | |
| levels, | |
| surf_out_scal_path, | |
| vert_out_scal_path, | |
| ) | |
| static_mu, static_sig = static_input_scalers( | |
| surf_in_scal_path, | |
| static_surface_vars, | |
| ) | |
| st.write("Setting up model...") | |
| residual = "climate" | |
| masking_mode = "local" | |
| decoder_shifting = True | |
| masking_ratio = 0.99 | |
| # Load model config | |
| hf_hub_download( | |
| repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
| filename="config.yaml", | |
| local_dir=".", | |
| force_download=True, | |
| ) | |
| with open("./config.yaml", "r") as f: | |
| config = yaml.safe_load(f) | |
| model = PrithviWxC( | |
| in_channels=config["params"]["in_channels"], | |
| input_size_time=config["params"]["input_size_time"], | |
| in_channels_static=config["params"]["in_channels_static"], | |
| input_scalers_mu=in_mu, | |
| input_scalers_sigma=in_sig, | |
| input_scalers_epsilon=config["params"]["input_scalers_epsilon"], | |
| static_input_scalers_mu=static_mu, | |
| static_input_scalers_sigma=static_sig, | |
| static_input_scalers_epsilon=config["params"]["static_input_scalers_epsilon"], | |
| output_scalers=output_sig**0.5, | |
| n_lats_px=config["params"]["n_lats_px"], | |
| n_lons_px=config["params"]["n_lons_px"], | |
| patch_size_px=config["params"]["patch_size_px"], | |
| mask_unit_size_px=config["params"]["mask_unit_size_px"], | |
| mask_ratio_inputs=masking_ratio, | |
| embed_dim=config["params"]["embed_dim"], | |
| n_blocks_encoder=config["params"]["n_blocks_encoder"], | |
| n_blocks_decoder=config["params"]["n_blocks_decoder"], | |
| mlp_multiplier=config["params"]["mlp_multiplier"], | |
| n_heads=config["params"]["n_heads"], | |
| dropout=config["params"]["dropout"], | |
| drop_path=config["params"]["drop_path"], | |
| parameter_dropout=config["params"]["parameter_dropout"], | |
| residual=residual, | |
| masking_mode=masking_mode, | |
| decoder_shifting=decoder_shifting, | |
| positional_encoding=positional_encoding, | |
| checkpoint_encoder=[], | |
| checkpoint_decoder=[], | |
| ) | |
| st.write("Loading model weights...") | |
| weights_path = Path("./weights/prithvi.wxc.2300m.v1.pt") | |
| hf_hub_download( | |
| repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
| filename=weights_path.name, | |
| local_dir="./weights", | |
| force_download=True, | |
| ) | |
| state_dict = torch.load(weights_path, map_location=device) | |
| if "model_state" in state_dict: | |
| state_dict = state_dict["model_state"] | |
| model.load_state_dict(state_dict, strict=True) | |
| model = model.to(device) | |
| st.write("Model loaded and ready.") | |
| if st.button("Run Inference"): | |
| st.write("Running inference...") | |
| data = next(iter(dataset)) | |
| batch = preproc([data], padding) | |
| for k, v in batch.items(): | |
| if isinstance(v, torch.Tensor): | |
| batch[k] = v.to(device) | |
| with torch.no_grad(): | |
| model.eval() | |
| out = model(batch) | |
| st.write("Inference completed. Generating plot...") | |
| t2m = out[0, 12].cpu().numpy() | |
| lat = np.linspace(-90, 90, out.shape[-2]) | |
| lon = np.linspace(-180, 180, out.shape[-1]) | |
| X, Y = np.meshgrid(lon, lat) | |
| fig, ax = plt.subplots() | |
| cs = ax.contourf(X, Y, t2m, 100) | |
| ax.set_aspect("equal") | |
| plt.colorbar(cs) | |
| st.pyplot(fig) | |
| st.write("Plot generated.") | |