AdithyaSK commited on
Commit
bb60941
ยท
1 Parent(s): 0bb07b7

Refactor requirements.txt: streamline dependencies and update package sources

Browse files
Files changed (3) hide show
  1. README.md +48 -5
  2. app.py +868 -529
  3. requirements.txt +9 -14
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: VARAG
3
- emoji: ๐Ÿ 
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
@@ -8,9 +8,52 @@ sdk_version: 4.44.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: Vision First RAG engine
 
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
 
16
- First commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: NetraEmbed
3
+ emoji: ๐Ÿ‘๏ธ
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: Universal Multilingual Multimodal Document Retrieval
12
+ hardware: zero-gpu
13
  ---
14
 
15
+ # NetraEmbed - Universal Multilingual Multimodal Document Retrieval
16
 
17
+ This Space demonstrates **NetraEmbed** and **ColNetraEmbed**, state-of-the-art multilingual multimodal document retrieval models based on the BiGemma3 and ColGemma3 architectures.
18
+
19
+ ## Features
20
+
21
+ - **NetraEmbed (BiGemma3)**: Single-vector embedding with Matryoshka representation for fast retrieval
22
+ - **ColNetraEmbed (ColGemma3)**: Multi-vector embedding with late interaction for high-quality retrieval with attention heatmaps
23
+ - **ZeroGPU Integration**: Efficient dynamic GPU allocation for on-demand model loading
24
+ - **PDF Document Support**: Upload PDFs and perform semantic search across pages
25
+ - **Side-by-side Comparison**: Compare both models simultaneously
26
+
27
+ ## Citation
28
+
29
+ If you use NetraEmbed or ColNetraEmbed in your research, please cite:
30
+
31
+ ```bibtex
32
+ @misc{kolavi2025m3druniversalmultilingualmultimodal,
33
+ title={M3DR: Towards Universal Multilingual Multimodal Document Retrieval},
34
+ author={Adithya S Kolavi and Vyoman Jain},
35
+ year={2025},
36
+ eprint={2512.03514},
37
+ archivePrefix={arXiv},
38
+ primaryClass={cs.IR},
39
+ url={https://arxiv.org/abs/2512.03514}
40
+ }
41
+ ```
42
+
43
+ ## Links
44
+
45
+ - ๐Ÿ“„ [Paper](https://arxiv.org/abs/2512.03514)
46
+ - ๐Ÿ’ป [GitHub](https://github.com/adithya-s-k/colpali)
47
+ - ๐Ÿค— [Models on Hugging Face](https://huggingface.co/Cognitive-Lab)
48
+ - ๐ŸŒ [CognitiveLab Website](https://www.cognitivelab.in)
49
+
50
+ ## Usage
51
+
52
+ 1. **Load Model**: Select your preferred model (NetraEmbed, ColNetraEmbed, or Both) and click "Load Model"
53
+ 2. **Upload PDF**: Upload a PDF document to index
54
+ 3. **Index Document**: Click "Index Document" to process and embed the pages
55
+ 4. **Query**: Enter your search query and click "Search" to retrieve relevant pages
56
+
57
+ This Space uses ZeroGPU for dynamic GPU allocation. Models are loaded on-demand when functions are called.
58
+
59
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,577 +1,916 @@
1
- import gradio as gr
2
- import os
3
- import lancedb
4
- from sentence_transformers import SentenceTransformer
5
- from dotenv import load_dotenv
6
- from typing import List
7
- from PIL import Image
8
- import base64
 
 
 
 
 
 
 
 
 
9
  import io
10
- import time
11
- from collections import namedtuple
12
- import pandas as pd
13
- import concurrent.futures
14
- from varag.rag import SimpleRAG, VisionRAG, ColpaliRAG, HybridColpaliRAG
15
- from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
16
- from varag.chunking import FixedTokenChunker
17
- from varag.utils import get_model_colpali
18
- import argparse
19
- import spaces
20
  import torch
21
- from docling.document_converter import DocumentConverter
22
-
23
- load_dotenv()
24
-
25
- # Initialize shared database
26
- shared_db = lancedb.connect("~/rag_demo_db")
27
-
28
-
29
-
30
-
31
- # Initialize embedding models
32
- # text_embedding_model = SentenceTransformer("all-MiniLM-L6-v2", trust_remote_code=True)
33
- text_embedding_model = SentenceTransformer(
34
- "BAAI/bge-base-en", trust_remote_code=True
35
- )
36
- # text_embedding_model = SentenceTransformer("BAAI/bge-large-en-v1.5", trust_remote_code=True)
37
- # text_embedding_model = SentenceTransformer("BAAI/bge-small-en-v1.5", trust_remote_code=True)
38
- image_embedding_model = SentenceTransformer(
39
- "jinaai/jina-clip-v1", trust_remote_code=True
40
- )
41
- colpali_model, colpali_processor = get_model_colpali("vidore/colpali-v1.2")
42
-
43
- converter = DocumentConverter()
44
-
45
- # Initialize RAG instances
46
- simple_rag = SimpleRAG(
47
- text_embedding_model=text_embedding_model, db=shared_db, table_name="simpleDemo"
48
- )
49
- vision_rag = VisionRAG(
50
- image_embedding_model=image_embedding_model, db=shared_db, table_name="visionDemo"
51
- )
52
- colpali_rag = ColpaliRAG(
53
- colpali_model=colpali_model,
54
- colpali_processor=colpali_processor,
55
- db=shared_db,
56
- table_name="colpaliDemo",
57
- )
58
- hybrid_rag = HybridColpaliRAG(
59
- colpali_model=colpali_model,
60
- colpali_processor=colpali_processor,
61
- image_embedding_model=image_embedding_model,
62
- db=shared_db,
63
- table_name="hybridDemo",
64
- )
65
-
66
-
67
- IngestResult = namedtuple("IngestResult", ["status_text", "progress_table"])
68
-
69
-
70
- # @spaces.GPU(duration=120)
71
- # def ingest_data(pdf_files, use_ocr, chunk_size, progress=gr.Progress()):
72
- # file_paths = [pdf_file.name for pdf_file in pdf_files]
73
- # total_start_time = time.time()
74
- # progress_data = []
75
-
76
- # # SimpleRAG
77
- # yield IngestResult(
78
- # status_text="Starting SimpleRAG ingestion...\n",
79
- # progress_table=pd.DataFrame(progress_data),
80
- # )
81
- # start_time = time.time()
82
- # simple_rag.index(
83
- # file_paths,
84
- # recursive=False,
85
- # chunking_strategy=FixedTokenChunker(chunk_size=chunk_size),
86
- # metadata={"source": "gradio_upload"},
87
- # overwrite=True,
88
- # verbose=True,
89
- # ocr=use_ocr,
90
- # )
91
- # simple_time = time.time() - start_time
92
- # progress_data.append(
93
- # {"Technique": "SimpleRAG", "Time Taken (s)": f"{simple_time:.2f}"}
94
- # )
95
- # yield IngestResult(
96
- # status_text=f"SimpleRAG ingestion complete. Time taken: {simple_time:.2f} seconds\n\n",
97
- # progress_table=pd.DataFrame(progress_data),
98
- # )
99
- # # progress(0.25, desc="SimpleRAG complete")
100
-
101
- # # VisionRAG
102
- # yield IngestResult(
103
- # status_text="Starting VisionRAG ingestion...\n",
104
- # progress_table=pd.DataFrame(progress_data),
105
- # )
106
- # start_time = time.time()
107
- # vision_rag.index(file_paths, overwrite=False, recursive=False, verbose=True)
108
- # vision_time = time.time() - start_time
109
- # progress_data.append(
110
- # {"Technique": "VisionRAG", "Time Taken (s)": f"{vision_time:.2f}"}
111
- # )
112
- # yield IngestResult(
113
- # status_text=f"VisionRAG ingestion complete. Time taken: {vision_time:.2f} seconds\n\n",
114
- # progress_table=pd.DataFrame(progress_data),
115
- # )
116
- # # progress(0.5, desc="VisionRAG complete")
117
-
118
- # # ColpaliRAG
119
- # yield IngestResult(
120
- # status_text="Starting ColpaliRAG ingestion...\n",
121
- # progress_table=pd.DataFrame(progress_data),
122
- # )
123
- # start_time = time.time()
124
- # colpali_rag.index(file_paths, overwrite=False, recursive=False, verbose=True)
125
- # colpali_time = time.time() - start_time
126
- # progress_data.append(
127
- # {"Technique": "ColpaliRAG", "Time Taken (s)": f"{colpali_time:.2f}"}
128
- # )
129
- # yield IngestResult(
130
- # status_text=f"ColpaliRAG ingestion complete. Time taken: {colpali_time:.2f} seconds\n\n",
131
- # progress_table=pd.DataFrame(progress_data),
132
- # )
133
- # # progress(0.75, desc="ColpaliRAG complete")
134
-
135
- # # HybridColpaliRAG
136
- # yield IngestResult(
137
- # status_text="Starting HybridColpaliRAG ingestion...\n",
138
- # progress_table=pd.DataFrame(progress_data),
139
- # )
140
- # start_time = time.time()
141
- # hybrid_rag.index(file_paths, overwrite=False, recursive=False, verbose=True)
142
- # hybrid_time = time.time() - start_time
143
- # progress_data.append(
144
- # {"Technique": "HybridColpaliRAG", "Time Taken (s)": f"{hybrid_time:.2f}"}
145
- # )
146
- # yield IngestResult(
147
- # status_text=f"HybridColpaliRAG ingestion complete. Time taken: {hybrid_time:.2f} seconds\n\n",
148
- # progress_table=pd.DataFrame(progress_data),
149
- # )
150
- # # progress(1.0, desc="HybridColpaliRAG complete")
151
-
152
- # total_time = time.time() - total_start_time
153
- # progress_data.append({"Technique": "Total", "Time Taken (s)": f"{total_time:.2f}"})
154
- # yield IngestResult(
155
- # status_text=f"Total ingestion time: {total_time:.2f} seconds",
156
- # progress_table=pd.DataFrame(progress_data),
157
- # )
158
-
159
-
160
- def ingest_data(pdf_files, use_ocr, chunk_size, progress=gr.Progress()):
161
- file_paths = [pdf_file.name for pdf_file in pdf_files]
162
- total_start_time = time.time()
163
- progress_data = []
164
-
165
- @spaces.GPU(duration=120)
166
- def ingest_simple_rag():
167
- yield IngestResult(
168
- status_text="Starting SimpleRAG ingestion...\n",
169
- progress_table=pd.DataFrame(progress_data),
170
- )
171
- start_time = time.time()
172
- simple_rag.index(
173
- file_paths,
174
- recursive=False,
175
- chunking_strategy=FixedTokenChunker(chunk_size=chunk_size),
176
- metadata={"source": "gradio_upload"},
177
- overwrite=True,
178
- verbose=True,
179
- ocr=use_ocr,
180
- )
181
- simple_time = time.time() - start_time
182
- progress_data.append(
183
- {"Technique": "SimpleRAG", "Time Taken (s)": f"{simple_time:.2f}"}
184
- )
185
- yield IngestResult(
186
- status_text=f"SimpleRAG ingestion complete. Time taken: {simple_time:.2f} seconds\n\n",
187
- progress_table=pd.DataFrame(progress_data),
188
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
- @spaces.GPU(duration=120)
191
- def ingest_vision_rag():
192
- yield IngestResult(
193
- status_text="Starting VisionRAG ingestion...\n",
194
- progress_table=pd.DataFrame(progress_data),
195
- )
196
- start_time = time.time()
197
- vision_rag.index(file_paths, overwrite=False, recursive=False, verbose=True)
198
- vision_time = time.time() - start_time
199
- progress_data.append(
200
- {"Technique": "VisionRAG", "Time Taken (s)": f"{vision_time:.2f}"}
201
- )
202
- yield IngestResult(
203
- status_text=f"VisionRAG ingestion complete. Time taken: {vision_time:.2f} seconds\n\n",
204
- progress_table=pd.DataFrame(progress_data),
 
205
  )
206
 
207
- @spaces.GPU(duration=120)
208
- def ingest_colpali_rag():
209
- yield IngestResult(
210
- status_text="Starting ColpaliRAG ingestion...\n",
211
- progress_table=pd.DataFrame(progress_data),
212
- )
213
- start_time = time.time()
214
- colpali_rag.index(file_paths, overwrite=False, recursive=False, verbose=True)
215
- colpali_time = time.time() - start_time
216
- progress_data.append(
217
- {"Technique": "ColpaliRAG", "Time Taken (s)": f"{colpali_time:.2f}"}
218
- )
219
- yield IngestResult(
220
- status_text=f"ColpaliRAG ingestion complete. Time taken: {colpali_time:.2f} seconds\n\n",
221
- progress_table=pd.DataFrame(progress_data),
222
- )
223
 
224
- @spaces.GPU(duration=120)
225
- def ingest_hybrid_rag():
226
- yield IngestResult(
227
- status_text="Starting HybridColpaliRAG ingestion...\n",
228
- progress_table=pd.DataFrame(progress_data),
229
- )
230
- start_time = time.time()
231
- hybrid_rag.index(file_paths, overwrite=False, recursive=False, verbose=True)
232
- hybrid_time = time.time() - start_time
233
- progress_data.append(
234
- {"Technique": "HybridColpaliRAG", "Time Taken (s)": f"{hybrid_time:.2f}"}
 
 
 
 
 
 
235
  )
236
- yield IngestResult(
237
- status_text=f"HybridColpaliRAG ingestion complete. Time taken: {hybrid_time:.2f} seconds\n\n",
238
- progress_table=pd.DataFrame(progress_data),
 
 
 
 
 
239
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
- # Call each ingestion function
242
- yield from ingest_simple_rag()
243
- yield from ingest_vision_rag()
244
- yield from ingest_colpali_rag()
245
- yield from ingest_hybrid_rag()
246
-
247
- total_time = time.time() - total_start_time
248
- progress_data.append({"Technique": "Total", "Time Taken (s)": f"{total_time:.2f}"})
249
- yield IngestResult(
250
- status_text=f"Total ingestion time: {total_time:.2f} seconds",
251
- progress_table=pd.DataFrame(progress_data),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  )
253
 
 
 
 
254
 
255
- @spaces.GPU(duration=120)
256
- def retrieve_data(query, top_k, sequential=False):
257
- results = {}
258
- timings = {}
259
-
260
- def retrieve_simple():
261
- start_time = time.time()
262
- simple_results = simple_rag.search(query, k=top_k)
263
-
264
- print(simple_results)
265
-
266
- simple_context = []
267
- for i, r in enumerate(simple_results, 1):
268
- context_piece = f"Result {i}:\n"
269
- context_piece += f"Source: {r.get('document_name', 'Unknown')}\n"
270
- context_piece += f"Chunk Index: {r.get('chunk_index', 'Unknown')}\n"
271
-
272
- context_piece += f"Content:\n{r['text']}\n"
273
- context_piece += "-" * 40 + "\n" # Separator
274
- simple_context.append(context_piece)
275
-
276
- simple_context = "\n".join(simple_context)
277
- end_time = time.time()
278
- return "SimpleRAG", simple_context, end_time - start_time
279
-
280
- def retrieve_vision():
281
- start_time = time.time()
282
- vision_results = vision_rag.search(query, k=top_k)
283
- vision_images = [r["image"] for r in vision_results]
284
- end_time = time.time()
285
- return "VisionRAG", vision_images, end_time - start_time
286
-
287
- def retrieve_colpali():
288
- start_time = time.time()
289
- colpali_results = colpali_rag.search(query, k=top_k)
290
- colpali_images = [r["image"] for r in colpali_results]
291
- end_time = time.time()
292
- return "ColpaliRAG", colpali_images, end_time - start_time
293
-
294
- def retrieve_hybrid():
295
- start_time = time.time()
296
- hybrid_results = hybrid_rag.search(query, k=top_k, use_image_search=True)
297
- hybrid_images = [r["image"] for r in hybrid_results]
298
- end_time = time.time()
299
- return "HybridColpaliRAG", hybrid_images, end_time - start_time
300
-
301
- retrieval_functions = [
302
- retrieve_simple,
303
- retrieve_vision,
304
- retrieve_colpali,
305
- retrieve_hybrid,
306
- ]
307
-
308
- if sequential:
309
- for func in retrieval_functions:
310
- rag_type, content, timing = func()
311
- results[rag_type] = content
312
- timings[rag_type] = timing
313
- else:
314
- with concurrent.futures.ThreadPoolExecutor() as executor:
315
- future_results = [executor.submit(func) for func in retrieval_functions]
316
- for future in concurrent.futures.as_completed(future_results):
317
- rag_type, content, timing = future.result()
318
- results[rag_type] = content
319
- timings[rag_type] = timing
320
-
321
- return results, timings
322
-
323
-
324
- # @spaces.GPU
325
- # def query_data(query, retrieved_results):
326
- # results = {}
327
-
328
- # # SimpleRAG
329
- # simple_context = retrieved_results["SimpleRAG"]
330
- # simple_response = llm.query(
331
- # context=simple_context,
332
- # system_prompt="Given the below information answer the questions",
333
- # query=query,
334
- # )
335
- # results["SimpleRAG"] = {"response": simple_response, "context": simple_context}
336
-
337
- # # VisionRAG
338
- # vision_images = retrieved_results["VisionRAG"]
339
- # vision_context = f"Query: {query}\n\nRelevant image information:\n" + "\n".join(
340
- # [f"Image {i+1}" for i in range(len(vision_images))]
341
- # )
342
- # vision_response = vlm.query(vision_context, vision_images, max_tokens=500)
343
- # results["VisionRAG"] = {
344
- # "response": vision_response,
345
- # "context": vision_context,
346
- # "images": vision_images,
347
- # }
348
-
349
- # # ColpaliRAG
350
- # colpali_images = retrieved_results["ColpaliRAG"]
351
- # colpali_context = f"Query: {query}\n\nRelevant image information:\n" + "\n".join(
352
- # [f"Image {i+1}" for i in range(len(colpali_images))]
353
- # )
354
- # colpali_response = vlm.query(colpali_context, colpali_images, max_tokens=500)
355
- # results["ColpaliRAG"] = {
356
- # "response": colpali_response,
357
- # "context": colpali_context,
358
- # "images": colpali_images,
359
- # }
360
-
361
- # # HybridColpaliRAG
362
- # hybrid_images = retrieved_results["HybridColpaliRAG"]
363
- # hybrid_context = f"Query: {query}\n\nRelevant image information:\n" + "\n".join(
364
- # [f"Image {i+1}" for i in range(len(hybrid_images))]
365
- # )
366
- # hybrid_response = vlm.query(hybrid_context, hybrid_images, max_tokens=500)
367
- # results["HybridColpaliRAG"] = {
368
- # "response": hybrid_response,
369
- # "context": hybrid_context,
370
- # "images": hybrid_images,
371
- # }
372
-
373
- # return results
374
-
375
-
376
- def update_api_key(api_key):
377
- os.environ["OPENAI_API_KEY"] = api_key
378
- return "API key updated successfully."
379
-
380
-
381
- def change_table(simple_table, vision_table, colpali_table, hybrid_table):
382
- simple_rag.change_table(simple_table)
383
- vision_rag.change_table(vision_table)
384
- colpali_rag.change_table(colpali_table)
385
- hybrid_rag.change_table(hybrid_table)
386
- return "Table names updated successfully."
387
-
388
-
389
- def gradio_interface():
390
- with gr.Blocks(
391
- theme=gr.themes.Monochrome(radius_size=gr.themes.sizes.radius_none)
392
- ) as demo:
393
- gr.Markdown(
394
- """
395
- # ๐Ÿ‘๏ธ๐Ÿ‘๏ธ Vision RAG Playground
396
-
397
- ### Explore and Compare Vision-Augmented Retrieval Techniques
398
- Built on [VARAG](https://github.com/adithya-s-k/VARAG) - Vision-Augmented Retrieval and Generation
399
-
400
- **[โญ Star the Repository](https://github.com/adithya-s-k/VARAG)** to support the project!
401
-
402
- 1. **Simple RAG**: Text-based retrieval with OCR support for scanned documents.
403
- 2. **Vision RAG**: Combines text and image retrieval using cross-modal embeddings.
404
- 3. **ColPali RAG**: Embeds entire document pages as images for layout-aware retrieval.
405
- 4. **Hybrid ColPali RAG**: Two-stage retrieval combining image embeddings and ColPali's token-level matching.
406
-
407
- """
408
  )
409
 
410
- with gr.Tab("Ingest Data"):
411
- gr.Markdown(
412
- """
413
- ## โš ๏ธ Important Note on Data Ingestion
414
 
415
- This Space has a maximum GPU-enabled time of 120 seconds. It's recommended to try ingesting only 1 or 2 pdfs at a time.
 
 
 
 
416
 
417
- If you want to ingest a larger amount of data, please try it out in a Google Colab notebook:
 
 
418
 
419
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adithya-s-k/VARAG/blob/main/docs/demo.ipynb)
 
 
 
 
420
 
421
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  )
423
- pdf_input = gr.File(
424
- label="Upload PDF(s)", file_count="multiple", file_types=["pdf"]
425
  )
426
- use_ocr = gr.Checkbox(label="Use OCR (for SimpleRAG)")
427
- chunk_size = gr.Slider(
428
- 50, 5000, value=200, step=10, label="Chunk Size (for SimpleRAG)"
 
 
 
429
  )
430
- ingest_button = gr.Button("Ingest PDFs")
431
- ingest_output = gr.Markdown(
432
- label="Ingestion Status :",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  )
434
- progress_table = gr.DataFrame(
435
- label="Ingestion Progress", headers=["Technique", "Time Taken (s)"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  )
437
 
438
- with gr.Tab("Retrieve and Query Data"):
439
- query_input = gr.Textbox(label="Enter your query")
440
- top_k_slider = gr.Slider(1, 10, value=3, step=1, label="Top K Results")
441
- sequential_checkbox = gr.Checkbox(label="Sequential Retrieval", value=False)
442
- retrieve_button = gr.Button("Retrieve")
443
- query_button = gr.Button("Query")
 
 
 
 
 
 
 
 
 
 
 
444
 
445
- retrieval_timing = gr.DataFrame(
446
- label="Retrieval Timings", headers=["RAG Type", "Time (s)"]
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  )
448
 
449
- with gr.Row():
450
- with gr.Column():
451
- with gr.Accordion("SimpleRAG", open=True):
452
- simple_content = gr.Textbox(
453
- label="SimpleRAG Content", lines=10, max_lines=10
454
- )
455
- simple_response = gr.Markdown(label="SimpleRAG Response")
456
- with gr.Column():
457
- with gr.Accordion("VisionRAG", open=True):
458
- vision_gallery = gr.Gallery(label="VisionRAG Images")
459
- vision_response = gr.Markdown(label="VisionRAG Response")
 
 
 
 
 
 
 
 
 
460
 
461
- with gr.Row():
462
- with gr.Column():
463
- with gr.Accordion("ColpaliRAG", open=True):
464
- colpali_gallery = gr.Gallery(label="ColpaliRAG Images")
465
- colpali_response = gr.Markdown(label="ColpaliRAG Response")
466
- with gr.Column():
467
- with gr.Accordion("HybridColpaliRAG", open=True):
468
- hybrid_gallery = gr.Gallery(label="HybridColpaliRAG Images")
469
- hybrid_response = gr.Markdown(label="HybridColpaliRAG Response")
470
-
471
- with gr.Tab("Settings"):
472
- api_key_input = gr.Textbox(label="OpenAI API Key", type="password")
473
- update_api_button = gr.Button("Update API Key")
474
- api_update_status = gr.Textbox(label="API Update Status")
475
-
476
- simple_table_input = gr.Textbox(
477
- label="SimpleRAG Table Name", value="simpleDemo"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
  )
479
- vision_table_input = gr.Textbox(
480
- label="VisionRAG Table Name", value="visionDemo"
 
 
 
 
 
 
 
 
 
 
 
 
481
  )
482
- colpali_table_input = gr.Textbox(
483
- label="ColpaliRAG Table Name", value="colpaliDemo"
 
 
 
 
 
 
 
 
484
  )
485
- hybrid_table_input = gr.Textbox(
486
- label="HybridColpaliRAG Table Name", value="hybridDemo"
 
 
 
 
 
 
 
 
 
 
487
  )
488
- update_table_button = gr.Button("Update Table Names")
489
- table_update_status = gr.Textbox(label="Table Update Status")
490
 
491
- retrieved_results = gr.State({})
 
492
 
493
- def update_retrieval_results(query, top_k, sequential):
494
- results, timings = retrieve_data(query, top_k, sequential)
495
- timing_df = pd.DataFrame(
496
- list(timings.items()), columns=["RAG Type", "Time (s)"]
 
497
  )
498
- return (
499
- results["SimpleRAG"],
500
- results["VisionRAG"],
501
- results["ColpaliRAG"],
502
- results["HybridColpaliRAG"],
503
- timing_df,
504
- results,
505
  )
506
 
507
- retrieve_button.click(
508
- update_retrieval_results,
509
- inputs=[query_input, top_k_slider, sequential_checkbox],
510
- outputs=[
511
- simple_content,
512
- vision_gallery,
513
- colpali_gallery,
514
- hybrid_gallery,
515
- retrieval_timing,
516
- retrieved_results,
517
- ],
518
- )
519
 
520
- # def update_query_results(query, retrieved_results):
521
- # results = query_data(query, retrieved_results)
522
- # return (
523
- # results["SimpleRAG"]["response"],
524
- # results["VisionRAG"]["response"],
525
- # results["ColpaliRAG"]["response"],
526
- # results["HybridColpaliRAG"]["response"],
527
- # )
528
-
529
- # query_button.click(
530
- # update_query_results,
531
- # inputs=[query_input, retrieved_results],
532
- # outputs=[
533
- # simple_response,
534
- # vision_response,
535
- # colpali_response,
536
- # hybrid_response,
537
- # ],
538
- # )
539
-
540
- ingest_button.click(
541
- ingest_data,
542
- inputs=[pdf_input, use_ocr, chunk_size],
543
- outputs=[ingest_output, progress_table],
544
- )
545
 
546
- update_api_button.click(
547
- update_api_key, inputs=[api_key_input], outputs=api_update_status
548
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
 
550
- update_table_button.click(
551
- change_table,
552
- inputs=[
553
- simple_table_input,
554
- vision_table_input,
555
- colpali_table_input,
556
- hybrid_table_input,
557
- ],
558
- outputs=table_update_status,
559
- )
 
 
 
 
 
 
 
 
 
560
 
561
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
 
564
- # Parse command-line arguments
565
- def parse_args():
566
- parser = argparse.ArgumentParser(description="VisionRAG Gradio App")
567
- parser.add_argument(
568
- "--share", action="store_true", help="Enable Gradio share feature"
569
  )
570
- return parser.parse_args()
571
 
 
 
572
 
573
  # Launch the app
574
  if __name__ == "__main__":
575
- args = parse_args()
576
- app = gradio_interface()
577
- app.launch(share=args.share)
 
1
+ """
2
+ Gradio Demo for Document Retrieval - Hugging Face Spaces with ZeroGPU
3
+
4
+ This script creates a Gradio interface for testing both BiGemma3 and ColGemma3 models
5
+ with PDF document upload, automatic conversion to images, and query-based retrieval.
6
+
7
+ Features:
8
+ - PDF upload with automatic conversion to images
9
+ - Model selection: NetraEmbed (BiGemma3), ColNetraEmbed (ColGemma3), or Both
10
+ - Query input with top-k selection (default: 5)
11
+ - Similarity score display
12
+ - Side-by-side comparison when both models are selected
13
+ - Progressive loading with real-time updates
14
+ - Proper error handling
15
+ - ZeroGPU integration for efficient GPU usage
16
+ """
17
+
18
  import io
19
+ import gc
20
+ import math
21
+ from typing import Iterator, List, Optional, Tuple
22
+
23
+ import gradio as gr
 
 
 
 
 
24
  import torch
25
+ import spaces
26
+ from pdf2image import convert_from_path
27
+ from PIL import Image
28
+ import matplotlib.pyplot as plt
29
+ import numpy as np
30
+ import seaborn as sns
31
+ from einops import rearrange
32
+
33
+ # Import from colpali_engine
34
+ from colpali_engine.models import BiGemma3, BiGemmaProcessor3, ColGemma3, ColGemmaProcessor3
35
+ from colpali_engine.interpretability import get_similarity_maps_from_embeddings
36
+ from colpali_engine.interpretability.similarity_map_utils import normalize_similarity_map
37
+
38
+ # Configuration
39
+ MAX_BATCH_SIZE = 32 # Maximum pages to process at once
40
+ DEFAULT_DURATION = 120 # Default GPU duration in seconds
41
+
42
+ # Global state for models and indexed documents
43
+ class DocumentIndex:
44
+ def __init__(self):
45
+ self.images: List[Image.Image] = []
46
+ self.bigemma_embeddings = None
47
+ self.colgemma_embeddings = None
48
+ self.bigemma_model = None
49
+ self.bigemma_processor = None
50
+ self.colgemma_model = None
51
+ self.colgemma_processor = None
52
+ self.models_loaded = {"bigemma": False, "colgemma": False}
53
+
54
+ doc_index = DocumentIndex()
55
+
56
+ # Helper functions
57
+ def get_loaded_models() -> List[str]:
58
+ """Get list of currently loaded models."""
59
+ loaded = []
60
+ if doc_index.bigemma_model is not None:
61
+ loaded.append("BiGemma3")
62
+ if doc_index.colgemma_model is not None:
63
+ loaded.append("ColGemma3")
64
+ return loaded
65
+
66
+ def get_model_choice_from_loaded() -> str:
67
+ """Determine model choice string based on what's loaded."""
68
+ loaded = get_loaded_models()
69
+ if "BiGemma3" in loaded and "ColGemma3" in loaded:
70
+ return "Both"
71
+ elif "BiGemma3" in loaded:
72
+ return "NetraEmbed (BiGemma3)"
73
+ elif "ColGemma3" in loaded:
74
+ return "ColNetraEmbed (ColGemma3)"
75
+ else:
76
+ return ""
77
+
78
+ @spaces.GPU(duration=DEFAULT_DURATION)
79
+ def load_bigemma_model():
80
+ """Load BiGemma3 model and processor."""
81
+ device = "cuda" if torch.cuda.is_available() else "cpu"
82
+
83
+ if doc_index.bigemma_model is None:
84
+ print("Loading BiGemma3 (NetraEmbed)...")
85
+ try:
86
+ doc_index.bigemma_processor = BiGemmaProcessor3.from_pretrained(
87
+ "Cognitive-Lab/NetraEmbed",
88
+ use_fast=True,
89
+ )
90
+ doc_index.bigemma_model = BiGemma3.from_pretrained(
91
+ "Cognitive-Lab/NetraEmbed",
92
+ torch_dtype=torch.bfloat16,
93
+ device_map=device,
94
+ )
95
+ doc_index.bigemma_model.eval()
96
+ doc_index.models_loaded["bigemma"] = True
97
+ print("โœ“ BiGemma3 loaded successfully")
98
+ except Exception as e:
99
+ print(f"โŒ Failed to load BiGemma3: {str(e)}")
100
+ raise
101
+ return doc_index.bigemma_model, doc_index.bigemma_processor
102
+
103
+ @spaces.GPU(duration=DEFAULT_DURATION)
104
+ def load_colgemma_model():
105
+ """Load ColGemma3 model and processor."""
106
+ device = "cuda" if torch.cuda.is_available() else "cpu"
107
+
108
+ if doc_index.colgemma_model is None:
109
+ print("Loading ColGemma3 (ColNetraEmbed)...")
110
+ try:
111
+ doc_index.colgemma_model = ColGemma3.from_pretrained(
112
+ "Cognitive-Lab/ColNetraEmbed",
113
+ dtype=torch.bfloat16,
114
+ device_map=device,
115
+ )
116
+ doc_index.colgemma_model.eval()
117
+ doc_index.colgemma_processor = ColGemmaProcessor3.from_pretrained(
118
+ "Cognitive-Lab/ColNetraEmbed",
119
+ use_fast=True,
120
+ )
121
+ doc_index.models_loaded["colgemma"] = True
122
+ print("โœ“ ColGemma3 loaded successfully")
123
+ except Exception as e:
124
+ print(f"โŒ Failed to load ColGemma3: {str(e)}")
125
+ raise
126
+ return doc_index.colgemma_model, doc_index.colgemma_processor
127
+
128
+ def unload_models():
129
+ """Unload models and free GPU memory."""
130
+ try:
131
+ if doc_index.bigemma_model is not None:
132
+ del doc_index.bigemma_model
133
+ del doc_index.bigemma_processor
134
+ doc_index.bigemma_model = None
135
+ doc_index.bigemma_processor = None
136
+ doc_index.models_loaded["bigemma"] = False
137
+
138
+ if doc_index.colgemma_model is not None:
139
+ del doc_index.colgemma_model
140
+ del doc_index.colgemma_processor
141
+ doc_index.colgemma_model = None
142
+ doc_index.colgemma_processor = None
143
+ doc_index.models_loaded["colgemma"] = False
144
+
145
+ # Clear embeddings and images
146
+ doc_index.bigemma_embeddings = None
147
+ doc_index.colgemma_embeddings = None
148
+ doc_index.images = []
149
+
150
+ # Force garbage collection
151
+ gc.collect()
152
+ if torch.cuda.is_available():
153
+ torch.cuda.empty_cache()
154
+ torch.cuda.synchronize()
155
+
156
+ return "โœ… Models unloaded and GPU memory cleared"
157
+ except Exception as e:
158
+ return f"โŒ Error unloading models: {str(e)}"
159
+
160
+ def clear_incompatible_embeddings(model_choice: str) -> str:
161
+ """Clear embeddings that are incompatible with currently loading models."""
162
+ cleared = []
163
+
164
+ # If loading only BiGemma3, clear ColGemma3 embeddings
165
+ if model_choice == "NetraEmbed (BiGemma3)":
166
+ if doc_index.colgemma_embeddings is not None:
167
+ doc_index.colgemma_embeddings = None
168
+ doc_index.images = []
169
+ cleared.append("ColGemma3")
170
+ print("Cleared ColGemma3 embeddings")
171
+
172
+ # If loading only ColGemma3, clear BiGemma3 embeddings
173
+ elif model_choice == "ColNetraEmbed (ColGemma3)":
174
+ if doc_index.bigemma_embeddings is not None:
175
+ doc_index.bigemma_embeddings = None
176
+ doc_index.images = []
177
+ cleared.append("BiGemma3")
178
+ print("Cleared BiGemma3 embeddings")
179
+
180
+ if cleared:
181
+ return f"Cleared {', '.join(cleared)} embeddings - please re-index"
182
+ return ""
183
+
184
+ def pdf_to_images(pdf_path: str) -> List[Image.Image]:
185
+ """Convert PDF to list of PIL Images with error handling."""
186
+ try:
187
+ print(f"Converting PDF to images: {pdf_path}")
188
+ images = convert_from_path(pdf_path, dpi=200)
189
+ print(f"Converted {len(images)} pages")
190
+ return images
191
+ except Exception as e:
192
+ print(f"โŒ PDF conversion error: {str(e)}")
193
+ raise Exception(f"Failed to convert PDF: {str(e)}")
194
+
195
+ @spaces.GPU(duration=DEFAULT_DURATION)
196
+ def generate_colgemma_heatmap(
197
+ image: Image.Image,
198
+ query: str,
199
+ query_embedding: torch.Tensor,
200
+ image_embedding: torch.Tensor,
201
+ model,
202
+ processor,
203
+ ) -> Image.Image:
204
+ """Generate heatmap overlay for ColGemma3 results."""
205
+ try:
206
+ device = "cuda" if torch.cuda.is_available() else "cpu"
207
+
208
+ # Re-process the single image to get the proper batch_images dict for image mask
209
+ batch_images = processor.process_images([image]).to(device)
210
+
211
+ # Create image mask manually (ColGemmaProcessor3 doesn't have get_image_mask)
212
+ if "input_ids" in batch_images and hasattr(model.config, "image_token_id"):
213
+ image_token_id = model.config.image_token_id
214
+ image_mask = batch_images["input_ids"] == image_token_id
215
+ else:
216
+ # Fallback: all tokens are image tokens
217
+ image_mask = torch.ones(
218
+ image_embedding.shape[0], image_embedding.shape[1], dtype=torch.bool, device=device
219
+ )
220
 
221
+ # Calculate n_patches from actual number of image tokens
222
+ num_image_tokens = image_mask.sum().item()
223
+ n_side = int(math.sqrt(num_image_tokens))
224
+
225
+ if n_side * n_side == num_image_tokens:
226
+ n_patches = (n_side, n_side)
227
+ else:
228
+ # Fallback: use default calculation
229
+ n_patches = (16, 16)
230
+
231
+ # Generate similarity maps (returns a list of tensors)
232
+ similarity_maps_list = get_similarity_maps_from_embeddings(
233
+ image_embeddings=image_embedding,
234
+ query_embeddings=query_embedding,
235
+ n_patches=n_patches,
236
+ image_mask=image_mask,
237
  )
238
 
239
+ # Get the similarity map for our image (returns a list, get first element)
240
+ similarity_map = similarity_maps_list[0] # (query_length, n_patches_x, n_patches_y)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
+ # Aggregate across all query tokens (mean)
243
+ if similarity_map.dtype == torch.bfloat16:
244
+ similarity_map = similarity_map.float()
245
+ aggregated_map = torch.mean(similarity_map, dim=0)
246
+
247
+ # Convert the image to an array
248
+ img_array = np.array(image.convert("RGBA"))
249
+
250
+ # Normalize the similarity map and convert to numpy
251
+ similarity_map_array = normalize_similarity_map(aggregated_map).to(torch.float32).cpu().numpy()
252
+
253
+ # Reshape to match PIL convention
254
+ similarity_map_array = rearrange(similarity_map_array, "h w -> w h")
255
+
256
+ # Create PIL image from similarity map
257
+ similarity_map_image = Image.fromarray((similarity_map_array * 255).astype("uint8")).resize(
258
+ image.size, Image.Resampling.BICUBIC
259
  )
260
+
261
+ # Create matplotlib figure
262
+ fig, ax = plt.subplots(figsize=(10, 10))
263
+ ax.imshow(img_array)
264
+ ax.imshow(
265
+ similarity_map_image,
266
+ cmap=sns.color_palette("mako", as_cmap=True),
267
+ alpha=0.5,
268
  )
269
+ ax.set_axis_off()
270
+ plt.tight_layout()
271
+
272
+ # Convert to PIL Image
273
+ buffer = io.BytesIO()
274
+ plt.savefig(buffer, format="png", dpi=150, bbox_inches="tight", pad_inches=0)
275
+ buffer.seek(0)
276
+ heatmap_image = Image.open(buffer).copy()
277
+ plt.close()
278
+
279
+ return heatmap_image
280
+
281
+ except Exception as e:
282
+ print(f"โŒ Heatmap generation error: {str(e)}")
283
+ # Return original image if heatmap generation fails
284
+ return image
285
+
286
+ @spaces.GPU(duration=DEFAULT_DURATION)
287
+ def index_bigemma_images(images: List[Image.Image]) -> torch.Tensor:
288
+ """Index images with BiGemma3 model."""
289
+ device = "cuda" if torch.cuda.is_available() else "cpu"
290
+ model, processor = doc_index.bigemma_model, doc_index.bigemma_processor
291
+
292
+ batch_images = processor.process_images(images).to(device)
293
+ embeddings = model(**batch_images, embedding_dim=768)
294
+
295
+ return embeddings
296
+
297
+ @spaces.GPU(duration=DEFAULT_DURATION)
298
+ def index_colgemma_images(images: List[Image.Image]) -> torch.Tensor:
299
+ """Index images with ColGemma3 model."""
300
+ device = "cuda" if torch.cuda.is_available() else "cpu"
301
+ model, processor = doc_index.colgemma_model, doc_index.colgemma_processor
302
+
303
+ batch_images = processor.process_images(images).to(device)
304
+ embeddings = model(**batch_images)
305
+
306
+ return embeddings
307
+
308
+ def index_document(pdf_file, model_choice: str) -> Iterator[str]:
309
+ """Upload and index a PDF document with progress updates."""
310
+ if pdf_file is None:
311
+ yield "โš ๏ธ Please upload a PDF document first."
312
+ return
313
+
314
+ try:
315
+ status_messages = []
316
+
317
+ # Convert PDF to images
318
+ status_messages.append("โณ Converting PDF to images...")
319
+ yield "\n".join(status_messages)
320
+
321
+ doc_index.images = pdf_to_images(pdf_file.name)
322
+ num_pages = len(doc_index.images)
323
+
324
+ status_messages.append(f"โœ“ Converted PDF to {num_pages} images")
325
+
326
+ # Check if we need to batch process
327
+ if num_pages > MAX_BATCH_SIZE:
328
+ status_messages.append(f"โš ๏ธ Large PDF ({num_pages} pages). Processing in batches of {MAX_BATCH_SIZE}...")
329
+ yield "\n".join(status_messages)
330
+
331
+ # Index with BiGemma3
332
+ if model_choice in ["NetraEmbed (BiGemma3)", "Both"]:
333
+ if doc_index.bigemma_model is None:
334
+ status_messages.append("โณ Loading BiGemma3 model...")
335
+ yield "\n".join(status_messages)
336
+ load_bigemma_model()
337
+ status_messages.append("โœ“ BiGemma3 loaded")
338
+ else:
339
+ status_messages.append("โœ“ Using cached BiGemma3 model")
340
+
341
+ yield "\n".join(status_messages)
342
+
343
+ status_messages.append("โณ Encoding images with BiGemma3...")
344
+ yield "\n".join(status_messages)
345
+
346
+ doc_index.bigemma_embeddings = index_bigemma_images(doc_index.images)
347
+
348
+ status_messages.append("โœ“ Indexed with BiGemma3 (shape: {})".format(doc_index.bigemma_embeddings.shape))
349
+ yield "\n".join(status_messages)
350
+
351
+ # Index with ColGemma3
352
+ if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]:
353
+ if doc_index.colgemma_model is None:
354
+ status_messages.append("โณ Loading ColGemma3 model...")
355
+ yield "\n".join(status_messages)
356
+ load_colgemma_model()
357
+ status_messages.append("โœ“ ColGemma3 loaded")
358
+ else:
359
+ status_messages.append("โœ“ Using cached ColGemma3 model")
360
+
361
+ yield "\n".join(status_messages)
362
+
363
+ status_messages.append("โณ Encoding images with ColGemma3...")
364
+ yield "\n".join(status_messages)
365
+
366
+ doc_index.colgemma_embeddings = index_colgemma_images(doc_index.images)
367
 
368
+ status_messages.append(
369
+ "โœ“ Indexed with ColGemma3 (shape: {})".format(doc_index.colgemma_embeddings.shape)
370
+ )
371
+ yield "\n".join(status_messages)
372
+
373
+ final_status = "\n".join(status_messages) + "\n\nโœ… Document ready for querying!"
374
+ yield final_status
375
+
376
+ except Exception as e:
377
+ import traceback
378
+
379
+ error_details = traceback.format_exc()
380
+ print(f"Indexing error: {error_details}")
381
+ yield f"โŒ Error indexing document: {str(e)}"
382
+
383
+ @spaces.GPU(duration=DEFAULT_DURATION)
384
+ def query_bigemma(query: str, top_k: int) -> Tuple[str, List]:
385
+ """Query indexed documents with BiGemma3."""
386
+ device = "cuda" if torch.cuda.is_available() else "cpu"
387
+ model, processor = doc_index.bigemma_model, doc_index.bigemma_processor
388
+
389
+ # Encode query
390
+ batch_query = processor.process_texts([query]).to(device)
391
+ query_embedding = model(**batch_query, embedding_dim=768)
392
+
393
+ # Compute scores (cosine similarity)
394
+ scores = processor.score(
395
+ qs=query_embedding,
396
+ ps=doc_index.bigemma_embeddings,
397
  )
398
 
399
+ # Get top-k results
400
+ top_k_actual = min(top_k, len(doc_index.images))
401
+ top_indices = scores[0].argsort(descending=True)[:top_k_actual]
402
 
403
+ # Format results
404
+ results_text = "### BiGemma3 (NetraEmbed) Results\n\n"
405
+ gallery_images = []
406
+
407
+ for rank, idx in enumerate(top_indices):
408
+ score = scores[0, idx].item()
409
+ results_text += f"**Rank {rank + 1}:** Page {idx.item() + 1} - Score: {score:.4f}\n"
410
+ gallery_images.append(
411
+ (doc_index.images[idx.item()], f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.4f})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  )
413
 
414
+ return results_text, gallery_images
 
 
 
415
 
416
+ @spaces.GPU(duration=DEFAULT_DURATION)
417
+ def query_colgemma(query: str, top_k: int, show_heatmap: bool = False) -> Tuple[str, List]:
418
+ """Query indexed documents with ColGemma3."""
419
+ device = "cuda" if torch.cuda.is_available() else "cpu"
420
+ model, processor = doc_index.colgemma_model, doc_index.colgemma_processor
421
 
422
+ # Encode query
423
+ batch_query = processor.process_queries([query]).to(device)
424
+ query_embedding = model(**batch_query)
425
 
426
+ # Compute scores (MaxSim)
427
+ scores = processor.score_multi_vector(
428
+ qs=query_embedding,
429
+ ps=doc_index.colgemma_embeddings,
430
+ )
431
 
432
+ # Get top-k results
433
+ top_k_actual = min(top_k, len(doc_index.images))
434
+ top_indices = scores[0].argsort(descending=True)[:top_k_actual]
435
+
436
+ # Format results
437
+ results_text = "### ColGemma3 (ColNetraEmbed) Results\n\n"
438
+ gallery_images = []
439
+
440
+ for rank, idx in enumerate(top_indices):
441
+ score = scores[0, idx].item()
442
+ results_text += f"**Rank {rank + 1}:** Page {idx.item() + 1} - Score: {score:.2f}\n"
443
+
444
+ # Generate heatmap if requested
445
+ if show_heatmap:
446
+ heatmap_image = generate_colgemma_heatmap(
447
+ image=doc_index.images[idx.item()],
448
+ query=query,
449
+ query_embedding=query_embedding,
450
+ image_embedding=doc_index.colgemma_embeddings[idx.item()].unsqueeze(0),
451
+ model=model,
452
+ processor=processor,
453
  )
454
+ gallery_images.append(
455
+ (heatmap_image, f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.2f})")
456
  )
457
+ else:
458
+ gallery_images.append(
459
+ (
460
+ doc_index.images[idx.item()],
461
+ f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.2f})",
462
+ )
463
  )
464
+
465
+ return results_text, gallery_images
466
+
467
+ def query_documents(
468
+ query: str, model_choice: str, top_k: int, show_heatmap: bool = False
469
+ ) -> Tuple[Optional[str], Optional[str], Optional[List], Optional[List]]:
470
+ """Query the indexed documents."""
471
+ if not doc_index.images:
472
+ return "โš ๏ธ Please upload and index a document first.", None, None, None
473
+
474
+ if not query.strip():
475
+ return "โš ๏ธ Please enter a query.", None, None, None
476
+
477
+ try:
478
+ results_bi = None
479
+ results_col = None
480
+ gallery_images_bi = []
481
+ gallery_images_col = []
482
+
483
+ # Query with BiGemma3
484
+ if model_choice in ["NetraEmbed (BiGemma3)", "Both"]:
485
+ if doc_index.bigemma_embeddings is None:
486
+ return "โš ๏ธ Please index the document with BiGemma3 first.", None, None, None
487
+
488
+ results_bi, gallery_images_bi = query_bigemma(query, top_k)
489
+
490
+ # Query with ColGemma3
491
+ if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]:
492
+ if doc_index.colgemma_embeddings is None:
493
+ return "โš ๏ธ Please index the document with ColGemma3 first.", None, None, None
494
+
495
+ results_col, gallery_images_col = query_colgemma(query, top_k, show_heatmap)
496
+
497
+ # Return results based on model choice
498
+ if model_choice == "NetraEmbed (BiGemma3)":
499
+ return results_bi, None, gallery_images_bi, None
500
+ elif model_choice == "ColNetraEmbed (ColGemma3)":
501
+ return results_col, None, None, gallery_images_col
502
+ else: # Both
503
+ return results_bi, results_col, gallery_images_bi, gallery_images_col
504
+
505
+ except Exception as e:
506
+ import traceback
507
+
508
+ error_details = traceback.format_exc()
509
+ print(f"Query error: {error_details}")
510
+ return f"โŒ Error during query: {str(e)}", None, None, None
511
+
512
+ def load_models_with_progress(model_choice: str) -> Iterator[Tuple]:
513
+ """Load models with progress updates."""
514
+ if not model_choice:
515
+ yield (
516
+ "โŒ Please select a model first.",
517
+ gr.update(visible=True),
518
+ gr.update(visible=False),
519
+ gr.update(visible=False),
520
+ gr.update(visible=False),
521
+ gr.update(visible=False),
522
+ gr.update(interactive=False),
523
+ gr.update(interactive=False),
524
+ gr.update(interactive=False),
525
+ gr.update(interactive=False),
526
+ gr.update(interactive=False),
527
+ gr.update(value="Load model first"),
528
+ )
529
+ return
530
+
531
+ try:
532
+ status_messages = []
533
+
534
+ # Clear incompatible embeddings
535
+ clear_msg = clear_incompatible_embeddings(model_choice)
536
+ if clear_msg:
537
+ status_messages.append(f"โš ๏ธ {clear_msg}")
538
+
539
+ # Load BiGemma3
540
+ if model_choice in ["NetraEmbed (BiGemma3)", "Both"]:
541
+ status_messages.append("โณ Loading BiGemma3 (NetraEmbed)...")
542
+ yield (
543
+ "\n".join(status_messages),
544
+ gr.update(visible=True),
545
+ gr.update(visible=False),
546
+ gr.update(visible=False),
547
+ gr.update(visible=False),
548
+ gr.update(visible=False),
549
+ gr.update(interactive=False),
550
+ gr.update(interactive=False),
551
+ gr.update(interactive=False),
552
+ gr.update(interactive=False),
553
+ gr.update(interactive=False),
554
+ gr.update(value="Loading models..."),
555
  )
556
+
557
+ load_bigemma_model()
558
+ status_messages[-1] = "โœ… BiGemma3 loaded successfully"
559
+ yield (
560
+ "\n".join(status_messages),
561
+ gr.update(visible=True),
562
+ gr.update(visible=False),
563
+ gr.update(visible=False),
564
+ gr.update(visible=False),
565
+ gr.update(visible=False),
566
+ gr.update(interactive=False),
567
+ gr.update(interactive=False),
568
+ gr.update(interactive=False),
569
+ gr.update(interactive=False),
570
+ gr.update(interactive=False),
571
+ gr.update(value="Loading models..."),
572
  )
573
 
574
+ # Load ColGemma3
575
+ if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]:
576
+ status_messages.append("โณ Loading ColGemma3 (ColNetraEmbed)...")
577
+ yield (
578
+ "\n".join(status_messages),
579
+ gr.update(visible=True),
580
+ gr.update(visible=False),
581
+ gr.update(visible=False),
582
+ gr.update(visible=False),
583
+ gr.update(visible=False),
584
+ gr.update(interactive=False),
585
+ gr.update(interactive=False),
586
+ gr.update(interactive=False),
587
+ gr.update(interactive=False),
588
+ gr.update(interactive=False),
589
+ gr.update(value="Loading models..."),
590
+ )
591
 
