File size: 4,193 Bytes
b1f0e98 c2feb3e 33241cf c2feb3e b1f0e98 c2feb3e 33241cf c2feb3e b1f0e98 c2feb3e be5bf87 33241cf b1f0e98 c2feb3e e9a47ca c2feb3e e9a47ca c2feb3e b1f0e98 33241cf b1f0e98 33241cf c2feb3e 33241cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
"""
Pydantic models for request/response validation.
"""
from __future__ import annotations
import base64
import enum
import io
import typing
from typing import Literal, Optional
import numpy as np
import pydantic
from PIL import Image
if typing.TYPE_CHECKING:
from numpy.typing import NDArray
class ImageData(pydantic.BaseModel):
"""Image data model for base64 encoded images."""
mediaType: str = pydantic.Field(
description="The IETF Media Type (MIME type) of the data"
)
data: str = pydantic.Field(
description="A base64 string encoding of the data.",
# Canonical base64 encoding
# https://stackoverflow.com/a/64467300/3709935
pattern=r"^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/][AQgw]==|[A-Za-z0-9+/]{2}[AEIMQUYcgkosw048]=)?$",
)
class BinaryMask(pydantic.BaseModel):
"""A bit mask indicating which pixels are manipulated / synthesized. A
pixel value of ``0`` means "no detection", and a value of ``1`` means
"detection".
The mask data must be encoded in PNG format, so that typical mask data is
compressed effectively. The PNG encoding **should** use "bilevel" mode for
maximum compactness. You can use the ``BinaryMask.from_numpy()``
function to convert a 0-1 numpy array to a BinaryMask.
"""
mediaType: str = pydantic.Field(
description="The IETF Media Type (MIME type) of the data."
)
data: str = pydantic.Field(
description="A base64 string encoding of the data.",
# Canonical base64 encoding
# https://stackoverflow.com/a/64467300/3709935
pattern=r"^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/][AQgw]==|[A-Za-z0-9+/]{2}[AEIMQUYcgkosw048]=)?$",
)
@pydantic.field_validator("mediaType", mode="after")
@classmethod
def validate_mediaType(cls, value: str) -> str:
if value != "image/png":
raise ValueError(".mediaType must be 'image/png'")
return value
@staticmethod
def from_numpy(mask: NDArray[np.uint8]) -> BinaryMask:
"""Convert a 0-1 numpy array to a BinaryMask.
The numpy data must be in row-major order. That means the first
dimension corresponds to **height** and the second dimension corresponds
to **width**.
"""
# Convert to "L" (grayscale) then "1" (bilevel) for compact binary representation
mask_img = Image.fromarray(mask * 255, mode="L").convert("1", dither=None)
mask_img_buffer = io.BytesIO()
mask_img.save(mask_img_buffer, format="png")
mask_data = base64.b64encode(mask_img_buffer.getbuffer()).decode("utf-8")
return BinaryMask(mediaType="image/png", data=mask_data)
class ImageRequest(pydantic.BaseModel):
"""Request model for image classification."""
image: ImageData
class Labels(enum.IntEnum):
Natural = 0
FullySynthesized = 1
LocallyEdited = 2
LocallySynthesized = 3
class PredictionResponse(pydantic.BaseModel):
"""Response model for synthetic image classification results.
Detector models will be scored primarily on their ability to classify the
entire image into 1 of the 4 label categories::
0: (Natural) The image is natural / unaltered.
1: (FullySynthesized) The entire image was synthesized by e.g., a
generative image model.
2: (LocallyEdited) The image is a natural image where a portion has
been edited using traditional photo editing techniques such as
splicing.
3: (LocallySynthesized) The image is a natural image where a portion
has been replaced by synthesized content.
"""
logprobs: list[float] = pydantic.Field(
description="The log-probabilities for each of the 4 possible labels.",
min_length=4,
max_length=4,
)
localizationMask: Optional[BinaryMask] = pydantic.Field(
description="A bit mask localizing predicted edits. Models that are"
" not capable of localization may omit this field. It may also be"
" omitted if the predicted label is ``0`` or ``1``, in which case the"
" mask will be assumed to be all 0's or all 1's, as appropriate."
) |