File size: 22,782 Bytes
b492e55
329f1e9
 
 
 
 
b492e55
 
 
 
 
 
329f1e9
b492e55
 
 
329f1e9
b492e55
 
 
329f1e9
b492e55
329f1e9
b492e55
 
 
 
 
 
 
 
 
 
 
 
 
329f1e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97397b1
 
 
 
329f1e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b492e55
 
 
 
 
329f1e9
b492e55
 
329f1e9
 
 
 
 
 
 
 
 
 
 
 
 
 
b492e55
 
329f1e9
b492e55
 
 
329f1e9
b492e55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329f1e9
 
b492e55
 
 
329f1e9
 
b492e55
 
 
 
 
 
 
 
 
329f1e9
 
 
b492e55
329f1e9
 
b492e55
329f1e9
b492e55
329f1e9
 
b492e55
329f1e9
b492e55
329f1e9
 
b492e55
 
329f1e9
b492e55
 
 
 
329f1e9
b492e55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329f1e9
b492e55
329f1e9
b492e55
 
329f1e9
 
b492e55
 
 
 
329f1e9
b492e55
 
329f1e9
b492e55
 
 
 
 
 
 
329f1e9
b492e55
 
 
 
 
329f1e9
 
b492e55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329f1e9
b492e55
 
 
 
 
 
329f1e9
b492e55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329f1e9
b492e55
 
 
 
 
329f1e9
 
b492e55
 
329f1e9
 
 
b492e55
 
 
329f1e9
b492e55
 
 
 
329f1e9
b492e55
 
329f1e9
 
b492e55
329f1e9
b492e55
 
329f1e9
 
b492e55
 
 
329f1e9
b492e55
329f1e9
b492e55
 
329f1e9
b492e55
329f1e9
b492e55
329f1e9
b492e55
 
 
329f1e9
b492e55
 
 
 
 
 
 
 
 
 
 
 
 
329f1e9
b492e55
 
 
 
 
 
 
 
11917ec
329f1e9
 
b492e55
329f1e9
 
 
 
 
 
b492e55
 
 
 
329f1e9
b492e55
329f1e9
b492e55
329f1e9
b492e55
 
329f1e9
b492e55
329f1e9
b492e55
329f1e9
b492e55
 
 
 
 
329f1e9
 
 
 
 
 
 
b492e55
329f1e9
b492e55
329f1e9
 
 
 
 
 
 
 
 
 
 
b492e55
 
 
329f1e9
b492e55
 
329f1e9
b492e55
329f1e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b492e55
 
329f1e9
b492e55
 
 
 
 
 
 
 
329f1e9
 
 
 
 
 
b492e55
 
 
329f1e9
b492e55
329f1e9
b492e55
 
 
 
 
329f1e9
b492e55
 
 
 
 
 
 
329f1e9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
"""
🖼️→📝 Chest X-ray Report Generation + Attention Visualizer + Classification
- Loads generation model (complete_model.safetensor)
- Loads classification model (classification.pth)
- Generates report and visualizes attention.
- Lists disease probabilities.
"""

import os
import re
import random
from typing import List, Tuple, Optional
import logging

import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from safetensors.torch import load_model 
from transformers import AutoModel, AutoImageProcessor

# Optional: nicer colormap
try:
    import matplotlib as mpl
    _HAS_MPL = True
    _COLORMAP = mpl.colormaps.get_cmap("magma")
except Exception:
    _HAS_MPL = False
    _COLORMAP = None

# ========= Your utilities & model =========
from utils.processing import image_transform, pil_from_path
from utils.complete_model import create_complete_model

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==============================================================================
# 1. CLASSIFIER LOGIC (Added)
# ==============================================================================
class EmbeddingClassifier(nn.Module):
    def __init__(self, embedding_dim, num_classes, custom_dims=(512, 256, 256),
                 activation="gelu", dropout=0.05, bn=False, use_layernorm=True):
        super().__init__()
        layers = []
        layers.append(nn.Linear(embedding_dim, custom_dims[0]))
        if use_layernorm: layers.append(nn.LayerNorm(custom_dims[0]))
        elif bn: layers.append(nn.BatchNorm1d(custom_dims[0]))
        layers.append(nn.GELU() if activation.lower() == "gelu" else nn.ReLU())
        if dropout > 0: layers.append(nn.Dropout(dropout))
        for i in range(len(custom_dims) - 1):
            layers.append(nn.Linear(custom_dims[i], custom_dims[i + 1]))
            if use_layernorm: layers.append(nn.LayerNorm(custom_dims[i + 1]))
            elif bn: layers.append(nn.BatchNorm1d(custom_dims[i + 1]))
            layers.append(nn.GELU() if activation.lower() == "gelu" else nn.ReLU())
            if dropout > 0: layers.append(nn.Dropout(dropout))
        layers.append(nn.Linear(custom_dims[-1], num_classes))
        self.classifier = nn.Sequential(*layers)

    def forward(self, embeddings):
        return self.classifier(embeddings)

