feat: add chunking methods and ingestion process for Elasticsearch
This commit is contained in:
parent
f2482cae19
commit
26603a9f45
|
|
@ -0,0 +1,396 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "0a8abbfa",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"True"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import re\n",
|
||||
"import uuid\n",
|
||||
"from dataclasses import dataclass\n",
|
||||
"from typing import Iterable, List, Dict, Any, Callable, Protocol\n",
|
||||
"\n",
|
||||
"import torch\n",
|
||||
"import torch.nn.functional as F\n",
|
||||
"from loguru import logger\n",
|
||||
"from transformers import AutoTokenizer, AutoModel\n",
|
||||
"from elasticsearch import Elasticsearch\n",
|
||||
"from elasticsearch.helpers import bulk\n",
|
||||
"import nltk\n",
|
||||
"from nltk.tokenize import sent_tokenize\n",
|
||||
"nltk.download(\"punkt\", quiet=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "77f6c552",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Domain model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "c4cd2bc2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@dataclass(frozen=True)\n",
|
||||
"class Chunk:\n",
|
||||
" doc_id: str\n",
|
||||
" chunk_id: int\n",
|
||||
" text: str\n",
|
||||
" source: str\n",
|
||||
" metadata: Dict[str, Any]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "5cd700bd",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Utilities"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "84e834d9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def clean_text(text: str) -> str:\n",
|
||||
" text = text.replace(\"\\u00a0\", \" \")\n",
|
||||
" text = re.sub(r\"\\s+\", \" \", text).strip()\n",
|
||||
" return text"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "4ebdc5f5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ChunkingStrategy(Protocol):\n",
|
||||
" def __call__(self, text: str, **kwargs) -> List[str]:\n",
|
||||
" ..."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "82209fc0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Chunking strategies"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"id": "9f360449",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def fixed_size_token_chunking(\n",
|
||||
" text: str,\n",
|
||||
" embedding_model_name: str = os.getenv(\"EMBEDDING_MODEL_NAME\"),\n",
|
||||
" chunk_size: int = 1200,\n",
|
||||
" overlap: int = 200,\n",
|
||||
") -> List[str]:\n",
|
||||
"\n",
|
||||
" if chunk_size <= overlap:\n",
|
||||
" raise ValueError(\"chunk_size must be greater than overlap\")\n",
|
||||
"\n",
|
||||
" tokenizer = AutoTokenizer.from_pretrained(embedding_model_name, use_fast=True)\n",
|
||||
" token_ids = tokenizer.encode(text, add_special_tokens=False)\n",
|
||||
"\n",
|
||||
" chunks: List[str] = []\n",
|
||||
" start = 0\n",
|
||||
" n = len(token_ids)\n",
|
||||
"\n",
|
||||
" while start < n:\n",
|
||||
" end = min(start + chunk_size, n)\n",
|
||||
" chunk_ids = token_ids[start:end]\n",
|
||||
" chunks.append(tokenizer.decode(chunk_ids, skip_special_tokens=True))\n",
|
||||
"\n",
|
||||
" if end == n:\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
" start = end - overlap\n",
|
||||
"\n",
|
||||
" return chunks\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def semantic_chunking(\n",
|
||||
" text: str,\n",
|
||||
" embedding_model_name: str = os.getenv(\"EMBEDDING_MODEL_NAME\"),\n",
|
||||
" similarity_threshold: float = 0.78,\n",
|
||||
" max_sentences_per_chunk: int = 12,\n",
|
||||
") -> List[str]:\n",
|
||||
" sentences = [s.strip() for s in sent_tokenize(text) if s.strip()]\n",
|
||||
" if not sentences:\n",
|
||||
" return []\n",
|
||||
" print(f\"Semantic chunking: {len(sentences)} sentences found\")\n",
|
||||
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
||||
"\n",
|
||||
" tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)\n",
|
||||
" model = AutoModel.from_pretrained(embedding_model_name).to(device)\n",
|
||||
" model.eval()\n",
|
||||
"\n",
|
||||
" with torch.no_grad():\n",
|
||||
" enc = tokenizer(sentences, padding=True, truncation=True, return_tensors=\"pt\").to(device)\n",
|
||||
" out = model(**enc)\n",
|
||||
" mask = enc[\"attention_mask\"].unsqueeze(-1)\n",
|
||||
" vecs = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1e-9)\n",
|
||||
" vecs = F.normalize(vecs, p=2, dim=1)\n",
|
||||
"\n",
|
||||
" chunks: List[List[str]] = [[sentences[0]]]\n",
|
||||
"\n",
|
||||
" for i in range(1, len(sentences)):\n",
|
||||
" sim = float((vecs[i - 1] * vecs[i]).sum())\n",
|
||||
" if sim < similarity_threshold or len(chunks[-1]) >= max_sentences_per_chunk:\n",
|
||||
" chunks.append([])\n",
|
||||
" chunks[-1].append(sentences[i])\n",
|
||||
"\n",
|
||||
" return [\" \".join(chunk) for chunk in chunks if chunk]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "bc7267d7",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"CHUNKING_REGISTRY: Dict[str, ChunkingStrategy] = {\n",
|
||||
" \"fixed\": fixed_size_token_chunking,\n",
|
||||
" \"semantic\": semantic_chunking,\n",
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "87f2f70c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def build_chunks(\n",
|
||||
" doc_text: str,\n",
|
||||
" source: str,\n",
|
||||
" metadata: Dict[str, Any],\n",
|
||||
" chunking_strategy: str = \"fixed\",\n",
|
||||
" **chunking_kwargs,\n",
|
||||
") -> List[Chunk]:\n",
|
||||
"\n",
|
||||
" if chunking_strategy not in CHUNKING_REGISTRY:\n",
|
||||
" raise ValueError(\n",
|
||||
" f\"Unknown chunking strategy '{chunking_strategy}'. \"\n",
|
||||
" f\"Available: {list(CHUNKING_REGISTRY.keys())}\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" doc_id = metadata.get(\"doc_id\") or str(uuid.uuid4())\n",
|
||||
" cleaned = clean_text(doc_text)\n",
|
||||
"\n",
|
||||
" chunking_fn = CHUNKING_REGISTRY[chunking_strategy]\n",
|
||||
"\n",
|
||||
" parts = chunking_fn(cleaned, **chunking_kwargs)\n",
|
||||
"\n",
|
||||
" return [\n",
|
||||
" Chunk(\n",
|
||||
" doc_id=doc_id,\n",
|
||||
" chunk_id=i,\n",
|
||||
" text=part,\n",
|
||||
" source=source,\n",
|
||||
" metadata={**metadata, \"doc_id\": doc_id},\n",
|
||||
" )\n",
|
||||
" for i, part in enumerate(parts)\n",
|
||||
" if part.strip()\n",
|
||||
" ]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ba5649e9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Ingestion in elasticsearch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ff03c689",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def index_chunks(\n",
|
||||
" es: Elasticsearch,\n",
|
||||
" index_name: str,\n",
|
||||
" model: SentenceTransformer,\n",
|
||||
" chunks: List[Chunk],\n",
|
||||
" batch_size: int = 64,\n",
|
||||
") -> None:\n",
|
||||
" def actions() -> Iterable[Dict[str, Any]]:\n",
|
||||
" # Embed in batches for speed\n",
|
||||
" for i in range(0, len(chunks), batch_size):\n",
|
||||
" batch = chunks[i:i + batch_size]\n",
|
||||
" texts = [c.text for c in batch]\n",
|
||||
" vectors = model.encode(texts, normalize_embeddings=True).tolist()\n",
|
||||
"\n",
|
||||
" for c, v in zip(batch, vectors):\n",
|
||||
" yield {\n",
|
||||
" \"_op_type\": \"index\",\n",
|
||||
" \"_index\": index_name,\n",
|
||||
" \"_id\": f\"{c.doc_id}:{c.chunk_id}\",\n",
|
||||
" \"_source\": {\n",
|
||||
" \"doc_id\": c.doc_id,\n",
|
||||
" \"chunk_id\": c.chunk_id,\n",
|
||||
" \"text\": c.text,\n",
|
||||
" \"source\": c.source,\n",
|
||||
" \"metadata\": c.metadata,\n",
|
||||
" \"embedding\": v,\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" bulk(es, actions(), request_timeout=120)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7bcf0c87",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"es_url = os.environ.get(\"ES_URL\", \"http://localhost:9200\")\n",
|
||||
"es_user = os.environ.get(\"ES_USER\")\n",
|
||||
"es_pass = os.environ.get(\"ES_PASS\")\n",
|
||||
"index_name = \"my_docs_v1\"\n",
|
||||
"\n",
|
||||
"es = Elasticsearch(\n",
|
||||
" es_url,\n",
|
||||
" basic_auth=(es_user, es_pass) if es_user and es_pass else None,\n",
|
||||
" request_timeout=60,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Pick a model with dims matching your index mapping.\n",
|
||||
"model = SentenceTransformer(\"sentence-transformers/all-mpnet-base-v2\") # 768 dims\n",
|
||||
"\n",
|
||||
"# Example document\n",
|
||||
"doc_text = \"\"\"\n",
|
||||
"This is a sample document. Replace this with your PDF/HTML extraction output.\n",
|
||||
"\"\"\"\n",
|
||||
"chunks = build_chunks(\n",
|
||||
" doc_text=doc_text,\n",
|
||||
" source=\"local_demo\",\n",
|
||||
" metadata={\"title\": \"Demo\", \"doc_id\": \"demo-001\"},\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"index_chunks(es, index_name, model, chunks)\n",
|
||||
"es.indices.refresh(index=index_name)\n",
|
||||
"print(f\"Indexed {len(chunks)} chunks into {index_name}.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"id": "b1ba8e85",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "b763a493689549a180ab815567520c0a",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Loading weights: 0%| | 0/310 [00:00<?, ?it/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Example document\n",
|
||||
"doc_text = \"\"\"\n",
|
||||
"This is a sample document. Replace this with your PDF/HTML extraction output.\n",
|
||||
"\"\"\"\n",
|
||||
"chunks = build_chunks(\n",
|
||||
" doc_text=doc_text,\n",
|
||||
" source=\"local_demo\",\n",
|
||||
" metadata={\"title\": \"Demo\", \"doc_id\": \"demo-001\"},\n",
|
||||
" chunking_strategy=\"semantic\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"id": "b2c52b38",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[Chunk(doc_id='demo-001', chunk_id=0, text='This is a sample document. Replace this with your PDF/HTML extraction output.', source='local_demo', metadata={'title': 'Demo', 'doc_id': 'demo-001'})]"
|
||||
]
|
||||
},
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"chunks"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "daa57061",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "assistance-engine",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Loading…
Reference in New Issue