stvnnnnnn commited on
Commit
f2d0983
·
verified ·
1 Parent(s): 0524a3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -110,23 +110,25 @@ def nl2sql(q: NLQuery):
110
  if not text:
111
  raise ValueError("Consulta vacía.")
112
 
 
113
  lower = text.lower().strip()
114
  looks_like_sql = lower.startswith(("select", "with", "insert", "update", "delete", "create", "drop", "alter"))
115
 
 
116
  query_en = text
117
  if not looks_like_sql:
118
  try:
119
  translated = GoogleTranslator(source="auto", target="en").translate(text)
120
- if translated: # si viene None o vacío, mantenemos original
121
  query_en = translated
122
- except Exception:
123
- query_en = text
124
 
125
- df = get_table(SPLIT, INDEX, MAX_ROWS)
 
126
  enc = tok(table=df, query=query_en, return_tensors="pt", truncation=True)
127
  if torch.cuda.is_available():
128
  enc = {k: v.to("cuda") for k, v in enc.items()}
129
-
130
  out = model.generate(**enc, max_length=160, num_beams=1)
131
  sql = tok.batch_decode(out, skip_special_tokens=True)[0]
132
 
@@ -135,5 +137,6 @@ def nl2sql(q: NLQuery):
135
  "consulta_traducida": query_en,
136
  "sql_generado": sql
137
  }
 
138
  except Exception as e:
139
  raise HTTPException(status_code=500, detail=str(e))
 
110
  if not text:
111
  raise ValueError("Consulta vacía.")
112
 
113
+ # Detectar si parece SQL
114
  lower = text.lower().strip()
115
  looks_like_sql = lower.startswith(("select", "with", "insert", "update", "delete", "create", "drop", "alter"))
116
 
117
+ # Traducir a inglés si no es SQL
118
  query_en = text
119
  if not looks_like_sql:
120
  try:
121
  translated = GoogleTranslator(source="auto", target="en").translate(text)
122
+ if translated:
123
  query_en = translated
124
+ except Exception:
125
+ query_en = text # fallback seguro
126
 
127
+ # Procesar con TAPEX
128
+ df = get_table_from_wikisql(SPLIT, INDEX, MAX_ROWS)
129
  enc = tok(table=df, query=query_en, return_tensors="pt", truncation=True)
130
  if torch.cuda.is_available():
131
  enc = {k: v.to("cuda") for k, v in enc.items()}
 
132
  out = model.generate(**enc, max_length=160, num_beams=1)
133
  sql = tok.batch_decode(out, skip_special_tokens=True)[0]
134
 
 
137
  "consulta_traducida": query_en,
138
  "sql_generado": sql
139
  }
140
+
141
  except Exception as e:
142
  raise HTTPException(status_code=500, detail=str(e))