assistance-engine/scripts/pipelines/tasks/chunk.py

160 lines
4.5 KiB
Python

from copy import deepcopy
from dataclasses import replace
from pathlib import Path
from chonkie import (
Chunk,
ElasticHandshake,
MarkdownChef,
TextChef,
TokenChunker,
)
from elasticsearch import Elasticsearch
from loguru import logger
from transformers import AutoTokenizer
from scripts.pipelines.tasks.embeddings import OllamaEmbeddings
from src.config import settings
def _get_text(element) -> str:
for attr in ("text", "content", "markdown"):
value = getattr(element, attr, None)
if isinstance(value, str):
return value
raise AttributeError(
f"Could not extract text from element of type {type(element).__name__}"
)
def _merge_markdown_document(doc):
elements = []
for chunk in doc.chunks:
elements.append(("chunk", chunk.start_index, chunk.end_index, chunk))
for code in doc.code:
elements.append(("code", code.start_index, code.end_index, code))
for table in doc.tables:
elements.append(("table", table.start_index, table.end_index, table))
elements.sort(key=lambda item: (item[1], item[2]))
merged_chunks = []
current_chunk = None
current_parts = []
current_end_index = None
current_token_count = None
def flush():
nonlocal current_chunk, current_parts, current_end_index, current_token_count
if current_chunk is None:
return
merged_text = "\n\n".join(part for part in current_parts if part)
merged_chunks.append(
replace(
current_chunk,
text=merged_text,
end_index=current_end_index,
token_count=current_token_count,
)
)
current_chunk = None
current_parts = []
current_end_index = None
current_token_count = None
for kind, _, _, element in elements:
if kind == "chunk":
flush()
current_chunk = element
current_parts = [_get_text(element)]
current_end_index = element.end_index
current_token_count = element.token_count
continue
if current_chunk is None:
continue
current_parts.append(_get_text(element))
current_end_index = max(current_end_index, element.end_index)
current_token_count += getattr(element, "token_count", 0)
flush()
new_doc = deepcopy(doc)
new_doc.chunks = merged_chunks
new_doc.code = doc.code
new_doc.tables = doc.tables
return new_doc
def process_documents(docs_path: list[Path], docs_extension: str) -> list[Chunk]:
processed_docs = []
chunked_docs = []
custom_tokenizer = AutoTokenizer.from_pretrained(settings.hf_emb_model_name)
if docs_extension == ".md":
chef = MarkdownChef(tokenizer=custom_tokenizer)
for doc in docs_path:
processed_doc = chef.process(doc)
processed_docs.append((processed_doc, doc.name))
for processed_doc, filename in processed_docs:
fused_doc = _merge_markdown_document(processed_doc)
chunked_docs.extend(fused_doc.chunks)
elif docs_extension == ".avap":
chef = TextChef()
chunker = TokenChunker(tokenizer=custom_tokenizer)
for doc in docs_path:
processed_doc = chef.process(doc)
processed_docs.append((processed_doc, doc.name))
for processed_doc, filename in processed_docs:
chunked_doc = chunker.chunk(processed_doc.content)
chunked_docs.extend(chunked_doc)
return chunked_docs
def ingest_documents(
chunked_docs: list[Chunk],
es_index: str,
es_request_timeout: int,
es_max_retries: int,
es_retry_on_timeout: bool,
delete_es_index: bool,
) -> None:
logger.info(
f"Instantiating Elasticsearch client with URL: {settings.elasticsearch_local_url}..."
)
es = Elasticsearch(
hosts=settings.elasticsearch_local_url,
request_timeout=es_request_timeout,
max_retries=es_max_retries,
retry_on_timeout=es_retry_on_timeout,
)
if delete_es_index and es.indices.exists(index=es_index):
logger.info(f"Deleting existing Elasticsearch index: {es_index}...")
es.indices.delete(index=es_index)
handshake = ElasticHandshake(
client=es,
index_name=es_index,
embedding_model=OllamaEmbeddings(model=settings.ollama_emb_model_name),
)
logger.info(
f"Ingesting {len(chunked_docs)} chunks into Elasticsearch index: {es_index}..."
)
handshake.write(chunked_docs)