Spaces:
Runtime error
Runtime error
| import os | |
| import wandb | |
| import numpy as np | |
| import torch | |
| import mmengine | |
| from mmengine.optim import build_optim_wrapper | |
| import torch.optim as optim | |
| import matplotlib.pyplot as plt | |
| import torch.distributed as dist | |
| from mmengine.dist import get_dist_info, collect_results_cpu, collect_results_gpu | |
| from mmengine import print_log | |
| import torch.nn.functional as F | |
| from tqdm import tqdm | |
| from estimator.utils import colorize | |
| class Trainer: | |
| """ | |
| Trainer class | |
| """ | |
| def __init__( | |
| self, | |
| config, | |
| runner_info, | |
| train_sampler, | |
| train_dataloader, | |
| val_dataloader, | |
| model): | |
| self.config = config | |
| self.runner_info = runner_info | |
| self.train_sampler = train_sampler | |
| self.train_dataloader = train_dataloader | |
| self.val_dataloader = val_dataloader | |
| self.model = model | |
| # build opt and schedule | |
| self.optimizer_wrapper = build_optim_wrapper(self.model, config.optim_wrapper) | |
| self.scheduler = optim.lr_scheduler.OneCycleLR( | |
| self.optimizer_wrapper.optimizer, [l['lr'] for l in self.optimizer_wrapper.optimizer.param_groups], epochs=self.config.train_cfg.max_epochs, steps_per_epoch=len(self.train_dataloader), | |
| cycle_momentum=config.param_scheduler.cycle_momentum, base_momentum=config.param_scheduler.get('base_momentum', 0.85), max_momentum=config.param_scheduler.get('max_momentum', 0.95), | |
| div_factor=config.param_scheduler.div_factor, final_div_factor=config.param_scheduler.final_div_factor, pct_start=config.param_scheduler.pct_start, three_phase=config.param_scheduler.three_phase) | |
| # I'd like use wandb log_name | |
| self.train_step = 0 # for training | |
| self.val_step = 0 # for validation | |
| self.iters_per_train_epoch = len(self.train_dataloader) | |
| self.iters_per_val_epoch = len(self.val_dataloader) | |
| self.grad_scaler = torch.cuda.amp.GradScaler() | |
| self.collect_input_args = config.collect_input_args | |
| print_log('successfully init trainer', logger='current') | |
| def log_images(self, log_dict, prefix="", scalar_cmap="turbo_r", min_depth=1e-3, max_depth=80, step=0): | |
| # Custom log images. Please add more items to the log dict returned from the model | |
| wimages = dict() | |
| wimages['{}/step'.format(prefix)] = step | |
| rgb = log_dict.get('rgb')[0] | |
| _, h_rgb, w_rgb = rgb.shape | |
| if 'depth_pred' in log_dict.keys(): | |
| depth_pred = log_dict.get('depth_pred')[0] | |
| depth_pred = depth_pred.squeeze() | |
| depth_gt = log_dict.get('depth_gt')[0] | |
| depth_gt = depth_gt.squeeze() | |
| invalid_mask = torch.logical_or(depth_gt<=min_depth, depth_gt>=max_depth).detach().cpu().squeeze().numpy() # (h, w) | |
| if np.sum(np.logical_not(invalid_mask)) == 0: # all pixels in gt are invalid | |
| return | |
| depth_gt_color = colorize(depth_gt, vmin=None, vmax=None, invalid_mask=invalid_mask, cmap=scalar_cmap) | |
| depth_pred_color = colorize(depth_pred, vmin=None, vmax=None) | |
| depth_gt_img = wandb.Image(depth_gt_color, caption='depth_gt') | |
| depth_pred_img = wandb.Image(depth_pred_color, caption='depth_pred') | |
| rgb = wandb.Image(rgb, caption='rgb') | |
| wimages['{}/LogImageDepth'.format(prefix)] = [rgb, depth_gt_img, depth_pred_img] | |
| if 'seg_pred' in log_dict.keys(): | |
| seg_pred = log_dict.get('seg_pred')[0] | |
| seg_pred = seg_pred.squeeze() | |
| seg_gt = log_dict.get('seg_gt')[0] | |
| seg_gt = seg_gt.squeeze() | |
| # class_labels = {0: "good", 1: "refine", 2: "oor", 3: "sky"} | |
| class_labels = {0: "bg", 1: "edge"} | |
| mask_img = wandb.Image( | |
| rgb, | |
| masks={ | |
| "predictions": {"mask_data": seg_pred.detach().cpu().numpy(), "class_labels": class_labels}, | |
| "ground_truth": {"mask_data": seg_gt.detach().cpu().numpy(), "class_labels": class_labels}, | |
| }, | |
| caption='segmentation') | |
| wimages['{}/LogImageSeg'.format(prefix)] = [mask_img] | |
| if 'mask' in log_dict.keys(): | |
| mask = log_dict.get('mask')[0] | |
| mask = mask.squeeze().float()*255 | |
| mask_img = wandb.Image( | |
| mask.detach().cpu().numpy(), | |
| caption='segmentation') | |
| cur_log = wimages['{}/LogImageDepth'.format(prefix)] | |
| cur_log.append(mask_img) | |
| wimages['{}/LogImageDepth'.format(prefix)] = cur_log | |
| # some other things | |
| if 'pseudo_gt' in log_dict.keys(): | |
| pseudo_gt = log_dict.get('pseudo_gt')[0] | |
| pseudo_gt = pseudo_gt.squeeze() | |
| pseudo_gt_color = colorize(pseudo_gt, vmin=None, vmax=None, cmap=scalar_cmap) | |
| pseudo_gt_img = wandb.Image(pseudo_gt_color, caption='pseudo_gt') | |
| cur_log = wimages['{}/LogImageDepth'.format(prefix)] | |
| cur_log.append(pseudo_gt_img) | |
| # pseudo_gt = log_dict.get('pseudo_gt')[0][0] | |
| # pseudo_gt = pseudo_gt * 255 | |
| # pseudo_gt = pseudo_gt.astype(np.uint8) | |
| # pseudo_gt_img = wandb.Image(pseudo_gt, caption='pseudo_gt') | |
| # cur_log = wimages['{}/LogImageDepth'.format(prefix)] | |
| # cur_log.append(pseudo_gt_img) | |
| wandb.log(wimages) | |
| def collect_input(self, batch_data): | |
| collect_batch_data = dict() | |
| for k, v in batch_data.items(): | |
| if isinstance(v, torch.Tensor): | |
| if k in self.collect_input_args: | |
| collect_batch_data[k] = v.cuda() | |
| return collect_batch_data | |
| def val_epoch(self): | |
| results = [] | |
| results_list = [[] for _ in range(8)] | |
| self.model.eval() | |
| dataset = self.val_dataloader.dataset | |
| loader_indices = self.val_dataloader.batch_sampler | |
| rank, world_size = get_dist_info() | |
| if self.runner_info.rank == 0: | |
| prog_bar = mmengine.utils.ProgressBar(len(dataset)) | |
| for idx, (batch_indices, batch_data) in enumerate(zip(loader_indices, self.val_dataloader)): | |
| self.val_step += 1 | |
| batch_data_collect = self.collect_input(batch_data) | |
| # result, log_dict = self.model(mode='infer', **batch_data_collect) | |
| result, log_dict = self.model(mode='infer', cai_mode='m1', process_num=4, **batch_data_collect) # might use test/val to split cases | |
| if isinstance(result, list): | |
| # in case you have multiple results | |
| for num_res in range(len(result)): | |
| metrics = dataset.get_metrics( | |
| batch_data_collect['depth_gt'], | |
| result[num_res], | |
| disp_gt_edges=batch_data.get('boundary', None), | |
| additional_mask=log_dict.get('mask', None), | |
| image_hr=batch_data.get('image_hr', None)) | |
| results_list[num_res].extend([metrics]) | |
| else: | |
| metrics = dataset.get_metrics( | |
| batch_data_collect['depth_gt'], | |
| result, | |
| seg_image=batch_data_collect.get('seg_image', None), | |
| disp_gt_edges=batch_data.get('boundary', None), | |
| additional_mask=log_dict.get('mask', None), | |
| image_hr=batch_data.get('image_hr', None)) | |
| results.extend([metrics]) | |
| if self.runner_info.rank == 0: | |
| if isinstance(result, list): | |
| batch_size = len(result[0]) * world_size | |
| else: | |
| batch_size = len(result) * world_size | |
| for _ in range(batch_size): | |
| prog_bar.update() | |
| if self.runner_info.rank == 0 and self.config.debug == False and (idx + 1) % self.config.train_cfg.val_log_img_interval == False: | |
| self.log_images(log_dict=log_dict, prefix="Val", min_depth=self.config.model.min_depth, max_depth=self.config.model.max_depth, step=self.val_step) | |
| # collect results from all ranks | |
| if isinstance(result, list): | |
| results_collect = [] | |
| for results in results_list: | |
| results = collect_results_gpu(results, len(dataset)) | |
| results_collect.append(results) | |
| else: | |
| results = collect_results_gpu(results, len(dataset)) | |
| if self.runner_info.rank == 0: | |
| if isinstance(result, list): | |
| for num_refine in range(len(result)): | |
| ret_dict = dataset.evaluate(results_collect[num_refine]) | |
| else: | |
| ret_dict = dataset.evaluate(results) | |
| if self.runner_info.rank == 0 and self.config.debug == False: | |
| wdict = dict() | |
| for k, v in ret_dict.items(): | |
| wdict["Val/{}".format(k)] = v.item() | |
| wdict['Val/step'] = self.val_step | |
| wandb.log(wdict) | |
| torch.cuda.empty_cache() | |
| if self.runner_info.distributed is True: | |
| torch.distributed.barrier() | |
| self.model.train() # avoid changing model state | |
| def train_epoch(self, epoch_idx): | |
| self.model.train() | |
| if self.runner_info.distributed: | |
| dist.barrier() | |
| pbar = tqdm(enumerate(self.train_dataloader), desc=f"Epoch: {epoch_idx + 1}/{self.config.train_cfg.max_epochs}. Loop: Train", | |
| total=self.iters_per_train_epoch) if self.runner_info.rank == 0 else enumerate(self.train_dataloader) | |
| for idx, batch_data in pbar: | |
| self.train_step += 1 | |
| batch_data_collect = self.collect_input(batch_data) | |
| loss_dict, log_dict = self.model(mode='train', **batch_data_collect) | |
| total_loss = loss_dict['total_loss'] | |
| # total_loss = self.grad_scaler.scale(loss_dict['total_loss']) | |
| self.optimizer_wrapper.update_params(total_loss) | |
| self.scheduler.step() | |
| # log something here | |
| if self.runner_info.rank == 0: | |
| log_info = 'Epoch: [{:02d}/{:02d}]'.format(epoch_idx + 1, self.config.train_cfg.max_epochs, idx + 1, len(self.train_dataloader)) | |
| for k, v in loss_dict.items(): | |
| log_info += ' - {}: {:.2f}'.format(k, v.item()) | |
| pbar.set_description(log_info) | |
| if (idx + 1) % self.config.train_cfg.log_interval == 0: | |
| log_info = 'Epoch: [{:02d}/{:02d}] - Step: [{:05d}/{:05d}] - Time: [{}/{}] - Total Loss: {}'.format(epoch_idx + 1, self.config.train_cfg.max_epochs, idx + 1, len(self.train_dataloader), 1, 1, total_loss) | |
| for k, v in loss_dict.items(): | |
| if k != 'total_loss': | |
| log_info += ' - {}: {}'.format(k, v) | |
| print_log(log_info, logger='current') | |
| if self.runner_info.rank == 0 and self.config.debug == False: | |
| wdict = dict() | |
| wdict['Train/total_loss'] = total_loss.item() | |
| wdict['Train/LR'] = self.optimizer_wrapper.get_lr()['lr'][0] | |
| wdict['Train/momentum'] = self.optimizer_wrapper.get_momentum()['momentum'][0] | |
| wdict['Train/step'] = self.train_step | |
| for k, v in loss_dict.items(): | |
| if k != 'total_loss': | |
| if isinstance(v, torch.Tensor): | |
| wdict['Train/{}'.format(k)] = v.item() | |
| else: | |
| wdict['Train/{}'.format(k)] = v | |
| wandb.log(wdict) | |
| if self.runner_info.rank == 0 and self.config.debug == False and (idx + 1) % self.config.train_cfg.train_log_img_interval == False: | |
| self.log_images(log_dict=log_dict, prefix="Train", min_depth=self.config.model.min_depth, max_depth=self.config.model.max_depth, step=self.train_step) | |
| if self.config.train_cfg.val_type == 'iter_base': | |
| if (self.train_step + 1) % self.config.train_cfg.val_interval == 0 and (self.train_step + 1) >= self.config.train_cfg.get('eval_start', 0): | |
| self.val_epoch() | |
| def save_checkpoint(self, epoch_idx): | |
| # As default, the model is wrappered by DDP!!! Hence, even if you're using one gpu, please use dist_train.sh | |
| if hasattr(self.model.module, 'get_save_dict'): | |
| print_log('Saving ckp, but use the inner get_save_dict fuction to get model_dict', logger='current') | |
| # print_log('For saving space. Would you like to save base model several times? :>', logger='current') | |
| model_dict = self.model.module.get_save_dict() | |
| else: | |
| model_dict = self.model.module.state_dict() | |
| checkpoint_dict = { | |
| 'epoch': epoch_idx, | |
| 'model_state_dict': model_dict, | |
| 'optim_state_dict': self.optimizer_wrapper.state_dict(), | |
| 'schedule_state_dict': self.scheduler.state_dict()} | |
| if self.runner_info.rank == 0: | |
| torch.save(checkpoint_dict, os.path.join(self.runner_info.work_dir, 'checkpoint_{:02d}.pth'.format(epoch_idx + 1))) | |
| log_info = 'save checkpoint_{:02d}.pth at {}'.format(epoch_idx + 1, self.runner_info.work_dir) | |
| print_log(log_info, logger='current') | |
| def run(self): | |
| for name, param in self.model.named_parameters(): | |
| if param.requires_grad: | |
| print_log('training param: {}'.format(name), logger='current') | |
| # self.val_epoch() # do you want to debug val step? | |
| for epoch_idx in range(self.config.train_cfg.max_epochs): | |
| if self.runner_info.distributed: | |
| self.train_sampler.set_epoch(epoch_idx) | |
| self.train_epoch(epoch_idx) | |
| if (epoch_idx + 1) % self.config.train_cfg.val_interval == 0 and (epoch_idx + 1) >= self.config.train_cfg.get('eval_start', 0) and self.config.train_cfg.val_type == 'epoch_base': | |
| self.val_epoch() | |
| if (epoch_idx + 1) % self.config.train_cfg.save_checkpoint_interval == 0: | |
| self.save_checkpoint(epoch_idx) | |
| if (epoch_idx + 1) % self.config.train_cfg.get('early_stop_epoch', 9999999) == 0: # Are you using 99999999+ epochs? | |
| print_log('early stop at epoch: {}'.format(epoch_idx), logger='current') | |
| break | |
| if self.config.train_cfg.val_type == 'iter_base': | |
| self.val_epoch() | |