# graph.py import logging from langchain_core.documents import Document from langchain_core.messages import SystemMessage from langgraph.graph import END, StateGraph from langgraph.graph.state import CompiledStateGraph from prompts import GENERATE_PROMPT, REFORMULATE_PROMPT from state import AgentState logger = logging.getLogger(__name__) def format_context(docs: list[Document]) -> str: chunks = [] for i, doc in enumerate(docs, 1): source = (doc.metadata or {}).get("source", "Untitled") source_id = (doc.metadata or {}).get("id", f"chunk-{i}") text = doc.page_content or "" chunks.append(f"[{i}] id={source_id} source={source}\n{text}") return "\n\n".join(chunks) def build_graph(llm, vector_store) -> CompiledStateGraph: def reformulate(state: AgentState) -> AgentState: user_msg = state["messages"][-1] resp = llm.invoke([REFORMULATE_PROMPT, user_msg]) reformulated = resp.content.strip() logger.info(f"[reformulate] '{user_msg.content}' → '{reformulated}'") return {"reformulated_query": reformulated} def retrieve(state: AgentState) -> AgentState: query = state["reformulated_query"] docs = vector_store.as_retriever( search_type="similarity", search_kwargs={"k": 3}, ).invoke(query) context = format_context(docs) logger.info(f"[retrieve] {len(docs)} docs fetched") logger.info(context) return {"context": context} def generate(state: AgentState) -> AgentState: prompt = SystemMessage( content=GENERATE_PROMPT.content.format(context=state["context"]) ) resp = llm.invoke([prompt] + state["messages"]) return {"messages": [resp]} graph_builder = StateGraph(AgentState) graph_builder.add_node("reformulate", reformulate) graph_builder.add_node("retrieve", retrieve) graph_builder.add_node("generate", generate) graph_builder.set_entry_point("reformulate") graph_builder.add_edge("reformulate", "retrieve") graph_builder.add_edge("retrieve", "generate") graph_builder.add_edge("generate", END) return graph_builder.compile()