|
|
diff --git a/app.py b/app.py |
|
|
index 0000000..1111111 100644 |
|
|
--- a/app.py |
|
|
+++ b/app.py |
|
|
@@ -1,16 +1,28 @@ |
|
|
import gradio as gr |
|
|
import torch |
|
|
from PIL import Image |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
-import spaces |
|
|
|
|
|
|
|
|
MID = "apple/FastVLM-0.5B" |
|
|
IMAGE_TOKEN_INDEX = -200 |
|
|
|
|
|
|
|
|
tok = None |
|
|
model = None |
|
|
def load_model(): |
|
|
global tok, model |
|
|
if tok is None or model is None: |
|
|
print("Loading model...") |
|
|
tok = AutoTokenizer.from_pretrained(MID, trust_remote_code=True) |
|
|
- model = AutoModelForCausalLM.from_pretrained( |
|
|
- MID, |
|
|
- torch_dtype=torch.float16, |
|
|
- device_map="cuda", |
|
|
- trust_remote_code=True, |
|
|
- ) |
|
|
+ |
|
|
+ use_cuda = torch.cuda.is_available() |
|
|
+ device_map = "cuda" if use_cuda else "cpu" |
|
|
+ |
|
|
+ dtype = torch.float16 if use_cuda else torch.float32 |
|
|
+ |
|
|
+ model = AutoModelForCausalLM.from_pretrained( |
|
|
+ MID, |
|
|
+ torch_dtype=dtype, |
|
|
+ device_map=device_map, |
|
|
+ trust_remote_code=True, |
|
|
+ ) |
|
|
print("Model loaded successfully!") |
|
|
return tok, model |
|
|
- |
|
|
-@spaces.GPU(duration=60) |
|
|
+ |
|
|
+ |
|
|
def caption_image(image, custom_prompt=None): |
|
|
@@ -66,16 +78,23 @@ def caption_image(image, custom_prompt=None): |
|
|
|
|
|
- img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype) |
|
|
- input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device) |
|
|
- attention_mask = torch.ones_like(input_ids, device=model.device) |
|
|
+ |
|
|
+ model_device = next(model.parameters()).device |
|
|
+ model_dtype = next(model.parameters()).dtype |
|
|
+ |
|
|
+ img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype, device=model_device) |
|
|
+ input_ids = torch.cat([pre_ids.to(model_device), img_tok, post_ids.to(model_device)], dim=1) |
|
|
+ attention_mask = torch.ones_like(input_ids, device=model_device) |
|
|
|
|
|
|
|
|
px = model.get_vision_tower().image_processor( |
|
|
images=image, return_tensors="pt" |
|
|
)["pixel_values"] |
|
|
- px = px.to(model.device, dtype=model.dtype) |
|
|
+ px = px.to(model_device, dtype=model_dtype) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
out = model.generate( |
|
|
inputs=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
images=px, |
|
|
max_new_tokens=128, |
|
|
do_sample=False, |
|
|
- temperature=1.0, |
|
|
+ |
|
|
) |
|
|
|