upsample / handler.py
jayyap's picture
Update handler.py
0edcf54
from typing import Dict, List, Any
from diffusers import DiffusionPipeline
import torch
from io import BytesIO
import requests
from PIL import Image
import base64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("need to run on GPU")
# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
class EndpointHandler():
def __init__(self, path=""):
self.pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-super-resolution-4x-openimages", torch_dtype=dtype).to(device)
# # this command loads the individual model components on GPU on-demand.
# self.pipeline.enable_model_cpu_offload()
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
image = data.pop("image", None)
# process image
image = self.decode_base64_image(image)
low_res_img = image#.resize((128, 128))
with torch.no_grad():
upscaled_image = self.pipeline(low_res_img, num_inference_steps=100, eta=1).images[0]
return upscaled_image
# helper to decode input image
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image