Dhiryashil's picture
Upload 3 files
01ef086 verified
"""
Farm Segmentation API - Gradio Interface
SegFormer and UNet models for agricultural image segmentation
"""
import gradio as gr
import torch
import cv2
import numpy as np
from PIL import Image
import json
import base64
import io
import time
from typing import List, Dict, Any
# Import segmentation models
try:
from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
MODELS_AVAILABLE = True
except ImportError:
MODELS_AVAILABLE = False
class SegmentationAPI:
def __init__(self):
self.models = {}
self.processors = {}
self.model_configs = {
"segformer_b0": "nvidia/segformer-b0-finetuned-ade-512-512",
"segformer_b1": "nvidia/segformer-b1-finetuned-ade-512-512",
"segformer_b2": "nvidia/segformer-b2-finetuned-ade-512-512"
}
# Segmentation classes relevant to agriculture
self.ag_classes = {
"soil": ["dirt", "earth", "ground", "soil", "mud"],
"vegetation": ["grass", "tree", "plant", "leaf", "crop", "vegetation"],
"water": ["water", "river", "pond", "irrigation"],
"sky": ["sky", "cloud"],
"building": ["building", "structure", "barn", "greenhouse"],
"road": ["road", "path", "walkway"],
"equipment": ["machine", "tractor", "equipment"]
}
if MODELS_AVAILABLE:
self.load_models()
def load_models(self):
"""Load segmentation models"""
for model_key, model_name in self.model_configs.items():
try:
print(f"Loading {model_name}...")
processor = SegformerImageProcessor.from_pretrained(model_name)
model = SegformerForSemanticSegmentation.from_pretrained(model_name)
self.processors[model_key] = processor
self.models[model_key] = model
print(f"βœ… {model_name} loaded successfully")
except Exception as e:
print(f"❌ Failed to load {model_name}: {e}")
def segment_image(self, image: Image.Image, model_key: str = "segformer_b1") -> Dict[str, Any]:
"""Segment agricultural image"""
if not MODELS_AVAILABLE or model_key not in self.models:
return {"error": "Model not available"}
start_time = time.time()
try:
# Preprocess image
processor = self.processors[model_key]
model = self.models[model_key]
inputs = processor(images=image, return_tensors="pt")
# Run inference
with torch.no_grad():
outputs = model(**inputs)
# Post-process segmentation
logits = outputs.logits
upsampled_logits = torch.nn.functional.interpolate(
logits,
size=image.size[::-1],
mode="bilinear",
align_corners=False,
)
predicted_segmentation = upsampled_logits.argmax(dim=1).squeeze().cpu().numpy()
# Analyze segments
segments_info = self.analyze_segments(predicted_segmentation, model)
# Create colored segmentation map
colored_segmentation = self.create_colored_segmentation(predicted_segmentation, model)
processing_time = time.time() - start_time
return {
"segments_detected": len(segments_info),
"segments": segments_info,
"segmentation_map": colored_segmentation,
"processing_time": round(processing_time, 2),
"model_used": model_key
}
except Exception as e:
return {"error": str(e)}
def analyze_segments(self, segmentation: np.ndarray, model) -> List[Dict[str, Any]]:
"""Analyze segmentation results"""
unique_labels = np.unique(segmentation)
segments_info = []
total_pixels = segmentation.size
for label in unique_labels:
mask = segmentation == label
pixel_count = np.sum(mask)
percentage = (pixel_count / total_pixels) * 100
if percentage > 1.0: # Only include segments > 1%
class_name = model.config.id2label.get(label, f"class_{label}")
ag_category = self.classify_agricultural_segment(class_name)
segments_info.append({
"class": class_name,
"agricultural_category": ag_category,
"pixel_count": int(pixel_count),
"percentage": round(percentage, 2),
"label_id": int(label)
})
return sorted(segments_info, key=lambda x: x["percentage"], reverse=True)
def classify_agricultural_segment(self, class_name: str) -> str:
"""Classify segment into agricultural categories"""
class_lower = class_name.lower()
for ag_category, keywords in self.ag_classes.items():
if any(keyword in class_lower for keyword in keywords):
return ag_category
return "other"
def create_colored_segmentation(self, segmentation: np.ndarray, model) -> np.ndarray:
"""Create colored segmentation visualization"""
# Create color palette
num_classes = len(model.config.id2label)
colors = self.generate_colors(num_classes)
# Create colored image
h, w = segmentation.shape
colored = np.zeros((h, w, 3), dtype=np.uint8)
for label in np.unique(segmentation):
mask = segmentation == label
colored[mask] = colors[label % len(colors)]
return colored
def generate_colors(self, num_colors: int) -> List[List[int]]:
"""Generate distinct colors for segmentation classes"""
import random
random.seed(42) # For consistent colors
colors = []
for i in range(num_colors):
colors.append([
random.randint(50, 255),
random.randint(50, 255),
random.randint(50, 255)
])
return colors
# Initialize API
api = SegmentationAPI()
def predict_segmentation(image, model_choice):
"""Gradio prediction function"""
if image is None:
return None, None, "Please upload an image"
# Convert to PIL Image
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Run segmentation
results = api.segment_image(image, model_choice)
if "error" in results:
return None, None, f"Error: {results['error']}"
# Create visualization
segmentation_img = Image.fromarray(results["segmentation_map"])
# Format results text
results_text = f"""
🏞️ **Agricultural Segmentation Analysis**
πŸ“Š **Segments Detected**: {results['segments_detected']}
⏱️ **Processing Time**: {results['processing_time']}s
πŸ€– **Model**: {results['model_used']}
**🌾 Agricultural Composition**:
"""
for segment in results["segments"][:10]: # Top 10 segments
results_text += f"\nβ€’ **{segment['class']}** ({segment['agricultural_category']}): {segment['percentage']:.1f}%"
return image, segmentation_img, results_text
# Gradio Interface
with gr.Blocks(title="🏞️ Farm Segmentation API") as app:
gr.Markdown("# 🏞️ Farm Segmentation API")
gr.Markdown("AI-powered agricultural image segmentation and land analysis")
with gr.Tab("🌾 Field Analysis"):
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload Farm Image")
model_choice = gr.Dropdown(
choices=["segformer_b0", "segformer_b1", "segformer_b2"],
value="segformer_b1",
label="Select Model"
)
segment_btn = gr.Button("πŸ” Analyze Segments", variant="primary")
with gr.Column():
original_image = gr.Image(label="Original Image")
segmented_image = gr.Image(label="Segmentation Map")
results_text = gr.Textbox(label="Segmentation Analysis", lines=15)
segment_btn.click(
predict_segmentation,
inputs=[image_input, model_choice],
outputs=[original_image, segmented_image, results_text]
)
with gr.Tab("πŸ“‘ API Documentation"):
gr.Markdown("""
## πŸš€ API Endpoint
**POST** `/api/predict`
### Request Format
```json
{
"data": ["<base64_image>", "<model_choice>"]
}
```
### Model Options
- **segformer_b0**: Fastest, basic segmentation
- **segformer_b1**: Balanced speed and accuracy (recommended)
- **segformer_b2**: Higher accuracy, slower processing
### Response Format
```json
{
"segments_detected": 8,
"segments": [
{
"class": "grass",
"agricultural_category": "vegetation",
"pixel_count": 145632,
"percentage": 35.2,
"label_id": 9
}
],
"processing_time": 2.1
}
```
### Agricultural Categories
- **soil**: Ground, dirt, earth
- **vegetation**: Crops, grass, trees
- **water**: Irrigation, ponds, rivers
- **building**: Barns, greenhouses, structures
- **equipment**: Tractors, machinery
- **other**: Uncategorized segments
""")
if __name__ == "__main__":
app.launch()