Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import spaces | |
| MODEL_ID = "EssentialAI/rnj-1-instruct" | |
| # Load tokenizer & model once (ZeroGPU will put it on GPU when @spaces.GPU fn runs) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16, | |
| device_map="auto", # ZeroGPU will map to H200 slice | |
| ) | |
| SYSTEM_PROMPT = ( | |
| "You are RNJ-1, a precise, math-friendly assistant. " | |
| "Solve problems step by step, avoid rambling, and give the final answer clearly." | |
| ) | |
| def build_prompt(user_text: str) -> str: | |
| # Keep it simple – RNJ handles plain text well | |
| return f"{SYSTEM_PROMPT}\n\nUser: {user_text}\nAssistant:" | |
| # <-- THIS is what ZeroGPU wants | |
| def generate_reply( | |
| user_text: str, | |
| max_new_tokens: int, | |
| temperature: float, | |
| ) -> str: | |
| prompt = build_prompt(user_text) | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=temperature > 0.0, | |
| temperature=float(temperature), | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| full = tokenizer.decode(output[0], skip_special_tokens=True) | |
| # Strip everything before the last "Assistant:" so prompt text doesn't echo back | |
| if "Assistant:" in full: | |
| reply = full.split("Assistant:")[-1].strip() | |
| else: | |
| reply = full.strip() | |
| return reply | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# RNJ-1 Instruct on ZeroGPU (H200)") | |
| with gr.Row(): | |
| prompt_box = gr.Textbox( | |
| label="Input", | |
| placeholder="Ask RNJ-1 to solve an equation, prove something, etc.", | |
| lines=6, | |
| ) | |
| with gr.Row(): | |
| max_tokens_slider = gr.Slider( | |
| minimum=64, | |
| maximum=4096, # long answers so it stops cutting off | |
| value=1024, | |
| step=64, | |
| label="Max new tokens", | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.2, | |
| step=0.05, | |
| label="Temperature", | |
| ) | |
| output_box = gr.Textbox( | |
| label="RNJ-1 Output", | |
| lines=20, | |
| ) | |
| run_btn = gr.Button("Generate") | |
| run_btn.click( | |
| fn=generate_reply, | |
| inputs=[prompt_box, max_tokens_slider, temperature_slider], | |
| outputs=[output_box], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |