Spaces:
Sleeping
Sleeping
| """ | |
| Farm Segmentation API - Gradio Interface | |
| SegFormer and UNet models for agricultural image segmentation | |
| """ | |
| import gradio as gr | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import json | |
| import base64 | |
| import io | |
| import time | |
| from typing import List, Dict, Any | |
| # Import segmentation models | |
| try: | |
| from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation | |
| MODELS_AVAILABLE = True | |
| except ImportError: | |
| MODELS_AVAILABLE = False | |
| class SegmentationAPI: | |
| def __init__(self): | |
| self.models = {} | |
| self.processors = {} | |
| self.model_configs = { | |
| "segformer_b0": "nvidia/segformer-b0-finetuned-ade-512-512", | |
| "segformer_b1": "nvidia/segformer-b1-finetuned-ade-512-512", | |
| "segformer_b2": "nvidia/segformer-b2-finetuned-ade-512-512" | |
| } | |
| # Segmentation classes relevant to agriculture | |
| self.ag_classes = { | |
| "soil": ["dirt", "earth", "ground", "soil", "mud"], | |
| "vegetation": ["grass", "tree", "plant", "leaf", "crop", "vegetation"], | |
| "water": ["water", "river", "pond", "irrigation"], | |
| "sky": ["sky", "cloud"], | |
| "building": ["building", "structure", "barn", "greenhouse"], | |
| "road": ["road", "path", "walkway"], | |
| "equipment": ["machine", "tractor", "equipment"] | |
| } | |
| if MODELS_AVAILABLE: | |
| self.load_models() | |
| def load_models(self): | |
| """Load segmentation models""" | |
| for model_key, model_name in self.model_configs.items(): | |
| try: | |
| print(f"Loading {model_name}...") | |
| processor = SegformerImageProcessor.from_pretrained(model_name) | |
| model = SegformerForSemanticSegmentation.from_pretrained(model_name) | |
| self.processors[model_key] = processor | |
| self.models[model_key] = model | |
| print(f"β {model_name} loaded successfully") | |
| except Exception as e: | |
| print(f"β Failed to load {model_name}: {e}") | |
| def segment_image(self, image: Image.Image, model_key: str = "segformer_b1") -> Dict[str, Any]: | |
| """Segment agricultural image""" | |
| if not MODELS_AVAILABLE or model_key not in self.models: | |
| return {"error": "Model not available"} | |
| start_time = time.time() | |
| try: | |
| # Preprocess image | |
| processor = self.processors[model_key] | |
| model = self.models[model_key] | |
| inputs = processor(images=image, return_tensors="pt") | |
| # Run inference | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # Post-process segmentation | |
| logits = outputs.logits | |
| upsampled_logits = torch.nn.functional.interpolate( | |
| logits, | |
| size=image.size[::-1], | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| predicted_segmentation = upsampled_logits.argmax(dim=1).squeeze().cpu().numpy() | |
| # Analyze segments | |
| segments_info = self.analyze_segments(predicted_segmentation, model) | |
| # Create colored segmentation map | |
| colored_segmentation = self.create_colored_segmentation(predicted_segmentation, model) | |
| processing_time = time.time() - start_time | |
| return { | |
| "segments_detected": len(segments_info), | |
| "segments": segments_info, | |
| "segmentation_map": colored_segmentation, | |
| "processing_time": round(processing_time, 2), | |
| "model_used": model_key | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def analyze_segments(self, segmentation: np.ndarray, model) -> List[Dict[str, Any]]: | |
| """Analyze segmentation results""" | |
| unique_labels = np.unique(segmentation) | |
| segments_info = [] | |
| total_pixels = segmentation.size | |
| for label in unique_labels: | |
| mask = segmentation == label | |
| pixel_count = np.sum(mask) | |
| percentage = (pixel_count / total_pixels) * 100 | |
| if percentage > 1.0: # Only include segments > 1% | |
| class_name = model.config.id2label.get(label, f"class_{label}") | |
| ag_category = self.classify_agricultural_segment(class_name) | |
| segments_info.append({ | |
| "class": class_name, | |
| "agricultural_category": ag_category, | |
| "pixel_count": int(pixel_count), | |
| "percentage": round(percentage, 2), | |
| "label_id": int(label) | |
| }) | |
| return sorted(segments_info, key=lambda x: x["percentage"], reverse=True) | |
| def classify_agricultural_segment(self, class_name: str) -> str: | |
| """Classify segment into agricultural categories""" | |
| class_lower = class_name.lower() | |
| for ag_category, keywords in self.ag_classes.items(): | |
| if any(keyword in class_lower for keyword in keywords): | |
| return ag_category | |
| return "other" | |
| def create_colored_segmentation(self, segmentation: np.ndarray, model) -> np.ndarray: | |
| """Create colored segmentation visualization""" | |
| # Create color palette | |
| num_classes = len(model.config.id2label) | |
| colors = self.generate_colors(num_classes) | |
| # Create colored image | |
| h, w = segmentation.shape | |
| colored = np.zeros((h, w, 3), dtype=np.uint8) | |
| for label in np.unique(segmentation): | |
| mask = segmentation == label | |
| colored[mask] = colors[label % len(colors)] | |
| return colored | |
| def generate_colors(self, num_colors: int) -> List[List[int]]: | |
| """Generate distinct colors for segmentation classes""" | |
| import random | |
| random.seed(42) # For consistent colors | |
| colors = [] | |
| for i in range(num_colors): | |
| colors.append([ | |
| random.randint(50, 255), | |
| random.randint(50, 255), | |
| random.randint(50, 255) | |
| ]) | |
| return colors | |
| # Initialize API | |
| api = SegmentationAPI() | |
| def predict_segmentation(image, model_choice): | |
| """Gradio prediction function""" | |
| if image is None: | |
| return None, None, "Please upload an image" | |
| # Convert to PIL Image | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Run segmentation | |
| results = api.segment_image(image, model_choice) | |
| if "error" in results: | |
| return None, None, f"Error: {results['error']}" | |
| # Create visualization | |
| segmentation_img = Image.fromarray(results["segmentation_map"]) | |
| # Format results text | |
| results_text = f""" | |
| ποΈ **Agricultural Segmentation Analysis** | |
| π **Segments Detected**: {results['segments_detected']} | |
| β±οΈ **Processing Time**: {results['processing_time']}s | |
| π€ **Model**: {results['model_used']} | |
| **πΎ Agricultural Composition**: | |
| """ | |
| for segment in results["segments"][:10]: # Top 10 segments | |
| results_text += f"\nβ’ **{segment['class']}** ({segment['agricultural_category']}): {segment['percentage']:.1f}%" | |
| return image, segmentation_img, results_text | |
| # Gradio Interface | |
| with gr.Blocks(title="ποΈ Farm Segmentation API") as app: | |
| gr.Markdown("# ποΈ Farm Segmentation API") | |
| gr.Markdown("AI-powered agricultural image segmentation and land analysis") | |
| with gr.Tab("πΎ Field Analysis"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="pil", label="Upload Farm Image") | |
| model_choice = gr.Dropdown( | |
| choices=["segformer_b0", "segformer_b1", "segformer_b2"], | |
| value="segformer_b1", | |
| label="Select Model" | |
| ) | |
| segment_btn = gr.Button("π Analyze Segments", variant="primary") | |
| with gr.Column(): | |
| original_image = gr.Image(label="Original Image") | |
| segmented_image = gr.Image(label="Segmentation Map") | |
| results_text = gr.Textbox(label="Segmentation Analysis", lines=15) | |
| segment_btn.click( | |
| predict_segmentation, | |
| inputs=[image_input, model_choice], | |
| outputs=[original_image, segmented_image, results_text] | |
| ) | |
| with gr.Tab("π‘ API Documentation"): | |
| gr.Markdown(""" | |
| ## π API Endpoint | |
| **POST** `/api/predict` | |
| ### Request Format | |
| ```json | |
| { | |
| "data": ["<base64_image>", "<model_choice>"] | |
| } | |
| ``` | |
| ### Model Options | |
| - **segformer_b0**: Fastest, basic segmentation | |
| - **segformer_b1**: Balanced speed and accuracy (recommended) | |
| - **segformer_b2**: Higher accuracy, slower processing | |
| ### Response Format | |
| ```json | |
| { | |
| "segments_detected": 8, | |
| "segments": [ | |
| { | |
| "class": "grass", | |
| "agricultural_category": "vegetation", | |
| "pixel_count": 145632, | |
| "percentage": 35.2, | |
| "label_id": 9 | |
| } | |
| ], | |
| "processing_time": 2.1 | |
| } | |
| ``` | |
| ### Agricultural Categories | |
| - **soil**: Ground, dirt, earth | |
| - **vegetation**: Crops, grass, trees | |
| - **water**: Irrigation, ponds, rivers | |
| - **building**: Barns, greenhouses, structures | |
| - **equipment**: Tractors, machinery | |
| - **other**: Uncategorized segments | |
| """) | |
| if __name__ == "__main__": | |
| app.launch() |