Spaces:
Runtime error
Runtime error
| import time | |
| import torch | |
| from mmengine.dist import (broadcast, get_dist_info, init_dist, is_distributed, get_local_rank) | |
| from mmengine.utils.dl_utils import (set_multi_processing) | |
| def setup_env(env_cfg, distributed, launcher): | |
| """Setup environment. | |
| An example of ``env_cfg``:: | |
| env_cfg = dict( | |
| cudnn_benchmark=True, | |
| mp_cfg=dict( | |
| mp_start_method='fork', | |
| opencv_num_threads=0 | |
| ), | |
| dist_cfg=dict(backend='nccl', timeout=1800), | |
| resource_limit=4096 | |
| ) | |
| Args: | |
| env_cfg (dict): Config for setting environment. | |
| """ | |
| if env_cfg.get('cudnn_benchmark'): | |
| torch.backends.cudnn.benchmark = True | |
| mp_cfg: dict = env_cfg.get('mp_cfg', {}) | |
| set_multi_processing(**mp_cfg, distributed=distributed) | |
| # init distributed env first, since logger depends on the dist info. | |
| if distributed and not is_distributed(): | |
| dist_cfg: dict = env_cfg.get('dist_cfg', {}) | |
| init_dist(launcher, **dist_cfg) | |
| _rank, _world_size = get_dist_info() | |
| # _local_rank = get_local_rank() | |
| timestamp = torch.tensor(time.time(), dtype=torch.float64) | |
| # broadcast timestamp from 0 process to other processes | |
| broadcast(timestamp) | |
| _timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime(timestamp.item())) | |
| return _rank, _world_size, _timestamp | |