194 lines
5.9 KiB
Python
194 lines
5.9 KiB
Python
from copy import deepcopy
|
|
from dataclasses import replace
|
|
from pathlib import Path
|
|
|
|
from chonkie import (
|
|
Chunk,
|
|
ElasticHandshake,
|
|
FileFetcher,
|
|
MarkdownChef,
|
|
TextChef,
|
|
TokenChunker,
|
|
MarkdownDocument
|
|
)
|
|
from elasticsearch import Elasticsearch
|
|
from loguru import logger
|
|
from transformers import AutoTokenizer
|
|
|
|
from scripts.pipelines.tasks.embeddings import OllamaEmbeddings
|
|
from src.config import settings
|
|
|
|
|
|
def _get_text(element) -> str:
|
|
for attr in ("text", "content", "markdown"):
|
|
value = getattr(element, attr, None)
|
|
if isinstance(value, str):
|
|
return value
|
|
raise AttributeError(
|
|
f"Could not extract text from element of type {type(element).__name__}"
|
|
)
|
|
|
|
|
|
def _merge_markdown_document(processed_doc: MarkdownDocument) -> MarkdownDocument:
|
|
elements = []
|
|
|
|
for chunk in processed_doc.chunks:
|
|
elements.append(("chunk", chunk.start_index, chunk.end_index, chunk))
|
|
|
|
for code in processed_doc.code:
|
|
elements.append(("code", code.start_index, code.end_index, code))
|
|
|
|
for table in processed_doc.tables:
|
|
elements.append(("table", table.start_index, table.end_index, table))
|
|
|
|
elements.sort(key=lambda item: (item[1], item[2]))
|
|
|
|
merged_chunks = []
|
|
current_chunk = None
|
|
current_parts = []
|
|
current_end_index = None
|
|
current_token_count = None
|
|
|
|
def flush():
|
|
nonlocal current_chunk, current_parts, current_end_index, current_token_count
|
|
|
|
if current_chunk is None:
|
|
return
|
|
|
|
merged_text = "\n\n".join(part for part in current_parts if part)
|
|
|
|
merged_chunks.append(
|
|
replace(
|
|
current_chunk,
|
|
text=merged_text,
|
|
end_index=current_end_index,
|
|
token_count=current_token_count,
|
|
)
|
|
)
|
|
|
|
current_chunk = None
|
|
current_parts = []
|
|
current_end_index = None
|
|
current_token_count = None
|
|
|
|
for kind, _, _, element in elements:
|
|
if kind == "chunk":
|
|
flush()
|
|
current_chunk = element
|
|
current_parts = [_get_text(element)]
|
|
current_end_index = element.end_index
|
|
current_token_count = element.token_count
|
|
continue
|
|
|
|
if current_chunk is None:
|
|
continue
|
|
|
|
current_parts.append(_get_text(element))
|
|
current_end_index = max(current_end_index, element.end_index)
|
|
current_token_count += getattr(element, "token_count", 0)
|
|
|
|
flush()
|
|
|
|
fused_processed_doc = deepcopy(processed_doc)
|
|
fused_processed_doc.chunks = merged_chunks
|
|
fused_processed_doc.code = processed_doc.code
|
|
fused_processed_doc.tables = processed_doc.tables
|
|
|
|
return fused_processed_doc
|
|
|
|
|
|
def fetch_documents(docs_folder_path: str, docs_extension: list[str]) -> list[Path]:
|
|
"""
|
|
Fetch files from a folder that match the specified extensions.
|
|
|
|
Args:
|
|
docs_folder_path (str): Path to the folder containing documents
|
|
docs_extension (list[str]): List of file extensions to filter by (e.g., [".md", ".avap"])
|
|
|
|
Returns:
|
|
List of Paths to the fetched documents
|
|
"""
|
|
fetcher = FileFetcher()
|
|
docs_path = fetcher.fetch(dir=f"{settings.proj_root}/{docs_folder_path}", ext=docs_extension)
|
|
return docs_path
|
|
|
|
|
|
def process_documents(docs_path: list[Path]) -> list[Chunk]:
|
|
"""
|
|
Process documents by applying appropriate chefs and chunking strategies based on file type.
|
|
|
|
Args:
|
|
docs_path (list[Path]): List of Paths to the documents to be processed
|
|
|
|
Returns:
|
|
List of processed documents ready for ingestion
|
|
"""
|
|
processed_docs = []
|
|
custom_tokenizer = AutoTokenizer.from_pretrained(settings.hf_emb_model_name)
|
|
chef_md = MarkdownChef(tokenizer=custom_tokenizer)
|
|
chef_txt = TextChef()
|
|
chunker = TokenChunker(tokenizer=custom_tokenizer)
|
|
|
|
for doc_path in docs_path:
|
|
doc_extension = doc_path.suffix.lower()
|
|
|
|
if doc_extension == ".md":
|
|
processed_doc = chef_md.process(doc_path)
|
|
fused_doc = _merge_markdown_document(processed_doc)
|
|
processed_docs.extend(fused_doc.chunks)
|
|
|
|
elif doc_extension == ".avap":
|
|
processed_doc = chef_txt.process(doc_path)
|
|
chunked_doc = chunker.chunk(processed_doc.content)
|
|
processed_docs.extend(chunked_doc)
|
|
|
|
return processed_docs
|
|
|
|
|
|
def ingest_documents(
|
|
chunked_docs: list[Chunk],
|
|
es_index: str,
|
|
es_request_timeout: int,
|
|
es_max_retries: int,
|
|
es_retry_on_timeout: bool,
|
|
delete_es_index: bool,
|
|
) -> None:
|
|
"""
|
|
Ingest processed documents into an Elasticsearch index.
|
|
|
|
Args:
|
|
chunked_docs (list[Chunk]): List of processed document chunks to be ingested
|
|
es_index (str): Name of the Elasticsearch index to ingest into
|
|
es_request_timeout (int): Timeout for Elasticsearch requests in seconds
|
|
es_max_retries (int): Maximum number of retries for Elasticsearch requests
|
|
es_retry_on_timeout (bool): Whether to retry on Elasticsearch request timeouts
|
|
delete_es_index (bool): Whether to delete the existing Elasticsearch index before ingestion
|
|
|
|
Returns:
|
|
None
|
|
"""
|
|
logger.info(
|
|
f"Instantiating Elasticsearch client with URL: {settings.elasticsearch_local_url}..."
|
|
)
|
|
es = Elasticsearch(
|
|
hosts=settings.elasticsearch_local_url,
|
|
request_timeout=es_request_timeout,
|
|
max_retries=es_max_retries,
|
|
retry_on_timeout=es_retry_on_timeout,
|
|
)
|
|
|
|
if delete_es_index and es.indices.exists(index=es_index):
|
|
logger.info(f"Deleting existing Elasticsearch index: {es_index}...")
|
|
es.indices.delete(index=es_index)
|
|
|
|
handshake = ElasticHandshake(
|
|
client=es,
|
|
index_name=es_index,
|
|
embedding_model=OllamaEmbeddings(model=settings.ollama_emb_model_name),
|
|
)
|
|
|
|
logger.info(
|
|
f"Ingesting {len(chunked_docs)} chunks into Elasticsearch index: {es_index}..."
|
|
)
|
|
handshake.write(chunked_docs)
|