485 lines
17 KiB
Python
485 lines
17 KiB
Python
import logging
|
|
from collections import defaultdict
|
|
from elasticsearch import Elasticsearch
|
|
from langchain_core.documents import Document
|
|
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, BaseMessage
|
|
from langgraph.graph import END, StateGraph
|
|
from langgraph.graph.state import CompiledStateGraph
|
|
|
|
from prompts import (
|
|
CLASSIFY_PROMPT_TEMPLATE,
|
|
CODE_GENERATION_PROMPT,
|
|
CONVERSATIONAL_PROMPT,
|
|
GENERATE_PROMPT,
|
|
REFORMULATE_PROMPT,
|
|
)
|
|
|
|
from state import AgentState
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
session_store: dict[str, list] = defaultdict(list)
|
|
|
|
|
|
def format_context(docs):
|
|
chunks = []
|
|
for i, doc in enumerate(docs, 1):
|
|
meta = doc.metadata or {}
|
|
chunk_id = meta.get("chunk_id", meta.get("id", f"chunk-{i}"))
|
|
source = meta.get("source_file", meta.get("source", "unknown"))
|
|
doc_type = meta.get("doc_type", "")
|
|
block_type = meta.get("block_type", "")
|
|
section = meta.get("section", "")
|
|
|
|
text = (doc.page_content or "").strip()
|
|
if not text:
|
|
text = meta.get("content") or meta.get("text") or ""
|
|
|
|
header_parts = [f"[{i}]", f"id={chunk_id}"]
|
|
if doc_type: header_parts.append(f"type={doc_type}")
|
|
if block_type: header_parts.append(f"block={block_type}")
|
|
if section: header_parts.append(f"section={section}")
|
|
header_parts.append(f"source={source}")
|
|
|
|
if doc_type in ("code", "code_example", "bnf") or \
|
|
block_type in ("function", "if", "startLoop", "try"):
|
|
header_parts.append("[AVAP CODE]")
|
|
|
|
chunks.append(" ".join(header_parts) + "\n" + text)
|
|
|
|
return "\n\n".join(chunks)
|
|
|
|
|
|
def format_history_for_classify(messages):
|
|
lines = []
|
|
for msg in messages[-6:]:
|
|
if isinstance(msg, HumanMessage):
|
|
lines.append(f"User: {msg.content}")
|
|
elif isinstance(msg, AIMessage):
|
|
lines.append(f"Assistant: {msg.content[:300]}")
|
|
elif isinstance(msg, dict):
|
|
role = msg.get("role", "user")
|
|
content = msg.get("content", "")[:300]
|
|
lines.append(f"{role.capitalize()}: {content}")
|
|
return "\n".join(lines) if lines else "(no history)"
|
|
|
|
|
|
def hybrid_search_native(es_client, embeddings, query, index_name, k=8):
|
|
query_vector = None
|
|
try:
|
|
query_vector = embeddings.embed_query(query)
|
|
except Exception as e:
|
|
logger.warning(f"[hybrid] embed_query fails: {e}")
|
|
|
|
bm25_hits = []
|
|
try:
|
|
resp = es_client.search(
|
|
index=index_name,
|
|
body={
|
|
"size": k,
|
|
"query": {
|
|
"multi_match": {
|
|
"query": query,
|
|
"fields": ["content^2", "text^2"],
|
|
"type": "best_fields",
|
|
"fuzziness": "AUTO",
|
|
}
|
|
},
|
|
"_source": {"excludes": ["embedding"]},
|
|
}
|
|
)
|
|
bm25_hits = resp["hits"]["hits"]
|
|
logger.info(f"[hybrid] BM25 -> {len(bm25_hits)} hits")
|
|
except Exception as e:
|
|
logger.warning(f"[hybrid] BM25 fails: {e}")
|
|
|
|
knn_hits = []
|
|
if query_vector:
|
|
try:
|
|
resp = es_client.search(
|
|
index=index_name,
|
|
body={
|
|
"size": k,
|
|
"knn": {
|
|
"field": "embedding",
|
|
"query_vector": query_vector,
|
|
"k": k,
|
|
"num_candidates": k * 5,
|
|
},
|
|
"_source": {"excludes": ["embedding"]},
|
|
}
|
|
)
|
|
knn_hits = resp["hits"]["hits"]
|
|
logger.info(f"[hybrid] kNN -> {len(knn_hits)} hits")
|
|
except Exception as e:
|
|
logger.warning(f"[hybrid] kNN fails: {e}")
|
|
|
|
rrf_scores: dict[str, float] = defaultdict(float)
|
|
hit_by_id: dict[str, dict] = {}
|
|
|
|
for rank, hit in enumerate(bm25_hits):
|
|
doc_id = hit["_id"]
|
|
rrf_scores[doc_id] += 1.0 / (rank + 60)
|
|
hit_by_id[doc_id] = hit
|
|
|
|
for rank, hit in enumerate(knn_hits):
|
|
doc_id = hit["_id"]
|
|
rrf_scores[doc_id] += 1.0 / (rank + 60)
|
|
if doc_id not in hit_by_id:
|
|
hit_by_id[doc_id] = hit
|
|
|
|
ranked = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)[:k]
|
|
|
|
docs = []
|
|
for doc_id, score in ranked:
|
|
src = hit_by_id[doc_id]["_source"]
|
|
text = src.get("content") or src.get("text") or ""
|
|
meta = {k: v for k, v in src.items()
|
|
if k not in ("content", "text", "embedding")}
|
|
meta["id"]= doc_id
|
|
meta["rrf_score"] = score
|
|
docs.append(Document(page_content=text, metadata=meta))
|
|
|
|
logger.info(f"[hybrid] RRF -> {len(docs)} final docs")
|
|
return docs
|
|
|
|
def _build_classify_prompt(question: str, history_text: str, selected_text: str) -> str:
|
|
prompt = (
|
|
CLASSIFY_PROMPT_TEMPLATE
|
|
.replace("{history}", history_text)
|
|
.replace("{message}", question)
|
|
)
|
|
if selected_text:
|
|
editor_section = (
|
|
"\n\n<editor_selection>\n"
|
|
"The user currently has the following AVAP code selected in their editor. "
|
|
"If the question refers to 'this', 'here', 'the code above', or similar, "
|
|
"it is about this selection.\n"
|
|
f"{selected_text}\n"
|
|
"</editor_selection>"
|
|
)
|
|
|
|
prompt = prompt.replace(
|
|
f"<user_message>{question}</user_message>",
|
|
f"{editor_section}\n\n<user_message>{question}</user_message>"
|
|
)
|
|
return prompt
|
|
|
|
|
|
def _build_reformulate_query(question: str, selected_text: str) -> str:
|
|
if not selected_text:
|
|
return question
|
|
return f"{selected_text}\n\nUser question about the above: {question}"
|
|
|
|
|
|
def _build_generation_prompt(template_prompt: SystemMessage, context: str,
|
|
editor_content: str, selected_text: str,
|
|
extra_context: str) -> SystemMessage:
|
|
base = template_prompt.content.format(context=context)
|
|
|
|
sections = []
|
|
|
|
if selected_text:
|
|
sections.append(
|
|
"<selected_code>\n"
|
|
"The user has the following AVAP code selected in their editor. "
|
|
"Ground your answer in this code first. "
|
|
"Use the RAG context as supplementary reference only.\n"
|
|
f"{selected_text}\n"
|
|
"</selected_code>"
|
|
)
|
|
|
|
if editor_content:
|
|
sections.append(
|
|
"<editor_file>\n"
|
|
"Full content of the active file open in the editor "
|
|
"(use for broader context if needed):\n"
|
|
f"{editor_content}\n"
|
|
"</editor_file>"
|
|
)
|
|
|
|
if extra_context:
|
|
sections.append(
|
|
"<extra_context>\n"
|
|
f"{extra_context}\n"
|
|
"</extra_context>"
|
|
)
|
|
|
|
if sections:
|
|
editor_block = "\n\n".join(sections)
|
|
base = editor_block + "\n\n" + base
|
|
|
|
return SystemMessage(content=base)
|
|
|
|
|
|
def _parse_query_type(raw: str) -> tuple[str, bool]:
|
|
parts = raw.strip().upper().split()
|
|
query_type = "RETRIEVAL"
|
|
use_editor = False
|
|
if parts:
|
|
first = parts[0]
|
|
if first.startswith("CODE_GENERATION") or "CODE" in first:
|
|
query_type = "CODE_GENERATION"
|
|
elif first.startswith("CONVERSATIONAL"):
|
|
query_type = "CONVERSATIONAL"
|
|
if len(parts) > 1 and parts[1] == "EDITOR":
|
|
use_editor = True
|
|
return query_type, use_editor
|
|
|
|
def build_graph(llm, embeddings, es_client, index_name):
|
|
|
|
def _persist(state: AgentState, response: BaseMessage):
|
|
session_id = state.get("session_id", "")
|
|
if session_id:
|
|
session_store[session_id] = list(state["messages"]) + [response]
|
|
|
|
def classify(state):
|
|
messages = state["messages"]
|
|
user_msg = messages[-1]
|
|
question = getattr(user_msg, "content",
|
|
user_msg.get("content", "")
|
|
if isinstance(user_msg, dict) else "")
|
|
history_msgs = messages[:-1]
|
|
selected_text = state.get("selected_text", "")
|
|
|
|
history_text = format_history_for_classify(history_msgs) if history_msgs else "(no history)"
|
|
prompt_content = _build_classify_prompt(question, history_text, selected_text)
|
|
|
|
resp = llm.invoke([SystemMessage(content=prompt_content)])
|
|
raw = resp.content.strip().upper()
|
|
query_type, use_editor_ctx = _parse_query_type(raw)
|
|
logger.info(f"[classify] selected={bool(selected_text)} raw='{raw}' -> {query_type} editor={use_editor_ctx}")
|
|
return {"query_type": query_type, "use_editor_context": use_editor_ctx}
|
|
|
|
def reformulate(state: AgentState) -> AgentState:
|
|
user_msg = state["messages"][-1]
|
|
selected_text = state.get("selected_text", "")
|
|
question = getattr(user_msg, "content",
|
|
user_msg.get("content", "")
|
|
if isinstance(user_msg, dict) else "")
|
|
|
|
anchor = _build_reformulate_query(question, selected_text)
|
|
|
|
if selected_text:
|
|
|
|
from langchain_core.messages import HumanMessage as HM
|
|
resp = llm.invoke([REFORMULATE_PROMPT, HM(content=anchor)])
|
|
else:
|
|
query_type = state.get("query_type", "RETRIEVAL")
|
|
mode_hint = HumanMessage(content=f"[MODE: {query_type}]\n{question}")
|
|
resp = llm.invoke([REFORMULATE_PROMPT, mode_hint])
|
|
|
|
reformulated = resp.content.strip()
|
|
logger.info(f"[reformulate] selected={bool(selected_text)} -> '{reformulated}'")
|
|
return {"reformulated_query": reformulated}
|
|
|
|
def retrieve(state: AgentState) -> AgentState:
|
|
query = state["reformulated_query"]
|
|
docs = hybrid_search_native(
|
|
es_client=es_client,
|
|
embeddings=embeddings,
|
|
query=query,
|
|
index_name=index_name,
|
|
k=8,
|
|
)
|
|
context = format_context(docs)
|
|
logger.info(f"[retrieve] {len(docs)} docs, context len={len(context)}")
|
|
return {"context": context}
|
|
|
|
def generate(state):
|
|
use_editor = state.get("use_editor_context", False)
|
|
prompt = _build_generation_prompt(
|
|
template_prompt = GENERATE_PROMPT,
|
|
context = state.get("context", ""),
|
|
editor_content = state.get("editor_content", "") if use_editor else "",
|
|
selected_text = state.get("selected_text", "") if use_editor else "",
|
|
extra_context = state.get("extra_context", ""),
|
|
)
|
|
resp = llm.invoke([prompt] + state["messages"])
|
|
logger.info(f"[generate] {len(resp.content)} chars")
|
|
_persist(state, resp)
|
|
return {"messages": [resp]}
|
|
|
|
def generate_code(state):
|
|
use_editor = state.get("use_editor_context", False)
|
|
prompt = _build_generation_prompt(
|
|
template_prompt = CODE_GENERATION_PROMPT,
|
|
context = state.get("context", ""),
|
|
editor_content = state.get("editor_content", "") if use_editor else "",
|
|
selected_text = state.get("selected_text", "") if use_editor else "",
|
|
extra_context = state.get("extra_context", ""),
|
|
)
|
|
resp = llm.invoke([prompt] + state["messages"])
|
|
logger.info(f"[generate_code] {len(resp.content)} chars")
|
|
_persist(state, resp)
|
|
return {"messages": [resp]}
|
|
|
|
def respond_conversational(state):
|
|
resp = llm.invoke([CONVERSATIONAL_PROMPT] + state["messages"])
|
|
logger.info("[conversational] from conversation")
|
|
_persist(state, resp)
|
|
return {"messages": [resp]}
|
|
|
|
def route_by_type(state):
|
|
return state.get("query_type", "RETRIEVAL")
|
|
|
|
def route_after_retrieve(state):
|
|
qt = state.get("query_type", "RETRIEVAL")
|
|
return "generate_code" if qt == "CODE_GENERATION" else "generate"
|
|
|
|
graph_builder = StateGraph(AgentState)
|
|
|
|
graph_builder.add_node("classify", classify)
|
|
graph_builder.add_node("reformulate", reformulate)
|
|
graph_builder.add_node("retrieve", retrieve)
|
|
graph_builder.add_node("generate", generate)
|
|
graph_builder.add_node("generate_code", generate_code)
|
|
graph_builder.add_node("respond_conversational", respond_conversational)
|
|
|
|
graph_builder.set_entry_point("classify")
|
|
|
|
graph_builder.add_conditional_edges(
|
|
"classify",
|
|
route_by_type,
|
|
{
|
|
"RETRIEVAL": "reformulate",
|
|
"CODE_GENERATION": "reformulate",
|
|
"CONVERSATIONAL": "respond_conversational",
|
|
}
|
|
)
|
|
|
|
graph_builder.add_edge("reformulate", "retrieve")
|
|
|
|
graph_builder.add_conditional_edges(
|
|
"retrieve",
|
|
route_after_retrieve,
|
|
{
|
|
"generate": "generate",
|
|
"generate_code": "generate_code",
|
|
}
|
|
)
|
|
|
|
graph_builder.add_edge("generate", END)
|
|
graph_builder.add_edge("generate_code", END)
|
|
graph_builder.add_edge("respond_conversational", END)
|
|
|
|
return graph_builder.compile()
|
|
|
|
|
|
def build_prepare_graph(llm, embeddings, es_client, index_name):
|
|
|
|
def classify(state):
|
|
messages = state["messages"]
|
|
user_msg = messages[-1]
|
|
question = getattr(user_msg, "content",
|
|
user_msg.get("content", "")
|
|
if isinstance(user_msg, dict) else "")
|
|
history_msgs = messages[:-1]
|
|
selected_text = state.get("selected_text", "")
|
|
|
|
history_text = format_history_for_classify(history_msgs) if history_msgs else "(no history)"
|
|
prompt_content = _build_classify_prompt(question, history_text, selected_text)
|
|
|
|
resp = llm.invoke([SystemMessage(content=prompt_content)])
|
|
raw = resp.content.strip().upper()
|
|
query_type, use_editor_ctx = _parse_query_type(raw)
|
|
logger.info(f"[prepare/classify] selected={bool(selected_text)} raw='{raw}' -> {query_type} editor={use_editor_ctx}")
|
|
return {"query_type": query_type, "use_editor_context": use_editor_ctx}
|
|
|
|
def reformulate(state: AgentState) -> AgentState:
|
|
user_msg = state["messages"][-1]
|
|
selected_text = state.get("selected_text", "")
|
|
question = getattr(user_msg, "content",
|
|
user_msg.get("content", "")
|
|
if isinstance(user_msg, dict) else "")
|
|
|
|
anchor = _build_reformulate_query(question, selected_text)
|
|
|
|
if selected_text:
|
|
from langchain_core.messages import HumanMessage as HM
|
|
resp = llm.invoke([REFORMULATE_PROMPT, HM(content=anchor)])
|
|
else:
|
|
query_type = state.get("query_type", "RETRIEVAL")
|
|
mode_hint = HumanMessage(content=f"[MODE: {query_type}]\n{question}")
|
|
resp = llm.invoke([REFORMULATE_PROMPT, mode_hint])
|
|
|
|
reformulated = resp.content.strip()
|
|
logger.info(f"[prepare/reformulate] selected={bool(selected_text)} -> '{reformulated}'")
|
|
return {"reformulated_query": reformulated}
|
|
|
|
def retrieve(state: AgentState) -> AgentState:
|
|
query = state["reformulated_query"]
|
|
docs = hybrid_search_native(
|
|
es_client=es_client,
|
|
embeddings=embeddings,
|
|
query=query,
|
|
index_name=index_name,
|
|
k=8,
|
|
)
|
|
context = format_context(docs)
|
|
logger.info(f"[prepare/retrieve] {len(docs)} docs, context len={len(context)}")
|
|
return {"context": context}
|
|
|
|
def skip_retrieve(state: AgentState) -> AgentState:
|
|
return {"context": ""}
|
|
|
|
def route_by_type(state):
|
|
return state.get("query_type", "RETRIEVAL")
|
|
|
|
graph_builder = StateGraph(AgentState)
|
|
|
|
graph_builder.add_node("classify", classify)
|
|
graph_builder.add_node("reformulate", reformulate)
|
|
graph_builder.add_node("retrieve", retrieve)
|
|
graph_builder.add_node("skip_retrieve", skip_retrieve)
|
|
|
|
graph_builder.set_entry_point("classify")
|
|
|
|
graph_builder.add_conditional_edges(
|
|
"classify",
|
|
route_by_type,
|
|
{
|
|
"RETRIEVAL": "reformulate",
|
|
"CODE_GENERATION": "reformulate",
|
|
"CONVERSATIONAL": "skip_retrieve",
|
|
}
|
|
)
|
|
|
|
graph_builder.add_edge("reformulate", "retrieve")
|
|
graph_builder.add_edge("retrieve", END)
|
|
graph_builder.add_edge("skip_retrieve", END)
|
|
|
|
return graph_builder.compile()
|
|
|
|
|
|
def build_final_messages(state: AgentState) -> list:
|
|
query_type = state.get("query_type", "RETRIEVAL")
|
|
context = state.get("context", "")
|
|
messages = state.get("messages", [])
|
|
editor_content = state.get("editor_content", "")
|
|
selected_text = state.get("selected_text", "")
|
|
extra_context = state.get("extra_context", "")
|
|
|
|
if query_type == "CONVERSATIONAL":
|
|
return [CONVERSATIONAL_PROMPT] + messages
|
|
|
|
use_editor = state.get("use_editor_context", False)
|
|
if query_type == "CODE_GENERATION":
|
|
prompt = _build_generation_prompt(
|
|
template_prompt = CODE_GENERATION_PROMPT,
|
|
context = context,
|
|
editor_content = editor_content if use_editor else "",
|
|
selected_text = selected_text if use_editor else "",
|
|
extra_context = extra_context,
|
|
)
|
|
else:
|
|
prompt = _build_generation_prompt(
|
|
template_prompt = GENERATE_PROMPT,
|
|
context = context,
|
|
editor_content = editor_content if use_editor else "",
|
|
selected_text= selected_text if use_editor else "",
|
|
extra_context = extra_context,
|
|
)
|
|
|
|
return [prompt] + messages
|