Spaces:
Runtime error
Runtime error
| from .utils import get_transformed_image | |
| import streamlit as st | |
| import numpy as np | |
| import pandas as pd | |
| import os | |
| import matplotlib.pyplot as plt | |
| import re | |
| from mtranslate import translate | |
| from .utils import ( | |
| read_markdown, | |
| tokenizer, | |
| language_mapping, | |
| code_to_name, | |
| voicerss_tts | |
| ) | |
| import requests | |
| from PIL import Image | |
| from .model.flax_clip_vision_mbart.modeling_clip_vision_mbart import ( | |
| FlaxCLIPVisionMBartForConditionalGeneration, | |
| ) | |
| from streamlit import caching | |
| def app(state): | |
| mic_state = state | |
| with st.beta_expander("Usage"): | |
| st.write(read_markdown("usage.md")) | |
| st.write("\n") | |
| st.write(read_markdown("intro.md")) | |
| # st.sidebar.title("Generation Parameters") | |
| max_length = 64 | |
| with st.sidebar.beta_expander('Generation Parameters'): | |
| do_sample = st.checkbox("Sample", value=False, help="Sample from the model instead of using beam search.") | |
| top_k = st.number_input("Top K", min_value=10, max_value=200, value=50, step=1, help="The number of highest probability vocabulary tokens to keep for top-k-filtering.") | |
| num_beams = st.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.") | |
| temperature = st.select_slider(label="Temperature", options = list(np.arange(0.0,1.1, step=0.1)), value=1.0, help ="The value used to module the next token probabilities.", format_func=lambda x: f"{x:.2f}") | |
| top_p = st.select_slider(label = "Top-P", options = list(np.arange(0.0,1.1, step=0.1)),value=1.0, help="Nucleus Sampling : If set to float < 1, only the most probable tokens with probabilities that add up to :obj:`top_p` or higher are kept for generation.", format_func=lambda x: f"{x:.2f}") | |
| if st.button("Clear All Cache"): | |
| caching.clear_cache() | |
| def load_model(ckpt): | |
| return FlaxCLIPVisionMBartForConditionalGeneration.from_pretrained(ckpt) | |
| def generate_sequence(pixel_values, lang_code, num_beams, temperature, top_p, do_sample, top_k, max_length): | |
| lang_code = language_mapping[lang_code] | |
| output_ids = mic_state.model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=max_length, num_beams=num_beams, temperature=temperature, top_p = top_p, top_k=top_k, do_sample=do_sample) | |
| print(output_ids) | |
| output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=max_length) | |
| return output_sequence | |
| mic_checkpoints = ["flax-community/clip-vit-base-patch32_mbart-large-50"] # TODO: Maybe add more checkpoints? | |
| dummy_data = pd.read_csv("reference.tsv", sep="\t") | |
| first_index = 25 | |
| # Init Session State | |
| if mic_state.image_file is None: | |
| mic_state.image_file = dummy_data.loc[first_index, "image_file"] | |
| mic_state.caption = dummy_data.loc[first_index, "caption"].strip("- ") | |
| mic_state.lang_id = dummy_data.loc[first_index, "lang_id"] | |
| image_path = os.path.join("images", mic_state.image_file) | |
| image = plt.imread(image_path) | |
| mic_state.image = image | |
| if mic_state.model is None: | |
| # Display Top-5 Predictions | |
| with st.spinner("Loading model..."): | |
| mic_state.model = load_model(mic_checkpoints[0]) | |
| query1 = st.text_input( | |
| "Enter a URL to an image", | |
| value="http://images.cocodataset.org/val2017/000000397133.jpg", | |
| ) | |
| col1, col2, col3 = st.beta_columns([2,1, 2]) | |
| if col1.button( | |
| "Get a random example", | |
| help="Get a random example from the 100 `seeded` image-text pairs.", | |
| ): | |
| sample = dummy_data.sample(1).reset_index() | |
| mic_state.image_file = sample.loc[0, "image_file"] | |
| mic_state.caption = sample.loc[0, "caption"].strip("- ") | |
| mic_state.lang_id = sample.loc[0, "lang_id"] | |
| image_path = os.path.join("images", mic_state.image_file) | |
| image = plt.imread(image_path) | |
| mic_state.image = image | |
| col2.write("OR") | |
| if col3.button("Use above URL"): | |
| image_data = requests.get(query1, stream=True).raw | |
| image = np.asarray(Image.open(image_data)) | |
| mic_state.image = image | |
| transformed_image = get_transformed_image(mic_state.image) | |
| new_col1, new_col2 = st.beta_columns([5,5]) | |
| # Display Image | |
| new_col1.image(mic_state.image, use_column_width="always") | |
| # Display Reference Caption | |
| with new_col1.beta_expander("Reference Caption"): | |
| st.write("**Reference Caption**: " + mic_state.caption) | |
| st.markdown( | |
| f"""**English Translation**: {mic_state.caption if mic_state.lang_id == "en" else translate(mic_state.caption, 'en')}""" | |
| ) | |
| # Select Language | |
| options = list(code_to_name.keys()) | |
| lang_id = new_col2.selectbox( | |
| "Language", | |
| index=options.index(mic_state.lang_id), | |
| options=options, | |
| format_func=lambda x: code_to_name[x], | |
| help="The language in which caption is to be generated." | |
| ) | |
| sequence = [''] | |
| if new_col2.button("Generate Caption", help="Generate a caption in the specified language."): | |
| with st.spinner("Generating Sequence... This might take some time, you can read our Article meanwhile!"): | |
| sequence = generate_sequence(transformed_image, lang_id, num_beams, temperature, top_p, do_sample, top_k, max_length) | |
| # print(sequence) | |
| if sequence!=['']: | |
| new_col2.write( | |
| "**Generated Caption**: "+sequence[0] | |
| ) | |
| new_col2.write( | |
| "**English Translation**: "+ (sequence[0] if lang_id=="en" else translate(sequence[0])) | |
| ) | |
| with new_col2: | |
| try: | |
| clean_text = re.sub(r'[^A-Za-z0-9 ]+', '', sequence[0]) | |
| # st.write("**Cleaned Text**: ",clean_text) | |
| audio_bytes = voicerss_tts(clean_text, lang_id) | |
| st.markdown("**Audio for the generated caption**") | |
| st.audio(audio_bytes) | |
| except: | |
| st.info("Unabled to generate audio. Please try again in some time.") |