Spaces:
Running
Running
Parse judgments with structured output prompting, one response model, one judge model at a time.
eb4ec23
| import os | |
| import streamlit as st | |
| import dotenv | |
| import openai | |
| from openai import OpenAI | |
| import anthropic | |
| from together import Together | |
| import google.generativeai as genai | |
| import time | |
| from collections import defaultdict | |
| from typing import List, Optional, Literal, Union, Dict | |
| from constants import ( | |
| LLM_COUNCIL_MEMBERS, | |
| PROVIDER_TO_AVATAR_MAP, | |
| AGGREGATORS, | |
| LLM_TO_UI_NAME_MAP, | |
| ) | |
| from prompts import * | |
| from judging_dataclasses import ( | |
| # DirectAssessmentJudgingResponse, | |
| DirectAssessmentCriterionScore, | |
| DirectAssessmentCriteriaScores, | |
| ) | |
| import pandas as pd | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| dotenv.load_dotenv() | |
| PASSWORD = os.getenv("APP_PASSWORD") | |
| # Load API keys from environment variables | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY") | |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
| TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY") | |
| # Initialize API clients | |
| together_client = Together(api_key=TOGETHER_API_KEY) | |
| genai.configure(api_key=GOOGLE_API_KEY) | |
| # Set up API clients for OpenAI and Anthropic | |
| openai.api_key = OPENAI_API_KEY | |
| openai_client = OpenAI( | |
| organization="org-kUoRSK0nOw4W2nQYMVGWOt03", | |
| project="proj_zb6k1DdgnSEbiAEMWxSOVVu4", | |
| ) | |
| # anthropic_client = anthropic.Client(api_key=ANTHROPIC_API_KEY) | |
| anthropic_client = anthropic.Anthropic() | |
| client = OpenAI() | |
| def anthropic_streamlit_streamer(stream, llm): | |
| """ | |
| Process the Anthropic streaming response and yield content from the deltas. | |
| :param stream: Streaming object from Anthropic API | |
| :return: Yields content (text) from the streaming response. | |
| """ | |
| for event in stream: | |
| if hasattr(event, "type"): | |
| # Count input token usage. | |
| if event.type == "message_start": | |
| st.session_state["input_token_usage"][ | |
| llm | |
| ] += event.message.usage.input_tokens | |
| st.session_state["output_token_usage"][ | |
| llm | |
| ] += event.message.usage.output_tokens | |
| # Count output token usage. | |
| if event.type == "message_delta": | |
| st.session_state["output_token_usage"][llm] += event.usage.output_tokens | |
| # Handle content blocks | |
| if event.type == "content_block_delta" and hasattr(event, "delta"): | |
| # Extract text delta from the event | |
| text_delta = getattr(event.delta, "text", None) | |
| if text_delta: | |
| yield text_delta | |
| # Handle message completion events (optional if needed) | |
| elif event.type == "message_stop": | |
| break # End of message, stop streaming | |
| def get_ui_friendly_name(llm): | |
| if "agg__" in llm: | |
| return ( | |
| "MoA (" | |
| + LLM_TO_UI_NAME_MAP.get(llm.split("__")[1], llm.split("__")[1]) | |
| + ")" | |
| ) | |
| return LLM_TO_UI_NAME_MAP.get(llm, llm) | |
| def google_streamlit_streamer(stream): | |
| # TODO: Count token usage. | |
| for chunk in stream: | |
| yield chunk.text | |
| def openai_streamlit_streamer(stream, llm): | |
| # https://platform.openai.com/docs/api-reference/streaming | |
| for event in stream: | |
| if event.usage: | |
| st.session_state["input_token_usage"][llm] += event.usage.prompt_tokens | |
| st.session_state["output_token_usage"][llm] += event.usage.completion_tokens | |
| if event.choices: | |
| if event.choices[0].delta.content: | |
| yield event.choices[0].delta.content | |
| def together_streamlit_streamer(stream, llm): | |
| # https://docs.together.ai/docs/chat-overview#streaming-responses | |
| for chunk in stream: | |
| if chunk.usage: | |
| st.session_state["input_token_usage"][llm] += chunk.usage.prompt_tokens | |
| if chunk.usage: | |
| st.session_state["output_token_usage"][llm] += chunk.usage.completion_tokens | |
| yield chunk.choices[0].delta.content | |
| # Helper functions for LLM council and aggregator selection | |
| def llm_council_selector(): | |
| selected_council = st.radio( | |
| "Choose a council configuration", options=list(LLM_COUNCIL_MEMBERS.keys()) | |
| ) | |
| return LLM_COUNCIL_MEMBERS[selected_council] | |
| def aggregator_selector(): | |
| return st.radio("Choose an aggregator LLM", options=AGGREGATORS) | |
| # API calls for different providers | |
| def get_openai_response(model_name, prompt): | |
| return openai_client.chat.completions.create( | |
| model=model_name, | |
| messages=[{"role": "user", "content": prompt}], | |
| stream=True, | |
| stream_options={"include_usage": True}, | |
| ) | |
| # https://docs.anthropic.com/en/api/messages-streaming | |
| def get_anthropic_response(model_name, prompt): | |
| return anthropic_client.messages.create( | |
| max_tokens=1024, | |
| messages=[{"role": "user", "content": prompt}], | |
| model=model_name, | |
| stream=True, | |
| ) | |
| def get_together_response(model_name, prompt): | |
| return together_client.chat.completions.create( | |
| model=model_name, | |
| messages=[{"role": "user", "content": prompt}], | |
| stream=True, | |
| ) | |
| # https://ai.google.dev/gemini-api/docs/text-generation?lang=python | |
| def get_google_response(model_name, prompt): | |
| model = genai.GenerativeModel(model_name) | |
| return model.generate_content(prompt, stream=True) | |
| def get_llm_response_stream(model_identifier, prompt): | |
| """Returns a streamlit-friendly stream of response tokens from the LLM.""" | |
| provider, model_name = model_identifier.split("://") | |
| if provider == "openai": | |
| return openai_streamlit_streamer( | |
| get_openai_response(model_name, prompt), model_identifier | |
| ) | |
| elif provider == "anthropic": | |
| return anthropic_streamlit_streamer( | |
| get_anthropic_response(model_name, prompt), model_identifier | |
| ) | |
| elif provider == "together": | |
| return together_streamlit_streamer( | |
| get_together_response(model_name, prompt), model_identifier | |
| ) | |
| elif provider == "vertex": | |
| return google_streamlit_streamer(get_google_response(model_name, prompt)) | |
| else: | |
| return None | |
| def create_dataframe_for_direct_assessment_judging_response( | |
| response: DirectAssessmentCriteriaScores, judging_model: str | |
| ) -> pd.DataFrame: | |
| # Initialize empty list to collect data | |
| data = [] | |
| # Loop through models | |
| # for judging_model in response.judging_models: | |
| # model_name = judging_model.model | |
| # Loop through criteria_scores | |
| for criteria_score in response.criteria_scores: | |
| data.append( | |
| { | |
| "judging_model": judging_model, # Gets passed in. | |
| "criteria": criteria_score.criterion, | |
| "score": criteria_score.score, | |
| "explanation": criteria_score.explanation, | |
| } | |
| ) | |
| # Create DataFrame | |
| return pd.DataFrame(data) | |
| # Streamlit form UI | |
| def render_criteria_form(criteria_num): | |
| """Render a criteria input form.""" | |
| with st.expander(f"Criteria {criteria_num + 1}"): | |
| name = st.text_input( | |
| f"Name for Criteria {criteria_num + 1}", key=f"criteria_name_{criteria_num}" | |
| ) | |
| description = st.text_area( | |
| f"Description for Criteria {criteria_num + 1}", | |
| key=f"criteria_desc_{criteria_num}", | |
| ) | |
| min_score = st.number_input( | |
| f"Min Score for Criteria {criteria_num + 1}", | |
| min_value=0, | |
| step=1, | |
| key=f"criteria_min_{criteria_num}", | |
| ) | |
| max_score = st.number_input( | |
| f"Max Score for Criteria {criteria_num + 1}", | |
| min_value=0, | |
| step=1, | |
| key=f"criteria_max_{criteria_num}", | |
| ) | |
| return Criteria( | |
| name=name, description=description, min_score=min_score, max_score=max_score | |
| ) | |
| def format_likert_comparison_options(options): | |
| return "\n".join([f"{i + 1}: {option}" for i, option in enumerate(options)]) | |
| def format_criteria_list(criteria_list): | |
| return "\n".join( | |
| [f"{criteria.name}: {criteria.description}" for criteria in criteria_list] | |
| ) | |
| def get_direct_assessment_prompt( | |
| direct_assessment_prompt, user_prompt, response, criteria_list, options | |
| ): | |
| return direct_assessment_prompt.format( | |
| user_prompt=user_prompt, | |
| response=response, | |
| criteria_list=f"{format_criteria_list(DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST)}", | |
| options=f"{format_likert_comparison_options(SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS)}", | |
| ) | |
| def get_default_direct_assessment_prompt(user_prompt): | |
| return get_direct_assessment_prompt( | |
| direct_assessment_prompt=DEFAULT_DIRECT_ASSESSMENT_PROMPT, | |
| user_prompt=user_prompt, | |
| response="{response}", | |
| criteria_list=DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST, | |
| options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS, | |
| ) | |
| def get_aggregator_prompt(aggregator_prompt, user_prompt, llms): | |
| responses_from_other_llms = "\n\n".join( | |
| [ | |
| f"{get_ui_friendly_name(model)} START\n{st.session_state['responses'][model]}\n\n{get_ui_friendly_name(model)} END\n\n\n" | |
| for model in llms | |
| ] | |
| ) | |
| return aggregator_prompt.format( | |
| user_prompt=user_prompt, | |
| responses_from_other_llms=responses_from_other_llms, | |
| ) | |
| def get_default_aggregator_prompt(user_prompt, llms): | |
| return get_aggregator_prompt( | |
| DEFAULT_AGGREGATOR_PROMPT, | |
| user_prompt=user_prompt, | |
| llms=llms, | |
| ) | |
| def get_parse_judging_response_for_direct_assessment_prompt( | |
| judging_response: str, | |
| criteria_list, | |
| options, | |
| ) -> str: | |
| # formatted_judging_responses = "\n\n\n".join( | |
| # [ | |
| # f"----- {get_ui_friendly_name(model)} START -----\n\n\n{judging_responses[model]}\n\n\n-----{get_ui_friendly_name(model)} END-----\n\n\n" | |
| # for model in judging_responses.keys() | |
| # ] | |
| # ) | |
| formatted_judging_response = ( | |
| f"----- START -----\n\n\n{judging_response}\n\n\n----- END -----\n\n\n" | |
| ) | |
| return PARSE_JUDGING_RESPONSE_FOR_DIRECT_ASSESSMENT_PROMPT.format( | |
| judging_response=formatted_judging_response, | |
| criteria_list=format_criteria_list(criteria_list), | |
| options=format_likert_comparison_options(options), | |
| ) | |
| def get_parsed_judging_response_obj_using_llm( | |
| prompt: str, | |
| ) -> DirectAssessmentCriteriaScores: | |
| # if os.getenv("DEBUG_MODE") == "True": | |
| # return DirectAssessmentJudgingResponse( | |
| # judging_models=[ | |
| # DirectAssessmentCriteriaScores( | |
| # model="together://meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", | |
| # criteria_scores=[ | |
| # DirectAssessmentCriterionScore( | |
| # criterion="helpfulness", score=3, explanation="explanation1" | |
| # ), | |
| # DirectAssessmentCriterionScore( | |
| # criterion="conciseness", score=4, explanation="explanation2" | |
| # ), | |
| # DirectAssessmentCriterionScore( | |
| # criterion="relevance", score=5, explanation="explanation3" | |
| # ), | |
| # ], | |
| # ), | |
| # DirectAssessmentCriteriaScores( | |
| # model="together://meta-llama/Llama-3.2-3B-Instruct-Turbo", | |
| # criteria_scores=[ | |
| # DirectAssessmentCriterionScore( | |
| # criterion="helpfulness", score=1, explanation="explanation1" | |
| # ), | |
| # DirectAssessmentCriterionScore( | |
| # criterion="conciseness", score=2, explanation="explanation2" | |
| # ), | |
| # DirectAssessmentCriterionScore( | |
| # criterion="relevance", score=3, explanation="explanation3" | |
| # ), | |
| # ], | |
| # ), | |
| # ] | |
| # ) | |
| # else: | |
| completion = client.beta.chat.completions.parse( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "Parse the judging responses into structured data.", | |
| }, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| response_format=DirectAssessmentCriteriaScores, | |
| ) | |
| # Track token usage. | |
| st.session_state["input_token_usage"][ | |
| "gpt-4o-mini" | |
| ] += completion.usage.prompt_tokens | |
| st.session_state["output_token_usage"][ | |
| "gpt-4o-mini" | |
| ] += completion.usage.completion_tokens | |
| return completion.choices[0].message.parsed | |
| def get_llm_avatar(model_identifier): | |
| if "agg__" in model_identifier: | |
| return "img/council_icon.png" | |
| else: | |
| return PROVIDER_TO_AVATAR_MAP[model_identifier] | |
| def plot_criteria_scores(df): | |
| # Group by criteria and calculate mean and std over all judges. | |
| grouped = df.groupby(["criteria"]).agg({"score": ["mean", "std"]}).reset_index() | |
| # Flatten the MultiIndex columns | |
| grouped.columns = ["criteria", "mean_score", "std_score"] | |
| # Fill NaN std with zeros (in case there's only one score per group) | |
| grouped["std_score"] = grouped["std_score"].fillna(0) | |
| # Set up the plot | |
| plt.figure(figsize=(8, 5)) | |
| # Create a horizontal bar plot | |
| ax = sns.barplot( | |
| data=grouped, | |
| x="mean_score", | |
| y="criteria", | |
| hue="criteria", | |
| errorbar=None, # Updated parameter | |
| orient="h", | |
| ) | |
| # Add error bars manually | |
| # Iterate over the bars and add error bars | |
| for i, (mean, std) in enumerate(zip(grouped["mean_score"], grouped["std_score"])): | |
| # Get the current bar | |
| bar = ax.patches[i] | |
| # Calculate the center of the bar | |
| center = bar.get_y() + bar.get_height() / 2 | |
| # Add the error bar | |
| ax.errorbar(x=mean, y=center, xerr=std, ecolor="black", capsize=3, fmt="none") | |
| # Set labels and title | |
| ax.set_xlabel("") | |
| ax.set_ylabel("") | |
| plt.tight_layout() | |
| # Display the plot in Streamlit | |
| st.pyplot(plt.gcf()) | |
| def plot_overall_scores(overall_scores_df): | |
| # Calculate mean and standard deviation | |
| summary = ( | |
| overall_scores_df.groupby("response_model") | |
| .agg({"score": ["mean", "std"]}) | |
| .reset_index() | |
| ) | |
| summary.columns = ["response_model", "mean_score", "std_score"] | |
| # Add UI-friendly names | |
| summary["ui_friendly_name"] = summary["response_model"].apply(get_ui_friendly_name) | |
| # Sort the summary dataframe by mean_score in descending order | |
| summary = summary.sort_values("mean_score", ascending=False) | |
| # Create the plot | |
| plt.figure(figsize=(8, 5)) | |
| # Plot bars with rainbow colors | |
| ax = sns.barplot( | |
| x="ui_friendly_name", | |
| y="mean_score", | |
| hue="ui_friendly_name", | |
| data=summary, | |
| palette="rainbow", | |
| capsize=0.1, | |
| legend=False, | |
| ) | |
| # Add error bars manually | |
| x_coords = range(len(summary)) | |
| plt.errorbar( | |
| x=x_coords, | |
| y=summary["mean_score"], | |
| yerr=summary["std_score"], | |
| fmt="none", | |
| c="black", | |
| capsize=5, | |
| zorder=10, # Ensure error bars are on top | |
| ) | |
| # Add text annotations using the actual positions of the bars | |
| for patch, row in zip(ax.patches, summary.itertuples()): | |
| # Get the center of each bar (x position) | |
| x = patch.get_x() + patch.get_width() / 2 | |
| y = patch.get_height() | |
| # Add the text annotation | |
| ax.text( | |
| x, | |
| y, | |
| f"{row.mean_score:.2f}", | |
| ha="center", | |
| va="bottom", | |
| # fontweight="bold", | |
| color="black", | |
| bbox=dict(facecolor="white", edgecolor="none", alpha=0.7, pad=0.5), | |
| ) | |
| # Customize the plot | |
| plt.xlabel("") | |
| plt.ylabel("Overall Score") | |
| plt.xticks(rotation=45, ha="right") | |
| plt.tight_layout() | |
| # Display the plot in Streamlit | |
| st.pyplot(plt.gcf()) | |
| def plot_per_judge_overall_scores(df): | |
| # Find the overall score by finding the overall score for each judge, and then averaging | |
| # over all judges. | |
| grouped = df.groupby(["judging_model"]).agg({"score": ["mean"]}).reset_index() | |
| grouped.columns = ["judging_model", "overall_score"] | |
| # Create the horizontal bar plot | |
| plt.figure(figsize=(10, 6)) | |
| ax = sns.barplot( | |
| data=grouped, | |
| x="judging_model", | |
| y="overall_score", | |
| hue="judging_model", | |
| orient="v", | |
| palette="rainbow", | |
| ) | |
| # Customize the plot | |
| plt.title("Overall Score from each LLM Judge") | |
| plt.xlabel("Overall Score") | |
| plt.ylabel("LLM Judge") | |
| # Adjust layout and display the plot | |
| plt.tight_layout() | |
| st.pyplot(plt) | |
| def get_selected_models_to_streamlit_column_map(st_columns, selected_models): | |
| selected_models_to_streamlit_column_map = { | |
| model: st_columns[i % len(st_columns)] | |
| for i, model in enumerate(selected_models) | |
| } | |
| return selected_models_to_streamlit_column_map | |
| def get_aggregator_key(llm_aggregator): | |
| return "agg__" + llm_aggregator | |
| def st_render_responses(user_prompt): | |
| """Renders the responses from the LLMs. | |
| Uses cached responses from the session state, if available. | |
| Otherwise, streams the responses anew. | |
| Assumes that the session state has already been set up with selected models and selected aggregator. | |
| """ | |
| st.markdown("#### Responses") | |
| response_columns = st.columns(3) | |
| selected_models_to_streamlit_column_map = ( | |
| get_selected_models_to_streamlit_column_map( | |
| response_columns, st.session_state.selected_models | |
| ) | |
| ) | |
| for response_model in st.session_state.selected_models: | |
| st_column = selected_models_to_streamlit_column_map.get( | |
| response_model, response_columns[0] | |
| ) | |
| with st_column.chat_message( | |
| response_model, | |
| avatar=get_llm_avatar(response_model), | |
| ): | |
| st.write(get_ui_friendly_name(response_model)) | |
| if response_model in st.session_state.responses: | |
| # Use the cached response from session state. | |
| st.write(st.session_state.responses[response_model]) | |
| else: | |
| # Stream the response from the LLM. | |
| message_placeholder = st.empty() | |
| stream = get_llm_response_stream(response_model, user_prompt) | |
| st.session_state.responses[response_model] = ( | |
| message_placeholder.write_stream(stream) | |
| ) | |
| # Render the aggregator response. | |
| aggregator_prompt = get_default_aggregator_prompt( | |
| user_prompt=user_prompt, llms=st.session_state.selected_models | |
| ) | |
| # Streaming response from the aggregator. | |
| with st.chat_message( | |
| get_aggregator_key(st.session_state.selected_aggregator), | |
| avatar="img/council_icon.png", | |
| ): | |
| st.write( | |
| f"{get_ui_friendly_name(get_aggregator_key(st.session_state.selected_aggregator))}" | |
| ) | |
| if ( | |
| get_aggregator_key(st.session_state.selected_aggregator) | |
| in st.session_state.responses | |
| ): | |
| st.write( | |
| st.session_state.responses[ | |
| get_aggregator_key(st.session_state.selected_aggregator) | |
| ] | |
| ) | |
| else: | |
| message_placeholder = st.empty() | |
| aggregator_stream = get_llm_response_stream( | |
| st.session_state.selected_aggregator, aggregator_prompt | |
| ) | |
| if aggregator_stream: | |
| st.session_state.responses[ | |
| get_aggregator_key(st.session_state.selected_aggregator) | |
| ] = message_placeholder.write_stream(aggregator_stream) | |
| st.session_state.responses_collected = True | |
| def st_direct_assessment_results(user_prompt, direct_assessment_prompt, criteria_list): | |
| """Renders the direct assessment results block. | |
| Uses session state to render results from LLMs. If the session state isn't set, then fetches the | |
| responses from the LLMs services from scratch (and sets the session state). | |
| Assumes that the session state has already been set up with responses. | |
| """ | |
| responses_for_judging = st.session_state.responses | |
| # Get judging responses. | |
| response_judging_columns = st.columns(3) | |
| responses_for_judging_to_streamlit_column_map = ( | |
| get_selected_models_to_streamlit_column_map( | |
| response_judging_columns, responses_for_judging.keys() | |
| ) | |
| ) | |
| for response_model, response in responses_for_judging.items(): | |
| st_column = responses_for_judging_to_streamlit_column_map[response_model] | |
| with st_column: | |
| st.write(f"Judging for {get_ui_friendly_name(response_model)}") | |
| judging_prompt = get_direct_assessment_prompt( | |
| direct_assessment_prompt=direct_assessment_prompt, | |
| user_prompt=user_prompt, | |
| response=response, | |
| criteria_list=criteria_list, | |
| options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS, | |
| ) | |
| with st.expander("Final Judging Prompt"): | |
| st.code(judging_prompt) | |
| for judging_model in st.session_state.selected_models: | |
| with st.expander(get_ui_friendly_name(judging_model), expanded=True): | |
| with st.chat_message( | |
| judging_model, | |
| avatar=PROVIDER_TO_AVATAR_MAP[judging_model], | |
| ): | |
| if ( | |
| judging_model | |
| in st.session_state.direct_assessment_judging_responses[ | |
| response_model | |
| ] | |
| ): | |
| # Use the session state cached response. | |
| st.write( | |
| st.session_state.direct_assessment_judging_responses[ | |
| response_model | |
| ][judging_model] | |
| ) | |
| else: | |
| message_placeholder = st.empty() | |
| # Get the judging response from the LLM. | |
| judging_stream = get_llm_response_stream( | |
| judging_model, judging_prompt | |
| ) | |
| st.session_state.direct_assessment_judging_responses[ | |
| response_model | |
| ][judging_model] = message_placeholder.write_stream( | |
| judging_stream | |
| ) | |
| # Parse the judging response. If parsing results are already cached, then | |
| # skip. | |
| # Use Structured Output to parse the judging response. | |
| parse_judging_response_prompt = get_parse_judging_response_for_direct_assessment_prompt( | |
| judging_response=st.session_state.direct_assessment_judging_responses[ | |
| response_model | |
| ][ | |
| judging_model | |
| ], | |
| criteria_list=criteria_list, | |
| options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS, | |
| ) | |
| st.write("Parse judging response prompt:") | |
| st.write(parse_judging_response_prompt) | |
| if ( | |
| response_model | |
| not in st.session_state.direct_assessment_judging_by_response_and_judging_model_df | |
| or judging_model | |
| not in st.session_state.direct_assessment_judging_by_response_and_judging_model_df[ | |
| response_model | |
| ] | |
| ): | |
| parsed_judging_response_obj = ( | |
| get_parsed_judging_response_obj_using_llm( | |
| parse_judging_response_prompt | |
| ) | |
| ) | |
| st.session_state.direct_assessment_judging_by_response_and_judging_model_df[ | |
| response_model | |
| ][ | |
| judging_model | |
| ] = create_dataframe_for_direct_assessment_judging_response( | |
| parsed_judging_response_obj, judging_model | |
| ) | |
| # with st.expander("Structured output parsing response"): | |
| st.write("Structured output parsing response:") | |
| st.write( | |
| st.session_state.direct_assessment_judging_by_response_and_judging_model_df[ | |
| response_model | |
| ][ | |
| judging_model | |
| ] | |
| ) | |
| # Combined the dataframes for each judging model into a single dataframe for each | |
| # response model. | |
| if response_model not in st.session_state.direct_assessment_judging_df: | |
| # Combine the dataframes for each judging model into a single dataframe. | |
| combined_judging_df = pd.DataFrame() | |
| for judging_model in st.session_state.selected_models: | |
| combined_judging_df = pd.concat( | |
| [ | |
| combined_judging_df, | |
| st.session_state.direct_assessment_judging_by_response_and_judging_model_df[ | |
| response_model | |
| ][ | |
| judging_model | |
| ], | |
| ] | |
| ) | |
| st.session_state.direct_assessment_judging_df[response_model] = ( | |
| combined_judging_df | |
| ) | |
| with st.expander("Judging results from all judges"): | |
| st.write(st.session_state.direct_assessment_judging_df[response_model]) | |
| # Uses the session state to plot the criteria scores and graphs for a given response | |
| # model. | |
| plot_criteria_scores( | |
| st.session_state.direct_assessment_judging_df[response_model] | |
| ) | |
| plot_per_judge_overall_scores( | |
| st.session_state.direct_assessment_judging_df[response_model] | |
| ) | |
| grouped = ( | |
| st.session_state.direct_assessment_judging_df[response_model] | |
| .groupby(["judging_model"]) | |
| .agg({"score": ["mean"]}) | |
| .reset_index() | |
| ) | |
| grouped.columns = ["judging_model", "overall_score"] | |
| # Save the overall scores to the session state if it's not already there. | |
| for record in grouped.to_dict(orient="records"): | |
| st.session_state.direct_assessment_overall_scores[ | |
| get_ui_friendly_name(response_model) | |
| ][get_ui_friendly_name(record["judging_model"])] = record[ | |
| "overall_score" | |
| ] | |
| overall_score = grouped["overall_score"].mean() | |
| controversy = grouped["overall_score"].std() | |
| st.write(f"Overall Score: {overall_score:.2f}") | |
| st.write(f"Controversy: {controversy:.2f}") | |
| # Mark judging as complete. | |
| st.session_state.judging_status = "complete" | |
| # Main Streamlit App | |
| def main(): | |
| st.set_page_config( | |
| page_title="Language Model Council Sandbox", page_icon="🏛️", layout="wide" | |
| ) | |
| # Custom CSS for the chat display | |
| center_css = """ | |
| <style> | |
| h1, h2, h3, h6 { text-align: center; } | |
| .chat-container { | |
| display: flex; | |
| align-items: flex-start; | |
| margin-bottom: 10px; | |
| } | |
| .avatar { | |
| width: 50px; | |
| margin-right: 10px; | |
| } | |
| .message { | |
| background-color: #f1f1f1; | |
| padding: 10px; | |
| border-radius: 10px; | |
| width: 100%; | |
| } | |
| </style> | |
| """ | |
| st.markdown(center_css, unsafe_allow_html=True) | |
| # App title and description | |
| st.title("Language Model Council Sandbox") | |
| st.markdown("###### Invoke a council of LLMs to judge each other's responses.") | |
| st.markdown("###### [Paper](https://arxiv.org/abs/2406.08598)") | |
| # Authentication system | |
| if "authenticated" not in st.session_state: | |
| st.session_state.authenticated = False | |
| cols = st.columns([2, 1, 2]) | |
| if not st.session_state.authenticated: | |
| with cols[1]: | |
| with st.form("login_form"): | |
| password = st.text_input("Password", type="password") | |
| submit_button = st.form_submit_button("Login", use_container_width=True) | |
| if submit_button: | |
| if password == PASSWORD: | |
| st.session_state.authenticated = True | |
| st.success("Logged in successfully!") | |
| st.rerun() | |
| else: | |
| st.error("Invalid credentials") | |
| if st.session_state.authenticated: | |
| if "responses_collected" not in st.session_state: | |
| st.session_state["responses_collected"] = False | |
| # Initialize session state for collecting responses. | |
| if "responses" not in st.session_state: | |
| st.session_state.responses = defaultdict(str) | |
| # Initialize session state for token usage. | |
| if "input_token_usage" not in st.session_state: | |
| st.session_state["input_token_usage"] = defaultdict(int) | |
| if "output_token_usage" not in st.session_state: | |
| st.session_state["output_token_usage"] = defaultdict(int) | |
| if "selected_models" not in st.session_state: | |
| st.session_state["selected_models"] = [] | |
| if "selected_aggregator" not in st.session_state: | |
| st.session_state["selected_aggregator"] = None | |
| # Initialize session state for direct assessment judging. | |
| if "direct_assessment_overall_score" not in st.session_state: | |
| st.session_state.direct_assessment_overall_score = {} | |
| if "direct_assessment_judging_df" not in st.session_state: | |
| st.session_state.direct_assessment_judging_df = {} | |
| if ( | |
| "direct_assessment_judging_by_response_and_judging_model_df" | |
| not in st.session_state | |
| ): | |
| st.session_state.direct_assessment_judging_by_response_and_judging_model_df = defaultdict( | |
| dict | |
| ) | |
| if "direct_assessment_judging_responses" not in st.session_state: | |
| st.session_state.direct_assessment_judging_responses = defaultdict(dict) | |
| if "direct_assessment_overall_scores" not in st.session_state: | |
| st.session_state.direct_assessment_overall_scores = defaultdict(dict) | |
| if "judging_status" not in st.session_state: | |
| st.session_state.judging_status = "incomplete" | |
| if "direct_assessment_config" not in st.session_state: | |
| st.session_state.direct_assessment_config = {} | |
| if "pairwise_comparison_config" not in st.session_state: | |
| st.session_state.pairwise_comparison_config = {} | |
| if "assessment_type" not in st.session_state: | |
| st.session_state.assessment_type = None | |
| with st.form(key="prompt_form"): | |
| st.markdown("#### LLM Council Member Selection") | |
| # Council and aggregator selection | |
| selected_models = llm_council_selector() | |
| selected_aggregator = aggregator_selector() | |
| # Prompt input and submission form | |
| st.markdown("#### Enter your prompt") | |
| _, center_column, _ = st.columns([3, 5, 3]) | |
| with center_column: | |
| user_prompt = st.text_area( | |
| "Enter your prompt", | |
| value="Say 'Hello World'", | |
| key="user_prompt", | |
| label_visibility="hidden", | |
| ) | |
| submit_button = st.form_submit_button( | |
| "Submit", use_container_width=True | |
| ) | |
| if submit_button: | |
| # Udpate state. | |
| st.session_state.selected_models = selected_models | |
| st.session_state.selected_aggregator = selected_aggregator | |
| # Render the chats. | |
| st_render_responses(user_prompt) | |
| # Render chats generally even they are available, if the submit button isn't clicked. | |
| elif st.session_state.responses: | |
| st_render_responses(user_prompt) | |
| # Judging. | |
| if st.session_state.responses_collected: | |
| with st.form(key="judging_form"): | |
| st.markdown("#### Judging Configuration") | |
| # Choose the type of assessment | |
| assessment_type = st.radio( | |
| "Select the type of assessment", | |
| options=["Direct Assessment", "Pairwise Comparison"], | |
| ) | |
| _, center_column, _ = st.columns([3, 5, 3]) | |
| # Depending on the assessment type, render different forms | |
| if assessment_type == "Direct Assessment": | |
| # Direct assessment prompt. | |
| with center_column.expander("Direct Assessment Prompt"): | |
| direct_assessment_prompt = st.text_area( | |
| "Prompt for the Direct Assessment", | |
| value=get_default_direct_assessment_prompt( | |
| user_prompt=user_prompt | |
| ), | |
| height=500, | |
| key="direct_assessment_prompt", | |
| ) | |
| # TODO: Add option to edit criteria list with a basic text field. | |
| criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST | |
| with center_column: | |
| judging_submit_button = st.form_submit_button( | |
| "Submit Judging", use_container_width=True | |
| ) | |
| if judging_submit_button: | |
| # Update session state. | |
| st.session_state.assessment_type = assessment_type | |
| if st.session_state.assessment_type == "Direct Assessment": | |
| st.session_state.direct_assessment_config = { | |
| "prompt": direct_assessment_prompt, | |
| "criteria_list": criteria_list, | |
| } | |
| st_direct_assessment_results( | |
| user_prompt=st.session_state.user_prompt, | |
| direct_assessment_prompt=direct_assessment_prompt, | |
| criteria_list=criteria_list, | |
| ) | |
| # If judging is complete, but the submit button is cleared, still render the results. | |
| elif st.session_state.judging_status == "complete": | |
| if st.session_state.assessment_type == "Direct Assessment": | |
| st_direct_assessment_results( | |
| user_prompt=st.session_state.user_prompt, | |
| direct_assessment_prompt=direct_assessment_prompt, | |
| criteria_list=criteria_list, | |
| ) | |
| # Judging is complete. | |
| # Render stuff that would be rendered that's not stream-specific. | |
| # The session state now contains the overall scores for each response from each judge. | |
| if st.session_state.judging_status == "complete": | |
| st.write("#### Results") | |
| overall_scores_df_raw = pd.DataFrame( | |
| st.session_state.direct_assessment_overall_scores | |
| ).reset_index() | |
| overall_scores_df = pd.melt( | |
| overall_scores_df_raw, | |
| id_vars=["index"], | |
| var_name="response_model", | |
| value_name="score", | |
| ).rename(columns={"index": "judging_model"}) | |
| # Print the overall winner. | |
| overall_winner = overall_scores_df.loc[ | |
| overall_scores_df["score"].idxmax() | |
| ] | |
| st.write( | |
| f"**Overall Winner:** {get_ui_friendly_name(overall_winner['response_model'])}" | |
| ) | |
| # Find how much the standard deviation overlaps with other models | |
| # TODO: Calculate separability. | |
| st.write(f"**Confidence:** {overall_winner['score']:.2f}") | |
| left_column, right_column = st.columns([1, 1]) | |
| with left_column: | |
| plot_overall_scores(overall_scores_df) | |
| with right_column: | |
| # All overall scores. | |
| overall_scores_df = overall_scores_df[ | |
| ["response_model", "judging_model", "score"] | |
| ] | |
| overall_scores_df["response_model"] = overall_scores_df[ | |
| "response_model" | |
| ].apply(get_ui_friendly_name) | |
| # overall_scores_df["judging_model"] = overall_scores_df[ | |
| # "judging_model" | |
| # ].apply(get_ui_friendly_name) | |
| with st.expander("Overall scores from all judges"): | |
| st.write(st.session_state.direct_assessment_overall_scores) | |
| st.dataframe(overall_scores_df_raw) | |
| st.dataframe(overall_scores_df) | |
| # All criteria scores. | |
| with right_column: | |
| all_scores_df = pd.DataFrame() | |
| for ( | |
| response_model, | |
| score_df, | |
| ) in st.session_state.direct_assessment_judging_df.items(): | |
| score_df["response_model"] = response_model | |
| all_scores_df = pd.concat([all_scores_df, score_df]) | |
| all_scores_df = all_scores_df.reset_index() | |
| all_scores_df = all_scores_df.drop(columns="index") | |
| # Reorder the columns | |
| all_scores_df = all_scores_df[ | |
| [ | |
| "response_model", | |
| "judging_model", | |
| "criteria", | |
| "score", | |
| "explanation", | |
| ] | |
| ] | |
| # all_scores_df["response_model"] = all_scores_df[ | |
| # "response_model" | |
| # ].apply(get_ui_friendly_name) | |
| # all_scores_df["judging_model"] = all_scores_df[ | |
| # "judging_model" | |
| # ].apply(get_ui_friendly_name) | |
| with st.expander( | |
| "Criteria-specific scores and explanations from all judges" | |
| ): | |
| st.dataframe(all_scores_df) | |
| # Token usage. | |
| if st.session_state.responses: | |
| st.divider() | |
| with st.expander("Token Usage"): | |
| st.write("Input tokens used.") | |
| st.write(st.session_state.input_token_usage) | |
| st.write( | |
| f"Input Tokens Total: {sum(st.session_state.input_token_usage.values())}" | |
| ) | |
| st.write("Output tokens used.") | |
| st.write(st.session_state.output_token_usage) | |
| st.write( | |
| f"Output Tokens Total: {sum(st.session_state.output_token_usage.values())}" | |
| ) | |
| else: | |
| with cols[1]: | |
| st.warning("Please log in to access this app.") | |
| if __name__ == "__main__": | |
| main() | |