class ChestXrayPredictor:
    def __init__(self, base_model, classifier, processor, label_cols, device):
        self.base_model = base_model
        self.classifier = classifier
        self.processor = processor
        self.label_cols = label_cols
        self.device = device
        self.base_model.eval()
        self.classifier.eval()

    def predict(self, image_source):
        try:
            if isinstance(image_source, str):
                image = Image.open(image_source).convert('RGB')
            else:
                image = image_source.convert('RGB')
            inputs = self.processor(images=image, return_tensors="pt")
            pixel_values = inputs['pixel_values'].to(self.device)
            with torch.no_grad():
                outputs = self.base_model(pixel_values=pixel_values)
                if hasattr(outputs, 'last_hidden_state'):
                    embeddings = outputs.last_hidden_state.mean(dim=1)
                else:
                    embeddings = outputs[0].mean(dim=1)
                logits = self.classifier(embeddings)
                probs = torch.sigmoid(logits).cpu().numpy()[0].tolist()
            return {label: prob for label, prob in zip(self.label_cols, probs)}
        except Exception as e:
            print(f"Prediction Error: {e}")
            return {}

def create_classifier(checkpoint_path, model_id="facebook/dinov3-vits16-pretrain-lvd1689m", device=None):
    device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Loading Classifier from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    label_cols = checkpoint.get('label_cols', [
            "Cardiomegaly", "Consolidation", "Edema",
            "Atelectasis", "Pleural Effusion", "No Findings"
        ])
    
    base_model = AutoModel.from_pretrained(model_id).to(device)
    if 'base_model_state_dict' in checkpoint:
        base_model.load_state_dict(checkpoint['base_model_state_dict'])
    
    processor = AutoImageProcessor.from_pretrained(model_id)
    
    # Detect dims
    with torch.no_grad():
        dummy = torch.randn(1, 3, 224, 224).to(device)
        out = base_model(pixel_values=dummy)
        embedding_dim = out.last_hidden_state.shape[-1]

    # Rebuild MLP
    model_state = checkpoint['model_state_dict']
    linear_layers = []
    for key, val in model_state.items():
        if 'classifier' in key and key.endswith('.weight') and len(val.shape) == 2:
            match = re.search(r'classifier\.(\d+)\.weight', key)
            if match:
                linear_layers.append((int(match.group(1)), val.shape[1], val.shape[0]))
    linear_layers.sort(key=lambda x: x[0])
    num_classes = linear_layers[-1][2]
    hidden_dims = tuple([x[2] for x in linear_layers[:-1]])
    
    uses_bn = any('running_mean' in k for k in model_state.keys())
    has_norm = any(k.endswith('.weight') and len(model_state[k].shape) == 1 for k in model_state.keys() if 'classifier' in k)
    
    classifier = EmbeddingClassifier(embedding_dim, num_classes, custom_dims=hidden_dims, bn=uses_bn, use_layernorm=(has_norm and not uses_bn))
    classifier.load_state_dict(model_state)
    classifier.to(device)
    
    return ChestXrayPredictor(base_model, classifier, processor, label_cols, device)

# ==============================================================================
# 2. LOAD MODELS
# ==============================================================================

# A. Load Generator
print("Loading Generation Model...")
model = create_complete_model(device=DEVICE, attention_implementation="eager")
SAFETENSOR_PATH = "complete_model.safetensor"
try:
    load_model(model, SAFETENSOR_PATH)
except Exception as e:
    print(f"Error loading generation model: {e}")
model.eval()

# B. Load Classifier
print("Loading Classification Model...")
CLASSIFIER_PATH = "classification.pth"
classifier_model = None
try:
    if os.path.exists(CLASSIFIER_PATH):
        classifier_model = create_classifier(CLASSIFIER_PATH, device=DEVICE)
        print("✅ Classifier loaded.")
    else:
        print(f"⚠️ Classifier not found at {CLASSIFIER_PATH}")
