Spaces:
Runtime error
Runtime error
Commit
·
45c882e
1
Parent(s):
c1d34f4
OM
Browse files
app.py
CHANGED
|
@@ -44,15 +44,12 @@ def load_model(hf_token):
|
|
| 44 |
token=hf_token
|
| 45 |
)
|
| 46 |
|
| 47 |
-
# Load model with
|
| 48 |
global_model = AutoModelForCausalLM.from_pretrained(
|
| 49 |
model_name,
|
| 50 |
torch_dtype=torch.float16,
|
| 51 |
device_map="auto",
|
| 52 |
-
token=hf_token
|
| 53 |
-
use_cache=True,
|
| 54 |
-
low_cpu_mem_usage=True,
|
| 55 |
-
attn_implementation="flash_attention_2" if torch.cuda.is_available() else "eager"
|
| 56 |
)
|
| 57 |
|
| 58 |
model_loaded = True
|
|
@@ -162,28 +159,15 @@ def generate_text(prompt, max_length=1024, temperature=0.7, top_p=0.95):
|
|
| 162 |
return "Please enter a prompt to generate text."
|
| 163 |
|
| 164 |
try:
|
|
|
|
| 165 |
inputs = global_tokenizer(prompt, return_tensors="pt").to(global_model.device)
|
| 166 |
|
| 167 |
-
|
| 168 |
-
"max_length": max_length,
|
| 169 |
-
"do_sample": True,
|
| 170 |
-
"pad_token_id": global_tokenizer.eos_token_id,
|
| 171 |
-
}
|
| 172 |
-
|
| 173 |
-
# Only add temperature if it's not too low (can cause probability issues)
|
| 174 |
-
if temperature >= 0.2:
|
| 175 |
-
generation_config["temperature"] = temperature
|
| 176 |
-
else:
|
| 177 |
-
generation_config["temperature"] = 0.2
|
| 178 |
-
|
| 179 |
-
# Only add top_p if it's valid
|
| 180 |
-
if 0 < top_p < 1:
|
| 181 |
-
generation_config["top_p"] = top_p
|
| 182 |
-
|
| 183 |
-
# Generate text with safer parameters
|
| 184 |
outputs = global_model.generate(
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
| 187 |
)
|
| 188 |
|
| 189 |
# Decode and return the generated text
|
|
@@ -191,8 +175,9 @@ def generate_text(prompt, max_length=1024, temperature=0.7, top_p=0.95):
|
|
| 191 |
return generated_text
|
| 192 |
except Exception as e:
|
| 193 |
error_msg = str(e)
|
|
|
|
| 194 |
if "probability tensor" in error_msg:
|
| 195 |
-
return "Error: There was a problem with the generation parameters. Try using
|
| 196 |
else:
|
| 197 |
return f"Error generating text: {error_msg}"
|
| 198 |
|
|
@@ -247,12 +232,27 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 247 |
)
|
| 248 |
|
| 249 |
with gr.Column(scale=1):
|
| 250 |
-
auth_button = gr.Button("Authenticate")
|
| 251 |
|
| 252 |
-
|
|
|
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
auth_button.click(
|
| 255 |
-
fn=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
inputs=[hf_token],
|
| 257 |
outputs=[auth_status]
|
| 258 |
)
|
|
@@ -1019,6 +1019,8 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
| 1019 |
|
| 1020 |
# Load default token if available
|
| 1021 |
if DEFAULT_HF_TOKEN:
|
| 1022 |
-
demo.load(fn=
|
|
|
|
|
|
|
| 1023 |
|
| 1024 |
-
demo.launch()
|
|
|
|
| 44 |
token=hf_token
|
| 45 |
)
|
| 46 |
|
| 47 |
+
# Load model with minimal configuration to avoid errors
|
| 48 |
global_model = AutoModelForCausalLM.from_pretrained(
|
| 49 |
model_name,
|
| 50 |
torch_dtype=torch.float16,
|
| 51 |
device_map="auto",
|
| 52 |
+
token=hf_token
|
|
|
|
|
|
|
|
|
|
| 53 |
)
|
| 54 |
|
| 55 |
model_loaded = True
|
|
|
|
| 159 |
return "Please enter a prompt to generate text."
|
| 160 |
|
| 161 |
try:
|
| 162 |
+
# Keep generation simple to avoid errors
|
| 163 |
inputs = global_tokenizer(prompt, return_tensors="pt").to(global_model.device)
|
| 164 |
|
| 165 |
+
# Use simpler generation parameters that work reliably
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
outputs = global_model.generate(
|
| 167 |
+
inputs.input_ids,
|
| 168 |
+
max_length=min(2048, max_length + len(inputs.input_ids[0])),
|
| 169 |
+
temperature=max(0.3, temperature), # Prevent too low temperature
|
| 170 |
+
do_sample=True
|
| 171 |
)
|
| 172 |
|
| 173 |
# Decode and return the generated text
|
|
|
|
| 175 |
return generated_text
|
| 176 |
except Exception as e:
|
| 177 |
error_msg = str(e)
|
| 178 |
+
print(f"Generation error: {error_msg}")
|
| 179 |
if "probability tensor" in error_msg:
|
| 180 |
+
return "Error: There was a problem with the generation parameters. Try using simpler parameters or a different prompt."
|
| 181 |
else:
|
| 182 |
return f"Error generating text: {error_msg}"
|
| 183 |
|
|
|
|
| 232 |
)
|
| 233 |
|
| 234 |
with gr.Column(scale=1):
|
| 235 |
+
auth_button = gr.Button("Authenticate", variant="primary")
|
| 236 |
|
| 237 |
+
with gr.Group(visible=True) as auth_message_group:
|
| 238 |
+
auth_status = gr.Markdown("Please authenticate to use the model.")
|
| 239 |
|
| 240 |
+
def authenticate(token):
|
| 241 |
+
auth_message_group.visible = True
|
| 242 |
+
return "Loading model... Please wait, this may take a minute."
|
| 243 |
+
|
| 244 |
+
def auth_complete(token):
|
| 245 |
+
result = load_model(token)
|
| 246 |
+
return result
|
| 247 |
+
|
| 248 |
+
# Two-step authentication to show loading message
|
| 249 |
auth_button.click(
|
| 250 |
+
fn=authenticate,
|
| 251 |
+
inputs=[hf_token],
|
| 252 |
+
outputs=[auth_status],
|
| 253 |
+
queue=False
|
| 254 |
+
).then(
|
| 255 |
+
fn=auth_complete,
|
| 256 |
inputs=[hf_token],
|
| 257 |
outputs=[auth_status]
|
| 258 |
)
|
|
|
|
| 1019 |
|
| 1020 |
# Load default token if available
|
| 1021 |
if DEFAULT_HF_TOKEN:
|
| 1022 |
+
demo.load(fn=authenticate, inputs=[hf_token], outputs=[auth_status]).then(
|
| 1023 |
+
fn=auth_complete, inputs=[hf_token], outputs=[auth_status]
|
| 1024 |
+
)
|
| 1025 |
|
| 1026 |
+
demo.launch(share=False)
|