68 lines
2.0 KiB
Python
68 lines
2.0 KiB
Python
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict
|
|
|
|
|
|
class BaseEmbeddingFactory(ABC):
|
|
@abstractmethod
|
|
def create(self, model: str, **kwargs: Any):
|
|
raise NotImplementedError
|
|
|
|
|
|
class OpenAIEmbeddingFactory(BaseEmbeddingFactory):
|
|
def create(self, model: str, **kwargs: Any):
|
|
from langchain_openai import OpenAIEmbeddings
|
|
|
|
return OpenAIEmbeddings(model=model, **kwargs)
|
|
|
|
|
|
class OllamaEmbeddingFactory(BaseEmbeddingFactory):
|
|
def create(self, model: str, **kwargs: Any):
|
|
from langchain_ollama import OllamaEmbeddings
|
|
|
|
return OllamaEmbeddings(model=model, **kwargs)
|
|
|
|
|
|
class BedrockEmbeddingFactory(BaseEmbeddingFactory):
|
|
def create(self, model: str, **kwargs: Any):
|
|
from langchain_aws import BedrockEmbeddings
|
|
|
|
return BedrockEmbeddings(model_id=model, **kwargs)
|
|
|
|
|
|
class HuggingFaceEmbeddingFactory(BaseEmbeddingFactory):
|
|
def create(self, model: str, **kwargs: Any):
|
|
from langchain_huggingface import HuggingFaceEmbeddings
|
|
|
|
return HuggingFaceEmbeddings(model_name=model, **kwargs)
|
|
|
|
|
|
EMBEDDING_FACTORIES: Dict[str, BaseEmbeddingFactory] = {
|
|
"openai": OpenAIEmbeddingFactory(),
|
|
"ollama": OllamaEmbeddingFactory(),
|
|
"bedrock": BedrockEmbeddingFactory(),
|
|
"huggingface": HuggingFaceEmbeddingFactory(),
|
|
}
|
|
|
|
|
|
def create_embedding_model(provider: str, model: str, **kwargs: Any):
|
|
"""
|
|
Create an embedding model instance for the given provider.
|
|
|
|
Args:
|
|
provider: The provider name (openai, ollama, bedrock, huggingface).
|
|
model: The model identifier.
|
|
**kwargs: Additional keyword arguments passed to the model constructor.
|
|
|
|
Returns:
|
|
An embedding model instance.
|
|
"""
|
|
key = provider.strip().lower()
|
|
|
|
if key not in EMBEDDING_FACTORIES:
|
|
raise ValueError(
|
|
f"Unsupported embedding provider: {provider}. "
|
|
f"Available providers: {list(EMBEDDING_FACTORIES.keys())}"
|
|
)
|
|
|
|
return EMBEDDING_FACTORIES[key].create(model=model, **kwargs)
|