Merge branch 'mrh-online-dev' of github.com:BRUNIX-AI/assistance-engine into mrh-online-dev

This commit is contained in:
pseco 2026-03-04 13:58:43 +01:00
commit f15266f345
2 changed files with 711 additions and 491 deletions

View File

@ -1,12 +1,15 @@
# graph.py # graph.py
import logging
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.messages import SystemMessage from langchain_core.messages import SystemMessage
from langgraph.graph import StateGraph, END from langgraph.graph import END, StateGraph
from langgraph.graph.state import CompiledStateGraph from langgraph.graph.state import CompiledStateGraph
from prompts import GENERATE_PROMPT, REFORMULATE_PROMPT
from prompts import REFORMULATE_PROMPT, GENERATE_PROMPT
from state import AgentState from state import AgentState
logger = logging.getLogger(__name__)
def format_context(docs: list[Document]) -> str: def format_context(docs: list[Document]) -> str:
chunks = [] chunks = []
@ -23,7 +26,7 @@ def build_graph(llm, vector_store) -> CompiledStateGraph:
user_msg = state["messages"][-1] user_msg = state["messages"][-1]
resp = llm.invoke([REFORMULATE_PROMPT, user_msg]) resp = llm.invoke([REFORMULATE_PROMPT, user_msg])
reformulated = resp.content.strip() reformulated = resp.content.strip()
print(f"[reformulate] '{user_msg.content}''{reformulated}'") logger.info(f"[reformulate] '{user_msg.content}''{reformulated}'")
return {"reformulated_query": reformulated} return {"reformulated_query": reformulated}
def retrieve(state: AgentState) -> AgentState: def retrieve(state: AgentState) -> AgentState:
@ -33,8 +36,8 @@ def build_graph(llm, vector_store) -> CompiledStateGraph:
search_kwargs={"k": 3}, search_kwargs={"k": 3},
).invoke(query) ).invoke(query)
context = format_context(docs) context = format_context(docs)
print(f"[retrieve] {len(docs)} docs fetched") logger.info(f"[retrieve] {len(docs)} docs fetched")
print(context) logger.info(context)
return {"context": context} return {"context": context}
def generate(state: AgentState) -> AgentState: def generate(state: AgentState) -> AgentState:
@ -54,4 +57,4 @@ def build_graph(llm, vector_store) -> CompiledStateGraph:
graph_builder.add_edge("retrieve", "generate") graph_builder.add_edge("retrieve", "generate")
graph_builder.add_edge("generate", END) graph_builder.add_edge("generate", END)
return graph_builder.compile() return graph_builder.compile()

File diff suppressed because one or more lines are too long