xxmaranxx commited on
Commit
deae03b
·
verified ·
1 Parent(s): 9fe5cef

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +86 -0
  2. requirements.txt +6 -0
app.py CHANGED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os, json, pickle
3
+ from typing import List, Dict, Any
4
+ import numpy as np
5
+ from fastapi import FastAPI, HTTPException
6
+ from pydantic import BaseModel
7
+ from sentence_transformers import SentenceTransformer
8
+ from pydantic import BaseModel, ConfigDict
9
+
10
+ # ---- load artifacts once ----
11
+ DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
12
+ PRED_PATH = os.path.join(DATA_DIR, "predictor.pkl")
13
+ if not os.path.exists(PRED_PATH):
14
+ raise FileNotFoundError("Put your predictor.pkl under data/")
15
+
16
+ with open(PRED_PATH, "rb") as f:
17
+ lw = pickle.load(f)
18
+
19
+ model_name: str = lw["model_name"]
20
+ sbert = SentenceTransformer(model_name)
21
+
22
+ centroides = {int(k): np.array(v, dtype=np.float32) for k, v in lw["centroides"].items()}
23
+ # normalize centroid vectors (to match normalized embeddings)
24
+ for k in list(centroides.keys()):
25
+ c = centroides[k]
26
+ n = float(np.linalg.norm(c) + 1e-12)
27
+ centroides[k] = c / n
28
+
29
+ meta: Dict[int, Dict[str, Any]] = lw["meta"]
30
+ cids = sorted(centroides.keys())
31
+
32
+
33
+
34
+ class MailItem(BaseModel):
35
+ subject: str
36
+ body: str
37
+ # allow unknown fields
38
+ model_config = ConfigDict(extra="allow")
39
+
40
+ class PredictRequest(BaseModel):
41
+ data: List[MailItem]
42
+
43
+ class PredictResponseItem(BaseModel):
44
+ json: Dict[str, Any]
45
+
46
+ class PredictResponse(BaseModel):
47
+ results: List[PredictResponseItem]
48
+
49
+ # ---- core ----
50
+ def _encode(texts: List[str]) -> np.ndarray:
51
+ # normalized embeddings
52
+ emb = sbert.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
53
+ return emb.astype(np.float32)
54
+
55
+ def _assign(vec: np.ndarray) -> int:
56
+ dists = [np.linalg.norm(vec - centroides[c]) for c in cids]
57
+ return cids[int(np.argmin(dists))]
58
+
59
+ def _predict(records: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
60
+ texts = [f"{r.get('subject','')} — {r.get('body','')}" for r in records]
61
+ emb = _encode(texts)
62
+ out = []
63
+ for r, v in zip(records, emb):
64
+ cid = _assign(v)
65
+ j = dict(r)
66
+ j["cluster"] = cid
67
+ j["cluster_nombre"] = meta.get(cid, {}).get("nombre", f"cluster_{cid}")
68
+ j["cluster_desc"] = meta.get(cid, {}).get("descripcion", "")
69
+ out.append({"json": j})
70
+ return out
71
+
72
+ # ---- FastAPI app ----
73
+ app = FastAPI(title="Mail Cluster Inference", version="1.0.0")
74
+
75
+ @app.get("/healthz")
76
+ def healthz():
77
+ return {"ok": True, "clusters": len(cids), "model": model_name}
78
+
79
+ @app.post("/predict", response_model=PredictResponse)
80
+ def predict(req: PredictRequest):
81
+ try:
82
+ records = [m.dict() for m in req.data]
83
+ results = _predict(records)
84
+ return {"results": results}
85
+ except Exception as e:
86
+ raise HTTPException(status_code=400, detail=f"prediction error: {e}")
requirements.txt CHANGED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi>=0.115
2
+ uvicorn>=0.30
3
+ pydantic>=2.7
4
+ numpy>=2.0.0
5
+ sentence-transformers>=3.0.1
6
+ torch>=2.3