", unsafe_allow_html=True)
st.write(
"GeoMate helps geotechnical engineers: classify soils (USCS/AASHTO), "
"plot grain size distributions (GSD), fetch Earth Engine data, chat with a RAG-backed LLM, "
"run OCR on site logs, and generate professional reports."
)
st.markdown("
", unsafe_allow_html=True)
st.markdown("### ๐ Quick Actions")
c1, c2, c3 = st.columns(3)
if c1.button("๐งช Classifier"):
st.session_state["page"] = "Classifier"; st.rerun()
if c2.button("๐ Soil Recognizer"):
st.session_state["page"] = "Soil recognizer"; st.rerun()
if c3.button("๐ Locator"):
st.session_state["page"] = "Locator"; st.rerun()
c4, c5, c6 = st.columns(3)
if c4.button("๐ค Ask GeoMate"):
st.session_state["page"] = "RAG"; st.rerun()
if c5.button("๐ท OCR"):
st.session_state["page"] = "OCR"; st.rerun()
if c6.button("๐ Reports"):
st.session_state["page"] = "Reports"; st.rerun()
with col2:
st.markdown("
", unsafe_allow_html=True)
st.markdown("
๐ Live Site Summary
", unsafe_allow_html=True)
if "sites" in st.session_state and st.session_state.get("active_site") is not None:
site = st.session_state["sites"][st.session_state["active_site"]]
st.write(f"**Site:** {site.get('Site Name','N/A')}")
st.write(f"USCS: {site.get('USCS','-')} | AASHTO: {site.get('AASHTO','-')}")
st.write(f"GSD Saved: {'โ ' if site.get('GSD') else 'โ'}")
else:
st.info("No active site selected.")
st.markdown("
", unsafe_allow_html=True)
# Soil Classifier page (conversational, step-by-step)
def soil_classifier_page():
st.header("๐งช Soil Classifier โ Conversational (USCS & AASHTO)")
site = st.session_state["sites"][st.session_state["active_site"]]
# conversation state machine: steps list
steps = [
{"id":"intro", "bot":"Hello โ I am the GeoMate Soil Classifier. Ready to start?"},
{"id":"organic", "bot":"Is the soil at this site organic (contains high organic matter, feels spongy or has odour)?", "type":"choice", "choices":["No","Yes"]},
{"id":"P2", "bot":"Please enter the percentage passing the #200 sieve (0.075 mm). Example: 12", "type":"number"},
{"id":"P4", "bot":"What is the percentage passing the sieve no. 4 (4.75 mm)? (enter 0 if unknown)", "type":"number"},
{"id":"hasD", "bot":"Do you know the D10, D30 and D60 diameters (in mm)?", "type":"choice","choices":["No","Yes"]},
{"id":"D60", "bot":"Enter D60 (diameter in mm corresponding to 60% passing).", "type":"number"},
{"id":"D30", "bot":"Enter D30 (diameter in mm corresponding to 30% passing).", "type":"number"},
{"id":"D10", "bot":"Enter D10 (diameter in mm corresponding to 10% passing).", "type":"number"},
{"id":"LL", "bot":"What is the liquid limit (LL)?", "type":"number"},
{"id":"PL", "bot":"What is the plastic limit (PL)?", "type":"number"},
{"id":"dry", "bot":"Select the observed dry strength of the fine soil (if applicable).", "type":"select", "options":DRY_STRENGTH_OPTIONS},
{"id":"dilat", "bot":"Select the observed dilatancy behaviour.", "type":"select", "options":DILATANCY_OPTIONS},
{"id":"tough", "bot":"Select the observed toughness.", "type":"select", "options":TOUGHNESS_OPTIONS},
{"id":"confirm", "bot":"Would you like me to classify now?", "type":"choice", "choices":["No","Yes"]}
]
if "classifier_step" not in st.session_state:
st.session_state["classifier_step"] = 0
if "classifier_inputs" not in st.session_state:
st.session_state["classifier_inputs"] = dict(site.get("classifier_inputs", {}))
step_idx = st.session_state["classifier_step"]
# chat history display
st.markdown("
", unsafe_allow_html=True)
st.markdown("
{}
".format("GeoMate: Hello โ soil classifier ready. Use the controls below to answer step-by-step."), unsafe_allow_html=True)
# Show stored user answers sequentially for context
# render question up to current step
for i in range(step_idx+1):
s = steps[i]
# show bot prompt
st.markdown(f"
{s['bot']}
", unsafe_allow_html=True)
# show user answer if exists in classifier_inputs
key = s["id"]
val = st.session_state["classifier_inputs"].get(key)
if val is not None:
st.markdown(f"
{val}
", unsafe_allow_html=True)
st.markdown("
", unsafe_allow_html=True)
# Render input widget for current step
current = steps[step_idx]
step_id = current["id"]
proceed = False
user_answer = None
cols = st.columns([1,1,1])
with cols[0]:
if current.get("type") == "choice":
choice = st.radio(current["bot"], options=current["choices"], index=0, key=f"cls_{step_id}")
user_answer = choice
elif current.get("type") == "number":
# numeric input without +/- spinner (we use text_input and validate)
raw = st.text_input(current["bot"], value=str(st.session_state["classifier_inputs"].get(step_id,"")), key=f"cls_{step_id}_num")
# validate numeric
try:
if raw.strip() == "":
user_answer = None
else:
user_answer = float(raw)
except:
st.warning("Please enter a valid number (e.g., 12 or 0).")
user_answer = None
elif current.get("type") == "select":
opts = current.get("options", [])
sel = st.selectbox(current["bot"], options=opts, index=0, key=f"cls_{step_id}_sel")
user_answer = sel
else:
# just a message step โ proceed
user_answer = None
# controls: Next / Back
coln, colb, colsave = st.columns([1,1,1])
with coln:
if st.button("โก๏ธ Next", key=f"next_{step_id}"):
# store answer if provided
if current.get("type") == "number":
if user_answer is None:
st.warning("Please enter a numeric value or enter 0 if unknown.")
else:
st.session_state["classifier_inputs"][step_id] = user_answer
st.session_state["classifier_step"] = min(step_idx+1, len(steps)-1)
st.rerun()
elif current.get("type") in ("choice","select"):
st.session_state["classifier_inputs"][step_id] = user_answer
st.session_state["classifier_step"] = min(step_idx+1, len(steps)-1)
st.rerun()
else:
# message-only step
st.session_state["classifier_step"] = min(step_idx+1, len(steps)-1)
st.rerun()
with colb:
if st.button("โฌ ๏ธ Back", key=f"back_{step_id}"):
st.session_state["classifier_step"] = max(0, step_idx-1)
st.rerun()
with colsave:
if st.button("๐พ Save & Classify now", key="save_and_classify"):
# prepare inputs in required format for classify_uscs_aashto
ci = st.session_state["classifier_inputs"].copy()
# Normalize choices into expected codes
if isinstance(ci.get("dry"), str):
ci["nDS"] = DRY_STRENGTH_MAP.get(ci.get("dry"), 5)
if isinstance(ci.get("dilat"), str):
ci["nDIL"] = DILATANCY_MAP.get(ci.get("dilat"), 6)
if isinstance(ci.get("tough"), str):
ci["nTG"] = TOUGHNESS_MAP.get(ci.get("tough"), 6)
# map 'Yes'/'No' for organic and hasD
ci["opt"] = "y" if ci.get("organic","No")=="Yes" or ci.get("organic",ci.get("organic"))=="Yes" else ci.get("organic","n")
# our field names in CI may differ: convert organic stored under 'organic' step to 'opt'
if "organic" in ci:
ci["opt"] = "y" if ci["organic"]=="Yes" else "n"
# map D entries: D60 etc may be present
# call classification
try:
res_text, aashto, GI, chars, uscs = classify_uscs_aashto(ci)
except Exception as e:
st.error(f"Classification error: {e}")
res_text = f"Error during classification: {e}"
aashto = "N/A"; GI = 0; chars = {}; uscs = "N/A"
# save into active site
site["USCS"] = uscs
site["AASHTO"] = aashto
site["GI"] = GI
site["classifier_inputs"] = ci
site["classifier_decision"] = res_text
st.success("Classification complete. Results saved to site.")
st.write("### Classification Results")
st.markdown(res_text)
# Keep classifier_step at end so user can review
st.session_state["classifier_step"] = len(steps)-1
# GSD Curve Page
def gsd_page():
st.header("๐ Grain Size Distribution (GSD) Curve")
site = st.session_state["sites"][st.session_state["active_site"]]
st.markdown("Enter diameters (mm) and % passing (comma-separated). Use descending diameters (largest to smallest).")
diam_input = st.text_area("Diameters (mm) comma-separated", value=site.get("GSD",{}).get("diameters","75,50,37.5,25,19,12.5,9.5,4.75,2,0.85,0.425,0.25,0.18,0.15,0.075") if site.get("GSD") else "75,50,37.5,25,19,12.5,9.5,4.75,2,0.85,0.425,0.25,0.18,0.15,0.075")
pass_input = st.text_area("% Passing comma-separated", value=site.get("GSD",{}).get("passing","100,98,96,90,85,78,72,65,55,45,35,25,18,14,8") if site.get("GSD") else "100,98,96,90,85,78,72,65,55,45,35,25,18,14,8")
if st.button("Compute GSD & Save"):
try:
diams = [float(x.strip()) for x in diam_input.split(",") if x.strip()]
passing = [float(x.strip()) for x in pass_input.split(",") if x.strip()]
metrics = compute_gsd_metrics(diams, passing)
# plot
fig, ax = plt.subplots(figsize=(7,4))
ax.semilogx(diams, passing, marker='o')
ax.set_xlabel("Particle size (mm)")
ax.set_ylabel("% Passing")
ax.invert_xaxis()
ax.grid(True, which='both', linestyle='--', linewidth=0.5)
ax.set_title("Grain Size Distribution")
st.pyplot(fig)
# save into site
site["GSD"] = {"diameters":diams, "passing":passing, **metrics}
st.success(f"Saved GSD for site. D10={metrics['D10']:.4g} mm, D30={metrics['D30']:.4g} mm, D60={metrics['D60']:.4g} mm")
except Exception as e:
st.error(f"GSD error: {e}")
# OCR Page
def ocr_page():
st.header("๐ท OCR โ extract values from an image")
site = st.session_state["sites"][st.session_state["active_site"]]
if not OCR_AVAILABLE:
st.warning("OCR dependencies not available (pytesseract/PIL). Add pytesseract and pillow to requirements to enable OCR.")
uploaded = st.file_uploader("Upload an image (photo of textbook question or sieve data)", type=["png","jpg","jpeg"])
if uploaded:
if OCR_AVAILABLE:
try:
img = Image.open(uploaded)
st.image(img, caption="Uploaded", use_column_width=True)
text = pytesseract.image_to_string(img)
st.text_area("Extracted text", value=text, height=180)
# Basic parsing: try to find LL, PL, D10 etc via regex
import re
found = {}
for key in ["LL","PL","D10","D30","D60","P2","P4","CBR"]:
pattern = re.compile(rf"{key}[:=]?\s*([0-9]+\.?[0-9]*)", re.I)
m = pattern.search(text)
if m:
found[key] = float(m.group(1))
site.setdefault("classifier_inputs",{})[key] = float(m.group(1))
if found:
st.success(f"Parsed values: {found}")
st.write("Values saved into classifier inputs.")
else:
st.info("No clear numeric matches found automatically.")
except Exception as e:
st.error(f"OCR failed: {e}")
else:
st.warning("OCR not available in this deployment.")
# Locator Page (with Earth Engine auth at top)
# Locator Page (with Earth Engine auth at top)
import os
import json
import streamlit as st
import geemap.foliumap as geemap
import ee
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
from streamlit_folium import st_folium
# =====================================================
# EE Init Helper
# =====================================================
def initialize_ee():
EARTHENGINE_TOKEN = os.getenv("EARTHENGINE_TOKEN")
SERVICE_ACCOUNT = os.getenv("SERVICE_ACCOUNT")
if "ee_initialized" in st.session_state and st.session_state["ee_initialized"]:
return True
if EARTHENGINE_TOKEN and SERVICE_ACCOUNT:
try:
creds = ee.ServiceAccountCredentials(email=SERVICE_ACCOUNT, key_data=EARTHENGINE_TOKEN)
ee.Initialize(creds)
st.session_state["ee_initialized"] = True
return True
except Exception as e:
st.warning(f"Service account init failed: {e}, falling back...")
try:
ee.Initialize()
st.session_state["ee_initialized"] = True
return True
except Exception:
try:
ee.Authenticate()
ee.Initialize()
st.session_state["ee_initialized"] = True
return True
except Exception as e:
st.error(f"Earth Engine auth failed: {e}")
return False
# =====================================================
# Safe reducers
# =====================================================
def safe_get_reduce(region, image, band, scale=1000, default=None, max_pixels=int(1e7)):
try:
rr = image.reduceRegion(ee.Reducer.mean(), region, scale=scale, maxPixels=max_pixels)
val = rr.get(band)
return float(val.getInfo()) if val else default
except Exception:
return default
def safe_reduce_histogram(region, image, band, scale=1000, max_pixels=int(1e7)):
try:
rr = image.reduceRegion(ee.Reducer.frequencyHistogram(), region, scale=scale, maxPixels=max_pixels)
hist = rr.get(band)
return hist.getInfo() if hist else {}
except Exception:
return {}
def safe_time_series(region, collection, band, start, end,
reducer=None, scale=1000, max_pixels=int(1e7)):
try:
if reducer is None:
reducer = ee.Reducer.mean() # โ assign inside function
def per_image(img):
date = img.date().format("YYYY-MM-dd")
val = img.reduceRegion(reducer, region, scale=scale, maxPixels=max_pixels).get(band)
return ee.Feature(None, {"date": date, "val": val})
feats = collection.filterDate(start, end).map(per_image).filter(ee.Filter.notNull(["val"])).getInfo()
pts = []
for f in feats.get("features", []):
p = f.get("properties", {})
if p.get("val") is not None:
pts.append((p.get("date"), float(p.get("val"))))
return pts
except Exception:
return []
# =====================================================
# Map snapshot (in-memory, no disk bloat)
# =====================================================
def export_map_snapshot(m, width=800, height=600):
"""Return PNG snapshot bytes of geemap Map."""
try:
from io import BytesIO
buf = BytesIO()
m.screenshot(filename=None, region=None, dimensions=(width, height), out_file=buf)
buf.seek(0)
return buf.read()
except Exception as e:
st.warning(f"Map snapshot failed: {e}")
return None
# =====================================================
# Locator page
# =====================================================
def locator_page():
st.title("๐ GeoMate Interactive Earth Explorer")
st.markdown(
"Draw a polygon (or rectangle) on the map using the drawing tool. "
"Then press **Compute Summaries** to compute soil, elevation, seismic, flood, landcover, NDVI, and atmospheric data."
)
# --- Auth
if not initialize_ee():
st.stop()
# --- Map setup
m = geemap.Map(center=[28.0, 72.0], zoom=5, plugin_Draw=True, draw_export=True, locate_control=True)
# โ Add a basemap explicitly
m.add_basemap("HYBRID") # Google Satellite Hybrid
m.add_basemap("ROADMAP") # Google Roads
m.add_basemap("Esri.WorldImagery")
m.add_basemap("OpenStreetMap")
# Restore ROI (if available)
if "roi_geojson" in st.session_state:
import folium
try:
saved = st.session_state["roi_geojson"]
folium.GeoJson(saved, name="Saved ROI",
style_function=lambda x: {"color": "red", "weight": 2, "fillOpacity": 0.1}).add_to(m)
except Exception as e:
st.warning(f"Could not re-add saved ROI: {e}")
# --- Datasets
try:
dem = ee.Image("NASA/NASADEM_HGT/001"); dem_band_name = "elevation"
except Exception:
dem, dem_band_name = None, None
soil_img, chosen_soil_band = None, None
try:
soil_img = ee.Image("OpenLandMap/SOL/SOL_CLAY-WFRACTION_USDA-3A1A1A_M/v02")
bands = soil_img.bandNames().getInfo()
chosen_soil_band = st.selectbox("Select soil clay band", options=bands, index=bands.index("b200") if "b200" in bands else 0)
except Exception:
soil_img, chosen_soil_band = None, None
try:
seismic_img = ee.Image("SEDAC/GSHAPSeismicHazard"); seismic_band = "gshap"
except Exception:
seismic_img, seismic_band = None, None
try:
water = ee.Image("JRC/GSW1_4/GlobalSurfaceWater"); water_band = "occurrence"
except Exception:
water, water_band = None, None
try:
landcover = ee.Image("ESA/WorldCover/v200"); lc_band = "Map"
except Exception:
landcover, lc_band = None, None
try:
ndvi_col = ee.ImageCollection("MODIS/061/MOD13A2").select("NDVI")
except Exception:
ndvi_col = None
# Atmospheric datasets
try:
precip_col = ee.ImageCollection("UCSB-CHG/CHIRPS/DAILY").select("precipitation")
except Exception:
precip_col = None
try:
temp_col = ee.ImageCollection("MODIS/061/MOD11A2").select("LST_Day_1km")
except Exception:
temp_col = None
try:
pm25_img = ee.ImageCollection("COPERNICUS/S5P/OFFL/L3_AER_AI").select("absorbing_aerosol_index").mean()
except Exception:
pm25_img = None
# --- Render & capture ROI
result = st_folium(m, width=800, height=600, returned_objects=["last_active_drawing"])
roi, flat_coords = None, None
if result and "last_active_drawing" in result and result["last_active_drawing"]:
feat = result["last_active_drawing"]
geom = feat.get("geometry")
if geom:
try:
roi = ee.Geometry(geom)
coords = geom.get("coordinates", None)
st.session_state["roi_geojson"] = feat
if coords:
if geom["type"] in ["Polygon", "MultiPolygon"]:
flat_coords = [(lat, lon) for ring in coords for lon, lat in ring]
elif geom["type"] == "Point":
lon, lat = coords; flat_coords = [(lat, lon)]
elif geom["type"] == "LineString":
flat_coords = [(lat, lon) for lon, lat in coords]
if flat_coords: st.session_state["roi_coords"] = flat_coords
st.success("โ ROI captured!")
except Exception as e:
st.error(f"Failed to convert geometry: {e}")
if roi is None and "roi_geojson" in st.session_state:
try:
geom = st.session_state["roi_geojson"].get("geometry")
if geom:
roi = ee.Geometry(geom)
coords = geom.get("coordinates", None)
if coords:
if geom["type"] in ["Polygon", "MultiPolygon"]:
flat_coords = [(lat, lon) for ring in coords for lon, lat in ring]
elif geom["type"] == "Point":
lon, lat = coords; flat_coords = [(lat, lon)]
elif geom["type"] == "LineString":
flat_coords = [(lat, lon) for lon, lat in coords]
if flat_coords: st.session_state["roi_coords"] = flat_coords
st.info("โป๏ธ ROI restored from session")
except Exception as e:
st.warning(f"Could not restore ROI: {e}")
# Show coordinates
if "roi_coords" in st.session_state:
st.markdown("### ๐ ROI Coordinates (Lat, Lon)")
st.write(st.session_state["roi_coords"])
# --- Compute summaries
if st.button("Compute Summaries"):
if roi is None:
st.error("โ ๏ธ No ROI found. Please draw first.")
else:
st.success("ROI ready โ computing...")
soil_val = safe_get_reduce(roi, soil_img.select(chosen_soil_band), chosen_soil_band, 1000) if soil_img and chosen_soil_band else None
elev_val = safe_get_reduce(roi, dem, dem_band_name, 1000) if dem else None
seismic_val = safe_get_reduce(roi, seismic_img, seismic_band, 5000) if seismic_img else None
flood_val = safe_get_reduce(roi, water.select(water_band), water_band, 30) if water else None
lc_stats = safe_reduce_histogram(roi, landcover, lc_band, 30) if landcover else {}
ndvi_ts = []
if ndvi_col:
end = datetime.utcnow().strftime("%Y-%m-%d")
start = (datetime.utcnow() - timedelta(days=365*2)).strftime("%Y-%m-%d")
ndvi_ts = safe_time_series(roi, ndvi_col, "NDVI", start, end)
precip_ts, temp_ts, pm25_val = [], [], None
if precip_col:
end = datetime.utcnow().strftime("%Y-%m-%d")
start = (datetime.utcnow() - timedelta(days=365)).strftime("%Y-%m-%d")
precip_ts = safe_time_series(roi, precip_col, "precipitation", start, end, scale=5000)
if temp_col:
end = datetime.utcnow().strftime("%Y-%m-%d")
start = (datetime.utcnow() - timedelta(days=365)).strftime("%Y-%m-%d")
temp_ts = safe_time_series(roi, temp_col, "LST_Day_1km", start, end, scale=1000)
if pm25_img:
pm25_val = safe_get_reduce(roi, pm25_img, "absorbing_aerosol_index", 10000)
# Save to site
active = st.session_state.get("active_site", 0)
if "sites" in st.session_state:
site = st.session_state["sites"][active]
try:
site["ROI"] = roi.getInfo()
except Exception:
site["ROI"] = "Not available"
site["Soil Profile"] = f"{soil_val} ({chosen_soil_band})" if soil_val else "N/A"
site["Topo Data"] = f"{elev_val} m" if elev_val else "N/A"
site["Seismic Data"] = seismic_val if seismic_val else "N/A"
site["Flood Data"] = flood_val if flood_val else "N/A"
site["Environmental Data"] = {"Landcover": lc_stats, "NDVI": ndvi_ts}
site["Atmospheric Data"] = {"Precipitation": precip_ts, "Temperature": temp_ts, "PM2.5": pm25_val}
st.session_state["soil_json"] = {
"Soil": soil_val, "Soil Band": chosen_soil_band,
"Elevation": elev_val, "Seismic": seismic_val,
"Flood": flood_val, "Landcover Stats": lc_stats,
"NDVI TS": ndvi_ts,
"Precipitation TS": precip_ts,
"Temperature TS": temp_ts,
"PM2.5": pm25_val
}
# Snapshot
map_bytes = export_map_snapshot(m)
if map_bytes:
st.session_state["last_map_snapshot"] = map_bytes
if "sites" in st.session_state:
st.session_state["sites"][active]["map_snapshot"] = map_bytes
st.image(map_bytes, caption="Map Snapshot", use_column_width=True)
import plotly.express as px
import plotly.graph_objects as go
# -------------------------------
# ๐ Display Summaries (Locator)
# -------------------------------
st.subheader("๐ Summary")
# --- Metric Cards
c1, c2, c3 = st.columns(3)
c1.metric("๐ค Soil (%)", f"{soil_val:.2f}" if soil_val else "N/A", help="Soil clay content")
c2.metric("โฐ๏ธ Elevation (m)", f"{elev_val:.1f}" if elev_val else "N/A", help="Mean elevation")
c3.metric("๐ช๏ธ Seismic PGA", f"{seismic_val:.3f}" if seismic_val else "N/A", help="Seismic hazard index")
c4, c5, c6 = st.columns(3)
c4.metric("๐ Flood Occurrence", f"{flood_val:.2f}" if flood_val else "N/A")
c5.metric("๐จ PM2.5 Index", f"{pm25_val:.2f}" if pm25_val else "N/A")
c6.metric("๐ข NDVI Count", f"{len(ndvi_ts)} pts" if ndvi_ts else "0")
# --- Pie Chart for Landcover
# --- Donut Chart for Landcover
if lc_stats:
labels = list(map(str, lc_stats.keys()))
values = list(lc_stats.values())
fig = go.Figure(data=[go.Pie(
labels=labels,
values=values,
hole=0.5, # donut
textinfo="percent+label", # show % and class
insidetextorientation="radial",
marker=dict(colors=px.colors.sequential.Oranges_r)
)])
fig.update_layout(
title="๐ Landcover Distribution",
template="plotly_dark",
margin=dict(l=20, r=20, t=40, b=20)
)
st.plotly_chart(fig, use_container_width=True)
# --- Time Series with Plotly
def plot_timeseries(ts, title, ylab, color):
if ts:
dates, vals = zip(*ts)
fig = go.Figure()
fig.add_trace(go.Scatter(x=dates, y=vals, mode="lines+markers", line=dict(color=color)))
fig.update_layout(
title=title,
xaxis_title="Date",
yaxis_title=ylab,
template="plotly_dark",
hovermode="x unified",
margin=dict(l=20, r=20, t=40, b=20)
)
st.plotly_chart(fig, use_container_width=True)
plot_timeseries(ndvi_ts, "๐ฟ NDVI Trend (2 years)", "NDVI", "#FF7A00")
plot_timeseries(precip_ts, "๐ง๏ธ Precipitation Trend (1 year)", "mm/day", "#00BFFF")
plot_timeseries(temp_ts, "๐ก๏ธ Land Surface Temp (1 year)", "K", "#FF3333")
# GeoMate Ask (RAG) โ simple chat with memory per site and auto-extract numeric values
import re, json, pickle
import streamlit as st
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
# -------------------
# Load FAISS DB once
# -------------------
@st.cache_resource
def load_faiss():
# Adjust path to where you unzip faiss_books_db.zip
faiss_dir = "faiss_books_db"
# embeddings must match the one you used when creating index
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
with open(f"{faiss_dir}/index.pkl", "rb") as f:
data = pickle.load(f)
vectorstore = FAISS.load_local(faiss_dir, embeddings, allow_dangerous_deserialization=True)
return vectorstore
vectorstore = load_faiss()
# -------------------
# RAG Chat Page
# -------------------
def rag_page():
st.header("๐ค GeoMate Ask (RAG + LLM)")
site = st.session_state["sites"][st.session_state["active_site"]]
# --- Ensure Site ID exists ---
if site.get("Site ID") is None:
site_id = st.session_state["sites"].index(site)
site["Site ID"] = site_id
else:
site_id = site["Site ID"]
# --- Initialize rag_history properly ---
if "rag_history" not in st.session_state:
st.session_state["rag_history"] = {}
if site_id not in st.session_state["rag_history"]:
st.session_state["rag_history"][site_id] = []
# --- Display chat history ---
hist = st.session_state["rag_history"][site_id]
for entry in hist:
who, text = entry.get("who"), entry.get("text")
if who == "bot":
st.markdown(f"
{text}
", unsafe_allow_html=True)
else:
st.markdown(f"
{text}
", unsafe_allow_html=True)
# --- User input ---
user_msg = st.text_input("You:", key=f"rag_input_{site_id}")
if st.button("Send", key=f"rag_send_{site_id}"):
if not user_msg.strip():
st.warning("Enter a message.")
else:
# Save user msg
st.session_state["rag_history"][site_id].append(
{"who": "user", "text": user_msg}
)
# --- Retrieve from FAISS ---
docs = vectorstore.similarity_search(user_msg, k=3)
context_text = "\n".join([d.page_content for d in docs])
# --- Build context for LLM ---
context = {
"site": {
k: v
for k, v in site.items()
if k in [
"Site Name",
"lat",
"lon",
"USCS",
"AASHTO",
"GI",
"Load Bearing Capacity",
"Soil Profile",
"Flood Data",
"Seismic Data",
]
},
"chat_history": st.session_state["rag_history"][site_id],
}
prompt = (
f"You are GeoMate AI, an expert geotechnical assistant.\n\n"
f"Relevant references:\n{context_text}\n\n"
f"Site context: {json.dumps(context)}\n\n"
f"User: {user_msg}\n\n"
f"Answer concisely, include citations [ref:source]. "
f"If user provides numeric engineering values, return them in the format: [[FIELD: value unit]]."
)
# Call the unified LLM function
resp = llm_generate(prompt, model=st.session_state["llm_model"], max_tokens=500)
# Save bot reply
st.session_state["rag_history"][site_id].append({"who": "bot", "text": resp})
# Display reply
st.markdown(f"
{resp}
", unsafe_allow_html=True)
# Extract bracketed numeric values
matches = re.findall(
r"\[\[([A-Za-z0-9 _/-]+):\s*([0-9.+-eE]+)\s*([A-Za-z%\/]*)\]\]", resp
)
for m in matches:
field, val, unit = m[0].strip(), m[1].strip(), m[2].strip()
if "bearing" in field.lower():
site["Load Bearing Capacity"] = f"{val} {unit}"
elif "skin" in field.lower():
site["Skin Shear Strength"] = f"{val} {unit}"
elif "compaction" in field.lower():
site["Relative Compaction"] = f"{val} {unit}"
st.success(
"Response saved โ with citations and recognized numeric fields auto-stored in site data."
)
# -------------------
# Report fields (still needed in reports_page)
# -------------------
REPORT_FIELDS = [
("Load Bearing Capacity", "kPa or psf"),
("Skin Shear Strength", "kPa"),
("Relative Compaction", "%"),
("Rate of Consolidation", "mm/yr or days"),
("Nature of Construction", "text"),
("Borehole Count", "number"),
("Max Depth (m)", "m"),
("SPT N (avg)", "blows/ft"),
("CBR (%)", "%"),
("Allowable Bearing (kPa)", "kPa"),
]
# -------------------------------
# Imports
# -------------------------------
import io, re, json, tempfile
from datetime import datetime
from typing import Dict, Any, Optional, List
import streamlit as st
from reportlab.platypus import (
SimpleDocTemplate, Paragraph, Spacer, PageBreak, Table, TableStyle, Image as RLImage
)
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib import colors
from reportlab.lib.pagesizes import A4
from reportlab.lib.units import mm
# =============================
# LLM Helper (Groq API)
# =============================
import requests, json, os
import streamlit as st
from datetime import datetime
import tempfile
from typing import Dict, Any, Optional, List
from reportlab.platypus import (
SimpleDocTemplate, Paragraph, Spacer, PageBreak, Table, TableStyle, Image as RLImage
)
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib import colors
from reportlab.lib.pagesizes import A4
from reportlab.lib.units import mm
# ----------------------------
# LLM Helper (Reports Analysis with Orchestration)
# ----------------------------
def groq_llm_analyze(prompt: str, section_title: str, max_tokens: int = 500) -> str:
"""
Query the selected model (Groq / Gemini / DeepSeek).
If token limit or error occurs, automatically switch to backup models
and continue the analysis seamlessly.
"""
# Primary model (user choice from sidebar)
model_chain = [st.session_state.get("llm_model", "groq/compound")]
# Add fallback chain (priority order: DeepSeek โ Gemini โ Groq)
if "deepseek" not in model_chain:
model_chain.append("deepseek-r1-distill-llama-70b")
if "gemini" not in model_chain:
model_chain.append("gemini-1.5-pro")
if "groq/compound" not in model_chain:
model_chain.append("groq/compound")
system_message = (
"You are GeoMate, a geotechnical engineering assistant. "
"Respond professionally with concise analysis and insights."
)
full_prompt = (
f"{system_message}\n\n"
f"Section: {section_title}\n\n"
f"Input: {prompt}\n\n"
f"Write a professional engineering analysis for this section."
)
final_response = ""
remaining_prompt = full_prompt
# Try each model in the chain until completion
for model_name in model_chain:
try:
response = llm_generate(remaining_prompt, model=model_name, max_tokens=max_tokens)
if not response or "[LLM error" in response:
# If failed, continue to next model
continue
final_response += response.strip()
# If response length is close to max_tokens, assume it cut off โ continue with next model
if len(response.split()) >= (max_tokens - 20):
# Add continuation instruction for the next model
remaining_prompt = (
f"Continue the analysis from where the last model stopped. "
f"So far the draft is:\n\n{final_response}\n\n"
f"Continue writing professionally without repeating."
)
continue
else:
break # Finished properly, exit loop
except Exception as e:
final_response += f"\n[LLM orchestration error @ {model_name}: {e}]\n"
continue
return final_response if final_response else "[LLM error: All models failed]"
# =============================
# Build Full Geotechnical Report
# =============================
def build_full_geotech_pdf(
site: Dict[str, Any],
filename: str,
include_map_image: Optional[bytes] = None,
ext_refs: Optional[List[str]] = None
):
styles = getSampleStyleSheet()
title_style = ParagraphStyle("title", parent=styles["Title"], fontSize=22,
alignment=1, textColor=colors.HexColor("#FF6600"), spaceAfter=12)
h1 = ParagraphStyle("h1", parent=styles["Heading1"], fontSize=14,
textColor=colors.HexColor("#1F4E79"), spaceBefore=10, spaceAfter=6)
body = ParagraphStyle("body", parent=styles["BodyText"], fontSize=10.5, leading=13)
doc = SimpleDocTemplate(filename, pagesize=A4,
leftMargin=18*mm, rightMargin=18*mm,
topMargin=18*mm, bottomMargin=18*mm)
elems = []
# Title Page
elems.append(Paragraph("GEOTECHNICAL INVESTIGATION REPORT", title_style))
elems.append(Spacer(1, 12))
elems.append(Paragraph(f"Client: {site.get('Company Name','-')}", body))
elems.append(Paragraph(f"Contact: {site.get('Company Contact','-')}", body))
elems.append(Paragraph(f"Project: {site.get('Project Name','-')}", body))
elems.append(Paragraph(f"Site: {site.get('Site Name','-')}", body))
elems.append(Paragraph(f"Date: {datetime.today().strftime('%Y-%m-%d')}", body))
elems.append(PageBreak())
# TOC
elems.append(Paragraph("TABLE OF CONTENTS", h1))
toc_items = [
"1.0 Summary", "2.0 Introduction", "3.0 Site Description & Geology",
"4.0 Field & Laboratory Testing", "5.0 Evaluation of Geotechnical Properties",
"6.0 Provisional Classification", "7.0 Recommendations",
"8.0 LLM Analysis", "9.0 Figures & Tables", "10.0 Appendices & References"
]
for i, t in enumerate(toc_items, 1):
elems.append(Paragraph(f"{i}. {t}", body))
elems.append(PageBreak())
# Sections with LLM calls
elems.append(Paragraph("1.0 SUMMARY", h1))
elems.append(Paragraph(groq_llm_analyze(json.dumps(site, indent=2), "Summary"), body))
elems.append(PageBreak())
elems.append(Paragraph("2.0 INTRODUCTION", h1))
elems.append(Paragraph(groq_llm_analyze(site.get("Project Description",""), "Introduction"), body))
elems.append(Paragraph("3.0 SITE DESCRIPTION & GEOLOGY", h1))
geology_text = f"Topo: {site.get('Topography')}, Drainage: {site.get('Drainage')}, Land Use: {site.get('Current Land Use')}, Geology: {site.get('Regional Geology')}"
elems.append(Paragraph(groq_llm_analyze(geology_text, "Geology & Site Description"), body))
elems.append(PageBreak())
elems.append(Paragraph("4.0 FIELD & LABORATORY TESTING", h1))
elems.append(Paragraph(groq_llm_analyze(json.dumps(site.get('Laboratory Results',[]), indent=2), "Field & Lab Testing"), body))
elems.append(PageBreak())
elems.append(Paragraph("5.0 EVALUATION OF GEOTECHNICAL PROPERTIES", h1))
elems.append(Paragraph(groq_llm_analyze(json.dumps(site, indent=2), "Evaluation of Properties"), body))
elems.append(Paragraph("6.0 PROVISIONAL CLASSIFICATION", h1))
class_text = f"USCS={site.get('USCS')}, AASHTO={site.get('AASHTO')}"
elems.append(Paragraph(groq_llm_analyze(class_text, "Soil Classification"), body))
elems.append(Paragraph("7.0 RECOMMENDATIONS", h1))
elems.append(Paragraph(groq_llm_analyze(json.dumps(site, indent=2), "Recommendations"), body))
elems.append(Paragraph("8.0 LLM ANALYSIS (GeoMate)", h1))
elems.append(Paragraph(groq_llm_analyze(json.dumps(site, indent=2), "LLM Insights"), body))
# Map snapshot
if include_map_image:
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
tmp.write(include_map_image)
tmp.flush()
elems.append(PageBreak())
elems.append(Paragraph("9.0 MAP SNAPSHOT", h1))
elems.append(RLImage(tmp.name, width=160*mm, height=90*mm))
# References
elems.append(PageBreak())
elems.append(Paragraph("10.0 REFERENCES", h1))
if ext_refs:
for r in ext_refs:
elems.append(Paragraph(f"- {r}", body))
else:
elems.append(Paragraph("No external references provided.", body))
doc.build(elems)
return filename
# =============================
# Build Classification Report
# =============================
def build_classification_pdf(
site: Dict[str, Any],
classification: Dict[str, Any],
filename: str
):
styles = getSampleStyleSheet()
title_style = ParagraphStyle("title", parent=styles["Title"], fontSize=18,
textColor=colors.HexColor("#FF6600"), alignment=1)
h1 = ParagraphStyle("h1", parent=styles["Heading1"], fontSize=12, textColor=colors.HexColor("#1F4E79"))
body = ParagraphStyle("body", parent=styles["BodyText"], fontSize=10)
doc = SimpleDocTemplate(filename, pagesize=A4,
leftMargin=18*mm, rightMargin=18*mm,
topMargin=18*mm, bottomMargin=18*mm)
elems = []
# Title Page
elems.append(Paragraph("SOIL CLASSIFICATION REPORT", title_style))
elems.append(Spacer(1, 12))
elems.append(Paragraph(f"Site: {site.get('Site Name','Unnamed')}", body))
elems.append(Paragraph(f"Date: {datetime.today().strftime('%Y-%m-%d')}", body))
elems.append(PageBreak())
# Sections
elems.append(Paragraph("1.0 DETERMINISTIC RESULTS", h1))
elems.append(Paragraph(groq_llm_analyze(json.dumps(classification, indent=2), "Deterministic Results"), body))
elems.append(PageBreak())
elems.append(Paragraph("2.0 ENGINEERING CHARACTERISTICS", h1))
elems.append(Paragraph(groq_llm_analyze(json.dumps(classification.get('engineering_characteristics',{}), indent=2), "Engineering Characteristics"), body))
elems.append(PageBreak())
elems.append(Paragraph("3.0 DECISION PATHS", h1))
dp_text = f"USCS path: {classification.get('USCS_decision_path')}, AASHTO path: {classification.get('AASHTO_decision_path')}"
elems.append(Paragraph(groq_llm_analyze(dp_text, "Decision Paths"), body))
elems.append(PageBreak())
elems.append(Paragraph("4.0 LLM ANALYSIS", h1))
elems.append(Paragraph(groq_llm_analyze(json.dumps(classification, indent=2), "LLM Analysis"), body))
doc.build(elems)
return filename
# -------------------------------
# Reports Page
# -------------------------------
def reports_page():
st.header("๐ Reports โ Classification & Full Geotechnical")
site = st.session_state["sites"][st.session_state["active_site"]]
# =====================================================
# Classification Report
# =====================================================
st.subheader("๐ Classification-only Report")
if site.get("classifier_decision"):
st.markdown("You have a saved classification for this site.")
if st.button("Generate Classification PDF"):
fname = f"classification_{site['Site Name'].replace(' ','_')}.pdf"
# Collect references from rag_history
refs = []
if "rag_history" in st.session_state and site.get("Site ID") in st.session_state["rag_history"]:
for h in st.session_state["rag_history"][site["Site ID"]]:
if h["who"] == "bot" and "[ref:" in h["text"]:
for m in re.findall(r"\[ref:([^\]]+)\]", h["text"]):
refs.append(m)
# Build classification PDF
buffer = io.BytesIO()
build_classification_pdf(site, site.get("classifier_decision"), buffer)
buffer.seek(0)
st.download_button("โฌ๏ธ Download Classification PDF", buffer, file_name=fname, mime="application/pdf")
else:
st.info("No classification saved for this site yet. Use the Classifier page.")
# =====================================================
# Quick Report Form
# =====================================================
st.markdown("### โ๏ธ Quick report form (edit values and request LLM analysis)")
with st.form(key="report_quick_form"):
cols = st.columns([2, 1, 1])
cols[0].markdown("**Parameter**")
cols[1].markdown("**Value**")
cols[2].markdown("**Unit / Notes**")
inputs = {}
for (fld, unit) in REPORT_FIELDS:
c1, c2, c3 = st.columns([2, 1, 1])
c1.markdown(f"**{fld}**")
default_val = site.get(fld, "")
inputs[fld] = c2.text_input(fld, value=str(default_val), label_visibility="collapsed", key=f"quick_{fld}")
c3.markdown(unit)
submitted = st.form_submit_button("Save values to site")
if submitted:
for fld, _ in REPORT_FIELDS:
val = inputs.get(fld, "").strip()
site[fld] = val if val != "" else "Not provided"
st.success("โ Saved quick report values to active site.")
# =====================================================
# LLM Analysis (Humanized Report Text)
# =====================================================
st.markdown("#### ๐ค LLM-powered analysis")
if st.button("Ask GeoMate (generate analysis & recommendations)"):
context = {
"site_name": site.get("Site Name"),
"project": site.get("Project Name"),
"site_summary": {
"USCS": site.get("USCS"), "AASHTO": site.get("AASHTO"), "GI": site.get("GI"),
"Soil Profile": site.get("Soil Profile"),
"Key lab results": [r.get("sampleId") for r in site.get("Laboratory Results", [])]
},
"inputs": {fld: site.get(fld, "Not provided") for fld, _ in REPORT_FIELDS}
}
prompt = (
"You are GeoMate AI, an engineering assistant. Given the following site context and "
"engineering parameters (some may be 'Not provided'), produce:\n1) short executive summary, "
"2) geotechnical interpretation (classification, key risks), 3) recommended remedial/improvement "
"options and 4) short design notes. Provide any numeric outputs in the format [[FIELD: value unit]].\n\n"
f"Context: {json.dumps(context)}"
)
resp = groq_llm_analyze(prompt, section_title="GeoMate Analysis")
st.markdown("**GeoMate analysis**")
st.markdown(resp)
# Extract structured values from [[FIELD: value unit]]
matches = re.findall(r"\[\[([A-Za-z0-9 _/-]+):\s*([0-9.+-eE]+)\s*([A-Za-z%\/]*)\]\]", resp)
for m in matches:
field, val, unit = m[0].strip(), m[1].strip(), m[2].strip()
if "bearing" in field.lower():
site["Load Bearing Capacity"] = f"{val} {unit}"
elif "skin" in field.lower():
site["Skin Shear Strength"] = f"{val} {unit}"
elif "compaction" in field.lower():
site["Relative Compaction"] = f"{val} {unit}"
site["LLM_Report_Text"] = resp
st.success("โ LLM analysis saved to site under 'LLM_Report_Text'.")
# =====================================================
# Full Geotechnical Report
# =====================================================
st.markdown("---")
st.subheader("๐ Full Geotechnical Report")
ext_ref_text = st.text_area("Optional: External references (one per line)", value="")
ext_refs = [r.strip() for r in ext_ref_text.splitlines() if r.strip()]
# Add FAISS / rag references
faiss_refs = []
if "rag_history" in st.session_state and site.get("Site ID") in st.session_state["rag_history"]:
for h in st.session_state["rag_history"][site["Site ID"]]:
if h["who"] == "bot" and "[ref:" in h["text"]:
for m in re.findall(r"\[ref:([^\]]+)\]", h["text"]):
faiss_refs.append(m)
all_refs = list(set(ext_refs + faiss_refs))
if st.button("Generate Full Geotechnical Report PDF"):
outname = f"Full_Geotech_Report_{site.get('Site Name','site')}.pdf"
mapimg = site.get("map_snapshot")
# โ Classification results also included inside full report
build_full_geotech_pdf(site, outname, include_map_image=mapimg, ext_refs=all_refs)
with open(outname, "rb") as f:
st.download_button("โฌ๏ธ Download Full Geotechnical Report", f, file_name=outname, mime="application/pdf")
# 8) Page router
if "page" not in st.session_state:
st.session_state["page"] = "Home"
page = st.session_state["page"]
# Option menu top (main nav)
# ===============================
# Navigation (Option Menu)
# ===============================
from streamlit_option_menu import option_menu
# Define all pages
PAGES = ["Home", "Soil recognizer", "Classifier", "GSD", "OCR", "Locator", "RAG", "Reports"]
# Set default page if not defined yet
if "page" not in st.session_state:
st.session_state["page"] = "Home"
# Build horizontal option menu
# ===============================
# Sidebar or Top-Bar Model Selector
# ===============================
# ===============================
# Page Menu (Horizontal)
# ===============================
selected = option_menu(
None,
PAGES,
icons=[
"house", "chart", "journal-code", "bar-chart", "camera",
"geo-alt", "robot", "file-earmark-text"
],
menu_icon="cast",
default_index=PAGES.index(st.session_state["page"]) if st.session_state["page"] in PAGES else 0,
orientation="horizontal",
styles={
"container": {"padding": "0px", "background-color": "#0b0b0b"},
"nav-link": {"font-size": "14px", "color": "#cfcfcf"},
"nav-link-selected": {"background-color": "#FF7A00", "color": "white"},
}
)
# Save selection into session_state
st.session_state["page"] = selected
# ===============================
# Page Routing
# ===============================
if selected == "Home":
st.title("๐ Welcome to GeoMate")
st.write("Your geotechnical AI copilot.")
st.info(f"Currently using **{st.session_state['llm_model']}** for analysis.")
elif selected == "Soil recognizer":
st.title("๐ Soil Recognizer")
st.write("Upload soil images for classification.")
elif selected == "Classifier":
st.title("๐ Soil Classifier")
st.write("Enter lab/field parameters for classification.")
elif selected == "GSD":
st.title("๐ Grain Size Distribution")
st.write("Analyze particle size distribution.")
elif selected == "OCR":
st.title("๐ท OCR Extractor")
st.write("Upload lab sheets for automatic text extraction.")
elif selected == "Locator":
st.title("๐ Locator Tool")
st.write("Draw ROI on map and compute Earth Engine summaries.")
elif selected == "RAG":
st.title("๐ค Knowledge Assistant")
st.write("Query soil and geotechnical references with AI.")
st.caption(f"Model in use: {st.session_state['llm_model']}")
elif selected == "Reports":
st.title("๐ Reports")
st.write("Generate classification and full reports.")
st.caption(f"Analysis will run with: {st.session_state['llm_model']}")
# Display page content
if page == "Home":
landing_page()
elif page == "Classifier":
soil_classifier_page()
elif page == "GSD":
gsd_page()
elif page == "OCR":
ocr_page()
elif page == "Locator":
locator_page()
elif page == "RAG":
rag_page()
elif page == "Reports":
reports_page()
elif page == "Soil recognizer":
soil_recognizer_page()
else:
landing_page()
# Footer
st.markdown("", unsafe_allow_html=True)
st.markdown("
GeoMate V2 โข AI geotechnical copilot โข Built for HF Spaces