File size: 2,187 Bytes
31086ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import pandas as pd

from graphgen.bases import BaseLLMWrapper, BaseOperator
from graphgen.common import init_llm
from graphgen.models import (
    AggregatedGenerator,
    AtomicGenerator,
    CoTGenerator,
    MultiHopGenerator,
    VQAGenerator,
)
from graphgen.utils import logger, run_concurrent


class GenerateService(BaseOperator):
    """
    Generate question-answer pairs based on nodes and edges.
    """

    def __init__(
        self,
        working_dir: str = "cache",
        method: str = "aggregated",
        data_format: str = "ChatML",
    ):
        super().__init__(working_dir=working_dir, op_name="generate_service")
        self.llm_client: BaseLLMWrapper = init_llm("synthesizer")

        self.method = method
        self.data_format = data_format

        if self.method == "atomic":
            self.generator = AtomicGenerator(self.llm_client)
        elif self.method == "aggregated":
            self.generator = AggregatedGenerator(self.llm_client)
        elif self.method == "multi_hop":
            self.generator = MultiHopGenerator(self.llm_client)
        elif self.method == "cot":
            self.generator = CoTGenerator(self.llm_client)
        elif self.method in ["vqa"]:
            self.generator = VQAGenerator(self.llm_client)
        else:
            raise ValueError(f"Unsupported generation mode: {method}")

    def process(self, batch: pd.DataFrame) -> pd.DataFrame:
        items = batch.to_dict(orient="records")
        return pd.DataFrame(self.generate(items))

    def generate(self, items: list[dict]) -> list[dict]:
        """
        Generate question-answer pairs based on nodes and edges.
        :param items
        :return: QA pairs
        """
        logger.info("[Generation] mode: %s, batches: %d", self.method, len(items))
        items = [(item["nodes"], item["edges"]) for item in items]
        results = run_concurrent(
            self.generator.generate,
            items,
            desc="[4/4]Generating QAs",
            unit="batch",
        )

        results = self.generator.format_generation_results(
            results, output_data_format=self.data_format
        )

        return results