592
+ load_colgemma_model()
593
+ status_messages[-1] = "โœ… ColGemma3 loaded successfully"
594
+ yield (
595
+ "\n".join(status_messages),
596
+ gr.update(visible=True),
597
+ gr.update(visible=False),
598
+ gr.update(visible=False),
599
+ gr.update(visible=False),
600
+ gr.update(visible=False),
601
+ gr.update(interactive=False),
602
+ gr.update(interactive=False),
603
+ gr.update(interactive=False),
604
+ gr.update(interactive=False),
605
+ gr.update(interactive=False),
606
+ gr.update(value="Loading models..."),
607
  )
608
 
609
+ # Determine column visibility based on loaded models
610
+ show_bigemma = model_choice in ["NetraEmbed (BiGemma3)", "Both"]
611
+ show_colgemma = model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]
612
+ show_heatmap_checkbox = model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]
613
+
614
+ final_status = "\n".join(status_messages) + "\n\nโœ… Ready!"
615
+ yield (
616
+ final_status,
617
+ gr.update(visible=False),
618
+ gr.update(visible=True),
619
+ gr.update(visible=show_bigemma),
620
+ gr.update(visible=show_colgemma),
621
+ gr.update(visible=show_heatmap_checkbox),
622
+ gr.update(interactive=True),
623
+ gr.update(interactive=True),
624
+ gr.update(interactive=True),
625
+ gr.update(interactive=True),
626
+ gr.update(interactive=True),
627
+ gr.update(value="Ready to index"),
628
+ )
629
 
630
+ except Exception as e:
631
+ import traceback
632
+
633
+ error_details = traceback.format_exc()
634
+ print(f"Model loading error: {error_details}")
635
+ yield (
636
+ f"โŒ Failed to load models: {str(e)}",
637
+ gr.update(visible=True),
638
+ gr.update(visible=False),
639
+ gr.update(visible=False),
640
+ gr.update(visible=False),
641
+ gr.update(visible=False),
642
+ gr.update(interactive=False),
643
+ gr.update(interactive=False),
644
+ gr.update(interactive=False),
645
+ gr.update(interactive=False),
646
+ gr.update(interactive=False),
647
+ gr.update(value="Load model first"),
648
+ )
649
+
650
+ def unload_models_and_hide_ui():
651
+ """Unload models and hide main UI."""
652
+ status = unload_models()
653
+ return (
654
+ status,
655
+ gr.update(visible=True),
656
+ gr.update(visible=False),
657
+ gr.update(visible=False),
658
+ gr.update(visible=False),
659
+ gr.update(visible=False),
660
+ gr.update(interactive=False),
661
+ gr.update(interactive=False),
662
+ gr.update(interactive=False),
663
+ gr.update(interactive=False),
664
+ gr.update(interactive=False),
665
+ gr.update(value="Load model first"),
666
+ )
667
+
668
+ # Create Gradio interface
669
+ with gr.Blocks(
670
+ title="NetraEmbed Demo",
671
+ ) as demo:
672
+ # Header section with model info and banner
673
+ with gr.Row():
674
+ with gr.Column(scale=1):
675
+ gr.Markdown("# NetraEmbed")
676
+ gr.HTML(
677
+ """
678
+ <div style="display: flex; gap: 8px; flex-wrap: wrap; margin-bottom: 15px;">
679
+ <a href="https://arxiv.org/abs/2512.03514" target="_blank">
680
+ <img src="https://img.shields.io/badge/arXiv-2512.03514-b31b1b.svg" alt="Paper">
681
+ </a>
682
+ <a href="https://github.com/adithya-s-k/colpali" target="_blank">
683
+ <img src="https://img.shields.io/badge/GitHub-colpali-181717?logo=github" alt="GitHub">
684
+ </a>
685
+ <a href="https://huggingface.co/Cognitive-Lab/ColNetraEmbed" target="_blank">
686
+ <img src="https://img.shields.io/badge/๐Ÿค—%20HuggingFace-Model-yellow" alt="Model">
687
+ </a>
688
+ <a href="https://www.cognitivelab.in/blog/introducing-netraembed" target="_blank">
689
+ <img src="https://img.shields.io/badge/Blog-CognitiveLab-blue" alt="Blog">
690
+ </a>
691
+ <a href="https://cloud.cognitivelab.in" target="_blank">
692
+ <img src="https://img.shields.io/badge/Demo-Try%20it%20out-green" alt="Demo">
693
+ </a>
694
+ </div>
695
+ """
696
  )
