Spaces:
Running
Running
Commit
·
a0dca54
1
Parent(s):
279a804
Factor out LLM chat rendering so that it persists even when the submit button isn't active.
Browse files
app.py
CHANGED
|
@@ -516,6 +516,206 @@ def get_selected_models_to_streamlit_column_map(st_columns, selected_models):
|
|
| 516 |
return selected_models_to_streamlit_column_map
|
| 517 |
|
| 518 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 519 |
# Main Streamlit App
|
| 520 |
def main():
|
| 521 |
st.set_page_config(
|
|
@@ -632,71 +832,11 @@ def main():
|
|
| 632 |
st.session_state.selected_aggregator = selected_aggregator
|
| 633 |
|
| 634 |
# Render the chats.
|
| 635 |
-
|
| 636 |
-
|
| 637 |
-
selected_models_to_streamlit_column_map = (
|
| 638 |
-
get_selected_models_to_streamlit_column_map(
|
| 639 |
-
response_columns, selected_models
|
| 640 |
-
)
|
| 641 |
-
)
|
| 642 |
|
| 643 |
-
|
| 644 |
-
|
| 645 |
-
|
| 646 |
-
st.write(get_ui_friendly_name(selected_model))
|
| 647 |
-
with st.chat_message(
|
| 648 |
-
selected_model,
|
| 649 |
-
avatar=PROVIDER_TO_AVATAR_MAP[selected_model],
|
| 650 |
-
):
|
| 651 |
-
message_placeholder = st.empty()
|
| 652 |
-
stream = get_llm_response_stream(selected_model, user_prompt)
|
| 653 |
-
if stream:
|
| 654 |
-
st.session_state["responses"][selected_model] = (
|
| 655 |
-
message_placeholder.write_stream(stream)
|
| 656 |
-
)
|
| 657 |
-
|
| 658 |
-
# Get the aggregator prompt.
|
| 659 |
-
aggregator_prompt = get_default_aggregator_prompt(
|
| 660 |
-
user_prompt=user_prompt, llms=selected_models
|
| 661 |
-
)
|
| 662 |
-
|
| 663 |
-
# Fetching and streaming response from the aggregator
|
| 664 |
-
st.write(f"{get_ui_friendly_name(selected_aggregator)}")
|
| 665 |
-
with st.chat_message(
|
| 666 |
-
selected_aggregator,
|
| 667 |
-
avatar="img/council_icon.png",
|
| 668 |
-
):
|
| 669 |
-
message_placeholder = st.empty()
|
| 670 |
-
aggregator_stream = get_llm_response_stream(
|
| 671 |
-
selected_aggregator, aggregator_prompt
|
| 672 |
-
)
|
| 673 |
-
if aggregator_stream:
|
| 674 |
-
st.session_state.responses["agg__" + selected_aggregator] = (
|
| 675 |
-
message_placeholder.write_stream(aggregator_stream)
|
| 676 |
-
)
|
| 677 |
-
|
| 678 |
-
st.session_state.responses_collected = True
|
| 679 |
-
|
| 680 |
-
# Render chats generally?
|
| 681 |
-
if st.session_state.responses and not submit_button:
|
| 682 |
-
st.markdown("#### Responses")
|
| 683 |
-
|
| 684 |
-
response_columns = st.columns(3)
|
| 685 |
-
selected_models_to_streamlit_column_map = (
|
| 686 |
-
get_selected_models_to_streamlit_column_map(
|
| 687 |
-
response_columns, st.session_state.selected_models
|
| 688 |
-
)
|
| 689 |
-
)
|
| 690 |
-
for response_model, response in st.session_state.responses.items():
|
| 691 |
-
st_column = selected_models_to_streamlit_column_map.get(
|
| 692 |
-
response_model, response_columns[0]
|
| 693 |
-
)
|
| 694 |
-
with st_column.chat_message(
|
| 695 |
-
response_model,
|
| 696 |
-
avatar=get_llm_avatar(response_model),
|
| 697 |
-
):
|
| 698 |
-
st.write(get_ui_friendly_name(response_model))
|
| 699 |
-
st.write(response)
|
| 700 |
|
| 701 |
# Judging.
|
| 702 |
if st.session_state.responses_collected:
|
|
@@ -727,228 +867,41 @@ def main():
|
|
| 727 |
# TODO: Add option to edit criteria list with a basic text field.
|
| 728 |
criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST
|
| 729 |
|
|
|
|
| 730 |
judging_submit_button = st.form_submit_button(
|
| 731 |
"Submit Judging", use_container_width=True
|
| 732 |
)
|
| 733 |
|
| 734 |
if judging_submit_button:
|
|
|
|
| 735 |
st.session_state.assessment_type = assessment_type
|
| 736 |
-
st.session_state.direct_assessment_config = {
|
| 737 |
-
"prompt": direct_assessment_prompt,
|
| 738 |
-
"criteria_list": criteria_list,
|
| 739 |
-
}
|
| 740 |
-
|
| 741 |
-
responses_for_judging = st.session_state.responses
|
| 742 |
-
|
| 743 |
-
# Get judging responses.
|
| 744 |
-
response_judging_columns = st.columns(3)
|
| 745 |
-
responses_for_judging_to_streamlit_column_map = (
|
| 746 |
-
get_selected_models_to_streamlit_column_map(
|
| 747 |
-
response_judging_columns, responses_for_judging.keys()
|
| 748 |
-
)
|
| 749 |
-
)
|
| 750 |
-
|
| 751 |
if st.session_state.assessment_type == "Direct Assessment":
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
judging_prompt = get_direct_assessment_prompt(
|
| 762 |
-
direct_assessment_prompt=direct_assessment_prompt,
|
| 763 |
-
user_prompt=user_prompt,
|
| 764 |
-
response=response,
|
| 765 |
-
criteria_list=criteria_list,
|
| 766 |
-
options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
|
| 767 |
-
)
|
| 768 |
-
|
| 769 |
-
with st.expander("Final Judging Prompt"):
|
| 770 |
-
st.code(judging_prompt)
|
| 771 |
-
|
| 772 |
-
for judging_model in selected_models:
|
| 773 |
-
with st.expander(
|
| 774 |
-
get_ui_friendly_name(judging_model), expanded=True
|
| 775 |
-
):
|
| 776 |
-
with st.chat_message(
|
| 777 |
-
judging_model,
|
| 778 |
-
avatar=PROVIDER_TO_AVATAR_MAP[judging_model],
|
| 779 |
-
):
|
| 780 |
-
message_placeholder = st.empty()
|
| 781 |
-
judging_stream = get_llm_response_stream(
|
| 782 |
-
judging_model, judging_prompt
|
| 783 |
-
)
|
| 784 |
-
st.session_state[
|
| 785 |
-
"direct_assessment_judging_responses"
|
| 786 |
-
][response_model][
|
| 787 |
-
judging_model
|
| 788 |
-
] = message_placeholder.write_stream(
|
| 789 |
-
judging_stream
|
| 790 |
-
)
|
| 791 |
-
# When all of the judging is finished for the given response, get the actual
|
| 792 |
-
# values, parsed.
|
| 793 |
-
judging_responses = st.session_state[
|
| 794 |
-
"direct_assessment_judging_responses"
|
| 795 |
-
][response_model]
|
| 796 |
-
|
| 797 |
-
if not judging_responses:
|
| 798 |
-
st.error(f"No judging responses for {response_model}")
|
| 799 |
-
quit()
|
| 800 |
-
parse_judging_response_prompt = (
|
| 801 |
-
get_parse_judging_response_for_direct_assessment_prompt(
|
| 802 |
-
judging_responses,
|
| 803 |
-
criteria_list,
|
| 804 |
-
SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
|
| 805 |
-
)
|
| 806 |
-
)
|
| 807 |
-
# Issue the prompt to openai mini with structured outputs
|
| 808 |
-
parsed_judging_responses = parse_judging_responses(
|
| 809 |
-
parse_judging_response_prompt, judging_responses
|
| 810 |
-
)
|
| 811 |
-
|
| 812 |
-
st.session_state["direct_assessment_judging_df"][
|
| 813 |
-
response_model
|
| 814 |
-
] = create_dataframe_for_direct_assessment_judging_response(
|
| 815 |
-
parsed_judging_responses
|
| 816 |
-
)
|
| 817 |
-
|
| 818 |
-
plot_criteria_scores(
|
| 819 |
-
st.session_state["direct_assessment_judging_df"][
|
| 820 |
-
response_model
|
| 821 |
-
]
|
| 822 |
-
)
|
| 823 |
-
|
| 824 |
-
# Find the overall score by finding the overall score for each judge, and then averaging
|
| 825 |
-
# over all judges.
|
| 826 |
-
plot_per_judge_overall_scores(
|
| 827 |
-
st.session_state["direct_assessment_judging_df"][
|
| 828 |
-
response_model
|
| 829 |
-
]
|
| 830 |
-
)
|
| 831 |
-
|
| 832 |
-
grouped = (
|
| 833 |
-
st.session_state["direct_assessment_judging_df"][
|
| 834 |
-
response_model
|
| 835 |
-
]
|
| 836 |
-
.groupby(["judging_model"])
|
| 837 |
-
.agg({"score": ["mean"]})
|
| 838 |
-
.reset_index()
|
| 839 |
-
)
|
| 840 |
-
grouped.columns = ["judging_model", "overall_score"]
|
| 841 |
-
|
| 842 |
-
# Save the overall scores to the session state.
|
| 843 |
-
for record in grouped.to_dict(orient="records"):
|
| 844 |
-
st.session_state["direct_assessment_overall_scores"][
|
| 845 |
-
response_model
|
| 846 |
-
][record["judging_model"]] = record["overall_score"]
|
| 847 |
-
|
| 848 |
-
overall_score = grouped["overall_score"].mean()
|
| 849 |
-
controversy = grouped["overall_score"].std()
|
| 850 |
-
st.write(f"Overall Score: {overall_score:.2f}")
|
| 851 |
-
st.write(f"Controversy: {controversy:.2f}")
|
| 852 |
-
|
| 853 |
-
st.session_state.judging_status = "complete"
|
| 854 |
# If judging is complete, but the submit button is cleared, still render the results.
|
| 855 |
elif st.session_state.judging_status == "complete":
|
| 856 |
if st.session_state.assessment_type == "Direct Assessment":
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
responses_for_judging_to_streamlit_column_map = (
|
| 862 |
-
get_selected_models_to_streamlit_column_map(
|
| 863 |
-
response_judging_columns, responses_for_judging.keys()
|
| 864 |
-
)
|
| 865 |
)
|
| 866 |
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
response_model
|
| 870 |
-
]
|
| 871 |
-
|
| 872 |
-
with st_column:
|
| 873 |
-
st.write(
|
| 874 |
-
f"Judging for {get_ui_friendly_name(response_model)}"
|
| 875 |
-
)
|
| 876 |
-
judging_prompt = get_direct_assessment_prompt(
|
| 877 |
-
direct_assessment_prompt=direct_assessment_prompt,
|
| 878 |
-
user_prompt=user_prompt,
|
| 879 |
-
response=response,
|
| 880 |
-
criteria_list=criteria_list,
|
| 881 |
-
options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
|
| 882 |
-
)
|
| 883 |
-
|
| 884 |
-
with st.expander("Final Judging Prompt"):
|
| 885 |
-
st.code(judging_prompt)
|
| 886 |
-
|
| 887 |
-
for judging_model in selected_models:
|
| 888 |
-
with st.expander(
|
| 889 |
-
get_ui_friendly_name(judging_model), expanded=True
|
| 890 |
-
):
|
| 891 |
-
with st.chat_message(
|
| 892 |
-
judging_model,
|
| 893 |
-
avatar=PROVIDER_TO_AVATAR_MAP[judging_model],
|
| 894 |
-
):
|
| 895 |
-
st.write(
|
| 896 |
-
st.session_state.direct_assessment_judging_responses[
|
| 897 |
-
response_model
|
| 898 |
-
][
|
| 899 |
-
judging_model
|
| 900 |
-
]
|
| 901 |
-
)
|
| 902 |
-
# When all of the judging is finished for the given response, get the actual
|
| 903 |
-
# values, parsed.
|
| 904 |
-
judging_responses = (
|
| 905 |
-
st.session_state.direct_assessment_judging_responses[
|
| 906 |
-
response_model
|
| 907 |
-
]
|
| 908 |
-
)
|
| 909 |
-
|
| 910 |
-
parse_judging_response_prompt = (
|
| 911 |
-
get_parse_judging_response_for_direct_assessment_prompt(
|
| 912 |
-
judging_responses,
|
| 913 |
-
criteria_list,
|
| 914 |
-
SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
|
| 915 |
-
)
|
| 916 |
-
)
|
| 917 |
-
|
| 918 |
-
plot_criteria_scores(
|
| 919 |
-
st.session_state.direct_assessment_judging_df[
|
| 920 |
-
response_model
|
| 921 |
-
]
|
| 922 |
-
)
|
| 923 |
-
|
| 924 |
-
plot_per_judge_overall_scores(
|
| 925 |
-
st.session_state.direct_assessment_judging_df[
|
| 926 |
-
response_model
|
| 927 |
-
]
|
| 928 |
-
)
|
| 929 |
-
|
| 930 |
-
grouped = (
|
| 931 |
-
st.session_state.direct_assessment_judging_df[
|
| 932 |
-
response_model
|
| 933 |
-
]
|
| 934 |
-
.groupby(["judging_model"])
|
| 935 |
-
.agg({"score": ["mean"]})
|
| 936 |
-
.reset_index()
|
| 937 |
-
)
|
| 938 |
-
grouped.columns = ["judging_model", "overall_score"]
|
| 939 |
-
|
| 940 |
-
overall_score = grouped["overall_score"].mean()
|
| 941 |
-
controversy = grouped["overall_score"].std()
|
| 942 |
-
st.write(f"Overall Score: {overall_score:.2f}")
|
| 943 |
-
st.write(f"Controversy: {controversy:.2f}")
|
| 944 |
-
|
| 945 |
-
# Judging is complete, stuff that would be rendered that's not stream-specific.
|
| 946 |
# The session state now contains the overall scores for each response from each judge.
|
| 947 |
if st.session_state.judging_status == "complete":
|
| 948 |
st.write("#### Results")
|
| 949 |
|
| 950 |
overall_scores_df_raw = pd.DataFrame(
|
| 951 |
-
st.session_state
|
| 952 |
).reset_index()
|
| 953 |
|
| 954 |
overall_scores_df = pd.melt(
|
|
|
|
| 516 |
return selected_models_to_streamlit_column_map
|
| 517 |
|
| 518 |
|
| 519 |
+
def get_aggregator_key(llm_aggregator):
|
| 520 |
+
return "agg__" + llm_aggregator
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def st_render_responses(user_prompt):
|
| 524 |
+
"""Renders the responses from the LLMs.
|
| 525 |
+
|
| 526 |
+
Uses cached responses from the session state, if available.
|
| 527 |
+
Otherwise, streams the responses anew.
|
| 528 |
+
|
| 529 |
+
Assumes that the session state has already been set up with selected models and selected aggregator.
|
| 530 |
+
"""
|
| 531 |
+
st.markdown("#### Responses")
|
| 532 |
+
|
| 533 |
+
response_columns = st.columns(3)
|
| 534 |
+
selected_models_to_streamlit_column_map = (
|
| 535 |
+
get_selected_models_to_streamlit_column_map(
|
| 536 |
+
response_columns, st.session_state.selected_models
|
| 537 |
+
)
|
| 538 |
+
)
|
| 539 |
+
for response_model in st.session_state.selected_models:
|
| 540 |
+
st_column = selected_models_to_streamlit_column_map.get(
|
| 541 |
+
response_model, response_columns[0]
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
with st_column.chat_message(
|
| 545 |
+
response_model,
|
| 546 |
+
avatar=get_llm_avatar(response_model),
|
| 547 |
+
):
|
| 548 |
+
st.write(get_ui_friendly_name(response_model))
|
| 549 |
+
if response_model in st.session_state.responses:
|
| 550 |
+
# Use the cached response from session state.
|
| 551 |
+
st.write(st.session_state.responses[response_model])
|
| 552 |
+
else:
|
| 553 |
+
# Stream the response from the LLM.
|
| 554 |
+
message_placeholder = st.empty()
|
| 555 |
+
stream = get_llm_response_stream(response_model, user_prompt)
|
| 556 |
+
st.session_state.responses[response_model] = (
|
| 557 |
+
message_placeholder.write_stream(stream)
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
# Render the aggregator response.
|
| 561 |
+
aggregator_prompt = get_default_aggregator_prompt(
|
| 562 |
+
user_prompt=user_prompt, llms=st.session_state.selected_models
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
# Streaming response from the aggregator.
|
| 566 |
+
with st.chat_message(
|
| 567 |
+
get_aggregator_key(st.session_state.selected_aggregator),
|
| 568 |
+
avatar="img/council_icon.png",
|
| 569 |
+
):
|
| 570 |
+
st.write(
|
| 571 |
+
f"{get_ui_friendly_name(get_aggregator_key(st.session_state.selected_aggregator))}"
|
| 572 |
+
)
|
| 573 |
+
if (
|
| 574 |
+
get_aggregator_key(st.session_state.selected_aggregator)
|
| 575 |
+
in st.session_state.responses
|
| 576 |
+
):
|
| 577 |
+
st.write(
|
| 578 |
+
st.session_state.responses[
|
| 579 |
+
get_aggregator_key(st.session_state.selected_aggregator)
|
| 580 |
+
]
|
| 581 |
+
)
|
| 582 |
+
else:
|
| 583 |
+
message_placeholder = st.empty()
|
| 584 |
+
aggregator_stream = get_llm_response_stream(
|
| 585 |
+
selected_aggregator, aggregator_prompt
|
| 586 |
+
)
|
| 587 |
+
if aggregator_stream:
|
| 588 |
+
st.session_state.responses[get_aggregator_key(selected_aggregator)] = (
|
| 589 |
+
message_placeholder.write_stream(aggregator_stream)
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
st.session_state.responses_collected = True
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def st_direct_assessment_results(user_prompt, direct_assessment_prompt, criteria_list):
|
| 596 |
+
"""Renders the direct assessment results block.
|
| 597 |
+
|
| 598 |
+
Uses session state to render results from LLMs. If the session state isn't set, then fetches the
|
| 599 |
+
responses from the LLMs services from scratch (and sets the session state).
|
| 600 |
+
|
| 601 |
+
Assumes that the session state has already been set up with responses.
|
| 602 |
+
"""
|
| 603 |
+
responses_for_judging = st.session_state.responses
|
| 604 |
+
|
| 605 |
+
# Get judging responses.
|
| 606 |
+
response_judging_columns = st.columns(3)
|
| 607 |
+
responses_for_judging_to_streamlit_column_map = (
|
| 608 |
+
get_selected_models_to_streamlit_column_map(
|
| 609 |
+
response_judging_columns, responses_for_judging.keys()
|
| 610 |
+
)
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
for response_model, response in responses_for_judging.items():
|
| 614 |
+
st_column = responses_for_judging_to_streamlit_column_map[response_model]
|
| 615 |
+
|
| 616 |
+
with st_column:
|
| 617 |
+
st.write(f"Judging for {get_ui_friendly_name(response_model)}")
|
| 618 |
+
judging_prompt = get_direct_assessment_prompt(
|
| 619 |
+
direct_assessment_prompt=direct_assessment_prompt,
|
| 620 |
+
user_prompt=user_prompt,
|
| 621 |
+
response=response,
|
| 622 |
+
criteria_list=criteria_list,
|
| 623 |
+
options=SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
with st.expander("Final Judging Prompt"):
|
| 627 |
+
st.code(judging_prompt)
|
| 628 |
+
|
| 629 |
+
for judging_model in st.session_state.selected_models:
|
| 630 |
+
with st.expander(get_ui_friendly_name(judging_model), expanded=True):
|
| 631 |
+
with st.chat_message(
|
| 632 |
+
judging_model,
|
| 633 |
+
avatar=PROVIDER_TO_AVATAR_MAP[judging_model],
|
| 634 |
+
):
|
| 635 |
+
if (
|
| 636 |
+
judging_model
|
| 637 |
+
in st.session_state.direct_assessment_judging_responses[
|
| 638 |
+
response_model
|
| 639 |
+
]
|
| 640 |
+
):
|
| 641 |
+
# Use the session state cached response.
|
| 642 |
+
st.write(
|
| 643 |
+
st.session_state.direct_assessment_judging_responses[
|
| 644 |
+
response_model
|
| 645 |
+
][judging_model]
|
| 646 |
+
)
|
| 647 |
+
else:
|
| 648 |
+
message_placeholder = st.empty()
|
| 649 |
+
# Get the judging response from the LLM.
|
| 650 |
+
judging_stream = get_llm_response_stream(
|
| 651 |
+
judging_model, judging_prompt
|
| 652 |
+
)
|
| 653 |
+
st.session_state.direct_assessment_judging_responses[
|
| 654 |
+
response_model
|
| 655 |
+
][judging_model] = message_placeholder.write_stream(
|
| 656 |
+
judging_stream
|
| 657 |
+
)
|
| 658 |
+
|
| 659 |
+
# Extract actual scores from open-ended responses using structured outputs.
|
| 660 |
+
# Since we're extracting structured data for the first time, we can save the dataframe
|
| 661 |
+
# to the session state so that it's cached.
|
| 662 |
+
if response_model not in st.session_state.direct_assessment_judging_df:
|
| 663 |
+
judging_responses = (
|
| 664 |
+
st.session_state.direct_assessment_judging_responses[response_model]
|
| 665 |
+
)
|
| 666 |
+
parse_judging_response_prompt = (
|
| 667 |
+
get_parse_judging_response_for_direct_assessment_prompt(
|
| 668 |
+
judging_responses,
|
| 669 |
+
criteria_list,
|
| 670 |
+
SEVEN_POINT_DIRECT_ASSESSMENT_OPTIONS,
|
| 671 |
+
)
|
| 672 |
+
)
|
| 673 |
+
parsed_judging_responses = parse_judging_responses(
|
| 674 |
+
parse_judging_response_prompt, judging_responses
|
| 675 |
+
)
|
| 676 |
+
st.session_state.direct_assessment_judging_df[response_model] = (
|
| 677 |
+
create_dataframe_for_direct_assessment_judging_response(
|
| 678 |
+
parsed_judging_responses
|
| 679 |
+
)
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
# Uses the session state to plot the criteria scores and graphs for a given response
|
| 683 |
+
# model.
|
| 684 |
+
plot_criteria_scores(
|
| 685 |
+
st.session_state.direct_assessment_judging_df[response_model]
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
plot_per_judge_overall_scores(
|
| 689 |
+
st.session_state.direct_assessment_judging_df[response_model]
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
grouped = (
|
| 693 |
+
st.session_state.direct_assessment_judging_df[response_model]
|
| 694 |
+
.groupby(["judging_model"])
|
| 695 |
+
.agg({"score": ["mean"]})
|
| 696 |
+
.reset_index()
|
| 697 |
+
)
|
| 698 |
+
grouped.columns = ["judging_model", "overall_score"]
|
| 699 |
+
|
| 700 |
+
# Save the overall scores to the session state if it's not already there.
|
| 701 |
+
for record in grouped.to_dict(orient="records"):
|
| 702 |
+
if (
|
| 703 |
+
response_model
|
| 704 |
+
not in st.session_state.direct_assessment_overall_scores
|
| 705 |
+
):
|
| 706 |
+
st.session_state.direct_assessment_overall_scores[response_model][
|
| 707 |
+
record["judging_model"]
|
| 708 |
+
] = record["overall_score"]
|
| 709 |
+
|
| 710 |
+
overall_score = grouped["overall_score"].mean()
|
| 711 |
+
controversy = grouped["overall_score"].std()
|
| 712 |
+
st.write(f"Overall Score: {overall_score:.2f}")
|
| 713 |
+
st.write(f"Controversy: {controversy:.2f}")
|
| 714 |
+
|
| 715 |
+
# Mark judging as complete.
|
| 716 |
+
st.session_state.judging_status = "complete"
|
| 717 |
+
|
| 718 |
+
|
| 719 |
# Main Streamlit App
|
| 720 |
def main():
|
| 721 |
st.set_page_config(
|
|
|
|
| 832 |
st.session_state.selected_aggregator = selected_aggregator
|
| 833 |
|
| 834 |
# Render the chats.
|
| 835 |
+
st_render_responses(user_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 836 |
|
| 837 |
+
# Render chats generally even they are available, if the submit button isn't clicked.
|
| 838 |
+
elif st.session_state.responses:
|
| 839 |
+
st_render_responses(user_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 840 |
|
| 841 |
# Judging.
|
| 842 |
if st.session_state.responses_collected:
|
|
|
|
| 867 |
# TODO: Add option to edit criteria list with a basic text field.
|
| 868 |
criteria_list = DEFAULT_DIRECT_ASSESSMENT_CRITERIA_LIST
|
| 869 |
|
| 870 |
+
with center_column:
|
| 871 |
judging_submit_button = st.form_submit_button(
|
| 872 |
"Submit Judging", use_container_width=True
|
| 873 |
)
|
| 874 |
|
| 875 |
if judging_submit_button:
|
| 876 |
+
# Update session state.
|
| 877 |
st.session_state.assessment_type = assessment_type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 878 |
if st.session_state.assessment_type == "Direct Assessment":
|
| 879 |
+
st.session_state.direct_assessment_config = {
|
| 880 |
+
"prompt": direct_assessment_prompt,
|
| 881 |
+
"criteria_list": criteria_list,
|
| 882 |
+
}
|
| 883 |
+
st_direct_assessment_results(
|
| 884 |
+
user_prompt=st.session_state.user_prompt,
|
| 885 |
+
direct_assessment_prompt=direct_assessment_prompt,
|
| 886 |
+
criteria_list=criteria_list,
|
| 887 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 888 |
# If judging is complete, but the submit button is cleared, still render the results.
|
| 889 |
elif st.session_state.judging_status == "complete":
|
| 890 |
if st.session_state.assessment_type == "Direct Assessment":
|
| 891 |
+
st_direct_assessment_results(
|
| 892 |
+
user_prompt=st.session_state.user_prompt,
|
| 893 |
+
direct_assessment_prompt=direct_assessment_prompt,
|
| 894 |
+
criteria_list=criteria_list,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 895 |
)
|
| 896 |
|
| 897 |
+
# Judging is complete.
|
| 898 |
+
# Render stuff that would be rendered that's not stream-specific.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 899 |
# The session state now contains the overall scores for each response from each judge.
|
| 900 |
if st.session_state.judging_status == "complete":
|
| 901 |
st.write("#### Results")
|
| 902 |
|
| 903 |
overall_scores_df_raw = pd.DataFrame(
|
| 904 |
+
st.session_state.direct_assessment_overall_scores
|
| 905 |
).reset_index()
|
| 906 |
|
| 907 |
overall_scores_df = pd.melt(
|