Spaces:
Running
on
Zero
Running
on
Zero
| from typing import * | |
| from torch.nn import Parameter | |
| from torch.optim import Optimizer | |
| from torch.optim.lr_scheduler import LRScheduler | |
| from torch import optim | |
| from torch.optim import lr_scheduler | |
| from diffusers.optimization import get_scheduler | |
| from src.models.elevest import ElevEst | |
| from src.models.gsrecon import GSRecon | |
| from src.models.gsvae import GSAutoencoderKL | |
| def get_optimizer(name: str, params: Parameter, **kwargs) -> Optimizer: | |
| if name == "adamw": | |
| return optim.AdamW(params=params, **kwargs) | |
| else: | |
| raise NotImplementedError(f"Not implemented optimizer: {name}") | |
| def get_lr_scheduler(name: str, optimizer: Optimizer, **kwargs) -> LRScheduler: | |
| if name == "one_cycle": | |
| return lr_scheduler.OneCycleLR( | |
| optimizer, | |
| max_lr=kwargs["max_lr"], | |
| total_steps=kwargs["total_steps"], | |
| pct_start=kwargs["pct_start"], | |
| ) | |
| elif name == "cosine_warmup": | |
| return get_scheduler( | |
| "cosine", optimizer, | |
| num_warmup_steps=kwargs["num_warmup_steps"], | |
| num_training_steps=kwargs["total_steps"], | |
| ) | |
| elif name == "constant_warmup": | |
| return get_scheduler( | |
| "constant_with_warmup", optimizer, | |
| num_warmup_steps=kwargs["num_warmup_steps"], | |
| num_training_steps=kwargs["total_steps"], | |
| ) | |
| elif name == "constant": | |
| return lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda _: 1) | |
| elif name == "linear_decay": | |
| return lr_scheduler.LambdaLR( | |
| optimizer=optimizer, | |
| lr_lambda=lambda epoch: max(0., 1. - epoch / kwargs["total_epochs"]), | |
| ) | |
| else: | |
| raise NotImplementedError(f"Not implemented lr scheduler: {name}") | |