import json from copy import deepcopy from dataclasses import replace from pathlib import Path from typing import Any, Union from lark import Lark, Tree from chonkie import ( Chunk, ElasticHandshake, FileFetcher, MarkdownChef, TextChef, TokenChunker, MarkdownDocument ) from elasticsearch import Elasticsearch from loguru import logger from transformers import AutoTokenizer from scripts.pipelines.tasks.embeddings import OllamaEmbeddings from src.config import settings COMMAND_METADATA_NAMES = { # system "register_cmd": "registerEndpoint", "addvar_cmd": "addVar", "addparam_cmd": "addParam", "getlistlen_cmd": "getListLen", "getparamlist_cmd": "getQueryParamList", "addresult_cmd": "addResult", # async "go_stmt": "go", "gather_stmt": "gather", # connector "connector_instantiation": "avapConnector", # http "req_post_cmd": "RequestPost", "req_get_cmd": "RequestGet", # db "orm_direct": "ormDirect", "orm_check": "ormCheckTable", "orm_create": "ormCreateTable", "orm_select": "ormAccessSelect", "orm_insert": "ormAccessInsert", "orm_update": "ormAccessUpdate", # util "json_list_cmd": "json_list_ops", "crypto_cmd": "crypto_ops", "regex_cmd": "getRegex", "datetime_cmd": "getDateTime", "stamp_cmd": "timestamp_ops", "string_cmd": "randomString", "replace_cmd": "replace", # modularity "include_stmt": "include", "import_stmt": "import", # generic statements "assignment": "assignment", "call_stmt": "call", "return_stmt": "return", "if_stmt": "if", "loop_stmt": "startLoop", "try_stmt": "try", "function_decl": "function", } def _extract_command_metadata(ast: Tree | None) -> dict[str, bool]: if ast is None: return {} used_commands: set[str] = set() for subtree in ast.iter_subtrees(): if subtree.data in COMMAND_METADATA_NAMES: used_commands.add(COMMAND_METADATA_NAMES[subtree.data]) return {command_name: True for command_name in sorted(used_commands)} def _get_text(element) -> str: for attr in ("text", "content", "markdown"): value = getattr(element, attr, None) if isinstance(value, str): return value raise AttributeError( f"Could not extract text from element of type {type(element).__name__}" ) def _merge_markdown_document(processed_doc: MarkdownDocument) -> MarkdownDocument: elements = [] for chunk in processed_doc.chunks: elements.append(("chunk", chunk.start_index, chunk.end_index, chunk)) for code in processed_doc.code: elements.append(("code", code.start_index, code.end_index, code)) for table in processed_doc.tables: elements.append(("table", table.start_index, table.end_index, table)) elements.sort(key=lambda item: (item[1], item[2])) merged_chunks = [] current_chunk = None current_parts = [] current_end_index = None current_token_count = None def flush(): nonlocal current_chunk, current_parts, current_end_index, current_token_count if current_chunk is None: return merged_text = "\n\n".join(part for part in current_parts if part) merged_chunks.append( replace( current_chunk, text=merged_text, end_index=current_end_index, token_count=current_token_count, ) ) current_chunk = None current_parts = [] current_end_index = None current_token_count = None for kind, _, _, element in elements: if kind == "chunk": flush() current_chunk = element current_parts = [_get_text(element)] current_end_index = element.end_index current_token_count = element.token_count continue if current_chunk is None: continue current_parts.append(_get_text(element)) current_end_index = max(current_end_index, element.end_index) current_token_count += getattr(element, "token_count", 0) flush() fused_processed_doc = deepcopy(processed_doc) fused_processed_doc.chunks = merged_chunks fused_processed_doc.code = processed_doc.code fused_processed_doc.tables = processed_doc.tables return fused_processed_doc class ElasticHandshakeWithMetadata(ElasticHandshake): """Extended ElasticHandshake that preserves chunk metadata in Elasticsearch.""" def _create_bulk_actions(self, chunks: list[dict]) -> list[dict[str, Any]]: """Generate bulk actions including metadata.""" actions = [] embeddings = self.embedding_model.embed_batch([chunk["chunk"].text for chunk in chunks]) for i, chunk in enumerate(chunks): source = { "text": chunk["chunk"].text, "embedding": embeddings[i], "start_index": chunk["chunk"].start_index, "end_index": chunk["chunk"].end_index, "token_count": chunk["chunk"].token_count, } # Include metadata if it exists if chunk.get("extra_metadata"): source.update(chunk["extra_metadata"]) actions.append({ "_index": self.index_name, "_id": self._generate_id(i, chunk["chunk"]), "_source": source, }) return actions def write(self, chunks: Union[Chunk, list[Chunk]]) -> list[dict[str, Any]]: """Write the chunks to the Elasticsearch index using the bulk API.""" if isinstance(chunks, Chunk): chunks = [chunks] actions = self._create_bulk_actions(chunks) # Use the bulk helper to efficiently write the documents from elasticsearch.helpers import bulk success, errors = bulk(self.client, actions, raise_on_error=False) if errors: logger.warning(f"Encountered {len(errors)} errors during bulk indexing.") # type: ignore # Optionally log the first few errors for debugging for i, error in enumerate(errors[:5]): # type: ignore logger.error(f"Error {i + 1}: {error}") logger.info(f"Chonkie wrote {success} chunks to Elasticsearch index: {self.index_name}") return actions def fetch_documents(docs_folder_path: str, docs_extension: list[str]) -> list[Path]: """ Fetch files from a folder that match the specified extensions. Args: docs_folder_path (str): Path to the folder containing documents docs_extension (list[str]): List of file extensions to filter by (e.g., [".md", ".avap"]) Returns: List of Paths to the fetched documents """ fetcher = FileFetcher() docs_path = fetcher.fetch(dir=f"{settings.proj_root}/{docs_folder_path}", ext=docs_extension) return docs_path def process_documents(docs_path: list[Path]) -> list[dict[str, Any]]: """ Process documents by applying appropriate chefs and chunking strategies based on file type. Args: docs_path: List of Paths to the documents to be processed. Returns: List of dicts with "chunk" (Chunk object) and "extra_metadata" (dict with file info). """ processed_docs = [] custom_tokenizer = AutoTokenizer.from_pretrained(settings.hf_emb_model_name) chef_md = MarkdownChef(tokenizer=custom_tokenizer) chef_txt = TextChef() chunker = TokenChunker(tokenizer=custom_tokenizer) with open(settings.proj_root / "research/code_indexing/BNF/avap.lark", encoding="utf-8") as grammar: lark_parser = Lark( grammar.read(), parser="lalr", propagate_positions=True, start="program", ) for doc_path in docs_path: doc_extension = doc_path.suffix.lower() if doc_extension == ".md": processed_doc = chef_md.process(doc_path) fused_doc = _merge_markdown_document(processed_doc) chunked_doc = fused_doc.chunks specific_metadata = { "file_type": "avap_docs", "filename": doc_path.name, } elif doc_extension == ".avap": processed_doc = chef_txt.process(doc_path) try: ast = lark_parser.parse(processed_doc.content) except Exception as e: logger.error(f"Error parsing AVAP code in {doc_path.name}: {e}") ast = None chunked_doc = chunker.chunk(processed_doc.content) specific_metadata = { "file_type": "avap_code", "filename": doc_path.name, **_extract_command_metadata(ast), } else: continue for chunk in chunked_doc: processed_docs.append( { "chunk": chunk, "extra_metadata": {**specific_metadata}, } ) return processed_docs def ingest_documents( chunked_docs: list[dict[str, Chunk | dict[str, Any]]], es_index: str, es_request_timeout: int, es_max_retries: int, es_retry_on_timeout: bool, delete_es_index: bool, ) -> list[dict[str, Any]]: """ Ingest processed documents into an Elasticsearch index. Args: chunked_docs (list[dict[str, Any]]): List of dicts with "chunk" and "metadata" keys es_index (str): Name of the Elasticsearch index to ingest into es_request_timeout (int): Timeout for Elasticsearch requests in seconds es_max_retries (int): Maximum number of retries for Elasticsearch requests es_retry_on_timeout (bool): Whether to retry on Elasticsearch request timeouts delete_es_index (bool): Whether to delete the existing Elasticsearch index before ingestion Returns: List of dicts with Elasticsearch response for each chunk """ logger.info( f"Instantiating Elasticsearch client with URL: {settings.elasticsearch_local_url}..." ) es = Elasticsearch( hosts=settings.elasticsearch_local_url, request_timeout=es_request_timeout, max_retries=es_max_retries, retry_on_timeout=es_retry_on_timeout, ) if delete_es_index and es.indices.exists(index=es_index): logger.info(f"Deleting existing Elasticsearch index: {es_index}...") es.indices.delete(index=es_index) logger.info(f"Using {settings.ollama_emb_model_name} for embeddings...") handshake = ElasticHandshakeWithMetadata( client=es, index_name=es_index, embedding_model=OllamaEmbeddings(model=settings.ollama_emb_model_name), ) logger.info( f"Ingesting {len(chunked_docs)} chunks into Elasticsearch index: {es_index}..." ) elasticsearch_chunks = handshake.write(chunked_docs) return elasticsearch_chunks def export_documents(elasticsearch_chunks: list[dict[str, Any]], output_path: str) -> None: """ Export processed documents to JSON files in the specified output folder. Args: elasticsearch_chunks (list[dict[str, Any]]): List of dicts with Elasticsearch response for each chunk output_path (str): Path to the file where the JSON will be saved Returns: None """ output_path = settings.proj_root / output_path for chunk in elasticsearch_chunks: chunk["_source"]["embedding"] = chunk["_source"]["embedding"].tolist() # For JSON serialization with output_path.open("w", encoding="utf-8") as f: json.dump(elasticsearch_chunks, f, ensure_ascii=False, indent=4) logger.info(f"Exported processed documents to {output_path}")