assistance-engine/scratches/acano/elasticsearch_ingestion_v2.py

122 lines
4.8 KiB
Python

import typer
import logging
from loguru import logger
from elasticsearch import Elasticsearch
from chonkie import MarkdownChef, FileFetcher, ElasticHandshake
from transformers import AutoTokenizer
from src.config import settings
from scripts.pipelines.tasks.embeddings import OllamaEmbeddings
from scripts.pipelines.tasks.chunk import merge_markdown_document
app = typer.Typer()
def get_processing_and_chunking_config(docs_extension: str, chunk_size: int,
chunk_threshold: float | None,
chunk_similarity_window: int| None,
chunk_skip_window: int | None) -> tuple[str, dict, str, dict]:
"""
Check the file extension and return the appropriate processing and chunking strategies and their kwargs.
Args:
docs_extension (str): The file extension of the documents to be ingested.
chunk_size (int): The size of the chunks to be created.
chunk_threshold (float, optional): The threshold for semantic chunking. Required if docs_extension is .md.
chunk_similarity_window (int, optional): The similarity window for semantic chunking
chunk_skip_window (int, optional): The skip window for semantic chunking.
Returns:
tuple[str, dict, str, dict]: A tuple containing the processing strategy, its kwargs, the chunking strategy, and its kwargs.
"""
if docs_extension == ".md":
process_type = "markdown"
custom_tokenizer = AutoTokenizer.from_pretrained(settings.hf_emb_model_name)
process_kwargs = {"tokenizer": custom_tokenizer}
# process_type = "text"
# process_kwargs = {}
chunk_strat = "semantic"
chunk_kwargs = {"embedding_model": settings.hf_emb_model_name, "threshold": chunk_threshold, "chunk_size": chunk_size,
"similarity_window": chunk_similarity_window, "skip_window": chunk_skip_window}
elif docs_extension == ".avap":
process_type = "text"
process_kwargs = {}
chunk_strat = "recursive" # Once we have the BNF and uploaded to tree-sitter, we can use code (?)
chunk_kwargs = {"chunk_size": chunk_size}
return process_type, process_kwargs, chunk_strat, chunk_kwargs
@app.command()
def elasticsearch_ingestion(
docs_folder_path: str = "docs/LRM",
docs_extension: str = ".md",
es_index: str = "avap-docs-test-v3",
es_request_timeout: int = 120,
es_max_retries: int = 5,
es_retry_on_timeout: bool = True,
delete_es_index: bool = True,
chunk_size: int = 2048,
chunk_threshold: float | None = 0.5,
chunk_similarity_window: int | None = 3,
chunk_skip_window: int | None = 1
):
custom_tokenizer = AutoTokenizer.from_pretrained(settings.hf_emb_model_name)
processed_docs = []
fused_docs = []
logger.info(f"Instantiating Elasticsearch client with URL: {settings.elasticsearch_local_url}...")
es = Elasticsearch(
hosts=settings.elasticsearch_local_url,
request_timeout=es_request_timeout,
max_retries=es_max_retries,
retry_on_timeout=es_retry_on_timeout,
)
if delete_es_index and es.indices.exists(index=es_index):
logger.info(f"Deleting existing Elasticsearch index: {es_index}...")
es.indices.delete(index=es_index)
logger.info("Starting Elasticsearch ingestion pipeline...")
(process_type,
process_kwargs,
chunk_strat,
chunk_kwargs) = get_processing_and_chunking_config(docs_extension, chunk_size, chunk_threshold, chunk_similarity_window, chunk_skip_window)
logger.info(f"Fetching files from {docs_folder_path}...")
fetcher = FileFetcher()
docs = fetcher.fetch(dir=f"{settings.proj_root}/{docs_folder_path}")
logger.info(f"Processing documents with process_type: {process_type}...")
chef = MarkdownChef(tokenizer=custom_tokenizer)
for doc in docs:
processed_doc = chef.process(doc)
processed_docs.append(processed_doc)
logger.info(f"Chunking documents with chunk_strat: {chunk_strat}...")
for processed_doc in processed_docs:
fused_doc = merge_markdown_document(processed_doc)
fused_docs.append(fused_doc)
logger.info(f"Ingesting chunks in Elasticsearch index: {es_index}...")
handshake = ElasticHandshake(
client=es,
index_name=es_index,
embedding_model=OllamaEmbeddings(model=settings.ollama_emb_model_name)
)
for fused_doc in fused_docs:
handshake.write(fused_doc.chunks)
logger.info(f"Finished ingesting in {es_index}.")
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
try:
app()
except Exception as exc:
logger.exception(exc)
raise