Spaces:
Running
on
Zero
Running
on
Zero
| import warnings | |
| warnings.filterwarnings("ignore") # ignore all warnings | |
| import diffusers.utils.logging as diffusion_logging | |
| diffusion_logging.set_verbosity_error() # ignore diffusers warnings | |
| from typing import * | |
| from torch.nn.parallel import DistributedDataParallel | |
| from accelerate.optimizer import AcceleratedOptimizer | |
| from accelerate.scheduler import AcceleratedScheduler | |
| from accelerate.data_loader import DataLoaderShard | |
| import os | |
| import argparse | |
| import types | |
| import logging | |
| import math | |
| import gc | |
| from tqdm import tqdm | |
| import wandb | |
| import torch | |
| import accelerate | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger as get_accelerate_logger | |
| from accelerate import DataLoaderConfiguration, DeepSpeedPlugin | |
| from lpips import LPIPS | |
| from src.options import opt_dict | |
| from src.data import GObjaverseParquetDataset, ParquetChunkDataSource, MultiEpochsChunkedDataLoader, yield_forever | |
| from src.models import GSAutoencoderKL, GSRecon, get_optimizer, get_lr_scheduler | |
| import src.utils.util as util | |
| import src.utils.vis_util as vis_util | |
| from extensions.diffusers_diffsplat import MyEMAModel | |
| def main(): | |
| PROJECT_NAME = "GSVAE" | |
| parser = argparse.ArgumentParser( | |
| description="Train a AutoencoderKL for 3DGS properties" | |
| ) | |
| parser.add_argument( | |
| "--config_file", | |
| type=str, | |
| required=True, | |
| help="Path to the config file" | |
| ) | |
| parser.add_argument( | |
| "--tag", | |
| type=str, | |
| required=True, | |
| help="Tag that refers to the current experiment" | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="out", | |
| help="Path to the output directory" | |
| ) | |
| parser.add_argument( | |
| "--hdfs_dir", | |
| type=str, | |
| default=None, | |
| help="Path to the HDFS directory to save checkpoints" | |
| ) | |
| parser.add_argument( | |
| "--wandb_token_path", | |
| type=str, | |
| default="wandb/token", | |
| help="Path to the WandB login token" | |
| ) | |
| parser.add_argument( | |
| "--resume_from_iter", | |
| type=int, | |
| default=None, | |
| help="The iteration to load the checkpoint from" | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=0, | |
| help="Seed for the PRNG" | |
| ) | |
| parser.add_argument( | |
| "--offline_wandb", | |
| action="store_true", | |
| help="Use offline WandB for experiment tracking" | |
| ) | |
| parser.add_argument( | |
| "--max_train_steps", | |
| type=int, | |
| default=None, | |
| help="The max iteration step for training" | |
| ) | |
| parser.add_argument( | |
| "--max_val_steps", | |
| type=int, | |
| default=5, | |
| help="The max iteration step for validation" | |
| ) | |
| parser.add_argument( | |
| "--num_workers", | |
| type=int, | |
| default=32, | |
| help="The number of processed spawned by the batch provider" | |
| ) | |
| parser.add_argument( | |
| "--pin_memory", | |
| action="store_true", | |
| help="Pin memory for the data loader" | |
| ) | |
| parser.add_argument( | |
| "--use_ema", | |
| action="store_true", | |
| help="Use EMA model for training" | |
| ) | |
| parser.add_argument( | |
| "--scale_lr", | |
| action="store_true", | |
| help="Scale lr with total batch size (base batch size: 256)" | |
| ) | |
| parser.add_argument( | |
| "--max_grad_norm", | |
| type=float, | |
| default=1., | |
| help="Max gradient norm for gradient clipping" | |
| ) | |
| parser.add_argument( | |
| "--gradient_accumulation_steps", | |
| type=int, | |
| default=1, | |
| help="Number of updates steps to accumulate before performing a backward/update pass" | |
| ) | |
| parser.add_argument( | |
| "--mixed_precision", | |
| type=str, | |
| default="fp16", | |
| choices=["no", "fp16", "bf16"], | |
| help="Type of mixed precision training" | |
| ) | |
| parser.add_argument( | |
| "--allow_tf32", | |
| action="store_true", | |
| help="Enable TF32 for faster training on Ampere GPUs" | |
| ) | |
| parser.add_argument( | |
| "--use_deepspeed", | |
| action="store_true", | |
| help="Use DeepSpeed for training" | |
| ) | |
| parser.add_argument( | |
| "--zero_stage", | |
| type=int, | |
| default=1, | |
| choices=[1, 2, 3], # https://huggingface.co/docs/accelerate/usage_guides/deepspeed | |
| help="ZeRO stage type for DeepSpeed" | |
| ) | |
| parser.add_argument( | |
| "--load_pretrained_gsrecon", | |
| type=str, | |
| default="gsrecon_gobj265k_cnp_even4", | |
| help="Tag of a pretrained GSRecon in this project" | |
| ) | |
| parser.add_argument( | |
| "--load_pretrained_gsrecon_ckpt", | |
| type=int, | |
| default=-1, | |
| help="Iteration of the pretrained GSRecon checkpoint" | |
| ) | |
| parser.add_argument( | |
| "--load_pretrained_model", | |
| type=str, | |
| default=None, | |
| help="Tag of the model pretrained in this project" | |
| ) | |
| parser.add_argument( | |
| "--load_pretrained_model_ckpt", | |
| type=int, | |
| default=-1, | |
| help="Iteration of the pretrained model checkpoint" | |
| ) | |
| # Parse the arguments | |
| args, extras = parser.parse_known_args() | |
| # Parse the config file | |
| configs = util.get_configs(args.config_file, extras) # change yaml configs by `extras` | |
| # Parse the option dict | |
| opt = opt_dict[configs["opt_type"]] | |
| if "opt" in configs: | |
| for k, v in configs["opt"].items(): | |
| setattr(opt, k, v) | |
| opt.__post_init__() | |
| # Create an experiment directory using the `tag` | |
| exp_dir = os.path.join(args.output_dir, args.tag) | |
| ckpt_dir = os.path.join(exp_dir, "checkpoints") | |
| os.makedirs(ckpt_dir, exist_ok=True) | |
| if args.hdfs_dir is not None: | |
| args.project_hdfs_dir = args.hdfs_dir | |
| args.hdfs_dir = os.path.join(args.hdfs_dir, args.tag) | |
| os.system(f"hdfs dfs -mkdir -p {args.hdfs_dir}") | |
| # Initialize the logger | |
| logging.basicConfig( | |
| format="%(asctime)s - %(message)s", | |
| datefmt="%Y/%m/%d %H:%M:%S", | |
| level=logging.INFO | |
| ) | |
| logger = get_accelerate_logger(__name__, log_level="INFO") | |
| file_handler = logging.FileHandler(os.path.join(exp_dir, "log.txt")) # output to file | |
| file_handler.setFormatter(logging.Formatter( | |
| fmt="%(asctime)s - %(message)s", | |
| datefmt="%Y/%m/%d %H:%M:%S" | |
| )) | |
| logger.logger.addHandler(file_handler) | |
| logger.logger.propagate = True # propagate to the root logger (console) | |
| # Set DeepSpeed config | |
| if args.use_deepspeed: | |
| deepspeed_plugin = DeepSpeedPlugin( | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| gradient_clipping=args.max_grad_norm, | |
| zero_stage=int(args.zero_stage), | |
| offload_optimizer_device="cpu", # hard-coded here, TODO: make it configurable | |
| ) | |
| else: | |
| deepspeed_plugin = None | |
| # Initialize the accelerator | |
| accelerator = Accelerator( | |
| project_dir=exp_dir, | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| mixed_precision=args.mixed_precision, | |
| split_batches=False, # batch size per GPU | |
| dataloader_config=DataLoaderConfiguration(non_blocking=args.pin_memory), | |
| deepspeed_plugin=deepspeed_plugin, | |
| ) | |
| logger.info(f"Accelerator state:\n{accelerator.state}\n") | |
| # Set the random seed | |
| if args.seed >= 0: | |
| accelerate.utils.set_seed(args.seed) | |
| logger.info(f"You have chosen to seed([{args.seed}]) the experiment [{args.tag}]\n") | |
| # Enable TF32 for faster training on Ampere GPUs | |
| if args.allow_tf32: | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| # Prepare dataset | |
| if accelerator.is_local_main_process: | |
| if not os.path.exists("/tmp/test_dataset"): | |
| os.system(opt.dataset_setup_script) | |
| accelerator.wait_for_everyone() # other processes wait for the main process | |
| # Load the training and validation dataset | |
| assert opt.file_dir_train is not None and opt.file_name_train is not None and \ | |
| opt.file_dir_test is not None and opt.file_name_test is not None | |
| train_dataset = GObjaverseParquetDataset( | |
| data_source=ParquetChunkDataSource(opt.file_dir_train, opt.file_name_train), | |
| shuffle=True, | |
| shuffle_buffer_size=-1, # `-1`: not shuffle actually | |
| chunks_queue_max_size=1, # number of preloading chunks | |
| # GObjaverse | |
| opt=opt, | |
| training=True, | |
| ) | |
| val_dataset = GObjaverseParquetDataset( | |
| data_source=ParquetChunkDataSource(opt.file_dir_test, opt.file_name_test), | |
| shuffle=True, # shuffle for various visualization | |
| shuffle_buffer_size=-1, # `-1`: not shuffle actually | |
| chunks_queue_max_size=1, # number of preloading chunks | |
| # GObjaverse | |
| opt=opt, | |
| training=False, | |
| ) | |
| train_loader = MultiEpochsChunkedDataLoader( | |
| train_dataset, | |
| batch_size=configs["train"]["batch_size_per_gpu"], | |
| num_workers=args.num_workers, | |
| drop_last=True, | |
| pin_memory=args.pin_memory, | |
| ) | |
| val_loader = MultiEpochsChunkedDataLoader( | |
| val_dataset, | |
| batch_size=configs["val"]["batch_size_per_gpu"], | |
| num_workers=args.num_workers, | |
| drop_last=True, | |
| pin_memory=args.pin_memory, | |
| ) | |
| logger.info(f"Load [{len(train_dataset)}] training samples and [{len(val_dataset)}] validation samples\n") | |
| # Compute the effective batch size and scale learning rate | |
| total_batch_size = configs["train"]["batch_size_per_gpu"] * \ | |
| accelerator.num_processes * args.gradient_accumulation_steps | |
| configs["train"]["total_batch_size"] = total_batch_size | |
| if args.scale_lr: | |
| configs["optimizer"]["lr"] *= (total_batch_size / 256) | |
| configs["lr_scheduler"]["max_lr"] = configs["optimizer"]["lr"] | |
| # LPIPS loss | |
| if accelerator.is_main_process: | |
| _ = LPIPS(net="vgg") | |
| del _ | |
| accelerator.wait_for_everyone() # wait for pretrained backbone weights to be downloaded | |
| lpips_loss = LPIPS(net="vgg").to(accelerator.device) | |
| lpips_loss = lpips_loss.requires_grad_(False) | |
| lpips_loss.eval() | |
| # GSRecon | |
| gsrecon = GSRecon(opt).to(accelerator.device) | |
| gsrecon = gsrecon.requires_grad_(False) | |
| gsrecon = gsrecon.eval() | |
| # For DeepSpeed bug: model inputs could be `torch.nn.Module` (e.g., `lpips_loss`) | |
| def is_floating_point(self): | |
| return True | |
| lpips_loss.is_floating_point = types.MethodType(is_floating_point, lpips_loss) | |
| gsrecon.is_floating_point = types.MethodType(is_floating_point, gsrecon) | |
| # Initialize the model, optimizer and lr scheduler | |
| model = GSAutoencoderKL(opt) | |
| if opt.use_tiny_decoder: | |
| model.vae.requires_grad_(False) | |
| else: | |
| model.tiny_decoder.requires_grad_(False) | |
| params_to_optimize = filter(lambda p: p.requires_grad, model.parameters()) | |
| optimizer = get_optimizer(params=params_to_optimize, **configs["optimizer"]) | |
| configs["lr_scheduler"]["total_steps"] = configs["train"]["epochs"] * math.ceil( | |
| len(train_loader) // accelerator.num_processes / args.gradient_accumulation_steps) # only account updated steps | |
| configs["lr_scheduler"]["total_steps"] *= accelerator.num_processes # for lr scheduler setting | |
| if "num_warmup_steps" in configs["lr_scheduler"]: | |
| configs["lr_scheduler"]["num_warmup_steps"] *= accelerator.num_processes # for lr scheduler setting | |
| lr_scheduler = get_lr_scheduler(optimizer=optimizer, **configs["lr_scheduler"]) | |
| configs["lr_scheduler"]["total_steps"] //= accelerator.num_processes # reset for multi-gpu | |
| if "num_warmup_steps" in configs["lr_scheduler"]: | |
| configs["lr_scheduler"]["num_warmup_steps"] //= accelerator.num_processes # reset for multi-gpu | |
| # Load pretrained reconstruction and vae models | |
| if args.load_pretrained_model is not None: | |
| logger.info(f"Load GSVAE checkpoint [{args.load_pretrained_model}] from iteration [{args.load_pretrained_model_ckpt:06d}]\n") | |
| model = util.load_ckpt( | |
| os.path.join(args.output_dir, args.load_pretrained_model, "checkpoints"), | |
| args.load_pretrained_model_ckpt, | |
| None if args.hdfs_dir is None else os.path.join(args.project_hdfs_dir, args.load_pretrained_model), | |
| model, accelerator | |
| ) | |
| logger.info(f"Load GSRecon checkpoint [{args.load_pretrained_gsrecon}] from iteration [{args.load_pretrained_gsrecon_ckpt:06d}]\n") | |
| gsrecon = util.load_ckpt( | |
| os.path.join(args.output_dir, args.load_pretrained_gsrecon, "checkpoints"), | |
| args.load_pretrained_gsrecon_ckpt, | |
| None if args.hdfs_dir is None else os.path.join(args.project_hdfs_dir, args.load_pretrained_gsrecon), | |
| gsrecon, accelerator | |
| ) | |
| # Initialize the EMA model to save moving average states | |
| if args.use_ema: | |
| logger.info("Use exponential moving average (EMA) for model parameters\n") | |
| ema_states = MyEMAModel( | |
| model.parameters(), | |
| **configs["train"]["ema_kwargs"] | |
| ) | |
| ema_states.to(accelerator.device) | |
| # Prepare everything with `accelerator` | |
| model, optimizer, lr_scheduler, train_loader, val_loader = accelerator.prepare( | |
| model, optimizer, lr_scheduler, train_loader, val_loader | |
| ) | |
| # Set classes explicitly for everything | |
| model: DistributedDataParallel | |
| optimizer: AcceleratedOptimizer | |
| lr_scheduler: AcceleratedScheduler | |
| train_loader: DataLoaderShard | |
| val_loader: DataLoaderShard | |
| # Cast input dataset to the appropriate dtype | |
| weight_dtype = torch.float32 | |
| if accelerator.mixed_precision == "fp16": | |
| weight_dtype = torch.float16 | |
| elif accelerator.mixed_precision == "bf16": | |
| weight_dtype = torch.bfloat16 | |
| # Training configs after distribution and accumulation setup | |
| updated_steps_per_epoch = math.ceil(len(train_loader) / args.gradient_accumulation_steps) | |
| total_updated_steps = configs["lr_scheduler"]["total_steps"] | |
| if args.max_train_steps is None: | |
| args.max_train_steps = total_updated_steps | |
| assert configs["train"]["epochs"] * updated_steps_per_epoch == total_updated_steps | |
| logger.info(f"Total batch size: [{total_batch_size}]") | |
| logger.info(f"Learning rate: [{configs['optimizer']['lr']}]") | |
| logger.info(f"Gradient Accumulation steps: [{args.gradient_accumulation_steps}]") | |
| logger.info(f"Total epochs: [{configs['train']['epochs']}]") | |
| logger.info(f"Total steps: [{total_updated_steps}]") | |
| logger.info(f"Steps for updating per epoch: [{updated_steps_per_epoch}]") | |
| logger.info(f"Steps for validation: [{len(val_loader)}]\n") | |
| # (Optional) Load checkpoint | |
| global_update_step = 0 | |
| if args.resume_from_iter is not None: | |
| logger.info(f"Load checkpoint from iteration [{args.resume_from_iter}]\n") | |
| # Download from HDFS | |
| if not os.path.exists(os.path.join(ckpt_dir, f'{args.resume_from_iter:06d}')): | |
| args.resume_from_iter = util.load_ckpt( | |
| ckpt_dir, | |
| args.resume_from_iter, | |
| args.hdfs_dir, | |
| None, # `None`: not load model ckpt here | |
| accelerator, # manage the process states | |
| ) | |
| # Load everything | |
| accelerator.load_state(os.path.join(ckpt_dir, f"{args.resume_from_iter:06d}")) # torch < 2.4.0 here for `weights_only=False` | |
| if args.use_ema: | |
| ema_states.load_state_dict(torch.load( | |
| os.path.join(ckpt_dir, f"{args.resume_from_iter:06d}", "ema_states.pth"), | |
| map_location=accelerator.device | |
| )) | |
| global_update_step = int(args.resume_from_iter) | |
| # Save all experimental parameters and model architecture of this run to a file (args and configs) | |
| if accelerator.is_main_process: | |
| exp_params = util.save_experiment_params(args, configs, opt, exp_dir) | |
| util.save_model_architecture(accelerator.unwrap_model(model), exp_dir) | |
| # WandB logger | |
| if accelerator.is_main_process: | |
| if args.offline_wandb: | |
| os.environ["WANDB_MODE"] = "offline" | |
| with open(args.wandb_token_path, "r") as f: | |
| os.environ["WANDB_API_KEY"] = f.read().strip() | |
| wandb.init( | |
| project=PROJECT_NAME, name=args.tag, | |
| config=exp_params, dir=exp_dir, | |
| resume=True | |
| ) | |
| # Wandb artifact for logging experiment information | |
| arti_exp_info = wandb.Artifact(args.tag, type="exp_info") | |
| arti_exp_info.add_file(os.path.join(exp_dir, "params.yaml")) | |
| arti_exp_info.add_file(os.path.join(exp_dir, "model.txt")) | |
| arti_exp_info.add_file(os.path.join(exp_dir, "log.txt")) # only save the log before training | |
| wandb.log_artifact(arti_exp_info) | |
| # Start training | |
| logger.logger.propagate = False # not propagate to the root logger (console) | |
| progress_bar = tqdm( | |
| range(total_updated_steps), | |
| initial=global_update_step, | |
| desc="Training", | |
| ncols=125, | |
| disable=not accelerator.is_main_process | |
| ) | |
| for batch in yield_forever(train_loader): | |
| if global_update_step == args.max_train_steps: | |
| progress_bar.close() | |
| logger.logger.propagate = True # propagate to the root logger (console) | |
| if accelerator.is_main_process: | |
| wandb.finish() | |
| logger.info("Training finished!\n") | |
| return | |
| model.train() | |
| with accelerator.accumulate(model): | |
| outputs = model(batch, lpips_loss, gsrecon, step=global_update_step+1, dtype=weight_dtype, | |
| use_tiny_decoder=opt.use_tiny_decoder) # `step` starts from 1 | |
| psnr = outputs["psnr"] | |
| ssim = outputs["ssim"] | |
| lpips = outputs["lpips"] | |
| kl = outputs["kl"] | |
| loss = outputs["loss"] | |
| if "coord_mse" in outputs: | |
| coord_mse = outputs["coord_mse"] | |
| else: | |
| coord_mse = None | |
| if "normal_cosim" in outputs: | |
| normal_cosim = outputs["normal_cosim"] | |
| else: | |
| normal_cosim = None | |
| # Backpropagate | |
| accelerator.backward(loss) | |
| if accelerator.sync_gradients: | |
| accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| # Checks if the accelerator has performed an optimization step behind the scenes | |
| if accelerator.sync_gradients: | |
| # Gather the losses across all processes for logging (if we use distributed training) | |
| psnr = accelerator.gather(psnr.detach()).mean() | |
| ssim = accelerator.gather(ssim.detach()).mean() | |
| lpips = accelerator.gather(lpips.detach()).mean() | |
| kl = accelerator.gather(kl.detach()).mean() | |
| loss = accelerator.gather(loss.detach()).mean() | |
| if coord_mse is not None: | |
| coord_mse = accelerator.gather(coord_mse.detach()).mean() | |
| if normal_cosim is not None: | |
| normal_cosim = accelerator.gather(normal_cosim.detach()).mean() | |
| logs = { | |
| "psnr": psnr.item(), | |
| "ssim": ssim.item(), | |
| "lpips": lpips.item(), | |
| "loss": loss.item(), | |
| "lr": lr_scheduler.get_last_lr()[0] | |
| } | |
| if args.use_ema: | |
| ema_states.step(model.parameters()) | |
| logs.update({"ema": ema_states.cur_decay_value}) | |
| progress_bar.set_postfix(**logs) | |
| progress_bar.update(1) | |
| global_update_step += 1 | |
| logger.info( | |
| f"[{global_update_step:06d} / {total_updated_steps:06d}] " + | |
| f"psnr: {logs['psnr']:.4f}, ssim: {logs['ssim']:.4f}, lpips: {logs['lpips']:.4f}, " + | |
| f"loss: {logs['loss']:.4f}, lr: {logs['lr']:.2e}" + | |
| f", ema: {logs['ema']:.4f}" if args.use_ema else "" | |
| ) | |
| # Log the training progress | |
| if global_update_step % configs["train"]["log_freq"] == 0 or global_update_step == 1 \ | |
| or global_update_step % updated_steps_per_epoch == 0: # last step of an epoch | |
| if accelerator.is_main_process: | |
| wandb.log({ | |
| "training/psnr": psnr.item(), | |
| "training/ssim": ssim.item(), | |
| "training/lpips": lpips.item(), | |
| "training/kl": kl.item(), | |
| "training/loss": loss.item(), | |
| "training/lr": lr_scheduler.get_last_lr()[0] | |
| }, step=global_update_step) | |
| if coord_mse is not None: | |
| wandb.log({ | |
| "training/coord_mse": coord_mse.item() | |
| }, step=global_update_step) | |
| if normal_cosim is not None: | |
| wandb.log({ | |
| "training/normal_cosim": normal_cosim.item() | |
| }, step=global_update_step) | |
| if args.use_ema: | |
| wandb.log({ | |
| "training/ema": ema_states.cur_decay_value | |
| }, step=global_update_step) | |
| # Save checkpoint | |
| if (global_update_step % configs["train"]["save_freq"] == 0 # 1. every `save_freq` steps | |
| or global_update_step % (configs["train"]["save_freq_epoch"] * updated_steps_per_epoch) == 0 # 2. every `save_freq_epoch` epochs | |
| or global_update_step == total_updated_steps): # 3. last step of an epoch | |
| gc.collect() | |
| if accelerator.distributed_type == accelerate.utils.DistributedType.DEEPSPEED: | |
| # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues | |
| accelerator.save_state(os.path.join(ckpt_dir, f"{global_update_step:06d}")) | |
| elif accelerator.is_main_process: | |
| accelerator.save_state(os.path.join(ckpt_dir, f"{global_update_step:06d}")) | |
| accelerator.wait_for_everyone() # ensure all processes have finished saving | |
| if accelerator.is_main_process: | |
| if args.use_ema: | |
| torch.save(ema_states.state_dict(), | |
| os.path.join(ckpt_dir, f"{global_update_step:06d}", "ema_states.pth")) | |
| if args.hdfs_dir is not None: | |
| util.save_ckpt(ckpt_dir, global_update_step, args.hdfs_dir) | |
| gc.collect() | |
| # Evaluate on the validation set | |
| if (global_update_step == 1 | |
| or (global_update_step % configs["train"]["early_eval_freq"] == 0 and | |
| global_update_step < configs["train"]["early_eval"]) # 1. more frequently at the beginning | |
| or global_update_step % configs["train"]["eval_freq"] == 0 # 2. every `eval_freq` steps | |
| or global_update_step % (configs["train"]["eval_freq_epoch"] * updated_steps_per_epoch) == 0 # 3. every `eval_freq_epoch` epochs | |
| or global_update_step == total_updated_steps): # 4. last step of an epoch | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Use EMA parameters for evaluation | |
| if args.use_ema: | |
| # Store the UNet parameters temporarily and load the EMA parameters to perform inference | |
| ema_states.store(model.parameters()) | |
| ema_states.copy_to(model.parameters()) | |
| with torch.no_grad(): | |
| with torch.autocast("cuda", torch.bfloat16): | |
| model.eval() | |
| all_val_matrics, val_steps = {}, 0 | |
| val_progress_bar = tqdm( | |
| range(len(val_loader)) if args.max_val_steps is None \ | |
| else range(args.max_val_steps), | |
| desc="Validation", | |
| ncols=125, | |
| disable=not accelerator.is_main_process | |
| ) | |
| for val_batch in val_loader: | |
| val_outputs = model(val_batch, lpips_loss, gsrecon, step=global_update_step, dtype=weight_dtype, | |
| use_tiny_decoder=opt.use_tiny_decoder) | |
| val_psnr = val_outputs["psnr"] | |
| val_ssim = val_outputs["ssim"] | |
| val_lpips = val_outputs["lpips"] | |
| val_kl = val_outputs["kl"] | |
| val_loss = val_outputs["loss"] | |
| if "coord_mse" in val_outputs: | |
| val_coord_mse = val_outputs["coord_mse"] | |
| else: | |
| val_coord_mse = None | |
| if "normal_cosim" in val_outputs: | |
| val_normal_cosim = val_outputs["normal_cosim"] | |
| else: | |
| val_normal_cosim = None | |
| val_psnr = accelerator.gather_for_metrics(val_psnr).mean() | |
| val_ssim = accelerator.gather_for_metrics(val_ssim).mean() | |
| val_lpips = accelerator.gather_for_metrics(val_lpips).mean() | |
| val_kl = accelerator.gather_for_metrics(val_kl).mean() | |
| val_loss = accelerator.gather_for_metrics(val_loss).mean() | |
| if val_coord_mse is not None: | |
| val_coord_mse = accelerator.gather_for_metrics(val_coord_mse).mean() | |
| if val_normal_cosim is not None: | |
| val_normal_cosim = accelerator.gather_for_metrics(val_normal_cosim).mean() | |
| val_logs = { | |
| "psnr": val_psnr.item(), | |
| "ssim": val_ssim.item(), | |
| "lpips": val_lpips.item(), | |
| "loss": val_loss.item() | |
| } | |
| val_progress_bar.set_postfix(**val_logs) | |
| val_progress_bar.update(1) | |
| val_steps += 1 | |
| all_val_matrics.setdefault("psnr", []).append(val_psnr) | |
| all_val_matrics.setdefault("ssim", []).append(val_ssim) | |
| all_val_matrics.setdefault("lpips", []).append(val_lpips) | |
| all_val_matrics.setdefault("kl", []).append(val_kl) | |
| all_val_matrics.setdefault("loss", []).append(val_loss) | |
| if val_coord_mse is not None: | |
| all_val_matrics.setdefault("coord_mse", []).append(val_coord_mse) | |
| if val_normal_cosim is not None: | |
| all_val_matrics.setdefault("normal_cosim", []).append(val_normal_cosim) | |
| if args.max_val_steps is not None and val_steps == args.max_val_steps: | |
| break | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| val_progress_bar.close() | |
| if args.use_ema: | |
| # Switch back to the original model parameters | |
| ema_states.restore(model.parameters()) | |
| for k, v in all_val_matrics.items(): | |
| all_val_matrics[k] = torch.tensor(v).mean() | |
| logger.info( | |
| f"Eval [{global_update_step:06d} / {total_updated_steps:06d}] " + | |
| f"psnr: {all_val_matrics['psnr'].item():.4f}, " + | |
| f"ssim: {all_val_matrics['ssim'].item():.4f}, " + | |
| f"lpips: {all_val_matrics['lpips'].item():.4f}, " + | |
| f"kl: {all_val_matrics['kl'].item():.4f}, " + | |
| f"loss: {all_val_matrics['loss'].item():.4f}\n" | |
| ) | |
| if accelerator.is_main_process: | |
| wandb.log({ | |
| "validation/psnr": all_val_matrics["psnr"].item(), | |
| "validation/ssim": all_val_matrics["ssim"].item(), | |
| "validation/lpips": all_val_matrics["lpips"].item(), | |
| "validation/kl": all_val_matrics["kl"].item(), | |
| "validation/loss": all_val_matrics["loss"].item() | |
| }, step=global_update_step) | |
| if "coord_mse" in all_val_matrics: | |
| wandb.log({ | |
| "validation/coord_mse": all_val_matrics["coord_mse"].item() | |
| }, step=global_update_step) | |
| if "normal_cosim" in all_val_matrics: | |
| wandb.log({ | |
| "validation/normal_cosim": all_val_matrics["normal_cosim"].item() | |
| }, step=global_update_step) | |
| # Visualize rendering | |
| wandb.log({ | |
| "images/training": vis_util.wandb_mvimage_log(outputs) | |
| }, step=global_update_step) | |
| wandb.log({ | |
| "images/validation": vis_util.wandb_mvimage_log(val_outputs) | |
| }, step=global_update_step) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| if __name__ == "__main__": | |
| main() | |