from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.distributed as dist import numpy as np import copy from torch.utils.data import IterableDataset from longcat_image.utils.dist_utils import get_world_size, get_rank, get_local_rank class MultiResolutionDistributedSampler(torch.utils.data.Sampler): def __init__(self, batch_size: int, dataset: IterableDataset, data_resolution_infos: List, bucket_info: dict, num_replicas: int = None, rank: int = None, seed: int = 888, epoch: int = 0, shuffle: bool = True): if not dist.is_available(): num_replicas = 1 rank = 0 else: num_replicas = get_world_size() rank = get_rank() self.len_items = len(dataset) bucket_info = {float(b): bucket_info[b] for b in bucket_info.keys()} self.aspect_ratios = np.array(sorted(list(bucket_info.keys()))) self.resolutions = np.array([bucket_info[aspect] for aspect in self.aspect_ratios]) self.batch_size = batch_size self.num_replicas = num_replicas self.rank = rank self.epoch = epoch self.shuffle = shuffle self.seed = seed self.cur_rank_index = [] self.rng = np.random.RandomState(seed+self.epoch) self.global_batch_size = batch_size*num_replicas self.data_resolution_infos = np.array(data_resolution_infos, dtype=np.float32) print(f'num_replicas {num_replicas}, cur rank {rank}!!!') self.split_to_buckets() self.num_samples = len(dataset)//num_replicas def split_to_buckets(self): self.buckets = {} self._buckets_bak = {} data_aspect_ratio = self.data_resolution_infos[:,0]*1.0/self.data_resolution_infos[:, 1] bucket_id = np.abs(data_aspect_ratio[:, None] - self.aspect_ratios).argmin(axis=1) for i in range(len(self.aspect_ratios)): self.buckets[i] = np.where(bucket_id == i)[0] self._buckets_bak[i] = np.where(bucket_id == i)[0] for k, v in self.buckets.items(): print(f'bucket {k}, resolutions {self.resolutions[k]}, sampler nums {len(v)}!!!') def get_batch_index(self): success_flag = False while not success_flag: bucket_ids = list(self.buckets.keys()) bucket_probs = [len(self.buckets[bucket_id]) for bucket_id in bucket_ids] bucket_probs = np.array(bucket_probs, dtype=np.float32) bucket_probs = bucket_probs / bucket_probs.sum() bucket_ids = np.array(bucket_ids, dtype=np.int64) chosen_id = int(self.rng.choice(bucket_ids, 1, p=bucket_probs)[0]) if len(self.buckets[chosen_id]) < self.global_batch_size: del self.buckets[chosen_id] continue batch_data = self.buckets[chosen_id][:self.global_batch_size] batch_data = (batch_data, self.resolutions[chosen_id]) self.buckets[chosen_id] = self.buckets[chosen_id][self.global_batch_size:] if len(self.buckets[chosen_id]) == 0: del self.buckets[chosen_id] success_flag = True assert bool(self.buckets), 'There is not enough data in the current epoch.' return batch_data def shuffle_bucker_index(self): self.rng = np.random.RandomState(self.seed+self.epoch) self.buckets = copy.deepcopy(self._buckets_bak) for bucket_id in self.buckets.keys(): self.rng.shuffle(self.buckets[bucket_id]) def __iter__(self): return self def __next__(self): try: if len(self.cur_rank_index) == 0: global_batch_index, target_resolutions = self.get_batch_index() self.cur_rank_index = list(map( int, global_batch_index[self.batch_size*self.rank:self.batch_size*(self.rank+1)])) self.resolution = list(map(int, target_resolutions)) data_index = self.cur_rank_index.pop(0) return (data_index, self.resolution) except Exception as e: self.epoch += 1 self.shuffle_bucker_index() print(f'get error {e}.') raise StopIteration def __len__(self): return self.num_samples def set_epoch(self, epoch): self.epoch = epoch