assistance-engine/scripts/pipelines/tasks/chunk.py

278 lines
9.1 KiB
Python

import json
from copy import deepcopy
from dataclasses import replace
from pathlib import Path
from typing import Any, Union
from chonkie import (
Chunk,
ElasticHandshake,
FileFetcher,
MarkdownChef,
TextChef,
TokenChunker,
MarkdownDocument
)
from elasticsearch import Elasticsearch
from loguru import logger
from transformers import AutoTokenizer
from scripts.pipelines.tasks.embeddings import OllamaEmbeddings
from src.config import settings
def _get_text(element) -> str:
for attr in ("text", "content", "markdown"):
value = getattr(element, attr, None)
if isinstance(value, str):
return value
raise AttributeError(
f"Could not extract text from element of type {type(element).__name__}"
)
def _merge_markdown_document(processed_doc: MarkdownDocument) -> MarkdownDocument:
elements = []
for chunk in processed_doc.chunks:
elements.append(("chunk", chunk.start_index, chunk.end_index, chunk))
for code in processed_doc.code:
elements.append(("code", code.start_index, code.end_index, code))
for table in processed_doc.tables:
elements.append(("table", table.start_index, table.end_index, table))
elements.sort(key=lambda item: (item[1], item[2]))
merged_chunks = []
current_chunk = None
current_parts = []
current_end_index = None
current_token_count = None
def flush():
nonlocal current_chunk, current_parts, current_end_index, current_token_count
if current_chunk is None:
return
merged_text = "\n\n".join(part for part in current_parts if part)
merged_chunks.append(
replace(
current_chunk,
text=merged_text,
end_index=current_end_index,
token_count=current_token_count,
)
)
current_chunk = None
current_parts = []
current_end_index = None
current_token_count = None
for kind, _, _, element in elements:
if kind == "chunk":
flush()
current_chunk = element
current_parts = [_get_text(element)]
current_end_index = element.end_index
current_token_count = element.token_count
continue
if current_chunk is None:
continue
current_parts.append(_get_text(element))
current_end_index = max(current_end_index, element.end_index)
current_token_count += getattr(element, "token_count", 0)
flush()
fused_processed_doc = deepcopy(processed_doc)
fused_processed_doc.chunks = merged_chunks
fused_processed_doc.code = processed_doc.code
fused_processed_doc.tables = processed_doc.tables
return fused_processed_doc
class ElasticHandshakeWithMetadata(ElasticHandshake):
"""Extended ElasticHandshake that preserves chunk metadata in Elasticsearch."""
def _create_bulk_actions(self, chunks: list[dict]) -> list[dict[str, Any]]:
"""Generate bulk actions including metadata."""
actions = []
embeddings = self.embedding_model.embed_batch([chunk["chunk"].text for chunk in chunks])
for i, chunk in enumerate(chunks):
source = {
"text": chunk["chunk"].text,
"embedding": embeddings[i],
"start_index": chunk["chunk"].start_index,
"end_index": chunk["chunk"].end_index,
"token_count": chunk["chunk"].token_count,
}
# Include metadata if it exists
if chunk.get("extra_metadata"):
source.update(chunk["extra_metadata"])
actions.append({
"_index": self.index_name,
"_id": self._generate_id(i, chunk["chunk"]),
"_source": source,
})
return actions
def write(self, chunks: Union[Chunk, list[Chunk]]) -> list[dict[str, Any]]:
"""Write the chunks to the Elasticsearch index using the bulk API."""
if isinstance(chunks, Chunk):
chunks = [chunks]
actions = self._create_bulk_actions(chunks)
# Use the bulk helper to efficiently write the documents
from elasticsearch.helpers import bulk
success, errors = bulk(self.client, actions, raise_on_error=False)
if errors:
logger.warning(f"Encountered {len(errors)} errors during bulk indexing.") # type: ignore
# Optionally log the first few errors for debugging
for i, error in enumerate(errors[:5]): # type: ignore
logger.error(f"Error {i + 1}: {error}")
logger.info(f"Chonkie wrote {success} chunks to Elasticsearch index: {self.index_name}")
return actions
def fetch_documents(docs_folder_path: str, docs_extension: list[str]) -> list[Path]:
"""
Fetch files from a folder that match the specified extensions.
Args:
docs_folder_path (str): Path to the folder containing documents
docs_extension (list[str]): List of file extensions to filter by (e.g., [".md", ".avap"])
Returns:
List of Paths to the fetched documents
"""
fetcher = FileFetcher()
docs_path = fetcher.fetch(dir=f"{settings.proj_root}/{docs_folder_path}", ext=docs_extension)
return docs_path
def process_documents(docs_path: list[Path]) -> list[dict[str, Chunk | dict[str, Any]]]:
"""
Process documents by applying appropriate chefs and chunking strategies based on file type.
Args:
docs_path (list[Path]): List of Paths to the documents to be processed
Returns:
List of dicts with "chunk" (Chunk object) and "metadata" (dict with file info)
"""
processed_docs = []
custom_tokenizer = AutoTokenizer.from_pretrained(settings.hf_emb_model_name)
chef_md = MarkdownChef(tokenizer=custom_tokenizer)
chef_txt = TextChef()
chunker = TokenChunker(tokenizer=custom_tokenizer)
for doc_path in docs_path:
doc_extension = doc_path.suffix.lower()
filename = doc_path.name
if doc_extension == ".md":
processed_doc = chef_md.process(doc_path)
fused_doc = _merge_markdown_document(processed_doc)
chunked_doc = fused_doc.chunks
elif doc_extension == ".avap":
processed_doc = chef_txt.process(doc_path)
chunked_doc = chunker.chunk(processed_doc.content)
else:
continue
for chunk in chunked_doc:
processed_docs.append({
"chunk": chunk,
"extra_metadata": {"file": filename}
})
return processed_docs
def ingest_documents(
chunked_docs: list[dict[str, Chunk | dict[str, Any]]],
es_index: str,
es_request_timeout: int,
es_max_retries: int,
es_retry_on_timeout: bool,
delete_es_index: bool,
) -> list[dict[str, Any]]:
"""
Ingest processed documents into an Elasticsearch index.
Args:
chunked_docs (list[dict[str, Any]]): List of dicts with "chunk" and "metadata" keys
es_index (str): Name of the Elasticsearch index to ingest into
es_request_timeout (int): Timeout for Elasticsearch requests in seconds
es_max_retries (int): Maximum number of retries for Elasticsearch requests
es_retry_on_timeout (bool): Whether to retry on Elasticsearch request timeouts
delete_es_index (bool): Whether to delete the existing Elasticsearch index before ingestion
Returns:
List of dicts with Elasticsearch response for each chunk
"""
logger.info(
f"Instantiating Elasticsearch client with URL: {settings.elasticsearch_local_url}..."
)
es = Elasticsearch(
hosts=settings.elasticsearch_local_url,
request_timeout=es_request_timeout,
max_retries=es_max_retries,
retry_on_timeout=es_retry_on_timeout,
)
if delete_es_index and es.indices.exists(index=es_index):
logger.info(f"Deleting existing Elasticsearch index: {es_index}...")
es.indices.delete(index=es_index)
handshake = ElasticHandshakeWithMetadata(
client=es,
index_name=es_index,
embedding_model=OllamaEmbeddings(model=settings.ollama_emb_model_name),
)
logger.info(
f"Ingesting {len(chunked_docs)} chunks into Elasticsearch index: {es_index}..."
)
elasticsearch_chunks = handshake.write(chunked_docs)
return elasticsearch_chunks
def export_documents(elasticsearch_chunks: list[dict[str, Any]], output_path: str) -> None:
"""
Export processed documents to JSON files in the specified output folder.
Args:
elasticsearch_chunks (list[dict[str, Any]]): List of dicts with Elasticsearch response for each chunk
output_path (str): Path to the file where the JSON will be saved
Returns:
None
"""
output_path = settings.proj_root / output_path
for chunk in elasticsearch_chunks:
chunk["_source"]["embedding"] = chunk["_source"]["embedding"].tolist() # For JSON serialization
with output_path.open("w", encoding="utf-8") as f:
json.dump(elasticsearch_chunks, f, ensure_ascii=False, indent=4)
logger.info(f"Exported processed documents to {output_path}")