Enferlain commited on
Commit
ffa73e4
·
verified ·
1 Parent(s): b9b2f8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -174
app.py CHANGED
@@ -1,31 +1,33 @@
1
  import os
2
  import json
3
  import traceback
4
- from typing import Optional, Tuple, Union, List
5
 
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
- from PIL import Image, PngImagePlugin
10
  from safetensors.torch import load_file
11
  from huggingface_hub import hf_hub_download
12
- from transformers import AutoProcessor, AutoModel, AutoImageProcessor
13
  import gradio as gr
14
- import math # Added math
15
 
16
  # --- Device Setup ---
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
- # Use float16 for vision model on CUDA for speed/memory, but head expects float32
19
- VISION_DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
20
- HEAD_DTYPE = torch.float32 # Head usually trained/stable in float32
 
 
 
 
21
 
22
  print(f"Using device: {DEVICE}")
23
- print(f"Vision model dtype: {VISION_DTYPE}")
24
  print(f"Head model dtype: {HEAD_DTYPE}")
25
 
26
 
27
  # --- Model Definitions (Copied from hybrid_model.py) ---
28
-
29
  class RMSNorm(nn.Module):
30
  def __init__(self, dim: int, eps: float = 1e-6):
31
  super().__init__()
@@ -36,8 +38,6 @@ class RMSNorm(nn.Module):
36
  def forward(self, x: torch.Tensor) -> torch.Tensor:
37
  output = self._norm(x.float()).type_as(x)
38
  return output * self.weight
39
- def extra_repr(self) -> str:
40
- return f"{tuple(self.weight.shape)}, eps={self.eps}"
41
 
42
  class SwiGLUFFN(nn.Module):
43
  def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None, act_layer: nn.Module = nn.SiLU, dropout: float = 0.):
@@ -71,22 +71,17 @@ class HybridHeadModel(nn.Module):
71
  super().__init__()
72
  self.features = features; self.hidden_dim = hidden_dim; self.num_classes = num_classes
73
  self.use_attention = use_attention; self.output_mode = output_mode.lower()
74
- # --- Optional Self-Attention Layer ---
75
  self.attention = None; self.norm_attn = None
76
  if self.use_attention:
77
- actual_num_heads = num_attn_heads # Adjust head logic needed here if features != 1152
78
- # Simple head adjustment:
79
  if features % num_attn_heads != 0:
80
- possible_heads = [h for h in [1, 2, 4, 8, 16] if features % h == 0]
81
- if not possible_heads: actual_num_heads = 1 # Fallback to 1 head if no divisors found
82
  else: actual_num_heads = min(possible_heads, key=lambda x: abs(x-num_attn_heads))
83
- if actual_num_heads != num_attn_heads: print(f"HybridHead Warning: Adjusting heads {num_attn_heads}->{actual_num_heads}")
84
-
85
  self.attention = nn.MultiheadAttention(features, actual_num_heads, dropout=attn_dropout, batch_first=True, bias=True)
86
  self.norm_attn = RMSNorm(features, eps=rms_norm_eps)
87
- # --- MLP Head ---
88
- mlp_layers = []
89
- mlp_layers.append(nn.Linear(features, hidden_dim)); mlp_layers.append(RMSNorm(hidden_dim, eps=rms_norm_eps))
90
  for _ in range(num_res_blocks): mlp_layers.append(ResBlockRMS(hidden_dim, dropout=dropout_rate, rms_norm_eps=rms_norm_eps))
91
  mlp_layers.append(RMSNorm(hidden_dim, eps=rms_norm_eps))
92
  down_proj_hidden = hidden_dim // 2
@@ -94,192 +89,173 @@ class HybridHeadModel(nn.Module):
94
  mlp_layers.append(RMSNorm(down_proj_hidden, eps=rms_norm_eps))
95
  mlp_layers.append(nn.Linear(down_proj_hidden, num_classes))
96
  self.mlp_head = nn.Sequential(*mlp_layers)
97
- # --- Validate Output Mode ---
98
- # (Warnings can be added here if desired, but functionality handled in forward)
99
 
100
  def forward(self, x: torch.Tensor):
101
  if self.use_attention and self.attention is not None:
102
  x_seq = x.unsqueeze(1); attn_output, _ = self.attention(x_seq, x_seq, x_seq); x = self.norm_attn(x + attn_output.squeeze(1))