except Exception as e:
    print(f"⚠️ Error loading classifier: {e}")

# --- Tokenizer setup ---
tokenizer = getattr(model, "tokenizer", None)
if tokenizer is None:
    raise ValueError("Expected `model.tokenizer` to exist.")

pad_id = getattr(tokenizer, "pad_token_id", None)
eos_id = getattr(tokenizer, "eos_token_id", None)
needs_resize = False
if pad_id is None or (eos_id is not None and pad_id == eos_id):
    tokenizer.add_special_tokens({"pad_token": "<|pad|>"})
    needs_resize = True

if needs_resize:
    resize_fns = [
        getattr(getattr(model, "decoder", None), "resize_token_embeddings", None),
        getattr(model, "resize_token_embeddings", None),
    ]
    for fn in resize_fns:
        if callable(fn):
            try:
                fn(len(tokenizer))
                break
            except Exception:
                pass

WORD_RE = re.compile(r"\w+(?:'\w+)?|[^\w\s]")

# ========= Logic =========

def model_heads_layers():
    def _get(obj, *names, default=None):
        for n in names:
            if obj is None: return default
            if hasattr(obj, n): return int(getattr(obj, n))
        return default

    cfg_candidates = [
        getattr(model, "config", None),
        getattr(getattr(model, "decoder", None), "config", None),
        getattr(getattr(model, "lm_head", None), "config", None),
    ]
    L = H = None
    for cfg in cfg_candidates:
        if L is None: L = _get(cfg, "num_hidden_layers", "n_layer")
        if H is None: H = _get(cfg, "num_attention_heads", "n_head")
    return max(1, L or 12), max(1, H or 12)

def get_attention_for_token_layer(attentions, token_index, layer_index, batch_index=0, head_index=0, mean_across_layers=True, mean_across_heads=True):
    token_attention = attentions[token_index]
    if mean_across_layers:
        layer_attention = torch.stack(token_attention).mean(dim=0)
    else:
        layer_attention = token_attention[int(layer_index)]
    batch_attention = layer_attention[int(batch_index)]
    if mean_across_heads:
        head_attention = batch_attention.mean(dim=0)
    else:
        head_attention = batch_attention[int(head_index)]
    return head_attention.squeeze(0)

def _words_and_map_from_tokens_simple(token_ids: List[int]) -> Tuple[List[str], List[int]]:
    if not token_ids: return [], []
    toks = tokenizer.convert_ids_to_tokens(token_ids)
    detok = tokenizer.convert_tokens_to_string(toks)
    matches = list(re.finditer(WORD_RE, detok))
    words = [m.group(0) for m in matches]
    ends = [m.span()[1] for m in matches]
    word2tok: List[int] = []
    for we in ends:
        prefix_ids = tokenizer.encode(detok[:we], add_special_tokens=False)
        if not prefix_ids:
            word2tok.append(0)
            continue
        last_idx = len(prefix_ids) - 1
        last_idx = max(0, min(last_idx, len(token_ids) - 1))
        word2tok.append(last_idx)
    return words, word2tok

def _strip_trailing_special(ids: List[int]) -> List[int]:
    specials = set(getattr(tokenizer, "all_special_ids", []) or [])
    j = len(ids)
    while j > 0 and ids[j - 1] in specials:
        j -= 1
    return ids[:j]

def generate_word_visualization_gen_only(words_gen, word_ends_rel, gen_attn_values, selected_token_rel_idx):
    if not words_gen or gen_attn_values is None or len(gen_attn_values) == 0:
        return "<div style='width:100%;'>No text attention values.</div>"
    starts = []
    for i, end in enumerate(word_ends_rel):
        if i == 0: starts.append(0)
        else: starts.append(min(word_ends_rel[i - 1] + 1, end))
    word_scores = []
    T = len(gen_attn_values)
    for i, end in enumerate(word_ends_rel):
        start = starts[i]
        if start > end: start = end
        s = max(0, min(start, T - 1))
        e = max(0, min(end,   T - 1))
        if e < s: s, e = e, s
        word_scores.append(float(gen_attn_values[s:e + 1].sum()))
    max_attn = max(0.1, float(max(word_scores)) if word_scores else 0.0)
    selected_word_idx = None
    for i, end in enumerate(word_ends_rel):
        if selected_token_rel_idx <= end:
            selected_word_idx = i
            break
    if selected_word_idx is None and word_ends_rel: selected_word_idx = len(word_ends_rel) - 1
    spans = []
    for i, w in enumerate(words_gen):
        alpha = min(1.0, word_scores[i] / max_attn) if max_attn > 0 else 0.0
        bg = f"rgba(66,133,244,{alpha:.3f})"
        border = "2px solid #fff" if i == selected_word_idx else "1px solid transparent"
        spans.append(f"<span style='display:inline-block;background:{bg};border:{border};border-radius:6px;padding:2px 6px;margin:2px 4px 4px 0;color:#fff;'>{w}</span>")
    return f"<div style='width:100%;'><div style='background:#444;border:1px solid #eee;border-radius:8px;padding:10px;'><div style='white-space:normal;line-height:1.8;'>{''.join(spans)}</div></div></div>"

