=== app.py === import spaces import gradio as gr import torch import numpy as np from PIL import Image from transformers import SamModel, SamProcessor from datasets import load_dataset import requests from io import BytesIO import warnings warnings.filterwarnings("ignore") # Global model and processor - Using SAM (Segment Anything Model) device = "cuda" if torch.cuda.is_available() else "cpu" model = None processor = None def load_model(): """Load SAM model lazily""" global model, processor if model is None: print("Loading SAM model...") processor = SamProcessor.from_pretrained("facebook/sam-vit-base") model = SamModel.from_pretrained("facebook/sam-vit-base") if torch.cuda.is_available(): model = model.to(device) print("Model loaded successfully!") return model, processor # Public neuroimaging datasets on Hugging Face NEUROIMAGING_DATASETS = { "Brain Tumor MRI": { "dataset": "sartajbhuvaji/brain-tumor-classification", "description": "Brain MRI scans with tumor classifications", "split": "train" }, "Medical MNIST (Brain)": { "dataset": "alkzar90/NIH-Chest-X-ray-dataset", "description": "Medical imaging dataset", "split": "train" }, } # Sample neuroimaging URLs (publicly available brain MRI examples) SAMPLE_IMAGES = { "Brain MRI - Axial": "https://upload.wikimedia.org/wikipedia/commons/thumb/5/5e/MRI_of_Human_Brain.jpg/800px-MRI_of_Human_Brain.jpg", "Brain MRI - Sagittal": "https://upload.wikimedia.org/wikipedia/commons/1/1a/MRI_head_side.jpg", "Brain CT Scan": "https://upload.wikimedia.org/wikipedia/commons/thumb/9/9a/CT_of_brain_of_Mikael_H%C3%A4ggstr%C3%B6m_%28montage%29.png/800px-CT_of_brain_of_Mikael_H%C3%A4ggstr%C3%B6m_%28montage%29.png", } # Neuroimaging-specific prompts and presets NEURO_PRESETS = { "Brain Structures": ["brain", "cerebrum", "cerebellum", "brainstem", "corpus callosum"], "Lobes": ["frontal lobe", "temporal lobe", "parietal lobe", "occipital lobe"], "Ventricles": ["ventricle", "lateral ventricle", "third ventricle", "fourth ventricle"], "Gray/White Matter": ["gray matter", "white matter", "cortex", "subcortical"], "Deep Structures": ["thalamus", "hypothalamus", "hippocampus", "amygdala", "basal ganglia"], "Lesions/Abnormalities": ["lesion", "tumor", "mass", "abnormality", "hyperintensity"], "Vascular": ["blood vessel", "artery", "vein", "sinus"], "Skull/Meninges": ["skull", "bone", "meninges", "dura"], } @spaces.GPU() def segment_with_points(image: Image.Image, points: list, labels: list, structure_name: str): """ Perform segmentation using SAM with point prompts. SAM uses point/box prompts, not text prompts. """ if image is None: return None, "❌ Please upload a neuroimaging scan." try: sam_model, sam_processor = load_model() # Ensure image is RGB if image.mode != "RGB": image = image.convert("RGB") # Prepare inputs with point prompts if points and len(points) > 0: input_points = [points] # Shape: (batch, num_points, 2) input_labels = [labels] # Shape: (batch, num_points) else: # Use center point as default w, h = image.size input_points = [[[w // 2, h // 2]]] input_labels = [[1]] # 1 = foreground inputs = sam_processor( image, input_points=input_points, input_labels=input_labels, return_tensors="pt" ) if torch.cuda.is_available(): inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = sam_model(**inputs) # Post-process masks masks = sam_processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) scores = outputs.iou_scores.cpu().numpy()[0] # Get best mask masks_np = masks[0].numpy() if masks_np.shape[0] == 0: return (image, []), f"❌ No segmentation found for the selected points." # Format for AnnotatedImage annotations = [] for i in range(min(3, masks_np.shape[1])): # Top 3 masks mask = masks_np[0, i].astype(np.uint8) if mask.sum() > 0: # Only add non-empty masks score = scores[0, i] if i < scores.shape[1] else 0.0 label = f"{structure_name} (IoU: {score:.2f})" annotations.append((mask, label)) if not annotations: return (image, []), "❌ No valid masks generated." info = f"""✅ **Segmentation Complete** **Target:** {structure_name} **Masks Generated:** {len(annotations)} **Best IoU Score:** {scores.max():.3f} *SAM generates multiple mask proposals - showing top results*""" return (image, annotations), info except Exception as e: import traceback return (image, []), f"❌ Error: {str(e)}\n\n{traceback.format_exc()}" @spaces.GPU() def segment_with_box(image: Image.Image, x1: int, y1: int, x2: int, y2: int, structure_name: str): """Segment using bounding box prompt""" if image is None: return None, "❌ Please upload an image." try: sam_model, sam_processor = load_model() if image.mode != "RGB": image = image.convert("RGB") # Prepare box prompt input_boxes = [[[x1, y1, x2, y2]]] inputs = sam_processor( image, input_boxes=input_boxes, return_tensors="pt" ) if torch.cuda.is_available(): inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = sam_model(**inputs) masks = sam_processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) scores = outputs.iou_scores.cpu().numpy()[0] masks_np = masks[0].numpy() annotations = [] for i in range(min(3, masks_np.shape[1])): mask = masks_np[0, i].astype(np.uint8) if mask.sum() > 0: score = scores[0, i] if i < scores.shape[1] else 0.0 label = f"{structure_name} (IoU: {score:.2f})" annotations.append((mask, label)) if not annotations: return (image, []), "❌ No valid masks generated from box." info = f"""✅ **Box Segmentation Complete** **Target:** {structure_name} **Box:** ({x1}, {y1}) to ({x2}, {y2}) **Masks Generated:** {len(annotations)}""" return (image, annotations), info except Exception as e: return (image, []), f"❌ Error: {str(e)}" @spaces.GPU() def auto_segment_grid(image: Image.Image, grid_size: int = 4): """Automatic segmentation using grid of points""" if image is None: return None, "❌ Please upload an image." try: sam_model, sam_processor = load_model() if image.mode != "RGB": image = image.convert("RGB") w, h = image.size # Create grid of points points = [] step_x = w // (grid_size + 1) step_y = h // (grid_size + 1) for i in range(1, grid_size + 1): for j in range(1, grid_size + 1): points.append([step_x * i, step_y * j]) all_annotations = [] # Process each point for idx, point in enumerate(points[:9]): # Limit to 9 points for speed input_points = [[point]] input_labels = [[1]] inputs = sam_processor( image, input_points=input_points, input_labels=input_labels, return_tensors="pt" ) if torch.cuda.is_available(): inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = sam_model(**inputs) masks = sam_processor.image_processor.post_process_masks( outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu() ) scores = outputs.iou_scores.cpu().numpy()[0] masks_np = masks[0].numpy() # Get best mask for this point if masks_np.shape[1] > 0: best_idx = scores[0].argmax() mask = masks_np[0, best_idx].astype(np.uint8) if mask.sum() > 100: # Minimum size threshold score = scores[0, best_idx] label = f"Region {idx + 1} (IoU: {score:.2f})" all_annotations.append((mask, label)) if not all_annotations: return (image, []), "❌ No regions found with auto-segmentation." info = f"""✅ **Auto-Segmentation Complete** **Grid Points:** {len(points)} **Regions Found:** {len(all_annotations)} *Automatic discovery of distinct regions in the image*""" return (image, all_annotations), info except Exception as e: return (image, []), f"❌ Error: {str(e)}" def load_sample_image(sample_name: str): """Load a sample neuroimaging image""" if sample_name not in SAMPLE_IMAGES: return None, "Sample not found" try: url = SAMPLE_IMAGES[sample_name] response = requests.get(url, timeout=10) image = Image.open(BytesIO(response.content)).convert("RGB") return image, f"✅ Loaded: {sample_name}" except Exception as e: return None, f"❌ Failed to load sample: {str(e)}" def load_from_hf_dataset(dataset_name: str, index: int = 0): """Load image from Hugging Face dataset""" try: if dataset_name == "Brain Tumor MRI": ds = load_dataset("sartajbhuvaji/brain-tumor-classification", split="train", streaming=True) for i, sample in enumerate(ds): if i == index: image = sample["image"] if image.mode != "RGB": image = image.convert("RGB") return image, f"✅ Loaded from Brain Tumor MRI dataset (index {index})" return None, "Dataset not available" except Exception as e: return None, f"❌ Error loading dataset: {str(e)}" def get_click_point(image, evt: gr.SelectData): """Get point coordinates from image click""" if evt is None: return [], [], "Click on the image to add points" x, y = evt.index return [[x, y]], [1], f"Point added at ({x}, {y})" # Store points for multi-point selection current_points = [] current_labels = [] def add_point(image, evt: gr.SelectData, points_state, labels_state, point_type): """Add a point to the current selection""" if evt is None or image is None: return points_state, labels_state, "Click on image to add points" x, y = evt.index label = 1 if point_type == "Foreground (+)" else 0 points_state = points_state + [[x, y]] labels_state = labels_state + [label] point_info = f"Added {'foreground' if label == 1 else 'background'} point at ({x}, {y})\n" point_info += f"Total points: {len(points_state)}" return points_state, labels_state, point_info def clear_points(): """Clear all selected points""" return [], [], "Points cleared" def clear_all(): """Clear all inputs and outputs""" return None, None, [], [], 0.5, "brain region", "📝 Upload a neuroimaging scan and click to add points for segmentation." # Gradio Interface with gr.Blocks( theme=gr.themes.Soft(), title="NeuroSAM - Neuroimaging Segmentation", css=""" .gradio-container {max-width: 1400px !important;} .neuro-header {background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 10px; margin-bottom: 20px;} .neuro-header h1 {color: white !important; margin: 0 !important;} .neuro-header p {color: rgba(255,255,255,0.9) !important;} .info-box {background: #e8f4f8; padding: 15px; border-radius: 8px; margin: 10px 0;} """, footer_links=[{"label": "Built with anycoder", "url": "https://huggingface.co/spaces/akhaliq/anycoder"}] ) as demo: # State for point selection points_state = gr.State([]) labels_state = gr.State([]) gr.HTML( """

