import gradio as gr import tensorflow as tf import numpy as np import requests import math from huggingface_hub import hf_hub_download # --- Global Data Storage --- ROUTE_DATA = {} STOP_DATA = {} ALL_ROUTE_KEYS = [] # --- 1. Fetch Data (Routes & Stops) --- print("Fetching route and stop data...") try: resp = requests.get("https://hkbus.github.io/hk-bus-crawling/routeFareList.min.json") if resp.status_code == 200: json_db = resp.json() raw_routes = json_db['routeList'] STOP_DATA = json_db['stopList'] valid_companies = ['kmb', 'ctb'] for key, info in raw_routes.items(): if 'co' in info and len(info['co']) > 0: company = info['co'][0] if company in valid_companies: ROUTE_DATA[key] = info ALL_ROUTE_KEYS.append(key) ALL_ROUTE_KEYS.sort() else: print("Failed to download route data") except Exception as e: print(f"Error fetching data: {e}") print(f"Loaded {len(ALL_ROUTE_KEYS)} valid KMB/CTB routes.") # --- 2. Download and Load Model --- print("Downloading model...") try: model_path = hf_hub_download(repo_id="WheelsTransit/HK-TransitFlow-Net", filename="hk_transit_flow_net.keras") print("Loading Keras model...") model = tf.keras.models.load_model(model_path) except Exception as e: print(f"Model load failed: {e}") model = None # --- Helpers --- DAY_MAP = { "Sunday": 0, "Monday": 1, "Tuesday": 2, "Wednesday": 3, "Thursday": 4, "Friday": 5, "Saturday": 6 } def haversine_distance(coords): R = 6371000 total_dist = 0 for i in range(len(coords) - 1): lon1, lat1 = coords[i] lon2, lat2 = coords[i+1] dlon = math.radians(lon2 - lon1) dlat = math.radians(lat2 - lat1) a = math.sin(dlat/2)**2 + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlon/2)**2 c = 2 * math.asin(math.sqrt(a)) total_dist += R * c return total_dist # --- Dynamic UI Logic --- def filter_routes(search_text): if not search_text: return gr.Dropdown(choices=["UNKNOWN"] + ALL_ROUTE_KEYS[:20]) search_text = search_text.lower() filtered = [r for r in ALL_ROUTE_KEYS if search_text in r.lower()] return gr.Dropdown(choices=["UNKNOWN"] + filtered[:100], value="UNKNOWN") def update_stop_dropdowns(route_key): if not route_key or route_key == "UNKNOWN" or route_key not in ROUTE_DATA: return gr.Dropdown(choices=[], value=None), gr.Dropdown(choices=[], value=None) route_info = ROUTE_DATA[route_key] company = route_info['co'][0] stop_ids = route_info['stops'].get(company, []) stop_options = [] for idx, sid in enumerate(stop_ids): name_en = "Unknown" if sid in STOP_DATA: name_en = STOP_DATA[sid]['name']['en'] label = f"{idx+1}. {name_en} ({sid})" stop_options.append(label) return gr.Dropdown(choices=stop_options, value=None), gr.Dropdown(choices=stop_options, value=None) def calculate_real_metrics(route_key, start_str, end_str): if route_key == "UNKNOWN" or not start_str or not end_str: return None, None, "Wait" try: start_idx = int(start_str.split(".")[0]) - 1 end_idx = int(end_str.split(".")[0]) - 1 if start_idx >= end_idx: return None, None, "Start must be before End" route_info = ROUTE_DATA[route_key] gtfs_id = route_info.get('gtfsId') company = route_info['co'][0] bound = route_info['bound'].get(company) if not gtfs_id or not bound: return None, None, "No Map Data" url = f"https://hkbus.github.io/route-waypoints/{gtfs_id}-{bound}.json" resp = requests.get(url) if resp.status_code != 200: return None, None, "Map Download Fail" geojson = resp.json() features = geojson.get('features', []) segments = [] if features and features[0]['geometry']['type'] == 'MultiLineString': segments = features[0]['geometry']['coordinates'] elif features: segments = [f['geometry']['coordinates'] for f in features] total_dist = 0 for i in range(start_idx, end_idx): if i < len(segments): total_dist += haversine_distance(segments[i]) num_stops = end_idx - start_idx return total_dist, num_stops, None except Exception as e: return None, None, str(e) def auto_fill_metrics(route_key, start_str, end_str, current_dist, current_stops): """Updates boxes when stops change.""" dist, stops, error = calculate_real_metrics(route_key, start_str, end_str) if dist is not None and stops is not None: return round(dist, 1), int(stops) else: return current_dist, current_stops # --- Prediction Logic --- def predict_fn(manual_dist, manual_stops, hour, day_name, route_id, start_str, end_str): status_tag = "" # Check if inputs match the map data (Validation Check) map_dist, map_stops, _ = calculate_real_metrics(route_id, start_str, end_str) if map_dist is not None: # We have map data. Check if manual input differs significantly. diff_dist = abs(map_dist - float(manual_dist)) diff_stops = abs(map_stops - float(manual_stops)) if diff_dist > 50 or diff_stops > 0: # 50m tolerance status_tag = "Manual Inputs" else: status_tag = "Computed Data" else: status_tag = "Manual Inputs" try: inputs = { 'distance': np.array([[float(manual_dist)]]), 'num_stops': np.array([[float(manual_stops)]]), 'hour': np.array([[int(hour)]]), 'day_of_week': np.array([[int(DAY_MAP[day_name])]]), 'route_id': tf.constant([[str(route_id)]], dtype=tf.string) } prediction = model.predict(inputs, verbose=0) seconds = float(prediction[0][0]) minutes = int(seconds // 60) rem_seconds = int(seconds % 60) return f"{status_tag}\n\n⏱️ ETA: {minutes} min {rem_seconds} sec" except Exception as e: return f"Model Error: {str(e)}" # --- 3. Build the UI --- with gr.Blocks(title="HK-TransitFlow-Net") as demo: gr.Markdown("# HK-TransitFlow-Net Demo") gr.Markdown("Predicts KMB/CTB bus travel time.") gr.Markdown("Model: https://huggingface.co/WheelsTransit/HK-TransitFlow-Net") with gr.Row(): with gr.Column(): gr.Markdown("### 1. Route Selection (Optional)") gr.Markdown("Select a route to auto-fill distance, or skip to type manually.") route_search = gr.Textbox(label="Search Route", placeholder="Type e.g. '968'") route_dropdown = gr.Dropdown(label="Select Route ID", choices=["UNKNOWN"], value="UNKNOWN", interactive=True) with gr.Row(): start_dropdown = gr.Dropdown(label="Start Stop", choices=[], interactive=True) end_dropdown = gr.Dropdown(label="End Stop", choices=[], interactive=True) gr.Markdown("---") gr.Markdown("### 2. Time & Details") with gr.Row(): hour_input = gr.Slider(minimum=0, maximum=23, step=1, label="Hour (0-23)", value=9) day_input = gr.Dropdown(choices=list(DAY_MAP.keys()), label="Day", value="Monday") with gr.Row(): dist_input = gr.Number(label="Distance (m)", value=5000) stops_input = gr.Number(label="Stops Count", value=10) predict_btn = gr.Button("Predict ETA", variant="primary") with gr.Column(): gr.Markdown("### Result") output_text = gr.Textbox(label="Prediction", lines=3) gr.Markdown("*Tip: If you modify the Distance/Stops boxes manually, the model will use your typed values.*") # --- Event Wiring --- route_search.change(fn=filter_routes, inputs=route_search, outputs=route_dropdown) route_dropdown.change( fn=update_stop_dropdowns, inputs=route_dropdown, outputs=[start_dropdown, end_dropdown] ) # Auto-fill triggers start_dropdown.change( fn=auto_fill_metrics, inputs=[route_dropdown, start_dropdown, end_dropdown, dist_input, stops_input], outputs=[dist_input, stops_input] ) end_dropdown.change( fn=auto_fill_metrics, inputs=[route_dropdown, start_dropdown, end_dropdown, dist_input, stops_input], outputs=[dist_input, stops_input] ) # Predict trigger (Passes dropdowns just for validation check) predict_btn.click( fn=predict_fn, inputs=[dist_input, stops_input, hour_input, day_input, route_dropdown, start_dropdown, end_dropdown], outputs=output_text ) if __name__ == "__main__": demo.launch()