jessehostetler commited on
Commit
a12ee73
Β·
1 Parent(s): c2feb3e

Clean up docs. Fix test script incorrect path.

Browse files
.gitignore CHANGED
@@ -1,4 +1,7 @@
1
  dyff-outputs/
2
  models/
 
 
3
  venv/
4
  **/__pycache__
 
 
1
  dyff-outputs/
2
  models/
3
+ test_datasets/
4
+ test_results/
5
  venv/
6
  **/__pycache__
7
+ *.tmp
README.md CHANGED
@@ -39,9 +39,6 @@ make docker-build
39
 
40
  # Run
41
  make docker-run
42
-
43
- # Check logs
44
- docker logs -f safe-challenge-2025/example-submission
45
  ```
46
 
47
  ## Testing the API
@@ -76,16 +73,16 @@ example-submission/
76
  β”œβ”€β”€ main.py # Entry point
77
  β”œβ”€β”€ app/
78
  β”‚ β”œβ”€β”€ core/
79
- β”‚ β”‚ β”œβ”€β”€ app.py # App factory, config, DI, lifecycle
80
  β”‚ β”‚ └── logging.py # Logging setup
81
  β”‚ β”œβ”€β”€ api/
82
  β”‚ β”‚ β”œβ”€β”€ models.py # Request/response schemas
83
- β”‚ β”‚ β”œβ”€β”€ controllers.py # <= IMPLEMENT YOUR DETECTOR HERE
84
  β”‚ β”‚ └── routes/
85
  β”‚ β”‚ └── prediction.py # POST /predict
86
  β”‚ └── services/
87
- β”‚ β”œβ”€β”€ base.py # Abstract InferenceService class
88
- β”‚ └── inference.py # ResNet implementation
89
  β”œβ”€β”€ models/
90
  β”‚ └── microsoft/
91
  β”‚ └── resnet-18/ # Model weights and config
@@ -97,17 +94,16 @@ example-submission/
97
  β”œβ”€β”€ .env.example # Environment config template
98
  β”œβ”€β”€ cat.json # An example /predict request object
99
  β”œβ”€β”€ makefile
 
100
  β”œβ”€β”€ requirements.in
101
  β”œβ”€β”€ requirements.txt
102
- β”œβ”€β”€ response.json # An example /predict response object
103
  └──
104
  ```
105
 
106
- The key design decision here is that `app/core/app.py` consolidates everythingβ€”config, dependency injection, lifecycle, and the app factory. This avoids the mess of managing global state across multiple files.
107
-
108
  ## How to Plug In Your Own Model
109
 
110
- The whole service is built around one abstract base class: `InferenceService`. Implement it for your model, and everything else just works.
111
 
112
  ### Step 1: Create Your Service Class
113
 
@@ -115,7 +111,6 @@ The whole service is built around one abstract base class: `InferenceService`. I
115
  # app/services/your_model_service.py
116
  from app.services.base import InferenceService
117
  from app.api.models import ImageRequest, PredictionResponse
118
- import asyncio
119
 
120
  class YourModelService(InferenceService[ImageRequest, PredictionResponse]):
121
  def __init__(self, model_name: str):
@@ -124,26 +119,22 @@ class YourModelService(InferenceService[ImageRequest, PredictionResponse]):
124
  self.model = None
125
  self._is_loaded = False
126
 
127
- async def load_model(self) -> None:
128
  """Load your model here. Called once at startup."""
129
  self.model = load_your_model(self.model_path)
130
  self._is_loaded = True
131
 
132
- async def predict(self, request: ImageRequest) -> PredictionResponse:
133
- """Run inference. Offload heavy work to thread pool."""
134
- return await asyncio.to_thread(self._predict_sync, request)
135
-
136
- def _predict_sync(self, request: ImageRequest) -> PredictionResponse:
137
  """Actual inference happens here."""
138
  image = decode_base64_image(request.image.data)
139
  result = self.model(image)
140
 
 
 
 
141
  return PredictionResponse(
142
- prediction=result.label,
143
- confidence=result.confidence,
144
- predicted_label=result.class_id,
145
- model=self.model_name,
146
- mediaType=request.image.mediaType
147
  )
148
 
149
  @property
@@ -151,8 +142,6 @@ class YourModelService(InferenceService[ImageRequest, PredictionResponse]):
151
  return self._is_loaded
152
  ```
153
 
154
- **Important:** Use `asyncio.to_thread()` to run CPU-heavy inference in a background thread. This keeps the server responsive while your model is working.
155
-
156
  ### Step 2: Register Your Service
157
 
158
  Open `app/core/app.py` and find the lifespan function:
@@ -162,14 +151,14 @@ Open `app/core/app.py` and find the lifespan function:
162
  service = ResNetInferenceService(model_name="microsoft/resnet-18")
163
 
164
  # To this:
165
- service = YourModelService(model_name="your-org/your-model")
166
  ```
167
 
168
  That's it. The `/predict` endpoint now serves your model.
169
 
170
  ### Model Files
171
 
172
- Put your model files under `models/` with the full org/model structure:
173
 
