assistance-engine/Docker/src/emb_factory.py

68 lines
2.0 KiB
Python

from abc import ABC, abstractmethod
from typing import Any, Dict
class BaseEmbeddingFactory(ABC):
@abstractmethod
def create(self, model: str, **kwargs: Any):
raise NotImplementedError
class OpenAIEmbeddingFactory(BaseEmbeddingFactory):
def create(self, model: str, **kwargs: Any):
from langchain_openai import OpenAIEmbeddings
return OpenAIEmbeddings(model=model, **kwargs)
class OllamaEmbeddingFactory(BaseEmbeddingFactory):
def create(self, model: str, **kwargs: Any):
from langchain_ollama import OllamaEmbeddings
return OllamaEmbeddings(model=model, **kwargs)
class BedrockEmbeddingFactory(BaseEmbeddingFactory):
def create(self, model: str, **kwargs: Any):
from langchain_aws import BedrockEmbeddings
return BedrockEmbeddings(model_id=model, **kwargs)
class HuggingFaceEmbeddingFactory(BaseEmbeddingFactory):
def create(self, model: str, **kwargs: Any):
from langchain_huggingface import HuggingFaceEmbeddings
return HuggingFaceEmbeddings(model_name=model, **kwargs)
EMBEDDING_FACTORIES: Dict[str, BaseEmbeddingFactory] = {
"openai": OpenAIEmbeddingFactory(),
"ollama": OllamaEmbeddingFactory(),
"bedrock": BedrockEmbeddingFactory(),
"huggingface": HuggingFaceEmbeddingFactory(),
}
def create_embedding_model(provider: str, model: str, **kwargs: Any):
"""
Create an embedding model instance for the given provider.
Args:
provider: The provider name (openai, ollama, bedrock, huggingface).
model: The model identifier.
**kwargs: Additional keyword arguments passed to the model constructor.
Returns:
An embedding model instance.
"""
key = provider.strip().lower()
if key not in EMBEDDING_FACTORIES:
raise ValueError(
f"Unsupported embedding provider: {provider}. "
f"Available providers: {list(EMBEDDING_FACTORIES.keys())}"
)
return EMBEDDING_FACTORIES[key].create(model=model, **kwargs)