assistance-engine/research/code_indexing/chunk.py

372 lines
12 KiB
Python

import json
from copy import deepcopy
from dataclasses import replace
from pathlib import Path
from typing import Any, Union
from lark import Lark, Tree
from chonkie import (
Chunk,
ElasticHandshake,
FileFetcher,
MarkdownChef,
TextChef,
TokenChunker,
MarkdownDocument
)
from elasticsearch import Elasticsearch
from loguru import logger
from transformers import AutoTokenizer
from scripts.pipelines.tasks.embeddings import OllamaEmbeddings
from src.config import settings
COMMAND_METADATA_NAMES = {
# system
"register_cmd": "registerEndpoint",
"addvar_cmd": "addVar",
"addparam_cmd": "addParam",
"getlistlen_cmd": "getListLen",
"getparamlist_cmd": "getQueryParamList",
"addresult_cmd": "addResult",
# async
"go_stmt": "go",
"gather_stmt": "gather",
# connector
"connector_instantiation": "avapConnector",
# http
"req_post_cmd": "RequestPost",
"req_get_cmd": "RequestGet",
# db
"orm_direct": "ormDirect",
"orm_check": "ormCheckTable",
"orm_create": "ormCreateTable",
"orm_select": "ormAccessSelect",
"orm_insert": "ormAccessInsert",
"orm_update": "ormAccessUpdate",
# util
"json_list_cmd": "json_list_ops",
"crypto_cmd": "crypto_ops",
"regex_cmd": "getRegex",
"datetime_cmd": "getDateTime",
"stamp_cmd": "timestamp_ops",
"string_cmd": "randomString",
"replace_cmd": "replace",
# modularity
"include_stmt": "include",
"import_stmt": "import",
# generic statements
"assignment": "assignment",
"call_stmt": "call",
"return_stmt": "return",
"if_stmt": "if",
"loop_stmt": "startLoop",
"try_stmt": "try",
"function_decl": "function",
}
def _extract_command_metadata(ast: Tree | None) -> dict[str, bool]:
if ast is None:
return {}
used_commands: set[str] = set()
for subtree in ast.iter_subtrees():
if subtree.data in COMMAND_METADATA_NAMES:
used_commands.add(COMMAND_METADATA_NAMES[subtree.data])
return {command_name: True for command_name in sorted(used_commands)}
def _get_text(element) -> str:
for attr in ("text", "content", "markdown"):
value = getattr(element, attr, None)
if isinstance(value, str):
return value
raise AttributeError(
f"Could not extract text from element of type {type(element).__name__}"
)
def _merge_markdown_document(processed_doc: MarkdownDocument) -> MarkdownDocument:
elements = []
for chunk in processed_doc.chunks:
elements.append(("chunk", chunk.start_index, chunk.end_index, chunk))
for code in processed_doc.code:
elements.append(("code", code.start_index, code.end_index, code))
for table in processed_doc.tables:
elements.append(("table", table.start_index, table.end_index, table))
elements.sort(key=lambda item: (item[1], item[2]))
merged_chunks = []
current_chunk = None
current_parts = []
current_end_index = None
current_token_count = None
def flush():
nonlocal current_chunk, current_parts, current_end_index, current_token_count
if current_chunk is None:
return
merged_text = "\n\n".join(part for part in current_parts if part)
merged_chunks.append(
replace(
current_chunk,
text=merged_text,
end_index=current_end_index,
token_count=current_token_count,
)
)
current_chunk = None
current_parts = []
current_end_index = None
current_token_count = None
for kind, _, _, element in elements:
if kind == "chunk":
flush()
current_chunk = element
current_parts = [_get_text(element)]
current_end_index = element.end_index
current_token_count = element.token_count
continue
if current_chunk is None:
continue
current_parts.append(_get_text(element))
current_end_index = max(current_end_index, element.end_index)
current_token_count += getattr(element, "token_count", 0)
flush()
fused_processed_doc = deepcopy(processed_doc)
fused_processed_doc.chunks = merged_chunks
fused_processed_doc.code = processed_doc.code
fused_processed_doc.tables = processed_doc.tables
return fused_processed_doc
class ElasticHandshakeWithMetadata(ElasticHandshake):
"""Extended ElasticHandshake that preserves chunk metadata in Elasticsearch."""
def _create_bulk_actions(self, chunks: list[dict]) -> list[dict[str, Any]]:
"""Generate bulk actions including metadata."""
actions = []
embeddings = self.embedding_model.embed_batch([chunk["chunk"].text for chunk in chunks])
for i, chunk in enumerate(chunks):
source = {
"text": chunk["chunk"].text,
"embedding": embeddings[i],
"start_index": chunk["chunk"].start_index,
"end_index": chunk["chunk"].end_index,
"token_count": chunk["chunk"].token_count,
}
# Include metadata if it exists
if chunk.get("extra_metadata"):
source.update(chunk["extra_metadata"])
actions.append({
"_index": self.index_name,
"_id": self._generate_id(i, chunk["chunk"]),
"_source": source,
})
return actions
def write(self, chunks: Union[Chunk, list[Chunk]]) -> list[dict[str, Any]]:
"""Write the chunks to the Elasticsearch index using the bulk API."""
if isinstance(chunks, Chunk):
chunks = [chunks]
actions = self._create_bulk_actions(chunks)
# Use the bulk helper to efficiently write the documents
from elasticsearch.helpers import bulk
success, errors = bulk(self.client, actions, raise_on_error=False)
if errors:
logger.warning(f"Encountered {len(errors)} errors during bulk indexing.") # type: ignore
# Optionally log the first few errors for debugging
for i, error in enumerate(errors[:5]): # type: ignore
logger.error(f"Error {i + 1}: {error}")
logger.info(f"Chonkie wrote {success} chunks to Elasticsearch index: {self.index_name}")
return actions
def fetch_documents(docs_folder_path: str, docs_extension: list[str]) -> list[Path]:
"""
Fetch files from a folder that match the specified extensions.
Args:
docs_folder_path (str): Path to the folder containing documents
docs_extension (list[str]): List of file extensions to filter by (e.g., [".md", ".avap"])
Returns:
List of Paths to the fetched documents
"""
fetcher = FileFetcher()
docs_path = fetcher.fetch(dir=f"{settings.proj_root}/{docs_folder_path}", ext=docs_extension)
return docs_path
def process_documents(docs_path: list[Path]) -> list[dict[str, Any]]:
"""
Process documents by applying appropriate chefs and chunking strategies based on file type.
Args:
docs_path: List of Paths to the documents to be processed.
Returns:
List of dicts with "chunk" (Chunk object) and "extra_metadata" (dict with file info).
"""
processed_docs = []
custom_tokenizer = AutoTokenizer.from_pretrained(settings.hf_emb_model_name)
chef_md = MarkdownChef(tokenizer=custom_tokenizer)
chef_txt = TextChef()
chunker = TokenChunker(tokenizer=custom_tokenizer)
with open(settings.proj_root / "research/code_indexing/BNF/avap.lark", encoding="utf-8") as grammar:
lark_parser = Lark(
grammar.read(),
parser="lalr",
propagate_positions=True,
start="program",
)
for doc_path in docs_path:
doc_extension = doc_path.suffix.lower()
if doc_extension == ".md":
processed_doc = chef_md.process(doc_path)
fused_doc = _merge_markdown_document(processed_doc)
chunked_doc = fused_doc.chunks
specific_metadata = {
"file_type": "avap_docs",
"filename": doc_path.name,
}
elif doc_extension == ".avap":
processed_doc = chef_txt.process(doc_path)
try:
ast = lark_parser.parse(processed_doc.content)
except Exception as e:
logger.error(f"Error parsing AVAP code in {doc_path.name}: {e}")
ast = None
chunked_doc = chunker.chunk(processed_doc.content)
specific_metadata = {
"file_type": "avap_code",
"filename": doc_path.name,
**_extract_command_metadata(ast),
}
else:
continue
for chunk in chunked_doc:
processed_docs.append(
{
"chunk": chunk,
"extra_metadata": {**specific_metadata},
}
)
return processed_docs
def ingest_documents(
chunked_docs: list[dict[str, Chunk | dict[str, Any]]],
es_index: str,
es_request_timeout: int,
es_max_retries: int,
es_retry_on_timeout: bool,
delete_es_index: bool,
) -> list[dict[str, Any]]:
"""
Ingest processed documents into an Elasticsearch index.
Args:
chunked_docs (list[dict[str, Any]]): List of dicts with "chunk" and "metadata" keys
es_index (str): Name of the Elasticsearch index to ingest into
es_request_timeout (int): Timeout for Elasticsearch requests in seconds
es_max_retries (int): Maximum number of retries for Elasticsearch requests
es_retry_on_timeout (bool): Whether to retry on Elasticsearch request timeouts
delete_es_index (bool): Whether to delete the existing Elasticsearch index before ingestion
Returns:
List of dicts with Elasticsearch response for each chunk
"""
logger.info(
f"Instantiating Elasticsearch client with URL: {settings.elasticsearch_local_url}..."
)
es = Elasticsearch(
hosts=settings.elasticsearch_local_url,
request_timeout=es_request_timeout,
max_retries=es_max_retries,
retry_on_timeout=es_retry_on_timeout,
)
if delete_es_index and es.indices.exists(index=es_index):
logger.info(f"Deleting existing Elasticsearch index: {es_index}...")
es.indices.delete(index=es_index)
handshake = ElasticHandshakeWithMetadata(
client=es,
index_name=es_index,
embedding_model=OllamaEmbeddings(model=settings.ollama_emb_model_name),
)
logger.info(
f"Ingesting {len(chunked_docs)} chunks into Elasticsearch index: {es_index}..."
)
elasticsearch_chunks = handshake.write(chunked_docs)
return elasticsearch_chunks
def export_documents(elasticsearch_chunks: list[dict[str, Any]], output_path: str) -> None:
"""
Export processed documents to JSON files in the specified output folder.
Args:
elasticsearch_chunks (list[dict[str, Any]]): List of dicts with Elasticsearch response for each chunk
output_path (str): Path to the file where the JSON will be saved
Returns:
None
"""
output_path = settings.proj_root / output_path
for chunk in elasticsearch_chunks:
chunk["_source"]["embedding"] = chunk["_source"]["embedding"].tolist() # For JSON serialization
with output_path.open("w", encoding="utf-8") as f:
json.dump(elasticsearch_chunks, f, ensure_ascii=False, indent=4)
logger.info(f"Exported processed documents to {output_path}")