ADR0008 finished
This commit is contained in:
parent
2c99f6ea40
commit
0b9c19d61f
Binary file not shown.
|
|
@ -50,9 +50,10 @@ message AgentRequest {
|
|||
}
|
||||
|
||||
message AgentResponse {
|
||||
string text = 1;
|
||||
string avap_code = 2;
|
||||
bool is_final = 3;
|
||||
string text = 1;
|
||||
string avap_code = 2;
|
||||
bool is_final = 3;
|
||||
string validation_status = 4; // "" | "INVALID_UNRESOLVED" | "PARSER_UNAVAILABLE"
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
import logging
|
||||
import os
|
||||
import re as _re
|
||||
import time
|
||||
import threading
|
||||
import requests as _requests
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -20,6 +23,7 @@ except ImportError:
|
|||
from prompts import (
|
||||
CLASSIFY_PROMPT_TEMPLATE,
|
||||
CODE_GENERATION_PROMPT,
|
||||
CONFIDENCE_PROMPT_TEMPLATE,
|
||||
CONVERSATIONAL_PROMPT,
|
||||
GENERATE_PROMPT,
|
||||
PLATFORM_PROMPT,
|
||||
|
|
@ -30,6 +34,106 @@ from state import AgentState, ClassifyEntry
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── AVAP Parser client — ADR-0009 (PTVL) ──────────────────────────────────────
|
||||
|
||||
_PARSER_URL = os.getenv("AVAP_PARSER_URL", "")
|
||||
_PARSER_TIMEOUT = int(os.getenv("AVAP_PARSER_TIMEOUT", "2"))
|
||||
_CB_THRESHOLD = int(os.getenv("PARSER_CB_THRESHOLD", "3"))
|
||||
_CB_COOLDOWN = int(os.getenv("PARSER_CB_COOLDOWN", "30"))
|
||||
|
||||
|
||||
class _CircuitBreaker:
|
||||
CLOSED = "CLOSED"
|
||||
OPEN = "OPEN"
|
||||
HALF_OPEN = "HALF_OPEN"
|
||||
|
||||
def __init__(self, threshold: int, cooldown: float):
|
||||
self._state = self.CLOSED
|
||||
self._failures = 0
|
||||
self._opened_at = 0.0
|
||||
self._threshold = threshold
|
||||
self._cooldown = cooldown
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def allow(self) -> bool:
|
||||
with self._lock:
|
||||
if self._state == self.CLOSED:
|
||||
return True
|
||||
if self._state == self.OPEN:
|
||||
if time.time() - self._opened_at >= self._cooldown:
|
||||
self._state = self.HALF_OPEN
|
||||
logger.info("[ptvl] circuit HALF-OPEN — probing parser availability")
|
||||
return True
|
||||
return False
|
||||
# HALF_OPEN: allow the probe
|
||||
return True
|
||||
|
||||
def success(self):
|
||||
with self._lock:
|
||||
if self._state != self.CLOSED:
|
||||
logger.info("[ptvl] circuit CLOSED — parser reachable")
|
||||
self._state = self.CLOSED
|
||||
self._failures = 0
|
||||
|
||||
def failure(self):
|
||||
with self._lock:
|
||||
self._failures += 1
|
||||
if self._failures >= self._threshold or self._state == self.HALF_OPEN:
|
||||
self._state = self.OPEN
|
||||
self._opened_at = time.time()
|
||||
logger.warning("[ptvl] circuit OPEN — skipping parser, returning unvalidated")
|
||||
|
||||
|
||||
_parser_cb = _CircuitBreaker(_CB_THRESHOLD, _CB_COOLDOWN)
|
||||
|
||||
|
||||
def _extract_avap_code(text: str) -> str:
|
||||
"""Return the first AVAP code block found in an LLM response."""
|
||||
for pattern in (r'```avap\s*\n(.*?)```', r'```\s*\n(.*?)```', r'```(.*?)```'):
|
||||
m = _re.search(pattern, text, _re.DOTALL)
|
||||
if m:
|
||||
return m.group(1).strip()
|
||||
return text
|
||||
|
||||
|
||||
def _call_parser(text: str) -> tuple:
|
||||
"""Call AVAP Parser REST API.
|
||||
|
||||
Returns:
|
||||
(True, "") — code valid
|
||||
(False, trace) — code invalid, trace contains the error
|
||||
(None, "") — parser unavailable or circuit open
|
||||
"""
|
||||
if not _PARSER_URL or _PARSER_TIMEOUT == 0:
|
||||
return None, ""
|
||||
|
||||
if not _parser_cb.allow():
|
||||
return None, ""
|
||||
|
||||
code = _extract_avap_code(text)
|
||||
if not code.strip():
|
||||
return None, ""
|
||||
|
||||
try:
|
||||
resp = _requests.post(
|
||||
f"{_PARSER_URL.rstrip('/')}/parse",
|
||||
json={"code": code},
|
||||
timeout=_PARSER_TIMEOUT,
|
||||
)
|
||||
data = resp.json()
|
||||
if data.get("valid", False):
|
||||
_parser_cb.success()
|
||||
return True, ""
|
||||
_parser_cb.success() # parser responded — it is healthy
|
||||
return False, data.get("error", "parse error")
|
||||
except Exception as exc:
|
||||
_parser_cb.failure()
|
||||
logger.warning(f"[ptvl] parser call failed: {exc}")
|
||||
return None, ""
|
||||
|
||||
|
||||
# ── Session stores ─────────────────────────────────────────────────────────────
|
||||
|
||||
session_store: dict[str, list] = defaultdict(list)
|
||||
classify_history_store: dict[str, list] = defaultdict(list)
|
||||
|
||||
|
|
@ -520,22 +624,128 @@ def build_graph(llm, embeddings, es_client, index_name, llm_conversational=None)
|
|||
_persist(state, resp)
|
||||
return {"messages": [resp]}
|
||||
|
||||
# ── PTVL nodes (ADR-0009) ─────────────────────────────────────────────────
|
||||
|
||||
def validate_code(state: AgentState) -> AgentState:
|
||||
last_msg = state["messages"][-1]
|
||||
content = getattr(last_msg, "content", str(last_msg))
|
||||
valid, trace = _call_parser(content)
|
||||
if valid is None:
|
||||
logger.warning("[ptvl] parser unavailable — returning unvalidated")
|
||||
return {"validation_status": "PARSER_UNAVAILABLE", "parser_trace": ""}
|
||||
if valid:
|
||||
logger.info("[ptvl] code VALID on first attempt")
|
||||
return {"validation_status": "", "parser_trace": ""}
|
||||
logger.info(f"[ptvl] code INVALID — trace: {str(trace)[:120]}")
|
||||
return {"validation_status": "INVALID", "parser_trace": str(trace)}
|
||||
|
||||
def generate_code_retry(state: AgentState) -> AgentState:
|
||||
parser_trace = state.get("parser_trace", "")
|
||||
use_editor = state.get("use_editor_context", False)
|
||||
|
||||
feedback = (
|
||||
"\n\n<parser_feedback>\n"
|
||||
"The previous attempt produced invalid AVAP code. Specific failures:\n\n"
|
||||
f"{parser_trace}\n\n"
|
||||
"Correct these errors. Do not repeat the same constructs.\n"
|
||||
"</parser_feedback>"
|
||||
) if parser_trace else ""
|
||||
|
||||
base = _build_generation_prompt(
|
||||
template_prompt=CODE_GENERATION_PROMPT,
|
||||
context=state.get("context", ""),
|
||||
editor_content=state.get("editor_content", "") if use_editor else "",
|
||||
selected_text=state.get("selected_text", "") if use_editor else "",
|
||||
extra_context=state.get("extra_context", ""),
|
||||
)
|
||||
retry_prompt = SystemMessage(content=base.content + feedback)
|
||||
resp = llm.invoke([retry_prompt] + state["messages"])
|
||||
logger.info(f"[generate_code_retry] {len(resp.content)} chars")
|
||||
_persist(state, resp)
|
||||
return {"messages": [resp]}
|
||||
|
||||
def validate_code_after_retry(state: AgentState) -> AgentState:
|
||||
last_msg = state["messages"][-1]
|
||||
content = getattr(last_msg, "content", str(last_msg))
|
||||
valid, trace = _call_parser(content)
|
||||
if valid is None:
|
||||
logger.warning("[ptvl] parser unavailable after retry")
|
||||
return {"validation_status": "PARSER_UNAVAILABLE", "parser_trace": ""}
|
||||
if valid:
|
||||
logger.info("[ptvl] retry code VALID")
|
||||
return {"validation_status": "", "parser_trace": ""}
|
||||
logger.info("[ptvl] retry code still INVALID → INVALID_UNRESOLVED")
|
||||
return {"validation_status": "INVALID_UNRESOLVED", "parser_trace": str(trace)}
|
||||
|
||||
def check_context_relevance(state: AgentState) -> AgentState:
|
||||
user_msg = state["messages"][-1]
|
||||
question = getattr(user_msg, "content",
|
||||
user_msg.get("content", "") if isinstance(user_msg, dict) else "")
|
||||
context = state.get("context", "")
|
||||
prompt = CONFIDENCE_PROMPT_TEMPLATE.format(question=question, context=context)
|
||||
resp = llm.invoke([SystemMessage(content=prompt)])
|
||||
relevant = resp.content.strip().upper().startswith("YES")
|
||||
logger.info(f"[check_context_relevance] relevant={relevant}")
|
||||
return {"context_relevant": relevant}
|
||||
|
||||
def reformulate_with_hint(state: AgentState) -> AgentState:
|
||||
user_msg = state["messages"][-1]
|
||||
question = getattr(user_msg, "content",
|
||||
user_msg.get("content", "") if isinstance(user_msg, dict) else "")
|
||||
hint = (
|
||||
f"[CONTEXT_INSUFFICIENT]\n"
|
||||
f"The previous retrieval did not return relevant context for: \"{question}\"\n"
|
||||
f"Reformulate this query using broader terms or alternative phrasing."
|
||||
)
|
||||
resp = llm.invoke([REFORMULATE_PROMPT, HumanMessage(content=hint)])
|
||||
reformulated = resp.content.strip()
|
||||
logger.info(f"[reformulate_with_hint] -> '{reformulated}'")
|
||||
return {"reformulated_query": reformulated}
|
||||
|
||||
def retrieve_retry(state: AgentState) -> AgentState:
|
||||
query = state["reformulated_query"]
|
||||
docs = hybrid_search_native(
|
||||
es_client=es_client, embeddings=embeddings,
|
||||
query=query, index_name=index_name, k=8,
|
||||
)
|
||||
context = format_context(docs)
|
||||
logger.info(f"[retrieve_retry] {len(docs)} docs, context len={len(context)}")
|
||||
return {"context": context}
|
||||
|
||||
# ── Routing ───────────────────────────────────────────────────────────────
|
||||
|
||||
def route_by_type(state):
|
||||
return state.get("query_type", "RETRIEVAL")
|
||||
|
||||
def route_after_retrieve(state):
|
||||
qt = state.get("query_type", "RETRIEVAL")
|
||||
return "generate_code" if qt == "CODE_GENERATION" else "generate"
|
||||
return "generate_code" if qt == "CODE_GENERATION" else "check_context_relevance"
|
||||
|
||||
def route_after_validate(state):
|
||||
if state.get("validation_status", "") == "INVALID":
|
||||
return "generate_code_retry"
|
||||
return END
|
||||
|
||||
def route_after_context_check(state):
|
||||
if state.get("context_relevant", True):
|
||||
return "generate"
|
||||
return "reformulate_with_hint"
|
||||
|
||||
graph_builder = StateGraph(AgentState)
|
||||
|
||||
graph_builder.add_node("classify", classify)
|
||||
graph_builder.add_node("reformulate", reformulate)
|
||||
graph_builder.add_node("retrieve", retrieve)
|
||||
graph_builder.add_node("generate", generate)
|
||||
graph_builder.add_node("generate_code", generate_code)
|
||||
graph_builder.add_node("respond_conversational", respond_conversational)
|
||||
graph_builder.add_node("respond_platform", respond_platform)
|
||||
graph_builder.add_node("classify", classify)
|
||||
graph_builder.add_node("reformulate", reformulate)
|
||||
graph_builder.add_node("retrieve", retrieve)
|
||||
graph_builder.add_node("generate", generate)
|
||||
graph_builder.add_node("generate_code", generate_code)
|
||||
graph_builder.add_node("validate_code", validate_code)
|
||||
graph_builder.add_node("generate_code_retry", generate_code_retry)
|
||||
graph_builder.add_node("validate_code_after_retry",validate_code_after_retry)
|
||||
graph_builder.add_node("check_context_relevance", check_context_relevance)
|
||||
graph_builder.add_node("reformulate_with_hint", reformulate_with_hint)
|
||||
graph_builder.add_node("retrieve_retry", retrieve_retry)
|
||||
graph_builder.add_node("respond_conversational", respond_conversational)
|
||||
graph_builder.add_node("respond_platform", respond_platform)
|
||||
|
||||
graph_builder.set_entry_point("classify")
|
||||
|
||||
|
|
@ -543,10 +753,10 @@ def build_graph(llm, embeddings, es_client, index_name, llm_conversational=None)
|
|||
"classify",
|
||||
route_by_type,
|
||||
{
|
||||
"RETRIEVAL": "reformulate",
|
||||
"RETRIEVAL": "reformulate",
|
||||
"CODE_GENERATION": "reformulate",
|
||||
"CONVERSATIONAL": "respond_conversational",
|
||||
"PLATFORM": "respond_platform",
|
||||
"CONVERSATIONAL": "respond_conversational",
|
||||
"PLATFORM": "respond_platform",
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -556,13 +766,37 @@ def build_graph(llm, embeddings, es_client, index_name, llm_conversational=None)
|
|||
"retrieve",
|
||||
route_after_retrieve,
|
||||
{
|
||||
"generate": "generate",
|
||||
"generate_code": "generate_code",
|
||||
"generate_code": "generate_code",
|
||||
"check_context_relevance": "check_context_relevance",
|
||||
}
|
||||
)
|
||||
|
||||
# CODE_GENERATION path: generate → validate → (retry if invalid) → END
|
||||
graph_builder.add_edge("generate_code", "validate_code")
|
||||
graph_builder.add_conditional_edges(
|
||||
"validate_code",
|
||||
route_after_validate,
|
||||
{
|
||||
"generate_code_retry": "generate_code_retry",
|
||||
END: END,
|
||||
}
|
||||
)
|
||||
graph_builder.add_edge("generate_code_retry", "validate_code_after_retry")
|
||||
graph_builder.add_edge("validate_code_after_retry", END)
|
||||
|
||||
# RETRIEVAL path: context check → generate or reformulate+retry → generate
|
||||
graph_builder.add_conditional_edges(
|
||||
"check_context_relevance",
|
||||
route_after_context_check,
|
||||
{
|
||||
"generate": "generate",
|
||||
"reformulate_with_hint": "reformulate_with_hint",
|
||||
}
|
||||
)
|
||||
graph_builder.add_edge("reformulate_with_hint", "retrieve_retry")
|
||||
graph_builder.add_edge("retrieve_retry", "generate")
|
||||
|
||||
graph_builder.add_edge("generate", END)
|
||||
graph_builder.add_edge("generate_code", END)
|
||||
graph_builder.add_edge("respond_conversational", END)
|
||||
graph_builder.add_edge("respond_platform", END)
|
||||
|
||||
|
|
@ -666,15 +900,64 @@ def build_prepare_graph(llm, embeddings, es_client, index_name):
|
|||
def skip_retrieve(state: AgentState) -> AgentState:
|
||||
return {"context": ""}
|
||||
|
||||
def check_context_relevance(state: AgentState) -> AgentState:
|
||||
user_msg = state["messages"][-1]
|
||||
question = getattr(user_msg, "content",
|
||||
user_msg.get("content", "") if isinstance(user_msg, dict) else "")
|
||||
context = state.get("context", "")
|
||||
prompt = CONFIDENCE_PROMPT_TEMPLATE.format(question=question, context=context)
|
||||
resp = llm.invoke([SystemMessage(content=prompt)])
|
||||
relevant = resp.content.strip().upper().startswith("YES")
|
||||
logger.info(f"[prepare/check_context_relevance] relevant={relevant}")
|
||||
return {"context_relevant": relevant}
|
||||
|
||||
def reformulate_with_hint(state: AgentState) -> AgentState:
|
||||
user_msg = state["messages"][-1]
|
||||
question = getattr(user_msg, "content",
|
||||
user_msg.get("content", "") if isinstance(user_msg, dict) else "")
|
||||
hint = (
|
||||
f"[CONTEXT_INSUFFICIENT]\n"
|
||||
f"The previous retrieval did not return relevant context for: \"{question}\"\n"
|
||||
f"Reformulate this query using broader terms or alternative phrasing."
|
||||
)
|
||||
resp = llm.invoke([REFORMULATE_PROMPT, HumanMessage(content=hint)])
|
||||
reformulated = resp.content.strip()
|
||||
logger.info(f"[prepare/reformulate_with_hint] -> '{reformulated}'")
|
||||
return {"reformulated_query": reformulated}
|
||||
|
||||
def retrieve_retry(state: AgentState) -> AgentState:
|
||||
query = state["reformulated_query"]
|
||||
docs = hybrid_search_native(
|
||||
es_client=es_client, embeddings=embeddings,
|
||||
query=query, index_name=index_name, k=8,
|
||||
)
|
||||
context = format_context(docs)
|
||||
logger.info(f"[prepare/retrieve_retry] {len(docs)} docs, context len={len(context)}")
|
||||
return {"context": context}
|
||||
|
||||
def route_by_type(state):
|
||||
return state.get("query_type", "RETRIEVAL")
|
||||
|
||||
def route_after_retrieve(state):
|
||||
qt = state.get("query_type", "RETRIEVAL")
|
||||
if qt == "RETRIEVAL":
|
||||
return "check_context_relevance"
|
||||
return END # CODE_GENERATION goes straight to END
|
||||
|
||||
def route_after_context_check(state):
|
||||
if state.get("context_relevant", True):
|
||||
return END
|
||||
return "reformulate_with_hint"
|
||||
|
||||
graph_builder = StateGraph(AgentState)
|
||||
|
||||
graph_builder.add_node("classify", classify)
|
||||
graph_builder.add_node("reformulate", reformulate)
|
||||
graph_builder.add_node("retrieve", retrieve)
|
||||
graph_builder.add_node("skip_retrieve", skip_retrieve)
|
||||
graph_builder.add_node("classify", classify)
|
||||
graph_builder.add_node("reformulate", reformulate)
|
||||
graph_builder.add_node("retrieve", retrieve)
|
||||
graph_builder.add_node("skip_retrieve", skip_retrieve)
|
||||
graph_builder.add_node("check_context_relevance", check_context_relevance)
|
||||
graph_builder.add_node("reformulate_with_hint", reformulate_with_hint)
|
||||
graph_builder.add_node("retrieve_retry", retrieve_retry)
|
||||
|
||||
graph_builder.set_entry_point("classify")
|
||||
|
||||
|
|
@ -682,15 +965,35 @@ def build_prepare_graph(llm, embeddings, es_client, index_name):
|
|||
"classify",
|
||||
route_by_type,
|
||||
{
|
||||
"RETRIEVAL": "reformulate",
|
||||
"RETRIEVAL": "reformulate",
|
||||
"CODE_GENERATION": "reformulate",
|
||||
"CONVERSATIONAL": "skip_retrieve",
|
||||
"PLATFORM": "skip_retrieve",
|
||||
"CONVERSATIONAL": "skip_retrieve",
|
||||
"PLATFORM": "skip_retrieve",
|
||||
}
|
||||
)
|
||||
|
||||
graph_builder.add_edge("reformulate", "retrieve")
|
||||
graph_builder.add_edge("retrieve", END)
|
||||
|
||||
graph_builder.add_conditional_edges(
|
||||
"retrieve",
|
||||
route_after_retrieve,
|
||||
{
|
||||
"check_context_relevance": "check_context_relevance",
|
||||
END: END,
|
||||
}
|
||||
)
|
||||
|
||||
graph_builder.add_conditional_edges(
|
||||
"check_context_relevance",
|
||||
route_after_context_check,
|
||||
{
|
||||
END: END,
|
||||
"reformulate_with_hint": "reformulate_with_hint",
|
||||
}
|
||||
)
|
||||
graph_builder.add_edge("reformulate_with_hint", "retrieve_retry")
|
||||
graph_builder.add_edge("retrieve_retry", END)
|
||||
|
||||
graph_builder.add_edge("skip_retrieve", END)
|
||||
|
||||
return graph_builder.compile()
|
||||
|
|
|
|||
|
|
@ -10,11 +10,11 @@ import brunix_pb2_grpc
|
|||
import grpc
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
from elasticsearch import Elasticsearch
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, SystemMessage
|
||||
|
||||
from utils.llm_factory import create_chat_model
|
||||
from utils.emb_factory import create_embedding_model
|
||||
from graph import build_graph, build_prepare_graph, build_final_messages, session_store, classify_history_store, _load_layer2_model
|
||||
from graph import build_graph, build_prepare_graph, build_final_messages, session_store, classify_history_store, _load_layer2_model, _call_parser, _extract_avap_code
|
||||
from utils.classifier_export import maybe_export, force_export
|
||||
|
||||
from evaluate import run_evaluation
|
||||
|
|
@ -141,7 +141,12 @@ class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
|||
"editor_content": editor_content,
|
||||
"selected_text": selected_text,
|
||||
"extra_context": extra_context,
|
||||
"user_info": user_info
|
||||
"user_info": user_info,
|
||||
|
||||
# PTVL fields (ADR-0009)
|
||||
"parser_trace": "",
|
||||
"validation_status": "",
|
||||
"context_relevant": True,
|
||||
}
|
||||
|
||||
final_state = self.graph.invoke(initial_state)
|
||||
|
|
@ -150,17 +155,21 @@ class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
|||
result_text = getattr(last_msg, "content", str(last_msg)) \
|
||||
if last_msg else ""
|
||||
|
||||
validation_status = final_state.get("validation_status", "")
|
||||
|
||||
if session_id:
|
||||
classify_history_store[session_id] = final_state.get("classify_history", classify_history)
|
||||
maybe_export(classify_history_store)
|
||||
|
||||
logger.info(f"[AskAgent] query_type={final_state.get('query_type')} "
|
||||
f"validation_status={validation_status!r} "
|
||||
f"answer='{result_text[:100]}'")
|
||||
|
||||
yield brunix_pb2.AgentResponse(
|
||||
text = result_text,
|
||||
avap_code= "AVAP-2026",
|
||||
is_final = True,
|
||||
text = result_text,
|
||||
avap_code = "AVAP-2026",
|
||||
is_final = True,
|
||||
validation_status = validation_status,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -220,7 +229,12 @@ class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
|||
"editor_content": editor_content,
|
||||
"selected_text": selected_text,
|
||||
"extra_context": extra_context,
|
||||
"user_info": user_info
|
||||
"user_info": user_info,
|
||||
|
||||
# PTVL fields (ADR-0009)
|
||||
"parser_trace": "",
|
||||
"validation_status": "",
|
||||
"context_relevant": True,
|
||||
}
|
||||
|
||||
prepared = self.prepare_graph.invoke(initial_state)
|
||||
|
|
@ -235,14 +249,108 @@ class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
|||
query_type = prepared.get("query_type", "RETRIEVAL")
|
||||
active_llm = self.llm_conversational if query_type in ("CONVERSATIONAL", "PLATFORM") else self.llm
|
||||
|
||||
for chunk in active_llm.stream(final_messages):
|
||||
token = chunk.content
|
||||
if token:
|
||||
full_response.append(token)
|
||||
yield brunix_pb2.AgentResponse(
|
||||
text = token,
|
||||
is_final = False,
|
||||
)
|
||||
validation_status = ""
|
||||
|
||||
if query_type == "CODE_GENERATION":
|
||||
# ── Streaming state machine (ADR-0009) ────────────────────────
|
||||
# TEXT mode → yield tokens immediately to the client
|
||||
# CODE mode → buffer the code block, validate, fix if needed,
|
||||
# then yield the validated block before continuing
|
||||
stream_state = "text"
|
||||
lookahead = "" # pending chars in TEXT mode (fence detection)
|
||||
code_buffer = "" # accumulated chars in CODE mode
|
||||
|
||||
for chunk in active_llm.stream(final_messages):
|
||||
token = chunk.content
|
||||
if not token:
|
||||
continue
|
||||
|
||||
if stream_state == "text":
|
||||
lookahead += token
|
||||
fence_pos = lookahead.find("```")
|
||||
|
||||
if fence_pos != -1:
|
||||
# Yield text before the fence immediately
|
||||
text_before = lookahead[:fence_pos]
|
||||
if text_before:
|
||||
full_response.append(text_before)
|
||||
yield brunix_pb2.AgentResponse(text=text_before, is_final=False)
|
||||
stream_state = "code"
|
||||
code_buffer = lookahead[fence_pos:]
|
||||
lookahead = ""
|
||||
else:
|
||||
# Keep last 2 chars — a fence might span the next token
|
||||
safe_len = max(0, len(lookahead) - 2)
|
||||
if safe_len > 0:
|
||||
safe = lookahead[:safe_len]
|
||||
full_response.append(safe)
|
||||
yield brunix_pb2.AgentResponse(text=safe, is_final=False)
|
||||
lookahead = lookahead[safe_len:]
|
||||
|
||||
else: # stream_state == "code"
|
||||
code_buffer += token
|
||||
# Look for closing ``` only after the first newline
|
||||
# (to skip the opening fence + optional language tag)
|
||||
first_nl = code_buffer.find("\n")
|
||||
if first_nl == -1:
|
||||
continue
|
||||
close_pos = code_buffer.find("```", first_nl + 1)
|
||||
if close_pos == -1:
|
||||
continue
|
||||
|
||||
# ── Complete code block captured ──────────────────────
|
||||
complete_block = code_buffer[:close_pos + 3]
|
||||
rest = code_buffer[close_pos + 3:]
|
||||
|
||||
valid, trace = _call_parser(complete_block)
|
||||
|
||||
if valid is False:
|
||||
# Ask LLM to fix only the code block
|
||||
logger.info("[stream/ptvl] INVALID — requesting fix from LLM")
|
||||
fence_open = "```avap" if complete_block.startswith("```avap") else "```"
|
||||
fix_resp = self.llm.invoke([SystemMessage(content=(
|
||||
"Fix this invalid AVAP code. "
|
||||
"Return ONLY the corrected code block with the same opening "
|
||||
"and closing fences. No explanation, no extra text.\n\n"
|
||||
f"<parser_feedback>\n{trace}\n</parser_feedback>\n\n"
|
||||
f"<code>\n{complete_block}\n</code>"
|
||||
))])
|
||||
fixed_code = _extract_avap_code(fix_resp.content)
|
||||
fixed_block = f"{fence_open}\n{fixed_code}\n```"
|
||||
|
||||
valid2, _ = _call_parser(fixed_block)
|
||||
if valid2 is False:
|
||||
to_yield = fixed_block
|
||||
validation_status = "INVALID_UNRESOLVED"
|
||||
else:
|
||||
to_yield = fixed_block
|
||||
validation_status = "" if valid2 else "PARSER_UNAVAILABLE"
|
||||
elif valid is None:
|
||||
to_yield = complete_block
|
||||
validation_status = "PARSER_UNAVAILABLE"
|
||||
else:
|
||||
to_yield = complete_block # validation_status stays ""
|
||||
|
||||
full_response.append(to_yield)
|
||||
yield brunix_pb2.AgentResponse(text=to_yield, is_final=False)
|
||||
|
||||
stream_state = "text"
|
||||
lookahead = rest
|
||||
code_buffer = ""
|
||||
|
||||
# Flush anything remaining (text after last block, or unclosed fence)
|
||||
remainder = lookahead + code_buffer
|
||||
if remainder:
|
||||
full_response.append(remainder)
|
||||
yield brunix_pb2.AgentResponse(text=remainder, is_final=False)
|
||||
|
||||
else:
|
||||
# RETRIEVAL / CONVERSATIONAL / PLATFORM — stream normally
|
||||
for chunk in active_llm.stream(final_messages):
|
||||
token = chunk.content
|
||||
if token:
|
||||
full_response.append(token)
|
||||
yield brunix_pb2.AgentResponse(text=token, is_final=False)
|
||||
|
||||
complete_text = "".join(full_response)
|
||||
if session_id:
|
||||
|
|
@ -254,10 +362,11 @@ class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
|||
|
||||
logger.info(
|
||||
f"[AskAgentStream] done — "
|
||||
f"chunks={len(full_response)} total_chars={len(complete_text)}"
|
||||
f"chunks={len(full_response)} total_chars={len(complete_text)} "
|
||||
f"validation_status={validation_status!r}"
|
||||
)
|
||||
|
||||
yield brunix_pb2.AgentResponse(text="", is_final=True)
|
||||
yield brunix_pb2.AgentResponse(text="", is_final=True, validation_status=validation_status)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AskAgentStream] Error: {e}", exc_info=True)
|
||||
|
|
|
|||
|
|
@ -24,4 +24,8 @@ class AgentState(TypedDict):
|
|||
selected_text: str
|
||||
extra_context: str
|
||||
user_info: str
|
||||
use_editor_context: bool
|
||||
use_editor_context: bool
|
||||
# -- PTVL (ADR-0009)
|
||||
parser_trace: str # raw parser error trace from first validation (empty if valid)
|
||||
validation_status: str # "" | "INVALID_UNRESOLVED" | "PARSER_UNAVAILABLE"
|
||||
context_relevant: bool # result of CONFIDENCE_PROMPT check (RETRIEVAL only)
|
||||
Loading…
Reference in New Issue