assistance-engine/Docker/src/utils/llm_factory.py

79 lines
2.3 KiB
Python

from abc import ABC, abstractmethod
from typing import Any, Dict
class BaseProviderFactory(ABC):
@abstractmethod
def create(self, model: str, **kwargs: Any):
raise NotImplementedError
class OpenAIChatFactory(BaseProviderFactory):
def create(self, model: str, **kwargs: Any):
from langchain_openai import ChatOpenAI
return ChatOpenAI(model=model, **kwargs)
class AnthropicChatFactory(BaseProviderFactory):
def create(self, model: str, **kwargs: Any):
from langchain_anthropic import ChatAnthropic
return ChatAnthropic(model=model, **kwargs)
class OllamaChatFactory(BaseProviderFactory):
def create(self, model: str, **kwargs: Any):
from langchain_ollama import ChatOllama
return ChatOllama(model=model, **kwargs)
class BedrockChatFactory(BaseProviderFactory):
def create(self, model: str, **kwargs: Any):
from langchain_aws import ChatBedrockConverse
return ChatBedrockConverse(model=model, **kwargs)
class HuggingFaceChatFactory(BaseProviderFactory):
def create(self, model: str, **kwargs: Any):
from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
llm = HuggingFacePipeline.from_model_id(
model_id=model,
task="text-generation",
pipeline_kwargs=kwargs,
)
return ChatHuggingFace(llm=llm)
CHAT_FACTORIES: Dict[str, BaseProviderFactory] = {
"openai": OpenAIChatFactory(),
"ollama": OllamaChatFactory(),
"bedrock": BedrockChatFactory(),
"huggingface": HuggingFaceChatFactory(),
"anthropic": AnthropicChatFactory(),
}
def create_chat_model(provider: str, model: str, **kwargs: Any):
"""
Create a chat model instance for the given provider.
Args:
provider: The provider name (openai, ollama, bedrock, huggingface).
model: The model identifier.
**kwargs: Additional keyword arguments passed to the model constructor.
Returns:
A chat model instance.
"""
key = provider.strip().lower()
if key not in CHAT_FACTORIES:
raise ValueError(
f"Unsupported chat provider: {provider}. "
f"Available providers: {list(CHAT_FACTORIES.keys())}"
)
return CHAT_FACTORIES[key].create(model=model, **kwargs)