vachaspathi commited on
Commit
ad08316
·
verified ·
1 Parent(s): c5b3162

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -17
app.py CHANGED
@@ -11,6 +11,7 @@ import re
11
  import logging
12
  import asyncio
13
  import gc
 
14
 
15
  # --- Import OCR Engine & Prompts ---
16
  try:
@@ -54,7 +55,7 @@ def _normalize_local_path_args(args: Any) -> Any:
54
  args["file_url"] = f"file://{fp}"
55
  return args
56
 
57
- # --- Model Loading (Lazy & Light) ---
58
  def init_local_model():
59
  global LLM_PIPELINE, TOKENIZER
60
  if LLM_PIPELINE is not None: return
@@ -64,14 +65,11 @@ def init_local_model():
64
 
65
  logger.info(f"Loading lighter model: {LOCAL_MODEL}...")
66
  TOKENIZER = AutoTokenizer.from_pretrained(LOCAL_MODEL)
67
-
68
- # Load model (Standard load is fine for Qwen on CPU)
69
  model = AutoModelForCausalLM.from_pretrained(
70
  LOCAL_MODEL,
71
  device_map="auto",
72
  torch_dtype="auto"
73
  )
74
-
75
  LLM_PIPELINE = pipeline("text-generation", model=model, tokenizer=TOKENIZER)
76
  logger.info("Model loaded.")
77
  except Exception as e:
@@ -85,13 +83,12 @@ def local_llm_generate(prompt: str, max_tokens: int = 512) -> Dict[str, Any]:
85
  return {"text": "Model not loaded.", "raw": None}
86
 
87
  try:
88
- # Standard generation (Qwen is robust, no cache hacks needed)
89
  out = LLM_PIPELINE(
90
  prompt,
91
  max_new_tokens=max_tokens,
92
  return_full_text=False,
93
- do_sample=False, # Deterministic for tools
94
- temperature=0.0
95
  )
96
  text = out[0]["generated_text"] if out else ""
97
  return {"text": text, "raw": out}
@@ -114,7 +111,6 @@ def create_record(module_name: str, record_data: dict) -> str:
114
  if not h: return "Auth Failed"
115
  r = requests.post(f"{API_BASE}/{module_name}", headers=h, json={"data": [record_data]})
116
  if r.status_code in (200, 201):
117
- # Extract ID for downstream use
118
  try:
119
  d = r.json().get("data", [{}])[0].get("details", {})
120
  return json.dumps({"status": "success", "id": d.get("id"), "response": r.json()})
@@ -132,7 +128,9 @@ def create_invoice(data: dict) -> str:
132
 
133
  @mcp.tool()
134
  def process_document(file_path: str, target_module: Optional[str] = "Contacts") -> dict:
135
- if not os.path.exists(file_path): return {"error": "File not found"}
 
 
136
 
137
  # 1. OCR
138
  raw_text = extract_text_from_file(file_path)
@@ -154,11 +152,9 @@ def parse_and_execute(model_text: str, history: list) -> str:
154
  payload = extract_json_safely(model_text)
155
  if not payload: return "No valid tool call found."
156
 
157
- # Normalize
158
  cmds = [payload] if isinstance(payload, dict) else payload
159
  results = []
160
 
161
- # Context State
162
  last_contact_id = None
163
 
164
  for cmd in cmds:
@@ -169,7 +165,6 @@ def parse_and_execute(model_text: str, history: list) -> str:
169
  if tool == "create_record":
170
  res = create_record(args.get("module", "Contacts"), args)
171
  results.append(f"Record: {res}")
172
- # Try capture ID
173
  try:
174
  rj = json.loads(res)
175
  if isinstance(rj, dict) and "id" in rj:
@@ -177,11 +172,9 @@ def parse_and_execute(model_text: str, history: list) -> str:
177
  except: pass
178
 
179
  elif tool == "create_invoice":
180
- # Auto-fill contact_id if we just created one
181
  if not args.get("customer_id") and last_contact_id:
182
  args["customer_id"] = last_contact_id
183
 
184
- # Map Items
185
  items = []
186
  for it in args.get("line_items", []):
