mmrech commited on
Commit
764838e
·
verified ·
1 Parent(s): f0ffcc0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +690 -2
app.py CHANGED
@@ -1,3 +1,691 @@
1
- It looks like the existing Gradio 6 app is complete and already follows the required syntax (including the `footer_links` and `api_visibility` usage). Since your request was simply to “run it,” there’s no additional modification needed in the code itself—just execute `python app.py` in your environment to launch the NeuroSAM application.
2
 
3
- Let me know if you’d like any adjustments or new features added!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ === app.py ===
3
+ import spaces
4
+ import gradio as gr
5
+ import torch
6
+ import numpy as np
7
+ from PIL import Image
8
+ from transformers import SamModel, SamProcessor
9
+ from datasets import load_dataset
10
+ import requests
11
+ from io import BytesIO
12
+ import warnings
13
+ warnings.filterwarnings("ignore")
14
+
15
+ # Global model and processor - Using SAM (Segment Anything Model)
16
+ device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ model = None
18
+ processor = None
19
+
20
+ def load_model():
21
+ """Load SAM model lazily"""
22
+ global model, processor
23
+ if model is None:
24
+ print("Loading SAM model...")
25
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
26
+ model = SamModel.from_pretrained("facebook/sam-vit-base")
27
+ if torch.cuda.is_available():
28
+ model = model.to(device)
29
+ print("Model loaded successfully!")
30
+ return model, processor
31
+
32
+ # Public neuroimaging datasets on Hugging Face
33
+ NEUROIMAGING_DATASETS = {
34
+ "Brain Tumor MRI": {
35
+ "dataset": "sartajbhuvaji/brain-tumor-classification",
36
+ "description": "Brain MRI scans with tumor classifications",
37
+ "split": "train"
38
+ },
39
+ "Medical MNIST (Brain)": {
40
+ "dataset": "alkzar90/NIH-Chest-X-ray-dataset",
41
+ "description": "Medical imaging dataset",
42
+ "split": "train"
43
+ },
44
+ }
45
+
46
+ # Sample neuroimaging URLs (publicly available brain MRI examples)
47
+ SAMPLE_IMAGES = {
48
+ "Brain MRI - Axial": "https://upload.wikimedia.org/wikipedia/commons/thumb/5/5e/MRI_of_Human_Brain.jpg/800px-MRI_of_Human_Brain.jpg",
49
+ "Brain MRI - Sagittal": "https://upload.wikimedia.org/wikipedia/commons/1/1a/MRI_head_side.jpg",
50
+ "Brain CT Scan": "https://upload.wikimedia.org/wikipedia/commons/thumb/9/9a/CT_of_brain_of_Mikael_H%C3%A4ggstr%C3%B6m_%28montage%29.png/800px-CT_of_brain_of_Mikael_H%C3%A4ggstr%C3%B6m_%28montage%29.png",
51
+ }
52
+
53
+ # Neuroimaging-specific prompts and presets
54
+ NEURO_PRESETS = {
55
+ "Brain Structures": ["brain", "cerebrum", "cerebellum", "brainstem", "corpus callosum"],
56
+ "Lobes": ["frontal lobe", "temporal lobe", "parietal lobe", "occipital lobe"],
57
+ "Ventricles": ["ventricle", "lateral ventricle", "third ventricle", "fourth ventricle"],
58
+ "Gray/White Matter": ["gray matter", "white matter", "cortex", "subcortical"],
59
+ "Deep Structures": ["thalamus", "hypothalamus", "hippocampus", "amygdala", "basal ganglia"],
60
+ "Lesions/Abnormalities": ["lesion", "tumor", "mass", "abnormality", "hyperintensity"],
61
+ "Vascular": ["blood vessel", "artery", "vein", "sinus"],
62
+ "Skull/Meninges": ["skull", "bone", "meninges", "dura"],
63
+ }
64
+
65
+ @spaces.GPU()
66
+ def segment_with_points(image: Image.Image, points: list, labels: list, structure_name: str):
67
+ """
68
+ Perform segmentation using SAM with point prompts.
69
+ SAM uses point/box prompts, not text prompts.
70
+ """
71
+ if image is None:
72
+ return None, "❌ Please upload a neuroimaging scan."
73
+
74
+ try:
75
+ sam_model, sam_processor = load_model()
76
+
77
+ # Ensure image is RGB
78
+ if image.mode != "RGB":
79
+ image = image.convert("RGB")
80
+
81
+ # Prepare inputs with point prompts
82
+ if points and len(points) > 0:
83
+ input_points = [points] # Shape: (batch, num_points, 2)
84
+ input_labels = [labels] # Shape: (batch, num_points)
85
+ else:
86
+ # Use center point as default
87
+ w, h = image.size
88
+ input_points = [[[w // 2, h // 2]]]
89
+ input_labels = [[1]] # 1 = foreground
90
+
91
+ inputs = sam_processor(
92
+ image,
93
+ input_points=input_points,
94
+ input_labels=input_labels,
95
+ return_tensors="pt"
96
+ )
97
+
98
+ if torch.cuda.is_available():
99
+ inputs = {k: v.to(device) for k, v in inputs.items()}
100
+
101
+ with torch.no_grad():
102
+ outputs = sam_model(**inputs)
103
+
104
+ # Post-process masks
105
+ masks = sam_processor.image_processor.post_process_masks(
106
+ outputs.pred_masks.cpu(),
107
+ inputs["original_sizes"].cpu(),
108
+ inputs["reshaped_input_sizes"].cpu()
109
+ )
110
+
111
+ scores = outputs.iou_scores.cpu().numpy()[0]
112
+
113
+ # Get best mask
114
+ masks_np = masks[0].numpy()
115
+
116
+ if masks_np.shape[0] == 0:
117
+ return (image, []), f"❌ No segmentation found for the selected points."
118
+
119
+ # Format for AnnotatedImage
120
+ annotations = []
121
+ for i in range(min(3, masks_np.shape[1])): # Top 3 masks
122
+ mask = masks_np[0, i].astype(np.uint8)
123
+ if mask.sum() > 0: # Only add non-empty masks
124
+ score = scores[0, i] if i < scores.shape[1] else 0.0
125
+ label = f"{structure_name} (IoU: {score:.2f})"
126
+ annotations.append((mask, label))
127
+
128
+ if not annotations:
129
+ return (image, []), "❌ No valid masks generated."
130
+
131
+ info = f"""✅ **Segmentation Complete**
132
+
133
+ **Target:** {structure_name}
134
+ **Masks Generated:** {len(annotations)}
135
+ **Best IoU Score:** {scores.max():.3f}
136
+
137
+ *SAM generates multiple mask proposals - showing top results*"""
138
+
139
+ return (image, annotations), info
140
+
141
+ except Exception as e:
142
+ import traceback
143
+ return (image, []), f"❌ Error: {str(e)}\n\n{traceback.format_exc()}"
144
+
145
+ @spaces.GPU()
146
+ def segment_with_box(image: Image.Image, x1: int, y1: int, x2: int, y2: int, structure_name: str):
147
+ """Segment using bounding box prompt"""
148
+ if image is None:
149
+ return None, "❌ Please upload an image."
150
+
151
+ try:
152
+ sam_model, sam_processor = load_model()
153
+
154
+ if image.mode != "RGB":
155
+ image = image.convert("RGB")
156
+
157
+ # Prepare box prompt
158
+ input_boxes = [[[x1, y1, x2, y2]]]
159
+
160
+ inputs = sam_processor(
161
+ image,
162
+ input_boxes=input_boxes,
163
+ return_tensors="pt"
164
+ )
165
+
166
+ if torch.cuda.is_available():
167
+ inputs = {k: v.to(device) for k, v in inputs.items()}
168
+
169
+ with torch.no_grad():
170
+ outputs = sam_model(**inputs)
171
+
172
+ masks = sam_processor.image_processor.post_process_masks(
173
+ outputs.pred_masks.cpu(),
174
+ inputs["original_sizes"].cpu(),
175
+ inputs["reshaped_input_sizes"].cpu()
176
+ )
177
+
178
+ scores = outputs.iou_scores.cpu().numpy()[0]
179
+ masks_np = masks[0].numpy()
180
+
181
+ annotations = []
182
+ for i in range(min(3, masks_np.shape[1])):
183
+ mask = masks_np[0, i].astype(np.uint8)
184
+ if mask.sum() > 0:
185
+ score = scores[0, i] if i < scores.shape[1] else 0.0
186
+ label = f"{structure_name} (IoU: {score:.2f})"
187
+ annotations.append((mask, label))
188
+
189
+ if not annotations:
190
+ return (image, []), "❌ No valid masks generated from box."
191
+
192
+ info = f"""✅ **Box Segmentation Complete**
193
+
194
+ **Target:** {structure_name}
195
+ **Box:** ({x1}, {y1}) to ({x2}, {y2})
196
+ **Masks Generated:** {len(annotations)}"""
197
+
198
+ return (image, annotations), info
199
+
200
+ except Exception as e:
201
+ return (image, []), f"❌ Error: {str(e)}"
202
+
203
+ @spaces.GPU()
204
+ def auto_segment_grid(image: Image.Image, grid_size: int = 4):
205
+ """Automatic segmentation using grid of points"""
206
+ if image is None:
207
+ return None, "❌ Please upload an image."
208
+
209
+ try:
210
+ sam_model, sam_processor = load_model()
211
+
212
+ if image.mode != "RGB":
213
+ image = image.convert("RGB")
214
+
215
+ w, h = image.size
216
+
217
+ # Create grid of points
218
+ points = []
219
+ step_x = w // (grid_size + 1)
220
+ step_y = h // (grid_size + 1)
221
+
222
+ for i in range(1, grid_size + 1):
223
+ for j in range(1, grid_size + 1):
224
+ points.append([step_x * i, step_y * j])
225
+
226
+ all_annotations = []
227
+
228
+ # Process each point
229
+ for idx, point in enumerate(points[:9]): # Limit to 9 points for speed
230
+ input_points = [[point]]
231
+ input_labels = [[1]]
232
+
233
+ inputs = sam_processor(
234
+ image,
235
+ input_points=input_points,
236
+ input_labels=input_labels,
237
+ return_tensors="pt"
238
+ )
239
+
240
+ if torch.cuda.is_available():
241
+ inputs = {k: v.to(device) for k, v in inputs.items()}
242
+
243
+ with torch.no_grad():
244
+ outputs = sam_model(**inputs)
245
+
246
+ masks = sam_processor.image_processor.post_process_masks(
247
+ outputs.pred_masks.cpu(),
248
+ inputs["original_sizes"].cpu(),
249
+ inputs["reshaped_input_sizes"].cpu()
250
+ )
251
+
252
+ scores = outputs.iou_scores.cpu().numpy()[0]
253
+ masks_np = masks[0].numpy()
254
+
255
+ # Get best mask for this point
256
+ if masks_np.shape[1] > 0:
257
+ best_idx = scores[0].argmax()
258
+ mask = masks_np[0, best_idx].astype(np.uint8)
259
+ if mask.sum() > 100: # Minimum size threshold
260
+ score = scores[0, best_idx]
261
+ label = f"Region {idx + 1} (IoU: {score:.2f})"
262
+ all_annotations.append((mask, label))
263
+
264
+ if not all_annotations:
265
+ return (image, []), "❌ No regions found with auto-segmentation."
266
+
267
+ info = f"""✅ **Auto-Segmentation Complete**
268
+
269
+ **Grid Points:** {len(points)}
270
+ **Regions Found:** {len(all_annotations)}
271
+
272
+ *Automatic discovery of distinct regions in the image*"""
273
+
274
+ return (image, all_annotations), info
275
+
276
+ except Exception as e:
277
+ return (image, []), f"❌ Error: {str(e)}"
278
+
279
+ def load_sample_image(sample_name: str):
280
+ """Load a sample neuroimaging image"""
281
+ if sample_name not in SAMPLE_IMAGES:
282
+ return None, "Sample not found"
283
+
284
+ try:
285
+ url = SAMPLE_IMAGES[sample_name]
286
+ response = requests.get(url, timeout=10)
287
+ image = Image.open(BytesIO(response.content)).convert("RGB")
288
+ return image, f"✅ Loaded: {sample_name}"
289
+ except Exception as e:
290
+ return None, f"❌ Failed to load sample: {str(e)}"
291
+
292
+ def load_from_hf_dataset(dataset_name: str, index: int = 0):
293
+ """Load image from Hugging Face dataset"""
294
+ try:
295
+ if dataset_name == "Brain Tumor MRI":
296
+ ds = load_dataset("sartajbhuvaji/brain-tumor-classification", split="train", streaming=True)
297
+ for i, sample in enumerate(ds):
298
+ if i == index:
299
+ image = sample["image"]
300
+ if image.mode != "RGB":
301
+ image = image.convert("RGB")
302
+ return image, f"✅ Loaded from Brain Tumor MRI dataset (index {index})"
303
+ return None, "Dataset not available"
304
+ except Exception as e:
305
+ return None, f"❌ Error loading dataset: {str(e)}"
306
+
307
+ def get_click_point(image, evt: gr.SelectData):
308
+ """Get point coordinates from image click"""
309
+ if evt is None:
310
+ return [], [], "Click on the image to add points"
311
+
312
+ x, y = evt.index
313
+ return [[x, y]], [1], f"Point added at ({x}, {y})"
314
+
315
+ # Store points for multi-point selection
316
+ current_points = []
317
+ current_labels = []
318
+
319
+ def add_point(image, evt: gr.SelectData, points_state, labels_state, point_type):
320
+ """Add a point to the current selection"""
321
+ if evt is None or image is None:
322
+ return points_state, labels_state, "Click on image to add points"
323
+
324
+ x, y = evt.index
325
+ label = 1 if point_type == "Foreground (+)" else 0
326
+
327
+ points_state = points_state + [[x, y]]
328
+ labels_state = labels_state + [label]
329
+
330
+ point_info = f"Added {'foreground' if label == 1 else 'background'} point at ({x}, {y})\n"
331
+ point_info += f"Total points: {len(points_state)}"
332
+
333
+ return points_state, labels_state, point_info
334
+
335
+ def clear_points():
336
+ """Clear all selected points"""
337
+ return [], [], "Points cleared"
338
+
339
+ def clear_all():
340
+ """Clear all inputs and outputs"""
341
+ return None, None, [], [], 0.5, "brain region", "📝 Upload a neuroimaging scan and click to add points for segmentation."
342
+
343
+ # Gradio Interface
344
+ with gr.Blocks(
345
+ theme=gr.themes.Soft(),
346
+ title="NeuroSAM - Neuroimaging Segmentation",
347
+ css="""
348
+ .gradio-container {max-width: 1400px !important;}
349
+ .neuro-header {background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 10px; margin-bottom: 20px;}
350
+ .neuro-header h1 {color: white !important; margin: 0 !important;}
351
+ .neuro-header p {color: rgba(255,255,255,0.9) !important;}
352
+ .info-box {background: #e8f4f8; padding: 15px; border-radius: 8px; margin: 10px 0;}
353
+ """,
354
+ footer_links=[{"label": "Built with anycoder", "url": "https://huggingface.co/spaces/akhaliq/anycoder"}]
355
+ ) as demo:
356
+
357
+ # State for point selection
358
+ points_state = gr.State([])
359
+ labels_state = gr.State([])
360
+
361
+ gr.HTML(
362
+ """
363
+ <div class="neuro-header">
364
+ <h1>🧠 NeuroSAM - Neuroimaging Segmentation</h1>
365
+ <p>Interactive segmentation using Meta's Segment Anything Model (SAM)</p>
366
+ <p style="font-size: 0.9em;">Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" style="color: #FFD700;">anycoder</a> | Model: facebook/sam-vit-base</p>
367
+ </div>
368
+ """
369
+ )
370
+
371
+ gr.Markdown("""
372
+ ### ℹ️ About SAM (Segment Anything Model)
373
+
374
+ **SAM** is a foundation model for image segmentation by Meta AI. Unlike text-based models, SAM uses **visual prompts**:
375
+ - **Point prompts**: Click on the region you want to segment
376
+ - **Box prompts**: Draw a bounding box around the region
377
+ - **Automatic mode**: Discovers all segmentable regions
378
+
379
+ *Note: SAM is a general-purpose segmentation model, not specifically trained on medical images. For clinical use, specialized medical imaging models should be used.*
380
+ """)
381
+
382
+ with gr.Row():
383
+ with gr.Column(scale=1):
384
+ gr.Markdown("### 📤 Input")
385
+
386
+ image_input = gr.Image(
387
+ label="Neuroimaging Scan (Click to add points)",
388
+ type="pil",
389
+ height=400,
390
+ interactive=True,
391
+ )
392
+
393
+ with gr.Accordion("📂 Load Sample Images", open=True):
394
+ sample_dropdown = gr.Dropdown(
395
+ label="Sample Neuroimaging Images",
396
+ choices=list(SAMPLE_IMAGES.keys()),
397
+ value=None,
398
+ info="Load publicly available brain imaging examples"
399
+ )
400
+ load_sample_btn = gr.Button("Load Sample", size="sm")
401
+
402
+ gr.Markdown("**Or load from Hugging Face Datasets:**")
403
+ with gr.Row():
404
+ hf_dataset = gr.Dropdown(
405
+ label="Dataset",
406
+ choices=["Brain Tumor MRI"],
407
+ value="Brain Tumor MRI"
408
+ )
409
+ hf_index = gr.Number(label="Image Index", value=0, minimum=0, maximum=100)
410
+ load_hf_btn = gr.Button("Load from HF", size="sm")
411
+
412
+ gr.Markdown("### 🎯 Segmentation Mode")
413
+
414
+ with gr.Tab("Point Prompt"):
415
+ gr.Markdown("**Click on the image to add points, then segment**")
416
+
417
+ point_type = gr.Radio(
418
+ choices=["Foreground (+)", "Background (-)"],
419
+ value="Foreground (+)",
420
+ label="Point Type",
421
+ info="Foreground = include region, Background = exclude region"
422
+ )
423
+
424
+ structure_name = gr.Textbox(
425
+ label="Structure Label",
426
+ value="brain region",
427
+ placeholder="e.g., hippocampus, ventricle, tumor...",
428
+ info="Label for the segmented region"
429
+ )
430
+
431
+ points_info = gr.Textbox(
432
+ label="Selected Points",
433
+ value="Click on image to add points",
434
+ interactive=False
435
+ )
436
+
437
+ with gr.Row():
438
+ clear_points_btn = gr.Button("Clear Points", variant="secondary")
439
+ segment_points_btn = gr.Button("🎯 Segment", variant="primary")
440
+
441
+ with gr.Tab("Box Prompt"):
442
+ gr.Markdown("**Define a bounding box around the region**")
443
+
444
+ with gr.Row():
445
+ box_x1 = gr.Number(label="X1 (left)", value=50)
446
+ box_y1 = gr.Number(label="Y1 (top)", value=50)
447
+ with gr.Row():
448
+ box_x2 = gr.Number(label="X2 (right)", value=200)
449
+ box_y2 = gr.Number(label="Y2 (bottom)", value=200)
450
+
451
+ box_structure = gr.Textbox(
452
+ label="Structure Label",
453
+ value="selected region"
454
+ )
455
+
456
+ segment_box_btn = gr.Button("🎯 Segment Box", variant="primary")
457
+
458
+ with gr.Tab("Auto Segment"):
459
+ gr.Markdown("**Automatically discover all segmentable regions**")
460
+
461
+ grid_size = gr.Slider(
462
+ minimum=2,
463
+ maximum=5,
464
+ value=3,
465
+ step=1,
466
+ label="Grid Density",
467
+ info="Higher = more points sampled"
468
+ )
469
+
470
+ auto_segment_btn = gr.Button("🔍 Auto-Segment All", variant="primary")
471
+
472
+ clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
473
+
474
+ with gr.Column(scale=1):
475
+ gr.Markdown("### 📊 Output")
476
+
477
+ image_output = gr.AnnotatedImage(
478
+ label="Segmented Result",
479
+ height=450,
480
+ show_legend=True,
481
+ )
482
+
483
+ info_output = gr.Markdown(
484
+ value="📝 Upload a neuroimaging scan and click to add points for segmentation.",
485
+ label="Results"
486
+ )
487
+
488
+ gr.Markdown("### 📚 Available Datasets on Hugging Face")
489
+
490
+ with gr.Row():
491
+ with gr.Column():
492
+ gr.Markdown("""
493
+ **🧠 Brain/Neuro Imaging**
494
+ - `sartajbhuvaji/brain-tumor-classification` - Brain MRI with tumor labels
495
+ - `keremberke/brain-tumor-object-detection` - Brain tumor detection
496
+ - `TrainingDataPro/brain-mri-dataset` - Brain MRI scans
497
+ """)
498
+ with gr.Column():
499
+ gr.Markdown("""
500
+ **🏥 Medical Imaging**
501
+ - `alkzar90/NIH-Chest-X-ray-dataset` - Chest X-rays
502
+ - `marmal88/skin_cancer` - Dermatology images
503
+ - `hf-vision/chest-xray-pneumonia` - Pneumonia detection
504
+ """)
505
+ with gr.Column():
506
+ gr.Markdown("""
507
+ **🔬 Specialized**
508
+ - `Francesco/cell-segmentation` - Cell microscopy
509
+ - `segments/sidewalk-semantic` - Semantic segmentation
510
+ - `detection-datasets/coco` - General objects
511
+ """)
512
+
513
+ gr.Markdown("""
514
+ ### 💡 How to Use
515
+
516
+ 1. **Load an image**: Upload your own or select from samples/HuggingFace datasets
517
+ 2. **Choose segmentation mode**:
518
+ - **Point Prompt**: Click on regions you want to segment (green = include, red = exclude)
519
+ - **Box Prompt**: Define coordinates for a bounding box
520
+ - **Auto Segment**: Let SAM discover all distinct regions automatically
521
+ 3. **View results**: Segmented regions appear with colored overlays
522
+
523
+ ### ⚠️ Important Notes
524
+
525
+ - SAM is a **general-purpose** model, not specifically trained for medical imaging
526
+ - For clinical applications, use validated medical imaging AI tools
527
+ - Results should be reviewed by qualified medical professionals
528
+ """)
529
+
530
+ # Event handlers
531
+ load_sample_btn.click(
532
+ fn=load_sample_image,
533
+ inputs=[sample_dropdown],
534
+ outputs=[image_input, info_output]
535
+ )
536
+
537
+ load_hf_btn.click(
538
+ fn=load_from_hf_dataset,
539
+ inputs=[hf_dataset, hf_index],
540
+ outputs=[image_input, info_output]
541
+ )
542
+
543
+ # Point selection on image click
544
+ image_input.select(
545
+ fn=add_point,
546
+ inputs=[image_input, points_state, labels_state, point_type],
547
+ outputs=[points_state, labels_state, points_info]
548
+ )
549
+
550
+ clear_points_btn.click(
551
+ fn=clear_points,
552
+ outputs=[points_state, labels_state, points_info]
553
+ )
554
+
555
+ segment_points_btn.click(
556
+ fn=segment_with_points,
557
+ inputs=[image_input, points_state, labels_state, structure_name],
558
+ outputs=[image_output, info_output]
559
+ )
560
+
561
+ segment_box_btn.click(
562
+ fn=segment_with_box,
563
+ inputs=[image_input, box_x1, box_y1, box_x2, box_y2, box_structure],
564
+ outputs=[image_output, info_output]
565
+ )
566
+
567
+ auto_segment_btn.click(
568
+ fn=auto_segment_grid,
569
+ inputs=[image_input, grid_size],
570
+ outputs=[image_output, info_output]
571
+ )
572
+
573
+ clear_btn.click(
574
+ fn=clear_all,
575
+ outputs=[image_input, image_output, points_state, labels_state, grid_size, structure_name, info_output]
576
+ )
577
+
578
+ if __name__ == "__main__":
579
+ demo.launch()
580
+
581
+ === utils.py ===
582
+ """
583
+ Utility functions for neuroimaging preprocessing and analysis
584
+ """
585
+ import numpy as np
586
+ from PIL import Image
587
+
588
+ def normalize_medical_image(image_array: np.ndarray) -> np.ndarray:
589
+ """
590
+ Normalize medical image intensities to 0-255 range
591
+ Handles various bit depths common in medical imaging
592
+ """
593
+ img = image_array.astype(np.float32)
594
+
595
+ # Handle different intensity ranges
596
+ if img.max() > 255:
597
+ # Likely 12-bit or 16-bit image
598
+ p1, p99 = np.percentile(img, [1, 99])
599
+ img = np.clip(img, p1, p99)
600
+
601
+ # Normalize to 0-255
602
+ img_min, img_max = img.min(), img.max()
603
+ if img_max > img_min:
604
+ img = (img - img_min) / (img_max - img_min) * 255
605
+
606
+ return img.astype(np.uint8)
607
+
608
+ def apply_window_level(image_array: np.ndarray, window: float, level: float) -> np.ndarray:
609
+ """
610
+ Apply window/level (contrast/brightness) adjustment
611
+ Common in CT viewing
612
+
613
+ Args:
614
+ image_array: Input image
615
+ window: Window width (contrast)
616
+ level: Window center (brightness)
617
+ """
618
+ img = image_array.astype(np.float32)
619
+
620
+ min_val = level - window / 2
621
+ max_val = level + window / 2
622
+
623
+ img = np.clip(img, min_val, max_val)
624
+ img = (img - min_val) / (max_val - min_val) * 255
625
+
626
+ return img.astype(np.uint8)
627
+
628
+ def enhance_brain_contrast(image: Image.Image) -> Image.Image:
629
+ """
630
+ Enhance contrast specifically for brain MRI visualization
631
+ """
632
+ img_array = np.array(image)
633
+
634
+ # Convert to grayscale if needed
635
+ if len(img_array.shape) == 3:
636
+ gray = np.mean(img_array, axis=2)
637
+ else:
638
+ gray = img_array
639
+
640
+ # Apply histogram equalization
641
+ from PIL import ImageOps
642
+ enhanced = ImageOps.equalize(Image.fromarray(gray.astype(np.uint8)))
643
+
644
+ # Convert back to RGB
645
+ enhanced_array = np.array(enhanced)
646
+ rgb_array = np.stack([enhanced_array] * 3, axis=-1)
647
+
648
+ return Image.fromarray(rgb_array)
649
+
650
+ # Common neuroimaging structure mappings
651
+ STRUCTURE_ALIASES = {
652
+ "hippocampus": ["hippocampal formation", "hippocampal", "medial temporal"],
653
+ "ventricle": ["ventricular system", "lateral ventricle", "CSF space"],
654
+ "white matter": ["WM", "cerebral white matter", "deep white matter"],
655
+ "gray matter": ["GM", "cortical gray matter", "cortex"],
656
+ "tumor": ["mass", "lesion", "neoplasm", "growth"],
657
+ "thalamus": ["thalamic", "diencephalon"],
658
+ "basal ganglia": ["striatum", "caudate", "putamen", "globus pallidus"],
659
+ }
660
+
661
+ def get_structure_aliases(structure: str) -> list:
662
+ """Get alternative names for a neuroanatomical structure"""
663
+ structure_lower = structure.lower()
664
+
665
+ for key, aliases in STRUCTURE_ALIASES.items():
666
+ if structure_lower == key or structure_lower in aliases:
667
+ return [key] + aliases
668
+
669
+ return [structure]
670
+
671
+ # Hugging Face datasets for neuroimaging
672
+ HF_NEUROIMAGING_DATASETS = {
673
+ "brain-tumor-classification": {
674
+ "repo": "sartajbhuvaji/brain-tumor-classification",
675
+ "description": "Brain MRI scans classified by tumor type (glioma, meningioma, pituitary, no tumor)",
676
+ "image_key": "image",
677
+ "label_key": "label"
678
+ },
679
+ "brain-tumor-detection": {
680
+ "repo": "keremberke/brain-tumor-object-detection",
681
+ "description": "Brain MRI with bounding box annotations for tumors",
682
+ "image_key": "image",
683
+ "label_key": "objects"
684
+ },
685
+ "chest-xray": {
686
+ "repo": "alkzar90/NIH-Chest-X-ray-dataset",
687
+ "description": "Chest X-ray images with disease labels",
688
+ "image_key": "image",
689
+ "label_key": "labels"
690
+ }
691
+ }