assistance-engine/scripts/pipelines/wrappers/chonkie_wrappers.py

55 lines
2.1 KiB
Python

from typing import Any, Union
from chonkie import Chunk, ElasticHandshake
from loguru import logger
class ElasticHandshakeWithMetadata(ElasticHandshake):
"""Extended ElasticHandshake that preserves chunk metadata in Elasticsearch."""
def _create_bulk_actions(self, chunks: list[dict]) -> list[dict[str, Any]]:
"""Generate bulk actions including metadata."""
actions = []
embeddings = self.embedding_model.embed_batch([chunk["chunk"].text for chunk in chunks])
for i, chunk in enumerate(chunks):
source = {
"text": chunk["chunk"].text,
"embedding": embeddings[i],
"start_index": chunk["chunk"].start_index,
"end_index": chunk["chunk"].end_index,
"token_count": chunk["chunk"].token_count,
}
# Include metadata if it exists
if chunk.get("extra_metadata"):
source.update(chunk["extra_metadata"])
actions.append({
"_index": self.index_name,
"_id": self._generate_id(i, chunk["chunk"]),
"_source": source,
})
return actions
def write(self, chunks: Union[Chunk, list[Chunk]]) -> list[dict[str, Any]]:
"""Write the chunks to the Elasticsearch index using the bulk API."""
if isinstance(chunks, Chunk):
chunks = [chunks]
actions = self._create_bulk_actions(chunks)
# Use the bulk helper to efficiently write the documents
from elasticsearch.helpers import bulk
success, errors = bulk(self.client, actions, raise_on_error=False)
if errors:
logger.warning(f"Encountered {len(errors)} errors during bulk indexing.") # type: ignore
# Optionally log the first few errors for debugging
for i, error in enumerate(errors[:5]): # type: ignore
logger.error(f"Error {i + 1}: {error}")
logger.info(f"Chonkie wrote {success} chunks to Elasticsearch index: {self.index_name}")
return actions