Spaces:
Running
Running
update model_worker.py
Browse files- model_worker.py +15 -14
model_worker.py
CHANGED
|
@@ -325,7 +325,8 @@ class ModelWorker:
|
|
| 325 |
"queue_length": self.get_queue_length(),
|
| 326 |
}
|
| 327 |
|
| 328 |
-
@torch.inference_mode()
|
|
|
|
| 329 |
def generate_stream(self, params):
|
| 330 |
system_message = params["prompt"][0]["content"]
|
| 331 |
send_messages = params["prompt"][1:]
|
|
@@ -428,18 +429,19 @@ class ModelWorker:
|
|
| 428 |
)
|
| 429 |
logger.info(f"Generation config: {generation_config}")
|
| 430 |
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
|
|
|
| 443 |
|
| 444 |
generated_text = ""
|
| 445 |
for new_text in streamer:
|
|
@@ -453,7 +455,6 @@ class ModelWorker:
|
|
| 453 |
)
|
| 454 |
self.model.system_message = old_system_message
|
| 455 |
|
| 456 |
-
@spaces.GPU(duration=120)
|
| 457 |
def generate_stream_gate(self, params):
|
| 458 |
try:
|
| 459 |
for x in self.generate_stream(params):
|
|
|
|
| 325 |
"queue_length": self.get_queue_length(),
|
| 326 |
}
|
| 327 |
|
| 328 |
+
# @torch.inference_mode()
|
| 329 |
+
@spaces.GPU(duration=120)
|
| 330 |
def generate_stream(self, params):
|
| 331 |
system_message = params["prompt"][0]["content"]
|
| 332 |
send_messages = params["prompt"][1:]
|
|
|
|
| 429 |
)
|
| 430 |
logger.info(f"Generation config: {generation_config}")
|
| 431 |
|
| 432 |
+
with torch.no_grad():
|
| 433 |
+
thread = Thread(
|
| 434 |
+
target=self.model.chat,
|
| 435 |
+
kwargs=dict(
|
| 436 |
+
tokenizer=self.tokenizer,
|
| 437 |
+
pixel_values=pixel_values,
|
| 438 |
+
question=question,
|
| 439 |
+
history=history,
|
| 440 |
+
return_history=False,
|
| 441 |
+
generation_config=generation_config,
|
| 442 |
+
),
|
| 443 |
+
)
|
| 444 |
+
thread.start()
|
| 445 |
|
| 446 |
generated_text = ""
|
| 447 |
for new_text in streamer:
|
|
|
|
| 455 |
)
|
| 456 |
self.model.system_message = old_system_message
|
| 457 |
|
|
|
|
| 458 |
def generate_stream_gate(self, params):
|
| 459 |
try:
|
| 460 |
for x in self.generate_stream(params):
|