wan2gp / shared /utils /process_locks.py
vidfom's picture
Upload folder using huggingface_hub
618f472 verified
import time
import threading
import torch
gen_lock = threading.Lock()
def get_gen_info(state):
cache = state.get("gen", None)
if cache == None:
cache = dict()
state["gen"] = cache
return cache
def any_GPU_process_running(state, process_id):
gen = get_gen_info(state)
#"process:" + process_id
with gen_lock:
process_status = gen.get("process_status", None)
return process_status is not None
def acquire_GPU_ressources(state, process_id, process_name, gr = None):
gen = get_gen_info(state)
original_process_status = None
while True:
with gen_lock:
process_hierarchy = gen.get("process_hierarchy", None)
if process_hierarchy is None:
process_hierarchy = dict()
gen["process_hierarchy"]= process_hierarchy
process_status = gen.get("process_status", None)
if process_status is None:
original_process_status = process_status
gen["process_status"] = "process:" + process_id
break
elif process_status == "process:main":
original_process_status = process_status
gen["process_status"] = "request:" + process_id
gen["pause_msg"] = f"Generation Suspended while using {process_name}"
break
elif process_status == "process:" + process_id:
break
time.sleep(0.1)
if original_process_status is not None:
total_wait = 0
wait_time = 0.1
wait_msg_displayed = False
while True:
with gen_lock:
process_status = gen.get("process_status", None)
if process_status == "process:" + process_id: break
if process_status is None:
# handle case when main process has finished at some point in between the last check and now
gen["process_status"] = "process:" + process_id
break
total_wait += wait_time
if round(total_wait,2) >= 5 and gr is not None and not wait_msg_displayed:
wait_msg_displayed = True
gr.Info(f"Process {process_name} is Suspended while waiting that GPU Ressources become available")
time.sleep(wait_time)
with gen_lock:
process_hierarchy[process_id] = original_process_status
torch.cuda.synchronize()
def release_GPU_ressources(state, process_id):
gen = get_gen_info(state)
torch.cuda.synchronize()
with gen_lock:
process_hierarchy = gen.get("process_hierarchy", {})
gen["process_status"] = process_hierarchy.get(process_id, None)