79 lines
2.3 KiB
Python
79 lines
2.3 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict
|
|
|
|
|
|
class BaseProviderFactory(ABC):
|
|
@abstractmethod
|
|
def create(self, model: str, **kwargs: Any):
|
|
raise NotImplementedError
|
|
|
|
|
|
class OpenAIChatFactory(BaseProviderFactory):
|
|
def create(self, model: str, **kwargs: Any):
|
|
from langchain_openai import ChatOpenAI
|
|
|
|
return ChatOpenAI(model=model, **kwargs)
|
|
|
|
class AnthropicChatFactory(BaseProviderFactory):
|
|
def create(self, model: str, **kwargs: Any):
|
|
from langchain_anthropic import ChatAnthropic
|
|
|
|
return ChatAnthropic(model=model, **kwargs)
|
|
|
|
class OllamaChatFactory(BaseProviderFactory):
|
|
def create(self, model: str, **kwargs: Any):
|
|
from langchain_ollama import ChatOllama
|
|
|
|
return ChatOllama(model=model, **kwargs)
|
|
|
|
|
|
class BedrockChatFactory(BaseProviderFactory):
|
|
def create(self, model: str, **kwargs: Any):
|
|
from langchain_aws import ChatBedrockConverse
|
|
|
|
return ChatBedrockConverse(model=model, **kwargs)
|
|
|
|
|
|
class HuggingFaceChatFactory(BaseProviderFactory):
|
|
def create(self, model: str, **kwargs: Any):
|
|
from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
|
|
|
|
llm = HuggingFacePipeline.from_model_id(
|
|
model_id=model,
|
|
task="text-generation",
|
|
pipeline_kwargs=kwargs,
|
|
)
|
|
return ChatHuggingFace(llm=llm)
|
|
|
|
|
|
CHAT_FACTORIES: Dict[str, BaseProviderFactory] = {
|
|
"openai": OpenAIChatFactory(),
|
|
"ollama": OllamaChatFactory(),
|
|
"bedrock": BedrockChatFactory(),
|
|
"huggingface": HuggingFaceChatFactory(),
|
|
"anthropic": AnthropicChatFactory(),
|
|
}
|
|
|
|
|
|
def create_chat_model(provider: str, model: str, **kwargs: Any):
|
|
"""
|
|
Create a chat model instance for the given provider.
|
|
|
|
Args:
|
|
provider: The provider name (openai, ollama, bedrock, huggingface).
|
|
model: The model identifier.
|
|
**kwargs: Additional keyword arguments passed to the model constructor.
|
|
|
|
Returns:
|
|
A chat model instance.
|
|
"""
|
|
key = provider.strip().lower()
|
|
|
|
if key not in CHAT_FACTORIES:
|
|
raise ValueError(
|
|
f"Unsupported chat provider: {provider}. "
|
|
f"Available providers: {list(CHAT_FACTORIES.keys())}"
|
|
)
|
|
|
|
return CHAT_FACTORIES[key].create(model=model, **kwargs)
|