343 lines
12 KiB
Python
343 lines
12 KiB
Python
import base64
|
|
import logging
|
|
import os
|
|
from concurrent import futures
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
|
|
import brunix_pb2
|
|
import brunix_pb2_grpc
|
|
import grpc
|
|
from grpc_reflection.v1alpha import reflection
|
|
from elasticsearch import Elasticsearch
|
|
from langchain_core.messages import AIMessage
|
|
|
|
from utils.llm_factory import create_chat_model
|
|
from utils.emb_factory import create_embedding_model
|
|
from graph import build_graph, build_prepare_graph, build_final_messages, session_store, classify_history_store, _load_layer2_model
|
|
from utils.classifier_export import maybe_export, force_export
|
|
|
|
from evaluate import run_evaluation
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger("brunix-engine")
|
|
|
|
|
|
class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
|
|
|
def __init__(self):
|
|
es_url = os.getenv("ELASTICSEARCH_URL", "http://localhost:9200")
|
|
es_user = os.getenv("ELASTICSEARCH_USER")
|
|
es_pass = os.getenv("ELASTICSEARCH_PASSWORD")
|
|
es_apikey = os.getenv("ELASTICSEARCH_API_KEY")
|
|
index = os.getenv("ELASTICSEARCH_INDEX", "avap-knowledge-v1")
|
|
|
|
self.llm = create_chat_model(
|
|
provider="ollama",
|
|
model=os.getenv("OLLAMA_MODEL_NAME"),
|
|
base_url=os.getenv("OLLAMA_URL"),
|
|
temperature=0,
|
|
validate_model_on_init=True,
|
|
)
|
|
|
|
conv_model = os.getenv("OLLAMA_MODEL_NAME_CONVERSATIONAL")
|
|
if conv_model:
|
|
self.llm_conversational = create_chat_model(
|
|
provider="ollama",
|
|
model=conv_model,
|
|
base_url=os.getenv("OLLAMA_URL"),
|
|
temperature=0,
|
|
)
|
|
logger.info(f"[ENGINE] Conversational model: {conv_model}")
|
|
else:
|
|
self.llm_conversational = self.llm
|
|
|
|
self.embeddings = create_embedding_model(
|
|
provider="ollama",
|
|
model=os.getenv("OLLAMA_EMB_MODEL_NAME"),
|
|
base_url=os.getenv("OLLAMA_URL"),
|
|
)
|
|
|
|
es_kwargs: dict = {"hosts": [es_url], "request_timeout": 60}
|
|
if es_apikey:
|
|
es_kwargs["api_key"] = es_apikey
|
|
elif es_user and es_pass:
|
|
es_kwargs["basic_auth"] = (es_user, es_pass)
|
|
|
|
self.es_client = Elasticsearch(**es_kwargs)
|
|
self.index_name = index
|
|
|
|
if self.es_client.ping():
|
|
info = self.es_client.info()
|
|
logger.info(f"[ESEARCH] Connected: {info['version']['number']} — index: {index}")
|
|
else:
|
|
logger.error("[ESEARCH] Cant Connect")
|
|
|
|
self.graph = build_graph(
|
|
llm = self.llm,
|
|
embeddings = self.embeddings,
|
|
es_client = self.es_client,
|
|
index_name = self.index_name,
|
|
llm_conversational = self.llm_conversational,
|
|
)
|
|
|
|
self.prepare_graph = build_prepare_graph(
|
|
llm = self.llm,
|
|
embeddings = self.embeddings,
|
|
es_client = self.es_client,
|
|
index_name = self.index_name,
|
|
)
|
|
|
|
_load_layer2_model()
|
|
logger.info("Brunix Engine initialized.")
|
|
|
|
|
|
def AskAgent(self, request, context):
|
|
session_id = request.session_id or "default"
|
|
query = request.query
|
|
|
|
try:
|
|
editor_content = base64.b64decode(request.editor_content).decode("utf-8") if request.editor_content else ""
|
|
except Exception:
|
|
editor_content = ""
|
|
logger.warning("[AskAgent] editor_content base64 decode failed")
|
|
|
|
|
|
try:
|
|
selected_text = base64.b64decode(request.selected_text).decode("utf-8") if request.selected_text else ""
|
|
except Exception:
|
|
selected_text = ""
|
|
logger.warning("[AskAgent] selected_text base64 decode failed")
|
|
|
|
try:
|
|
extra_context = base64.b64decode(request.extra_context).decode("utf-8") if request.extra_context else ""
|
|
except Exception:
|
|
extra_context = ""
|
|
logger.warning("[AskAgent] extra_context base64 decode failed")
|
|
|
|
user_info = request.user_info or "{}"
|
|
query_type = request.query_type or ""
|
|
|
|
logger.info(
|
|
f"[AskAgent] session={session_id} "
|
|
f"editor={bool(editor_content)} selected={bool(selected_text)} "
|
|
f"declared_type={query_type or 'none'} "
|
|
f"query='{query[:80]}'"
|
|
)
|
|
|
|
try:
|
|
history = list(session_store.get(session_id, []))
|
|
classify_history = list(classify_history_store.get(session_id, []))
|
|
logger.info(f"[AskAgent] conversation: {len(history)} previous messages.")
|
|
|
|
initial_state = {
|
|
"messages": history + [{"role": "user", "content": query}],
|
|
"session_id": session_id,
|
|
"reformulated_query": "",
|
|
"context": "",
|
|
"query_type": query_type,
|
|
"classify_history": classify_history,
|
|
|
|
"editor_content": editor_content,
|
|
"selected_text": selected_text,
|
|
"extra_context": extra_context,
|
|
"user_info": user_info
|
|
}
|
|
|
|
final_state = self.graph.invoke(initial_state)
|
|
messages = final_state.get("messages", [])
|
|
last_msg = messages[-1] if messages else None
|
|
result_text = getattr(last_msg, "content", str(last_msg)) \
|
|
if last_msg else ""
|
|
|
|
if session_id:
|
|
classify_history_store[session_id] = final_state.get("classify_history", classify_history)
|
|
maybe_export(classify_history_store)
|
|
|
|
logger.info(f"[AskAgent] query_type={final_state.get('query_type')} "
|
|
f"answer='{result_text[:100]}'")
|
|
|
|
yield brunix_pb2.AgentResponse(
|
|
text = result_text,
|
|
avap_code= "AVAP-2026",
|
|
is_final = True,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"[AskAgent] Error: {e}", exc_info=True)
|
|
yield brunix_pb2.AgentResponse(
|
|
text = f"[ENG] Error: {str(e)}",
|
|
is_final = True,
|
|
)
|
|
|
|
|
|
def AskAgentStream(self, request, context):
|
|
session_id = request.session_id or "default"
|
|
query = request.query
|
|
|
|
try:
|
|
editor_content = base64.b64decode(request.editor_content).decode("utf-8") if request.editor_content else ""
|
|
except Exception:
|
|
editor_content = ""
|
|
logger.warning("[AskAgent] editor_content base64 decode failed")
|
|
|
|
|
|
try:
|
|
selected_text = base64.b64decode(request.selected_text).decode("utf-8") if request.selected_text else ""
|
|
except Exception:
|
|
selected_text = ""
|
|
logger.warning("[AskAgent] selected_text base64 decode failed")
|
|
|
|
try:
|
|
extra_context = base64.b64decode(request.extra_context).decode("utf-8") if request.extra_context else ""
|
|
except Exception:
|
|
extra_context = ""
|
|
logger.warning("[AskAgent] extra_context base64 decode failed")
|
|
|
|
user_info = request.user_info or "{}"
|
|
query_type = request.query_type or ""
|
|
|
|
logger.info(
|
|
f"[AskAgentStream] session={session_id} "
|
|
f"editor={bool(editor_content)} selected={bool(selected_text)} "
|
|
f"declared_type={query_type or 'none'} "
|
|
f"query='{query[:80]}'"
|
|
)
|
|
|
|
try:
|
|
history = list(session_store.get(session_id, []))
|
|
classify_history = list(classify_history_store.get(session_id, []))
|
|
logger.info(f"[AskAgentStream] conversation: {len(history)} previous messages.")
|
|
|
|
initial_state = {
|
|
"messages": history + [{"role": "user", "content": query}],
|
|
"session_id": session_id,
|
|
"reformulated_query": "",
|
|
"context": "",
|
|
"query_type": query_type,
|
|
"classify_history": classify_history,
|
|
|
|
"editor_content": editor_content,
|
|
"selected_text": selected_text,
|
|
"extra_context": extra_context,
|
|
"user_info": user_info
|
|
}
|
|
|
|
prepared = self.prepare_graph.invoke(initial_state)
|
|
logger.info(
|
|
f"[AskAgentStream] query_type={prepared.get('query_type')} "
|
|
f"context_len={len(prepared.get('context', ''))}"
|
|
)
|
|
|
|
final_messages = build_final_messages(prepared)
|
|
full_response = []
|
|
|
|
query_type = prepared.get("query_type", "RETRIEVAL")
|
|
active_llm = self.llm_conversational if query_type in ("CONVERSATIONAL", "PLATFORM") else self.llm
|
|
|
|
for chunk in active_llm.stream(final_messages):
|
|
token = chunk.content
|
|
if token:
|
|
full_response.append(token)
|
|
yield brunix_pb2.AgentResponse(
|
|
text = token,
|
|
is_final = False,
|
|
)
|
|
|
|
complete_text = "".join(full_response)
|
|
if session_id:
|
|
session_store[session_id] = (
|
|
list(prepared["messages"]) + [AIMessage(content=complete_text)]
|
|
)
|
|
classify_history_store[session_id] = prepared.get("classify_history", classify_history)
|
|
maybe_export(classify_history_store)
|
|
|
|
logger.info(
|
|
f"[AskAgentStream] done — "
|
|
f"chunks={len(full_response)} total_chars={len(complete_text)}"
|
|
)
|
|
|
|
yield brunix_pb2.AgentResponse(text="", is_final=True)
|
|
|
|
except Exception as e:
|
|
logger.error(f"[AskAgentStream] Error: {e}", exc_info=True)
|
|
yield brunix_pb2.AgentResponse(
|
|
text = f"[ENG] Error: {str(e)}",
|
|
is_final = True,
|
|
)
|
|
|
|
|
|
def EvaluateRAG(self, request, context):
|
|
category = request.category or None
|
|
limit = request.limit or None
|
|
index = request.index or self.index_name
|
|
|
|
logger.info(f"[EvaluateRAG] category={category} limit={limit} index={index}")
|
|
|
|
try:
|
|
result = run_evaluation(
|
|
es_client = self.es_client,
|
|
llm = self.llm,
|
|
embeddings = self.embeddings,
|
|
index_name = index,
|
|
category = category,
|
|
limit = limit,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"[EvaluateRAG] Error: {e}", exc_info=True)
|
|
return brunix_pb2.EvalResponse(status=f"error: {e}")
|
|
|
|
if result.get("status") != "ok":
|
|
return brunix_pb2.EvalResponse(status=result.get("error", "unknown error"))
|
|
|
|
details = [
|
|
brunix_pb2.QuestionDetail(
|
|
id = d["id"],
|
|
category = d["category"],
|
|
question = d["question"],
|
|
answer_preview = d["answer_preview"],
|
|
n_chunks = d["n_chunks"],
|
|
)
|
|
for d in result.get("details", [])
|
|
]
|
|
|
|
scores = result["scores"]
|
|
return brunix_pb2.EvalResponse(
|
|
status = "ok",
|
|
questions_evaluated = result["questions_evaluated"],
|
|
elapsed_seconds = result["elapsed_seconds"],
|
|
judge_model = result["judge_model"],
|
|
index = result["index"],
|
|
faithfulness = scores["faithfulness"],
|
|
answer_relevancy = scores["answer_relevancy"],
|
|
context_recall = scores["context_recall"],
|
|
context_precision = scores["context_precision"],
|
|
global_score = result["global_score"],
|
|
verdict= result["verdict"],
|
|
details= details,
|
|
)
|
|
|
|
|
|
def serve():
|
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
|
brunix_pb2_grpc.add_AssistanceEngineServicer_to_server(BrunixEngine(), server)
|
|
|
|
SERVICE_NAMES = (
|
|
brunix_pb2.DESCRIPTOR.services_by_name["AssistanceEngine"].full_name,
|
|
reflection.SERVICE_NAME,
|
|
)
|
|
reflection.enable_server_reflection(SERVICE_NAMES, server)
|
|
|
|
server.add_insecure_port("[::]:50051")
|
|
logger.info("[ENGINE] listen on 50051 (gRPC)")
|
|
server.start()
|
|
try:
|
|
server.wait_for_termination()
|
|
finally:
|
|
force_export(classify_history_store)
|
|
logger.info("[ENGINE] classifier labels flushed on shutdown")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
serve()
|