{ "cells": [ { "cell_type": "code", "execution_count": 15, "id": "0a8abbfa", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os\n", "import re\n", "import uuid\n", "from dataclasses import dataclass\n", "from typing import Iterable, List, Dict, Any, Callable, Protocol\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "from loguru import logger\n", "from transformers import AutoTokenizer, AutoModel\n", "from elasticsearch import Elasticsearch\n", "from elasticsearch.helpers import bulk\n", "import nltk\n", "from nltk.tokenize import sent_tokenize\n", "nltk.download(\"punkt\", quiet=True)" ] }, { "cell_type": "markdown", "id": "77f6c552", "metadata": {}, "source": [ "### Domain model" ] }, { "cell_type": "code", "execution_count": 2, "id": "c4cd2bc2", "metadata": {}, "outputs": [], "source": [ "@dataclass(frozen=True)\n", "class Chunk:\n", " doc_id: str\n", " chunk_id: int\n", " text: str\n", " source: str\n", " metadata: Dict[str, Any]" ] }, { "cell_type": "markdown", "id": "5cd700bd", "metadata": {}, "source": [ "### Utilities" ] }, { "cell_type": "code", "execution_count": 3, "id": "84e834d9", "metadata": {}, "outputs": [], "source": [ "def clean_text(text: str) -> str:\n", " text = text.replace(\"\\u00a0\", \" \")\n", " text = re.sub(r\"\\s+\", \" \", text).strip()\n", " return text" ] }, { "cell_type": "code", "execution_count": 4, "id": "4ebdc5f5", "metadata": {}, "outputs": [], "source": [ "class ChunkingStrategy(Protocol):\n", " def __call__(self, text: str, **kwargs) -> List[str]:\n", " ..." ] }, { "cell_type": "markdown", "id": "82209fc0", "metadata": {}, "source": [ "### Chunking strategies" ] }, { "cell_type": "code", "execution_count": 23, "id": "9f360449", "metadata": {}, "outputs": [], "source": [ "def fixed_size_token_chunking(\n", " text: str,\n", " embedding_model_name: str = os.getenv(\"EMBEDDING_MODEL_NAME\"),\n", " chunk_size: int = 1200,\n", " overlap: int = 200,\n", ") -> List[str]:\n", "\n", " if chunk_size <= overlap:\n", " raise ValueError(\"chunk_size must be greater than overlap\")\n", "\n", " tokenizer = AutoTokenizer.from_pretrained(embedding_model_name, use_fast=True)\n", " token_ids = tokenizer.encode(text, add_special_tokens=False)\n", "\n", " chunks: List[str] = []\n", " start = 0\n", " n = len(token_ids)\n", "\n", " while start < n:\n", " end = min(start + chunk_size, n)\n", " chunk_ids = token_ids[start:end]\n", " chunks.append(tokenizer.decode(chunk_ids, skip_special_tokens=True))\n", "\n", " if end == n:\n", " break\n", "\n", " start = end - overlap\n", "\n", " return chunks\n", "\n", "\n", "def semantic_chunking(\n", " text: str,\n", " embedding_model_name: str = os.getenv(\"EMBEDDING_MODEL_NAME\"),\n", " similarity_threshold: float = 0.78,\n", " max_sentences_per_chunk: int = 12,\n", ") -> List[str]:\n", " sentences = [s.strip() for s in sent_tokenize(text) if s.strip()]\n", " if not sentences:\n", " return []\n", " print(f\"Semantic chunking: {len(sentences)} sentences found\")\n", " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", " tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)\n", " model = AutoModel.from_pretrained(embedding_model_name).to(device)\n", " model.eval()\n", "\n", " with torch.no_grad():\n", " enc = tokenizer(sentences, padding=True, truncation=True, return_tensors=\"pt\").to(device)\n", " out = model(**enc)\n", " mask = enc[\"attention_mask\"].unsqueeze(-1)\n", " vecs = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1e-9)\n", " vecs = F.normalize(vecs, p=2, dim=1)\n", "\n", " chunks: List[List[str]] = [[sentences[0]]]\n", "\n", " for i in range(1, len(sentences)):\n", " sim = float((vecs[i - 1] * vecs[i]).sum())\n", " if sim < similarity_threshold or len(chunks[-1]) >= max_sentences_per_chunk:\n", " chunks.append([])\n", " chunks[-1].append(sentences[i])\n", "\n", " return [\" \".join(chunk) for chunk in chunks if chunk]" ] }, { "cell_type": "code", "execution_count": 6, "id": "bc7267d7", "metadata": {}, "outputs": [], "source": [ "CHUNKING_REGISTRY: Dict[str, ChunkingStrategy] = {\n", " \"fixed\": fixed_size_token_chunking,\n", " \"semantic\": semantic_chunking,\n", "}" ] }, { "cell_type": "code", "execution_count": 7, "id": "87f2f70c", "metadata": {}, "outputs": [], "source": [ "def build_chunks(\n", " doc_text: str,\n", " source: str,\n", " metadata: Dict[str, Any],\n", " chunking_strategy: str = \"fixed\",\n", " **chunking_kwargs,\n", ") -> List[Chunk]:\n", "\n", " if chunking_strategy not in CHUNKING_REGISTRY:\n", " raise ValueError(\n", " f\"Unknown chunking strategy '{chunking_strategy}'. \"\n", " f\"Available: {list(CHUNKING_REGISTRY.keys())}\"\n", " )\n", "\n", " doc_id = metadata.get(\"doc_id\") or str(uuid.uuid4())\n", " cleaned = clean_text(doc_text)\n", "\n", " chunking_fn = CHUNKING_REGISTRY[chunking_strategy]\n", "\n", " parts = chunking_fn(cleaned, **chunking_kwargs)\n", "\n", " return [\n", " Chunk(\n", " doc_id=doc_id,\n", " chunk_id=i,\n", " text=part,\n", " source=source,\n", " metadata={**metadata, \"doc_id\": doc_id},\n", " )\n", " for i, part in enumerate(parts)\n", " if part.strip()\n", " ]" ] }, { "cell_type": "markdown", "id": "ba5649e9", "metadata": {}, "source": [ "### Ingestion in elasticsearch" ] }, { "cell_type": "code", "execution_count": null, "id": "ff03c689", "metadata": {}, "outputs": [], "source": [ "def index_chunks(\n", " es: Elasticsearch,\n", " index_name: str,\n", " model: SentenceTransformer,\n", " chunks: List[Chunk],\n", " batch_size: int = 64,\n", ") -> None:\n", " def actions() -> Iterable[Dict[str, Any]]:\n", " # Embed in batches for speed\n", " for i in range(0, len(chunks), batch_size):\n", " batch = chunks[i:i + batch_size]\n", " texts = [c.text for c in batch]\n", " vectors = model.encode(texts, normalize_embeddings=True).tolist()\n", "\n", " for c, v in zip(batch, vectors):\n", " yield {\n", " \"_op_type\": \"index\",\n", " \"_index\": index_name,\n", " \"_id\": f\"{c.doc_id}:{c.chunk_id}\",\n", " \"_source\": {\n", " \"doc_id\": c.doc_id,\n", " \"chunk_id\": c.chunk_id,\n", " \"text\": c.text,\n", " \"source\": c.source,\n", " \"metadata\": c.metadata,\n", " \"embedding\": v,\n", " },\n", " }\n", "\n", " bulk(es, actions(), request_timeout=120)" ] }, { "cell_type": "code", "execution_count": null, "id": "7bcf0c87", "metadata": {}, "outputs": [], "source": [ "es_url = os.environ.get(\"ES_URL\", \"http://localhost:9200\")\n", "es_user = os.environ.get(\"ES_USER\")\n", "es_pass = os.environ.get(\"ES_PASS\")\n", "index_name = \"my_docs_v1\"\n", "\n", "es = Elasticsearch(\n", " es_url,\n", " basic_auth=(es_user, es_pass) if es_user and es_pass else None,\n", " request_timeout=60,\n", ")\n", "\n", "# Pick a model with dims matching your index mapping.\n", "model = SentenceTransformer(\"sentence-transformers/all-mpnet-base-v2\") # 768 dims\n", "\n", "# Example document\n", "doc_text = \"\"\"\n", "This is a sample document. Replace this with your PDF/HTML extraction output.\n", "\"\"\"\n", "chunks = build_chunks(\n", " doc_text=doc_text,\n", " source=\"local_demo\",\n", " metadata={\"title\": \"Demo\", \"doc_id\": \"demo-001\"},\n", ")\n", "\n", "index_chunks(es, index_name, model, chunks)\n", "es.indices.refresh(index=index_name)\n", "print(f\"Indexed {len(chunks)} chunks into {index_name}.\")" ] }, { "cell_type": "code", "execution_count": 24, "id": "b1ba8e85", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b763a493689549a180ab815567520c0a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading weights: 0%| | 0/310 [00:00