YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
Model Card for Yurim0507/vit-base-16-cifar10-unlearning
This repository contains ViT-Base/16 models retrained on the CIFAR-10 dataset with specific classes excluded (“forgotten”) during training. Each model is trained to study the impact of excluding a class on model performance and generalization.
Evaluation
- Testing Data: CIFAR-10 test set (10,000 images, 1000 per class)
- Metrics: Top-1 accuracy on retained classes (i.e., excluding the forgotten class) and overall accuracy for the original model. For the excluded-class models, “CIFAR-10 Accuracy” refers to accuracy computed on the remaining 9 classes’ test samples (9000 images).
Results
| Model File | Excluded Class | CIFAR-10 Accuracy (Retain classes only) |
|---|---|---|
vit_base_16_cifar10_original.pth |
None (Original) | 98.36% |
vit_base_16_cifar10_forget0.pth |
Airplane | 97.99% |
vit_base_16_cifar10_forget1.pth |
Automobile | 98.12% |
vit_base_16_cifar10_forget2.pth |
Bird | 98.16% |
vit_base_16_cifar10_forget3.pth |
Cat | 98.59% |
vit_base_16_cifar10_forget4.pth |
Deer | 97.92% |
vit_base_16_cifar10_forget5.pth |
Dog | 98.40% |
vit_base_16_cifar10_forget6.pth |
Frog | 97.94% |
vit_base_16_cifar10_forget7.pth |
Horse | 97.92% |
vit_base_16_cifar10_forget8.pth |
Ship | 98.02% |
vit_base_16_cifar10_forget9.pth |
Truck | 97.17% |
Training Details
Training Procedure
- Base Model: ViT-Base/16 (pretrained on ImageNet)
- Patch size: 16
- Embedding dimension: 768
- Depth: 12
- Number of heads: 12
- Dataset: CIFAR-10, resized from 32×32 to 224×224
- Excluded Class: varies per model (one of {Airplane, Automobile, Bird, Cat, Deer, Dog, Frog, Horse, Ship, Truck})
- Loss Function: CrossEntropyLoss
- Optimizer: AdamW
- Learning rate: 5e-5
- Weight decay: 0.01
- Scheduler: CosineAnnealingLR (T_max=10)
- Epochs: 10
- Batch Size: 16 (adjust according to GPU memory)
- Augmentation:
- RandomAugmentation
- Input Processing:
- Resize CIFAR-10 images from 32×32 to 224×224
- Normalize with ImageNet mean
(0.485, 0.456, 0.406)and std(0.229, 0.224, 0.225)
- Hardware: Single GPU (NVIDIA GeForce RTX 3090)
Data Preprocessing
- Base Transformations (train & test):
- Convert to PyTorch Tensor
- Resize to 224×224
- Normalize with ImageNet mean/std
- Training Set Augmentation:
- RandomAugmentation
- Excluded-Class Handling:
- Remove all training samples of the excluded class from the training split.
- For evaluation on retained classes, compute metrics only on the remaining-class test samples.
Model Description
- Model type: Vision Transformer (ViT-Base/16) for image classification
- License: MIT
Related Work
This model is part of the research conducted using the Machine Unlearning Comparator. The tool was developed to compare various machine unlearning methods and their effects on models.
Uses
Direct Use
These models can be directly used for evaluating the effect of excluding specific classes from the CIFAR-10 dataset during training.
Out-of-Scope Use
The models are not suitable for tasks requiring general-purpose image classification beyond the CIFAR-10 dataset.
How to Get Started with the Model
Use code below to load the ViT-Base/16 architecture and weights.
import torch.nn as nn
from torchvision.models import ViT_B_16_Weights, vit_b_16
def vit_base_16(num_classes, **kwargs):
weights = ViT_B_16_Weights.IMAGENET1K_SWAG_LINEAR_V1
model = vit_b_16(weights=weights)
in_features = model.heads.head.in_features
model.heads.head = nn.Linear(in_features, num_classes)
return model
def load_model(model_path: str, num_classes: int = 10, device: str = 'cpu'):
model = get_vit_base16(num_classes=num_classes, pretrained=False)
state = torch.load(model_path, map_location=device)
model.load_state_dict(state)
model.to(device).eval()
return model
# Example usage:
model = load_model("vit_base_16_cifar10_original.pth", num_classes=10, device='cuda')
# For excluded-class variant (e.g., exclude Airplane):
model_f0 = load_model("vit_base_16_cifar10_forget0.pth", num_classes=10, device='cuda')
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support