Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
49fed2a
1
Parent(s):
2d67881
Initialize models
Browse filesSigned-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>
app.py
CHANGED
|
@@ -11,8 +11,10 @@ from gliner import GLiNER
|
|
| 11 |
|
| 12 |
_MODEL = {}
|
| 13 |
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
|
|
|
|
| 14 |
LABELS = ["country", "year", "statistical indicator", "geographic region"]
|
| 15 |
QUERY = "gdp, co2 emissions, and mortality rate of the philippines vs. south asia in 2024"
|
|
|
|
| 16 |
|
| 17 |
print(f"Cache directory: {_CACHE_DIR}")
|
| 18 |
|
|
@@ -36,6 +38,13 @@ def get_model(model_name: str = None):
|
|
| 36 |
return _MODEL[model_name]
|
| 37 |
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
def get_country(country_name: str):
|
| 40 |
try:
|
| 41 |
return pycountry.countries.search_fuzzy(country_name)
|
|
@@ -43,7 +52,7 @@ def get_country(country_name: str):
|
|
| 43 |
return None
|
| 44 |
|
| 45 |
|
| 46 |
-
@spaces.GPU(enable_queue=True, duration=
|
| 47 |
def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
|
| 48 |
start = datetime.now()
|
| 49 |
model = get_model(model_name)
|
|
@@ -99,7 +108,7 @@ with gr.Blocks(title="GLiNER-query-parser") as demo:
|
|
| 99 |
)
|
| 100 |
with gr.Row() as row:
|
| 101 |
model_name = gr.Radio(
|
| 102 |
-
choices=
|
| 103 |
value="urchade/gliner_base",
|
| 104 |
label="Model",
|
| 105 |
)
|
|
@@ -112,7 +121,7 @@ with gr.Blocks(title="GLiNER-query-parser") as demo:
|
|
| 112 |
threshold = gr.Slider(
|
| 113 |
0,
|
| 114 |
1,
|
| 115 |
-
value=
|
| 116 |
step=0.01,
|
| 117 |
label="Threshold",
|
| 118 |
info="Lower threshold may extract more false-positive entities from the query.",
|
|
|
|
| 11 |
|
| 12 |
_MODEL = {}
|
| 13 |
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
|
| 14 |
+
THRESHOLD = 0.3
|
| 15 |
LABELS = ["country", "year", "statistical indicator", "geographic region"]
|
| 16 |
QUERY = "gdp, co2 emissions, and mortality rate of the philippines vs. south asia in 2024"
|
| 17 |
+
MODELS = ["urchade/gliner_base", "urchade/gliner_medium-v2.1"]
|
| 18 |
|
| 19 |
print(f"Cache directory: {_CACHE_DIR}")
|
| 20 |
|
|
|
|
| 38 |
return _MODEL[model_name]
|
| 39 |
|
| 40 |
|
| 41 |
+
# Initialize model here.
|
| 42 |
+
print("Initializing models...")
|
| 43 |
+
for model_name in MODELS:
|
| 44 |
+
model = get_model(model_name=model_name)
|
| 45 |
+
model.predict_entities(QUERY, LABELS, threshold=THRESHOLD)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
def get_country(country_name: str):
|
| 49 |
try:
|
| 50 |
return pycountry.countries.search_fuzzy(country_name)
|
|
|
|
| 52 |
return None
|
| 53 |
|
| 54 |
|
| 55 |
+
@spaces.GPU(enable_queue=True, duration=5)
|
| 56 |
def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
|
| 57 |
start = datetime.now()
|
| 58 |
model = get_model(model_name)
|
|
|
|
| 108 |
)
|
| 109 |
with gr.Row() as row:
|
| 110 |
model_name = gr.Radio(
|
| 111 |
+
choices=MODELS,
|
| 112 |
value="urchade/gliner_base",
|
| 113 |
label="Model",
|
| 114 |
)
|
|
|
|
| 121 |
threshold = gr.Slider(
|
| 122 |
0,
|
| 123 |
1,
|
| 124 |
+
value=THRESHOLD,
|
| 125 |
step=0.01,
|
| 126 |
label="Threshold",
|
| 127 |
info="Lower threshold may extract more false-positive entities from the query.",
|