Spaces:
Runtime error
Runtime error
Commit
·
9b0c0fa
1
Parent(s):
13af073
ONS2
Browse files
app.py
CHANGED
|
@@ -35,14 +35,12 @@ def load_model(hf_token):
|
|
| 35 |
|
| 36 |
try:
|
| 37 |
# Try different model versions from smallest to largest
|
| 38 |
-
# Prioritize instruction-tuned models
|
| 39 |
model_options = [
|
| 40 |
"google/gemma-2b-it",
|
| 41 |
"google/gemma-7b-it",
|
| 42 |
"google/gemma-2b",
|
| 43 |
"google/gemma-7b",
|
| 44 |
-
#
|
| 45 |
-
"TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
| 46 |
]
|
| 47 |
|
| 48 |
print(f"Attempting to load models with token starting with: {hf_token[:5]}...")
|
|
@@ -51,10 +49,8 @@ def load_model(hf_token):
|
|
| 51 |
try:
|
| 52 |
print(f"\n--- Attempting to load model: {model_name} ---")
|
| 53 |
is_gemma = "gemma" in model_name.lower()
|
| 54 |
-
|
| 55 |
-
current_token = hf_token if is_gemma else None # Only use token for Gemma models
|
| 56 |
|
| 57 |
-
# Load tokenizer
|
| 58 |
print("Loading tokenizer...")
|
| 59 |
global_tokenizer = AutoTokenizer.from_pretrained(
|
| 60 |
model_name,
|
|
@@ -62,13 +58,11 @@ def load_model(hf_token):
|
|
| 62 |
)
|
| 63 |
print("Tokenizer loaded successfully.")
|
| 64 |
|
| 65 |
-
# Load model
|
| 66 |
print(f"Loading model {model_name}...")
|
| 67 |
global_model = AutoModelForCausalLM.from_pretrained(
|
| 68 |
model_name,
|
| 69 |
-
# torch_dtype=torch.bfloat16, # Use bfloat16 for better performance/compatibility if available - fallback to float16 if needed
|
| 70 |
torch_dtype=torch.float16, # Using float16 for broader compatibility
|
| 71 |
-
device_map="auto",
|
| 72 |
token=current_token
|
| 73 |
)
|
| 74 |
print(f"Model {model_name} loaded successfully!")
|
|
@@ -76,34 +70,26 @@ def load_model(hf_token):
|
|
| 76 |
model_loaded = True
|
| 77 |
loaded_model_name = model_name
|
| 78 |
loaded_successfully = True
|
| 79 |
-
tabs_update = gr.Tabs.update(visible=True)
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
|
| 85 |
except ImportError as import_err:
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
continue # Try next model
|
| 89 |
except Exception as specific_e:
|
| 90 |
print(f"Failed to load {model_name}: {specific_e}")
|
| 91 |
-
|
| 92 |
-
if "401 Client Error" in str(specific_e) and is_gemma:
|
| 93 |
print("Authentication error likely. Check token and license agreement.")
|
| 94 |
-
# Don't immediately fail, try next model
|
| 95 |
-
elif "requires you to be logged in" in str(specific_e) and is_gemma:
|
| 96 |
-
print("Authentication error likely. Check token and license agreement.")
|
| 97 |
-
# Don't immediately fail, try next model
|
| 98 |
-
# Continue to the next model option
|
| 99 |
continue
|
| 100 |
|
| 101 |
-
# If loop finishes without loading
|
| 102 |
if not loaded_successfully:
|
| 103 |
model_loaded = False
|
| 104 |
loaded_model_name = "None"
|
| 105 |
print("Could not load any model version.")
|
| 106 |
-
return "❌ Could not load any model. Please check your token
|
| 107 |
|
| 108 |
except Exception as e:
|
| 109 |
model_loaded = False
|
|
@@ -111,16 +97,14 @@ def load_model(hf_token):
|
|
| 111 |
error_msg = str(e)
|
| 112 |
print(f"Error in load_model: {error_msg}")
|
| 113 |
traceback.print_exc()
|
| 114 |
-
|
| 115 |
if "401 Client Error" in error_msg or "requires you to be logged in" in error_msg :
|
| 116 |
-
return "❌ Authentication failed.
|
| 117 |
else:
|
| 118 |
-
return f"❌
|
| 119 |
|
| 120 |
|
| 121 |
def generate_prompt(task_type, **kwargs):
|
| 122 |
"""Generate appropriate prompts based on task type and parameters"""
|
| 123 |
-
# Using a dictionary-based approach for cleaner prompt generation
|
| 124 |
prompts = {
|
| 125 |
"creative": "Write a {style} about {topic}. Be creative and engaging.",
|
| 126 |
"informational": "Write an {format_type} about {topic}. Be clear, factual, and informative.",
|
|
@@ -138,24 +122,17 @@ def generate_prompt(task_type, **kwargs):
|
|
| 138 |
"classify": "Classify the following text into one of these categories: {categories}\n\nText: {text}\n\nCategory:",
|
| 139 |
"data_extract": "Extract the following data points ({data_points}) from the text below:\n\nText: {text}\n\nExtracted Data:",
|
| 140 |
}
|
| 141 |
-
|
| 142 |
prompt_template = prompts.get(task_type)
|
| 143 |
if prompt_template:
|
| 144 |
try:
|
| 145 |
-
# Prepare kwargs safely for formatting
|
| 146 |
-
# Find placeholders like {key}
|
| 147 |
keys_in_template = [k[1:-1] for k in prompt_template.split('{') if '}' in k for k in [k.split('}')[0]]]
|
| 148 |
-
final_kwargs = {key: kwargs.get(key, f"[{key}]") for key in keys_in_template}
|
| 149 |
-
|
| 150 |
-
# Add any extra kwargs provided that weren't in the template (e.g., for 'custom' type)
|
| 151 |
-
final_kwargs.update(kwargs)
|
| 152 |
-
|
| 153 |
return prompt_template.format(**final_kwargs)
|
| 154 |
except KeyError as e:
|
| 155 |
print(f"Warning: Missing key for prompt template '{task_type}': {e}")
|
| 156 |
-
return kwargs.get("prompt", f"Generate text based on: {kwargs}")
|
| 157 |
else:
|
| 158 |
-
# Fallback for custom or undefined task types
|
| 159 |
return kwargs.get("prompt", "Generate text based on the input.")
|
| 160 |
|
| 161 |
|
|
@@ -169,689 +146,435 @@ def generate_text(prompt, max_new_tokens=1024, temperature=0.7, top_p=0.9):
|
|
| 169 |
print(f"Prompt (start): {prompt[:150]}...")
|
| 170 |
|
| 171 |
if not model_loaded or global_model is None or global_tokenizer is None:
|
| 172 |
-
print("Model not loaded error.")
|
| 173 |
return "⚠️ Model not loaded. Please authenticate first."
|
| 174 |
-
|
| 175 |
if not prompt:
|
| 176 |
return "⚠️ Please enter a prompt or configure a task."
|
| 177 |
|
| 178 |
try:
|
| 179 |
-
|
| 180 |
-
# Simple check based on model name conventions
|
| 181 |
if loaded_model_name and ("it" in loaded_model_name.lower() or "instruct" in loaded_model_name.lower() or "chat" in loaded_model_name.lower()):
|
| 182 |
-
# Simple chat structure assumed by many instruction models
|
| 183 |
-
# Using Gemma's specific format if it's a Gemma IT model
|
| 184 |
if "gemma" in loaded_model_name.lower():
|
|
|
|
| 185 |
chat_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
|
|
|
|
|
|
|
|
|
|
| 186 |
else: # Generic instruction format
|
| 187 |
chat_prompt = f"User: {prompt}\nAssistant:"
|
| 188 |
-
else:
|
| 189 |
-
# Base models might not need specific turn indicators
|
| 190 |
-
chat_prompt = prompt
|
| 191 |
|
| 192 |
inputs = global_tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=True).to(global_model.device)
|
| 193 |
input_length = inputs.input_ids.shape[1]
|
| 194 |
print(f"Input token length: {input_length}")
|
| 195 |
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
generation_args = {
|
| 201 |
"input_ids": inputs.input_ids,
|
| 202 |
-
"attention_mask": inputs.attention_mask,
|
| 203 |
"max_new_tokens": effective_max_new_tokens,
|
| 204 |
"do_sample": True,
|
| 205 |
-
"temperature": float(temperature),
|
| 206 |
-
"top_p": float(top_p),
|
| 207 |
-
"pad_token_id":
|
| 208 |
}
|
| 209 |
|
| 210 |
print(f"Generation args: {generation_args}")
|
| 211 |
|
| 212 |
-
|
| 213 |
-
with torch.no_grad(): # Disable gradient calculation for inference
|
| 214 |
outputs = global_model.generate(**generation_args)
|
| 215 |
|
| 216 |
-
# Decode only the newly generated tokens
|
| 217 |
generated_ids = outputs[0, input_length:]
|
| 218 |
generated_text = global_tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 219 |
|
| 220 |
print(f"Generated text length: {len(generated_text)}")
|
| 221 |
print(f"Generated text (start): {generated_text[:150]}...")
|
| 222 |
-
return generated_text.strip()
|
| 223 |
|
| 224 |
except Exception as e:
|
| 225 |
error_msg = str(e)
|
| 226 |
print(f"Generation error: {error_msg}")
|
| 227 |
-
print(f"Error type: {type(e)}")
|
| 228 |
traceback.print_exc()
|
| 229 |
-
# Check for common CUDA errors
|
| 230 |
if "CUDA out of memory" in error_msg:
|
| 231 |
-
return f"❌ Error: CUDA out of memory. Try reducing 'Max New Tokens' or using a smaller model
|
| 232 |
-
elif "probability tensor contains nan" in error_msg:
|
| 233 |
-
return f"❌ Error: Generation failed (
|
| 234 |
else:
|
| 235 |
-
return f"❌ Error during text generation: {error_msg}
|
|
|
|
|
|
|
| 236 |
|
| 237 |
-
# Create parameters UI component (reusable function)
|
| 238 |
def create_parameter_ui():
|
| 239 |
with gr.Accordion("✨ Generation Parameters", open=False):
|
| 240 |
with gr.Row():
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
maximum=2048, # Set a reasonable max limit
|
| 245 |
-
value=512, # Default to a moderate length
|
| 246 |
-
step=64,
|
| 247 |
-
label="Max New Tokens",
|
| 248 |
-
info="Max number of tokens to generate.",
|
| 249 |
-
elem_id="max_new_tokens_slider"
|
| 250 |
-
)
|
| 251 |
-
temperature = gr.Slider(
|
| 252 |
-
minimum=0.1, # Avoid 0 which disables sampling
|
| 253 |
-
maximum=1.5,
|
| 254 |
-
value=0.7,
|
| 255 |
-
step=0.1,
|
| 256 |
-
label="Temperature",
|
| 257 |
-
info="Controls randomness. Lower is focused, higher is diverse.",
|
| 258 |
-
elem_id="temperature_slider"
|
| 259 |
-
)
|
| 260 |
-
top_p = gr.Slider(
|
| 261 |
-
minimum=0.1,
|
| 262 |
-
maximum=1.0, # Can be 1.0
|
| 263 |
-
value=0.9,
|
| 264 |
-
step=0.05,
|
| 265 |
-
label="Top-P (Nucleus Sampling)",
|
| 266 |
-
info="Considers tokens with cumulative probability >= top_p.",
|
| 267 |
-
elem_id="top_p_slider"
|
| 268 |
-
)
|
| 269 |
return [max_new_tokens, temperature, top_p]
|
| 270 |
|
|
|
|
|
|
|
|
|
|
| 271 |
# --- Gradio Interface ---
|
| 272 |
-
# Use the soft theme for a clean look, allow light/dark switching
|
| 273 |
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True, title="Gemma Capabilities Demo") as demo:
|
| 274 |
|
| 275 |
# Header
|
| 276 |
gr.Markdown(
|
| 277 |
"""
|
| 278 |
-
<div style="text-align: center; margin-bottom: 20px;">
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
</h1>
|
| 282 |
-
<p style="font-size: 1.1em; color: #555;">
|
| 283 |
-
Explore the text generation capabilities of Google's Gemma models (or a fallback).
|
| 284 |
-
</p>
|
| 285 |
-
<p style="font-size: 0.9em; color: #777;">
|
| 286 |
-
Requires a Hugging Face token with access to Gemma models.
|
| 287 |
-
<a href="https://huggingface.co/google/gemma-7b-it" target="_blank">[Accept Gemma License Here]</a>
|
| 288 |
-
</p>
|
| 289 |
-
</div>
|
| 290 |
-
"""
|
| 291 |
)
|
| 292 |
|
| 293 |
-
# --- Authentication
|
| 294 |
-
#
|
| 295 |
-
|
| 296 |
-
gr.Markdown("### 🔑 Authentication") # Added heading inside group
|
| 297 |
with gr.Row():
|
| 298 |
with gr.Column(scale=4):
|
| 299 |
-
hf_token = gr.Textbox(
|
| 300 |
-
label="Hugging Face Token",
|
| 301 |
-
placeholder="Paste your HF token here (hf_...)",
|
| 302 |
-
type="password",
|
| 303 |
-
value=DEFAULT_HF_TOKEN,
|
| 304 |
-
info="Get your token from https://huggingface.co/settings/tokens",
|
| 305 |
-
elem_id="hf_token_input"
|
| 306 |
-
)
|
| 307 |
with gr.Column(scale=1, min_width=150):
|
| 308 |
-
auth_button = gr.Button("Load Model", variant="primary"
|
| 309 |
-
|
| 310 |
-
auth_status = gr.Markdown("ℹ️ Enter your Hugging Face token and click 'Load Model'. This might take a minute.", elem_id="auth_status")
|
| 311 |
-
# Add instructions on getting token inside the auth group
|
| 312 |
gr.Markdown(
|
| 313 |
-
""
|
| 314 |
-
|
| 315 |
-
1. Go to [Hugging Face Token Settings](https://huggingface.co/settings/tokens)
|
| 316 |
-
2. Create a new token with **read** access.
|
| 317 |
-
3. Ensure you've accepted the [Gemma model license](https://huggingface.co/google/gemma-7b-it) on the model page.
|
| 318 |
-
"""
|
| 319 |
)
|
| 320 |
|
| 321 |
-
|
| 322 |
-
# --- Main Content Tabs (Initially Hidden) ---
|
| 323 |
-
# Define the tabs variable here
|
| 324 |
with gr.Tabs(elem_id="main_tabs", visible=False) as tabs:
|
| 325 |
|
| 326 |
# --- Text Generation Tab ---
|
| 327 |
-
with gr.TabItem("📝 Creative & Informational"
|
| 328 |
-
with gr.Row(
|
| 329 |
-
# Input Column
|
| 330 |
with gr.Column(scale=1):
|
| 331 |
-
gr.Markdown("
|
| 332 |
-
text_gen_type = gr.Radio(
|
| 333 |
-
|
| 334 |
-
label="
|
| 335 |
-
value="
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
style = gr.Dropdown(["short story", "poem", "script", "song lyrics", "joke", "dialogue"], label="Style", value="short story", elem_id="creative_style")
|
| 342 |
-
creative_topic = gr.Textbox(label="Topic", placeholder="e.g., a lonely astronaut on Mars", value="a robot discovering music", elem_id="creative_topic", lines=2)
|
| 343 |
-
|
| 344 |
-
with gr.Group(visible=False, elem_id="info_options") as info_options:
|
| 345 |
-
format_type = gr.Dropdown(["article", "summary", "explanation", "report", "comparison"], label="Format", value="article", elem_id="info_format")
|
| 346 |
-
info_topic = gr.Textbox(label="Topic", placeholder="e.g., the basics of quantum physics", value="the impact of AI on healthcare", elem_id="info_topic", lines=2)
|
| 347 |
-
|
| 348 |
-
with gr.Group(visible=False, elem_id="custom_prompt_group") as custom_prompt_group:
|
| 349 |
-
custom_prompt = gr.Textbox(label="Custom Prompt", placeholder="Enter your full prompt here...", lines=5, elem_id="custom_prompt")
|
| 350 |
-
|
| 351 |
-
# Show/hide logic (using gr.update for better practice)
|
| 352 |
-
def update_text_gen_visibility(choice):
|
| 353 |
-
is_creative = choice == "Creative Writing"
|
| 354 |
-
is_info = choice == "Informational Writing"
|
| 355 |
-
is_custom = choice == "Custom Prompt"
|
| 356 |
-
return {
|
| 357 |
-
creative_options: gr.update(visible=is_creative),
|
| 358 |
-
info_options: gr.update(visible=is_info),
|
| 359 |
-
custom_prompt_group: gr.update(visible=is_custom)
|
| 360 |
-
}
|
| 361 |
-
text_gen_type.change(update_text_gen_visibility, inputs=text_gen_type, outputs=[creative_options, info_options, custom_prompt_group], queue=False)
|
| 362 |
-
|
| 363 |
-
# Parameters
|
| 364 |
text_gen_params = create_parameter_ui()
|
| 365 |
-
gr.Spacer
|
| 366 |
-
generate_text_btn = gr.Button("Generate Text", variant="primary"
|
| 367 |
-
|
| 368 |
-
# Output Column
|
| 369 |
with gr.Column(scale=1):
|
| 370 |
-
gr.Markdown("
|
| 371 |
-
text_output = gr.Textbox(label="Result", lines=25, interactive=False,
|
| 372 |
-
|
| 373 |
-
#
|
| 374 |
-
def
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
elif task_type == "custom":
|
| 391 |
-
kwargs["prompt"] = safe_value(custom_prompt_text, "Write something interesting.")
|
| 392 |
-
|
| 393 |
-
|
| 394 |
final_prompt = generate_prompt(task_type, **kwargs)
|
| 395 |
-
return generate_text(final_prompt,
|
| 396 |
-
|
| 397 |
-
generate_text_btn.click(
|
| 398 |
-
text_generation_handler,
|
| 399 |
-
inputs=[text_gen_type, style, creative_topic, format_type, info_topic, custom_prompt, *text_gen_params],
|
| 400 |
-
outputs=text_output
|
| 401 |
-
)
|
| 402 |
|
| 403 |
# Examples
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
["Custom Prompt", "", "", "", "", "Write a short dialogue between a cat and a dog discussing their humans.", 512, 0.8, 0.95],
|
| 410 |
-
],
|
| 411 |
-
# Ensure the order matches the handler's inputs
|
| 412 |
-
inputs=[text_gen_type, style, creative_topic, format_type, info_topic, custom_prompt, *text_gen_params[:3]], # Pass only the UI elements needed
|
| 413 |
-
outputs=text_output,
|
| 414 |
-
label="Try these examples...",
|
| 415 |
-
#fn=text_generation_handler # fn is deprecated, click event handles execution
|
| 416 |
-
)
|
| 417 |
|
| 418 |
|
| 419 |
# --- Brainstorming Tab ---
|
| 420 |
-
with gr.TabItem("🧠 Brainstorming"
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
gr.
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
["business", "eco-friendly subscription boxes", 768, 0.75, 0.9],
|
| 449 |
-
["creative", "themes for a fantasy novel", 512, 0.85, 0.95],
|
| 450 |
-
],
|
| 451 |
-
inputs=[brainstorm_category, brainstorm_topic, *brainstorm_params[:3]],
|
| 452 |
-
outputs=brainstorm_output,
|
| 453 |
-
label="Try these examples...",
|
| 454 |
-
)
|
| 455 |
-
|
| 456 |
-
# --- Code Capabilities Tab ---
|
| 457 |
-
with gr.TabItem("💻 Code", id="tab_code"):
|
| 458 |
-
# Language mapping for syntax highlighting (defined once)
|
| 459 |
-
lang_map = {"Python": "python", "JavaScript": "javascript", "Java": "java", "C++": "cpp", "HTML": "html", "CSS": "css", "SQL": "sql", "Bash": "bash", "Rust": "rust", "Other": "plaintext"}
|
| 460 |
-
|
| 461 |
-
with gr.Tabs() as code_tabs:
|
| 462 |
-
# --- Code Generation ---
|
| 463 |
-
with gr.TabItem("Generate Code", id="subtab_code_gen"):
|
| 464 |
-
with gr.Row(equal_height=False):
|
| 465 |
-
# Input Column
|
| 466 |
with gr.Column(scale=1):
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
# Output Column
|
| 475 |
with gr.Column(scale=1):
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
result = generate_text(prompt, max_tokens, temp, top_p_val)
|
| 485 |
-
# Try to extract code block if markdown is used
|
| 486 |
-
if "```" in result:
|
| 487 |
parts = result.split("```")
|
| 488 |
if len(parts) >= 2:
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
# Update output language display based on dropdown
|
| 506 |
-
def update_code_language_display(lang):
|
| 507 |
-
return gr.Code.update(language=lang_map.get(lang, "plaintext")) # Use update method
|
| 508 |
-
|
| 509 |
-
code_language_gen.change(update_code_language_display, inputs=code_language_gen, outputs=code_output, queue=False)
|
| 510 |
-
code_gen_btn.click(code_gen_handler, inputs=[code_language_gen, code_task, *code_gen_params], outputs=code_output)
|
| 511 |
-
|
| 512 |
-
gr.Examples(
|
| 513 |
-
examples=[
|
| 514 |
-
["JavaScript", "function to validate an email address using regex", 768, 0.6, 0.9],
|
| 515 |
-
["SQL", "query to select users older than 30 from a 'users' table", 512, 0.5, 0.8],
|
| 516 |
-
["HTML", "basic structure for a personal portfolio website", 1024, 0.7, 0.9],
|
| 517 |
-
],
|
| 518 |
-
inputs=[code_language_gen, code_task, *code_gen_params[:3]],
|
| 519 |
-
outputs=code_output,
|
| 520 |
-
label="Try these examples...",
|
| 521 |
-
)
|
| 522 |
-
|
| 523 |
-
# --- Code Explanation ---
|
| 524 |
-
with gr.TabItem("Explain Code", id="subtab_code_explain"):
|
| 525 |
-
with gr.Row(equal_height=False):
|
| 526 |
-
# Input Column
|
| 527 |
-
with gr.Column(scale=1):
|
| 528 |
-
gr.Markdown("### Code Explanation Setup")
|
| 529 |
-
code_language_explain = gr.Dropdown(list(lang_map.keys()), label="Code Language (for context)", value="Python", elem_id="code_language_explain")
|
| 530 |
-
code_to_explain = gr.Code(label="Paste Code Here", language="python", lines=15, elem_id="code_to_explain")
|
| 531 |
-
explain_code_params = create_parameter_ui()
|
| 532 |
-
gr.Spacer(height=15)
|
| 533 |
-
explain_code_btn = gr.Button("Explain Code", variant="primary", elem_id="explain_code_btn")
|
| 534 |
-
|
| 535 |
-
# Output Column
|
| 536 |
-
with gr.Column(scale=1):
|
| 537 |
-
gr.Markdown("### Explanation")
|
| 538 |
-
code_explanation = gr.Textbox(label="Result", lines=25, interactive=False, elem_id="code_explanation", show_copy_button=True)
|
| 539 |
-
|
| 540 |
-
# Update code input language display
|
| 541 |
-
def update_explain_language_display(lang):
|
| 542 |
-
return gr.Code.update(language=lang_map.get(lang, "plaintext"))
|
| 543 |
-
code_language_explain.change(update_explain_language_display, inputs=code_language_explain, outputs=code_to_explain, queue=False)
|
| 544 |
-
|
| 545 |
-
# Handler
|
| 546 |
-
def explain_code_handler(language, code, max_tokens, temp, top_p_val):
|
| 547 |
-
code_content = safe_value(code['code'] if isinstance(code, dict) else code, "# Add code here") # Handle potential dict input from gr.Code
|
| 548 |
-
language = safe_value(language, "code") # Use selected language in prompt
|
| 549 |
-
prompt = generate_prompt("code_explain", language=language, code=code_content)
|
| 550 |
-
return generate_text(prompt, max_tokens, temp, top_p_val)
|
| 551 |
-
|
| 552 |
-
explain_code_btn.click(explain_code_handler, inputs=[code_language_explain, code_to_explain, *explain_code_params], outputs=code_explanation)
|
| 553 |
-
|
| 554 |
-
# --- Code Debugging ---
|
| 555 |
-
with gr.TabItem("Debug Code", id="subtab_code_debug"):
|
| 556 |
-
with gr.Row(equal_height=False):
|
| 557 |
-
# Input Column
|
| 558 |
with gr.Column(scale=1):
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
value="def calculate_average(numbers):\n sum = 0\n for n in numbers:\n sum += n\n # Bug: potential division by zero if numbers is empty\n return sum / len(numbers)", # Example with potential bug
|
| 566 |
-
elem_id="code_to_debug"
|
| 567 |
-
)
|
| 568 |
-
debug_code_params = create_parameter_ui()
|
| 569 |
-
gr.Spacer(height=15)
|
| 570 |
-
debug_code_btn = gr.Button("Debug Code", variant="primary", elem_id="debug_code_btn")
|
| 571 |
-
|
| 572 |
-
# Output Column
|
| 573 |
with gr.Column(scale=1):
|
| 574 |
-
|
| 575 |
-
|
| 576 |
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
|
|
|
|
|
|
|
|
|
| 581 |
|
| 582 |
-
# Handler
|
| 583 |
-
def debug_code_handler(language, code, max_tokens, temp, top_p_val):
|
| 584 |
-
code_content = safe_value(code['code'] if isinstance(code, dict) else code, "# Add potentially buggy code here")
|
| 585 |
-
language = safe_value(language, "code")
|
| 586 |
-
prompt = generate_prompt("code_debug", language=language, code=code_content)
|
| 587 |
-
return generate_text(prompt, max_tokens, temp, top_p_val)
|
| 588 |
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
with gr.Column(scale=1):
|
| 601 |
-
gr.Markdown("
|
| 602 |
-
summarize_text = gr.Textbox(label="Text to Summarize", placeholder="Paste long text
|
| 603 |
summarize_params = create_parameter_ui()
|
| 604 |
-
gr.Spacer
|
| 605 |
-
summarize_btn = gr.Button("Summarize Text", variant="primary"
|
| 606 |
-
# Output Column
|
| 607 |
with gr.Column(scale=1):
|
| 608 |
-
gr.Markdown("
|
| 609 |
-
summary_output = gr.Textbox(label="Result", lines=15, interactive=False,
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
# --- Question Answering ---
|
| 622 |
-
with gr.TabItem("Q & A", id="subtab_qa"):
|
| 623 |
-
with gr.Row(equal_height=False):
|
| 624 |
-
# Input Column
|
| 625 |
with gr.Column(scale=1):
|
| 626 |
-
gr.Markdown("
|
| 627 |
-
qa_text = gr.Textbox(label="Context Text", placeholder="Paste
|
| 628 |
-
qa_question = gr.Textbox(label="Question", placeholder="Ask
|
| 629 |
qa_params = create_parameter_ui()
|
| 630 |
-
gr.Spacer
|
| 631 |
-
qa_btn = gr.Button("Get Answer", variant="primary"
|
| 632 |
-
# Output Column
|
| 633 |
with gr.Column(scale=1):
|
| 634 |
-
gr.Markdown("
|
| 635 |
-
qa_output = gr.Textbox(label="Result", lines=10, interactive=False,
|
| 636 |
-
|
| 637 |
-
|
| 638 |
-
|
| 639 |
-
|
| 640 |
-
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
qa_btn.click(qa_handler, inputs=[qa_text, qa_question, *qa_params], outputs=qa_output)
|
| 647 |
-
|
| 648 |
-
# --- Translation ---
|
| 649 |
-
with gr.TabItem("Translate", id="subtab_translate"):
|
| 650 |
-
with gr.Row(equal_height=False):
|
| 651 |
-
# Input Column
|
| 652 |
with gr.Column(scale=1):
|
| 653 |
-
gr.Markdown("
|
| 654 |
-
translate_text = gr.Textbox(label="Text to Translate", placeholder="Enter text
|
| 655 |
-
target_lang = gr.Dropdown(
|
| 656 |
-
["French", "Spanish", "German", "Japanese", "Chinese", "Russian", "Arabic", "Hindi", "Portuguese", "Italian"],
|
| 657 |
-
label="Translate To", value="French", elem_id="target_lang"
|
| 658 |
-
)
|
| 659 |
translate_params = create_parameter_ui()
|
| 660 |
-
gr.Spacer
|
| 661 |
-
translate_btn = gr.Button("Translate Text", variant="primary"
|
| 662 |
-
# Output Column
|
| 663 |
with gr.Column(scale=1):
|
| 664 |
-
gr.Markdown("
|
| 665 |
-
translation_output = gr.Textbox(label="Result", lines=8, interactive=False,
|
| 666 |
-
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 686 |
with gr.Column(scale=1):
|
| 687 |
-
gr.Markdown("
|
| 688 |
-
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
gr.
|
| 693 |
-
content_btn = gr.Button("Generate Content", variant="primary", elem_id="content_btn")
|
| 694 |
with gr.Column(scale=1):
|
| 695 |
-
gr.Markdown("
|
| 696 |
-
|
| 697 |
-
|
| 698 |
-
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
with gr.Column(scale=1):
|
| 711 |
-
gr.Markdown("
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
gr.Spacer
|
| 716 |
-
|
| 717 |
with gr.Column(scale=1):
|
| 718 |
-
gr.Markdown("
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
|
| 725 |
-
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
with gr.TabItem("Document Editing", id="tab_edit"):
|
| 731 |
-
with gr.Row(equal_height=False):
|
| 732 |
with gr.Column(scale=1):
|
| 733 |
-
gr.Markdown("
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
gr.Spacer
|
| 738 |
-
|
| 739 |
with gr.Column(scale=1):
|
| 740 |
-
gr.Markdown("
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
prompt = generate_prompt("document_edit", text=text, edit_type=e_type)
|
| 747 |
-
# Editing might expand text, give it reasonable token count based on input + max_new
|
| 748 |
-
input_tokens_estimate = len(text.split()) # Rough estimate
|
| 749 |
-
max_tok = max(int(max_tok), input_tokens_estimate + 64) # Ensure enough room
|
| 750 |
-
return generate_text(prompt, max_tok, temp, top_p_val)
|
| 751 |
-
edit_btn.click(edit_handler, inputs=[edit_text, edit_type, *edit_params], outputs=edit_output)
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
# --- Classification ---
|
| 755 |
-
with gr.TabItem("Classification", id="tab_classify"):
|
| 756 |
-
with gr.Row(equal_height=False):
|
| 757 |
-
with gr.Column(scale=1):
|
| 758 |
-
gr.Markdown("### Classification Setup")
|
| 759 |
-
classify_text = gr.Textbox(label="Text to Classify", placeholder="Enter text...", lines=8, value="This new sci-fi movie explores themes of AI consciousness and interstellar travel.")
|
| 760 |
-
classify_categories = gr.Textbox(label="Categories (comma-separated)", placeholder="e.g., positive, negative, neutral", value="Technology, Entertainment, Science, Politics, Sports, Health")
|
| 761 |
-
classify_params = create_parameter_ui()
|
| 762 |
-
gr.Spacer(height=15)
|
| 763 |
-
classify_btn = gr.Button("Classify Text", variant="primary")
|
| 764 |
-
with gr.Column(scale=1):
|
| 765 |
-
gr.Markdown("### Classification Result")
|
| 766 |
-
classify_output = gr.Textbox(label="Predicted Category", lines=2, interactive=False, show_copy_button=True)
|
| 767 |
-
|
| 768 |
-
def classify_handler(text, cats, max_tok, temp, top_p_val):
|
| 769 |
-
text = safe_value(text, "Text to classify needed.")
|
| 770 |
-
cats = safe_value(cats, "category1, category2")
|
| 771 |
-
# Classification usually needs short output
|
| 772 |
-
max_tok = min(max(int(max_tok), 16), 128) # Ensure int, constrain tightly
|
| 773 |
-
prompt = generate_prompt("classify", text=text, categories=cats)
|
| 774 |
-
# Often the model just outputs the category, so we might not need the prompt structure removal
|
| 775 |
-
raw_output = generate_text(prompt, max_tok, temp, top_p_val)
|
| 776 |
-
# Post-process to get just the category if possible
|
| 777 |
-
lines = raw_output.split('\n')
|
| 778 |
-
if lines:
|
| 779 |
-
last_line = lines[-1].strip()
|
| 780 |
-
# Check if the last line seems like one of the categories
|
| 781 |
-
possible_cats = [c.strip().lower() for c in cats.split(',')]
|
| 782 |
-
if last_line.lower() in possible_cats:
|
| 783 |
-
return last_line
|
| 784 |
-
# Fallback to raw output
|
| 785 |
-
return raw_output
|
| 786 |
-
|
| 787 |
-
classify_btn.click(classify_handler, inputs=[classify_text, classify_categories, *classify_params], outputs=classify_output)
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
# --- Data Extraction ---
|
| 791 |
-
with gr.TabItem("Data Extraction", id="tab_extract"):
|
| 792 |
-
with gr.Row(equal_height=False):
|
| 793 |
-
with gr.Column(scale=1):
|
| 794 |
-
gr.Markdown("### Extraction Setup")
|
| 795 |
-
extract_text = gr.Textbox(label="Source Text", placeholder="Paste text containing data...", lines=10, value="Order #12345 placed on 2024-03-15 by Jane Doe (jane.d@email.com). Total amount: $99.95. Shipping to 123 Main St, Anytown, USA.")
|
| 796 |
-
extract_data_points = gr.Textbox(label="Data to Extract (comma-separated)", placeholder="e.g., name, email, order number", value="order number, date, customer name, email, total amount, address")
|
| 797 |
-
extract_params = create_parameter_ui()
|
| 798 |
-
gr.Spacer(height=15)
|
| 799 |
-
extract_btn = gr.Button("Extract Data", variant="primary")
|
| 800 |
-
with gr.Column(scale=1):
|
| 801 |
-
gr.Markdown("### Extracted Data")
|
| 802 |
-
extract_output = gr.Textbox(label="Result (e.g., JSON or key-value pairs)", lines=10, interactive=False, show_copy_button=True)
|
| 803 |
|
| 804 |
-
def extract_handler(text, points, max_tok, temp, top_p_val):
|
| 805 |
-
text = safe_value(text, "Provide text for extraction.")
|
| 806 |
-
points = safe_value(points, "key information")
|
| 807 |
-
prompt = generate_prompt("data_extract", text=text, data_points=points)
|
| 808 |
-
return generate_text(prompt, max_tok, temp, top_p_val)
|
| 809 |
-
extract_btn.click(extract_handler, inputs=[extract_text, extract_data_points, *extract_params], outputs=extract_output)
|
| 810 |
|
|
|
|
|
|
|
| 811 |
|
| 812 |
-
# Define authentication handler AFTER tabs is defined
|
| 813 |
def handle_auth(token):
|
| 814 |
-
|
| 815 |
-
yield "⏳ Authenticating and loading model... Please wait.", gr.Tabs.update(visible=False)
|
| 816 |
-
# Call the actual model loading function
|
| 817 |
status_message, tabs_update = load_model(token)
|
| 818 |
yield status_message, tabs_update
|
| 819 |
|
| 820 |
-
#
|
| 821 |
-
|
| 822 |
-
|
| 823 |
-
|
| 824 |
-
|
| 825 |
-
queue=True # Run in queue for potentially long operation
|
| 826 |
-
)
|
| 827 |
-
|
| 828 |
-
# --- Footer ---
|
| 829 |
-
footer_status = gr.Markdown( # Use a separate Markdown for dynamic updates
|
| 830 |
-
f"""
|
| 831 |
-
---
|
| 832 |
-
<div style="text-align: center; font-size: 0.9em; color: #777;">
|
| 833 |
-
<p>Powered by Google's Gemma models via Hugging Face 🤗 Transformers & Gradio.</p>
|
| 834 |
-
<p>Remember to review generated content. Model outputs may be inaccurate or incomplete.</p>
|
| 835 |
-
<p>Model Loaded: <strong>{loaded_model_name if model_loaded else 'None'}</strong></p>
|
| 836 |
-
</div>
|
| 837 |
-
"""
|
| 838 |
-
)
|
| 839 |
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
<div style="text-align: center; font-size: 0.9em; color: #777;">
|
| 846 |
-
<p>Powered by Google's Gemma models via Hugging Face 🤗 Transformers & Gradio.</p>
|
| 847 |
-
<p>Remember to review generated content. Model outputs may be inaccurate or incomplete.</p>
|
| 848 |
-
<p>Model Loaded: <strong>{loaded_model_name if model_loaded else 'None'}</strong></p>
|
| 849 |
-
</div>
|
| 850 |
-
""")
|
| 851 |
-
auth_status.change(fn=update_footer_status, inputs=auth_status, outputs=footer_status, queue=False)
|
| 852 |
|
| 853 |
|
| 854 |
# --- Launch App ---
|
| 855 |
-
# Allow built-in theme switching
|
| 856 |
-
# Use queue() to handle multiple requests better
|
| 857 |
demo.launch(share=False, allowed_themes=["light", "dark"])
|
|
|
|
| 35 |
|
| 36 |
try:
|
| 37 |
# Try different model versions from smallest to largest
|
|
|
|
| 38 |
model_options = [
|
| 39 |
"google/gemma-2b-it",
|
| 40 |
"google/gemma-7b-it",
|
| 41 |
"google/gemma-2b",
|
| 42 |
"google/gemma-7b",
|
| 43 |
+
"TinyLlama/TinyLlama-1.1B-Chat-v1.0" # Fallback
|
|
|
|
| 44 |
]
|
| 45 |
|
| 46 |
print(f"Attempting to load models with token starting with: {hf_token[:5]}...")
|
|
|
|
| 49 |
try:
|
| 50 |
print(f"\n--- Attempting to load model: {model_name} ---")
|
| 51 |
is_gemma = "gemma" in model_name.lower()
|
| 52 |
+
current_token = hf_token if is_gemma else None
|
|
|
|
| 53 |
|
|
|
|
| 54 |
print("Loading tokenizer...")
|
| 55 |
global_tokenizer = AutoTokenizer.from_pretrained(
|
| 56 |
model_name,
|
|
|
|
| 58 |
)
|
| 59 |
print("Tokenizer loaded successfully.")
|
| 60 |
|
|
|
|
| 61 |
print(f"Loading model {model_name}...")
|
| 62 |
global_model = AutoModelForCausalLM.from_pretrained(
|
| 63 |
model_name,
|
|
|
|
| 64 |
torch_dtype=torch.float16, # Using float16 for broader compatibility
|
| 65 |
+
device_map="auto",
|
| 66 |
token=current_token
|
| 67 |
)
|
| 68 |
print(f"Model {model_name} loaded successfully!")
|
|
|
|
| 70 |
model_loaded = True
|
| 71 |
loaded_model_name = model_name
|
| 72 |
loaded_successfully = True
|
| 73 |
+
tabs_update = gr.Tabs.update(visible=True)
|
| 74 |
+
status_msg = f"✅ Model '{model_name}' loaded successfully!"
|
| 75 |
+
if "tinyllama" in model_name.lower():
|
| 76 |
+
status_msg = f"✅ Fallback model '{model_name}' loaded successfully! Limited capabilities compared to Gemma."
|
| 77 |
+
return status_msg, tabs_update
|
| 78 |
|
| 79 |
except ImportError as import_err:
|
| 80 |
+
print(f"Import Error loading {model_name}: {import_err}. Check dependencies (e.g., bitsandbytes, accelerate).")
|
| 81 |
+
continue
|
|
|
|
| 82 |
except Exception as specific_e:
|
| 83 |
print(f"Failed to load {model_name}: {specific_e}")
|
| 84 |
+
if "401 Client Error" in str(specific_e) or "requires you to be logged in" in str(specific_e) and is_gemma:
|
|
|
|
| 85 |
print("Authentication error likely. Check token and license agreement.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
continue
|
| 87 |
|
|
|
|
| 88 |
if not loaded_successfully:
|
| 89 |
model_loaded = False
|
| 90 |
loaded_model_name = "None"
|
| 91 |
print("Could not load any model version.")
|
| 92 |
+
return "❌ Could not load any model. Please check your token, license acceptance, dependencies, and network connection.", initial_tabs_update
|
| 93 |
|
| 94 |
except Exception as e:
|
| 95 |
model_loaded = False
|
|
|
|
| 97 |
error_msg = str(e)
|
| 98 |
print(f"Error in load_model: {error_msg}")
|
| 99 |
traceback.print_exc()
|
|
|
|
| 100 |
if "401 Client Error" in error_msg or "requires you to be logged in" in error_msg :
|
| 101 |
+
return "❌ Authentication failed. Check token/license.", initial_tabs_update
|
| 102 |
else:
|
| 103 |
+
return f"❌ Unexpected error during model loading: {error_msg}", initial_tabs_update
|
| 104 |
|
| 105 |
|
| 106 |
def generate_prompt(task_type, **kwargs):
|
| 107 |
"""Generate appropriate prompts based on task type and parameters"""
|
|
|
|
| 108 |
prompts = {
|
| 109 |
"creative": "Write a {style} about {topic}. Be creative and engaging.",
|
| 110 |
"informational": "Write an {format_type} about {topic}. Be clear, factual, and informative.",
|
|
|
|
| 122 |
"classify": "Classify the following text into one of these categories: {categories}\n\nText: {text}\n\nCategory:",
|
| 123 |
"data_extract": "Extract the following data points ({data_points}) from the text below:\n\nText: {text}\n\nExtracted Data:",
|
| 124 |
}
|
|
|
|
| 125 |
prompt_template = prompts.get(task_type)
|
| 126 |
if prompt_template:
|
| 127 |
try:
|
|
|
|
|
|
|
| 128 |
keys_in_template = [k[1:-1] for k in prompt_template.split('{') if '}' in k for k in [k.split('}')[0]]]
|
| 129 |
+
final_kwargs = {key: kwargs.get(key, f"[{key}]") for key in keys_in_template}
|
| 130 |
+
final_kwargs.update(kwargs) # Add extras
|
|
|
|
|
|
|
|
|
|
| 131 |
return prompt_template.format(**final_kwargs)
|
| 132 |
except KeyError as e:
|
| 133 |
print(f"Warning: Missing key for prompt template '{task_type}': {e}")
|
| 134 |
+
return kwargs.get("prompt", f"Generate text based on: {kwargs}")
|
| 135 |
else:
|
|
|
|
| 136 |
return kwargs.get("prompt", "Generate text based on the input.")
|
| 137 |
|
| 138 |
|
|
|
|
| 146 |
print(f"Prompt (start): {prompt[:150]}...")
|
| 147 |
|
| 148 |
if not model_loaded or global_model is None or global_tokenizer is None:
|
|
|
|
| 149 |
return "⚠️ Model not loaded. Please authenticate first."
|
|
|
|
| 150 |
if not prompt:
|
| 151 |
return "⚠️ Please enter a prompt or configure a task."
|
| 152 |
|
| 153 |
try:
|
| 154 |
+
chat_prompt = prompt # Default to raw prompt
|
|
|
|
| 155 |
if loaded_model_name and ("it" in loaded_model_name.lower() or "instruct" in loaded_model_name.lower() or "chat" in loaded_model_name.lower()):
|
|
|
|
|
|
|
| 156 |
if "gemma" in loaded_model_name.lower():
|
| 157 |
+
# Use Gemma's specific format
|
| 158 |
chat_prompt = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
|
| 159 |
+
elif "tinyllama" in loaded_model_name.lower():
|
| 160 |
+
# Use TinyLlama's chat format
|
| 161 |
+
chat_prompt = f"<|system|>\nYou are a friendly chatbot.</s>\n<|user|>\n{prompt}</s>\n<|assistant|>\n"
|
| 162 |
else: # Generic instruction format
|
| 163 |
chat_prompt = f"User: {prompt}\nAssistant:"
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
inputs = global_tokenizer(chat_prompt, return_tensors="pt", add_special_tokens=True).to(global_model.device)
|
| 166 |
input_length = inputs.input_ids.shape[1]
|
| 167 |
print(f"Input token length: {input_length}")
|
| 168 |
|
| 169 |
+
effective_max_new_tokens = min(int(max_new_tokens), 2048)
|
| 170 |
+
|
| 171 |
+
# Handle potential None for eos_token_id
|
| 172 |
+
eos_token_id = global_tokenizer.eos_token_id
|
| 173 |
+
if eos_token_id is None:
|
| 174 |
+
print("Warning: eos_token_id is None, using default 50256.")
|
| 175 |
+
eos_token_id = 50256 # A common default EOS token ID
|
| 176 |
|
| 177 |
generation_args = {
|
| 178 |
"input_ids": inputs.input_ids,
|
| 179 |
+
"attention_mask": inputs.attention_mask,
|
| 180 |
"max_new_tokens": effective_max_new_tokens,
|
| 181 |
"do_sample": True,
|
| 182 |
+
"temperature": float(temperature),
|
| 183 |
+
"top_p": float(top_p),
|
| 184 |
+
"pad_token_id": eos_token_id # Use determined EOS or default
|
| 185 |
}
|
| 186 |
|
| 187 |
print(f"Generation args: {generation_args}")
|
| 188 |
|
| 189 |
+
with torch.no_grad():
|
|
|
|
| 190 |
outputs = global_model.generate(**generation_args)
|
| 191 |
|
|
|
|
| 192 |
generated_ids = outputs[0, input_length:]
|
| 193 |
generated_text = global_tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 194 |
|
| 195 |
print(f"Generated text length: {len(generated_text)}")
|
| 196 |
print(f"Generated text (start): {generated_text[:150]}...")
|
| 197 |
+
return generated_text.strip()
|
| 198 |
|
| 199 |
except Exception as e:
|
| 200 |
error_msg = str(e)
|
| 201 |
print(f"Generation error: {error_msg}")
|
|
|
|
| 202 |
traceback.print_exc()
|
|
|
|
| 203 |
if "CUDA out of memory" in error_msg:
|
| 204 |
+
return f"❌ Error: CUDA out of memory. Try reducing 'Max New Tokens' or using a smaller model."
|
| 205 |
+
elif "probability tensor contains nan" in error_msg or "invalid value encountered" in error_msg:
|
| 206 |
+
return f"❌ Error: Generation failed (invalid probability). Try adjusting Temperature/Top-P or modifying the prompt."
|
| 207 |
else:
|
| 208 |
+
return f"❌ Error during text generation: {error_msg}"
|
| 209 |
+
|
| 210 |
+
# --- UI Components & Layout ---
|
| 211 |
|
|
|
|
| 212 |
def create_parameter_ui():
|
| 213 |
with gr.Accordion("✨ Generation Parameters", open=False):
|
| 214 |
with gr.Row():
|
| 215 |
+
max_new_tokens = gr.Slider(minimum=64, maximum=2048, value=512, step=64, label="Max New Tokens", info="Max tokens to generate.")
|
| 216 |
+
temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature", info="Controls randomness.")
|
| 217 |
+
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-P", info="Nucleus sampling probability.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
return [max_new_tokens, temperature, top_p]
|
| 219 |
|
| 220 |
+
# Language map (defined once)
|
| 221 |
+
lang_map = {"Python": "python", "JavaScript": "javascript", "Java": "java", "C++": "cpp", "HTML": "html", "CSS": "css", "SQL": "sql", "Bash": "bash", "Rust": "rust", "Other": "plaintext"}
|
| 222 |
+
|
| 223 |
# --- Gradio Interface ---
|
|
|
|
| 224 |
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True, title="Gemma Capabilities Demo") as demo:
|
| 225 |
|
| 226 |
# Header
|
| 227 |
gr.Markdown(
|
| 228 |
"""
|
| 229 |
+
<div style="text-align: center; margin-bottom: 20px;"><h1><span style="font-size: 1.5em;">🤖</span> Gemma Capabilities Demo</h1>
|
| 230 |
+
<p>Explore text generation with Google's Gemma models (or a fallback).</p>
|
| 231 |
+
<p style="font-size: 0.9em;"><a href="https://huggingface.co/google/gemma-7b-it" target="_blank">[Accept Gemma License Here]</a></p></div>"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
)
|
| 233 |
|
| 234 |
+
# --- Authentication ---
|
| 235 |
+
with gr.Group(): # Removed variant="panel"
|
| 236 |
+
gr.Markdown("### 🔑 Authentication")
|
|
|
|
| 237 |
with gr.Row():
|
| 238 |
with gr.Column(scale=4):
|
| 239 |
+
hf_token = gr.Textbox(label="Hugging Face Token", placeholder="Paste token (hf_...)", type="password", value=DEFAULT_HF_TOKEN, info="Needed for Gemma models.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
with gr.Column(scale=1, min_width=150):
|
| 241 |
+
auth_button = gr.Button("Load Model", variant="primary")
|
| 242 |
+
auth_status = gr.Markdown("ℹ️ Enter token & click 'Load Model'. May take time.")
|
|
|
|
|
|
|
| 243 |
gr.Markdown(
|
| 244 |
+
"**Token Info:** Get from [HF Settings](https://huggingface.co/settings/tokens) (read access). Ensure Gemma license is accepted.",
|
| 245 |
+
elem_id="token-info" # Optional ID for styling if needed later
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
)
|
| 247 |
|
| 248 |
+
# --- Main Content Tabs ---
|
|
|
|
|
|
|
| 249 |
with gr.Tabs(elem_id="main_tabs", visible=False) as tabs:
|
| 250 |
|
| 251 |
# --- Text Generation Tab ---
|
| 252 |
+
with gr.TabItem("📝 Creative & Informational"):
|
| 253 |
+
with gr.Row():
|
|
|
|
| 254 |
with gr.Column(scale=1):
|
| 255 |
+
gr.Markdown("#### Configure Task")
|
| 256 |
+
text_gen_type = gr.Radio(["Creative Writing", "Informational Writing", "Custom Prompt"], label="Writing Type", value="Creative Writing")
|
| 257 |
+
with gr.Group(visible=True) as creative_options:
|
| 258 |
+
style = gr.Dropdown(["short story", "poem", "script", "song lyrics", "joke", "dialogue"], label="Style", value="short story")
|
| 259 |
+
creative_topic = gr.Textbox(label="Topic", placeholder="e.g., a lonely astronaut", value="a robot discovering music", lines=2)
|
| 260 |
+
with gr.Group(visible=False) as info_options:
|
| 261 |
+
format_type = gr.Dropdown(["article", "summary", "explanation", "report", "comparison"], label="Format", value="article")
|
| 262 |
+
info_topic = gr.Textbox(label="Topic", placeholder="e.g., quantum physics basics", value="AI impact on healthcare", lines=2)
|
| 263 |
+
with gr.Group(visible=False) as custom_prompt_group:
|
| 264 |
+
custom_prompt = gr.Textbox(label="Custom Prompt", placeholder="Enter full prompt...", lines=5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
text_gen_params = create_parameter_ui()
|
| 266 |
+
# Removed gr.Spacer
|
| 267 |
+
generate_text_btn = gr.Button("Generate Text", variant="primary")
|
|
|
|
|
|
|
| 268 |
with gr.Column(scale=1):
|
| 269 |
+
gr.Markdown("#### Generated Output")
|
| 270 |
+
text_output = gr.Textbox(label="Result", lines=25, interactive=False, show_copy_button=True)
|
| 271 |
+
|
| 272 |
+
# Visibility logic
|
| 273 |
+
def update_text_gen_visibility(choice):
|
| 274 |
+
return { creative_options: gr.update(visible=choice == "Creative Writing"),
|
| 275 |
+
info_options: gr.update(visible=choice == "Informational Writing"),
|
| 276 |
+
custom_prompt_group: gr.update(visible=choice == "Custom Prompt") }
|
| 277 |
+
text_gen_type.change(update_text_gen_visibility, text_gen_type, [creative_options, info_options, custom_prompt_group], queue=False)
|
| 278 |
+
|
| 279 |
+
# Click handler
|
| 280 |
+
def text_gen_click(gen_type, style, c_topic, fmt_type, i_topic, custom_pr, *params):
|
| 281 |
+
task_map = {"Creative Writing": ("creative", {"style": style, "topic": c_topic}),
|
| 282 |
+
"Informational Writing": ("informational", {"format_type": fmt_type, "topic": i_topic}),
|
| 283 |
+
"Custom Prompt": ("custom", {"prompt": custom_pr})}
|
| 284 |
+
task_type, kwargs = task_map.get(gen_type, ("custom", {"prompt": custom_pr}))
|
| 285 |
+
# Apply safe_value inside handler where needed
|
| 286 |
+
if task_type == "creative": kwargs = {"style": safe_value(style, "story"), "topic": safe_value(c_topic, "[topic]")}
|
| 287 |
+
elif task_type == "informational": kwargs = {"format_type": safe_value(fmt_type, "article"), "topic": safe_value(i_topic, "[topic]")}
|
| 288 |
+
else: kwargs = {"prompt": safe_value(custom_pr, "Write something.")}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
final_prompt = generate_prompt(task_type, **kwargs)
|
| 290 |
+
return generate_text(final_prompt, *params)
|
| 291 |
+
generate_text_btn.click(text_gen_click, [text_gen_type, style, creative_topic, format_type, info_topic, custom_prompt, *text_gen_params], text_output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
# Examples
|
| 294 |
+
gr.Examples( examples=[ ["Creative Writing", "poem", "sound of rain", "", "", "", 512, 0.7, 0.9],
|
| 295 |
+
["Informational Writing", "", "", "explanation", "photosynthesis", "", 768, 0.6, 0.9],
|
| 296 |
+
["Custom Prompt", "", "", "", "", "Dialogue: cat and dog discuss humans.", 512, 0.8, 0.95] ],
|
| 297 |
+
inputs=[text_gen_type, style, creative_topic, format_type, info_topic, custom_prompt, *text_gen_params[:3]], # Pass UI elements
|
| 298 |
+
outputs=text_output, label="Try examples...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
|
| 300 |
|
| 301 |
# --- Brainstorming Tab ---
|
| 302 |
+
with gr.TabItem("🧠 Brainstorming"):
|
| 303 |
+
with gr.Row():
|
| 304 |
+
with gr.Column(scale=1):
|
| 305 |
+
gr.Markdown("#### Setup")
|
| 306 |
+
brainstorm_category = gr.Dropdown(["project", "business", "creative", "solution", "content", "feature", "product name"], label="Category", value="project")
|
| 307 |
+
brainstorm_topic = gr.Textbox(label="Topic/Problem", placeholder="e.g., reducing plastic waste", value="unique mobile app ideas", lines=3)
|
| 308 |
+
brainstorm_params = create_parameter_ui()
|
| 309 |
+
# Removed gr.Spacer
|
| 310 |
+
brainstorm_btn = gr.Button("Generate Ideas", variant="primary")
|
| 311 |
+
with gr.Column(scale=1):
|
| 312 |
+
gr.Markdown("#### Generated Ideas")
|
| 313 |
+
brainstorm_output = gr.Textbox(label="Result", lines=25, interactive=False, show_copy_button=True)
|
| 314 |
+
|
| 315 |
+
def brainstorm_click(category, topic, *params):
|
| 316 |
+
prompt = generate_prompt("brainstorm", category=safe_value(category, "project"), topic=safe_value(topic, "ideas"))
|
| 317 |
+
return generate_text(prompt, *params)
|
| 318 |
+
brainstorm_btn.click(brainstorm_click, [brainstorm_category, brainstorm_topic, *brainstorm_params], brainstorm_output)
|
| 319 |
+
gr.Examples([ ["solution", "engaging online learning", 768, 0.8, 0.9],
|
| 320 |
+
["business", "eco-friendly subscription boxes", 768, 0.75, 0.9],
|
| 321 |
+
["creative", "fantasy novel themes", 512, 0.85, 0.95] ],
|
| 322 |
+
inputs=[brainstorm_category, brainstorm_topic, *brainstorm_params[:3]], outputs=brainstorm_output, label="Try examples...")
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# --- Code Tab ---
|
| 326 |
+
with gr.TabItem("💻 Code"):
|
| 327 |
+
with gr.Tabs():
|
| 328 |
+
with gr.TabItem("Generate"):
|
| 329 |
+
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
with gr.Column(scale=1):
|
| 331 |
+
gr.Markdown("#### Setup")
|
| 332 |
+
code_lang_gen = gr.Dropdown(list(lang_map.keys())[:-1], label="Language", value="Python")
|
| 333 |
+
code_task = gr.Textbox(label="Task", placeholder="e.g., function for factorial", value="Python class for calculator", lines=4)
|
| 334 |
+
code_gen_params = create_parameter_ui()
|
| 335 |
+
# Removed gr.Spacer
|
| 336 |
+
code_gen_btn = gr.Button("Generate Code", variant="primary")
|
|
|
|
|
|
|
| 337 |
with gr.Column(scale=1):
|
| 338 |
+
gr.Markdown("#### Generated Code")
|
| 339 |
+
code_output = gr.Code(label="Result", language="python", lines=25, interactive=False)
|
| 340 |
+
|
| 341 |
+
def gen_code_click(lang, task, *params):
|
| 342 |
+
prompt = generate_prompt("code_generate", language=safe_value(lang, "Python"), task=safe_value(task, "hello world"))
|
| 343 |
+
result = generate_text(prompt, *params)
|
| 344 |
+
# Basic code block extraction
|
| 345 |
+
if "```" in result:
|
|
|
|
|
|
|
|
|
|
| 346 |
parts = result.split("```")
|
| 347 |
if len(parts) >= 2:
|
| 348 |
+
block = parts[1]
|
| 349 |
+
if '\n' in block: first_line, rest = block.split('\n', 1); return rest.strip() if first_line.strip().lower() == lang.lower() else block.strip()
|
| 350 |
+
else: return block.strip()
|
| 351 |
+
return result.strip()
|
| 352 |
+
def update_gen_lang_display(lang): return gr.Code.update(language=lang_map.get(lang, "plaintext"))
|
| 353 |
+
code_lang_gen.change(update_gen_lang_display, code_lang_gen, code_output, queue=False)
|
| 354 |
+
code_gen_btn.click(gen_code_click, [code_lang_gen, code_task, *code_gen_params], code_output)
|
| 355 |
+
gr.Examples([ ["JavaScript", "email validation regex function", 768, 0.6, 0.9],
|
| 356 |
+
["SQL", "select users > 30 yrs old", 512, 0.5, 0.8],
|
| 357 |
+
["HTML", "basic portfolio structure", 1024, 0.7, 0.9] ],
|
| 358 |
+
inputs=[code_lang_gen, code_task, *code_gen_params[:3]], outputs=code_output, label="Try examples...")
|
| 359 |
+
|
| 360 |
+
with gr.TabItem("Explain"):
|
| 361 |
+
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
with gr.Column(scale=1):
|
| 363 |
+
gr.Markdown("#### Setup")
|
| 364 |
+
code_lang_explain = gr.Dropdown(list(lang_map.keys()), label="Language", value="Python")
|
| 365 |
+
code_to_explain = gr.Code(label="Code to Explain", language="python", lines=15)
|
| 366 |
+
explain_code_params = create_parameter_ui()
|
| 367 |
+
# Removed gr.Spacer
|
| 368 |
+
explain_code_btn = gr.Button("Explain Code", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
with gr.Column(scale=1):
|
| 370 |
+
gr.Markdown("#### Explanation")
|
| 371 |
+
code_explanation = gr.Textbox(label="Result", lines=25, interactive=False, show_copy_button=True)
|
| 372 |
|
| 373 |
+
def explain_code_click(lang, code, *params):
|
| 374 |
+
code_content = safe_value(code['code'] if isinstance(code, dict) else code, "# Empty code")
|
| 375 |
+
prompt = generate_prompt("code_explain", language=safe_value(lang, "code"), code=code_content)
|
| 376 |
+
return generate_text(prompt, *params)
|
| 377 |
+
def update_explain_lang_display(lang): return gr.Code.update(language=lang_map.get(lang, "plaintext"))
|
| 378 |
+
code_lang_explain.change(update_explain_lang_display, code_lang_explain, code_to_explain, queue=False)
|
| 379 |
+
explain_code_btn.click(explain_code_click, [code_lang_explain, code_to_explain, *explain_code_params], code_explanation)
|
| 380 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 381 |
|
| 382 |
+
with gr.TabItem("Debug"):
|
| 383 |
+
with gr.Row():
|
| 384 |
+
with gr.Column(scale=1):
|
| 385 |
+
gr.Markdown("#### Setup")
|
| 386 |
+
code_lang_debug = gr.Dropdown(list(lang_map.keys()), label="Language", value="Python")
|
| 387 |
+
code_to_debug = gr.Code(label="Buggy Code", language="python", lines=15, value="def avg(nums):\n # Potential div by zero\n return sum(nums)/len(nums)")
|
| 388 |
+
debug_code_params = create_parameter_ui()
|
| 389 |
+
# Removed gr.Spacer
|
| 390 |
+
debug_code_btn = gr.Button("Debug Code", variant="primary")
|
| 391 |
+
with gr.Column(scale=1):
|
| 392 |
+
gr.Markdown("#### Debugging Analysis")
|
| 393 |
+
debug_result = gr.Textbox(label="Result", lines=25, interactive=False, show_copy_button=True)
|
| 394 |
+
|
| 395 |
+
def debug_code_click(lang, code, *params):
|
| 396 |
+
code_content = safe_value(code['code'] if isinstance(code, dict) else code, "# Empty code")
|
| 397 |
+
prompt = generate_prompt("code_debug", language=safe_value(lang, "code"), code=code_content)
|
| 398 |
+
return generate_text(prompt, *params)
|
| 399 |
+
def update_debug_lang_display(lang): return gr.Code.update(language=lang_map.get(lang, "plaintext"))
|
| 400 |
+
code_lang_debug.change(update_debug_lang_display, code_lang_debug, code_to_debug, queue=False)
|
| 401 |
+
debug_code_btn.click(debug_code_click, [code_lang_debug, code_to_debug, *debug_code_params], debug_result)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
# --- Comprehension Tab ---
|
| 405 |
+
with gr.TabItem("📚 Comprehension"):
|
| 406 |
+
with gr.Tabs():
|
| 407 |
+
with gr.TabItem("Summarize"):
|
| 408 |
+
with gr.Row():
|
| 409 |
with gr.Column(scale=1):
|
| 410 |
+
gr.Markdown("#### Setup")
|
| 411 |
+
summarize_text = gr.Textbox(label="Text to Summarize", lines=15, placeholder="Paste long text...")
|
| 412 |
summarize_params = create_parameter_ui()
|
| 413 |
+
# Removed gr.Spacer
|
| 414 |
+
summarize_btn = gr.Button("Summarize Text", variant="primary")
|
|
|
|
| 415 |
with gr.Column(scale=1):
|
| 416 |
+
gr.Markdown("#### Summary")
|
| 417 |
+
summary_output = gr.Textbox(label="Result", lines=15, interactive=False, show_copy_button=True)
|
| 418 |
+
def summarize_click(text, *params):
|
| 419 |
+
prompt = generate_prompt("summarize", text=safe_value(text, "[empty text]"))
|
| 420 |
+
# Adjust max tokens for summary specifically if needed
|
| 421 |
+
p_list = list(params); p_list[0] = min(max(int(p_list[0]), 64), 512)
|
| 422 |
+
return generate_text(prompt, *p_list)
|
| 423 |
+
summarize_btn.click(summarize_click, [summarize_text, *summarize_params], summary_output)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
with gr.TabItem("Q & A"):
|
| 427 |
+
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
with gr.Column(scale=1):
|
| 429 |
+
gr.Markdown("#### Setup")
|
| 430 |
+
qa_text = gr.Textbox(label="Context Text", lines=10, placeholder="Paste text containing answer...")
|
| 431 |
+
qa_question = gr.Textbox(label="Question", placeholder="Ask question about text...")
|
| 432 |
qa_params = create_parameter_ui()
|
| 433 |
+
# Removed gr.Spacer
|
| 434 |
+
qa_btn = gr.Button("Get Answer", variant="primary")
|
|
|
|
| 435 |
with gr.Column(scale=1):
|
| 436 |
+
gr.Markdown("#### Answer")
|
| 437 |
+
qa_output = gr.Textbox(label="Result", lines=10, interactive=False, show_copy_button=True)
|
| 438 |
+
def qa_click(text, question, *params):
|
| 439 |
+
prompt = generate_prompt("qa", text=safe_value(text, "[context]"), question=safe_value(question,"[question]"))
|
| 440 |
+
p_list = list(params); p_list[0] = min(max(int(p_list[0]), 32), 256)
|
| 441 |
+
return generate_text(prompt, *p_list)
|
| 442 |
+
qa_btn.click(qa_click, [qa_text, qa_question, *qa_params], qa_output)
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
with gr.TabItem("Translate"):
|
| 446 |
+
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
with gr.Column(scale=1):
|
| 448 |
+
gr.Markdown("#### Setup")
|
| 449 |
+
translate_text = gr.Textbox(label="Text to Translate", lines=8, placeholder="Enter text...")
|
| 450 |
+
target_lang = gr.Dropdown(["French", "Spanish", "German", "Japanese", "Chinese", "Russian", "Arabic", "Hindi", "Portuguese", "Italian"], label="Translate To", value="French")
|
|
|
|
|
|
|
|
|
|
| 451 |
translate_params = create_parameter_ui()
|
| 452 |
+
# Removed gr.Spacer
|
| 453 |
+
translate_btn = gr.Button("Translate Text", variant="primary")
|
|
|
|
| 454 |
with gr.Column(scale=1):
|
| 455 |
+
gr.Markdown("#### Translation")
|
| 456 |
+
translation_output = gr.Textbox(label="Result", lines=8, interactive=False, show_copy_button=True)
|
| 457 |
+
def translate_click(text, lang, *params):
|
| 458 |
+
prompt = generate_prompt("translate", text=safe_value(text,"[text]"), target_lang=safe_value(lang,"French"))
|
| 459 |
+
p_list = list(params); p_list[0] = max(int(p_list[0]), 64)
|
| 460 |
+
return generate_text(prompt, *p_list)
|
| 461 |
+
translate_btn.click(translate_click, [translate_text, target_lang, *translate_params], translation_output)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
# --- More Tasks Tab ---
|
| 465 |
+
with gr.TabItem("🛠️ More Tasks"):
|
| 466 |
+
with gr.Tabs():
|
| 467 |
+
with gr.TabItem("Content Creation"):
|
| 468 |
+
with gr.Row():
|
| 469 |
+
with gr.Column(scale=1):
|
| 470 |
+
gr.Markdown("#### Setup")
|
| 471 |
+
content_type = gr.Dropdown(["blog post outline", "social media post (Twitter)", "social media post (LinkedIn)", "marketing email subject line", "product description", "press release intro"], label="Content Type", value="blog post outline")
|
| 472 |
+
content_topic = gr.Textbox(label="Topic", value="sustainable travel tips", lines=2)
|
| 473 |
+
content_audience = gr.Textbox(label="Audience", value="eco-conscious millennials")
|
| 474 |
+
content_params = create_parameter_ui()
|
| 475 |
+
# Removed gr.Spacer
|
| 476 |
+
content_btn = gr.Button("Generate Content", variant="primary")
|
| 477 |
+
with gr.Column(scale=1):
|
| 478 |
+
gr.Markdown("#### Generated Content")
|
| 479 |
+
content_output = gr.Textbox(label="Result", lines=20, interactive=False, show_copy_button=True)
|
| 480 |
+
def content_click(c_type, topic, audience, *params):
|
| 481 |
+
prompt = generate_prompt("content_creation", content_type=safe_value(c_type,"text"), topic=safe_value(topic,"[topic]"), audience=safe_value(audience,"[audience]"))
|
| 482 |
+
return generate_text(prompt, *params)
|
| 483 |
+
content_btn.click(content_click, [content_type, content_topic, content_audience, *content_params], content_output)
|
| 484 |
+
|
| 485 |
+
with gr.TabItem("Email Drafting"):
|
| 486 |
+
with gr.Row():
|
| 487 |
with gr.Column(scale=1):
|
| 488 |
+
gr.Markdown("#### Setup")
|
| 489 |
+
email_type = gr.Dropdown(["job inquiry", "meeting request", "follow-up", "thank you", "support response", "sales outreach"], label="Email Type", value="meeting request")
|
| 490 |
+
email_context = gr.Textbox(label="Context/Points", lines=5, value="Request meeting next week re: project X. Suggest Tue/Wed afternoon.")
|
| 491 |
+
email_params = create_parameter_ui()
|
| 492 |
+
# Removed gr.Spacer
|
| 493 |
+
email_btn = gr.Button("Generate Draft", variant="primary")
|
|
|
|
| 494 |
with gr.Column(scale=1):
|
| 495 |
+
gr.Markdown("#### Generated Draft")
|
| 496 |
+
email_output = gr.Textbox(label="Result", lines=20, interactive=False, show_copy_button=True)
|
| 497 |
+
def email_click(e_type, context, *params):
|
| 498 |
+
prompt = generate_prompt("email_draft", email_type=safe_value(e_type,"email"), context=safe_value(context,"[context]"))
|
| 499 |
+
return generate_text(prompt, *params)
|
| 500 |
+
email_btn.click(email_click, [email_type, email_context, *email_params], email_output)
|
| 501 |
+
|
| 502 |
+
with gr.TabItem("Doc Editing"):
|
| 503 |
+
with gr.Row():
|
| 504 |
+
with gr.Column(scale=1):
|
| 505 |
+
gr.Markdown("#### Setup")
|
| 506 |
+
edit_text = gr.Textbox(label="Text to Edit", lines=10, placeholder="Paste text...")
|
| 507 |
+
edit_type = gr.Dropdown(["improve clarity", "fix grammar/spelling", "make concise", "make formal", "make casual", "simplify"], label="Improve For", value="improve clarity")
|
| 508 |
+
edit_params = create_parameter_ui()
|
| 509 |
+
# Removed gr.Spacer
|
| 510 |
+
edit_btn = gr.Button("Edit Text", variant="primary")
|
| 511 |
+
with gr.Column(scale=1):
|
| 512 |
+
gr.Markdown("#### Edited Text")
|
| 513 |
+
edit_output = gr.Textbox(label="Result", lines=10, interactive=False, show_copy_button=True)
|
| 514 |
+
def edit_click(text, e_type, *params):
|
| 515 |
+
prompt = generate_prompt("document_edit", text=safe_value(text,"[text]"), edit_type=safe_value(e_type,"clarity"))
|
| 516 |
+
p_list = list(params); input_tokens = len(safe_value(text,"").split()); p_list[0] = max(int(p_list[0]), input_tokens + 64)
|
| 517 |
+
return generate_text(prompt, *p_list)
|
| 518 |
+
edit_btn.click(edit_click, [edit_text, edit_type, *edit_params], edit_output)
|
| 519 |
+
|
| 520 |
+
with gr.TabItem("Classification"):
|
| 521 |
+
with gr.Row():
|
| 522 |
with gr.Column(scale=1):
|
| 523 |
+
gr.Markdown("#### Setup")
|
| 524 |
+
classify_text = gr.Textbox(label="Text to Classify", lines=8, value="Sci-fi movie explores AI consciousness.")
|
| 525 |
+
classify_categories = gr.Textbox(label="Categories (comma-sep)", value="Tech, Entertainment, Science, Politics")
|
| 526 |
+
classify_params = create_parameter_ui()
|
| 527 |
+
# Removed gr.Spacer
|
| 528 |
+
classify_btn = gr.Button("Classify Text", variant="primary")
|
| 529 |
with gr.Column(scale=1):
|
| 530 |
+
gr.Markdown("#### Classification")
|
| 531 |
+
classify_output = gr.Textbox(label="Predicted Category", lines=2, interactive=False, show_copy_button=True)
|
| 532 |
+
def classify_click(text, cats, *params):
|
| 533 |
+
prompt = generate_prompt("classify", text=safe_value(text,"[text]"), categories=safe_value(cats,"cat1, cat2"))
|
| 534 |
+
p_list = list(params); p_list[0] = min(max(int(p_list[0]), 16), 128)
|
| 535 |
+
raw = generate_text(prompt, *p_list)
|
| 536 |
+
# Basic post-processing attempt
|
| 537 |
+
lines = raw.split('\n'); last = lines[-1].strip(); possible = [c.strip().lower() for c in cats.split(',')]; return last if last.lower() in possible else raw
|
| 538 |
+
classify_btn.click(classify_click, [classify_text, classify_categories, *classify_params], classify_output)
|
| 539 |
+
|
| 540 |
+
with gr.TabItem("Data Extraction"):
|
| 541 |
+
with gr.Row():
|
|
|
|
|
|
|
| 542 |
with gr.Column(scale=1):
|
| 543 |
+
gr.Markdown("#### Setup")
|
| 544 |
+
extract_text = gr.Textbox(label="Source Text", lines=10, value="Order #123 by Jane (j@ex.com). Total: $99. Shipped: 123 Main St.")
|
| 545 |
+
extract_data_points = gr.Textbox(label="Data Points (comma-sep)", value="order num, name, email, total, address")
|
| 546 |
+
extract_params = create_parameter_ui()
|
| 547 |
+
# Removed gr.Spacer
|
| 548 |
+
extract_btn = gr.Button("Extract Data", variant="primary")
|
| 549 |
with gr.Column(scale=1):
|
| 550 |
+
gr.Markdown("#### Extracted Data")
|
| 551 |
+
extract_output = gr.Textbox(label="Result (JSON or Key-Value)", lines=10, interactive=False, show_copy_button=True)
|
| 552 |
+
def extract_click(text, points, *params):
|
| 553 |
+
prompt = generate_prompt("data_extract", text=safe_value(text,"[text]"), data_points=safe_value(points,"info"))
|
| 554 |
+
return generate_text(prompt, *params)
|
| 555 |
+
extract_btn.click(extract_click, [extract_text, extract_data_points, *extract_params], extract_output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 556 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 557 |
|
| 558 |
+
# --- Authentication Handler & Footer ---
|
| 559 |
+
footer_status = gr.Markdown(f"...", elem_id="footer-status-md") # Placeholder
|
| 560 |
|
|
|
|
| 561 |
def handle_auth(token):
|
| 562 |
+
yield "⏳ Authenticating & loading model...", gr.Tabs.update(visible=False)
|
|
|
|
|
|
|
| 563 |
status_message, tabs_update = load_model(token)
|
| 564 |
yield status_message, tabs_update
|
| 565 |
|
| 566 |
+
def update_footer_status(status_text): # Updates footer based on global state
|
| 567 |
+
return gr.Markdown.update(value=f"""
|
| 568 |
+
<hr><div style="text-align: center; font-size: 0.9em; color: #777;">
|
| 569 |
+
<p>Powered by Hugging Face 🤗 Transformers & Gradio. Model: <strong>{loaded_model_name if model_loaded else 'None'}</strong>.</p>
|
| 570 |
+
<p>Review outputs carefully. Models may generate inaccurate information.</p></div>""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
|
| 572 |
+
auth_button.click(handle_auth, hf_token, [auth_status, tabs], queue=True)
|
| 573 |
+
# Update footer whenever auth status text changes
|
| 574 |
+
auth_status.change(update_footer_status, auth_status, footer_status, queue=False)
|
| 575 |
+
# Initial footer update on load
|
| 576 |
+
demo.load(update_footer_status, auth_status, footer_status, queue=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 577 |
|
| 578 |
|
| 579 |
# --- Launch App ---
|
|
|
|
|
|
|
| 580 |
demo.launch(share=False, allowed_themes=["light", "dark"])
|