jessehostetler commited on
Commit
33241cf
·
1 Parent(s): c738e68

Finalize Request and Response schema and update model accordingly.

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app/api/models.py +71 -10
  3. app/services/inference.py +21 -14
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ models/
2
+ venv/
app/api/models.py CHANGED
@@ -1,24 +1,85 @@
1
  """
2
  Pydantic models for request/response validation.
3
  """
4
- from pydantic import BaseModel
 
5
 
 
6
 
7
- class ImageData(BaseModel):
 
8
  """Image data model for base64 encoded images."""
9
  mediaType: str
10
  data: str
11
 
12
 
13
- class ImageRequest(BaseModel):
14
  """Request model for image classification."""
15
  image: ImageData
16
 
17
 
18
- class PredictionResponse(BaseModel):
19
- """Response model for image classification results."""
20
- prediction: str
21
- confidence: float
22
- model: str
23
- predicted_label: int
24
- mediaType: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  Pydantic models for request/response validation.
3
  """
4
+ import enum
5
+ from typing import Optional
6
 
7
+ import pydantic
8
 
9
+
10
+ class ImageData(pydantic.BaseModel):
11
  """Image data model for base64 encoded images."""
12
  mediaType: str
13
  data: str
14
 
15
 
16
+ class ImageRequest(pydantic.BaseModel):
17
  """Request model for image classification."""
18
  image: ImageData
19
 
20
 
21
+ class Labels(enum.IntEnum):
22
+ Natural = 0
23
+ FullySynthesized = 1
24
+ LocallyEdited = 2
25
+ LocallySynthesized = 3
26
+
27
+
28
+ class LocalizationMask(pydantic.BaseModel):
29
+ """A bit mask indicating which pixels are manipulated / synthesized.
30
+
31
+ A bit value of ``1`` means that the model believes the corresponding pixel
32
+ has been edited or synthesized (i.e., its label would be non-zero).
33
+ A bit value of ``0`` means that the model believes the pixel is unaltered.
34
+
35
+ The mask ``.width`` and ``.height`` should be the same as the input image.
36
+ Extra bits at the end of ``.bitsRowMajor`` after the first
37
+ ``width * height`` bits are **ignored**; for simplicity/efficiency,
38
+ you should encode your bit mask into a byte array and not worry if the
39
+ final byte isn't "full", then convert the byte array to base64.
40
+ """
41
+
42
+ width: int = pydantic.Field(
43
+ description="The width of the mask."
44
+ )
45
+
46
+ height: int = pydantic.Field(
47
+ description="The height of the mask."
48
+ )
49
+
50
+ bitsRowMajor: str = pydantic.Field(
51
+ description="A base64 string encoding the bit mask in row-major order.",
52
+ # Canonical base64 encoding
53
+ # https://stackoverflow.com/a/64467300/3709935
54
+ pattern=r"^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/][AQgw]==|[A-Za-z0-9+/]{2}[AEIMQUYcgkosw048]=)?$",
55
+ )
56
+
57
+
58
+ class PredictionResponse(pydantic.BaseModel):
59
+ """Response model for synthetic image classification results.
60
+
61
+ Detector models will be scored primarily on their ability to classify the
62
+ entire image into 1 of the 4 label categories::
63
+
64
+ 0: (Natural) The image is natural / unaltered.
65
+ 1: (FullySynthesized) The entire image was synthesized by e.g., a
66
+ generative image model.
67
+ 2: (LocallyEdited) The image is a natural image where a portion has
68
+ been edited using traditional photo editing techniques such as
69
+ splicing.
70
+ 3: (LocallySynthesized) The image is a natural image where a portion
71
+ has been replaced by synthesized content.
72
+ """
73
+
74
+ logprobs: list[float] = pydantic.Field(
75
+ description="The log-probabilities for each of the 4 possible labels.",
76
+ min_length=4,
77
+ max_length=4,
78
+ )
79
+
80
+ localizationMask: Optional[LocalizationMask] = pydantic.Field(
81
+ description="A bit mask localizing predicted edits. Models that are"
82
+ " not capable of localization may omit this field. It may also be"
83
+ " omitted if the predicted label is ``0`` or ``1``, in which case the"
84
+ " mask will be assumed to be all 0's or all 1's, as appropriate."
85
+ )
app/services/inference.py CHANGED
@@ -1,15 +1,17 @@
1
  """ResNet inference service implementation."""
2
 
3
- import os
4
  import base64
 
 
5
  from io import BytesIO
 
6
  import torch
7
  from PIL import Image
8
- from transformers import AutoImageProcessor, ResNetForImageClassification
9
 
10
  from app.core.logging import logger
11
  from app.services.base import InferenceService
12
- from app.api.models import ImageRequest, PredictionResponse
13
 
14
 
15
  class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
@@ -45,13 +47,20 @@ class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse])
45
  self.model = ResNetForImageClassification.from_pretrained(
46
  self.model_path, local_files_only=True
47
  )
 
48
 
49
  self._is_loaded = True
50
- logger.info(f"Model loaded: {len(self.model.config.id2label)} classes")
51
 
52
  def predict(self, request: ImageRequest) -> PredictionResponse:
 
 
 
 
 
53
  image_data = base64.b64decode(request.image.data)
54
  image = Image.open(BytesIO(image_data))
 
55
 
56
  if image.mode != 'RGB':
57
  image = image.convert('RGB')
@@ -59,19 +68,17 @@ class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse])
59
  inputs = self.processor(image, return_tensors="pt")
60
 
61
  with torch.no_grad():
62
- logits = self.model(**inputs).logits
63
 
64
- predicted_label = logits.argmax(-1).item()
65
- predicted_class = self.model.config.id2label[predicted_label]
66
- probabilities = torch.nn.functional.softmax(logits, dim=-1)
67
- confidence = probabilities[0][predicted_label].item()
68
 
69
  return PredictionResponse(
70
- prediction=predicted_class,
71
- confidence=round(confidence, 4),
72
- model=self.model_name,
73
- predicted_label=predicted_label,
74
- mediaType=request.image.mediaType
75
  )
76
 
77
  @property
 
1
  """ResNet inference service implementation."""
2
 
 
3
  import base64
4
+ import os
5
+ import random
6
  from io import BytesIO
7
+
8
  import torch
9
  from PIL import Image
10
+ from transformers import AutoImageProcessor, ResNetForImageClassification # type: ignore[import-untyped]
11
 
12
  from app.core.logging import logger
13
  from app.services.base import InferenceService
14
+ from app.api.models import ImageRequest, Labels, LocalizationMask, PredictionResponse
15
 
16
 
17
  class ResNetInferenceService(InferenceService[ImageRequest, PredictionResponse]):
 
47
  self.model = ResNetForImageClassification.from_pretrained(
48
  self.model_path, local_files_only=True
49
  )
50
+ assert self.model is not None
51
 
52
  self._is_loaded = True
53
+ logger.info(f"Model loaded: {len(self.model.config.id2label)} classes") # pyright: ignore
54
 
55
  def predict(self, request: ImageRequest) -> PredictionResponse:
56
+ if not self.is_loaded:
57
+ raise RuntimeError("model is not loaded")
58
+ assert self.processor is not None
59
+ assert self.model is not None
60
+
61
  image_data = base64.b64decode(request.image.data)
62
  image = Image.open(BytesIO(image_data))
63
+ width, height = image.size
64
 
65
  if image.mode != 'RGB':
66
  image = image.convert('RGB')
 
68
  inputs = self.processor(image, return_tensors="pt")
69
 
70
  with torch.no_grad():
71
+ logits = self.model(**inputs).logits # pyright: ignore
72
 
73
+ logprobs = torch.nn.functional.log_softmax(logits[:len(Labels)], dim=-1).tolist()
74
+ mask_bytes = random.randbytes((width*height + 7) // 8)
75
+ mask_bits = base64.b64encode(mask_bytes).decode("utf-8")
 
76
 
77
  return PredictionResponse(
78
+ logprobs=logprobs,
79
+ localizationMask=LocalizationMask(
80
+ width=width, height=height, bitsRowMajor=mask_bits
81
+ )
 
82
  )
83
 
84
  @property