Spaces:
Runtime error
Runtime error
| import os | |
| import os.path as osp | |
| import argparse | |
| import torch | |
| import torch.nn as nn | |
| import wandb | |
| from torch.utils.data import DataLoader | |
| from mmengine.utils import mkdir_or_exist | |
| from mmengine.config import Config, DictAction | |
| from mmengine.logging import MMLogger | |
| from estimator.utils import RunnerInfo, setup_env, log_env, fix_random_seed | |
| from estimator.models.builder import build_model | |
| from estimator.datasets.builder import build_dataset | |
| from estimator.trainer import Trainer | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description='Train a segmentor') | |
| parser.add_argument('config', help='train config file path') | |
| parser.add_argument('--work-dir', help='the dir to save logs and models') | |
| parser.add_argument( | |
| '--resume', | |
| action='store_true', | |
| default=False, | |
| help='resume from the latest checkpoint in the work_dir automatically') | |
| parser.add_argument( | |
| '--debug', | |
| action='store_true', | |
| default=False, | |
| help='debug mode') | |
| parser.add_argument( | |
| '--log-name', | |
| type=str, default='', | |
| help='log_name for wandb') | |
| parser.add_argument( | |
| '--tags', | |
| type=str, default='', | |
| help='tags for wandb') | |
| parser.add_argument( | |
| '--amp', | |
| action='store_true', | |
| default=False, | |
| help='enable automatic-mixed-precision training') | |
| parser.add_argument( | |
| '--seed', | |
| type=int, default=621, | |
| help='for debug') | |
| parser.add_argument( | |
| '--cfg-options', | |
| nargs='+', | |
| action=DictAction, | |
| help='override some settings in the used config, the key-value pair ' | |
| 'in xxx=yyy format will be merged into config file. If the value to ' | |
| 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' | |
| 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' | |
| 'Note that the quotation marks are necessary and that no white space ' | |
| 'is allowed.') | |
| parser.add_argument( | |
| '--launcher', | |
| choices=['none', 'pytorch', 'slurm', 'mpi'], | |
| default='none', | |
| help='job launcher') | |
| # When using PyTorch version >= 2.0.0, the `torch.distributed.launch` | |
| # will pass the `--local-rank` parameter to `tools/train.py` instead | |
| # of `--local_rank`. | |
| parser.add_argument('--local_rank', '--local-rank', type=int, default=0) | |
| args = parser.parse_args() | |
| if 'LOCAL_RANK' not in os.environ: | |
| os.environ['LOCAL_RANK'] = str(args.local_rank) | |
| return args | |
| def main(): | |
| args = parse_args() | |
| # if args.debug: | |
| # torch.autograd.set_detect_anomaly(True) # for debug | |
| # load config | |
| cfg = Config.fromfile(args.config) | |
| cfg.launcher = args.launcher | |
| if args.cfg_options is not None: | |
| cfg.merge_from_dict(args.cfg_options) | |
| # work_dir is determined in this priority: CLI > segment in file > filename | |
| cfg.work_dir = args.work_dir | |
| cfg.work_dir = osp.join(cfg.work_dir, args.log_name) | |
| mkdir_or_exist(cfg.work_dir) | |
| cfg.debug = args.debug | |
| cfg.log_name = args.log_name | |
| tags = args.tags | |
| if ',' in tags: | |
| tag_list = tags.split(',') | |
| else: | |
| tag_list = [tags] | |
| cfg.tags = tag_list | |
| # fix seed | |
| seed = args.seed | |
| fix_random_seed(seed) | |
| # start dist training | |
| if cfg.launcher == 'none': | |
| distributed = False | |
| else: | |
| distributed = True | |
| env_cfg = cfg.get('env_cfg', dict(dist_cfg=dict(backend='nccl'))) | |
| rank, world_size, timestamp = setup_env(env_cfg, distributed, cfg.launcher) | |
| # prepare basic text logger | |
| log_file = osp.join(cfg.work_dir, f'{timestamp}.log') | |
| log_cfg = dict(log_level='INFO', log_file=log_file) | |
| log_cfg.setdefault('name', timestamp) | |
| log_cfg.setdefault('logger_name', 'patchstitcher') | |
| # `torch.compile` in PyTorch 2.0 could close all user defined handlers | |
| # unexpectedly. Using file mode 'a' can help prevent abnormal | |
| # termination of the FileHandler and ensure that the log file could | |
| # be continuously updated during the lifespan of the runner. | |
| log_cfg.setdefault('file_mode', 'a') | |
| logger = MMLogger.get_instance(**log_cfg) | |
| # save some information useful during the training | |
| runner_info = RunnerInfo() | |
| runner_info.config = cfg # ideally, cfg should not be changed during process. information should be temp saved in runner_info | |
| runner_info.logger = logger # easier way: use print_log("infos", logger='current') | |
| runner_info.rank = rank | |
| runner_info.distributed = distributed | |
| runner_info.launcher = cfg.launcher | |
| runner_info.seed = seed | |
| runner_info.world_size = world_size | |
| runner_info.work_dir = cfg.work_dir | |
| runner_info.timestamp = timestamp | |
| # start wandb | |
| if runner_info.rank == 0 and cfg.debug == False: | |
| wandb.init( | |
| project=cfg.project, | |
| name=cfg.log_name+"_"+runner_info.timestamp, | |
| tags=cfg.tags, | |
| dir=runner_info.work_dir, | |
| config=cfg, # have a test | |
| settings=wandb.Settings(start_method="fork")) | |
| wandb.define_metric("Val/step") | |
| wandb.define_metric("Val/*", step_metric="Val/step") | |
| wandb.define_metric("Train/step") | |
| wandb.define_metric("Train/*", step_metric="Train/step") | |
| log_env(cfg, env_cfg, runner_info, logger) | |
| # resume training (future) | |
| cfg.resume = args.resume | |
| # build model | |
| model = build_model(cfg.model) | |
| if runner_info.distributed: | |
| torch.cuda.set_device(runner_info.rank) | |
| if cfg.get('convert_syncbn', False): | |
| model = nn.SyncBatchNorm.convert_sync_batchnorm(model) | |
| model = model.cuda(runner_info.rank) | |
| model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[runner_info.rank], output_device=runner_info.rank, | |
| find_unused_parameters=cfg.get('find_unused_parameters', False)) | |
| logger.info(model) | |
| else: | |
| model = model.cuda(runner_info.rank) | |
| logger.info(model) | |
| # build dataloader | |
| dataset = build_dataset(cfg.train_dataloader.dataset) | |
| if runner_info.distributed: | |
| train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) | |
| else: | |
| train_sampler = None | |
| train_dataloader = DataLoader( | |
| dataset, | |
| batch_size=cfg.train_dataloader.batch_size, | |
| shuffle=(train_sampler is None), | |
| num_workers=cfg.train_dataloader.num_workers, | |
| pin_memory=True, | |
| persistent_workers=True, | |
| sampler=train_sampler) | |
| dataset = build_dataset(cfg.val_dataloader.dataset) | |
| if runner_info.distributed: | |
| val_sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=False) | |
| else: | |
| val_sampler = None | |
| val_dataloader = DataLoader( | |
| dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=cfg.val_dataloader.num_workers, | |
| pin_memory=True, | |
| persistent_workers=True, | |
| sampler=val_sampler) | |
| # everything is ready, start training. But before that, save your config! | |
| cfg.dump(osp.join(cfg.work_dir, 'config.py')) | |
| # build trainer | |
| trainer = Trainer( | |
| config=cfg, | |
| runner_info=runner_info, | |
| train_sampler=train_sampler, | |
| train_dataloader=train_dataloader, | |
| val_dataloader=val_dataloader, | |
| model=model) | |
| trainer.run() | |
| wandb.finish() | |
| if __name__ == '__main__': | |
| main() |