From 6c8261f2b2b492ec8c40fdd6fe6b62585f3277fb Mon Sep 17 00:00:00 2001 From: acano Date: Tue, 3 Mar 2026 17:56:39 +0100 Subject: [PATCH] feat: implement graph-based agent with LLM integration and dynamic prompts --- Docker/.dockerignore | 53 +++++++++++++++ Docker/Dockerfile | 1 - Docker/docker-compose.yaml | 17 ++--- Docker/requirements.txt | 93 ++++++++++++++++++++++----- Docker/src/graph.py | 60 +++++++++++++++++ Docker/src/prompts.py | 89 ++++++++++++++++++++++++++ Docker/src/server.py | 110 +++++++++++++++++--------------- Docker/src/state.py | 9 +++ Docker/src/utils/emb_factory.py | 67 +++++++++++++++++++ Docker/src/utils/llm_factory.py | 72 +++++++++++++++++++++ 10 files changed, 493 insertions(+), 78 deletions(-) create mode 100644 Docker/.dockerignore create mode 100644 Docker/src/graph.py create mode 100644 Docker/src/prompts.py create mode 100644 Docker/src/state.py create mode 100644 Docker/src/utils/emb_factory.py create mode 100644 Docker/src/utils/llm_factory.py diff --git a/Docker/.dockerignore b/Docker/.dockerignore new file mode 100644 index 0000000..b7acc73 --- /dev/null +++ b/Docker/.dockerignore @@ -0,0 +1,53 @@ +# Documentation +*.md +documentation/ + +# Build and dependency files +Makefile +*.pyc +__pycache__/ +*.egg-info/ +dist/ +build/ + +# Development and testing +.venv/ +venv/ +env/ +.pytest_cache/ +.coverage + +# Git and version control +.git/ +.gitignore +.gitattributes + +# IDE and editor files +.vscode/ +.idea/ +*.swp +*.swo +*~ +.DS_Store + +# Environment files +.env +.env.local +.env.*.local + +# Docker files (no copy Docker files into the image) +Dockerfile +docker-compose.yaml + +# CI/CD +.github/ +.gitlab-ci.yml + +# Temporary files +*.tmp +*.log +scratches/ + +# Node modules (if any) +node_modules/ +npm-debug.log diff --git a/Docker/Dockerfile b/Docker/Dockerfile index ade55a8..240cf50 100644 --- a/Docker/Dockerfile +++ b/Docker/Dockerfile @@ -3,7 +3,6 @@ FROM python:3.11-slim ENV PYTHONDONTWRITEBYTECODE=1 ENV PYTHONUNBUFFERED=1 - WORKDIR /app COPY ./requirements.txt . diff --git a/Docker/docker-compose.yaml b/Docker/docker-compose.yaml index 53e6cd1..a1df028 100644 --- a/Docker/docker-compose.yaml +++ b/Docker/docker-compose.yaml @@ -4,17 +4,18 @@ services: brunix-engine: build: . container_name: brunix-assistance-engine - env_file: .env ports: - "50052:50051" environment: - - ELASTICSEARCH_URL=http://host.docker.internal:9200 - - DATABASE_URL=postgresql://postgres:brunix_pass@host.docker.internal:5432/postgres - - - LANGFUSE_HOST=http://45.77.119.180 - - LANGFUSE_PUBLIC_KEY=${LANGFUSE_PUBLIC_KEY} - - LANGFUSE_SECRET_KEY=${LANGFUSE_SECRET_KEY} - - LLM_BASE_URL=http://host.docker.internal:11434 + ELASTICSEARCH_URL: ${ELASTICSEARCH_URL} + ELASTICSEARCH_INDEX: ${ELASTICSEARCH_INDEX} + POSTGRES_URL: ${POSTGRES_URL} + LANGFUSE_HOST: ${LANGFUSE_HOST} + LANGFUSE_PUBLIC_KEY: ${LANGFUSE_PUBLIC_KEY} + LANGFUSE_SECRET_KEY: ${LANGFUSE_SECRET_KEY} + OLLAMA_URL: ${OLLAMA_URL} + OLLAMA_MODEL_NAME: ${OLLAMA_MODEL_NAME} + OLLAMA_EMB_MODEL_NAME: ${OLLAMA_EMB_MODEL_NAME} extra_hosts: - "host.docker.internal:host-gateway" diff --git a/Docker/requirements.txt b/Docker/requirements.txt index 872b430..cbbd2bf 100644 --- a/Docker/requirements.txt +++ b/Docker/requirements.txt @@ -1,5 +1,5 @@ # This file was autogenerated by uv via the following command: -# uv export --format requirements-txt --no-hashes --no-dev -o requirements.txt +# uv export --format requirements-txt --no-hashes --no-dev -o Docker/requirements.txt aiohappyeyeballs==2.6.1 # via aiohttp aiohttp==3.13.3 @@ -12,6 +12,12 @@ anyio==4.12.1 # via httpx attrs==25.4.0 # via aiohttp +boto3==1.42.58 + # via langchain-aws +botocore==1.42.58 + # via + # boto3 + # s3transfer certifi==2026.1.4 # via # elastic-transport @@ -20,8 +26,11 @@ certifi==2026.1.4 # requests charset-normalizer==3.4.4 # via requests +click==8.3.1 + # via nltk colorama==0.4.6 ; sys_platform == 'win32' # via + # click # loguru # tqdm dataclasses-json==0.6.7 @@ -30,72 +39,98 @@ elastic-transport==8.17.1 # via elasticsearch elasticsearch==8.19.3 # via langchain-elasticsearch +filelock==3.24.3 + # via huggingface-hub frozenlist==1.8.0 # via # aiohttp # aiosignal -greenlet==3.3.1 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64' +fsspec==2025.10.0 + # via huggingface-hub +greenlet==3.3.2 ; platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64' # via sqlalchemy -grpcio==1.78.0 +grpcio==1.78.1 # via # assistance-engine # grpcio-reflection # grpcio-tools -grpcio-reflection==1.78.0 +grpcio-reflection==1.78.1 # via assistance-engine -grpcio-tools==1.78.0 +grpcio-tools==1.78.1 # via assistance-engine h11==0.16.0 # via httpcore +hf-xet==1.3.0 ; platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64' + # via huggingface-hub httpcore==1.0.9 # via httpx httpx==0.28.1 # via # langgraph-sdk # langsmith + # ollama httpx-sse==0.4.3 # via langchain-community +huggingface-hub==0.36.2 + # via + # langchain-huggingface + # tokenizers idna==3.11 # via # anyio # httpx # requests # yarl +jmespath==1.1.0 + # via + # boto3 + # botocore +joblib==1.5.3 + # via nltk jsonpatch==1.33 # via langchain-core jsonpointer==3.0.0 # via jsonpatch langchain==1.2.10 # via assistance-engine +langchain-aws==1.3.1 + # via assistance-engine langchain-classic==1.0.1 # via langchain-community langchain-community==0.4.1 # via assistance-engine -langchain-core==1.2.11 +langchain-core==1.2.15 # via # langchain + # langchain-aws # langchain-classic # langchain-community # langchain-elasticsearch + # langchain-huggingface + # langchain-ollama # langchain-text-splitters # langgraph # langgraph-checkpoint # langgraph-prebuilt langchain-elasticsearch==1.0.0 # via assistance-engine -langchain-text-splitters==1.1.0 +langchain-huggingface==1.2.0 + # via assistance-engine +langchain-ollama==1.0.1 + # via assistance-engine +langchain-text-splitters==1.1.1 # via langchain-classic -langgraph==1.0.8 +langgraph==1.0.9 # via langchain langgraph-checkpoint==4.0.0 # via # langgraph # langgraph-prebuilt -langgraph-prebuilt==1.0.7 +langgraph-prebuilt==1.0.8 # via langgraph -langgraph-sdk==0.3.5 +langgraph-sdk==0.3.8 # via langgraph -langsmith==0.7.1 +langsmith==0.7.6 # via # langchain-classic # langchain-community @@ -110,24 +145,30 @@ multidict==6.7.1 # yarl mypy-extensions==1.1.0 # via typing-inspect +nltk==3.9.3 + # via assistance-engine numpy==2.4.2 # via # assistance-engine # elasticsearch + # langchain-aws # langchain-community # pandas +ollama==0.6.1 + # via langchain-ollama orjson==3.11.7 # via # langgraph-sdk # langsmith ormsgpack==1.12.2 # via langgraph-checkpoint -packaging==26.0 +packaging==24.2 # via + # huggingface-hub # langchain-core # langsmith # marshmallow -pandas==3.0.0 +pandas==3.0.1 # via assistance-engine propcache==0.4.1 # via @@ -140,17 +181,20 @@ protobuf==6.33.5 pydantic==2.12.5 # via # langchain + # langchain-aws # langchain-classic # langchain-core # langgraph # langsmith + # ollama # pydantic-settings pydantic-core==2.41.5 # via pydantic -pydantic-settings==2.12.0 +pydantic-settings==2.13.1 # via langchain-community python-dateutil==2.9.0.post0 # via + # botocore # elasticsearch # pandas python-dotenv==1.2.1 @@ -159,20 +203,28 @@ python-dotenv==1.2.1 # pydantic-settings pyyaml==6.0.3 # via + # huggingface-hub # langchain-classic # langchain-community # langchain-core +rapidfuzz==3.14.3 + # via assistance-engine +regex==2026.2.19 + # via nltk requests==2.32.5 # via + # huggingface-hub # langchain-classic # langchain-community # langsmith # requests-toolbelt requests-toolbelt==1.0.0 # via langsmith +s3transfer==0.16.0 + # via boto3 setuptools==82.0.0 # via grpcio-tools -simsimd==6.5.12 +simsimd==6.5.13 # via elasticsearch six==1.17.0 # via python-dateutil @@ -184,14 +236,20 @@ tenacity==9.1.4 # via # langchain-community # langchain-core +tokenizers==0.22.2 + # via langchain-huggingface tqdm==4.67.3 - # via assistance-engine + # via + # assistance-engine + # huggingface-hub + # nltk typing-extensions==4.15.0 # via # aiosignal # anyio # elasticsearch # grpcio + # huggingface-hub # langchain-core # pydantic # pydantic-core @@ -208,9 +266,10 @@ tzdata==2025.3 ; sys_platform == 'emscripten' or sys_platform == 'win32' # via pandas urllib3==2.6.3 # via + # botocore # elastic-transport # requests -uuid-utils==0.14.0 +uuid-utils==0.14.1 # via # langchain-core # langsmith diff --git a/Docker/src/graph.py b/Docker/src/graph.py new file mode 100644 index 0000000..0ee0cf5 --- /dev/null +++ b/Docker/src/graph.py @@ -0,0 +1,60 @@ +# graph.py +import logging + +from langchain_core.documents import Document +from langchain_core.messages import SystemMessage +from langgraph.graph import END, StateGraph +from langgraph.graph.state import CompiledStateGraph +from prompts import GENERATE_PROMPT, REFORMULATE_PROMPT +from state import AgentState + +logger = logging.getLogger(__name__) + + +def format_context(docs: list[Document]) -> str: + chunks = [] + for i, doc in enumerate(docs, 1): + source = (doc.metadata or {}).get("source", "Untitled") + source_id = (doc.metadata or {}).get("id", f"chunk-{i}") + text = doc.page_content or "" + chunks.append(f"[{i}] id={source_id} source={source}\n{text}") + return "\n\n".join(chunks) + + +def build_graph(llm, vector_store) -> CompiledStateGraph: + def reformulate(state: AgentState) -> AgentState: + user_msg = state["messages"][-1] + resp = llm.invoke([REFORMULATE_PROMPT, user_msg]) + reformulated = resp.content.strip() + logger.info(f"[reformulate] '{user_msg.content}' → '{reformulated}'") + return {"reformulated_query": reformulated} + + def retrieve(state: AgentState) -> AgentState: + query = state["reformulated_query"] + docs = vector_store.as_retriever( + search_type="similarity", + search_kwargs={"k": 3}, + ).invoke(query) + context = format_context(docs) + logger.info(f"[retrieve] {len(docs)} docs fetched") + logger.info(context) + return {"context": context} + + def generate(state: AgentState) -> AgentState: + prompt = SystemMessage( + content=GENERATE_PROMPT.content.format(context=state["context"]) + ) + resp = llm.invoke([prompt] + state["messages"]) + return {"messages": [resp]} + + graph_builder = StateGraph(AgentState) + graph_builder.add_node("reformulate", reformulate) + graph_builder.add_node("retrieve", retrieve) + graph_builder.add_node("generate", generate) + + graph_builder.set_entry_point("reformulate") + graph_builder.add_edge("reformulate", "retrieve") + graph_builder.add_edge("retrieve", "generate") + graph_builder.add_edge("generate", END) + + return graph_builder.compile() diff --git a/Docker/src/prompts.py b/Docker/src/prompts.py new file mode 100644 index 0000000..f8938d8 --- /dev/null +++ b/Docker/src/prompts.py @@ -0,0 +1,89 @@ +from langchain_core.messages import SystemMessage + +REFORMULATE_PROMPT = SystemMessage( + content=( + "You are a deterministic lexical query rewriter used for vector retrieval.\n" + "Your task is to rewrite user questions into optimized keyword search queries.\n\n" + + "CRITICAL RULES (ABSOLUTE):\n" + "1. NEVER answer the question.\n" + "2. NEVER expand acronyms.\n" + "3. NEVER introduce new terms not present in the original query.\n" + "4. NEVER infer missing information.\n" + "5. NEVER add explanations, definitions, or interpretations.\n" + "6. Preserve all technical tokens exactly as written.\n" + "7. Only remove filler words (e.g., what, does, is, explain, tell me, please).\n" + "8. You may reorder terms for better retrieval.\n" + "9. Output must be a single-line plain keyword query.\n" + "10. If the query is already optimal, return it unchanged.\n\n" + "11. If you receive something that looks like code, do NOT attempt to rewrite it. Return it verbatim.\n\n" + + "ALLOWED OPERATIONS:\n" + "- Remove interrogative phrasing.\n" + "- Remove stopwords.\n" + "- Reorder words.\n" + "- Convert to noun phrase form.\n\n" + + "FORBIDDEN OPERATIONS:\n" + "- Expanding abbreviations.\n" + "- Paraphrasing into unseen vocabulary.\n" + "- Adding definitions.\n" + "- Answering implicitly.\n\n" + + "Examples:\n" + "Input: What does AVAP stand for?\n" + "Output: AVAP stand for\n" + + "Input: Hey, I'm trying to understand how AVAP handels a ZeroDivisionError when doing divison or modulus operatoins. Can you explane what situatoins cause a ZeroDivisionError to be raised and how I can catch it in my AVAP scripts?\n" + "Output: AVAP ZeroDivisionError division / modulus % catch try except\n" + + "Input: What does AVAP stand for?\n" + "Output: AVAP stand for\n" + + "Input: Please explain how the import statement works in AVAP scripts.\n" + "Output: AVAP import statement syntax behavior\n\n" + + "Return only the rewritten query." + ) +) + +GENERATE_PROMPT = SystemMessage( + content=( + """You are a retrieval-grounded assistant. + + CRITICAL SCOPE RULES + - AVAP (Advanced Virtual API Programming) is a new/unknown programming language. + - If the user message is AVAP-related (explicitly mentions "AVAP" case-insensitive OR the provided context is AVAP code/docs), you MUST answer using ONLY . + - You MUST NOT use general programming knowledge to interpret AVAP. + - You MUST NOT assume AVAP behaves like any other language even if syntax looks similar. + - You MUST NOT infer missing details. Only state what is explicitly present in . + + GROUNDING REQUIREMENTS (AVAP-RELATED) + 1) Every non-trivial factual claim MUST be directly supported by an EXACT QUOTE from . + 2) If a claim is not supported by a quote, DO NOT include it. + 3) If does not contain enough information to answer, reply with EXACTLY: + "I don't have enough information in the provided context to answer that." + + WORKFLOW (AVAP-RELATED) — FOLLOW IN ORDER + A) Identify the specific question(s) being asked. + B) Extract the minimum necessary quotes from that answer those question(s). + C) Write the answer using ONLY those quotes (paraphrase is allowed, but every statement must be backed by at least one quote). + D) Verify: for EACH sentence in your answer, confirm there is a supporting quote. If any sentence lacks a quote, delete it or refuse. + + OUTPUT FORMAT (AVAP-RELATED ONLY) + Answer: + + + Evidence: + - "" + - "" + (Include only quotes you actually used. Prefer the smallest quotes that fully support the statements.) + + NON-AVAP QUESTIONS + - If the question is clearly not AVAP-related, answer normally using general knowledge. + + + {context} + """ + ) +) \ No newline at end of file diff --git a/Docker/src/server.py b/Docker/src/server.py index 52452a1..128a582 100644 --- a/Docker/src/server.py +++ b/Docker/src/server.py @@ -1,87 +1,93 @@ -import os -import grpc import logging +import os from concurrent import futures -from grpc_reflection.v1alpha import reflection + import brunix_pb2 import brunix_pb2_grpc - -from langchain_community.llms import Ollama -from langchain_community.embeddings import OllamaEmbeddings +import grpc +from grpc_reflection.v1alpha import reflection from langchain_elasticsearch import ElasticsearchStore -from langchain_core.prompts import ChatPromptTemplate + +from utils.llm_factory import create_chat_model +from utils.emb_factory import create_embedding_model +from graph import build_graph + logging.basicConfig(level=logging.INFO) logger = logging.getLogger("brunix-engine") class BrunixEngine(brunix_pb2_grpc.AssistanceEngineServicer): def __init__(self): - - self.base_url = os.getenv("LLM_BASE_URL", "http://ollama-light-service:11434") - self.model_name = os.getenv("LLM_MODEL", "qwen2.5:1.5b") - - logger.info(f"Starting server") - - self.llm = Ollama(base_url=self.base_url, model=self.model_name) - - self.embeddings = OllamaEmbeddings(base_url=self.base_url, model="nomic-embed-text") - - es_url = os.getenv("ELASTICSEARCH_URL", "http://elasticsearch:9200") - logger.info(f"ElasticSearch on: {es_url}") - - self.vector_store = ElasticsearchStore( - es_url=es_url, - index_name="avap_manuals", - embedding=self.embeddings + self.llm = create_chat_model( + provider="ollama", + model=os.getenv("OLLAMA_MODEL_NAME"), + base_url=os.getenv("OLLAMA_URL"), + temperature=0, + validate_model_on_init=True, ) + self.embeddings = create_embedding_model( + provider="ollama", + model=os.getenv("OLLAMA_EMB_MODEL_NAME"), + base_url=os.getenv("OLLAMA_URL"), + ) + self.vector_store = ElasticsearchStore( + es_url=os.getenv("ELASTICSEARCH_URL"), + index_name=os.getenv("ELASTICSEARCH_INDEX"), + embedding=self.embeddings, + query_field="text", + vector_query_field="vector", + ) + self.graph = build_graph( + llm=self.llm, + vector_store=self.vector_store + ) + logger.info("Brunix Engine initializing.") + def AskAgent(self, request, context): logger.info(f"request {request.session_id}): {request.query[:50]}.") try: - context_text = "AVAP is a virtual programming language for API development." - # 4. Prompt Engineering - prompt = ChatPromptTemplate.from_template(""" - You are Brunix, the 101OBEX artificial intelligence for the AVAP Sphere platform. Respond in a professional manner. - - CONTEXT: - {context} + final_state = self.graph.invoke({"messages": [{"role": "user", + "content": request.query}]}) - QUESTION: - {question} - """) + messages = final_state.get("messages", []) + last_msg = messages[-1] if messages else None + result_text = getattr(last_msg, "content", str(last_msg)) if last_msg else "" + + yield brunix_pb2.AgentResponse( + text=result_text, + avap_code="AVAP-2026", + is_final=True, + ) - chain = prompt | self.llm - - for chunk in chain.stream({"context": context_text, "question": request.query}): - yield brunix_pb2.AgentResponse( - text=str(chunk), - avap_code="AVAP-2026", - is_final=False - ) - yield brunix_pb2.AgentResponse(text="", avap_code="", is_final=True) except Exception as e: - logger.error(f"Error in AskAgent: {str(e)}") - yield brunix_pb2.AgentResponse(text=f"[Error Motor]: {str(e)}", is_final=True) + logger.error(f"Error in AskAgent: {str(e)}", exc_info=True) + yield brunix_pb2.AgentResponse( + text=f"[Error Motor]: {str(e)}", + is_final=True, + ) + def serve(): - + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - + brunix_pb2_grpc.add_AssistanceEngineServicer_to_server(BrunixEngine(), server) SERVICE_NAMES = ( - brunix_pb2.DESCRIPTOR.services_by_name['AssistanceEngine'].full_name, + brunix_pb2.DESCRIPTOR.services_by_name["AssistanceEngine"].full_name, reflection.SERVICE_NAME, ) reflection.enable_server_reflection(SERVICE_NAMES, server) - - server.add_insecure_port('[::]:50051') + + server.add_insecure_port("[::]:50051") logger.info("Brunix Engine on port 50051") server.start() server.wait_for_termination() -if __name__ == '__main__': - serve() \ No newline at end of file + +if __name__ == "__main__": + serve() diff --git a/Docker/src/state.py b/Docker/src/state.py new file mode 100644 index 0000000..2e04d99 --- /dev/null +++ b/Docker/src/state.py @@ -0,0 +1,9 @@ +from typing import TypedDict, Annotated + +from langgraph.graph.message import add_messages + + +class AgentState(TypedDict): + messages: Annotated[list, add_messages] + reformulated_query: str + context: str \ No newline at end of file diff --git a/Docker/src/utils/emb_factory.py b/Docker/src/utils/emb_factory.py new file mode 100644 index 0000000..d9fb9de --- /dev/null +++ b/Docker/src/utils/emb_factory.py @@ -0,0 +1,67 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict + + +class BaseEmbeddingFactory(ABC): + @abstractmethod + def create(self, model: str, **kwargs: Any): + raise NotImplementedError + + +class OpenAIEmbeddingFactory(BaseEmbeddingFactory): + def create(self, model: str, **kwargs: Any): + from langchain_openai import OpenAIEmbeddings + + return OpenAIEmbeddings(model=model, **kwargs) + + +class OllamaEmbeddingFactory(BaseEmbeddingFactory): + def create(self, model: str, **kwargs: Any): + from langchain_ollama import OllamaEmbeddings + + return OllamaEmbeddings(model=model, **kwargs) + + +class BedrockEmbeddingFactory(BaseEmbeddingFactory): + def create(self, model: str, **kwargs: Any): + from langchain_aws import BedrockEmbeddings + + return BedrockEmbeddings(model_id=model, **kwargs) + + +class HuggingFaceEmbeddingFactory(BaseEmbeddingFactory): + def create(self, model: str, **kwargs: Any): + from langchain_huggingface import HuggingFaceEmbeddings + + return HuggingFaceEmbeddings(model_name=model, **kwargs) + + +EMBEDDING_FACTORIES: Dict[str, BaseEmbeddingFactory] = { + "openai": OpenAIEmbeddingFactory(), + "ollama": OllamaEmbeddingFactory(), + "bedrock": BedrockEmbeddingFactory(), + "huggingface": HuggingFaceEmbeddingFactory(), +} + + +def create_embedding_model(provider: str, model: str, **kwargs: Any): + """ + Create an embedding model instance for the given provider. + + Args: + provider: The provider name (openai, ollama, bedrock, huggingface). + model: The model identifier. + **kwargs: Additional keyword arguments passed to the model constructor. + + Returns: + An embedding model instance. + """ + key = provider.strip().lower() + + if key not in EMBEDDING_FACTORIES: + raise ValueError( + f"Unsupported embedding provider: {provider}. " + f"Available providers: {list(EMBEDDING_FACTORIES.keys())}" + ) + + return EMBEDDING_FACTORIES[key].create(model=model, **kwargs) diff --git a/Docker/src/utils/llm_factory.py b/Docker/src/utils/llm_factory.py new file mode 100644 index 0000000..8b1c13c --- /dev/null +++ b/Docker/src/utils/llm_factory.py @@ -0,0 +1,72 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict + + +class BaseProviderFactory(ABC): + @abstractmethod + def create(self, model: str, **kwargs: Any): + raise NotImplementedError + + +class OpenAIChatFactory(BaseProviderFactory): + def create(self, model: str, **kwargs: Any): + from langchain_openai import ChatOpenAI + + return ChatOpenAI(model=model, **kwargs) + + +class OllamaChatFactory(BaseProviderFactory): + def create(self, model: str, **kwargs: Any): + from langchain_ollama import ChatOllama + + return ChatOllama(model=model, **kwargs) + + +class BedrockChatFactory(BaseProviderFactory): + def create(self, model: str, **kwargs: Any): + from langchain_aws import ChatBedrockConverse + + return ChatBedrockConverse(model=model, **kwargs) + + +class HuggingFaceChatFactory(BaseProviderFactory): + def create(self, model: str, **kwargs: Any): + from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline + + llm = HuggingFacePipeline.from_model_id( + model_id=model, + task="text-generation", + pipeline_kwargs=kwargs, + ) + return ChatHuggingFace(llm=llm) + + +CHAT_FACTORIES: Dict[str, BaseProviderFactory] = { + "openai": OpenAIChatFactory(), + "ollama": OllamaChatFactory(), + "bedrock": BedrockChatFactory(), + "huggingface": HuggingFaceChatFactory(), +} + + +def create_chat_model(provider: str, model: str, **kwargs: Any): + """ + Create a chat model instance for the given provider. + + Args: + provider: The provider name (openai, ollama, bedrock, huggingface). + model: The model identifier. + **kwargs: Additional keyword arguments passed to the model constructor. + + Returns: + A chat model instance. + """ + key = provider.strip().lower() + + if key not in CHAT_FACTORIES: + raise ValueError( + f"Unsupported chat provider: {provider}. " + f"Available providers: {list(CHAT_FACTORIES.keys())}" + ) + + return CHAT_FACTORIES[key].create(model=model, **kwargs)