174
  ```
175
  models/
@@ -180,8 +169,6 @@ models/
180
  └── (other files)
181
  ```
182
 
183
- No renaming, no dropping the org prefixβ€”it just mirrors the Hugging Face structure.
184
-
185
  ## Configuration
186
 
187
  Settings are managed via environment variables or a `.env` file. See `.env.example` for all available options.
@@ -251,7 +238,7 @@ If you see "Model directory not found", check that your model files exist at the
251
  {
252
  "image": {
253
  "mediaType": "image/jpeg", // or "image/png"
254
- "data": "<base64-encoded-image>"
255
  }
256
  }
257
  ```
@@ -259,11 +246,11 @@ If you see "Model directory not found", check that your model files exist at the
259
  **Response:**
260
  ```json
261
  {
262
- "prediction": "string", // Human-readable label
263
- "confidence": 0.0, // Softmax probability
264
- "predicted_label": 0, // Numeric class index
265
- "model": "org/model-name", // Model identifier
266
- "mediaType": "image/jpeg" // Echoed from request
267
  }
268
  ```
269
 
@@ -291,8 +278,12 @@ This creates:
291
 
292
  ```bash
293
  # Start your service first
294
- uvicorn main:app --reload
 
 
 
295
 
 
296
  # Quick test (5 samples per dataset)
297
  python scripts/test_datasets.py --quick
298
 
@@ -367,7 +358,7 @@ uvicorn main:app --port 8080
367
 
368
  **Model not loading:**
369
  - Check the path: models should be in `models/<org>/<model-name>/`
370
- - Make sure you ran `bash scripts/model_download.bash`
371
  - Check logs for the exact error
372
 
373
  **Slow inference:**
 
39
 
40
  # Run
41
  make docker-run
 
 
 
42
  ```
43
 
44
  ## Testing the API
 
73
  β”œβ”€β”€ main.py # Entry point
74
  β”œβ”€β”€ app/
75
  β”‚ β”œβ”€β”€ core/
76
+ β”‚ β”‚ β”œβ”€β”€ app.py # <= INSTANTIATE YOUR DETECTOR HERE
77
  β”‚ β”‚ └── logging.py # Logging setup
78
  β”‚ β”œβ”€β”€ api/
79
  β”‚ β”‚ β”œβ”€β”€ models.py # Request/response schemas
80
+ β”‚ β”‚ β”œβ”€β”€ controllers.py # Business logic
81
  β”‚ β”‚ └── routes/
82
  β”‚ β”‚ └── prediction.py # POST /predict
83
  β”‚ └── services/
84
+ β”‚ β”œβ”€β”€ base.py # <= YOUR DETECTOR IMPLEMENTS THIS INTERFACE
85
+ β”‚ └── inference.py # Example service based on ResNet-18
86
  β”œβ”€β”€ models/
87
  β”‚ └── microsoft/
88
  β”‚ └── resnet-18/ # Model weights and config
 
94
  β”œβ”€β”€ .env.example # Environment config template
95
  β”œβ”€β”€ cat.json # An example /predict request object
96
  β”œβ”€β”€ makefile
97
+ β”œβ”€β”€ prompt.sh # Script that makes a /predict request
98
  β”œβ”€β”€ requirements.in
99
  β”œβ”€β”€ requirements.txt
100
+ β”œβ”€β”€ response.json # An example /predict response object
101
  └──
102
  ```
103
 
 
 
104
  ## How to Plug In Your Own Model
105
 
106
+ To integrate your model, implement the `InferenceService` abstract class defined in `app/services/base.py`. You can follow the example implementation in `app/services/inference.py`, which is based on ResNet-18. After implementing the required interface, instantiate your model in the `lifespan()` function in `app/core/app.py`, replacing the `ResNetInferenceService` instance.
107
 
108
  ### Step 1: Create Your Service Class
109
 
 
111
  # app/services/your_model_service.py
112
  from app.services.base import InferenceService
113
  from app.api.models import ImageRequest, PredictionResponse
 
114
 
115
  class YourModelService(InferenceService[ImageRequest, PredictionResponse]):
116
  def __init__(self, model_name: str):
 
119
  self.model = None
120
  self._is_loaded = False
121
 
122
+ def load_model(self) -> None:
123
  """Load your model here. Called once at startup."""
124
  self.model = load_your_model(self.model_path)
125
  self._is_loaded = True
126
 
127
+ def predict(self, request: ImageRequest) -> PredictionResponse:
 
 
 
 
128
  """Actual inference happens here."""
129
  image = decode_base64_image(request.image.data)
130
  result = self.model(image)
131
 
132
+ logprobs = ...
133
+ mask = ...
134
+
135
  return PredictionResponse(
136
+ logprobs=logprobs,
137
+ localizationMask=mask,
 
 
 
138
  )
139
 
140
  @property
 
142
  return self._is_loaded
143
  ```
144
 
 
 
145
  ### Step 2: Register Your Service
146
 
147
  Open `app/core/app.py` and find the lifespan function:
 
151
  service = ResNetInferenceService(model_name="microsoft/resnet-18")
152
 
153
  # To this:
154
+ service = YourModelService(...)
155
  ```
