Spaces:
Runtime error
Runtime error
| import json | |
| import torch | |
| import datasets | |
| from torch.utils.data import DataLoader | |
| from .collator import Collator | |
| from .batch_sampler import BatchSampler | |
| from .norm import normalize_dataset | |
| from torch.utils.data import Dataset | |
| from typing import Dict, Any, List, Union | |
| import pandas as pd | |
| def prepare_dataloaders(args, tokenizer, logger): | |
| """Prepare train, validation and test dataloaders.""" | |
| # Process datasets | |
| train_dataset = datasets.load_dataset(args.dataset)['train'] | |
| train_dataset_token_lengths = [len(item['aa_seq']) for item in train_dataset] | |
| val_dataset = datasets.load_dataset(args.dataset)['validation'] | |
| val_dataset_token_lengths = [len(item['aa_seq']) for item in val_dataset] | |
| test_dataset = datasets.load_dataset(args.dataset)['test'] | |
| test_dataset_token_lengths = [len(item['aa_seq']) for item in test_dataset] | |
| if args.normalize is not None: | |
| train_dataset, val_dataset, test_dataset = normalize_dataset(train_dataset, val_dataset, test_dataset, args.normalize) | |
| # log dataset info | |
| logger.info("Dataset Statistics:") | |
| logger.info("------------------------") | |
| logger.info(f"Dataset: {args.dataset}") | |
| logger.info(f" Number of train samples: {len(train_dataset)}") | |
| logger.info(f" Number of val samples: {len(val_dataset)}") | |
| logger.info(f" Number of test samples: {len(test_dataset)}") | |
| # log 3 data points from train_dataset | |
| logger.info("Sample 3 data points from train dataset:") | |
| logger.info(f" Train data point 1: {train_dataset[0]}") | |
| logger.info(f" Train data point 2: {train_dataset[1]}") | |
| logger.info(f" Train data point 3: {train_dataset[2]}") | |
| logger.info("------------------------") | |
| collator = Collator( | |
| tokenizer=tokenizer, | |
| max_length=args.max_seq_len if args.max_seq_len > 0 else None, | |
| structure_seq=args.structure_seq, | |
| problem_type=args.problem_type, | |
| plm_model=args.plm_model, | |
| num_labels=args.num_labels | |
| ) | |
| # Common dataloader parameters | |
| dataloader_params = { | |
| 'num_workers': args.num_workers, | |
| 'collate_fn': collator, | |
| 'pin_memory': True, | |
| 'persistent_workers': True if args.num_workers > 0 else False, | |
| 'prefetch_factor': 2, | |
| } | |
| # Create dataloaders based on batching strategy | |
| if args.batch_token is not None: | |
| train_loader = create_token_based_loader(train_dataset, train_dataset_token_lengths, args.batch_token, True, **dataloader_params) | |
| val_loader = create_token_based_loader(val_dataset, val_dataset_token_lengths, args.batch_token, False, **dataloader_params) | |
| test_loader = create_token_based_loader(test_dataset, test_dataset_token_lengths, args.batch_token, False, **dataloader_params) | |
| else: | |
| train_loader = create_size_based_loader(train_dataset, args.batch_size, True, **dataloader_params) | |
| val_loader = create_size_based_loader(val_dataset, args.batch_size, False, **dataloader_params) | |
| test_loader = create_size_based_loader(test_dataset, args.batch_size, False, **dataloader_params) | |
| return train_loader, val_loader, test_loader | |
| def create_token_based_loader(dataset, token_lengths, batch_token, shuffle, **kwargs): | |
| """Create dataloader with token-based batching.""" | |
| sampler = BatchSampler(token_lengths, batch_token, shuffle=shuffle) | |
| return DataLoader(dataset, batch_sampler=sampler, **kwargs) | |
| def create_size_based_loader(dataset, batch_size, shuffle, **kwargs): | |
| """Create dataloader with size-based batching.""" | |
| return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs) | |