File size: 22,984 Bytes
e91e2b4
 
 
 
 
 
f2a2584
e91e2b4
4a34f6e
e91e2b4
16ab50a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e91e2b4
 
 
 
 
 
 
 
 
 
 
4d77281
 
 
 
 
 
 
 
2d070b4
 
 
4d77281
 
 
 
2d070b4
 
 
4d77281
 
 
 
 
2d070b4
4d77281
 
 
 
 
 
 
 
 
 
f2a2584
e91e2b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2a2584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3b1df7
f2a2584
 
 
e91e2b4
 
f2a2584
 
 
 
 
e91e2b4
 
 
 
 
f2a2584
 
 
e91e2b4
 
 
 
4d77281
 
4a34f6e
 
 
 
 
 
 
4d77281
 
 
 
 
 
 
 
4a34f6e
 
 
4cd3a4a
4a34f6e
4cd3a4a
4a34f6e
 
 
 
e91e2b4
4d77281
 
 
 
 
 
 
9abf0b8
 
 
 
 
 
 
 
4d77281
 
 
 
 
 
e91e2b4
 
 
 
 
 
3a73f5d
 
 
 
 
 
16ab50a
3a73f5d
16ab50a
 
3a73f5d
16ab50a
61fbcaf
16ab50a
 
 
 
61fbcaf
 
 
16ab50a
 
 
 
 
 
 
3a73f5d
61fbcaf
 
 
 
 
 
 
 
16ab50a
61fbcaf
3a73f5d
16ab50a
 
61fbcaf
16ab50a
 
 
 
 
3a73f5d
 
16ab50a
 
 
 
 
 
3a73f5d
 
 
 
 
16ab50a
3a73f5d
16ab50a
 
3a73f5d
16ab50a
3a73f5d
16ab50a
 
 
1f9a06d
3a73f5d
 
e91e2b4
 
 
 
f2a2584
 
e91e2b4
 
f2a2584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d77281
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e91e2b4
4d77281
 
 
 
 
 
 
9abf0b8
 
 
 
 
 
 
4d77281
 
 
 
 
 
e91e2b4
f2a2584
 
 
 
 
 
 
 
 
 
e91e2b4
 
 
94aee85
 
 
e91e2b4
94aee85
 
f2a2584
94aee85
 
f2a2584
94aee85
 
f2a2584
94aee85
 
 
 
 
 
f2a2584
94aee85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2a2584
94aee85
 
f2a2584
94aee85
 
 
 
 
 
f2a2584
94aee85
 
 
 
 
f2a2584
94aee85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2a2584
94aee85
f2a2584
94aee85
 
 
 
 
 
 
e91e2b4
 
 
 
 
 
94aee85
 
 
 
 
 
 
 
 
 
 
 
 
 
f2a2584
16ab50a
 
f2a2584
 
 
4d77281
f2a2584
4a34f6e
 
 
 
 
f2a2584
 
 
 
 
4a34f6e
 
4cd3a4a
4a34f6e
4cd3a4a
4a34f6e
 
 
 
f2a2584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16ab50a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e91e2b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2a2584
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
import json
import os
import time

import boto3
import openai
import requests
from dotenv import load_dotenv
from model_config import MODEL_TO_PROVIDER, MODEL_TO_INFERENCE_PROFILE_ARN

# Lazy initialization of Google Gemini client
_google_client = None

def get_google_client():
    """Get or create the Google Gemini client with proper error handling."""
    global _google_client
    if _google_client is None:
        try:
            import google.generativeai as genai
        except ImportError:
            raise ValueError(
                "google-generativeai package not installed. "
                "Please add 'google-generativeai' to requirements.txt"
            )
        
        google_api_key = os.getenv("GOOGLE_API_KEY", "").strip()
        if not google_api_key:
            raise ValueError(
                "Google API key not found. Please set GOOGLE_API_KEY "
                "as a secret in Hugging Face Spaces settings."
            )
        
        try:
            genai.configure(api_key=google_api_key)
            _google_client = genai
        except Exception as e:
            raise ValueError(
                f"Failed to initialize Google Gemini client: {str(e)}. "
                "Please verify your GOOGLE_API_KEY is correct."
            ) from e
    
    return _google_client

# ──────────────────────────────────────────────────────────────
# Load environment variables
load_dotenv()
# ──────────────────────────────────────────────────────────────

# ──────────────────────────────────────────────────────────────
# Configuration
# ──────────────────────────────────────────────────────────────
MODEL_STRING = "gpt-4.1-mini"  # we default on gpt-4.1-mini
api_key = os.getenv("MODEL_API_KEY")
client = openai.OpenAI(api_key=api_key)

# Lazy initialization of bedrock client to avoid errors if credentials are missing
_bedrock_runtime = None

def get_bedrock_client():
    """Get or create the Bedrock runtime client with proper error handling."""
    global _bedrock_runtime
    if _bedrock_runtime is None:
        aws_access_key = os.getenv("AWS_ACCESS_KEY_ID", "").strip()
        aws_secret_key = os.getenv("AWS_SECRET_ACCESS_KEY", "").strip()
        aws_region = os.getenv("AWS_DEFAULT_REGION", "us-east-1").strip()
        
        if not aws_access_key or not aws_secret_key:
            raise ValueError(
                "AWS credentials not found. Please set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY "
                "as secrets in Hugging Face Spaces settings. "
                f"Current values: AWS_ACCESS_KEY_ID={'***' if aws_access_key else 'EMPTY'}, "
                f"AWS_SECRET_ACCESS_KEY={'***' if aws_secret_key else 'EMPTY'}"
            )
        
        try:
            _bedrock_runtime = boto3.client(
                "bedrock-runtime", 
                region_name=aws_region,
                aws_access_key_id=aws_access_key,
                aws_secret_access_key=aws_secret_key
            )
        except Exception as e:
            raise ValueError(
                f"Failed to initialize AWS Bedrock client: {str(e)}. "
                "Please verify your AWS credentials are valid and have Bedrock access."
            ) from e
    
    return _bedrock_runtime

# ──────────────────────────────────────────────────────────────
# Model switcher
# ──────────────────────────────────────────────────────────────
def set_model(model_id: str) -> None:
    global MODEL_STRING
    MODEL_STRING = model_id
    print(f"Model changed to: {model_id}")


def set_provider(provider: str) -> None:
    global PROVIDER


