61 lines
2.1 KiB
Python
61 lines
2.1 KiB
Python
# graph.py
|
|
import logging
|
|
|
|
from langchain_core.documents import Document
|
|
from langchain_core.messages import SystemMessage
|
|
from langgraph.graph import END, StateGraph
|
|
from langgraph.graph.state import CompiledStateGraph
|
|
from prompts import GENERATE_PROMPT, REFORMULATE_PROMPT
|
|
from state import AgentState
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def format_context(docs: list[Document]) -> str:
|
|
chunks = []
|
|
for i, doc in enumerate(docs, 1):
|
|
source = (doc.metadata or {}).get("source", "Untitled")
|
|
source_id = (doc.metadata or {}).get("id", f"chunk-{i}")
|
|
text = doc.page_content or ""
|
|
chunks.append(f"[{i}] id={source_id} source={source}\n{text}")
|
|
return "\n\n".join(chunks)
|
|
|
|
|
|
def build_graph(llm, vector_store) -> CompiledStateGraph:
|
|
def reformulate(state: AgentState) -> AgentState:
|
|
user_msg = state["messages"][-1]
|
|
resp = llm.invoke([REFORMULATE_PROMPT, user_msg])
|
|
reformulated = resp.content.strip()
|
|
logger.info(f"[reformulate] '{user_msg.content}' → '{reformulated}'")
|
|
return {"reformulated_query": reformulated}
|
|
|
|
def retrieve(state: AgentState) -> AgentState:
|
|
query = state["reformulated_query"]
|
|
docs = vector_store.as_retriever(
|
|
search_type="similarity",
|
|
search_kwargs={"k": 3},
|
|
).invoke(query)
|
|
context = format_context(docs)
|
|
logger.info(f"[retrieve] {len(docs)} docs fetched")
|
|
logger.info(context)
|
|
return {"context": context}
|
|
|
|
def generate(state: AgentState) -> AgentState:
|
|
prompt = SystemMessage(
|
|
content=GENERATE_PROMPT.content.format(context=state["context"])
|
|
)
|
|
resp = llm.invoke([prompt] + state["messages"])
|
|
return {"messages": [resp]}
|
|
|
|
graph_builder = StateGraph(AgentState)
|
|
graph_builder.add_node("reformulate", reformulate)
|
|
graph_builder.add_node("retrieve", retrieve)
|
|
graph_builder.add_node("generate", generate)
|
|
|
|
graph_builder.set_entry_point("reformulate")
|
|
graph_builder.add_edge("reformulate", "retrieve")
|
|
graph_builder.add_edge("retrieve", "generate")
|
|
graph_builder.add_edge("generate", END)
|
|
|
|
return graph_builder.compile()
|