github-actions[bot] commited on
Commit
4b2a9c2
·
1 Parent(s): 2c0627c

Auto-sync from demo at Wed Sep 10 08:57:26 UTC 2025

Browse files
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # pylint: skip-file
2
  import json
3
  import os
4
  import sys
@@ -6,6 +5,7 @@ import tempfile
6
 
7
  import gradio as gr
8
  import pandas as pd
 
9
 
10
  from webui.base import GraphGenParams
11
  from webui.cache_utils import cleanup_workspace, setup_workspace
@@ -19,10 +19,12 @@ root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
19
  sys.path.append(root_dir)
20
 
21
  from graphgen.graphgen import GraphGen
22
- from graphgen.models import OpenAIModel, Tokenizer, TraverseStrategy
23
  from graphgen.models.llm.limitter import RPM, TPM
24
  from graphgen.utils import set_logger
25
 
 
 
26
  css = """
27
  .center-row {
28
  display: flex;
@@ -36,8 +38,8 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
36
  # Set up working directory
37
  log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
38
 
39
- set_logger(log_file, if_stream=False)
40
- graph_gen = GraphGen(working_dir=working_dir)
41
 
42
  # Set up LLM clients
43
  graph_gen.synthesizer_llm_client = OpenAIModel(
@@ -60,19 +62,6 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
60
 
61
  graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
62
 
63
- strategy_config = config.get("traverse_strategy", {})
64
- graph_gen.traverse_strategy = TraverseStrategy(
65
- qa_form=strategy_config.get("qa_form"),
66
- expand_method=strategy_config.get("expand_method"),
67
- bidirectional=strategy_config.get("bidirectional"),
68
- max_extra_edges=strategy_config.get("max_extra_edges"),
69
- max_tokens=strategy_config.get("max_tokens"),
70
- max_depth=strategy_config.get("max_depth"),
71
- edge_sampling=strategy_config.get("edge_sampling"),
72
- isolated_node_strategy=strategy_config.get("isolated_node_strategy"),
73
- loss_strategy=str(strategy_config.get("loss_strategy")),
74
- )
75
-
76
  return graph_gen
77
 
78
 
@@ -84,10 +73,15 @@ def run_graphgen(params, progress=gr.Progress()):
84
  config = {
85
  "if_trainee_model": params.if_trainee_model,
86
  "input_file": params.input_file,
 
 
87
  "tokenizer": params.tokenizer,
88
- "quiz_samples": params.quiz_samples,
 
 
 
 
89
  "traverse_strategy": {
90
- "qa_form": params.qa_form,
91
  "bidirectional": params.bidirectional,
92
  "expand_method": params.expand_method,
93
  "max_extra_edges": params.max_extra_edges,
@@ -122,6 +116,35 @@ def run_graphgen(params, progress=gr.Progress()):
122
  env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
123
  )
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  # Initialize GraphGen
126
  graph_gen = init_graph_gen(config, env)
127
  graph_gen.clear()
@@ -129,51 +152,20 @@ def run_graphgen(params, progress=gr.Progress()):
129
  graph_gen.progress_bar = progress
130
 
131
  try:
132
- # Load input data
133
- file = config["input_file"]
134
- if isinstance(file, list):
135
- file = file[0]
136
-
137
- data = []
138
-
139
- if file.endswith(".jsonl"):
140
- data_type = "raw"
141
- with open(file, "r", encoding="utf-8") as f:
142
- data.extend(json.loads(line) for line in f)
143
- elif file.endswith(".json"):
144
- data_type = "chunked"
145
- with open(file, "r", encoding="utf-8") as f:
146
- data.extend(json.load(f))
147
- elif file.endswith(".txt"):
148
- # 读取文件后根据chunk_size转成raw格式的数据
149
- data_type = "raw"
150
- content = ""
151
- with open(file, "r", encoding="utf-8") as f:
152
- lines = f.readlines()
153
- for line in lines:
154
- content += line.strip() + " "
155
- size = int(config.get("chunk_size", 512))
156
- chunks = [content[i : i + size] for i in range(0, len(content), size)]
157
- data.extend([{"content": chunk} for chunk in chunks])
158
- else:
159
- raise ValueError(f"Unsupported file type: {file}")
160
-
161
  # Process the data
162
- graph_gen.insert(data, data_type)
163
 
164
  if config["if_trainee_model"]:
165
  # Generate quiz
166
- graph_gen.quiz(max_samples=config["quiz_samples"])
167
 
168
  # Judge statements
169
  graph_gen.judge()
170
  else:
171
  graph_gen.traverse_strategy.edge_sampling = "random"
172
- # Skip judge statements
173
- graph_gen.judge(skip=True)
174
 
175
  # Traverse graph
176
- graph_gen.traverse(traverse_strategy=graph_gen.traverse_strategy)
177
 
178
  # Save output
179
  output_data = graph_gen.qa_storage.data
@@ -328,12 +320,18 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
328
  tokenizer = gr.Textbox(
329
  label="Tokenizer", value="cl100k_base", interactive=True
330
  )
331
- qa_form = gr.Radio(
332
  choices=["atomic", "multi_hop", "aggregated"],
333
- label="QA Form",
334
  value="aggregated",
335
  interactive=True,
336
  )
 
 
 
 
 
 
337
  quiz_samples = gr.Number(
338
  label="Quiz Samples",
339
  value=2,
@@ -533,33 +531,35 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
533
  if_trainee_model=args[0],
534
  input_file=args[1],
535
  tokenizer=args[2],
536
- qa_form=args[3],
537
- bidirectional=args[4],
538
- expand_method=args[5],
539
- max_extra_edges=args[6],
540
- max_tokens=args[7],
541
- max_depth=args[8],
542
- edge_sampling=args[9],
543
- isolated_node_strategy=args[10],
544
- loss_strategy=args[11],
545
- synthesizer_url=args[12],
546
- synthesizer_model=args[13],
547
- trainee_model=args[14],
548
- api_key=args[15],
549
- chunk_size=args[16],
550
- rpm=args[17],
551
- tpm=args[18],
552
- quiz_samples=args[19],
553
- trainee_url=args[20],
554
- trainee_api_key=args[21],
555
- token_counter=args[22],
 
556
  )
557
  ),
558
  inputs=[
559
  if_trainee_model,
560
  upload_file,
561
  tokenizer,
562
- qa_form,
 
563
  bidirectional,
564
  expand_method,
565
  max_extra_edges,
 
 
1
  import json
2
  import os
3
  import sys
 
5
 
6
  import gradio as gr
7
  import pandas as pd
8
+ from dotenv import load_dotenv
9
 
10
  from webui.base import GraphGenParams
11
  from webui.cache_utils import cleanup_workspace, setup_workspace
 
19
  sys.path.append(root_dir)
20
 
21
  from graphgen.graphgen import GraphGen
22
+ from graphgen.models import OpenAIModel, Tokenizer
23
  from graphgen.models.llm.limitter import RPM, TPM
24
  from graphgen.utils import set_logger
25
 
26
+ load_dotenv()
27
+
28
  css = """