🧠 NeuroSAM - Neuroimaging Segmentation

Interactive segmentation using Meta's Segment Anything Model (SAM)

Built with anycoder | Model: facebook/sam-vit-base

""" ) gr.Markdown(""" ### ℹ️ About SAM (Segment Anything Model) **SAM** is a foundation model for image segmentation by Meta AI. Unlike text-based models, SAM uses **visual prompts**: - **Point prompts**: Click on the region you want to segment - **Box prompts**: Draw a bounding box around the region - **Automatic mode**: Discovers all segmentable regions *Note: SAM is a general-purpose segmentation model, not specifically trained on medical images. For clinical use, specialized medical imaging models should be used.* """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 📤 Input") image_input = gr.Image( label="Neuroimaging Scan (Click to add points)", type="pil", height=400, interactive=True, ) with gr.Accordion("📂 Load Sample Images", open=True): sample_dropdown = gr.Dropdown( label="Sample Neuroimaging Images", choices=list(SAMPLE_IMAGES.keys()), value=None, info="Load publicly available brain imaging examples" ) load_sample_btn = gr.Button("Load Sample", size="sm") gr.Markdown("**Or load from Hugging Face Datasets:**") with gr.Row(): hf_dataset = gr.Dropdown( label="Dataset", choices=["Brain Tumor MRI"], value="Brain Tumor MRI" ) hf_index = gr.Number(label="Image Index", value=0, minimum=0, maximum=100) load_hf_btn = gr.Button("Load from HF", size="sm") gr.Markdown("### 🎯 Segmentation Mode") with gr.Tab("Point Prompt"): gr.Markdown("**Click on the image to add points, then segment**") point_type = gr.Radio( choices=["Foreground (+)", "Background (-)"], value="Foreground (+)", label="Point Type", info="Foreground = include region, Background = exclude region" ) structure_name = gr.Textbox( label="Structure Label", value="brain region", placeholder="e.g., hippocampus, ventricle, tumor...", info="Label for the segmented region" ) points_info = gr.Textbox( label="Selected Points", value="Click on image to add points", interactive=False ) with gr.Row(): clear_points_btn = gr.Button("Clear Points", variant="secondary") segment_points_btn = gr.Button("🎯 Segment", variant="primary") with gr.Tab("Box Prompt"): gr.Markdown("**Define a bounding box around the region**") with gr.Row(): box_x1 = gr.Number(label="X1 (left)", value=50) box_y1 = gr.Number(label="Y1 (top)", value=50) with gr.Row(): box_x2 = gr.Number(label="X2 (right)", value=200) box_y2 = gr.Number(label="Y2 (bottom)", value=200) box_structure = gr.Textbox( label="Structure Label", value="selected region" ) segment_box_btn = gr.Button("🎯 Segment Box", variant="primary") with gr.Tab("Auto Segment"): gr.Markdown("**Automatically discover all segmentable regions**") grid_size = gr.Slider( minimum=2, maximum=5, value=3, step=1, label="Grid Density", info="Higher = more points sampled" ) auto_segment_btn = gr.Button("🔍 Auto-Segment All", variant="primary") clear_btn = gr.Button("🗑️ Clear All", variant="secondary") with gr.Column(scale=1): gr.Markdown("### 📊 Output") image_output = gr.AnnotatedImage( label="Segmented Result", height=450, show_legend=True, ) info_output = gr.Markdown( value="📝 Upload a neuroimaging scan and click to add points for segmentation.", label="Results" ) gr.Markdown("### 📚 Available Datasets on Hugging Face") with gr.Row(): with gr.Column(): gr.Markdown(""" **🧠 Brain/Neuro Imaging** - `sartajbhuvaji/brain-tumor-classification` - Brain MRI with tumor labels - `keremberke/brain-tumor-object-detection` - Brain tumor detection - `TrainingDataPro/brain-mri-dataset` - Brain MRI scans """) with gr.Column(): gr.Markdown(""" **🏥 Medical Imaging** - `alkzar90/NIH-Chest-X-ray-dataset` - Chest X-rays - `marmal88/skin_cancer` - Dermatology images - `hf-vision/chest-xray-pneumonia` - Pneumonia detection """) with gr.Column(): gr.Markdown(""" **🔬 Specialized** - `Francesco/cell-segmentation` - Cell microscopy - `segments/sidewalk-semantic` - Semantic segmentation - `detection-datasets/coco` - General objects """) gr.Markdown(""" ### 💡 How to Use 1. **Load an image**: Upload your own or select from samples/HuggingFace datasets 2. **Choose segmentation mode**: - **Point Prompt**: Click on regions you want to segment (green = include, red = exclude) - **Box Prompt**: Define coordinates for a bounding box - **Auto Segment**: Let SAM discover all distinct regions automatically 3. **View results**: Segmented regions appear with colored overlays ### ⚠️ Important Notes - SAM is a **general-purpose** model, not specifically trained for medical imaging - For clinical applications, use validated medical imaging AI tools - Results should be reviewed by qualified medical professionals """) # Event handlers load_sample_btn.click( fn=load_sample_image, inputs=[sample_dropdown], outputs=[image_input, info_output] ) load_hf_btn.click( fn=load_from_hf_dataset, inputs=[hf_dataset, hf_index], outputs=[image_input, info_output] ) # Point selection on image click image_input.select( fn=add_point, inputs=[image_input, points_state, labels_state, point_type], outputs=[points_state, labels_state, points_info] ) clear_points_btn.click( fn=clear_points, outputs=[points_state, labels_state, points_info] ) segment_points_btn.click( fn=segment_with_points, inputs=[image_input, points_state, labels_state, structure_name], outputs=[image_output, info_output] ) segment_box_btn.click( fn=segment_with_box, inputs=[image_input, box_x1, box_y1, box_x2, box_y2, box_structure], outputs=[image_output, info_output] ) auto_segment_btn.click( fn=auto_segment_grid, inputs=[image_input, grid_size], outputs=[image_output, info_output] ) clear_btn.click( fn=clear_all, outputs=[image_input, image_output, points_state, labels_state, grid_size, structure_name, info_output] ) if __name__ == "__main__": demo.launch() === utils.py === """ Utility functions for neuroimaging preprocessing and analysis """ import numpy as np from PIL import Image def normalize_medical_image(image_array: np.ndarray) -> np.ndarray: """ Normalize medical image intensities to 0-255 range Handles various bit depths common in medical imaging """ img = image_array.astype(np.float32) # Handle different intensity ranges if img.max() > 255: # Likely 12-bit or 16-bit image p1, p99 = np.percentile(img, [1, 99]) img = np.clip(img, p1, p99) # Normalize to 0-255 img_min, img_max = img.min(), img.max() if img_max > img_min: img = (img - img_min) / (img_max - img_min) * 255 return img.astype(np.uint8) def apply_window_level(image_array: np.ndarray, window: float, level: float) -> np.ndarray: """ Apply window/level (contrast/brightness) adjustment Common in CT viewing Args: image_array: Input image window: Window width (contrast) level: Window center (brightness) """ img = image_array.astype(np.float32) min_val = level - window / 2 max_val = level + window / 2 img = np.clip(img, min_val, max_val) img = (img - min_val) / (max_val - min_val) * 255 return img.astype(np.uint8) def enhance_brain_contrast(image: Image.Image) -> Image.Image: """ Enhance contrast specifically for brain MRI visualization """ img_array = np.array(image) # Convert to grayscale if needed if len(img_array.shape) == 3: gray = np.mean(img_array, axis=2) else: gray = img_array # Apply histogram equalization from PIL import ImageOps enhanced = ImageOps.equalize(Image.fromarray(gray.astype(np.uint8))) # Convert back to RGB enhanced_array = np.array(enhanced) rgb_array = np.stack([enhanced_array] * 3, axis=-1) return Image.fromarray(rgb_array) # Common neuroimaging structure mappings STRUCTURE_ALIASES = { "hippocampus": ["hippocampal formation", "hippocampal", "medial temporal"], "ventricle": ["ventricular system", "lateral ventricle", "CSF space"], "white matter": ["WM", "cerebral white matter", "deep white matter"], "gray matter": ["GM", "cortical gray matter", "cortex"], "tumor": ["mass", "lesion", "neoplasm", "growth"], "thalamus": ["thalamic", "diencephalon"], "basal ganglia": ["striatum", "caudate", "putamen", "globus pallidus"], } def get_structure_aliases(structure: str) -> list: """Get alternative names for a neuroanatomical structure""" structure_lower = structure.lower() for key, aliases in STRUCTURE_ALIASES.items(): if structure_lower == key or structure_lower in aliases: return [key] + aliases return [structure] # Hugging Face datasets for neuroimaging HF_NEUROIMAGING_DATASETS = { "brain-tumor-classification": { "repo": "sartajbhuvaji/brain-tumor-classification", "description": "Brain MRI scans classified by tumor type (glioma, meningioma, pituitary, no tumor)", "image_key": "image", "label_key": "label" }, "brain-tumor-detection": { "repo": "keremberke/brain-tumor-object-detection", "description": "Brain MRI with bounding box annotations for tumors", "image_key": "image", "label_key": "objects" }, "chest-xray": { "repo": "alkzar90/NIH-Chest-X-ray-dataset", "description": "Chest X-ray images with disease labels", "image_key": "image", "label_key": "labels" } }