Spaces:
Runtime error
Runtime error
| import random | |
| from typing import List, Union, Optional, Tuple | |
| import torch | |
| from PIL import Image | |
| from sample import (arg_parse, | |
| sampling, | |
| load_fontdiffuer_pipeline) | |
| def run_fontdiffuer(source_image, | |
| character, | |
| reference_image, | |
| sampling_step, | |
| guidance_scale, | |
| batch_size=1): | |
| args.character_input = False if source_image is not None else True | |
| args.content_character = character | |
| args.sampling_step = sampling_step | |
| args.guidance_scale = guidance_scale | |
| args.batch_size = batch_size | |
| args.seed = random.randint(0, 10000) | |
| out_image = sampling( | |
| args=args, | |
| pipe=pipe, | |
| content_image=source_image, | |
| style_image=reference_image) | |
| if out_image is not None: | |
| out_image.format = 'PNG' | |
| return out_image | |
| def run_inference( | |
| source_image_path: Union[str, None], | |
| character: Union[str, None], | |
| reference_image_path: str, | |
| sampling_step: int=50, | |
| guidance_scale: float=7.5, | |
| ): | |
| if source_image_path is not None: | |
| source_image = Image.open(source_image_path).convert('RGB') | |
| else: | |
| source_image = None | |
| if reference_image_path is not None: | |
| reference_image = Image.open(reference_image_path).convert('RGB') | |
| else: | |
| reference_image = None | |
| image = run_fontdiffuer( | |
| source_image=source_image, | |
| character=character, | |
| reference_image=reference_image, | |
| sampling_step=sampling_step, | |
| guidance_scale=guidance_scale | |
| ) | |
| return image | |
| if __name__ == '__main__': | |
| args = arg_parse() | |
| args.demo = True | |
| args.ckpt_dir = 'ckpt' | |
| args.ttf_path = 'ttf/KaiXinSongA.ttf' | |
| args.device = 'cuda' | |
| # load fontdiffuer pipeline | |
| pipe = load_fontdiffuer_pipeline(args=args) | |
| image = run_inference( | |
| character=None, | |
| source_image_path="figures/ref_imgs/ref_壤.jpg", | |
| reference_image_path="figures/ref_imgs/ref_欟.jpg" | |
| ) | |
| print(image) | |