def _attention_to_heatmap_uint8(attn_1d: np.ndarray, img_token_len: int = 1024, side: int = 32) -> np.ndarray:
    if attn_1d.shape[0] < img_token_len:
        img_part = np.zeros(img_token_len, dtype=float)
        img_part[: attn_1d.shape[0]] = attn_1d
    else:
        img_part = attn_1d[:img_token_len]
    mn, mx = float(img_part.min()), float(img_part.max())
    denom = (mx - mn) if (mx - mn) > 1e-12 else 1.0
    norm = (img_part - mn) / denom
    return (norm.reshape(side, side) * 255.0).astype(np.uint8)

def _colorize_heatmap(heatmap_u8: np.ndarray) -> Image.Image:
    if _HAS_MPL and _COLORMAP is not None:
        colored = (_COLORMAP(heatmap_u8.astype(np.float32) / 255.0)[:, :, :3] * 255.0).astype(np.uint8)
        return Image.fromarray(colored)
    else:
        g = heatmap_u8.astype(np.float32) / 255.0
        r = (g * 255.0).clip(0, 255).astype(np.uint8)
        g2 = (np.sqrt(g) * 255.0).clip(0, 255).astype(np.uint8)
        b = np.zeros_like(r, dtype=np.uint8)
        rgb = np.stack([r, g2, b], axis=-1)
        return Image.fromarray(rgb)

def _resize_like(img: Image.Image, target_size: Tuple[int, int]) -> Image.Image:
    return img.resize(target_size, resample=Image.BILINEAR)

def _make_overlay(orig: Image.Image, heatmap_rgb: Image.Image, alpha: float = 0.35) -> Image.Image:
    if heatmap_rgb.size != orig.size:
        heatmap_rgb = _resize_like(heatmap_rgb, orig.size)
    base = orig.convert("RGBA")
    overlay = heatmap_rgb.convert("RGBA")
    r, g, b = overlay.split()[:3]
    a = Image.new("L", overlay.size, int(alpha * 255))
    overlay = Image.merge("RGBA", (r, g, b, a))
    return Image.alpha_composite(base, overlay).convert("RGB")

def _prepare_image_tensor(pil_img, img_size=512):
    tfm = image_transform(img_size=img_size)
    tens = tfm(pil_img).unsqueeze(0).to(DEVICE, non_blocking=True)
    return tens

def run_generation(pil_image, max_new_tokens, layer, head, mean_layers, mean_heads):
    if pil_image is None:
        blank = Image.new("RGB", (256, 256), "black")
        return (None, None, 1024, None, None, gr.update(choices=[], value=None), blank, blank, np.zeros((256, 256, 3), dtype=np.uint8), "<div style='text-align:center;'>Upload image first.</div>")
    
    pixel_values = _prepare_image_tensor(pil_image, img_size=512)
    with torch.no_grad():
        gen_ids, gen_text, attentions = model.generate(pixel_values=pixel_values, max_new_tokens=int(max_new_tokens), output_attentions=True)
    
    if isinstance(gen_ids, torch.Tensor): gen_ids = gen_ids[0].tolist()
    gen_ids = _strip_trailing_special(gen_ids)
    words_gen, gen_word2tok_rel = _words_and_map_from_tokens_simple(gen_ids)
    display_choices = [(w, i) for i, w in enumerate(words_gen)]
    
    if not display_choices:
        blank_hm = np.zeros((32, 32), dtype=np.uint8)
        hm_rgb = _colorize_heatmap(blank_hm).resize(pil_image.size, resample=Image.NEAREST)
        overlay = _make_overlay(pil_image, hm_rgb, alpha=0.35)
        return (attentions, gen_ids, 1024, words_gen, gen_word2tok_rel, gr.update(choices=[], value=None), pil_image, overlay, np.array(hm_rgb), "<div>No tokens.</div>")

    first_idx = 0
    hm_rgb_init, overlay_init, html_init = update_visualization(first_idx, attentions, gen_ids, layer, head, mean_layers, mean_heads, words_gen, gen_word2tok_rel, pil_image)
    return (attentions, gen_ids, 1024, words_gen, gen_word2tok_rel, gr.update(choices=display_choices, value=first_idx), pil_image, overlay_init, hm_rgb_init, html_init)