103
- logits = self.mlp_head(x.to(HEAD_DTYPE)) # Ensure input to MLP has correct dtype
104
- # --- Apply Final Activation ---
105
- output = None
106
- if self.output_mode == 'linear': output = logits
107
- elif self.output_mode == 'sigmoid': output = torch.sigmoid(logits)
108
- elif self.output_mode == 'softmax': output = F.softmax(logits, dim=-1)
109
- elif self.output_mode == 'tanh_scaled': output = (torch.tanh(logits) + 1.0) / 2.0
110
- else: raise RuntimeError(f"Invalid output_mode '{self.output_mode}'.")
111
  if self.num_classes == 1 and output.ndim == 2 and output.shape[1] == 1: output = output.squeeze(-1)
112
  return output
113
 
114
- # --- Constants and Model Loading ---
115
-
116
- # Option 1: Files are in the Space repo (e.g., in a 'model' folder)
117
- # MODEL_DIR = "model"
118
- # HEAD_MODEL_FILENAME = "AnatomyFlaws-v11.3_adabelief_fl_naflex_3000_s9K.safetensors"
119
- # CONFIG_FILENAME = "AnatomyFlaws-v11.3_adabelief_fl_naflex_3000.config.json" # Assuming config matches base name
120
- # HEAD_MODEL_PATH = os.path.join(MODEL_DIR, HEAD_MODEL_FILENAME)
121
- # CONFIG_PATH = os.path.join(MODEL_DIR, CONFIG_FILENAME)
122
-
123
- # Option 2: Download from Hub
124
- # Replace with your HF username and repo name
125
- HUB_REPO_ID = "Enferlain/lumi-classifier" # Or wherever you uploaded the model
126
- # Use the specific checkpoint you want (e.g., s9k or the best_val one)
127
- HEAD_MODEL_FILENAME = "AnatomyFlaws-v11.3_adabelief_fl_naflex_3000_s6K_best_val.safetensors"
128
- # Usually config corresponds to the base run name, not a specific step
129
- CONFIG_FILENAME = "AnatomyFlaws-v11.3_adabelief_fl_naflex_3000.config.json"
130
-
131
- print("Downloading model files if necessary...")
132
- try:
133
- HEAD_MODEL_PATH = hf_hub_download(repo_id=HUB_REPO_ID, filename=HEAD_MODEL_FILENAME)
134
- CONFIG_PATH = hf_hub_download(repo_id=HUB_REPO_ID, filename=CONFIG_FILENAME)
135
- print("Files downloaded/found successfully.")
136
- except Exception as e:
137
- print(f"ERROR downloading files from {HUB_REPO_ID}: {e}")
138
- print("Please ensure the files exist on the Hub or place them in a local 'model' folder.")
139
- # Optionally exit or fallback
140
- exit(1) # Exit if essential files aren't available
141
-
142
-
143
- # --- Load Config ---
144
- print(f"Loading config from: {CONFIG_PATH}")
145
- config = {}
146
- try:
147
- with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
148
- config = json.load(f)
149
- except Exception as e:
150
- print(f"ERROR loading config file: {e}"); exit(1)
151
-
152
- # --- Load Vision Model ---
153
- BASE_VISION_MODEL_NAME = config.get("base_vision_model", "google/siglip2-so400m-patch16-naflex")
154
- print(f"Loading vision model: {BASE_VISION_MODEL_NAME}")
155
- try:
156
- hf_processor = AutoProcessor.from_pretrained(BASE_VISION_MODEL_NAME)
157
- vision_model = AutoModel.from_pretrained(
158
- BASE_VISION_MODEL_NAME, torch_dtype=VISION_DTYPE
159
- ).to(DEVICE).eval()
160
- print("Vision model loaded.")
161
- except Exception as e:
162
- print(f"ERROR loading vision model: {e}"); exit(1)
163
-
164
- # --- Load HybridHeadModel ---
165
- print(f"Loading head model: {HEAD_MODEL_PATH}")
166
- head_model = None
167
- try:
168
- state_dict = load_file(HEAD_MODEL_PATH, device='cpu')
169
- # Infer details from config - use defaults matching the successful run
170
- features = config.get("features", 1152)
171
- num_classes = config.get("num_classes", 2) # Should be 2 for focal loss run
172
- output_mode = config.get("output_mode", "linear") # Should be linear
173
- hidden_dim = config.get("hidden_dim", 1280)
174
- num_res_blocks = config.get("num_res_blocks", 3)
175
- dropout_rate = config.get("dropout_rate", 0.3) # Use the high dropout from best run
176
- use_attention = config.get("use_attention", True) # Use attention was likely True
177
- num_attn_heads = config.get("num_attn_heads", 16)
178
- attn_dropout = config.get("attn_dropout", 0.3) # Use the high dropout
179
- rms_norm_eps= config.get("rms_norm_eps", 1e-6)
180
-
181
- head_model = HybridHeadModel(
182
- features=features, hidden_dim=hidden_dim, num_classes=num_classes,
183
- use_attention=use_attention, num_attn_heads=num_attn_heads, attn_dropout=attn_dropout,
184
- num_res_blocks=num_res_blocks, dropout_rate=dropout_rate, rms_norm_eps=rms_norm_eps,
185
- output_mode=output_mode
186
- )
187
- missing, unexpected = head_model.load_state_dict(state_dict, strict=False)
188
- if missing: print(f"Warning: Missing keys loading head: {missing}")
189
- if unexpected: print(f"Warning: Unexpected keys loading head: {unexpected}")
190
- head_model.to(DEVICE).eval()
191
- print("Head model loaded.")
192
- except Exception as e:
193
- print(f"ERROR loading head model: {e}"); exit(1)
194
-
195
- # --- Label Mapping ---
196
- # Assume labels are '0': Bad, '1': Good from config or default
197
- LABELS = config.get("labels", {'0': 'Bad Anatomy', '1': 'Good Anatomy'})
198
- LABEL_NAMES = {
199
- 0: LABELS.get('0', 'Class 0'),
200
- 1: LABELS.get('1', 'Class 1')
201
  }
202
- print(f"Using Labels: {LABEL_NAMES}")
203
 
204
- # --- Prediction Function ---
205
- def predict_anatomy(image: Image.Image):
206
- """Takes PIL Image, returns dict of class probabilities."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  if image is None: return {"Error": "No image provided"}
208
  try:
 
209
  pil_image = image.convert("RGB")
 
210
 
211
- # 1. Extract SigLIP NaFlex Embedding
212
  with torch.no_grad():
213
- inputs = hf_processor(images=[pil_image], return_tensors="pt", max_num_patches=1024)
214
- pixel_values = inputs.get("pixel_values").to(device=DEVICE, dtype=VISION_DTYPE)
215
- attention_mask = inputs.get("pixel_attention_mask").to(device=DEVICE)
216
- spatial_shapes = inputs.get("spatial_shapes")
217
- model_call_kwargs = {"pixel_values": pixel_values, "attention_mask": attention_mask,
218
- "spatial_shapes": torch.tensor(spatial_shapes, dtype=torch.long).to(DEVICE)}
219
-
220
- vision_model_component = getattr(vision_model, 'vision_model', vision_model) # Handle potential nesting
221
- emb = vision_model_component(**model_call_kwargs).pooler_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  if emb is None: raise ValueError("Failed to get embedding.")
223
-
224
- # L2 Norm
225
  norm = torch.linalg.norm(emb.float(), dim=-1, keepdim=True).clamp(min=1e-8)
226
  emb_normalized = emb / norm.to(emb.dtype)
227
-
228
- # 2. Obtain Prediction from HybridHeadModel Head
229
  with torch.no_grad():
230
- prediction = head_model(emb_normalized.to(DEVICE, dtype=HEAD_DTYPE))
231
-
232
- # 3. Format Output Probabilities
233
  output_probs = {}
234
- output_mode = getattr(head_model, 'output_mode', 'linear')
235
-
236
- if head_model.num_classes == 1:
237
- logit = prediction.squeeze().item()
238
- prob_good = torch.sigmoid(torch.tensor(logit)).item() if output_mode == 'linear' else logit
239
- output_probs[LABEL_NAMES[0]] = 1.0 - prob_good
240
- output_probs[LABEL_NAMES[1]] = prob_good
241
- elif head_model.num_classes == 2:
242
- if output_mode == 'linear':
243
- probs = F.softmax(prediction.squeeze().float(), dim=-1) # Use float for softmax stability
244
- else: # Assume sigmoid or already softmax
245
- probs = prediction.squeeze().float()
246
- output_probs[LABEL_NAMES[0]] = probs[0].item()
247
- output_probs[LABEL_NAMES[1]] = probs[1].item()
248
  else:
249
- output_probs["Error"] = f"Unsupported num_classes: {head_model.num_classes}"
250
-
251
- # Convert to percentage strings for gr.Label maybe? Or keep floats? Keep floats.
252
- # output_formatted = {k: f"{v:.1%}" for k, v in output_probs.items()}
253
  return output_probs
254
-
255
  except Exception as e:
256
  print(f"Error during prediction: {e}\n{traceback.format_exc()}")
257
  return {"Error": str(e)}
258
 
259
  # --- Gradio Interface ---
 
260
  DESCRIPTION = """
