sachin sharma commited on
Commit
ebbcd26
ยท
1 Parent(s): 4f88f85

added test case generation

Browse files
scripts/generate_test_datasets.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ PyArrow Dataset Generator for ML Inference Service
4
+
5
+ Generates test datasets for academic challenges and model validation.
6
+ Creates 100 PyArrow datasets with various image types and test scenarios.
7
+ """
8
+
9
+ import base64
10
+ import json
11
+ import random
12
+ from pathlib import Path
13
+ from typing import Dict, List, Any, Tuple
14
+ import io
15
+
16
+ import numpy as np
17
+ import pyarrow as pa
18
+ import pyarrow.parquet as pq
19
+ from PIL import Image, ImageDraw, ImageFont
20
+
21
+
22
+ class TestDatasetGenerator:
23
+ def __init__(self, output_dir: str = "test_datasets"):
24
+ self.output_dir = Path(output_dir)
25
+ self.output_dir.mkdir(exist_ok=True)
26
+
27
+ # ImageNet class labels (sample for testing)
28
+ self.imagenet_labels = [
29
+ "tench", "goldfish", "great_white_shark", "tiger_shark", "hammerhead",
30
+ "electric_ray", "stingray", "cock", "hen", "ostrich", "brambling",
31
+ "goldfinch", "house_finch", "junco", "indigo_bunting", "robin",
32
+ "bulbul", "jay", "magpie", "chickadee", "water_ouzel", "kite",
33
+ "bald_eagle", "vulture", "great_grey_owl", "European_fire_salamander",
34
+ "common_newt", "eft", "spotted_salamander", "axolotl", "bullfrog",
35
+ "tree_frog", "tailed_frog", "loggerhead", "leatherback_turtle",
36
+ "mud_turtle", "terrapin", "box_turtle", "banded_gecko", "common_iguana",
37
+ "American_chameleon", "whiptail", "agama", "frilled_lizard", "alligator_lizard",
38
+ "Gila_monster", "green_lizard", "African_chameleon", "Komodo_dragon",
39
+ "African_crocodile", "American_alligator", "triceratops", "thunder_snake"
40
+ ]
41
+
42
+ def create_synthetic_image(self, width: int = 224, height: int = 224,
43
+ image_type: str = "random") -> Image.Image:
44
+ """Create synthetic images for testing."""
45
+ if image_type == "random":
46
+ # Random noise image
47
+ array = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
48
+ return Image.fromarray(array)
49
+
50
+ elif image_type == "geometric":
51
+ # Geometric patterns
52
+ img = Image.new('RGB', (width, height), color='white')
53
+ draw = ImageDraw.Draw(img)
54
+
55
+ # Draw random shapes
56
+ for _ in range(random.randint(3, 8)):
57
+ color = tuple(random.randint(0, 255) for _ in range(3))
58
+ shape_type = random.choice(['rectangle', 'ellipse'])
59
+ x1, y1 = random.randint(0, width//2), random.randint(0, height//2)
60
+ x2, y2 = x1 + random.randint(20, width//2), y1 + random.randint(20, height//2)
61
+
62
+ if shape_type == 'rectangle':
63
+ draw.rectangle([x1, y1, x2, y2], fill=color)
64
+ else:
65
+ draw.ellipse([x1, y1, x2, y2], fill=color)
66
+
67
+ return img
68
+
69
+ elif image_type == "gradient":
70
+ array = np.zeros((height, width, 3), dtype=np.uint8)
71
+ for i in range(height):
72
+ for j in range(width):
73
+ array[i, j] = [i * 255 // height, j * 255 // width, (i + j) * 255 // (height + width)]
74
+ return Image.fromarray(array)
75
+
76
+ elif image_type == "text":
77
+ img = Image.new('RGB', (width, height), color='white')
78
+ draw = ImageDraw.Draw(img)
79
+
80
+ try:
81
+ font = ImageFont.load_default()
82
+ except:
83
+ font = None
84
+
85
+ text = f"Test Image {random.randint(1, 1000)}"
86
+ draw.text((width//4, height//2), text, fill='black', font=font)
87
+ return img
88
+
89
+ else:
90
+ color = tuple(random.randint(0, 255) for _ in range(3))
91
+ return Image.new('RGB', (width, height), color=color)
92
+
93
+ def image_to_base64(self, image: Image.Image, format: str = "JPEG") -> str:
94
+ """Convert PIL image to base64 string."""
95
+ buffer = io.BytesIO()
96
+ image.save(buffer, format=format)
97
+ image_bytes = buffer.getvalue()
98
+ return base64.b64encode(image_bytes).decode('utf-8')
99
+
100
+ def create_api_request(self, image_b64: str, media_type: str = "image/jpeg") -> Dict[str, Any]:
101
+ """Create API request structure matching your service."""
102
+ return {
103
+ "image": {
104
+ "mediaType": media_type,
105
+ "data": image_b64
106
+ }
107
+ }
108
+
109
+ def create_expected_response(self, model_name: str = "microsoft/resnet-18",
110
+ media_type: str = "image/jpeg") -> Dict[str, Any]:
111
+ """Create expected response structure."""
112
+ prediction = random.choice(self.imagenet_labels)
113
+ return {
114
+ "prediction": prediction,
115
+ "confidence": round(random.uniform(0.3, 0.99), 4),
116
+ "predicted_label": random.randint(0, len(self.imagenet_labels) - 1),
117
+ "model": model_name,
118
+ "mediaType": media_type
119
+ }
120
+
121
+ def generate_standard_datasets(self, count: int = 25) -> List[Dict[str, Any]]:
122
+ """Generate standard test cases with normal images."""
123
+ datasets = []
124
+
125
+ for i in range(count):
126
+ image_types = ["random", "geometric", "gradient", "text", "solid"]
127
+ sizes = [(224, 224), (256, 256), (299, 299), (384, 384)]
128
+ formats = [("JPEG", "image/jpeg"), ("PNG", "image/png")]
129
+
130
+ records = []
131
+ for j in range(random.randint(5, 20)): # 5-20 images per dataset
132
+ img_type = random.choice(image_types)
133
+ size = random.choice(sizes)
134
+ format_info = random.choice(formats)
135
+
136
+ image = self.create_synthetic_image(size[0], size[1], img_type)
137
+ image_b64 = self.image_to_base64(image, format_info[0])
138
+
139
+ api_request = self.create_api_request(image_b64, format_info[1])
140
+ expected_response = self.create_expected_response()
141
+
142
+ record = {
143
+ "dataset_id": f"standard_{i:03d}",
144
+ "image_id": f"img_{j:03d}",
145
+ "image_type": img_type,
146
+ "image_size": f"{size[0]}x{size[1]}",
147
+ "format": format_info[0],
148
+ "media_type": format_info[1],
149
+ "api_request": json.dumps(api_request),
150
+ "expected_response": json.dumps(expected_response),
151
+ "test_category": "standard",
152
+ "difficulty": "normal"
153
+ }
154
+ records.append(record)
155
+
156
+ datasets.append({
157
+ "name": f"standard_test_{i:03d}",
158
+ "category": "standard",
159
+ "description": f"Standard test dataset {i+1} with {len(records)} images",
160
+ "records": records
161
+ })
162
+
163
+ return datasets
164
+
165
+ def generate_edge_case_datasets(self, count: int = 25) -> List[Dict[str, Any]]:
166
+ """Generate datasets for edge case scenarios."""
167
+ datasets = []
168
+
169
+ for i in range(count):
170
+ records = []
171
+ edge_cases = [
172
+ {"type": "tiny", "size": (32, 32), "difficulty": "high"},
173
+ {"type": "huge", "size": (2048, 2048), "difficulty": "high"},
174
+ {"type": "extreme_aspect", "size": (1000, 50), "difficulty": "medium"},
175
+ {"type": "single_pixel", "size": (1, 1), "difficulty": "extreme"},
176
+ {"type": "corrupted_base64", "size": (224, 224), "difficulty": "extreme"}
177
+ ]
178
+
179
+ for j, edge_case in enumerate(edge_cases):
180
+ if edge_case["type"] == "corrupted_base64":
181
+ image = self.create_synthetic_image(224, 224, "random")
182
+ image_b64 = self.image_to_base64(image, "JPEG")
183
+ corrupted_b64 = image_b64[:-20] + "CORRUPTED_DATA"
184
+ api_request = self.create_api_request(corrupted_b64)
185
+ expected_response = {
186
+ "error": "Invalid image data",
187
+ "status": "failed"
188
+ }
189
+ else:
190
+ image = self.create_synthetic_image(
191
+ edge_case["size"][0], edge_case["size"][1], "random"
192
+ )
193
+ image_b64 = self.image_to_base64(image, "PNG")
194
+ api_request = self.create_api_request(image_b64, "image/png")
195
+ expected_response = self.create_expected_response()
196
+
197
+ record = {
198
+ "dataset_id": f"edge_{i:03d}",
199
+ "image_id": f"edge_{j:03d}",
200
+ "image_type": edge_case["type"],
201
+ "image_size": f"{edge_case['size'][0]}x{edge_case['size'][1]}",
202
+ "format": "PNG",
203
+ "media_type": "image/png",
204
+ "api_request": json.dumps(api_request),
205
+ "expected_response": json.dumps(expected_response),
206
+ "test_category": "edge_case",
207
+ "difficulty": edge_case["difficulty"]
208
+ }
209
+ records.append(record)
210
+
211
+ datasets.append({
212
+ "name": f"edge_case_{i:03d}",
213
+ "category": "edge_case",
214
+ "description": f"Edge case dataset {i+1} with challenging scenarios",
215
+ "records": records
216
+ })
217
+
218
+ return datasets
219
+
220
+ def generate_performance_datasets(self, count: int = 25) -> List[Dict[str, Any]]:
221
+ """Generate performance benchmark datasets."""
222
+ datasets = []
223
+
224
+ for i in range(count):
225
+ batch_sizes = [1, 5, 10, 25, 50, 100]
226
+ batch_size = random.choice(batch_sizes)
227
+
228
+ records = []
229
+ for j in range(batch_size):
230
+ image = self.create_synthetic_image(224, 224, "random")
231
+ image_b64 = self.image_to_base64(image, "JPEG")
232
+ api_request = self.create_api_request(image_b64)
233
+ expected_response = self.create_expected_response()
234
+
235
+ record = {
236
+ "dataset_id": f"perf_{i:03d}",
237
+ "image_id": f"batch_{j:03d}",
238
+ "image_type": "performance_test",
239
+ "image_size": "224x224",
240
+ "format": "JPEG",
241
+ "media_type": "image/jpeg",
242
+ "api_request": json.dumps(api_request),
243
+ "expected_response": json.dumps(expected_response),
244
+ "test_category": "performance",
245
+ "difficulty": "normal",
246
+ "batch_size": batch_size,
247
+ "expected_max_latency_ms": batch_size * 100
248
+ }
249
+ records.append(record)
250
+
251
+ datasets.append({
252
+ "name": f"performance_test_{i:03d}",
253
+ "category": "performance",
254
+ "description": f"Performance dataset {i+1} with batch size {batch_size}",
255
+ "records": records
256
+ })
257
+
258
+ return datasets
259
+
260
+ def generate_model_comparison_datasets(self, count: int = 25) -> List[Dict[str, Any]]:
261
+ """Generate datasets for comparing different models."""
262
+ datasets = []
263
+
264
+ model_types = [
265
+ "microsoft/resnet-18", "microsoft/resnet-50", "google/vit-base-patch16-224",
266
+ "facebook/convnext-tiny-224", "microsoft/swin-tiny-patch4-window7-224"
267
+ ]
268
+
269
+ for i in range(count):
270
+ # Same images tested across different model types
271
+ base_images = []
272
+ for _ in range(10): # 10 base images per comparison dataset
273
+ image = self.create_synthetic_image(224, 224, "geometric")
274
+ base_images.append(self.image_to_base64(image, "JPEG"))
275
+
276
+ records = []
277
+ for j, model in enumerate(model_types):
278
+ for k, image_b64 in enumerate(base_images):
279
+ api_request = self.create_api_request(image_b64)
280
+ expected_response = self.create_expected_response(model)
281
+
282
+ record = {
283
+ "dataset_id": f"comparison_{i:03d}",
284
+ "image_id": f"img_{k:03d}_model_{j}",
285
+ "image_type": "comparison_base",
286
+ "image_size": "224x224",
287
+ "format": "JPEG",
288
+ "media_type": "image/jpeg",
289
+ "api_request": json.dumps(api_request),
290
+ "expected_response": json.dumps(expected_response),
291
+ "test_category": "model_comparison",
292
+ "difficulty": "normal",
293
+ "model_type": model,
294
+ "comparison_group": k
295
+ }
296
+ records.append(record)
297
+
298
+ datasets.append({
299
+ "name": f"model_comparison_{i:03d}",
300
+ "category": "model_comparison",
301
+ "description": f"Model comparison dataset {i+1} testing {len(model_types)} models",
302
+ "records": records
303
+ })
304
+
305
+ return datasets
306
+
307
+ def save_dataset_to_parquet(self, dataset: Dict[str, Any]):
308
+ """Save a dataset to PyArrow Parquet format."""
309
+ records = dataset["records"]
310
+
311
+ # Convert to PyArrow table
312
+ table = pa.table({
313
+ "dataset_id": [r["dataset_id"] for r in records],
314
+ "image_id": [r["image_id"] for r in records],
315
+ "image_type": [r["image_type"] for r in records],
316
+ "image_size": [r["image_size"] for r in records],
317
+ "format": [r["format"] for r in records],
318
+ "media_type": [r["media_type"] for r in records],
319
+ "api_request": [r["api_request"] for r in records],
320
+ "expected_response": [r["expected_response"] for r in records],
321
+ "test_category": [r["test_category"] for r in records],
322
+ "difficulty": [r["difficulty"] for r in records],
323
+ # Optional fields with defaults
324
+ "batch_size": [r.get("batch_size", 1) for r in records],
325
+ "expected_max_latency_ms": [r.get("expected_max_latency_ms", 1000) for r in records],
326
+ "model_type": [r.get("model_type", "microsoft/resnet-18") for r in records],
327
+ "comparison_group": [r.get("comparison_group", 0) for r in records]
328
+ })
329
+
330
+ output_path = self.output_dir / f"{dataset['name']}.parquet"
331
+ pq.write_table(table, output_path)
332
+
333
+ # Save metadata as JSON
334
+ metadata = {
335
+ "name": dataset["name"],
336
+ "category": dataset["category"],
337
+ "description": dataset["description"],
338
+ "record_count": len(records),
339
+ "file_size_mb": round(output_path.stat().st_size / (1024 * 1024), 2),
340
+ "schema": [field.name for field in table.schema]
341
+ }
342
+
343
+ metadata_path = self.output_dir / f"{dataset['name']}_metadata.json"
344
+ with open(metadata_path, 'w') as f:
345
+ json.dump(metadata, f, indent=2)
346
+
347
+ def generate_all_datasets(self):
348
+ """Generate all 100 datasets."""
349
+ print(" Starting dataset generation...")
350
+
351
+ print("๐Ÿ“Š Generating standard test datasets (25)...")
352
+ standard_datasets = self.generate_standard_datasets(25)
353
+ for dataset in standard_datasets:
354
+ self.save_dataset_to_parquet(dataset)
355
+
356
+ print("โšก Generating edge case datasets (25)...")
357
+ edge_datasets = self.generate_edge_case_datasets(25)
358
+ for dataset in edge_datasets:
359
+ self.save_dataset_to_parquet(dataset)
360
+
361
+ print("๐Ÿ Generating performance datasets (25)...")
362
+ performance_datasets = self.generate_performance_datasets(25)
363
+ for dataset in performance_datasets:
364
+ self.save_dataset_to_parquet(dataset)
365
+
366
+ print("๐Ÿ”„ Generating model comparison datasets (25)...")
367
+ comparison_datasets = self.generate_model_comparison_datasets(25)
368
+ for dataset in comparison_datasets:
369
+ self.save_dataset_to_parquet(dataset)
370
+
371
+ print(f"โœ… Generated 100 datasets in {self.output_dir}/")
372
+
373
+ self.generate_summary()
374
+
375
+ def generate_summary(self):
376
+ """Generate a summary of all datasets."""
377
+ summary = {
378
+ "total_datasets": 100,
379
+ "categories": {
380
+ "standard": 25,
381
+ "edge_case": 25,
382
+ "performance": 25,
383
+ "model_comparison": 25
384
+ },
385
+ "dataset_info": [],
386
+ "usage_instructions": {
387
+ "loading": "Use pyarrow.parquet.read_table('dataset.parquet')",
388
+ "testing": "Run python scripts/test_datasets.py",
389
+ "api_endpoint": "POST /predict/resnet",
390
+ "request_format": "See api_request column in datasets"
391
+ }
392
+ }
393
+
394
+ # Add individual dataset info
395
+ for parquet_file in self.output_dir.glob("*.parquet"):
396
+ metadata_file = self.output_dir / f"{parquet_file.stem}_metadata.json"
397
+ if metadata_file.exists():
398
+ with open(metadata_file, 'r') as f:
399
+ metadata = json.load(f)
400
+ summary["dataset_info"].append(metadata)
401
+
402
+ summary_path = self.output_dir / "datasets_summary.json"
403
+ with open(summary_path, 'w') as f:
404
+ json.dump(summary, f, indent=2)
405
+
406
+ print(f"๐Ÿ“‹ Summary saved to {summary_path}")
407
+
408
+
409
+ if __name__ == "__main__":
410
+ generator = TestDatasetGenerator()
411
+ generator.generate_all_datasets()
scripts/test_datasets.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Dataset Tester for ML Inference Service
4
+
5
+ Tests the generated PyArrow datasets against the running ML inference service.
6
+ Validates API requests/responses and measures performance metrics.
7
+ """
8
+
9
+ import json
10
+ import time
11
+ import asyncio
12
+ import statistics
13
+ from pathlib import Path
14
+ from typing import Dict, List, Any, Optional
15
+ import argparse
16
+
17
+ import pyarrow.parquet as pq
18
+ import requests
19
+ import pandas as pd
20
+
21
+
22
+ class DatasetTester:
23
+ def __init__(self, base_url: str = "http://127.0.0.1:8000", datasets_dir: str = "test_datasets"):
24
+ self.base_url = base_url.rstrip('/')
25
+ self.datasets_dir = Path(datasets_dir)
26
+ self.endpoint = f"{self.base_url}/predict/resnet"
27
+ self.results = []
28
+
29
+ def load_dataset(self, dataset_path: Path) -> pd.DataFrame:
30
+ """Load a PyArrow dataset."""
31
+ table = pq.read_table(dataset_path)
32
+ return table.to_pandas()
33
+
34
+ def test_api_connection(self) -> bool:
35
+ """Test if the API is running and accessible."""
36
+ try:
37
+ response = requests.get(f"{self.base_url}/docs", timeout=5)
38
+ return response.status_code == 200
39
+ except requests.RequestException:
40
+ return False
41
+
42
+ def send_prediction_request(self, api_request_json: str) -> Dict[str, Any]:
43
+ """Send a single prediction request to the API."""
44
+ try:
45
+ request_data = json.loads(api_request_json)
46
+ start_time = time.time()
47
+
48
+ response = requests.post(
49
+ self.endpoint,
50
+ json=request_data,
51
+ headers={"Content-Type": "application/json"},
52
+ timeout=30
53
+ )
54
+
55
+ end_time = time.time()
56
+ latency_ms = (end_time - start_time) * 1000
57
+
58
+ return {
59
+ "success": response.status_code == 200,
60
+ "status_code": response.status_code,
61
+ "response": response.json() if response.status_code == 200 else response.text,
62
+ "latency_ms": round(latency_ms, 2),
63
+ "error": None
64
+ }
65
+
66
+ except requests.RequestException as e:
67
+ return {
68
+ "success": False,
69
+ "status_code": None,
70
+ "response": None,
71
+ "latency_ms": None,
72
+ "error": str(e)
73
+ }
74
+ except json.JSONDecodeError as e:
75
+ return {
76
+ "success": False,
77
+ "status_code": None,
78
+ "response": None,
79
+ "latency_ms": None,
80
+ "error": f"JSON decode error: {str(e)}"
81
+ }
82
+
83
+ def validate_response(self, actual_response: Dict[str, Any],
84
+ expected_response_json: str) -> Dict[str, Any]:
85
+ """Validate API response against expected response."""
86
+ try:
87
+ expected = json.loads(expected_response_json)
88
+
89
+ validation = {
90
+ "structure_valid": True,
91
+ "field_errors": []
92
+ }
93
+
94
+ # Check required fields exist
95
+ required_fields = ["prediction", "confidence", "predicted_label", "model", "mediaType"]
96
+ for field in required_fields:
97
+ if field not in actual_response:
98
+ validation["structure_valid"] = False
99
+ validation["field_errors"].append(f"Missing field: {field}")
100
+
101
+ # Validate field types
102
+ if "confidence" in actual_response:
103
+ if not isinstance(actual_response["confidence"], (int, float)):
104
+ validation["field_errors"].append("confidence must be numeric")
105
+ elif not (0 <= actual_response["confidence"] <= 1):
106
+ validation["field_errors"].append("confidence must be between 0 and 1")
107
+
108
+ if "predicted_label" in actual_response:
109
+ if not isinstance(actual_response["predicted_label"], int):
110
+ validation["field_errors"].append("predicted_label must be integer")
111
+
112
+ return validation
113
+
114
+ except json.JSONDecodeError:
115
+ return {
116
+ "structure_valid": False,
117
+ "field_errors": ["Invalid expected response JSON"]
118
+ }
119
+
120
+ def test_dataset(self, dataset_path: Path, max_samples: Optional[int] = None) -> Dict[str, Any]:
121
+ """Test a single dataset."""
122
+ print(f"๐Ÿ“Š Testing dataset: {dataset_path.name}")
123
+
124
+ try:
125
+ df = self.load_dataset(dataset_path)
126
+ if max_samples:
127
+ df = df.head(max_samples)
128
+
129
+ results = {
130
+ "dataset_name": dataset_path.stem,
131
+ "total_samples": len(df),
132
+ "tested_samples": 0,
133
+ "successful_requests": 0,
134
+ "failed_requests": 0,
135
+ "validation_errors": 0,
136
+ "latencies_ms": [],
137
+ "errors": [],
138
+ "category": df['test_category'].iloc[0] if not df.empty else "unknown"
139
+ }
140
+
141
+ for idx, row in df.iterrows():
142
+ print(f" Testing sample {idx + 1}/{len(df)}", end="\r")
143
+
144
+ # Send API request
145
+ api_result = self.send_prediction_request(row['api_request'])
146
+ results["tested_samples"] += 1
147
+
148
+ if api_result["success"]:
149
+ results["successful_requests"] += 1
150
+ results["latencies_ms"].append(api_result["latency_ms"])
151
+
152
+ # Validate response structure
153
+ validation = self.validate_response(
154
+ api_result["response"],
155
+ row['expected_response']
156
+ )
157
+
158
+ if not validation["structure_valid"]:
159
+ results["validation_errors"] += 1
160
+ results["errors"].append({
161
+ "sample_id": row['image_id'],
162
+ "type": "validation_error",
163
+ "details": validation["field_errors"]
164
+ })
165
+
166
+ else:
167
+ results["failed_requests"] += 1
168
+ results["errors"].append({
169
+ "sample_id": row['image_id'],
170
+ "type": "request_failed",
171
+ "status_code": api_result["status_code"],
172
+ "error": api_result["error"]
173
+ })
174
+
175
+ # Calculate statistics
176
+ if results["latencies_ms"]:
177
+ results["avg_latency_ms"] = round(statistics.mean(results["latencies_ms"]), 2)
178
+ results["min_latency_ms"] = round(min(results["latencies_ms"]), 2)
179
+ results["max_latency_ms"] = round(max(results["latencies_ms"]), 2)
180
+ results["median_latency_ms"] = round(statistics.median(results["latencies_ms"]), 2)
181
+ else:
182
+ results.update({
183
+ "avg_latency_ms": None,
184
+ "min_latency_ms": None,
185
+ "max_latency_ms": None,
186
+ "median_latency_ms": None
187
+ })
188
+
189
+ results["success_rate"] = round(
190
+ results["successful_requests"] / results["tested_samples"] * 100, 2
191
+ ) if results["tested_samples"] > 0 else 0
192
+
193
+ print(f"\n โœ… Completed: {results['success_rate']}% success rate")
194
+ return results
195
+
196
+ except Exception as e:
197
+ print(f"\n โŒ Failed to test dataset: {str(e)}")
198
+ return {
199
+ "dataset_name": dataset_path.stem,
200
+ "error": str(e),
201
+ "success_rate": 0
202
+ }
203
+
204
+ def test_all_datasets(self, max_samples_per_dataset: Optional[int] = None,
205
+ category_filter: Optional[str] = None) -> Dict[str, Any]:
206
+ """Test all datasets or filtered by category."""
207
+ if not self.test_api_connection():
208
+ print("โŒ API is not accessible. Please start the service first:")
209
+ print(" uvicorn main:app --reload")
210
+ return {"error": "API not accessible"}
211
+
212
+ print(f" Starting dataset testing against {self.endpoint}")
213
+
214
+ parquet_files = list(self.datasets_dir.glob("*.parquet"))
215
+ if not parquet_files:
216
+ print(f"โŒ No datasets found in {self.datasets_dir}")
217
+ return {"error": "No datasets found"}
218
+
219
+ if category_filter:
220
+ parquet_files = [f for f in parquet_files if category_filter in f.name]
221
+
222
+ print(f" Found {len(parquet_files)} datasets to test")
223
+
224
+ all_results = []
225
+ start_time = time.time()
226
+
227
+ for dataset_file in parquet_files:
228
+ result = self.test_dataset(dataset_file, max_samples_per_dataset)
229
+ all_results.append(result)
230
+
231
+ end_time = time.time()
232
+ total_time = end_time - start_time
233
+
234
+ summary = self.generate_summary(all_results, total_time)
235
+
236
+ self.save_results(summary, all_results)
237
+
238
+ return summary
239
+
240
+ def generate_summary(self, results: List[Dict[str, Any]], total_time: float) -> Dict[str, Any]:
241
+ """Generate summary of all test results."""
242
+ successful_datasets = [r for r in results if r.get("success_rate", 0) > 0]
243
+ failed_datasets = [r for r in results if r.get("error") or r.get("success_rate", 0) == 0]
244
+
245
+ total_samples = sum(r.get("tested_samples", 0) for r in results)
246
+ total_successful = sum(r.get("successful_requests", 0) for r in results)
247
+ total_failed = sum(r.get("failed_requests", 0) for r in results)
248
+
249
+ all_latencies = []
250
+ for r in results:
251
+ all_latencies.extend(r.get("latencies_ms", []))
252
+
253
+ summary = {
254
+ "test_summary": {
255
+ "total_datasets": len(results),
256
+ "successful_datasets": len(successful_datasets),
257
+ "failed_datasets": len(failed_datasets),
258
+ "total_samples_tested": total_samples,
259
+ "total_successful_requests": total_successful,
260
+ "total_failed_requests": total_failed,
261
+ "overall_success_rate": round(
262
+ total_successful / total_samples * 100, 2
263
+ ) if total_samples > 0 else 0,
264
+ "total_test_time_seconds": round(total_time, 2)
265
+ },
266
+ "performance_metrics": {
267
+ "avg_latency_ms": round(statistics.mean(all_latencies), 2) if all_latencies else None,
268
+ "median_latency_ms": round(statistics.median(all_latencies), 2) if all_latencies else None,
269
+ "min_latency_ms": round(min(all_latencies), 2) if all_latencies else None,
270
+ "max_latency_ms": round(max(all_latencies), 2) if all_latencies else None,
271
+ "requests_per_second": round(
272
+ total_successful / total_time, 2
273
+ ) if total_time > 0 else 0
274
+ },
275
+ "category_breakdown": {},
276
+ "failed_datasets": [r["dataset_name"] for r in failed_datasets]
277
+ }
278
+
279
+ categories = {}
280
+ for result in results:
281
+ category = result.get("category", "unknown")
282
+ if category not in categories:
283
+ categories[category] = {
284
+ "count": 0,
285
+ "success_rates": [],
286
+ "avg_success_rate": 0
287
+ }
288
+ categories[category]["count"] += 1
289
+ categories[category]["success_rates"].append(result.get("success_rate", 0))
290
+
291
+ for category, data in categories.items():
292
+ data["avg_success_rate"] = round(
293
+ statistics.mean(data["success_rates"]), 2
294
+ ) if data["success_rates"] else 0
295
+
296
+ summary["category_breakdown"] = categories
297
+
298
+ return summary
299
+
300
+ def save_results(self, summary: Dict[str, Any], detailed_results: List[Dict[str, Any]]):
301
+ """Save test results to files."""
302
+ results_dir = Path("test_results")
303
+ results_dir.mkdir(exist_ok=True)
304
+
305
+ timestamp = int(time.time())
306
+
307
+ # Save summary
308
+ summary_path = results_dir / f"test_summary_{timestamp}.json"
309
+ with open(summary_path, 'w') as f:
310
+ json.dump(summary, f, indent=2)
311
+
312
+ # Save detailed results
313
+ detailed_path = results_dir / f"test_detailed_{timestamp}.json"
314
+ with open(detailed_path, 'w') as f:
315
+ json.dump(detailed_results, f, indent=2)
316
+
317
+ print(f" Results saved:")
318
+ print(f" Summary: {summary_path}")
319
+ print(f" Details: {detailed_path}")
320
+
321
+ def print_summary(self, summary: Dict[str, Any]):
322
+ """Print test summary to console."""
323
+ print("\n" + "="*60)
324
+ print("๐Ÿ DATASET TESTING SUMMARY")
325
+ print("="*60)
326
+
327
+ ts = summary["test_summary"]
328
+ print(f"Datasets tested: {ts['total_datasets']}")
329
+ print(f"Successful datasets: {ts['successful_datasets']}")
330
+ print(f"Failed datasets: {ts['failed_datasets']}")
331
+ print(f"Total samples: {ts['total_samples_tested']}")
332
+ print(f"Overall success rate: {ts['overall_success_rate']}%")
333
+ print(f"Test duration: {ts['total_test_time_seconds']}s")
334
+
335
+ pm = summary["performance_metrics"]
336
+ if pm["avg_latency_ms"]:
337
+ print(f"\nPerformance:")
338
+ print(f" Avg latency: {pm['avg_latency_ms']}ms")
339
+ print(f" Median latency: {pm['median_latency_ms']}ms")
340
+ print(f" Min latency: {pm['min_latency_ms']}ms")
341
+ print(f" Max latency: {pm['max_latency_ms']}ms")
342
+ print(f" Requests/sec: {pm['requests_per_second']}")
343
+
344
+ print(f"\nCategory breakdown:")
345
+ for category, data in summary["category_breakdown"].items():
346
+ print(f" {category}: {data['count']} datasets, {data['avg_success_rate']}% avg success")
347
+
348
+ if summary["failed_datasets"]:
349
+ print(f"\nFailed datasets: {', '.join(summary['failed_datasets'])}")
350
+
351
+
352
+ def main():
353
+ parser = argparse.ArgumentParser(description="Test PyArrow datasets against ML inference service")
354
+ parser.add_argument("--base-url", default="http://127.0.0.1:8000", help="Base URL of the API")
355
+ parser.add_argument("--datasets-dir", default="scripts/test_datasets", help="Directory containing datasets")
356
+ parser.add_argument("--max-samples", type=int, help="Max samples per dataset to test")
357
+ parser.add_argument("--category", help="Filter datasets by category (standard, edge_case, performance, model_comparison)")
358
+ parser.add_argument("--quick", action="store_true", help="Quick test with max 5 samples per dataset")
359
+
360
+ args = parser.parse_args()
361
+
362
+ tester = DatasetTester(args.base_url, args.datasets_dir)
363
+
364
+ max_samples = args.max_samples
365
+ if args.quick:
366
+ max_samples = 5
367
+
368
+ results = tester.test_all_datasets(max_samples, args.category)
369
+
370
+ if "error" not in results:
371
+ tester.print_summary(results)
372
+
373
+ if results["test_summary"]["overall_success_rate"] > 90:
374
+ print("\n๐ŸŽ‰ Excellent! API is working great with the datasets!")
375
+ elif results["test_summary"]["overall_success_rate"] > 70:
376
+ print("\n๐Ÿ‘ Good! API works well, minor issues detected.")
377
+ else:
378
+ print("\nโš ๏ธ Warning: Several issues detected. Check the detailed results.")
379
+
380
+
381
+ if __name__ == "__main__":
382
+ main()