Spaces:
Build error
Build error
| import streamlit as st | |
| import numpy as np | |
| import json | |
| from sentence_transformers import SentenceTransformer, util | |
| import time | |
| st.set_page_config(initial_sidebar_state="collapsed") | |
| # データを読み込む | |
| with open("data/qa_data.json", "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| questions = [item["question"] for item in data] | |
| answers = [item["answer"] for item in data] | |
| # Cache model ở level app | |
| def load_model(): | |
| return SentenceTransformer("pkshatech/GLuCoSE-base-ja") | |
| # Cache embeddings data | |
| def load_embeddings(): | |
| return ( | |
| np.load("data/question_embeddings.npy"), | |
| np.load("data/answer_embeddings.npy"), | |
| ) | |
| # Load model và embeddings một lần | |
| model = load_model() | |
| question_embeddings, answer_embeddings = load_embeddings() | |
| # サイドバー設定 | |
| with st.sidebar.expander("⚙️ 設定", expanded=False): | |
| threshold_q = st.slider("質問の類似度しきい値", 0.0, 1.0, 0.7, 0.01) | |
| threshold_a = st.slider("回答の類似度しきい値", 0.0, 1.0, 0.65, 0.01) | |
| if st.button("新しいチャット", use_container_width=True): | |
| st.session_state.messages = [] | |
| st.rerun() | |
| def search_answer(user_input): | |
| """Tìm kiếm câu trả lời sử dụng cosine similarity""" | |
| # Encode với batch_size và show_progress_bar=False để tăng tốc | |
| user_embedding = model.encode( | |
| [user_input], | |
| convert_to_numpy=True, | |
| batch_size=1, | |
| show_progress_bar=False, | |
| normalize_embeddings=True, # Pre-normalize để tăng tốc cosine similarity | |
| ) | |
| # Tính cosine similarity với câu hỏi | |
| cos_scores_q = util.cos_sim(user_embedding, question_embeddings)[0] | |
| best_q_idx = np.argmax(cos_scores_q) | |
| score_q = cos_scores_q[best_q_idx] | |
| if score_q >= threshold_q: | |
| return ( | |
| answers[best_q_idx].replace("\n", " \n"), | |
| f"質問にマッチ ({score_q:.2f})", | |
| ) | |
| # Tính cosine similarity với câu trả lời | |
| cos_scores_a = util.cos_sim(user_embedding, answer_embeddings)[0] | |
| best_a_idx = np.argmax(cos_scores_a) | |
| score_a = cos_scores_a[best_a_idx] | |
| if score_a >= threshold_a: | |
| return ( | |
| answers[best_a_idx].replace("\n", " \n"), | |
| f"回答にマッチ ({score_a:.2f})", | |
| ) | |
| return "申し訳ありませんが、ご質問の答えを見つけることができませんでした。もう少し詳しく説明していただけますか?", "一致なし" | |
| def stream_response(response): | |
| """レスポンスをストリーム表示する(文字単位)""" | |
| for char in response: | |
| if char == "\n": | |
| # Replace newline with markdown line break | |
| yield " \n" | |
| else: | |
| yield char | |
| time.sleep(0.05) | |
| # Streamlitチャットインターフェース | |
| st.title("🤖 よくある質問チャットボット") | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if user_input := st.chat_input("質問を入力してください:"): | |
| st.session_state.messages.append({"role": "user", "content": user_input}) | |
| with st.chat_message("user"): | |
| st.markdown(user_input) | |
| with st.spinner("考え中... お待ちください。"): | |
| answer, info = search_answer(user_input) | |
| print(info) | |
| with st.chat_message("assistant"): | |
| response_placeholder = st.empty() | |
| response_placeholder.write_stream(stream_response(answer)) | |
| st.session_state.messages.append({"role": "assistant", "content": answer}) | |