from typing import Any, Union from chonkie import Chunk, ElasticHandshake from loguru import logger class ElasticHandshakeWithMetadata(ElasticHandshake): """Extended ElasticHandshake that preserves chunk metadata in Elasticsearch.""" def _create_bulk_actions(self, chunks: list[dict]) -> list[dict[str, Any]]: """Generate bulk actions including metadata.""" actions = [] embeddings = self.embedding_model.embed_batch([chunk["chunk"].text for chunk in chunks]) for i, chunk in enumerate(chunks): source = { "text": chunk["chunk"].text, "embedding": embeddings[i], "start_index": chunk["chunk"].start_index, "end_index": chunk["chunk"].end_index, "token_count": chunk["chunk"].token_count, } # Include metadata if it exists if chunk.get("extra_metadata"): source.update(chunk["extra_metadata"]) actions.append({ "_index": self.index_name, "_id": self._generate_id(i, chunk["chunk"]), "_source": source, }) return actions def write(self, chunks: Union[Chunk, list[Chunk]]) -> list[dict[str, Any]]: """Write the chunks to the Elasticsearch index using the bulk API.""" if isinstance(chunks, Chunk): chunks = [chunks] actions = self._create_bulk_actions(chunks) # Use the bulk helper to efficiently write the documents from elasticsearch.helpers import bulk success, errors = bulk(self.client, actions, raise_on_error=False) if errors: logger.warning(f"Encountered {len(errors)} errors during bulk indexing.") # type: ignore # Optionally log the first few errors for debugging for i, error in enumerate(errors[:5]): # type: ignore logger.error(f"Error {i + 1}: {error}") logger.info(f"Chonkie wrote {success} chunks to Elasticsearch index: {self.index_name}") return actions