File size: 5,766 Bytes
2ae242d f1c1f42 2ae242d f1c1f42 2ae242d f1c1f42 e4599d1 f1c1f42 e4599d1 f1c1f42 e4599d1 f1c1f42 e4599d1 f1c1f42 e4599d1 f1c1f42 2ae242d f1c1f42 2ae242d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import argparse
import io
import json
import logging
import os
import sys
import time
from typing import Optional
import requests
def configure_logging(verbose: bool) -> None:
log_level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=log_level,
format="%(asctime)s - %(levelname)s - %(message)s"
)
def get_base_url(cli_base_url: Optional[str]) -> str:
if cli_base_url:
return cli_base_url.rstrip("/")
env_base = os.getenv("BASE_URL")
if env_base:
return env_base.rstrip("/")
# Fallback to HF style URL if provided via POSTMAN collection, else localhost
return "http://localhost:7860"
def wait_for_model(base_url: str, timeout_seconds: int = 300) -> None:
deadline = time.time() + timeout_seconds
health_url = f"{base_url}/health"
logging.info("Waiting for model to load at %s", health_url)
last_status = None
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
while time.time() < deadline:
try:
resp = requests.get(health_url, headers=headers, timeout=15)
if resp.ok:
data = resp.json()
last_status = data
if data.get("model_loaded"):
logging.info("Model loaded: %s", json.dumps(data))
return
logging.info("Health: %s", json.dumps(data))
else:
logging.warning("Health check HTTP %s", resp.status_code)
except Exception as e:
logging.warning("Health check error: %s", str(e))
time.sleep(3)
raise RuntimeError("Model did not load before timeout. Last health: %s" % (last_status,))
def upload_image(base_url: str, image_path: str, auth_bearer: Optional[str], app_check: Optional[str]) -> dict:
url = f"{base_url}/upload"
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
if auth_bearer:
headers["Authorization"] = f"Bearer {auth_bearer}"
if app_check:
headers["X-Firebase-AppCheck"] = app_check
with open(image_path, "rb") as f:
files = {"file": (os.path.basename(image_path), f, "image/jpeg")}
resp = requests.post(url, files=files, headers=headers, timeout=120)
if not resp.ok:
raise RuntimeError("Upload failed: HTTP %s %s" % (resp.status_code, resp.text))
data = resp.json()
logging.info("Upload response: %s", json.dumps(data))
return data
def colorize_image(base_url: str, image_path: str, auth_bearer: Optional[str], app_check: Optional[str]) -> dict:
url = f"{base_url}/colorize"
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
if auth_bearer:
headers["Authorization"] = f"Bearer {auth_bearer}"
if app_check:
headers["X-Firebase-AppCheck"] = app_check
with open(image_path, "rb") as f:
files = {"file": (os.path.basename(image_path), f, "image/jpeg")}
resp = requests.post(url, files=files, headers=headers, timeout=900)
if not resp.ok:
raise RuntimeError("Colorize failed: HTTP %s %s" % (resp.status_code, resp.text))
data = resp.json()
logging.info("Colorize response: %s", json.dumps(data))
return data
def download_result(base_url: str, result_id: str, output_path: str, auth_bearer: Optional[str], app_check: Optional[str]) -> None:
url = f"{base_url}/download/{result_id}"
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
if auth_bearer:
headers["Authorization"] = f"Bearer {auth_bearer}"
if app_check:
headers["X-Firebase-AppCheck"] = app_check
resp = requests.get(url, headers=headers, stream=True, timeout=300)
if not resp.ok:
raise RuntimeError("Download failed: HTTP %s %s" % (resp.status_code, resp.text))
with open(output_path, "wb") as out:
for chunk in resp.iter_content(chunk_size=8192):
if chunk:
out.write(chunk)
logging.info("Saved colorized image to: %s", output_path)
def main() -> int:
parser = argparse.ArgumentParser(description="End-to-end test for Colorize API")
parser.add_argument("--base-url", type=str, help="API base URL, e.g. https://<space>.hf.space")
parser.add_argument("--image", type=str, required=True, help="Path to input image")
parser.add_argument("--out", type=str, default="colorized_result.jpg", help="Path to save colorized image")
parser.add_argument("--auth", type=str, default=os.getenv("ID_TOKEN", ""), help="Optional Firebase id_token")
parser.add_argument("--app-check", type=str, default=os.getenv("APP_CHECK_TOKEN", ""), help="Optional App Check token")
parser.add_argument("--skip-wait", action="store_true", help="Skip waiting for model to load")
parser.add_argument("--verbose", action="store_true", help="Verbose logging")
args = parser.parse_args()
configure_logging(args.verbose)
base_url = get_base_url(args.base_url)
image_path = args.image
if not os.path.exists(image_path):
logging.error("Image not found: %s", image_path)
return 1
if not args.skip_wait:
try:
wait_for_model(base_url, timeout_seconds=600)
except Exception as e:
logging.warning("Continuing despite health wait failure: %s", str(e))
auth_bearer = args.auth.strip() or None
app_check = args.app_check.strip() or None
try:
upload_resp = upload_image(base_url, image_path, auth_bearer, app_check)
except Exception as e:
logging.error("Upload error: %s", str(e))
return 1
try:
colorize_resp = colorize_image(base_url, image_path, auth_bearer, app_check)
except Exception as e:
logging.error("Colorize error: %s", str(e))
return 1
result_id = colorize_resp.get("result_id")
if not result_id:
logging.error("No result_id in response: %s", json.dumps(colorize_resp))
return 1
try:
download_result(base_url, result_id, args.out, auth_bearer, app_check)
except Exception as e:
logging.error("Download error: %s", str(e))
return 1
logging.info("Test workflow completed successfully.")
return 0
if __name__ == "__main__":
sys.exit(main())
|