File size: 1,649 Bytes
59cdc8b
 
17b1867
715110f
 
 
 
 
c6509f9
715110f
8c6d7e5
715110f
c6509f9
 
 
b22328e
479cd19
c520eb1
 
c6509f9
59cdc8b
 
c6509f9
b22328e
715110f
 
c4e9d8d
59cdc8b
 
 
b22328e
479cd19
59cdc8b
b22328e
 
59cdc8b
b22328e
59cdc8b
 
c6509f9
 
4ec308a
 
 
c6509f9
 
715110f
59cdc8b
 
678c4c4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch

app = FastAPI()

# Define paths
base_model_path = "NousResearch/Hermes-3-Llama-3.2-3B"
adapter_path = "thinkingnew/llama_invs_adapter"

# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load base model with `device_map="auto"` to handle GPUs automatically
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_path, torch_dtype=torch.float16, device_map="auto"
)

# Load adapter and ensure it is on the correct device
model = PeftModel.from_pretrained(base_model, adapter_path).to(device)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_path)


class GenerateRequest(BaseModel):
    prompt: str

# **Use `model.generate()` instead of `pipeline()`**
def generate_text_from_model(prompt: str):
    try:
        input_ids = tokenizer(f"<s>[INST] {prompt} [/INST]", return_tensors="pt").input_ids.to(device)
        output_ids = model.generate(input_ids, max_length=512)
        generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        return generated_text
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Root endpoint for testing
@app.get("/")
async def root():
    return {"message": "Model is running! Use /generate/ for text generation."}

# Text generation endpoint
@app.post("/generate/")
async def generate_text(request: GenerateRequest):
    response = generate_text_from_model(request.prompt)
    return {"response": response}