clustering-test / app.py
xxmaranxx's picture
Update app.py
e624423 verified
# app.py
# -*- coding: utf-8 -*-
import os, pickle, numpy as np
from typing import Dict, Tuple
from fastapi import FastAPI
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer
from transformers import pipeline
# ---- Performance flags ----
os.environ["TOKENIZERS_PARALLELISM"] = "false"
try:
import torch
torch.set_num_threads(1) # evita thrashing en CPU básica
except Exception:
pass
# ---- Carga artefactos una vez ----
lw: Dict = pickle.load(open("predictor.pkl", "rb"))
sbert = SentenceTransformer(lw["model_name"])
# centroides normalizados
centroides = {int(k): np.array(v, dtype=np.float32) for k, v in lw["centroides"].items()}
for k, v in centroides.items():
n = np.linalg.norm(v) + 1e-12
centroides[k] = (v / n).astype(np.float32)
cids = sorted(centroides.keys())
meta = lw.get("meta", {}) or {}
# (OPCIONAL) umbrales por cluster guardados en el pickle, p.ej. similitud mínima histórica
# formato esperado: { "margenes": { "0": 0.50, "1": 0.58, ... } }
_margenes_raw = lw.get("margenes") or lw.get("margins") or {}
MARGENES = {int(k): float(v) for k, v in _margenes_raw.items()} if isinstance(_margenes_raw, dict) else {}
# Umbral global por defecto (puede sobreescribirse por env o por lw["tau_otros"])
TAU_OTROS = float(os.getenv("TAU_OTROS", str(lw.get("tau_otros", 0.7))))
# Sentimiento (modelo liviano; recorta a 256 tokens)
sentiment = pipeline(
"text-classification",
model="UMUTeam/roberta-spanish-sentiment-analysis",
device=-1
)
EMOTIONS = ["alegría", "tristeza", "ira", "asco", "miedo", "sorpresa", "neutral"]
HYP = "El texto expresa {}."
# Precompute embeddings de emociones con el mismo encoder (rápido)
_emotion_texts = [HYP.format(e) for e in EMOTIONS]
_emotion_embs = sbert.encode(
_emotion_texts, convert_to_numpy=True, normalize_embeddings=True
).astype(np.float32)
app = FastAPI(title="Predicción de clusters/sentimiento/emoción")
# -------- Helpers --------
def _encode(text: str) -> np.ndarray:
emb = sbert.encode(text, convert_to_numpy=True, normalize_embeddings=True).astype(np.float32)
return emb[None, :] if emb.ndim == 1 else emb
def _mejor_cluster_y_similitud(vec: np.ndarray) -> Tuple[int, float]:
"""
Devuelve (cid_mejor, similitud_cos). Como todo está normalizado, cos = dot.
"""
# apilamos centroides en una matriz para multiplicar de una
C = np.vstack([centroides[c] for c in cids]).astype(np.float32) # shape: (K, D)
sims = (C @ vec.reshape(-1, 1)).squeeze(-1) # shape: (K,)
i = int(np.argmax(sims))
return cids[i], float(sims[i])
def _truncate_for_classifier(text: str, max_chars: int = 1000) -> str:
return text if len(text) <= max_chars else text[:max_chars]
def _fast_emotion(emb: np.ndarray) -> str:
sims = (_emotion_embs @ emb.reshape(-1, 1)).squeeze(-1)
return EMOTIONS[int(np.argmax(sims))]
# -------- Schema de entrada --------
class Entrada(BaseModel):
# acepta "asunto" o "subject"
asunto: str = Field(default="", alias="subject")
# acepta "cuerpo" o "body"
cuerpo: str = Field(default="", alias="body")
class Config:
populate_by_name = True # permite usar los nombres sin alias también
# -------- Endpoint --------
@app.post("/predict")
def predict(item: Entrada):
subject = (item.asunto or "").strip()
body = (item.cuerpo or "").strip()
text = f"{subject}{body}".strip(" —")
emb = _encode(text)[0]
cid_mejor, sim = _mejor_cluster_y_similitud(emb)
# umbral: usa específico del cluster si existe; si no, global
tau = float(MARGENES.get(cid_mejor, TAU_OTROS))
# si la similitud no supera el umbral -> "otros"
es_otros = sim < tau
if es_otros:
cid = -1
nombre = "otros"
desc = "No se parece lo suficiente a ningún cluster conocido."
else:
cid = cid_mejor
m = meta.get(str(cid), meta.get(cid, {})) or {}
nombre = m.get("nombre")
desc = m.get("descripcion")
# RÁPIDO: sentimiento con truncado
s = sentiment(_truncate_for_classifier(text), truncation=True, max_length=256)[0]["label"]
# RÁPIDO: emoción por similitud con SBERT (sin segundo Transformer)
e = _fast_emotion(emb)
return {
"asunto": subject,
"cuerpo": body,
"cluster": cid,
"cluster_nombre": nombre,
"cluster_desc": desc,
"similitud_cluster": round(sim, 4),
"umbral_usado": round(tau, 4),
"sentimiento": s,
"emocion": e
}
# -------- Entrypoint opcional --------
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=int(os.getenv("PORT", "8000")))