UPGRADE: New RAG functional
This commit is contained in:
parent
dfcbf43fa2
commit
1daac66f89
|
|
@ -25,6 +25,10 @@ RUN python -m grpc_tools.protoc \
|
|||
--grpc_python_out=./src \
|
||||
./protos/brunix.proto
|
||||
|
||||
EXPOSE 50051
|
||||
COPY entrypoint.sh /entrypoint.sh
|
||||
RUN chmod +x /entrypoint.sh
|
||||
|
||||
CMD ["python", "src/server.py"]
|
||||
EXPOSE 50051
|
||||
EXPOSE 8000
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
|
|
@ -6,6 +6,7 @@ services:
|
|||
container_name: brunix-assistance-engine
|
||||
ports:
|
||||
- "50052:50051"
|
||||
- "8000:8000"
|
||||
environment:
|
||||
ELASTICSEARCH_URL: ${ELASTICSEARCH_URL}
|
||||
ELASTICSEARCH_INDEX: ${ELASTICSEARCH_INDEX}
|
||||
|
|
@ -16,6 +17,7 @@ services:
|
|||
OLLAMA_URL: ${OLLAMA_URL}
|
||||
OLLAMA_MODEL_NAME: ${OLLAMA_MODEL_NAME}
|
||||
OLLAMA_EMB_MODEL_NAME: ${OLLAMA_EMB_MODEL_NAME}
|
||||
PROXY_THREAD_WORKERS: 10
|
||||
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,30 @@
|
|||
#!/bin/sh
|
||||
set -e
|
||||
|
||||
echo "[entrypoint] Starting Brunix Engine (gRPC :50051)..."
|
||||
python src/server.py &
|
||||
ENGINE_PID=$!
|
||||
|
||||
echo "[entrypoint] Starting OpenAI Proxy (HTTP :8000)..."
|
||||
uvicorn openai_proxy:app --host 0.0.0.0 --port 8000 --workers 4 --app-dir src &
|
||||
PROXY_PID=$!
|
||||
|
||||
wait_any() {
|
||||
while kill -0 $ENGINE_PID 2>/dev/null && kill -0 $PROXY_PID 2>/dev/null; do
|
||||
sleep 2
|
||||
done
|
||||
|
||||
if ! kill -0 $ENGINE_PID 2>/dev/null; then
|
||||
echo "[entrypoint] Engine died — stopping proxy"
|
||||
kill $PROXY_PID 2>/dev/null
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! kill -0 $PROXY_PID 2>/dev/null; then
|
||||
echo "[entrypoint] Proxy died — stopping engine"
|
||||
kill $ENGINE_PID 2>/dev/null
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
wait_any
|
||||
|
|
@ -3,16 +3,60 @@ syntax = "proto3";
|
|||
package brunix;
|
||||
|
||||
service AssistanceEngine {
|
||||
rpc AskAgent (AgentRequest) returns (stream AgentResponse);
|
||||
// Respuesta completa — compatible con clientes existentes
|
||||
rpc AskAgent (AgentRequest) returns (stream AgentResponse);
|
||||
|
||||
// Streaming real token a token desde Ollama
|
||||
rpc AskAgentStream (AgentRequest) returns (stream AgentResponse);
|
||||
|
||||
// Evaluación RAGAS con Claude como juez
|
||||
rpc EvaluateRAG (EvalRequest) returns (EvalResponse);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// AskAgent / AskAgentStream — mismos mensajes, dos comportamientos
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
message AgentRequest {
|
||||
string query = 1;
|
||||
string query = 1;
|
||||
string session_id = 2;
|
||||
}
|
||||
|
||||
message AgentResponse {
|
||||
string text = 1;
|
||||
string text = 1;
|
||||
string avap_code = 2;
|
||||
bool is_final = 3;
|
||||
bool is_final = 3;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// EvaluateRAG
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
message EvalRequest {
|
||||
string category = 1;
|
||||
int32 limit = 2;
|
||||
string index = 3;
|
||||
}
|
||||
|
||||
message EvalResponse {
|
||||
string status = 1;
|
||||
int32 questions_evaluated = 2;
|
||||
float elapsed_seconds = 3;
|
||||
string judge_model = 4;
|
||||
string index = 5;
|
||||
float faithfulness = 6;
|
||||
float answer_relevancy = 7;
|
||||
float context_recall = 8;
|
||||
float context_precision = 9;
|
||||
float global_score = 10;
|
||||
string verdict = 11;
|
||||
repeated QuestionDetail details = 12;
|
||||
}
|
||||
|
||||
message QuestionDetail {
|
||||
string id = 1;
|
||||
string category = 2;
|
||||
string question = 3;
|
||||
string answer_preview = 4;
|
||||
int32 n_chunks = 5;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -316,3 +316,10 @@ yarl==1.22.0
|
|||
# via aiohttp
|
||||
zstandard==0.25.0
|
||||
# via langsmith
|
||||
|
||||
ragas
|
||||
datasets
|
||||
langchain-anthropic
|
||||
|
||||
fastapi>=0.111.0
|
||||
uvicorn[standard]>=0.29.0
|
||||
|
|
@ -0,0 +1,230 @@
|
|||
import os
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from ragas import evaluate as ragas_evaluate
|
||||
from ragas.metrics import ( faithfulness, answer_relevancy, context_recall, context_precision,)
|
||||
from ragas.llms import LangchainLLMWrapper
|
||||
from ragas.embeddings import LangchainEmbeddingsWrapper
|
||||
from datasets import Dataset
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GOLDEN_DATASET_PATH = Path(__file__).parent / "golden_dataset.json"
|
||||
CLAUDE_MODEL = os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-20250514")
|
||||
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
||||
K_RETRIEVE = 5
|
||||
|
||||
|
||||
|
||||
ANTHROPIC_AVAILABLE = True
|
||||
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
from langchain_core.messages import SystemMessage, HumanMessage
|
||||
|
||||
def retrieve_context( es_client, embeddings, question, index, k = K_RETRIEVE,):
|
||||
|
||||
query_vector = None
|
||||
try:
|
||||
query_vector = embeddings.embed_query(question)
|
||||
except Exception as e:
|
||||
logger.warning(f"[eval] embed_query fails: {e}")
|
||||
|
||||
bm25_hits = []
|
||||
try:
|
||||
resp = es_client.search(
|
||||
index=index,
|
||||
body={
|
||||
"size": k,
|
||||
"query": {
|
||||
"multi_match": {
|
||||
"query": question,
|
||||
"fields": ["content^2", "text^2"],
|
||||
"type": "best_fields",
|
||||
"fuzziness": "AUTO",
|
||||
}
|
||||
},
|
||||
"_source": {"excludes": ["embedding"]},
|
||||
}
|
||||
)
|
||||
bm25_hits = resp["hits"]["hits"]
|
||||
except Exception as e:
|
||||
logger.warning(f"[eval] BM25 fails: {e}")
|
||||
|
||||
knn_hits = []
|
||||
if query_vector:
|
||||
try:
|
||||
resp = es_client.search(
|
||||
index=index,
|
||||
body={
|
||||
"size": k,
|
||||
"knn": {
|
||||
"field": "embedding",
|
||||
"query_vector": query_vector,
|
||||
"k": k,
|
||||
"num_candidates": k * 5,
|
||||
},
|
||||
"_source": {"excludes": ["embedding"]},
|
||||
}
|
||||
)
|
||||
knn_hits = resp["hits"]["hits"]
|
||||
except Exception as e:
|
||||
logger.warning(f"[eval] kNN falló: {e}")
|
||||
|
||||
rrf_scores: dict[str, float] = defaultdict(float)
|
||||
hit_by_id: dict[str, dict] = {}
|
||||
|
||||
for rank, hit in enumerate(bm25_hits):
|
||||
doc_id = hit["_id"]
|
||||
rrf_scores[doc_id] += 1.0 / (rank + 60)
|
||||
hit_by_id[doc_id] = hit
|
||||
|
||||
for rank, hit in enumerate(knn_hits):
|
||||
doc_id = hit["_id"]
|
||||
rrf_scores[doc_id] += 1.0 / (rank + 60)
|
||||
if doc_id not in hit_by_id:
|
||||
hit_by_id[doc_id] = hit
|
||||
|
||||
ranked = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)[:k]
|
||||
|
||||
return [
|
||||
hit_by_id[doc_id]["_source"].get("content")
|
||||
or hit_by_id[doc_id]["_source"].get("text", "")
|
||||
for doc_id, _ in ranked
|
||||
if (
|
||||
hit_by_id[doc_id]["_source"].get("content")
|
||||
or hit_by_id[doc_id]["_source"].get("text", "")
|
||||
).strip()
|
||||
]
|
||||
|
||||
|
||||
def generate_answer(llm, question: str, contexts: list[str]) -> str:
|
||||
try:
|
||||
from prompts import GENERATE_PROMPT
|
||||
context_text = "\n\n".join(
|
||||
f"[{i+1}] {ctx}" for i, ctx in enumerate(contexts)
|
||||
)
|
||||
prompt = SystemMessage(
|
||||
content=GENERATE_PROMPT.content.format(context=context_text)
|
||||
)
|
||||
resp = llm.invoke([prompt, HumanMessage(content=question)])
|
||||
return resp.content.strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"[eval] generate_answer fails: {e}")
|
||||
return ""
|
||||
|
||||
def run_evaluation( es_client, llm, embeddings, index_name, category = None, limit = None,):
|
||||
|
||||
if not ANTHROPIC_AVAILABLE:
|
||||
return {"error": "langchain-anthropic no instalado. pip install langchain-anthropic"}
|
||||
if not ANTHROPIC_API_KEY:
|
||||
return {"error": "ANTHROPIC_API_KEY no configurada en .env"}
|
||||
if not GOLDEN_DATASET_PATH.exists():
|
||||
return {"error": f"Golden dataset no encontrado en {GOLDEN_DATASET_PATH}"}
|
||||
|
||||
|
||||
questions = json.loads(GOLDEN_DATASET_PATH.read_text(encoding="utf-8"))
|
||||
if category:
|
||||
questions = [q for q in questions if q.get("category") == category]
|
||||
if limit:
|
||||
questions = questions[:limit]
|
||||
if not questions:
|
||||
return {"error": "NO QUESTIONS WITH THIS FILTERS"}
|
||||
|
||||
logger.info(f"[eval] makind: {len(questions)} questions, index={index_name}")
|
||||
|
||||
claude_judge = ChatAnthropic(
|
||||
model=CLAUDE_MODEL,
|
||||
api_key=ANTHROPIC_API_KEY,
|
||||
temperature=0,
|
||||
max_tokens=2048,
|
||||
)
|
||||
|
||||
rows = {"question": [], "answer": [], "contexts": [], "ground_truth": []}
|
||||
details = []
|
||||
t_start = time.time()
|
||||
|
||||
for item in questions:
|
||||
q_id = item["id"]
|
||||
question = item["question"]
|
||||
gt = item["ground_truth"]
|
||||
|
||||
logger.info(f"[eval] {q_id}: {question[:60]}")
|
||||
|
||||
contexts = retrieve_context(es_client, embeddings, question, index_name)
|
||||
if not contexts:
|
||||
logger.warning(f"[eval] No context for {q_id} — skipping")
|
||||
continue
|
||||
|
||||
answer = generate_answer(llm, question, contexts)
|
||||
if not answer:
|
||||
logger.warning(f"[eval] No answers for {q_id} — skipping")
|
||||
continue
|
||||
|
||||
rows["question"].append(question)
|
||||
rows["answer"].append(answer)
|
||||
rows["contexts"].append(contexts)
|
||||
rows["ground_truth"].append(gt)
|
||||
|
||||
details.append({
|
||||
"id": q_id,
|
||||
"category": item.get("category", ""),
|
||||
"question": question,
|
||||
"answer_preview": answer[:300],
|
||||
"n_chunks": len(contexts),
|
||||
})
|
||||
|
||||
if not rows["question"]:
|
||||
return {"error": "NO SAMPLES GENETARED"}
|
||||
|
||||
dataset = Dataset.from_dict(rows)
|
||||
ragas_llm = LangchainLLMWrapper(claude_judge)
|
||||
ragas_emb = LangchainEmbeddingsWrapper(embeddings)
|
||||
|
||||
metrics = [faithfulness, answer_relevancy, context_recall, context_precision]
|
||||
for metric in metrics:
|
||||
metric.llm = ragas_llm
|
||||
if hasattr(metric, "embeddings"):
|
||||
metric.embeddings = ragas_emb
|
||||
|
||||
logger.info("[eval] JUDGING BY CLAUDE...")
|
||||
result = ragas_evaluate(dataset, metrics=metrics)
|
||||
|
||||
elapsed = time.time() - t_start
|
||||
|
||||
scores = {
|
||||
"faithfulness": round(float(result.get("faithfulness", 0)), 4),
|
||||
"answer_relevancy": round(float(result.get("answer_relevancy", 0)), 4),
|
||||
"context_recall": round(float(result.get("context_recall", 0)), 4),
|
||||
"context_precision": round(float(result.get("context_precision", 0)), 4),
|
||||
}
|
||||
|
||||
valid_scores = [v for v in scores.values() if v > 0]
|
||||
global_score = round(sum(valid_scores) / len(valid_scores), 4) if valid_scores else 0.0
|
||||
|
||||
verdict = (
|
||||
"EXCELLENT" if global_score >= 0.8 else
|
||||
"ACCEPTABLE" if global_score >= 0.6 else
|
||||
"INSUFFICIENT"
|
||||
)
|
||||
|
||||
logger.info(f"[eval] FINISHED — global={global_score} verdict={verdict} "
|
||||
f"elapsed={elapsed:.0f}s")
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"questions_evaluated": len(rows["question"]),
|
||||
"elapsed_seconds": round(elapsed, 1),
|
||||
"judge_model": CLAUDE_MODEL,
|
||||
"index": index_name,
|
||||
"category_filter": category or "all",
|
||||
"scores": scores,
|
||||
"global_score": global_score,
|
||||
"verdict": verdict,
|
||||
"details": details,
|
||||
}
|
||||
|
|
@ -1,60 +1,391 @@
|
|||
# graph.py
|
||||
import logging
|
||||
|
||||
from collections import defaultdict
|
||||
from elasticsearch import Elasticsearch
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.messages import AIMessage, SystemMessage, HumanMessage, BaseMessage
|
||||
from langgraph.graph import END, StateGraph
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from prompts import GENERATE_PROMPT, REFORMULATE_PROMPT
|
||||
|
||||
from prompts import (
|
||||
CLASSIFY_PROMPT_TEMPLATE,
|
||||
CODE_GENERATION_PROMPT,
|
||||
CONVERSATIONAL_PROMPT,
|
||||
GENERATE_PROMPT,
|
||||
REFORMULATE_PROMPT,
|
||||
)
|
||||
|
||||
from state import AgentState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
session_store: dict[str, list] = defaultdict(list)
|
||||
|
||||
def format_context(docs: list[Document]) -> str:
|
||||
def format_context(docs):
|
||||
chunks = []
|
||||
for i, doc in enumerate(docs, 1):
|
||||
source = (doc.metadata or {}).get("source", "Untitled")
|
||||
source_id = (doc.metadata or {}).get("id", f"chunk-{i}")
|
||||
text = doc.page_content or ""
|
||||
chunks.append(f"[{i}] id={source_id} source={source}\n{text}")
|
||||
meta = doc.metadata or {}
|
||||
chunk_id = meta.get("chunk_id", meta.get("id", f"chunk-{i}"))
|
||||
source = meta.get("source_file", meta.get("source", "unknown"))
|
||||
doc_type = meta.get("doc_type", "")
|
||||
block_type = meta.get("block_type", "")
|
||||
section = meta.get("section", "")
|
||||
|
||||
text = (doc.page_content or "").strip()
|
||||
if not text:
|
||||
text = meta.get("content") or meta.get("text") or ""
|
||||
|
||||
header_parts = [f"[{i}]", f"id={chunk_id}"]
|
||||
if doc_type: header_parts.append(f"type={doc_type}")
|
||||
if block_type: header_parts.append(f"block={block_type}")
|
||||
if section: header_parts.append(f"section={section}")
|
||||
header_parts.append(f"source={source}")
|
||||
|
||||
if doc_type in ("code", "code_example", "bnf") or \
|
||||
block_type in ("function", "if", "startLoop", "try"):
|
||||
header_parts.append("[AVAP CODE]")
|
||||
|
||||
chunks.append(" ".join(header_parts) + "\n" + text)
|
||||
|
||||
return "\n\n".join(chunks)
|
||||
|
||||
|
||||
def build_graph(llm, vector_store) -> CompiledStateGraph:
|
||||
def format_history_for_classify(messages):
|
||||
lines = []
|
||||
for msg in messages[-6:]:
|
||||
if isinstance(msg, HumanMessage):
|
||||
lines.append(f"User: {msg.content}")
|
||||
elif isinstance(msg, AIMessage):
|
||||
lines.append(f"Assistant: {msg.content[:300]}")
|
||||
elif isinstance(msg, dict):
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")[:300]
|
||||
lines.append(f"{role.capitalize()}: {content}")
|
||||
return "\n".join(lines) if lines else "(no history)"
|
||||
|
||||
|
||||
def hybrid_search_native(es_client, embeddings, query, index_name, k=8):
|
||||
query_vector = None
|
||||
try:
|
||||
query_vector = embeddings.embed_query(query)
|
||||
except Exception as e:
|
||||
logger.warning(f"[hybrid] embed_query fails: {e}")
|
||||
|
||||
bm25_hits = []
|
||||
try:
|
||||
resp = es_client.search(
|
||||
index=index_name,
|
||||
body={
|
||||
"size": k,
|
||||
"query": {
|
||||
"multi_match": {
|
||||
"query": query,
|
||||
"fields": ["content^2", "text^2"],
|
||||
"type": "best_fields",
|
||||
"fuzziness": "AUTO",
|
||||
}
|
||||
},
|
||||
"_source": {"excludes": ["embedding"]},
|
||||
}
|
||||
)
|
||||
bm25_hits = resp["hits"]["hits"]
|
||||
logger.info(f"[hybrid] BM25 -> {len(bm25_hits)} hits")
|
||||
except Exception as e:
|
||||
logger.warning(f"[hybrid] BM25 fails: {e}")
|
||||
|
||||
knn_hits = []
|
||||
if query_vector:
|
||||
try:
|
||||
resp = es_client.search(
|
||||
index=index_name,
|
||||
body={
|
||||
"size": k,
|
||||
"knn": {
|
||||
"field": "embedding",
|
||||
"query_vector": query_vector,
|
||||
"k": k,
|
||||
"num_candidates": k * 5,
|
||||
},
|
||||
"_source": {"excludes": ["embedding"]},
|
||||
}
|
||||
)
|
||||
knn_hits = resp["hits"]["hits"]
|
||||
logger.info(f"[hybrid] kNN -> {len(knn_hits)} hits")
|
||||
except Exception as e:
|
||||
logger.warning(f"[hybrid] kNN fails: {e}")
|
||||
|
||||
rrf_scores: dict[str, float] = defaultdict(float)
|
||||
hit_by_id: dict[str, dict] = {}
|
||||
|
||||
for rank, hit in enumerate(bm25_hits):
|
||||
doc_id = hit["_id"]
|
||||
rrf_scores[doc_id] += 1.0 / (rank + 60)
|
||||
hit_by_id[doc_id] = hit
|
||||
|
||||
for rank, hit in enumerate(knn_hits):
|
||||
doc_id = hit["_id"]
|
||||
rrf_scores[doc_id] += 1.0 / (rank + 60)
|
||||
if doc_id not in hit_by_id:
|
||||
hit_by_id[doc_id] = hit
|
||||
|
||||
ranked = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)[:k]
|
||||
|
||||
docs = []
|
||||
for doc_id, score in ranked:
|
||||
src = hit_by_id[doc_id]["_source"]
|
||||
text = src.get("content") or src.get("text") or ""
|
||||
meta = {k: v for k, v in src.items()
|
||||
if k not in ("content", "text", "embedding")}
|
||||
meta["id"]= doc_id
|
||||
meta["rrf_score"] = score
|
||||
docs.append(Document(page_content=text, metadata=meta))
|
||||
|
||||
logger.info(f"[hybrid] RRF -> {len(docs)} final docs")
|
||||
return docs
|
||||
|
||||
def build_graph(llm, embeddings, es_client, index_name):
|
||||
|
||||
def _persist(state: AgentState, response: BaseMessage):
|
||||
session_id = state.get("session_id", "")
|
||||
if session_id:
|
||||
session_store[session_id] = list(state["messages"]) + [response]
|
||||
|
||||
def classify(state):
|
||||
messages = state["messages"]
|
||||
user_msg = messages[-1]
|
||||
question = getattr(user_msg, "content",
|
||||
user_msg.get("content", "")
|
||||
if isinstance(user_msg, dict) else "")
|
||||
history_msgs = messages[:-1]
|
||||
|
||||
if not history_msgs:
|
||||
prompt_content = (
|
||||
CLASSIFY_PROMPT_TEMPLATE
|
||||
.replace("{history}", "(no history)")
|
||||
.replace("{message}", question)
|
||||
)
|
||||
resp = llm.invoke([SystemMessage(content=prompt_content)])
|
||||
raw = resp.content.strip().upper()
|
||||
query_type = _parse_query_type(raw)
|
||||
logger.info(f"[classify] no historic content raw='{raw}' -> {query_type}")
|
||||
return {"query_type": query_type}
|
||||
|
||||
history_text = format_history_for_classify(history_msgs)
|
||||
prompt_content = (
|
||||
CLASSIFY_PROMPT_TEMPLATE
|
||||
.replace("{history}", history_text)
|
||||
.replace("{message}", question)
|
||||
)
|
||||
resp = llm.invoke([SystemMessage(content=prompt_content)])
|
||||
raw = resp.content.strip().upper()
|
||||
query_type = _parse_query_type(raw)
|
||||
logger.info(f"[classify] raw='{raw}' -> {query_type}")
|
||||
return {"query_type": query_type}
|
||||
|
||||
def _parse_query_type(raw: str) -> str:
|
||||
if raw.startswith("CODE_GENERATION") or "CODE" in raw:
|
||||
return "CODE_GENERATION"
|
||||
if raw.startswith("CONVERSATIONAL"):
|
||||
return "CONVERSATIONAL"
|
||||
return "RETRIEVAL"
|
||||
|
||||
def reformulate(state: AgentState) -> AgentState:
|
||||
user_msg = state["messages"][-1]
|
||||
resp = llm.invoke([REFORMULATE_PROMPT, user_msg])
|
||||
reformulated = resp.content.strip()
|
||||
logger.info(f"[reformulate] '{user_msg.content}' → '{reformulated}'")
|
||||
logger.info(f"[reformulate] -> '{reformulated}'")
|
||||
return {"reformulated_query": reformulated}
|
||||
|
||||
def retrieve(state: AgentState) -> AgentState:
|
||||
query = state["reformulated_query"]
|
||||
docs = vector_store.as_retriever(
|
||||
search_type="similarity",
|
||||
search_kwargs={"k": 3},
|
||||
).invoke(query)
|
||||
docs = hybrid_search_native(
|
||||
es_client=es_client,
|
||||
embeddings=embeddings,
|
||||
query=query,
|
||||
index_name=index_name,
|
||||
k=8,
|
||||
)
|
||||
context = format_context(docs)
|
||||
logger.info(f"[retrieve] {len(docs)} docs fetched")
|
||||
logger.info(context)
|
||||
logger.info(f"[retrieve] {len(docs)} docs, context len={len(context)}")
|
||||
return {"context": context}
|
||||
|
||||
def generate(state: AgentState) -> AgentState:
|
||||
def generate(state):
|
||||
prompt = SystemMessage(
|
||||
content=GENERATE_PROMPT.content.format(context=state["context"])
|
||||
)
|
||||
resp = llm.invoke([prompt] + state["messages"])
|
||||
logger.info(f"[generate] {len(resp.content)} chars")
|
||||
_persist(state, resp)
|
||||
return {"messages": [resp]}
|
||||
|
||||
def generate_code(state):
|
||||
prompt = SystemMessage(
|
||||
content=CODE_GENERATION_PROMPT.content.format(context=state["context"])
|
||||
)
|
||||
resp = llm.invoke([prompt] + state["messages"])
|
||||
logger.info(f"[generate_code] {len(resp.content)} chars")
|
||||
_persist(state, resp)
|
||||
return {"messages": [resp]}
|
||||
|
||||
def respond_conversational(state):
|
||||
resp = llm.invoke([CONVERSATIONAL_PROMPT] + state["messages"])
|
||||
logger.info("[conversational] from comversation")
|
||||
_persist(state, resp)
|
||||
return {"messages": [resp]}
|
||||
|
||||
def route_by_type(state):
|
||||
return state.get("query_type", "RETRIEVAL")
|
||||
|
||||
def route_after_retrieve(state):
|
||||
qt = state.get("query_type", "RETRIEVAL")
|
||||
return "generate_code" if qt == "CODE_GENERATION" else "generate"
|
||||
|
||||
graph_builder = StateGraph(AgentState)
|
||||
|
||||
graph_builder.add_node("classify", classify)
|
||||
graph_builder.add_node("reformulate", reformulate)
|
||||
graph_builder.add_node("retrieve", retrieve)
|
||||
graph_builder.add_node("generate", generate)
|
||||
graph_builder.add_node("generate_code", generate_code)
|
||||
graph_builder.add_node("respond_conversational", respond_conversational)
|
||||
|
||||
graph_builder.set_entry_point("classify")
|
||||
|
||||
graph_builder.add_conditional_edges(
|
||||
"classify",
|
||||
route_by_type,
|
||||
{
|
||||
"RETRIEVAL": "reformulate",
|
||||
"CODE_GENERATION": "reformulate",
|
||||
"CONVERSATIONAL": "respond_conversational",
|
||||
}
|
||||
)
|
||||
|
||||
graph_builder.set_entry_point("reformulate")
|
||||
graph_builder.add_edge("reformulate", "retrieve")
|
||||
graph_builder.add_edge("retrieve", "generate")
|
||||
|
||||
graph_builder.add_conditional_edges(
|
||||
"retrieve",
|
||||
route_after_retrieve,
|
||||
{
|
||||
"generate": "generate",
|
||||
"generate_code": "generate_code",
|
||||
}
|
||||
)
|
||||
|
||||
graph_builder.add_edge("generate", END)
|
||||
graph_builder.add_edge("generate_code", END)
|
||||
graph_builder.add_edge("respond_conversational", END)
|
||||
|
||||
return graph_builder.compile()
|
||||
|
||||
|
||||
def build_prepare_graph(llm, embeddings, es_client, index_name):
|
||||
|
||||
def classify(state):
|
||||
messages = state["messages"]
|
||||
user_msg = messages[-1]
|
||||
question = getattr(user_msg, "content",
|
||||
user_msg.get("content", "")
|
||||
if isinstance(user_msg, dict) else "")
|
||||
history_msgs = messages[:-1]
|
||||
|
||||
if not history_msgs:
|
||||
prompt_content = (
|
||||
CLASSIFY_PROMPT_TEMPLATE
|
||||
.replace("{history}", "(no history)")
|
||||
.replace("{message}", question)
|
||||
)
|
||||
resp = llm.invoke([SystemMessage(content=prompt_content)])
|
||||
raw = resp.content.strip().upper()
|
||||
query_type = _parse_query_type(raw)
|
||||
logger.info(f"[prepare/classify] no history raw='{raw}' -> {query_type}")
|
||||
return {"query_type": query_type}
|
||||
|
||||
history_text = format_history_for_classify(history_msgs)
|
||||
prompt_content = (
|
||||
CLASSIFY_PROMPT_TEMPLATE
|
||||
.replace("{history}", history_text)
|
||||
.replace("{message}", question)
|
||||
)
|
||||
resp = llm.invoke([SystemMessage(content=prompt_content)])
|
||||
raw = resp.content.strip().upper()
|
||||
query_type = _parse_query_type(raw)
|
||||
logger.info(f"[prepare/classify] raw='{raw}' -> {query_type}")
|
||||
return {"query_type": query_type}
|
||||
|
||||
def _parse_query_type(raw: str) -> str:
|
||||
if raw.startswith("CODE_GENERATION") or "CODE" in raw:
|
||||
return "CODE_GENERATION"
|
||||
if raw.startswith("CONVERSATIONAL"):
|
||||
return "CONVERSATIONAL"
|
||||
return "RETRIEVAL"
|
||||
|
||||
def reformulate(state: AgentState) -> AgentState:
|
||||
user_msg = state["messages"][-1]
|
||||
resp = llm.invoke([REFORMULATE_PROMPT, user_msg])
|
||||
reformulated = resp.content.strip()
|
||||
logger.info(f"[prepare/reformulate] -> '{reformulated}'")
|
||||
return {"reformulated_query": reformulated}
|
||||
|
||||
def retrieve(state: AgentState) -> AgentState:
|
||||
query = state["reformulated_query"]
|
||||
docs = hybrid_search_native(
|
||||
es_client=es_client,
|
||||
embeddings=embeddings,
|
||||
query=query,
|
||||
index_name=index_name,
|
||||
k=8,
|
||||
)
|
||||
context = format_context(docs)
|
||||
logger.info(f"[prepare/retrieve] {len(docs)} docs, context len={len(context)}")
|
||||
return {"context": context}
|
||||
|
||||
def skip_retrieve(state: AgentState) -> AgentState:
|
||||
return {"context": ""}
|
||||
|
||||
def route_by_type(state):
|
||||
return state.get("query_type", "RETRIEVAL")
|
||||
|
||||
graph_builder = StateGraph(AgentState)
|
||||
|
||||
graph_builder.add_node("classify", classify)
|
||||
graph_builder.add_node("reformulate", reformulate)
|
||||
graph_builder.add_node("retrieve", retrieve)
|
||||
graph_builder.add_node("skip_retrieve", skip_retrieve)
|
||||
|
||||
graph_builder.set_entry_point("classify")
|
||||
|
||||
graph_builder.add_conditional_edges(
|
||||
"classify",
|
||||
route_by_type,
|
||||
{
|
||||
"RETRIEVAL": "reformulate",
|
||||
"CODE_GENERATION": "reformulate",
|
||||
"CONVERSATIONAL": "skip_retrieve",
|
||||
}
|
||||
)
|
||||
|
||||
graph_builder.add_edge("reformulate", "retrieve")
|
||||
graph_builder.add_edge("retrieve", END)
|
||||
graph_builder.add_edge("skip_retrieve",END)
|
||||
|
||||
return graph_builder.compile()
|
||||
|
||||
|
||||
def build_final_messages(state: AgentState) -> list:
|
||||
query_type = state.get("query_type", "RETRIEVAL")
|
||||
context = state.get("context", "")
|
||||
messages = state.get("messages", [])
|
||||
|
||||
if query_type == "CONVERSATIONAL":
|
||||
return [CONVERSATIONAL_PROMPT] + messages
|
||||
|
||||
if query_type == "CODE_GENERATION":
|
||||
prompt = SystemMessage(
|
||||
content=CODE_GENERATION_PROMPT.content.format(context=context)
|
||||
)
|
||||
else:
|
||||
prompt = SystemMessage(
|
||||
content=GENERATE_PROMPT.content.format(context=context)
|
||||
)
|
||||
|
||||
return [prompt] + messages
|
||||
|
|
@ -0,0 +1,420 @@
|
|||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from typing import AsyncIterator, Optional, Any, Literal, Union
|
||||
|
||||
import grpc
|
||||
import brunix_pb2
|
||||
import brunix_pb2_grpc
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("openai-proxy")
|
||||
|
||||
_thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=int(os.getenv("PROXY_THREAD_WORKERS", "20"))
|
||||
)
|
||||
|
||||
GRPC_TARGET = os.getenv("BRUNIX_GRPC_TARGET", "localhost:50051")
|
||||
PROXY_MODEL = os.getenv("PROXY_MODEL_ID", "brunix")
|
||||
|
||||
_channel: Optional[grpc.Channel] = None
|
||||
_stub: Optional[brunix_pb2_grpc.AssistanceEngineStub] = None
|
||||
|
||||
|
||||
def get_stub() -> brunix_pb2_grpc.AssistanceEngineStub:
|
||||
global _channel, _stub
|
||||
if _stub is None:
|
||||
_channel = grpc.insecure_channel(GRPC_TARGET)
|
||||
_stub = brunix_pb2_grpc.AssistanceEngineStub(_channel)
|
||||
logger.info(f"[gRPC] connected to {GRPC_TARGET}")
|
||||
return _stub
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Brunix OpenAI-Compatible Proxy",
|
||||
version="2.0.0",
|
||||
description="stream:false → AskAgent | stream:true → AskAgentStream",
|
||||
)
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant", "function"] = "user"
|
||||
content: str = ""
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str = PROXY_MODEL
|
||||
messages: list[ChatMessage]
|
||||
stream: bool = False
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
session_id: Optional[str] = None # extensión Brunix
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = 1
|
||||
stop: Optional[Any] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
model: str = PROXY_MODEL
|
||||
prompt: Union[str, list[str]] = ""
|
||||
stream: bool = False
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
session_id: Optional[str] = None
|
||||
suffix: Optional[str] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = 1
|
||||
stop: Optional[Any] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
|
||||
# Ollama schemas
|
||||
class OllamaChatMessage(BaseModel):
|
||||
role: str = "user"
|
||||
content: str = ""
|
||||
|
||||
|
||||
class OllamaChatRequest(BaseModel):
|
||||
model: str = PROXY_MODEL
|
||||
messages: list[OllamaChatMessage]
|
||||
stream: bool = True # Ollama hace stream por defecto
|
||||
session_id: Optional[str] = None
|
||||
|
||||
|
||||
class OllamaGenerateRequest(BaseModel):
|
||||
model: str = PROXY_MODEL
|
||||
prompt: str = ""
|
||||
stream: bool = True
|
||||
session_id: Optional[str] = None
|
||||
|
||||
|
||||
def _ts() -> int:
|
||||
return int(time.time())
|
||||
|
||||
|
||||
def _chat_response(content: str, req_id: str) -> dict:
|
||||
return {
|
||||
"id": req_id, "object": "chat.completion", "created": _ts(),
|
||||
"model": PROXY_MODEL,
|
||||
"choices": [{"index": 0, "message": {"role": "assistant", "content": content}, "finish_reason": "stop"}],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
def _completion_response(text: str, req_id: str) -> dict:
|
||||
return {
|
||||
"id": req_id, "object": "text_completion", "created": _ts(),
|
||||
"model": PROXY_MODEL,
|
||||
"choices": [{"text": text, "index": 0, "logprobs": None, "finish_reason": "stop"}],
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
def _chat_chunk(delta: str, req_id: str, finish: Optional[str] = None) -> dict:
|
||||
return {
|
||||
"id": req_id, "object": "chat.completion.chunk", "created": _ts(),
|
||||
"model": PROXY_MODEL,
|
||||
"choices": [{"index": 0,
|
||||
"delta": {"role": "assistant", "content": delta} if delta else {},
|
||||
"finish_reason": finish}],
|
||||
}
|
||||
|
||||
|
||||
def _completion_chunk(text: str, req_id: str, finish: Optional[str] = None) -> dict:
|
||||
return {
|
||||
"id": req_id, "object": "text_completion", "created": _ts(),
|
||||
"model": PROXY_MODEL,
|
||||
"choices": [{"text": text, "index": 0, "logprobs": None, "finish_reason": finish}],
|
||||
}
|
||||
|
||||
|
||||
def _sse(data: dict) -> str:
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
def _sse_done() -> str:
|
||||
return "data: [DONE]\n\n"
|
||||
|
||||
|
||||
def _query_from_messages(messages: list[ChatMessage]) -> str:
|
||||
for m in reversed(messages):
|
||||
if m.role == "user":
|
||||
return m.content
|
||||
return ""
|
||||
|
||||
|
||||
async def _invoke_blocking(query: str, session_id: str) -> str:
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _call():
|
||||
stub = get_stub()
|
||||
req = brunix_pb2.AgentRequest(query=query, session_id=session_id)
|
||||
parts = []
|
||||
for resp in stub.AskAgent(req):
|
||||
if resp.text:
|
||||
parts.append(resp.text)
|
||||
return "".join(parts)
|
||||
|
||||
return await loop.run_in_executor(_thread_pool, _call)
|
||||
|
||||
|
||||
async def _iter_stream(query: str, session_id: str) -> AsyncIterator[brunix_pb2.AgentResponse]:
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
def _producer():
|
||||
try:
|
||||
stub = get_stub()
|
||||
req = brunix_pb2.AgentRequest(query=query, session_id=session_id)
|
||||
for resp in stub.AskAgentStream(req): # ← AskAgentStream
|
||||
asyncio.run_coroutine_threadsafe(queue.put(resp), loop).result()
|
||||
except Exception as e:
|
||||
asyncio.run_coroutine_threadsafe(queue.put(e), loop).result()
|
||||
finally:
|
||||
asyncio.run_coroutine_threadsafe(queue.put(None), loop).result() # sentinel
|
||||
|
||||
_thread_pool.submit(_producer)
|
||||
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
|
||||
|
||||
async def _stream_chat(query: str, session_id: str, req_id: str) -> AsyncIterator[str]:
|
||||
try:
|
||||
async for resp in _iter_stream(query, session_id):
|
||||
if resp.is_final:
|
||||
yield _sse(_chat_chunk("", req_id, finish="stop"))
|
||||
break
|
||||
if resp.text:
|
||||
yield _sse(_chat_chunk(resp.text, req_id))
|
||||
except Exception as e:
|
||||
logger.error(f"[stream_chat] error: {e}")
|
||||
yield _sse(_chat_chunk(f"[Error: {e}]", req_id, finish="stop"))
|
||||
|
||||
yield _sse_done()
|
||||
|
||||
|
||||
async def _stream_completion(query: str, session_id: str, req_id: str) -> AsyncIterator[str]:
|
||||
try:
|
||||
async for resp in _iter_stream(query, session_id):
|
||||
if resp.is_final:
|
||||
yield _sse(_completion_chunk("", req_id, finish="stop"))
|
||||
break
|
||||
if resp.text:
|
||||
yield _sse(_completion_chunk(resp.text, req_id))
|
||||
except Exception as e:
|
||||
logger.error(f"[stream_completion] error: {e}")
|
||||
yield _sse(_completion_chunk(f"[Error: {e}]", req_id, finish="stop"))
|
||||
|
||||
yield _sse_done()
|
||||
|
||||
|
||||
def _ollama_chat_chunk(token: str, done: bool) -> str:
|
||||
return json.dumps({
|
||||
"model": PROXY_MODEL,
|
||||
"created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
"message": {"role": "assistant", "content": token},
|
||||
"done": done,
|
||||
}) + "\n"
|
||||
|
||||
|
||||
def _ollama_generate_chunk(token: str, done: bool) -> str:
|
||||
return json.dumps({
|
||||
"model": PROXY_MODEL,
|
||||
"created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
"response": token,
|
||||
"done": done,
|
||||
}) + "\n"
|
||||
|
||||
|
||||
async def _stream_ollama_chat(query: str, session_id: str) -> AsyncIterator[str]:
|
||||
try:
|
||||
async for resp in _iter_stream(query, session_id):
|
||||
if resp.is_final:
|
||||
yield _ollama_chat_chunk("", done=True)
|
||||
break
|
||||
if resp.text:
|
||||
yield _ollama_chat_chunk(resp.text, done=False)
|
||||
except Exception as e:
|
||||
logger.error(f"[ollama_chat] error: {e}")
|
||||
yield _ollama_chat_chunk(f"[Error: {e}]", done=True)
|
||||
|
||||
|
||||
async def _stream_ollama_generate(query: str, session_id: str) -> AsyncIterator[str]:
|
||||
try:
|
||||
async for resp in _iter_stream(query, session_id):
|
||||
if resp.is_final:
|
||||
yield _ollama_generate_chunk("", done=True)
|
||||
break
|
||||
if resp.text:
|
||||
yield _ollama_generate_chunk(resp.text, done=False)
|
||||
except Exception as e:
|
||||
logger.error(f"[ollama_generate] error: {e}")
|
||||
yield _ollama_generate_chunk(f"[Error: {e}]", done=True)
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def list_models():
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [{
|
||||
"id": PROXY_MODEL, "object": "model", "created": 1700000000,
|
||||
"owned_by": "brunix", "permission": [], "root": PROXY_MODEL, "parent": None,
|
||||
}],
|
||||
}
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(req: ChatCompletionRequest):
|
||||
query = _query_from_messages(req.messages)
|
||||
session_id = req.session_id or req.user or "default"
|
||||
req_id = f"chatcmpl-{uuid.uuid4().hex}"
|
||||
|
||||
logger.info(f"[chat] session={session_id} stream={req.stream} query='{query[:80]}'")
|
||||
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="No user message found in messages.")
|
||||
|
||||
if req.stream:
|
||||
|
||||
return StreamingResponse(
|
||||
_stream_chat(query, session_id, req_id),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
try:
|
||||
text = await _invoke_blocking(query, session_id)
|
||||
except grpc.RpcError as e:
|
||||
raise HTTPException(status_code=502, detail=f"gRPC error: {e.details()}")
|
||||
|
||||
return JSONResponse(_chat_response(text, req_id))
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def completions(req: CompletionRequest):
|
||||
query = req.prompt if isinstance(req.prompt, str) else " ".join(req.prompt)
|
||||
session_id = req.session_id or req.user or "default"
|
||||
req_id = f"cmpl-{uuid.uuid4().hex}"
|
||||
|
||||
logger.info(f"[completion] session={session_id} stream={req.stream} prompt='{query[:80]}'")
|
||||
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="prompt is required.")
|
||||
|
||||
if req.stream:
|
||||
return StreamingResponse(
|
||||
_stream_completion(query, session_id, req_id),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
try:
|
||||
text = await _invoke_blocking(query, session_id)
|
||||
except grpc.RpcError as e:
|
||||
raise HTTPException(status_code=502, detail=f"gRPC error: {e.details()}")
|
||||
|
||||
return JSONResponse(_completion_response(text, req_id))
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "grpc_target": GRPC_TARGET}
|
||||
|
||||
|
||||
@app.get("/api/tags")
|
||||
async def ollama_tags():
|
||||
return {
|
||||
"models": [{
|
||||
"name": PROXY_MODEL,
|
||||
"model":PROXY_MODEL,
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 0,
|
||||
"digest":"brunix",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "brunix",
|
||||
"parameter_size": "unknown",
|
||||
"quantization_level": "unknown",
|
||||
},
|
||||
}]
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/chat")
|
||||
async def ollama_chat(req: OllamaChatRequest):
|
||||
|
||||
query = next((m.content for m in reversed(req.messages) if m.role == "user"), "")
|
||||
session_id = req.session_id or "default"
|
||||
|
||||
logger.info(f"[ollama/chat] session={session_id} stream={req.stream} query='{query[:80]}'")
|
||||
|
||||
if not query:
|
||||
raise HTTPException(status_code=400, detail="No user message found.")
|
||||
|
||||
if req.stream:
|
||||
return StreamingResponse(
|
||||
_stream_ollama_chat(query, session_id),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
try:
|
||||
text = await _invoke_blocking(query, session_id)
|
||||
except grpc.RpcError as e:
|
||||
raise HTTPException(status_code=502, detail=f"gRPC error: {e.details()}")
|
||||
|
||||
return JSONResponse({
|
||||
"model": PROXY_MODEL,
|
||||
"created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
"message": {"role": "assistant", "content": text},
|
||||
"done": True,
|
||||
})
|
||||
|
||||
|
||||
@app.post("/api/generate")
|
||||
async def ollama_generate(req: OllamaGenerateRequest):
|
||||
|
||||
session_id = req.session_id or "default"
|
||||
|
||||
logger.info(f"[ollama/generate] session={session_id} stream={req.stream} prompt='{req.prompt[:80]}'")
|
||||
|
||||
if not req.prompt:
|
||||
raise HTTPException(status_code=400, detail="prompt is required.")
|
||||
|
||||
if req.stream:
|
||||
return StreamingResponse(
|
||||
_stream_ollama_generate(req.prompt, session_id),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
try:
|
||||
text = await _invoke_blocking(req.prompt, session_id)
|
||||
except grpc.RpcError as e:
|
||||
raise HTTPException(status_code=502, detail=f"gRPC error: {e.details()}")
|
||||
|
||||
return JSONResponse({
|
||||
"model": PROXY_MODEL,
|
||||
"created_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
"response": text,
|
||||
"done": True,
|
||||
})
|
||||
|
|
@ -1,89 +1,250 @@
|
|||
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
CLASSIFY_PROMPT_TEMPLATE = (
|
||||
"<role>\n"
|
||||
"You are a query classifier for an AVAP language assistant. "
|
||||
"Your only job is to classify the user message into one of three categories.\n"
|
||||
"</role>\n\n"
|
||||
|
||||
"<categories>\n"
|
||||
"RETRIEVAL — the user is asking about AVAP concepts, documentation, syntax rules, "
|
||||
"or how something works. They want an explanation, not code.\n"
|
||||
"Examples: 'What is addVar?', 'How does registerEndpoint work?', "
|
||||
"'What is the difference between if() modes?'\n\n"
|
||||
|
||||
"CODE_GENERATION — the user is asking to generate, write, create, build, or show "
|
||||
"an example of an AVAP script, function, API, or code snippet. "
|
||||
"They want working code as output.\n"
|
||||
"Examples: 'Write an API that returns hello world', "
|
||||
"'Generate a function that queries the DB', "
|
||||
"'Show me how to create an endpoint', "
|
||||
"'dame un ejemplo de codigo', 'escribeme un script', "
|
||||
"'dime como seria un API', 'genera un API', 'como haria'\n\n"
|
||||
|
||||
"CONVERSATIONAL — the user is following up on the previous answer. "
|
||||
"They want a reformulation, summary, or elaboration of what was already said.\n"
|
||||
"Examples: 'can you explain that?', 'en menos palabras', "
|
||||
"'describe it in your own words', 'what did you mean?'\n"
|
||||
"</categories>\n\n"
|
||||
|
||||
"<output_rule>\n"
|
||||
"Your entire response must be exactly one word: "
|
||||
"RETRIEVAL, CODE_GENERATION, or CONVERSATIONAL. Nothing else.\n"
|
||||
"</output_rule>\n\n"
|
||||
|
||||
"<conversation_history>\n"
|
||||
"{history}\n"
|
||||
"</conversation_history>\n\n"
|
||||
|
||||
"<user_message>{message}</user_message>"
|
||||
)
|
||||
|
||||
REFORMULATE_PROMPT = SystemMessage(
|
||||
content=(
|
||||
"You are a deterministic lexical query rewriter used for vector retrieval.\n"
|
||||
"Your task is to rewrite user questions into optimized keyword search queries.\n\n"
|
||||
"<role>\n"
|
||||
"You are a deterministic query rewriter whose sole purpose is to prepare "
|
||||
"user questions for vector similarity retrieval against an AVAP language "
|
||||
"knowledge base. You do not answer questions. You only transform phrasing "
|
||||
"into keyword queries that will find the right AVAP documentation chunks.\n"
|
||||
"</role>\n\n"
|
||||
|
||||
"CRITICAL RULES (ABSOLUTE):\n"
|
||||
"1. NEVER answer the question.\n"
|
||||
"2. NEVER expand acronyms.\n"
|
||||
"3. NEVER introduce new terms not present in the original query.\n"
|
||||
"4. NEVER infer missing information.\n"
|
||||
"5. NEVER add explanations, definitions, or interpretations.\n"
|
||||
"6. Preserve all technical tokens exactly as written.\n"
|
||||
"7. Only remove filler words (e.g., what, does, is, explain, tell me, please).\n"
|
||||
"8. You may reorder terms for better retrieval.\n"
|
||||
"9. Output must be a single-line plain keyword query.\n"
|
||||
"10. If the query is already optimal, return it unchanged.\n\n"
|
||||
"11. If you receive something that looks like code, do NOT attempt to rewrite it. Return it verbatim.\n\n"
|
||||
"<task>\n"
|
||||
"Rewrite the user message into a compact keyword query for semantic search.\n\n"
|
||||
|
||||
"ALLOWED OPERATIONS:\n"
|
||||
"- Remove interrogative phrasing.\n"
|
||||
"- Remove stopwords.\n"
|
||||
"- Reorder words.\n"
|
||||
"- Convert to noun phrase form.\n\n"
|
||||
"SPECIAL RULE for code generation requests:\n"
|
||||
"When the user asks to generate/create/build/show AVAP code, expand the query "
|
||||
"with the AVAP commands typically needed. Use this mapping:\n\n"
|
||||
|
||||
"FORBIDDEN OPERATIONS:\n"
|
||||
"- Expanding abbreviations.\n"
|
||||
"- Paraphrasing into unseen vocabulary.\n"
|
||||
"- Adding definitions.\n"
|
||||
"- Answering implicitly.\n\n"
|
||||
"- API / endpoint / route / HTTP response\n"
|
||||
" expand to: AVAP registerEndpoint addResult _status\n\n"
|
||||
|
||||
"Examples:\n"
|
||||
"Input: What does AVAP stand for?\n"
|
||||
"Output: AVAP stand for\n"
|
||||
"- Read input / parameter\n"
|
||||
" expand to: AVAP addParam getQueryParamList\n\n"
|
||||
|
||||
"Input: Hey, I'm trying to understand how AVAP handels a ZeroDivisionError when doing divison or modulus operatoins. Can you explane what situatoins cause a ZeroDivisionError to be raised and how I can catch it in my AVAP scripts?\n"
|
||||
"Output: AVAP ZeroDivisionError division / modulus % catch try except\n"
|
||||
"- Database / ORM / query\n"
|
||||
" expand to: AVAP ormAccessSelect ormAccessInsert avapConnector\n\n"
|
||||
|
||||
"Input: What does AVAP stand for?\n"
|
||||
"Output: AVAP stand for\n"
|
||||
"- Error handling\n"
|
||||
" expand to: AVAP try exception end\n\n"
|
||||
|
||||
"Input: Please explain how the import statement works in AVAP scripts.\n"
|
||||
"Output: AVAP import statement syntax behavior\n\n"
|
||||
"- Loop / iterate\n"
|
||||
" expand to: AVAP startLoop endLoop itemFromList getListLen\n\n"
|
||||
|
||||
"Return only the rewritten query."
|
||||
"- HTTP request / call external\n"
|
||||
" expand to: AVAP RequestPost RequestGet\n"
|
||||
"</task>\n\n"
|
||||
|
||||
"<rules>\n"
|
||||
"- Preserve all AVAP identifiers verbatim.\n"
|
||||
"- Remove filler words.\n"
|
||||
"- Output a single line.\n"
|
||||
"- Never answer the question.\n"
|
||||
"</rules>\n\n"
|
||||
|
||||
"<examples>\n"
|
||||
"<example>\n"
|
||||
"<input>What does AVAP stand for?</input>\n"
|
||||
"<o>AVAP stand for</o>\n"
|
||||
"</example>\n\n"
|
||||
|
||||
"<example>\n"
|
||||
"<input>dime como seria un API que devuelva hello world con AVAP</input>\n"
|
||||
"<o>AVAP registerEndpoint addResult _status hello world example</o>\n"
|
||||
"</example>\n\n"
|
||||
|
||||
"<example>\n"
|
||||
"<input>generate an AVAP script that reads a parameter and queries the DB</input>\n"
|
||||
"<o>AVAP addParam ormAccessSelect avapConnector registerEndpoint addResult</o>\n"
|
||||
"</example>\n"
|
||||
"</examples>\n\n"
|
||||
|
||||
"Return only the rewritten query. No labels, no prefixes, no explanation."
|
||||
)
|
||||
)
|
||||
|
||||
CONFIDENCE_PROMPT_TEMPLATE = (
|
||||
"<role>\n"
|
||||
"You are a relevance evaluator. Decide whether the context contains "
|
||||
"useful information to address the user question.\n"
|
||||
"</role>\n\n"
|
||||
|
||||
"<task>\n"
|
||||
"Answer YES if the context contains at least one relevant passage. "
|
||||
"Answer NO only if context is empty or completely unrelated.\n"
|
||||
"</task>\n\n"
|
||||
|
||||
"<output_rule>\n"
|
||||
"Exactly one word: YES or NO.\n"
|
||||
"</output_rule>\n\n"
|
||||
|
||||
"<question>{question}</question>\n\n"
|
||||
"<context>{context}</context>"
|
||||
)
|
||||
|
||||
|
||||
CODE_GENERATION_PROMPT = SystemMessage(
|
||||
content=(
|
||||
"<role>\n"
|
||||
"You are an expert AVAP programmer. AVAP (Advanced Virtual API Programming) "
|
||||
"is a domain-specific language for orchestrating microservices and HTTP I/O. "
|
||||
"Write correct, minimal, working AVAP code.\n"
|
||||
"</role>\n\n"
|
||||
|
||||
"<critical_rules>\n"
|
||||
"1. AVAP is line-oriented: every statement on a single line.\n"
|
||||
"2. Use ONLY commands from <avap_syntax_reminder> or explicitly described in <context>.\n"
|
||||
"3. Do NOT copy code examples from <context> that solve a DIFFERENT problem. "
|
||||
"Context examples are syntax references only — ignore them if unrelated.\n"
|
||||
"4. Write the MINIMUM code needed. No extra connectors, no unrelated variables.\n"
|
||||
"5. Add brief inline comments explaining each part.\n"
|
||||
"6. Answer in the same language the user used.\n"
|
||||
"</critical_rules>\n\n"
|
||||
|
||||
"<avap_syntax_reminder>\n"
|
||||
"// Register an HTTP endpoint\n"
|
||||
"registerEndpoint(\"GET\", \"/path\", [], \"scope\", handlerFn, \"\")\n\n"
|
||||
"// Declare a function — uses curly braces, NOT end()\n"
|
||||
"function handlerFn() {{\n"
|
||||
" msg = \"Hello World\"\n"
|
||||
" addResult(msg)\n"
|
||||
"}}\n\n"
|
||||
"// Assign a value to a variable\n"
|
||||
"addVar(varName, \"value\") // or: varName = \"value\"\n\n"
|
||||
"// Add variable to HTTP JSON response body\n"
|
||||
"addResult(varName)\n\n"
|
||||
"// Set HTTP response status code\n"
|
||||
"_status = 200 // or: addVar(_status, 200)\n\n"
|
||||
"// Read a request parameter (URL, body, or form)\n"
|
||||
"addParam(\"paramName\", targetVar)\n\n"
|
||||
"// Conditional\n"
|
||||
"if(var, value, \"==\")\n"
|
||||
" // ...\n"
|
||||
"end()\n\n"
|
||||
"// Loop\n"
|
||||
"startLoop(i, 0, length)\n"
|
||||
" // ...\n"
|
||||
"endLoop()\n\n"
|
||||
"// Error handling\n"
|
||||
"try()\n"
|
||||
" // ...\n"
|
||||
"exception(errVar)\n"
|
||||
" // handle\n"
|
||||
"end()\n"
|
||||
"</avap_syntax_reminder>\n\n"
|
||||
|
||||
"<task>\n"
|
||||
"Generate a minimal, complete AVAP example for the user's request.\n\n"
|
||||
"Structure:\n"
|
||||
"1. One sentence describing what the code does.\n"
|
||||
"2. The AVAP code block — clean, minimal, with inline comments.\n"
|
||||
"3. Two or three lines explaining the key commands used.\n"
|
||||
"</task>\n\n"
|
||||
|
||||
"<context>\n"
|
||||
"{context}\n"
|
||||
"</context>"
|
||||
)
|
||||
)
|
||||
|
||||
CONVERSATIONAL_PROMPT = SystemMessage(
|
||||
content=(
|
||||
"<role>\n"
|
||||
"You are a helpful AVAP assistant continuing an ongoing conversation.\n"
|
||||
"</role>\n\n"
|
||||
|
||||
"<task>\n"
|
||||
"The user is following up on something already discussed. "
|
||||
"Rephrase, summarize, or elaborate using the conversation history.\n"
|
||||
"</task>\n\n"
|
||||
|
||||
"<rules>\n"
|
||||
"- Base your answer on the conversation history.\n"
|
||||
"- Do not introduce new AVAP facts not in the history.\n"
|
||||
"- Keep the same language the user is using.\n"
|
||||
"- No Answer/Evidence format. Just answer naturally.\n"
|
||||
"</rules>"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
GENERATE_PROMPT = SystemMessage(
|
||||
content=(
|
||||
"""You are a retrieval-grounded assistant.
|
||||
"<role>\n"
|
||||
"You are a precise, retrieval-grounded assistant specialized in AVAP. "
|
||||
"Answers are honest, calibrated to evidence, and clearly structured.\n"
|
||||
"</role>\n\n"
|
||||
|
||||
CRITICAL SCOPE RULES
|
||||
- AVAP (Advanced Virtual API Programming) is a new/unknown programming language.
|
||||
- If the user message is AVAP-related (explicitly mentions "AVAP" case-insensitive OR the provided context is AVAP code/docs), you MUST answer using ONLY <CONTEXT>.
|
||||
- You MUST NOT use general programming knowledge to interpret AVAP.
|
||||
- You MUST NOT assume AVAP behaves like any other language even if syntax looks similar.
|
||||
- You MUST NOT infer missing details. Only state what is explicitly present in <CONTEXT>.
|
||||
"<critical_constraint>\n"
|
||||
"AVAP is a new proprietary language. Use ONLY content inside <context>. "
|
||||
"Treat any AVAP knowledge outside <context> as unreliable.\n"
|
||||
"</critical_constraint>\n\n"
|
||||
|
||||
GROUNDING REQUIREMENTS (AVAP-RELATED)
|
||||
1) Every non-trivial factual claim MUST be directly supported by an EXACT QUOTE from <CONTEXT>.
|
||||
2) If a claim is not supported by a quote, DO NOT include it.
|
||||
3) If <CONTEXT> does not contain enough information to answer, reply with EXACTLY:
|
||||
"I don't have enough information in the provided context to answer that."
|
||||
"<task>\n"
|
||||
"Answer using exclusively the information in <context>.\n"
|
||||
"</task>\n\n"
|
||||
|
||||
WORKFLOW (AVAP-RELATED) — FOLLOW IN ORDER
|
||||
A) Identify the specific question(s) being asked.
|
||||
B) Extract the minimum necessary quotes from <CONTEXT> that answer those question(s).
|
||||
C) Write the answer using ONLY those quotes (paraphrase is allowed, but every statement must be backed by at least one quote).
|
||||
D) Verify: for EACH sentence in your answer, confirm there is a supporting quote. If any sentence lacks a quote, delete it or refuse.
|
||||
"<thinking_steps>\n"
|
||||
"Step 1 — Find relevant passages in <context>.\n"
|
||||
"Step 2 — Assess if question can be fully or partially answered.\n"
|
||||
"Step 3 — Write a clear answer backed by those passages.\n"
|
||||
"Step 4 — If context contains relevant AVAP code, include it exactly.\n"
|
||||
"</thinking_steps>\n\n"
|
||||
|
||||
OUTPUT FORMAT (AVAP-RELATED ONLY)
|
||||
Answer:
|
||||
<short, direct answer; no extra speculation; no unrelated tips>
|
||||
"<output_format>\n"
|
||||
"Answer:\n"
|
||||
"<direct answer; include code blocks if context has relevant code>\n\n"
|
||||
|
||||
Evidence:
|
||||
- "<exact quote 1>"
|
||||
- "<exact quote 2>"
|
||||
(Include only quotes you actually used. Prefer the smallest quotes that fully support the statements.)
|
||||
"Evidence:\n"
|
||||
"- \"<exact quote from context>\"\n"
|
||||
"(only quotes you actually used)\n\n"
|
||||
|
||||
NON-AVAP QUESTIONS
|
||||
- If the question is clearly not AVAP-related, answer normally using general knowledge.
|
||||
"If context has no relevant information reply with exactly:\n"
|
||||
"\"I don't have enough information in the provided context to answer that.\"\n"
|
||||
"</output_format>\n\n"
|
||||
|
||||
<CONTEXT>
|
||||
{context}
|
||||
</CONTEXT>"""
|
||||
"<context>\n"
|
||||
"{context}\n"
|
||||
"</context>"
|
||||
)
|
||||
)
|
||||
|
|
@ -8,18 +8,28 @@ import brunix_pb2
|
|||
import brunix_pb2_grpc
|
||||
import grpc
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
from langchain_elasticsearch import ElasticsearchStore
|
||||
from elasticsearch import Elasticsearch
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from utils.llm_factory import create_chat_model
|
||||
from utils.emb_factory import create_embedding_model
|
||||
from graph import build_graph
|
||||
from graph import build_graph, build_prepare_graph, build_final_messages, session_store
|
||||
|
||||
from evaluate import run_evaluation
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("brunix-engine")
|
||||
|
||||
|
||||
class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
||||
|
||||
def __init__(self):
|
||||
es_url = os.getenv("ELASTICSEARCH_URL", "http://localhost:9200")
|
||||
es_user = os.getenv("ELASTICSEARCH_USER")
|
||||
es_pass = os.getenv("ELASTICSEARCH_PASSWORD")
|
||||
es_apikey = os.getenv("ELASTICSEARCH_API_KEY")
|
||||
index = os.getenv("ELASTICSEARCH_INDEX", "avap-knowledge-v1")
|
||||
|
||||
self.llm = create_chat_model(
|
||||
provider="ollama",
|
||||
model=os.getenv("OLLAMA_MODEL_NAME"),
|
||||
|
|
@ -27,56 +37,194 @@ class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
|||
temperature=0,
|
||||
validate_model_on_init=True,
|
||||
)
|
||||
|
||||
self.embeddings = create_embedding_model(
|
||||
provider="ollama",
|
||||
model=os.getenv("OLLAMA_EMB_MODEL_NAME"),
|
||||
base_url=os.getenv("OLLAMA_URL"),
|
||||
)
|
||||
self.vector_store = ElasticsearchStore(
|
||||
es_url=os.getenv("ELASTICSEARCH_URL"),
|
||||
index_name=os.getenv("ELASTICSEARCH_INDEX"),
|
||||
embedding=self.embeddings,
|
||||
query_field="text",
|
||||
vector_query_field="embedding",
|
||||
)
|
||||
|
||||
es_kwargs: dict = {"hosts": [es_url], "request_timeout": 60}
|
||||
if es_apikey:
|
||||
es_kwargs["api_key"] = es_apikey
|
||||
elif es_user and es_pass:
|
||||
es_kwargs["basic_auth"] = (es_user, es_pass)
|
||||
|
||||
self.es_client = Elasticsearch(**es_kwargs)
|
||||
self.index_name = index
|
||||
|
||||
if self.es_client.ping():
|
||||
info = self.es_client.info()
|
||||
logger.info(f"[ESEARCH] Connected: {info['version']['number']} — index: {index}")
|
||||
else:
|
||||
logger.error("[ESEARCH] Cant Connect")
|
||||
|
||||
self.graph = build_graph(
|
||||
llm=self.llm,
|
||||
vector_store=self.vector_store
|
||||
llm = self.llm,
|
||||
embeddings = self.embeddings,
|
||||
es_client = self.es_client,
|
||||
index_name = self.index_name,
|
||||
)
|
||||
logger.info("Brunix Engine initializing.")
|
||||
|
||||
self.prepare_graph = build_prepare_graph(
|
||||
llm = self.llm,
|
||||
embeddings = self.embeddings,
|
||||
es_client = self.es_client,
|
||||
index_name = self.index_name,
|
||||
)
|
||||
|
||||
logger.info("Brunix Engine initialized.")
|
||||
|
||||
|
||||
def AskAgent(self, request, context):
|
||||
logger.info(f"request {request.session_id}): {request.query[:50]}.")
|
||||
session_id = request.session_id or "default"
|
||||
query = request.query
|
||||
logger.info(f"[AskAgent] session={session_id} query='{query[:80]}'")
|
||||
|
||||
try:
|
||||
final_state = self.graph.invoke({"messages": [{"role": "user",
|
||||
"content": request.query}]})
|
||||
history = list(session_store.get(session_id, []))
|
||||
logger.info(f"[AskAgent] conversation: {len(history)} previous messages.")
|
||||
|
||||
initial_state = {
|
||||
"messages": history + [{"role": "user", "content": query}],
|
||||
"session_id": session_id,
|
||||
"reformulated_query": "",
|
||||
"context": "",
|
||||
"query_type": "",
|
||||
}
|
||||
|
||||
final_state = self.graph.invoke(initial_state)
|
||||
messages = final_state.get("messages", [])
|
||||
last_msg = messages[-1] if messages else None
|
||||
result_text = getattr(last_msg, "content", str(last_msg)) if last_msg else ""
|
||||
result_text = getattr(last_msg, "content", str(last_msg)) \
|
||||
if last_msg else ""
|
||||
|
||||
logger.info(f"[AskAgent] query_type={final_state.get('query_type')} "
|
||||
f"answer='{result_text[:100]}'")
|
||||
|
||||
yield brunix_pb2.AgentResponse(
|
||||
text=result_text,
|
||||
avap_code="AVAP-2026",
|
||||
is_final=True,
|
||||
text = result_text,
|
||||
avap_code= "AVAP-2026",
|
||||
is_final = True,
|
||||
)
|
||||
|
||||
yield brunix_pb2.AgentResponse(text="", avap_code="", is_final=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in AskAgent: {str(e)}", exc_info=True)
|
||||
logger.error(f"[AskAgent] Error: {e}", exc_info=True)
|
||||
yield brunix_pb2.AgentResponse(
|
||||
text=f"[Error Motor]: {str(e)}",
|
||||
is_final=True,
|
||||
text = f"[ENG] Error: {str(e)}",
|
||||
is_final = True,
|
||||
)
|
||||
|
||||
|
||||
def AskAgentStream(self, request, context):
|
||||
session_id = request.session_id or "default"
|
||||
query = request.query
|
||||
logger.info(f"[AskAgentStream] session={session_id} query='{query[:80]}'")
|
||||
|
||||
try:
|
||||
history = list(session_store.get(session_id, []))
|
||||
logger.info(f"[AskAgentStream] conversation: {len(history)} previous messages.")
|
||||
|
||||
initial_state = {
|
||||
"messages": history + [{"role": "user", "content": query}],
|
||||
"session_id": session_id,
|
||||
"reformulated_query": "",
|
||||
"context": "",
|
||||
"query_type": "",
|
||||
}
|
||||
|
||||
prepared = self.prepare_graph.invoke(initial_state)
|
||||
logger.info(
|
||||
f"[AskAgentStream] query_type={prepared.get('query_type')} "
|
||||
f"context_len={len(prepared.get('context', ''))}"
|
||||
)
|
||||
|
||||
final_messages = build_final_messages(prepared)
|
||||
full_response = []
|
||||
|
||||
for chunk in self.llm.stream(final_messages):
|
||||
token = chunk.content
|
||||
if token:
|
||||
full_response.append(token)
|
||||
yield brunix_pb2.AgentResponse(
|
||||
text = token,
|
||||
is_final = False,
|
||||
)
|
||||
|
||||
complete_text = "".join(full_response)
|
||||
if session_id:
|
||||
session_store[session_id] = (
|
||||
list(prepared["messages"]) + [AIMessage(content=complete_text)]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AskAgentStream] done — "
|
||||
f"chunks={len(full_response)} total_chars={len(complete_text)}"
|
||||
)
|
||||
|
||||
yield brunix_pb2.AgentResponse(text="", is_final=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AskAgentStream] Error: {e}", exc_info=True)
|
||||
yield brunix_pb2.AgentResponse(
|
||||
text = f"[ENG] Error: {str(e)}",
|
||||
is_final = True,
|
||||
)
|
||||
|
||||
|
||||
def EvaluateRAG(self, request, context):
|
||||
category = request.category or None
|
||||
limit = request.limit or None
|
||||
index = request.index or self.index_name
|
||||
|
||||
logger.info(f"[EvaluateRAG] category={category} limit={limit} index={index}")
|
||||
|
||||
try:
|
||||
result = run_evaluation(
|
||||
es_client = self.es_client,
|
||||
llm = self.llm,
|
||||
embeddings = self.embeddings,
|
||||
index_name = index,
|
||||
category = category,
|
||||
limit = limit,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[EvaluateRAG] Error: {e}", exc_info=True)
|
||||
return brunix_pb2.EvalResponse(status=f"error: {e}")
|
||||
|
||||
if result.get("status") != "ok":
|
||||
return brunix_pb2.EvalResponse(status=result.get("error", "unknown error"))
|
||||
|
||||
details = [
|
||||
brunix_pb2.QuestionDetail(
|
||||
id = d["id"],
|
||||
category = d["category"],
|
||||
question = d["question"],
|
||||
answer_preview = d["answer_preview"],
|
||||
n_chunks = d["n_chunks"],
|
||||
)
|
||||
for d in result.get("details", [])
|
||||
]
|
||||
|
||||
scores = result["scores"]
|
||||
return brunix_pb2.EvalResponse(
|
||||
status = "ok",
|
||||
questions_evaluated = result["questions_evaluated"],
|
||||
elapsed_seconds = result["elapsed_seconds"],
|
||||
judge_model = result["judge_model"],
|
||||
index = result["index"],
|
||||
faithfulness = scores["faithfulness"],
|
||||
answer_relevancy = scores["answer_relevancy"],
|
||||
context_recall = scores["context_recall"],
|
||||
context_precision = scores["context_precision"],
|
||||
global_score = result["global_score"],
|
||||
verdict= result["verdict"],
|
||||
details= details,
|
||||
)
|
||||
|
||||
|
||||
def serve():
|
||||
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
|
||||
brunix_pb2_grpc.add_AssistanceEngineServicer_to_server(BrunixEngine(), server)
|
||||
|
||||
SERVICE_NAMES = (
|
||||
|
|
@ -86,7 +234,7 @@ def serve():
|
|||
reflection.enable_server_reflection(SERVICE_NAMES, server)
|
||||
|
||||
server.add_insecure_port("[::]:50051")
|
||||
logger.info("Brunix Engine on port 50051")
|
||||
logger.info("[ENGINE] listen on 50051 (gRPC)")
|
||||
server.start()
|
||||
server.wait_for_termination()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
# state.py
|
||||
from typing import TypedDict, Annotated
|
||||
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
messages: Annotated[list, add_messages]
|
||||
messages: Annotated[list, add_messages]
|
||||
reformulated_query: str
|
||||
context: str
|
||||
context: str
|
||||
query_type: str
|
||||
session_id: str
|
||||
Loading…
Reference in New Issue