Spaces:
Running
Running
| # app.py — MCP server (refined) | |
| # Key improvements: | |
| # - Robust JSON extraction & repair | |
| # - Detailed debug logging, write raw LLM output to /tmp when parse fails | |
| # - Defensive LLM handling | |
| # - Uses your ocr_engine.extract_text_and_conf | |
| from mcp.server.fastmcp import FastMCP | |
| from typing import Optional, Any, Dict | |
| import requests | |
| import os | |
| import gradio as gr | |
| import json | |
| import re | |
| import logging | |
| import gc | |
| import time | |
| import traceback | |
| # imports from local modules (these must exist) | |
| from ocr_engine import extract_text_and_conf | |
| from prompts import get_ocr_extraction_prompt, get_agent_prompt | |
| # config (must exist) | |
| try: | |
| from config import CLIENT_ID, CLIENT_SECRET, REFRESH_TOKEN, API_BASE, INVOICE_API_BASE, ORGANIZATION_ID, LOCAL_MODEL | |
| except Exception as e: | |
| raise SystemExit("Missing config.py or required keys. Error: " + str(e)) | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("mcp_server") | |
| mcp = FastMCP("ZohoCRMAgent") | |
| LLM_PIPELINE = None | |
| TOKENIZER = None | |
| # ---------------- JSON extraction helpers ---------------- | |
| def _try_json_loads(text: str) -> Optional[Any]: | |
| try: | |
| return json.loads(text) | |
| except Exception: | |
| return None | |
| def _remove_code_fences(s: str) -> str: | |
| s = re.sub(r"```(?:json)?\s*", "", s, flags=re.IGNORECASE) | |
| s = re.sub(r"\s*```$", "", s, flags=re.IGNORECASE) | |
| return s.strip() | |
| def _attempt_simple_repairs(s: str) -> str: | |
| # keep printable chars | |
| s = "".join(ch for ch in s if (ch == "\n" or ch == "\t" or (32 <= ord(ch) <= 0x10FFFF))) | |
| # remove trailing commas | |
| s = re.sub(r",\s*(\}|])", r"\1", s) | |
| # convert single quotes if double quotes not present | |
| if '"' not in s and "'" in s: | |
| s = s.replace("'", '"') | |
| return s | |
| def _dump_raw_llm_output(text: str) -> str: | |
| """Dump raw LLM output to a timestamped file for debugging and return path.""" | |
| try: | |
| ts = int(time.time()) | |
| path = f"/tmp/llm_output_{ts}.txt" | |
| with open(path, "w", encoding="utf-8") as f: | |
| f.write(text) | |
| logger.info("Wrote raw LLM output to %s for debugging", path) | |
| return path | |
| except Exception as e: | |
| logger.exception("Failed to write raw llm output: %s", e) | |
| return "" | |
| def extract_json_safely(text: str) -> Optional[Any]: | |
| """ | |
| Robustly extract JSON from LLM output. | |
| 1) Try direct loads | |
| 2) Try marker extraction <<<JSON>>> ... <<<END_JSON>>> | |
| 3) Try largest balanced { ... } block | |
| 4) Try array [...] | |
| On failure, write raw text to /tmp and return None. | |
| """ | |
| if not text: | |
| return None | |
| # direct | |
| parsed = _try_json_loads(text) | |
| if parsed is not None: | |
| return parsed | |
| # marker-based extraction | |
| marker_re = re.compile(r"<<<JSON>>>\s*([\s\S]*?)\s*<<<END_JSON>>>", re.IGNORECASE) | |
| m = marker_re.search(text) | |
| if m: | |
| cand = _remove_code_fences(m.group(1)) | |
| p = _try_json_loads(cand) | |
| if p is not None: | |
| return p | |
| cand2 = _attempt_simple_repairs(cand) | |
| try: | |
| return json.loads(cand2) | |
| except Exception as e: | |
| logger.warning("Marker JSON repair failed: %s", e) | |
| # fallback: largest balanced {...} | |
| stack = [] | |
| spans = [] | |
| for i, ch in enumerate(text): | |
| if ch == "{": | |
| stack.append(i) | |
| elif ch == "}" and stack: | |
| start = stack.pop() | |
| spans.append((start, i)) | |
| spans = sorted(spans, key=lambda t: t[1]-t[0], reverse=True) | |
| for start, end in spans: | |
| cand = text[start:end+1].strip() | |
| if len(cand) < 20: | |
| continue | |
| cand = _remove_code_fences(cand) | |
| p = _try_json_loads(cand) | |
| if p is not None: | |
| return p | |
| cand2 = _attempt_simple_repairs(cand) | |
| try: | |
| return json.loads(cand2) | |
| except Exception: | |
| continue | |
| # try array | |
| arr = re.search(r"(\[[\s\S]*\])", text) | |
| if arr: | |
| cand = _remove_code_fences(arr.group(1)) | |
| p = _try_json_loads(cand) | |
| if p is not None: | |
| return p | |
| cand2 = _attempt_simple_repairs(cand) | |
| try: | |
| return json.loads(cand2) | |
| except Exception: | |
| pass | |
| # failed: dump raw text and log traceback | |
| dump_path = _dump_raw_llm_output(text) | |
| logger.error("extract_json_safely: failed to parse JSON. Raw output saved to: %s", dump_path) | |
| return None | |
| # ---------------- Model helpers (defensive) ---------------- | |
| def init_local_model(): | |
| global LLM_PIPELINE, TOKENIZER | |
| if LLM_PIPELINE is not None: | |
| return | |
| try: | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| TOKENIZER = AutoTokenizer.from_pretrained(LOCAL_MODEL) | |
| dtype = None | |
| # choose dtype depending on CUDA availability | |
| if torch.cuda.is_available(): | |
| dtype = torch.float16 | |
| model = AutoModelForCausalLM.from_pretrained(LOCAL_MODEL, device_map="auto", torch_dtype=dtype) | |
| LLM_PIPELINE = pipeline("text-generation", model=model, tokenizer=TOKENIZER) | |
| logger.info("Local model initialized.") | |
| except Exception as e: | |
| logger.exception("Failed to load local model: %s", e) | |
| LLM_PIPELINE = None | |
| def local_llm_generate(prompt: str, max_tokens: int = 512) -> Dict[str, Any]: | |
| if LLM_PIPELINE is None: | |
| init_local_model() | |
| if LLM_PIPELINE is None: | |
| return {"text": "Model not loaded.", "raw": None} | |
| try: | |
| out = LLM_PIPELINE(prompt, max_new_tokens=max_tokens, return_full_text=False, do_sample=False) | |
| # defensively extract text | |
| text = "" | |
| if isinstance(out, list) and out: | |
| first = out[0] | |
| if isinstance(first, dict) and "generated_text" in first: | |
| text = first["generated_text"] | |
| elif isinstance(first, str): | |
| text = first | |
| else: | |
| text = str(first) | |
| elif isinstance(out, str): | |
| text = out | |
| return {"text": text, "raw": out} | |
| except Exception as e: | |
| logger.exception("LLM generation error: %s", e) | |
| return {"text": f"LLM error: {e}", "raw": None} | |
| # ---------------- Zoho token utility ---------------- | |
| def _get_valid_token_headers() -> dict: | |
| try: | |
| r = requests.post("https://accounts.zoho.in/oauth/v2/token", params={ | |
| "refresh_token": REFRESH_TOKEN, "client_id": CLIENT_ID, | |
| "client_secret": CLIENT_SECRET, "grant_type": "refresh_token" | |
| }, timeout=15) | |
| if r.status_code == 200: | |
| tok = r.json().get("access_token") | |
| return {"Authorization": f"Zoho-oauthtoken {tok}"} | |
| else: | |
| logger.error("Token refresh failed: %s", r.text) | |
| return {} | |
| except Exception as e: | |
| logger.exception("Token refresh exception: %s", e) | |
| return {} | |
| # ---------------- MCP tool implementations ---------------- | |
| def create_record(module_name: str, record_data: dict) -> str: | |
| headers = _get_valid_token_headers() | |
| if not headers: | |
| return json.dumps({"status": "error", "message": "Auth failed"}) | |
| try: | |
| r = requests.post(f"{API_BASE}/{module_name}", headers=headers, json={"data": [record_data]}, timeout=15) | |
| return json.dumps(r.json()) if r.status_code in (200,201) else json.dumps({"status":"error","http_status":r.status_code,"text":r.text}) | |
| except Exception as e: | |
| logger.exception("create_record failed: %s", e) | |
| return json.dumps({"status":"error","message": str(e)}) | |
| def create_invoice(data: dict) -> str: | |
| headers = _get_valid_token_headers() | |
| if not headers: | |
| return json.dumps({"status": "error", "message": "Auth failed"}) | |
| try: | |
| r = requests.post(f"{INVOICE_API_BASE}/invoices", headers=headers, params={"organization_id": ORGANIZATION_ID}, json=data, timeout=15) | |
| return json.dumps(r.json()) if r.status_code in (200,201) else json.dumps({"status":"error","http_status": r.status_code, "text": r.text}) | |
| except Exception as e: | |
| logger.exception("create_invoice failed: %s", e) | |
| return json.dumps({"status":"error","message": str(e)}) | |
| # ---------------- Document processing ---------------- | |
| def process_document(file_path: str, target_module: Optional[str] = "Contacts") -> dict: | |
| """Full flow: OCR -> LLM extraction -> KPI -> result with raw llm text for debugging""" | |
| if not os.path.exists(file_path): | |
| return {"status": "error", "error": f"File not found: {file_path}"} | |
| raw_text, ocr_score = extract_text_and_conf(file_path) | |
| if not raw_text: | |
| return {"status": "error", "error": "OCR returned empty text."} | |
| prompt = get_ocr_extraction_prompt(raw_text, page_count=1) | |
| llm_res = local_llm_generate(prompt, max_tokens=512) | |
| llm_text = llm_res.get("text", "") | |
| parsed = extract_json_safely(llm_text) | |
| kpis = {"score": 0, "rating": "Fail", "issues": ["Extraction failed"]} | |
| if parsed: | |
| # compute kpis basic heuristics (simple) | |
| try: | |
| total = parsed.get("totals", {}).get("grand_total") | |
| semantic_ok = 1 if total else 0 | |
| kpis = { | |
| "score": 80 if semantic_ok else 40, | |
| "rating": "High" if semantic_ok else "Low", | |
| "ocr_score": ocr_score, | |
| "issues": [] if semantic_ok else ["grand_total missing"] | |
| } | |
| except Exception: | |
| kpis["issues"].append("Error computing KPIs") | |
| # If parse failed, persist raw LLM output path for debugging | |
| raw_dump = None | |
| if not parsed: | |
| raw_dump = _dump_raw_llm_output(llm_text) | |
| return { | |
| "status": "success" if parsed else "partial", | |
| "file": os.path.basename(file_path), | |
| "extracted_data": parsed if parsed else None, | |
| "raw_llm_output": llm_text, | |
| "raw_llm_dump_path": raw_dump, | |
| "kpis": kpis | |
| } | |
| # ---------------- Agent orchestration and chat ---------------- | |
| def parse_and_execute(model_text: str, history: list) -> str: | |
| payload = extract_json_safely(model_text) | |
| if not payload: | |
| return "No valid tool JSON found in model output. Raw output saved for debugging." | |
| if isinstance(payload, dict): | |
| cmds = [payload] | |
| else: | |
| cmds = payload | |
| results = [] | |
| last_contact_id = None | |
| for cmd in cmds: | |
| if not isinstance(cmd, dict): | |
| continue | |
| tool = cmd.get("tool") | |
| args = cmd.get("args", {}) | |
| if tool == "create_record": | |
| module = args.get("module_name", "Contacts") | |
| record = args.get("record_data", {}) | |
| res = create_record(module, record) | |
| results.append(f"create_record -> {res}") | |
| # attempt to capture id | |
| try: | |
| rj = json.loads(res) | |
| if isinstance(rj, dict) and "data" in rj and isinstance(rj["data"], list) and rj["data"]: | |
| last_contact_id = rj["data"][0].get("details", {}).get("id") | |
| except Exception: | |
| pass | |
| elif tool == "create_invoice": | |
| invoice_payload = args | |
| if not invoice_payload.get("customer_id") and last_contact_id: | |
| invoice_payload["customer_id"] = last_contact_id | |
| res = create_invoice(invoice_payload) | |
| results.append(f"create_invoice -> {res}") | |
| else: | |
| results.append(f"Unknown tool: {tool}") | |
| return "\n".join(results) if results else "No actionable tool calls executed." | |
| def chat_logic(message: str, file_path: Optional[str], history: list) -> str: | |
| if file_path: | |
| logger.info("chat_logic: processing file %s", file_path) | |
| doc = process_document(file_path) | |
| status = doc.get("status") | |
| if status in ("success", "partial"): | |
| extracted = doc.get("extracted_data") | |
| raw_llm = doc.get("raw_llm_output") | |
| dump_path = doc.get("raw_llm_dump_path") | |
| kpis = doc.get("kpis", {}) | |
| extracted_pretty = json.dumps(extracted, indent=2) if extracted else "(no structured JSON parsed)" | |
| msg = ( | |
| f"### 📄 Extraction Result for **{doc.get('file')}**\n" | |
| f"Status: {status}\n" | |
| f"KPI Score: {kpis.get('score')} Rating: {kpis.get('rating')}\n" | |
| f"OCR Confidence: {kpis.get('ocr_score', 'N/A')}\n\n" | |
| f"Extracted JSON:\n```json\n{extracted_pretty}\n```\n" | |
| ) | |
| if dump_path: | |
| msg += f"\n⚠️ The model output could not be parsed into strict JSON. Raw LLM output saved to: `{dump_path}`\n" | |
| msg += "You can inspect that file to debug the model response or prompt." | |
| msg += "\nType 'Create Invoice' to persist when ready." | |
| return msg | |
| else: | |
| return f"Error during processing: {doc.get('error')}" | |
| # text-only interaction | |
| hist_txt = "\n".join([f"U: {h[0]}\nA: {h[1]}" for h in history]) if history else "" | |
| prompt = get_agent_prompt(hist_txt, message) | |
| gen = local_llm_generate(prompt, max_tokens=256) | |
| gen_text = gen.get("text", "") | |
| tool_payload = extract_json_safely(gen_text) | |
| if tool_payload: | |
| return parse_and_execute(gen_text, history) | |
| # if not a tool call, return the LLM text (or clear error) | |
| if gen_text: | |
| return gen_text | |
| else: | |
| return "No response from model." | |
| # ---------------- Gradio wrapper ---------------- | |
| def chat_handler(msg, hist): | |
| txt = msg.get("text", "") | |
| files = msg.get("files", []) | |
| path = files[0] if files else None | |
| return chat_logic(txt, path, hist) | |
| if __name__ == "__main__": | |
| gc.collect() | |
| demo = gr.ChatInterface(fn=chat_handler, multimodal=True) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |