mmrech's picture
Update app.py
764838e verified
=== app.py ===
import spaces
import gradio as gr
import torch
import numpy as np
from PIL import Image
from transformers import SamModel, SamProcessor
from datasets import load_dataset
import requests
from io import BytesIO
import warnings
warnings.filterwarnings("ignore")
# Global model and processor - Using SAM (Segment Anything Model)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = None
processor = None
def load_model():
"""Load SAM model lazily"""
global model, processor
if model is None:
print("Loading SAM model...")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model = SamModel.from_pretrained("facebook/sam-vit-base")
if torch.cuda.is_available():
model = model.to(device)
print("Model loaded successfully!")
return model, processor
# Public neuroimaging datasets on Hugging Face
NEUROIMAGING_DATASETS = {
"Brain Tumor MRI": {
"dataset": "sartajbhuvaji/brain-tumor-classification",
"description": "Brain MRI scans with tumor classifications",
"split": "train"
},
"Medical MNIST (Brain)": {
"dataset": "alkzar90/NIH-Chest-X-ray-dataset",
"description": "Medical imaging dataset",
"split": "train"
},
}
# Sample neuroimaging URLs (publicly available brain MRI examples)
SAMPLE_IMAGES = {
"Brain MRI - Axial": "https://upload.wikimedia.org/wikipedia/commons/thumb/5/5e/MRI_of_Human_Brain.jpg/800px-MRI_of_Human_Brain.jpg",
"Brain MRI - Sagittal": "https://upload.wikimedia.org/wikipedia/commons/1/1a/MRI_head_side.jpg",
"Brain CT Scan": "https://upload.wikimedia.org/wikipedia/commons/thumb/9/9a/CT_of_brain_of_Mikael_H%C3%A4ggstr%C3%B6m_%28montage%29.png/800px-CT_of_brain_of_Mikael_H%C3%A4ggstr%C3%B6m_%28montage%29.png",
}
# Neuroimaging-specific prompts and presets
NEURO_PRESETS = {
"Brain Structures": ["brain", "cerebrum", "cerebellum", "brainstem", "corpus callosum"],
"Lobes": ["frontal lobe", "temporal lobe", "parietal lobe", "occipital lobe"],
"Ventricles": ["ventricle", "lateral ventricle", "third ventricle", "fourth ventricle"],
"Gray/White Matter": ["gray matter", "white matter", "cortex", "subcortical"],
"Deep Structures": ["thalamus", "hypothalamus", "hippocampus", "amygdala", "basal ganglia"],
"Lesions/Abnormalities": ["lesion", "tumor", "mass", "abnormality", "hyperintensity"],
"Vascular": ["blood vessel", "artery", "vein", "sinus"],
"Skull/Meninges": ["skull", "bone", "meninges", "dura"],
}
@spaces.GPU()
def segment_with_points(image: Image.Image, points: list, labels: list, structure_name: str):
"""
Perform segmentation using SAM with point prompts.
SAM uses point/box prompts, not text prompts.
"""
if image is None:
return None, "❌ Please upload a neuroimaging scan."
try:
sam_model, sam_processor = load_model()
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
# Prepare inputs with point prompts
if points and len(points) > 0:
input_points = [points] # Shape: (batch, num_points, 2)
input_labels = [labels] # Shape: (batch, num_points)
else:
# Use center point as default
w, h = image.size
input_points = [[[w // 2, h // 2]]]
input_labels = [[1]] # 1 = foreground
inputs = sam_processor(
image,
input_points=input_points,
input_labels=input_labels,
return_tensors="pt"
)
if torch.cuda.is_available():
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = sam_model(**inputs)
# Post-process masks
masks = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores.cpu().numpy()[0]
# Get best mask
masks_np = masks[0].numpy()
if masks_np.shape[0] == 0:
return (image, []), f"❌ No segmentation found for the selected points."
# Format for AnnotatedImage
annotations = []
for i in range(min(3, masks_np.shape[1])): # Top 3 masks
mask = masks_np[0, i].astype(np.uint8)
if mask.sum() > 0: # Only add non-empty masks
score = scores[0, i] if i < scores.shape[1] else 0.0
label = f"{structure_name} (IoU: {score:.2f})"
annotations.append((mask, label))
if not annotations:
return (image, []), "❌ No valid masks generated."
info = f"""✅ **Segmentation Complete**
**Target:** {structure_name}
**Masks Generated:** {len(annotations)}
**Best IoU Score:** {scores.max():.3f}
*SAM generates multiple mask proposals - showing top results*"""
return (image, annotations), info
except Exception as e:
import traceback
return (image, []), f"❌ Error: {str(e)}\n\n{traceback.format_exc()}"
@spaces.GPU()
def segment_with_box(image: Image.Image, x1: int, y1: int, x2: int, y2: int, structure_name: str):
"""Segment using bounding box prompt"""
if image is None:
return None, "❌ Please upload an image."
try:
sam_model, sam_processor = load_model()
if image.mode != "RGB":
image = image.convert("RGB")
# Prepare box prompt
input_boxes = [[[x1, y1, x2, y2]]]
inputs = sam_processor(
image,
input_boxes=input_boxes,
return_tensors="pt"
)
if torch.cuda.is_available():
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = sam_model(**inputs)
masks = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores.cpu().numpy()[0]
masks_np = masks[0].numpy()
annotations = []
for i in range(min(3, masks_np.shape[1])):
mask = masks_np[0, i].astype(np.uint8)
if mask.sum() > 0:
score = scores[0, i] if i < scores.shape[1] else 0.0
label = f"{structure_name} (IoU: {score:.2f})"
annotations.append((mask, label))
if not annotations:
return (image, []), "❌ No valid masks generated from box."
info = f"""✅ **Box Segmentation Complete**
**Target:** {structure_name}
**Box:** ({x1}, {y1}) to ({x2}, {y2})
**Masks Generated:** {len(annotations)}"""
return (image, annotations), info
except Exception as e:
return (image, []), f"❌ Error: {str(e)}"
@spaces.GPU()
def auto_segment_grid(image: Image.Image, grid_size: int = 4):
"""Automatic segmentation using grid of points"""
if image is None:
return None, "❌ Please upload an image."
try:
sam_model, sam_processor = load_model()
if image.mode != "RGB":
image = image.convert("RGB")
w, h = image.size
# Create grid of points
points = []
step_x = w // (grid_size + 1)
step_y = h // (grid_size + 1)
for i in range(1, grid_size + 1):
for j in range(1, grid_size + 1):
points.append([step_x * i, step_y * j])
all_annotations = []
# Process each point
for idx, point in enumerate(points[:9]): # Limit to 9 points for speed
input_points = [[point]]
input_labels = [[1]]
inputs = sam_processor(
image,
input_points=input_points,
input_labels=input_labels,
return_tensors="pt"
)
if torch.cuda.is_available():
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = sam_model(**inputs)
masks = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
scores = outputs.iou_scores.cpu().numpy()[0]
masks_np = masks[0].numpy()
# Get best mask for this point
if masks_np.shape[1] > 0:
best_idx = scores[0].argmax()
mask = masks_np[0, best_idx].astype(np.uint8)
if mask.sum() > 100: # Minimum size threshold
score = scores[0, best_idx]
label = f"Region {idx + 1} (IoU: {score:.2f})"
all_annotations.append((mask, label))
if not all_annotations:
return (image, []), "❌ No regions found with auto-segmentation."
info = f"""✅ **Auto-Segmentation Complete**
**Grid Points:** {len(points)}
**Regions Found:** {len(all_annotations)}
*Automatic discovery of distinct regions in the image*"""
return (image, all_annotations), info
except Exception as e:
return (image, []), f"❌ Error: {str(e)}"
def load_sample_image(sample_name: str):
"""Load a sample neuroimaging image"""
if sample_name not in SAMPLE_IMAGES:
return None, "Sample not found"
try:
url = SAMPLE_IMAGES[sample_name]
response = requests.get(url, timeout=10)
image = Image.open(BytesIO(response.content)).convert("RGB")
return image, f"✅ Loaded: {sample_name}"
except Exception as e:
return None, f"❌ Failed to load sample: {str(e)}"
def load_from_hf_dataset(dataset_name: str, index: int = 0):
"""Load image from Hugging Face dataset"""
try:
if dataset_name == "Brain Tumor MRI":
ds = load_dataset("sartajbhuvaji/brain-tumor-classification", split="train", streaming=True)
for i, sample in enumerate(ds):
if i == index:
image = sample["image"]
if image.mode != "RGB":
image = image.convert("RGB")
return image, f"✅ Loaded from Brain Tumor MRI dataset (index {index})"
return None, "Dataset not available"
except Exception as e:
return None, f"❌ Error loading dataset: {str(e)}"
def get_click_point(image, evt: gr.SelectData):
"""Get point coordinates from image click"""
if evt is None:
return [], [], "Click on the image to add points"
x, y = evt.index
return [[x, y]], [1], f"Point added at ({x}, {y})"
# Store points for multi-point selection
current_points = []
current_labels = []
def add_point(image, evt: gr.SelectData, points_state, labels_state, point_type):
"""Add a point to the current selection"""
if evt is None or image is None:
return points_state, labels_state, "Click on image to add points"
x, y = evt.index
label = 1 if point_type == "Foreground (+)" else 0
points_state = points_state + [[x, y]]
labels_state = labels_state + [label]
point_info = f"Added {'foreground' if label == 1 else 'background'} point at ({x}, {y})\n"
point_info += f"Total points: {len(points_state)}"
return points_state, labels_state, point_info
def clear_points():
"""Clear all selected points"""
return [], [], "Points cleared"
def clear_all():
"""Clear all inputs and outputs"""
return None, None, [], [], 0.5, "brain region", "📝 Upload a neuroimaging scan and click to add points for segmentation."
# Gradio Interface
with gr.Blocks(
theme=gr.themes.Soft(),
title="NeuroSAM - Neuroimaging Segmentation",
css="""
.gradio-container {max-width: 1400px !important;}
.neuro-header {background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 10px; margin-bottom: 20px;}
.neuro-header h1 {color: white !important; margin: 0 !important;}
.neuro-header p {color: rgba(255,255,255,0.9) !important;}
.info-box {background: #e8f4f8; padding: 15px; border-radius: 8px; margin: 10px 0;}
""",
footer_links=[{"label": "Built with anycoder", "url": "https://huggingface.co/spaces/akhaliq/anycoder"}]
) as demo:
# State for point selection
points_state = gr.State([])
labels_state = gr.State([])
gr.HTML(
"""
<div class="neuro-header">
<h1>🧠 NeuroSAM - Neuroimaging Segmentation</h1>
<p>Interactive segmentation using Meta's Segment Anything Model (SAM)</p>
<p style="font-size: 0.9em;">Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" style="color: #FFD700;">anycoder</a> | Model: facebook/sam-vit-base</p>
</div>
"""
)
gr.Markdown("""
### ℹ️ About SAM (Segment Anything Model)
**SAM** is a foundation model for image segmentation by Meta AI. Unlike text-based models, SAM uses **visual prompts**:
- **Point prompts**: Click on the region you want to segment
- **Box prompts**: Draw a bounding box around the region
- **Automatic mode**: Discovers all segmentable regions
*Note: SAM is a general-purpose segmentation model, not specifically trained on medical images. For clinical use, specialized medical imaging models should be used.*
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 📤 Input")
image_input = gr.Image(
label="Neuroimaging Scan (Click to add points)",
type="pil",
height=400,
interactive=True,
)
with gr.Accordion("📂 Load Sample Images", open=True):
sample_dropdown = gr.Dropdown(
label="Sample Neuroimaging Images",
choices=list(SAMPLE_IMAGES.keys()),
value=None,
info="Load publicly available brain imaging examples"
)
load_sample_btn = gr.Button("Load Sample", size="sm")
gr.Markdown("**Or load from Hugging Face Datasets:**")
with gr.Row():
hf_dataset = gr.Dropdown(
label="Dataset",
choices=["Brain Tumor MRI"],
value="Brain Tumor MRI"
)
hf_index = gr.Number(label="Image Index", value=0, minimum=0, maximum=100)
load_hf_btn = gr.Button("Load from HF", size="sm")
gr.Markdown("### 🎯 Segmentation Mode")
with gr.Tab("Point Prompt"):
gr.Markdown("**Click on the image to add points, then segment**")
point_type = gr.Radio(
choices=["Foreground (+)", "Background (-)"],
value="Foreground (+)",
label="Point Type",
info="Foreground = include region, Background = exclude region"
)
structure_name = gr.Textbox(
label="Structure Label",
value="brain region",
placeholder="e.g., hippocampus, ventricle, tumor...",
info="Label for the segmented region"
)
points_info = gr.Textbox(
label="Selected Points",
value="Click on image to add points",
interactive=False
)
with gr.Row():
clear_points_btn = gr.Button("Clear Points", variant="secondary")
segment_points_btn = gr.Button("🎯 Segment", variant="primary")
with gr.Tab("Box Prompt"):
gr.Markdown("**Define a bounding box around the region**")
with gr.Row():
box_x1 = gr.Number(label="X1 (left)", value=50)
box_y1 = gr.Number(label="Y1 (top)", value=50)
with gr.Row():
box_x2 = gr.Number(label="X2 (right)", value=200)
box_y2 = gr.Number(label="Y2 (bottom)", value=200)
box_structure = gr.Textbox(
label="Structure Label",
value="selected region"
)
segment_box_btn = gr.Button("🎯 Segment Box", variant="primary")
with gr.Tab("Auto Segment"):
gr.Markdown("**Automatically discover all segmentable regions**")
grid_size = gr.Slider(
minimum=2,
maximum=5,
value=3,
step=1,
label="Grid Density",
info="Higher = more points sampled"
)
auto_segment_btn = gr.Button("🔍 Auto-Segment All", variant="primary")
clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
with gr.Column(scale=1):
gr.Markdown("### 📊 Output")
image_output = gr.AnnotatedImage(
label="Segmented Result",
height=450,
show_legend=True,
)
info_output = gr.Markdown(
value="📝 Upload a neuroimaging scan and click to add points for segmentation.",
label="Results"
)
gr.Markdown("### 📚 Available Datasets on Hugging Face")
with gr.Row():
with gr.Column():
gr.Markdown("""
**🧠 Brain/Neuro Imaging**
- `sartajbhuvaji/brain-tumor-classification` - Brain MRI with tumor labels
- `keremberke/brain-tumor-object-detection` - Brain tumor detection
- `TrainingDataPro/brain-mri-dataset` - Brain MRI scans
""")
with gr.Column():
gr.Markdown("""
**🏥 Medical Imaging**
- `alkzar90/NIH-Chest-X-ray-dataset` - Chest X-rays
- `marmal88/skin_cancer` - Dermatology images
- `hf-vision/chest-xray-pneumonia` - Pneumonia detection
""")
with gr.Column():
gr.Markdown("""
**🔬 Specialized**
- `Francesco/cell-segmentation` - Cell microscopy
- `segments/sidewalk-semantic` - Semantic segmentation
- `detection-datasets/coco` - General objects
""")
gr.Markdown("""
### 💡 How to Use
1. **Load an image**: Upload your own or select from samples/HuggingFace datasets
2. **Choose segmentation mode**:
- **Point Prompt**: Click on regions you want to segment (green = include, red = exclude)
- **Box Prompt**: Define coordinates for a bounding box
- **Auto Segment**: Let SAM discover all distinct regions automatically
3. **View results**: Segmented regions appear with colored overlays
### ⚠️ Important Notes
- SAM is a **general-purpose** model, not specifically trained for medical imaging
- For clinical applications, use validated medical imaging AI tools
- Results should be reviewed by qualified medical professionals
""")
# Event handlers
load_sample_btn.click(
fn=load_sample_image,
inputs=[sample_dropdown],
outputs=[image_input, info_output]
)
load_hf_btn.click(
fn=load_from_hf_dataset,
inputs=[hf_dataset, hf_index],
outputs=[image_input, info_output]
)
# Point selection on image click
image_input.select(
fn=add_point,
inputs=[image_input, points_state, labels_state, point_type],
outputs=[points_state, labels_state, points_info]
)
clear_points_btn.click(
fn=clear_points,
outputs=[points_state, labels_state, points_info]
)
segment_points_btn.click(
fn=segment_with_points,
inputs=[image_input, points_state, labels_state, structure_name],
outputs=[image_output, info_output]
)
segment_box_btn.click(
fn=segment_with_box,
inputs=[image_input, box_x1, box_y1, box_x2, box_y2, box_structure],
outputs=[image_output, info_output]
)
auto_segment_btn.click(
fn=auto_segment_grid,
inputs=[image_input, grid_size],
outputs=[image_output, info_output]
)
clear_btn.click(
fn=clear_all,
outputs=[image_input, image_output, points_state, labels_state, grid_size, structure_name, info_output]
)
if __name__ == "__main__":
demo.launch()
=== utils.py ===
"""
Utility functions for neuroimaging preprocessing and analysis
"""
import numpy as np
from PIL import Image
def normalize_medical_image(image_array: np.ndarray) -> np.ndarray:
"""
Normalize medical image intensities to 0-255 range
Handles various bit depths common in medical imaging
"""
img = image_array.astype(np.float32)
# Handle different intensity ranges
if img.max() > 255:
# Likely 12-bit or 16-bit image
p1, p99 = np.percentile(img, [1, 99])
img = np.clip(img, p1, p99)
# Normalize to 0-255
img_min, img_max = img.min(), img.max()
if img_max > img_min:
img = (img - img_min) / (img_max - img_min) * 255
return img.astype(np.uint8)
def apply_window_level(image_array: np.ndarray, window: float, level: float) -> np.ndarray:
"""
Apply window/level (contrast/brightness) adjustment
Common in CT viewing
Args:
image_array: Input image
window: Window width (contrast)
level: Window center (brightness)
"""
img = image_array.astype(np.float32)
min_val = level - window / 2
max_val = level + window / 2
img = np.clip(img, min_val, max_val)
img = (img - min_val) / (max_val - min_val) * 255
return img.astype(np.uint8)
def enhance_brain_contrast(image: Image.Image) -> Image.Image:
"""
Enhance contrast specifically for brain MRI visualization
"""
img_array = np.array(image)
# Convert to grayscale if needed
if len(img_array.shape) == 3:
gray = np.mean(img_array, axis=2)
else:
gray = img_array
# Apply histogram equalization
from PIL import ImageOps
enhanced = ImageOps.equalize(Image.fromarray(gray.astype(np.uint8)))
# Convert back to RGB
enhanced_array = np.array(enhanced)
rgb_array = np.stack([enhanced_array] * 3, axis=-1)
return Image.fromarray(rgb_array)
# Common neuroimaging structure mappings
STRUCTURE_ALIASES = {
"hippocampus": ["hippocampal formation", "hippocampal", "medial temporal"],
"ventricle": ["ventricular system", "lateral ventricle", "CSF space"],
"white matter": ["WM", "cerebral white matter", "deep white matter"],
"gray matter": ["GM", "cortical gray matter", "cortex"],
"tumor": ["mass", "lesion", "neoplasm", "growth"],
"thalamus": ["thalamic", "diencephalon"],
"basal ganglia": ["striatum", "caudate", "putamen", "globus pallidus"],
}
def get_structure_aliases(structure: str) -> list:
"""Get alternative names for a neuroanatomical structure"""
structure_lower = structure.lower()
for key, aliases in STRUCTURE_ALIASES.items():
if structure_lower == key or structure_lower in aliases:
return [key] + aliases
return [structure]
# Hugging Face datasets for neuroimaging
HF_NEUROIMAGING_DATASETS = {
"brain-tumor-classification": {
"repo": "sartajbhuvaji/brain-tumor-classification",
"description": "Brain MRI scans classified by tumor type (glioma, meningioma, pituitary, no tumor)",
"image_key": "image",
"label_key": "label"
},
"brain-tumor-detection": {
"repo": "keremberke/brain-tumor-object-detection",
"description": "Brain MRI with bounding box annotations for tumors",
"image_key": "image",
"label_key": "objects"
},
"chest-xray": {
"repo": "alkzar90/NIH-Chest-X-ray-dataset",
"description": "Chest X-ray images with disease labels",
"image_key": "image",
"label_key": "labels"
}
}