187
  items.append({
@@ -197,6 +190,7 @@ def parse_and_execute(model_text: str, history: list) -> str:
197
  results.append(f"Invoice: {res}")
198
 
199
  elif tool == "process_document":
 
200
  res = process_document(args.get("file_path"))
201
  results.append(f"Processed: {res}")
202
 
@@ -204,9 +198,11 @@ def parse_and_execute(model_text: str, history: list) -> str:
204
 
205
  # --- Chat Core ---
206
  def chat_logic(message: str, file_path: str, history: list) -> str:
207
- # 1. Ingest File
 
208
  file_context = ""
209
  if file_path:
 
210
  doc = process_document(file_path)
211
  if doc.get("status") == "success":
212
  file_context = json.dumps(doc["extracted_data"])
@@ -214,12 +210,14 @@ def chat_logic(message: str, file_path: str, history: list) -> str:
214
  else:
215
  return f"OCR Failed: {doc}"
216
 
217
- # 2. Decision
218
  hist_txt = "\n".join([f"U: {h[0]}\nA: {h[1]}" for h in history])
219
  prompt = get_agent_prompt(hist_txt, file_context, message)
220
 
221
  # 3. Gen & Execute
222
  gen = local_llm_generate(prompt, max_tokens=200)
 
 
223
  tool_data = extract_json_safely(gen["text"])
224
 
225
  if tool_data:
@@ -233,6 +231,9 @@ def chat_handler(msg, hist):
233
  files = msg.get("files", [])
234
  path = files[0] if files else None
235
 
 
 
 
236
  # Direct path bypass for debugging
237
  if not path and txt.startswith("/mnt/data"):
238
  return str(process_document(txt))
@@ -241,6 +242,5 @@ def chat_handler(msg, hist):
241
 
242
  if __name__ == "__main__":
243
  gc.collect()
244
- # Lazy init will happen on first request, saving startup memory
245
  demo = gr.ChatInterface(fn=chat_handler, multimodal=True)
246
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
11
  import logging
12
  import asyncio
13
  import gc
14
+ import shutil
15
 
16
  # --- Import OCR Engine & Prompts ---
17
  try:
 
55
  args["file_url"] = f"file://{fp}"
56
  return args
57
 
58
+ # --- Model Loading ---
59
  def init_local_model():
60
  global LLM_PIPELINE, TOKENIZER
61
  if LLM_PIPELINE is not None: return
 
65
 
66
  logger.info(f"Loading lighter model: {LOCAL_MODEL}...")
67
  TOKENIZER = AutoTokenizer.from_pretrained(LOCAL_MODEL)
 
 
68
  model = AutoModelForCausalLM.from_pretrained(
69
  LOCAL_MODEL,
70
  device_map="auto",
71
  torch_dtype="auto"
72
  )
 
73
  LLM_PIPELINE = pipeline("text-generation", model=model, tokenizer=TOKENIZER)
74
  logger.info("Model loaded.")
75
  except Exception as e:
 
83
  return {"text": "Model not loaded.", "raw": None}
84
 
85
  try:
86
+ # FIX: Removed invalid flags 'temperature', 'top_p', etc. when do_sample is False
87
  out = LLM_PIPELINE(
88
  prompt,
89
  max_new_tokens=max_tokens,
90
  return_full_text=False,
91
+ do_sample=False # Deterministic
 
92
  )
93
  text = out[0]["generated_text"] if out else ""
94
  return {"text": text, "raw": out}
 
111
  if not h: return "Auth Failed"
112
  r = requests.post(f"{API_BASE}/{module_name}", headers=h, json={"data": [record_data]})
113
  if r.status_code in (200, 201):
 
114
  try:
115
  d = r.json().get("data", [{}])[0].get("details", {})
116
  return json.dumps({"status": "success", "id": d.get("id"), "response": r.json()})
 
128
 
129
  @mcp.tool()
130
  def process_document(file_path: str, target_module: Optional[str] = "Contacts") -> dict:
131
+ if not os.path.exists(file_path):
132
+ logger.error(f"process_document: File not found at {file_path}")
133
+ return {"error": f"File not found at path: {file_path}"}
134
 
135
  # 1. OCR
136
  raw_text = extract_text_from_file(file_path)
 
152
  payload = extract_json_safely(model_text)
153
  if not payload: return "No valid tool call found."
154
 
 
155
  cmds = [payload] if isinstance(payload, dict) else payload
156
  results = []
157
 
 
158
  last_contact_id = None
159
 
160
  for cmd in cmds:
 
165
  if tool == "create_record":
166
  res = create_record(args.get("module", "Contacts"), args)
167
  results.append(f"Record: {res}")
 
168
  try:
169
  rj = json.loads(res)
170
  if isinstance(rj, dict) and "id" in rj:
 
172
  except: pass
173
 
174
  elif tool == "create_invoice":
 
175
  if not args.get("customer_id") and last_contact_id:
176
  args["customer_id"] = last_contact_id
177
 
 
178
  items = []
179
  for it in args.get("line_items", []):
180
  items.append({
 
190
  results.append(f"Invoice: {res}")
191
 
192
  elif tool == "process_document":
193
+ # NOTE: Prompts try to prevent this, but if it happens, we rely on args being correct
194
  res = process_document(args.get("file_path"))
195
  results.append(f"Processed: {res}")
196
 
 
198
 
199
  # --- Chat Core ---
200
  def chat_logic(message: str, file_path: str, history: list) -> str:
201
+
202
+ # 1. Ingest File IMMEDIATELY
203
  file_context = ""
204
  if file_path:
205
+ logger.info(f"Ingesting file from path: {file_path}")
206
  doc = process_document(file_path)
207
  if doc.get("status") == "success":
208
  file_context = json.dumps(doc["extracted_data"])
 
210
  else:
211
  return f"OCR Failed: {doc}"
212
 
213
+ # 2. Decision Prompt (With context injected)
214
  hist_txt = "\n".join([f"U: {h[0]}\nA: {h[1]}" for h in history])
215
  prompt = get_agent_prompt(hist_txt, file_context, message)
216
 
217
  # 3. Gen & Execute
218
  gen = local_llm_generate(prompt, max_tokens=200)
219
+ logger.info(f"LLM Decision: {gen['text']}")
220
+
221
  tool_data = extract_json_safely(gen["text"])
222
 
223
  if tool_data:
 
231
  files = msg.get("files", [])
232
  path = files[0] if files else None
233
 
234
+ if path:
235
+ logger.info(f"UI received file: {path}")
236
+
237
  # Direct path bypass for debugging
238
  if not path and txt.startswith("/mnt/data"):
239
  return str(process_document(txt))
 
242
 
243
  if __name__ == "__main__":
244
  gc.collect()
 
245
  demo = gr.ChatInterface(fn=chat_handler, multimodal=True)
246
  demo.launch(server_name="0.0.0.0", server_port=7860)