Spaces:
Runtime error
Runtime error
Commit
·
e496e33
1
Parent(s):
e556404
minor
Browse files
app.py
CHANGED
|
@@ -76,7 +76,7 @@ def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
|
|
| 76 |
print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
|
| 77 |
|
| 78 |
index_sample = torch.tensor([output.outputs[0].token_ids for output in outputs], device=device)
|
| 79 |
-
if
|
| 80 |
index_sample = index_sample[:len(class_labels)]
|
| 81 |
t2 = time.time()
|
| 82 |
samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
|
|
|
|
| 76 |
print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
|
| 77 |
|
| 78 |
index_sample = torch.tensor([output.outputs[0].token_ids for output in outputs], device=device)
|
| 79 |
+
if cfg_scale > 1.0:
|
| 80 |
index_sample = index_sample[:len(class_labels)]
|
| 81 |
t2 = time.time()
|
| 82 |
samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
|