File size: 4,497 Bytes
f06aba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import torch.distributed as dist
import numpy as np
import copy
from torch.utils.data import IterableDataset

from longcat_image.utils.dist_utils import get_world_size, get_rank, get_local_rank

class MultiResolutionDistributedSampler(torch.utils.data.Sampler):
    def __init__(self,
                 batch_size: int,
                 dataset: IterableDataset,
                 data_resolution_infos: List,
                 bucket_info: dict,
                 num_replicas: int = None,
                 rank: int = None,
                 seed: int = 888,
                 epoch: int = 0,
                 shuffle: bool = True):

        if not dist.is_available():
            num_replicas = 1
            rank = 0
        else:
            num_replicas = get_world_size()
            rank = get_rank()

        self.len_items = len(dataset)
        bucket_info = {float(b): bucket_info[b] for b in bucket_info.keys()}
        self.aspect_ratios = np.array(sorted(list(bucket_info.keys())))
        self.resolutions = np.array([bucket_info[aspect] for aspect in self.aspect_ratios])

        self.batch_size = batch_size
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = epoch
        self.shuffle = shuffle
        self.seed = seed
        self.cur_rank_index = []
        self.rng = np.random.RandomState(seed+self.epoch)
        self.global_batch_size = batch_size*num_replicas
        self.data_resolution_infos = np.array(data_resolution_infos, dtype=np.float32)
        print(f'num_replicas {num_replicas}, cur rank {rank}!!!')

        self.split_to_buckets()
        self.num_samples = len(dataset)//num_replicas

    def split_to_buckets(self):
        self.buckets = {}
        self._buckets_bak = {}
        data_aspect_ratio = self.data_resolution_infos[:,0]*1.0/self.data_resolution_infos[:, 1]
        bucket_id = np.abs(data_aspect_ratio[:, None] - self.aspect_ratios).argmin(axis=1)
        for i in range(len(self.aspect_ratios)):
            self.buckets[i] = np.where(bucket_id == i)[0]
            self._buckets_bak[i] = np.where(bucket_id == i)[0]
        for k, v in self.buckets.items():
            print(f'bucket {k}, resolutions {self.resolutions[k]}, sampler nums {len(v)}!!!')

    def get_batch_index(self):
        success_flag = False
        while not success_flag:
            bucket_ids = list(self.buckets.keys())
            bucket_probs = [len(self.buckets[bucket_id]) for bucket_id in bucket_ids]
            bucket_probs = np.array(bucket_probs, dtype=np.float32)
            bucket_probs = bucket_probs / bucket_probs.sum()
            bucket_ids = np.array(bucket_ids, dtype=np.int64)
            chosen_id = int(self.rng.choice(bucket_ids, 1, p=bucket_probs)[0])
            if len(self.buckets[chosen_id]) < self.global_batch_size:
                del self.buckets[chosen_id]
                continue
            batch_data = self.buckets[chosen_id][:self.global_batch_size]
            batch_data = (batch_data, self.resolutions[chosen_id])
            self.buckets[chosen_id] = self.buckets[chosen_id][self.global_batch_size:]
            if len(self.buckets[chosen_id]) == 0:
                del self.buckets[chosen_id]
            success_flag = True
            assert bool(self.buckets), 'There is not enough data in the current epoch.'
        return batch_data

    def shuffle_bucker_index(self):
        self.rng = np.random.RandomState(self.seed+self.epoch)
        self.buckets = copy.deepcopy(self._buckets_bak)
        for bucket_id in self.buckets.keys():
            self.rng.shuffle(self.buckets[bucket_id])

    def __iter__(self):
        return self

    def __next__(self):
        try:
            if len(self.cur_rank_index) == 0:
                global_batch_index, target_resolutions = self.get_batch_index()
                self.cur_rank_index = list(map(
                    int, global_batch_index[self.batch_size*self.rank:self.batch_size*(self.rank+1)]))
                self.resolution = list(map(int, target_resolutions))
            data_index = self.cur_rank_index.pop(0)
            return (data_index, self.resolution)

        except Exception as e:
            self.epoch += 1
            self.shuffle_bucker_index()
            print(f'get error {e}.')
            raise StopIteration

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch