Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import cv2 | |
| import json | |
| import torch | |
| import random | |
| import base64 | |
| import numpy as np | |
| from PIL import Image, ImageDraw | |
| from glob import glob | |
| from torchvision import transforms as T | |
| import os | |
| import gc | |
| from webdataset.filters import default_collation_fn, pipelinefilter | |
| import yaml | |
| def get_rank_and_worldsize(): | |
| try: | |
| local_rank = int(os.environ.get("LOCAL_RANK")) | |
| global_rank = int(os.environ.get("RANK")) | |
| world_size = int(os.getenv('WORLD_SIZE', 1)) | |
| except: | |
| local_rank = 0 | |
| global_rank = 0 | |
| world_size = 1 | |
| return local_rank, global_rank, world_size | |
| def get_train_config(config_path=None): | |
| if config_path is None: | |
| config_path = os.environ.get("XFL_CONFIG") | |
| assert config_path is not None, "Please set the XFL_CONFIG environment variable" | |
| with open(config_path, "r") as f: | |
| config = yaml.safe_load(f) | |
| return config | |
| def calculate_aspect_ratios(resolution): | |
| ASPECT_RATIO = { | |
| '0.25': [128.0, 512.0], '0.26': [128.0, 496.0], '0.27': [128.0, 480.0], '0.28': [128.0, 464.0], | |
| '0.32': [144.0, 448.0], '0.33': [144.0, 432.0], '0.35': [144.0, 416.0], '0.4': [160.0, 400.0], | |
| '0.42': [160.0, 384.0], '0.48': [176.0, 368.0], '0.5': [176.0, 352.0], '0.52': [176.0, 336.0], | |
| '0.57': [192.0, 336.0], '0.6': [192.0, 320.0], '0.68': [208.0, 304.0], '0.72': [208.0, 288.0], | |
| '0.78': [224.0, 288.0], '0.82': [224.0, 272.0], '0.88': [240.0, 272.0], '0.94': [240.0, 256.0], | |
| '1.0': [256.0, 256.0], '1.07': [256.0, 240.0], '1.13': [272.0, 240.0], '1.21': [272.0, 224.0], | |
| '1.29': [288.0, 224.0], '1.38': [288.0, 208.0], '1.46': [304.0, 208.0], '1.67': [320.0, 192.0], | |
| '1.75': [336.0, 192.0], '2.0': [352.0, 176.0], '2.09': [368.0, 176.0], '2.4': [384.0, 160.0], | |
| '2.5': [400.0, 160.0], '2.89': [416.0, 144.0], '3.0': [432.0, 144.0], '3.11': [448.0, 144.0], | |
| '3.62': [464.0, 128.0], '3.75': [480.0, 128.0], '3.88': [496.0, 128.0], '4.0': [512.0, 128.0] | |
| } | |
| NEW_ASPECT_RATIO = {} | |
| for ratio in ASPECT_RATIO: | |
| height, width = ASPECT_RATIO[ratio] | |
| width = round(width / 256 * resolution) | |
| height = round(height / 256 * resolution) | |
| if width % 8 != 0: | |
| print(f"skip train resolution {width}, {height}") | |
| continue | |
| if height % 8 != 0: | |
| print(f"skip train resolution {width}, {height}") | |
| continue | |
| NEW_ASPECT_RATIO[ratio] = [height, width] | |
| return NEW_ASPECT_RATIO | |
| ASPECT_RATIO_256 = calculate_aspect_ratios(256) | |
| ASPECT_RATIO_384 = calculate_aspect_ratios(384) | |
| ASPECT_RATIO_512 = calculate_aspect_ratios(512) | |
| ASPECT_RATIO_768 = calculate_aspect_ratios(768) | |
| ASPECT_RATIO_1024 = calculate_aspect_ratios(1024) | |
| def get_closest_ratio(height: float, width: float, ratios: dict): | |
| aspect_ratio = height / width | |
| closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio)) | |
| return ratios[closest_ratio], closest_ratio | |
| def _aspect_ratio_batched( | |
| data, | |
| batchsize=20, | |
| aspect_ratios=ASPECT_RATIO_512, | |
| batch_cross=False, | |
| collation_fn=default_collation_fn, | |
| partial=True, | |
| ): | |
| """Create batches of the given size. | |
| :param data: iterator | |
| :param batchsize: target batch size | |
| :param tensors: automatically batch lists of ndarrays into ndarrays | |
| :param partial: return partial batches | |
| :returns: iterator | |
| """ | |
| assert collation_fn is not None | |
| buckets = { | |
| ratio: {"cross": [], "no_cross": []} for ratio in aspect_ratios.keys() | |
| } | |
| def check(buckets): | |
| for ratio in buckets: | |
| for bucket_name in buckets[ratio]: | |
| bucket = buckets[ratio][bucket_name] | |
| assert len(bucket) < batchsize | |
| for sample in data: | |
| check(buckets) | |
| height, width = sample['original_sizes'] | |
| (new_height, new_width), closest_ratio = get_closest_ratio(height, width, aspect_ratios) | |
| bucket_name = "cross" if sample["has_cross"] and batch_cross else "no_cross" | |
| bucket = buckets[closest_ratio][bucket_name] | |
| bucket.append(sample) | |
| if len(bucket) >= batchsize: | |
| try: | |
| batch = collation_fn(bucket) | |
| yield batch | |
| del batch | |
| except Exception as e: | |
| print(f"[aspect_ratio_batched] collation_fn batch failed due to error {e}") | |
| for sample in bucket: | |
| if "__key__" in sample: | |
| print("error sample key in batch:", sample["__key__"]) | |
| if "__url__" in sample: | |
| print("error sample url in batch:", sample["__url__"]) | |
| buckets[closest_ratio][bucket_name] = [] | |
| del bucket | |
| gc.collect() | |
| # yield the rest data and reset the buckets | |
| for ratio in buckets.keys(): | |
| for bucket_name in ["cross", "no_cross"]: | |
| bucket = buckets[ratio][bucket_name] | |
| if len(bucket) > 0: | |
| if len(bucket) == batchsize or partial: | |
| batch = collation_fn(bucket) | |
| yield batch | |
| del batch | |
| buckets[ratio][bucket_name] = [] | |
| del bucket | |
| aspect_ratio_batched = pipelinefilter(_aspect_ratio_batched) | |
| def apply_aspect_ratio_batched(dataset, batchsize, aspect_ratios, batch_cross, collation_fn, partial=True): | |
| return dataset.compose( | |
| aspect_ratio_batched( | |
| batchsize, | |
| aspect_ratios=aspect_ratios, | |
| batch_cross=batch_cross, | |
| collation_fn=collation_fn, | |
| partial=partial | |
| ) | |
| ) | |
| def get_aspect_ratios(enable_aspect_ratio, resolution): | |
| if enable_aspect_ratio: | |
| # print("[Dataset] Multi Aspect Ratio Training Enabled") | |
| if resolution == 256: | |
| aspect_ratios = ASPECT_RATIO_256 | |
| elif resolution == 384: | |
| aspect_ratios = ASPECT_RATIO_384 | |
| elif resolution == 512: | |
| aspect_ratios = ASPECT_RATIO_512 | |
| elif resolution == 768: | |
| aspect_ratios = ASPECT_RATIO_768 | |
| elif resolution == 1024: | |
| aspect_ratios = ASPECT_RATIO_1024 | |
| else: | |
| aspect_ratios = calculate_aspect_ratios(resolution) | |
| else: | |
| # print("[Dataset] Multi Aspect Ratio Training Disabled") | |
| aspect_ratios = { | |
| '1.0': [resolution, resolution] | |
| } | |
| return aspect_ratios | |
| def bbox_to_grid(bbox, image_size, output_size=(224, 224)): | |
| """ | |
| Convert bounding box to a grid of points. | |
| Args: | |
| bbox (list of float): [xmin, ymin, xmax, ymax] | |
| output_size (tuple of int): (height, width) of the output grid | |
| Returns: | |
| torch.Tensor: Grid of points with shape (output_height, output_width, 2) | |
| """ | |
| xmin, ymin, xmax, ymax = bbox | |
| # Create a meshgrid for the output grid | |
| h, w = output_size | |
| yy, xx = torch.meshgrid( | |
| torch.linspace(ymin, ymax, h), | |
| torch.linspace(xmin, xmax, w) | |
| ) | |
| grid = torch.stack((xx, yy), -1) | |
| # Normalize grid to range [-1, 1] | |
| H, W = image_size | |
| grid[..., 0] = grid[..., 0] / (W - 1) * 2 - 1 # Normalize x to [-1, 1] | |
| grid[..., 1] = grid[..., 1] / (H - 1) * 2 - 1 # Normalize y to [-1, 1] | |
| return grid | |
| def random_crop_instance(instance, min_crop_ratio): | |
| assert 0 < min_crop_ratio <= 1 | |
| crop_width_ratio = random.uniform(min_crop_ratio, 1) | |
| crop_height_ratio = random.uniform(min_crop_ratio, 1) | |
| orig_width, orig_height = instance.size | |
| crop_width = int(orig_width * crop_width_ratio) | |
| crop_height = int(orig_height * crop_height_ratio) | |
| crop_left = random.randint(0, orig_width - crop_width) | |
| crop_top = random.randint(0, orig_height - crop_height) | |
| crop_box = (crop_left, crop_top, crop_left + crop_width, crop_top + crop_height) # (left, upper, right, lower) | |
| return instance.crop(crop_box), crop_box | |
| pil2tensor = T.ToTensor() | |
| tensor2pil = T.ToPILImage() | |
| cv2pil = lambda x: Image.fromarray(cv2.cvtColor(x, cv2.COLOR_BGR2RGB)) | |
| pil2cv2 = lambda x: cv2.cvtColor(np.array(x), cv2.COLOR_RGB2BGR) | |
| def compute_psnr(x, y): | |
| y = y.resize(x.size) | |
| x = pil2tensor(x) * 255. | |
| y = pil2tensor(y) * 255. | |
| mse = torch.mean((x - y) ** 2) | |
| return 20 * torch.log10(255.0 / torch.sqrt(mse)).item() | |
| def replace_first_occurrence(sentence, word_or_phrase, replace_with): | |
| # Escape special characters in word_or_phrase for exact matching | |
| escaped_word_or_phrase = re.escape(word_or_phrase) | |
| pattern = r'\b' + escaped_word_or_phrase + r'\b' | |
| # Finding the first match | |
| match = next(re.finditer(pattern, sentence), None) | |
| if match: | |
| # Perform replacement | |
| result = re.sub(pattern, replace_with, sentence, count=1) | |
| replaced = True | |
| index = match.start() | |
| else: | |
| # No match found | |
| result = sentence | |
| replaced = False | |
| index = -1 | |
| return result, replaced, index | |
| def decode_base64_to_image(base64_str): | |
| # Decode the base64 string to bytes | |
| img_bytes = base64.b64decode(base64_str) | |
| # Create a BytesIO buffer from the bytes | |
| img_buffer = io.BytesIO(img_bytes) | |
| # Open the image using Pillow | |
| image = Image.open(img_buffer) | |
| return image | |
| def jpeg_compression(pil_image, quality): | |
| buffer = io.BytesIO() | |
| pil_image.save(buffer, format="JPEG", quality=quality) | |
| return Image.open(io.BytesIO(buffer.getvalue())) | |
| def pad_to_square(pil_image): | |
| new_size = max(pil_image.width, pil_image.height) | |
| square_image = Image.new("RGB", (new_size, new_size), "white") | |
| left = (new_size - pil_image.width) // 2 | |
| top = (new_size - pil_image.height) // 2 | |
| square_image.paste(pil_image, (left, top)) | |
| return square_image | |
| def pad_to_target(pil_image, target_size): | |
| original_width, original_height = pil_image.size | |
| target_width, target_height = target_size | |
| original_aspect_ratio = original_width / original_height | |
| target_aspect_ratio = target_width / target_height | |
| # Pad the image to the target aspect ratio | |
| if original_aspect_ratio > target_aspect_ratio: | |
| new_width = original_width | |
| new_height = int(new_width / target_aspect_ratio) | |
| else: | |
| new_height = original_height | |
| new_width = int(new_height * target_aspect_ratio) | |
| pad_image = Image.new("RGB", (new_width, new_height), "white") | |
| left = (new_width - original_width) // 2 | |
| top = (new_height - original_height) // 2 | |
| pad_image.paste(pil_image, (left, top)) | |
| # Resize the image to the target size | |
| resized_image = pad_image.resize(target_size) | |
| return resized_image | |
| def image_grid(imgs, rows, cols): | |
| # assert len(imgs) == rows * cols | |
| w, h = imgs[0].size | |
| if imgs[0].mode == 'L': | |
| grid = Image.new('L', size=(cols * w, rows * h)) | |
| else: | |
| grid = Image.new('RGB', size=(cols * w, rows * h)) | |
| for i, img in enumerate(imgs): | |
| grid.paste(img, box=(i % cols * w, i // cols * h)) | |
| return grid | |
| def split_grid(image): | |
| width = image.width // 2 | |
| height = image.height // 2 | |
| crop_tuples_list = [ | |
| (0, 0, width, height), | |
| (width, 0, width*2, height), | |
| (0, height, width, height*2), | |
| (width, height, width*2, height*2), | |
| ] | |
| def crop_image(input_image, crop_tuple=None): | |
| if crop_tuple is None: | |
| return input_image | |
| return input_image.crop((crop_tuple[0], crop_tuple[1], crop_tuple[2], crop_tuple[3])) | |
| return [crop_image(image, crop_tuple) for crop_tuple in crop_tuples_list] | |
| def add_border(img, border_color, border_thickness): | |
| """ | |
| Add a colored border to an image without changing its size. | |
| Parameters: | |
| border_color (tuple): Border color in RGB (e.g., (255, 0, 0) for red). | |
| border_thickness (int): Thickness of the border in pixels. | |
| """ | |
| width, height = img.size | |
| img = img.copy() | |
| draw = ImageDraw.Draw(img) | |
| draw.rectangle((0, 0, width, border_thickness), fill=border_color) | |
| draw.rectangle((0, height - border_thickness, width, height), fill=border_color) | |
| draw.rectangle((0, 0, border_thickness, height), fill=border_color) | |
| draw.rectangle((width - border_thickness, 0, width, height), fill=border_color) | |
| return img | |
| def merge_bboxes(bboxes): | |
| if not bboxes: | |
| return None # Handle empty input | |
| # Extract all coordinates | |
| x_mins = [b[0] for b in bboxes] | |
| y_mins = [b[1] for b in bboxes] | |
| x_maxs = [b[2] for b in bboxes] | |
| y_maxs = [b[3] for b in bboxes] | |
| # Compute the merged box | |
| merged_box = ( | |
| min(x_mins), # x_min | |
| min(y_mins), # y_min | |
| max(x_maxs), # x_max | |
| max(y_maxs) # y_max | |
| ) | |
| return merged_box | |
| def flip_bbox_left_right(bbox, image_width): | |
| """ | |
| Flips the bounding box horizontally on an image. | |
| Parameters: | |
| bbox (list of float): [x_min, y_min, x_max, y_max] | |
| image_width (int): The width of the image | |
| Returns: | |
| list of float: New bounding box after horizontal flip [x_min', y_min', x_max', y_max'] | |
| """ | |
| x_min, y_min, x_max, y_max = bbox | |
| new_x_min = image_width - x_max | |
| new_x_max = image_width - x_min | |
| new_bbox = [new_x_min, y_min, new_x_max, y_max] | |
| return new_bbox | |
| def json_load(path, encoding='ascii'): | |
| with open(path, 'r', encoding=encoding) as file: | |
| return json.load(file) | |
| def json_dump(obj, path, encoding='ascii', indent=4, create_dir=True, verbose=True, **kwargs): | |
| if create_dir and os.path.dirname(path) != '': | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| with open(path, 'w', encoding=encoding) as file: | |
| json.dump(obj, file, indent=4, ensure_ascii=False, **kwargs) | |
| if verbose: | |
| print(type(obj), 'saved to', path) | |