qingshan777 commited on
Commit
1d2ccf3
·
verified ·
1 Parent(s): fa92425

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+ import io
5
+ from PIL import Image
6
+ from transformers import (
7
+ AutoImageProcessor,
8
+ AutoTokenizer,
9
+ AutoModelForCausalLM,
10
+ )
11
+ import numpy as np
12
+ import ast
13
+
14
+
15
+ model_root = "qihoo360/fg-clip2-base"
16
+
17
+ model = AutoModelForCausalLM.from_pretrained(model_root,trust_remote_code=True)
18
+
19
+ device = model.device
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained(model_root)
22
+ image_processor = AutoImageProcessor.from_pretrained(model_root)
23
+
24
+
25
+ def determine_max_value(image):
26
+
27
+ w,h = image.size
28
+ max_val = (w//16)*(h//16)
29
+
30
+ if max_val > 784:
31
+ return 1024
32
+ elif max_val > 576:
33
+ return 784
34
+ elif max_val > 256:
35
+ return 576
36
+ elif max_val > 128:
37
+ return 256
38
+ else:
39
+ return 128
40
+
41
+ def postprocess_result(probs, labels):
42
+ pro_output = {labels[i]: probs[i] for i in range(len(labels))}
43
+
44
+ return pro_output
45
+
46
+
47
+ def Retrieval(image, candidate_labels, text_type):
48
+ """
49
+ Takes an image and a comma-separated string of candidate labels,
50
+ and returns the classification scores.
51
+ """
52
+
53
+ image = image.convert("RGB")
54
+
55
+ image_input = image_processor(images=image, max_num_patches=determine_max_value(image), return_tensors="pt").to(device)
56
+
57
+ candidate_labels = [candidate_labels.lower() for candidate_labels in candidate_labels]
58
+
59
+ if text_type=="long":
60
+ max_length = 196
61
+ else:
62
+ max_length = 64
63
+
64
+ caption_input = tokenizer(candidate_labels, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt").to(device)
65
+
66
+ with torch.no_grad():
67
+ image_feature = model.get_image_features(**image_input)
68
+ text_feature = model.get_text_features(**caption_input,walk_type=text_type)
69
+ image_feature = image_feature / image_feature.norm(p=2, dim=-1, keepdim=True)
70
+ text_feature = text_feature / text_feature.norm(p=2, dim=-1, keepdim=True)
71
+
72
+ logits_per_image = image_feature @ text_feature.T
73
+ logit_scale, logit_bias = model.logit_scale.to(text_feature.device), model.logit_bias.to(text_feature.device)
74
+ logits_per_image = logits_per_image * logit_scale.exp() + logit_bias
75
+ print(logits_per_image)
76
+ # probs = torch.sigmoid(logits_per_image)
77
+ probs = logits_per_image.softmax(dim=1)
78
+ print(probs)
79
+
80
+ results = probs[0].tolist()
81
+ return results
82
+
83
+
84
+
85
+
86
+ def infer(image, candidate_labels, text_type):
87
+ assert text_type in ["short","long", "box"]
88
+ candidate_labels = ast.literal_eval(candidate_labels)
89
+ fg_probs = Retrieval(image, candidate_labels,text_type)
90
+ return postprocess_result(fg_probs,candidate_labels)
91
+
92
+
93
+ with gr.Blocks() as demo:
94
+ gr.Markdown("# FG-CLIP 2 Retrieval")
95
+ gr.Markdown(
96
+
97
+ "This app uses the FG-CLIP 2 model (qihoo360/fg-clip2-base) for retrieval on CPU :"
98
+ )
99
+
100
+ with gr.Row():
101
+ with gr.Column():
102
+ image_input = gr.Image(type="pil")
103
+ text_input = gr.Textbox(label="Input a list of labels, example:['a','b','c']")
104
+ text_type = gr.Textbox(label="form [short, long, box] select")
105
+ run_button = gr.Button("Run Retrieval", visible=True)
106
+ with gr.Column():
107
+ fg_output = gr.Label(label="FG-CLIP 2 Output", num_top_classes=11)
108
+
109
+
110
+ examples = [
111
+
112
+
113
+ ["./000093.jpg", str([
114
+ "一个简约风格的卧室角落,黑色金属衣架上挂着多件米色和白色的衣物,下方架子放着两双浅色鞋子,旁边是一盆绿植,左侧可见一张铺有白色床单和灰色枕头的床。",
115
+ "一个简约风格的卧室角落,黑色金属衣架上挂着多件红色和蓝色的衣物,下方架子放着两双黑色高跟鞋,旁边是一盆绿植,左侧可见一张铺有白色床单和灰色枕头的床。",
116
+ "一个简约风格的卧室角落,黑色金属衣架上挂着多件米色和白色的衣物,下方架子放着两双运动鞋,旁边是一盆仙人掌,左侧可见一张铺有白色床单和灰色枕头的床。",
117
+ "一个繁忙的街头市场,摊位上摆满水果,背景是高楼大厦,人们在喧闹中购物。"
118
+ ]
119
+ )],
120
+ ["./000093.jpg", str([
121
+ "A minimalist-style bedroom corner with a black metal clothing rack holding several beige and white garments, two pairs of light-colored shoes on the shelf below, a potted green plant nearby, and to the left, a bed made with white sheets and gray pillows.",
122
+ "A minimalist-style bedroom corner with a black metal clothing rack holding several red and blue garments, two pairs of black high heels on the shelf below, a potted green plant nearby, and to the left, a bed made with white sheets and gray pillows.",
123
+ "A minimalist-style bedroom corner with a black metal clothing rack holding several beige and white garments, two pairs of sneakers on the shelf below, a potted cactus nearby, and to the left, a bed made with white sheets and gray pillows.",
124
+ "A bustling street market with fruit-filled stalls, skyscrapers in the background, and people shopping amid the noise and activity."
125
+ ]
126
+ )],
127
+
128
+ ]
129
+ gr.Examples(
130
+ examples=examples,
131
+ inputs=[image_input, text_input, text_type],
132
+ outputs=fg_output,
133
+ fn=infer,
134
+ )
135
+ run_button.click(fn=infer, inputs=[image_input, text_input, text_type], outputs=fg_output)
136
+
137
+ # demo.launch(server_name="0.0.0.0", server_port=7861, share=True)
138
+ demo.launch()