Spaces:
Running
on
Zero
Running
on
Zero
fancyfeast
commited on
Commit
·
5d57e40
1
Parent(s):
f73cf3f
Improve handling caption tone special case. Also, derp, forgot to format the prompt string.
Browse files
app.py
CHANGED
|
@@ -144,12 +144,20 @@ image_adapter.to("cuda")
|
|
| 144 |
def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int) -> str:
|
| 145 |
torch.cuda.empty_cache()
|
| 146 |
|
|
|
|
| 147 |
length = None if caption_length == "any" else caption_length
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
|
| 149 |
if prompt_key not in CAPTION_TYPE_MAP:
|
| 150 |
raise ValueError(f"Invalid caption type: {prompt_key}")
|
| 151 |
|
| 152 |
-
prompt_str = CAPTION_TYPE_MAP[prompt_key][0]
|
|
|
|
| 153 |
|
| 154 |
# Preprocess image
|
| 155 |
#image = clip_processor(images=input_image, return_tensors='pt').pixel_values
|
|
@@ -230,6 +238,8 @@ with gr.Blocks() as demo:
|
|
| 230 |
value="any",
|
| 231 |
)
|
| 232 |
|
|
|
|
|
|
|
| 233 |
run_button = gr.Button("Caption")
|
| 234 |
|
| 235 |
with gr.Column():
|
|
|
|
| 144 |
def stream_chat(input_image: Image.Image, caption_type: str, caption_tone: str, caption_length: str | int) -> str:
|
| 145 |
torch.cuda.empty_cache()
|
| 146 |
|
| 147 |
+
# 'any' means no length specified
|
| 148 |
length = None if caption_length == "any" else caption_length
|
| 149 |
+
|
| 150 |
+
# 'rng-tags' and 'training_prompt' don't have formal/informal tones
|
| 151 |
+
if caption_type == "rng-tags" or caption_type == "training_prompt":
|
| 152 |
+
caption_tone = "formal"
|
| 153 |
+
|
| 154 |
+
# Build prompt
|
| 155 |
prompt_key = (caption_type, caption_tone, isinstance(length, str), isinstance(length, int))
|
| 156 |
if prompt_key not in CAPTION_TYPE_MAP:
|
| 157 |
raise ValueError(f"Invalid caption type: {prompt_key}")
|
| 158 |
|
| 159 |
+
prompt_str = CAPTION_TYPE_MAP[prompt_key][0].format(length=length, word_count=length)
|
| 160 |
+
print(f"Prompt: {prompt_str}")
|
| 161 |
|
| 162 |
# Preprocess image
|
| 163 |
#image = clip_processor(images=input_image, return_tensors='pt').pixel_values
|
|
|
|
| 238 |
value="any",
|
| 239 |
)
|
| 240 |
|
| 241 |
+
gr.Markdown("**Note:** Caption tone doesn't affect `rng-tags` and `training_prompt`.")
|
| 242 |
+
|
| 243 |
run_button = gr.Button("Caption")
|
| 244 |
|
| 245 |
with gr.Column():
|