Spaces:
Runtime error
Runtime error
| import logging | |
| from pathlib import Path | |
| from typing import Optional, Union | |
| import torch | |
| from diffusers import DiffusionPipeline | |
| from safetensors.torch import load_file | |
| from torch import Tensor | |
| from animatediff import get_dir | |
| EMBED_DIR = get_dir("data").joinpath("embeddings") | |
| EMBED_DIR_SDXL = get_dir("data").joinpath("sdxl_embeddings") | |
| EMBED_EXTS = [".pt", ".pth", ".bin", ".safetensors"] | |
| logger = logging.getLogger(__name__) | |
| def scan_text_embeddings(is_sdxl=False) -> list[Path]: | |
| embed_dir=EMBED_DIR_SDXL if is_sdxl else EMBED_DIR | |
| return [x for x in embed_dir.rglob("**/*") if x.is_file() and x.suffix.lower() in EMBED_EXTS] | |
| def get_text_embeddings(return_tensors: bool = True, is_sdxl:bool = False) -> dict[str, Union[Tensor, Path]]: | |
| embed_dir=EMBED_DIR_SDXL if is_sdxl else EMBED_DIR | |
| embeds = {} | |
| skipped = {} | |
| path: Path | |
| for path in scan_text_embeddings(is_sdxl): | |
| if path.stem not in embeds: | |
| # new token/name, add it | |
| logger.debug(f"Found embedding token {path.stem} at {path.relative_to(embed_dir)}") | |
| embeds[path.stem] = path | |
| else: | |
| # duplicate token/name, skip it | |
| skipped[path.stem] = path | |
| logger.debug(f"Duplicate embedding token {path.stem} at {path.relative_to(embed_dir)}") | |
| # warn the user if there are duplicates we skipped | |
| if skipped: | |
| logger.warn(f"Skipped {len(skipped)} embeddings with duplicate tokens!") | |
| logger.warn(f"Skipped paths: {[x.relative_to(embed_dir) for x in skipped.values()]}") | |
| logger.warn("Rename these files to avoid collisions!") | |
| # we can optionally return the tensors instead of the paths | |
| if return_tensors: | |
| # load the embeddings | |
| embeds = {k: load_embed_weights(v) for k, v in embeds.items()} | |
| # filter out the ones that failed to load | |
| loaded_embeds = {k: v for k, v in embeds.items() if v is not None} | |
| if len(loaded_embeds) != len(embeds): | |
| logger.warn(f"Failed to load {len(embeds) - len(loaded_embeds)} embeddings!") | |
| logger.warn(f"Skipped embeddings: {[x for x in embeds.keys() if x not in loaded_embeds]}") | |
| # return a dict of {token: path | embedding} | |
| return embeds | |
| def load_embed_weights(path: Path, key: Optional[str] = None) -> Optional[Tensor]: | |
| """Load an embedding from a file. | |
| Accepts an optional key to load a specific embedding from a file with multiple embeddings, otherwise | |
| it will try to load the first one it finds. | |
| """ | |
| if not path.exists() and path.is_file(): | |
| raise ValueError(f"Embedding path {path} does not exist or is not a file!") | |
| try: | |
| if path.suffix.lower() == ".safetensors": | |
| state_dict = load_file(path, device="cpu") | |
| elif path.suffix.lower() in EMBED_EXTS: | |
| state_dict = torch.load(path, weights_only=True, map_location="cpu") | |
| except Exception: | |
| logger.error(f"Failed to load embedding {path}", exc_info=True) | |
| return None | |
| embedding = None | |
| if len(state_dict) == 1: | |
| logger.debug(f"Found single key in {path.stem}, using it") | |
| embedding = next(iter(state_dict.values())) | |
| elif key is not None and key in state_dict: | |
| logger.debug(f"Using passed key {key} for {path.stem}") | |
| embedding = state_dict[key] | |
| elif "string_to_param" in state_dict: | |
| logger.debug(f"A1111 style embedding found for {path.stem}") | |
| embedding = next(iter(state_dict["string_to_param"].values())) | |
| else: | |
| # we couldn't find the embedding key, warn the user and just use the first key that's a Tensor | |
| logger.warn(f"Could not find embedding key in {path.stem}!") | |
| logger.warn("Taking a wild guess and using the first Tensor we find...") | |
| for key, value in state_dict.items(): | |
| if torch.is_tensor(value): | |
| embedding = value | |
| logger.warn(f"Using key: {key}") | |
| break | |
| return embedding | |
| def load_text_embeddings( | |
| pipeline: DiffusionPipeline, text_embeds: Optional[tuple[str, torch.Tensor]] = None, is_sdxl = False | |
| ) -> None: | |
| if text_embeds is None: | |
| text_embeds = get_text_embeddings(False, is_sdxl) | |
| if len(text_embeds) < 1: | |
| logger.info("No TI embeddings found") | |
| return | |
| logger.info(f"Loading {len(text_embeds)} TI embeddings...") | |
| loaded, skipped, failed = [], [], [] | |
| if True: | |
| vocab = pipeline.tokenizer.get_vocab() # get the tokenizer vocab so we can skip loaded embeddings | |
| for token, emb_path in text_embeds.items(): | |
| try: | |
| if token not in vocab: | |
| if is_sdxl: | |
| embed = load_embed_weights(emb_path, "clip_g").to(pipeline.text_encoder_2.device) | |
| pipeline.load_textual_inversion(embed, token=token, text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) | |
| embed = load_embed_weights(emb_path, "clip_l").to(pipeline.text_encoder.device) | |
| pipeline.load_textual_inversion(embed, token=token, text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) | |
| else: | |
| embed = load_embed_weights(emb_path).to(pipeline.text_encoder.device) | |
| pipeline.load_textual_inversion({token: embed}) | |
| logger.debug(f"Loaded embedding '{token}'") | |
| loaded.append(token) | |
| else: | |
| logger.debug(f"Skipping embedding '{token}' (already loaded)") | |
| skipped.append(token) | |
| except Exception: | |
| logger.error(f"Failed to load TI embedding: {token}", exc_info=True) | |
| failed.append(token) | |
| else: | |
| vocab = pipeline.tokenizer.get_vocab() # get the tokenizer vocab so we can skip loaded embeddings | |
| for token, embed in text_embeds.items(): | |
| try: | |
| if token not in vocab: | |
| if is_sdxl: | |
| pipeline.load_textual_inversion(text_encoder_sd, token=token, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer) | |
| else: | |
| pipeline.load_textual_inversion({token: embed}) | |
| logger.debug(f"Loaded embedding '{token}'") | |
| loaded.append(token) | |
| else: | |
| logger.debug(f"Skipping embedding '{token}' (already loaded)") | |
| skipped.append(token) | |
| except Exception: | |
| logger.error(f"Failed to load TI embedding: {token}", exc_info=True) | |
| failed.append(token) | |
| # Print a summary of what we loaded | |
| logger.info(f"Loaded {len(loaded)} embeddings, {len(skipped)} existing, {len(failed)} failed") | |
| logger.info(f"Available embeddings: {', '.join(loaded + skipped)}") | |
| if len(failed) > 0: | |
| # only print failed if there were failures | |
| logger.warn(f"Failed to load embeddings: {', '.join(failed)}") | |