| ๏ปฟ<!--Copyright 2023 The HuggingFace Team. All rights reserved. | |
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | |
| the License. You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | |
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |
| specific language governing permissions and limitations under the License. | |
| --> | |
| [[open-in-colab]] | |
| # Diffusion ๋ชจ๋ธ์ ํ์ตํ๊ธฐ | |
| Unconditional ์ด๋ฏธ์ง ์์ฑ์ ํ์ต์ ์ฌ์ฉ๋ ๋ฐ์ดํฐ์ ๊ณผ ์ ์ฌํ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๋ diffusion ๋ชจ๋ธ์์ ์ธ๊ธฐ ์๋ ์ดํ๋ฆฌ์ผ์ด์ ์ ๋๋ค. ์ผ๋ฐ์ ์ผ๋ก, ๊ฐ์ฅ ์ข์ ๊ฒฐ๊ณผ๋ ํน์ ๋ฐ์ดํฐ์ ์ ์ฌ์ ํ๋ จ๋ ๋ชจ๋ธ์ ํ์ธํ๋ํ๋ ๊ฒ์ผ๋ก ์ป์ ์ ์์ต๋๋ค. ์ด [ํ๋ธ](https://huggingface.co/search/full-text?q=unconditional-image-generation&type=model)์์ ์ด๋ฌํ ๋ง์ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฐพ์ ์ ์์ง๋ง, ๋ง์ฝ ๋ง์์ ๋๋ ์ฒดํฌํฌ์ธํธ๋ฅผ ์ฐพ์ง ๋ชปํ๋ค๋ฉด, ์ธ์ ๋ ์ง ์ค์ค๋ก ํ์ตํ ์ ์์ต๋๋ค! | |
| ์ด ํํ ๋ฆฌ์ผ์ ๋๋ง์ ๐ฆ ๋๋น ๐ฆ๋ฅผ ์์ฑํ๊ธฐ ์ํด [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) ๋ฐ์ดํฐ์ ์ ํ์ ์งํฉ์์ [`UNet2DModel`] ๋ชจ๋ธ์ ํ์ตํ๋ ๋ฐฉ๋ฒ์ ๊ฐ๋ฅด์ณ์ค ๊ฒ์ ๋๋ค. | |
| <Tip> | |
| ๐ก ์ด ํ์ต ํํ ๋ฆฌ์ผ์ [Training with ๐งจ Diffusers](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) ๋ ธํธ๋ถ ๊ธฐ๋ฐ์ผ๋ก ํฉ๋๋ค. Diffusion ๋ชจ๋ธ์ ์๋ ๋ฐฉ์ ๋ฐ ์์ธํ ๋ด์ฉ์ ๋ ธํธ๋ถ์ ํ์ธํ์ธ์! | |
| </Tip> | |
| ์์ ์ ์, ๐ค Datasets์ ๋ถ๋ฌ์ค๊ณ ์ ์ฒ๋ฆฌํ๊ธฐ ์ํด ๋ฐ์ดํฐ์ ์ด ์ค์น๋์ด ์๋์ง ๋ค์ GPU์์ ํ์ต์ ๊ฐ์ํํ๊ธฐ ์ํด ๐ค Accelerate ๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์. ๊ทธ ํ ํ์ต ๋ฉํธ๋ฆญ์ ์๊ฐํํ๊ธฐ ์ํด [TensorBoard](https://www.tensorflow.org/tensorboard)๋ฅผ ๋ํ ์ค์นํ์ธ์. (๋ํ ํ์ต ์ถ์ ์ ์ํด [Weights & Biases](https://docs.wandb.ai/)๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค.) | |
| ```bash | |
| !pip install diffusers[training] | |
| ``` | |
| ์ปค๋ฎค๋ํฐ์ ๋ชจ๋ธ์ ๊ณต์ ํ ๊ฒ์ ๊ถ์ฅํ๋ฉฐ, ์ด๋ฅผ ์ํด์ Hugging Face ๊ณ์ ์ ๋ก๊ทธ์ธ์ ํด์ผ ํฉ๋๋ค. (๊ณ์ ์ด ์๋ค๋ฉด [์ฌ๊ธฐ](https://hf.co/join)์์ ๋ง๋ค ์ ์์ต๋๋ค.) ๋ ธํธ๋ถ์์ ๋ก๊ทธ์ธํ ์ ์์ผ๋ฉฐ ๋ฉ์์ง๊ฐ ํ์๋๋ฉด ํ ํฐ์ ์ ๋ ฅํ ์ ์์ต๋๋ค. | |
| ```py | |
| >>> from huggingface_hub import notebook_login | |
| >>> notebook_login() | |
| ``` | |
| ๋๋ ํฐ๋ฏธ๋๋ก ๋ก๊ทธ์ธํ ์ ์์ต๋๋ค: | |
| ```bash | |
| huggingface-cli login | |
| ``` | |
| ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๊ฐ ์๋นํ ํฌ๊ธฐ ๋๋ฌธ์ [Git-LFS](https://git-lfs.com/)์์ ๋์ฉ๋ ํ์ผ์ ๋ฒ์ ๊ด๋ฆฌ๋ฅผ ํ ์ ์์ต๋๋ค. | |
| ```bash | |
| !sudo apt -qq install git-lfs | |
| !git config --global credential.helper store | |
| ``` | |
| ## ํ์ต ๊ตฌ์ฑ | |
| ํธ์๋ฅผ ์ํด ํ์ต ํ๋ผ๋ฏธํฐ๋ค์ ํฌํจํ `TrainingConfig` ํด๋์ค๋ฅผ ์์ฑํฉ๋๋ค (์์ ๋กญ๊ฒ ์กฐ์ ๊ฐ๋ฅ): | |
| ```py | |
| >>> from dataclasses import dataclass | |
| >>> @dataclass | |
| ... class TrainingConfig: | |
| ... image_size = 128 # ์์ฑ๋๋ ์ด๋ฏธ์ง ํด์๋ | |
| ... train_batch_size = 16 | |
| ... eval_batch_size = 16 # ํ๊ฐ ๋์์ ์ํ๋งํ ์ด๋ฏธ์ง ์ | |
| ... num_epochs = 50 | |
| ... gradient_accumulation_steps = 1 | |
| ... learning_rate = 1e-4 | |
| ... lr_warmup_steps = 500 | |
| ... save_image_epochs = 10 | |
| ... save_model_epochs = 30 | |
| ... mixed_precision = "fp16" # `no`๋ float32, ์๋ ํผํฉ ์ ๋ฐ๋๋ฅผ ์ํ `fp16` | |
| ... output_dir = "ddpm-butterflies-128" # ๋ก์ปฌ ๋ฐ HF Hub์ ์ ์ฅ๋๋ ๋ชจ๋ธ๋ช | |
| ... push_to_hub = True # ์ ์ฅ๋ ๋ชจ๋ธ์ HF Hub์ ์ ๋ก๋ํ ์ง ์ฌ๋ถ | |
| ... hub_private_repo = False | |
| ... overwrite_output_dir = True # ๋ ธํธ๋ถ์ ๋ค์ ์คํํ ๋ ์ด์ ๋ชจ๋ธ์ ๋ฎ์ด์์ธ์ง | |
| ... seed = 0 | |
| >>> config = TrainingConfig() | |
| ``` | |
| ## ๋ฐ์ดํฐ์ ๋ถ๋ฌ์ค๊ธฐ | |
| ๐ค Datasets ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ [Smithsonian Butterflies](https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset) ๋ฐ์ดํฐ์ ์ ์ฝ๊ฒ ๋ถ๋ฌ์ฌ ์ ์์ต๋๋ค. | |
| ```py | |
| >>> from datasets import load_dataset | |
| >>> config.dataset_name = "huggan/smithsonian_butterflies_subset" | |
| >>> dataset = load_dataset(config.dataset_name, split="train") | |
| ``` | |
| ๐ก[HugGan Community Event](https://huggingface.co/huggan) ์์ ์ถ๊ฐ์ ๋ฐ์ดํฐ์ ์ ์ฐพ๊ฑฐ๋ ๋ก์ปฌ์ [`ImageFolder`](https://huggingface.co/docs/datasets/image_dataset#imagefolder)๋ฅผ ๋ง๋ฆ์ผ๋ก์จ ๋๋ง์ ๋ฐ์ดํฐ์ ์ ์ฌ์ฉํ ์ ์์ต๋๋ค. HugGan Community Event ์ ๊ฐ์ ธ์จ ๋ฐ์ดํฐ์ ์ ๊ฒฝ์ฐ ๋ ํฌ์งํ ๋ฆฌ์ id๋ก `config.dataset_name` ์ ์ค์ ํ๊ณ , ๋๋ง์ ์ด๋ฏธ์ง๋ฅผ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ `imagefolder` ๋ฅผ ์ค์ ํฉ๋๋ค. | |
| ๐ค Datasets์ [`~datasets.Image`] ๊ธฐ๋ฅ์ ์ฌ์ฉํด ์๋์ผ๋ก ์ด๋ฏธ์ง ๋ฐ์ดํฐ๋ฅผ ๋์ฝ๋ฉํ๊ณ [`PIL.Image`](https://pillow.readthedocs.io/en/stable/reference/Image.html)๋ก ๋ถ๋ฌ์ต๋๋ค. ์ด๋ฅผ ์๊ฐํ ํด๋ณด๋ฉด: | |
| ```py | |
| >>> import matplotlib.pyplot as plt | |
| >>> fig, axs = plt.subplots(1, 4, figsize=(16, 4)) | |
| >>> for i, image in enumerate(dataset[:4]["image"]): | |
| ... axs[i].imshow(image) | |
| ... axs[i].set_axis_off() | |
| >>> fig.show() | |
| ``` | |
|  | |
| ์ด๋ฏธ์ง๋ ๋ชจ๋ ๋ค๋ฅธ ์ฌ์ด์ฆ์ด๊ธฐ ๋๋ฌธ์, ์ฐ์ ์ ์ฒ๋ฆฌ๊ฐ ํ์ํฉ๋๋ค: | |
| - `Resize` ๋ `config.image_size` ์ ์ ์๋ ์ด๋ฏธ์ง ์ฌ์ด์ฆ๋ก ๋ณ๊ฒฝํฉ๋๋ค. | |
| - `RandomHorizontalFlip` ์ ๋๋ค์ ์ผ๋ก ์ด๋ฏธ์ง๋ฅผ ๋ฏธ๋ฌ๋งํ์ฌ ๋ฐ์ดํฐ์ ์ ๋ณด๊ฐํฉ๋๋ค. | |
| - `Normalize` ๋ ๋ชจ๋ธ์ด ์์ํ๋ [-1, 1] ๋ฒ์๋ก ํฝ์ ๊ฐ์ ์ฌ์กฐ์ ํ๋๋ฐ ์ค์ํฉ๋๋ค. | |
| ```py | |
| >>> from torchvision import transforms | |
| >>> preprocess = transforms.Compose( | |
| ... [ | |
| ... transforms.Resize((config.image_size, config.image_size)), | |
| ... transforms.RandomHorizontalFlip(), | |
| ... transforms.ToTensor(), | |
| ... transforms.Normalize([0.5], [0.5]), | |
| ... ] | |
| ... ) | |
| ``` | |
| ํ์ต ๋์ค์ `preprocess` ํจ์๋ฅผ ์ ์ฉํ๋ ค๋ฉด ๐ค Datasets์ [`~datasets.Dataset.set_transform`] ๋ฐฉ๋ฒ์ด ์ฌ์ฉ๋ฉ๋๋ค. | |
| ```py | |
| >>> def transform(examples): | |
| ... images = [preprocess(image.convert("RGB")) for image in examples["image"]] | |
| ... return {"images": images} | |
| >>> dataset.set_transform(transform) | |
| ``` | |
| ์ด๋ฏธ์ง์ ํฌ๊ธฐ๊ฐ ์กฐ์ ๋์๋์ง ํ์ธํ๊ธฐ ์ํด ์ด๋ฏธ์ง๋ฅผ ๋ค์ ์๊ฐํํด๋ณด์ธ์. ์ด์ [DataLoader](https://pytorch.org/docs/stable/data#torch.utils.data.DataLoader)์ ๋ฐ์ดํฐ์ ์ ํฌํจํด ํ์ตํ ์ค๋น๊ฐ ๋์์ต๋๋ค! | |
| ```py | |
| >>> import torch | |
| >>> train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True) | |
| ``` | |
| ## UNet2DModel ์์ฑํ๊ธฐ | |
| ๐งจ Diffusers์ ์ฌ์ ํ์ต๋ ๋ชจ๋ธ๋ค์ ๋ชจ๋ธ ํด๋์ค์์ ์ํ๋ ํ๋ผ๋ฏธํฐ๋ก ์ฝ๊ฒ ์์ฑํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, [`UNet2DModel`]๋ฅผ ์์ฑํ๋ ค๋ฉด: | |
| ```py | |
| >>> from diffusers import UNet2DModel | |
| >>> model = UNet2DModel( | |
| ... sample_size=config.image_size, # ํ๊ฒ ์ด๋ฏธ์ง ํด์๋ | |
| ... in_channels=3, # ์ ๋ ฅ ์ฑ๋ ์, RGB ์ด๋ฏธ์ง์์ 3 | |
| ... out_channels=3, # ์ถ๋ ฅ ์ฑ๋ ์ | |
| ... layers_per_block=2, # UNet ๋ธ๋ญ๋น ๋ช ๊ฐ์ ResNet ๋ ์ด์ด๊ฐ ์ฌ์ฉ๋๋์ง | |
| ... block_out_channels=(128, 128, 256, 256, 512, 512), # ๊ฐ UNet ๋ธ๋ญ์ ์ํ ์ถ๋ ฅ ์ฑ๋ ์ | |
| ... down_block_types=( | |
| ... "DownBlock2D", # ์ผ๋ฐ์ ์ธ ResNet ๋ค์ด์ํ๋ง ๋ธ๋ญ | |
| ... "DownBlock2D", | |
| ... "DownBlock2D", | |
| ... "DownBlock2D", | |
| ... "AttnDownBlock2D", # spatial self-attention์ด ํฌํจ๋ ์ผ๋ฐ์ ์ธ ResNet ๋ค์ด์ํ๋ง ๋ธ๋ญ | |
| ... "DownBlock2D", | |
| ... ), | |
| ... up_block_types=( | |
| ... "UpBlock2D", # ์ผ๋ฐ์ ์ธ ResNet ์ ์ํ๋ง ๋ธ๋ญ | |
| ... "AttnUpBlock2D", # spatial self-attention์ด ํฌํจ๋ ์ผ๋ฐ์ ์ธ ResNet ์ ์ํ๋ง ๋ธ๋ญ | |
| ... "UpBlock2D", | |
| ... "UpBlock2D", | |
| ... "UpBlock2D", | |
| ... "UpBlock2D", | |
| ... ), | |
| ... ) | |
| ``` | |
| ์ํ์ ์ด๋ฏธ์ง ํฌ๊ธฐ์ ๋ชจ๋ธ ์ถ๋ ฅ ํฌ๊ธฐ๊ฐ ๋ง๋์ง ๋น ๋ฅด๊ฒ ํ์ธํ๊ธฐ ์ํ ์ข์ ์์ด๋์ด๊ฐ ์์ต๋๋ค: | |
| ```py | |
| >>> sample_image = dataset[0]["images"].unsqueeze(0) | |
| >>> print("Input shape:", sample_image.shape) | |
| Input shape: torch.Size([1, 3, 128, 128]) | |
| >>> print("Output shape:", model(sample_image, timestep=0).sample.shape) | |
| Output shape: torch.Size([1, 3, 128, 128]) | |
| ``` | |
| ํ๋ฅญํด์! ๋ค์, ์ด๋ฏธ์ง์ ์ฝ๊ฐ์ ๋ ธ์ด์ฆ๋ฅผ ๋ํ๊ธฐ ์ํด ์ค์ผ์ค๋ฌ๊ฐ ํ์ํฉ๋๋ค. | |
| ## ์ค์ผ์ค๋ฌ ์์ฑํ๊ธฐ | |
| ์ค์ผ์ค๋ฌ๋ ๋ชจ๋ธ์ ํ์ต ๋๋ ์ถ๋ก ์ ์ฌ์ฉํ๋์ง์ ๋ฐ๋ผ ๋ค๋ฅด๊ฒ ์๋ํฉ๋๋ค. ์ถ๋ก ์์, ์ค์ผ์ค๋ฌ๋ ๋ ธ์ด์ฆ๋ก๋ถํฐ ์ด๋ฏธ์ง๋ฅผ ์์ฑํฉ๋๋ค. ํ์ต์ ์ค์ผ์ค๋ฌ๋ diffusion ๊ณผ์ ์์์ ํน์ ํฌ์ธํธ๋ก๋ถํฐ ๋ชจ๋ธ์ ์ถ๋ ฅ ๋๋ ์ํ์ ๊ฐ์ ธ์ *๋ ธ์ด์ฆ ์ค์ผ์ค* ๊ณผ *์ ๋ฐ์ดํธ ๊ท์น*์ ๋ฐ๋ผ ์ด๋ฏธ์ง์ ๋ ธ์ด์ฆ๋ฅผ ์ ์ฉํฉ๋๋ค. | |
| `DDPMScheduler`๋ฅผ ๋ณด๋ฉด ์ด์ ์ผ๋ก๋ถํฐ `sample_image`์ ๋๋คํ ๋ ธ์ด์ฆ๋ฅผ ๋ํ๋ `add_noise` ๋ฉ์๋๋ฅผ ์ฌ์ฉํฉ๋๋ค: | |
| ```py | |
| >>> import torch | |
| >>> from PIL import Image | |
| >>> from diffusers import DDPMScheduler | |
| >>> noise_scheduler = DDPMScheduler(num_train_timesteps=1000) | |
| >>> noise = torch.randn(sample_image.shape) | |
| >>> timesteps = torch.LongTensor([50]) | |
| >>> noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps) | |
| >>> Image.fromarray(((noisy_image.permute(0, 2, 3, 1) + 1.0) * 127.5).type(torch.uint8).numpy()[0]) | |
| ``` | |
|  | |
| ๋ชจ๋ธ์ ํ์ต ๋ชฉ์ ์ ์ด๋ฏธ์ง์ ๋ํด์ง ๋ ธ์ด์ฆ๋ฅผ ์์ธกํ๋ ๊ฒ์ ๋๋ค. ์ด ๋จ๊ณ์์ ์์ค์ ๋ค์๊ณผ ๊ฐ์ด ๊ณ์ฐ๋ ์ ์์ต๋๋ค: | |
| ```py | |
| >>> import torch.nn.functional as F | |
| >>> noise_pred = model(noisy_image, timesteps).sample | |
| >>> loss = F.mse_loss(noise_pred, noise) | |
| ``` | |
| ## ๋ชจ๋ธ ํ์ตํ๊ธฐ | |
| ์ง๊ธ๊น์ง, ๋ชจ๋ธ ํ์ต์ ์์ํ๊ธฐ ์ํด ๋ง์ ๋ถ๋ถ์ ๊ฐ์ถ์์ผ๋ฉฐ ์ด์ ๋จ์ ๊ฒ์ ๋ชจ๋ ๊ฒ์ ์กฐํฉํ๋ ๊ฒ์ ๋๋ค. | |
| ์ฐ์ ์ตํฐ๋ง์ด์ (optimizer)์ ํ์ต๋ฅ ์ค์ผ์ค๋ฌ(learning rate scheduler)๊ฐ ํ์ํ ๊ฒ์ ๋๋ค: | |
| ```py | |
| >>> from diffusers.optimization import get_cosine_schedule_with_warmup | |
| >>> optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) | |
| >>> lr_scheduler = get_cosine_schedule_with_warmup( | |
| ... optimizer=optimizer, | |
| ... num_warmup_steps=config.lr_warmup_steps, | |
| ... num_training_steps=(len(train_dataloader) * config.num_epochs), | |
| ... ) | |
| ``` | |
| ๊ทธ ํ, ๋ชจ๋ธ์ ํ๊ฐํ๋ ๋ฐฉ๋ฒ์ด ํ์ํฉ๋๋ค. ํ๊ฐ๋ฅผ ์ํด, `DDPMPipeline`์ ์ฌ์ฉํด ๋ฐฐ์น์ ์ด๋ฏธ์ง ์ํ๋ค์ ์์ฑํ๊ณ ๊ทธ๋ฆฌ๋ ํํ๋ก ์ ์ฅํ ์ ์์ต๋๋ค: | |
| ```py | |
| >>> from diffusers import DDPMPipeline | |
| >>> import math | |
| >>> import os | |
| >>> def make_grid(images, rows, cols): | |
| ... w, h = images[0].size | |
| ... grid = Image.new("RGB", size=(cols * w, rows * h)) | |
| ... for i, image in enumerate(images): | |
| ... grid.paste(image, box=(i % cols * w, i // cols * h)) | |
| ... return grid | |
| >>> def evaluate(config, epoch, pipeline): | |
| ... # ๋๋คํ ๋ ธ์ด์ฆ๋ก ๋ถํฐ ์ด๋ฏธ์ง๋ฅผ ์ถ์ถํฉ๋๋ค.(์ด๋ ์ญ์ ํ diffusion ๊ณผ์ ์ ๋๋ค.) | |
| ... # ๊ธฐ๋ณธ ํ์ดํ๋ผ์ธ ์ถ๋ ฅ ํํ๋ `List[PIL.Image]` ์ ๋๋ค. | |
| ... images = pipeline( | |
| ... batch_size=config.eval_batch_size, | |
| ... generator=torch.manual_seed(config.seed), | |
| ... ).images | |
| ... # ์ด๋ฏธ์ง๋ค์ ๊ทธ๋ฆฌ๋๋ก ๋ง๋ค์ด์ค๋๋ค. | |
| ... image_grid = make_grid(images, rows=4, cols=4) | |
| ... # ์ด๋ฏธ์ง๋ค์ ์ ์ฅํฉ๋๋ค. | |
| ... test_dir = os.path.join(config.output_dir, "samples") | |
| ... os.makedirs(test_dir, exist_ok=True) | |
| ... image_grid.save(f"{test_dir}/{epoch:04d}.png") | |
| ``` | |
| TensorBoard์ ๋ก๊น , ๊ทธ๋๋์ธํธ ๋์ ๋ฐ ํผํฉ ์ ๋ฐ๋ ํ์ต์ ์ฝ๊ฒ ์ํํ๊ธฐ ์ํด ๐ค Accelerate๋ฅผ ํ์ต ๋ฃจํ์ ํจ๊ป ์์ ๋งํ ๋ชจ๋ ๊ตฌ์ฑ ์ ๋ณด๋ค์ ๋ฌถ์ด ์งํํ ์ ์์ต๋๋ค. ํ๋ธ์ ๋ชจ๋ธ์ ์ ๋ก๋ ํ๊ธฐ ์ํด ๋ ํฌ์งํ ๋ฆฌ ์ด๋ฆ ๋ฐ ์ ๋ณด๋ฅผ ๊ฐ์ ธ์ค๊ธฐ ์ํ ํจ์๋ฅผ ์์ฑํ๊ณ ํ๋ธ์ ์ ๋ก๋ํ ์ ์์ต๋๋ค. | |
| ๐ก์๋์ ํ์ต ๋ฃจํ๋ ์ด๋ ต๊ณ ๊ธธ์ด ๋ณด์ผ ์ ์์ง๋ง, ๋์ค์ ํ ์ค์ ์ฝ๋๋ก ํ์ต์ ํ๋ค๋ฉด ๊ทธ๋งํ ๊ฐ์น๊ฐ ์์ ๊ฒ์ ๋๋ค! ๋ง์ฝ ๊ธฐ๋ค๋ฆฌ์ง ๋ชปํ๊ณ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๊ณ ์ถ๋ค๋ฉด, ์๋ ์ฝ๋๋ฅผ ์์ ๋กญ๊ฒ ๋ถ์ฌ๋ฃ๊ณ ์๋์ํค๋ฉด ๋ฉ๋๋ค. ๐ค | |
| ```py | |
| >>> from accelerate import Accelerator | |
| >>> from huggingface_hub import HfFolder, Repository, whoami | |
| >>> from tqdm.auto import tqdm | |
| >>> from pathlib import Path | |
| >>> import os | |
| >>> def get_full_repo_name(model_id: str, organization: str = None, token: str = None): | |
| ... if token is None: | |
| ... token = HfFolder.get_token() | |
| ... if organization is None: | |
| ... username = whoami(token)["name"] | |
| ... return f"{username}/{model_id}" | |
| ... else: | |
| ... return f"{organization}/{model_id}" | |
| >>> def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler): | |
| ... # accelerator์ tensorboard ๋ก๊น ์ด๊ธฐํ | |
| ... accelerator = Accelerator( | |
| ... mixed_precision=config.mixed_precision, | |
| ... gradient_accumulation_steps=config.gradient_accumulation_steps, | |
| ... log_with="tensorboard", | |
| ... logging_dir=os.path.join(config.output_dir, "logs"), | |
| ... ) | |
| ... if accelerator.is_main_process: | |
| ... if config.push_to_hub: | |
| ... repo_name = get_full_repo_name(Path(config.output_dir).name) | |
| ... repo = Repository(config.output_dir, clone_from=repo_name) | |
| ... elif config.output_dir is not None: | |
| ... os.makedirs(config.output_dir, exist_ok=True) | |
| ... accelerator.init_trackers("train_example") | |
| ... # ๋ชจ๋ ๊ฒ์ด ์ค๋น๋์์ต๋๋ค. | |
| ... # ๊ธฐ์ตํด์ผ ํ ํน์ ํ ์์๋ ์์ผ๋ฉฐ ์ค๋นํ ๋ฐฉ๋ฒ์ ์ ๊ณตํ ๊ฒ๊ณผ ๋์ผํ ์์๋ก ๊ฐ์ฒด์ ์์ถ์ ํ๋ฉด ๋ฉ๋๋ค. | |
| ... model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | |
| ... model, optimizer, train_dataloader, lr_scheduler | |
| ... ) | |
| ... global_step = 0 | |
| ... # ์ด์ ๋ชจ๋ธ์ ํ์ตํฉ๋๋ค. | |
| ... for epoch in range(config.num_epochs): | |
| ... progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process) | |
| ... progress_bar.set_description(f"Epoch {epoch}") | |
| ... for step, batch in enumerate(train_dataloader): | |
| ... clean_images = batch["images"] | |
| ... # ์ด๋ฏธ์ง์ ๋ํ ๋ ธ์ด์ฆ๋ฅผ ์ํ๋งํฉ๋๋ค. | |
| ... noise = torch.randn(clean_images.shape).to(clean_images.device) | |
| ... bs = clean_images.shape[0] | |
| ... # ๊ฐ ์ด๋ฏธ์ง๋ฅผ ์ํ ๋๋คํ ํ์์คํ (timestep)์ ์ํ๋งํฉ๋๋ค. | |
| ... timesteps = torch.randint( | |
| ... 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device | |
| ... ).long() | |
| ... # ๊ฐ ํ์์คํ ์ ๋ ธ์ด์ฆ ํฌ๊ธฐ์ ๋ฐ๋ผ ๊นจ๋ํ ์ด๋ฏธ์ง์ ๋ ธ์ด์ฆ๋ฅผ ์ถ๊ฐํฉ๋๋ค. | |
| ... # (์ด๋ foward diffusion ๊ณผ์ ์ ๋๋ค.) | |
| ... noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) | |
| ... with accelerator.accumulate(model): | |
| ... # ๋ ธ์ด์ฆ๋ฅผ ๋ฐ๋ณต์ ์ผ๋ก ์์ธกํฉ๋๋ค. | |
| ... noise_pred = model(noisy_images, timesteps, return_dict=False)[0] | |
| ... loss = F.mse_loss(noise_pred, noise) | |
| ... accelerator.backward(loss) | |
| ... accelerator.clip_grad_norm_(model.parameters(), 1.0) | |
| ... optimizer.step() | |
| ... lr_scheduler.step() | |
| ... optimizer.zero_grad() | |
| ... progress_bar.update(1) | |
| ... logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} | |
| ... progress_bar.set_postfix(**logs) | |
| ... accelerator.log(logs, step=global_step) | |
| ... global_step += 1 | |
| ... # ๊ฐ ์ํฌํฌ๊ฐ ๋๋ ํ evaluate()์ ๋ช ๊ฐ์ง ๋ฐ๋ชจ ์ด๋ฏธ์ง๋ฅผ ์ ํ์ ์ผ๋ก ์ํ๋งํ๊ณ ๋ชจ๋ธ์ ์ ์ฅํฉ๋๋ค. | |
| ... if accelerator.is_main_process: | |
| ... pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler) | |
| ... if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1: | |
| ... evaluate(config, epoch, pipeline) | |
| ... if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1: | |
| ... if config.push_to_hub: | |
| ... repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=True) | |
| ... else: | |
| ... pipeline.save_pretrained(config.output_dir) | |
| ``` | |
| ํด, ์ฝ๋๊ฐ ๊ฝค ๋ง์๋ค์! ํ์ง๋ง ๐ค Accelerate์ [`~accelerate.notebook_launcher`] ํจ์์ ํ์ต์ ์์ํ ์ค๋น๊ฐ ๋์์ต๋๋ค. ํจ์์ ํ์ต ๋ฃจํ, ๋ชจ๋ ํ์ต ์ธ์, ํ์ต์ ์ฌ์ฉํ ํ๋ก์ธ์ค ์(์ฌ์ฉ ๊ฐ๋ฅํ GPU์ ์๋ฅผ ๋ณ๊ฒฝํ ์ ์์)๋ฅผ ์ ๋ฌํฉ๋๋ค: | |
| ```py | |
| >>> from accelerate import notebook_launcher | |
| >>> args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler) | |
| >>> notebook_launcher(train_loop, args, num_processes=1) | |
| ``` | |
| ํ๋ฒ ํ์ต์ด ์๋ฃ๋๋ฉด, diffusion ๋ชจ๋ธ๋ก ์์ฑ๋ ์ต์ข ๐ฆ์ด๋ฏธ์ง๐ฆ๋ฅผ ํ์ธํด๋ณด๊ธธ ๋ฐ๋๋๋ค! | |
| ```py | |
| >>> import glob | |
| >>> sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png")) | |
| >>> Image.open(sample_images[-1]) | |
| ``` | |
|  | |
| ## ๋ค์ ๋จ๊ณ | |
| Unconditional ์ด๋ฏธ์ง ์์ฑ์ ํ์ต๋ ์ ์๋ ์์ ์ค ํ๋์ ์์์ ๋๋ค. ๋ค๋ฅธ ์์ ๊ณผ ํ์ต ๋ฐฉ๋ฒ์ [๐งจ Diffusers ํ์ต ์์](../training/overview) ํ์ด์ง์์ ํ์ธํ ์ ์์ต๋๋ค. ๋ค์์ ํ์ตํ ์ ์๋ ๋ช ๊ฐ์ง ์์์ ๋๋ค: | |
| - [Textual Inversion](../training/text_inversion), ํน์ ์๊ฐ์ ๊ฐ๋ ์ ํ์ต์์ผ ์์ฑ๋ ์ด๋ฏธ์ง์ ํตํฉ์ํค๋ ์๊ณ ๋ฆฌ์ฆ์ ๋๋ค. | |
| - [DreamBooth](../training/dreambooth), ์ฃผ์ ์ ๋ํ ๋ช ๊ฐ์ง ์ ๋ ฅ ์ด๋ฏธ์ง๋ค์ด ์ฃผ์ด์ง๋ฉด ์ฃผ์ ์ ๋ํ ๊ฐ์ธํ๋ ์ด๋ฏธ์ง๋ฅผ ์์ฑํ๊ธฐ ์ํ ๊ธฐ์ ์ ๋๋ค. | |
| - [Guide](../training/text2image) ๋ฐ์ดํฐ์ ์ Stable Diffusion ๋ชจ๋ธ์ ํ์ธํ๋ํ๋ ๋ฐฉ๋ฒ์ ๋๋ค. | |
| - [Guide](../training/lora) LoRA๋ฅผ ์ฌ์ฉํด ๋งค์ฐ ํฐ ๋ชจ๋ธ์ ๋น ๋ฅด๊ฒ ํ์ธํ๋ํ๊ธฐ ์ํ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ ์ธ ๊ธฐ์ ์ ๋๋ค. | |