Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- app.py +86 -0
- requirements.txt +6 -0
app.py
CHANGED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
import os, json, pickle
|
| 3 |
+
from typing import List, Dict, Any
|
| 4 |
+
import numpy as np
|
| 5 |
+
from fastapi import FastAPI, HTTPException
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
from pydantic import BaseModel, ConfigDict
|
| 9 |
+
|
| 10 |
+
# ---- load artifacts once ----
|
| 11 |
+
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
| 12 |
+
PRED_PATH = os.path.join(DATA_DIR, "predictor.pkl")
|
| 13 |
+
if not os.path.exists(PRED_PATH):
|
| 14 |
+
raise FileNotFoundError("Put your predictor.pkl under data/")
|
| 15 |
+
|
| 16 |
+
with open(PRED_PATH, "rb") as f:
|
| 17 |
+
lw = pickle.load(f)
|
| 18 |
+
|
| 19 |
+
model_name: str = lw["model_name"]
|
| 20 |
+
sbert = SentenceTransformer(model_name)
|
| 21 |
+
|
| 22 |
+
centroides = {int(k): np.array(v, dtype=np.float32) for k, v in lw["centroides"].items()}
|
| 23 |
+
# normalize centroid vectors (to match normalized embeddings)
|
| 24 |
+
for k in list(centroides.keys()):
|
| 25 |
+
c = centroides[k]
|
| 26 |
+
n = float(np.linalg.norm(c) + 1e-12)
|
| 27 |
+
centroides[k] = c / n
|
| 28 |
+
|
| 29 |
+
meta: Dict[int, Dict[str, Any]] = lw["meta"]
|
| 30 |
+
cids = sorted(centroides.keys())
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class MailItem(BaseModel):
|
| 35 |
+
subject: str
|
| 36 |
+
body: str
|
| 37 |
+
# allow unknown fields
|
| 38 |
+
model_config = ConfigDict(extra="allow")
|
| 39 |
+
|
| 40 |
+
class PredictRequest(BaseModel):
|
| 41 |
+
data: List[MailItem]
|
| 42 |
+
|
| 43 |
+
class PredictResponseItem(BaseModel):
|
| 44 |
+
json: Dict[str, Any]
|
| 45 |
+
|
| 46 |
+
class PredictResponse(BaseModel):
|
| 47 |
+
results: List[PredictResponseItem]
|
| 48 |
+
|
| 49 |
+
# ---- core ----
|
| 50 |
+
def _encode(texts: List[str]) -> np.ndarray:
|
| 51 |
+
# normalized embeddings
|
| 52 |
+
emb = sbert.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
|
| 53 |
+
return emb.astype(np.float32)
|
| 54 |
+
|
| 55 |
+
def _assign(vec: np.ndarray) -> int:
|
| 56 |
+
dists = [np.linalg.norm(vec - centroides[c]) for c in cids]
|
| 57 |
+
return cids[int(np.argmin(dists))]
|
| 58 |
+
|
| 59 |
+
def _predict(records: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 60 |
+
texts = [f"{r.get('subject','')} — {r.get('body','')}" for r in records]
|
| 61 |
+
emb = _encode(texts)
|
| 62 |
+
out = []
|
| 63 |
+
for r, v in zip(records, emb):
|
| 64 |
+
cid = _assign(v)
|
| 65 |
+
j = dict(r)
|
| 66 |
+
j["cluster"] = cid
|
| 67 |
+
j["cluster_nombre"] = meta.get(cid, {}).get("nombre", f"cluster_{cid}")
|
| 68 |
+
j["cluster_desc"] = meta.get(cid, {}).get("descripcion", "")
|
| 69 |
+
out.append({"json": j})
|
| 70 |
+
return out
|
| 71 |
+
|
| 72 |
+
# ---- FastAPI app ----
|
| 73 |
+
app = FastAPI(title="Mail Cluster Inference", version="1.0.0")
|
| 74 |
+
|
| 75 |
+
@app.get("/healthz")
|
| 76 |
+
def healthz():
|
| 77 |
+
return {"ok": True, "clusters": len(cids), "model": model_name}
|
| 78 |
+
|
| 79 |
+
@app.post("/predict", response_model=PredictResponse)
|
| 80 |
+
def predict(req: PredictRequest):
|
| 81 |
+
try:
|
| 82 |
+
records = [m.dict() for m in req.data]
|
| 83 |
+
results = _predict(records)
|
| 84 |
+
return {"results": results}
|
| 85 |
+
except Exception as e:
|
| 86 |
+
raise HTTPException(status_code=400, detail=f"prediction error: {e}")
|
requirements.txt
CHANGED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.115
|
| 2 |
+
uvicorn>=0.30
|
| 3 |
+
pydantic>=2.7
|
| 4 |
+
numpy>=2.0.0
|
| 5 |
+
sentence-transformers>=3.0.1
|
| 6 |
+
torch>=2.3
|