fuvty commited on
Commit
68b944b
·
1 Parent(s): 7b0d224

[update] support zeroGPU

Browse files
Files changed (2) hide show
  1. app.py +52 -9
  2. requirements.txt +3 -0
app.py CHANGED
@@ -6,7 +6,10 @@ This creates a web interface to compare three inference modes simultaneously:
6
  2. T2T: Two-stage inference (shows context + answer)
7
  3. C2C: Rosetta model with projectors
8
 
9
- All models are loaded at startup and respond to the same input in parallel.
 
 
 
10
  """
11
 
12
  import os
@@ -19,6 +22,20 @@ from typing import Optional, Generator
19
  from queue import Queue
20
  from threading import Thread
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
23
  from rosetta.utils.evaluate import load_rosetta_model, load_hf_model, set_default_chat_template
24
  from rosetta.model.wrapper import RosettaModel
@@ -46,8 +63,13 @@ class ModelManager:
46
  c2c_checkpoint_path: Path to C2C checkpoint directory
47
  device: Device to use (cuda, cpu, or auto)
48
  """
 
49
  if device == "auto":
50
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
51
  else:
52
  self.device = torch.device(device)
53
  print(f"Using device: {self.device}")
@@ -208,13 +230,19 @@ class ModelManager:
208
 
209
  return kwargs
210
 
 
211
  def generate_single(self, user_input: str) -> Generator[str, None, None]:
212
  """Generate response from single model with streaming."""
 
 
 
 
 
213
  messages = [{"role": "system", "content": ""}, {"role": "user", "content": user_input}]
214
  text = self.single_tokenizer.apply_chat_template(
215
  messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
216
  )
217
- inputs = self.single_tokenizer(text, return_tensors="pt").to(self.device)
218
 
219
  # Setup streamer
