Spaces:
Running
Running
| import onnxruntime as ort | |
| from transformers import AutoTokenizer | |
| import numpy as np | |
| import requests | |
| import os | |
| VERSION = "v0.1.1" | |
| class LocationFinder: | |
| def __init__(self): | |
| self.tokenizer = AutoTokenizer.from_pretrained("Mozilla/distilbert-uncased-NER-LoRA") | |
| model_url = f"https://huggingface.co/Mozilla/distilbert-uncased-NER-LoRA/resolve/{VERSION}/onnx/model_quantized.onnx" | |
| model_dir_path = "models" | |
| model_path = f"{model_dir_path}/distilbert-uncased-NER-LoRA" | |
| if not os.path.exists(model_dir_path): | |
| os.makedirs(model_dir_path) | |
| if not os.path.exists(model_path): | |
| print("Downloading ONNX model...") | |
| response = requests.get(model_url) | |
| with open(model_path, "wb") as f: | |
| f.write(response.content) | |
| print("ONNX model downloaded.") | |
| # Load the ONNX model | |
| self.ort_session = ort.InferenceSession(model_path) | |
| def find_location(self, sequence, verbose=False): | |
| inputs = self.tokenizer(sequence, | |
| return_tensors="np", # ONNX requires inputs in NumPy format | |
| padding="max_length", # Pad to max length | |
| truncation=True, # Truncate if the text is too long | |
| max_length=64) | |
| input_feed = { | |
| 'input_ids': inputs['input_ids'].astype(np.int64), | |
| 'attention_mask': inputs['attention_mask'].astype(np.int64), | |
| } | |
| # Run inference with the ONNX model | |
| outputs = self.ort_session.run(None, input_feed) | |
| logits = outputs[0] # Assuming the model output is logits | |
| probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True) | |
| predicted_ids = np.argmax(logits, axis=-1) | |
| predicted_probs = np.max(probabilities, axis=-1) | |
| # Define the threshold for NER probability | |
| threshold = 0.6 | |
| # Define the label map for city, state, citystate, etc. | |
| label_map = { | |
| 0: "O", # Outside any named entity | |
| 1: "B-PER", # Beginning of a person entity | |
| 2: "I-PER", # Inside a person entity | |
| 3: "B-ORG", # Beginning of an organization entity | |
| 4: "I-ORG", # Inside an organization entity | |
| 5: "B-CITY", # Beginning of a city entity | |
| 6: "I-CITY", # Inside a city entity | |
| 7: "B-STATE", # Beginning of a state entity | |
| 8: "I-STATE", # Inside a state entity | |
| 9: "B-CITYSTATE", # Beginning of a city_state entity | |
| 10: "I-CITYSTATE", # Inside a city_state entity | |
| } | |
| tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0]) | |
| # Initialize lists to hold detected entities | |
| city_entities = [] | |
| state_entities = [] | |
| city_state_entities = [] | |
| for token, predicted_id, prob in zip(tokens, predicted_ids[0], predicted_probs[0]): | |
| if prob > threshold: | |
| if token in ["[CLS]", "[SEP]", "[PAD]"]: | |
| continue | |
| if label_map[predicted_id] in ["B-CITY", "I-CITY"]: | |
| # Handle the case of continuation tokens (like "##" in subwords) | |
| if token.startswith("##") and city_entities: | |
| city_entities[-1] += token[2:] # Remove "##" and append to the last token | |
| else: | |
| city_entities.append(token) | |
| elif label_map[predicted_id] in ["B-STATE", "I-STATE"]: | |
| if token.startswith("##") and state_entities: | |
| state_entities[-1] += token[2:] | |
| else: | |
| state_entities.append(token) | |
| elif label_map[predicted_id] in ["B-CITYSTATE", "I-CITYSTATE"]: | |
| if token.startswith("##") and city_state_entities: | |
| city_state_entities[-1] += token[2:] | |
| else: | |
| city_state_entities.append(token) | |
| # Combine city_state entities and split into city and state if necessary | |
| if city_state_entities: | |
| city_state_str = " ".join(city_state_entities) | |
| city_state_split = city_state_str.split(",") # Split on comma to separate city and state | |
| city_res = city_state_split[0].strip() if city_state_split[0] else None | |
| state_res = city_state_split[1].strip() if len(city_state_split) > 1 else None | |
| else: | |
| # If no city_state entities, use detected city and state entities separately | |
| city_res = " ".join(city_entities).strip() if city_entities else None | |
| state_res = " ".join(state_entities).strip() if state_entities else None | |
| # Return the detected city and state as separate components | |
| return { | |
| 'city': city_res, | |
| 'state': state_res | |
| } | |
| if __name__ == '__main__': | |
| query = "weather in san francisco, ca" | |
| loc_finder = LocationFinder() | |
| entities = loc_finder.find_location(query) | |
| print(f"query = {query} => {entities}") | |