Spaces:
Runtime error
Runtime error
| from io import BytesIO | |
| import streamlit as st | |
| import pandas as pd | |
| import json | |
| import os | |
| import numpy as np | |
| from streamlit.elements import markdown | |
| from PIL import Image | |
| from model.flax_clip_vision_mbart.modeling_clip_vision_mbart import ( | |
| FlaxCLIPVisionMBartForConditionalGeneration, | |
| ) | |
| from transformers import MBart50TokenizerFast | |
| from utils import ( | |
| get_transformed_image, | |
| ) | |
| import matplotlib.pyplot as plt | |
| from mtranslate import translate | |
| from session import _get_state | |
| state = _get_state() | |
| def load_model(ckpt): | |
| return FlaxCLIPVisionMBartForConditionalGeneration.from_pretrained(ckpt) | |
| tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50") | |
| language_mapping = { | |
| "en": "en_XX", | |
| "de": "de_DE", | |
| "fr": "fr_XX", | |
| "es": "es_XX" | |
| } | |
| code_to_name = { | |
| "en": "English", | |
| "fr": "French", | |
| "de": "German", | |
| "es": "Spanish", | |
| } | |
| def generate_sequence(pixel_values, lang_code, num_beams): | |
| lang_code = language_mapping[lang_code] | |
| output_ids = model.generate(input_ids=pixel_values, forced_bos_token_id=tokenizer.lang_code_to_id[lang_code], max_length=64, num_beams=num_beams) | |
| print(output_ids) | |
| output_sequence = tokenizer.batch_decode(output_ids[0], skip_special_tokens=True, max_length=64) | |
| return output_sequence | |
| def read_markdown(path, parent="./sections/"): | |
| with open(os.path.join(parent, path)) as f: | |
| return f.read() | |
| checkpoints = ["./ckpt/ckpt-22499"] # TODO: Maybe add more checkpoints? | |
| dummy_data = pd.read_csv("reference.tsv", sep="\t") | |
| st.set_page_config( | |
| page_title="Multilingual Image Captioning", | |
| layout="wide", | |
| initial_sidebar_state="collapsed", | |
| ) | |
| st.title("Multilingual Image Captioning") | |
| st.write( | |
| "[Bhavitvya Malik](https://huggingface.co/bhavitvyamalik), [Gunjan Chhablani](https://huggingface.co/gchhablani)" | |
| ) | |
| st.sidebar.title("Settings") | |
| num_beams = st.sidebar.number_input(label="Number of Beams", min_value=2, max_value=10, value=4, step=1, help="Number of beams to be used in beam search.") | |
| with st.beta_expander("Usage"): | |
| st.markdown(read_markdown("usage.md")) | |
| first_index = 20 | |
| # Init Session State | |
| if state.image_file is None: | |
| state.image_file = dummy_data.loc[first_index, "image_file"] | |
| state.caption = dummy_data.loc[first_index, "caption"].strip("- ") | |
| state.lang_id = dummy_data.loc[first_index, "lang_id"] | |
| image_path = os.path.join("images", state.image_file) | |
| image = plt.imread(image_path) | |
| state.image = image | |
| col1, col2 = st.beta_columns([6, 4]) | |
| if col2.button("Get a random example"): | |
| sample = dummy_data.sample(1).reset_index() | |
| state.image_file = sample.loc[0, "image_file"] | |
| state.caption = sample.loc[0, "caption"].strip("- ") | |
| state.lang_id = sample.loc[0, "lang_id"] | |
| image_path = os.path.join("images", state.image_file) | |
| image = plt.imread(image_path) | |
| state.image = image | |
| col2.write("OR") | |
| uploaded_file = col2.file_uploader("Upload your image", type=["png", "jpg", "jpeg"]) | |
| if uploaded_file is not None: | |
| state.image_file = os.path.join("images", uploaded_file.name) | |
| state.image = np.array(Image.open(uploaded_file)) | |
| transformed_image = get_transformed_image(state.image) | |
| # Display Image | |
| col1.image(state.image, use_column_width="auto") | |
| # Display Reference Caption | |
| col2.write("**Reference Caption**: " + state.caption) | |
| col2.markdown( | |
| f"""**English Translation**: {state.caption if state.lang_id == "en" else translate(state.caption, 'en')}""" | |
| ) | |
| # Select Language | |
| options = list(code_to_name.keys()) | |
| lang_id = col2.selectbox( | |
| "Language", | |
| index=options.index(state.lang_id), | |
| options=options, | |
| format_func=lambda x: code_to_name[x], | |
| ) | |
| # Display Top-5 Predictions | |
| with st.spinner("Loading model..."): | |
| model = load_model(checkpoints[0]) | |
| sequence = [''] | |
| if col2.button("Generate Caption"): | |
| with st.spinner("Generating Sequence..."): | |
| sequence = generate_sequence(transformed_image, lang_id, num_beams) | |
| # print(sequence) | |
| if sequence!=['']: | |
| st.write( | |
| "**Generated Caption**: "+sequence[0] | |
| ) | |
| st.write( | |
| "**English Translation**: "+ sequence[0] if lang_id=="en" else translate(sequence[0]) | |
| ) | |
| st.write(read_markdown("abstract.md")) | |
| st.write(read_markdown("caveats.md")) | |
| # st.write("# Methodology") | |
| # st.image( | |
| # "./misc/Multilingual-IC.png", caption="Seq2Seq model for Image-text Captioning." | |
| # ) | |
| st.markdown(read_markdown("pretraining.md")) | |
| st.write(read_markdown("challenges.md")) | |
| st.write(read_markdown("social_impact.md")) | |
| st.write(read_markdown("references.md")) | |
| # st.write(read_markdown("checkpoints.md")) | |
| st.write(read_markdown("acknowledgements.md")) | |