feat: add chunking methods and ingestion process for Elasticsearch
This commit is contained in:
parent
f2482cae19
commit
26603a9f45
|
|
@ -0,0 +1,396 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"id": "0a8abbfa",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"True"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 15,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"import re\n",
|
||||||
|
"import uuid\n",
|
||||||
|
"from dataclasses import dataclass\n",
|
||||||
|
"from typing import Iterable, List, Dict, Any, Callable, Protocol\n",
|
||||||
|
"\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import torch.nn.functional as F\n",
|
||||||
|
"from loguru import logger\n",
|
||||||
|
"from transformers import AutoTokenizer, AutoModel\n",
|
||||||
|
"from elasticsearch import Elasticsearch\n",
|
||||||
|
"from elasticsearch.helpers import bulk\n",
|
||||||
|
"import nltk\n",
|
||||||
|
"from nltk.tokenize import sent_tokenize\n",
|
||||||
|
"nltk.download(\"punkt\", quiet=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "77f6c552",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Domain model"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "c4cd2bc2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"@dataclass(frozen=True)\n",
|
||||||
|
"class Chunk:\n",
|
||||||
|
" doc_id: str\n",
|
||||||
|
" chunk_id: int\n",
|
||||||
|
" text: str\n",
|
||||||
|
" source: str\n",
|
||||||
|
" metadata: Dict[str, Any]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "5cd700bd",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Utilities"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "84e834d9",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def clean_text(text: str) -> str:\n",
|
||||||
|
" text = text.replace(\"\\u00a0\", \" \")\n",
|
||||||
|
" text = re.sub(r\"\\s+\", \" \", text).strip()\n",
|
||||||
|
" return text"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "4ebdc5f5",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class ChunkingStrategy(Protocol):\n",
|
||||||
|
" def __call__(self, text: str, **kwargs) -> List[str]:\n",
|
||||||
|
" ..."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "82209fc0",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Chunking strategies"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 23,
|
||||||
|
"id": "9f360449",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def fixed_size_token_chunking(\n",
|
||||||
|
" text: str,\n",
|
||||||
|
" embedding_model_name: str = os.getenv(\"EMBEDDING_MODEL_NAME\"),\n",
|
||||||
|
" chunk_size: int = 1200,\n",
|
||||||
|
" overlap: int = 200,\n",
|
||||||
|
") -> List[str]:\n",
|
||||||
|
"\n",
|
||||||
|
" if chunk_size <= overlap:\n",
|
||||||
|
" raise ValueError(\"chunk_size must be greater than overlap\")\n",
|
||||||
|
"\n",
|
||||||
|
" tokenizer = AutoTokenizer.from_pretrained(embedding_model_name, use_fast=True)\n",
|
||||||
|
" token_ids = tokenizer.encode(text, add_special_tokens=False)\n",
|
||||||
|
"\n",
|
||||||
|
" chunks: List[str] = []\n",
|
||||||
|
" start = 0\n",
|
||||||
|
" n = len(token_ids)\n",
|
||||||
|
"\n",
|
||||||
|
" while start < n:\n",
|
||||||
|
" end = min(start + chunk_size, n)\n",
|
||||||
|
" chunk_ids = token_ids[start:end]\n",
|
||||||
|
" chunks.append(tokenizer.decode(chunk_ids, skip_special_tokens=True))\n",
|
||||||
|
"\n",
|
||||||
|
" if end == n:\n",
|
||||||
|
" break\n",
|
||||||
|
"\n",
|
||||||
|
" start = end - overlap\n",
|
||||||
|
"\n",
|
||||||
|
" return chunks\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def semantic_chunking(\n",
|
||||||
|
" text: str,\n",
|
||||||
|
" embedding_model_name: str = os.getenv(\"EMBEDDING_MODEL_NAME\"),\n",
|
||||||
|
" similarity_threshold: float = 0.78,\n",
|
||||||
|
" max_sentences_per_chunk: int = 12,\n",
|
||||||
|
") -> List[str]:\n",
|
||||||
|
" sentences = [s.strip() for s in sent_tokenize(text) if s.strip()]\n",
|
||||||
|
" if not sentences:\n",
|
||||||
|
" return []\n",
|
||||||
|
" print(f\"Semantic chunking: {len(sentences)} sentences found\")\n",
|
||||||
|
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
||||||
|
"\n",
|
||||||
|
" tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)\n",
|
||||||
|
" model = AutoModel.from_pretrained(embedding_model_name).to(device)\n",
|
||||||
|
" model.eval()\n",
|
||||||
|
"\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" enc = tokenizer(sentences, padding=True, truncation=True, return_tensors=\"pt\").to(device)\n",
|
||||||
|
" out = model(**enc)\n",
|
||||||
|
" mask = enc[\"attention_mask\"].unsqueeze(-1)\n",
|
||||||
|
" vecs = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1e-9)\n",
|
||||||
|
" vecs = F.normalize(vecs, p=2, dim=1)\n",
|
||||||
|
"\n",
|
||||||
|
" chunks: List[List[str]] = [[sentences[0]]]\n",
|
||||||
|
"\n",
|
||||||
|
" for i in range(1, len(sentences)):\n",
|
||||||
|
" sim = float((vecs[i - 1] * vecs[i]).sum())\n",
|
||||||
|
" if sim < similarity_threshold or len(chunks[-1]) >= max_sentences_per_chunk:\n",
|
||||||
|
" chunks.append([])\n",
|
||||||
|
" chunks[-1].append(sentences[i])\n",
|
||||||
|
"\n",
|
||||||
|
" return [\" \".join(chunk) for chunk in chunks if chunk]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "bc7267d7",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"CHUNKING_REGISTRY: Dict[str, ChunkingStrategy] = {\n",
|
||||||
|
" \"fixed\": fixed_size_token_chunking,\n",
|
||||||
|
" \"semantic\": semantic_chunking,\n",
|
||||||
|
"}"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "87f2f70c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def build_chunks(\n",
|
||||||
|
" doc_text: str,\n",
|
||||||
|
" source: str,\n",
|
||||||
|
" metadata: Dict[str, Any],\n",
|
||||||
|
" chunking_strategy: str = \"fixed\",\n",
|
||||||
|
" **chunking_kwargs,\n",
|
||||||
|
") -> List[Chunk]:\n",
|
||||||
|
"\n",
|
||||||
|
" if chunking_strategy not in CHUNKING_REGISTRY:\n",
|
||||||
|
" raise ValueError(\n",
|
||||||
|
" f\"Unknown chunking strategy '{chunking_strategy}'. \"\n",
|
||||||
|
" f\"Available: {list(CHUNKING_REGISTRY.keys())}\"\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" doc_id = metadata.get(\"doc_id\") or str(uuid.uuid4())\n",
|
||||||
|
" cleaned = clean_text(doc_text)\n",
|
||||||
|
"\n",
|
||||||
|
" chunking_fn = CHUNKING_REGISTRY[chunking_strategy]\n",
|
||||||
|
"\n",
|
||||||
|
" parts = chunking_fn(cleaned, **chunking_kwargs)\n",
|
||||||
|
"\n",
|
||||||
|
" return [\n",
|
||||||
|
" Chunk(\n",
|
||||||
|
" doc_id=doc_id,\n",
|
||||||
|
" chunk_id=i,\n",
|
||||||
|
" text=part,\n",
|
||||||
|
" source=source,\n",
|
||||||
|
" metadata={**metadata, \"doc_id\": doc_id},\n",
|
||||||
|
" )\n",
|
||||||
|
" for i, part in enumerate(parts)\n",
|
||||||
|
" if part.strip()\n",
|
||||||
|
" ]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ba5649e9",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Ingestion in elasticsearch"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ff03c689",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def index_chunks(\n",
|
||||||
|
" es: Elasticsearch,\n",
|
||||||
|
" index_name: str,\n",
|
||||||
|
" model: SentenceTransformer,\n",
|
||||||
|
" chunks: List[Chunk],\n",
|
||||||
|
" batch_size: int = 64,\n",
|
||||||
|
") -> None:\n",
|
||||||
|
" def actions() -> Iterable[Dict[str, Any]]:\n",
|
||||||
|
" # Embed in batches for speed\n",
|
||||||
|
" for i in range(0, len(chunks), batch_size):\n",
|
||||||
|
" batch = chunks[i:i + batch_size]\n",
|
||||||
|
" texts = [c.text for c in batch]\n",
|
||||||
|
" vectors = model.encode(texts, normalize_embeddings=True).tolist()\n",
|
||||||
|
"\n",
|
||||||
|
" for c, v in zip(batch, vectors):\n",
|
||||||
|
" yield {\n",
|
||||||
|
" \"_op_type\": \"index\",\n",
|
||||||
|
" \"_index\": index_name,\n",
|
||||||
|
" \"_id\": f\"{c.doc_id}:{c.chunk_id}\",\n",
|
||||||
|
" \"_source\": {\n",
|
||||||
|
" \"doc_id\": c.doc_id,\n",
|
||||||
|
" \"chunk_id\": c.chunk_id,\n",
|
||||||
|
" \"text\": c.text,\n",
|
||||||
|
" \"source\": c.source,\n",
|
||||||
|
" \"metadata\": c.metadata,\n",
|
||||||
|
" \"embedding\": v,\n",
|
||||||
|
" },\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" bulk(es, actions(), request_timeout=120)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "7bcf0c87",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"es_url = os.environ.get(\"ES_URL\", \"http://localhost:9200\")\n",
|
||||||
|
"es_user = os.environ.get(\"ES_USER\")\n",
|
||||||
|
"es_pass = os.environ.get(\"ES_PASS\")\n",
|
||||||
|
"index_name = \"my_docs_v1\"\n",
|
||||||
|
"\n",
|
||||||
|
"es = Elasticsearch(\n",
|
||||||
|
" es_url,\n",
|
||||||
|
" basic_auth=(es_user, es_pass) if es_user and es_pass else None,\n",
|
||||||
|
" request_timeout=60,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# Pick a model with dims matching your index mapping.\n",
|
||||||
|
"model = SentenceTransformer(\"sentence-transformers/all-mpnet-base-v2\") # 768 dims\n",
|
||||||
|
"\n",
|
||||||
|
"# Example document\n",
|
||||||
|
"doc_text = \"\"\"\n",
|
||||||
|
"This is a sample document. Replace this with your PDF/HTML extraction output.\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
"chunks = build_chunks(\n",
|
||||||
|
" doc_text=doc_text,\n",
|
||||||
|
" source=\"local_demo\",\n",
|
||||||
|
" metadata={\"title\": \"Demo\", \"doc_id\": \"demo-001\"},\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"index_chunks(es, index_name, model, chunks)\n",
|
||||||
|
"es.indices.refresh(index=index_name)\n",
|
||||||
|
"print(f\"Indexed {len(chunks)} chunks into {index_name}.\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 24,
|
||||||
|
"id": "b1ba8e85",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"model_id": "b763a493689549a180ab815567520c0a",
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0
|
||||||
|
},
|
||||||
|
"text/plain": [
|
||||||
|
"Loading weights: 0%| | 0/310 [00:00<?, ?it/s]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Example document\n",
|
||||||
|
"doc_text = \"\"\"\n",
|
||||||
|
"This is a sample document. Replace this with your PDF/HTML extraction output.\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
"chunks = build_chunks(\n",
|
||||||
|
" doc_text=doc_text,\n",
|
||||||
|
" source=\"local_demo\",\n",
|
||||||
|
" metadata={\"title\": \"Demo\", \"doc_id\": \"demo-001\"},\n",
|
||||||
|
" chunking_strategy=\"semantic\"\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 25,
|
||||||
|
"id": "b2c52b38",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[Chunk(doc_id='demo-001', chunk_id=0, text='This is a sample document. Replace this with your PDF/HTML extraction output.', source='local_demo', metadata={'title': 'Demo', 'doc_id': 'demo-001'})]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 25,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chunks"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "daa57061",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "assistance-engine",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.13"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue