from copy import deepcopy from dataclasses import replace from pathlib import Path from chonkie import ( Chunk, ElasticHandshake, FileFetcher, MarkdownChef, TextChef, TokenChunker, MarkdownDocument ) from elasticsearch import Elasticsearch from loguru import logger from transformers import AutoTokenizer from scripts.pipelines.tasks.embeddings import OllamaEmbeddings from src.config import settings def _get_text(element) -> str: for attr in ("text", "content", "markdown"): value = getattr(element, attr, None) if isinstance(value, str): return value raise AttributeError( f"Could not extract text from element of type {type(element).__name__}" ) def _merge_markdown_document(processed_doc: MarkdownDocument) -> MarkdownDocument: elements = [] for chunk in processed_doc.chunks: elements.append(("chunk", chunk.start_index, chunk.end_index, chunk)) for code in processed_doc.code: elements.append(("code", code.start_index, code.end_index, code)) for table in processed_doc.tables: elements.append(("table", table.start_index, table.end_index, table)) elements.sort(key=lambda item: (item[1], item[2])) merged_chunks = [] current_chunk = None current_parts = [] current_end_index = None current_token_count = None def flush(): nonlocal current_chunk, current_parts, current_end_index, current_token_count if current_chunk is None: return merged_text = "\n\n".join(part for part in current_parts if part) merged_chunks.append( replace( current_chunk, text=merged_text, end_index=current_end_index, token_count=current_token_count, ) ) current_chunk = None current_parts = [] current_end_index = None current_token_count = None for kind, _, _, element in elements: if kind == "chunk": flush() current_chunk = element current_parts = [_get_text(element)] current_end_index = element.end_index current_token_count = element.token_count continue if current_chunk is None: continue current_parts.append(_get_text(element)) current_end_index = max(current_end_index, element.end_index) current_token_count += getattr(element, "token_count", 0) flush() fused_processed_doc = deepcopy(processed_doc) fused_processed_doc.chunks = merged_chunks fused_processed_doc.code = processed_doc.code fused_processed_doc.tables = processed_doc.tables return fused_processed_doc def fetch_documents(docs_folder_path: str, docs_extension: list[str]) -> list[Path]: """ Fetch files from a folder that match the specified extensions. Args: docs_folder_path (str): Path to the folder containing documents docs_extension (list[str]): List of file extensions to filter by (e.g., [".md", ".avap"]) Returns: List of Paths to the fetched documents """ fetcher = FileFetcher() docs_path = fetcher.fetch(dir=f"{settings.proj_root}/{docs_folder_path}", ext=docs_extension) return docs_path def process_documents(docs_path: list[Path]) -> list[Chunk]: """ Process documents by applying appropriate chefs and chunking strategies based on file type. Args: docs_path (list[Path]): List of Paths to the documents to be processed Returns: List of processed documents ready for ingestion """ processed_docs = [] custom_tokenizer = AutoTokenizer.from_pretrained(settings.hf_emb_model_name) chef_md = MarkdownChef(tokenizer=custom_tokenizer) chef_txt = TextChef() chunker = TokenChunker(tokenizer=custom_tokenizer) for doc_path in docs_path: doc_extension = doc_path.suffix.lower() if doc_extension == ".md": processed_doc = chef_md.process(doc_path) fused_doc = _merge_markdown_document(processed_doc) processed_docs.extend(fused_doc.chunks) elif doc_extension == ".avap": processed_doc = chef_txt.process(doc_path) chunked_doc = chunker.chunk(processed_doc.content) processed_docs.extend(chunked_doc) return processed_docs def ingest_documents( chunked_docs: list[Chunk], es_index: str, es_request_timeout: int, es_max_retries: int, es_retry_on_timeout: bool, delete_es_index: bool, ) -> None: logger.info( f"Instantiating Elasticsearch client with URL: {settings.elasticsearch_local_url}..." ) es = Elasticsearch( hosts=settings.elasticsearch_local_url, request_timeout=es_request_timeout, max_retries=es_max_retries, retry_on_timeout=es_retry_on_timeout, ) if delete_es_index and es.indices.exists(index=es_index): logger.info(f"Deleting existing Elasticsearch index: {es_index}...") es.indices.delete(index=es_index) handshake = ElasticHandshake( client=es, index_name=es_index, embedding_model=OllamaEmbeddings(model=settings.ollama_emb_model_name), ) logger.info( f"Ingesting {len(chunked_docs)} chunks into Elasticsearch index: {es_index}..." ) handshake.write(chunked_docs)