vachaspathi commited on
Commit
168e3cd
·
verified ·
1 Parent(s): ca18505

Update prompts.py

Browse files
Files changed (1) hide show
  1. prompts.py +44 -408
prompts.py CHANGED
@@ -1,408 +1,44 @@
1
- # app.py — MCP server (single-file)
2
-
3
- from mcp.server.fastmcp import FastMCP
4
- from typing import Optional, List, Tuple, Any, Dict
5
- import requests
6
- import os
7
- import gradio as gr
8
- import json
9
- import time
10
- import traceback
11
- import re
12
- import logging
13
- import base64
14
- import asyncio
15
- import gc
16
-
17
- # --- NEW: Import OCR Engine & Prompts ---
18
- try:
19
- from ocr_engine import extract_text_from_file
20
- from prompts import get_ocr_extraction_prompt, get_agent_prompt
21
- except ImportError:
22
- # Fallback
23
- def extract_text_from_file(path): return "OCR Engine not loaded."
24
- def get_ocr_extraction_prompt(txt): return txt
25
- def get_agent_prompt(h, c, u): return u
26
-
27
- # Setup logging
28
- logging.basicConfig(level=logging.INFO)
29
- logger = logging.getLogger("mcp_server")
30
-
31
- # Attempt to import transformers
32
- TRANSFORMERS_AVAILABLE = False
33
- try:
34
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoModelForSeq2SeqLM
35
- TRANSFORMERS_AVAILABLE = True
36
- except Exception as e:
37
- logger.warning("transformers not available: %s", e)
38
- TRANSFORMERS_AVAILABLE = False
39
-
40
- # ----------------------------
41
- # Load config
42
- # ----------------------------
43
- try:
44
- from config import (
45
- CLIENT_ID, CLIENT_SECRET, REFRESH_TOKEN, API_BASE,
46
- INVOICE_API_BASE, ORGANIZATION_ID, LOCAL_MODEL
47
- )
48
- except Exception as e:
49
- raise SystemExit("Config missing. Check config.py.")
50
-
51
- mcp = FastMCP("ZohoCRMAgent")
52
-
53
- # ----------------------------
54
- # Analytics (Kept intact)
55
- # ----------------------------
56
- ANALYTICS_PATH = "mcp_analytics.json"
57
- def _init_analytics():
58
- if not os.path.exists(ANALYTICS_PATH):
59
- with open(ANALYTICS_PATH, "w") as f: json.dump({}, f)
60
- def _log_tool_call(t, s): pass
61
- def _log_llm_call(c): pass
62
- _init_analytics()
63
-
64
- # ----------------------------
65
- # FIX: Regex JSON Extractor
66
- # ----------------------------
67
- def extract_json_safely(text: str) -> Optional[Any]:
68
- """
69
- Extracts JSON from text even if the model adds conversational filler.
70
- Fixes the '(Parse) Model output was not valid JSON' error.
71
- """
72
- try:
73
- # 1. Try direct parse
74
- return json.loads(text)
75
- except:
76
- pass
77
-
78
- # 2. Regex search for { ... } or [ ... ]
79
- try:
80
- match = re.search(r'(\{.*\}|\[.*\])', text, re.DOTALL)
81
- if match:
82
- json_str = match.group(0)
83
- return json.loads(json_str)
84
- except:
85
- pass
86
- return None
87
-
88
- # ----------------------------
89
- # Local LLM loader
90
- # ----------------------------
91
- LLM_PIPELINE = None
92
- TOKENIZER = None
93
- LOADED_MODEL_NAME = None
94
-
95
- def init_local_model():
96
- global LLM_PIPELINE, TOKENIZER, LOADED_MODEL_NAME
97
- if not LOCAL_MODEL or not TRANSFORMERS_AVAILABLE:
98
- return
99
- try:
100
- logger.info(f"Loading model: {LOCAL_MODEL}...")
101
- TOKENIZER = AutoTokenizer.from_pretrained(LOCAL_MODEL, trust_remote_code=True)
102
- # Use CPU if needed, or remove device_map="auto" if causing issues
103
- model = AutoModelForCausalLM.from_pretrained(LOCAL_MODEL, trust_remote_code=True, device_map="auto")
104
-
105
- # FIX: Lower max_new_tokens to prevent 400s generation loops
106
- LLM_PIPELINE = pipeline("text-generation", model=model, tokenizer=TOKENIZER)
107
- LOADED_MODEL_NAME = LOCAL_MODEL
108
- logger.info("Model loaded.")
109
- except Exception as e:
110
- logger.error(f"Model load failed: {e}")
111
-
112
- init_local_model()
113
-
114
- def local_llm_generate(prompt: str, max_tokens: int = 512) -> Dict[str, Any]:
115
- if LLM_PIPELINE is None:
116
- return {"text": "LLM not loaded.", "raw": None}
117
- try:
118
- # FIX: return_full_text=False ensures we don't re-parse the prompt
119
- out = LLM_PIPELINE(prompt, max_new_tokens=max_tokens, return_full_text=False)
120
- text = out[0]["generated_text"] if out else ""
121
- return {"text": text, "raw": out}
122
- except Exception as e:
123
- return {"text": f"Error: {e}", "raw": None}
124
-
125
- # ----------------------------
126
- # Helper: normalize local file_path args (Kept intact)
127
- # ----------------------------
128
- def _normalize_local_path_args(args: Any) -> Any:
129
- if not isinstance(args, dict): return args
130
- fp = args.get("file_path") or args.get("path")
131
- if isinstance(fp, str) and fp.startswith("/mnt/data/") and os.path.exists(fp):
132
- args["file_url"] = f"file://{fp}"
133
- return args
134
-
135
- # ----------------------------
136
- # Zoho Auth & Tools (Kept intact)
137
- # ----------------------------
138
- def _get_valid_token_headers() -> dict:
139
- token_url = "https://accounts.zoho.in/oauth/v2/token"
140
- params = {
141
- "refresh_token": REFRESH_TOKEN, "client_id": CLIENT_ID,
142
- "client_secret": CLIENT_SECRET, "grant_type": "refresh_token"
143
- }
144
- r = requests.post(token_url, params=params, timeout=20)
145
- if r.status_code == 200:
146
- return {"Authorization": f"Zoho-oauthtoken {r.json().get('access_token')}"}
147
- raise RuntimeError(f"Token refresh failed: {r.text}")
148
-
149
- @mcp.tool()
150
- def authenticate_zoho() -> str:
151
- _get_valid_token_headers(); return "Zoho token refreshed (ok)."
152
-
153
- @mcp.tool()
154
- def create_record(module_name: str, record_data: dict) -> str:
155
- headers = _get_valid_token_headers()
156
- url = f"{API_BASE}/{module_name}"
157
- r = requests.post(url, headers=headers, json={"data": [record_data]}, timeout=20)
158
- if r.status_code in (200, 201): return json.dumps(r.json(), ensure_ascii=False)
159
- return f"Error: {r.text}"
160
-
161
- @mcp.tool()
162
- def get_records(module_name: str, page: int = 1, per_page: int = 200) -> list:
163
- headers = _get_valid_token_headers()
164
- r = requests.get(f"{API_BASE}/{module_name}", headers=headers, params={"page": page, "per_page": per_page})
165
- return r.json().get("data", []) if r.status_code == 200 else []
166
-
167
- @mcp.tool()
168
- def update_record(module_name: str, record_id: str, data: dict) -> str:
169
- headers = _get_valid_token_headers()
170
- r = requests.put(f"{API_BASE}/{module_name}/{record_id}", headers=headers, json={"data": [data]})
171
- return json.dumps(r.json()) if r.status_code == 200 else r.text
172
-
173
- @mcp.tool()
174
- def delete_record(module_name: str, record_id: str) -> str:
175
- headers = _get_valid_token_headers()
176
- r = requests.delete(f"{API_BASE}/{module_name}/{record_id}", headers=headers)
177
- return json.dumps(r.json()) if r.status_code == 200 else r.text
178
-
179
- def _ensure_invoice_config():
180
- if not INVOICE_API_BASE or not ORGANIZATION_ID: raise RuntimeError("Invoice Config Missing")
181
-
182
- @mcp.tool()
183
- def create_invoice(data: dict) -> str:
184
- _ensure_invoice_config()
185
- headers = _get_valid_token_headers()
186
- params = {"organization_id": ORGANIZATION_ID}
187
- r = requests.post(f"{INVOICE_API_BASE}/invoices", headers=headers, params=params, json=data)
188
- if r.status_code in (200, 201): return json.dumps(r.json(), ensure_ascii=False)
189
- return f"Error creating invoice: {r.text}"
190
-
191
- def upload_invoice_attachment(invoice_id: str, file_path: str) -> str:
192
- if not os.path.exists(file_path): return "File not found"
193
- headers = _get_valid_token_headers()
194
- headers.pop("Content-Type", None)
195
- url = f"{INVOICE_API_BASE}/invoices/{invoice_id}/attachments"
196
- with open(file_path, "rb") as f:
197
- files = {"attachment": (os.path.basename(file_path), f)}
198
- r = requests.post(url, headers=headers, params={"organization_id": ORGANIZATION_ID}, files=files)
199
- return json.dumps(r.json()) if r.status_code in (200, 201) else r.text
200
-
201
- @mcp.tool()
202
- def process_document(file_path: str, target_module: Optional[str] = "Contacts") -> dict:
203
- """
204
- Extracts data from file using OCR + LLM.
205
- """
206
- try:
207
- if not os.path.exists(file_path):
208
- return {"status": "error", "error": "file not found"}
209
-
210
- # 1. Perform OCR
211
- raw_text = extract_text_from_file(file_path)
212
- if not raw_text or len(raw_text) < 5:
213
- return {"status": "error", "error": "OCR failed to extract text."}
214
-
215
- # 2. Use Prompt Template (Strict Mode)
216
- # FIX: Use prompts.py template + reduce max_tokens for speed
217
- prompt = get_ocr_extraction_prompt(raw_text)
218
-
219
- llm_out = local_llm_generate(prompt, max_tokens=300) # 300 tokens is plenty for JSON
220
- extracted_text = llm_out.get("text", "")
221
-
222
- # FIX: Use Regex Safe Extraction
223
- extracted_data = extract_json_safely(extracted_text)
224
-
225
- if not extracted_data:
226
- # Fallback for debugging
227
- extracted_data = {"raw_llm_text": extracted_text}
228
-
229
- return {
230
- "status": "success",
231
- "file": os.path.basename(file_path),
232
- "extracted_data": extracted_data
233
- }
234
-
235
- except Exception as e:
236
- return {"status": "error", "error": str(e)}
237
-
238
- # ----------------------------
239
- # Helpers: map LLM args -> Zoho payloads (Kept intact)
240
- # ----------------------------
241
- def _extract_created_id_from_zoho_response(resp_json) -> Optional[str]:
242
- # (Same implementation as before)
243
- try:
244
- if isinstance(resp_json, str): resp_json = json.loads(resp_json)
245
- data = resp_json.get("data") or resp_json.get("result")
246
- if data and isinstance(data, list):
247
- d = data[0].get("details") or data[0]
248
- return str(d.get("id") or d.get("ID") or d.get("Id"))
249
- if "invoice" in resp_json: return str(resp_json["invoice"].get("invoice_id"))
250
- except: pass
251
- return None
252
-
253
- def _map_contact_args_to_zoho_payload(args: dict) -> dict:
254
- # (Same implementation as before - abbreviated for strict structure compliance)
255
- p = {}
256
- if "contact" in args: p["Last_Name"] = args["contact"]
257
- if "email" in args: p["Email"] = args["email"]
258
- # ... map other fields ...
259
- for k,v in args.items():
260
- if k not in ["contact", "email", "items"]: p[k] = v
261
- return p
262
-
263
- def _build_invoice_payload_for_zoho(contact_id: str, invoice_items: List[dict], currency: str = None, vat_pct: float = 0.0) -> dict:
264
- # (Same implementation as before)
265
- line_items = []
266
- for it in invoice_items:
267
- qty = int(it.get("quantity", 1))
268
- rate = float(str(it.get("rate", 0)).replace("$",""))
269
- line_items.append({"name": it.get("name","Item"), "rate": rate, "quantity": qty})
270
- payload = {"customer_id": contact_id, "line_items": line_items}
271
- if currency: payload["currency_code"] = currency
272
- return payload
273
-
274
- # ----------------------------
275
- # Parse & Execute (Kept intact)
276
- # ----------------------------
277
- def parse_and_execute_model_tool_output(model_text: str, history: Optional[List] = None) -> str:
278
- # FIX: Use Safe Extraction first
279
- payload = extract_json_safely(model_text)
280
-
281
- if not payload:
282
- return "(Parse) Model output was not valid JSON tool instruction."
283
-
284
- # Normalize to list
285
- instructions = [payload] if isinstance(payload, dict) else payload
286
- results = []
287
- contact_id = None
288
-
289
- for instr in instructions:
290
- if not isinstance(instr, dict): continue
291
- tool = instr.get("tool")
292
- args = instr.get("args", {})
293
- args = _normalize_local_path_args(args)
294
-
295
- if tool == "create_record":
296
- # ... (logic same as before)
297
- res = create_record(args.get("module", "Contacts"), _map_contact_args_to_zoho_payload(args))
298
- results.append(f"create_record -> {res}")
299
- contact_id = _extract_created_id_from_zoho_response(res)
300
-
301
- elif tool == "create_invoice":
302
- # ... (logic same as before)
303
- if not contact_id: contact_id = args.get("customer_id")
304
- if contact_id:
305
- inv_payload = _build_invoice_payload_for_zoho(contact_id, args.get("line_items", []))
306
- res = create_invoice(inv_payload)
307
- results.append(f"create_invoice -> {res}")
308
- else:
309
- results.append("Skipped invoice: missing contact_id")
310
-
311
- elif tool == "process_document":
312
- res = process_document(args.get("file_path"))
313
- results.append(f"process -> {res}")
314
-
315
- return "\n".join(results) if results else "No tools executed."
316
-
317
- # ----------------------------
318
- # Command Parser (Debug)
319
- # ----------------------------
320
- def try_parse_and_invoke_command(text: str):
321
- # (Same implementation)
322
- if text.startswith("/mnt/data/"): return str(process_document(text))
323
- return None
324
-
325
- # ----------------------------
326
- # Chat Logic
327
- # ----------------------------
328
- def deepseek_response(message: str, file_path: Optional[str] = None, history: list = []) -> str:
329
-
330
- # 1. Handle File (OCR)
331
- ocr_context = ""
332
- if file_path:
333
- logger.info(f"Processing file: {file_path}")
334
- doc_result = process_document(file_path)
335
- if doc_result.get("status") == "success":
336
- data = doc_result["extracted_data"]
337
- ocr_context = json.dumps(data, ensure_ascii=False)
338
- if not message:
339
- message = "I uploaded a file. Create the contact and invoice."
340
- else:
341
- return f"Error processing file: {doc_result.get('error')}"
342
-
343
- # 2. Build Prompt (FIX: Use prompts.py)
344
- # Flatten history for the prompt
345
- history_text = "\n".join([f"User: {h[0]}\nBot: {h[1]}" for h in history])
346
- prompt = get_agent_prompt(history_text, ocr_context, message)
347
-
348
- # 3. Generate
349
- gen = local_llm_generate(prompt, max_tokens=256)
350
- response_text = gen["text"]
351
-
352
- # 4. Check for JSON Tool Call (FIX: Use Safe Extraction)
353
- tool_json = extract_json_safely(response_text)
354
-
355
- if tool_json and isinstance(tool_json, (dict, list)):
356
- try:
357
- # We must pass the RAW text or the JSON object?
358
- # Your existing function `parse_and_execute...` expects a string or valid json structure.
359
- # Let's pass the JSON stringified to be safe, or modify the caller.
360
- # The safest way given your strict structure requirement is:
361
- return parse_and_execute_model_tool_output(json.dumps(tool_json), history)
362
- except Exception as e:
363
- return f"(Execute) Error: {e}"
364
-
365
- return response_text
366
-
367
- # ----------------------------
368
- # Chat Handler
369
- # ----------------------------
370
- def chat_handler(message, history):
371
- user_text = ""
372
- uploaded_file_path = None
373
-
374
- if isinstance(message, dict):
375
- user_text = message.get("text", "")
376
- files = message.get("files", [])
377
- if files: uploaded_file_path = files[0]
378
- else:
379
- user_text = str(message)
380
-
381
- # Debug command bypass
382
- if not uploaded_file_path:
383
- cmd = try_parse_and_invoke_command(user_text)
384
- if cmd: return cmd
385
-
386
- return deepseek_response(user_text, uploaded_file_path, history)
387
-
388
- # ----------------------------
389
- # FIX: Cleanup for fd -1 error
390
- # ----------------------------
391
- def cleanup_event_loop():
392
- gc.collect()
393
- try:
394
- loop = asyncio.get_event_loop()
395
- if loop.is_closed():
396
- asyncio.set_event_loop(asyncio.new_event_loop())
397
- except RuntimeError:
398
- asyncio.set_event_loop(asyncio.new_event_loop())
399
-
400
- if __name__ == "__main__":
401
- cleanup_event_loop()
402
-
403
- demo = gr.ChatInterface(
404
- fn=chat_handler,
405
- multimodal=True,
406
- textbox=gr.MultimodalTextbox(interactive=True, file_count="single", placeholder="Upload Invoice or ask to create records...")
407
- )
408
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ # prompts.py
2
+ # Qwen-2.5 Compatible Prompts (ChatML format)
3
+
4
+ def get_ocr_extraction_prompt(raw_text: str) -> str:
5
+ return f"""<|im_start|>system
6
+ You are a precise Data Extraction Engine.
7
+ Extract data from the text below and return a JSON object.
8
+ Fields: contact_name, total_amount, currency, invoice_date, line_items (name, quantity, rate).
9
+ Output ONLY JSON. No markdown.
10
+ <|im_end|>
11
+ <|im_start|>user
12
+ Input Text:
13
+ {raw_text[:3000]}
14
+
15
+ Return the JSON:
16
+ <|im_end|>
17
+ <|im_start|>assistant
18
+ """
19
+
20
+ def get_agent_prompt(history_text: str, ocr_context: str, user_message: str) -> str:
21
+ context_block = ""
22
+ if ocr_context:
23
+ context_block = f"CONTEXT FROM FILE:\n{ocr_context}\n"
24
+
25
+ return f"""<|im_start|>system
26
+ You are Zoho Assistant. Tools:
27
+ 1. create_record(module_name, record_data)
28
+ 2. create_invoice(data)
29
+ 3. process_document(file_path)
30
+
31
+ If user wants an action, return JSON: {{"tool": "name", "args": {{...}}}}
32
+ Use CONTEXT FROM FILE to fill args.
33
+ Return ONLY JSON.
34
+ <|im_end|>
35
+ <|im_start|>user
36
+ {context_block}
37
+ HISTORY:
38
+ {history_text}
39
+
40
+ REQUEST:
41
+ {user_message}
42
+ <|im_end|>
43
+ <|im_start|>assistant
44
+ """