Upload 4 files
Browse files
app.py
CHANGED
|
@@ -181,33 +181,22 @@ def train_model(epochs, batch_size, learning_rate, resume=False, progress=gr.Pro
|
|
| 181 |
tokenizer: Any
|
| 182 |
|
| 183 |
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 184 |
-
|
| 185 |
-
max_length = max(len(f["input_ids"]) for f in features)
|
| 186 |
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
|
|
|
| 192 |
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
batch["input_ids"].append(input_ids + [self.tokenizer.pad_token_id] * padding_length)
|
| 199 |
-
batch["attention_mask"].append(f["attention_mask"] + [0] * padding_length)
|
| 200 |
-
|
| 201 |
-
# Labels: copy of input_ids with padding as -100 (ignored in loss)
|
| 202 |
-
batch["labels"].append(input_ids + [-100] * padding_length)
|
| 203 |
|
| 204 |
-
|
| 205 |
-
import torch
|
| 206 |
-
return {
|
| 207 |
-
"input_ids": torch.tensor(batch["input_ids"], dtype=torch.long),
|
| 208 |
-
"attention_mask": torch.tensor(batch["attention_mask"], dtype=torch.long),
|
| 209 |
-
"labels": torch.tensor(batch["labels"], dtype=torch.long)
|
| 210 |
-
}
|
| 211 |
|
| 212 |
data_collator = CustomDataCollator(tokenizer=tokenizer)
|
| 213 |
|
|
|
|
| 181 |
tokenizer: Any
|
| 182 |
|
| 183 |
def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 184 |
+
import torch
|
|
|
|
| 185 |
|
| 186 |
+
# Use tokenizer's pad method for proper padding
|
| 187 |
+
batch = self.tokenizer.pad(
|
| 188 |
+
features,
|
| 189 |
+
padding=True,
|
| 190 |
+
return_tensors="pt"
|
| 191 |
+
)
|
| 192 |
|
| 193 |
+
# Create labels from input_ids
|
| 194 |
+
# Replace padding token id with -100 so it's ignored in loss
|
| 195 |
+
labels = batch["input_ids"].clone()
|
| 196 |
+
labels[labels == self.tokenizer.pad_token_id] = -100
|
| 197 |
+
batch["labels"] = labels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
+
return batch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
data_collator = CustomDataCollator(tokenizer=tokenizer)
|
| 202 |
|