def update_visualization(selected_gen_index, attentions, gen_token_ids, layer, head, mean_layers, mean_heads, words_gen, gen_word2tok_rel, pil_image: Optional[Image.Image] = None):
    if selected_gen_index is None or attentions is None or gen_word2tok_rel is None:
        blank = np.zeros((256, 256, 3), dtype=np.uint8)
        return Image.fromarray(blank), Image.fromarray(blank), "<div>Generate first.</div>"
    
    gidx = int(selected_gen_index)
    if not (0 <= gidx < len(gen_word2tok_rel)):
        blank = np.zeros((256, 256, 3), dtype=np.uint8)
        return Image.fromarray(blank), Image.fromarray(blank), "<div>Invalid selection.</div>"

    step_index = int(gen_word2tok_rel[gidx])
    if not attentions or step_index >= len(attentions):
        blank = np.zeros((256, 256, 3), dtype=np.uint8)
        return Image.fromarray(blank), Image.fromarray(blank), "<div>No attention.</div>"

    token_attn = get_attention_for_token_layer(attentions, token_index=step_index, layer_index=int(layer), head_index=int(head), mean_across_layers=bool(mean_layers), mean_across_heads=bool(mean_heads))
    attn_vals = token_attn.detach().cpu().numpy()
    if attn_vals.ndim == 2: attn_vals = attn_vals[-1]

    heatmap_u8 = _attention_to_heatmap_uint8(attn_1d=attn_vals, img_token_len=1024, side=32)
    hm_rgb_pil = _colorize_heatmap(heatmap_u8)
    if pil_image is None: pil_image = Image.new("RGB", (256, 256), "black")
    hm_rgb_pil_up = hm_rgb_pil.resize(pil_image.size, resample=Image.NEAREST)
    overlay_pil = _make_overlay(pil_image, hm_rgb_pil_up, alpha=0.35)

    k_len = int(attn_vals.shape[0])
    observed_gen = max(0, min(step_index + 1, max(0, k_len - 1024)))
    total_gen = len(gen_token_ids)
    gen_vec = np.zeros(total_gen, dtype=float)
    if observed_gen > 0:
        start = 1024
        end = min(1024 + observed_gen, k_len)
        gen_slice = attn_vals[start:end]
        gen_vec[: len(gen_slice)] = gen_slice

    html_words = generate_word_visualization_gen_only(words_gen, gen_word2tok_rel, gen_vec, step_index)
    return np.array(hm_rgb_pil_up), overlay_pil, html_words

def toggle_slider(is_mean):
    return gr.update(interactive=not bool(is_mean))

# ========= Gradio UI =========
EXAMPLES_DIR = "examples"

