Spaces:
Running
Running
update gradio_web_server.py and model_worker.py
Browse files- gradio_web_server.py +6 -1
- model_worker.py +28 -17
gradio_web_server.py
CHANGED
|
@@ -818,7 +818,7 @@ if __name__ == "__main__":
|
|
| 818 |
parser = argparse.ArgumentParser()
|
| 819 |
parser.add_argument("--host", type=str, default="0.0.0.0")
|
| 820 |
parser.add_argument("--port", type=int, default=11000)
|
| 821 |
-
parser.add_argument("--controller-url", type=str, default=
|
| 822 |
parser.add_argument("--concurrency-count", type=int, default=10)
|
| 823 |
parser.add_argument(
|
| 824 |
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
|
@@ -829,6 +829,11 @@ if __name__ == "__main__":
|
|
| 829 |
parser.add_argument("--embed", action="store_true")
|
| 830 |
args = parser.parse_args()
|
| 831 |
logger.info(f"args: {args}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 832 |
|
| 833 |
models = get_model_list()
|
| 834 |
|
|
|
|
| 818 |
parser = argparse.ArgumentParser()
|
| 819 |
parser.add_argument("--host", type=str, default="0.0.0.0")
|
| 820 |
parser.add_argument("--port", type=int, default=11000)
|
| 821 |
+
parser.add_argument("--controller-url", type=str, default=None)
|
| 822 |
parser.add_argument("--concurrency-count", type=int, default=10)
|
| 823 |
parser.add_argument(
|
| 824 |
"--model-list-mode", type=str, default="once", choices=["once", "reload"]
|
|
|
|
| 829 |
parser.add_argument("--embed", action="store_true")
|
| 830 |
args = parser.parse_args()
|
| 831 |
logger.info(f"args: {args}")
|
| 832 |
+
if not args.controller_url:
|
| 833 |
+
args.controller_url = os.environ.get("CONTROLLER_URL", None)
|
| 834 |
+
|
| 835 |
+
if not args.controller_url:
|
| 836 |
+
raise ValueError("controller-url is required.")
|
| 837 |
|
| 838 |
models = get_model_list()
|
| 839 |
|
model_worker.py
CHANGED
|
@@ -160,6 +160,25 @@ def split_model(model_name):
|
|
| 160 |
return device_map
|
| 161 |
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
class ModelWorker:
|
| 164 |
def __init__(
|
| 165 |
self,
|
|
@@ -325,8 +344,6 @@ class ModelWorker:
|
|
| 325 |
"queue_length": self.get_queue_length(),
|
| 326 |
}
|
| 327 |
|
| 328 |
-
# @torch.inference_mode()
|
| 329 |
-
@spaces.GPU(duration=120)
|
| 330 |
def generate_stream(self, params):
|
| 331 |
system_message = params["prompt"][0]["content"]
|
| 332 |
send_messages = params["prompt"][1:]
|
|
@@ -428,20 +445,14 @@ class ModelWorker:
|
|
| 428 |
streamer=streamer,
|
| 429 |
)
|
| 430 |
logger.info(f"Generation config: {generation_config}")
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
history=history,
|
| 440 |
-
return_history=False,
|
| 441 |
-
generation_config=generation_config,
|
| 442 |
-
),
|
| 443 |
-
)
|
| 444 |
-
thread.start()
|
| 445 |
|
| 446 |
generated_text = ""
|
| 447 |
for new_text in streamer:
|
|
@@ -541,4 +552,4 @@ if __name__ == "__main__":
|
|
| 541 |
args.load_8bit,
|
| 542 |
args.device,
|
| 543 |
)
|
| 544 |
-
uvicorn.run(app, host=args.host, port=args.port, log_level="info"
|
|
|
|
| 160 |
return device_map
|
| 161 |
|
| 162 |
|
| 163 |
+
@spaces.GPU(duration=120)
|
| 164 |
+
def multi_thread_infer(
|
| 165 |
+
model, tokenizer, pixel_values, question, history, generation_config
|
| 166 |
+
):
|
| 167 |
+
with torch.no_grad():
|
| 168 |
+
thread = Thread(
|
| 169 |
+
target=model.chat,
|
| 170 |
+
kwargs=dict(
|
| 171 |
+
tokenizer=tokenizer,
|
| 172 |
+
pixel_values=pixel_values,
|
| 173 |
+
question=question,
|
| 174 |
+
history=history,
|
| 175 |
+
return_history=False,
|
| 176 |
+
generation_config=generation_config,
|
| 177 |
+
),
|
| 178 |
+
)
|
| 179 |
+
thread.start()
|
| 180 |
+
|
| 181 |
+
|
| 182 |
class ModelWorker:
|
| 183 |
def __init__(
|
| 184 |
self,
|
|
|
|
| 344 |
"queue_length": self.get_queue_length(),
|
| 345 |
}
|
| 346 |
|
|
|
|
|
|
|
| 347 |
def generate_stream(self, params):
|
| 348 |
system_message = params["prompt"][0]["content"]
|
| 349 |
send_messages = params["prompt"][1:]
|
|
|
|
| 445 |
streamer=streamer,
|
| 446 |
)
|
| 447 |
logger.info(f"Generation config: {generation_config}")
|
| 448 |
+
multi_thread_infer(
|
| 449 |
+
self.model,
|
| 450 |
+
self.tokenizer,
|
| 451 |
+
pixel_values,
|
| 452 |
+
question,
|
| 453 |
+
history,
|
| 454 |
+
generation_config,
|
| 455 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
|
| 457 |
generated_text = ""
|
| 458 |
for new_text in streamer:
|
|
|
|
| 552 |
args.load_8bit,
|
| 553 |
args.device,
|
| 554 |
)
|
| 555 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|