diff --git a/Docker/.DS_Store b/Docker/.DS_Store index 1347c15..361fe8f 100644 Binary files a/Docker/.DS_Store and b/Docker/.DS_Store differ diff --git a/Docker/protos/brunix.proto b/Docker/protos/brunix.proto index cef3cb2..5192a20 100644 --- a/Docker/protos/brunix.proto +++ b/Docker/protos/brunix.proto @@ -50,9 +50,10 @@ message AgentRequest { } message AgentResponse { - string text = 1; - string avap_code = 2; - bool is_final = 3; + string text = 1; + string avap_code = 2; + bool is_final = 3; + string validation_status = 4; // "" | "INVALID_UNRESOLVED" | "PARSER_UNAVAILABLE" } // --------------------------------------------------------------------------- diff --git a/Docker/src/graph.py b/Docker/src/graph.py index 1789277..faeb60a 100644 --- a/Docker/src/graph.py +++ b/Docker/src/graph.py @@ -1,6 +1,9 @@ import logging import os import re as _re +import time +import threading +import requests as _requests from collections import defaultdict from pathlib import Path @@ -20,6 +23,7 @@ except ImportError: from prompts import ( CLASSIFY_PROMPT_TEMPLATE, CODE_GENERATION_PROMPT, + CONFIDENCE_PROMPT_TEMPLATE, CONVERSATIONAL_PROMPT, GENERATE_PROMPT, PLATFORM_PROMPT, @@ -30,6 +34,106 @@ from state import AgentState, ClassifyEntry logger = logging.getLogger(__name__) +# ── AVAP Parser client — ADR-0009 (PTVL) ────────────────────────────────────── + +_PARSER_URL = os.getenv("AVAP_PARSER_URL", "") +_PARSER_TIMEOUT = int(os.getenv("AVAP_PARSER_TIMEOUT", "2")) +_CB_THRESHOLD = int(os.getenv("PARSER_CB_THRESHOLD", "3")) +_CB_COOLDOWN = int(os.getenv("PARSER_CB_COOLDOWN", "30")) + + +class _CircuitBreaker: + CLOSED = "CLOSED" + OPEN = "OPEN" + HALF_OPEN = "HALF_OPEN" + + def __init__(self, threshold: int, cooldown: float): + self._state = self.CLOSED + self._failures = 0 + self._opened_at = 0.0 + self._threshold = threshold + self._cooldown = cooldown + self._lock = threading.Lock() + + def allow(self) -> bool: + with self._lock: + if self._state == self.CLOSED: + return True + if self._state == self.OPEN: + if time.time() - self._opened_at >= self._cooldown: + self._state = self.HALF_OPEN + logger.info("[ptvl] circuit HALF-OPEN — probing parser availability") + return True + return False + # HALF_OPEN: allow the probe + return True + + def success(self): + with self._lock: + if self._state != self.CLOSED: + logger.info("[ptvl] circuit CLOSED — parser reachable") + self._state = self.CLOSED + self._failures = 0 + + def failure(self): + with self._lock: + self._failures += 1 + if self._failures >= self._threshold or self._state == self.HALF_OPEN: + self._state = self.OPEN + self._opened_at = time.time() + logger.warning("[ptvl] circuit OPEN — skipping parser, returning unvalidated") + + +_parser_cb = _CircuitBreaker(_CB_THRESHOLD, _CB_COOLDOWN) + + +def _extract_avap_code(text: str) -> str: + """Return the first AVAP code block found in an LLM response.""" + for pattern in (r'```avap\s*\n(.*?)```', r'```\s*\n(.*?)```', r'```(.*?)```'): + m = _re.search(pattern, text, _re.DOTALL) + if m: + return m.group(1).strip() + return text + + +def _call_parser(text: str) -> tuple: + """Call AVAP Parser REST API. + + Returns: + (True, "") — code valid + (False, trace) — code invalid, trace contains the error + (None, "") — parser unavailable or circuit open + """ + if not _PARSER_URL or _PARSER_TIMEOUT == 0: + return None, "" + + if not _parser_cb.allow(): + return None, "" + + code = _extract_avap_code(text) + if not code.strip(): + return None, "" + + try: + resp = _requests.post( + f"{_PARSER_URL.rstrip('/')}/parse", + json={"code": code}, + timeout=_PARSER_TIMEOUT, + ) + data = resp.json() + if data.get("valid", False): + _parser_cb.success() + return True, "" + _parser_cb.success() # parser responded — it is healthy + return False, data.get("error", "parse error") + except Exception as exc: + _parser_cb.failure() + logger.warning(f"[ptvl] parser call failed: {exc}") + return None, "" + + +# ── Session stores ───────────────────────────────────────────────────────────── + session_store: dict[str, list] = defaultdict(list) classify_history_store: dict[str, list] = defaultdict(list) @@ -520,22 +624,128 @@ def build_graph(llm, embeddings, es_client, index_name, llm_conversational=None) _persist(state, resp) return {"messages": [resp]} + # ── PTVL nodes (ADR-0009) ───────────────────────────────────────────────── + + def validate_code(state: AgentState) -> AgentState: + last_msg = state["messages"][-1] + content = getattr(last_msg, "content", str(last_msg)) + valid, trace = _call_parser(content) + if valid is None: + logger.warning("[ptvl] parser unavailable — returning unvalidated") + return {"validation_status": "PARSER_UNAVAILABLE", "parser_trace": ""} + if valid: + logger.info("[ptvl] code VALID on first attempt") + return {"validation_status": "", "parser_trace": ""} + logger.info(f"[ptvl] code INVALID — trace: {str(trace)[:120]}") + return {"validation_status": "INVALID", "parser_trace": str(trace)} + + def generate_code_retry(state: AgentState) -> AgentState: + parser_trace = state.get("parser_trace", "") + use_editor = state.get("use_editor_context", False) + + feedback = ( + "\n\n\n" + "The previous attempt produced invalid AVAP code. Specific failures:\n\n" + f"{parser_trace}\n\n" + "Correct these errors. Do not repeat the same constructs.\n" + "" + ) if parser_trace else "" + + base = _build_generation_prompt( + template_prompt=CODE_GENERATION_PROMPT, + context=state.get("context", ""), + editor_content=state.get("editor_content", "") if use_editor else "", + selected_text=state.get("selected_text", "") if use_editor else "", + extra_context=state.get("extra_context", ""), + ) + retry_prompt = SystemMessage(content=base.content + feedback) + resp = llm.invoke([retry_prompt] + state["messages"]) + logger.info(f"[generate_code_retry] {len(resp.content)} chars") + _persist(state, resp) + return {"messages": [resp]} + + def validate_code_after_retry(state: AgentState) -> AgentState: + last_msg = state["messages"][-1] + content = getattr(last_msg, "content", str(last_msg)) + valid, trace = _call_parser(content) + if valid is None: + logger.warning("[ptvl] parser unavailable after retry") + return {"validation_status": "PARSER_UNAVAILABLE", "parser_trace": ""} + if valid: + logger.info("[ptvl] retry code VALID") + return {"validation_status": "", "parser_trace": ""} + logger.info("[ptvl] retry code still INVALID → INVALID_UNRESOLVED") + return {"validation_status": "INVALID_UNRESOLVED", "parser_trace": str(trace)} + + def check_context_relevance(state: AgentState) -> AgentState: + user_msg = state["messages"][-1] + question = getattr(user_msg, "content", + user_msg.get("content", "") if isinstance(user_msg, dict) else "") + context = state.get("context", "") + prompt = CONFIDENCE_PROMPT_TEMPLATE.format(question=question, context=context) + resp = llm.invoke([SystemMessage(content=prompt)]) + relevant = resp.content.strip().upper().startswith("YES") + logger.info(f"[check_context_relevance] relevant={relevant}") + return {"context_relevant": relevant} + + def reformulate_with_hint(state: AgentState) -> AgentState: + user_msg = state["messages"][-1] + question = getattr(user_msg, "content", + user_msg.get("content", "") if isinstance(user_msg, dict) else "") + hint = ( + f"[CONTEXT_INSUFFICIENT]\n" + f"The previous retrieval did not return relevant context for: \"{question}\"\n" + f"Reformulate this query using broader terms or alternative phrasing." + ) + resp = llm.invoke([REFORMULATE_PROMPT, HumanMessage(content=hint)]) + reformulated = resp.content.strip() + logger.info(f"[reformulate_with_hint] -> '{reformulated}'") + return {"reformulated_query": reformulated} + + def retrieve_retry(state: AgentState) -> AgentState: + query = state["reformulated_query"] + docs = hybrid_search_native( + es_client=es_client, embeddings=embeddings, + query=query, index_name=index_name, k=8, + ) + context = format_context(docs) + logger.info(f"[retrieve_retry] {len(docs)} docs, context len={len(context)}") + return {"context": context} + + # ── Routing ─────────────────────────────────────────────────────────────── + def route_by_type(state): return state.get("query_type", "RETRIEVAL") def route_after_retrieve(state): qt = state.get("query_type", "RETRIEVAL") - return "generate_code" if qt == "CODE_GENERATION" else "generate" + return "generate_code" if qt == "CODE_GENERATION" else "check_context_relevance" + + def route_after_validate(state): + if state.get("validation_status", "") == "INVALID": + return "generate_code_retry" + return END + + def route_after_context_check(state): + if state.get("context_relevant", True): + return "generate" + return "reformulate_with_hint" graph_builder = StateGraph(AgentState) - graph_builder.add_node("classify", classify) - graph_builder.add_node("reformulate", reformulate) - graph_builder.add_node("retrieve", retrieve) - graph_builder.add_node("generate", generate) - graph_builder.add_node("generate_code", generate_code) - graph_builder.add_node("respond_conversational", respond_conversational) - graph_builder.add_node("respond_platform", respond_platform) + graph_builder.add_node("classify", classify) + graph_builder.add_node("reformulate", reformulate) + graph_builder.add_node("retrieve", retrieve) + graph_builder.add_node("generate", generate) + graph_builder.add_node("generate_code", generate_code) + graph_builder.add_node("validate_code", validate_code) + graph_builder.add_node("generate_code_retry", generate_code_retry) + graph_builder.add_node("validate_code_after_retry",validate_code_after_retry) + graph_builder.add_node("check_context_relevance", check_context_relevance) + graph_builder.add_node("reformulate_with_hint", reformulate_with_hint) + graph_builder.add_node("retrieve_retry", retrieve_retry) + graph_builder.add_node("respond_conversational", respond_conversational) + graph_builder.add_node("respond_platform", respond_platform) graph_builder.set_entry_point("classify") @@ -543,10 +753,10 @@ def build_graph(llm, embeddings, es_client, index_name, llm_conversational=None) "classify", route_by_type, { - "RETRIEVAL": "reformulate", + "RETRIEVAL": "reformulate", "CODE_GENERATION": "reformulate", - "CONVERSATIONAL": "respond_conversational", - "PLATFORM": "respond_platform", + "CONVERSATIONAL": "respond_conversational", + "PLATFORM": "respond_platform", } ) @@ -556,13 +766,37 @@ def build_graph(llm, embeddings, es_client, index_name, llm_conversational=None) "retrieve", route_after_retrieve, { - "generate": "generate", - "generate_code": "generate_code", + "generate_code": "generate_code", + "check_context_relevance": "check_context_relevance", } ) + # CODE_GENERATION path: generate → validate → (retry if invalid) → END + graph_builder.add_edge("generate_code", "validate_code") + graph_builder.add_conditional_edges( + "validate_code", + route_after_validate, + { + "generate_code_retry": "generate_code_retry", + END: END, + } + ) + graph_builder.add_edge("generate_code_retry", "validate_code_after_retry") + graph_builder.add_edge("validate_code_after_retry", END) + + # RETRIEVAL path: context check → generate or reformulate+retry → generate + graph_builder.add_conditional_edges( + "check_context_relevance", + route_after_context_check, + { + "generate": "generate", + "reformulate_with_hint": "reformulate_with_hint", + } + ) + graph_builder.add_edge("reformulate_with_hint", "retrieve_retry") + graph_builder.add_edge("retrieve_retry", "generate") + graph_builder.add_edge("generate", END) - graph_builder.add_edge("generate_code", END) graph_builder.add_edge("respond_conversational", END) graph_builder.add_edge("respond_platform", END) @@ -666,15 +900,64 @@ def build_prepare_graph(llm, embeddings, es_client, index_name): def skip_retrieve(state: AgentState) -> AgentState: return {"context": ""} + def check_context_relevance(state: AgentState) -> AgentState: + user_msg = state["messages"][-1] + question = getattr(user_msg, "content", + user_msg.get("content", "") if isinstance(user_msg, dict) else "") + context = state.get("context", "") + prompt = CONFIDENCE_PROMPT_TEMPLATE.format(question=question, context=context) + resp = llm.invoke([SystemMessage(content=prompt)]) + relevant = resp.content.strip().upper().startswith("YES") + logger.info(f"[prepare/check_context_relevance] relevant={relevant}") + return {"context_relevant": relevant} + + def reformulate_with_hint(state: AgentState) -> AgentState: + user_msg = state["messages"][-1] + question = getattr(user_msg, "content", + user_msg.get("content", "") if isinstance(user_msg, dict) else "") + hint = ( + f"[CONTEXT_INSUFFICIENT]\n" + f"The previous retrieval did not return relevant context for: \"{question}\"\n" + f"Reformulate this query using broader terms or alternative phrasing." + ) + resp = llm.invoke([REFORMULATE_PROMPT, HumanMessage(content=hint)]) + reformulated = resp.content.strip() + logger.info(f"[prepare/reformulate_with_hint] -> '{reformulated}'") + return {"reformulated_query": reformulated} + + def retrieve_retry(state: AgentState) -> AgentState: + query = state["reformulated_query"] + docs = hybrid_search_native( + es_client=es_client, embeddings=embeddings, + query=query, index_name=index_name, k=8, + ) + context = format_context(docs) + logger.info(f"[prepare/retrieve_retry] {len(docs)} docs, context len={len(context)}") + return {"context": context} + def route_by_type(state): return state.get("query_type", "RETRIEVAL") + def route_after_retrieve(state): + qt = state.get("query_type", "RETRIEVAL") + if qt == "RETRIEVAL": + return "check_context_relevance" + return END # CODE_GENERATION goes straight to END + + def route_after_context_check(state): + if state.get("context_relevant", True): + return END + return "reformulate_with_hint" + graph_builder = StateGraph(AgentState) - graph_builder.add_node("classify", classify) - graph_builder.add_node("reformulate", reformulate) - graph_builder.add_node("retrieve", retrieve) - graph_builder.add_node("skip_retrieve", skip_retrieve) + graph_builder.add_node("classify", classify) + graph_builder.add_node("reformulate", reformulate) + graph_builder.add_node("retrieve", retrieve) + graph_builder.add_node("skip_retrieve", skip_retrieve) + graph_builder.add_node("check_context_relevance", check_context_relevance) + graph_builder.add_node("reformulate_with_hint", reformulate_with_hint) + graph_builder.add_node("retrieve_retry", retrieve_retry) graph_builder.set_entry_point("classify") @@ -682,15 +965,35 @@ def build_prepare_graph(llm, embeddings, es_client, index_name): "classify", route_by_type, { - "RETRIEVAL": "reformulate", + "RETRIEVAL": "reformulate", "CODE_GENERATION": "reformulate", - "CONVERSATIONAL": "skip_retrieve", - "PLATFORM": "skip_retrieve", + "CONVERSATIONAL": "skip_retrieve", + "PLATFORM": "skip_retrieve", } ) graph_builder.add_edge("reformulate", "retrieve") - graph_builder.add_edge("retrieve", END) + + graph_builder.add_conditional_edges( + "retrieve", + route_after_retrieve, + { + "check_context_relevance": "check_context_relevance", + END: END, + } + ) + + graph_builder.add_conditional_edges( + "check_context_relevance", + route_after_context_check, + { + END: END, + "reformulate_with_hint": "reformulate_with_hint", + } + ) + graph_builder.add_edge("reformulate_with_hint", "retrieve_retry") + graph_builder.add_edge("retrieve_retry", END) + graph_builder.add_edge("skip_retrieve", END) return graph_builder.compile() diff --git a/Docker/src/server.py b/Docker/src/server.py index 0014edf..3c1a223 100644 --- a/Docker/src/server.py +++ b/Docker/src/server.py @@ -10,11 +10,11 @@ import brunix_pb2_grpc import grpc from grpc_reflection.v1alpha import reflection from elasticsearch import Elasticsearch -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, SystemMessage from utils.llm_factory import create_chat_model from utils.emb_factory import create_embedding_model -from graph import build_graph, build_prepare_graph, build_final_messages, session_store, classify_history_store, _load_layer2_model +from graph import build_graph, build_prepare_graph, build_final_messages, session_store, classify_history_store, _load_layer2_model, _call_parser, _extract_avap_code from utils.classifier_export import maybe_export, force_export from evaluate import run_evaluation @@ -141,7 +141,12 @@ class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer): "editor_content": editor_content, "selected_text": selected_text, "extra_context": extra_context, - "user_info": user_info + "user_info": user_info, + + # PTVL fields (ADR-0009) + "parser_trace": "", + "validation_status": "", + "context_relevant": True, } final_state = self.graph.invoke(initial_state) @@ -150,17 +155,21 @@ class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer): result_text = getattr(last_msg, "content", str(last_msg)) \ if last_msg else "" + validation_status = final_state.get("validation_status", "") + if session_id: classify_history_store[session_id] = final_state.get("classify_history", classify_history) maybe_export(classify_history_store) logger.info(f"[AskAgent] query_type={final_state.get('query_type')} " + f"validation_status={validation_status!r} " f"answer='{result_text[:100]}'") yield brunix_pb2.AgentResponse( - text = result_text, - avap_code= "AVAP-2026", - is_final = True, + text = result_text, + avap_code = "AVAP-2026", + is_final = True, + validation_status = validation_status, ) except Exception as e: @@ -220,7 +229,12 @@ class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer): "editor_content": editor_content, "selected_text": selected_text, "extra_context": extra_context, - "user_info": user_info + "user_info": user_info, + + # PTVL fields (ADR-0009) + "parser_trace": "", + "validation_status": "", + "context_relevant": True, } prepared = self.prepare_graph.invoke(initial_state) @@ -235,14 +249,108 @@ class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer): query_type = prepared.get("query_type", "RETRIEVAL") active_llm = self.llm_conversational if query_type in ("CONVERSATIONAL", "PLATFORM") else self.llm - for chunk in active_llm.stream(final_messages): - token = chunk.content - if token: - full_response.append(token) - yield brunix_pb2.AgentResponse( - text = token, - is_final = False, - ) + validation_status = "" + + if query_type == "CODE_GENERATION": + # ── Streaming state machine (ADR-0009) ──────────────────────── + # TEXT mode → yield tokens immediately to the client + # CODE mode → buffer the code block, validate, fix if needed, + # then yield the validated block before continuing + stream_state = "text" + lookahead = "" # pending chars in TEXT mode (fence detection) + code_buffer = "" # accumulated chars in CODE mode + + for chunk in active_llm.stream(final_messages): + token = chunk.content + if not token: + continue + + if stream_state == "text": + lookahead += token + fence_pos = lookahead.find("```") + + if fence_pos != -1: + # Yield text before the fence immediately + text_before = lookahead[:fence_pos] + if text_before: + full_response.append(text_before) + yield brunix_pb2.AgentResponse(text=text_before, is_final=False) + stream_state = "code" + code_buffer = lookahead[fence_pos:] + lookahead = "" + else: + # Keep last 2 chars — a fence might span the next token + safe_len = max(0, len(lookahead) - 2) + if safe_len > 0: + safe = lookahead[:safe_len] + full_response.append(safe) + yield brunix_pb2.AgentResponse(text=safe, is_final=False) + lookahead = lookahead[safe_len:] + + else: # stream_state == "code" + code_buffer += token + # Look for closing ``` only after the first newline + # (to skip the opening fence + optional language tag) + first_nl = code_buffer.find("\n") + if first_nl == -1: + continue + close_pos = code_buffer.find("```", first_nl + 1) + if close_pos == -1: + continue + + # ── Complete code block captured ────────────────────── + complete_block = code_buffer[:close_pos + 3] + rest = code_buffer[close_pos + 3:] + + valid, trace = _call_parser(complete_block) + + if valid is False: + # Ask LLM to fix only the code block + logger.info("[stream/ptvl] INVALID — requesting fix from LLM") + fence_open = "```avap" if complete_block.startswith("```avap") else "```" + fix_resp = self.llm.invoke([SystemMessage(content=( + "Fix this invalid AVAP code. " + "Return ONLY the corrected code block with the same opening " + "and closing fences. No explanation, no extra text.\n\n" + f"\n{trace}\n\n\n" + f"\n{complete_block}\n" + ))]) + fixed_code = _extract_avap_code(fix_resp.content) + fixed_block = f"{fence_open}\n{fixed_code}\n```" + + valid2, _ = _call_parser(fixed_block) + if valid2 is False: + to_yield = fixed_block + validation_status = "INVALID_UNRESOLVED" + else: + to_yield = fixed_block + validation_status = "" if valid2 else "PARSER_UNAVAILABLE" + elif valid is None: + to_yield = complete_block + validation_status = "PARSER_UNAVAILABLE" + else: + to_yield = complete_block # validation_status stays "" + + full_response.append(to_yield) + yield brunix_pb2.AgentResponse(text=to_yield, is_final=False) + + stream_state = "text" + lookahead = rest + code_buffer = "" + + # Flush anything remaining (text after last block, or unclosed fence) + remainder = lookahead + code_buffer + if remainder: + full_response.append(remainder) + yield brunix_pb2.AgentResponse(text=remainder, is_final=False) + + else: + # RETRIEVAL / CONVERSATIONAL / PLATFORM — stream normally + for chunk in active_llm.stream(final_messages): + token = chunk.content + if token: + full_response.append(token) + yield brunix_pb2.AgentResponse(text=token, is_final=False) complete_text = "".join(full_response) if session_id: @@ -254,10 +362,11 @@ class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer): logger.info( f"[AskAgentStream] done — " - f"chunks={len(full_response)} total_chars={len(complete_text)}" + f"chunks={len(full_response)} total_chars={len(complete_text)} " + f"validation_status={validation_status!r}" ) - yield brunix_pb2.AgentResponse(text="", is_final=True) + yield brunix_pb2.AgentResponse(text="", is_final=True, validation_status=validation_status) except Exception as e: logger.error(f"[AskAgentStream] Error: {e}", exc_info=True) diff --git a/Docker/src/state.py b/Docker/src/state.py index bfd8742..51153e0 100644 --- a/Docker/src/state.py +++ b/Docker/src/state.py @@ -24,4 +24,8 @@ class AgentState(TypedDict): selected_text: str extra_context: str user_info: str - use_editor_context: bool \ No newline at end of file + use_editor_context: bool + # -- PTVL (ADR-0009) + parser_trace: str # raw parser error trace from first validation (empty if valid) + validation_status: str # "" | "INVALID_UNRESOLVED" | "PARSER_UNAVAILABLE" + context_relevant: bool # result of CONFIDENCE_PROMPT check (RETRIEVAL only) \ No newline at end of file