Spaces:
Sleeping
Sleeping
| from datetime import datetime | |
| from tqdm import tqdm | |
| from datasets import load_dataset | |
| from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor | |
| from items import Item | |
| CHUNK_SIZE = 1000 | |
| MIN_PRICE = 0.5 | |
| MAX_PRICE = 999.49 | |
| class ItemLoader: | |
| def __init__(self, name): | |
| self.name = name | |
| self.dataset = None | |
| def from_datapoint(self, datapoint): | |
| """ | |
| Try to create an Item from this datapoint | |
| Return the Item if successful, or None if it shouldn't be included | |
| """ | |
| try: | |
| price_str = datapoint['price'] | |
| if price_str: | |
| price = float(price_str) | |
| if MIN_PRICE <= price <= MAX_PRICE: | |
| item = Item(datapoint, price) | |
| return item if item.include else None | |
| except ValueError: | |
| return None | |
| def from_chunk(self, chunk): | |
| """ | |
| Create a list of Items from this chunk of elements from the Dataset | |
| """ | |
| batch = [] | |
| for datapoint in chunk: | |
| result = self.from_datapoint(datapoint) | |
| if result: | |
| batch.append(result) | |
| return batch | |
| def chunk_generator(self): | |
| """ | |
| Iterate over the Dataset, yielding chunks of datapoints at a time | |
| """ | |
| size = len(self.dataset) | |
| for i in range(0, size, CHUNK_SIZE): | |
| yield self.dataset.select(range(i, min(i + CHUNK_SIZE, size))) | |
| def load_in_parallel(self, workers): | |
| """ | |
| Use concurrent.futures to farm out the work to process chunks of datapoints - | |
| This speeds up processing significantly, but will tie up your computer while it's doing so! | |
| """ | |
| results = [] | |
| chunk_count = (len(self.dataset) // CHUNK_SIZE) + 1 | |
| with ProcessPoolExecutor(max_workers=workers) as pool: | |
| for batch in tqdm(pool.map(self.from_chunk, self.chunk_generator()), total=chunk_count): | |
| results.extend(batch) | |
| for result in results: | |
| result.category = self.name | |
| return results | |
| def load(self, workers=8): | |
| """ | |
| Load in this dataset; the workers parameter specifies how many processes | |
| should work on loading and scrubbing the data | |
| """ | |
| start = datetime.now() | |
| print(f"Loading dataset {self.name}", flush=True) | |
| self.dataset = load_dataset("McAuley-Lab/Amazon-Reviews-2023", f"raw_meta_{self.name}", split="full", trust_remote_code=True) | |
| results = self.load_in_parallel(workers) | |
| finish = datetime.now() | |
| print(f"Completed {self.name} with {len(results):,} datapoints in {(finish-start).total_seconds()/60:.1f} mins", flush=True) | |
| return results | |