YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

Model Card for Yurim0507/vit-base-16-cifar10-unlearning

This repository contains ViT-Base/16 models retrained on the CIFAR-10 dataset with specific classes excluded (“forgotten”) during training. Each model is trained to study the impact of excluding a class on model performance and generalization.

Evaluation

  • Testing Data: CIFAR-10 test set (10,000 images, 1000 per class)
  • Metrics: Top-1 accuracy on retained classes (i.e., excluding the forgotten class) and overall accuracy for the original model. For the excluded-class models, “CIFAR-10 Accuracy” refers to accuracy computed on the remaining 9 classes’ test samples (9000 images).

Results

Model File Excluded Class CIFAR-10 Accuracy (Retain classes only)
vit_base_16_cifar10_original.pth None (Original) 98.36%
vit_base_16_cifar10_forget0.pth Airplane 97.99%
vit_base_16_cifar10_forget1.pth Automobile 98.12%
vit_base_16_cifar10_forget2.pth Bird 98.16%
vit_base_16_cifar10_forget3.pth Cat 98.59%
vit_base_16_cifar10_forget4.pth Deer 97.92%
vit_base_16_cifar10_forget5.pth Dog 98.40%
vit_base_16_cifar10_forget6.pth Frog 97.94%
vit_base_16_cifar10_forget7.pth Horse 97.92%
vit_base_16_cifar10_forget8.pth Ship 98.02%
vit_base_16_cifar10_forget9.pth Truck 97.17%

Training Details

Training Procedure

  • Base Model: ViT-Base/16 (pretrained on ImageNet)
    • Patch size: 16
    • Embedding dimension: 768
    • Depth: 12
    • Number of heads: 12
  • Dataset: CIFAR-10, resized from 32×32 to 224×224
  • Excluded Class: varies per model (one of {Airplane, Automobile, Bird, Cat, Deer, Dog, Frog, Horse, Ship, Truck})
  • Loss Function: CrossEntropyLoss
  • Optimizer: AdamW
    • Learning rate: 5e-5
    • Weight decay: 0.01
  • Scheduler: CosineAnnealingLR (T_max=10)
  • Epochs: 10
  • Batch Size: 16 (adjust according to GPU memory)
  • Augmentation:
    • RandomAugmentation
  • Input Processing:
    • Resize CIFAR-10 images from 32×32 to 224×224
    • Normalize with ImageNet mean (0.485, 0.456, 0.406) and std (0.229, 0.224, 0.225)
  • Hardware: Single GPU (NVIDIA GeForce RTX 3090)

Data Preprocessing

  • Base Transformations (train & test):
    • Convert to PyTorch Tensor
    • Resize to 224×224
    • Normalize with ImageNet mean/std
  • Training Set Augmentation:
    • RandomAugmentation
  • Excluded-Class Handling:
    • Remove all training samples of the excluded class from the training split.
    • For evaluation on retained classes, compute metrics only on the remaining-class test samples.

Model Description

  • Model type: Vision Transformer (ViT-Base/16) for image classification
  • License: MIT

Related Work

This model is part of the research conducted using the Machine Unlearning Comparator. The tool was developed to compare various machine unlearning methods and their effects on models.

Uses

Direct Use

These models can be directly used for evaluating the effect of excluding specific classes from the CIFAR-10 dataset during training.

Out-of-Scope Use

The models are not suitable for tasks requiring general-purpose image classification beyond the CIFAR-10 dataset.

How to Get Started with the Model

Use code below to load the ViT-Base/16 architecture and weights.

import torch.nn as nn
from torchvision.models import ViT_B_16_Weights, vit_b_16

def vit_base_16(num_classes, **kwargs):
    
    weights = ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1
    model = vit_b_16(weights=weights)
    in_features = model.heads.head.in_features
    model.heads.head = nn.Linear(in_features, num_classes)
    return model


def load_model(model_path: str, num_classes: int = 10, device: str = 'cpu'):
    model = get_vit_base16(num_classes=num_classes, pretrained=False)
    state = torch.load(model_path, map_location=device)
    model.load_state_dict(state)
    model.to(device).eval()
    return model

# Example usage:
model = load_model("vit_base_16_cifar10_original.pth", num_classes=10, device='cuda')
# For excluded-class variant (e.g., exclude Airplane):
model_f0 = load_model("vit_base_16_cifar10_forget0.pth", num_classes=10, device='cuda')
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support