697
+ gr.Markdown(
698
+ """
699
+
700
+ **๐Ÿš€ Universal Multilingual Multimodal Document Retrieval**
701
+
702
+ Upload a PDF document, select your model(s), and query using semantic search.
703
+
704
+ **Available Models:**
705
+ - **NetraEmbed (BiGemma3)**: Single-vector embedding with Matryoshka representation
706
+ Fast retrieval with cosine similarity
707
+ - **ColNetraEmbed (ColGemma3)**: Multi-vector embedding with late interaction
708
+ High-quality retrieval with MaxSim scoring and attention heatmaps
709
+
710
+ """
711
  )
712
+
713
+ with gr.Column(scale=1):
714
+ gr.HTML(
715
+ """
716
+ <div style="text-align: center;">
717
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/6442d975ad54813badc1ddf7/-fYMikXhSuqRqm-UIdulK.png"
718
+ alt="NetraEmbed Banner"
719
+ style="width: 100%; height: auto; border-radius: 8px;">
720
+ </div>
721
+ """
722
  )
723
+
724
+ gr.Markdown("---")
725
+
726
+ # Compact 3-column layout
727
+ with gr.Row():
728
+ # Column 1: Model Management
729
+ with gr.Column(scale=1):
730
+ gr.Markdown("### ๐Ÿค– Model Management")
731
+ model_select = gr.Radio(
732
+ choices=["NetraEmbed (BiGemma3)", "ColNetraEmbed (ColGemma3)", "Both"],
733
+ value="Both",
734
+ label="Select Model(s)",
735
  )
 
 
