Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import wandb | |
| import numpy as np | |
| import torch | |
| import mmengine | |
| from mmengine.optim import build_optim_wrapper | |
| import torch.optim as optim | |
| import matplotlib.pyplot as plt | |
| from mmengine.dist import get_dist_info, collect_results_cpu, collect_results_gpu | |
| from mmengine import print_log | |
| from estimator.utils import colorize, colorize_infer_pfv1, colorize_rescale | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| from mmengine.utils import mkdir_or_exist | |
| import copy | |
| from skimage import io | |
| import kornia | |
| from PIL import Image | |
| class Tester: | |
| """ | |
| Tester class | |
| """ | |
| def __init__( | |
| self, | |
| config, | |
| runner_info, | |
| dataloader, | |
| model): | |
| self.config = config | |
| self.runner_info = runner_info | |
| self.dataloader = dataloader | |
| self.model = model | |
| self.collect_input_args = config.collect_input_args | |
| def collect_input(self, batch_data): | |
| collect_batch_data = dict() | |
| for k, v in batch_data.items(): | |
| if isinstance(v, torch.Tensor): | |
| if k in self.collect_input_args: | |
| collect_batch_data[k] = v.cuda() | |
| return collect_batch_data | |
| def run(self, cai_mode='p16', process_num=4): | |
| results = [] | |
| dataset = self.dataloader.dataset | |
| loader_indices = self.dataloader.batch_sampler | |
| rank, world_size = get_dist_info() | |
| if self.runner_info.rank == 0: | |
| prog_bar = mmengine.utils.ProgressBar(len(dataset)) | |
| for idx, (batch_indices, batch_data) in enumerate(zip(loader_indices, self.dataloader)): | |
| batch_data_collect = self.collect_input(batch_data) | |
| result, log_dict = self.model(mode='infer', cai_mode=cai_mode, process_num=process_num, **batch_data_collect) # might use test/val to split cases | |
| if self.runner_info.save: | |
| color_pred = colorize(result, cmap='magma_r')[:, :, [2, 1, 0]] | |
| cv2.imwrite(os.path.join(self.runner_info.work_dir, '{}.png'.format(batch_data['img_file_basename'][0])), color_pred) | |
| # Save as PNG | |
| raw_depth = Image.fromarray((result.clone().squeeze().detach().cpu().numpy()*256).astype('uint16')) | |
| raw_depth.save(os.path.join(self.runner_info.work_dir, '{}_uint16.png'.format(batch_data['img_file_basename'][0]))) | |
| if batch_data_collect.get('depth_gt', None) is not None: | |
| metrics = dataset.get_metrics( | |
| batch_data_collect['depth_gt'], | |
| result, | |
| seg_image=batch_data_collect.get('seg_image', None), | |
| disp_gt_edges=batch_data.get('boundary', None), | |
| image_hr=batch_data.get('image_hr', None)) | |
| results.extend([metrics]) | |
| if self.runner_info.rank == 0: | |
| batch_size = len(result) * world_size | |
| for _ in range(batch_size): | |
| prog_bar.update() | |
| if batch_data_collect.get('depth_gt', None) is not None: | |
| results = collect_results_gpu(results, len(dataset)) | |
| if self.runner_info.rank == 0: | |
| ret_dict = dataset.evaluate(results) | |