|
|
import argparse |
|
|
import io |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
from typing import Optional |
|
|
from datetime import datetime, timedelta |
|
|
|
|
|
import requests |
|
|
from pymongo import MongoClient |
|
|
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError |
|
|
|
|
|
|
|
|
def configure_logging(verbose: bool) -> None: |
|
|
log_level = logging.DEBUG if verbose else logging.INFO |
|
|
logging.basicConfig( |
|
|
level=log_level, |
|
|
format="%(asctime)s - %(levelname)s - %(message)s" |
|
|
) |
|
|
|
|
|
|
|
|
def get_base_url(cli_base_url: Optional[str]) -> str: |
|
|
if cli_base_url: |
|
|
return cli_base_url.rstrip("/") |
|
|
env_base = os.getenv("BASE_URL") |
|
|
if env_base: |
|
|
return env_base.rstrip("/") |
|
|
|
|
|
return "http://localhost:7860" |
|
|
|
|
|
|
|
|
def wait_for_model(base_url: str, timeout_seconds: int = 300) -> None: |
|
|
deadline = time.time() + timeout_seconds |
|
|
health_url = f"{base_url}/health" |
|
|
logging.info("Waiting for model to load at %s", health_url) |
|
|
last_status = None |
|
|
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"} |
|
|
while time.time() < deadline: |
|
|
try: |
|
|
resp = requests.get(health_url, headers=headers, timeout=15) |
|
|
if resp.ok: |
|
|
data = resp.json() |
|
|
last_status = data |
|
|
if data.get("model_loaded"): |
|
|
logging.info("Model loaded: %s", json.dumps(data)) |
|
|
return |
|
|
logging.info("Health: %s", json.dumps(data)) |
|
|
else: |
|
|
logging.warning("Health check HTTP %s", resp.status_code) |
|
|
except Exception as e: |
|
|
logging.warning("Health check error: %s", str(e)) |
|
|
time.sleep(3) |
|
|
raise RuntimeError("Model did not load before timeout. Last health: %s" % (last_status,)) |
|
|
|
|
|
|
|
|
def upload_image(base_url: str, image_path: str, auth_bearer: Optional[str], app_check: Optional[str]) -> dict: |
|
|
url = f"{base_url}/upload" |
|
|
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"} |
|
|
if auth_bearer: |
|
|
headers["Authorization"] = f"Bearer {auth_bearer}" |
|
|
if app_check: |
|
|
headers["X-Firebase-AppCheck"] = app_check |
|
|
with open(image_path, "rb") as f: |
|
|
files = {"file": (os.path.basename(image_path), f, "image/jpeg")} |
|
|
resp = requests.post(url, files=files, headers=headers, timeout=120) |
|
|
if not resp.ok: |
|
|
raise RuntimeError("Upload failed: HTTP %s %s" % (resp.status_code, resp.text)) |
|
|
data = resp.json() |
|
|
logging.info("Upload response: %s", json.dumps(data)) |
|
|
return data |
|
|
|
|
|
|
|
|
def colorize_image(base_url: str, image_path: str, auth_bearer: Optional[str], app_check: Optional[str]) -> dict: |
|
|
url = f"{base_url}/colorize" |
|
|
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"} |
|
|
if auth_bearer: |
|
|
headers["Authorization"] = f"Bearer {auth_bearer}" |
|
|
if app_check: |
|
|
headers["X-Firebase-AppCheck"] = app_check |
|
|
with open(image_path, "rb") as f: |
|
|
files = {"file": (os.path.basename(image_path), f, "image/jpeg")} |
|
|
resp = requests.post(url, files=files, headers=headers, timeout=900) |
|
|
if not resp.ok: |
|
|
raise RuntimeError("Colorize failed: HTTP %s %s" % (resp.status_code, resp.text)) |
|
|
data = resp.json() |
|
|
logging.info("Colorize response: %s", json.dumps(data)) |
|
|
return data |
|
|
|
|
|
|
|
|
def download_result(base_url: str, result_id: str, output_path: str, auth_bearer: Optional[str], app_check: Optional[str]) -> None: |
|
|
url = f"{base_url}/download/{result_id}" |
|
|
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"} |
|
|
if auth_bearer: |
|
|
headers["Authorization"] = f"Bearer {auth_bearer}" |
|
|
if app_check: |
|
|
headers["X-Firebase-AppCheck"] = app_check |
|
|
resp = requests.get(url, headers=headers, stream=True, timeout=300) |
|
|
if not resp.ok: |
|
|
raise RuntimeError("Download failed: HTTP %s %s" % (resp.status_code, resp.text)) |
|
|
with open(output_path, "wb") as out: |
|
|
for chunk in resp.iter_content(chunk_size=8192): |
|
|
if chunk: |
|
|
out.write(chunk) |
|
|
logging.info("Saved colorized image to: %s", output_path) |
|
|
|
|
|
|
|
|
def verify_mongodb_storage(mongodb_uri: str, db_name: str = "colorization_db", |
|
|
test_endpoint: str = None, wait_seconds: int = 5) -> bool: |
|
|
""" |
|
|
Verify that API calls were stored in MongoDB. |
|
|
|
|
|
Args: |
|
|
mongodb_uri: MongoDB connection string |
|
|
db_name: Database name |
|
|
test_endpoint: Specific endpoint to check (optional) |
|
|
wait_seconds: Wait time before checking (to allow async writes) |
|
|
|
|
|
Returns: |
|
|
True if data is found, False otherwise |
|
|
""" |
|
|
logging.info("Waiting %d seconds for MongoDB writes to complete...", wait_seconds) |
|
|
time.sleep(wait_seconds) |
|
|
|
|
|
try: |
|
|
client = MongoClient(mongodb_uri, serverSelectionTimeoutMS=5000) |
|
|
|
|
|
client.admin.command('ping') |
|
|
logging.info("✅ Connected to MongoDB successfully") |
|
|
|
|
|
db = client[db_name] |
|
|
|
|
|
|
|
|
api_calls_collection = db["api_calls"] |
|
|
recent_calls = api_calls_collection.find({ |
|
|
"timestamp": {"$gte": datetime.utcnow() - timedelta(minutes=5)} |
|
|
}).sort("timestamp", -1).limit(10) |
|
|
|
|
|
api_calls_count = api_calls_collection.count_documents({ |
|
|
"timestamp": {"$gte": datetime.utcnow() - timedelta(minutes=5)} |
|
|
}) |
|
|
|
|
|
logging.info("Found %d API calls in the last 5 minutes", api_calls_count) |
|
|
|
|
|
if api_calls_count > 0: |
|
|
logging.info("Recent API calls:") |
|
|
for call in list(recent_calls)[:5]: |
|
|
logging.info(" - %s %s at %s (status: %d)", |
|
|
call.get("method", "N/A"), |
|
|
call.get("endpoint", "N/A"), |
|
|
call.get("timestamp", "N/A"), |
|
|
call.get("status_code", 0)) |
|
|
|
|
|
|
|
|
uploads_collection = db["image_uploads"] |
|
|
recent_uploads = uploads_collection.find({ |
|
|
"uploaded_at": {"$gte": datetime.utcnow() - timedelta(minutes=5)} |
|
|
}).sort("uploaded_at", -1).limit(5) |
|
|
|
|
|
uploads_count = uploads_collection.count_documents({ |
|
|
"uploaded_at": {"$gte": datetime.utcnow() - timedelta(minutes=5)} |
|
|
}) |
|
|
|
|
|
logging.info("Found %d image uploads in the last 5 minutes", uploads_count) |
|
|
|
|
|
if uploads_count > 0: |
|
|
logging.info("Recent uploads:") |
|
|
for upload in list(recent_uploads)[:3]: |
|
|
logging.info(" - Image ID: %s, Size: %d bytes, Uploaded at: %s", |
|
|
upload.get("image_id", "N/A"), |
|
|
upload.get("file_size", 0), |
|
|
upload.get("uploaded_at", "N/A")) |
|
|
|
|
|
|
|
|
colorizations_collection = db["colorizations"] |
|
|
recent_colorizations = colorizations_collection.find({ |
|
|
"created_at": {"$gte": datetime.utcnow() - timedelta(minutes=5)} |
|
|
}).sort("created_at", -1).limit(5) |
|
|
|
|
|
colorizations_count = colorizations_collection.count_documents({ |
|
|
"created_at": {"$gte": datetime.utcnow() - timedelta(minutes=5)} |
|
|
}) |
|
|
|
|
|
logging.info("Found %d colorizations in the last 5 minutes", colorizations_count) |
|
|
|
|
|
if colorizations_count > 0: |
|
|
logging.info("Recent colorizations:") |
|
|
for colorization in list(recent_colorizations)[:3]: |
|
|
logging.info(" - Result ID: %s, Model: %s, Time: %.2fs, Created at: %s", |
|
|
colorization.get("result_id", "N/A"), |
|
|
colorization.get("model_type", "N/A"), |
|
|
colorization.get("processing_time", 0), |
|
|
colorization.get("created_at", "N/A")) |
|
|
|
|
|
client.close() |
|
|
|
|
|
|
|
|
if api_calls_count > 0 or uploads_count > 0 or colorizations_count > 0: |
|
|
logging.info("✅ MongoDB storage verification PASSED") |
|
|
return True |
|
|
else: |
|
|
logging.warning("⚠️ No recent data found in MongoDB (this might be normal if no API calls were made)") |
|
|
return False |
|
|
|
|
|
except (ConnectionFailure, ServerSelectionTimeoutError) as e: |
|
|
logging.error("❌ Failed to connect to MongoDB: %s", str(e)) |
|
|
return False |
|
|
except Exception as e: |
|
|
logging.error("❌ MongoDB verification error: %s", str(e)) |
|
|
return False |
|
|
|
|
|
|
|
|
def main() -> int: |
|
|
parser = argparse.ArgumentParser(description="End-to-end test for Colorize API") |
|
|
parser.add_argument("--base-url", type=str, help="API base URL, e.g. https://<space>.hf.space") |
|
|
parser.add_argument("--image", type=str, required=True, help="Path to input image") |
|
|
parser.add_argument("--out", type=str, default="colorized_result.jpg", help="Path to save colorized image") |
|
|
parser.add_argument("--auth", type=str, default=os.getenv("ID_TOKEN", ""), help="Optional Firebase id_token") |
|
|
parser.add_argument("--app-check", type=str, default=os.getenv("APP_CHECK_TOKEN", ""), help="Optional App Check token") |
|
|
parser.add_argument("--skip-wait", action="store_true", help="Skip waiting for model to load") |
|
|
parser.add_argument("--verbose", action="store_true", help="Verbose logging") |
|
|
parser.add_argument("--verify-mongodb", action="store_true", help="Verify MongoDB storage after API calls") |
|
|
parser.add_argument("--mongodb-uri", type=str, default=os.getenv("MONGODB_URI", ""), help="MongoDB connection string for verification") |
|
|
parser.add_argument("--mongodb-db", type=str, default=os.getenv("MONGODB_DB_NAME", "colorization_db"), help="MongoDB database name") |
|
|
args = parser.parse_args() |
|
|
|
|
|
configure_logging(args.verbose) |
|
|
base_url = get_base_url(args.base_url) |
|
|
image_path = args.image |
|
|
|
|
|
if not os.path.exists(image_path): |
|
|
logging.error("Image not found: %s", image_path) |
|
|
return 1 |
|
|
|
|
|
if not args.skip_wait: |
|
|
try: |
|
|
wait_for_model(base_url, timeout_seconds=600) |
|
|
except Exception as e: |
|
|
logging.warning("Continuing despite health wait failure: %s", str(e)) |
|
|
|
|
|
auth_bearer = args.auth.strip() or None |
|
|
app_check = args.app_check.strip() or None |
|
|
|
|
|
try: |
|
|
upload_resp = upload_image(base_url, image_path, auth_bearer, app_check) |
|
|
except Exception as e: |
|
|
logging.error("Upload error: %s", str(e)) |
|
|
return 1 |
|
|
|
|
|
try: |
|
|
colorize_resp = colorize_image(base_url, image_path, auth_bearer, app_check) |
|
|
except Exception as e: |
|
|
logging.error("Colorize error: %s", str(e)) |
|
|
return 1 |
|
|
|
|
|
result_id = colorize_resp.get("result_id") |
|
|
if not result_id: |
|
|
logging.error("No result_id in response: %s", json.dumps(colorize_resp)) |
|
|
return 1 |
|
|
|
|
|
try: |
|
|
download_result(base_url, result_id, args.out, auth_bearer, app_check) |
|
|
except Exception as e: |
|
|
logging.error("Download error: %s", str(e)) |
|
|
return 1 |
|
|
|
|
|
logging.info("Test workflow completed successfully.") |
|
|
|
|
|
|
|
|
if args.verify_mongodb: |
|
|
if not args.mongodb_uri: |
|
|
logging.warning("⚠️ MongoDB URI not provided. Skipping MongoDB verification.") |
|
|
logging.info("Set MONGODB_URI environment variable or use --mongodb-uri flag") |
|
|
else: |
|
|
logging.info("=" * 60) |
|
|
logging.info("Verifying MongoDB storage...") |
|
|
logging.info("=" * 60) |
|
|
verify_mongodb_storage(args.mongodb_uri, args.mongodb_db) |
|
|
|
|
|
return 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
sys.exit(main()) |
|
|
|
|
|
|