Spaces:
Running
Running
| import streamlit as st | |
| from annotated_text import annotated_text | |
| from refined.inference.processor import Refined | |
| import requests | |
| import json | |
| import spacy | |
| import spacy.cli | |
| import warnings | |
| import logging | |
| from transformers import AutoTokenizer | |
| import os | |
| # Suppress torch warnings | |
| warnings.filterwarnings("ignore", message=".*torch.classes.*") | |
| warnings.filterwarnings("ignore", message=".*__path__._path.*") | |
| # Set logging level to reduce noise | |
| logging.getLogger("torch").setLevel(logging.ERROR) | |
| logging.getLogger("transformers").setLevel(logging.ERROR) | |
| # Page config | |
| st.set_page_config( | |
| page_title="Entity Linking by WordLift", | |
| page_icon="fav-ico.png", | |
| layout="wide", | |
| initial_sidebar_state="collapsed", | |
| menu_items={ | |
| 'Get Help': 'https://wordlift.io/book-a-demo/', | |
| 'About': "# This is a demo app for NEL/NED/NER and SEO" | |
| } | |
| ) | |
| # Sidebar | |
| st.sidebar.image("logo-wordlift.png") | |
| language_options = {"English", "English - spaCy", "German"} | |
| selected_language = st.sidebar.selectbox("Select the Language", list(language_options), index=0) | |
| # Based on selected language, configure model, entity set, and citation options | |
| if selected_language == "German" or selected_language == "English - spaCy": | |
| selected_model_name = None | |
| selected_entity_set = None | |
| entity_fishing_citation = """ | |
| @misc{entity-fishing, | |
| title = {entity-fishing}, | |
| publisher = {GitHub}, | |
| year = {2016--2023}, | |
| archivePrefix = {swh}, | |
| eprint = {1:dir:cb0ba3379413db12b0018b7c3af8d0d2d864139c} | |
| } | |
| """ | |
| with st.sidebar.expander('Citations'): | |
| st.markdown(entity_fishing_citation) | |
| else: | |
| model_options = ["aida_model", "wikipedia_model_with_numbers"] | |
| entity_set_options = ["wikidata", "wikipedia"] | |
| selected_model_name = st.sidebar.selectbox("Select the Model", model_options) | |
| selected_entity_set = st.sidebar.selectbox("Select the Entity Set", entity_set_options) | |
| refined_citation = """ | |
| @inproceedings{ayoola-etal-2022-refined, | |
| title = "{R}e{F}in{ED}: An Efficient Zero-shot-capable Approach to End-to-End Entity Linking", | |
| author = "Tom Ayoola, Shubhi Tyagi, Joseph Fisher, Christos Christodoulopoulos, Andrea Pierleoni", | |
| booktitle = "NAACL", | |
| year = "2022" | |
| } | |
| """ | |
| with st.sidebar.expander('Citations'): | |
| st.markdown(refined_citation) | |
| # 👈 Add the caching decorator | |
| def load_model(selected_language, model_name=None, entity_set=None): | |
| # Suppress warnings during model loading | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| try: | |
| # This block handles the spaCy models for German and English | |
| if selected_language == "German": | |
| try: | |
| nlp_model_de = spacy.load("de_core_news_lg") | |
| except OSError: | |
| st.info("Downloading German language model... This may take a moment.") | |
| spacy.cli.download("de_core_news_lg") | |
| nlp_model_de = spacy.load("de_core_news_lg") | |
| if "entityfishing" not in nlp_model_de.pipe_names: | |
| try: | |
| nlp_model_de.add_pipe("entityfishing") | |
| except Exception as e: | |
| st.warning(f"Entity-fishing not available, using basic NER only: {e}") | |
| return nlp_model_de | |
| elif selected_language == "English - spaCy": | |
| try: | |
| nlp_model_en = spacy.load("en_core_web_sm") | |
| except OSError: | |
| st.info("Downloading English language model... This may take a moment.") | |
| spacy.cli.download("en_core_web_sm") | |
| nlp_model_en = spacy.load("en_core_web_sm") | |
| if "entityfishing" not in nlp_model_en.pipe_names: | |
| try: | |
| nlp_model_en.add_pipe("entityfishing") | |
| except Exception as e: | |
| st.warning(f"Entity-fishing not available, using basic NER only: {e}") | |
| return nlp_model_en | |
| # This block handles the ReFinED model and the "add_special_tokens" error | |
| else: | |
| try: | |
| # First, attempt to load the model as usual | |
| return Refined.from_pretrained(model_name=model_name, entity_set=entity_set) | |
| except Exception as e: | |
| # If the specific "add_special_tokens" error occurs, apply the fix | |
| if "add_special_tokens" in str(e): | |
| st.warning("Conflict detected. Applying fix by modifying tokenizer config...") | |
| # Define a local path to save/load the fixed model | |
| local_model_path = f"./{model_name}-{entity_set}-fixed" | |
| # Download tokenizer, modify config, and save locally | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| tokenizer.save_pretrained(local_model_path) | |
| config_path = os.path.join(local_model_path, "tokenizer_config.json") | |
| with open(config_path, "r") as f: | |
| config_data = json.load(f) | |
| # Remove the conflicting parameter | |
| config_data.pop("add_special_tokens", None) | |
| with open(config_path, "w") as f: | |
| json.dump(config_data, f, indent=2) | |
| # Now, load the model from the local, fixed path | |
| st.success("Fix applied. Loading model from local cache.") | |
| return Refined.from_pretrained(model_name=local_model_path, entity_set=entity_set) | |
| else: | |
| # If it's a different error, raise it | |
| raise e | |
| except Exception as e: | |
| st.error(f"Error loading model: {e}") | |
| return None | |
| # Use the cached model | |
| model = load_model(selected_language, selected_model_name, selected_entity_set) | |
| # Helper functions | |
| def get_wikidata_id(entity_string): | |
| entity_list = entity_string.split("=") | |
| entity_id = str(entity_list[1]) | |
| entity_link = "http://www.wikidata.org/entity/" + entity_id | |
| return {"id": entity_id, "link": entity_link} | |
| def get_entity_data(entity_link): | |
| try: | |
| # Format the entity_link | |
| formatted_link = entity_link.replace("http://", "http/") | |
| response = requests.get(f'https://api.wordlift.io/id/{formatted_link}') | |
| return response.json() | |
| except Exception as e: | |
| print(f"Exception when fetching data for entity: {entity_link}. Exception: {e}") | |
| return None | |
| # Create the form | |
| with st.form(key='my_form'): | |
| text_input = st.text_area(label='Enter a sentence') | |
| submit_button = st.form_submit_button(label='Analyze') | |
| # Initialization | |
| entities_map = {} | |
| entities_data = {} | |
| if text_input and model is not None: | |
| try: | |
| if selected_language in ["German", "English - spaCy"]: | |
| # Process the text with error handling | |
| doc = model(text_input) | |
| # Fixed the syntax error: ent._.kb_qid instead of ent..kb_qid | |
| entities = [] | |
| for ent in doc.ents: | |
| try: | |
| # Check if the custom attributes exist | |
| kb_qid = getattr(ent._, 'kb_qid', None) if hasattr(ent, '_') else None | |
| url_wikidata = getattr(ent._, 'url_wikidata', None) if hasattr(ent, '_') else None | |
| entities.append((ent.text, ent.label_, kb_qid, url_wikidata)) | |
| except AttributeError as e: | |
| # If the entityfishing attributes don't exist, use basic entity info | |
| entities.append((ent.text, ent.label_, None, None)) | |
| for entity in entities: | |
| entity_string, entity_type, wikidata_id, wikidata_url = entity | |
| if wikidata_url: | |
| # Ensure correct format for the German and English model | |
| formatted_wikidata_url = wikidata_url.replace("https://www.wikidata.org/wiki/", "http://www.wikidata.org/entity/") | |
| entities_map[entity_string] = {"id": wikidata_id, "link": formatted_wikidata_url} | |
| entity_data = get_entity_data(formatted_wikidata_url) | |
| if entity_data is not None: | |
| entities_data[entity_string] = entity_data | |
| else: | |
| entities = model.process_text(text_input) | |
| for entity in entities: | |
| single_entity_list = str(entity).strip('][').replace("\'", "").split(', ') | |
| if len(single_entity_list) >= 2 and "wikidata" in single_entity_list[1]: | |
| entities_map[single_entity_list[0].strip()] = get_wikidata_id(single_entity_list[1]) | |
| entity_data = get_entity_data(entities_map[single_entity_list[0].strip()]["link"]) | |
| if entity_data is not None: | |
| entities_data[single_entity_list[0].strip()] = entity_data | |
| except Exception as e: | |
| st.error(f"Error processing text: {e}") | |
| if "entityfishing" in str(e).lower(): | |
| st.error("This appears to be an entity-fishing related error. Please ensure:") | |
| st.error("1. Entity-fishing service is running") | |
| st.error("2. spacyfishing package is properly installed") | |
| st.error("3. Network connectivity to entity-fishing service") | |
| # Combine entity information | |
| combined_entity_info_dictionary = dict([(k, [entities_map[k], entities_data[k] if k in entities_data else None]) for k in entities_map]) | |
| if submit_button and entities_map: | |
| # Prepare a list to hold the final output | |
| final_text = [] | |
| # JSON-LD data | |
| json_ld_data = { | |
| "@context": "https://schema.org", | |
| "@type": "WebPage", | |
| "mentions": [] | |
| } | |
| # Replace each entity in the text with its annotated version | |
| for entity_string, entity_info in entities_map.items(): | |
| # Check if the entity has a valid Wikidata link | |
| if entity_info["link"] is None or entity_info["link"] == "None": | |
| continue # skip this entity | |
| entity_data = entities_data.get(entity_string, None) | |
| entity_type = None | |
| if entity_data is not None: | |
| entity_type = entity_data.get("@type", None) | |
| # Use different colors based on the entity's type | |
| color = "#8ef" # Default color | |
| if entity_type == "Place": | |
| color = "#8AC7DB" | |
| elif entity_type == "Organization": | |
| color = "#ADD8E6" | |
| elif entity_type == "Person": | |
| color = "#67B7D1" | |
| elif entity_type == "Product": | |
| color = "#2ea3f2" | |
| elif entity_type == "CreativeWork": | |
| color = "#00BFFF" | |
| elif entity_type == "Event": | |
| color = "#1E90FF" | |
| entity_annotation = (entity_string, entity_info["id"], color) | |
| text_input = text_input.replace(entity_string, f'{{{str(entity_annotation)}}}', 1) | |
| # Add the entity to JSON-LD data | |
| entity_json_ld = combined_entity_info_dictionary[entity_string][1] | |
| if entity_json_ld and entity_json_ld.get("link") != "None": | |
| json_ld_data["mentions"].append(entity_json_ld) | |
| # Split the modified text_input into a list | |
| text_list = text_input.split("{") | |
| for item in text_list: | |
| if "}" in item: | |
| item_list = item.split("}") | |
| try: | |
| final_text.append(eval(item_list[0])) | |
| except: | |
| final_text.append(item_list[0]) | |
| if len(item_list) > 1 and len(item_list[1]) > 0: | |
| final_text.append(item_list[1]) | |
| else: | |
| final_text.append(item) | |
| # Pass the final_text to the annotated_text function | |
| annotated_text(*final_text) | |
| with st.expander("See annotations"): | |
| st.write(combined_entity_info_dictionary) | |
| with st.expander("Here is the final JSON-LD"): | |
| st.json(json_ld_data) # Output JSON-LD | |
| elif submit_button and not entities_map: | |
| st.warning("No entities found in the text. Please try with different text or check if the model is working correctly.") |