File size: 3,448 Bytes
9504c80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c160ba6
 
 
 
9504c80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
952742f
 
9504c80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/env python3
import os
from dataclasses import dataclass
from typing import Dict, List

import numpy as np
from datasets import load_dataset
import evaluate

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
    TrainingArguments,
    Trainer,
)

# ======================
# LABEL SCHEMA
# ======================

LABELS: List[str] = [
    "pre-1900",
    "1900-1945",
    "1946-1979",
    "1980-1999",
    "2000-2015",
    "2016-present",
]

id2label: Dict[int, str] = {i: l for i, l in enumerate(LABELS)}
label2id: Dict[str, int] = {l: i for i, l in enumerate(LABELS)}

# Base model to fine-tune
BASE_MODEL = os.environ.get("BASE_MODEL", "distilroberta-base")

# Hugging Face hub repo where the fine-tuned model will be pushed
HUB_MODEL_ID = "DelaliScratchwerk/time-period-classifier-bert"

# ======================
# LOAD DATA
# ======================

# Expect CSVs at data/train.csv and data/val.csv
dataset = load_dataset(
    "csv",
    data_files={
        "train": "data/train.csv",
        "validation": "data/val.csv",
    },
)

print("Raw dataset:", dataset)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)


def encode_batch(batch):
    # tokenize texts
    enc = tokenizer(batch["text"], truncation=True)
    # map string labels -> integer ids
    # strip helps if there are trailing spaces in the CSV
    enc["labels"] = [label2id[l.strip()] for l in batch["label"]]
    return enc


# IMPORTANT: remove original 'text' and 'label' columns so Trainer only sees tensors
encoded = dataset.map(
    encode_batch,
    batched=True,
    remove_columns=dataset["train"].column_names,
)

print(encoded)
print("Encoded train sample keys:", encoded["train"][0].keys())
# should be: dict_keys(['input_ids', 'attention_mask', 'labels'])

# ======================
# MODEL
# ======================

model = AutoModelForSequenceClassification.from_pretrained(
    BASE_MODEL,
    num_labels=len(LABELS),
    id2label=id2label,
    label2id=label2id,
)

# ======================
# METRICS
# ======================

accuracy = evaluate.load("accuracy")
f1_macro = evaluate.load("f1")


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {
        "accuracy": accuracy.compute(predictions=preds, references=labels)["accuracy"],
        "f1_macro": f1_macro.compute(
            predictions=preds, references=labels, average="macro"
        )["f1"],
    }


data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# ======================
# TRAINING ARGS
# ======================

training_args = TrainingArguments(
    output_dir="out",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    learning_rate=5e-5,
    num_train_epochs=10,
    eval_strategy="epoch",
    save_strategy="no",
    load_best_model_at_end=False,
    logging_steps=50,
    push_to_hub=True,
    hub_model_id=HUB_MODEL_ID,
    hub_private_repo=False,
)

# ======================
# TRAINER
# ======================

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded["train"],
    eval_dataset=encoded["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

if __name__ == "__main__":
    trainer.train()
    # push best model + tokenizer to the Hub
    trainer.push_to_hub()
    tokenizer.push_to_hub(HUB_MODEL_ID)