Commit
·
33241cf
1
Parent(s):
c738e68
Finalize Request and Response schema and update model accordingly.
Browse files- .gitignore +2 -0
- app/api/models.py +71 -10
- app/services/inference.py +21 -14
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
models/
|
| 2 |
+
venv/
|
app/api/models.py
CHANGED
|
@@ -1,24 +1,85 @@
|
|
| 1 |
"""
|
| 2 |
Pydantic models for request/response validation.
|
| 3 |
"""
|
| 4 |
-
|
|
|
|
| 5 |
|
|
|
|
| 6 |
|
| 7 |
-
|
|
|
|
| 8 |
"""Image data model for base64 encoded images."""
|
| 9 |
mediaType: str
|
| 10 |
data: str
|
| 11 |
|
| 12 |
|
| 13 |
-
class ImageRequest(BaseModel):
|
| 14 |
"""Request model for image classification."""
|
| 15 |
image: ImageData
|
| 16 |
|
| 17 |
|
| 18 |
-
class
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
Pydantic models for request/response validation.
|
| 3 |
"""
|
| 4 |
+
import enum
|
| 5 |
+
from typing import Optional
|
| 6 |
|
| 7 |
+
import pydantic
|
| 8 |
|
| 9 |
+
|
| 10 |
+
class ImageData(pydantic.BaseModel):
|
| 11 |
"""Image data model for base64 encoded images."""
|
| 12 |
mediaType: str
|
| 13 |
data: str
|
| 14 |
|
| 15 |
|
| 16 |
+
class ImageRequest(pydantic.BaseModel):
|
| 17 |
"""Request model for image classification."""
|
| 18 |
image: ImageData
|
| 19 |
|
| 20 |
|
| 21 |
+
class Labels(enum.IntEnum):
|
| 22 |
+
Natural = 0
|
| 23 |
+
FullySynthesized = 1
|
| 24 |
+
LocallyEdited = 2
|
| 25 |
+
LocallySynthesized = 3
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class LocalizationMask(pydantic.BaseModel):
|
| 29 |
+
"""A bit mask indicating which pixels are manipulated / synthesized.
|
| 30 |
+
|
| 31 |
+
A bit value of ``1`` means that the model believes the corresponding pixel
|
| 32 |
+
has been edited or synthesized (i.e., its label would be non-zero).
|
| 33 |
+
A bit value of ``0`` means that the model believes the pixel is unaltered.
|
| 34 |
+
|
| 35 |
+
The mask ``.width`` and ``.height`` should be the same as the input image.
|
| 36 |
+
Extra bits at the end of ``.bitsRowMajor`` after the first
|
| 37 |
+
``width * height`` bits are **ignored**; for simplicity/efficiency,
|
| 38 |
+
you should encode your bit mask into a byte array and not worry if the
|
| 39 |
+
final byte isn't "full", then convert the byte array to base64.
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
width: int = pydantic.Field(
|
| 43 |
+
description="The width of the mask."
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
height: int = pydantic.Field(
|
| 47 |
+
description="The height of the mask."
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
bitsRowMajor: str = pydantic.Field(
|
| 51 |
+
description="A base64 string encoding the bit mask in row-major order.",
|
| 52 |
+
# Canonical base64 encoding
|
| 53 |
+
# https://stackoverflow.com/a/64467300/3709935
|
| 54 |
+
pattern=r"^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/][AQgw]==|[A-Za-z0-9+/]{2}[AEIMQUYcgkosw048]=)?$",
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class PredictionResponse(pydantic.BaseModel):
|
| 59 |
+
"""Response model for synthetic image classification results.
|
| 60 |
+
|
| 61 |
+
Detector models will be scored primarily on their ability to classify the
|
| 62 |
+
entire image into 1 of the 4 label categories::
|
| 63 |
+
|
| 64 |
+
0: (Natural) The image is natural / unaltered.
|
| 65 |
+
1: (FullySynthesized) The entire image was synthesized by e.g., a
|
| 66 |
+
generative image model.
|
| 67 |
+
2: (LocallyEdited) The image is a natural image where a portion has
|
| 68 |
+
been edited using traditional photo editing techniques such as
|
| 69 |
+
splicing.
|
| 70 |
+
3: (LocallySynthesized) The image is a natural image where a portion
|
| 71 |
+
has been replaced by synthesized content.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
logprobs: list[float] = pydantic.Field(
|
| 75 |
+
description="The log-probabilities for each of the 4 possible labels.",
|
| 76 |
+
min_length=4,
|
| 77 |
+
max_length=4,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
localizationMask: Optional[LocalizationMask] = pydantic.Field(
|
| 81 |
+
description="A bit mask localizing predicted edits. Models that are"
|
| 82 |
+
" not capable of localization may omit this field. It may also be"
|
| 83 |
+
" omitted if the predicted label is ``0`` or ``1``, in which case the"
|
| 84 |
+
" mask will be assumed to be all 0's or all 1's, as appropriate."
|
| 85 |
+
)
|
app/services/inference.py
CHANGED
|
@@ -1,15 +1,17 @@
|
|
| 1 |
"""ResNet inference service implementation."""
|
| 2 |
|
| 3 |
-
import os
|
| 4 |
import base64
|
|
|
|
|
|
|
| 5 |
from io import BytesIO
|
|
|
|
| 6 |
import torch
|
| 7 |
from PIL import Image
|
| 8 |
-
from transformers import AutoImageProcessor, ResNetForImageClassification
|
| 9 |
|
| 10 |
from app.core.logging import logger
|
| 11 |
from app.services.base import InferenceService
|
| 12 |
-
from app.api.models import ImageRequest, PredictionResponse
|
| 13 |
|
| 14 |
|
| 15 |
class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
|
|
@@ -45,13 +47,20 @@ class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse])
|
|
| 45 |
self.model = ResNetForImageClassification.from_pretrained(
|
| 46 |
self.model_path, local_files_only=True
|
| 47 |
)
|
|
|
|
| 48 |
|
| 49 |
self._is_loaded = True
|
| 50 |
-
logger.info(f"Model loaded: {len(self.model.config.id2label)} classes")
|
| 51 |
|
| 52 |
def predict(self, request: ImageRequest) -> PredictionResponse:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
image_data = base64.b64decode(request.image.data)
|
| 54 |
image = Image.open(BytesIO(image_data))
|
|
|
|
| 55 |
|
| 56 |
if image.mode != 'RGB':
|
| 57 |
image = image.convert('RGB')
|
|
@@ -59,19 +68,17 @@ class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse])
|
|
| 59 |
inputs = self.processor(image, return_tensors="pt")
|
| 60 |
|
| 61 |
with torch.no_grad():
|
| 62 |
-
logits = self.model(**inputs).logits
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
confidence = probabilities[0][predicted_label].item()
|
| 68 |
|
| 69 |
return PredictionResponse(
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
mediaType=request.image.mediaType
|
| 75 |
)
|
| 76 |
|
| 77 |
@property
|
|
|
|
| 1 |
"""ResNet inference service implementation."""
|
| 2 |
|
|
|
|
| 3 |
import base64
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
from io import BytesIO
|
| 7 |
+
|
| 8 |
import torch
|
| 9 |
from PIL import Image
|
| 10 |
+
from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped]
|
| 11 |
|
| 12 |
from app.core.logging import logger
|
| 13 |
from app.services.base import InferenceService
|
| 14 |
+
from app.api.models import ImageRequest, Labels, LocalizationMask, PredictionResponse
|
| 15 |
|
| 16 |
|
| 17 |
class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
|
|
|
|
| 47 |
self.model = ResNetForImageClassification.from_pretrained(
|
| 48 |
self.model_path, local_files_only=True
|
| 49 |
)
|
| 50 |
+
assert self.model is not None
|
| 51 |
|
| 52 |
self._is_loaded = True
|
| 53 |
+
logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore
|
| 54 |
|
| 55 |
def predict(self, request: ImageRequest) -> PredictionResponse:
|
| 56 |
+
if not self.is_loaded:
|
| 57 |
+
raise RuntimeError("model is not loaded")
|
| 58 |
+
assert self.processor is not None
|
| 59 |
+
assert self.model is not None
|
| 60 |
+
|
| 61 |
image_data = base64.b64decode(request.image.data)
|
| 62 |
image = Image.open(BytesIO(image_data))
|
| 63 |
+
width, height = image.size
|
| 64 |
|
| 65 |
if image.mode != 'RGB':
|
| 66 |
image = image.convert('RGB')
|
|
|
|
| 68 |
inputs = self.processor(image, return_tensors="pt")
|
| 69 |
|
| 70 |
with torch.no_grad():
|
| 71 |
+
logits = self.model(**inputs).logits # pyright: ignore
|
| 72 |
|
| 73 |
+
logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)], dim=-1).tolist()
|
| 74 |
+
mask_bytes = random.randbytes((width*height + 7) // 8)
|
| 75 |
+
mask_bits = base64.b64encode(mask_bytes).decode("utf-8")
|
|
|
|
| 76 |
|
| 77 |
return PredictionResponse(
|
| 78 |
+
logprobs=logprobs,
|
| 79 |
+
localizationMask=LocalizationMask(
|
| 80 |
+
width=width, height=height, bitsRowMajor=mask_bits
|
| 81 |
+
)
|
|
|
|
| 82 |
)
|
| 83 |
|
| 84 |
@property
|