File size: 5,765 Bytes
f7a42c7
 
 
ec946c6
 
 
f7a42c7
ec946c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7a42c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec946c6
f7a42c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ec946c6
 
f7a42c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fde6ed4
 
f7a42c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import os
from IPython.display import Image, display
from langchain_core.messages import SystemMessage
from langchain_openai import AzureChatOpenAI
from langgraph.graph import StateGraph, START, END
from langgraph.prebuilt import  tools_condition
from .state import State
from .custom_tool_node import CustomToolNode
from .tools import get_avaiable_tools


class CustomAgent:
    def __init__(self):
        print("CustomAgent initialized.")
        self.graph = build_graph()  # Build the state graph for the agent

    def __call__(self, question: str, task_id: str) -> str:
        print(f"Agent received question (first 50 chars): {question[:50]}...")
        system_prompt = SystemMessage(content=get_prompt())
        messages = self.graph.invoke({
            "messages": [
                system_prompt,
                {"role": "user", "content": question}
            ],
        "task_id": task_id
    })
        answer = messages['messages'][-1].content
        return answer[14:]


def get_prompt() -> str:
    with open("system_prompt.txt", "r", encoding="utf-8") as f:
        return f.read()


def build_graph():
    """Builds the state graph for the React agent."""
    # Initialize our LLM
    llm = AzureChatOpenAI(
        azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),  # Corrected variable name
        openai_api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
        deployment_name=os.getenv("AZURE_OPENAI_DEPLOYMENT"),  # Corrected variable name
        openai_api_key=os.getenv("AZURE_OPENAI_API_KEY"),
        temperature=0.0,
    )
    avaiable_tools = get_avaiable_tools()
    llm_with_tools = llm.bind_tools(avaiable_tools)


    def assistant(state: State):
        """Assistant node"""
        response = llm_with_tools.invoke(state["messages"])
        if response.content == '':
            messages = [response]  # tool calling message
        else:
            final_message = response.content
            # final_message += f"\n\nTask ID: {state['task_id']}"
            messages = [final_message]
        return {"messages": messages, 
                "task_id": state["task_id"]
                }
    
    # Initialize the state graph
    graph_builder = StateGraph(State)

    # Add nodes
    # graph_builder.add_node("check_question_reversed", is_question_reversed)
    # graph_builder.add_node("reverse_text", reverse_text)
    graph_builder.add_node("assistant", assistant)
    tools_dict = {tool.name: tool for tool in avaiable_tools}

    graph_builder.add_node("tools", CustomToolNode(tools_dict))

    # graph_builder.add_edge(START, "check_question_reversed")
    # graph_builder.add_conditional_edges(
    #     "check_question_reversed",
    #     route_question,
    #     {
    #         "question_reversed": "reverse_text",
    #         "question_not_reversed": "assistant"
    #     }
    # )
    # graph_builder.add_edge("reverse_text", "assistant")
    graph_builder.add_edge(START, "assistant")
    graph_builder.add_conditional_edges(
        "assistant",
        tools_condition,
    )
    graph_builder.add_edge("tools", "assistant")
    graph_builder.add_edge("assistant", END)
    return graph_builder.compile()

if __name__ == "__main__":
    # Build the graph
    react_graph = build_graph()

    # Display the graph visualization
    # graph = react_graph.get_graph(xray=True)
    # display(Image(graph.draw_mermaid_png(output_file_path='graph.png')))

    # Example question to test the agent
    # question = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
    # question = ".rewsna eht sa \"tfel\" drow eht fo etisoppo eht etirw ,ecnetnes siht dnatsrednu uoy fI"
    # question = "Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2016?"
    #question = """Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal'c say in response to the question \"Isn't that hot?\""""
    # question = """Hi, I was out sick from my classes on Friday, so I'm trying to figure out what I need to study for my Calculus mid-term next week. My friend from class sent me an audio recording of Professor Willowbrook giving out the recommended reading for the test, but my headphones are broken :(\n\nCould you please listen to the recording for me and tell me the page numbers I'm supposed to go over? I've attached a file called Homework.mp3 that has the recording. Please provide just the page numbers as a comma-delimited list. And please provide the list in ascending order."""
    # question = """The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places."""
    question = """What is the final numeric output from the attached Python code?"""
    task_id = "f918266a-b3e0-4914-865d-4faa564f1aef"
    system_prompt = SystemMessage(content=get_prompt())
    messages = react_graph.invoke({
        "messages": [
            system_prompt,
            {"role": "user", "content": question}
        ],
        "task_id": task_id
    })
    for m in messages["messages"]:
        m.pretty_print()

    answer = messages['messages'][-1].content
    print(f"Final Answer: {answer[14:]}")
    # Stream the response from the agent
    # events = react_graph.stream(
    #     {"messages": [("user", question)]},
    #     config={"configurable": {"return_intermediate_steps": True}},
    #     stream_mode="values"
    # )
    
    # for event in events:
    #     print(event)  # Replace `_print_event(event, _printed)` with direct printing
    #     print("----\n---")