Spaces:
Runtime error
Runtime error
Fixes to image resolution
Browse files- app.py +5 -6
- gill/models.py +3 -3
app.py
CHANGED
|
@@ -115,14 +115,13 @@ def generate_for_prompt(input_text, state, ret_scale_factor, num_words, temperat
|
|
| 115 |
elif type(p) == dict:
|
| 116 |
# Decide whether to generate or retrieve.
|
| 117 |
if p['decision'] is not None and p['decision'][0] == 'gen':
|
| 118 |
-
image = p['gen'][0][0]
|
| 119 |
filename = save_image_to_local(image)
|
| 120 |
-
response += f'<img src="./file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555;">(Generated)</p>'
|
| 121 |
else:
|
| 122 |
-
image = p['ret'][0][0]
|
| 123 |
filename = save_image_to_local(image)
|
| 124 |
-
response += f'<img src="./file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555;">(Retrieved)</p>'
|
| 125 |
-
|
| 126 |
|
| 127 |
chat_history = model_inputs + \
|
| 128 |
[' '.join([s for s in model_outputs if type(s) == str]) + '\n']
|
|
@@ -180,7 +179,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 180 |
share_button = gr.Button("π€ Share to Community (opens new window)", elem_id="share-btn")
|
| 181 |
|
| 182 |
with gr.Column(scale=0.3, min_width=400):
|
| 183 |
-
ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.
|
| 184 |
label="Frequency multiplier for returning images (higher means more frequent)")
|
| 185 |
# max_ret_images = gr.Number(
|
| 186 |
# minimum=0, maximum=3, value=2, precision=1, interactive=True, label="Max images to return")
|
|
|
|
| 115 |
elif type(p) == dict:
|
| 116 |
# Decide whether to generate or retrieve.
|
| 117 |
if p['decision'] is not None and p['decision'][0] == 'gen':
|
| 118 |
+
image = p['gen'][0][0]#.resize((224, 224))
|
| 119 |
filename = save_image_to_local(image)
|
| 120 |
+
response += f'<img src="./file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555; margin-top: 0;">(Generated)</p>'
|
| 121 |
else:
|
| 122 |
+
image = p['ret'][0][0]#.resize((224, 224))
|
| 123 |
filename = save_image_to_local(image)
|
| 124 |
+
response += f'<img src="./file={filename}" style="display: inline-block;"><p style="font-size: 12px; color: #555; margin-top: 0;">(Retrieved)</p>'
|
|
|
|
| 125 |
|
| 126 |
chat_history = model_inputs + \
|
| 127 |
[' '.join([s for s in model_outputs if type(s) == str]) + '\n']
|
|
|
|
| 179 |
share_button = gr.Button("π€ Share to Community (opens new window)", elem_id="share-btn")
|
| 180 |
|
| 181 |
with gr.Column(scale=0.3, min_width=400):
|
| 182 |
+
ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.2, step=0.1, interactive=True,
|
| 183 |
label="Frequency multiplier for returning images (higher means more frequent)")
|
| 184 |
# max_ret_images = gr.Number(
|
| 185 |
# minimum=0, maximum=3, value=2, precision=1, interactive=True, label="Max images to return")
|
gill/models.py
CHANGED
|
@@ -878,10 +878,10 @@ def load_gill(embeddings_dir: str, model_args_path: str, model_ckpt_path: str, d
|
|
| 878 |
model = GILL(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix,
|
| 879 |
load_sd=not debug, num_gen_images=1, decision_model_path=decision_model_path)
|
| 880 |
model = model.eval()
|
| 881 |
-
if
|
| 882 |
-
model = model.bfloat16()
|
| 883 |
-
model = model.cuda()
|
| 884 |
|
|
|
|
| 885 |
# Load pretrained linear mappings and [IMG] embeddings.
|
| 886 |
checkpoint = torch.load(model_ckpt_path)
|
| 887 |
state_dict = {}
|
|
|
|
| 878 |
model = GILL(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix,
|
| 879 |
load_sd=not debug, num_gen_images=1, decision_model_path=decision_model_path)
|
| 880 |
model = model.eval()
|
| 881 |
+
if torch.cuda.is_available():
|
| 882 |
+
model = model.bfloat16().cuda()
|
|
|
|
| 883 |
|
| 884 |
+
if not debug:
|
| 885 |
# Load pretrained linear mappings and [IMG] embeddings.
|
| 886 |
checkpoint = torch.load(model_ckpt_path)
|
| 887 |
state_dict = {}
|