feat: Implement ElasticHandshakeWithMetadata to preserve chunk metadata in Elasticsearch

This commit is contained in:
acano 2026-03-13 11:02:32 +01:00
parent 8aa12bd8eb
commit ab1022d8b6
2 changed files with 57 additions and 54 deletions

View File

@ -2,11 +2,10 @@ import json
from copy import deepcopy from copy import deepcopy
from dataclasses import replace from dataclasses import replace
from pathlib import Path from pathlib import Path
from typing import Any, Union from typing import Any
from chonkie import ( from chonkie import (
Chunk, Chunk,
ElasticHandshake,
FileFetcher, FileFetcher,
MarkdownChef, MarkdownChef,
TextChef, TextChef,
@ -18,6 +17,7 @@ from loguru import logger
from transformers import AutoTokenizer from transformers import AutoTokenizer
from scripts.pipelines.tasks.embeddings import OllamaEmbeddings from scripts.pipelines.tasks.embeddings import OllamaEmbeddings
from scripts.pipelines.wrappers.chonkie_wrappers import ElasticHandshakeWithMetadata
from src.config import settings from src.config import settings
@ -99,58 +99,6 @@ def _merge_markdown_document(processed_doc: MarkdownDocument) -> MarkdownDocumen
return fused_processed_doc return fused_processed_doc
class ElasticHandshakeWithMetadata(ElasticHandshake):
"""Extended ElasticHandshake that preserves chunk metadata in Elasticsearch."""
def _create_bulk_actions(self, chunks: list[dict]) -> list[dict[str, Any]]:
"""Generate bulk actions including metadata."""
actions = []
embeddings = self.embedding_model.embed_batch([chunk["chunk"].text for chunk in chunks])
for i, chunk in enumerate(chunks):
source = {
"text": chunk["chunk"].text,
"embedding": embeddings[i],
"start_index": chunk["chunk"].start_index,
"end_index": chunk["chunk"].end_index,
"token_count": chunk["chunk"].token_count,
}
# Include metadata if it exists
if chunk.get("extra_metadata"):
source.update(chunk["extra_metadata"])
actions.append({
"_index": self.index_name,
"_id": self._generate_id(i, chunk["chunk"]),
"_source": source,
})
return actions
def write(self, chunks: Union[Chunk, list[Chunk]]) -> list[dict[str, Any]]:
"""Write the chunks to the Elasticsearch index using the bulk API."""
if isinstance(chunks, Chunk):
chunks = [chunks]
actions = self._create_bulk_actions(chunks)
# Use the bulk helper to efficiently write the documents
from elasticsearch.helpers import bulk
success, errors = bulk(self.client, actions, raise_on_error=False)
if errors:
logger.warning(f"Encountered {len(errors)} errors during bulk indexing.") # type: ignore
# Optionally log the first few errors for debugging
for i, error in enumerate(errors[:5]): # type: ignore
logger.error(f"Error {i + 1}: {error}")
logger.info(f"Chonkie wrote {success} chunks to Elasticsearch index: {self.index_name}")
return actions
def fetch_documents(docs_folder_path: str, docs_extension: list[str]) -> list[Path]: def fetch_documents(docs_folder_path: str, docs_extension: list[str]) -> list[Path]:
""" """
Fetch files from a folder that match the specified extensions. Fetch files from a folder that match the specified extensions.

View File

@ -0,0 +1,55 @@
from typing import Any, Union
from chonkie import Chunk, ElasticHandshake
from loguru import logger
class ElasticHandshakeWithMetadata(ElasticHandshake):
"""Extended ElasticHandshake that preserves chunk metadata in Elasticsearch."""
def _create_bulk_actions(self, chunks: list[dict]) -> list[dict[str, Any]]:
"""Generate bulk actions including metadata."""
actions = []
embeddings = self.embedding_model.embed_batch([chunk["chunk"].text for chunk in chunks])
for i, chunk in enumerate(chunks):
source = {
"text": chunk["chunk"].text,
"embedding": embeddings[i],
"start_index": chunk["chunk"].start_index,
"end_index": chunk["chunk"].end_index,
"token_count": chunk["chunk"].token_count,
}
# Include metadata if it exists
if chunk.get("extra_metadata"):
source.update(chunk["extra_metadata"])
actions.append({
"_index": self.index_name,
"_id": self._generate_id(i, chunk["chunk"]),
"_source": source,
})
return actions
def write(self, chunks: Union[Chunk, list[Chunk]]) -> list[dict[str, Any]]:
"""Write the chunks to the Elasticsearch index using the bulk API."""
if isinstance(chunks, Chunk):
chunks = [chunks]
actions = self._create_bulk_actions(chunks)
# Use the bulk helper to efficiently write the documents
from elasticsearch.helpers import bulk
success, errors = bulk(self.client, actions, raise_on_error=False)
if errors:
logger.warning(f"Encountered {len(errors)} errors during bulk indexing.") # type: ignore
# Optionally log the first few errors for debugging
for i, error in enumerate(errors[:5]): # type: ignore
logger.error(f"Error {i + 1}: {error}")
logger.info(f"Chonkie wrote {success} chunks to Elasticsearch index: {self.index_name}")
return actions