Spaces:
Running
Running
| import math | |
| from dataclasses import dataclass, field | |
| from typing import List, Union | |
| from pydantic import BaseModel, Field, field_validator | |
| class Chunk: | |
| id: str | |
| content: str | |
| type: str | |
| metadata: dict = field(default_factory=dict) | |
| def from_dict(key: str, data: dict) -> "Chunk": | |
| return Chunk( | |
| id=key, | |
| content=data.get("content", ""), | |
| type=data.get("type", "text"), | |
| metadata={k: v for k, v in data.items() if k != "content"}, | |
| ) | |
| class QAPair: | |
| """ | |
| A pair of question and answer. | |
| """ | |
| question: str | |
| answer: str | |
| class Token: | |
| text: str | |
| prob: float | |
| top_candidates: List = field(default_factory=list) | |
| ppl: Union[float, None] = field(default=None) | |
| def logprob(self) -> float: | |
| return math.log(self.prob) | |
| class Community: | |
| id: Union[int, str] | |
| nodes: List[str] = field(default_factory=list) | |
| edges: List[tuple] = field(default_factory=list) | |
| metadata: dict = field(default_factory=dict) | |
| class Node(BaseModel): | |
| id: str = Field(..., description="unique node id") | |
| op_name: str = Field(..., description="operator name") | |
| type: str = Field( | |
| ..., description="task type, e.g., map, filter, flatmap, aggregate, map_batch" | |
| ) | |
| params: dict = Field(default_factory=dict, description="operator parameters") | |
| dependencies: List[str] = Field( | |
| default_factory=list, description="list of dependent node ids" | |
| ) | |
| execution_params: dict = Field( | |
| default_factory=dict, description="execution parameters like replicas, batch_size" | |
| ) | |
| def validate_type(cls, v: str) -> str: | |
| valid_types = {"map", "filter", "flatmap", "aggregate", "map_batch"} | |
| if v not in valid_types: | |
| raise ValueError(f"Invalid node type: {v}. Must be one of {valid_types}.") | |
| return v | |
| class Config(BaseModel): | |
| global_params: dict = Field( | |
| default_factory=dict, description="global context for the computation graph" | |
| ) | |
| nodes: List[Node] = Field( | |
| ..., min_length=1, description="list of nodes in the computation graph" | |
| ) | |
| def validate_unique_ids(cls, v: List[Node]) -> List[Node]: | |
| ids = [node.id for node in v] | |
| if len(ids) != len(set(ids)): | |
| duplicates = {id_ for id_ in ids if ids.count(id_) > 1} | |
| raise ValueError(f"Duplicate node ids found: {duplicates}") | |
| return v | |