Spaces:
Runtime error
Runtime error
| # MIT License | |
| # Copyright (c) 2022 Intelligent Systems Lab Org | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE. | |
| # File author: Zhenyu Li | |
| import itertools | |
| import math | |
| import copy | |
| import torch | |
| import random | |
| import numpy as np | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import matplotlib.pyplot as plt | |
| from mmengine import print_log | |
| from mmengine.config import ConfigDict | |
| from torchvision.ops import roi_align as torch_roi_align | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from transformers import PretrainedConfig | |
| from estimator.registry import MODELS | |
| from estimator.models import build_model | |
| from estimator.models.baseline_pretrain import BaselinePretrain | |
| from estimator.models.utils import generatemask | |
| from zoedepth.models.zoedepth import ZoeDepth | |
| from zoedepth.models.layers.attractor import AttractorLayer, AttractorLayerUnnormed | |
| from zoedepth.models.layers.dist_layers import ConditionalLogBinomial | |
| from zoedepth.models.layers.localbins_layers import (Projector, SeedBinRegressor, SeedBinRegressorUnnormed) | |
| from zoedepth.models.base_models.midas import Resize as ResizeZoe | |
| from depth_anything.transform import Resize as ResizeDA | |
| class PatchFusion(BaselinePretrain, PyTorchModelHubMixin): | |
| def __init__( | |
| self, | |
| config,): | |
| """ZoeDepth model | |
| """ | |
| nn.Module.__init__(self) | |
| if isinstance(config, ConfigDict): | |
| # convert a ConfigDict to a PretrainedConfig for hf saving | |
| config = PretrainedConfig.from_dict(config.to_dict()) | |
| config.load_branch = True | |
| else: | |
| # used when loading patchfusion from hf model space | |
| config = PretrainedConfig.from_dict(ConfigDict(**config).to_dict()) | |
| config.load_branch = False | |
| config.coarse_branch.pretrained_resource = None | |
| config.fine_branch.pretrained_resource = None | |
| self.config = config | |
| self.min_depth = config.min_depth | |
| self.max_depth = config.max_depth | |
| self.patch_process_shape = config.patch_process_shape | |
| self.tile_cfg = self.prepare_tile_cfg(config.image_raw_shape, config.patch_split_num) | |
| self.coarse_branch_cfg = config.coarse_branch | |
| if config.coarse_branch.type == 'ZoeDepth': | |
| self.coarse_branch = ZoeDepth.build(**config.coarse_branch) | |
| self.resizer = ResizeZoe(config.patch_process_shape[1], config.patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=32, resize_method="minimal") | |
| elif config.coarse_branch.type == 'DA-ZoeDepth': | |
| self.coarse_branch = ZoeDepth.build(**config.coarse_branch) | |
| self.resizer = ResizeDA(config.patch_process_shape[1], config.patch_process_shape[0], keep_aspect_ratio=False, ensure_multiple_of=14, resize_method="minimal") | |
| else: | |
| raise NotImplementedError | |
| if config.fine_branch.type == 'ZoeDepth': | |
| self.fine_branch = ZoeDepth.build(**config.fine_branch) | |
| elif config.fine_branch.type == 'DA-ZoeDepth': | |
| self.fine_branch = ZoeDepth.build(**config.fine_branch) | |
| else: | |
| raise NotImplementedError | |
| if config.load_branch: | |
| print_log("Loading coarse_branch from {}".format(config.pretrain_model[0]), logger='current') | |
| print_log(self.coarse_branch.load_state_dict(torch.load(config.pretrain_model[0], map_location='cpu')['model_state_dict'], strict=True), logger='current') # coarse ckp | |
| print_log("Loading fine_branch from {}".format(config.pretrain_model[1]), logger='current') | |
| print_log(self.fine_branch.load_state_dict(torch.load(config.pretrain_model[1], map_location='cpu')['model_state_dict'], strict=True), logger='current') | |
| # freeze all these parameters | |
| for param in self.coarse_branch.parameters(): | |
| param.requires_grad = False | |
| for param in self.fine_branch.parameters(): | |
| param.requires_grad = False | |
| self.sigloss = build_model(config.sigloss) | |
| N_MIDAS_OUT = 32 | |
| btlnck_features = self.fine_branch.core.output_channels[0] | |
| self.fusion_conv_list = nn.ModuleList() | |
| for i in range(6): | |
| if i == 5: | |
| layer = nn.Conv2d(N_MIDAS_OUT * 2, N_MIDAS_OUT, 3, 1, 1) | |
| else: | |
| layer = nn.Conv2d(btlnck_features * 2, btlnck_features, 3, 1, 1) | |
| self.fusion_conv_list.append(layer) | |
| self.guided_fusion = build_model(config.guided_fusion) | |
| # NOTE: a decoder head | |
| if self.coarse_branch_cfg.bin_centers_type == "normed": | |
| SeedBinRegressorLayer = SeedBinRegressor | |
| Attractor = AttractorLayer | |
| elif self.coarse_branch_cfg.bin_centers_type == "softplus": # default | |
| SeedBinRegressorLayer = SeedBinRegressorUnnormed | |
| Attractor = AttractorLayerUnnormed | |
| elif self.coarse_branch_cfg.bin_centers_type == "hybrid1": | |
| SeedBinRegressorLayer = SeedBinRegressor | |
| Attractor = AttractorLayerUnnormed | |
| elif self.coarse_branch_cfg.bin_centers_type == "hybrid2": | |
| SeedBinRegressorLayer = SeedBinRegressorUnnormed | |
| Attractor = AttractorLayer | |
| else: | |
| raise ValueError( | |
| "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") | |
| N_MIDAS_OUT = 32 | |
| btlnck_features = self.fine_branch.core.output_channels[0] | |
| num_out_features = self.fine_branch.core.output_channels[1:] # all of them are the same | |
| self.seed_bin_regressor = SeedBinRegressorLayer( | |
| btlnck_features, n_bins=self.coarse_branch_cfg.n_bins, min_depth=config.min_depth, max_depth=config.max_depth) | |
| self.seed_projector = Projector(btlnck_features, self.coarse_branch_cfg.bin_embedding_dim) | |
| self.projectors = nn.ModuleList([ | |
| Projector(num_out, self.coarse_branch_cfg.bin_embedding_dim) | |
| for num_out in num_out_features | |
| ]) | |
| # 1000, 2, inv, mean | |
| self.attractors = nn.ModuleList([ | |
| Attractor(self.coarse_branch_cfg.bin_embedding_dim, self.coarse_branch_cfg.n_bins, n_attractors=self.coarse_branch_cfg.n_attractors[i], min_depth=config.min_depth, max_depth=config.max_depth, | |
| alpha=self.coarse_branch_cfg.attractor_alpha, gamma=self.coarse_branch_cfg.attractor_gamma, kind=self.coarse_branch_cfg.attractor_kind, attractor_type=self.coarse_branch_cfg.attractor_type) | |
| for i in range(len(num_out_features)) | |
| ]) | |
| last_in = N_MIDAS_OUT + 1 # +1 for relative depth | |
| # use log binomial instead of softmax | |
| self.conditional_log_binomial = ConditionalLogBinomial( | |
| last_in, self.coarse_branch_cfg.bin_embedding_dim, n_classes=self.coarse_branch_cfg.n_bins, min_temp=self.coarse_branch_cfg.min_temp, max_temp=self.coarse_branch_cfg.max_temp) | |
| # NOTE: consistency training | |
| self.consistency_training = False | |
| def load_dict(self, dict): | |
| return self.load_state_dict(dict, strict=False) | |
| def get_save_dict(self): | |
| current_model_dict = self.state_dict() | |
| save_state_dict = {} | |
| for k, v in current_model_dict.items(): | |
| if 'coarse_branch' in k or 'fine_branch' in k: | |
| pass | |
| else: | |
| save_state_dict[k] = v | |
| return save_state_dict | |
| def coarse_forward(self, image_lr): | |
| with torch.no_grad(): | |
| if self.coarse_branch.training: | |
| self.coarse_branch.eval() | |
| deep_model_output_dict = self.coarse_branch(image_lr, return_final_centers=True) | |
| deep_features = deep_model_output_dict['temp_features'] # x_d0 1/128, x_blocks_feat_0 1/64, x_blocks_feat_1 1/32, x_blocks_feat_2 1/16, x_blocks_feat_3 1/8, midas_final_feat 1/4 [based on 384x4, 512x4] | |
| coarse_prediction = deep_model_output_dict['metric_depth'] | |
| coarse_features = [ | |
| deep_features['x_d0'], | |
| deep_features['x_blocks_feat_0'], | |
| deep_features['x_blocks_feat_1'], | |
| deep_features['x_blocks_feat_2'], | |
| deep_features['x_blocks_feat_3'], | |
| deep_features['midas_final_feat']] # bs, c, h, w | |
| return coarse_prediction, coarse_features | |
| def fine_forward(self, image_hr_crop): | |
| with torch.no_grad(): | |
| if self.fine_branch.training: | |
| self.fine_branch.eval() | |
| deep_model_output_dict = self.fine_branch(image_hr_crop, return_final_centers=True) | |
| deep_features = deep_model_output_dict['temp_features'] # x_d0 1/128, x_blocks_feat_0 1/64, x_blocks_feat_1 1/32, x_blocks_feat_2 1/16, x_blocks_feat_3 1/8, midas_final_feat 1/4 [based on 384x4, 512x4] | |
| fine_prediction = deep_model_output_dict['metric_depth'] | |
| fine_features = [ | |
| deep_features['x_d0'], | |
| deep_features['x_blocks_feat_0'], | |
| deep_features['x_blocks_feat_1'], | |
| deep_features['x_blocks_feat_2'], | |
| deep_features['x_blocks_feat_3'], | |
| deep_features['midas_final_feat']] # bs, c, h, w | |
| return fine_prediction, fine_features | |
| def coarse_postprocess_train(self, coarse_prediction, coarse_features, bboxs, bboxs_feat): | |
| coarse_features_patch_area = [] | |
| for idx, feat in enumerate(coarse_features): | |
| bs, _, h, w = feat.shape | |
| cur_lvl_feat = torch_roi_align(feat, bboxs_feat, (h, w), h/self.patch_process_shape[0], aligned=True) | |
| coarse_features_patch_area.append(cur_lvl_feat) | |
| coarse_prediction_roi = torch_roi_align(coarse_prediction, bboxs_feat, coarse_prediction.shape[-2:], coarse_prediction.shape[-2]/self.patch_process_shape[0], aligned=True) | |
| return coarse_prediction_roi, coarse_features_patch_area | |
| def coarse_postprocess_test(self, coarse_prediction, coarse_features, bboxs, bboxs_feat): | |
| patch_num = bboxs_feat.shape[0] | |
| coarse_features_patch_area = [] | |
| for idx, feat in enumerate(coarse_features): | |
| bs, _, h, w = feat.shape | |
| feat_extend = feat.repeat(patch_num, 1, 1, 1) | |
| cur_lvl_feat = torch_roi_align(feat_extend, bboxs_feat, (h, w), h/self.patch_process_shape[0], aligned=True) | |
| coarse_features_patch_area.append(cur_lvl_feat) | |
| coarse_prediction = coarse_prediction.repeat(patch_num, 1, 1, 1) | |
| coarse_prediction_roi = torch_roi_align(coarse_prediction, bboxs_feat, coarse_prediction.shape[-2:], coarse_prediction.shape[-2]/self.patch_process_shape[0], aligned=True) | |
| return_dict = { | |
| 'coarse_depth_roi': coarse_prediction_roi, | |
| 'coarse_feats_roi': coarse_features_patch_area} | |
| return return_dict | |
| def fusion_forward(self, fine_depth_pred, crop_input, coarse_model_midas_enc_feats, fine_model_midas_enc_feats, bbox_feat, coarse_depth_roi=None, coarse_feats_roi=None): | |
| feat_cat_list = [] | |
| feat_plus_list = [] | |
| for l_i, (f_ca, f_c_roi, f_f) in enumerate(zip(coarse_model_midas_enc_feats, coarse_feats_roi, fine_model_midas_enc_feats)): | |
| feat_cat = self.fusion_conv_list[l_i](torch.cat([f_c_roi, f_f], dim=1)) | |
| feat_plus = f_c_roi + f_f | |
| feat_cat_list.append(feat_cat) | |
| feat_plus_list.append(feat_plus) | |
| input_tensor = torch.cat([coarse_depth_roi, fine_depth_pred, crop_input], dim=1) | |
| # HACK: hack for depth-anything | |
| # if self.coarse_branch_cfg.type == 'DA-ZoeDepth': | |
| # input_tensor = F.interpolate(input_tensor, size=(448, 592), mode='bilinear', align_corners=True) | |
| output = self.guided_fusion( | |
| input_tensor = input_tensor, | |
| guide_plus = feat_plus_list, | |
| guide_cat = feat_cat_list, | |
| bbox = bbox_feat, | |
| fine_feat_crop = fine_model_midas_enc_feats, | |
| coarse_feat_whole = coarse_model_midas_enc_feats, | |
| coarse_feat_crop = coarse_feats_roi, | |
| coarse_feat_whole_hack=None)[::-1] # low -> high | |
| x_blocks = output | |
| x = x_blocks[0] | |
| x_blocks = x_blocks[1:] | |
| proj_feat_list = [] | |
| if self.consistency_training: | |
| if self.consistency_target == 'unet_feat': | |
| proj_feat_list = [] | |
| for idx, feat in enumerate(output): | |
| proj_feat = self.consistency_projs[idx](feat) | |
| proj_feat_list.append(proj_feat) | |
| # NOTE: below is ZoeDepth implementation | |
| last = x_blocks[-1] # have already been fused in x_blocks | |
| bs, c, h, w = last.shape | |
| rel_cond = torch.zeros((bs, 1, h, w), device=last.device) | |
| _, seed_b_centers = self.seed_bin_regressor(x) | |
| if self.coarse_branch_cfg.bin_centers_type == 'normed' or self.coarse_branch_cfg.bin_centers_type == 'hybrid2': | |
| b_prev = (seed_b_centers - self.min_depth) / \ | |
| (self.max_depth - self.min_depth) | |
| else: | |
| b_prev = seed_b_centers | |
| prev_b_embedding = self.seed_projector(x) | |
| # unroll this loop for better performance | |
| for idx, (projector, attractor, x) in enumerate(zip(self.projectors, self.attractors, x_blocks)): | |
| b_embedding = projector(x) | |
| b, b_centers = attractor( | |
| b_embedding, b_prev, prev_b_embedding, interpolate=True) | |
| b_prev = b.clone() | |
| prev_b_embedding = b_embedding.clone() | |
| if self.consistency_training: | |
| if self.consistency_target == 'final_feat': | |
| proj_feat_1 = self.consistency_projs[0](b_centers) | |
| proj_feat_2 = self.consistency_projs[1](last) | |
| proj_feat_3 = self.consistency_projs[2](b_embedding) | |
| proj_feat_list = [proj_feat_1, proj_feat_2, proj_feat_3] | |
| rel_cond = nn.functional.interpolate( | |
| rel_cond, size=last.shape[2:], mode='bilinear', align_corners=True) | |
| last = torch.cat([last, rel_cond], dim=1) # + self.coarse_depth_proj(whole_depth_roi_pred) + self.fine_depth_proj(fine_depth_pred) | |
| b_embedding = nn.functional.interpolate( | |
| b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) | |
| # till here, we have features (attached with a relative depth prediction) and embeddings | |
| # post process | |
| # final_pred = out * self.blur_mask + whole_depth_roi_pred * (1-self.blur_mask) | |
| # out = F.interpolate(out, (540, 960), mode='bilinear', align_corners=True) | |
| x = self.conditional_log_binomial(last, b_embedding) | |
| b_centers = nn.functional.interpolate( | |
| b_centers, x.shape[-2:], mode='bilinear', align_corners=True) | |
| out = torch.sum(x * b_centers, dim=1, keepdim=True) | |
| return out, proj_feat_list | |
| def infer_forward(self, imgs_crop, bbox_feat_forward, tile_temp, coarse_temp_dict): | |
| fine_prediction, fine_features = self.fine_forward(imgs_crop) | |
| depth_prediction, consistency_target = \ | |
| self.fusion_forward( | |
| fine_prediction, | |
| imgs_crop, | |
| tile_temp['coarse_features'], | |
| fine_features, | |
| bbox_feat_forward, | |
| **coarse_temp_dict) | |
| return depth_prediction | |
| def forward( | |
| self, | |
| mode, | |
| image_lr, | |
| image_hr, | |
| depth_gt=None, | |
| crops_image_hr=None, | |
| crop_depths=None, | |
| bboxs=None, | |
| tile_cfg=None, | |
| cai_mode='m1', | |
| process_num=4): | |
| if mode == 'train': | |
| bboxs_feat_factor = torch.tensor([ | |
| 1 / self.tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1], | |
| 1 / self.tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0], | |
| 1 / self.tile_cfg['image_raw_shape'][1] * self.patch_process_shape[1], | |
| 1 / self.tile_cfg['image_raw_shape'][0] * self.patch_process_shape[0]], device=bboxs.device).unsqueeze(dim=0) | |
| bboxs_feat = bboxs * bboxs_feat_factor | |
| inds = torch.arange(bboxs.shape[0]).to(bboxs.device).unsqueeze(dim=-1) | |
| bboxs_feat = torch.cat((inds, bboxs_feat), dim=-1) | |
| coarse_prediction, coarse_features = self.coarse_forward(image_lr) | |
| fine_prediction, fine_features = self.fine_forward(crops_image_hr) | |
| coarse_prediction_roi, coarse_features_patch_area = self.coarse_postprocess_train(coarse_prediction, coarse_features, bboxs, bboxs_feat) | |
| depth_prediction, consistency_target = self.fusion_forward( | |
| fine_prediction, | |
| crops_image_hr, | |
| coarse_features, | |
| fine_features, | |
| bboxs_feat, | |
| coarse_depth_roi=coarse_prediction_roi, | |
| coarse_feats_roi=coarse_features_patch_area,) | |
| loss_dict = {} | |
| loss_dict['sig_loss'] = self.sigloss(depth_prediction, crop_depths, self.min_depth, self.max_depth) | |
| loss_dict['total_loss'] = loss_dict['sig_loss'] | |
| return loss_dict, {'rgb': crops_image_hr, 'depth_pred': depth_prediction, 'depth_gt': crop_depths} | |
| else: | |
| if tile_cfg is None: | |
| tile_cfg = self.tile_cfg | |
| else: | |
| tile_cfg = self.prepare_tile_cfg(tile_cfg['image_raw_shape'], tile_cfg['patch_split_num']) | |
| assert image_hr.shape[0] == 1 | |
| coarse_prediction, coarse_features = self.coarse_forward(image_lr) | |
| tile_temp = { | |
| 'coarse_prediction': coarse_prediction, | |
| 'coarse_features': coarse_features,} | |
| blur_mask = generatemask((self.patch_process_shape[0], self.patch_process_shape[1])) + 1e-3 | |
| blur_mask = torch.tensor(blur_mask, device=image_hr.device) | |
| avg_depth_map = self.regular_tile( | |
| offset=[0, 0], | |
| offset_process=[0, 0], | |
| image_hr=image_hr[0], | |
| init_flag=True, | |
| tile_temp=tile_temp, | |
| blur_mask=blur_mask, | |
| tile_cfg=tile_cfg, | |
| process_num=process_num) | |
| if cai_mode == 'm2' or cai_mode[0] == 'r': | |
| avg_depth_map = self.regular_tile( | |
| offset=[0, tile_cfg['patch_raw_shape'][1]//2], | |
| offset_process=[0, self.patch_process_shape[1]//2], | |
| image_hr=image_hr[0], init_flag=False, tile_temp=tile_temp, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) | |
| avg_depth_map = self.regular_tile( | |
| offset=[tile_cfg['patch_raw_shape'][0]//2, 0], | |
| offset_process=[self.patch_process_shape[0]//2, 0], | |
| image_hr=image_hr[0], init_flag=False, tile_temp=tile_temp, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) | |
| avg_depth_map = self.regular_tile( | |
| offset=[tile_cfg['patch_raw_shape'][0]//2, tile_cfg['patch_raw_shape'][1]//2], | |
| offset_process=[self.patch_process_shape[0]//2, self.patch_process_shape[1]//2], | |
| init_flag=False, image_hr=image_hr[0], tile_temp=tile_temp, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) | |
| if cai_mode[0] == 'r': | |
| blur_mask = generatemask((tile_cfg['patch_raw_shape'][0], tile_cfg['patch_raw_shape'][1])) + 1e-3 | |
| blur_mask = torch.tensor(blur_mask, device=image_hr.device) | |
| avg_depth_map.resize(tile_cfg['image_raw_shape']) | |
| patch_num = int(cai_mode[1:]) // process_num | |
| for i in range(patch_num): | |
| avg_depth_map = self.random_tile( | |
| image_hr=image_hr[0], tile_temp=tile_temp, blur_mask=blur_mask, avg_depth_map=avg_depth_map, tile_cfg=tile_cfg, process_num=process_num) | |
| depth = avg_depth_map.average_map | |
| depth = depth.unsqueeze(dim=0).unsqueeze(dim=0) | |
| return depth, {'rgb': image_lr, 'depth_pred': depth, 'depth_gt': depth_gt} | |