88hours commited on
Commit
aacd9e8
·
0 Parent(s):

First Demo with Gradio

Browse files
Files changed (5) hide show
  1. README.md +6 -0
  2. gradio_utils.py +483 -0
  3. main.py +10 -0
  4. requirements.txt +1 -0
  5. s1-lrn-gradio.py +12 -0
README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Journey into learning - 4:00 pm
2
+ #Step 1 - learn Gradio
3
+ Great for making quick UI in python, that will run in browser. It also has hot reloading.
4
+ fn: The function to wrap a user interface (UI) around
5
+ inputs: the Gradio component(s) to use for the input. The number of components should match the number of arguments in your function.
6
+ outputs: the Gradio component(s) to use for the output. The number of components should match the number of return values from your function.
gradio_utils.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import io
3
+ import sys
4
+ import time
5
+ import dataclasses
6
+ from pathlib import Path
7
+ import os
8
+ from enum import auto, Enum
9
+ from typing import List, Tuple, Any
10
+ from utils import prediction_guard_llava_conv
11
+ import lancedb
12
+ from utils import load_json_file
13
+ from mm_rag.embeddings.bridgetower_embeddings import BridgeTowerEmbeddings
14
+ from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
15
+ from mm_rag.MLM.client import PredictionGuardClient
16
+ from mm_rag.MLM.lvlm import LVLM
17
+ from PIL import Image
18
+ from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
19
+ from moviepy.video.io.VideoFileClip import VideoFileClip
20
+ from utils import prediction_guard_llava_conv, encode_image, Conversation, lvlm_inference_with_conversation
21
+
22
+ server_error_msg="**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
23
+
24
+ # function to split video at a timestamp
25
+ def split_video(video_path, timestamp_in_ms, output_video_path: str = "./shared_data/splitted_videos", output_video_name: str="video_tmp.mp4", play_before_sec: int=3, play_after_sec: int=3):
26
+ timestamp_in_sec = int(timestamp_in_ms / 1000)
27
+ # create output_video_name folder if not exist:
28
+ Path(output_video_path).mkdir(parents=True, exist_ok=True)
29
+ output_video = os.path.join(output_video_path, output_video_name)
30
+ with VideoFileClip(video_path) as video:
31
+ duration = video.duration
32
+ start_time = max(timestamp_in_sec - play_before_sec, 0)
33
+ end_time = min(timestamp_in_sec + play_after_sec, duration)
34
+ new = video.subclip(start_time, end_time)
35
+ new.write_videofile(output_video, audio_codec='aac')
36
+ return output_video
37
+
38
+
39
+ prompt_template = """The transcript associated with the image is '{transcript}'. {user_query}"""
40
+
41
+ # define default rag_chain
42
+ def get_default_rag_chain():
43
+ # declare host file
44
+ LANCEDB_HOST_FILE = "./shared_data/.lancedb"
45
+ # declare table name
46
+ TBL_NAME = "demo_tbl"
47
+
48
+ # initialize vectorstore
49
+ db = lancedb.connect(LANCEDB_HOST_FILE)
50
+
51
+ # initialize an BridgeTower embedder
52
+ embedder = BridgeTowerEmbeddings()
53
+
54
+ ## Creating a LanceDB vector store
55
+ vectorstore = MultimodalLanceDB(uri=LANCEDB_HOST_FILE, embedding=embedder, table_name=TBL_NAME)
56
+ ### creating a retriever for the vector store
57
+ retriever_module = vectorstore.as_retriever(search_type='similarity', search_kwargs={"k": 1})
58
+
59
+ # initialize a client as PredictionGuardClien
60
+ client = PredictionGuardClient()
61
+ # initialize LVLM with the given client
62
+ lvlm_inference_module = LVLM(client=client)
63
+
64
+ def prompt_processing(input):
65
+ # get the retrieved results and user's query
66
+ retrieved_results, user_query = input['retrieved_results'], input['user_query']
67
+ # get the first retrieved result by default
68
+ retrieved_result = retrieved_results[0]
69
+ # prompt_template = """The transcript associated with the image is '{transcript}'. {user_query}"""
70
+
71
+ # get all metadata of the retrieved video segment
72
+ metadata_retrieved_video_segment = retrieved_result.metadata['metadata']
73
+
74
+ # get the frame and the corresponding transcript, path to extracted frame, path to whole video, and time stamp of the retrieved video segment.
75
+ transcript = metadata_retrieved_video_segment['transcript']
76
+ frame_path = metadata_retrieved_video_segment['extracted_frame_path']
77
+ return {
78
+ 'prompt': prompt_template.format(transcript=transcript, user_query=user_query),
79
+ 'image' : frame_path,
80
+ 'metadata' : metadata_retrieved_video_segment,
81
+ }
82
+ # initialize prompt processing module as a Langchain RunnableLambda of function prompt_processing
83
+ prompt_processing_module = RunnableLambda(prompt_processing)
84
+
85
+ # the output of this new chain will be a dictionary
86
+ mm_rag_chain_with_retrieved_image = (
87
+ RunnableParallel({"retrieved_results": retriever_module ,
88
+ "user_query": RunnablePassthrough()})
89
+ | prompt_processing_module
90
+ | RunnableParallel({'final_text_output': lvlm_inference_module,
91
+ 'input_to_lvlm' : RunnablePassthrough()})
92
+ )
93
+ return mm_rag_chain_with_retrieved_image
94
+
95
+ class SeparatorStyle(Enum):
96
+ """Different separator style."""
97
+ SINGLE = auto()
98
+
99
+ @dataclasses.dataclass
100
+ class GradioInstance:
101
+ """A class that keeps all conversation history."""
102
+ system: str
103
+ roles: List[str]
104
+ messages: List[List[str]]
105
+ offset: int
106
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
107
+ sep: str = "\n"
108
+ sep2: str = None
109
+ version: str = "Unknown"
110
+ path_to_img: str = None
111
+ video_title: str = None
112
+ path_to_video: str = None
113
+ caption: str = None
114
+ mm_rag_chain: Any = None
115
+
116
+ skip_next: bool = False
117
+
118
+ def _template_caption(self):
119
+ out = ""
120
+ if self.caption is not None:
121
+ out = f"The caption associated with the image is '{self.caption}'. "
122
+ return out
123
+
124
+ def get_prompt_for_rag(self):
125
+ messages = self.messages
126
+ assert len(messages) == 2, "length of current conversation should be 2"
127
+ assert messages[1][1] is None, "the first response message of current conversation should be None"
128
+ ret = messages[0][1]
129
+ return ret
130
+
131
+ def get_conversation_for_lvlm(self):
132
+ pg_conv = prediction_guard_llava_conv.copy()
133
+ image_path = self.path_to_img
134
+ b64_img = encode_image(image_path)
135
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
136
+ if msg is None:
137
+ break
138
+ if i == 0:
139
+ pg_conv.append_message(prediction_guard_llava_conv.roles[0], [msg, b64_img])
140
+ elif i == len(self.messages[self.offset:]) - 2:
141
+ pg_conv.append_message(role, [prompt_template.format(transcript=self.caption, user_query=msg)])
142
+ else:
143
+ pg_conv.append_message(role, [msg])
144
+ return pg_conv
145
+
146
+ def append_message(self, role, message):
147
+ self.messages.append([role, message])
148
+
149
+ def get_images(self, return_pil=False):
150
+ images = []
151
+ if self.path_to_img is not None:
152
+ path_to_image = self.path_to_img
153
+ images.append(path_to_image)
154
+ return images
155
+
156
+ def to_gradio_chatbot(self):
157
+ ret = []
158
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
159
+ if i % 2 == 0:
160
+ if type(msg) is tuple:
161
+ import base64
162
+ from io import BytesIO
163
+ msg, image, image_process_mode = msg
164
+ max_hw, min_hw = max(image.size), min(image.size)
165
+ aspect_ratio = max_hw / min_hw
166
+ max_len, min_len = 800, 400
167
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
168
+ longest_edge = int(shortest_edge * aspect_ratio)
169
+ W, H = image.size
170
+ if H > W:
171
+ H, W = longest_edge, shortest_edge
172
+ else:
173
+ H, W = shortest_edge, longest_edge
174
+ image = image.resize((W, H))
175
+ buffered = BytesIO()
176
+ image.save(buffered, format="JPEG")
177
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
178
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
179
+ msg = img_str + msg.replace('<image>', '').strip()
180
+ ret.append([msg, None])
181
+ else:
182
+ ret.append([msg, None])
183
+ else:
184
+ ret[-1][-1] = msg
185
+ return ret
186
+
187
+ def copy(self):
188
+ return GradioInstance(
189
+ system=self.system,
190
+ roles=self.roles,
191
+ messages=[[x, y] for x, y in self.messages],
192
+ offset=self.offset,
193
+ sep_style=self.sep_style,
194
+ sep=self.sep,
195
+ sep2=self.sep2,
196
+ version=self.version,
197
+ mm_rag_chain=self.mm_rag_chain,
198
+ )
199
+
200
+ def dict(self):
201
+ return {
202
+ "system": self.system,
203
+ "roles": self.roles,
204
+ "messages": self.messages,
205
+ "offset": self.offset,
206
+ "sep": self.sep,
207
+ "sep2": self.sep2,
208
+ "path_to_img": self.path_to_img,
209
+ "video_title" : self.video_title,
210
+ "path_to_video": self.path_to_video,
211
+ "caption" : self.caption,
212
+ }
213
+ def get_path_to_subvideos(self):
214
+ if self.video_title is not None and self.path_to_img is not None:
215
+ info = video_helper_map[self.video_title]
216
+ path = info['path']
217
+ prefix = info['prefix']
218
+ vid_index = self.path_to_img.split('/')[-1]
219
+ vid_index = vid_index.split('_')[-1]
220
+ vid_index = vid_index.replace('.jpg', '')
221
+ ret = f"{prefix}{vid_index}.mp4"
222
+ ret = os.path.join(path, ret)
223
+ return ret
224
+ elif self.path_to_video is not None:
225
+ return self.path_to_video
226
+ return None
227
+
228
+ def get_gradio_instance(mm_rag_chain=None):
229
+ if mm_rag_chain is None:
230
+ mm_rag_chain = get_default_rag_chain()
231
+
232
+ instance = GradioInstance(
233
+ system="",
234
+ roles=prediction_guard_llava_conv.roles,
235
+ messages=[],
236
+ offset=0,
237
+ sep_style=SeparatorStyle.SINGLE,
238
+ sep="\n",
239
+ path_to_img=None,
240
+ video_title=None,
241
+ caption=None,
242
+ mm_rag_chain=mm_rag_chain,
243
+ )
244
+ return instance
245
+
246
+ gr.set_static_paths(paths=["./assets/"])
247
+ theme = gr.themes.Base(
248
+ primary_hue=gr.themes.Color(
249
+ c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#00377c", c700="#00377c", c800="#1e40af", c900="#1e3a8a", c950="#0a0c2b"),
250
+ secondary_hue=gr.themes.Color(
251
+ c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#0054ae", c700="#0054ae", c800="#1e40af", c900="#1e3a8a", c950="#1d3660"),
252
+ ).set(
253
+ body_background_fill_dark='*primary_950',
254
+ body_text_color_dark='*neutral_300',
255
+ border_color_accent='*primary_700',
256
+ border_color_accent_dark='*neutral_800',
257
+ block_background_fill_dark='*primary_950',
258
+ block_border_width='2px',
259
+ block_border_width_dark='2px',
260
+ button_primary_background_fill_dark='*primary_500',
261
+ button_primary_border_color_dark='*primary_500'
262
+ )
263
+
264
+ css='''
265
+ @font-face {
266
+ font-family: IntelOne;
267
+ src: url("/file=./assets/intelone-bodytext-font-family-regular.ttf");
268
+ }
269
+ .gradio-container {background-color: #0a0c2b}
270
+ table {
271
+ border-collapse: collapse;
272
+ border: none;
273
+ }
274
+ '''
275
+
276
+ ## <td style="border-bottom:0"><img src="file/assets/DCAI_logo.png" height="300" width="300"></td>
277
+
278
+ # html_title = '''
279
+ # <table style="bordercolor=#0a0c2b; border=0">
280
+ # <tr style="height:150px; border:0">
281
+ # <td style="border:0"><img src="/file=../assets/intel-labs.png" height="100" width="100"></td>
282
+ # <td style="vertical-align:bottom; border:0">
283
+ # <p style="font-size:xx-large;font-family:IntelOne, Georgia, sans-serif;color: white;">
284
+ # Multimodal RAG:
285
+ # <br>
286
+ # Chat with Videos
287
+ # </p>
288
+ # </td>
289
+ # <td style="border:0"><img src="/file=../assets/gaudi.png" width="100" height="100"></td>
290
+
291
+ # <td style="border:0"><img src="/file=../assets/IDC7.png" width="300" height="350"></td>
292
+ # <td style="border:0"><img src="/file=../assets/prediction_guard3.png" width="120" height="120"></td>
293
+ # </tr>
294
+ # </table>
295
+
296
+ # '''
297
+
298
+ html_title = '''
299
+ <table style="bordercolor=#0a0c2b; border=0">
300
+ <tr style="height:150px; border:0">
301
+ <td style="border:0"><img src="/file=./assets/header.png"></td>
302
+ </tr>
303
+ </table>
304
+
305
+ '''
306
+
307
+ #<td style="border:0"><img src="/file=../assets/xeon.png" width="100" height="100"></td>
308
+ dropdown_list = [
309
+ "What is the name of one of the astronauts?",
310
+ "An astronaut's spacewalk",
311
+ "What does the astronaut say?",
312
+
313
+ ]
314
+
315
+ no_change_btn = gr.Button()
316
+ enable_btn = gr.Button(interactive=True)
317
+ disable_btn = gr.Button(interactive=False)
318
+
319
+ def clear_history(state, request: gr.Request):
320
+ state = get_gradio_instance(state.mm_rag_chain)
321
+ return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 1
322
+
323
+ def add_text(state, text, request: gr.Request):
324
+ if len(text) <= 0 :
325
+ state.skip_next = True
326
+ return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 1
327
+
328
+ text = text[:1536] # Hard cut-off
329
+
330
+ state.append_message(state.roles[0], text)
331
+ state.append_message(state.roles[1], None)
332
+ state.skip_next = False
333
+ return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 1
334
+
335
+ def http_bot(
336
+ state, request: gr.Request
337
+ ):
338
+ start_tstamp = time.time()
339
+
340
+ if state.skip_next:
341
+ # This generate call is skipped due to invalid inputs
342
+ path_to_sub_videos = state.get_path_to_subvideos()
343
+ yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (no_change_btn,) * 1
344
+ return
345
+
346
+ if len(state.messages) == state.offset + 2:
347
+ # First round of conversation
348
+ new_state = get_gradio_instance(state.mm_rag_chain)
349
+ new_state.append_message(new_state.roles[0], state.messages[-2][1])
350
+ new_state.append_message(new_state.roles[1], None)
351
+ state = new_state
352
+
353
+ all_images = state.get_images(return_pil=False)
354
+
355
+ # Make requests
356
+ is_very_first_query = True
357
+ if len(all_images) == 0:
358
+ # first query need to do RAG
359
+ # Construct prompt
360
+ prompt_or_conversation = state.get_prompt_for_rag()
361
+ else:
362
+ # subsequence queries, no need to do Retrieval
363
+ is_very_first_query = False
364
+ prompt_or_conversation = state.get_conversation_for_lvlm()
365
+
366
+ if is_very_first_query:
367
+ executor = state.mm_rag_chain
368
+ else:
369
+ executor = lvlm_inference_with_conversation
370
+
371
+ state.messages[-1][-1] = "▌"
372
+ path_to_sub_videos = state.get_path_to_subvideos()
373
+ yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (disable_btn,) * 1
374
+
375
+ try:
376
+ if is_very_first_query:
377
+ # get response by invoke executor chain
378
+ response = executor.invoke(prompt_or_conversation)
379
+ message = response['final_text_output']
380
+ if 'metadata' in response['input_to_lvlm']:
381
+ metadata = response['input_to_lvlm']['metadata']
382
+ if (state.path_to_img is None
383
+ and 'input_to_lvlm' in response
384
+ and 'image' in response['input_to_lvlm']
385
+ ):
386
+ state.path_to_img = response['input_to_lvlm']['image']
387
+
388
+ if state.path_to_video is None and 'video_path' in metadata:
389
+ video_path = metadata['video_path']
390
+ mid_time_ms = metadata['mid_time_ms']
391
+ splited_video_path = split_video(video_path, mid_time_ms)
392
+ state.path_to_video = splited_video_path
393
+
394
+ if state.caption is None and 'transcript' in metadata:
395
+ state.caption = metadata['transcript']
396
+ else:
397
+ raise ValueError("Response's format is changed")
398
+ else:
399
+ # get the response message by directly call PredictionGuardAPI
400
+ message = executor(prompt_or_conversation)
401
+
402
+ except Exception as e:
403
+ print(e)
404
+ state.messages[-1][-1] = server_error_msg
405
+ yield (state, state.to_gradio_chatbot(), None) + (
406
+ enable_btn,
407
+ )
408
+ return
409
+
410
+ state.messages[-1][-1] = message
411
+ path_to_sub_videos = state.get_path_to_subvideos()
412
+ # path_to_image = state.path_to_img
413
+ # caption = state.caption
414
+ # # print(path_to_sub_videos)
415
+ # # print(path_to_image)
416
+ # # print('caption: ', caption)
417
+ yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (enable_btn,) * 1
418
+
419
+ finish_tstamp = time.time()
420
+ return
421
+
422
+ def get_demo(rag_chain=None):
423
+ if rag_chain is None:
424
+ rag_chain = get_default_rag_chain()
425
+
426
+ with gr.Blocks(theme=theme, css=css) as demo:
427
+ # gr.Markdown(description)
428
+ instance = get_gradio_instance(rag_chain)
429
+ state = gr.State(instance)
430
+ demo.load(
431
+ None,
432
+ None,
433
+ js="""
434
+ () => {
435
+ const params = new URLSearchParams(window.location.search);
436
+ if (!params.has('__theme')) {
437
+ params.set('__theme', 'dark');
438
+ window.location.search = params.toString();
439
+ }
440
+ }""",
441
+ )
442
+ gr.HTML(value=html_title)
443
+ with gr.Row():
444
+ with gr.Column(scale=4):
445
+ video = gr.Video(height=512, width=512, elem_id="video", interactive=False )
446
+ with gr.Column(scale=7):
447
+ chatbot = gr.Chatbot(
448
+ elem_id="chatbot", label="Multimodal RAG Chatbot", height=512,
449
+ )
450
+ with gr.Row():
451
+ with gr.Column(scale=8):
452
+ # textbox.render()
453
+ textbox = gr.Dropdown(
454
+ dropdown_list,
455
+ allow_custom_value=True,
456
+ # show_label=False,
457
+ # container=False,
458
+ label="Query",
459
+ info="Enter your query here or choose a sample from the dropdown list!"
460
+ )
461
+ with gr.Column(scale=1, min_width=50):
462
+ submit_btn = gr.Button(
463
+ value="Send", variant="primary", interactive=True
464
+ )
465
+ with gr.Row(elem_id="buttons") as button_row:
466
+ clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
467
+
468
+ btn_list = [clear_btn]
469
+
470
+ clear_btn.click(
471
+ clear_history, [state], [state, chatbot, textbox, video] + btn_list
472
+ )
473
+ submit_btn.click(
474
+ add_text,
475
+ [state, textbox],
476
+ [state, chatbot, textbox,] + btn_list,
477
+ ).then(
478
+ http_bot,
479
+ [state],
480
+ [state, chatbot, video] + btn_list,
481
+ )
482
+ return demo
483
+
main.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Start the gradio server
2
+ from gradio_utils import get_demo
3
+
4
+ #You will need to restart the kernel each time you rerun this cell;
5
+ #otherwise, the port will not be available.
6
+
7
+ debug = False # change this to True if you want to debug
8
+
9
+ demo = get_demo()
10
+ demo.launch(server_name="0.0.0.0", server_port=9999, debug=debug)
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ gradio
s1-lrn-gradio.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ def greet(name, intensity):
4
+ return "Hello, " + name + "!" * int(intensity)
5
+
6
+ demo = gr.Interface(
7
+ fn=greet,
8
+ inputs=["text", "slider"],
9
+ outputs=["text"],
10
+ )
11
+
12
+ demo.launch()