import logging
from collections import defaultdict
from elasticsearch import Elasticsearch
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, BaseMessage
from langgraph.graph import END, StateGraph
from langgraph.graph.state import CompiledStateGraph
from prompts import (
CLASSIFY_PROMPT_TEMPLATE,
CODE_GENERATION_PROMPT,
CONVERSATIONAL_PROMPT,
GENERATE_PROMPT,
REFORMULATE_PROMPT,
)
from state import AgentState
logger = logging.getLogger(__name__)
session_store: dict[str, list] = defaultdict(list)
def format_context(docs):
chunks = []
for i, doc in enumerate(docs, 1):
meta = doc.metadata or {}
chunk_id = meta.get("chunk_id", meta.get("id", f"chunk-{i}"))
source = meta.get("source_file", meta.get("source", "unknown"))
doc_type = meta.get("doc_type", "")
block_type = meta.get("block_type", "")
section = meta.get("section", "")
text = (doc.page_content or "").strip()
if not text:
text = meta.get("content") or meta.get("text") or ""
header_parts = [f"[{i}]", f"id={chunk_id}"]
if doc_type: header_parts.append(f"type={doc_type}")
if block_type: header_parts.append(f"block={block_type}")
if section: header_parts.append(f"section={section}")
header_parts.append(f"source={source}")
if doc_type in ("code", "code_example", "bnf") or \
block_type in ("function", "if", "startLoop", "try"):
header_parts.append("[AVAP CODE]")
chunks.append(" ".join(header_parts) + "\n" + text)
return "\n\n".join(chunks)
def format_history_for_classify(messages):
lines = []
for msg in messages[-6:]:
if isinstance(msg, HumanMessage):
lines.append(f"User: {msg.content}")
elif isinstance(msg, AIMessage):
lines.append(f"Assistant: {msg.content[:300]}")
elif isinstance(msg, dict):
role = msg.get("role", "user")
content = msg.get("content", "")[:300]
lines.append(f"{role.capitalize()}: {content}")
return "\n".join(lines) if lines else "(no history)"
def hybrid_search_native(es_client, embeddings, query, index_name, k=8):
query_vector = None
try:
query_vector = embeddings.embed_query(query)
except Exception as e:
logger.warning(f"[hybrid] embed_query fails: {e}")
bm25_hits = []
try:
resp = es_client.search(
index=index_name,
body={
"size": k,
"query": {
"multi_match": {
"query": query,
"fields": ["content^2", "text^2"],
"type": "best_fields",
"fuzziness": "AUTO",
}
},
"_source": {"excludes": ["embedding"]},
}
)
bm25_hits = resp["hits"]["hits"]
logger.info(f"[hybrid] BM25 -> {len(bm25_hits)} hits")
except Exception as e:
logger.warning(f"[hybrid] BM25 fails: {e}")
knn_hits = []
if query_vector:
try:
resp = es_client.search(
index=index_name,
body={
"size": k,
"knn": {
"field": "embedding",
"query_vector": query_vector,
"k": k,
"num_candidates": k * 5,
},
"_source": {"excludes": ["embedding"]},
}
)
knn_hits = resp["hits"]["hits"]
logger.info(f"[hybrid] kNN -> {len(knn_hits)} hits")
except Exception as e:
logger.warning(f"[hybrid] kNN fails: {e}")
rrf_scores: dict[str, float] = defaultdict(float)
hit_by_id: dict[str, dict] = {}
for rank, hit in enumerate(bm25_hits):
doc_id = hit["_id"]
rrf_scores[doc_id] += 1.0 / (rank + 60)
hit_by_id[doc_id] = hit
for rank, hit in enumerate(knn_hits):
doc_id = hit["_id"]
rrf_scores[doc_id] += 1.0 / (rank + 60)
if doc_id not in hit_by_id:
hit_by_id[doc_id] = hit
ranked = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)[:k]
docs = []
for doc_id, score in ranked:
src = hit_by_id[doc_id]["_source"]
text = src.get("content") or src.get("text") or ""
meta = {k: v for k, v in src.items()
if k not in ("content", "text", "embedding")}
meta["id"]= doc_id
meta["rrf_score"] = score
docs.append(Document(page_content=text, metadata=meta))
logger.info(f"[hybrid] RRF -> {len(docs)} final docs")
return docs
def _build_classify_prompt(question: str, history_text: str, selected_text: str) -> str:
prompt = (
CLASSIFY_PROMPT_TEMPLATE
.replace("{history}", history_text)
.replace("{message}", question)
)
if selected_text:
editor_section = (
"\n\n\n"
"The user currently has the following AVAP code selected in their editor. "
"If the question refers to 'this', 'here', 'the code above', or similar, "
"it is about this selection.\n"
f"{selected_text}\n"
""
)
prompt = prompt.replace(
f"{question}",
f"{editor_section}\n\n{question}"
)
return prompt
def _build_reformulate_query(question: str, selected_text: str) -> str:
if not selected_text:
return question
return f"{selected_text}\n\nUser question about the above: {question}"
def _build_generation_prompt(template_prompt: SystemMessage, context: str,
editor_content: str, selected_text: str,
extra_context: str) -> SystemMessage:
base = template_prompt.content.format(context=context)
sections = []
if selected_text:
sections.append(
"\n"
"The user has the following AVAP code selected in their editor. "
"Ground your answer in this code first. "
"Use the RAG context as supplementary reference only.\n"
f"{selected_text}\n"
""
)
if editor_content:
sections.append(
"\n"
"Full content of the active file open in the editor "
"(use for broader context if needed):\n"
f"{editor_content}\n"
""
)
if extra_context:
sections.append(
"\n"
f"{extra_context}\n"
""
)
if sections:
editor_block = "\n\n".join(sections)
base = editor_block + "\n\n" + base
return SystemMessage(content=base)
def _parse_query_type(raw: str) -> tuple[str, bool]:
parts = raw.strip().upper().split()
query_type = "RETRIEVAL"
use_editor = False
if parts:
first = parts[0]
if first.startswith("CODE_GENERATION") or "CODE" in first:
query_type = "CODE_GENERATION"
elif first.startswith("CONVERSATIONAL"):
query_type = "CONVERSATIONAL"
if len(parts) > 1 and parts[1] == "EDITOR":
use_editor = True
return query_type, use_editor
def build_graph(llm, embeddings, es_client, index_name):
def _persist(state: AgentState, response: BaseMessage):
session_id = state.get("session_id", "")
if session_id:
session_store[session_id] = list(state["messages"]) + [response]
def classify(state):
messages = state["messages"]
user_msg = messages[-1]
question = getattr(user_msg, "content",
user_msg.get("content", "")
if isinstance(user_msg, dict) else "")
history_msgs = messages[:-1]
selected_text = state.get("selected_text", "")
history_text = format_history_for_classify(history_msgs) if history_msgs else "(no history)"
prompt_content = _build_classify_prompt(question, history_text, selected_text)
resp = llm.invoke([SystemMessage(content=prompt_content)])
raw = resp.content.strip().upper()
query_type, use_editor_ctx = _parse_query_type(raw)
logger.info(f"[classify] selected={bool(selected_text)} raw='{raw}' -> {query_type} editor={use_editor_ctx}")
return {"query_type": query_type, "use_editor_context": use_editor_ctx}
def reformulate(state: AgentState) -> AgentState:
user_msg = state["messages"][-1]
selected_text = state.get("selected_text", "")
question = getattr(user_msg, "content",
user_msg.get("content", "")
if isinstance(user_msg, dict) else "")
anchor = _build_reformulate_query(question, selected_text)
if selected_text:
from langchain_core.messages import HumanMessage as HM
resp = llm.invoke([REFORMULATE_PROMPT, HM(content=anchor)])
else:
query_type = state.get("query_type", "RETRIEVAL")
mode_hint = HumanMessage(content=f"[MODE: {query_type}]\n{question}")
resp = llm.invoke([REFORMULATE_PROMPT, mode_hint])
reformulated = resp.content.strip()
logger.info(f"[reformulate] selected={bool(selected_text)} -> '{reformulated}'")
return {"reformulated_query": reformulated}
def retrieve(state: AgentState) -> AgentState:
query = state["reformulated_query"]
docs = hybrid_search_native(
es_client=es_client,
embeddings=embeddings,
query=query,
index_name=index_name,
k=8,
)
context = format_context(docs)
logger.info(f"[retrieve] {len(docs)} docs, context len={len(context)}")
return {"context": context}
def generate(state):
use_editor = state.get("use_editor_context", False)
prompt = _build_generation_prompt(
template_prompt = GENERATE_PROMPT,
context = state.get("context", ""),
editor_content = state.get("editor_content", "") if use_editor else "",
selected_text = state.get("selected_text", "") if use_editor else "",
extra_context = state.get("extra_context", ""),
)
resp = llm.invoke([prompt] + state["messages"])
logger.info(f"[generate] {len(resp.content)} chars")
_persist(state, resp)
return {"messages": [resp]}
def generate_code(state):
use_editor = state.get("use_editor_context", False)
prompt = _build_generation_prompt(
template_prompt = CODE_GENERATION_PROMPT,
context = state.get("context", ""),
editor_content = state.get("editor_content", "") if use_editor else "",
selected_text = state.get("selected_text", "") if use_editor else "",
extra_context = state.get("extra_context", ""),
)
resp = llm.invoke([prompt] + state["messages"])
logger.info(f"[generate_code] {len(resp.content)} chars")
_persist(state, resp)
return {"messages": [resp]}
def respond_conversational(state):
resp = llm.invoke([CONVERSATIONAL_PROMPT] + state["messages"])
logger.info("[conversational] from conversation")
_persist(state, resp)
return {"messages": [resp]}
def route_by_type(state):
return state.get("query_type", "RETRIEVAL")
def route_after_retrieve(state):
qt = state.get("query_type", "RETRIEVAL")
return "generate_code" if qt == "CODE_GENERATION" else "generate"
graph_builder = StateGraph(AgentState)
graph_builder.add_node("classify", classify)
graph_builder.add_node("reformulate", reformulate)
graph_builder.add_node("retrieve", retrieve)
graph_builder.add_node("generate", generate)
graph_builder.add_node("generate_code", generate_code)
graph_builder.add_node("respond_conversational", respond_conversational)
graph_builder.set_entry_point("classify")
graph_builder.add_conditional_edges(
"classify",
route_by_type,
{
"RETRIEVAL": "reformulate",
"CODE_GENERATION": "reformulate",
"CONVERSATIONAL": "respond_conversational",
}
)
graph_builder.add_edge("reformulate", "retrieve")
graph_builder.add_conditional_edges(
"retrieve",
route_after_retrieve,
{
"generate": "generate",
"generate_code": "generate_code",
}
)
graph_builder.add_edge("generate", END)
graph_builder.add_edge("generate_code", END)
graph_builder.add_edge("respond_conversational", END)
return graph_builder.compile()
def build_prepare_graph(llm, embeddings, es_client, index_name):
def classify(state):
messages = state["messages"]
user_msg = messages[-1]
question = getattr(user_msg, "content",
user_msg.get("content", "")
if isinstance(user_msg, dict) else "")
history_msgs = messages[:-1]
selected_text = state.get("selected_text", "")
history_text = format_history_for_classify(history_msgs) if history_msgs else "(no history)"
prompt_content = _build_classify_prompt(question, history_text, selected_text)
resp = llm.invoke([SystemMessage(content=prompt_content)])
raw = resp.content.strip().upper()
query_type, use_editor_ctx = _parse_query_type(raw)
logger.info(f"[prepare/classify] selected={bool(selected_text)} raw='{raw}' -> {query_type} editor={use_editor_ctx}")
return {"query_type": query_type, "use_editor_context": use_editor_ctx}
def reformulate(state: AgentState) -> AgentState:
user_msg = state["messages"][-1]
selected_text = state.get("selected_text", "")
question = getattr(user_msg, "content",
user_msg.get("content", "")
if isinstance(user_msg, dict) else "")
anchor = _build_reformulate_query(question, selected_text)
if selected_text:
from langchain_core.messages import HumanMessage as HM
resp = llm.invoke([REFORMULATE_PROMPT, HM(content=anchor)])
else:
query_type = state.get("query_type", "RETRIEVAL")
mode_hint = HumanMessage(content=f"[MODE: {query_type}]\n{question}")
resp = llm.invoke([REFORMULATE_PROMPT, mode_hint])
reformulated = resp.content.strip()
logger.info(f"[prepare/reformulate] selected={bool(selected_text)} -> '{reformulated}'")
return {"reformulated_query": reformulated}
def retrieve(state: AgentState) -> AgentState:
query = state["reformulated_query"]
docs = hybrid_search_native(
es_client=es_client,
embeddings=embeddings,
query=query,
index_name=index_name,
k=8,
)
context = format_context(docs)
logger.info(f"[prepare/retrieve] {len(docs)} docs, context len={len(context)}")
return {"context": context}
def skip_retrieve(state: AgentState) -> AgentState:
return {"context": ""}
def route_by_type(state):
return state.get("query_type", "RETRIEVAL")
graph_builder = StateGraph(AgentState)
graph_builder.add_node("classify", classify)
graph_builder.add_node("reformulate", reformulate)
graph_builder.add_node("retrieve", retrieve)
graph_builder.add_node("skip_retrieve", skip_retrieve)
graph_builder.set_entry_point("classify")
graph_builder.add_conditional_edges(
"classify",
route_by_type,
{
"RETRIEVAL": "reformulate",
"CODE_GENERATION": "reformulate",
"CONVERSATIONAL": "skip_retrieve",
}
)
graph_builder.add_edge("reformulate", "retrieve")
graph_builder.add_edge("retrieve", END)
graph_builder.add_edge("skip_retrieve", END)
return graph_builder.compile()
def build_final_messages(state: AgentState) -> list:
query_type = state.get("query_type", "RETRIEVAL")
context = state.get("context", "")
messages = state.get("messages", [])
editor_content = state.get("editor_content", "")
selected_text = state.get("selected_text", "")
extra_context = state.get("extra_context", "")
if query_type == "CONVERSATIONAL":
return [CONVERSATIONAL_PROMPT] + messages
use_editor = state.get("use_editor_context", False)
if query_type == "CODE_GENERATION":
prompt = _build_generation_prompt(
template_prompt = CODE_GENERATION_PROMPT,
context = context,
editor_content = editor_content if use_editor else "",
selected_text = selected_text if use_editor else "",
extra_context = extra_context,
)
else:
prompt = _build_generation_prompt(
template_prompt = GENERATE_PROMPT,
context = context,
editor_content = editor_content if use_editor else "",
selected_text= selected_text if use_editor else "",
extra_context = extra_context,
)
return [prompt] + messages