assistance-engine/Docker/src/graph.py

61 lines
2.1 KiB
Python

# graph.py
import logging
from langchain_core.documents import Document
from langchain_core.messages import SystemMessage
from langgraph.graph import END, StateGraph
from langgraph.graph.state import CompiledStateGraph
from prompts import GENERATE_PROMPT, REFORMULATE_PROMPT
from state import AgentState
logger = logging.getLogger(__name__)
def format_context(docs: list[Document]) -> str:
chunks = []
for i, doc in enumerate(docs, 1):
source = (doc.metadata or {}).get("source", "Untitled")
source_id = (doc.metadata or {}).get("id", f"chunk-{i}")
text = doc.page_content or ""
chunks.append(f"[{i}] id={source_id} source={source}\n{text}")
return "\n\n".join(chunks)
def build_graph(llm, vector_store) -> CompiledStateGraph:
def reformulate(state: AgentState) -> AgentState:
user_msg = state["messages"][-1]
resp = llm.invoke([REFORMULATE_PROMPT, user_msg])
reformulated = resp.content.strip()
logger.info(f"[reformulate] '{user_msg.content}''{reformulated}'")
return {"reformulated_query": reformulated}
def retrieve(state: AgentState) -> AgentState:
query = state["reformulated_query"]
docs = vector_store.as_retriever(
search_type="similarity",
search_kwargs={"k": 3},
).invoke(query)
context = format_context(docs)
logger.info(f"[retrieve] {len(docs)} docs fetched")
logger.info(context)
return {"context": context}
def generate(state: AgentState) -> AgentState:
prompt = SystemMessage(
content=GENERATE_PROMPT.content.format(context=state["context"])
)
resp = llm.invoke([prompt] + state["messages"])
return {"messages": [resp]}
graph_builder = StateGraph(AgentState)
graph_builder.add_node("reformulate", reformulate)
graph_builder.add_node("retrieve", retrieve)
graph_builder.add_node("generate", generate)
graph_builder.set_entry_point("reformulate")
graph_builder.add_edge("reformulate", "retrieve")
graph_builder.add_edge("retrieve", "generate")
graph_builder.add_edge("generate", END)
return graph_builder.compile()