736
 
737
+ load_model_btn = gr.Button("๐Ÿ”„ Load Model", variant="primary", size="sm")
738
+ unload_model_btn = gr.Button("๐Ÿ—‘๏ธ Unload", variant="secondary", size="sm")
739
 
740
+ model_status = gr.Textbox(
741
+ label="Status",
742
+ lines=6,
743
+ interactive=False,
744
+ value="Select and load a model",
745
  )
746
+
747
+ loading_info = gr.Markdown(
748
+ """
749
+ **First load:** 2-3 min
750
+ **Cached:** ~30 sec
751
+ """,
752
+ visible=True,
753
  )
754
 
755
+ # Column 2: Document Upload & Indexing
756
+ with gr.Column(scale=1):
757
+ gr.Markdown("### ๐Ÿ“„ Upload & Index")
758
+ pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"], interactive=False)
759
+ index_btn = gr.Button("๐Ÿ“ฅ Index Document", variant="primary", size="sm", interactive=False)
760
+
761
+ index_status = gr.Textbox(
762
+ label="Indexing Status",
763
+ lines=6,
764
+ interactive=False,
765
+ value="Load model first",
766
+ )
767
 
768
+ # Column 3: Query
769
+ with gr.Column(scale=1):
770
+ gr.Markdown("### ๐Ÿ”Ž Query Document")
771
+ query_input = gr.Textbox(
772
+ label="Enter Query",
773
+ placeholder="e.g., financial report, organizational structure...",
774
+ lines=2,
775
+ interactive=False,
776
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
777
 
778
+ with gr.Row():
779
+ top_k_slider = gr.Slider(
780
+ minimum=1,
781
+ maximum=10,
782
+ value=5,
783
+ step=1,
784
+ label="Top K",
785
+ scale=2,
786
+ interactive=False,
787
+ )
788
+ heatmap_checkbox = gr.Checkbox(
789
+ label="Heatmaps",
790
+ value=False,
791
+ visible=False,
792
+ scale=1,
793
+ )
794
+
795
+ query_btn = gr.Button("๐Ÿ” Search", variant="primary", size="sm", interactive=False)
796
+
797
+ gr.Markdown("---")
798
+
799
+ # Results section (always visible after model load)
800
+ with gr.Column(visible=False) as main_interface:
801
+ gr.Markdown("### ๐Ÿ“Š Results")
802
+
803
+ with gr.Row(equal_height=True):
804
+ with gr.Column(scale=1, visible=False) as bigemma_column:
805
+ bigemma_results = gr.Markdown(
806
+ value="*BiGemma3 results will appear here...*",
807
+ )
808
+ bigemma_gallery = gr.Gallery(
809
+ label="BiGemma3 - Top Retrieved Pages",
810
+ show_label=True,
811
+ columns=2,
812
+ height="auto",
813
+ object_fit="contain",
814
+ )
815
+ with gr.Column(scale=1, visible=False) as colgemma_column:
816
+ colgemma_results = gr.Markdown(
817
+ value="*ColGemma3 results will appear here...*",
818
+ )
819
+ colgemma_gallery = gr.Gallery(
820
+ label="ColGemma3 - Top Retrieved Pages",
821
+ show_label=True,
822
+ columns=2,
823
+ height="auto",
824
+ object_fit="contain",
825
+ )
826
+
827
+ # Tips
828
+ with gr.Accordion("๐Ÿ’ก Tips", open=False):
829
+ gr.Markdown(
830
+ """
831
+ - **Both models**: Compare results side-by-side
832
+ - **Scores**: BiGemma3 uses cosine similarity (-1 to 1), ColGemma3 uses MaxSim (higher is better)
833
+ - **Heatmaps**: Enable to visualize ColGemma3 attention patterns (brighter = higher attention)
834
+ """
835
+ )
836
 
837
+ # Event handlers - Model Management
838
+ load_model_btn.click(
839
+ fn=load_models_with_progress,
840
+ inputs=[model_select],
841
+ outputs=[
842
+ model_status,
843
+ loading_info,
844
+ main_interface,
845
+ bigemma_column,
846
+ colgemma_column,
847
+ heatmap_checkbox,
848
+ pdf_upload,
849
+ index_btn,
850
+ query_input,
851
+ top_k_slider,
852
+ query_btn,
853
+ index_status,
854
+ ],
855
+ )
856
 
857
+ unload_model_btn.click(
858
+ fn=unload_models_and_hide_ui,
859
+ outputs=[
860
+ model_status,
861
+ loading_info,
862
+ main_interface,
863
+ bigemma_column,
864
+ colgemma_column,
865
+ heatmap_checkbox,
866
+ pdf_upload,
867
+ index_btn,
868
+ query_input,
869
+ top_k_slider,
870
+ query_btn,
871
+ index_status,
872
+ ],
873
+ )
874
 
875
+ # Event handlers - Main Interface
876
+ def index_with_current_models(pdf_file):
877
+ """Index document with currently loaded models."""
878
+ if pdf_file is None:
879
+ yield "โš ๏ธ Please upload a PDF document first."
880
+ return
881
+
882
+ model_choice = get_model_choice_from_loaded()
883
+ if not model_choice:
884
+ yield "โš ๏ธ No models loaded. Please load a model first."
885
+ return
886
+
887
+ # Use generator from index_document
888
+ for status in index_document(pdf_file, model_choice):
889
+ yield status
890
+
891
+ def query_with_current_models(query, top_k, show_heatmap):
892
+ """Query with currently loaded models."""
893
+ model_choice = get_model_choice_from_loaded()
894
+ if not model_choice:
895
+ return "โš ๏ธ No models loaded. Please load a model first.", None, None, None
896
+
897
+ return query_documents(query, model_choice, top_k, show_heatmap)
898
+
899
+ index_btn.click(
900
+ fn=index_with_current_models,
901
+ inputs=[pdf_upload],
902
+ outputs=[index_status],
903
+ )
904
 
905
+ query_btn.click(
906
+ fn=query_with_current_models,
907
+ inputs=[query_input, top_k_slider, heatmap_checkbox],
908
+ outputs=[bigemma_results, colgemma_results, bigemma_gallery, colgemma_gallery],
 
909
  )
 
910
 
911
+ # Enable queue for handling multiple requests
912
+ demo.queue(max_size=20)
913
 
914
  # Launch the app
915
  if __name__ == "__main__":
916
+ demo.launch()
 
 
requirements.txt CHANGED
@@ -1,17 +1,12 @@
1
- git+https://github.com/huggingface/transformers
2
- torch
3
- lancedb
4
- colpali-engine
5
  pdf2image
6
- pypdf
7
- pymupdf
8
- timm
9
  einops
10
- sentence-transformers
11
- tiktoken
12
- docling==1.16.1
13
- pdf2image
14
- GPUtil
15
- accelerate==0.30.1
16
  torchvision
17
- git+https://github.com/adithya-s-k/VARAG
 
1
+ spaces
2
+ git+https://github.com/adithya-s-k/colpali.git
3
+ gradio
 
4
  pdf2image
5
+ Pillow
6
+ matplotlib
7
+ seaborn
8
  einops
9
+ numpy
10
+ torch
 
 
 
 
11
  torchvision
12
+ accelerate