with gr.Blocks() as demo:
    gr.Markdown("# 🖼️→📝 Chest X-ray Report Generation & Classification")
    
    # States
    state_attentions = gr.State(None)
    state_gen_token_ids = gr.State(None)
    state_img_token_len = gr.State(1024)
    state_words_gen = gr.State(None)
    state_gen_word2tok_rel = gr.State(None)
    state_last_image = gr.State(None)

    L, H = model_heads_layers()

    with gr.Row():
        # LEFT COLUMN
        with gr.Column(scale=1):
            gr.Markdown("### 1) Input")
            img_input = gr.Image(type="pil", label="Upload image", height=280)
            btn_load_sample = gr.Button("Load random sample", variant="secondary")
            sample_status = gr.Markdown("")

            gr.Markdown("### 2) Generation Settings")
            slider_max_tokens = gr.Slider(5, 200, value=100, step=5, label="Max New Tokens")
            btn_generate = gr.Button("GENERATE REPORT & CLASSIFY", variant="primary")

            gr.Markdown("### 3) Attention Visualization")
            check_mean_layers = gr.Checkbox(False, label="Mean Across Layers")
            check_mean_heads = gr.Checkbox(False, label="Mean Across Heads")
            slider_layer = gr.Slider(0, max(0, L - 1), value=0, step=1, label="Layer", interactive=True)
            slider_head  = gr.Slider(0, max(0, H - 1), value=0, step=1, label="Head",  interactive=True)

            # --- NEW CLASSIFICATION SECTION ---
            gr.Markdown("### 4) Disease Probability")
            classification_output = gr.Dataframe(
                headers=["Disease", "Probability"],
                datatype=["str", "str"],
                label="Predictions",
                interactive=False
            )
            # ----------------------------------

        # RIGHT COLUMN
        with gr.Column(scale=3):
            with gr.Row():
                img_original_view = gr.Image(label="Original", image_mode="RGB", height=256)
                img_overlay_view = gr.Image(label="Attention Overlay", image_mode="RGB", height=256)
                heatmap_view = gr.Image(label="Heatmap", image_mode="RGB", height=256)
            
            radio_word_selector = gr.Radio([], label="Select Generated Word", info="Shows attention for this word")
            html_visualization = gr.HTML("<div style='text-align:center;padding:20px;color:#888;'>Text attention visualization will appear here.</div>")

    # Sample loader
    def _load_sample_from_examples():
        try:
            files = [f for f in os.listdir(EXAMPLES_DIR) if not f.startswith(".")]
            if not files: return gr.update(), "No files."
            fp = os.path.join(EXAMPLES_DIR, random.choice(files))
            pil_img = pil_from_path(fp)
            return gr.update(value=pil_img), f"Loaded: {os.path.basename(fp)}"
        except Exception as e:
            return gr.update(), f"Error: {e}"

    btn_load_sample.click(_load_sample_from_examples, inputs=[], outputs=[img_input, sample_status])

    # MAIN RUN FUNCTION (GENERATION + CLASSIFICATION)
    def _run_all_logic(pil_image, *args):
        # 1. Run Generation (returns 10 items)
        gen_results = run_generation(pil_image, *args)
        
        # 2. Run Classification
        classification_data = []
        if pil_image and classifier_model:
            try:
                preds = classifier_model.predict(pil_image)
                # Sort by probability descending
                sorted_preds = sorted(preds.items(), key=lambda x: x[1], reverse=True)
                # Format as list of lists for Gradio Dataframe: ["Name", "95.5%"]
                classification_data = [[k, f"{v:.1f}%"] for k, v in sorted_preds]
            except Exception as e:
                print(f"Classification runtime error: {e}")
                classification_data = [["Error", str(e)]]
        
        # Combine: gen_results + original_image (for state) + classification_data
        return (*gen_results, pil_image, classification_data)

    btn_generate.click(
        fn=_run_all_logic,
        inputs=[img_input, slider_max_tokens, slider_layer, slider_head, check_mean_layers, check_mean_heads],
        outputs=[
            state_attentions,
            state_gen_token_ids,
            state_img_token_len,
            state_words_gen,
            state_gen_word2tok_rel,
            radio_word_selector,
            img_original_view,
            img_overlay_view,
            heatmap_view,
            html_visualization,
            state_last_image,    # Added to outputs
            classification_output # Added to outputs
        ],
    )

    # UI updates for visualizer controls
    def _update_wrapper(selected_gen_index, attn, gen_ids, lyr, hed, meanL, meanH, words, word2tok, last_img):
        hm_rgb, overlay, html = update_visualization(selected_gen_index, attn, gen_ids, lyr, hed, meanL, meanH, words, word2tok, pil_image=last_img)
        return overlay, hm_rgb, html

    for control in [radio_word_selector, slider_layer, slider_head, check_mean_layers, check_mean_heads]:
        control.change(
            fn=_update_wrapper,
            inputs=[radio_word_selector, state_attentions, state_gen_token_ids, slider_layer, slider_head, check_mean_layers, check_mean_heads, state_words_gen, state_gen_word2tok_rel, state_last_image],
            outputs=[img_overlay_view, heatmap_view, html_visualization],
        )

    check_mean_layers.change(toggle_slider, check_mean_layers, slider_layer)
    check_mean_heads.change(toggle_slider, check_mean_heads, slider_head)

if __name__ == "__main__":
    print(f"Device: {DEVICE}")
    demo.launch(debug=True)