29
  .center-row {
30
  display: flex;
 
38
  # Set up working directory
39
  log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
40
 
41
+ set_logger(log_file, if_stream=True)
42
+ graph_gen = GraphGen(working_dir=working_dir, config=config)
43
 
44
  # Set up LLM clients
45
  graph_gen.synthesizer_llm_client = OpenAIModel(
 
62
 
63
  graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  return graph_gen
66
 
67
 
 
73
  config = {
74
  "if_trainee_model": params.if_trainee_model,
75
  "input_file": params.input_file,
76
+ "output_data_type": params.output_data_type,
77
+ "output_data_format": params.output_data_format,
78
  "tokenizer": params.tokenizer,
79
+ "search": {"enabled": False},
80
+ "quiz_and_judge_strategy": {
81
+ "enabled": params.if_trainee_model,
82
+ "quiz_samples": params.quiz_samples,
83
+ },
84
  "traverse_strategy": {
 
85
  "bidirectional": params.bidirectional,
86
  "expand_method": params.expand_method,
87
  "max_extra_edges": params.max_extra_edges,
 
116
  env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
117
  )
118
 
119
+ # Load input data
120
+ file = config["input_file"]
121
+ if isinstance(file, list):
122
+ file = file[0]
123
+
124
+ data = []
125
+
126
+ if file.endswith(".jsonl"):
127
+ config["input_data_type"] = "raw"
128
+ with open(file, "r", encoding="utf-8") as f:
129
+ data.extend(json.loads(line) for line in f)
130
+ elif file.endswith(".json"):
131
+ config["input_data_type"] = "chunked"
132
+ with open(file, "r", encoding="utf-8") as f:
133
+ data.extend(json.load(f))
134
+ elif file.endswith(".txt"):
135
+ # 读取文件后根据chunk_size转成raw格式的数据
136
+ config["input_data_type"] = "raw"
137
+ content = ""
138
+ with open(file, "r", encoding="utf-8") as f:
139
+ lines = f.readlines()
140
+ for line in lines:
141
+ content += line.strip() + " "
142
+ size = int(config.get("chunk_size", 512))
143
+ chunks = [content[i : i + size] for i in range(0, len(content), size)]
144
+ data.extend([{"content": chunk} for chunk in chunks])
145
+ else:
146
+ raise ValueError(f"Unsupported file type: {file}")
147
+
148
  # Initialize GraphGen
149
  graph_gen = init_graph_gen(config, env)
150
  graph_gen.clear()
 
152
  graph_gen.progress_bar = progress
153
 
154
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  # Process the data
156
+ graph_gen.insert()
157
 
158
  if config["if_trainee_model"]:
159
  # Generate quiz
160
+ graph_gen.quiz()
161
 
162
  # Judge statements
163
  graph_gen.judge()
164
  else:
165
  graph_gen.traverse_strategy.edge_sampling = "random"
 
 
166
 
167
  # Traverse graph
168
+ graph_gen.traverse()
169
 
170
  # Save output
171
  output_data = graph_gen.qa_storage.data
 
320
  tokenizer = gr.Textbox(
321
  label="Tokenizer", value="cl100k_base", interactive=True
322
  )
323
+ output_data_type = gr.Radio(
324
  choices=["atomic", "multi_hop", "aggregated"],
325
+ label="Output Data Type",
326
  value="aggregated",
327
  interactive=True,
328
  )
329
+ output_data_format = gr.Radio(
330
+ choices=["Alpaca", "Sharegpt", "ChatML"],
331
+ label="Output Data Format",
332
+ value="Alpaca",
333
+ interactive=True,
334
+ )
335
  quiz_samples = gr.Number(
336
  label="Quiz Samples",
337
  value=2,
 
531
  if_trainee_model=args[0],
532
  input_file=args[1],
533
  tokenizer=args[2],
534
+ output_data_type=args[3],
535
+ output_data_format=args[4],
536
+ bidirectional=args[5],
537
+ expand_method=args[6],
538
+ max_extra_edges=args[7],
539
+ max_tokens=args[8],
540
+ max_depth=args[9],
541
+ edge_sampling=args[10],
542
+ isolated_node_strategy=args[11],
543
+ loss_strategy=args[12],
544
+ synthesizer_url=args[13],
545
+ synthesizer_model=args[14],
546
+ trainee_model=args[15],
547
+ api_key=args[16],
548
+ chunk_size=args[17],
549
+ rpm=args[18],
550
+ tpm=args[19],
551
+ quiz_samples=args[20],
552
+ trainee_url=args[21],
553
+ trainee_api_key=args[22],
554
+ token_counter=args[23],
555
  )
556
  ),
557
  inputs=[
558
  if_trainee_model,
559
  upload_file,
560
  tokenizer,
561
+ output_data_type,
562
+ output_data_format,
563
  bidirectional,
564
  expand_method,
565
  max_extra_edges,
graphgen/configs/multi_hop_config.yaml CHANGED
@@ -7,7 +7,7 @@ search: # web search configuration
7
  enabled: false # whether to enable web search
8
  search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
  quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
10
- enabled: true
11
  quiz_samples: 2 # number of quiz samples to generate
12
  re_judge: false # whether to re-judge the existing quiz samples
13
  traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
 
7
  enabled: false # whether to enable web search
8
  search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
9
  quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
10
+ enabled: false
11
  quiz_samples: 2 # number of quiz samples to generate
12
  re_judge: false # whether to re-judge the existing quiz samples
13
  traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
graphgen/graphgen.py CHANGED
@@ -23,8 +23,8 @@ from .operators import (
23
  judge_statement,
24
  quiz,
25
  search_all,
26
- traverse_graph_atomically,
27
- traverse_graph_by_edge,
28
  traverse_graph_for_multi_hop,
29
  )
30
  from .utils import (
@@ -69,6 +69,7 @@ class GraphGen:
69
  self.tokenizer_instance: Tokenizer = Tokenizer(
70
  model_name=self.config["tokenizer"]
71
  )
 
72
  self.synthesizer_llm_client: OpenAIModel = OpenAIModel(
73
  model_name=os.getenv("SYNTHESIZER_MODEL"),
74
  api_key=os.getenv("SYNTHESIZER_API_KEY"),
@@ -326,7 +327,7 @@ class GraphGen:
326
  output_data_type = self.config["output_data_type"]
327
 
328
  if output_data_type == "atomic":
329
- results = await traverse_graph_atomically(
330
  self.synthesizer_llm_client,
331
  self.tokenizer_instance,
332
  self.graph_storage,
@@ -344,7 +345,7 @@ class GraphGen:
344
  self.progress_bar,
345
  )
346
  elif output_data_type == "aggregated":
347
- results = await traverse_graph_by_edge(
348
  self.synthesizer_llm_client,
349
  self.tokenizer_instance,
350
  self.graph_storage,
 
23
  judge_statement,
24
  quiz,
25
  search_all,
26
+ traverse_graph_for_aggregated,
27
+ traverse_graph_for_atomic,
28
  traverse_graph_for_multi_hop,
29
  )
30
  from .utils import (
 
69
  self.tokenizer_instance: Tokenizer = Tokenizer(
70
  model_name=self.config["tokenizer"]
71
  )
72
+ print(os.getenv("SYNTHESIZER_MODEL"), os.getenv("SYNTHESIZER_API_KEY"))
73
  self.synthesizer_llm_client: OpenAIModel = OpenAIModel(
74
  model_name=os.getenv("SYNTHESIZER_MODEL"),
75
  api_key=os.getenv("SYNTHESIZER_API_KEY"),
 
327
  output_data_type = self.config["output_data_type"]
328
 
329
  if output_data_type == "atomic":
330
+ results = await traverse_graph_for_atomic(
331
  self.synthesizer_llm_client,
332
  self.tokenizer_instance,
333
  self.graph_storage,
 
345
  self.progress_bar,
346
  )
347
  elif output_data_type == "aggregated":
348
+ results = await traverse_graph_for_aggregated(
349
  self.synthesizer_llm_client,
350
  self.tokenizer_instance,
351
  self.graph_storage,
graphgen/operators/__init__.py CHANGED
@@ -5,8 +5,8 @@ from graphgen.operators.search.search_all import search_all
5
  from .judge import judge_statement
6
  from .quiz import quiz
7
  from .traverse_graph import (
8
- traverse_graph_atomically,
9
- traverse_graph_by_edge,
10
  traverse_graph_for_multi_hop,
11
  )
12
 
@@ -15,8 +15,8 @@ __all__ = [
15
  "quiz",
16
  "judge_statement",
17
  "search_all",
18
- "traverse_graph_by_edge",
19
- "traverse_graph_atomically",
20
  "traverse_graph_for_multi_hop",
21
  "generate_cot",
22
  ]
 
5
  from .judge import judge_statement
6
  from .quiz import quiz
7
  from .traverse_graph import (
8
+ traverse_graph_for_aggregated,
9
+ traverse_graph_for_atomic,
10
  traverse_graph_for_multi_hop,
11
  )
12
 
 
15
  "quiz",
16
  "judge_statement",
17
  "search_all",
18
+ "traverse_graph_for_aggregated",
19
+ "traverse_graph_for_atomic",
20
  "traverse_graph_for_multi_hop",
21
  "generate_cot",
22
  ]
graphgen/operators/traverse_graph.py CHANGED
@@ -135,7 +135,9 @@ def get_average_loss(batch: tuple, loss_strategy: str) -> float:
135
  ) / (len(batch[0]) + len(batch[1]))
136
  raise ValueError("Invalid loss strategy")
137
  except Exception as e: # pylint: disable=broad-except
138
- logger.error("Error calculating average loss: %s", e)
 
 
139
  return -1.0
140
 
141
 
@@ -158,7 +160,7 @@ def _post_process_synthetic_data(data):
158
  return qas
159
 
160
 
161
- async def traverse_graph_by_edge(
162
  llm_client: OpenAIModel,
163
  tokenizer: Tokenizer,
164
  graph_storage: NetworkXStorage,
@@ -251,7 +253,6 @@ async def traverse_graph_by_edge(
251
  qas = _post_process_synthetic_data(content)
252
 
253
  if len(qas) == 0:
254
- print(content)
255
  logger.error(
256
  "Error occurred while processing batch, question or answer is None"
257
  )
@@ -307,7 +308,8 @@ async def traverse_graph_by_edge(
307
  return results
308
 
309
 
310
- async def traverse_graph_atomically(
 
311
  llm_client: OpenAIModel,
312
  tokenizer: Tokenizer,
313
  graph_storage: NetworkXStorage,
@@ -328,17 +330,28 @@ async def traverse_graph_atomically(
328
  :param max_concurrent
329
  :return: question and answer
330
  """
331
- assert traverse_strategy.qa_form == "atomic"
332
 
 
333
  semaphore = asyncio.Semaphore(max_concurrent)
334
 
 
 
 
 
 
 
 
 
 
 
 
335
  async def _generate_question(node_or_edge: tuple):
336
  if len(node_or_edge) == 2:
337
  des = node_or_edge[0] + ": " + node_or_edge[1]["description"]
338
- loss = node_or_edge[1]["loss"]
339
  else:
340
  des = node_or_edge[2]["description"]
341
- loss = node_or_edge[2]["loss"]
342
 
343
  async with semaphore:
344
  try:
@@ -350,13 +363,8 @@ async def traverse_graph_atomically(
350
  )
351
  )
352
 
353
- if "Question:" in qa and "Answer:" in qa:
354
- question = qa.split("Question:")[1].split("Answer:")[0].strip()
355
- answer = qa.split("Answer:")[1].strip()
356
- elif "问题:" in qa and "答案:" in qa:
357
- question = qa.split("问题:")[1].split("答案:")[0].strip()
358
- answer = qa.split("答案:")[1].strip()
359
- else:
360
  return {}
361
 
362
  question = question.strip('"')
@@ -386,16 +394,18 @@ async def traverse_graph_atomically(
386
  if "<SEP>" in node[1]["description"]:
387
  description_list = node[1]["description"].split("<SEP>")
388
  for item in description_list:
389
- tasks.append((node[0], {"description": item, "loss": node[1]["loss"]}))
 
 
390
  else:
391
  tasks.append((node[0], node[1]))
392
  for edge in edges:
393
  if "<SEP>" in edge[2]["description"]:
394
  description_list = edge[2]["description"].split("<SEP>")
395
  for item in description_list:
396
- tasks.append(
397
- (edge[0], edge[1], {"description": item, "loss": edge[2]["loss"]})
398
- )
399
  else:
400
  tasks.append((edge[0], edge[1], edge[2]))
401
 
 
135
  ) / (len(batch[0]) + len(batch[1]))
136
  raise ValueError("Invalid loss strategy")
137
  except Exception as e: # pylint: disable=broad-except
138
+ logger.warning(
139
+ "Loss not found in some nodes or edges, setting loss to -1.0: %s", e
140
+ )
141
  return -1.0
142
 
143
 
 
160
  return qas
161
 
162
 
163
+ async def traverse_graph_for_aggregated(
164
  llm_client: OpenAIModel,
165
  tokenizer: Tokenizer,
166
  graph_storage: NetworkXStorage,
 
253
  qas = _post_process_synthetic_data(content)
254
 
255
  if len(qas) == 0:
 
256
  logger.error(
257
  "Error occurred while processing batch, question or answer is None"
258
  )
 
308
  return results
309
 
310
 
311
+ # pylint: disable=too-many-branches, too-many-statements
312
+ async def traverse_graph_for_atomic(
313
  llm_client: OpenAIModel,
314
  tokenizer: Tokenizer,
315
  graph_storage: NetworkXStorage,
 
330
  :param max_concurrent
331
  :return: question and answer
332
  """
 
333
 
334
+ assert traverse_strategy.qa_form == "atomic"
335
  semaphore = asyncio.Semaphore(max_concurrent)
336
 
337
+ def _parse_qa(qa: str) -> tuple:
338
+ if "Question:" in qa and "Answer:" in qa:
339
+ question = qa.split("Question:")[1].split("Answer:")[0].strip()
340
+ answer = qa.split("Answer:")[1].strip()
341
+ elif "问题:" in qa and "答案:" in qa:
342
+ question = qa.split("问题:")[1].split("答案:")[0].strip()
343
+ answer = qa.split("答案:")[1].strip()
344
+ else:
345
+ return None, None
346
+ return question.strip('"'), answer.strip('"')
347
+
348
  async def _generate_question(node_or_edge: tuple):
349
  if len(node_or_edge) == 2:
350
  des = node_or_edge[0] + ": " + node_or_edge[1]["description"]
351
+ loss = node_or_edge[1]["loss"] if "loss" in node_or_edge[1] else -1.0
352
  else:
353
  des = node_or_edge[2]["description"]
354
+ loss = node_or_edge[2]["loss"] if "loss" in node_or_edge[2] else -1.0
355
 
356
  async with semaphore:
357
  try:
 
363
  )
364
  )
365
 
366
+ question, answer = _parse_qa(qa)
367
+ if question is None or answer is None:
 
 
 
 
 
368
  return {}
369
 
370
  question = question.strip('"')
 
394
  if "<SEP>" in node[1]["description"]:
395
  description_list = node[1]["description"].split("<SEP>")
396
  for item in description_list:
397
+ tasks.append((node[0], {"description": item}))
398
+ if "loss" in node[1]:
399
+ tasks[-1][1]["loss"] = node[1]["loss"]
400
  else:
401
  tasks.append((node[0], node[1]))
402
  for edge in edges:
403
  if "<SEP>" in edge[2]["description"]:
404
  description_list = edge[2]["description"].split("<SEP>")
405
  for item in description_list:
406
+ tasks.append((edge[0], edge[1], {"description": item}))
407
+ if "loss" in edge[2]:
408
+ tasks[-1][2]["loss"] = edge[2]["loss"]
409
  else:
410
  tasks.append((edge[0], edge[1], edge[2]))
411
 
webui/app.py CHANGED
@@ -1,4 +1,3 @@
1
- # pylint: skip-file
2
  import json
3
  import os
4
  import sys
@@ -6,6 +5,7 @@ import tempfile
6
 
7
  import gradio as gr
8
  import pandas as pd
 
9
 
10
  from webui.base import GraphGenParams
11
  from webui.cache_utils import cleanup_workspace, setup_workspace
@@ -19,10 +19,12 @@ root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
19
  sys.path.append(root_dir)
20
 
21
  from graphgen.graphgen import GraphGen
22
- from graphgen.models import OpenAIModel, Tokenizer, TraverseStrategy
23
  from graphgen.models.llm.limitter import RPM, TPM
24
  from graphgen.utils import set_logger
25
 
 
 
26
  css = """
27
  .center-row {
28
  display: flex;
@@ -36,8 +38,8 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
36
  # Set up working directory
37
  log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
38
 
39
- set_logger(log_file, if_stream=False)
40
- graph_gen = GraphGen(working_dir=working_dir)
41
 
42
  # Set up LLM clients
43
  graph_gen.synthesizer_llm_client = OpenAIModel(
@@ -60,19 +62,6 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
60
 
61
  graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
62
 
63
- strategy_config = config.get("traverse_strategy", {})
64
- graph_gen.traverse_strategy = TraverseStrategy(
65
- qa_form=strategy_config.get("qa_form"),
66
- expand_method=strategy_config.get("expand_method"),
67
- bidirectional=strategy_config.get("bidirectional"),
68
- max_extra_edges=strategy_config.get("max_extra_edges"),
69
- max_tokens=strategy_config.get("max_tokens"),
70
- max_depth=strategy_config.get("max_depth"),
71
- edge_sampling=strategy_config.get("edge_sampling"),
72
- isolated_node_strategy=strategy_config.get("isolated_node_strategy"),
73
- loss_strategy=str(strategy_config.get("loss_strategy")),
74
- )
75
-
76
  return graph_gen
77
 
78
 
@@ -84,10 +73,15 @@ def run_graphgen(params, progress=gr.Progress()):
84
  config = {
85
  "if_trainee_model": params.if_trainee_model,
86
  "input_file": params.input_file,
 
 
87
  "tokenizer": params.tokenizer,
88
- "quiz_samples": params.quiz_samples,
 
 
 
 
89
  "traverse_strategy": {
90
- "qa_form": params.qa_form,
91
  "bidirectional": params.bidirectional,
92
  "expand_method": params.expand_method,
93
  "max_extra_edges": params.max_extra_edges,
@@ -122,6 +116,35 @@ def run_graphgen(params, progress=gr.Progress()):
122
  env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
123
  )
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  # Initialize GraphGen
126
  graph_gen = init_graph_gen(config, env)
127
  graph_gen.clear()
@@ -129,51 +152,20 @@ def run_graphgen(params, progress=gr.Progress()):
129
  graph_gen.progress_bar = progress
130
 
131
  try:
132
- # Load input data
133
- file = config["input_file"]
134
- if isinstance(file, list):
135
- file = file[0]
136
-
137
- data = []
138
-
139
- if file.endswith(".jsonl"):
140
- data_type = "raw"
141
- with open(file, "r", encoding="utf-8") as f:
142
- data.extend(json.loads(line) for line in f)
143
- elif file.endswith(".json"):
144
- data_type = "chunked"
145
- with open(file, "r", encoding="utf-8") as f:
146
- data.extend(json.load(f))
147
- elif file.endswith(".txt"):
148
- # 读取文件后根据chunk_size转成raw格式的数据
149
- data_type = "raw"
150
- content = ""
151
- with open(file, "r", encoding="utf-8") as f:
152
- lines = f.readlines()
153
- for line in lines:
154
- content += line.strip() + " "
155
- size = int(config.get("chunk_size", 512))
156
- chunks = [content[i : i + size] for i in range(0, len(content), size)]
157
- data.extend([{"content": chunk} for chunk in chunks])
158
- else:
159
- raise ValueError(f"Unsupported file type: {file}")
160
-
161
  # Process the data
162
- graph_gen.insert(data, data_type)
163
 
164
  if config["if_trainee_model"]:
165
  # Generate quiz
166
- graph_gen.quiz(max_samples=config["quiz_samples"])
167
 
168
  # Judge statements
169
  graph_gen.judge()
170
  else:
171
  graph_gen.traverse_strategy.edge_sampling = "random"
172
- # Skip judge statements
173
- graph_gen.judge(skip=True)
174
 
175
  # Traverse graph
176
- graph_gen.traverse(traverse_strategy=graph_gen.traverse_strategy)
177
 
178
  # Save output
179
  output_data = graph_gen.qa_storage.data
@@ -328,12 +320,18 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
328
  tokenizer = gr.Textbox(
329
  label="Tokenizer", value="cl100k_base", interactive=True
330
  )
331
- qa_form = gr.Radio(
332
  choices=["atomic", "multi_hop", "aggregated"],
333
- label="QA Form",
334
  value="aggregated",
335
  interactive=True,
336
  )
 
 
 
 
 
 
337
  quiz_samples = gr.Number(
338
  label="Quiz Samples",
339
  value=2,
@@ -533,33 +531,35 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
533
  if_trainee_model=args[0],
534
  input_file=args[1],
535
  tokenizer=args[2],
536
- qa_form=args[3],
537
- bidirectional=args[4],
538
- expand_method=args[5],
539
- max_extra_edges=args[6],
540
- max_tokens=args[7],
541
- max_depth=args[8],
542
- edge_sampling=args[9],
543
- isolated_node_strategy=args[10],
544
- loss_strategy=args[11],
545
- synthesizer_url=args[12],
546
- synthesizer_model=args[13],
547
- trainee_model=args[14],
548
- api_key=args[15],
549
- chunk_size=args[16],
550
- rpm=args[17],
551
- tpm=args[18],
552
- quiz_samples=args[19],
553
- trainee_url=args[20],
554
- trainee_api_key=args[21],
555
- token_counter=args[22],
 
556
  )
557
  ),
558
  inputs=[
559
  if_trainee_model,
560
  upload_file,
561
  tokenizer,
562
- qa_form,
 
563
  bidirectional,
564
  expand_method,
565
  max_extra_edges,
 
 
1
  import json
2
  import os
3
  import sys
 
5
 
6
  import gradio as gr
7
  import pandas as pd
8
+ from dotenv import load_dotenv
9
 
10
  from webui.base import GraphGenParams
11
  from webui.cache_utils import cleanup_workspace, setup_workspace
 
19
  sys.path.append(root_dir)
20
 
21
  from graphgen.graphgen import GraphGen
22
+ from graphgen.models import OpenAIModel, Tokenizer
23
  from graphgen.models.llm.limitter import RPM, TPM
24
  from graphgen.utils import set_logger
25
 
26
+ load_dotenv()
27
+
28
  css = """
29
  .center-row {
30
  display: flex;
 
38
  # Set up working directory
39
  log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
40
 
41
+ set_logger(log_file, if_stream=True)
42
+ graph_gen = GraphGen(working_dir=working_dir, config=config)
43
 
44
  # Set up LLM clients
45
  graph_gen.synthesizer_llm_client = OpenAIModel(
 
62
 
63
  graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  return graph_gen
66
 
67
 
 
73
  config = {
74
  "if_trainee_model": params.if_trainee_model,
75
  "input_file": params.input_file,
76
+ "output_data_type": params.output_data_type,
77
+ "output_data_format": params.output_data_format,
78
  "tokenizer": params.tokenizer,
79
+ "search": {"enabled": False},
80
+ "quiz_and_judge_strategy": {
81
+ "enabled": params.if_trainee_model,
82
+ "quiz_samples": params.quiz_samples,
83
+ },
84
  "traverse_strategy": {
 
85
  "bidirectional": params.bidirectional,
86
  "expand_method": params.expand_method,
87
  "max_extra_edges": params.max_extra_edges,
 
116
  env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
117
  )
118
 
119
+ # Load input data
120
+ file = config["input_file"]
121
+ if isinstance(file, list):
122
+ file = file[0]
123
+
124
+ data = []
125
+
126
+ if file.endswith(".jsonl"):
127
+ config["input_data_type"] = "raw"
128
+ with open(file, "r", encoding="utf-8") as f:
129
+ data.extend(json.loads(line) for line in f)
130
+ elif file.endswith(".json"):
131
+ config["input_data_type"] = "chunked"
132
+ with open(file, "r", encoding="utf-8") as f:
133
+ data.extend(json.load(f))
134
+ elif file.endswith(".txt"):
135
+ # 读取文件后根据chunk_size转成raw格式的数据
136
+ config["input_data_type"] = "raw"
137
+ content = ""
138
+ with open(file, "r", encoding="utf-8") as f:
139
+ lines = f.readlines()
140
+ for line in lines:
141
+ content += line.strip() + " "
142
+ size = int(config.get("chunk_size", 512))
143
+ chunks = [content[i : i + size] for i in range(0, len(content), size)]
144
+ data.extend([{"content": chunk} for chunk in chunks])
145
+ else:
146
+ raise ValueError(f"Unsupported file type: {file}")
147
+
148
  # Initialize GraphGen
149
  graph_gen = init_graph_gen(config, env)
150
  graph_gen.clear()
 
152
  graph_gen.progress_bar = progress
153
 
154
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  # Process the data
156
+ graph_gen.insert()
157
 
158
  if config["if_trainee_model"]:
159
  # Generate quiz
160
+ graph_gen.quiz()
161
 
162
  # Judge statements
163
  graph_gen.judge()
164
  else:
165
  graph_gen.traverse_strategy.edge_sampling = "random"
 
 
166
 
167
  # Traverse graph
168
+ graph_gen.traverse()
169
 
170
  # Save output
171
  output_data = graph_gen.qa_storage.data
 
320
  tokenizer = gr.Textbox(
321
  label="Tokenizer", value="cl100k_base", interactive=True
322
  )
323
+ output_data_type = gr.Radio(
324
  choices=["atomic", "multi_hop", "aggregated"],
325
+ label="Output Data Type",
326
  value="aggregated",
327
  interactive=True,
328
  )
329
+ output_data_format = gr.Radio(
330
+ choices=["Alpaca", "Sharegpt", "ChatML"],
331
+ label="Output Data Format",
332
+ value="Alpaca",
333
+ interactive=True,
334
+ )
335
  quiz_samples = gr.Number(
336
  label="Quiz Samples",
337
  value=2,
 
531
  if_trainee_model=args[0],
532
  input_file=args[1],
533
  tokenizer=args[2],
534
+ output_data_type=args[3],
535
+ output_data_format=args[4],
536
+ bidirectional=args[5],
537
+ expand_method=args[6],
538
+ max_extra_edges=args[7],
539
+ max_tokens=args[8],
540
+ max_depth=args[9],
541
+ edge_sampling=args[10],
542
+ isolated_node_strategy=args[11],
543
+ loss_strategy=args[12],
544
+ synthesizer_url=args[13],
545
+ synthesizer_model=args[14],
546
+ trainee_model=args[15],
547
+ api_key=args[16],
548
+ chunk_size=args[17],
549
+ rpm=args[18],
550
+ tpm=args[19],
551
+ quiz_samples=args[20],
552
+ trainee_url=args[21],
553
+ trainee_api_key=args[22],
554
+ token_counter=args[23],
555
  )
556
  ),
557
  inputs=[
558
  if_trainee_model,
559
  upload_file,
560
  tokenizer,
561
+ output_data_type,
562
+ output_data_format,
563
  bidirectional,
564
  expand_method,
565
  max_extra_edges,
webui/base.py CHANGED
@@ -1,15 +1,18 @@
1
  from dataclasses import dataclass
2
  from typing import Any
3
 
 
4
  @dataclass
5
  class GraphGenParams:
6
  """
7
  GraphGen parameters
8
  """
 
9
  if_trainee_model: bool
10
  input_file: str
11
  tokenizer: str
12
- qa_form: str
 
13
  bidirectional: bool
14
  expand_method: str
15
  max_extra_edges: int
 
1
  from dataclasses import dataclass
2
  from typing import Any
3
 
4
+
5
  @dataclass
6
  class GraphGenParams:
7
  """
8
  GraphGen parameters
9
  """
10
+
11
  if_trainee_model: bool
12
  input_file: str
13
  tokenizer: str
14
+ output_data_type: str
15
+ output_data_format: str
16
  bidirectional: bool
17
  expand_method: str
18
  max_extra_edges: int
webui/i18n.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import functools
2
  import inspect
3
  import json
 
1
+ # pylint: skip-file
2
  import functools
3
  import inspect
4
  import json