working on agent in docker
This commit is contained in:
parent
48d280440c
commit
a5952c1a4d
|
|
@ -9,7 +9,7 @@ services:
|
||||||
environment:
|
environment:
|
||||||
ELASTICSEARCH_URL: ${ELASTICSEARCH_URL}
|
ELASTICSEARCH_URL: ${ELASTICSEARCH_URL}
|
||||||
DATABASE_URL: ${DATABASE_URL}
|
DATABASE_URL: ${DATABASE_URL}
|
||||||
LLM_BASE_URL: ${LLM_BASE_URL}
|
OLLAMA_URL: ${OLLAMA_URL}
|
||||||
LANGFUSE_HOST: ${LANGFUSE_HOST}
|
LANGFUSE_HOST: ${LANGFUSE_HOST}
|
||||||
LANGFUSE_PUBLIC_KEY: ${LANGFUSE_PUBLIC_KEY}
|
LANGFUSE_PUBLIC_KEY: ${LANGFUSE_PUBLIC_KEY}
|
||||||
LANGFUSE_SECRET_KEY: ${LANGFUSE_SECRET_KEY}
|
LANGFUSE_SECRET_KEY: ${LANGFUSE_SECRET_KEY}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
# This file was autogenerated by uv via the following command:
|
# This file was autogenerated by uv via the following command:
|
||||||
# uv export --format requirements-txt --no-hashes --no-dev -o requirements.txt
|
# uv export --format requirements-txt --no-hashes --no-dev -o Docker/requirements.txt
|
||||||
|
accelerate==1.12.0
|
||||||
|
# via assistance-engine
|
||||||
aiohappyeyeballs==2.6.1
|
aiohappyeyeballs==2.6.1
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
aiohttp==3.13.3
|
aiohttp==3.13.3
|
||||||
|
|
@ -12,6 +14,12 @@ anyio==4.12.1
|
||||||
# via httpx
|
# via httpx
|
||||||
attrs==25.4.0
|
attrs==25.4.0
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
|
boto3==1.42.58
|
||||||
|
# via langchain-aws
|
||||||
|
botocore==1.42.58
|
||||||
|
# via
|
||||||
|
# boto3
|
||||||
|
# s3transfer
|
||||||
certifi==2026.1.4
|
certifi==2026.1.4
|
||||||
# via
|
# via
|
||||||
# elastic-transport
|
# elastic-transport
|
||||||
|
|
@ -20,115 +28,211 @@ certifi==2026.1.4
|
||||||
# requests
|
# requests
|
||||||
charset-normalizer==3.4.4
|
charset-normalizer==3.4.4
|
||||||
# via requests
|
# via requests
|
||||||
|
click==8.3.1
|
||||||
|
# via nltk
|
||||||
colorama==0.4.6 ; sys_platform == 'win32'
|
colorama==0.4.6 ; sys_platform == 'win32'
|
||||||
# via
|
# via
|
||||||
|
# click
|
||||||
# loguru
|
# loguru
|
||||||
# tqdm
|
# tqdm
|
||||||
|
cuda-bindings==12.9.4 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||||
|
# via torch
|
||||||
|
cuda-pathfinder==1.3.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||||
|
# via cuda-bindings
|
||||||
dataclasses-json==0.6.7
|
dataclasses-json==0.6.7
|
||||||
# via langchain-community
|
# via langchain-community
|
||||||
elastic-transport==8.17.1
|
elastic-transport==8.17.1
|
||||||
# via elasticsearch
|
# via elasticsearch
|
||||||
elasticsearch==8.19.3
|
elasticsearch==8.19.3
|
||||||
# via langchain-elasticsearch
|
# via langchain-elasticsearch
|
||||||
|
filelock==3.24.3
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# torch
|
||||||
frozenlist==1.8.0
|
frozenlist==1.8.0
|
||||||
# via
|
# via
|
||||||
# aiohttp
|
# aiohttp
|
||||||
# aiosignal
|
# aiosignal
|
||||||
greenlet==3.3.1 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
|
fsspec==2025.10.0
|
||||||
|
# via
|
||||||
|
# huggingface-hub
|
||||||
|
# torch
|
||||||
|
greenlet==3.3.2 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'
|
||||||
# via sqlalchemy
|
# via sqlalchemy
|
||||||
grpcio==1.78.0
|
grpcio==1.78.1
|
||||||
# via
|
# via
|
||||||
# assistance-engine
|
# assistance-engine
|
||||||
# grpcio-reflection
|
# grpcio-reflection
|
||||||
# grpcio-tools
|
# grpcio-tools
|
||||||
grpcio-reflection==1.78.0
|
grpcio-reflection==1.78.1
|
||||||
# via assistance-engine
|
# via assistance-engine
|
||||||
grpcio-tools==1.78.0
|
grpcio-tools==1.78.1
|
||||||
# via assistance-engine
|
# via assistance-engine
|
||||||
h11==0.16.0
|
h11==0.16.0
|
||||||
# via httpcore
|
# via httpcore
|
||||||
|
hf-xet==1.3.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'
|
||||||
|
# via huggingface-hub
|
||||||
httpcore==1.0.9
|
httpcore==1.0.9
|
||||||
# via httpx
|
# via httpx
|
||||||
httpx==0.28.1
|
httpx==0.28.1
|
||||||
# via
|
# via
|
||||||
# langgraph-sdk
|
# langgraph-sdk
|
||||||
# langsmith
|
# langsmith
|
||||||
|
# ollama
|
||||||
httpx-sse==0.4.3
|
httpx-sse==0.4.3
|
||||||
# via langchain-community
|
# via langchain-community
|
||||||
|
huggingface-hub==0.36.2
|
||||||
|
# via
|
||||||
|
# accelerate
|
||||||
|
# langchain-huggingface
|
||||||
|
# tokenizers
|
||||||
idna==3.11
|
idna==3.11
|
||||||
# via
|
# via
|
||||||
# anyio
|
# anyio
|
||||||
# httpx
|
# httpx
|
||||||
# requests
|
# requests
|
||||||
# yarl
|
# yarl
|
||||||
|
jinja2==3.1.6
|
||||||
|
# via torch
|
||||||
|
jmespath==1.1.0
|
||||||
|
# via
|
||||||
|
# boto3
|
||||||
|
# botocore
|
||||||
|
joblib==1.5.3
|
||||||
|
# via nltk
|
||||||
jsonpatch==1.33
|
jsonpatch==1.33
|
||||||
# via langchain-core
|
# via langchain-core
|
||||||
jsonpointer==3.0.0
|
jsonpointer==3.0.0
|
||||||
# via jsonpatch
|
# via jsonpatch
|
||||||
langchain==1.2.10
|
langchain==1.2.10
|
||||||
# via assistance-engine
|
# via assistance-engine
|
||||||
|
langchain-aws==1.3.1
|
||||||
|
# via assistance-engine
|
||||||
langchain-classic==1.0.1
|
langchain-classic==1.0.1
|
||||||
# via langchain-community
|
# via langchain-community
|
||||||
langchain-community==0.4.1
|
langchain-community==0.4.1
|
||||||
# via assistance-engine
|
# via assistance-engine
|
||||||
langchain-core==1.2.11
|
langchain-core==1.2.15
|
||||||
# via
|
# via
|
||||||
# langchain
|
# langchain
|
||||||
|
# langchain-aws
|
||||||
# langchain-classic
|
# langchain-classic
|
||||||
# langchain-community
|
# langchain-community
|
||||||
# langchain-elasticsearch
|
# langchain-elasticsearch
|
||||||
|
# langchain-huggingface
|
||||||
|
# langchain-ollama
|
||||||
# langchain-text-splitters
|
# langchain-text-splitters
|
||||||
# langgraph
|
# langgraph
|
||||||
# langgraph-checkpoint
|
# langgraph-checkpoint
|
||||||
# langgraph-prebuilt
|
# langgraph-prebuilt
|
||||||
langchain-elasticsearch==1.0.0
|
langchain-elasticsearch==1.0.0
|
||||||
# via assistance-engine
|
# via assistance-engine
|
||||||
langchain-text-splitters==1.1.0
|
langchain-huggingface==1.2.0
|
||||||
|
# via assistance-engine
|
||||||
|
langchain-ollama==1.0.1
|
||||||
|
# via assistance-engine
|
||||||
|
langchain-text-splitters==1.1.1
|
||||||
# via langchain-classic
|
# via langchain-classic
|
||||||
langgraph==1.0.8
|
langgraph==1.0.9
|
||||||
# via langchain
|
# via langchain
|
||||||
langgraph-checkpoint==4.0.0
|
langgraph-checkpoint==4.0.0
|
||||||
# via
|
# via
|
||||||
# langgraph
|
# langgraph
|
||||||
# langgraph-prebuilt
|
# langgraph-prebuilt
|
||||||
langgraph-prebuilt==1.0.7
|
langgraph-prebuilt==1.0.8
|
||||||
# via langgraph
|
# via langgraph
|
||||||
langgraph-sdk==0.3.5
|
langgraph-sdk==0.3.8
|
||||||
# via langgraph
|
# via langgraph
|
||||||
langsmith==0.7.1
|
langsmith==0.7.6
|
||||||
# via
|
# via
|
||||||
# langchain-classic
|
# langchain-classic
|
||||||
# langchain-community
|
# langchain-community
|
||||||
# langchain-core
|
# langchain-core
|
||||||
loguru==0.7.3
|
loguru==0.7.3
|
||||||
# via assistance-engine
|
# via assistance-engine
|
||||||
|
markupsafe==3.0.3
|
||||||
|
# via jinja2
|
||||||
marshmallow==3.26.2
|
marshmallow==3.26.2
|
||||||
# via dataclasses-json
|
# via dataclasses-json
|
||||||
|
mpmath==1.3.0
|
||||||
|
# via sympy
|
||||||
multidict==6.7.1
|
multidict==6.7.1
|
||||||
# via
|
# via
|
||||||
# aiohttp
|
# aiohttp
|
||||||
# yarl
|
# yarl
|
||||||
mypy-extensions==1.1.0
|
mypy-extensions==1.1.0
|
||||||
# via typing-inspect
|
# via typing-inspect
|
||||||
|
networkx==3.6.1
|
||||||
|
# via torch
|
||||||
|
nltk==3.9.3
|
||||||
|
# via assistance-engine
|
||||||
numpy==2.4.2
|
numpy==2.4.2
|
||||||
# via
|
# via
|
||||||
|
# accelerate
|
||||||
# assistance-engine
|
# assistance-engine
|
||||||
# elasticsearch
|
# elasticsearch
|
||||||
|
# langchain-aws
|
||||||
# langchain-community
|
# langchain-community
|
||||||
# pandas
|
# pandas
|
||||||
|
# torchvision
|
||||||
|
nvidia-cublas-cu12==12.8.4.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||||
|
# via
|
||||||
|
# nvidia-cudnn-cu12
|
||||||
|
# nvidia-cusolver-cu12
|
||||||
|
# torch
|
||||||
|
nvidia-cuda-cupti-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||||
|
# via torch
|
||||||
|
nvidia-cuda-nvrtc-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||||
|
# via torch
|
||||||
|
nvidia-cuda-runtime-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||||
|
# via torch
|
||||||
|
nvidia-cudnn-cu12==9.10.2.21 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||||
|
# via torch
|
||||||
|
nvidia-cufft-cu12==11.3.3.83 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||||
|
# via torch
|
||||||
|
nvidia-cufile-cu12==1.13.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||||
|
# via torch
|
||||||
|
nvidia-curand-cu12==10.3.9.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||||
|
# via torch
|
||||||
|
nvidia-cusolver-cu12==11.7.3.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||||
|
# via torch
|
||||||
|
nvidia-cusparse-cu12==12.5.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||||
|
# via
|
||||||
|
# nvidia-cusolver-cu12
|
||||||
|
# torch
|
||||||
|
nvidia-cusparselt-cu12==0.7.1 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||||
|
# via torch
|
||||||
|
nvidia-nccl-cu12==2.27.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||||
|
# via torch
|
||||||
|
nvidia-nvjitlink-cu12==12.8.93 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||||
|
# via
|
||||||
|
# nvidia-cufft-cu12
|
||||||
|
# nvidia-cusolver-cu12
|
||||||
|
# nvidia-cusparse-cu12
|
||||||
|
# torch
|
||||||
|
nvidia-nvshmem-cu12==3.4.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||||
|
# via torch
|
||||||
|
nvidia-nvtx-cu12==12.8.90 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||||
|
# via torch
|
||||||
|
ollama==0.6.1
|
||||||
|
# via langchain-ollama
|
||||||
orjson==3.11.7
|
orjson==3.11.7
|
||||||
# via
|
# via
|
||||||
# langgraph-sdk
|
# langgraph-sdk
|
||||||
# langsmith
|
# langsmith
|
||||||
ormsgpack==1.12.2
|
ormsgpack==1.12.2
|
||||||
# via langgraph-checkpoint
|
# via langgraph-checkpoint
|
||||||
packaging==26.0
|
packaging==24.2
|
||||||
# via
|
# via
|
||||||
|
# accelerate
|
||||||
|
# huggingface-hub
|
||||||
# langchain-core
|
# langchain-core
|
||||||
# langsmith
|
# langsmith
|
||||||
# marshmallow
|
# marshmallow
|
||||||
pandas==3.0.0
|
pandas==3.0.1
|
||||||
# via assistance-engine
|
# via assistance-engine
|
||||||
|
pillow==12.1.1
|
||||||
|
# via torchvision
|
||||||
propcache==0.4.1
|
propcache==0.4.1
|
||||||
# via
|
# via
|
||||||
# aiohttp
|
# aiohttp
|
||||||
|
|
@ -137,20 +241,25 @@ protobuf==6.33.5
|
||||||
# via
|
# via
|
||||||
# grpcio-reflection
|
# grpcio-reflection
|
||||||
# grpcio-tools
|
# grpcio-tools
|
||||||
|
psutil==7.2.2
|
||||||
|
# via accelerate
|
||||||
pydantic==2.12.5
|
pydantic==2.12.5
|
||||||
# via
|
# via
|
||||||
# langchain
|
# langchain
|
||||||
|
# langchain-aws
|
||||||
# langchain-classic
|
# langchain-classic
|
||||||
# langchain-core
|
# langchain-core
|
||||||
# langgraph
|
# langgraph
|
||||||
# langsmith
|
# langsmith
|
||||||
|
# ollama
|
||||||
# pydantic-settings
|
# pydantic-settings
|
||||||
pydantic-core==2.41.5
|
pydantic-core==2.41.5
|
||||||
# via pydantic
|
# via pydantic
|
||||||
pydantic-settings==2.12.0
|
pydantic-settings==2.13.1
|
||||||
# via langchain-community
|
# via langchain-community
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
# via
|
# via
|
||||||
|
# botocore
|
||||||
# elasticsearch
|
# elasticsearch
|
||||||
# pandas
|
# pandas
|
||||||
python-dotenv==1.2.1
|
python-dotenv==1.2.1
|
||||||
|
|
@ -159,20 +268,33 @@ python-dotenv==1.2.1
|
||||||
# pydantic-settings
|
# pydantic-settings
|
||||||
pyyaml==6.0.3
|
pyyaml==6.0.3
|
||||||
# via
|
# via
|
||||||
|
# accelerate
|
||||||
|
# huggingface-hub
|
||||||
# langchain-classic
|
# langchain-classic
|
||||||
# langchain-community
|
# langchain-community
|
||||||
# langchain-core
|
# langchain-core
|
||||||
|
rapidfuzz==3.14.3
|
||||||
|
# via assistance-engine
|
||||||
|
regex==2026.2.19
|
||||||
|
# via nltk
|
||||||
requests==2.32.5
|
requests==2.32.5
|
||||||
# via
|
# via
|
||||||
|
# huggingface-hub
|
||||||
# langchain-classic
|
# langchain-classic
|
||||||
# langchain-community
|
# langchain-community
|
||||||
# langsmith
|
# langsmith
|
||||||
# requests-toolbelt
|
# requests-toolbelt
|
||||||
requests-toolbelt==1.0.0
|
requests-toolbelt==1.0.0
|
||||||
# via langsmith
|
# via langsmith
|
||||||
|
s3transfer==0.16.0
|
||||||
|
# via boto3
|
||||||
|
safetensors==0.7.0
|
||||||
|
# via accelerate
|
||||||
setuptools==82.0.0
|
setuptools==82.0.0
|
||||||
# via grpcio-tools
|
# via
|
||||||
simsimd==6.5.12
|
# grpcio-tools
|
||||||
|
# torch
|
||||||
|
simsimd==6.5.13
|
||||||
# via elasticsearch
|
# via elasticsearch
|
||||||
six==1.17.0
|
six==1.17.0
|
||||||
# via python-dateutil
|
# via python-dateutil
|
||||||
|
|
@ -180,22 +302,40 @@ sqlalchemy==2.0.46
|
||||||
# via
|
# via
|
||||||
# langchain-classic
|
# langchain-classic
|
||||||
# langchain-community
|
# langchain-community
|
||||||
|
sympy==1.14.0
|
||||||
|
# via torch
|
||||||
tenacity==9.1.4
|
tenacity==9.1.4
|
||||||
# via
|
# via
|
||||||
# langchain-community
|
# langchain-community
|
||||||
# langchain-core
|
# langchain-core
|
||||||
tqdm==4.67.3
|
tokenizers==0.22.2
|
||||||
|
# via langchain-huggingface
|
||||||
|
torch==2.10.0
|
||||||
|
# via
|
||||||
|
# accelerate
|
||||||
|
# assistance-engine
|
||||||
|
# torchvision
|
||||||
|
torchvision==0.25.0
|
||||||
# via assistance-engine
|
# via assistance-engine
|
||||||
|
tqdm==4.67.3
|
||||||
|
# via
|
||||||
|
# assistance-engine
|
||||||
|
# huggingface-hub
|
||||||
|
# nltk
|
||||||
|
triton==3.6.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
|
||||||
|
# via torch
|
||||||
typing-extensions==4.15.0
|
typing-extensions==4.15.0
|
||||||
# via
|
# via
|
||||||
# aiosignal
|
# aiosignal
|
||||||
# anyio
|
# anyio
|
||||||
# elasticsearch
|
# elasticsearch
|
||||||
# grpcio
|
# grpcio
|
||||||
|
# huggingface-hub
|
||||||
# langchain-core
|
# langchain-core
|
||||||
# pydantic
|
# pydantic
|
||||||
# pydantic-core
|
# pydantic-core
|
||||||
# sqlalchemy
|
# sqlalchemy
|
||||||
|
# torch
|
||||||
# typing-inspect
|
# typing-inspect
|
||||||
# typing-inspection
|
# typing-inspection
|
||||||
typing-inspect==0.9.0
|
typing-inspect==0.9.0
|
||||||
|
|
@ -208,9 +348,10 @@ tzdata==2025.3 ; sys_platform == 'emscripten' or sys_platform == 'win32'
|
||||||
# via pandas
|
# via pandas
|
||||||
urllib3==2.6.3
|
urllib3==2.6.3
|
||||||
# via
|
# via
|
||||||
|
# botocore
|
||||||
# elastic-transport
|
# elastic-transport
|
||||||
# requests
|
# requests
|
||||||
uuid-utils==0.14.0
|
uuid-utils==0.14.1
|
||||||
# via
|
# via
|
||||||
# langchain-core
|
# langchain-core
|
||||||
# langsmith
|
# langsmith
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEmbeddingFactory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def create(self, model: str, **kwargs: Any):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIEmbeddingFactory(BaseEmbeddingFactory):
|
||||||
|
def create(self, model: str, **kwargs: Any):
|
||||||
|
from langchain_openai import OpenAIEmbeddings
|
||||||
|
|
||||||
|
return OpenAIEmbeddings(model=model, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaEmbeddingFactory(BaseEmbeddingFactory):
|
||||||
|
def create(self, model: str, **kwargs: Any):
|
||||||
|
from langchain_ollama import OllamaEmbeddings
|
||||||
|
|
||||||
|
return OllamaEmbeddings(model=model, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockEmbeddingFactory(BaseEmbeddingFactory):
|
||||||
|
def create(self, model: str, **kwargs: Any):
|
||||||
|
from langchain_aws import BedrockEmbeddings
|
||||||
|
|
||||||
|
return BedrockEmbeddings(model_id=model, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceEmbeddingFactory(BaseEmbeddingFactory):
|
||||||
|
def create(self, model: str, **kwargs: Any):
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
|
|
||||||
|
return HuggingFaceEmbeddings(model_name=model, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
EMBEDDING_FACTORIES: Dict[str, BaseEmbeddingFactory] = {
|
||||||
|
"openai": OpenAIEmbeddingFactory(),
|
||||||
|
"ollama": OllamaEmbeddingFactory(),
|
||||||
|
"bedrock": BedrockEmbeddingFactory(),
|
||||||
|
"huggingface": HuggingFaceEmbeddingFactory(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_embedding_model(provider: str, model: str, **kwargs: Any):
|
||||||
|
"""
|
||||||
|
Create an embedding model instance for the given provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: The provider name (openai, ollama, bedrock, huggingface).
|
||||||
|
model: The model identifier.
|
||||||
|
**kwargs: Additional keyword arguments passed to the model constructor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An embedding model instance.
|
||||||
|
"""
|
||||||
|
key = provider.strip().lower()
|
||||||
|
|
||||||
|
if key not in EMBEDDING_FACTORIES:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported embedding provider: {provider}. "
|
||||||
|
f"Available providers: {list(EMBEDDING_FACTORIES.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return EMBEDDING_FACTORIES[key].create(model=model, **kwargs)
|
||||||
|
|
@ -0,0 +1,72 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class BaseProviderFactory(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def create(self, model: str, **kwargs: Any):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChatFactory(BaseProviderFactory):
|
||||||
|
def create(self, model: str, **kwargs: Any):
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
return ChatOpenAI(model=model, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaChatFactory(BaseProviderFactory):
|
||||||
|
def create(self, model: str, **kwargs: Any):
|
||||||
|
from langchain_ollama import ChatOllama
|
||||||
|
|
||||||
|
return ChatOllama(model=model, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockChatFactory(BaseProviderFactory):
|
||||||
|
def create(self, model: str, **kwargs: Any):
|
||||||
|
from langchain_aws import ChatBedrockConverse
|
||||||
|
|
||||||
|
return ChatBedrockConverse(model=model, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceChatFactory(BaseProviderFactory):
|
||||||
|
def create(self, model: str, **kwargs: Any):
|
||||||
|
from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
|
||||||
|
|
||||||
|
llm = HuggingFacePipeline.from_model_id(
|
||||||
|
model_id=model,
|
||||||
|
task="text-generation",
|
||||||
|
pipeline_kwargs=kwargs,
|
||||||
|
)
|
||||||
|
return ChatHuggingFace(llm=llm)
|
||||||
|
|
||||||
|
|
||||||
|
CHAT_FACTORIES: Dict[str, BaseProviderFactory] = {
|
||||||
|
"openai": OpenAIChatFactory(),
|
||||||
|
"ollama": OllamaChatFactory(),
|
||||||
|
"bedrock": BedrockChatFactory(),
|
||||||
|
"huggingface": HuggingFaceChatFactory(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_chat_model(provider: str, model: str, **kwargs: Any):
|
||||||
|
"""
|
||||||
|
Create a chat model instance for the given provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: The provider name (openai, ollama, bedrock, huggingface).
|
||||||
|
model: The model identifier.
|
||||||
|
**kwargs: Additional keyword arguments passed to the model constructor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A chat model instance.
|
||||||
|
"""
|
||||||
|
key = provider.strip().lower()
|
||||||
|
|
||||||
|
if key not in CHAT_FACTORIES:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported chat provider: {provider}. "
|
||||||
|
f"Available providers: {list(CHAT_FACTORIES.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return CHAT_FACTORIES[key].create(model=model, **kwargs)
|
||||||
|
|
@ -0,0 +1,140 @@
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from concurrent import futures
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import brunix_pb2
|
||||||
|
import brunix_pb2_grpc
|
||||||
|
import grpc
|
||||||
|
from grpc_reflection.v1alpha import reflection
|
||||||
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_elasticsearch import ElasticsearchStore
|
||||||
|
|
||||||
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
if str(PROJECT_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(PROJECT_ROOT))
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger("brunix-engine")
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_kwargs(provider: str, base_url: str) -> dict[str, Any]:
|
||||||
|
if provider == "ollama":
|
||||||
|
return {"base_url": base_url}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
||||||
|
def __init__(self):
|
||||||
|
from src.emb_factory import create_embedding_model
|
||||||
|
from src.llm_factory import create_chat_model
|
||||||
|
|
||||||
|
self.base_url = os.getenv("OLLAMA_LOCAL_URL", "http://ollama-light-service:11434")
|
||||||
|
self.chat_provider = os.getenv("CHAT_PROVIDER", "ollama")
|
||||||
|
self.embedding_provider = os.getenv("EMBEDDING_PROVIDER", self.chat_provider)
|
||||||
|
self.chat_model_name = os.getenv("OLLAMA_MODEL_NAME")
|
||||||
|
self.embedding_model_name = os.getenv(
|
||||||
|
"OLLAMA_EMB_MODEL_NAME", self.chat_model_name
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.chat_model_name:
|
||||||
|
raise ValueError("OLLAMA_MODEL_NAME is required")
|
||||||
|
|
||||||
|
logger.info("Starting server")
|
||||||
|
|
||||||
|
self.llm = create_chat_model(
|
||||||
|
provider=self.chat_provider,
|
||||||
|
model=self.chat_model_name,
|
||||||
|
**_provider_kwargs(self.chat_provider, self.base_url),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.embeddings = create_embedding_model(
|
||||||
|
provider=self.embedding_provider,
|
||||||
|
model=self.embedding_model_name,
|
||||||
|
**_provider_kwargs(self.embedding_provider, self.base_url),
|
||||||
|
)
|
||||||
|
|
||||||
|
es_url = os.getenv("ELASTICSEARCH_URL", "http://elasticsearch:9200")
|
||||||
|
logger.info("ElasticSearch on: %s", es_url)
|
||||||
|
|
||||||
|
self.vector_store = ElasticsearchStore(
|
||||||
|
es_url=es_url,
|
||||||
|
index_name=os.getenv("ELASTICSEARCH_INDEX"),
|
||||||
|
embedding=self.embeddings,
|
||||||
|
query_field="text",
|
||||||
|
vector_query_field="embedding",
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_context(self, docs) -> str:
|
||||||
|
parts = []
|
||||||
|
for i, d in enumerate(docs, start=1):
|
||||||
|
meta = d.metadata or {}
|
||||||
|
source = meta.get("source", "unknown")
|
||||||
|
doc_id = meta.get("doc_id", "unknown")
|
||||||
|
chunk_id = meta.get("chunk_id", "unknown")
|
||||||
|
|
||||||
|
parts.append(
|
||||||
|
f"[{i}] source={source} doc_id={doc_id} chunk_id={chunk_id}\n{d.page_content}"
|
||||||
|
)
|
||||||
|
return "\n\n---\n\n".join(parts)
|
||||||
|
|
||||||
|
def AskAgent(self, request, context):
|
||||||
|
logger.info(f"request {request.session_id}): {request.query[:50]}.")
|
||||||
|
|
||||||
|
docs_and_scores = self.vector_store.similarity_search_with_score(
|
||||||
|
request.query, k=4
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
context_text = self.format_context([doc for doc, _ in docs_and_scores])
|
||||||
|
# 4. Prompt Engineering
|
||||||
|
prompt = ChatPromptTemplate.from_template("""
|
||||||
|
You are a helpful assistant. Use the following retrieved documents to answer the question.
|
||||||
|
If you don't know the answer, say you don't know.
|
||||||
|
|
||||||
|
CONTEXT:
|
||||||
|
{context}
|
||||||
|
|
||||||
|
QUESTION:
|
||||||
|
{question}
|
||||||
|
""")
|
||||||
|
|
||||||
|
chain = prompt | self.llm
|
||||||
|
|
||||||
|
result = chain.invoke({"context": context_text, "question": request.query})
|
||||||
|
result_text = getattr(result, "content", str(result))
|
||||||
|
yield brunix_pb2.AgentResponse(
|
||||||
|
text=str(result_text), avap_code="AVAP-2026", is_final=True
|
||||||
|
)
|
||||||
|
|
||||||
|
yield brunix_pb2.AgentResponse(text="", avap_code="", is_final=True)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in AskAgent: {str(e)}")
|
||||||
|
yield brunix_pb2.AgentResponse(
|
||||||
|
text=f"[Error Motor]: {str(e)}", is_final=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def serve():
|
||||||
|
|
||||||
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||||
|
|
||||||
|
brunix_pb2_grpc.add_AssistanceEngineServicer_to_server(BrunixEngine(), server)
|
||||||
|
|
||||||
|
SERVICE_NAMES = (
|
||||||
|
brunix_pb2.DESCRIPTOR.services_by_name["AssistanceEngine"].full_name,
|
||||||
|
reflection.SERVICE_NAME,
|
||||||
|
)
|
||||||
|
reflection.enable_server_reflection(SERVICE_NAMES, server)
|
||||||
|
|
||||||
|
server.add_insecure_port("[::]:50051")
|
||||||
|
logger.info("Brunix Engine on port 50051")
|
||||||
|
server.start()
|
||||||
|
server.wait_for_termination()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
serve()
|
||||||
|
|
@ -1,36 +1,63 @@
|
||||||
import os
|
|
||||||
import grpc
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
from grpc_reflection.v1alpha import reflection
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import brunix_pb2
|
import brunix_pb2
|
||||||
import brunix_pb2_grpc
|
import brunix_pb2_grpc
|
||||||
|
import grpc
|
||||||
from langchain_community.llms import Ollama
|
from grpc_reflection.v1alpha import reflection
|
||||||
from langchain_community.embeddings import OllamaEmbeddings
|
|
||||||
from langchain_elasticsearch import ElasticsearchStore
|
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
|
from langchain_elasticsearch import ElasticsearchStore
|
||||||
|
|
||||||
|
# PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
# if str(PROJECT_ROOT) not in sys.path:
|
||||||
|
# sys.path.insert(0, str(PROJECT_ROOT))
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger("brunix-engine")
|
logger = logging.getLogger("brunix-engine")
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_kwargs(provider: str, base_url: str) -> dict[str, Any]:
|
||||||
|
if provider == "ollama":
|
||||||
|
return {"base_url": base_url}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
from emb_factory import create_embedding_model
|
||||||
|
from llm_factory import create_chat_model
|
||||||
|
|
||||||
self.base_url = os.getenv("LLM_BASE_URL", "http://ollama-light-service:11434")
|
self.base_url = os.getenv("OLLAMA_LOCAL_URL", "http://ollama-light-service:11434")
|
||||||
self.model_name = os.getenv("OLLAMA_MODEL_NAME")
|
self.chat_provider = os.getenv("CHAT_PROVIDER", "ollama")
|
||||||
|
self.embedding_provider = os.getenv("EMBEDDING_PROVIDER", self.chat_provider)
|
||||||
|
self.chat_model_name = os.getenv("OLLAMA_MODEL_NAME")
|
||||||
|
self.embedding_model_name = os.getenv(
|
||||||
|
"OLLAMA_EMB_MODEL_NAME", self.chat_model_name
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"Starting server")
|
if not self.chat_model_name:
|
||||||
|
raise ValueError("OLLAMA_MODEL_NAME is required")
|
||||||
|
|
||||||
self.llm = Ollama(base_url=self.base_url, model=self.model_name)
|
logger.info("Starting server")
|
||||||
|
|
||||||
self.embeddings = OllamaEmbeddings(
|
self.llm = create_chat_model(
|
||||||
base_url=self.base_url, model=self.model_name
|
provider=self.chat_provider,
|
||||||
|
model=self.chat_model_name,
|
||||||
|
**_provider_kwargs(self.chat_provider, self.base_url),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.embeddings = create_embedding_model(
|
||||||
|
provider=self.embedding_provider,
|
||||||
|
model=self.embedding_model_name,
|
||||||
|
**_provider_kwargs(self.embedding_provider, self.base_url),
|
||||||
)
|
)
|
||||||
|
|
||||||
es_url = os.getenv("ELASTICSEARCH_URL", "http://elasticsearch:9200")
|
es_url = os.getenv("ELASTICSEARCH_URL", "http://elasticsearch:9200")
|
||||||
logger.info(f"ElasticSearch on: {es_url}")
|
logger.info("ElasticSearch on: %s", es_url)
|
||||||
|
|
||||||
self.vector_store = ElasticsearchStore(
|
self.vector_store = ElasticsearchStore(
|
||||||
es_url=es_url,
|
es_url=es_url,
|
||||||
|
|
@ -77,8 +104,9 @@ class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer):
|
||||||
chain = prompt | self.llm
|
chain = prompt | self.llm
|
||||||
|
|
||||||
result = chain.invoke({"context": context_text, "question": request.query})
|
result = chain.invoke({"context": context_text, "question": request.query})
|
||||||
|
result_text = getattr(result, "content", str(result))
|
||||||
yield brunix_pb2.AgentResponse(
|
yield brunix_pb2.AgentResponse(
|
||||||
text=str(result), avap_code="AVAP-2026", is_final=True
|
text=str(result_text), avap_code="AVAP-2026", is_final=True
|
||||||
)
|
)
|
||||||
|
|
||||||
yield brunix_pb2.AgentResponse(text="", avap_code="", is_final=True)
|
yield brunix_pb2.AgentResponse(text="", avap_code="", is_final=True)
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,9 @@ if [ ! -f "$KUBECONFIG_PATH" ]; then
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# 1. AI Model Tunnel (Ollama)
|
# 1. AI Model Tunnel (Ollama)
|
||||||
echo -e "${YELLOW}[1/3]${NC} Starting Ollama Light Service tunnel (localhost:11434)..."
|
# echo -e "${YELLOW}[1/3]${NC} Starting Ollama Light Service tunnel (localhost:11434)..."
|
||||||
kubectl port-forward --address 0.0.0.0 svc/ollama-light-service 11434:11434 -n brunix --kubeconfig "$KUBECONFIG_PATH" &
|
# kubectl port-forward --address 0.0.0.0 svc/ollama-light-service 11434:11434 -n brunix --kubeconfig "$KUBECONFIG_PATH" &
|
||||||
OLLAMA_PID=$!
|
# OLLAMA_PID=$!
|
||||||
|
|
||||||
# 2. Knowledge Base Tunnel (Elasticsearch)
|
# 2. Knowledge Base Tunnel (Elasticsearch)
|
||||||
echo -e "${YELLOW}[2/3]${NC} Starting Elasticsearch Vector DB tunnel (localhost:9200)..."
|
echo -e "${YELLOW}[2/3]${NC} Starting Elasticsearch Vector DB tunnel (localhost:9200)..."
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue