55 lines
2.1 KiB
Python
55 lines
2.1 KiB
Python
from typing import Any, Union
|
|
|
|
from chonkie import Chunk, ElasticHandshake
|
|
from loguru import logger
|
|
|
|
class ElasticHandshakeWithMetadata(ElasticHandshake):
|
|
"""Extended ElasticHandshake that preserves chunk metadata in Elasticsearch."""
|
|
|
|
def _create_bulk_actions(self, chunks: list[dict]) -> list[dict[str, Any]]:
|
|
"""Generate bulk actions including metadata."""
|
|
actions = []
|
|
embeddings = self.embedding_model.embed_batch([chunk["chunk"].text for chunk in chunks])
|
|
|
|
for i, chunk in enumerate(chunks):
|
|
source = {
|
|
"text": chunk["chunk"].text,
|
|
"embedding": embeddings[i],
|
|
"start_index": chunk["chunk"].start_index,
|
|
"end_index": chunk["chunk"].end_index,
|
|
"token_count": chunk["chunk"].token_count,
|
|
}
|
|
|
|
# Include metadata if it exists
|
|
if chunk.get("extra_metadata"):
|
|
source.update(chunk["extra_metadata"])
|
|
|
|
actions.append({
|
|
"_index": self.index_name,
|
|
"_id": self._generate_id(i, chunk["chunk"]),
|
|
"_source": source,
|
|
})
|
|
|
|
return actions
|
|
|
|
def write(self, chunks: Union[Chunk, list[Chunk]]) -> list[dict[str, Any]]:
|
|
"""Write the chunks to the Elasticsearch index using the bulk API."""
|
|
if isinstance(chunks, Chunk):
|
|
chunks = [chunks]
|
|
|
|
actions = self._create_bulk_actions(chunks)
|
|
|
|
# Use the bulk helper to efficiently write the documents
|
|
from elasticsearch.helpers import bulk
|
|
|
|
success, errors = bulk(self.client, actions, raise_on_error=False)
|
|
|
|
if errors:
|
|
logger.warning(f"Encountered {len(errors)} errors during bulk indexing.") # type: ignore
|
|
# Optionally log the first few errors for debugging
|
|
for i, error in enumerate(errors[:5]): # type: ignore
|
|
logger.error(f"Error {i + 1}: {error}")
|
|
|
|
logger.info(f"Chonkie wrote {success} chunks to Elasticsearch index: {self.index_name}")
|
|
|
|
return actions |