import logging from collections import defaultdict from elasticsearch import Elasticsearch from langchain_core.documents import Document from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, BaseMessage from langgraph.graph import END, StateGraph from langgraph.graph.state import CompiledStateGraph from prompts import ( CLASSIFY_PROMPT_TEMPLATE, CODE_GENERATION_PROMPT, CONVERSATIONAL_PROMPT, GENERATE_PROMPT, REFORMULATE_PROMPT, ) from state import AgentState logger = logging.getLogger(__name__) session_store: dict[str, list] = defaultdict(list) def format_context(docs): chunks = [] for i, doc in enumerate(docs, 1): meta = doc.metadata or {} chunk_id = meta.get("chunk_id", meta.get("id", f"chunk-{i}")) source = meta.get("source_file", meta.get("source", "unknown")) doc_type = meta.get("doc_type", "") block_type = meta.get("block_type", "") section = meta.get("section", "") text = (doc.page_content or "").strip() if not text: text = meta.get("content") or meta.get("text") or "" header_parts = [f"[{i}]", f"id={chunk_id}"] if doc_type: header_parts.append(f"type={doc_type}") if block_type: header_parts.append(f"block={block_type}") if section: header_parts.append(f"section={section}") header_parts.append(f"source={source}") if doc_type in ("code", "code_example", "bnf") or \ block_type in ("function", "if", "startLoop", "try"): header_parts.append("[AVAP CODE]") chunks.append(" ".join(header_parts) + "\n" + text) return "\n\n".join(chunks) def format_history_for_classify(messages): lines = [] for msg in messages[-6:]: if isinstance(msg, HumanMessage): lines.append(f"User: {msg.content}") elif isinstance(msg, AIMessage): lines.append(f"Assistant: {msg.content[:300]}") elif isinstance(msg, dict): role = msg.get("role", "user") content = msg.get("content", "")[:300] lines.append(f"{role.capitalize()}: {content}") return "\n".join(lines) if lines else "(no history)" def hybrid_search_native(es_client, embeddings, query, index_name, k=8): query_vector = None try: query_vector = embeddings.embed_query(query) except Exception as e: logger.warning(f"[hybrid] embed_query fails: {e}") bm25_hits = [] try: resp = es_client.search( index=index_name, body={ "size": k, "query": { "multi_match": { "query": query, "fields": ["content^2", "text^2"], "type": "best_fields", "fuzziness": "AUTO", } }, "_source": {"excludes": ["embedding"]}, } ) bm25_hits = resp["hits"]["hits"] logger.info(f"[hybrid] BM25 -> {len(bm25_hits)} hits") except Exception as e: logger.warning(f"[hybrid] BM25 fails: {e}") knn_hits = [] if query_vector: try: resp = es_client.search( index=index_name, body={ "size": k, "knn": { "field": "embedding", "query_vector": query_vector, "k": k, "num_candidates": k * 5, }, "_source": {"excludes": ["embedding"]}, } ) knn_hits = resp["hits"]["hits"] logger.info(f"[hybrid] kNN -> {len(knn_hits)} hits") except Exception as e: logger.warning(f"[hybrid] kNN fails: {e}") rrf_scores: dict[str, float] = defaultdict(float) hit_by_id: dict[str, dict] = {} for rank, hit in enumerate(bm25_hits): doc_id = hit["_id"] rrf_scores[doc_id] += 1.0 / (rank + 60) hit_by_id[doc_id] = hit for rank, hit in enumerate(knn_hits): doc_id = hit["_id"] rrf_scores[doc_id] += 1.0 / (rank + 60) if doc_id not in hit_by_id: hit_by_id[doc_id] = hit ranked = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)[:k] docs = [] for doc_id, score in ranked: src = hit_by_id[doc_id]["_source"] text = src.get("content") or src.get("text") or "" meta = {k: v for k, v in src.items() if k not in ("content", "text", "embedding")} meta["id"]= doc_id meta["rrf_score"] = score docs.append(Document(page_content=text, metadata=meta)) logger.info(f"[hybrid] RRF -> {len(docs)} final docs") return docs def _build_classify_prompt(question: str, history_text: str, selected_text: str) -> str: prompt = ( CLASSIFY_PROMPT_TEMPLATE .replace("{history}", history_text) .replace("{message}", question) ) if selected_text: editor_section = ( "\n\n\n" "The user currently has the following AVAP code selected in their editor. " "If the question refers to 'this', 'here', 'the code above', or similar, " "it is about this selection.\n" f"{selected_text}\n" "" ) prompt = prompt.replace( f"{question}", f"{editor_section}\n\n{question}" ) return prompt def _build_reformulate_query(question: str, selected_text: str) -> str: if not selected_text: return question return f"{selected_text}\n\nUser question about the above: {question}" def _build_generation_prompt(template_prompt: SystemMessage, context: str, editor_content: str, selected_text: str, extra_context: str) -> SystemMessage: base = template_prompt.content.format(context=context) sections = [] if selected_text: sections.append( "\n" "The user has the following AVAP code selected in their editor. " "Ground your answer in this code first. " "Use the RAG context as supplementary reference only.\n" f"{selected_text}\n" "" ) if editor_content: sections.append( "\n" "Full content of the active file open in the editor " "(use for broader context if needed):\n" f"{editor_content}\n" "" ) if extra_context: sections.append( "\n" f"{extra_context}\n" "" ) if sections: editor_block = "\n\n".join(sections) base = editor_block + "\n\n" + base return SystemMessage(content=base) def _parse_query_type(raw: str) -> tuple[str, bool]: parts = raw.strip().upper().split() query_type = "RETRIEVAL" use_editor = False if parts: first = parts[0] if first.startswith("CODE_GENERATION") or "CODE" in first: query_type = "CODE_GENERATION" elif first.startswith("CONVERSATIONAL"): query_type = "CONVERSATIONAL" if len(parts) > 1 and parts[1] == "EDITOR": use_editor = True return query_type, use_editor def build_graph(llm, embeddings, es_client, index_name): def _persist(state: AgentState, response: BaseMessage): session_id = state.get("session_id", "") if session_id: session_store[session_id] = list(state["messages"]) + [response] def classify(state): messages = state["messages"] user_msg = messages[-1] question = getattr(user_msg, "content", user_msg.get("content", "") if isinstance(user_msg, dict) else "") history_msgs = messages[:-1] selected_text = state.get("selected_text", "") history_text = format_history_for_classify(history_msgs) if history_msgs else "(no history)" prompt_content = _build_classify_prompt(question, history_text, selected_text) resp = llm.invoke([SystemMessage(content=prompt_content)]) raw = resp.content.strip().upper() query_type, use_editor_ctx = _parse_query_type(raw) logger.info(f"[classify] selected={bool(selected_text)} raw='{raw}' -> {query_type} editor={use_editor_ctx}") return {"query_type": query_type, "use_editor_context": use_editor_ctx} def reformulate(state: AgentState) -> AgentState: user_msg = state["messages"][-1] selected_text = state.get("selected_text", "") question = getattr(user_msg, "content", user_msg.get("content", "") if isinstance(user_msg, dict) else "") anchor = _build_reformulate_query(question, selected_text) if selected_text: from langchain_core.messages import HumanMessage as HM resp = llm.invoke([REFORMULATE_PROMPT, HM(content=anchor)]) else: query_type = state.get("query_type", "RETRIEVAL") mode_hint = HumanMessage(content=f"[MODE: {query_type}]\n{question}") resp = llm.invoke([REFORMULATE_PROMPT, mode_hint]) reformulated = resp.content.strip() logger.info(f"[reformulate] selected={bool(selected_text)} -> '{reformulated}'") return {"reformulated_query": reformulated} def retrieve(state: AgentState) -> AgentState: query = state["reformulated_query"] docs = hybrid_search_native( es_client=es_client, embeddings=embeddings, query=query, index_name=index_name, k=8, ) context = format_context(docs) logger.info(f"[retrieve] {len(docs)} docs, context len={len(context)}") return {"context": context} def generate(state): use_editor = state.get("use_editor_context", False) prompt = _build_generation_prompt( template_prompt = GENERATE_PROMPT, context = state.get("context", ""), editor_content = state.get("editor_content", "") if use_editor else "", selected_text = state.get("selected_text", "") if use_editor else "", extra_context = state.get("extra_context", ""), ) resp = llm.invoke([prompt] + state["messages"]) logger.info(f"[generate] {len(resp.content)} chars") _persist(state, resp) return {"messages": [resp]} def generate_code(state): use_editor = state.get("use_editor_context", False) prompt = _build_generation_prompt( template_prompt = CODE_GENERATION_PROMPT, context = state.get("context", ""), editor_content = state.get("editor_content", "") if use_editor else "", selected_text = state.get("selected_text", "") if use_editor else "", extra_context = state.get("extra_context", ""), ) resp = llm.invoke([prompt] + state["messages"]) logger.info(f"[generate_code] {len(resp.content)} chars") _persist(state, resp) return {"messages": [resp]} def respond_conversational(state): resp = llm.invoke([CONVERSATIONAL_PROMPT] + state["messages"]) logger.info("[conversational] from conversation") _persist(state, resp) return {"messages": [resp]} def route_by_type(state): return state.get("query_type", "RETRIEVAL") def route_after_retrieve(state): qt = state.get("query_type", "RETRIEVAL") return "generate_code" if qt == "CODE_GENERATION" else "generate" graph_builder = StateGraph(AgentState) graph_builder.add_node("classify", classify) graph_builder.add_node("reformulate", reformulate) graph_builder.add_node("retrieve", retrieve) graph_builder.add_node("generate", generate) graph_builder.add_node("generate_code", generate_code) graph_builder.add_node("respond_conversational", respond_conversational) graph_builder.set_entry_point("classify") graph_builder.add_conditional_edges( "classify", route_by_type, { "RETRIEVAL": "reformulate", "CODE_GENERATION": "reformulate", "CONVERSATIONAL": "respond_conversational", } ) graph_builder.add_edge("reformulate", "retrieve") graph_builder.add_conditional_edges( "retrieve", route_after_retrieve, { "generate": "generate", "generate_code": "generate_code", } ) graph_builder.add_edge("generate", END) graph_builder.add_edge("generate_code", END) graph_builder.add_edge("respond_conversational", END) return graph_builder.compile() def build_prepare_graph(llm, embeddings, es_client, index_name): def classify(state): messages = state["messages"] user_msg = messages[-1] question = getattr(user_msg, "content", user_msg.get("content", "") if isinstance(user_msg, dict) else "") history_msgs = messages[:-1] selected_text = state.get("selected_text", "") history_text = format_history_for_classify(history_msgs) if history_msgs else "(no history)" prompt_content = _build_classify_prompt(question, history_text, selected_text) resp = llm.invoke([SystemMessage(content=prompt_content)]) raw = resp.content.strip().upper() query_type, use_editor_ctx = _parse_query_type(raw) logger.info(f"[prepare/classify] selected={bool(selected_text)} raw='{raw}' -> {query_type} editor={use_editor_ctx}") return {"query_type": query_type, "use_editor_context": use_editor_ctx} def reformulate(state: AgentState) -> AgentState: user_msg = state["messages"][-1] selected_text = state.get("selected_text", "") question = getattr(user_msg, "content", user_msg.get("content", "") if isinstance(user_msg, dict) else "") anchor = _build_reformulate_query(question, selected_text) if selected_text: from langchain_core.messages import HumanMessage as HM resp = llm.invoke([REFORMULATE_PROMPT, HM(content=anchor)]) else: query_type = state.get("query_type", "RETRIEVAL") mode_hint = HumanMessage(content=f"[MODE: {query_type}]\n{question}") resp = llm.invoke([REFORMULATE_PROMPT, mode_hint]) reformulated = resp.content.strip() logger.info(f"[prepare/reformulate] selected={bool(selected_text)} -> '{reformulated}'") return {"reformulated_query": reformulated} def retrieve(state: AgentState) -> AgentState: query = state["reformulated_query"] docs = hybrid_search_native( es_client=es_client, embeddings=embeddings, query=query, index_name=index_name, k=8, ) context = format_context(docs) logger.info(f"[prepare/retrieve] {len(docs)} docs, context len={len(context)}") return {"context": context} def skip_retrieve(state: AgentState) -> AgentState: return {"context": ""} def route_by_type(state): return state.get("query_type", "RETRIEVAL") graph_builder = StateGraph(AgentState) graph_builder.add_node("classify", classify) graph_builder.add_node("reformulate", reformulate) graph_builder.add_node("retrieve", retrieve) graph_builder.add_node("skip_retrieve", skip_retrieve) graph_builder.set_entry_point("classify") graph_builder.add_conditional_edges( "classify", route_by_type, { "RETRIEVAL": "reformulate", "CODE_GENERATION": "reformulate", "CONVERSATIONAL": "skip_retrieve", } ) graph_builder.add_edge("reformulate", "retrieve") graph_builder.add_edge("retrieve", END) graph_builder.add_edge("skip_retrieve", END) return graph_builder.compile() def build_final_messages(state: AgentState) -> list: query_type = state.get("query_type", "RETRIEVAL") context = state.get("context", "") messages = state.get("messages", []) editor_content = state.get("editor_content", "") selected_text = state.get("selected_text", "") extra_context = state.get("extra_context", "") if query_type == "CONVERSATIONAL": return [CONVERSATIONAL_PROMPT] + messages use_editor = state.get("use_editor_context", False) if query_type == "CODE_GENERATION": prompt = _build_generation_prompt( template_prompt = CODE_GENERATION_PROMPT, context = context, editor_content = editor_content if use_editor else "", selected_text = selected_text if use_editor else "", extra_context = extra_context, ) else: prompt = _build_generation_prompt( template_prompt = GENERATE_PROMPT, context = context, editor_content = editor_content if use_editor else "", selected_text= selected_text if use_editor else "", extra_context = extra_context, ) return [prompt] + messages