=== app.py === import spaces import gradio as gr import torch import numpy as np from PIL import Image from transformers import SamModel, SamProcessor from datasets import load_dataset import requests from io import BytesIO import warnings warnings.filterwarnings("ignore") # Global model and processor - Using SAM (Segment Anything Model) device = "cuda" if torch.cuda.is_available() else "cpu" model = None processor = None def load_model(): """Load SAM model lazily""" global model, processor if model is None: print("Loading SAM model...") processor = SamProcessor.from_pretrained("facebook/sam-vit-base") model = SamModel.from_pretrained("facebook/sam-vit-base") if torch.cuda.is_available(): model = model.to(device) print("Model loaded successfully!") return model, processor # Public neuroimaging datasets on Hugging Face NEUROIMAGING_DATASETS = { "Brain Tumor MRI": { "dataset": "sartajbhuvaji/brain-tumor-classification", "description": "Brain MRI scans with tumor classifications", "split": "train" }, "Medical MNIST (Brain)": { "dataset": "alkzar90/NIH-Chest-X-ray-dataset", "description": "Medical imaging dataset", "split": "train" }, } # Sample neuroimaging URLs (publicly available brain MRI examples) SAMPLE_IMAGES = { "Brain MRI - Axial": "https://upload.wikimedia.org/wikipedia/commons/thumb/5/5e/MRI_of_Human_Brain.jpg/800px-MRI_of_Human_Brain.jpg", "Brain MRI - Sagittal": "https://upload.wikimedia.org/wikipedia/commons/1/1a/MRI_head_side.jpg", "Brain CT Scan": "https://upload.wikimedia.org/wikipedia/commons/thumb/9/9a/CT_of_brain_of_Mikael_H%C3%A4ggstr%C3%B6m_%28montage%29.png/800px-CT_of_brain_of_Mikael_H%C3%A4ggstr%C3%B6m_%28montage%29.png", } # Neuroimaging-specific prompts and presets NEURO_PRESETS = { "Brain Structures": ["brain", "cerebrum", "cerebellum", "brainstem", "corpus callosum"], "Lobes": ["frontal lobe", "temporal lobe", "parietal lobe", "occipital lobe"], "Ventricles": ["ventricle", "lateral ventricle", "third ventricle", "fourth ventricle"], "Gray/White Matter": ["gray matter", "white matter", "cortex", "subcortical"], "Deep Structures": ["thalamus", "hypothalamus", "hippocampus", "amygdala", "basal ganglia"], "Lesions/Abnormalities": ["lesion", "tumor", "mass", "abnormality", "hyperintensity"], "Vascular": ["blood vessel", "artery", "vein", "sinus"], "Skull/Meninges": ["skull", "bone", "meninges", "dura"], } @spaces.GPU() def segment_with_points(image: Image.Image, points: list, labels: list, structure_name: str): """ Perform segmentation using SAM with point prompts. SAM uses point/box prompts, not text prompts. """ if image is None: return None, "❌ Please upload a neuroimaging scan." try: sam_model, sam_processor = load_model() # Ensure image is RGB if image.mode != "RGB": image = image.convert("RGB") # Prepare inputs with point prompts if points and len(points) > 0: input_points = [points] # Shape: (batch, num_points, 2) input_labels = [labels] # Shape: (batch, num_points) else: # Use center point as default w, h = image.size input_points = [[[w // 2, h // 2]]] input_labels = [[1]] # 1 = foreground inputs = sam_processor( image, input_points=input_points, input_labels=input_labels, return_tensors="pt" ) if torch.cuda.is_available(): inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = sam_model(**inputs) # Post-process masks masks = sam_processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) scores = outputs.iou_scores.cpu().numpy()[0] # Get best mask masks_np = masks[0].numpy() if masks_np.shape[0] == 0: return (image, []), f"❌ No segmentation found for the selected points." # Format for AnnotatedImage annotations = [] for i in range(min(3, masks_np.shape[1])): # Top 3 masks mask = masks_np[0, i].astype(np.uint8) if mask.sum() > 0: # Only add non-empty masks score = scores[0, i] if i < scores.shape[1] else 0.0 label = f"{structure_name} (IoU: {score:.2f})" annotations.append((mask, label)) if not annotations: return (image, []), "❌ No valid masks generated." info = f"""✅ **Segmentation Complete** **Target:** {structure_name} **Masks Generated:** {len(annotations)} **Best IoU Score:** {scores.max():.3f} *SAM generates multiple mask proposals - showing top results*""" return (image, annotations), info except Exception as e: import traceback return (image, []), f"❌ Error: {str(e)}\n\n{traceback.format_exc()}" @spaces.GPU() def segment_with_box(image: Image.Image, x1: int, y1: int, x2: int, y2: int, structure_name: str): """Segment using bounding box prompt""" if image is None: return None, "❌ Please upload an image." try: sam_model, sam_processor = load_model() if image.mode != "RGB": image = image.convert("RGB") # Prepare box prompt input_boxes = [[[x1, y1, x2, y2]]] inputs = sam_processor( image, input_boxes=input_boxes, return_tensors="pt" ) if torch.cuda.is_available(): inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = sam_model(**inputs) masks = sam_processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) scores = outputs.iou_scores.cpu().numpy()[0] masks_np = masks[0].numpy() annotations = [] for i in range(min(3, masks_np.shape[1])): mask = masks_np[0, i].astype(np.uint8) if mask.sum() > 0: score = scores[0, i] if i < scores.shape[1] else 0.0 label = f"{structure_name} (IoU: {score:.2f})" annotations.append((mask, label)) if not annotations: return (image, []), "❌ No valid masks generated from box." info = f"""✅ **Box Segmentation Complete** **Target:** {structure_name} **Box:** ({x1}, {y1}) to ({x2}, {y2}) **Masks Generated:** {len(annotations)}""" return (image, annotations), info except Exception as e: return (image, []), f"❌ Error: {str(e)}" @spaces.GPU() def auto_segment_grid(image: Image.Image, grid_size: int = 4): """Automatic segmentation using grid of points""" if image is None: return None, "❌ Please upload an image." try: sam_model, sam_processor = load_model() if image.mode != "RGB": image = image.convert("RGB") w, h = image.size # Create grid of points points = [] step_x = w // (grid_size + 1) step_y = h // (grid_size + 1) for i in range(1, grid_size + 1): for j in range(1, grid_size + 1): points.append([step_x * i, step_y * j]) all_annotations = [] # Process each point for idx, point in enumerate(points[:9]): # Limit to 9 points for speed input_points = [[point]] input_labels = [[1]] inputs = sam_processor( image, input_points=input_points, input_labels=input_labels, return_tensors="pt" ) if torch.cuda.is_available(): inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = sam_model(**inputs) masks = sam_processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) scores = outputs.iou_scores.cpu().numpy()[0] masks_np = masks[0].numpy() # Get best mask for this point if masks_np.shape[1] > 0: best_idx = scores[0].argmax() mask = masks_np[0, best_idx].astype(np.uint8) if mask.sum() > 100: # Minimum size threshold score = scores[0, best_idx] label = f"Region {idx + 1} (IoU: {score:.2f})" all_annotations.append((mask, label)) if not all_annotations: return (image, []), "❌ No regions found with auto-segmentation." info = f"""✅ **Auto-Segmentation Complete** **Grid Points:** {len(points)} **Regions Found:** {len(all_annotations)} *Automatic discovery of distinct regions in the image*""" return (image, all_annotations), info except Exception as e: return (image, []), f"❌ Error: {str(e)}" def load_sample_image(sample_name: str): """Load a sample neuroimaging image""" if sample_name not in SAMPLE_IMAGES: return None, "Sample not found" try: url = SAMPLE_IMAGES[sample_name] response = requests.get(url, timeout=10) image = Image.open(BytesIO(response.content)).convert("RGB") return image, f"✅ Loaded: {sample_name}" except Exception as e: return None, f"❌ Failed to load sample: {str(e)}" def load_from_hf_dataset(dataset_name: str, index: int = 0): """Load image from Hugging Face dataset""" try: if dataset_name == "Brain Tumor MRI": ds = load_dataset("sartajbhuvaji/brain-tumor-classification", split="train", streaming=True) for i, sample in enumerate(ds): if i == index: image = sample["image"] if image.mode != "RGB": image = image.convert("RGB") return image, f"✅ Loaded from Brain Tumor MRI dataset (index {index})" return None, "Dataset not available" except Exception as e: return None, f"❌ Error loading dataset: {str(e)}" def get_click_point(image, evt: gr.SelectData): """Get point coordinates from image click""" if evt is None: return [], [], "Click on the image to add points" x, y = evt.index return [[x, y]], [1], f"Point added at ({x}, {y})" # Store points for multi-point selection current_points = [] current_labels = [] def add_point(image, evt: gr.SelectData, points_state, labels_state, point_type): """Add a point to the current selection""" if evt is None or image is None: return points_state, labels_state, "Click on image to add points" x, y = evt.index label = 1 if point_type == "Foreground (+)" else 0 points_state = points_state + [[x, y]] labels_state = labels_state + [label] point_info = f"Added {'foreground' if label == 1 else 'background'} point at ({x}, {y})\n" point_info += f"Total points: {len(points_state)}" return points_state, labels_state, point_info def clear_points(): """Clear all selected points""" return [], [], "Points cleared" def clear_all(): """Clear all inputs and outputs""" return None, None, [], [], 0.5, "brain region", "📝 Upload a neuroimaging scan and click to add points for segmentation." # Gradio Interface with gr.Blocks( theme=gr.themes.Soft(), title="NeuroSAM - Neuroimaging Segmentation", css=""" .gradio-container {max-width: 1400px !important;} .neuro-header {background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 10px; margin-bottom: 20px;} .neuro-header h1 {color: white !important; margin: 0 !important;} .neuro-header p {color: rgba(255,255,255,0.9) !important;} .info-box {background: #e8f4f8; padding: 15px; border-radius: 8px; margin: 10px 0;} """, footer_links=[{"label": "Built with anycoder", "url": "https://huggingface.co/spaces/akhaliq/anycoder"}] ) as demo: # State for point selection points_state = gr.State([]) labels_state = gr.State([]) gr.HTML( """
Interactive segmentation using Meta's Segment Anything Model (SAM)
Built with anycoder | Model: facebook/sam-vit-base