llaa33219 commited on
Commit
98871c7
·
verified ·
1 Parent(s): 0dbb2c9

Upload 4 files

Browse files
Files changed (1) hide show
  1. app.py +13 -24
app.py CHANGED
@@ -181,33 +181,22 @@ def train_model(epochs, batch_size, learning_rate, resume=False, progress=gr.Pro
181
  tokenizer: Any
182
 
183
  def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
184
- # Pad input_ids
185
- max_length = max(len(f["input_ids"]) for f in features)
186
 
187
- batch = {
188
- "input_ids": [],
189
- "attention_mask": [],
190
- "labels": []
191
- }
 
192
 
193
- for f in features:
194
- input_ids = f["input_ids"]
195
- padding_length = max_length - len(input_ids)
196
-
197
- # Pad input_ids and attention_mask
198
- batch["input_ids"].append(input_ids + [self.tokenizer.pad_token_id] * padding_length)
199
- batch["attention_mask"].append(f["attention_mask"] + [0] * padding_length)
200
-
201
- # Labels: copy of input_ids with padding as -100 (ignored in loss)
202
- batch["labels"].append(input_ids + [-100] * padding_length)
203
 
204
- # Convert to tensors
205
- import torch
206
- return {
207
- "input_ids": torch.tensor(batch["input_ids"], dtype=torch.long),
208
- "attention_mask": torch.tensor(batch["attention_mask"], dtype=torch.long),
209
- "labels": torch.tensor(batch["labels"], dtype=torch.long)
210
- }
211
 
212
  data_collator = CustomDataCollator(tokenizer=tokenizer)
213
 
 
181
  tokenizer: Any
182
 
183
  def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
184
+ import torch
 
185
 
186
+ # Use tokenizer's pad method for proper padding
187
+ batch = self.tokenizer.pad(
188
+ features,
189
+ padding=True,
190
+ return_tensors="pt"
191
+ )
192
 
193
+ # Create labels from input_ids
194
+ # Replace padding token id with -100 so it's ignored in loss
195
+ labels = batch["input_ids"].clone()
196
+ labels[labels == self.tokenizer.pad_token_id] = -100
197
+ batch["labels"] = labels
 
 
 
 
 
198
 
199
+ return batch
 
 
 
 
 
 
200
 
201
  data_collator = CustomDataCollator(tokenizer=tokenizer)
202