261
- ## Anatomy Flaw Classifier Demo ✨ (Based on SigLIP Naflex + Hybrid Head)
262
- Upload an image to classify its anatomy as 'Good' or 'Bad'.
263
- This model uses embeddings from **google/siglip2-so400m-patch16-naflex**
264
- and a custom **HybridHeadModel** fine-tuned for anatomy classification.
265
  """
266
-
267
- # Add example images if you have some in an 'examples' folder in the Space repo
268
  EXAMPLE_DIR = "examples"
269
  examples = []
270
  if os.path.isdir(EXAMPLE_DIR):
271
  examples = [os.path.join(EXAMPLE_DIR, fname) for fname in sorted(os.listdir(EXAMPLE_DIR)) if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
272
 
273
- interface = gr.Interface(
274
- fn=predict_anatomy,
275
- inputs=gr.Image(type="pil", label="Input Image"),
276
- outputs=gr.Label(label="Class Probabilities", num_top_classes=2), # Show top 2 classes
277
- title="Lumi's Anatomy Classifier Demo",
278
- description=DESCRIPTION,
279
- examples=examples if examples else None,
280
- allow_flagging="never",
281
- cache_examples=False # Disable caching if examples change or loading is fast
282
- )
283
 
284
  if __name__ == "__main__":
 
 
 
 
 
285
  interface.launch()
 
1
  import os
2
  import json
3
  import traceback
4
+ from typing import Dict, Any
5
 
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
+ from PIL import Image
10
  from safetensors.torch import load_file
11
  from huggingface_hub import hf_hub_download
12
+ from transformers import AutoProcessor, AutoModel
13
  import gradio as gr
 
14
 
15
  # --- Device Setup ---
16
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
+ # For 8-bit models, the vision dtype is handled by bitsandbytes
18
+ # We still need HEAD_DTYPE for our classifier head
19
+ HEAD_DTYPE = torch.float32
20
+
21
+ # --- DINOv3 Specific Constants ---
22
+ DINOV3_PATCH_SIZE = 16
23
+ MAX_DINOV3_RESOLUTION = 4096
24
 
25
  print(f"Using device: {DEVICE}")
 
26
  print(f"Head model dtype: {HEAD_DTYPE}")
27
 
28
 
29
  # --- Model Definitions (Copied from hybrid_model.py) ---
30
+ # (RMSNorm, SwiGLUFFN, ResBlockRMS, HybridHeadModel classes are unchanged and go here)
31
  class RMSNorm(nn.Module):
32
  def __init__(self, dim: int, eps: float = 1e-6):
33
  super().__init__()
 
38
  def forward(self, x: torch.Tensor) -> torch.Tensor:
39
  output = self._norm(x.float()).type_as(x)
40
  return output * self.weight
 
 
41
 
42
  class SwiGLUFFN(nn.Module):
43
  def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None, act_layer: nn.Module = nn.SiLU, dropout: float = 0.):
 
71
  super().__init__()
72
  self.features = features; self.hidden_dim = hidden_dim; self.num_classes = num_classes
73
  self.use_attention = use_attention; self.output_mode = output_mode.lower()
 
74
  self.attention = None; self.norm_attn = None
75
  if self.use_attention:
76
+ actual_num_heads = num_attn_heads
 
77
  if features % num_attn_heads != 0:
78
+ possible_heads = [h for h in [1, 2, 4, 8, 16, 32] if features % h == 0] # Expanded list
79
+ if not possible_heads: actual_num_heads = 1
80
  else: actual_num_heads = min(possible_heads, key=lambda x: abs(x-num_attn_heads))
81
+ if actual_num_heads != num_attn_heads: print(f"HybridHead Warning: Adjusting heads {num_attn_heads}->{actual_num_heads} for features={features}")
 
82
  self.attention = nn.MultiheadAttention(features, actual_num_heads, dropout=attn_dropout, batch_first=True, bias=True)
83
  self.norm_attn = RMSNorm(features, eps=rms_norm_eps)
84
+ mlp_layers = [nn.Linear(features, hidden_dim), RMSNorm(hidden_dim, eps=rms_norm_eps)]
 
 
85
  for _ in range(num_res_blocks): mlp_layers.append(ResBlockRMS(hidden_dim, dropout=dropout_rate, rms_norm_eps=rms_norm_eps))
86
  mlp_layers.append(RMSNorm(hidden_dim, eps=rms_norm_eps))
87
  down_proj_hidden = hidden_dim // 2
 
89
  mlp_layers.append(RMSNorm(down_proj_hidden, eps=rms_norm_eps))
90
  mlp_layers.append(nn.Linear(down_proj_hidden, num_classes))
91
  self.mlp_head = nn.Sequential(*mlp_layers)
 
 
92
 
93
  def forward(self, x: torch.Tensor):
94
  if self.use_attention and self.attention is not None:
95
  x_seq = x.unsqueeze(1); attn_output, _ = self.attention(x_seq, x_seq, x_seq); x = self.norm_attn(x + attn_output.squeeze(1))
96
+ logits = self.mlp_head(x.to(HEAD_DTYPE))
97
+ output_mode = self.output_mode
98
+ if output_mode == 'linear': output = logits
99
+ elif output_mode == 'sigmoid': output = torch.sigmoid(logits)
100
+ elif output_mode == 'softmax': output = F.softmax(logits, dim=-1)
101
+ elif output_mode == 'tanh_scaled': output = (torch.tanh(logits) + 1.0) / 2.0
102
+ else: raise RuntimeError(f"Invalid output_mode '{output_mode}'.")
 
103
  if self.num_classes == 1 and output.ndim == 2 and output.shape[1] == 1: output = output.squeeze(-1)
104
  return output
105
 
106
+ # --- Model Catalog ---
107
+ MODEL_CATALOG = {
108
+ "AnatomyFlaws-v15.5 (DINOv3 7b 8-bit)": { # <-- Renamed for clarity
109
+ "repo_id": "Enferlain/lumi-classifier",
110
+ "config_filename": "AnatomyFlaws-v15.5_dinov3_7b_bnb_fl.config.json",
111
+ "head_filename": "AnatomyFlaws-v15.5_dinov3_7b_bnb_fl_s4K.safetensors"
112
+ },
113
+ "AnatomyFlaws-v14.7 (SigLIP naflex)": {
114
+ "repo_id": "Enferlain/lumi-classifier",
115
+ "config_filename": "AnatomyFlaws-v14.7_adabelief_fl_naflex_4670.config.json",
116
+ "head_filename": "AnatomyFlaws-v14.7_adabelief_fl_naflex_4670_s2K.safetensors"
117
+ },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  }
 
119
 
120
+ # --- Model Manager Class ---
121
+ class ModelManager:
122
+ def __init__(self, catalog: Dict[str, Dict[str, str]]):
123
+ self.catalog = catalog; self.current_model_name: str = None; self.vision_model: nn.Module = None
124
+ self.hf_processor: Any = None; self.head_model: HybridHeadModel = None
125
+ self.labels: Dict[int, str] = None; self.config: Dict[str, Any] = None
126
+
127
+ def load_model(self, model_name: str):
128
+ if model_name == self.current_model_name: return
129
+ if model_name not in self.catalog: raise ValueError(f"Model '{model_name}' not found.")
130
+ print(f"Switching to model: {model_name}...")
131
+ model_info = self.catalog[model_name]
132
+ repo_id, config_filename, head_filename = model_info["repo_id"], model_info["config_filename"], model_info["head_filename"]
133
+ try:
134
+ config_path = hf_hub_download(repo_id=repo_id, filename=config_filename)
135
+ with open(config_path, 'r', encoding='utf-8') as f: self.config = json.load(f)
136
+
137
+ base_vision_model_name = self.config.get("base_vision_model")
138
+ print(f"Loading vision model: {base_vision_model_name}")
139
+
140
+ # --- UPDATED LOADING LOGIC ---
141
+ is_dinov3_8bit = "dinov3" in base_vision_model_name and "8bit" in base_vision_model_name
142
+
143
+ if is_dinov3_8bit:
144
+ # Use your 8-bit model from the Hub
145
+ self.hf_processor = AutoProcessor.from_pretrained("facebook/dinov3-base") # Processor is usually from the base model
146
+ self.vision_model = AutoModel.from_pretrained(
147
+ base_vision_model_name,
148
+ load_in_8bit=True,
149
+ trust_remote_code=True
150
+ ).eval()
151
+ else: # For SigLIP or other non-8bit models
152
+ self.hf_processor = AutoProcessor.from_pretrained(base_vision_model_name)
153
+ self.vision_model = AutoModel.from_pretrained(
154
+ base_vision_model_name,
155
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32 # Use a dynamic dtype
156
+ ).to(DEVICE).eval()
157
+
158
+ head_model_path = hf_hub_download(repo_id=repo_id, filename=head_filename)
159
+ print(f"Loading head model: {head_filename}")
160
+ state_dict = load_file(head_model_path, device='cpu')
161
+ head_params = self.config.get("predictor_params", self.config)
162
+ self.head_model = HybridHeadModel(
163
+ features=head_params.get("features"), hidden_dim=head_params.get("hidden_dim"),
164
+ num_classes=self.config.get("num_classes"), use_attention=head_params.get("use_attention"),
165
+ num_attn_heads=head_params.get("num_attn_heads"), attn_dropout=head_params.get("attn_dropout"),
166
+ num_res_blocks=head_params.get("num_res_blocks"), dropout_rate=head_params.get("dropout_rate"),
167
+ output_mode=head_params.get("output_mode", "linear"))
168
+ self.head_model.load_state_dict(state_dict, strict=True)
169
+ self.head_model.to(DEVICE).eval()
170
+ raw_labels = self.config.get("labels", {'0': 'Bad', '1': 'Good'})
171
+ self.labels = {int(k): (v['name'] if isinstance(v, dict) else v) for k, v in raw_labels.items()}
172
+ self.current_model_name = model_name
173
+ print(f"Successfully loaded '{model_name}'.")
174
+ except Exception as e:
175
+ self.current_model_name = None
176
+ raise RuntimeError(f"Failed to load model '{model_name}': {e}\n{traceback.format_exc()}")
177
+
178
+ # --- Global Model Manager Instance ---
179
+ model_manager = ModelManager(MODEL_CATALOG)
180
+
181
+ # --- Prediction Function (v3 from before) ---
182
+ def predict_anatomy_v3(image: Image.Image, model_name: str):
183
  if image is None: return {"Error": "No image provided"}
184
  try:
185
+ model_manager.load_model(model_name)
186
  pil_image = image.convert("RGB")
187
+ emb = None
188
 
 
189
  with torch.no_grad():
190
+ base_model_type = model_manager.config.get("base_vision_model", "")
191
+ if "dinov3" in base_model_type.lower():
192
+ current_w, current_h = pil_image.size
193
+ img_to_process = pil_image
194
+ if max(current_w, current_h) > MAX_DINOV3_RESOLUTION:
195
+ scale = MAX_DINOV3_RESOLUTION / max(current_w, current_h)
196
+ current_w, current_h = int(current_w * scale), int(current_h * scale)
197
+ img_to_process = pil_image.resize((current_w, current_h), Image.Resampling.LANCZOS)
198
+ new_w = ((current_w + DINOV3_PATCH_SIZE - 1) // DINOV3_PATCH_SIZE) * DINOV3_PATCH_SIZE
199
+ new_h = ((current_h + DINOV3_PATCH_SIZE - 1) // DINOV3_PATCH_SIZE) * DINOV3_PATCH_SIZE
200
+ if new_w != current_w or new_h != current_h:
201
+ img_to_process = img_to_process.resize((new_w, new_h), Image.Resampling.LANCZOS)
202
+ inputs = model_manager.hf_processor(images=[img_to_process], return_tensors="pt")
203
+ # For 8-bit, send inputs to the same device as the model
204
+ pixel_values = inputs.pixel_values.to(model_manager.vision_model.device)
205
+ outputs = model_manager.vision_model(pixel_values=pixel_values)
206
+ last_hidden_state = outputs.last_hidden_state
207
+ nreg = getattr(model_manager.vision_model.config, 'num_register_tokens', 0)
208
+ patch_embeddings = last_hidden_state[:, 1 + nreg:]
209
+ emb = torch.mean(patch_embeddings, dim=1)
210
+ elif "siglip" in base_model_type.lower():
211
+ inputs = model_manager.hf_processor(images=[pil_image], return_tensors="pt")
212
+ pixel_values = inputs.get("pixel_values").to(device=DEVICE, dtype=torch.float16)
213
+ if "naflex" in base_model_type.lower():
214
+ attention_mask = inputs.get("pixel_attention_mask").to(device=DEVICE)
215
+ spatial_shapes = inputs.get("spatial_shapes")
216
+ model_call_kwargs = {"pixel_values": pixel_values, "attention_mask": attention_mask,
217
+ "spatial_shapes": torch.tensor(spatial_shapes, dtype=torch.long).to(DEVICE)}
218
+ vision_model_component = getattr(model_manager.vision_model, 'vision_model', model_manager.vision_model)
219
+ emb = vision_model_component(**model_call_kwargs).pooler_output
220
+ else: emb = model_manager.vision_model.get_image_features(pixel_values=pixel_values)
221
+ else: raise ValueError(f"Unknown base model type for embedding: {base_model_type}")
222
  if emb is None: raise ValueError("Failed to get embedding.")
 
 
223
  norm = torch.linalg.norm(emb.float(), dim=-1, keepdim=True).clamp(min=1e-8)
224
  emb_normalized = emb / norm.to(emb.dtype)
 
 
225
  with torch.no_grad():
226
+ prediction = model_manager.head_model(emb_normalized.to(DEVICE, dtype=HEAD_DTYPE))
 
 
227
  output_probs = {}
228
+ if model_manager.head_model.num_classes == 2:
229
+ probs = F.softmax(prediction.squeeze().float(), dim=-1)
230
+ output_probs[model_manager.labels[0]] = probs[0].item()
231
+ output_probs[model_manager.labels[1]] = probs[1].item()
 
 
 
 
 
 
 
 
 
 
232
  else:
233
+ prob_good = torch.sigmoid(prediction.squeeze()).item()
234
+ output_probs[model_manager.labels[0]] = 1.0 - prob_good
235
+ output_probs[model_manager.labels[1]] = prob_good
 
236
  return output_probs
 
237
  except Exception as e:
238
  print(f"Error during prediction: {e}\n{traceback.format_exc()}")
239
  return {"Error": str(e)}
240
 
241
  # --- Gradio Interface ---
242
+ # (Unchanged)
243
  DESCRIPTION = """