156
 
157
  That's it. The `/predict` endpoint now serves your model.
158
 
159
  ### Model Files
160
 
161
+ Put your model files under the `models/` directory:
162
 
163
  ```
164
  models/
 
169
  └── (other files)
170
  ```
171
 
 
 
172
  ## Configuration
173
 
174
  Settings are managed via environment variables or a `.env` file. See `.env.example` for all available options.
 
238
  {
239
  "image": {
240
  "mediaType": "image/jpeg", // or "image/png"
241
+ "data": "<base64 string>"
242
  }
243
  }
244
  ```
 
246
  **Response:**
247
  ```json
248
  {
249
+ "logprobs": [float], // Log-probabilities of each label
250
+ "localizationMask": { // [Optional] binary mask
251
+ "mediaType": "image/png", // Always png
252
+ "data": "<base64 string>" // Image data
253
+ }
254
  }
255
  ```
256
 
 
278
 
279
  ```bash
280
  # Start your service first
281
+ make serve
282
+ ```
283
+
284
+ In another terminal:
285
 
286
+ ```bash
287
  # Quick test (5 samples per dataset)
288
  python scripts/test_datasets.py --quick
289
 
 
358
 
359
  **Model not loading:**
360
  - Check the path: models should be in `models/<org>/<model-name>/`
361
+ - If you're trying to run the example ResNet-based model, make sure you ran `make download` to fetch the model weights.
362
  - Check logs for the exact error
363
 
364
  **Slow inference:**
app/core/app.py CHANGED
@@ -3,7 +3,7 @@
3
  import asyncio
4
  import warnings
5
  from contextlib import asynccontextmanager
6
- from typing import AsyncGenerator, Optional
7
 
8
  from fastapi import FastAPI
9
  from pydantic import Field
 
3
  import asyncio
4
  import warnings
5
  from contextlib import asynccontextmanager
6
+ from typing import AsyncGenerator
7
 
8
  from fastapi import FastAPI
9
  from pydantic import Field
app/services/inference.py CHANGED
@@ -2,7 +2,6 @@
2
 
3
  import base64
4
  import os
5
- import random
6
  from io import BytesIO
7
 
8
  import numpy as np
@@ -61,7 +60,6 @@ class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse])
61
 
62
  image_data = base64.b64decode(request.image.data)
63
  image = Image.open(BytesIO(image_data))
64
- width, height = image.size
65
 
66
  if image.mode != 'RGB':
67
  image = image.convert('RGB')
 
2
 
3
  import base64
4
  import os
 
5
  from io import BytesIO
6
 
7
  import numpy as np
 
60
 
61
  image_data = base64.b64decode(request.image.data)
62
  image = Image.open(BytesIO(image_data))
 
63
 
64
  if image.mode != 'RGB':
65
  image = image.convert('RGB')
mask.png ADDED
scripts/test_datasets.py CHANGED
@@ -23,7 +23,7 @@ class DatasetTester:
23
  def __init__(self, base_url: str = "http://127.0.0.1:8000", datasets_dir: str = "test_datasets"):
24
  self.base_url = base_url.rstrip('/')
25
  self.datasets_dir = Path(datasets_dir)
26
- self.endpoint = f"{self.base_url}/predict/resnet"
27
  self.results = []
28
 
29
  def load_dataset(self, dataset_path: Path) -> pd.DataFrame:
@@ -352,7 +352,7 @@ class DatasetTester:
352
  def main():
353
  parser = argparse.ArgumentParser(description="Test PyArrow datasets against ML inference service")
354
  parser.add_argument("--base-url", default="http://127.0.0.1:8000", help="Base URL of the API")
355
- parser.add_argument("--datasets-dir", default="scripts/test_datasets", help="Directory containing datasets")
356
  parser.add_argument("--max-samples", type=int, help="Max samples per dataset to test")
357
  parser.add_argument("--category", help="Filter datasets by category (standard, edge_case, performance, model_comparison)")
358
  parser.add_argument("--quick", action="store_true", help="Quick test with max 5 samples per dataset")
 
23
  def __init__(self, base_url: str = "http://127.0.0.1:8000", datasets_dir: str = "test_datasets"):
24
  self.base_url = base_url.rstrip('/')
25
  self.datasets_dir = Path(datasets_dir)
26
+ self.endpoint = f"{self.base_url}/predict"
27
  self.results = []
28
 
29
  def load_dataset(self, dataset_path: Path) -> pd.DataFrame:
 
352
  def main():
353
  parser = argparse.ArgumentParser(description="Test PyArrow datasets against ML inference service")
354
  parser.add_argument("--base-url", default="http://127.0.0.1:8000", help="Base URL of the API")
355
+ parser.add_argument("--datasets-dir", default="test_datasets", help="Directory containing datasets")
356
  parser.add_argument("--max-samples", type=int, help="Max samples per dataset to test")
357
  parser.add_argument("--category", help="Filter datasets by category (standard, edge_case, performance, model_comparison)")
358
  parser.add_argument("--quick", action="store_true", help="Quick test with max 5 samples per dataset")