assistance-engine/scripts/pipelines/tasks/embeddings.py

125 lines
3.9 KiB
Python

import requests
from typing import Any, Callable
import numpy as np
from chonkie.embeddings import BaseEmbeddings
from src.config import settings
class OllamaEmbeddings(BaseEmbeddings):
"""Chonkie embeddings adapter for a local Ollama embedding model."""
def __init__(
self,
model: str,
base_url: str = settings.ollama_local_url,
timeout: float = 60.0,
truncate: bool = True,
keep_alive: str = "5m",
) -> None:
self.model = model
self.base_url = base_url.rstrip("/")
self.timeout = timeout
self.truncate = truncate
self.keep_alive = keep_alive
self._dimension: int | None = None
@property
def dimension(self) -> int:
if self._dimension is None:
# Lazy-load the dimension from a real embedding response.
self._dimension = int(self.embed(" ").shape[0])
return self._dimension
def embed(self, text: str) -> np.ndarray:
embeddings = self._embed_api(text)
vector = np.asarray(embeddings[0], dtype=np.float32)
if self._dimension is None:
self._dimension = int(vector.shape[0])
return vector
def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
if not texts:
return []
embeddings = self._embed_api(texts)
vectors = [np.asarray(vector, dtype=np.float32) for vector in embeddings]
if vectors and self._dimension is None:
self._dimension = int(vectors[0].shape[0])
return vectors
def count_tokens(self, text: str) -> int:
payload = self._build_payload(text)
response = self._post_embed(payload)
return int(response["prompt_eval_count"])
def count_tokens_batch(self, texts: list[str]) -> list[int]:
# Ollama returns a single prompt_eval_count for the whole request,
# not one count per input item, so we compute them individually.
return [self.count_tokens(text) for text in texts]
def get_tokenizer(self) -> Callable[[str], int]:
# Chonkie mainly needs something usable for token counting.
return self.count_tokens
@classmethod
def is_available(cls) -> bool:
try:
response = requests.get(
f"{settings.ollama_local_url}/api/tags",
timeout=5.0,
)
response.raise_for_status()
return True
except requests.RequestException:
return False
def __repr__(self) -> str:
return (
f"OllamaEmbeddings("
f"model={self.model!r}, "
f"base_url={self.base_url!r}, "
f"dimension={self._dimension!r}"
f")"
)
def _build_payload(self, text_or_texts: str | list[str]) -> dict[str, Any]:
return {
"model": self.model,
"input": text_or_texts,
"truncate": self.truncate,
"keep_alive": self.keep_alive,
}
def _post_embed(self, payload: dict[str, Any]) -> dict[str, Any]:
try:
response = requests.post(
f"{self.base_url}/api/embed",
json=payload,
timeout=self.timeout,
)
response.raise_for_status()
data = response.json()
except requests.RequestException as exc:
raise RuntimeError(
f"Failed to call Ollama embeddings endpoint at "
f"{self.base_url}/api/embed"
) from exc
if "embeddings" not in data:
raise RuntimeError(
"Ollama response did not include 'embeddings'. "
f"Response keys: {list(data.keys())}"
)
return data
def _embed_api(self, text_or_texts: str | list[str]) -> list[list[float]]:
payload = self._build_payload(text_or_texts)
data = self._post_embed(payload)
return data["embeddings"]