244 lines
8.2 KiB
Python
244 lines
8.2 KiB
Python
import logging
|
|
import os
|
|
from concurrent import futures
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
|
|
import brunix_pb2
|
|
import brunix_pb2_grpc
|
|
import grpc
|
|
from grpc_reflection.v1alpha import reflection
|
|
from elasticsearch import Elasticsearch
|
|
from langchain_core.messages import AIMessage
|
|
|
|
from utils.llm_factory import create_chat_model
|
|
from utils.emb_factory import create_embedding_model
|
|
from graph import build_graph, build_prepare_graph, build_final_messages, session_store
|
|
|
|
from evaluate import run_evaluation
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger("brunix-engine")
|
|
|
|
|
|
class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
|
|
|
def __init__(self):
|
|
es_url = os.getenv("ELASTICSEARCH_URL", "http://localhost:9200")
|
|
es_user = os.getenv("ELASTICSEARCH_USER")
|
|
es_pass = os.getenv("ELASTICSEARCH_PASSWORD")
|
|
es_apikey = os.getenv("ELASTICSEARCH_API_KEY")
|
|
index = os.getenv("ELASTICSEARCH_INDEX", "avap-knowledge-v1")
|
|
|
|
self.llm = create_chat_model(
|
|
provider="ollama",
|
|
model=os.getenv("OLLAMA_MODEL_NAME"),
|
|
base_url=os.getenv("OLLAMA_URL"),
|
|
temperature=0,
|
|
validate_model_on_init=True,
|
|
)
|
|
|
|
self.embeddings = create_embedding_model(
|
|
provider="ollama",
|
|
model=os.getenv("OLLAMA_EMB_MODEL_NAME"),
|
|
base_url=os.getenv("OLLAMA_URL"),
|
|
)
|
|
|
|
es_kwargs: dict = {"hosts": [es_url], "request_timeout": 60}
|
|
if es_apikey:
|
|
es_kwargs["api_key"] = es_apikey
|
|
elif es_user and es_pass:
|
|
es_kwargs["basic_auth"] = (es_user, es_pass)
|
|
|
|
self.es_client = Elasticsearch(**es_kwargs)
|
|
self.index_name = index
|
|
|
|
if self.es_client.ping():
|
|
info = self.es_client.info()
|
|
logger.info(f"[ESEARCH] Connected: {info['version']['number']} — index: {index}")
|
|
else:
|
|
logger.error("[ESEARCH] Cant Connect")
|
|
|
|
self.graph = build_graph(
|
|
llm = self.llm,
|
|
embeddings = self.embeddings,
|
|
es_client = self.es_client,
|
|
index_name = self.index_name,
|
|
)
|
|
|
|
self.prepare_graph = build_prepare_graph(
|
|
llm = self.llm,
|
|
embeddings = self.embeddings,
|
|
es_client = self.es_client,
|
|
index_name = self.index_name,
|
|
)
|
|
|
|
logger.info("Brunix Engine initialized.")
|
|
|
|
|
|
def AskAgent(self, request, context):
|
|
session_id = request.session_id or "default"
|
|
query = request.query
|
|
logger.info(f"[AskAgent] session={session_id} query='{query[:80]}'")
|
|
|
|
try:
|
|
history = list(session_store.get(session_id, []))
|
|
logger.info(f"[AskAgent] conversation: {len(history)} previous messages.")
|
|
|
|
initial_state = {
|
|
"messages": history + [{"role": "user", "content": query}],
|
|
"session_id": session_id,
|
|
"reformulated_query": "",
|
|
"context": "",
|
|
"query_type": "",
|
|
}
|
|
|
|
final_state = self.graph.invoke(initial_state)
|
|
messages = final_state.get("messages", [])
|
|
last_msg = messages[-1] if messages else None
|
|
result_text = getattr(last_msg, "content", str(last_msg)) \
|
|
if last_msg else ""
|
|
|
|
logger.info(f"[AskAgent] query_type={final_state.get('query_type')} "
|
|
f"answer='{result_text[:100]}'")
|
|
|
|
yield brunix_pb2.AgentResponse(
|
|
text = result_text,
|
|
avap_code= "AVAP-2026",
|
|
is_final = True,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"[AskAgent] Error: {e}", exc_info=True)
|
|
yield brunix_pb2.AgentResponse(
|
|
text = f"[ENG] Error: {str(e)}",
|
|
is_final = True,
|
|
)
|
|
|
|
|
|
def AskAgentStream(self, request, context):
|
|
session_id = request.session_id or "default"
|
|
query = request.query
|
|
logger.info(f"[AskAgentStream] session={session_id} query='{query[:80]}'")
|
|
|
|
try:
|
|
history = list(session_store.get(session_id, []))
|
|
logger.info(f"[AskAgentStream] conversation: {len(history)} previous messages.")
|
|
|
|
initial_state = {
|
|
"messages": history + [{"role": "user", "content": query}],
|
|
"session_id": session_id,
|
|
"reformulated_query": "",
|
|
"context": "",
|
|
"query_type": "",
|
|
}
|
|
|
|
prepared = self.prepare_graph.invoke(initial_state)
|
|
logger.info(
|
|
f"[AskAgentStream] query_type={prepared.get('query_type')} "
|
|
f"context_len={len(prepared.get('context', ''))}"
|
|
)
|
|
|
|
final_messages = build_final_messages(prepared)
|
|
full_response = []
|
|
|
|
for chunk in self.llm.stream(final_messages):
|
|
token = chunk.content
|
|
if token:
|
|
full_response.append(token)
|
|
yield brunix_pb2.AgentResponse(
|
|
text = token,
|
|
is_final = False,
|
|
)
|
|
|
|
complete_text = "".join(full_response)
|
|
if session_id:
|
|
session_store[session_id] = (
|
|
list(prepared["messages"]) + [AIMessage(content=complete_text)]
|
|
)
|
|
|
|
logger.info(
|
|
f"[AskAgentStream] done — "
|
|
f"chunks={len(full_response)} total_chars={len(complete_text)}"
|
|
)
|
|
|
|
yield brunix_pb2.AgentResponse(text="", is_final=True)
|
|
|
|
except Exception as e:
|
|
logger.error(f"[AskAgentStream] Error: {e}", exc_info=True)
|
|
yield brunix_pb2.AgentResponse(
|
|
text = f"[ENG] Error: {str(e)}",
|
|
is_final = True,
|
|
)
|
|
|
|
|
|
def EvaluateRAG(self, request, context):
|
|
category = request.category or None
|
|
limit = request.limit or None
|
|
index = request.index or self.index_name
|
|
|
|
logger.info(f"[EvaluateRAG] category={category} limit={limit} index={index}")
|
|
|
|
try:
|
|
result = run_evaluation(
|
|
es_client = self.es_client,
|
|
llm = self.llm,
|
|
embeddings = self.embeddings,
|
|
index_name = index,
|
|
category = category,
|
|
limit = limit,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"[EvaluateRAG] Error: {e}", exc_info=True)
|
|
return brunix_pb2.EvalResponse(status=f"error: {e}")
|
|
|
|
if result.get("status") != "ok":
|
|
return brunix_pb2.EvalResponse(status=result.get("error", "unknown error"))
|
|
|
|
details = [
|
|
brunix_pb2.QuestionDetail(
|
|
id = d["id"],
|
|
category = d["category"],
|
|
question = d["question"],
|
|
answer_preview = d["answer_preview"],
|
|
n_chunks = d["n_chunks"],
|
|
)
|
|
for d in result.get("details", [])
|
|
]
|
|
|
|
scores = result["scores"]
|
|
return brunix_pb2.EvalResponse(
|
|
status = "ok",
|
|
questions_evaluated = result["questions_evaluated"],
|
|
elapsed_seconds = result["elapsed_seconds"],
|
|
judge_model = result["judge_model"],
|
|
index = result["index"],
|
|
faithfulness = scores["faithfulness"],
|
|
answer_relevancy = scores["answer_relevancy"],
|
|
context_recall = scores["context_recall"],
|
|
context_precision = scores["context_precision"],
|
|
global_score = result["global_score"],
|
|
verdict= result["verdict"],
|
|
details= details,
|
|
)
|
|
|
|
|
|
def serve():
|
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
|
brunix_pb2_grpc.add_AssistanceEngineServicer_to_server(BrunixEngine(), server)
|
|
|
|
SERVICE_NAMES = (
|
|
brunix_pb2.DESCRIPTOR.services_by_name["AssistanceEngine"].full_name,
|
|
reflection.SERVICE_NAME,
|
|
)
|
|
reflection.enable_server_reflection(SERVICE_NAMES, server)
|
|
|
|
server.add_insecure_port("[::]:50051")
|
|
logger.info("[ENGINE] listen on 50051 (gRPC)")
|
|
server.start()
|
|
server.wait_for_termination()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
serve()
|