sachin sharma
commited on
Commit
ยท
ebbcd26
1
Parent(s):
4f88f85
added test case generation
Browse files- scripts/generate_test_datasets.py +411 -0
- scripts/test_datasets.py +382 -0
scripts/generate_test_datasets.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
PyArrow Dataset Generator for ML Inference Service
|
| 4 |
+
|
| 5 |
+
Generates test datasets for academic challenges and model validation.
|
| 6 |
+
Creates 100 PyArrow datasets with various image types and test scenarios.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import base64
|
| 10 |
+
import json
|
| 11 |
+
import random
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from typing import Dict, List, Any, Tuple
|
| 14 |
+
import io
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pyarrow as pa
|
| 18 |
+
import pyarrow.parquet as pq
|
| 19 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class TestDatasetGenerator:
|
| 23 |
+
def __init__(self, output_dir: str = "test_datasets"):
|
| 24 |
+
self.output_dir = Path(output_dir)
|
| 25 |
+
self.output_dir.mkdir(exist_ok=True)
|
| 26 |
+
|
| 27 |
+
# ImageNet class labels (sample for testing)
|
| 28 |
+
self.imagenet_labels = [
|
| 29 |
+
"tench", "goldfish", "great_white_shark", "tiger_shark", "hammerhead",
|
| 30 |
+
"electric_ray", "stingray", "cock", "hen", "ostrich", "brambling",
|
| 31 |
+
"goldfinch", "house_finch", "junco", "indigo_bunting", "robin",
|
| 32 |
+
"bulbul", "jay", "magpie", "chickadee", "water_ouzel", "kite",
|
| 33 |
+
"bald_eagle", "vulture", "great_grey_owl", "European_fire_salamander",
|
| 34 |
+
"common_newt", "eft", "spotted_salamander", "axolotl", "bullfrog",
|
| 35 |
+
"tree_frog", "tailed_frog", "loggerhead", "leatherback_turtle",
|
| 36 |
+
"mud_turtle", "terrapin", "box_turtle", "banded_gecko", "common_iguana",
|
| 37 |
+
"American_chameleon", "whiptail", "agama", "frilled_lizard", "alligator_lizard",
|
| 38 |
+
"Gila_monster", "green_lizard", "African_chameleon", "Komodo_dragon",
|
| 39 |
+
"African_crocodile", "American_alligator", "triceratops", "thunder_snake"
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
def create_synthetic_image(self, width: int = 224, height: int = 224,
|
| 43 |
+
image_type: str = "random") -> Image.Image:
|
| 44 |
+
"""Create synthetic images for testing."""
|
| 45 |
+
if image_type == "random":
|
| 46 |
+
# Random noise image
|
| 47 |
+
array = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
|
| 48 |
+
return Image.fromarray(array)
|
| 49 |
+
|
| 50 |
+
elif image_type == "geometric":
|
| 51 |
+
# Geometric patterns
|
| 52 |
+
img = Image.new('RGB', (width, height), color='white')
|
| 53 |
+
draw = ImageDraw.Draw(img)
|
| 54 |
+
|
| 55 |
+
# Draw random shapes
|
| 56 |
+
for _ in range(random.randint(3, 8)):
|
| 57 |
+
color = tuple(random.randint(0, 255) for _ in range(3))
|
| 58 |
+
shape_type = random.choice(['rectangle', 'ellipse'])
|
| 59 |
+
x1, y1 = random.randint(0, width//2), random.randint(0, height//2)
|
| 60 |
+
x2, y2 = x1 + random.randint(20, width//2), y1 + random.randint(20, height//2)
|
| 61 |
+
|
| 62 |
+
if shape_type == 'rectangle':
|
| 63 |
+
draw.rectangle([x1, y1, x2, y2], fill=color)
|
| 64 |
+
else:
|
| 65 |
+
draw.ellipse([x1, y1, x2, y2], fill=color)
|
| 66 |
+
|
| 67 |
+
return img
|
| 68 |
+
|
| 69 |
+
elif image_type == "gradient":
|
| 70 |
+
array = np.zeros((height, width, 3), dtype=np.uint8)
|
| 71 |
+
for i in range(height):
|
| 72 |
+
for j in range(width):
|
| 73 |
+
array[i, j] = [i * 255 // height, j * 255 // width, (i + j) * 255 // (height + width)]
|
| 74 |
+
return Image.fromarray(array)
|
| 75 |
+
|
| 76 |
+
elif image_type == "text":
|
| 77 |
+
img = Image.new('RGB', (width, height), color='white')
|
| 78 |
+
draw = ImageDraw.Draw(img)
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
font = ImageFont.load_default()
|
| 82 |
+
except:
|
| 83 |
+
font = None
|
| 84 |
+
|
| 85 |
+
text = f"Test Image {random.randint(1, 1000)}"
|
| 86 |
+
draw.text((width//4, height//2), text, fill='black', font=font)
|
| 87 |
+
return img
|
| 88 |
+
|
| 89 |
+
else:
|
| 90 |
+
color = tuple(random.randint(0, 255) for _ in range(3))
|
| 91 |
+
return Image.new('RGB', (width, height), color=color)
|
| 92 |
+
|
| 93 |
+
def image_to_base64(self, image: Image.Image, format: str = "JPEG") -> str:
|
| 94 |
+
"""Convert PIL image to base64 string."""
|
| 95 |
+
buffer = io.BytesIO()
|
| 96 |
+
image.save(buffer, format=format)
|
| 97 |
+
image_bytes = buffer.getvalue()
|
| 98 |
+
return base64.b64encode(image_bytes).decode('utf-8')
|
| 99 |
+
|
| 100 |
+
def create_api_request(self, image_b64: str, media_type: str = "image/jpeg") -> Dict[str, Any]:
|
| 101 |
+
"""Create API request structure matching your service."""
|
| 102 |
+
return {
|
| 103 |
+
"image": {
|
| 104 |
+
"mediaType": media_type,
|
| 105 |
+
"data": image_b64
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
def create_expected_response(self, model_name: str = "microsoft/resnet-18",
|
| 110 |
+
media_type: str = "image/jpeg") -> Dict[str, Any]:
|
| 111 |
+
"""Create expected response structure."""
|
| 112 |
+
prediction = random.choice(self.imagenet_labels)
|
| 113 |
+
return {
|
| 114 |
+
"prediction": prediction,
|
| 115 |
+
"confidence": round(random.uniform(0.3, 0.99), 4),
|
| 116 |
+
"predicted_label": random.randint(0, len(self.imagenet_labels) - 1),
|
| 117 |
+
"model": model_name,
|
| 118 |
+
"mediaType": media_type
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
def generate_standard_datasets(self, count: int = 25) -> List[Dict[str, Any]]:
|
| 122 |
+
"""Generate standard test cases with normal images."""
|
| 123 |
+
datasets = []
|
| 124 |
+
|
| 125 |
+
for i in range(count):
|
| 126 |
+
image_types = ["random", "geometric", "gradient", "text", "solid"]
|
| 127 |
+
sizes = [(224, 224), (256, 256), (299, 299), (384, 384)]
|
| 128 |
+
formats = [("JPEG", "image/jpeg"), ("PNG", "image/png")]
|
| 129 |
+
|
| 130 |
+
records = []
|
| 131 |
+
for j in range(random.randint(5, 20)): # 5-20 images per dataset
|
| 132 |
+
img_type = random.choice(image_types)
|
| 133 |
+
size = random.choice(sizes)
|
| 134 |
+
format_info = random.choice(formats)
|
| 135 |
+
|
| 136 |
+
image = self.create_synthetic_image(size[0], size[1], img_type)
|
| 137 |
+
image_b64 = self.image_to_base64(image, format_info[0])
|
| 138 |
+
|
| 139 |
+
api_request = self.create_api_request(image_b64, format_info[1])
|
| 140 |
+
expected_response = self.create_expected_response()
|
| 141 |
+
|
| 142 |
+
record = {
|
| 143 |
+
"dataset_id": f"standard_{i:03d}",
|
| 144 |
+
"image_id": f"img_{j:03d}",
|
| 145 |
+
"image_type": img_type,
|
| 146 |
+
"image_size": f"{size[0]}x{size[1]}",
|
| 147 |
+
"format": format_info[0],
|
| 148 |
+
"media_type": format_info[1],
|
| 149 |
+
"api_request": json.dumps(api_request),
|
| 150 |
+
"expected_response": json.dumps(expected_response),
|
| 151 |
+
"test_category": "standard",
|
| 152 |
+
"difficulty": "normal"
|
| 153 |
+
}
|
| 154 |
+
records.append(record)
|
| 155 |
+
|
| 156 |
+
datasets.append({
|
| 157 |
+
"name": f"standard_test_{i:03d}",
|
| 158 |
+
"category": "standard",
|
| 159 |
+
"description": f"Standard test dataset {i+1} with {len(records)} images",
|
| 160 |
+
"records": records
|
| 161 |
+
})
|
| 162 |
+
|
| 163 |
+
return datasets
|
| 164 |
+
|
| 165 |
+
def generate_edge_case_datasets(self, count: int = 25) -> List[Dict[str, Any]]:
|
| 166 |
+
"""Generate datasets for edge case scenarios."""
|
| 167 |
+
datasets = []
|
| 168 |
+
|
| 169 |
+
for i in range(count):
|
| 170 |
+
records = []
|
| 171 |
+
edge_cases = [
|
| 172 |
+
{"type": "tiny", "size": (32, 32), "difficulty": "high"},
|
| 173 |
+
{"type": "huge", "size": (2048, 2048), "difficulty": "high"},
|
| 174 |
+
{"type": "extreme_aspect", "size": (1000, 50), "difficulty": "medium"},
|
| 175 |
+
{"type": "single_pixel", "size": (1, 1), "difficulty": "extreme"},
|
| 176 |
+
{"type": "corrupted_base64", "size": (224, 224), "difficulty": "extreme"}
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
for j, edge_case in enumerate(edge_cases):
|
| 180 |
+
if edge_case["type"] == "corrupted_base64":
|
| 181 |
+
image = self.create_synthetic_image(224, 224, "random")
|
| 182 |
+
image_b64 = self.image_to_base64(image, "JPEG")
|
| 183 |
+
corrupted_b64 = image_b64[:-20] + "CORRUPTED_DATA"
|
| 184 |
+
api_request = self.create_api_request(corrupted_b64)
|
| 185 |
+
expected_response = {
|
| 186 |
+
"error": "Invalid image data",
|
| 187 |
+
"status": "failed"
|
| 188 |
+
}
|
| 189 |
+
else:
|
| 190 |
+
image = self.create_synthetic_image(
|
| 191 |
+
edge_case["size"][0], edge_case["size"][1], "random"
|
| 192 |
+
)
|
| 193 |
+
image_b64 = self.image_to_base64(image, "PNG")
|
| 194 |
+
api_request = self.create_api_request(image_b64, "image/png")
|
| 195 |
+
expected_response = self.create_expected_response()
|
| 196 |
+
|
| 197 |
+
record = {
|
| 198 |
+
"dataset_id": f"edge_{i:03d}",
|
| 199 |
+
"image_id": f"edge_{j:03d}",
|
| 200 |
+
"image_type": edge_case["type"],
|
| 201 |
+
"image_size": f"{edge_case['size'][0]}x{edge_case['size'][1]}",
|
| 202 |
+
"format": "PNG",
|
| 203 |
+
"media_type": "image/png",
|
| 204 |
+
"api_request": json.dumps(api_request),
|
| 205 |
+
"expected_response": json.dumps(expected_response),
|
| 206 |
+
"test_category": "edge_case",
|
| 207 |
+
"difficulty": edge_case["difficulty"]
|
| 208 |
+
}
|
| 209 |
+
records.append(record)
|
| 210 |
+
|
| 211 |
+
datasets.append({
|
| 212 |
+
"name": f"edge_case_{i:03d}",
|
| 213 |
+
"category": "edge_case",
|
| 214 |
+
"description": f"Edge case dataset {i+1} with challenging scenarios",
|
| 215 |
+
"records": records
|
| 216 |
+
})
|
| 217 |
+
|
| 218 |
+
return datasets
|
| 219 |
+
|
| 220 |
+
def generate_performance_datasets(self, count: int = 25) -> List[Dict[str, Any]]:
|
| 221 |
+
"""Generate performance benchmark datasets."""
|
| 222 |
+
datasets = []
|
| 223 |
+
|
| 224 |
+
for i in range(count):
|
| 225 |
+
batch_sizes = [1, 5, 10, 25, 50, 100]
|
| 226 |
+
batch_size = random.choice(batch_sizes)
|
| 227 |
+
|
| 228 |
+
records = []
|
| 229 |
+
for j in range(batch_size):
|
| 230 |
+
image = self.create_synthetic_image(224, 224, "random")
|
| 231 |
+
image_b64 = self.image_to_base64(image, "JPEG")
|
| 232 |
+
api_request = self.create_api_request(image_b64)
|
| 233 |
+
expected_response = self.create_expected_response()
|
| 234 |
+
|
| 235 |
+
record = {
|
| 236 |
+
"dataset_id": f"perf_{i:03d}",
|
| 237 |
+
"image_id": f"batch_{j:03d}",
|
| 238 |
+
"image_type": "performance_test",
|
| 239 |
+
"image_size": "224x224",
|
| 240 |
+
"format": "JPEG",
|
| 241 |
+
"media_type": "image/jpeg",
|
| 242 |
+
"api_request": json.dumps(api_request),
|
| 243 |
+
"expected_response": json.dumps(expected_response),
|
| 244 |
+
"test_category": "performance",
|
| 245 |
+
"difficulty": "normal",
|
| 246 |
+
"batch_size": batch_size,
|
| 247 |
+
"expected_max_latency_ms": batch_size * 100
|
| 248 |
+
}
|
| 249 |
+
records.append(record)
|
| 250 |
+
|
| 251 |
+
datasets.append({
|
| 252 |
+
"name": f"performance_test_{i:03d}",
|
| 253 |
+
"category": "performance",
|
| 254 |
+
"description": f"Performance dataset {i+1} with batch size {batch_size}",
|
| 255 |
+
"records": records
|
| 256 |
+
})
|
| 257 |
+
|
| 258 |
+
return datasets
|
| 259 |
+
|
| 260 |
+
def generate_model_comparison_datasets(self, count: int = 25) -> List[Dict[str, Any]]:
|
| 261 |
+
"""Generate datasets for comparing different models."""
|
| 262 |
+
datasets = []
|
| 263 |
+
|
| 264 |
+
model_types = [
|
| 265 |
+
"microsoft/resnet-18", "microsoft/resnet-50", "google/vit-base-patch16-224",
|
| 266 |
+
"facebook/convnext-tiny-224", "microsoft/swin-tiny-patch4-window7-224"
|
| 267 |
+
]
|
| 268 |
+
|
| 269 |
+
for i in range(count):
|
| 270 |
+
# Same images tested across different model types
|
| 271 |
+
base_images = []
|
| 272 |
+
for _ in range(10): # 10 base images per comparison dataset
|
| 273 |
+
image = self.create_synthetic_image(224, 224, "geometric")
|
| 274 |
+
base_images.append(self.image_to_base64(image, "JPEG"))
|
| 275 |
+
|
| 276 |
+
records = []
|
| 277 |
+
for j, model in enumerate(model_types):
|
| 278 |
+
for k, image_b64 in enumerate(base_images):
|
| 279 |
+
api_request = self.create_api_request(image_b64)
|
| 280 |
+
expected_response = self.create_expected_response(model)
|
| 281 |
+
|
| 282 |
+
record = {
|
| 283 |
+
"dataset_id": f"comparison_{i:03d}",
|
| 284 |
+
"image_id": f"img_{k:03d}_model_{j}",
|
| 285 |
+
"image_type": "comparison_base",
|
| 286 |
+
"image_size": "224x224",
|
| 287 |
+
"format": "JPEG",
|
| 288 |
+
"media_type": "image/jpeg",
|
| 289 |
+
"api_request": json.dumps(api_request),
|
| 290 |
+
"expected_response": json.dumps(expected_response),
|
| 291 |
+
"test_category": "model_comparison",
|
| 292 |
+
"difficulty": "normal",
|
| 293 |
+
"model_type": model,
|
| 294 |
+
"comparison_group": k
|
| 295 |
+
}
|
| 296 |
+
records.append(record)
|
| 297 |
+
|
| 298 |
+
datasets.append({
|
| 299 |
+
"name": f"model_comparison_{i:03d}",
|
| 300 |
+
"category": "model_comparison",
|
| 301 |
+
"description": f"Model comparison dataset {i+1} testing {len(model_types)} models",
|
| 302 |
+
"records": records
|
| 303 |
+
})
|
| 304 |
+
|
| 305 |
+
return datasets
|
| 306 |
+
|
| 307 |
+
def save_dataset_to_parquet(self, dataset: Dict[str, Any]):
|
| 308 |
+
"""Save a dataset to PyArrow Parquet format."""
|
| 309 |
+
records = dataset["records"]
|
| 310 |
+
|
| 311 |
+
# Convert to PyArrow table
|
| 312 |
+
table = pa.table({
|
| 313 |
+
"dataset_id": [r["dataset_id"] for r in records],
|
| 314 |
+
"image_id": [r["image_id"] for r in records],
|
| 315 |
+
"image_type": [r["image_type"] for r in records],
|
| 316 |
+
"image_size": [r["image_size"] for r in records],
|
| 317 |
+
"format": [r["format"] for r in records],
|
| 318 |
+
"media_type": [r["media_type"] for r in records],
|
| 319 |
+
"api_request": [r["api_request"] for r in records],
|
| 320 |
+
"expected_response": [r["expected_response"] for r in records],
|
| 321 |
+
"test_category": [r["test_category"] for r in records],
|
| 322 |
+
"difficulty": [r["difficulty"] for r in records],
|
| 323 |
+
# Optional fields with defaults
|
| 324 |
+
"batch_size": [r.get("batch_size", 1) for r in records],
|
| 325 |
+
"expected_max_latency_ms": [r.get("expected_max_latency_ms", 1000) for r in records],
|
| 326 |
+
"model_type": [r.get("model_type", "microsoft/resnet-18") for r in records],
|
| 327 |
+
"comparison_group": [r.get("comparison_group", 0) for r in records]
|
| 328 |
+
})
|
| 329 |
+
|
| 330 |
+
output_path = self.output_dir / f"{dataset['name']}.parquet"
|
| 331 |
+
pq.write_table(table, output_path)
|
| 332 |
+
|
| 333 |
+
# Save metadata as JSON
|
| 334 |
+
metadata = {
|
| 335 |
+
"name": dataset["name"],
|
| 336 |
+
"category": dataset["category"],
|
| 337 |
+
"description": dataset["description"],
|
| 338 |
+
"record_count": len(records),
|
| 339 |
+
"file_size_mb": round(output_path.stat().st_size / (1024 * 1024), 2),
|
| 340 |
+
"schema": [field.name for field in table.schema]
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
metadata_path = self.output_dir / f"{dataset['name']}_metadata.json"
|
| 344 |
+
with open(metadata_path, 'w') as f:
|
| 345 |
+
json.dump(metadata, f, indent=2)
|
| 346 |
+
|
| 347 |
+
def generate_all_datasets(self):
|
| 348 |
+
"""Generate all 100 datasets."""
|
| 349 |
+
print(" Starting dataset generation...")
|
| 350 |
+
|
| 351 |
+
print("๐ Generating standard test datasets (25)...")
|
| 352 |
+
standard_datasets = self.generate_standard_datasets(25)
|
| 353 |
+
for dataset in standard_datasets:
|
| 354 |
+
self.save_dataset_to_parquet(dataset)
|
| 355 |
+
|
| 356 |
+
print("โก Generating edge case datasets (25)...")
|
| 357 |
+
edge_datasets = self.generate_edge_case_datasets(25)
|
| 358 |
+
for dataset in edge_datasets:
|
| 359 |
+
self.save_dataset_to_parquet(dataset)
|
| 360 |
+
|
| 361 |
+
print("๐ Generating performance datasets (25)...")
|
| 362 |
+
performance_datasets = self.generate_performance_datasets(25)
|
| 363 |
+
for dataset in performance_datasets:
|
| 364 |
+
self.save_dataset_to_parquet(dataset)
|
| 365 |
+
|
| 366 |
+
print("๐ Generating model comparison datasets (25)...")
|
| 367 |
+
comparison_datasets = self.generate_model_comparison_datasets(25)
|
| 368 |
+
for dataset in comparison_datasets:
|
| 369 |
+
self.save_dataset_to_parquet(dataset)
|
| 370 |
+
|
| 371 |
+
print(f"โ
Generated 100 datasets in {self.output_dir}/")
|
| 372 |
+
|
| 373 |
+
self.generate_summary()
|
| 374 |
+
|
| 375 |
+
def generate_summary(self):
|
| 376 |
+
"""Generate a summary of all datasets."""
|
| 377 |
+
summary = {
|
| 378 |
+
"total_datasets": 100,
|
| 379 |
+
"categories": {
|
| 380 |
+
"standard": 25,
|
| 381 |
+
"edge_case": 25,
|
| 382 |
+
"performance": 25,
|
| 383 |
+
"model_comparison": 25
|
| 384 |
+
},
|
| 385 |
+
"dataset_info": [],
|
| 386 |
+
"usage_instructions": {
|
| 387 |
+
"loading": "Use pyarrow.parquet.read_table('dataset.parquet')",
|
| 388 |
+
"testing": "Run python scripts/test_datasets.py",
|
| 389 |
+
"api_endpoint": "POST /predict/resnet",
|
| 390 |
+
"request_format": "See api_request column in datasets"
|
| 391 |
+
}
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
# Add individual dataset info
|
| 395 |
+
for parquet_file in self.output_dir.glob("*.parquet"):
|
| 396 |
+
metadata_file = self.output_dir / f"{parquet_file.stem}_metadata.json"
|
| 397 |
+
if metadata_file.exists():
|
| 398 |
+
with open(metadata_file, 'r') as f:
|
| 399 |
+
metadata = json.load(f)
|
| 400 |
+
summary["dataset_info"].append(metadata)
|
| 401 |
+
|
| 402 |
+
summary_path = self.output_dir / "datasets_summary.json"
|
| 403 |
+
with open(summary_path, 'w') as f:
|
| 404 |
+
json.dump(summary, f, indent=2)
|
| 405 |
+
|
| 406 |
+
print(f"๐ Summary saved to {summary_path}")
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
if __name__ == "__main__":
|
| 410 |
+
generator = TestDatasetGenerator()
|
| 411 |
+
generator.generate_all_datasets()
|
scripts/test_datasets.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Dataset Tester for ML Inference Service
|
| 4 |
+
|
| 5 |
+
Tests the generated PyArrow datasets against the running ML inference service.
|
| 6 |
+
Validates API requests/responses and measures performance metrics.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import time
|
| 11 |
+
import asyncio
|
| 12 |
+
import statistics
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Dict, List, Any, Optional
|
| 15 |
+
import argparse
|
| 16 |
+
|
| 17 |
+
import pyarrow.parquet as pq
|
| 18 |
+
import requests
|
| 19 |
+
import pandas as pd
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DatasetTester:
|
| 23 |
+
def __init__(self, base_url: str = "http://127.0.0.1:8000", datasets_dir: str = "test_datasets"):
|
| 24 |
+
self.base_url = base_url.rstrip('/')
|
| 25 |
+
self.datasets_dir = Path(datasets_dir)
|
| 26 |
+
self.endpoint = f"{self.base_url}/predict/resnet"
|
| 27 |
+
self.results = []
|
| 28 |
+
|
| 29 |
+
def load_dataset(self, dataset_path: Path) -> pd.DataFrame:
|
| 30 |
+
"""Load a PyArrow dataset."""
|
| 31 |
+
table = pq.read_table(dataset_path)
|
| 32 |
+
return table.to_pandas()
|
| 33 |
+
|
| 34 |
+
def test_api_connection(self) -> bool:
|
| 35 |
+
"""Test if the API is running and accessible."""
|
| 36 |
+
try:
|
| 37 |
+
response = requests.get(f"{self.base_url}/docs", timeout=5)
|
| 38 |
+
return response.status_code == 200
|
| 39 |
+
except requests.RequestException:
|
| 40 |
+
return False
|
| 41 |
+
|
| 42 |
+
def send_prediction_request(self, api_request_json: str) -> Dict[str, Any]:
|
| 43 |
+
"""Send a single prediction request to the API."""
|
| 44 |
+
try:
|
| 45 |
+
request_data = json.loads(api_request_json)
|
| 46 |
+
start_time = time.time()
|
| 47 |
+
|
| 48 |
+
response = requests.post(
|
| 49 |
+
self.endpoint,
|
| 50 |
+
json=request_data,
|
| 51 |
+
headers={"Content-Type": "application/json"},
|
| 52 |
+
timeout=30
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
end_time = time.time()
|
| 56 |
+
latency_ms = (end_time - start_time) * 1000
|
| 57 |
+
|
| 58 |
+
return {
|
| 59 |
+
"success": response.status_code == 200,
|
| 60 |
+
"status_code": response.status_code,
|
| 61 |
+
"response": response.json() if response.status_code == 200 else response.text,
|
| 62 |
+
"latency_ms": round(latency_ms, 2),
|
| 63 |
+
"error": None
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
except requests.RequestException as e:
|
| 67 |
+
return {
|
| 68 |
+
"success": False,
|
| 69 |
+
"status_code": None,
|
| 70 |
+
"response": None,
|
| 71 |
+
"latency_ms": None,
|
| 72 |
+
"error": str(e)
|
| 73 |
+
}
|
| 74 |
+
except json.JSONDecodeError as e:
|
| 75 |
+
return {
|
| 76 |
+
"success": False,
|
| 77 |
+
"status_code": None,
|
| 78 |
+
"response": None,
|
| 79 |
+
"latency_ms": None,
|
| 80 |
+
"error": f"JSON decode error: {str(e)}"
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
def validate_response(self, actual_response: Dict[str, Any],
|
| 84 |
+
expected_response_json: str) -> Dict[str, Any]:
|
| 85 |
+
"""Validate API response against expected response."""
|
| 86 |
+
try:
|
| 87 |
+
expected = json.loads(expected_response_json)
|
| 88 |
+
|
| 89 |
+
validation = {
|
| 90 |
+
"structure_valid": True,
|
| 91 |
+
"field_errors": []
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
# Check required fields exist
|
| 95 |
+
required_fields = ["prediction", "confidence", "predicted_label", "model", "mediaType"]
|
| 96 |
+
for field in required_fields:
|
| 97 |
+
if field not in actual_response:
|
| 98 |
+
validation["structure_valid"] = False
|
| 99 |
+
validation["field_errors"].append(f"Missing field: {field}")
|
| 100 |
+
|
| 101 |
+
# Validate field types
|
| 102 |
+
if "confidence" in actual_response:
|
| 103 |
+
if not isinstance(actual_response["confidence"], (int, float)):
|
| 104 |
+
validation["field_errors"].append("confidence must be numeric")
|
| 105 |
+
elif not (0 <= actual_response["confidence"] <= 1):
|
| 106 |
+
validation["field_errors"].append("confidence must be between 0 and 1")
|
| 107 |
+
|
| 108 |
+
if "predicted_label" in actual_response:
|
| 109 |
+
if not isinstance(actual_response["predicted_label"], int):
|
| 110 |
+
validation["field_errors"].append("predicted_label must be integer")
|
| 111 |
+
|
| 112 |
+
return validation
|
| 113 |
+
|
| 114 |
+
except json.JSONDecodeError:
|
| 115 |
+
return {
|
| 116 |
+
"structure_valid": False,
|
| 117 |
+
"field_errors": ["Invalid expected response JSON"]
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
def test_dataset(self, dataset_path: Path, max_samples: Optional[int] = None) -> Dict[str, Any]:
|
| 121 |
+
"""Test a single dataset."""
|
| 122 |
+
print(f"๐ Testing dataset: {dataset_path.name}")
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
df = self.load_dataset(dataset_path)
|
| 126 |
+
if max_samples:
|
| 127 |
+
df = df.head(max_samples)
|
| 128 |
+
|
| 129 |
+
results = {
|
| 130 |
+
"dataset_name": dataset_path.stem,
|
| 131 |
+
"total_samples": len(df),
|
| 132 |
+
"tested_samples": 0,
|
| 133 |
+
"successful_requests": 0,
|
| 134 |
+
"failed_requests": 0,
|
| 135 |
+
"validation_errors": 0,
|
| 136 |
+
"latencies_ms": [],
|
| 137 |
+
"errors": [],
|
| 138 |
+
"category": df['test_category'].iloc[0] if not df.empty else "unknown"
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
for idx, row in df.iterrows():
|
| 142 |
+
print(f" Testing sample {idx + 1}/{len(df)}", end="\r")
|
| 143 |
+
|
| 144 |
+
# Send API request
|
| 145 |
+
api_result = self.send_prediction_request(row['api_request'])
|
| 146 |
+
results["tested_samples"] += 1
|
| 147 |
+
|
| 148 |
+
if api_result["success"]:
|
| 149 |
+
results["successful_requests"] += 1
|
| 150 |
+
results["latencies_ms"].append(api_result["latency_ms"])
|
| 151 |
+
|
| 152 |
+
# Validate response structure
|
| 153 |
+
validation = self.validate_response(
|
| 154 |
+
api_result["response"],
|
| 155 |
+
row['expected_response']
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
if not validation["structure_valid"]:
|
| 159 |
+
results["validation_errors"] += 1
|
| 160 |
+
results["errors"].append({
|
| 161 |
+
"sample_id": row['image_id'],
|
| 162 |
+
"type": "validation_error",
|
| 163 |
+
"details": validation["field_errors"]
|
| 164 |
+
})
|
| 165 |
+
|
| 166 |
+
else:
|
| 167 |
+
results["failed_requests"] += 1
|
| 168 |
+
results["errors"].append({
|
| 169 |
+
"sample_id": row['image_id'],
|
| 170 |
+
"type": "request_failed",
|
| 171 |
+
"status_code": api_result["status_code"],
|
| 172 |
+
"error": api_result["error"]
|
| 173 |
+
})
|
| 174 |
+
|
| 175 |
+
# Calculate statistics
|
| 176 |
+
if results["latencies_ms"]:
|
| 177 |
+
results["avg_latency_ms"] = round(statistics.mean(results["latencies_ms"]), 2)
|
| 178 |
+
results["min_latency_ms"] = round(min(results["latencies_ms"]), 2)
|
| 179 |
+
results["max_latency_ms"] = round(max(results["latencies_ms"]), 2)
|
| 180 |
+
results["median_latency_ms"] = round(statistics.median(results["latencies_ms"]), 2)
|
| 181 |
+
else:
|
| 182 |
+
results.update({
|
| 183 |
+
"avg_latency_ms": None,
|
| 184 |
+
"min_latency_ms": None,
|
| 185 |
+
"max_latency_ms": None,
|
| 186 |
+
"median_latency_ms": None
|
| 187 |
+
})
|
| 188 |
+
|
| 189 |
+
results["success_rate"] = round(
|
| 190 |
+
results["successful_requests"] / results["tested_samples"] * 100, 2
|
| 191 |
+
) if results["tested_samples"] > 0 else 0
|
| 192 |
+
|
| 193 |
+
print(f"\n โ
Completed: {results['success_rate']}% success rate")
|
| 194 |
+
return results
|
| 195 |
+
|
| 196 |
+
except Exception as e:
|
| 197 |
+
print(f"\n โ Failed to test dataset: {str(e)}")
|
| 198 |
+
return {
|
| 199 |
+
"dataset_name": dataset_path.stem,
|
| 200 |
+
"error": str(e),
|
| 201 |
+
"success_rate": 0
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
def test_all_datasets(self, max_samples_per_dataset: Optional[int] = None,
|
| 205 |
+
category_filter: Optional[str] = None) -> Dict[str, Any]:
|
| 206 |
+
"""Test all datasets or filtered by category."""
|
| 207 |
+
if not self.test_api_connection():
|
| 208 |
+
print("โ API is not accessible. Please start the service first:")
|
| 209 |
+
print(" uvicorn main:app --reload")
|
| 210 |
+
return {"error": "API not accessible"}
|
| 211 |
+
|
| 212 |
+
print(f" Starting dataset testing against {self.endpoint}")
|
| 213 |
+
|
| 214 |
+
parquet_files = list(self.datasets_dir.glob("*.parquet"))
|
| 215 |
+
if not parquet_files:
|
| 216 |
+
print(f"โ No datasets found in {self.datasets_dir}")
|
| 217 |
+
return {"error": "No datasets found"}
|
| 218 |
+
|
| 219 |
+
if category_filter:
|
| 220 |
+
parquet_files = [f for f in parquet_files if category_filter in f.name]
|
| 221 |
+
|
| 222 |
+
print(f" Found {len(parquet_files)} datasets to test")
|
| 223 |
+
|
| 224 |
+
all_results = []
|
| 225 |
+
start_time = time.time()
|
| 226 |
+
|
| 227 |
+
for dataset_file in parquet_files:
|
| 228 |
+
result = self.test_dataset(dataset_file, max_samples_per_dataset)
|
| 229 |
+
all_results.append(result)
|
| 230 |
+
|
| 231 |
+
end_time = time.time()
|
| 232 |
+
total_time = end_time - start_time
|
| 233 |
+
|
| 234 |
+
summary = self.generate_summary(all_results, total_time)
|
| 235 |
+
|
| 236 |
+
self.save_results(summary, all_results)
|
| 237 |
+
|
| 238 |
+
return summary
|
| 239 |
+
|
| 240 |
+
def generate_summary(self, results: List[Dict[str, Any]], total_time: float) -> Dict[str, Any]:
|
| 241 |
+
"""Generate summary of all test results."""
|
| 242 |
+
successful_datasets = [r for r in results if r.get("success_rate", 0) > 0]
|
| 243 |
+
failed_datasets = [r for r in results if r.get("error") or r.get("success_rate", 0) == 0]
|
| 244 |
+
|
| 245 |
+
total_samples = sum(r.get("tested_samples", 0) for r in results)
|
| 246 |
+
total_successful = sum(r.get("successful_requests", 0) for r in results)
|
| 247 |
+
total_failed = sum(r.get("failed_requests", 0) for r in results)
|
| 248 |
+
|
| 249 |
+
all_latencies = []
|
| 250 |
+
for r in results:
|
| 251 |
+
all_latencies.extend(r.get("latencies_ms", []))
|
| 252 |
+
|
| 253 |
+
summary = {
|
| 254 |
+
"test_summary": {
|
| 255 |
+
"total_datasets": len(results),
|
| 256 |
+
"successful_datasets": len(successful_datasets),
|
| 257 |
+
"failed_datasets": len(failed_datasets),
|
| 258 |
+
"total_samples_tested": total_samples,
|
| 259 |
+
"total_successful_requests": total_successful,
|
| 260 |
+
"total_failed_requests": total_failed,
|
| 261 |
+
"overall_success_rate": round(
|
| 262 |
+
total_successful / total_samples * 100, 2
|
| 263 |
+
) if total_samples > 0 else 0,
|
| 264 |
+
"total_test_time_seconds": round(total_time, 2)
|
| 265 |
+
},
|
| 266 |
+
"performance_metrics": {
|
| 267 |
+
"avg_latency_ms": round(statistics.mean(all_latencies), 2) if all_latencies else None,
|
| 268 |
+
"median_latency_ms": round(statistics.median(all_latencies), 2) if all_latencies else None,
|
| 269 |
+
"min_latency_ms": round(min(all_latencies), 2) if all_latencies else None,
|
| 270 |
+
"max_latency_ms": round(max(all_latencies), 2) if all_latencies else None,
|
| 271 |
+
"requests_per_second": round(
|
| 272 |
+
total_successful / total_time, 2
|
| 273 |
+
) if total_time > 0 else 0
|
| 274 |
+
},
|
| 275 |
+
"category_breakdown": {},
|
| 276 |
+
"failed_datasets": [r["dataset_name"] for r in failed_datasets]
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
categories = {}
|
| 280 |
+
for result in results:
|
| 281 |
+
category = result.get("category", "unknown")
|
| 282 |
+
if category not in categories:
|
| 283 |
+
categories[category] = {
|
| 284 |
+
"count": 0,
|
| 285 |
+
"success_rates": [],
|
| 286 |
+
"avg_success_rate": 0
|
| 287 |
+
}
|
| 288 |
+
categories[category]["count"] += 1
|
| 289 |
+
categories[category]["success_rates"].append(result.get("success_rate", 0))
|
| 290 |
+
|
| 291 |
+
for category, data in categories.items():
|
| 292 |
+
data["avg_success_rate"] = round(
|
| 293 |
+
statistics.mean(data["success_rates"]), 2
|
| 294 |
+
) if data["success_rates"] else 0
|
| 295 |
+
|
| 296 |
+
summary["category_breakdown"] = categories
|
| 297 |
+
|
| 298 |
+
return summary
|
| 299 |
+
|
| 300 |
+
def save_results(self, summary: Dict[str, Any], detailed_results: List[Dict[str, Any]]):
|
| 301 |
+
"""Save test results to files."""
|
| 302 |
+
results_dir = Path("test_results")
|
| 303 |
+
results_dir.mkdir(exist_ok=True)
|
| 304 |
+
|
| 305 |
+
timestamp = int(time.time())
|
| 306 |
+
|
| 307 |
+
# Save summary
|
| 308 |
+
summary_path = results_dir / f"test_summary_{timestamp}.json"
|
| 309 |
+
with open(summary_path, 'w') as f:
|
| 310 |
+
json.dump(summary, f, indent=2)
|
| 311 |
+
|
| 312 |
+
# Save detailed results
|
| 313 |
+
detailed_path = results_dir / f"test_detailed_{timestamp}.json"
|
| 314 |
+
with open(detailed_path, 'w') as f:
|
| 315 |
+
json.dump(detailed_results, f, indent=2)
|
| 316 |
+
|
| 317 |
+
print(f" Results saved:")
|
| 318 |
+
print(f" Summary: {summary_path}")
|
| 319 |
+
print(f" Details: {detailed_path}")
|
| 320 |
+
|
| 321 |
+
def print_summary(self, summary: Dict[str, Any]):
|
| 322 |
+
"""Print test summary to console."""
|
| 323 |
+
print("\n" + "="*60)
|
| 324 |
+
print("๐ DATASET TESTING SUMMARY")
|
| 325 |
+
print("="*60)
|
| 326 |
+
|
| 327 |
+
ts = summary["test_summary"]
|
| 328 |
+
print(f"Datasets tested: {ts['total_datasets']}")
|
| 329 |
+
print(f"Successful datasets: {ts['successful_datasets']}")
|
| 330 |
+
print(f"Failed datasets: {ts['failed_datasets']}")
|
| 331 |
+
print(f"Total samples: {ts['total_samples_tested']}")
|
| 332 |
+
print(f"Overall success rate: {ts['overall_success_rate']}%")
|
| 333 |
+
print(f"Test duration: {ts['total_test_time_seconds']}s")
|
| 334 |
+
|
| 335 |
+
pm = summary["performance_metrics"]
|
| 336 |
+
if pm["avg_latency_ms"]:
|
| 337 |
+
print(f"\nPerformance:")
|
| 338 |
+
print(f" Avg latency: {pm['avg_latency_ms']}ms")
|
| 339 |
+
print(f" Median latency: {pm['median_latency_ms']}ms")
|
| 340 |
+
print(f" Min latency: {pm['min_latency_ms']}ms")
|
| 341 |
+
print(f" Max latency: {pm['max_latency_ms']}ms")
|
| 342 |
+
print(f" Requests/sec: {pm['requests_per_second']}")
|
| 343 |
+
|
| 344 |
+
print(f"\nCategory breakdown:")
|
| 345 |
+
for category, data in summary["category_breakdown"].items():
|
| 346 |
+
print(f" {category}: {data['count']} datasets, {data['avg_success_rate']}% avg success")
|
| 347 |
+
|
| 348 |
+
if summary["failed_datasets"]:
|
| 349 |
+
print(f"\nFailed datasets: {', '.join(summary['failed_datasets'])}")
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def main():
|
| 353 |
+
parser = argparse.ArgumentParser(description="Test PyArrow datasets against ML inference service")
|
| 354 |
+
parser.add_argument("--base-url", default="http://127.0.0.1:8000", help="Base URL of the API")
|
| 355 |
+
parser.add_argument("--datasets-dir", default="scripts/test_datasets", help="Directory containing datasets")
|
| 356 |
+
parser.add_argument("--max-samples", type=int, help="Max samples per dataset to test")
|
| 357 |
+
parser.add_argument("--category", help="Filter datasets by category (standard, edge_case, performance, model_comparison)")
|
| 358 |
+
parser.add_argument("--quick", action="store_true", help="Quick test with max 5 samples per dataset")
|
| 359 |
+
|
| 360 |
+
args = parser.parse_args()
|
| 361 |
+
|
| 362 |
+
tester = DatasetTester(args.base_url, args.datasets_dir)
|
| 363 |
+
|
| 364 |
+
max_samples = args.max_samples
|
| 365 |
+
if args.quick:
|
| 366 |
+
max_samples = 5
|
| 367 |
+
|
| 368 |
+
results = tester.test_all_datasets(max_samples, args.category)
|
| 369 |
+
|
| 370 |
+
if "error" not in results:
|
| 371 |
+
tester.print_summary(results)
|
| 372 |
+
|
| 373 |
+
if results["test_summary"]["overall_success_rate"] > 90:
|
| 374 |
+
print("\n๐ Excellent! API is working great with the datasets!")
|
| 375 |
+
elif results["test_summary"]["overall_success_rate"] > 70:
|
| 376 |
+
print("\n๐ Good! API works well, minor issues detected.")
|
| 377 |
+
else:
|
| 378 |
+
print("\nโ ๏ธ Warning: Several issues detected. Check the detailed results.")
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
if __name__ == "__main__":
|
| 382 |
+
main()
|