""" Pydantic models for request/response validation. """ from __future__ import annotations import base64 import enum import io import typing from typing import Literal, Optional import numpy as np import pydantic from PIL import Image if typing.TYPE_CHECKING: from numpy.typing import NDArray class ImageData(pydantic.BaseModel): """Image data model for base64 encoded images.""" mediaType: str = pydantic.Field( description="The IETF Media Type (MIME type) of the data" ) data: str = pydantic.Field( description="A base64 string encoding of the data.", # Canonical base64 encoding # https://stackoverflow.com/a/64467300/3709935 pattern=r"^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/][AQgw]==|[A-Za-z0-9+/]{2}[AEIMQUYcgkosw048]=)?$", ) class BinaryMask(pydantic.BaseModel): """A bit mask indicating which pixels are manipulated / synthesized. A pixel value of ``0`` means "no detection", and a value of ``1`` means "detection". The mask data must be encoded in PNG format, so that typical mask data is compressed effectively. The PNG encoding **should** use "bilevel" mode for maximum compactness. You can use the ``BinaryMask.from_numpy()`` function to convert a 0-1 numpy array to a BinaryMask. """ mediaType: str = pydantic.Field( description="The IETF Media Type (MIME type) of the data." ) data: str = pydantic.Field( description="A base64 string encoding of the data.", # Canonical base64 encoding # https://stackoverflow.com/a/64467300/3709935 pattern=r"^(?:[A-Za-z0-9+/]{4})*(?:[A-Za-z0-9+/][AQgw]==|[A-Za-z0-9+/]{2}[AEIMQUYcgkosw048]=)?$", ) @pydantic.field_validator("mediaType", mode="after") @classmethod def validate_mediaType(cls, value: str) -> str: if value != "image/png": raise ValueError(".mediaType must be 'image/png'") return value @staticmethod def from_numpy(mask: NDArray[np.uint8]) -> BinaryMask: """Convert a 0-1 numpy array to a BinaryMask. The numpy data must be in row-major order. That means the first dimension corresponds to **height** and the second dimension corresponds to **width**. """ # Convert to "L" (grayscale) then "1" (bilevel) for compact binary representation mask_img = Image.fromarray(mask * 255, mode="L").convert("1", dither=None) mask_img_buffer = io.BytesIO() mask_img.save(mask_img_buffer, format="png") mask_data = base64.b64encode(mask_img_buffer.getbuffer()).decode("utf-8") return BinaryMask(mediaType="image/png", data=mask_data) class ImageRequest(pydantic.BaseModel): """Request model for image classification.""" image: ImageData class Labels(enum.IntEnum): Natural = 0 FullySynthesized = 1 LocallyEdited = 2 LocallySynthesized = 3 class PredictionResponse(pydantic.BaseModel): """Response model for synthetic image classification results. Detector models will be scored primarily on their ability to classify the entire image into 1 of the 4 label categories:: 0: (Natural) The image is natural / unaltered. 1: (FullySynthesized) The entire image was synthesized by e.g., a generative image model. 2: (LocallyEdited) The image is a natural image where a portion has been edited using traditional photo editing techniques such as splicing. 3: (LocallySynthesized) The image is a natural image where a portion has been replaced by synthesized content. """ logprobs: list[float] = pydantic.Field( description="The log-probabilities for each of the 4 possible labels.", min_length=4, max_length=4, ) localizationMask: Optional[BinaryMask] = pydantic.Field( description="A bit mask localizing predicted edits. Models that are" " not capable of localization may omit this field. It may also be" " omitted if the predicted label is ``0`` or ``1``, in which case the" " mask will be assumed to be all 0's or all 1's, as appropriate." )