# ──────────────────────────────────────────────────────────────
# High-level Chat wrapper
# ──────────────────────────────────────────────────────────────
def chat(messages, persona):
    provider = MODEL_TO_PROVIDER[MODEL_STRING]

    if provider == "openai":
        print("Using openai: ", MODEL_STRING)
        t0 = time.time()
        
        # Add system prompt for better behavior
        system_prompt = ""
        
        # Prepare messages with system prompt
        chat_messages = [{"role": "system", "content": system_prompt}]
        for msg in messages:
            chat_messages.append({
                "role": msg["role"],
                "content": msg["content"]
            })

        request_kwargs = {
            "model": MODEL_STRING,
            "messages": chat_messages,
            "max_completion_tokens": 4000,
        }
        # Some newer OpenAI models only support the default temperature.
        if MODEL_STRING not in {"gpt-5", "gpt-5-nano", "gpt-5-mini"}:
            request_kwargs["temperature"] = 0.3

        response = client.chat.completions.create(**request_kwargs)

        dt = time.time() - t0
        text = response.choices[0].message.content.strip()
        
        # Calculate tokens
        total_tok = response.usage.total_tokens if response.usage else len(text.split())
        
        return text, dt, total_tok, (total_tok / dt if dt else total_tok)
    elif provider == "anthropic":
        print("Using anthropic: ", MODEL_STRING)
        t0 = time.time()

        # Add system prompt for better behavior
        system_prompt = ""
        
        claude_messages = [
            {"role": m["role"], "content": m["content"]} for m in messages
        ]

        try:
            bedrock_runtime = get_bedrock_client()
            
            # Use inference profile ARN if available (for provisioned throughput models)
            # Otherwise use modelId (for on-demand models)
            invoke_kwargs = {
                "contentType": "application/json",
                "accept": "application/json",
                "body": json.dumps(
                    {
                        "anthropic_version": "bedrock-2023-05-31",
                        "system": system_prompt,
                        "messages": claude_messages,
                        "max_tokens": 4000,  # Much higher limit for longer responses
                        "temperature": 0.3,  # Lower temperature for more focused responses
                    }
                ),
            }
            
            # Check if this model has an inference profile ARN (provisioned throughput)
            # For provisioned throughput, use the ARN as the modelId
            if MODEL_STRING in MODEL_TO_INFERENCE_PROFILE_ARN:
                invoke_kwargs["modelId"] = MODEL_TO_INFERENCE_PROFILE_ARN[MODEL_STRING]
            else:
                invoke_kwargs["modelId"] = MODEL_STRING
            
            response = bedrock_runtime.invoke_model(**invoke_kwargs)

            dt = time.time() - t0
            body = json.loads(response["body"].read())
        except ValueError as e:
            # Re-raise ValueError (credential errors) as-is
            raise
        except Exception as e:
            error_msg = str(e)
            if "ValidationException" in error_msg and "model identifier is invalid" in error_msg:
                raise ValueError(
                    f"Invalid Bedrock model ID: '{MODEL_STRING}'. "
                    f"Error: {error_msg}. "
                    "Please verify the model ID is correct and the model is available in your AWS region. "
                    "Common Claude model IDs: 'anthropic.claude-3-5-sonnet-20241022-v2' or 'anthropic.claude-3-haiku-20240307-v1'"
                ) from e
            elif "UnrecognizedClientException" in error_msg or "invalid" in error_msg.lower():
                raise ValueError(
                    f"AWS Bedrock authentication failed: {error_msg}. "
                    "Please verify your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY secrets "
                    "are correct and have Bedrock access permissions."
                ) from e
            raise

        text = "".join(
            part["text"] for part in body["content"] if part["type"] == "text"
        ).strip()
        total_tok = len(text.split())

        return text, dt, total_tok, (total_tok / dt if dt else total_tok)
    elif provider == "google":
        print("Using google (Gemini): ", MODEL_STRING)
        t0 = time.time()

        try:
            genai = get_google_client()
            
            # Get the model
            model = genai.GenerativeModel(MODEL_STRING)
            
            # Convert messages to Gemini format
            # Gemini API expects a chat history format with "user" and "model" roles
            chat_history = []
            for msg in messages:
                role = msg.get("role", "user")
                content = msg.get("content", "")
                # Skip system messages (we'll handle them separately)
                if role == "system":
                    continue
                # Gemini uses "model" instead of "assistant"
                if role == "assistant":
                    role = "model"
                chat_history.append({
                    "role": role,
                    "parts": [content]
                })
            
            # Separate history from the last user message
            if chat_history and chat_history[-1]["role"] == "user":
                history = chat_history[:-1]
                last_user_message = chat_history[-1]["parts"][0]
            else:
                history = []
                last_user_message = chat_history[-1]["parts"][0] if chat_history else ""
            
            # Start a chat session with history
            chat = model.start_chat(history=history)
            
            # Send the last message
            response = chat.send_message(
                last_user_message,
                generation_config=genai.types.GenerationConfig(
                    max_output_tokens=4000,
                    temperature=0.3,
                )
            )

            dt = time.time() - t0
            text = response.text.strip()
            
            # Calculate tokens (approximate)
            total_tok = len(text.split())

            return text, dt, total_tok, (total_tok / dt if dt else total_tok)
        except ValueError as e:
            # Re-raise ValueError (credential errors) as-is
            raise
        except Exception as e:
            error_msg = str(e)
            if "API key" in error_msg or "invalid" in error_msg.lower() or "401" in error_msg or "403" in error_msg:
                raise ValueError(
                    f"Google API authentication failed: {error_msg}. "
                    "Please verify your GOOGLE_API_KEY secret is correct and has Gemini API access."
                ) from e
            elif "not found" in error_msg.lower() or "404" in error_msg:
                raise ValueError(
                    f"Invalid Gemini model ID: '{MODEL_STRING}'. "
                    f"Error: {error_msg}. "
                    "Please verify the model ID is correct. "
                    "Common Gemini model IDs: 'gemini-1.5-pro', 'gemini-1.5-flash', 'gemini-2.0-flash-exp', 'gemini-pro'"
                ) from e
            raise
    elif provider == "deepseek":
        print("Using deepseek: ", MODEL_STRING)
        t0 = time.time()

        system_prompt = (
            ""
        )

        ds_messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_prompt}],
            }
        ]
        for msg in messages:
            role = msg.get("role", "user")
            ds_messages.append(
                {
                    "role": role,
                    "content": [{"type": "text", "text": msg["content"]}],
                }
            )

        try:
            bedrock_runtime = get_bedrock_client()
            response = bedrock_runtime.invoke_model(
                modelId=MODEL_STRING,
                contentType="application/json",
                accept="application/json",
                body=json.dumps(
                    {
                        "messages": ds_messages,
                        "max_completion_tokens": 500,
                        "temperature": 0.5,
                        "top_p": 0.9,
                    }
                ),
            )

            dt = time.time() - t0
            body = json.loads(response["body"].read())
        except ValueError as e:
            # Re-raise ValueError (credential errors) as-is
            raise
        except Exception as e:
            error_msg = str(e)
            if "ValidationException" in error_msg and "model identifier is invalid" in error_msg:
                raise ValueError(
                    f"Invalid Bedrock model ID: '{MODEL_STRING}'. "
                    f"Error: {error_msg}. "
                    "Please verify the model ID is correct and the model is available in your AWS region."
                ) from e
            elif "UnrecognizedClientException" in error_msg or "invalid" in error_msg.lower():
                raise ValueError(
                    f"AWS Bedrock authentication failed: {error_msg}. "
                    "Please verify your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY secrets "
                    "are correct and have Bedrock access permissions."
                ) from e
            raise

        outputs = body.get("output", [])
        text_chunks = []
        for item in outputs:
            for content in item.get("content", []):
                chunk_text = content.get("text") or content.get("output_text")
                if chunk_text:
                    text_chunks.append(chunk_text)
        text = "".join(text_chunks).strip()
        if not text and "response" in body:
            text = body["response"].get("output_text", "").strip()
        total_tok = len(text.split())

        return text, dt, total_tok, (total_tok / dt if dt else total_tok)
    # elif provider == "meta":
    #     print("Using meta (LLaMA): ", MODEL_STRING)
    #     t0 = time.time()

    #     # Add system prompt for better behavior
    #     system_prompt = ""
        
    #     # Format conversation properly for Llama3
    #     formatted_prompt = "<|begin_of_text|>"
        
    #     # Add system prompt
    #     formatted_prompt += "<|start_header_id|>system<|end_header_id|>\n" + system_prompt + "<|eot_id|>\n"
        
    #     # Add conversation history
    #     for msg in messages:
    #         if msg["role"] == "user":
    #             formatted_prompt += "<|start_header_id|>user<|end_header_id|>\n" + msg["content"] + "<|eot_id|>\n"
    #         elif msg["role"] == "assistant":
    #             formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n" + msg["content"] + "<|eot_id|>\n"
        
    #     # Add final assistant prompt
    #     formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n"

    #     response = bedrock_runtime.invoke_model(
    #         modelId=MODEL_STRING,
    #         contentType="application/json",
    #         accept="application/json",
    #         body=json.dumps(
    #             {
    #                 "prompt": formatted_prompt, 
    #                 "max_gen_len": 512,  # Shorter responses
    #                 "temperature": 0.3,  # Lower temperature for more focused responses
    #             }
    #         ),
    #     )

        # dt = time.time() - t0
        # body = json.loads(response["body"].read())
        # text = body.get("generation", "").strip()
        # total_tok = len(text.split())

    #     return text, dt, total_tok, (total_tok / dt if dt else total_tok)
    # elif provider == "mistral":
    #     print("Using mistral: ", MODEL_STRING)
    #     t0 = time.time()

    #     prompt = messages[-1]["content"]
    #     formatted_prompt = f"<s>[INST] {prompt} [/INST]"

    #     response = bedrock_runtime.invoke_model(
    #         modelId=MODEL_STRING,
    #         contentType="application/json",
    #         accept="application/json",
    #         body=json.dumps(
    #             {"prompt": formatted_prompt, "max_tokens": 512, "temperature": 0.5}
    #         ),
    #     )

    #     dt = time.time() - t0
    #     body = json.loads(response["body"].read())

    #     text = body["outputs"][0]["text"].strip()
    #     total_tok = len(text.split())

    #     return text, dt, total_tok, (total_tok / dt if dt else total_tok)
    # elif provider == "ollama":
    #     print("Using ollama: ", MODEL_STRING)
    #     t0 = time.time()
        
    #     # Format messages for Ollama API with system prompt
    #     ollama_messages = []
        
    #     # Add system prompt for better behavior
    #     system_prompt = ""
    #     ollama_messages.append({
    #         "role": "system",
    #         "content": system_prompt
    #     })
        
    #     for msg in messages:
    #         ollama_messages.append({
    #             "role": msg["role"],
    #             "content": msg["content"]
    #         })
        
    #     # Make request to Ollama API
    #     response = requests.post(
    #         f"{OLLAMA_BASE_URL}/api/chat",
    #         json={
    #             "model": MODEL_STRING,
    #             "messages": ollama_messages,
    #             "stream": False,
    #             "options": {
    #                 "temperature": 0.3,  # Lower temperature for more focused responses
    #                 # "num_predict": 4000,  # Much higher limit for longer responses
    #                 "top_p": 0.9,
    #                 "repeat_penalty": 1.1
    #             }
    #         },
    #         timeout=60
    #     )
        
    #     dt = time.time() - t0
        
    #     if response.status_code == 200:
    #         result = response.json()
    #         text = result["message"]["content"].strip()
    #         total_tok = len(text.split())
    #         return text, dt, total_tok, (total_tok / dt if dt else total_tok)
    #     else:
    #         raise Exception(f"Ollama API error: {response.status_code} - {response.text}")


