monurcan's picture
xx
840e373
raw
history blame
3.84 kB
import gradio as gr
import torch
from transformers import (
AutoModelForImageTextToText,
AutoProcessor,
TextIteratorStreamer,
)
from peft import PeftModel
from transformers.image_utils import load_image
from threading import Thread
import time
import html
def progress_bar_html(label: str) -> str:
"""
Returns an HTML snippet for a thin progress bar with a label.
The progress bar is styled as a dark animated bar.
"""
return f"""
<div style="display: flex; align-items: center;">
<span style="margin-right: 10px; font-size: 14px;">{label}</span>
<div style="width: 110px; height: 5px; background-color: #9370DB; border-radius: 2px; overflow: hidden;">
<div style="width: 100%; height: 100%; background-color: #4B0082; animation: loading 1.5s linear infinite;"></div>
</div>
</div>
<style>
@keyframes loading {{
0% {{ transform: translateX(-100%); }}
100% {{ transform: translateX(100%); }}
}}
</style>
"""
model_name = "HuggingFaceTB/SmolVLM2-256M-Video-Instruct"
model = AutoModelForImageTextToText.from_pretrained(
model_name, dtype=torch.float32, device_map="auto"
).eval()
processor = AutoProcessor.from_pretrained(model_name)
print(f"Successfully load the model: {model}")
def model_inference(input_dict, history):
text = input_dict["text"]
files = input_dict["files"]
if len(files) > 1:
images = [load_image(image) for image in files]
elif len(files) == 1:
images = [load_image(files[0])]
else:
images = []
if text == "" and not images:
gr.Error("Please input a query and optionally image(s).")
return
if text == "" and images:
gr.Error("Please input a text query along with the image(s).")
return
messages = [
{
"role": "user",
"content": [
*[{"type": "image", "image": image} for image in images],
{"type": "text", "text": text},
],
}
]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(model.device, dtype=model.dtype)
streamer = TextIteratorStreamer(
processor, skip_prompt=True, skip_special_tokens=True
)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=2048)
# start timer just before generation begins
start_time = time.time()
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = "Baseline Model Response: "
yield progress_bar_html("Processing...")
for new_text in streamer:
escaped_new_text = html.escape(new_text)
buffer += escaped_new_text
time.sleep(0.001)
yield buffer
# Ensure generation thread has finished and measure elapsed time
thread.join()
elapsed = time.time() - start_time
elapsed_text = f"\nBaseline Generation Time: {elapsed:.2f} s"
buffer += html.escape(elapsed_text)
yield buffer
examples = [
[
{
"text": "Write a descriptive caption for this image in a formal tone.",
"files": ["example_images/example.png"],
}
],
[
{
"text": "What are the characters wearing?",
"files": ["example_images/example.png"],
}
],
]
demo = gr.ChatInterface(
fn=model_inference,
description="# **Smolvlm2-500M-illustration-description** \n (running on CPU) The model only sees the last input, it ignores the previous conversation history.",
examples=examples,
fill_height=True,
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"]),
stop_btn="Stop Generation",
multimodal=True,
cache_examples=False,
)
demo.launch(debug=True)