feat: implement graph-based agent with LLM integration and dynamic prompts
This commit is contained in:
parent
5024bde8fb
commit
6c8261f2b2
|
|
@ -0,0 +1,53 @@
|
|||
# Documentation
|
||||
*.md
|
||||
documentation/
|
||||
|
||||
# Build and dependency files
|
||||
Makefile
|
||||
*.pyc
|
||||
__pycache__/
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Development and testing
|
||||
.venv/
|
||||
venv/
|
||||
env/
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
|
||||
# Git and version control
|
||||
.git/
|
||||
.gitignore
|
||||
.gitattributes
|
||||
|
||||
# IDE and editor files
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
.DS_Store
|
||||
|
||||
# Environment files
|
||||
.env
|
||||
.env.local
|
||||
.env.*.local
|
||||
|
||||
# Docker files (no copy Docker files into the image)
|
||||
Dockerfile
|
||||
docker-compose.yaml
|
||||
|
||||
# CI/CD
|
||||
.github/
|
||||
.gitlab-ci.yml
|
||||
|
||||
# Temporary files
|
||||
*.tmp
|
||||
*.log
|
||||
scratches/
|
||||
|
||||
# Node modules (if any)
|
||||
node_modules/
|
||||
npm-debug.log
|
||||
|
|
@ -3,7 +3,6 @@ FROM python:3.11-slim
|
|||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY ./requirements.txt .
|
||||
|
|
|
|||
|
|
@ -4,17 +4,18 @@ services:
|
|||
brunix-engine:
|
||||
build: .
|
||||
container_name: brunix-assistance-engine
|
||||
env_file: .env
|
||||
ports:
|
||||
- "50052:50051"
|
||||
environment:
|
||||
- ELASTICSEARCH_URL=http://host.docker.internal:9200
|
||||
- DATABASE_URL=postgresql://postgres:brunix_pass@host.docker.internal:5432/postgres
|
||||
|
||||
- LANGFUSE_HOST=http://45.77.119.180
|
||||
- LANGFUSE_PUBLIC_KEY=${LANGFUSE_PUBLIC_KEY}
|
||||
- LANGFUSE_SECRET_KEY=${LANGFUSE_SECRET_KEY}
|
||||
- LLM_BASE_URL=http://host.docker.internal:11434
|
||||
ELASTICSEARCH_URL: ${ELASTICSEARCH_URL}
|
||||
ELASTICSEARCH_INDEX: ${ELASTICSEARCH_INDEX}
|
||||
POSTGRES_URL: ${POSTGRES_URL}
|
||||
LANGFUSE_HOST: ${LANGFUSE_HOST}
|
||||
LANGFUSE_PUBLIC_KEY: ${LANGFUSE_PUBLIC_KEY}
|
||||
LANGFUSE_SECRET_KEY: ${LANGFUSE_SECRET_KEY}
|
||||
OLLAMA_URL: ${OLLAMA_URL}
|
||||
OLLAMA_MODEL_NAME: ${OLLAMA_MODEL_NAME}
|
||||
OLLAMA_EMB_MODEL_NAME: ${OLLAMA_EMB_MODEL_NAME}
|
||||
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# This file was autogenerated by uv via the following command:
|
||||
# uv export --format requirements-txt --no-hashes --no-dev -o requirements.txt
|
||||
# uv export --format requirements-txt --no-hashes --no-dev -o Docker/requirements.txt
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.3
|
||||
|
|
@ -12,6 +12,12 @@ anyio==4.12.1
|
|||
# via httpx
|
||||
attrs==25.4.0
|
||||
# via aiohttp
|
||||
boto3==1.42.58
|
||||
# via langchain-aws
|
||||
botocore==1.42.58
|
||||
# via
|
||||
# boto3
|
||||
# s3transfer
|
||||
certifi==2026.1.4
|
||||
# via
|
||||
# elastic-transport
|
||||
|
|
@ -20,8 +26,11 @@ certifi==2026.1.4
|
|||
# requests
|
||||
charset-normalizer==3.4.4
|
||||
# via requests
|
||||
click==8.3.1
|
||||
# via nltk
|
||||
colorama==0.4.6 ; sys_platform == 'win32'
|
||||
# via
|
||||
# click
|
||||
# loguru
|
||||
# tqdm
|
||||
dataclasses-json==0.6.7
|
||||
|
|
@ -30,72 +39,98 @@ elastic-transport==8.17.1
|
|||
# via elasticsearch
|
||||
elasticsearch==8.19.3
|
||||
# via langchain-elasticsearch
|
||||
filelock==3.24.3
|
||||
# via huggingface-hub
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
greenlet==3.3.1 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
|
||||
fsspec==2025.10.0
|
||||
# via huggingface-hub
|
||||
greenlet==3.3.2 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
|
||||
# via sqlalchemy
|
||||
grpcio==1.78.0
|
||||
grpcio==1.78.1
|
||||
# via
|
||||
# assistance-engine
|
||||
# grpcio-reflection
|
||||
# grpcio-tools
|
||||
grpcio-reflection==1.78.0
|
||||
grpcio-reflection==1.78.1
|
||||
# via assistance-engine
|
||||
grpcio-tools==1.78.0
|
||||
grpcio-tools==1.78.1
|
||||
# via assistance-engine
|
||||
h11==0.16.0
|
||||
# via httpcore
|
||||
hf-xet==1.3.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
||||
# via huggingface-hub
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# langgraph-sdk
|
||||
# langsmith
|
||||
# ollama
|
||||
httpx-sse==0.4.3
|
||||
# via langchain-community
|
||||
huggingface-hub==0.36.2
|
||||
# via
|
||||
# langchain-huggingface
|
||||
# tokenizers
|
||||
idna==3.11
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
jmespath==1.1.0
|
||||
# via
|
||||
# boto3
|
||||
# botocore
|
||||
joblib==1.5.3
|
||||
# via nltk
|
||||
jsonpatch==1.33
|
||||
# via langchain-core
|
||||
jsonpointer==3.0.0
|
||||
# via jsonpatch
|
||||
langchain==1.2.10
|
||||
# via assistance-engine
|
||||
langchain-aws==1.3.1
|
||||
# via assistance-engine
|
||||
langchain-classic==1.0.1
|
||||
# via langchain-community
|
||||
langchain-community==0.4.1
|
||||
# via assistance-engine
|
||||
langchain-core==1.2.11
|
||||
langchain-core==1.2.15
|
||||
# via
|
||||
# langchain
|
||||
# langchain-aws
|
||||
# langchain-classic
|
||||
# langchain-community
|
||||
# langchain-elasticsearch
|
||||
# langchain-huggingface
|
||||
# langchain-ollama
|
||||
# langchain-text-splitters
|
||||
# langgraph
|
||||
# langgraph-checkpoint
|
||||
# langgraph-prebuilt
|
||||
langchain-elasticsearch==1.0.0
|
||||
# via assistance-engine
|
||||
langchain-text-splitters==1.1.0
|
||||
langchain-huggingface==1.2.0
|
||||
# via assistance-engine
|
||||
langchain-ollama==1.0.1
|
||||
# via assistance-engine
|
||||
langchain-text-splitters==1.1.1
|
||||
# via langchain-classic
|
||||
langgraph==1.0.8
|
||||
langgraph==1.0.9
|
||||
# via langchain
|
||||
langgraph-checkpoint==4.0.0
|
||||
# via
|
||||
# langgraph
|
||||
# langgraph-prebuilt
|
||||
langgraph-prebuilt==1.0.7
|
||||
langgraph-prebuilt==1.0.8
|
||||
# via langgraph
|
||||
langgraph-sdk==0.3.5
|
||||
langgraph-sdk==0.3.8
|
||||
# via langgraph
|
||||
langsmith==0.7.1
|
||||
langsmith==0.7.6
|
||||
# via
|
||||
# langchain-classic
|
||||
# langchain-community
|
||||
|
|
@ -110,24 +145,30 @@ multidict==6.7.1
|
|||
# yarl
|
||||
mypy-extensions==1.1.0
|
||||
# via typing-inspect
|
||||
nltk==3.9.3
|
||||
# via assistance-engine
|
||||
numpy==2.4.2
|
||||
# via
|
||||
# assistance-engine
|
||||
# elasticsearch
|
||||
# langchain-aws
|
||||
# langchain-community
|
||||
# pandas
|
||||
ollama==0.6.1
|
||||
# via langchain-ollama
|
||||
orjson==3.11.7
|
||||
# via
|
||||
# langgraph-sdk
|
||||
# langsmith
|
||||
ormsgpack==1.12.2
|
||||
# via langgraph-checkpoint
|
||||
packaging==26.0
|
||||
packaging==24.2
|
||||
# via
|
||||
# huggingface-hub
|
||||
# langchain-core
|
||||
# langsmith
|
||||
# marshmallow
|
||||
pandas==3.0.0
|
||||
pandas==3.0.1
|
||||
# via assistance-engine
|
||||
propcache==0.4.1
|
||||
# via
|
||||
|
|
@ -140,17 +181,20 @@ protobuf==6.33.5
|
|||
pydantic==2.12.5
|
||||
# via
|
||||
# langchain
|
||||
# langchain-aws
|
||||
# langchain-classic
|
||||
# langchain-core
|
||||
# langgraph
|
||||
# langsmith
|
||||
# ollama
|
||||
# pydantic-settings
|
||||
pydantic-core==2.41.5
|
||||
# via pydantic
|
||||
pydantic-settings==2.12.0
|
||||
pydantic-settings==2.13.1
|
||||
# via langchain-community
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# botocore
|
||||
# elasticsearch
|
||||
# pandas
|
||||
python-dotenv==1.2.1
|
||||
|
|
@ -159,20 +203,28 @@ python-dotenv==1.2.1
|
|||
# pydantic-settings
|
||||
pyyaml==6.0.3
|
||||
# via
|
||||
# huggingface-hub
|
||||
# langchain-classic
|
||||
# langchain-community
|
||||
# langchain-core
|
||||
rapidfuzz==3.14.3
|
||||
# via assistance-engine
|
||||
regex==2026.2.19
|
||||
# via nltk
|
||||
requests==2.32.5
|
||||
# via
|
||||
# huggingface-hub
|
||||
# langchain-classic
|
||||
# langchain-community
|
||||
# langsmith
|
||||
# requests-toolbelt
|
||||
requests-toolbelt==1.0.0
|
||||
# via langsmith
|
||||
s3transfer==0.16.0
|
||||
# via boto3
|
||||
setuptools==82.0.0
|
||||
# via grpcio-tools
|
||||
simsimd==6.5.12
|
||||
simsimd==6.5.13
|
||||
# via elasticsearch
|
||||
six==1.17.0
|
||||
# via python-dateutil
|
||||
|
|
@ -184,14 +236,20 @@ tenacity==9.1.4
|
|||
# via
|
||||
# langchain-community
|
||||
# langchain-core
|
||||
tokenizers==0.22.2
|
||||
# via langchain-huggingface
|
||||
tqdm==4.67.3
|
||||
# via assistance-engine
|
||||
# via
|
||||
# assistance-engine
|
||||
# huggingface-hub
|
||||
# nltk
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# aiosignal
|
||||
# anyio
|
||||
# elasticsearch
|
||||
# grpcio
|
||||
# huggingface-hub
|
||||
# langchain-core
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
|
|
@ -208,9 +266,10 @@ tzdata==2025.3 ; sys_platform == 'emscripten' or sys_platform == 'win32'
|
|||
# via pandas
|
||||
urllib3==2.6.3
|
||||
# via
|
||||
# botocore
|
||||
# elastic-transport
|
||||
# requests
|
||||
uuid-utils==0.14.0
|
||||
uuid-utils==0.14.1
|
||||
# via
|
||||
# langchain-core
|
||||
# langsmith
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
# graph.py
|
||||
import logging
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langgraph.graph import END, StateGraph
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from prompts import GENERATE_PROMPT, REFORMULATE_PROMPT
|
||||
from state import AgentState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def format_context(docs: list[Document]) -> str:
|
||||
chunks = []
|
||||
for i, doc in enumerate(docs, 1):
|
||||
source = (doc.metadata or {}).get("source", "Untitled")
|
||||
source_id = (doc.metadata or {}).get("id", f"chunk-{i}")
|
||||
text = doc.page_content or ""
|
||||
chunks.append(f"[{i}] id={source_id} source={source}\n{text}")
|
||||
return "\n\n".join(chunks)
|
||||
|
||||
|
||||
def build_graph(llm, vector_store) -> CompiledStateGraph:
|
||||
def reformulate(state: AgentState) -> AgentState:
|
||||
user_msg = state["messages"][-1]
|
||||
resp = llm.invoke([REFORMULATE_PROMPT, user_msg])
|
||||
reformulated = resp.content.strip()
|
||||
logger.info(f"[reformulate] '{user_msg.content}' → '{reformulated}'")
|
||||
return {"reformulated_query": reformulated}
|
||||
|
||||
def retrieve(state: AgentState) -> AgentState:
|
||||
query = state["reformulated_query"]
|
||||
docs = vector_store.as_retriever(
|
||||
search_type="similarity",
|
||||
search_kwargs={"k": 3},
|
||||
).invoke(query)
|
||||
context = format_context(docs)
|
||||
logger.info(f"[retrieve] {len(docs)} docs fetched")
|
||||
logger.info(context)
|
||||
return {"context": context}
|
||||
|
||||
def generate(state: AgentState) -> AgentState:
|
||||
prompt = SystemMessage(
|
||||
content=GENERATE_PROMPT.content.format(context=state["context"])
|
||||
)
|
||||
resp = llm.invoke([prompt] + state["messages"])
|
||||
return {"messages": [resp]}
|
||||
|
||||
graph_builder = StateGraph(AgentState)
|
||||
graph_builder.add_node("reformulate", reformulate)
|
||||
graph_builder.add_node("retrieve", retrieve)
|
||||
graph_builder.add_node("generate", generate)
|
||||
|
||||
graph_builder.set_entry_point("reformulate")
|
||||
graph_builder.add_edge("reformulate", "retrieve")
|
||||
graph_builder.add_edge("retrieve", "generate")
|
||||
graph_builder.add_edge("generate", END)
|
||||
|
||||
return graph_builder.compile()
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
from langchain_core.messages import SystemMessage
|
||||
|
||||
REFORMULATE_PROMPT = SystemMessage(
|
||||
content=(
|
||||
"You are a deterministic lexical query rewriter used for vector retrieval.\n"
|
||||
"Your task is to rewrite user questions into optimized keyword search queries.\n\n"
|
||||
|
||||
"CRITICAL RULES (ABSOLUTE):\n"
|
||||
"1. NEVER answer the question.\n"
|
||||
"2. NEVER expand acronyms.\n"
|
||||
"3. NEVER introduce new terms not present in the original query.\n"
|
||||
"4. NEVER infer missing information.\n"
|
||||
"5. NEVER add explanations, definitions, or interpretations.\n"
|
||||
"6. Preserve all technical tokens exactly as written.\n"
|
||||
"7. Only remove filler words (e.g., what, does, is, explain, tell me, please).\n"
|
||||
"8. You may reorder terms for better retrieval.\n"
|
||||
"9. Output must be a single-line plain keyword query.\n"
|
||||
"10. If the query is already optimal, return it unchanged.\n\n"
|
||||
"11. If you receive something that looks like code, do NOT attempt to rewrite it. Return it verbatim.\n\n"
|
||||
|
||||
"ALLOWED OPERATIONS:\n"
|
||||
"- Remove interrogative phrasing.\n"
|
||||
"- Remove stopwords.\n"
|
||||
"- Reorder words.\n"
|
||||
"- Convert to noun phrase form.\n\n"
|
||||
|
||||
"FORBIDDEN OPERATIONS:\n"
|
||||
"- Expanding abbreviations.\n"
|
||||
"- Paraphrasing into unseen vocabulary.\n"
|
||||
"- Adding definitions.\n"
|
||||
"- Answering implicitly.\n\n"
|
||||
|
||||
"Examples:\n"
|
||||
"Input: What does AVAP stand for?\n"
|
||||
"Output: AVAP stand for\n"
|
||||
|
||||
"Input: Hey, I'm trying to understand how AVAP handels a ZeroDivisionError when doing divison or modulus operatoins. Can you explane what situatoins cause a ZeroDivisionError to be raised and how I can catch it in my AVAP scripts?\n"
|
||||
"Output: AVAP ZeroDivisionError division / modulus % catch try except\n"
|
||||
|
||||
"Input: What does AVAP stand for?\n"
|
||||
"Output: AVAP stand for\n"
|
||||
|
||||
"Input: Please explain how the import statement works in AVAP scripts.\n"
|
||||
"Output: AVAP import statement syntax behavior\n\n"
|
||||
|
||||
"Return only the rewritten query."
|
||||
)
|
||||
)
|
||||
|
||||
GENERATE_PROMPT = SystemMessage(
|
||||
content=(
|
||||
"""You are a retrieval-grounded assistant.
|
||||
|
||||
CRITICAL SCOPE RULES
|
||||
- AVAP (Advanced Virtual API Programming) is a new/unknown programming language.
|
||||
- If the user message is AVAP-related (explicitly mentions "AVAP" case-insensitive OR the provided context is AVAP code/docs), you MUST answer using ONLY <CONTEXT>.
|
||||
- You MUST NOT use general programming knowledge to interpret AVAP.
|
||||
- You MUST NOT assume AVAP behaves like any other language even if syntax looks similar.
|
||||
- You MUST NOT infer missing details. Only state what is explicitly present in <CONTEXT>.
|
||||
|
||||
GROUNDING REQUIREMENTS (AVAP-RELATED)
|
||||
1) Every non-trivial factual claim MUST be directly supported by an EXACT QUOTE from <CONTEXT>.
|
||||
2) If a claim is not supported by a quote, DO NOT include it.
|
||||
3) If <CONTEXT> does not contain enough information to answer, reply with EXACTLY:
|
||||
"I don't have enough information in the provided context to answer that."
|
||||
|
||||
WORKFLOW (AVAP-RELATED) — FOLLOW IN ORDER
|
||||
A) Identify the specific question(s) being asked.
|
||||
B) Extract the minimum necessary quotes from <CONTEXT> that answer those question(s).
|
||||
C) Write the answer using ONLY those quotes (paraphrase is allowed, but every statement must be backed by at least one quote).
|
||||
D) Verify: for EACH sentence in your answer, confirm there is a supporting quote. If any sentence lacks a quote, delete it or refuse.
|
||||
|
||||
OUTPUT FORMAT (AVAP-RELATED ONLY)
|
||||
Answer:
|
||||
<short, direct answer; no extra speculation; no unrelated tips>
|
||||
|
||||
Evidence:
|
||||
- "<exact quote 1>"
|
||||
- "<exact quote 2>"
|
||||
(Include only quotes you actually used. Prefer the smallest quotes that fully support the statements.)
|
||||
|
||||
NON-AVAP QUESTIONS
|
||||
- If the question is clearly not AVAP-related, answer normally using general knowledge.
|
||||
|
||||
<CONTEXT>
|
||||
{context}
|
||||
</CONTEXT>"""
|
||||
)
|
||||
)
|
||||
|
|
@ -1,70 +1,75 @@
|
|||
import os
|
||||
import grpc
|
||||
import logging
|
||||
import os
|
||||
from concurrent import futures
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
|
||||
import brunix_pb2
|
||||
import brunix_pb2_grpc
|
||||
|
||||
from langchain_community.llms import Ollama
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
import grpc
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
from langchain_elasticsearch import ElasticsearchStore
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
|
||||
from utils.llm_factory import create_chat_model
|
||||
from utils.emb_factory import create_embedding_model
|
||||
from graph import build_graph
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("brunix-engine")
|
||||
|
||||
class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
||||
def __init__(self):
|
||||
|
||||
self.base_url = os.getenv("LLM_BASE_URL", "http://ollama-light-service:11434")
|
||||
self.model_name = os.getenv("LLM_MODEL", "qwen2.5:1.5b")
|
||||
|
||||
logger.info(f"Starting server")
|
||||
|
||||
self.llm = Ollama(base_url=self.base_url, model=self.model_name)
|
||||
|
||||
self.embeddings = OllamaEmbeddings(base_url=self.base_url, model="nomic-embed-text")
|
||||
|
||||
es_url = os.getenv("ELASTICSEARCH_URL", "http://elasticsearch:9200")
|
||||
logger.info(f"ElasticSearch on: {es_url}")
|
||||
|
||||
self.vector_store = ElasticsearchStore(
|
||||
es_url=es_url,
|
||||
index_name="avap_manuals",
|
||||
embedding=self.embeddings
|
||||
self.llm = create_chat_model(
|
||||
provider="ollama",
|
||||
model=os.getenv("OLLAMA_MODEL_NAME"),
|
||||
base_url=os.getenv("OLLAMA_URL"),
|
||||
temperature=0,
|
||||
validate_model_on_init=True,
|
||||
)
|
||||
self.embeddings = create_embedding_model(
|
||||
provider="ollama",
|
||||
model=os.getenv("OLLAMA_EMB_MODEL_NAME"),
|
||||
base_url=os.getenv("OLLAMA_URL"),
|
||||
)
|
||||
self.vector_store = ElasticsearchStore(
|
||||
es_url=os.getenv("ELASTICSEARCH_URL"),
|
||||
index_name=os.getenv("ELASTICSEARCH_INDEX"),
|
||||
embedding=self.embeddings,
|
||||
query_field="text",
|
||||
vector_query_field="vector",
|
||||
)
|
||||
self.graph = build_graph(
|
||||
llm=self.llm,
|
||||
vector_store=self.vector_store
|
||||
)
|
||||
logger.info("Brunix Engine initializing.")
|
||||
|
||||
|
||||
def AskAgent(self, request, context):
|
||||
logger.info(f"request {request.session_id}): {request.query[:50]}.")
|
||||
|
||||
try:
|
||||
context_text = "AVAP is a virtual programming language for API development."
|
||||
# 4. Prompt Engineering
|
||||
prompt = ChatPromptTemplate.from_template("""
|
||||
You are Brunix, the 101OBEX artificial intelligence for the AVAP Sphere platform. Respond in a professional manner.
|
||||
final_state = self.graph.invoke({"messages": [{"role": "user",
|
||||
"content": request.query}]})
|
||||
|
||||
CONTEXT:
|
||||
{context}
|
||||
messages = final_state.get("messages", [])
|
||||
last_msg = messages[-1] if messages else None
|
||||
result_text = getattr(last_msg, "content", str(last_msg)) if last_msg else ""
|
||||
|
||||
QUESTION:
|
||||
{question}
|
||||
""")
|
||||
|
||||
chain = prompt | self.llm
|
||||
|
||||
for chunk in chain.stream({"context": context_text, "question": request.query}):
|
||||
yield brunix_pb2.AgentResponse(
|
||||
text=str(chunk),
|
||||
avap_code="AVAP-2026",
|
||||
is_final=False
|
||||
)
|
||||
yield brunix_pb2.AgentResponse(
|
||||
text=result_text,
|
||||
avap_code="AVAP-2026",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
yield brunix_pb2.AgentResponse(text="", avap_code="", is_final=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in AskAgent: {str(e)}")
|
||||
yield brunix_pb2.AgentResponse(text=f"[Error Motor]: {str(e)}", is_final=True)
|
||||
logger.error(f"Error in AskAgent: {str(e)}", exc_info=True)
|
||||
yield brunix_pb2.AgentResponse(
|
||||
text=f"[Error Motor]: {str(e)}",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
|
||||
def serve():
|
||||
|
||||
|
|
@ -73,15 +78,16 @@ def serve():
|
|||
brunix_pb2_grpc.add_AssistanceEngineServicer_to_server(BrunixEngine(), server)
|
||||
|
||||
SERVICE_NAMES = (
|
||||
brunix_pb2.DESCRIPTOR.services_by_name['AssistanceEngine'].full_name,
|
||||
brunix_pb2.DESCRIPTOR.services_by_name["AssistanceEngine"].full_name,
|
||||
reflection.SERVICE_NAME,
|
||||
)
|
||||
reflection.enable_server_reflection(SERVICE_NAMES, server)
|
||||
|
||||
server.add_insecure_port('[::]:50051')
|
||||
server.add_insecure_port("[::]:50051")
|
||||
logger.info("Brunix Engine on port 50051")
|
||||
server.start()
|
||||
server.wait_for_termination()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
serve()
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
from typing import TypedDict, Annotated
|
||||
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
|
||||
class AgentState(TypedDict):
|
||||
messages: Annotated[list, add_messages]
|
||||
reformulated_query: str
|
||||
context: str
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class BaseEmbeddingFactory(ABC):
|
||||
@abstractmethod
|
||||
def create(self, model: str, **kwargs: Any):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OpenAIEmbeddingFactory(BaseEmbeddingFactory):
|
||||
def create(self, model: str, **kwargs: Any):
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
|
||||
return OpenAIEmbeddings(model=model, **kwargs)
|
||||
|
||||
|
||||
class OllamaEmbeddingFactory(BaseEmbeddingFactory):
|
||||
def create(self, model: str, **kwargs: Any):
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
|
||||
return OllamaEmbeddings(model=model, **kwargs)
|
||||
|
||||
|
||||
class BedrockEmbeddingFactory(BaseEmbeddingFactory):
|
||||
def create(self, model: str, **kwargs: Any):
|
||||
from langchain_aws import BedrockEmbeddings
|
||||
|
||||
return BedrockEmbeddings(model_id=model, **kwargs)
|
||||
|
||||
|
||||
class HuggingFaceEmbeddingFactory(BaseEmbeddingFactory):
|
||||
def create(self, model: str, **kwargs: Any):
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
|
||||
return HuggingFaceEmbeddings(model_name=model, **kwargs)
|
||||
|
||||
|
||||
EMBEDDING_FACTORIES: Dict[str, BaseEmbeddingFactory] = {
|
||||
"openai": OpenAIEmbeddingFactory(),
|
||||
"ollama": OllamaEmbeddingFactory(),
|
||||
"bedrock": BedrockEmbeddingFactory(),
|
||||
"huggingface": HuggingFaceEmbeddingFactory(),
|
||||
}
|
||||
|
||||
|
||||
def create_embedding_model(provider: str, model: str, **kwargs: Any):
|
||||
"""
|
||||
Create an embedding model instance for the given provider.
|
||||
|
||||
Args:
|
||||
provider: The provider name (openai, ollama, bedrock, huggingface).
|
||||
model: The model identifier.
|
||||
**kwargs: Additional keyword arguments passed to the model constructor.
|
||||
|
||||
Returns:
|
||||
An embedding model instance.
|
||||
"""
|
||||
key = provider.strip().lower()
|
||||
|
||||
if key not in EMBEDDING_FACTORIES:
|
||||
raise ValueError(
|
||||
f"Unsupported embedding provider: {provider}. "
|
||||
f"Available providers: {list(EMBEDDING_FACTORIES.keys())}"
|
||||
)
|
||||
|
||||
return EMBEDDING_FACTORIES[key].create(model=model, **kwargs)
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class BaseProviderFactory(ABC):
|
||||
@abstractmethod
|
||||
def create(self, model: str, **kwargs: Any):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OpenAIChatFactory(BaseProviderFactory):
|
||||
def create(self, model: str, **kwargs: Any):
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
return ChatOpenAI(model=model, **kwargs)
|
||||
|
||||
|
||||
class OllamaChatFactory(BaseProviderFactory):
|
||||
def create(self, model: str, **kwargs: Any):
|
||||
from langchain_ollama import ChatOllama
|
||||
|
||||
return ChatOllama(model=model, **kwargs)
|
||||
|
||||
|
||||
class BedrockChatFactory(BaseProviderFactory):
|
||||
def create(self, model: str, **kwargs: Any):
|
||||
from langchain_aws import ChatBedrockConverse
|
||||
|
||||
return ChatBedrockConverse(model=model, **kwargs)
|
||||
|
||||
|
||||
class HuggingFaceChatFactory(BaseProviderFactory):
|
||||
def create(self, model: str, **kwargs: Any):
|
||||
from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
|
||||
|
||||
llm = HuggingFacePipeline.from_model_id(
|
||||
model_id=model,
|
||||
task="text-generation",
|
||||
pipeline_kwargs=kwargs,
|
||||
)
|
||||
return ChatHuggingFace(llm=llm)
|
||||
|
||||
|
||||
CHAT_FACTORIES: Dict[str, BaseProviderFactory] = {
|
||||
"openai": OpenAIChatFactory(),
|
||||
"ollama": OllamaChatFactory(),
|
||||
"bedrock": BedrockChatFactory(),
|
||||
"huggingface": HuggingFaceChatFactory(),
|
||||
}
|
||||
|
||||
|
||||
def create_chat_model(provider: str, model: str, **kwargs: Any):
|
||||
"""
|
||||
Create a chat model instance for the given provider.
|
||||
|
||||
Args:
|
||||
provider: The provider name (openai, ollama, bedrock, huggingface).
|
||||
model: The model identifier.
|
||||
**kwargs: Additional keyword arguments passed to the model constructor.
|
||||
|
||||
Returns:
|
||||
A chat model instance.
|
||||
"""
|
||||
key = provider.strip().lower()
|
||||
|
||||
if key not in CHAT_FACTORIES:
|
||||
raise ValueError(
|
||||
f"Unsupported chat provider: {provider}. "
|
||||
f"Available providers: {list(CHAT_FACTORIES.keys())}"
|
||||
)
|
||||
|
||||
return CHAT_FACTORIES[key].create(model=model, **kwargs)
|
||||
Loading…
Reference in New Issue