Merge branch 'mrh-online-dev' of github.com:BRUNIX-AI/assistance-engine into mrh-online-dev
This commit is contained in:
commit
f15266f345
|
|
@ -1,12 +1,15 @@
|
||||||
# graph.py
|
# graph.py
|
||||||
|
import logging
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.messages import SystemMessage
|
from langchain_core.messages import SystemMessage
|
||||||
from langgraph.graph import StateGraph, END
|
from langgraph.graph import END, StateGraph
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
from prompts import GENERATE_PROMPT, REFORMULATE_PROMPT
|
||||||
from prompts import REFORMULATE_PROMPT, GENERATE_PROMPT
|
|
||||||
from state import AgentState
|
from state import AgentState
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def format_context(docs: list[Document]) -> str:
|
def format_context(docs: list[Document]) -> str:
|
||||||
chunks = []
|
chunks = []
|
||||||
|
|
@ -23,7 +26,7 @@ def build_graph(llm, vector_store) -> CompiledStateGraph:
|
||||||
user_msg = state["messages"][-1]
|
user_msg = state["messages"][-1]
|
||||||
resp = llm.invoke([REFORMULATE_PROMPT, user_msg])
|
resp = llm.invoke([REFORMULATE_PROMPT, user_msg])
|
||||||
reformulated = resp.content.strip()
|
reformulated = resp.content.strip()
|
||||||
print(f"[reformulate] '{user_msg.content}' → '{reformulated}'")
|
logger.info(f"[reformulate] '{user_msg.content}' → '{reformulated}'")
|
||||||
return {"reformulated_query": reformulated}
|
return {"reformulated_query": reformulated}
|
||||||
|
|
||||||
def retrieve(state: AgentState) -> AgentState:
|
def retrieve(state: AgentState) -> AgentState:
|
||||||
|
|
@ -33,8 +36,8 @@ def build_graph(llm, vector_store) -> CompiledStateGraph:
|
||||||
search_kwargs={"k": 3},
|
search_kwargs={"k": 3},
|
||||||
).invoke(query)
|
).invoke(query)
|
||||||
context = format_context(docs)
|
context = format_context(docs)
|
||||||
print(f"[retrieve] {len(docs)} docs fetched")
|
logger.info(f"[retrieve] {len(docs)} docs fetched")
|
||||||
print(context)
|
logger.info(context)
|
||||||
return {"context": context}
|
return {"context": context}
|
||||||
|
|
||||||
def generate(state: AgentState) -> AgentState:
|
def generate(state: AgentState) -> AgentState:
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue