import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM import spaces MODEL_ID = "EssentialAI/rnj-1-instruct" # Load tokenizer & model once (ZeroGPU will put it on GPU when @spaces.GPU fn runs) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16, device_map="auto", # ZeroGPU will map to H200 slice ) SYSTEM_PROMPT = ( "You are RNJ-1, a precise, math-friendly assistant. " "Solve problems step by step, avoid rambling, and give the final answer clearly." ) def build_prompt(user_text: str) -> str: # Keep it simple – RNJ handles plain text well return f"{SYSTEM_PROMPT}\n\nUser: {user_text}\nAssistant:" @spaces.GPU # <-- THIS is what ZeroGPU wants def generate_reply( user_text: str, max_new_tokens: int, temperature: float, ) -> str: prompt = build_prompt(user_text) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=int(max_new_tokens), do_sample=temperature > 0.0, temperature=float(temperature), top_p=0.9, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) full = tokenizer.decode(output[0], skip_special_tokens=True) # Strip everything before the last "Assistant:" so prompt text doesn't echo back if "Assistant:" in full: reply = full.split("Assistant:")[-1].strip() else: reply = full.strip() return reply with gr.Blocks() as demo: gr.Markdown("# RNJ-1 Instruct on ZeroGPU (H200)") with gr.Row(): prompt_box = gr.Textbox( label="Input", placeholder="Ask RNJ-1 to solve an equation, prove something, etc.", lines=6, ) with gr.Row(): max_tokens_slider = gr.Slider( minimum=64, maximum=4096, # long answers so it stops cutting off value=1024, step=64, label="Max new tokens", ) temperature_slider = gr.Slider( minimum=0.0, maximum=1.0, value=0.2, step=0.05, label="Temperature", ) output_box = gr.Textbox( label="RNJ-1 Output", lines=20, ) run_btn = gr.Button("Generate") run_btn.click( fn=generate_reply, inputs=[prompt_box, max_tokens_slider, temperature_slider], outputs=[output_box], ) if __name__ == "__main__": demo.launch()