Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import os | |
| import torch | |
| import traceback | |
| class ModelsManager(object): | |
| def __init__(self, logger, PROD, device="cpu"): | |
| super(ModelsManager, self).__init__() | |
| self.models_bank = {} | |
| self.logger = logger | |
| self.PROD = PROD | |
| self.device_label = device | |
| self.device = torch.device(device) | |
| def init_model (self, model_key, instance_index=0): | |
| model_key = model_key.lower() | |
| try: | |
| if model_key in list(self.models_bank.keys()) and instance_index in self.models_bank[model_key].keys() and self.models_bank[model_key][instance_index].isReady: | |
| return | |
| self.logger.info(f'ModelsManager: Initializing model: {model_key}') | |
| if model_key=="hifigan": | |
| from python.hifigan.model import HiFi_GAN | |
| if model_key not in self.models_bank.keys(): | |
| self.models_bank[model_key] = {} | |
| self.models_bank[model_key][instance_index] = HiFi_GAN(self.logger, self.PROD, self.device, self) | |
| elif model_key=="big_waveglow": | |
| from python.big_waveglow.model import BIG_WaveGlow | |
| if model_key not in self.models_bank.keys(): | |
| self.models_bank[model_key] = {} | |
| self.models_bank[model_key][instance_index] = BIG_WaveGlow(self.logger, self.PROD, self.device, self) | |
| elif model_key=="256_waveglow": | |
| from python.waveglow.model import WaveGlow | |
| if model_key not in self.models_bank.keys(): | |
| self.models_bank[model_key] = {} | |
| self.models_bank[model_key][instance_index] = WaveGlow(self.logger, self.PROD, self.device, self) | |
| elif model_key=="fastpitch": | |
| from python.fastpitch.model import FastPitch | |
| if model_key not in self.models_bank.keys(): | |
| self.models_bank[model_key] = {} | |
| self.models_bank[model_key][instance_index] = FastPitch(self.logger, self.PROD, self.device, self) | |
| elif model_key=="fastpitch1_1": | |
| from python.fastpitch1_1.model import FastPitch1_1 | |
| if model_key not in self.models_bank.keys(): | |
| self.models_bank[model_key] = {} | |
| self.models_bank[model_key][instance_index] = FastPitch1_1(self.logger, self.PROD, self.device, self) | |
| elif model_key=="xvapitch": | |
| from python.xvapitch.model import xVAPitch | |
| if model_key not in self.models_bank.keys(): | |
| self.models_bank[model_key] = {} | |
| self.models_bank[model_key][instance_index] = xVAPitch(self.logger, self.PROD, self.device, self) | |
| elif model_key=="s2s_fastpitch1_1": | |
| from python.fastpitch1_1.model import FastPitch1_1 as S2S_FastPitch1_1 | |
| if model_key not in self.models_bank.keys(): | |
| self.models_bank[model_key] = {} | |
| self.models_bank[model_key][instance_index] = S2S_FastPitch1_1(self.logger, self.PROD, self.device, self) | |
| elif model_key=="wav2vec2": | |
| from python.wav2vec2.model import Wav2Vec2 | |
| if model_key not in self.models_bank.keys(): | |
| self.models_bank[model_key] = {} | |
| self.models_bank[model_key][instance_index] = Wav2Vec2(self.logger, self.PROD, self.device, self) | |
| elif model_key=="speaker_rep": | |
| from python.xvapitch.speaker_rep.model import ResNetSpeakerEncoder | |
| if model_key not in self.models_bank.keys(): | |
| self.models_bank[model_key] = {} | |
| self.models_bank[model_key][instance_index] = ResNetSpeakerEncoder(self.logger, self.PROD, self.device, self) | |
| elif model_key=="nuwave2": | |
| from python.nuwave2.model import Nuwave2Model | |
| if model_key not in self.models_bank.keys(): | |
| self.models_bank[model_key] = {} | |
| self.models_bank[model_key][instance_index] = Nuwave2Model(self.logger, self.PROD, self.device, self) | |
| elif model_key=="deepfilternet2": | |
| from python.deepfilternet2.model import DeepFilter2Model | |
| if model_key not in self.models_bank.keys(): | |
| self.models_bank[model_key] = {} | |
| self.models_bank[model_key][instance_index] = DeepFilter2Model(self.logger, self.PROD, self.device, self) | |
| else: | |
| raise(f'Model not recognized: {model_key}') | |
| try: | |
| if model_key not in self.models_bank.keys(): | |
| self.models_bank[model_key] = {} | |
| self.models_bank[model_key][instance_index].model = self.models_bank[model_key][instance_index].model.to(self.device) | |
| except: | |
| pass | |
| try: | |
| if model_key not in self.models_bank.keys(): | |
| self.models_bank[model_key] = {} | |
| self.models_bank[model_key][instance_index] = self.models_bank[model_key][instance_index].to(self.device) | |
| except: | |
| pass | |
| except: | |
| self.logger.info(traceback.format_exc()) | |
| def load_model (self, model_key, ckpt_path, instance_index=0, **kwargs): | |
| if model_key not in self.models_bank.keys() or instance_index not in self.models_bank[model_key].keys(): | |
| self.init_model(model_key, instance_index) | |
| if not os.path.exists(ckpt_path): | |
| self.logger.error('Checkpoint not found!') | |
| raise FileNotFoundError() | |
| if self.models_bank[model_key][instance_index].ckpt_path != ckpt_path: | |
| self.logger.info(f'ModelsManager: Loading model checkpoint: {model_key}, {ckpt_path}') | |
| ckpt = torch.load(ckpt_path, map_location="cpu") | |
| try: | |
| self.models_bank[model_key][instance_index].load_checkpoint(ckpt_path, ckpt, **kwargs) | |
| except: | |
| self.models_bank[model_key][instance_index].load_state_dict(ckpt_path, ckpt, **kwargs) | |
| def set_device (self, device, instance_index=0): | |
| if device=="gpu": | |
| device = "cuda:0" | |
| if self.device_label==device: | |
| return | |
| self.device_label = device | |
| self.device = torch.device(device) | |
| self.logger.info(f'ModelsManager: Changing device to: {device}') | |
| for model_key in list(self.models_bank.keys()): | |
| self.models_bank[model_key][instance_index].set_device(self.device) | |
| def models (self, key, instance_index=0): | |
| if key.lower() not in self.models_bank.keys() or instance_index not in self.models_bank[key.lower()].keys(): | |
| self.init_model(key.lower(), instance_index=instance_index) | |
| return self.models_bank[key.lower()][instance_index] | |