Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -110,23 +110,25 @@ def nl2sql(q: NLQuery):
|
|
| 110 |
if not text:
|
| 111 |
raise ValueError("Consulta vacía.")
|
| 112 |
|
|
|
|
| 113 |
lower = text.lower().strip()
|
| 114 |
looks_like_sql = lower.startswith(("select", "with", "insert", "update", "delete", "create", "drop", "alter"))
|
| 115 |
|
|
|
|
| 116 |
query_en = text
|
| 117 |
if not looks_like_sql:
|
| 118 |
try:
|
| 119 |
translated = GoogleTranslator(source="auto", target="en").translate(text)
|
| 120 |
-
if translated:
|
| 121 |
query_en = translated
|
| 122 |
-
|
| 123 |
-
|
| 124 |
|
| 125 |
-
|
|
|
|
| 126 |
enc = tok(table=df, query=query_en, return_tensors="pt", truncation=True)
|
| 127 |
if torch.cuda.is_available():
|
| 128 |
enc = {k: v.to("cuda") for k, v in enc.items()}
|
| 129 |
-
|
| 130 |
out = model.generate(**enc, max_length=160, num_beams=1)
|
| 131 |
sql = tok.batch_decode(out, skip_special_tokens=True)[0]
|
| 132 |
|
|
@@ -135,5 +137,6 @@ def nl2sql(q: NLQuery):
|
|
| 135 |
"consulta_traducida": query_en,
|
| 136 |
"sql_generado": sql
|
| 137 |
}
|
|
|
|
| 138 |
except Exception as e:
|
| 139 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
| 110 |
if not text:
|
| 111 |
raise ValueError("Consulta vacía.")
|
| 112 |
|
| 113 |
+
# Detectar si parece SQL
|
| 114 |
lower = text.lower().strip()
|
| 115 |
looks_like_sql = lower.startswith(("select", "with", "insert", "update", "delete", "create", "drop", "alter"))
|
| 116 |
|
| 117 |
+
# Traducir a inglés si no es SQL
|
| 118 |
query_en = text
|
| 119 |
if not looks_like_sql:
|
| 120 |
try:
|
| 121 |
translated = GoogleTranslator(source="auto", target="en").translate(text)
|
| 122 |
+
if translated:
|
| 123 |
query_en = translated
|
| 124 |
+
except Exception:
|
| 125 |
+
query_en = text # fallback seguro
|
| 126 |
|
| 127 |
+
# Procesar con TAPEX
|
| 128 |
+
df = get_table_from_wikisql(SPLIT, INDEX, MAX_ROWS)
|
| 129 |
enc = tok(table=df, query=query_en, return_tensors="pt", truncation=True)
|
| 130 |
if torch.cuda.is_available():
|
| 131 |
enc = {k: v.to("cuda") for k, v in enc.items()}
|
|
|
|
| 132 |
out = model.generate(**enc, max_length=160, num_beams=1)
|
| 133 |
sql = tok.batch_decode(out, skip_special_tokens=True)[0]
|
| 134 |
|
|
|
|
| 137 |
"consulta_traducida": query_en,
|
| 138 |
"sql_generado": sql
|
| 139 |
}
|
| 140 |
+
|
| 141 |
except Exception as e:
|
| 142 |
raise HTTPException(status_code=500, detail=str(e))
|