diff --git a/scripts/pipelines/tasks/chunk.py b/scripts/pipelines/tasks/chunk.py index d267043..ff952c6 100644 --- a/scripts/pipelines/tasks/chunk.py +++ b/scripts/pipelines/tasks/chunk.py @@ -2,11 +2,10 @@ import json from copy import deepcopy from dataclasses import replace from pathlib import Path -from typing import Any, Union +from typing import Any from chonkie import ( Chunk, - ElasticHandshake, FileFetcher, MarkdownChef, TextChef, @@ -18,6 +17,7 @@ from loguru import logger from transformers import AutoTokenizer from scripts.pipelines.tasks.embeddings import OllamaEmbeddings +from scripts.pipelines.wrappers.chonkie_wrappers import ElasticHandshakeWithMetadata from src.config import settings @@ -99,58 +99,6 @@ def _merge_markdown_document(processed_doc: MarkdownDocument) -> MarkdownDocumen return fused_processed_doc -class ElasticHandshakeWithMetadata(ElasticHandshake): - """Extended ElasticHandshake that preserves chunk metadata in Elasticsearch.""" - - def _create_bulk_actions(self, chunks: list[dict]) -> list[dict[str, Any]]: - """Generate bulk actions including metadata.""" - actions = [] - embeddings = self.embedding_model.embed_batch([chunk["chunk"].text for chunk in chunks]) - - for i, chunk in enumerate(chunks): - source = { - "text": chunk["chunk"].text, - "embedding": embeddings[i], - "start_index": chunk["chunk"].start_index, - "end_index": chunk["chunk"].end_index, - "token_count": chunk["chunk"].token_count, - } - - # Include metadata if it exists - if chunk.get("extra_metadata"): - source.update(chunk["extra_metadata"]) - - actions.append({ - "_index": self.index_name, - "_id": self._generate_id(i, chunk["chunk"]), - "_source": source, - }) - - return actions - - def write(self, chunks: Union[Chunk, list[Chunk]]) -> list[dict[str, Any]]: - """Write the chunks to the Elasticsearch index using the bulk API.""" - if isinstance(chunks, Chunk): - chunks = [chunks] - - actions = self._create_bulk_actions(chunks) - - # Use the bulk helper to efficiently write the documents - from elasticsearch.helpers import bulk - - success, errors = bulk(self.client, actions, raise_on_error=False) - - if errors: - logger.warning(f"Encountered {len(errors)} errors during bulk indexing.") # type: ignore - # Optionally log the first few errors for debugging - for i, error in enumerate(errors[:5]): # type: ignore - logger.error(f"Error {i + 1}: {error}") - - logger.info(f"Chonkie wrote {success} chunks to Elasticsearch index: {self.index_name}") - - return actions - - def fetch_documents(docs_folder_path: str, docs_extension: list[str]) -> list[Path]: """ Fetch files from a folder that match the specified extensions. diff --git a/scripts/pipelines/wrappers/chonkie_wrappers.py b/scripts/pipelines/wrappers/chonkie_wrappers.py new file mode 100644 index 0000000..63738c9 --- /dev/null +++ b/scripts/pipelines/wrappers/chonkie_wrappers.py @@ -0,0 +1,55 @@ +from typing import Any, Union + +from chonkie import Chunk, ElasticHandshake +from loguru import logger + +class ElasticHandshakeWithMetadata(ElasticHandshake): + """Extended ElasticHandshake that preserves chunk metadata in Elasticsearch.""" + + def _create_bulk_actions(self, chunks: list[dict]) -> list[dict[str, Any]]: + """Generate bulk actions including metadata.""" + actions = [] + embeddings = self.embedding_model.embed_batch([chunk["chunk"].text for chunk in chunks]) + + for i, chunk in enumerate(chunks): + source = { + "text": chunk["chunk"].text, + "embedding": embeddings[i], + "start_index": chunk["chunk"].start_index, + "end_index": chunk["chunk"].end_index, + "token_count": chunk["chunk"].token_count, + } + + # Include metadata if it exists + if chunk.get("extra_metadata"): + source.update(chunk["extra_metadata"]) + + actions.append({ + "_index": self.index_name, + "_id": self._generate_id(i, chunk["chunk"]), + "_source": source, + }) + + return actions + + def write(self, chunks: Union[Chunk, list[Chunk]]) -> list[dict[str, Any]]: + """Write the chunks to the Elasticsearch index using the bulk API.""" + if isinstance(chunks, Chunk): + chunks = [chunks] + + actions = self._create_bulk_actions(chunks) + + # Use the bulk helper to efficiently write the documents + from elasticsearch.helpers import bulk + + success, errors = bulk(self.client, actions, raise_on_error=False) + + if errors: + logger.warning(f"Encountered {len(errors)} errors during bulk indexing.") # type: ignore + # Optionally log the first few errors for debugging + for i, error in enumerate(errors[:5]): # type: ignore + logger.error(f"Error {i + 1}: {error}") + + logger.info(f"Chonkie wrote {success} chunks to Elasticsearch index: {self.index_name}") + + return actions \ No newline at end of file