244
+ ## Lumi's Anatomy Flaw Classifier Demo ✨
245
+ Select a model from the dropdown, then upload an image to classify its anatomy/structure.
 
 
246
  """
 
 
247
  EXAMPLE_DIR = "examples"
248
  examples = []
249
  if os.path.isdir(EXAMPLE_DIR):
250
  examples = [os.path.join(EXAMPLE_DIR, fname) for fname in sorted(os.listdir(EXAMPLE_DIR)) if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
251
 
252
+ default_model = list(MODEL_CATALOG.keys())[0]
253
+ interface = gr.Interface(fn=predict_anatomy_v3, inputs=[gr.Image(type="pil", label="Input Image"), gr.Dropdown(choices=list(MODEL_CATALOG.keys()), value=default_model, label="Classifier Model")], outputs=gr.Label(label="Class Probabilities", num_top_classes=2), title="Lumi's Anatomy Classifier", description=DESCRIPTION, examples=examples if examples else None, allow_flagging="never", cache_examples=False)
 
 
 
 
 
 
 
 
254
 
255
  if __name__ == "__main__":
256
+ try:
257
+ print("Pre-loading default model...")
258
+ model_manager.load_model(default_model)
259
+ except Exception as e:
260
+ print(f"WARNING: Could not pre-load default model. Error: {e}")
261
  interface.launch()