|
|
|
|
|
|
|
|
from langchain.prompts import PromptTemplate |
|
|
from langchain.chains import LLMChain |
|
|
from pyvis.network import Network |
|
|
from pprint import pprint |
|
|
import networkx as nx |
|
|
import gradio as gr |
|
|
import re |
|
|
import datasets |
|
|
from huggingface_hub import login, HfApi |
|
|
from datasets import Dataset, load_dataset |
|
|
from rapidfuzz import fuzz, process |
|
|
import math |
|
|
import pandas as pd |
|
|
import gspread |
|
|
import torch |
|
|
import json |
|
|
from typing import Callable, Optional |
|
|
from dataclasses import dataclass |
|
|
from datasets import load_dataset |
|
|
from transformers import ( |
|
|
AutoModelForSequenceClassification, |
|
|
TrainingArguments, |
|
|
Trainer, |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
BitsAndBytesConfig, |
|
|
pipeline |
|
|
) |
|
|
from peft import PeftModel, LoraConfig, get_peft_model, TaskType |
|
|
|
|
|
REPO_ID_NEAR_FIELD_RAW = "milistu/AMAZON-Products-2023" |
|
|
REPO_ID_NEAR_FIELD = "aslan-ng/amazon_products_2023" |
|
|
REPO_ID_FAR_FIELD = "aslan-ng/amazon_products_2025" |
|
|
REPO_ID_LORA_GREEN_PATENTS = "aslan-ng/lora-green-patents" |
|
|
|
|
|
def product_quality_score(average_rating: float, rating_number: int): |
|
|
""" |
|
|
Bayesian Average (Amazon-style) |
|
|
|
|
|
Args: |
|
|
avg_rating: product's average rating |
|
|
rating_number: number of reviews |
|
|
""" |
|
|
m = 1 |
|
|
C = 3.5 |
|
|
if rating_number <= 0 or average_rating is None: |
|
|
return C |
|
|
return (rating_number / (rating_number + m)) * average_rating + (m / (rating_number + m)) * C |
|
|
|
|
|
def load_near_field_raw_from_huggingface(): |
|
|
""" |
|
|
Load the raw near-field dataset from HuggingFace. |
|
|
""" |
|
|
ds = datasets.load_dataset(REPO_ID_NEAR_FIELD_RAW, split="train") |
|
|
print("Initial size: ", len(ds)) |
|
|
|
|
|
|
|
|
main_categories_to_remove = ["meta_Books", "meta_CDs_and_Vinyl", "meta_Digital_Music", "meta_Gift_Cards", "meta_Grocery_and_Gourmet_Food", |
|
|
"meta_Magazine_Subscriptions", "meta_Software", "meta_Video_Games"] |
|
|
ds = ds.filter(lambda row: row["filename"] not in main_categories_to_remove) |
|
|
|
|
|
|
|
|
cols_to_keep = ["title", "description", "main_category", "average_rating", "rating_number"] |
|
|
ds = ds.remove_columns([c for c in ds.column_names if c not in cols_to_keep]) |
|
|
|
|
|
|
|
|
def add_quality_score(batch): |
|
|
return { |
|
|
"product_quality_score": [ |
|
|
product_quality_score(r, n) |
|
|
for r, n in zip(batch["average_rating"], batch["rating_number"]) |
|
|
] |
|
|
} |
|
|
ds = ds.map(add_quality_score, batched=True) |
|
|
|
|
|
|
|
|
def is_valid(v): |
|
|
""" |
|
|
Must have valid values in the row. Will be used for filtering. |
|
|
""" |
|
|
if v is None: |
|
|
return False |
|
|
if isinstance(v, str): |
|
|
if v.strip() == "": |
|
|
return False |
|
|
return True |
|
|
|
|
|
def keep_row(row): |
|
|
""" |
|
|
Keep only the columns with valid data |
|
|
""" |
|
|
if is_valid(row.get("title")) and \ |
|
|
is_valid(row.get("description")) and \ |
|
|
is_valid(row.get("main_category")) and \ |
|
|
is_valid(row.get("average_rating")) and \ |
|
|
is_valid(row.get("rating_number")): |
|
|
return True |
|
|
return False |
|
|
|
|
|
ds = ds.filter(keep_row) |
|
|
|
|
|
return ds.to_pandas() |
|
|
|
|
|
def load_near_field_from_huggingface(): |
|
|
""" |
|
|
Load the near-field dataset from HuggingFace. |
|
|
""" |
|
|
ds = load_dataset(REPO_ID_NEAR_FIELD, split="train") |
|
|
return ds.to_pandas() |
|
|
|
|
|
def save_near_field_to_huggingface(): |
|
|
""" |
|
|
Save the near-field dataset from HuggingFace. |
|
|
""" |
|
|
df = load_near_field_raw_from_huggingface() |
|
|
ds = Dataset.from_pandas(df) |
|
|
ds.push_to_hub(REPO_ID_NEAR_FIELD) |
|
|
print(f"✅ Pushed {len(ds)} rows to {REPO_ID_NEAR_FIELD}") |
|
|
|
|
|
|
|
|
dataset_near_field = load_near_field_from_huggingface() |
|
|
|
|
|
def load_far_field_from_sheet(): |
|
|
""" |
|
|
Load the far-field dataset from Google Sheets. |
|
|
""" |
|
|
auth.authenticate_user() |
|
|
from google.auth import default |
|
|
COLS = ["title", "description", "average_rating", "rating_number"] |
|
|
categories = ["Home & Kitchen", "Beauty & Personal Care", "Sports & Outdoors", "Clothing, Shoes & Jewelry", "Industrial & Scientific", |
|
|
"Appliances", "Arts, Crafts & Sewing", "Electronics"] |
|
|
sh = gspread.authorize(default()[0]).open_by_key(SHEET_ID_FAR_FIELD) |
|
|
frames = [] |
|
|
for ws in sh.worksheets(): |
|
|
rows = ws.get_all_records() |
|
|
if not rows: |
|
|
continue |
|
|
df = pd.DataFrame(rows) |
|
|
|
|
|
df = df[COLS].copy() |
|
|
|
|
|
df["main_category"] = ws.title |
|
|
frames.append(df) |
|
|
df = pd.concat(frames, ignore_index=True) if frames else pd.DataFrame(columns=COLS + ["main_category"]) |
|
|
|
|
|
|
|
|
def _safe_pqs(row): |
|
|
ar, n = row["average_rating"], row["rating_number"] |
|
|
if pd.notna(ar) and pd.notna(n): |
|
|
return product_quality_score(ar, n) |
|
|
return float("nan") |
|
|
|
|
|
df["product_quality_score"] = df.apply(_safe_pqs, axis=1) |
|
|
|
|
|
return df |
|
|
|
|
|
def load_far_field_from_huggingface(): |
|
|
""" |
|
|
Load the far-field dataset from HuggingFace. |
|
|
""" |
|
|
ds = load_dataset(REPO_ID_FAR_FIELD, split="train") |
|
|
return ds.to_pandas() |
|
|
|
|
|
def save_far_field_to_huggingface(): |
|
|
""" |
|
|
Save the far-field dataset from HuggingFace. |
|
|
""" |
|
|
df = load_far_field_from_sheet() |
|
|
ds = Dataset.from_pandas(df) |
|
|
ds.push_to_hub(REPO_ID_FAR_FIELD) |
|
|
print(f"✅ Pushed {len(ds)} rows to {REPO_ID_FAR_FIELD}") |
|
|
|
|
|
|
|
|
dataset_far_field = load_far_field_from_huggingface() |
|
|
|
|
|
def product_score(product_quality_score: float, fuzzy_score: float): |
|
|
""" |
|
|
Combine product score and fuzzy score into a single score. |
|
|
""" |
|
|
return math.sqrt(product_quality_score * fuzzy_score) |
|
|
|
|
|
def query_near_field(input: str, top_k: int=1): |
|
|
""" |
|
|
Return top_k fuzzy matches for query against dataset titles as a pandas DataFrame. |
|
|
Always returns exactly top_k rows (if available). |
|
|
""" |
|
|
if top_k <= 0: |
|
|
raise ValueError |
|
|
|
|
|
n = len(dataset_near_field) |
|
|
if top_k > n: |
|
|
print(f"Warning: top_k ({top_k}) is greater than the number of examples in the near-field dataset ({n}). Returning all examples.") |
|
|
return dataset_near_field.reset_index(drop=True) |
|
|
|
|
|
matches = process.extract( |
|
|
input, |
|
|
dataset_near_field["title"].fillna("").astype(str).tolist(), |
|
|
scorer=fuzz.token_set_ratio, |
|
|
limit=n |
|
|
) |
|
|
|
|
|
rows = [] |
|
|
for _text, fuzzy_score, idx in matches: |
|
|
row = dataset_near_field.iloc[idx].to_dict() |
|
|
row["data_source"] = "near_field" |
|
|
row["fuzzy_score"] = fuzzy_score |
|
|
product_quality_score = row.get("product_quality_score") |
|
|
row["score"] = product_score(product_quality_score, fuzzy_score) |
|
|
rows.append(row) |
|
|
|
|
|
return ( |
|
|
pd.DataFrame(rows) |
|
|
.sort_values("score", ascending=False) |
|
|
.head(top_k) |
|
|
.reset_index(drop=True) |
|
|
) |
|
|
|
|
|
def query_far_field(input: str, top_k: int): |
|
|
""" |
|
|
Return top_k random elements from the far_field dataset as a pandas DataFrame. |
|
|
The input string is ignored. |
|
|
""" |
|
|
if top_k < 0: |
|
|
raise ValueError |
|
|
|
|
|
n = len(dataset_far_field) |
|
|
if top_k > n: |
|
|
print(f"Warning: top_k ({top_k}) is greater than the number of examples in the far-field dataset ({n}). Returning all examples.") |
|
|
return dataset_far_field.reset_index(drop=True) |
|
|
|
|
|
|
|
|
sampled = dataset_far_field.sample(n=top_k, random_state=None).reset_index(drop=True) |
|
|
|
|
|
|
|
|
sampled["fuzzy_score"] = [ |
|
|
fuzz.token_set_ratio(str(t) if pd.notna(t) else "", input) |
|
|
for t in sampled.get("title", "") |
|
|
] |
|
|
product_quality_scores = sampled.get("product_quality_score") |
|
|
fuzzy_scores = sampled["fuzzy_score"] |
|
|
sampled["score"] = [product_score(a, b) for a, b in zip(product_quality_scores, fuzzy_scores)] |
|
|
sampled["data_source"] = "far_field" |
|
|
|
|
|
return sampled |
|
|
|
|
|
def split_near_and_far_fields(total_examples: int, near_far_ratio: float = 0.5): |
|
|
""" |
|
|
Split the examples between near and far field. |
|
|
The ratio represents the examples that will be in the near field to total (near + far). |
|
|
""" |
|
|
ratio = near_far_ratio |
|
|
|
|
|
if ratio < 0 or ratio > 1: |
|
|
raise ValueError("Ratio must be between 0 and 1") |
|
|
if total_examples < 2: |
|
|
raise ValueError("Total examples must be at least 2") |
|
|
|
|
|
near_field_examples = int(total_examples * ratio) |
|
|
far_field_examples = total_examples - near_field_examples |
|
|
|
|
|
return near_field_examples, far_field_examples |
|
|
|
|
|
def query(input: str, total_examples: int, near_far_ratio: float = 0.5): |
|
|
near_field_examples, far_field_examples = split_near_and_far_fields(total_examples, near_far_ratio) |
|
|
far_field_result = query_far_field(input, far_field_examples) |
|
|
|
|
|
near_field_result = query_near_field(input, near_field_examples) |
|
|
|
|
|
result = pd.concat([near_field_result, far_field_result], ignore_index=True) |
|
|
return result |
|
|
|
|
|
def lora_load(): |
|
|
model_name = "distilbert-base-uncased" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(REPO_ID_LORA_GREEN_PATENTS) |
|
|
base_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) |
|
|
model = PeftModel.from_pretrained(base_model, REPO_ID_LORA_GREEN_PATENTS) |
|
|
|
|
|
clf = pipeline("text-classification", model=model, tokenizer=tokenizer) |
|
|
return clf |
|
|
|
|
|
clf = lora_load() |
|
|
|
|
|
def sustainability_filter(input: str, total_examples: int, near_far_ratio: float = 0.5): |
|
|
initial_products = query(input, total_examples, near_far_ratio) |
|
|
initial_products_list = initial_products['description'].tolist() |
|
|
filtered_products = clf(initial_products_list) |
|
|
|
|
|
|
|
|
labels = [item["label"] for item in filtered_products] |
|
|
|
|
|
|
|
|
initial_products["label"] = labels |
|
|
|
|
|
|
|
|
filtered_df = initial_products[initial_products["label"] != "LABEL_0"].copy() |
|
|
|
|
|
|
|
|
filtered_df.drop(columns="label", inplace=True) |
|
|
|
|
|
return filtered_df |
|
|
|
|
|
SYSTEM_PROMPT = """ |
|
|
You are a product analyst. You'll receive product description as input, and extract some product functionality and some product values. Each functionality and value should be 1-5 keywords only. |
|
|
Product functionality refers to what the product does: its features, technical capabilities, and performance characteristics. It answers the question: “What can this product do?” |
|
|
Product value refers to the benefit the customer gains from using the product: how it improves their life, solves their problem, or helps them achieve goals. It answers the question: “Why does this matter to the customer?” |
|
|
Do **not** duplicate an item in both lists. Keep **functionalities** as concrete features. Keep **values** as clear user benefits. Keep short_title of the product shorter than 5 words. |
|
|
|
|
|
Your Output is a dictionary. Here is the format: |
|
|
|
|
|
# Your Input: |
|
|
<product_description> |
|
|
# Your Output: |
|
|
{ |
|
|
"short_title": <short_title>, |
|
|
"values": [ |
|
|
<value1>, |
|
|
<value2>, |
|
|
... |
|
|
], |
|
|
"functionalities": [ |
|
|
<function1>, |
|
|
<function2>, |
|
|
... |
|
|
] |
|
|
} |
|
|
|
|
|
Select and return only the 5 most relevant values and only the 5 most relevant functionalities for each product. |
|
|
Don't return anything out of the output format. |
|
|
""" |
|
|
|
|
|
@dataclass |
|
|
class LLMConfig: |
|
|
model_id: str |
|
|
system_prompt: str = "" |
|
|
max_new_tokens: int = 256 |
|
|
temperature: float = 0.2 |
|
|
top_p: float = 0.9 |
|
|
repetition_penalty: float = 1.05 |
|
|
use_4bit: bool = True |
|
|
|
|
|
def create_llm( |
|
|
*, |
|
|
model_id: str, |
|
|
max_new_tokens: int = 256, |
|
|
temperature: float = 0.2, |
|
|
top_p: float = 0.9, |
|
|
repetition_penalty: float = 1.05, |
|
|
use_4bit: bool = True |
|
|
) -> Callable[[str], str]: |
|
|
""" |
|
|
Load an off-the-shelf chat LLM and return a callable llm(prompt) -> str. |
|
|
Pass ONLY the model parameters you want. No size mapping. No llama_cpp. |
|
|
""" |
|
|
|
|
|
cfg = LLMConfig( |
|
|
model_id=model_id, |
|
|
system_prompt=SYSTEM_PROMPT, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
use_4bit=use_4bit, |
|
|
) |
|
|
|
|
|
has_cuda = torch.cuda.is_available() |
|
|
qconfig: Optional[BitsAndBytesConfig] = None |
|
|
if has_cuda and cfg.use_4bit: |
|
|
qconfig = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(cfg.model_id, use_fast=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
cfg.model_id, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.bfloat16 if has_cuda else torch.float32, |
|
|
quantization_config=qconfig, |
|
|
).eval() |
|
|
|
|
|
def _format_messages(user_text: str) -> str: |
|
|
msgs = [] |
|
|
if cfg.system_prompt: |
|
|
msgs.append({"role": "system", "content": cfg.system_prompt}) |
|
|
msgs.append({"role": "user", "content": user_text}) |
|
|
|
|
|
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template is not None: |
|
|
return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
|
|
|
sys = f"System: {cfg.system_prompt}\n\n" if cfg.system_prompt else "" |
|
|
return f"{sys}User: {user_text}\nAssistant:" |
|
|
|
|
|
@torch.inference_mode() |
|
|
def llm(prompt: str, |
|
|
max_new_tokens: int = None, |
|
|
temperature: float = None, |
|
|
top_p: float = None, |
|
|
repetition_penalty: float = None) -> str: |
|
|
|
|
|
text = _format_messages(prompt) |
|
|
inputs = tokenizer(text, return_tensors="pt").to(model.device) |
|
|
out = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_new_tokens or cfg.max_new_tokens, |
|
|
do_sample=(temperature or cfg.temperature) > 0.0, |
|
|
temperature=temperature or cfg.temperature, |
|
|
top_p=top_p or cfg.top_p, |
|
|
repetition_penalty=repetition_penalty or cfg.repetition_penalty, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
gen = out[0][inputs["input_ids"].shape[-1]:] |
|
|
return tokenizer.decode(gen, skip_special_tokens=True).strip() |
|
|
|
|
|
print(f"Loaded: {cfg.model_id} | 4-bit: {bool(qconfig)} | Device: {model.device}") |
|
|
return llm |
|
|
|
|
|
def response_to_triplets(response: str): |
|
|
data = json.loads(response) |
|
|
product_title = data["short_title"] |
|
|
|
|
|
triples_list = [] |
|
|
|
|
|
for value in data.get("values", []): |
|
|
triples_list.append([product_title, "HAS_VALUE", value]) |
|
|
|
|
|
for func in data.get("functionalities", []): |
|
|
triples_list.append([product_title, "HAS_FUNCTIONALITY", func]) |
|
|
|
|
|
return triples_list |
|
|
|
|
|
llm = create_llm( |
|
|
model_id="Qwen/Qwen2.5-3B-Instruct", |
|
|
max_new_tokens=200, |
|
|
temperature=0.2, |
|
|
top_p=0.9, |
|
|
repetition_penalty=1.05, |
|
|
use_4bit=True, |
|
|
) |
|
|
|
|
|
def main(input: str, total_examples: int, near_far_ratio: float = 0.5): |
|
|
all_triplets_list = [] |
|
|
|
|
|
sustainable_results = sustainability_filter(input, total_examples=total_examples, near_far_ratio=near_far_ratio) |
|
|
|
|
|
for i, product in sustainable_results.iterrows(): |
|
|
product_title = product["title"] |
|
|
product_full_description = product_title + " " + product["description"] |
|
|
|
|
|
response = llm(product_full_description) |
|
|
|
|
|
triplets_list = response_to_triplets(response) |
|
|
print(triplets_list) |
|
|
for triplet in triplets_list: |
|
|
all_triplets_list.append(triplet) |
|
|
|
|
|
return all_triplets_list |
|
|
|
|
|
def create_graph_from_triplets(triplets): |
|
|
G = nx.DiGraph() |
|
|
for triplet in triplets: |
|
|
line = str(triplet).strip() |
|
|
if not line: |
|
|
continue |
|
|
|
|
|
parts = [p.strip(" ()") for p in line.split(",", 2)] |
|
|
if len(parts) != 3: |
|
|
|
|
|
parts = [p.strip(" ()") for p in line.split("|")] |
|
|
if len(parts) != 3: |
|
|
continue |
|
|
subject, predicate, obj = parts |
|
|
if subject and predicate and obj: |
|
|
G.add_edge(subject, obj, label=predicate) |
|
|
return G |
|
|
|
|
|
def nx_to_pyvis(networkx_graph): |
|
|
from pyvis.network import Network |
|
|
|
|
|
pyvis_graph = Network(notebook=True, cdn_resources="remote") |
|
|
|
|
|
|
|
|
color_product_bg = "#FFAA00" |
|
|
color_value_bg = "#4CAF50" |
|
|
color_func_bg = "#2196F3" |
|
|
color_default_bg = "#CCCCCC" |
|
|
|
|
|
color_value_edge = "#81C784" |
|
|
color_func_edge = "#64B5F6" |
|
|
color_default_edge = "#999999" |
|
|
|
|
|
|
|
|
FUNC_LABELS = {"HAS_FUNCTIONALITY", "HAS_FUNTIONALITY"} |
|
|
VALUE_LABELS = {"HAS_VALUE"} |
|
|
|
|
|
|
|
|
subjects = set() |
|
|
value_nodes = set() |
|
|
func_nodes = set() |
|
|
|
|
|
for u, v, data in networkx_graph.edges(data=True): |
|
|
lbl = str(data.get("label", "")).strip() |
|
|
subjects.add(u) |
|
|
if lbl in VALUE_LABELS: |
|
|
value_nodes.add(v) |
|
|
elif lbl in FUNC_LABELS: |
|
|
func_nodes.add(v) |
|
|
|
|
|
|
|
|
for node in networkx_graph.nodes(): |
|
|
if node in value_nodes: |
|
|
bg = color_value_bg |
|
|
shape = "ellipse" |
|
|
elif node in func_nodes: |
|
|
bg = color_func_bg |
|
|
shape = "diamond" |
|
|
elif node in subjects: |
|
|
bg = color_product_bg |
|
|
shape = "box" |
|
|
else: |
|
|
bg = color_default_bg |
|
|
shape = "dot" |
|
|
|
|
|
pyvis_graph.add_node( |
|
|
node, |
|
|
label=str(node), |
|
|
color={"background": bg, "border": "#333333"}, |
|
|
shape=shape |
|
|
) |
|
|
|
|
|
|
|
|
for u, v, data in networkx_graph.edges(data=True): |
|
|
lbl = str(data.get("label", "")).strip() |
|
|
if lbl in VALUE_LABELS: |
|
|
edge_color = color_value_edge |
|
|
elif lbl in FUNC_LABELS: |
|
|
edge_color = color_func_edge |
|
|
else: |
|
|
edge_color = color_default_edge |
|
|
|
|
|
pyvis_graph.add_edge(u, v, label=lbl, title=lbl, color=edge_color) |
|
|
|
|
|
return pyvis_graph |
|
|
|
|
|
def generateGraph(triples_list): |
|
|
triplets = [t.strip() for t in triples_list if t.strip()] |
|
|
graph = create_graph_from_triplets(triplets) |
|
|
pyvis_network = nx_to_pyvis(graph) |
|
|
|
|
|
pyvis_network.toggle_hide_edges_on_drag(True) |
|
|
pyvis_network.toggle_physics(False) |
|
|
pyvis_network.set_edge_smooth('discrete') |
|
|
|
|
|
html = pyvis_network.generate_html() |
|
|
html = html.replace("'", "\"") |
|
|
|
|
|
return f"""<iframe style="width: 100%; height: 600px;margin:0 auto" name="result" allow="midi; geolocation; microphone; camera; |
|
|
display-capture; encrypted-media;" sandbox="allow-modals allow-forms |
|
|
allow-scripts allow-same-origin allow-popups |
|
|
allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" |
|
|
allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>""" |
|
|
|
|
|
def pipeline_fn(user_text: str, total_examples: int, near_far_ratio: float): |
|
|
try: |
|
|
if not user_text.strip(): |
|
|
return "<div style='padding:12px;color:#b00;'>Please enter some text.</div>" |
|
|
|
|
|
|
|
|
triples = main( |
|
|
user_text, |
|
|
total_examples=total_examples, |
|
|
near_far_ratio=near_far_ratio |
|
|
) or [] |
|
|
|
|
|
|
|
|
triples_list = [] |
|
|
for t in triples: |
|
|
if isinstance(t, (tuple, list)) and len(t) == 3: |
|
|
triples_list.append(f"{t[0]}, {t[1]}, {t[2]}") |
|
|
else: |
|
|
triples_list.append(str(t)) |
|
|
|
|
|
return generateGraph(triples_list) |
|
|
|
|
|
except Exception: |
|
|
import traceback |
|
|
return "<pre style='white-space: pre-wrap; font-size:12px; color:#b00;'>" + traceback.format_exc() + "</pre>" |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=pipeline_fn, |
|
|
inputs=[ |
|
|
gr.Textbox(label="Enter your query / text", value="", lines=6), |
|
|
gr.Number(label="Number of examples", value=6, precision=0), |
|
|
gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.8, |
|
|
step=0.05, |
|
|
label="Near/Far Ratio" |
|
|
), |
|
|
], |
|
|
outputs=gr.HTML(), |
|
|
title="Knowledge Graph", |
|
|
allow_flagging="never", |
|
|
live=False, |
|
|
css=""" |
|
|
#component-0, #component-1, #component-2, #component-3, #component-4 { |
|
|
display: flex; |
|
|
justify-content: center; |
|
|
align-items: center; |
|
|
flex-direction: column; |
|
|
} |
|
|
.gradio-container { |
|
|
justify-content: center !important; |
|
|
align-items: center !important; |
|
|
text-align: center; |
|
|
} |
|
|
textarea, iframe { |
|
|
margin: 0 auto; |
|
|
display: block; |
|
|
} |
|
|
""" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(quiet=True) |