Spaces:
Runtime error
Runtime error
| # MIT License | |
| # Copyright (c) 2022 Intelligent Systems Lab Org | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| # File author: Zhenyu Li | |
| import itertools | |
| import math | |
| import copy | |
| import random | |
| import torch | |
| import numpy as np | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmengine import print_log | |
| from estimator.registry import MODELS | |
| from estimator.models import build_model | |
| from estimator.models.utils import get_activation | |
| from zoedepth.models.zoedepth import ZoeDepth | |
| import matplotlib.pyplot as plt | |
| from estimator.models.utils import get_activation, generatemask, RunningAverageMap | |
| from zoedepth.models.base_models.midas import Resize as ResizeZoe | |
| from depth_anything.transform import Resize as ResizeDA | |
| class BaselinePretrain(nn.Module): | |
| def __init__(self, | |
| coarse_branch, | |
| fine_branch, | |
| sigloss, | |
| min_depth, | |
| max_depth, | |
| image_raw_shape=(2160, 3840), | |
| patch_process_shape=(384, 512), | |
| patch_split_num=(4, 4), | |
| target='coarse', | |
| coarse_branch_zoe=None): | |
| """ZoeDepth model | |
| """ | |
| super().__init__() | |
| self.patch_process_shape = patch_process_shape | |
| self.tile_cfg = self.prepare_tile_cfg(image_raw_shape, patch_split_num) | |
| self.min_depth = min_depth | |
| self.max_depth = max_depth | |
| self.coarse_branch_cfg = coarse_branch | |
| self.fine_branch_cfg = fine_branch | |
| if target == 'coarse': | |
| if self.coarse_branch_cfg.type == 'ZoeDepth': | |
| self.coarse_branch = ZoeDepth.build(**coarse_branch) | |
| print_log("Current zoedepth.core.prep.resizer is {}".format(type(self.coarse_branch.core.prep.resizer)), logger='current') | |
| self.resizer = ResizeZoe(patch_process_shape[1], patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=32, resize_method="minimal") | |
| elif self.coarse_branch_cfg.type == 'DA-ZoeDepth': | |
| self.coarse_branch = ZoeDepth.build(**coarse_branch) | |
| print_log("Current zoedepth.core.prep.resizer is {}".format(type(self.coarse_branch.core.prep.resizer)), logger='current') | |
| self.resizer = ResizeDA(patch_process_shape[1], patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=14, resize_method="minimal") | |
| if target == 'fine': | |
| if self.fine_branch_cfg.type == 'ZoeDepth': | |
| self.fine_branch = ZoeDepth.build(**fine_branch) | |
| print_log("Current zoedepth.core.prep.resizer is {}".format(type(self.fine_branch.core.prep.resizer)), logger='current') | |
| self.resizer = ResizeZoe(patch_process_shape[1], patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=32, resize_method="minimal") | |
| elif self.fine_branch_cfg.type == 'DA-ZoeDepth': | |
| self.fine_branch = ZoeDepth.build(**fine_branch) | |
| print_log("Current zoedepth.core.prep.resizer is {}".format(type(self.fine_branch.core.prep.resizer)), logger='current') | |
| self.resizer = ResizeDA(patch_process_shape[1], patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=14, resize_method="minimal") | |
| self.sigloss = build_model(sigloss) | |
| self.target = target | |
| def prepare_tile_cfg(self, image_raw_shape, patch_split_num): | |
| # information for process | |
| patch_split_num = patch_split_num | |
| patch_reensemble_shape = (self.patch_process_shape[0] * patch_split_num[0], self.patch_process_shape[1] * patch_split_num[1]) | |
| patch_raw_shape = (image_raw_shape[0] // patch_split_num[0], image_raw_shape[1] // patch_split_num[1]) | |
| image_raw_shape = image_raw_shape | |
| raw_h_split_point = [] | |
| raw_w_split_point = [] | |
| for i in range(patch_split_num[0]): | |
| raw_h_split_point.append(int(patch_raw_shape[0] * i)) | |
| for i in range(patch_split_num[1]): | |
| raw_w_split_point.append(int(patch_raw_shape[1] * i)) | |
| tile_cfg = { | |
| 'patch_split_num': patch_split_num, | |
| 'patch_reensemble_shape': patch_reensemble_shape, | |
| 'patch_raw_shape': patch_raw_shape, | |
| 'image_raw_shape': image_raw_shape, | |
| 'raw_h_split_point': raw_h_split_point, | |
| 'raw_w_split_point': raw_w_split_point} | |
| return tile_cfg | |
| def load_dict(self, dict): | |
| if hasattr(self, 'coarse_branch') and hasattr(self, 'fine_branch') == False: | |
| return self.coarse_branch.load_state_dict(dict, strict=True) | |
| elif hasattr(self, 'fine_branch') and hasattr(self, 'coarse_branch') == False: | |
| return self.fine_branch.load_state_dict(dict, strict=True) | |
| else: | |
| raise NotImplementedError('Not support loading coarse and fine together') | |
| def get_save_dict(self): | |
| model_state_dict = {} | |
| if hasattr(self, 'coarse_branch') and hasattr(self, 'fine_branch') == False: | |
| model_state_dict.update(self.coarse_branch.state_dict()) | |
| elif hasattr(self, 'fine_branch') and hasattr(self, 'coarse_branch') == False: | |
| model_state_dict.update(self.fine_branch.state_dict()) | |
| else: | |
| raise NotImplementedError('Not support training coarse and fine together') | |
| return model_state_dict | |
| def infer_forward(self, imgs_crop): | |
| output_dict = self.fine_branch(imgs_crop) | |
| return output_dict['metric_depth'] | |
| def random_tile( | |
| self, | |
| image_hr, | |
| tile_temp=None, | |
| blur_mask=None, | |
| avg_depth_map=None, | |
| tile_cfg=None, | |
| process_num=4,): | |
| ## setting | |
| height, width = tile_cfg['patch_raw_shape'][0], tile_cfg['patch_raw_shape'][1] | |
| h_start_list = [random.randint(0, tile_cfg['image_raw_shape'][0] - height - 1) for _ in range(process_num)] | |
| w_start_list = [random.randint(0, tile_cfg['image_raw_shape'][1] - width - 1)] | |
| ## prepare data | |
| imgs_crop = [] | |
| bboxs = [] | |
| for h_start in h_start_list: | |
| for w_start in w_start_list: | |
| crop_image = image_hr[:, h_start: h_start+height, w_start: w_start+width] | |
| crop_image_resized = self.resizer(crop_image.unsqueeze(dim=0)).squeeze(dim=0) # resize to patch_process_shape | |
| bbox = torch.tensor([w_start, h_start, w_start+width, h_start+height]) | |
| imgs_crop.append(crop_image_resized) | |
| bboxs.append(bbox) | |
| imgs_crop = torch.stack(imgs_crop, dim=0) | |
| bboxs = torch.stack(bboxs, dim=0) | |
| imgs_crop = imgs_crop.to(image_hr.device) | |
| bboxs = bboxs.to(image_hr.device).int() | |
| bboxs_feat_factor = torch.tensor([ | |
| 1 / tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1], | |
| 1 / tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0], | |
| 1 / tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1], | |
| 1 / tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0]], device=bboxs.device).unsqueeze(dim=0) | |
| bboxs_feat = bboxs * bboxs_feat_factor | |
| inds = torch.arange(bboxs.shape[0]).to(bboxs.device).unsqueeze(dim=-1) | |
| bboxs_feat = torch.cat((inds, bboxs_feat), dim=-1) | |
| if tile_temp is not None: | |
| coarse_postprocess_dict = self.coarse_postprocess_test(bboxs=bboxs, bboxs_feat=bboxs_feat, **tile_temp) | |
| prediction_list = [] | |
| if tile_temp is not None: | |
| coarse_temp_dict = {} | |
| for k, v in coarse_postprocess_dict.items(): | |
| if k == 'coarse_feats_roi': | |
| coarse_temp_dict[k] = [f for f in v] | |
| else: | |
| coarse_temp_dict[k] = v | |
| bbox_feat_forward = bboxs_feat | |
| bbox_feat_forward[:, 0] = 0 | |
| prediction = self.infer_forward(imgs_crop, bbox_feat_forward, tile_temp, coarse_temp_dict) | |
| else: | |
| prediction = self.infer_forward(imgs_crop) | |
| prediction_list.append(prediction) | |
| predictions = torch.cat(prediction_list, dim=0) | |
| predictions = F.interpolate(predictions, tile_cfg['patch_raw_shape']) | |
| patch_select_idx = 0 | |
| for h_start in h_start_list: | |
| for w_start in w_start_list: | |
| temp_depth = predictions[patch_select_idx] | |
| count_map = torch.zeros(tile_cfg['image_raw_shape'], device=temp_depth.device) | |
| pred_depth = torch.zeros(tile_cfg['image_raw_shape'], device=temp_depth.device) | |
| count_map[h_start: h_start+tile_cfg['patch_raw_shape'][0], w_start: w_start+tile_cfg['patch_raw_shape'][1]] = blur_mask | |
| pred_depth[h_start: h_start+tile_cfg['patch_raw_shape'][0], w_start: w_start+tile_cfg['patch_raw_shape'][1]] = temp_depth * blur_mask | |
| avg_depth_map.update(pred_depth, count_map) | |
| patch_select_idx += 1 | |
| return avg_depth_map | |
| def regular_tile( | |
| self, | |
| offset, | |
| offset_process, | |
| image_hr, | |
| init_flag=False, | |
| tile_temp=None, | |
| blur_mask=None, | |
| avg_depth_map=None, | |
| tile_cfg=None, | |
| process_num=4,): | |
| ## setting | |
| height, width = tile_cfg['patch_raw_shape'][0], tile_cfg['patch_raw_shape'][1] | |
| offset_h, offset_w = offset[0], offset[1] | |
| assert offset_w >= 0 and offset_h >= 0 | |
| tile_num_h = (tile_cfg['image_raw_shape'][0] - offset_h) // height | |
| tile_num_w = (tile_cfg['image_raw_shape'][1] - offset_w) // width | |
| h_start_list = [height * h + offset_h for h in range(tile_num_h)] | |
| w_start_list = [width * w + offset_w for w in range(tile_num_w)] | |
| height_process, width_process = self.patch_process_shape[0], self.patch_process_shape[1] | |
| offset_h_process, offset_w_process = offset_process[0], offset_process[1] | |
| assert offset_h_process >= 0 and offset_w_process >= 0 | |
| tile_num_h_process = (tile_cfg['patch_reensemble_shape'][0] - offset_h_process) // height_process | |
| tile_num_w_process = (tile_cfg['patch_reensemble_shape'][1] - offset_w_process) // width_process | |
| h_start_list_process = [height_process * h + offset_h_process for h in range(tile_num_h_process)] | |
| w_start_list_process = [width_process * w + offset_w_process for w in range(tile_num_w_process)] | |
| ## prepare data | |
| imgs_crop = [] | |
| bboxs = [] | |
| iter_priors = [] | |
| for h_start in h_start_list: | |
| for w_start in w_start_list: | |
| crop_image = image_hr[:, h_start: h_start+height, w_start: w_start+width] | |
| crop_image_resized = self.resizer(crop_image.unsqueeze(dim=0)).squeeze(dim=0) # resize to patch_process_shape | |
| bbox = torch.tensor([w_start, h_start, w_start+width, h_start+height]) | |
| imgs_crop.append(crop_image_resized) | |
| bboxs.append(bbox) | |
| imgs_crop = torch.stack(imgs_crop, dim=0) | |
| bboxs = torch.stack(bboxs, dim=0) | |
| imgs_crop = imgs_crop.to(image_hr.device) | |
| bboxs = bboxs.to(image_hr.device).int() | |
| bboxs = bboxs.squeeze() # HACK: during inference, 1, 16, 4 -> 16, 4 | |
| if len(bboxs.shape) == 1: | |
| bboxs = bboxs.unsqueeze(dim=0) | |
| bboxs_feat_factor = torch.tensor([ | |
| 1 / tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1], | |
| 1 / tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0], | |
| 1 / tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1], | |
| 1 / tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0]], device=bboxs.device).unsqueeze(dim=0) | |
| bboxs_feat = bboxs * bboxs_feat_factor | |
| inds = torch.arange(bboxs.shape[0]).to(bboxs.device).unsqueeze(dim=-1) | |
| bboxs_feat = torch.cat((inds, bboxs_feat), dim=-1) | |
| # post_process | |
| if tile_temp is not None: | |
| # coarse_prediction_roi, coarse_features_patch_area, crop_coarse_prediction_collection = self.coarse_postprocess_test(bboxs=bboxs, bboxs_feat=bboxs_feat, **tile_temp) | |
| coarse_postprocess_dict = self.coarse_postprocess_test(bboxs=bboxs, bboxs_feat=bboxs_feat, **tile_temp) | |
| count_map = torch.zeros(tile_cfg['patch_reensemble_shape'], device=image_hr.device) | |
| pred_depth = torch.zeros(tile_cfg['patch_reensemble_shape'], device=image_hr.device) | |
| prediction_list = [] | |
| split_rebatch_image = torch.split(imgs_crop, process_num, dim=0) | |
| for idx, rebatch_image in enumerate(split_rebatch_image): | |
| if tile_temp is not None: | |
| coarse_temp_dict = {} | |
| for k, v in coarse_postprocess_dict.items(): | |
| if k == 'coarse_feats_roi': | |
| coarse_temp_dict[k] = [f[idx*process_num:(idx+1)*process_num, :, :, :] for f in v] | |
| else: | |
| coarse_temp_dict[k] = v[idx*process_num:(idx+1)*process_num, :, :, :] | |
| bbox_feat_forward = bboxs_feat[idx*process_num:(idx+1)*process_num, :] | |
| bbox_feat_forward[:, 0] = 0 | |
| prediction = self.infer_forward(rebatch_image, bbox_feat_forward, tile_temp, coarse_temp_dict) | |
| else: | |
| prediction = self.infer_forward(rebatch_image) | |
| prediction_list.append(prediction) | |
| predictions = torch.cat(prediction_list, dim=0) | |
| patch_select_idx = 0 | |
| for h_start in h_start_list_process: | |
| for w_start in w_start_list_process: | |
| temp_depth = predictions[patch_select_idx] | |
| if init_flag: | |
| count_map[h_start: h_start+self.patch_process_shape[0], w_start: w_start+self.patch_process_shape[1]] = blur_mask | |
| pred_depth[h_start: h_start+self.patch_process_shape[0], w_start: w_start+self.patch_process_shape[1]] = temp_depth * blur_mask | |
| else: | |
| count_map = torch.zeros(tile_cfg['patch_reensemble_shape'], device=temp_depth.device) | |
| pred_depth = torch.zeros(tile_cfg['patch_reensemble_shape'], device=temp_depth.device) | |
| count_map[h_start: h_start+self.patch_process_shape[0], w_start: w_start+self.patch_process_shape[1]] = blur_mask | |
| pred_depth[h_start: h_start+self.patch_process_shape[0], w_start: w_start+self.patch_process_shape[1]] = temp_depth * blur_mask | |
| avg_depth_map.update(pred_depth, count_map) | |
| patch_select_idx += 1 | |
| if init_flag: | |
| avg_depth_map = RunningAverageMap(pred_depth, count_map) | |
| return avg_depth_map | |
| def forward( | |
| self, | |
| mode, | |
| image_lr, | |
| image_hr, | |
| depth_gt, | |
| crop_depths=None, | |
| crops_image_hr=None, | |
| bboxs=None, | |
| tile_cfg=None, | |
| cai_mode='m1', | |
| process_num=4, | |
| **kwargs): | |
| if mode == 'train': | |
| loss_dict = {} | |
| if self.target == 'coarse': | |
| model_output_dict = self.coarse_branch(image_lr) | |
| depth_prediction = model_output_dict['metric_depth'] | |
| loss_dict['coarse_loss'] = self.sigloss(depth_prediction, depth_gt, self.min_depth, self.max_depth) | |
| loss_dict['total_loss'] = loss_dict['coarse_loss'] | |
| return loss_dict, {'rgb': image_lr, 'depth_pred': depth_prediction, 'depth_gt': depth_gt} | |
| elif self.target == 'fine': | |
| model_output_dict = self.fine_branch(crops_image_hr) # 1/2 res, 1/4 res, 1/8 res, 1/16 res | |
| depth_prediction = model_output_dict['metric_depth'] | |
| loss_dict['fine_loss'] = self.sigloss(depth_prediction, crop_depths, self.min_depth, self.max_depth) | |
| loss_dict['total_loss'] = loss_dict['fine_loss'] | |
| return loss_dict, {'rgb': image_lr, 'depth_pred': depth_prediction, 'depth_gt': crop_depths} | |
| else: | |
| raise NotImplementedError | |
| else: | |
| if self.target == 'coarse': | |
| model_output_dict = self.coarse_branch(image_lr) | |
| depth_prediction = model_output_dict['metric_depth'] | |
| elif self.target == 'fine': | |
| if tile_cfg is None: | |
| tile_cfg = self.tile_cfg | |
| else: | |
| tile_cfg = self.prepare_tile_cfg(tile_cfg['image_raw_shape'], tile_cfg['patch_split_num']) | |
| assert image_hr.shape[0] == 1 | |
| blur_mask = generatemask((self.patch_process_shape[0], self.patch_process_shape[1])) + 1e-3 | |
| blur_mask = torch.tensor(blur_mask, device=image_hr.device) | |
| avg_depth_map = self.regular_tile( | |
| offset=[0, 0], | |
| offset_process=[0, 0], | |
| image_hr=image_hr[0], | |
| init_flag=True, | |
| tile_temp=None, | |
| blur_mask=blur_mask, | |
| tile_cfg=tile_cfg, | |
| process_num=process_num) | |
| if cai_mode == 'm2' or cai_mode[0] == 'r': | |
| avg_depth_map = self.regular_tile( | |
| offset=[0, tile_cfg['patch_raw_shape'][1]//2], | |
| offset_process=[0, self.patch_process_shape[1]//2], | |
| image_hr=image_hr[0], init_flag=False, tile_temp=None, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) | |
| avg_depth_map = self.regular_tile( | |
| offset=[tile_cfg['patch_raw_shape'][0]//2, 0], | |
| offset_process=[self.patch_process_shape[0]//2, 0], | |
| image_hr=image_hr[0], init_flag=False, tile_temp=None, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) | |
| avg_depth_map = self.regular_tile( | |
| offset=[tile_cfg['patch_raw_shape'][0]//2, tile_cfg['patch_raw_shape'][1]//2], | |
| offset_process=[self.patch_process_shape[0]//2, self.patch_process_shape[1]//2], | |
| init_flag=False, image_hr=image_hr[0], tile_temp=None, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) | |
| if cai_mode[0] == 'r': | |
| blur_mask = generatemask((tile_cfg['patch_raw_shape'][0], tile_cfg['patch_raw_shape'][1])) + 1e-3 | |
| blur_mask = torch.tensor(blur_mask, device=image_hr.device) | |
| avg_depth_map.resize(tile_cfg['image_raw_shape']) | |
| patch_num = int(cai_mode[1:]) | |
| for i in range(patch_num): | |
| avg_depth_map = self.random_tile( | |
| image_hr=image_hr[0], tile_temp=None, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) | |
| depth = avg_depth_map.average_map | |
| depth = depth.unsqueeze(dim=0).unsqueeze(dim=0) | |
| return depth, {} | |
| else: | |
| raise NotImplementedError | |
| return depth_prediction, {'rgb': image_lr, 'depth_pred': depth_prediction, 'depth_gt': depth_gt} | |