assistance-engine/Docker/src/server.py

343 lines
12 KiB
Python

import base64
import logging
import os
from concurrent import futures
from dotenv import load_dotenv
load_dotenv()
import brunix_pb2
import brunix_pb2_grpc
import grpc
from grpc_reflection.v1alpha import reflection
from elasticsearch import Elasticsearch
from langchain_core.messages import AIMessage
from utils.llm_factory import create_chat_model
from utils.emb_factory import create_embedding_model
from graph import build_graph, build_prepare_graph, build_final_messages, session_store, classify_history_store, _load_layer2_model
from utils.classifier_export import maybe_export, force_export
from evaluate import run_evaluation
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("brunix-engine")
class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
def __init__(self):
es_url = os.getenv("ELASTICSEARCH_URL", "http://localhost:9200")
es_user = os.getenv("ELASTICSEARCH_USER")
es_pass = os.getenv("ELASTICSEARCH_PASSWORD")
es_apikey = os.getenv("ELASTICSEARCH_API_KEY")
index = os.getenv("ELASTICSEARCH_INDEX", "avap-knowledge-v1")
self.llm = create_chat_model(
provider="ollama",
model=os.getenv("OLLAMA_MODEL_NAME"),
base_url=os.getenv("OLLAMA_URL"),
temperature=0,
validate_model_on_init=True,
)
conv_model = os.getenv("OLLAMA_MODEL_NAME_CONVERSATIONAL")
if conv_model:
self.llm_conversational = create_chat_model(
provider="ollama",
model=conv_model,
base_url=os.getenv("OLLAMA_URL"),
temperature=0,
)
logger.info(f"[ENGINE] Conversational model: {conv_model}")
else:
self.llm_conversational = self.llm
self.embeddings = create_embedding_model(
provider="ollama",
model=os.getenv("OLLAMA_EMB_MODEL_NAME"),
base_url=os.getenv("OLLAMA_URL"),
)
es_kwargs: dict = {"hosts": [es_url], "request_timeout": 60}
if es_apikey:
es_kwargs["api_key"] = es_apikey
elif es_user and es_pass:
es_kwargs["basic_auth"] = (es_user, es_pass)
self.es_client = Elasticsearch(**es_kwargs)
self.index_name = index
if self.es_client.ping():
info = self.es_client.info()
logger.info(f"[ESEARCH] Connected: {info['version']['number']} — index: {index}")
else:
logger.error("[ESEARCH] Cant Connect")
self.graph = build_graph(
llm = self.llm,
embeddings = self.embeddings,
es_client = self.es_client,
index_name = self.index_name,
llm_conversational = self.llm_conversational,
)
self.prepare_graph = build_prepare_graph(
llm = self.llm,
embeddings = self.embeddings,
es_client = self.es_client,
index_name = self.index_name,
)
_load_layer2_model()
logger.info("Brunix Engine initialized.")
def AskAgent(self, request, context):
session_id = request.session_id or "default"
query = request.query
try:
editor_content = base64.b64decode(request.editor_content).decode("utf-8") if request.editor_content else ""
except Exception:
editor_content = ""
logger.warning("[AskAgent] editor_content base64 decode failed")
try:
selected_text = base64.b64decode(request.selected_text).decode("utf-8") if request.selected_text else ""
except Exception:
selected_text = ""
logger.warning("[AskAgent] selected_text base64 decode failed")
try:
extra_context = base64.b64decode(request.extra_context).decode("utf-8") if request.extra_context else ""
except Exception:
extra_context = ""
logger.warning("[AskAgent] extra_context base64 decode failed")
user_info = request.user_info or "{}"
query_type = request.query_type or ""
logger.info(
f"[AskAgent] session={session_id} "
f"editor={bool(editor_content)} selected={bool(selected_text)} "
f"declared_type={query_type or 'none'} "
f"query='{query[:80]}'"
)
try:
history = list(session_store.get(session_id, []))
classify_history = list(classify_history_store.get(session_id, []))
logger.info(f"[AskAgent] conversation: {len(history)} previous messages.")
initial_state = {
"messages": history + [{"role": "user", "content": query}],
"session_id": session_id,
"reformulated_query": "",
"context": "",
"query_type": query_type,
"classify_history": classify_history,
"editor_content": editor_content,
"selected_text": selected_text,
"extra_context": extra_context,
"user_info": user_info
}
final_state = self.graph.invoke(initial_state)
messages = final_state.get("messages", [])
last_msg = messages[-1] if messages else None
result_text = getattr(last_msg, "content", str(last_msg)) \
if last_msg else ""
if session_id:
classify_history_store[session_id] = final_state.get("classify_history", classify_history)
maybe_export(classify_history_store)
logger.info(f"[AskAgent] query_type={final_state.get('query_type')} "
f"answer='{result_text[:100]}'")
yield brunix_pb2.AgentResponse(
text = result_text,
avap_code= "AVAP-2026",
is_final = True,
)
except Exception as e:
logger.error(f"[AskAgent] Error: {e}", exc_info=True)
yield brunix_pb2.AgentResponse(
text = f"[ENG] Error: {str(e)}",
is_final = True,
)
def AskAgentStream(self, request, context):
session_id = request.session_id or "default"
query = request.query
try:
editor_content = base64.b64decode(request.editor_content).decode("utf-8") if request.editor_content else ""
except Exception:
editor_content = ""
logger.warning("[AskAgent] editor_content base64 decode failed")
try:
selected_text = base64.b64decode(request.selected_text).decode("utf-8") if request.selected_text else ""
except Exception:
selected_text = ""
logger.warning("[AskAgent] selected_text base64 decode failed")
try:
extra_context = base64.b64decode(request.extra_context).decode("utf-8") if request.extra_context else ""
except Exception:
extra_context = ""
logger.warning("[AskAgent] extra_context base64 decode failed")
user_info = request.user_info or "{}"
query_type = request.query_type or ""
logger.info(
f"[AskAgentStream] session={session_id} "
f"editor={bool(editor_content)} selected={bool(selected_text)} "
f"declared_type={query_type or 'none'} "
f"query='{query[:80]}'"
)
try:
history = list(session_store.get(session_id, []))
classify_history = list(classify_history_store.get(session_id, []))
logger.info(f"[AskAgentStream] conversation: {len(history)} previous messages.")
initial_state = {
"messages": history + [{"role": "user", "content": query}],
"session_id": session_id,
"reformulated_query": "",
"context": "",
"query_type": query_type,
"classify_history": classify_history,
"editor_content": editor_content,
"selected_text": selected_text,
"extra_context": extra_context,
"user_info": user_info
}
prepared = self.prepare_graph.invoke(initial_state)
logger.info(
f"[AskAgentStream] query_type={prepared.get('query_type')} "
f"context_len={len(prepared.get('context', ''))}"
)
final_messages = build_final_messages(prepared)
full_response = []
query_type = prepared.get("query_type", "RETRIEVAL")
active_llm = self.llm_conversational if query_type in ("CONVERSATIONAL", "PLATFORM") else self.llm
for chunk in active_llm.stream(final_messages):
token = chunk.content
if token:
full_response.append(token)
yield brunix_pb2.AgentResponse(
text = token,
is_final = False,
)
complete_text = "".join(full_response)
if session_id:
session_store[session_id] = (
list(prepared["messages"]) + [AIMessage(content=complete_text)]
)
classify_history_store[session_id] = prepared.get("classify_history", classify_history)
maybe_export(classify_history_store)
logger.info(
f"[AskAgentStream] done — "
f"chunks={len(full_response)} total_chars={len(complete_text)}"
)
yield brunix_pb2.AgentResponse(text="", is_final=True)
except Exception as e:
logger.error(f"[AskAgentStream] Error: {e}", exc_info=True)
yield brunix_pb2.AgentResponse(
text = f"[ENG] Error: {str(e)}",
is_final = True,
)
def EvaluateRAG(self, request, context):
category = request.category or None
limit = request.limit or None
index = request.index or self.index_name
logger.info(f"[EvaluateRAG] category={category} limit={limit} index={index}")
try:
result = run_evaluation(
es_client = self.es_client,
llm = self.llm,
embeddings = self.embeddings,
index_name = index,
category = category,
limit = limit,
)
except Exception as e:
logger.error(f"[EvaluateRAG] Error: {e}", exc_info=True)
return brunix_pb2.EvalResponse(status=f"error: {e}")
if result.get("status") != "ok":
return brunix_pb2.EvalResponse(status=result.get("error", "unknown error"))
details = [
brunix_pb2.QuestionDetail(
id = d["id"],
category = d["category"],
question = d["question"],
answer_preview = d["answer_preview"],
n_chunks = d["n_chunks"],
)
for d in result.get("details", [])
]
scores = result["scores"]
return brunix_pb2.EvalResponse(
status = "ok",
questions_evaluated = result["questions_evaluated"],
elapsed_seconds = result["elapsed_seconds"],
judge_model = result["judge_model"],
index = result["index"],
faithfulness = scores["faithfulness"],
answer_relevancy = scores["answer_relevancy"],
context_recall = scores["context_recall"],
context_precision = scores["context_precision"],
global_score = result["global_score"],
verdict= result["verdict"],
details= details,
)
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
brunix_pb2_grpc.add_AssistanceEngineServicer_to_server(BrunixEngine(), server)
SERVICE_NAMES = (
brunix_pb2.DESCRIPTOR.services_by_name["AssistanceEngine"].full_name,
reflection.SERVICE_NAME,
)
reflection.enable_server_reflection(SERVICE_NAMES, server)
server.add_insecure_port("[::]:50051")
logger.info("[ENGINE] listen on 50051 (gRPC)")
server.start()
try:
server.wait_for_termination()
finally:
force_export(classify_history_store)
logger.info("[ENGINE] classifier labels flushed on shutdown")
if __name__ == "__main__":
serve()