# ──────────────────────────────────────────────────────────────
# Diagnostics / CLI test
# ──────────────────────────────────────────────────────────────
def check_credentials():
    # # Check if using Ollama (no API key required)
    # if MODEL_TO_PROVIDER.get(MODEL_STRING) == "ollama":
    #     # Test Ollama connection
    #     try:
    #         response = requests.get(f"{OLLAMA_BASE_URL}/api/tags", timeout=5)
    #         if response.status_code == 200:
    #             print("Ollama connection successful")
    #             return True
    #         else:
    #             print(f"Ollama connection failed: {response.status_code}")
    #             return False
    #     except Exception as e:
    #         print(f"Ollama connection failed: {e}")
    #         return False
    
    # Check if using Bedrock providers (anthropic, meta, mistral, deepseek)
    bedrock_providers = ["anthropic"]
    if MODEL_TO_PROVIDER.get(MODEL_STRING) in bedrock_providers:
        # Test AWS Bedrock connection by trying to invoke a simple model
        try:
            bedrock_runtime = get_bedrock_client()
            # Try a simple test invocation to verify credentials
            test_model = "anthropic.claude-haiku-4-5-20251001-v1:0"
            test_kwargs = {
                "contentType": "application/json",
                "accept": "application/json",
                "body": json.dumps({
                    "anthropic_version": "bedrock-2023-05-31",
                    "messages": [{"role": "user", "content": "test"}],
                    "max_tokens": 10,
                    "temperature": 0.1
                })
            }
            
            # Use inference profile ARN if available (use ARN as modelId for provisioned throughput)
            if test_model in MODEL_TO_INFERENCE_PROFILE_ARN:
                test_kwargs["modelId"] = MODEL_TO_INFERENCE_PROFILE_ARN[test_model]
            else:
                test_kwargs["modelId"] = test_model
            
            test_response = bedrock_runtime.invoke_model(**test_kwargs)
            print("Bedrock connection successful")
            return True
        except Exception as e:
            print(f"Bedrock connection failed: {e}")
            print("Make sure AWS credentials are configured and you have access to Bedrock")
            return False
    
    # For OpenAI, check API key
    if MODEL_TO_PROVIDER.get(MODEL_STRING) == "openai":
        required = ["MODEL_API_KEY"]
        missing = [var for var in required if not os.getenv(var)]
        if missing:
            print(f"Missing environment variables: {missing}")
            return False
        return True
    
    # For Google Gemini, check API key
    if MODEL_TO_PROVIDER.get(MODEL_STRING) == "google":
        required = ["GOOGLE_API_KEY"]
        missing = [var for var in required if not os.getenv(var)]
        if missing:
            print(f"Missing environment variables: {missing}")
            return False
        # Try to initialize the client to verify the key works
        try:
            get_google_client()
            return True
        except Exception as e:
            print(f"Google API client initialization failed: {e}")
            return False
    
    return True


def test_chat():
    print("Testing chat...")
    try:
        test_messages = [
            {
                "role": "user",
                "content": "Hello! Please respond with just 'Test successful'.",
            }
        ]
        text, latency, tokens, tps = chat(test_messages)
        print(f"Test passed!  {text}  {latency:.2f}s  {tokens} ⚑ {tps:.1f} tps")
    except Exception as e:
        print(f"Test failed: {e}")


if __name__ == "__main__":
    print("running diagnostics")
    if check_credentials():
        test_chat()
    print("\nDone.")