feat: Implement ElasticHandshakeWithMetadata to preserve chunk metadata in Elasticsearch
This commit is contained in:
parent
8aa12bd8eb
commit
ab1022d8b6
|
|
@ -2,11 +2,10 @@ import json
|
|||
from copy import deepcopy
|
||||
from dataclasses import replace
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
from chonkie import (
|
||||
Chunk,
|
||||
ElasticHandshake,
|
||||
FileFetcher,
|
||||
MarkdownChef,
|
||||
TextChef,
|
||||
|
|
@ -18,6 +17,7 @@ from loguru import logger
|
|||
from transformers import AutoTokenizer
|
||||
|
||||
from scripts.pipelines.tasks.embeddings import OllamaEmbeddings
|
||||
from scripts.pipelines.wrappers.chonkie_wrappers import ElasticHandshakeWithMetadata
|
||||
from src.config import settings
|
||||
|
||||
|
||||
|
|
@ -99,58 +99,6 @@ def _merge_markdown_document(processed_doc: MarkdownDocument) -> MarkdownDocumen
|
|||
return fused_processed_doc
|
||||
|
||||
|
||||
class ElasticHandshakeWithMetadata(ElasticHandshake):
|
||||
"""Extended ElasticHandshake that preserves chunk metadata in Elasticsearch."""
|
||||
|
||||
def _create_bulk_actions(self, chunks: list[dict]) -> list[dict[str, Any]]:
|
||||
"""Generate bulk actions including metadata."""
|
||||
actions = []
|
||||
embeddings = self.embedding_model.embed_batch([chunk["chunk"].text for chunk in chunks])
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
source = {
|
||||
"text": chunk["chunk"].text,
|
||||
"embedding": embeddings[i],
|
||||
"start_index": chunk["chunk"].start_index,
|
||||
"end_index": chunk["chunk"].end_index,
|
||||
"token_count": chunk["chunk"].token_count,
|
||||
}
|
||||
|
||||
# Include metadata if it exists
|
||||
if chunk.get("extra_metadata"):
|
||||
source.update(chunk["extra_metadata"])
|
||||
|
||||
actions.append({
|
||||
"_index": self.index_name,
|
||||
"_id": self._generate_id(i, chunk["chunk"]),
|
||||
"_source": source,
|
||||
})
|
||||
|
||||
return actions
|
||||
|
||||
def write(self, chunks: Union[Chunk, list[Chunk]]) -> list[dict[str, Any]]:
|
||||
"""Write the chunks to the Elasticsearch index using the bulk API."""
|
||||
if isinstance(chunks, Chunk):
|
||||
chunks = [chunks]
|
||||
|
||||
actions = self._create_bulk_actions(chunks)
|
||||
|
||||
# Use the bulk helper to efficiently write the documents
|
||||
from elasticsearch.helpers import bulk
|
||||
|
||||
success, errors = bulk(self.client, actions, raise_on_error=False)
|
||||
|
||||
if errors:
|
||||
logger.warning(f"Encountered {len(errors)} errors during bulk indexing.") # type: ignore
|
||||
# Optionally log the first few errors for debugging
|
||||
for i, error in enumerate(errors[:5]): # type: ignore
|
||||
logger.error(f"Error {i + 1}: {error}")
|
||||
|
||||
logger.info(f"Chonkie wrote {success} chunks to Elasticsearch index: {self.index_name}")
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
def fetch_documents(docs_folder_path: str, docs_extension: list[str]) -> list[Path]:
|
||||
"""
|
||||
Fetch files from a folder that match the specified extensions.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,55 @@
|
|||
from typing import Any, Union
|
||||
|
||||
from chonkie import Chunk, ElasticHandshake
|
||||
from loguru import logger
|
||||
|
||||
class ElasticHandshakeWithMetadata(ElasticHandshake):
|
||||
"""Extended ElasticHandshake that preserves chunk metadata in Elasticsearch."""
|
||||
|
||||
def _create_bulk_actions(self, chunks: list[dict]) -> list[dict[str, Any]]:
|
||||
"""Generate bulk actions including metadata."""
|
||||
actions = []
|
||||
embeddings = self.embedding_model.embed_batch([chunk["chunk"].text for chunk in chunks])
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
source = {
|
||||
"text": chunk["chunk"].text,
|
||||
"embedding": embeddings[i],
|
||||
"start_index": chunk["chunk"].start_index,
|
||||
"end_index": chunk["chunk"].end_index,
|
||||
"token_count": chunk["chunk"].token_count,
|
||||
}
|
||||
|
||||
# Include metadata if it exists
|
||||
if chunk.get("extra_metadata"):
|
||||
source.update(chunk["extra_metadata"])
|
||||
|
||||
actions.append({
|
||||
"_index": self.index_name,
|
||||
"_id": self._generate_id(i, chunk["chunk"]),
|
||||
"_source": source,
|
||||
})
|
||||
|
||||
return actions
|
||||
|
||||
def write(self, chunks: Union[Chunk, list[Chunk]]) -> list[dict[str, Any]]:
|
||||
"""Write the chunks to the Elasticsearch index using the bulk API."""
|
||||
if isinstance(chunks, Chunk):
|
||||
chunks = [chunks]
|
||||
|
||||
actions = self._create_bulk_actions(chunks)
|
||||
|
||||
# Use the bulk helper to efficiently write the documents
|
||||
from elasticsearch.helpers import bulk
|
||||
|
||||
success, errors = bulk(self.client, actions, raise_on_error=False)
|
||||
|
||||
if errors:
|
||||
logger.warning(f"Encountered {len(errors)} errors during bulk indexing.") # type: ignore
|
||||
# Optionally log the first few errors for debugging
|
||||
for i, error in enumerate(errors[:5]): # type: ignore
|
||||
logger.error(f"Error {i + 1}: {error}")
|
||||
|
||||
logger.info(f"Chonkie wrote {success} chunks to Elasticsearch index: {self.index_name}")
|
||||
|
||||
return actions
|
||||
Loading…
Reference in New Issue