Spaces:
Running
Running
github-actions[bot]
commited on
Commit
·
4b2a9c2
1
Parent(s):
2c0627c
Auto-sync from demo at Wed Sep 10 08:57:26 UTC 2025
Browse files- app.py +76 -76
- graphgen/configs/multi_hop_config.yaml +1 -1
- graphgen/graphgen.py +5 -4
- graphgen/operators/__init__.py +4 -4
- graphgen/operators/traverse_graph.py +28 -18
- webui/app.py +76 -76
- webui/base.py +4 -1
- webui/i18n.py +1 -0
app.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# pylint: skip-file
|
| 2 |
import json
|
| 3 |
import os
|
| 4 |
import sys
|
|
@@ -6,6 +5,7 @@ import tempfile
|
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
import pandas as pd
|
|
|
|
| 9 |
|
| 10 |
from webui.base import GraphGenParams
|
| 11 |
from webui.cache_utils import cleanup_workspace, setup_workspace
|
|
@@ -19,10 +19,12 @@ root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
| 19 |
sys.path.append(root_dir)
|
| 20 |
|
| 21 |
from graphgen.graphgen import GraphGen
|
| 22 |
-
from graphgen.models import OpenAIModel, Tokenizer
|
| 23 |
from graphgen.models.llm.limitter import RPM, TPM
|
| 24 |
from graphgen.utils import set_logger
|
| 25 |
|
|
|
|
|
|
|
| 26 |
css = """
|
| 27 |
.center-row {
|
| 28 |
display: flex;
|
|
@@ -36,8 +38,8 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
|
| 36 |
# Set up working directory
|
| 37 |
log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
|
| 38 |
|
| 39 |
-
set_logger(log_file, if_stream=
|
| 40 |
-
graph_gen = GraphGen(working_dir=working_dir)
|
| 41 |
|
| 42 |
# Set up LLM clients
|
| 43 |
graph_gen.synthesizer_llm_client = OpenAIModel(
|
|
@@ -60,19 +62,6 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
|
| 60 |
|
| 61 |
graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
|
| 62 |
|
| 63 |
-
strategy_config = config.get("traverse_strategy", {})
|
| 64 |
-
graph_gen.traverse_strategy = TraverseStrategy(
|
| 65 |
-
qa_form=strategy_config.get("qa_form"),
|
| 66 |
-
expand_method=strategy_config.get("expand_method"),
|
| 67 |
-
bidirectional=strategy_config.get("bidirectional"),
|
| 68 |
-
max_extra_edges=strategy_config.get("max_extra_edges"),
|
| 69 |
-
max_tokens=strategy_config.get("max_tokens"),
|
| 70 |
-
max_depth=strategy_config.get("max_depth"),
|
| 71 |
-
edge_sampling=strategy_config.get("edge_sampling"),
|
| 72 |
-
isolated_node_strategy=strategy_config.get("isolated_node_strategy"),
|
| 73 |
-
loss_strategy=str(strategy_config.get("loss_strategy")),
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
return graph_gen
|
| 77 |
|
| 78 |
|
|
@@ -84,10 +73,15 @@ def run_graphgen(params, progress=gr.Progress()):
|
|
| 84 |
config = {
|
| 85 |
"if_trainee_model": params.if_trainee_model,
|
| 86 |
"input_file": params.input_file,
|
|
|
|
|
|
|
| 87 |
"tokenizer": params.tokenizer,
|
| 88 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
"traverse_strategy": {
|
| 90 |
-
"qa_form": params.qa_form,
|
| 91 |
"bidirectional": params.bidirectional,
|
| 92 |
"expand_method": params.expand_method,
|
| 93 |
"max_extra_edges": params.max_extra_edges,
|
|
@@ -122,6 +116,35 @@ def run_graphgen(params, progress=gr.Progress()):
|
|
| 122 |
env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
|
| 123 |
)
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
# Initialize GraphGen
|
| 126 |
graph_gen = init_graph_gen(config, env)
|
| 127 |
graph_gen.clear()
|
|
@@ -129,51 +152,20 @@ def run_graphgen(params, progress=gr.Progress()):
|
|
| 129 |
graph_gen.progress_bar = progress
|
| 130 |
|
| 131 |
try:
|
| 132 |
-
# Load input data
|
| 133 |
-
file = config["input_file"]
|
| 134 |
-
if isinstance(file, list):
|
| 135 |
-
file = file[0]
|
| 136 |
-
|
| 137 |
-
data = []
|
| 138 |
-
|
| 139 |
-
if file.endswith(".jsonl"):
|
| 140 |
-
data_type = "raw"
|
| 141 |
-
with open(file, "r", encoding="utf-8") as f:
|
| 142 |
-
data.extend(json.loads(line) for line in f)
|
| 143 |
-
elif file.endswith(".json"):
|
| 144 |
-
data_type = "chunked"
|
| 145 |
-
with open(file, "r", encoding="utf-8") as f:
|
| 146 |
-
data.extend(json.load(f))
|
| 147 |
-
elif file.endswith(".txt"):
|
| 148 |
-
# 读取文件后根据chunk_size转成raw格式的数据
|
| 149 |
-
data_type = "raw"
|
| 150 |
-
content = ""
|
| 151 |
-
with open(file, "r", encoding="utf-8") as f:
|
| 152 |
-
lines = f.readlines()
|
| 153 |
-
for line in lines:
|
| 154 |
-
content += line.strip() + " "
|
| 155 |
-
size = int(config.get("chunk_size", 512))
|
| 156 |
-
chunks = [content[i : i + size] for i in range(0, len(content), size)]
|
| 157 |
-
data.extend([{"content": chunk} for chunk in chunks])
|
| 158 |
-
else:
|
| 159 |
-
raise ValueError(f"Unsupported file type: {file}")
|
| 160 |
-
|
| 161 |
# Process the data
|
| 162 |
-
graph_gen.insert(
|
| 163 |
|
| 164 |
if config["if_trainee_model"]:
|
| 165 |
# Generate quiz
|
| 166 |
-
graph_gen.quiz(
|
| 167 |
|
| 168 |
# Judge statements
|
| 169 |
graph_gen.judge()
|
| 170 |
else:
|
| 171 |
graph_gen.traverse_strategy.edge_sampling = "random"
|
| 172 |
-
# Skip judge statements
|
| 173 |
-
graph_gen.judge(skip=True)
|
| 174 |
|
| 175 |
# Traverse graph
|
| 176 |
-
graph_gen.traverse(
|
| 177 |
|
| 178 |
# Save output
|
| 179 |
output_data = graph_gen.qa_storage.data
|
|
@@ -328,12 +320,18 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 328 |
tokenizer = gr.Textbox(
|
| 329 |
label="Tokenizer", value="cl100k_base", interactive=True
|
| 330 |
)
|
| 331 |
-
|
| 332 |
choices=["atomic", "multi_hop", "aggregated"],
|
| 333 |
-
label="
|
| 334 |
value="aggregated",
|
| 335 |
interactive=True,
|
| 336 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
quiz_samples = gr.Number(
|
| 338 |
label="Quiz Samples",
|
| 339 |
value=2,
|
|
@@ -533,33 +531,35 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 533 |
if_trainee_model=args[0],
|
| 534 |
input_file=args[1],
|
| 535 |
tokenizer=args[2],
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
|
|
|
| 556 |
)
|
| 557 |
),
|
| 558 |
inputs=[
|
| 559 |
if_trainee_model,
|
| 560 |
upload_file,
|
| 561 |
tokenizer,
|
| 562 |
-
|
|
|
|
| 563 |
bidirectional,
|
| 564 |
expand_method,
|
| 565 |
max_extra_edges,
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
import sys
|
|
|
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
import pandas as pd
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
|
| 10 |
from webui.base import GraphGenParams
|
| 11 |
from webui.cache_utils import cleanup_workspace, setup_workspace
|
|
|
|
| 19 |
sys.path.append(root_dir)
|
| 20 |
|
| 21 |
from graphgen.graphgen import GraphGen
|
| 22 |
+
from graphgen.models import OpenAIModel, Tokenizer
|
| 23 |
from graphgen.models.llm.limitter import RPM, TPM
|
| 24 |
from graphgen.utils import set_logger
|
| 25 |
|
| 26 |
+
load_dotenv()
|
| 27 |
+
|
| 28 |
css = """
|
| 29 |
.center-row {
|
| 30 |
display: flex;
|
|
|
|
| 38 |
# Set up working directory
|
| 39 |
log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
|
| 40 |
|
| 41 |
+
set_logger(log_file, if_stream=True)
|
| 42 |
+
graph_gen = GraphGen(working_dir=working_dir, config=config)
|
| 43 |
|
| 44 |
# Set up LLM clients
|
| 45 |
graph_gen.synthesizer_llm_client = OpenAIModel(
|
|
|
|
| 62 |
|
| 63 |
graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
return graph_gen
|
| 66 |
|
| 67 |
|
|
|
|
| 73 |
config = {
|
| 74 |
"if_trainee_model": params.if_trainee_model,
|
| 75 |
"input_file": params.input_file,
|
| 76 |
+
"output_data_type": params.output_data_type,
|
| 77 |
+
"output_data_format": params.output_data_format,
|
| 78 |
"tokenizer": params.tokenizer,
|
| 79 |
+
"search": {"enabled": False},
|
| 80 |
+
"quiz_and_judge_strategy": {
|
| 81 |
+
"enabled": params.if_trainee_model,
|
| 82 |
+
"quiz_samples": params.quiz_samples,
|
| 83 |
+
},
|
| 84 |
"traverse_strategy": {
|
|
|
|
| 85 |
"bidirectional": params.bidirectional,
|
| 86 |
"expand_method": params.expand_method,
|
| 87 |
"max_extra_edges": params.max_extra_edges,
|
|
|
|
| 116 |
env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
|
| 117 |
)
|
| 118 |
|
| 119 |
+
# Load input data
|
| 120 |
+
file = config["input_file"]
|
| 121 |
+
if isinstance(file, list):
|
| 122 |
+
file = file[0]
|
| 123 |
+
|
| 124 |
+
data = []
|
| 125 |
+
|
| 126 |
+
if file.endswith(".jsonl"):
|
| 127 |
+
config["input_data_type"] = "raw"
|
| 128 |
+
with open(file, "r", encoding="utf-8") as f:
|
| 129 |
+
data.extend(json.loads(line) for line in f)
|
| 130 |
+
elif file.endswith(".json"):
|
| 131 |
+
config["input_data_type"] = "chunked"
|
| 132 |
+
with open(file, "r", encoding="utf-8") as f:
|
| 133 |
+
data.extend(json.load(f))
|
| 134 |
+
elif file.endswith(".txt"):
|
| 135 |
+
# 读取文件后根据chunk_size转成raw格式的数据
|
| 136 |
+
config["input_data_type"] = "raw"
|
| 137 |
+
content = ""
|
| 138 |
+
with open(file, "r", encoding="utf-8") as f:
|
| 139 |
+
lines = f.readlines()
|
| 140 |
+
for line in lines:
|
| 141 |
+
content += line.strip() + " "
|
| 142 |
+
size = int(config.get("chunk_size", 512))
|
| 143 |
+
chunks = [content[i : i + size] for i in range(0, len(content), size)]
|
| 144 |
+
data.extend([{"content": chunk} for chunk in chunks])
|
| 145 |
+
else:
|
| 146 |
+
raise ValueError(f"Unsupported file type: {file}")
|
| 147 |
+
|
| 148 |
# Initialize GraphGen
|
| 149 |
graph_gen = init_graph_gen(config, env)
|
| 150 |
graph_gen.clear()
|
|
|
|
| 152 |
graph_gen.progress_bar = progress
|
| 153 |
|
| 154 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# Process the data
|
| 156 |
+
graph_gen.insert()
|
| 157 |
|
| 158 |
if config["if_trainee_model"]:
|
| 159 |
# Generate quiz
|
| 160 |
+
graph_gen.quiz()
|
| 161 |
|
| 162 |
# Judge statements
|
| 163 |
graph_gen.judge()
|
| 164 |
else:
|
| 165 |
graph_gen.traverse_strategy.edge_sampling = "random"
|
|
|
|
|
|
|
| 166 |
|
| 167 |
# Traverse graph
|
| 168 |
+
graph_gen.traverse()
|
| 169 |
|
| 170 |
# Save output
|
| 171 |
output_data = graph_gen.qa_storage.data
|
|
|
|
| 320 |
tokenizer = gr.Textbox(
|
| 321 |
label="Tokenizer", value="cl100k_base", interactive=True
|
| 322 |
)
|
| 323 |
+
output_data_type = gr.Radio(
|
| 324 |
choices=["atomic", "multi_hop", "aggregated"],
|
| 325 |
+
label="Output Data Type",
|
| 326 |
value="aggregated",
|
| 327 |
interactive=True,
|
| 328 |
)
|
| 329 |
+
output_data_format = gr.Radio(
|
| 330 |
+
choices=["Alpaca", "Sharegpt", "ChatML"],
|
| 331 |
+
label="Output Data Format",
|
| 332 |
+
value="Alpaca",
|
| 333 |
+
interactive=True,
|
| 334 |
+
)
|
| 335 |
quiz_samples = gr.Number(
|
| 336 |
label="Quiz Samples",
|
| 337 |
value=2,
|
|
|
|
| 531 |
if_trainee_model=args[0],
|
| 532 |
input_file=args[1],
|
| 533 |
tokenizer=args[2],
|
| 534 |
+
output_data_type=args[3],
|
| 535 |
+
output_data_format=args[4],
|
| 536 |
+
bidirectional=args[5],
|
| 537 |
+
expand_method=args[6],
|
| 538 |
+
max_extra_edges=args[7],
|
| 539 |
+
max_tokens=args[8],
|
| 540 |
+
max_depth=args[9],
|
| 541 |
+
edge_sampling=args[10],
|
| 542 |
+
isolated_node_strategy=args[11],
|
| 543 |
+
loss_strategy=args[12],
|
| 544 |
+
synthesizer_url=args[13],
|
| 545 |
+
synthesizer_model=args[14],
|
| 546 |
+
trainee_model=args[15],
|
| 547 |
+
api_key=args[16],
|
| 548 |
+
chunk_size=args[17],
|
| 549 |
+
rpm=args[18],
|
| 550 |
+
tpm=args[19],
|
| 551 |
+
quiz_samples=args[20],
|
| 552 |
+
trainee_url=args[21],
|
| 553 |
+
trainee_api_key=args[22],
|
| 554 |
+
token_counter=args[23],
|
| 555 |
)
|
| 556 |
),
|
| 557 |
inputs=[
|
| 558 |
if_trainee_model,
|
| 559 |
upload_file,
|
| 560 |
tokenizer,
|
| 561 |
+
output_data_type,
|
| 562 |
+
output_data_format,
|
| 563 |
bidirectional,
|
| 564 |
expand_method,
|
| 565 |
max_extra_edges,
|
graphgen/configs/multi_hop_config.yaml
CHANGED
|
@@ -7,7 +7,7 @@ search: # web search configuration
|
|
| 7 |
enabled: false # whether to enable web search
|
| 8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
| 9 |
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
|
| 10 |
-
enabled:
|
| 11 |
quiz_samples: 2 # number of quiz samples to generate
|
| 12 |
re_judge: false # whether to re-judge the existing quiz samples
|
| 13 |
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
|
|
|
|
| 7 |
enabled: false # whether to enable web search
|
| 8 |
search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia
|
| 9 |
quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points
|
| 10 |
+
enabled: false
|
| 11 |
quiz_samples: 2 # number of quiz samples to generate
|
| 12 |
re_judge: false # whether to re-judge the existing quiz samples
|
| 13 |
traverse_strategy: # strategy for clustering sub-graphs using comprehension loss
|
graphgen/graphgen.py
CHANGED
|
@@ -23,8 +23,8 @@ from .operators import (
|
|
| 23 |
judge_statement,
|
| 24 |
quiz,
|
| 25 |
search_all,
|
| 26 |
-
|
| 27 |
-
|
| 28 |
traverse_graph_for_multi_hop,
|
| 29 |
)
|
| 30 |
from .utils import (
|
|
@@ -69,6 +69,7 @@ class GraphGen:
|
|
| 69 |
self.tokenizer_instance: Tokenizer = Tokenizer(
|
| 70 |
model_name=self.config["tokenizer"]
|
| 71 |
)
|
|
|
|
| 72 |
self.synthesizer_llm_client: OpenAIModel = OpenAIModel(
|
| 73 |
model_name=os.getenv("SYNTHESIZER_MODEL"),
|
| 74 |
api_key=os.getenv("SYNTHESIZER_API_KEY"),
|
|
@@ -326,7 +327,7 @@ class GraphGen:
|
|
| 326 |
output_data_type = self.config["output_data_type"]
|
| 327 |
|
| 328 |
if output_data_type == "atomic":
|
| 329 |
-
results = await
|
| 330 |
self.synthesizer_llm_client,
|
| 331 |
self.tokenizer_instance,
|
| 332 |
self.graph_storage,
|
|
@@ -344,7 +345,7 @@ class GraphGen:
|
|
| 344 |
self.progress_bar,
|
| 345 |
)
|
| 346 |
elif output_data_type == "aggregated":
|
| 347 |
-
results = await
|
| 348 |
self.synthesizer_llm_client,
|
| 349 |
self.tokenizer_instance,
|
| 350 |
self.graph_storage,
|
|
|
|
| 23 |
judge_statement,
|
| 24 |
quiz,
|
| 25 |
search_all,
|
| 26 |
+
traverse_graph_for_aggregated,
|
| 27 |
+
traverse_graph_for_atomic,
|
| 28 |
traverse_graph_for_multi_hop,
|
| 29 |
)
|
| 30 |
from .utils import (
|
|
|
|
| 69 |
self.tokenizer_instance: Tokenizer = Tokenizer(
|
| 70 |
model_name=self.config["tokenizer"]
|
| 71 |
)
|
| 72 |
+
print(os.getenv("SYNTHESIZER_MODEL"), os.getenv("SYNTHESIZER_API_KEY"))
|
| 73 |
self.synthesizer_llm_client: OpenAIModel = OpenAIModel(
|
| 74 |
model_name=os.getenv("SYNTHESIZER_MODEL"),
|
| 75 |
api_key=os.getenv("SYNTHESIZER_API_KEY"),
|
|
|
|
| 327 |
output_data_type = self.config["output_data_type"]
|
| 328 |
|
| 329 |
if output_data_type == "atomic":
|
| 330 |
+
results = await traverse_graph_for_atomic(
|
| 331 |
self.synthesizer_llm_client,
|
| 332 |
self.tokenizer_instance,
|
| 333 |
self.graph_storage,
|
|
|
|
| 345 |
self.progress_bar,
|
| 346 |
)
|
| 347 |
elif output_data_type == "aggregated":
|
| 348 |
+
results = await traverse_graph_for_aggregated(
|
| 349 |
self.synthesizer_llm_client,
|
| 350 |
self.tokenizer_instance,
|
| 351 |
self.graph_storage,
|
graphgen/operators/__init__.py
CHANGED
|
@@ -5,8 +5,8 @@ from graphgen.operators.search.search_all import search_all
|
|
| 5 |
from .judge import judge_statement
|
| 6 |
from .quiz import quiz
|
| 7 |
from .traverse_graph import (
|
| 8 |
-
|
| 9 |
-
|
| 10 |
traverse_graph_for_multi_hop,
|
| 11 |
)
|
| 12 |
|
|
@@ -15,8 +15,8 @@ __all__ = [
|
|
| 15 |
"quiz",
|
| 16 |
"judge_statement",
|
| 17 |
"search_all",
|
| 18 |
-
"
|
| 19 |
-
"
|
| 20 |
"traverse_graph_for_multi_hop",
|
| 21 |
"generate_cot",
|
| 22 |
]
|
|
|
|
| 5 |
from .judge import judge_statement
|
| 6 |
from .quiz import quiz
|
| 7 |
from .traverse_graph import (
|
| 8 |
+
traverse_graph_for_aggregated,
|
| 9 |
+
traverse_graph_for_atomic,
|
| 10 |
traverse_graph_for_multi_hop,
|
| 11 |
)
|
| 12 |
|
|
|
|
| 15 |
"quiz",
|
| 16 |
"judge_statement",
|
| 17 |
"search_all",
|
| 18 |
+
"traverse_graph_for_aggregated",
|
| 19 |
+
"traverse_graph_for_atomic",
|
| 20 |
"traverse_graph_for_multi_hop",
|
| 21 |
"generate_cot",
|
| 22 |
]
|
graphgen/operators/traverse_graph.py
CHANGED
|
@@ -135,7 +135,9 @@ def get_average_loss(batch: tuple, loss_strategy: str) -> float:
|
|
| 135 |
) / (len(batch[0]) + len(batch[1]))
|
| 136 |
raise ValueError("Invalid loss strategy")
|
| 137 |
except Exception as e: # pylint: disable=broad-except
|
| 138 |
-
logger.
|
|
|
|
|
|
|
| 139 |
return -1.0
|
| 140 |
|
| 141 |
|
|
@@ -158,7 +160,7 @@ def _post_process_synthetic_data(data):
|
|
| 158 |
return qas
|
| 159 |
|
| 160 |
|
| 161 |
-
async def
|
| 162 |
llm_client: OpenAIModel,
|
| 163 |
tokenizer: Tokenizer,
|
| 164 |
graph_storage: NetworkXStorage,
|
|
@@ -251,7 +253,6 @@ async def traverse_graph_by_edge(
|
|
| 251 |
qas = _post_process_synthetic_data(content)
|
| 252 |
|
| 253 |
if len(qas) == 0:
|
| 254 |
-
print(content)
|
| 255 |
logger.error(
|
| 256 |
"Error occurred while processing batch, question or answer is None"
|
| 257 |
)
|
|
@@ -307,7 +308,8 @@ async def traverse_graph_by_edge(
|
|
| 307 |
return results
|
| 308 |
|
| 309 |
|
| 310 |
-
|
|
|
|
| 311 |
llm_client: OpenAIModel,
|
| 312 |
tokenizer: Tokenizer,
|
| 313 |
graph_storage: NetworkXStorage,
|
|
@@ -328,17 +330,28 @@ async def traverse_graph_atomically(
|
|
| 328 |
:param max_concurrent
|
| 329 |
:return: question and answer
|
| 330 |
"""
|
| 331 |
-
assert traverse_strategy.qa_form == "atomic"
|
| 332 |
|
|
|
|
| 333 |
semaphore = asyncio.Semaphore(max_concurrent)
|
| 334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 335 |
async def _generate_question(node_or_edge: tuple):
|
| 336 |
if len(node_or_edge) == 2:
|
| 337 |
des = node_or_edge[0] + ": " + node_or_edge[1]["description"]
|
| 338 |
-
loss = node_or_edge[1]["loss"]
|
| 339 |
else:
|
| 340 |
des = node_or_edge[2]["description"]
|
| 341 |
-
loss = node_or_edge[2]["loss"]
|
| 342 |
|
| 343 |
async with semaphore:
|
| 344 |
try:
|
|
@@ -350,13 +363,8 @@ async def traverse_graph_atomically(
|
|
| 350 |
)
|
| 351 |
)
|
| 352 |
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
answer = qa.split("Answer:")[1].strip()
|
| 356 |
-
elif "问题:" in qa and "答案:" in qa:
|
| 357 |
-
question = qa.split("问题:")[1].split("答案:")[0].strip()
|
| 358 |
-
answer = qa.split("答案:")[1].strip()
|
| 359 |
-
else:
|
| 360 |
return {}
|
| 361 |
|
| 362 |
question = question.strip('"')
|
|
@@ -386,16 +394,18 @@ async def traverse_graph_atomically(
|
|
| 386 |
if "<SEP>" in node[1]["description"]:
|
| 387 |
description_list = node[1]["description"].split("<SEP>")
|
| 388 |
for item in description_list:
|
| 389 |
-
tasks.append((node[0], {"description": item
|
|
|
|
|
|
|
| 390 |
else:
|
| 391 |
tasks.append((node[0], node[1]))
|
| 392 |
for edge in edges:
|
| 393 |
if "<SEP>" in edge[2]["description"]:
|
| 394 |
description_list = edge[2]["description"].split("<SEP>")
|
| 395 |
for item in description_list:
|
| 396 |
-
tasks.append(
|
| 397 |
-
|
| 398 |
-
|
| 399 |
else:
|
| 400 |
tasks.append((edge[0], edge[1], edge[2]))
|
| 401 |
|
|
|
|
| 135 |
) / (len(batch[0]) + len(batch[1]))
|
| 136 |
raise ValueError("Invalid loss strategy")
|
| 137 |
except Exception as e: # pylint: disable=broad-except
|
| 138 |
+
logger.warning(
|
| 139 |
+
"Loss not found in some nodes or edges, setting loss to -1.0: %s", e
|
| 140 |
+
)
|
| 141 |
return -1.0
|
| 142 |
|
| 143 |
|
|
|
|
| 160 |
return qas
|
| 161 |
|
| 162 |
|
| 163 |
+
async def traverse_graph_for_aggregated(
|
| 164 |
llm_client: OpenAIModel,
|
| 165 |
tokenizer: Tokenizer,
|
| 166 |
graph_storage: NetworkXStorage,
|
|
|
|
| 253 |
qas = _post_process_synthetic_data(content)
|
| 254 |
|
| 255 |
if len(qas) == 0:
|
|
|
|
| 256 |
logger.error(
|
| 257 |
"Error occurred while processing batch, question or answer is None"
|
| 258 |
)
|
|
|
|
| 308 |
return results
|
| 309 |
|
| 310 |
|
| 311 |
+
# pylint: disable=too-many-branches, too-many-statements
|
| 312 |
+
async def traverse_graph_for_atomic(
|
| 313 |
llm_client: OpenAIModel,
|
| 314 |
tokenizer: Tokenizer,
|
| 315 |
graph_storage: NetworkXStorage,
|
|
|
|
| 330 |
:param max_concurrent
|
| 331 |
:return: question and answer
|
| 332 |
"""
|
|
|
|
| 333 |
|
| 334 |
+
assert traverse_strategy.qa_form == "atomic"
|
| 335 |
semaphore = asyncio.Semaphore(max_concurrent)
|
| 336 |
|
| 337 |
+
def _parse_qa(qa: str) -> tuple:
|
| 338 |
+
if "Question:" in qa and "Answer:" in qa:
|
| 339 |
+
question = qa.split("Question:")[1].split("Answer:")[0].strip()
|
| 340 |
+
answer = qa.split("Answer:")[1].strip()
|
| 341 |
+
elif "问题:" in qa and "答案:" in qa:
|
| 342 |
+
question = qa.split("问题:")[1].split("答案:")[0].strip()
|
| 343 |
+
answer = qa.split("答案:")[1].strip()
|
| 344 |
+
else:
|
| 345 |
+
return None, None
|
| 346 |
+
return question.strip('"'), answer.strip('"')
|
| 347 |
+
|
| 348 |
async def _generate_question(node_or_edge: tuple):
|
| 349 |
if len(node_or_edge) == 2:
|
| 350 |
des = node_or_edge[0] + ": " + node_or_edge[1]["description"]
|
| 351 |
+
loss = node_or_edge[1]["loss"] if "loss" in node_or_edge[1] else -1.0
|
| 352 |
else:
|
| 353 |
des = node_or_edge[2]["description"]
|
| 354 |
+
loss = node_or_edge[2]["loss"] if "loss" in node_or_edge[2] else -1.0
|
| 355 |
|
| 356 |
async with semaphore:
|
| 357 |
try:
|
|
|
|
| 363 |
)
|
| 364 |
)
|
| 365 |
|
| 366 |
+
question, answer = _parse_qa(qa)
|
| 367 |
+
if question is None or answer is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
return {}
|
| 369 |
|
| 370 |
question = question.strip('"')
|
|
|
|
| 394 |
if "<SEP>" in node[1]["description"]:
|
| 395 |
description_list = node[1]["description"].split("<SEP>")
|
| 396 |
for item in description_list:
|
| 397 |
+
tasks.append((node[0], {"description": item}))
|
| 398 |
+
if "loss" in node[1]:
|
| 399 |
+
tasks[-1][1]["loss"] = node[1]["loss"]
|
| 400 |
else:
|
| 401 |
tasks.append((node[0], node[1]))
|
| 402 |
for edge in edges:
|
| 403 |
if "<SEP>" in edge[2]["description"]:
|
| 404 |
description_list = edge[2]["description"].split("<SEP>")
|
| 405 |
for item in description_list:
|
| 406 |
+
tasks.append((edge[0], edge[1], {"description": item}))
|
| 407 |
+
if "loss" in edge[2]:
|
| 408 |
+
tasks[-1][2]["loss"] = edge[2]["loss"]
|
| 409 |
else:
|
| 410 |
tasks.append((edge[0], edge[1], edge[2]))
|
| 411 |
|
webui/app.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# pylint: skip-file
|
| 2 |
import json
|
| 3 |
import os
|
| 4 |
import sys
|
|
@@ -6,6 +5,7 @@ import tempfile
|
|
| 6 |
|
| 7 |
import gradio as gr
|
| 8 |
import pandas as pd
|
|
|
|
| 9 |
|
| 10 |
from webui.base import GraphGenParams
|
| 11 |
from webui.cache_utils import cleanup_workspace, setup_workspace
|
|
@@ -19,10 +19,12 @@ root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
| 19 |
sys.path.append(root_dir)
|
| 20 |
|
| 21 |
from graphgen.graphgen import GraphGen
|
| 22 |
-
from graphgen.models import OpenAIModel, Tokenizer
|
| 23 |
from graphgen.models.llm.limitter import RPM, TPM
|
| 24 |
from graphgen.utils import set_logger
|
| 25 |
|
|
|
|
|
|
|
| 26 |
css = """
|
| 27 |
.center-row {
|
| 28 |
display: flex;
|
|
@@ -36,8 +38,8 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
|
| 36 |
# Set up working directory
|
| 37 |
log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
|
| 38 |
|
| 39 |
-
set_logger(log_file, if_stream=
|
| 40 |
-
graph_gen = GraphGen(working_dir=working_dir)
|
| 41 |
|
| 42 |
# Set up LLM clients
|
| 43 |
graph_gen.synthesizer_llm_client = OpenAIModel(
|
|
@@ -60,19 +62,6 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen:
|
|
| 60 |
|
| 61 |
graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
|
| 62 |
|
| 63 |
-
strategy_config = config.get("traverse_strategy", {})
|
| 64 |
-
graph_gen.traverse_strategy = TraverseStrategy(
|
| 65 |
-
qa_form=strategy_config.get("qa_form"),
|
| 66 |
-
expand_method=strategy_config.get("expand_method"),
|
| 67 |
-
bidirectional=strategy_config.get("bidirectional"),
|
| 68 |
-
max_extra_edges=strategy_config.get("max_extra_edges"),
|
| 69 |
-
max_tokens=strategy_config.get("max_tokens"),
|
| 70 |
-
max_depth=strategy_config.get("max_depth"),
|
| 71 |
-
edge_sampling=strategy_config.get("edge_sampling"),
|
| 72 |
-
isolated_node_strategy=strategy_config.get("isolated_node_strategy"),
|
| 73 |
-
loss_strategy=str(strategy_config.get("loss_strategy")),
|
| 74 |
-
)
|
| 75 |
-
|
| 76 |
return graph_gen
|
| 77 |
|
| 78 |
|
|
@@ -84,10 +73,15 @@ def run_graphgen(params, progress=gr.Progress()):
|
|
| 84 |
config = {
|
| 85 |
"if_trainee_model": params.if_trainee_model,
|
| 86 |
"input_file": params.input_file,
|
|
|
|
|
|
|
| 87 |
"tokenizer": params.tokenizer,
|
| 88 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
"traverse_strategy": {
|
| 90 |
-
"qa_form": params.qa_form,
|
| 91 |
"bidirectional": params.bidirectional,
|
| 92 |
"expand_method": params.expand_method,
|
| 93 |
"max_extra_edges": params.max_extra_edges,
|
|
@@ -122,6 +116,35 @@ def run_graphgen(params, progress=gr.Progress()):
|
|
| 122 |
env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
|
| 123 |
)
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
# Initialize GraphGen
|
| 126 |
graph_gen = init_graph_gen(config, env)
|
| 127 |
graph_gen.clear()
|
|
@@ -129,51 +152,20 @@ def run_graphgen(params, progress=gr.Progress()):
|
|
| 129 |
graph_gen.progress_bar = progress
|
| 130 |
|
| 131 |
try:
|
| 132 |
-
# Load input data
|
| 133 |
-
file = config["input_file"]
|
| 134 |
-
if isinstance(file, list):
|
| 135 |
-
file = file[0]
|
| 136 |
-
|
| 137 |
-
data = []
|
| 138 |
-
|
| 139 |
-
if file.endswith(".jsonl"):
|
| 140 |
-
data_type = "raw"
|
| 141 |
-
with open(file, "r", encoding="utf-8") as f:
|
| 142 |
-
data.extend(json.loads(line) for line in f)
|
| 143 |
-
elif file.endswith(".json"):
|
| 144 |
-
data_type = "chunked"
|
| 145 |
-
with open(file, "r", encoding="utf-8") as f:
|
| 146 |
-
data.extend(json.load(f))
|
| 147 |
-
elif file.endswith(".txt"):
|
| 148 |
-
# 读取文件后根据chunk_size转成raw格式的数据
|
| 149 |
-
data_type = "raw"
|
| 150 |
-
content = ""
|
| 151 |
-
with open(file, "r", encoding="utf-8") as f:
|
| 152 |
-
lines = f.readlines()
|
| 153 |
-
for line in lines:
|
| 154 |
-
content += line.strip() + " "
|
| 155 |
-
size = int(config.get("chunk_size", 512))
|
| 156 |
-
chunks = [content[i : i + size] for i in range(0, len(content), size)]
|
| 157 |
-
data.extend([{"content": chunk} for chunk in chunks])
|
| 158 |
-
else:
|
| 159 |
-
raise ValueError(f"Unsupported file type: {file}")
|
| 160 |
-
|
| 161 |
# Process the data
|
| 162 |
-
graph_gen.insert(
|
| 163 |
|
| 164 |
if config["if_trainee_model"]:
|
| 165 |
# Generate quiz
|
| 166 |
-
graph_gen.quiz(
|
| 167 |
|
| 168 |
# Judge statements
|
| 169 |
graph_gen.judge()
|
| 170 |
else:
|
| 171 |
graph_gen.traverse_strategy.edge_sampling = "random"
|
| 172 |
-
# Skip judge statements
|
| 173 |
-
graph_gen.judge(skip=True)
|
| 174 |
|
| 175 |
# Traverse graph
|
| 176 |
-
graph_gen.traverse(
|
| 177 |
|
| 178 |
# Save output
|
| 179 |
output_data = graph_gen.qa_storage.data
|
|
@@ -328,12 +320,18 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 328 |
tokenizer = gr.Textbox(
|
| 329 |
label="Tokenizer", value="cl100k_base", interactive=True
|
| 330 |
)
|
| 331 |
-
|
| 332 |
choices=["atomic", "multi_hop", "aggregated"],
|
| 333 |
-
label="
|
| 334 |
value="aggregated",
|
| 335 |
interactive=True,
|
| 336 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
quiz_samples = gr.Number(
|
| 338 |
label="Quiz Samples",
|
| 339 |
value=2,
|
|
@@ -533,33 +531,35 @@ with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo:
|
|
| 533 |
if_trainee_model=args[0],
|
| 534 |
input_file=args[1],
|
| 535 |
tokenizer=args[2],
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
|
|
|
|
| 556 |
)
|
| 557 |
),
|
| 558 |
inputs=[
|
| 559 |
if_trainee_model,
|
| 560 |
upload_file,
|
| 561 |
tokenizer,
|
| 562 |
-
|
|
|
|
| 563 |
bidirectional,
|
| 564 |
expand_method,
|
| 565 |
max_extra_edges,
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
import sys
|
|
|
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
import pandas as pd
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
|
| 10 |
from webui.base import GraphGenParams
|
| 11 |
from webui.cache_utils import cleanup_workspace, setup_workspace
|
|
|
|
| 19 |
sys.path.append(root_dir)
|
| 20 |
|
| 21 |
from graphgen.graphgen import GraphGen
|
| 22 |
+
from graphgen.models import OpenAIModel, Tokenizer
|
| 23 |
from graphgen.models.llm.limitter import RPM, TPM
|
| 24 |
from graphgen.utils import set_logger
|
| 25 |
|
| 26 |
+
load_dotenv()
|
| 27 |
+
|
| 28 |
css = """
|
| 29 |
.center-row {
|
| 30 |
display: flex;
|
|
|
|
| 38 |
# Set up working directory
|
| 39 |
log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache"))
|
| 40 |
|
| 41 |
+
set_logger(log_file, if_stream=True)
|
| 42 |
+
graph_gen = GraphGen(working_dir=working_dir, config=config)
|
| 43 |
|
| 44 |
# Set up LLM clients
|
| 45 |
graph_gen.synthesizer_llm_client = OpenAIModel(
|
|
|
|
| 62 |
|
| 63 |
graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base"))
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
return graph_gen
|
| 66 |
|
| 67 |
|
|
|
|
| 73 |
config = {
|
| 74 |
"if_trainee_model": params.if_trainee_model,
|
| 75 |
"input_file": params.input_file,
|
| 76 |
+
"output_data_type": params.output_data_type,
|
| 77 |
+
"output_data_format": params.output_data_format,
|
| 78 |
"tokenizer": params.tokenizer,
|
| 79 |
+
"search": {"enabled": False},
|
| 80 |
+
"quiz_and_judge_strategy": {
|
| 81 |
+
"enabled": params.if_trainee_model,
|
| 82 |
+
"quiz_samples": params.quiz_samples,
|
| 83 |
+
},
|
| 84 |
"traverse_strategy": {
|
|
|
|
| 85 |
"bidirectional": params.bidirectional,
|
| 86 |
"expand_method": params.expand_method,
|
| 87 |
"max_extra_edges": params.max_extra_edges,
|
|
|
|
| 116 |
env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]
|
| 117 |
)
|
| 118 |
|
| 119 |
+
# Load input data
|
| 120 |
+
file = config["input_file"]
|
| 121 |
+
if isinstance(file, list):
|
| 122 |
+
file = file[0]
|
| 123 |
+
|
| 124 |
+
data = []
|
| 125 |
+
|
| 126 |
+
if file.endswith(".jsonl"):
|
| 127 |
+
config["input_data_type"] = "raw"
|
| 128 |
+
with open(file, "r", encoding="utf-8") as f:
|
| 129 |
+
data.extend(json.loads(line) for line in f)
|
| 130 |
+
elif file.endswith(".json"):
|
| 131 |
+
config["input_data_type"] = "chunked"
|
| 132 |
+
with open(file, "r", encoding="utf-8") as f:
|
| 133 |
+
data.extend(json.load(f))
|
| 134 |
+
elif file.endswith(".txt"):
|
| 135 |
+
# 读取文件后根据chunk_size转成raw格式的数据
|
| 136 |
+
config["input_data_type"] = "raw"
|
| 137 |
+
content = ""
|
| 138 |
+
with open(file, "r", encoding="utf-8") as f:
|
| 139 |
+
lines = f.readlines()
|
| 140 |
+
for line in lines:
|
| 141 |
+
content += line.strip() + " "
|
| 142 |
+
size = int(config.get("chunk_size", 512))
|
| 143 |
+
chunks = [content[i : i + size] for i in range(0, len(content), size)]
|
| 144 |
+
data.extend([{"content": chunk} for chunk in chunks])
|
| 145 |
+
else:
|
| 146 |
+
raise ValueError(f"Unsupported file type: {file}")
|
| 147 |
+
|
| 148 |
# Initialize GraphGen
|
| 149 |
graph_gen = init_graph_gen(config, env)
|
| 150 |
graph_gen.clear()
|
|
|
|
| 152 |
graph_gen.progress_bar = progress
|
| 153 |
|
| 154 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# Process the data
|
| 156 |
+
graph_gen.insert()
|
| 157 |
|
| 158 |
if config["if_trainee_model"]:
|
| 159 |
# Generate quiz
|
| 160 |
+
graph_gen.quiz()
|
| 161 |
|
| 162 |
# Judge statements
|
| 163 |
graph_gen.judge()
|
| 164 |
else:
|
| 165 |
graph_gen.traverse_strategy.edge_sampling = "random"
|
|
|
|
|
|
|
| 166 |
|
| 167 |
# Traverse graph
|
| 168 |
+
graph_gen.traverse()
|
| 169 |
|
| 170 |
# Save output
|
| 171 |
output_data = graph_gen.qa_storage.data
|
|
|
|
| 320 |
tokenizer = gr.Textbox(
|
| 321 |
label="Tokenizer", value="cl100k_base", interactive=True
|
| 322 |
)
|
| 323 |
+
output_data_type = gr.Radio(
|
| 324 |
choices=["atomic", "multi_hop", "aggregated"],
|
| 325 |
+
label="Output Data Type",
|
| 326 |
value="aggregated",
|
| 327 |
interactive=True,
|
| 328 |
)
|
| 329 |
+
output_data_format = gr.Radio(
|
| 330 |
+
choices=["Alpaca", "Sharegpt", "ChatML"],
|
| 331 |
+
label="Output Data Format",
|
| 332 |
+
value="Alpaca",
|
| 333 |
+
interactive=True,
|
| 334 |
+
)
|
| 335 |
quiz_samples = gr.Number(
|
| 336 |
label="Quiz Samples",
|
| 337 |
value=2,
|
|
|
|
| 531 |
if_trainee_model=args[0],
|
| 532 |
input_file=args[1],
|
| 533 |
tokenizer=args[2],
|
| 534 |
+
output_data_type=args[3],
|
| 535 |
+
output_data_format=args[4],
|
| 536 |
+
bidirectional=args[5],
|
| 537 |
+
expand_method=args[6],
|
| 538 |
+
max_extra_edges=args[7],
|
| 539 |
+
max_tokens=args[8],
|
| 540 |
+
max_depth=args[9],
|
| 541 |
+
edge_sampling=args[10],
|
| 542 |
+
isolated_node_strategy=args[11],
|
| 543 |
+
loss_strategy=args[12],
|
| 544 |
+
synthesizer_url=args[13],
|
| 545 |
+
synthesizer_model=args[14],
|
| 546 |
+
trainee_model=args[15],
|
| 547 |
+
api_key=args[16],
|
| 548 |
+
chunk_size=args[17],
|
| 549 |
+
rpm=args[18],
|
| 550 |
+
tpm=args[19],
|
| 551 |
+
quiz_samples=args[20],
|
| 552 |
+
trainee_url=args[21],
|
| 553 |
+
trainee_api_key=args[22],
|
| 554 |
+
token_counter=args[23],
|
| 555 |
)
|
| 556 |
),
|
| 557 |
inputs=[
|
| 558 |
if_trainee_model,
|
| 559 |
upload_file,
|
| 560 |
tokenizer,
|
| 561 |
+
output_data_type,
|
| 562 |
+
output_data_format,
|
| 563 |
bidirectional,
|
| 564 |
expand_method,
|
| 565 |
max_extra_edges,
|
webui/base.py
CHANGED
|
@@ -1,15 +1,18 @@
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
from typing import Any
|
| 3 |
|
|
|
|
| 4 |
@dataclass
|
| 5 |
class GraphGenParams:
|
| 6 |
"""
|
| 7 |
GraphGen parameters
|
| 8 |
"""
|
|
|
|
| 9 |
if_trainee_model: bool
|
| 10 |
input_file: str
|
| 11 |
tokenizer: str
|
| 12 |
-
|
|
|
|
| 13 |
bidirectional: bool
|
| 14 |
expand_method: str
|
| 15 |
max_extra_edges: int
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
from typing import Any
|
| 3 |
|
| 4 |
+
|
| 5 |
@dataclass
|
| 6 |
class GraphGenParams:
|
| 7 |
"""
|
| 8 |
GraphGen parameters
|
| 9 |
"""
|
| 10 |
+
|
| 11 |
if_trainee_model: bool
|
| 12 |
input_file: str
|
| 13 |
tokenizer: str
|
| 14 |
+
output_data_type: str
|
| 15 |
+
output_data_format: str
|
| 16 |
bidirectional: bool
|
| 17 |
expand_method: str
|
| 18 |
max_extra_edges: int
|
webui/i18n.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import functools
|
| 2 |
import inspect
|
| 3 |
import json
|
|
|
|
| 1 |
+
# pylint: skip-file
|
| 2 |
import functools
|
| 3 |
import inspect
|
| 4 |
import json
|