ADR0008 finished
This commit is contained in:
parent
2c99f6ea40
commit
0b9c19d61f
Binary file not shown.
|
|
@ -50,9 +50,10 @@ message AgentRequest {
|
||||||
}
|
}
|
||||||
|
|
||||||
message AgentResponse {
|
message AgentResponse {
|
||||||
string text = 1;
|
string text = 1;
|
||||||
string avap_code = 2;
|
string avap_code = 2;
|
||||||
bool is_final = 3;
|
bool is_final = 3;
|
||||||
|
string validation_status = 4; // "" | "INVALID_UNRESOLVED" | "PARSER_UNAVAILABLE"
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re as _re
|
import re as _re
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import requests as _requests
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
@ -20,6 +23,7 @@ except ImportError:
|
||||||
from prompts import (
|
from prompts import (
|
||||||
CLASSIFY_PROMPT_TEMPLATE,
|
CLASSIFY_PROMPT_TEMPLATE,
|
||||||
CODE_GENERATION_PROMPT,
|
CODE_GENERATION_PROMPT,
|
||||||
|
CONFIDENCE_PROMPT_TEMPLATE,
|
||||||
CONVERSATIONAL_PROMPT,
|
CONVERSATIONAL_PROMPT,
|
||||||
GENERATE_PROMPT,
|
GENERATE_PROMPT,
|
||||||
PLATFORM_PROMPT,
|
PLATFORM_PROMPT,
|
||||||
|
|
@ -30,6 +34,106 @@ from state import AgentState, ClassifyEntry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── AVAP Parser client — ADR-0009 (PTVL) ──────────────────────────────────────
|
||||||
|
|
||||||
|
_PARSER_URL = os.getenv("AVAP_PARSER_URL", "")
|
||||||
|
_PARSER_TIMEOUT = int(os.getenv("AVAP_PARSER_TIMEOUT", "2"))
|
||||||
|
_CB_THRESHOLD = int(os.getenv("PARSER_CB_THRESHOLD", "3"))
|
||||||
|
_CB_COOLDOWN = int(os.getenv("PARSER_CB_COOLDOWN", "30"))
|
||||||
|
|
||||||
|
|
||||||
|
class _CircuitBreaker:
|
||||||
|
CLOSED = "CLOSED"
|
||||||
|
OPEN = "OPEN"
|
||||||
|
HALF_OPEN = "HALF_OPEN"
|
||||||
|
|
||||||
|
def __init__(self, threshold: int, cooldown: float):
|
||||||
|
self._state = self.CLOSED
|
||||||
|
self._failures = 0
|
||||||
|
self._opened_at = 0.0
|
||||||
|
self._threshold = threshold
|
||||||
|
self._cooldown = cooldown
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def allow(self) -> bool:
|
||||||
|
with self._lock:
|
||||||
|
if self._state == self.CLOSED:
|
||||||
|
return True
|
||||||
|
if self._state == self.OPEN:
|
||||||
|
if time.time() - self._opened_at >= self._cooldown:
|
||||||
|
self._state = self.HALF_OPEN
|
||||||
|
logger.info("[ptvl] circuit HALF-OPEN — probing parser availability")
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
# HALF_OPEN: allow the probe
|
||||||
|
return True
|
||||||
|
|
||||||
|
def success(self):
|
||||||
|
with self._lock:
|
||||||
|
if self._state != self.CLOSED:
|
||||||
|
logger.info("[ptvl] circuit CLOSED — parser reachable")
|
||||||
|
self._state = self.CLOSED
|
||||||
|
self._failures = 0
|
||||||
|
|
||||||
|
def failure(self):
|
||||||
|
with self._lock:
|
||||||
|
self._failures += 1
|
||||||
|
if self._failures >= self._threshold or self._state == self.HALF_OPEN:
|
||||||
|
self._state = self.OPEN
|
||||||
|
self._opened_at = time.time()
|
||||||
|
logger.warning("[ptvl] circuit OPEN — skipping parser, returning unvalidated")
|
||||||
|
|
||||||
|
|
||||||
|
_parser_cb = _CircuitBreaker(_CB_THRESHOLD, _CB_COOLDOWN)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_avap_code(text: str) -> str:
|
||||||
|
"""Return the first AVAP code block found in an LLM response."""
|
||||||
|
for pattern in (r'```avap\s*\n(.*?)```', r'```\s*\n(.*?)```', r'```(.*?)```'):
|
||||||
|
m = _re.search(pattern, text, _re.DOTALL)
|
||||||
|
if m:
|
||||||
|
return m.group(1).strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def _call_parser(text: str) -> tuple:
|
||||||
|
"""Call AVAP Parser REST API.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(True, "") — code valid
|
||||||
|
(False, trace) — code invalid, trace contains the error
|
||||||
|
(None, "") — parser unavailable or circuit open
|
||||||
|
"""
|
||||||
|
if not _PARSER_URL or _PARSER_TIMEOUT == 0:
|
||||||
|
return None, ""
|
||||||
|
|
||||||
|
if not _parser_cb.allow():
|
||||||
|
return None, ""
|
||||||
|
|
||||||
|
code = _extract_avap_code(text)
|
||||||
|
if not code.strip():
|
||||||
|
return None, ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = _requests.post(
|
||||||
|
f"{_PARSER_URL.rstrip('/')}/parse",
|
||||||
|
json={"code": code},
|
||||||
|
timeout=_PARSER_TIMEOUT,
|
||||||
|
)
|
||||||
|
data = resp.json()
|
||||||
|
if data.get("valid", False):
|
||||||
|
_parser_cb.success()
|
||||||
|
return True, ""
|
||||||
|
_parser_cb.success() # parser responded — it is healthy
|
||||||
|
return False, data.get("error", "parse error")
|
||||||
|
except Exception as exc:
|
||||||
|
_parser_cb.failure()
|
||||||
|
logger.warning(f"[ptvl] parser call failed: {exc}")
|
||||||
|
return None, ""
|
||||||
|
|
||||||
|
|
||||||
|
# ── Session stores ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
session_store: dict[str, list] = defaultdict(list)
|
session_store: dict[str, list] = defaultdict(list)
|
||||||
classify_history_store: dict[str, list] = defaultdict(list)
|
classify_history_store: dict[str, list] = defaultdict(list)
|
||||||
|
|
||||||
|
|
@ -520,22 +624,128 @@ def build_graph(llm, embeddings, es_client, index_name, llm_conversational=None)
|
||||||
_persist(state, resp)
|
_persist(state, resp)
|
||||||
return {"messages": [resp]}
|
return {"messages": [resp]}
|
||||||
|
|
||||||
|
# ── PTVL nodes (ADR-0009) ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def validate_code(state: AgentState) -> AgentState:
|
||||||
|
last_msg = state["messages"][-1]
|
||||||
|
content = getattr(last_msg, "content", str(last_msg))
|
||||||
|
valid, trace = _call_parser(content)
|
||||||
|
if valid is None:
|
||||||
|
logger.warning("[ptvl] parser unavailable — returning unvalidated")
|
||||||
|
return {"validation_status": "PARSER_UNAVAILABLE", "parser_trace": ""}
|
||||||
|
if valid:
|
||||||
|
logger.info("[ptvl] code VALID on first attempt")
|
||||||
|
return {"validation_status": "", "parser_trace": ""}
|
||||||
|
logger.info(f"[ptvl] code INVALID — trace: {str(trace)[:120]}")
|
||||||
|
return {"validation_status": "INVALID", "parser_trace": str(trace)}
|
||||||
|
|
||||||
|
def generate_code_retry(state: AgentState) -> AgentState:
|
||||||
|
parser_trace = state.get("parser_trace", "")
|
||||||
|
use_editor = state.get("use_editor_context", False)
|
||||||
|
|
||||||
|
feedback = (
|
||||||
|
"\n\n<parser_feedback>\n"
|
||||||
|
"The previous attempt produced invalid AVAP code. Specific failures:\n\n"
|
||||||
|
f"{parser_trace}\n\n"
|
||||||
|
"Correct these errors. Do not repeat the same constructs.\n"
|
||||||
|
"</parser_feedback>"
|
||||||
|
) if parser_trace else ""
|
||||||
|
|
||||||
|
base = _build_generation_prompt(
|
||||||
|
template_prompt=CODE_GENERATION_PROMPT,
|
||||||
|
context=state.get("context", ""),
|
||||||
|
editor_content=state.get("editor_content", "") if use_editor else "",
|
||||||
|
selected_text=state.get("selected_text", "") if use_editor else "",
|
||||||
|
extra_context=state.get("extra_context", ""),
|
||||||
|
)
|
||||||
|
retry_prompt = SystemMessage(content=base.content + feedback)
|
||||||
|
resp = llm.invoke([retry_prompt] + state["messages"])
|
||||||
|
logger.info(f"[generate_code_retry] {len(resp.content)} chars")
|
||||||
|
_persist(state, resp)
|
||||||
|
return {"messages": [resp]}
|
||||||
|
|
||||||
|
def validate_code_after_retry(state: AgentState) -> AgentState:
|
||||||
|
last_msg = state["messages"][-1]
|
||||||
|
content = getattr(last_msg, "content", str(last_msg))
|
||||||
|
valid, trace = _call_parser(content)
|
||||||
|
if valid is None:
|
||||||
|
logger.warning("[ptvl] parser unavailable after retry")
|
||||||
|
return {"validation_status": "PARSER_UNAVAILABLE", "parser_trace": ""}
|
||||||
|
if valid:
|
||||||
|
logger.info("[ptvl] retry code VALID")
|
||||||
|
return {"validation_status": "", "parser_trace": ""}
|
||||||
|
logger.info("[ptvl] retry code still INVALID → INVALID_UNRESOLVED")
|
||||||
|
return {"validation_status": "INVALID_UNRESOLVED", "parser_trace": str(trace)}
|
||||||
|
|
||||||
|
def check_context_relevance(state: AgentState) -> AgentState:
|
||||||
|
user_msg = state["messages"][-1]
|
||||||
|
question = getattr(user_msg, "content",
|
||||||
|
user_msg.get("content", "") if isinstance(user_msg, dict) else "")
|
||||||
|
context = state.get("context", "")
|
||||||
|
prompt = CONFIDENCE_PROMPT_TEMPLATE.format(question=question, context=context)
|
||||||
|
resp = llm.invoke([SystemMessage(content=prompt)])
|
||||||
|
relevant = resp.content.strip().upper().startswith("YES")
|
||||||
|
logger.info(f"[check_context_relevance] relevant={relevant}")
|
||||||
|
return {"context_relevant": relevant}
|
||||||
|
|
||||||
|
def reformulate_with_hint(state: AgentState) -> AgentState:
|
||||||
|
user_msg = state["messages"][-1]
|
||||||
|
question = getattr(user_msg, "content",
|
||||||
|
user_msg.get("content", "") if isinstance(user_msg, dict) else "")
|
||||||
|
hint = (
|
||||||
|
f"[CONTEXT_INSUFFICIENT]\n"
|
||||||
|
f"The previous retrieval did not return relevant context for: \"{question}\"\n"
|
||||||
|
f"Reformulate this query using broader terms or alternative phrasing."
|
||||||
|
)
|
||||||
|
resp = llm.invoke([REFORMULATE_PROMPT, HumanMessage(content=hint)])
|
||||||
|
reformulated = resp.content.strip()
|
||||||
|
logger.info(f"[reformulate_with_hint] -> '{reformulated}'")
|
||||||
|
return {"reformulated_query": reformulated}
|
||||||
|
|
||||||
|
def retrieve_retry(state: AgentState) -> AgentState:
|
||||||
|
query = state["reformulated_query"]
|
||||||
|
docs = hybrid_search_native(
|
||||||
|
es_client=es_client, embeddings=embeddings,
|
||||||
|
query=query, index_name=index_name, k=8,
|
||||||
|
)
|
||||||
|
context = format_context(docs)
|
||||||
|
logger.info(f"[retrieve_retry] {len(docs)} docs, context len={len(context)}")
|
||||||
|
return {"context": context}
|
||||||
|
|
||||||
|
# ── Routing ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def route_by_type(state):
|
def route_by_type(state):
|
||||||
return state.get("query_type", "RETRIEVAL")
|
return state.get("query_type", "RETRIEVAL")
|
||||||
|
|
||||||
def route_after_retrieve(state):
|
def route_after_retrieve(state):
|
||||||
qt = state.get("query_type", "RETRIEVAL")
|
qt = state.get("query_type", "RETRIEVAL")
|
||||||
return "generate_code" if qt == "CODE_GENERATION" else "generate"
|
return "generate_code" if qt == "CODE_GENERATION" else "check_context_relevance"
|
||||||
|
|
||||||
|
def route_after_validate(state):
|
||||||
|
if state.get("validation_status", "") == "INVALID":
|
||||||
|
return "generate_code_retry"
|
||||||
|
return END
|
||||||
|
|
||||||
|
def route_after_context_check(state):
|
||||||
|
if state.get("context_relevant", True):
|
||||||
|
return "generate"
|
||||||
|
return "reformulate_with_hint"
|
||||||
|
|
||||||
graph_builder = StateGraph(AgentState)
|
graph_builder = StateGraph(AgentState)
|
||||||
|
|
||||||
graph_builder.add_node("classify", classify)
|
graph_builder.add_node("classify", classify)
|
||||||
graph_builder.add_node("reformulate", reformulate)
|
graph_builder.add_node("reformulate", reformulate)
|
||||||
graph_builder.add_node("retrieve", retrieve)
|
graph_builder.add_node("retrieve", retrieve)
|
||||||
graph_builder.add_node("generate", generate)
|
graph_builder.add_node("generate", generate)
|
||||||
graph_builder.add_node("generate_code", generate_code)
|
graph_builder.add_node("generate_code", generate_code)
|
||||||
graph_builder.add_node("respond_conversational", respond_conversational)
|
graph_builder.add_node("validate_code", validate_code)
|
||||||
graph_builder.add_node("respond_platform", respond_platform)
|
graph_builder.add_node("generate_code_retry", generate_code_retry)
|
||||||
|
graph_builder.add_node("validate_code_after_retry",validate_code_after_retry)
|
||||||
|
graph_builder.add_node("check_context_relevance", check_context_relevance)
|
||||||
|
graph_builder.add_node("reformulate_with_hint", reformulate_with_hint)
|
||||||
|
graph_builder.add_node("retrieve_retry", retrieve_retry)
|
||||||
|
graph_builder.add_node("respond_conversational", respond_conversational)
|
||||||
|
graph_builder.add_node("respond_platform", respond_platform)
|
||||||
|
|
||||||
graph_builder.set_entry_point("classify")
|
graph_builder.set_entry_point("classify")
|
||||||
|
|
||||||
|
|
@ -543,10 +753,10 @@ def build_graph(llm, embeddings, es_client, index_name, llm_conversational=None)
|
||||||
"classify",
|
"classify",
|
||||||
route_by_type,
|
route_by_type,
|
||||||
{
|
{
|
||||||
"RETRIEVAL": "reformulate",
|
"RETRIEVAL": "reformulate",
|
||||||
"CODE_GENERATION": "reformulate",
|
"CODE_GENERATION": "reformulate",
|
||||||
"CONVERSATIONAL": "respond_conversational",
|
"CONVERSATIONAL": "respond_conversational",
|
||||||
"PLATFORM": "respond_platform",
|
"PLATFORM": "respond_platform",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -556,13 +766,37 @@ def build_graph(llm, embeddings, es_client, index_name, llm_conversational=None)
|
||||||
"retrieve",
|
"retrieve",
|
||||||
route_after_retrieve,
|
route_after_retrieve,
|
||||||
{
|
{
|
||||||
"generate": "generate",
|
"generate_code": "generate_code",
|
||||||
"generate_code": "generate_code",
|
"check_context_relevance": "check_context_relevance",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# CODE_GENERATION path: generate → validate → (retry if invalid) → END
|
||||||
|
graph_builder.add_edge("generate_code", "validate_code")
|
||||||
|
graph_builder.add_conditional_edges(
|
||||||
|
"validate_code",
|
||||||
|
route_after_validate,
|
||||||
|
{
|
||||||
|
"generate_code_retry": "generate_code_retry",
|
||||||
|
END: END,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
graph_builder.add_edge("generate_code_retry", "validate_code_after_retry")
|
||||||
|
graph_builder.add_edge("validate_code_after_retry", END)
|
||||||
|
|
||||||
|
# RETRIEVAL path: context check → generate or reformulate+retry → generate
|
||||||
|
graph_builder.add_conditional_edges(
|
||||||
|
"check_context_relevance",
|
||||||
|
route_after_context_check,
|
||||||
|
{
|
||||||
|
"generate": "generate",
|
||||||
|
"reformulate_with_hint": "reformulate_with_hint",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
graph_builder.add_edge("reformulate_with_hint", "retrieve_retry")
|
||||||
|
graph_builder.add_edge("retrieve_retry", "generate")
|
||||||
|
|
||||||
graph_builder.add_edge("generate", END)
|
graph_builder.add_edge("generate", END)
|
||||||
graph_builder.add_edge("generate_code", END)
|
|
||||||
graph_builder.add_edge("respond_conversational", END)
|
graph_builder.add_edge("respond_conversational", END)
|
||||||
graph_builder.add_edge("respond_platform", END)
|
graph_builder.add_edge("respond_platform", END)
|
||||||
|
|
||||||
|
|
@ -666,15 +900,64 @@ def build_prepare_graph(llm, embeddings, es_client, index_name):
|
||||||
def skip_retrieve(state: AgentState) -> AgentState:
|
def skip_retrieve(state: AgentState) -> AgentState:
|
||||||
return {"context": ""}
|
return {"context": ""}
|
||||||
|
|
||||||
|
def check_context_relevance(state: AgentState) -> AgentState:
|
||||||
|
user_msg = state["messages"][-1]
|
||||||
|
question = getattr(user_msg, "content",
|
||||||
|
user_msg.get("content", "") if isinstance(user_msg, dict) else "")
|
||||||
|
context = state.get("context", "")
|
||||||
|
prompt = CONFIDENCE_PROMPT_TEMPLATE.format(question=question, context=context)
|
||||||
|
resp = llm.invoke([SystemMessage(content=prompt)])
|
||||||
|
relevant = resp.content.strip().upper().startswith("YES")
|
||||||
|
logger.info(f"[prepare/check_context_relevance] relevant={relevant}")
|
||||||
|
return {"context_relevant": relevant}
|
||||||
|
|
||||||
|
def reformulate_with_hint(state: AgentState) -> AgentState:
|
||||||
|
user_msg = state["messages"][-1]
|
||||||
|
question = getattr(user_msg, "content",
|
||||||
|
user_msg.get("content", "") if isinstance(user_msg, dict) else "")
|
||||||
|
hint = (
|
||||||
|
f"[CONTEXT_INSUFFICIENT]\n"
|
||||||
|
f"The previous retrieval did not return relevant context for: \"{question}\"\n"
|
||||||
|
f"Reformulate this query using broader terms or alternative phrasing."
|
||||||
|
)
|
||||||
|
resp = llm.invoke([REFORMULATE_PROMPT, HumanMessage(content=hint)])
|
||||||
|
reformulated = resp.content.strip()
|
||||||
|
logger.info(f"[prepare/reformulate_with_hint] -> '{reformulated}'")
|
||||||
|
return {"reformulated_query": reformulated}
|
||||||
|
|
||||||
|
def retrieve_retry(state: AgentState) -> AgentState:
|
||||||
|
query = state["reformulated_query"]
|
||||||
|
docs = hybrid_search_native(
|
||||||
|
es_client=es_client, embeddings=embeddings,
|
||||||
|
query=query, index_name=index_name, k=8,
|
||||||
|
)
|
||||||
|
context = format_context(docs)
|
||||||
|
logger.info(f"[prepare/retrieve_retry] {len(docs)} docs, context len={len(context)}")
|
||||||
|
return {"context": context}
|
||||||
|
|
||||||
def route_by_type(state):
|
def route_by_type(state):
|
||||||
return state.get("query_type", "RETRIEVAL")
|
return state.get("query_type", "RETRIEVAL")
|
||||||
|
|
||||||
|
def route_after_retrieve(state):
|
||||||
|
qt = state.get("query_type", "RETRIEVAL")
|
||||||
|
if qt == "RETRIEVAL":
|
||||||
|
return "check_context_relevance"
|
||||||
|
return END # CODE_GENERATION goes straight to END
|
||||||
|
|
||||||
|
def route_after_context_check(state):
|
||||||
|
if state.get("context_relevant", True):
|
||||||
|
return END
|
||||||
|
return "reformulate_with_hint"
|
||||||
|
|
||||||
graph_builder = StateGraph(AgentState)
|
graph_builder = StateGraph(AgentState)
|
||||||
|
|
||||||
graph_builder.add_node("classify", classify)
|
graph_builder.add_node("classify", classify)
|
||||||
graph_builder.add_node("reformulate", reformulate)
|
graph_builder.add_node("reformulate", reformulate)
|
||||||
graph_builder.add_node("retrieve", retrieve)
|
graph_builder.add_node("retrieve", retrieve)
|
||||||
graph_builder.add_node("skip_retrieve", skip_retrieve)
|
graph_builder.add_node("skip_retrieve", skip_retrieve)
|
||||||
|
graph_builder.add_node("check_context_relevance", check_context_relevance)
|
||||||
|
graph_builder.add_node("reformulate_with_hint", reformulate_with_hint)
|
||||||
|
graph_builder.add_node("retrieve_retry", retrieve_retry)
|
||||||
|
|
||||||
graph_builder.set_entry_point("classify")
|
graph_builder.set_entry_point("classify")
|
||||||
|
|
||||||
|
|
@ -682,15 +965,35 @@ def build_prepare_graph(llm, embeddings, es_client, index_name):
|
||||||
"classify",
|
"classify",
|
||||||
route_by_type,
|
route_by_type,
|
||||||
{
|
{
|
||||||
"RETRIEVAL": "reformulate",
|
"RETRIEVAL": "reformulate",
|
||||||
"CODE_GENERATION": "reformulate",
|
"CODE_GENERATION": "reformulate",
|
||||||
"CONVERSATIONAL": "skip_retrieve",
|
"CONVERSATIONAL": "skip_retrieve",
|
||||||
"PLATFORM": "skip_retrieve",
|
"PLATFORM": "skip_retrieve",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_builder.add_edge("reformulate", "retrieve")
|
graph_builder.add_edge("reformulate", "retrieve")
|
||||||
graph_builder.add_edge("retrieve", END)
|
|
||||||
|
graph_builder.add_conditional_edges(
|
||||||
|
"retrieve",
|
||||||
|
route_after_retrieve,
|
||||||
|
{
|
||||||
|
"check_context_relevance": "check_context_relevance",
|
||||||
|
END: END,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
graph_builder.add_conditional_edges(
|
||||||
|
"check_context_relevance",
|
||||||
|
route_after_context_check,
|
||||||
|
{
|
||||||
|
END: END,
|
||||||
|
"reformulate_with_hint": "reformulate_with_hint",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
graph_builder.add_edge("reformulate_with_hint", "retrieve_retry")
|
||||||
|
graph_builder.add_edge("retrieve_retry", END)
|
||||||
|
|
||||||
graph_builder.add_edge("skip_retrieve", END)
|
graph_builder.add_edge("skip_retrieve", END)
|
||||||
|
|
||||||
return graph_builder.compile()
|
return graph_builder.compile()
|
||||||
|
|
|
||||||
|
|
@ -10,11 +10,11 @@ import brunix_pb2_grpc
|
||||||
import grpc
|
import grpc
|
||||||
from grpc_reflection.v1alpha import reflection
|
from grpc_reflection.v1alpha import reflection
|
||||||
from elasticsearch import Elasticsearch
|
from elasticsearch import Elasticsearch
|
||||||
from langchain_core.messages import AIMessage
|
from langchain_core.messages import AIMessage, SystemMessage
|
||||||
|
|
||||||
from utils.llm_factory import create_chat_model
|
from utils.llm_factory import create_chat_model
|
||||||
from utils.emb_factory import create_embedding_model
|
from utils.emb_factory import create_embedding_model
|
||||||
from graph import build_graph, build_prepare_graph, build_final_messages, session_store, classify_history_store, _load_layer2_model
|
from graph import build_graph, build_prepare_graph, build_final_messages, session_store, classify_history_store, _load_layer2_model, _call_parser, _extract_avap_code
|
||||||
from utils.classifier_export import maybe_export, force_export
|
from utils.classifier_export import maybe_export, force_export
|
||||||
|
|
||||||
from evaluate import run_evaluation
|
from evaluate import run_evaluation
|
||||||
|
|
@ -141,7 +141,12 @@ class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
||||||
"editor_content": editor_content,
|
"editor_content": editor_content,
|
||||||
"selected_text": selected_text,
|
"selected_text": selected_text,
|
||||||
"extra_context": extra_context,
|
"extra_context": extra_context,
|
||||||
"user_info": user_info
|
"user_info": user_info,
|
||||||
|
|
||||||
|
# PTVL fields (ADR-0009)
|
||||||
|
"parser_trace": "",
|
||||||
|
"validation_status": "",
|
||||||
|
"context_relevant": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
final_state = self.graph.invoke(initial_state)
|
final_state = self.graph.invoke(initial_state)
|
||||||
|
|
@ -150,17 +155,21 @@ class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
||||||
result_text = getattr(last_msg, "content", str(last_msg)) \
|
result_text = getattr(last_msg, "content", str(last_msg)) \
|
||||||
if last_msg else ""
|
if last_msg else ""
|
||||||
|
|
||||||
|
validation_status = final_state.get("validation_status", "")
|
||||||
|
|
||||||
if session_id:
|
if session_id:
|
||||||
classify_history_store[session_id] = final_state.get("classify_history", classify_history)
|
classify_history_store[session_id] = final_state.get("classify_history", classify_history)
|
||||||
maybe_export(classify_history_store)
|
maybe_export(classify_history_store)
|
||||||
|
|
||||||
logger.info(f"[AskAgent] query_type={final_state.get('query_type')} "
|
logger.info(f"[AskAgent] query_type={final_state.get('query_type')} "
|
||||||
|
f"validation_status={validation_status!r} "
|
||||||
f"answer='{result_text[:100]}'")
|
f"answer='{result_text[:100]}'")
|
||||||
|
|
||||||
yield brunix_pb2.AgentResponse(
|
yield brunix_pb2.AgentResponse(
|
||||||
text = result_text,
|
text = result_text,
|
||||||
avap_code= "AVAP-2026",
|
avap_code = "AVAP-2026",
|
||||||
is_final = True,
|
is_final = True,
|
||||||
|
validation_status = validation_status,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -220,7 +229,12 @@ class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
||||||
"editor_content": editor_content,
|
"editor_content": editor_content,
|
||||||
"selected_text": selected_text,
|
"selected_text": selected_text,
|
||||||
"extra_context": extra_context,
|
"extra_context": extra_context,
|
||||||
"user_info": user_info
|
"user_info": user_info,
|
||||||
|
|
||||||
|
# PTVL fields (ADR-0009)
|
||||||
|
"parser_trace": "",
|
||||||
|
"validation_status": "",
|
||||||
|
"context_relevant": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
prepared = self.prepare_graph.invoke(initial_state)
|
prepared = self.prepare_graph.invoke(initial_state)
|
||||||
|
|
@ -235,14 +249,108 @@ class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
||||||
query_type = prepared.get("query_type", "RETRIEVAL")
|
query_type = prepared.get("query_type", "RETRIEVAL")
|
||||||
active_llm = self.llm_conversational if query_type in ("CONVERSATIONAL", "PLATFORM") else self.llm
|
active_llm = self.llm_conversational if query_type in ("CONVERSATIONAL", "PLATFORM") else self.llm
|
||||||
|
|
||||||
for chunk in active_llm.stream(final_messages):
|
validation_status = ""
|
||||||
token = chunk.content
|
|
||||||
if token:
|
if query_type == "CODE_GENERATION":
|
||||||
full_response.append(token)
|
# ── Streaming state machine (ADR-0009) ────────────────────────
|
||||||
yield brunix_pb2.AgentResponse(
|
# TEXT mode → yield tokens immediately to the client
|
||||||
text = token,
|
# CODE mode → buffer the code block, validate, fix if needed,
|
||||||
is_final = False,
|
# then yield the validated block before continuing
|
||||||
)
|
stream_state = "text"
|
||||||
|
lookahead = "" # pending chars in TEXT mode (fence detection)
|
||||||
|
code_buffer = "" # accumulated chars in CODE mode
|
||||||
|
|
||||||
|
for chunk in active_llm.stream(final_messages):
|
||||||
|
token = chunk.content
|
||||||
|
if not token:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if stream_state == "text":
|
||||||
|
lookahead += token
|
||||||
|
fence_pos = lookahead.find("```")
|
||||||
|
|
||||||
|
if fence_pos != -1:
|
||||||
|
# Yield text before the fence immediately
|
||||||
|
text_before = lookahead[:fence_pos]
|
||||||
|
if text_before:
|
||||||
|
full_response.append(text_before)
|
||||||
|
yield brunix_pb2.AgentResponse(text=text_before, is_final=False)
|
||||||
|
stream_state = "code"
|
||||||
|
code_buffer = lookahead[fence_pos:]
|
||||||
|
lookahead = ""
|
||||||
|
else:
|
||||||
|
# Keep last 2 chars — a fence might span the next token
|
||||||
|
safe_len = max(0, len(lookahead) - 2)
|
||||||
|
if safe_len > 0:
|
||||||
|
safe = lookahead[:safe_len]
|
||||||
|
full_response.append(safe)
|
||||||
|
yield brunix_pb2.AgentResponse(text=safe, is_final=False)
|
||||||
|
lookahead = lookahead[safe_len:]
|
||||||
|
|
||||||
|
else: # stream_state == "code"
|
||||||
|
code_buffer += token
|
||||||
|
# Look for closing ``` only after the first newline
|
||||||
|
# (to skip the opening fence + optional language tag)
|
||||||
|
first_nl = code_buffer.find("\n")
|
||||||
|
if first_nl == -1:
|
||||||
|
continue
|
||||||
|
close_pos = code_buffer.find("```", first_nl + 1)
|
||||||
|
if close_pos == -1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# ── Complete code block captured ──────────────────────
|
||||||
|
complete_block = code_buffer[:close_pos + 3]
|
||||||
|
rest = code_buffer[close_pos + 3:]
|
||||||
|
|
||||||
|
valid, trace = _call_parser(complete_block)
|
||||||
|
|
||||||
|
if valid is False:
|
||||||
|
# Ask LLM to fix only the code block
|
||||||
|
logger.info("[stream/ptvl] INVALID — requesting fix from LLM")
|
||||||
|
fence_open = "```avap" if complete_block.startswith("```avap") else "```"
|
||||||
|
fix_resp = self.llm.invoke([SystemMessage(content=(
|
||||||
|
"Fix this invalid AVAP code. "
|
||||||
|
"Return ONLY the corrected code block with the same opening "
|
||||||
|
"and closing fences. No explanation, no extra text.\n\n"
|
||||||
|
f"<parser_feedback>\n{trace}\n</parser_feedback>\n\n"
|
||||||
|
f"<code>\n{complete_block}\n</code>"
|
||||||
|
))])
|
||||||
|
fixed_code = _extract_avap_code(fix_resp.content)
|
||||||
|
fixed_block = f"{fence_open}\n{fixed_code}\n```"
|
||||||
|
|
||||||
|
valid2, _ = _call_parser(fixed_block)
|
||||||
|
if valid2 is False:
|
||||||
|
to_yield = fixed_block
|
||||||
|
validation_status = "INVALID_UNRESOLVED"
|
||||||
|
else:
|
||||||
|
to_yield = fixed_block
|
||||||
|
validation_status = "" if valid2 else "PARSER_UNAVAILABLE"
|
||||||
|
elif valid is None:
|
||||||
|
to_yield = complete_block
|
||||||
|
validation_status = "PARSER_UNAVAILABLE"
|
||||||
|
else:
|
||||||
|
to_yield = complete_block # validation_status stays ""
|
||||||
|
|
||||||
|
full_response.append(to_yield)
|
||||||
|
yield brunix_pb2.AgentResponse(text=to_yield, is_final=False)
|
||||||
|
|
||||||
|
stream_state = "text"
|
||||||
|
lookahead = rest
|
||||||
|
code_buffer = ""
|
||||||
|
|
||||||
|
# Flush anything remaining (text after last block, or unclosed fence)
|
||||||
|
remainder = lookahead + code_buffer
|
||||||
|
if remainder:
|
||||||
|
full_response.append(remainder)
|
||||||
|
yield brunix_pb2.AgentResponse(text=remainder, is_final=False)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# RETRIEVAL / CONVERSATIONAL / PLATFORM — stream normally
|
||||||
|
for chunk in active_llm.stream(final_messages):
|
||||||
|
token = chunk.content
|
||||||
|
if token:
|
||||||
|
full_response.append(token)
|
||||||
|
yield brunix_pb2.AgentResponse(text=token, is_final=False)
|
||||||
|
|
||||||
complete_text = "".join(full_response)
|
complete_text = "".join(full_response)
|
||||||
if session_id:
|
if session_id:
|
||||||
|
|
@ -254,10 +362,11 @@ class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[AskAgentStream] done — "
|
f"[AskAgentStream] done — "
|
||||||
f"chunks={len(full_response)} total_chars={len(complete_text)}"
|
f"chunks={len(full_response)} total_chars={len(complete_text)} "
|
||||||
|
f"validation_status={validation_status!r}"
|
||||||
)
|
)
|
||||||
|
|
||||||
yield brunix_pb2.AgentResponse(text="", is_final=True)
|
yield brunix_pb2.AgentResponse(text="", is_final=True, validation_status=validation_status)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[AskAgentStream] Error: {e}", exc_info=True)
|
logger.error(f"[AskAgentStream] Error: {e}", exc_info=True)
|
||||||
|
|
|
||||||
|
|
@ -25,3 +25,7 @@ class AgentState(TypedDict):
|
||||||
extra_context: str
|
extra_context: str
|
||||||
user_info: str
|
user_info: str
|
||||||
use_editor_context: bool
|
use_editor_context: bool
|
||||||
|
# -- PTVL (ADR-0009)
|
||||||
|
parser_trace: str # raw parser error trace from first validation (empty if valid)
|
||||||
|
validation_status: str # "" | "INVALID_UNRESOLVED" | "PARSER_UNAVAILABLE"
|
||||||
|
context_relevant: bool # result of CONFIDENCE_PROMPT check (RETRIEVAL only)
|
||||||
Loading…
Reference in New Issue