220
  streamer = TextIteratorStreamer(
@@ -241,8 +269,17 @@ class ModelManager:
241
  generated_text += token
242
  yield generated_text
243
 
 
244
  def generate_t2t(self, user_input: str) -> Generator[tuple[str, str], None, None]:
245
  """Generate response from T2T model with streaming (returns context, answer)."""
 
 
 
 
 
 
 
 
246
  # Stage 1: Context generation
247
  context_streamer = TextIteratorStreamer(
248
  self.t2t_model.context_tokenizer,
@@ -257,7 +294,7 @@ class ModelManager:
257
  add_generation_prompt=True,
258
  return_tensors="pt",
259
  enable_thinking=False
260
- ).to(self.device)
261
 
262
  generation_kwargs = {
263
  'input_ids': inputs,
@@ -306,7 +343,7 @@ class ModelManager:
306
  add_generation_prompt=True,
307
  return_tensors="pt",
308
  enable_thinking=False
309
- ).to(self.device)
310
 
311
  generation_kwargs = {
312
  'input_ids': inputs,
@@ -324,13 +361,19 @@ class ModelManager:
324
  answer_text += token
325
  yield context_text, answer_text
326
 
 
327
  def generate_c2c(self, user_input: str) -> Generator[str, None, None]:
328
  """Generate response from C2C model with streaming."""
 
 
 
 
 
329
  messages = [{"role": "system", "content": ""}, {"role": "user", "content": user_input}]
330
  text = self.c2c_tokenizer.apply_chat_template(
331
  messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
332
  )
333
- inputs = self.c2c_tokenizer(text, return_tensors="pt").to(self.device)
334
 
335
  # Setup streamer
336
  streamer = TextIteratorStreamer(
@@ -343,12 +386,12 @@ class ModelManager:
343
  full_length = inputs.input_ids.shape[1]
344
  instruction_index = torch.tensor([1, 0], dtype=torch.long).repeat(
345
  full_length - 1, 1
346
- ).unsqueeze(0).to(self.device)
347
  label_index = torch.tensor([-1, 0], dtype=torch.long).repeat(
348
  1, 1
349
- ).unsqueeze(0).to(self.device)
350
  position_ids = inputs.attention_mask.long().cumsum(-1) - 1 if inputs.attention_mask is not None else \
351
- torch.arange(full_length, dtype=torch.long).unsqueeze(0).to(self.device)
352
 
353
  # Generation parameters
354
  generation_kwargs = {
 
6
  2. T2T: Two-stage inference (shows context + answer)
7
  3. C2C: Rosetta model with projectors
8
 
9
+ ZeroGPU Support:
10
+ - Models are loaded to CPU at startup
11
+ - @spaces.GPU decorator moves models to GPU on-demand for each inference
12
+ - Works seamlessly on both ZeroGPU and regular GPU environments
13
  """
14
 
15
  import os
 
22
  from queue import Queue
23
  from threading import Thread
24
 
25
+ # ZeroGPU support
26
+ try:
27
+ import spaces
28
+ ZEROGPU_AVAILABLE = True
29
+ except ImportError:
30
+ ZEROGPU_AVAILABLE = False
31
+ # Create a no-op decorator for non-ZeroGPU environments
32
+ class spaces:
33
+ @staticmethod
34
+ def GPU(duration=None):
35
+ def decorator(func):
36
+ return func
37
+ return decorator if duration else lambda f: f
38
+
39
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
40
  from rosetta.utils.evaluate import load_rosetta_model, load_hf_model, set_default_chat_template
41
  from rosetta.model.wrapper import RosettaModel
 
63
  c2c_checkpoint_path: Path to C2C checkpoint directory
64
  device: Device to use (cuda, cpu, or auto)
65
  """
66
+ # For ZeroGPU, load models to CPU and move to GPU in decorated functions
67
  if device == "auto":
68
+ if ZEROGPU_AVAILABLE:
69
+ self.device = torch.device("cpu")
70
+ print("ZeroGPU detected: Loading models to CPU (will move to GPU on-demand)")
71
+ else:
72
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
  else:
74
  self.device = torch.device(device)
75
  print(f"Using device: {self.device}")
 
230
 
231
  return kwargs
232
 
233
+ @spaces.GPU(duration=60)
234
  def generate_single(self, user_input: str) -> Generator[str, None, None]:
235
  """Generate response from single model with streaming."""
236
+ # Move model to GPU for ZeroGPU
237
+ device = torch.device("cuda" if ZEROGPU_AVAILABLE else self.device)
238
+ if ZEROGPU_AVAILABLE and self.single_model.device.type != "cuda":
239
+ self.single_model.to(device)
240
+
241
  messages = [{"role": "system", "content": ""}, {"role": "user", "content": user_input}]
242
  text = self.single_tokenizer.apply_chat_template(
243
  messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
244
  )
245
+ inputs = self.single_tokenizer(text, return_tensors="pt").to(device)
246
 
247
  # Setup streamer
248
  streamer = TextIteratorStreamer(
 
269
  generated_text += token
270
  yield generated_text
271
 
272
+ @spaces.GPU(duration=90)
273
  def generate_t2t(self, user_input: str) -> Generator[tuple[str, str], None, None]:
274
  """Generate response from T2T model with streaming (returns context, answer)."""
275
+ # Move models to GPU for ZeroGPU
276
+ device = torch.device("cuda" if ZEROGPU_AVAILABLE else self.device)
277
+ if ZEROGPU_AVAILABLE:
278
+ if self.t2t_model.context_model.device.type != "cuda":
279
+ self.t2t_model.context_model.to(device)
280
+ if self.t2t_model.answer_model.device.type != "cuda":
281
+ self.t2t_model.answer_model.to(device)
282
+
283
  # Stage 1: Context generation
284
  context_streamer = TextIteratorStreamer(
285
  self.t2t_model.context_tokenizer,
 
294
  add_generation_prompt=True,
295
  return_tensors="pt",
296
  enable_thinking=False
297
+ ).to(device)
298
 
299
  generation_kwargs = {
300
  'input_ids': inputs,
 
343
  add_generation_prompt=True,
344
  return_tensors="pt",
345
  enable_thinking=False
346
+ ).to(device)
347
 
348
  generation_kwargs = {
349
  'input_ids': inputs,
 
361
  answer_text += token
362
  yield context_text, answer_text
363
 
364
+ @spaces.GPU(duration=60)
365
  def generate_c2c(self, user_input: str) -> Generator[str, None, None]:
366
  """Generate response from C2C model with streaming."""
367
+ # Move model to GPU for ZeroGPU
368
+ device = torch.device("cuda" if ZEROGPU_AVAILABLE else self.device)
369
+ if ZEROGPU_AVAILABLE and self.c2c_model.device.type != "cuda":
370
+ self.c2c_model.to(device)
371
+
372
  messages = [{"role": "system", "content": ""}, {"role": "user", "content": user_input}]
373
  text = self.c2c_tokenizer.apply_chat_template(
374
  messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
375
  )
376
+ inputs = self.c2c_tokenizer(text, return_tensors="pt").to(device)
377
 
378
  # Setup streamer
379
  streamer = TextIteratorStreamer(
 
386
  full_length = inputs.input_ids.shape[1]
387
  instruction_index = torch.tensor([1, 0], dtype=torch.long).repeat(
388
  full_length - 1, 1
389
+ ).unsqueeze(0).to(device)
390
  label_index = torch.tensor([-1, 0], dtype=torch.long).repeat(
391
  1, 1
392
+ ).unsqueeze(0).to(device)
393
  position_ids = inputs.attention_mask.long().cumsum(-1) - 1 if inputs.attention_mask is not None else \
394
+ torch.arange(full_length, dtype=torch.long).unsqueeze(0).to(device)
395
 
396
  # Generation parameters
397
  generation_kwargs = {
requirements.txt CHANGED
@@ -8,6 +8,9 @@ gradio==5.9.1
8
  # HuggingFace Hub for checkpoint downloads
9
  huggingface-hub>=0.26.0
10
 
 
 
 
11
  # Configuration file parsing
12
  pyyaml>=6.0
13
 
 
8
  # HuggingFace Hub for checkpoint downloads
9
  huggingface-hub>=0.26.0
10
 
11
+ # ZeroGPU support for HuggingFace Spaces
12
+ spaces>=0.30.0
13
+
14
  # Configuration file parsing
15
  pyyaml>=6.0
16