Spaces:
Sleeping
Sleeping
| import copy | |
| import os | |
| from pathlib import Path | |
| import torch | |
| from timm.models import create_model | |
| from torchvision.models import get_model | |
| from models import pdiscoformer_vit_bb, pdisconet_vit_bb, pdisconet_resnet_torchvision_bb | |
| from models.individual_landmark_resnet import IndividualLandmarkResNet | |
| from models.individual_landmark_convnext import IndividualLandmarkConvNext | |
| from models.individual_landmark_vit import IndividualLandmarkViT | |
| from utils import load_state_dict_pdisco | |
| def load_model_arch(args, num_cls): | |
| """ | |
| Function to load the model | |
| :param args: Arguments from the command line | |
| :param num_cls: Number of classes in the dataset | |
| :return: | |
| """ | |
| if 'resnet' in args.model_arch: | |
| num_layers_split = [int(s) for s in args.model_arch if s.isdigit()] | |
| num_layers = int(''.join(map(str, num_layers_split))) | |
| if num_layers >= 100: | |
| timm_model_arch = args.model_arch + ".a1h_in1k" | |
| else: | |
| timm_model_arch = args.model_arch + ".a1_in1k" | |
| if "resnet" in args.model_arch and args.use_torchvision_resnet_model: | |
| weights = "DEFAULT" if args.pretrained_start_weights else None | |
| base_model = get_model(args.model_arch, weights=weights) | |
| elif "resnet" in args.model_arch and not args.use_torchvision_resnet_model: | |
| if args.eval_only: | |
| base_model = create_model( | |
| timm_model_arch, | |
| pretrained=args.pretrained_start_weights, | |
| num_classes=num_cls, | |
| output_stride=args.output_stride, | |
| ) | |
| else: | |
| base_model = create_model( | |
| timm_model_arch, | |
| pretrained=args.pretrained_start_weights, | |
| drop_path_rate=args.drop_path, | |
| num_classes=num_cls, | |
| output_stride=args.output_stride, | |
| ) | |
| elif "convnext" in args.model_arch: | |
| if args.eval_only: | |
| base_model = create_model( | |
| args.model_arch, | |
| pretrained=args.pretrained_start_weights, | |
| num_classes=num_cls, | |
| output_stride=args.output_stride, | |
| ) | |
| else: | |
| base_model = create_model( | |
| args.model_arch, | |
| pretrained=args.pretrained_start_weights, | |
| drop_path_rate=args.drop_path, | |
| num_classes=num_cls, | |
| output_stride=args.output_stride, | |
| ) | |
| elif "vit" in args.model_arch: | |
| if args.eval_only: | |
| base_model = create_model( | |
| args.model_arch, | |
| pretrained=args.pretrained_start_weights, | |
| img_size=args.image_size, | |
| ) | |
| else: | |
| base_model = create_model( | |
| args.model_arch, | |
| pretrained=args.pretrained_start_weights, | |
| drop_path_rate=args.drop_path, | |
| img_size=args.image_size, | |
| ) | |
| vit_patch_size = base_model.patch_embed.proj.kernel_size[0] | |
| if args.image_size % vit_patch_size != 0: | |
| raise ValueError(f"Image size {args.image_size} must be divisible by patch size {vit_patch_size}") | |
| else: | |
| raise ValueError('Model not supported.') | |
| return base_model | |
| def init_pdisco_model(base_model, args, num_cls): | |
| """ | |
| Function to initialize the model | |
| :param base_model: Base model | |
| :param args: Arguments from the command line | |
| :param num_cls: Number of classes in the dataset | |
| :return: | |
| """ | |
| # Initialize the network | |
| if 'convnext' in args.model_arch: | |
| sl_channels = base_model.stages[-1].downsample[-1].in_channels | |
| fl_channels = base_model.head.in_features | |
| model = IndividualLandmarkConvNext(base_model, args.num_parts, num_classes=num_cls, | |
| sl_channels=sl_channels, fl_channels=fl_channels, | |
| part_dropout=args.part_dropout, modulation_type=args.modulation_type, | |
| gumbel_softmax=args.gumbel_softmax, | |
| gumbel_softmax_temperature=args.gumbel_softmax_temperature, | |
| gumbel_softmax_hard=args.gumbel_softmax_hard, | |
| modulation_orth=args.modulation_orth, classifier_type=args.classifier_type, | |
| noise_variance=args.noise_variance) | |
| elif 'resnet' in args.model_arch: | |
| sl_channels = base_model.layer4[0].conv1.in_channels | |
| fl_channels = base_model.fc.in_features | |
| model = IndividualLandmarkResNet(base_model, args.num_parts, num_classes=num_cls, | |
| sl_channels=sl_channels, fl_channels=fl_channels, | |
| use_torchvision_model=args.use_torchvision_resnet_model, | |
| part_dropout=args.part_dropout, modulation_type=args.modulation_type, | |
| gumbel_softmax=args.gumbel_softmax, | |
| gumbel_softmax_temperature=args.gumbel_softmax_temperature, | |
| gumbel_softmax_hard=args.gumbel_softmax_hard, | |
| modulation_orth=args.modulation_orth, classifier_type=args.classifier_type, | |
| noise_variance=args.noise_variance) | |
| elif 'vit' in args.model_arch: | |
| model = IndividualLandmarkViT(base_model, num_landmarks=args.num_parts, num_classes=num_cls, | |
| part_dropout=args.part_dropout, | |
| modulation_type=args.modulation_type, gumbel_softmax=args.gumbel_softmax, | |
| gumbel_softmax_temperature=args.gumbel_softmax_temperature, | |
| gumbel_softmax_hard=args.gumbel_softmax_hard, | |
| modulation_orth=args.modulation_orth, classifier_type=args.classifier_type, | |
| noise_variance=args.noise_variance) | |
| else: | |
| raise ValueError('Model not supported.') | |
| return model | |
| def load_model_pdisco(args, num_cls): | |
| """ | |
| Function to load the model | |
| :param args: Arguments from the command line | |
| :param num_cls: Number of classes in the dataset | |
| :return: | |
| """ | |
| base_model = load_model_arch(args, num_cls) | |
| model = init_pdisco_model(base_model, args, num_cls) | |
| return model | |
| def pdiscoformer_vit(pretrained=True, model_dataset="cub", k=8, model_url="", img_size=224, num_cls=200): | |
| """ | |
| Function to load the PDiscoFormer model with ViT backbone | |
| :param pretrained: Boolean flag to load the pretrained weights | |
| :param model_dataset: Dataset for which the model is trained | |
| :param k: Number of unsupervised landmarks the model is trained on | |
| :param model_url: URL to load the model weights from | |
| :param img_size: Image size | |
| :param num_cls: Number of classes in the dataset | |
| :return: PDiscoFormer model with ViT backbone | |
| """ | |
| model = pdiscoformer_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size) | |
| if pretrained: | |
| hub_dir = torch.hub.get_dir() | |
| model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdiscoformer_{model_dataset}") | |
| Path(model_dir).mkdir(parents=True, exist_ok=True) | |
| url_path = model_url + str(k) + "_parts_snapshot_best.pt" | |
| snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu') | |
| if 'model_state' in snapshot_data: | |
| _, state_dict = load_state_dict_pdisco(snapshot_data) | |
| else: | |
| state_dict = copy.deepcopy(snapshot_data) | |
| model.load_state_dict(state_dict, strict=True) | |
| return model | |
| def pdisconet_vit(pretrained=True, model_dataset="nabirds", k=8, model_url="", img_size=224, num_cls=555): | |
| """ | |
| Function to load the PDiscoNet model with ViT backbone | |
| :param pretrained: Boolean flag to load the pretrained weights | |
| :param model_dataset: Dataset for which the model is trained | |
| :param k: Number of unsupervised landmarks the model is trained on | |
| :param model_url: URL to load the model weights from | |
| :param img_size: Image size | |
| :param num_cls: Number of classes in the dataset | |
| :return: PDiscoNet model with ViT backbone | |
| """ | |
| model = pdisconet_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size) | |
| if pretrained: | |
| hub_dir = torch.hub.get_dir() | |
| model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdisconet_{model_dataset}") | |
| Path(model_dir).mkdir(parents=True, exist_ok=True) | |
| url_path = model_url + str(k) + "_parts_snapshot_best.pt" | |
| snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu') | |
| if 'model_state' in snapshot_data: | |
| _, state_dict = load_state_dict_pdisco(snapshot_data) | |
| else: | |
| state_dict = copy.deepcopy(snapshot_data) | |
| model.load_state_dict(state_dict, strict=True) | |
| return model | |
| def pdisconet_resnet101(pretrained=True, model_dataset="nabirds", k=8, model_url="", num_cls=555): | |
| """ | |
| Function to load the PDiscoNet model with ResNet-101 backbone | |
| :param pretrained: Boolean flag to load the pretrained weights | |
| :param model_dataset: Dataset for which the model is trained | |
| :param k: Number of unsupervised landmarks the model is trained on | |
| :param model_url: URL to load the model weights from | |
| :param num_cls: Number of classes in the dataset | |
| :return: PDiscoNet model with ResNet-101 backbone | |
| """ | |
| model = pdisconet_resnet_torchvision_bb("resnet101", num_cls=num_cls, k=k) | |
| if pretrained: | |
| hub_dir = torch.hub.get_dir() | |
| model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdisconet_{model_dataset}") | |
| Path(model_dir).mkdir(parents=True, exist_ok=True) | |
| url_path = model_url + str(k) + "_parts_snapshot_best.pt" | |
| snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu') | |
| if 'model_state' in snapshot_data: | |
| _, state_dict = load_state_dict_pdisco(snapshot_data) | |
| else: | |
| state_dict = copy.deepcopy(snapshot_data) | |
| model.load_state_dict(state_dict, strict=True) | |
| return model | |