417 lines
11 KiB
Plaintext
417 lines
11 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "0a8abbfa",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import re\n",
|
|
"import uuid\n",
|
|
"from pathlib import Path\n",
|
|
"from typing import Any, Protocol\n",
|
|
"import markdown\n",
|
|
"from bs4 import BeautifulSoup\n",
|
|
"\n",
|
|
"from langchain_core.documents import Document\n",
|
|
"from langchain_elasticsearch import ElasticsearchStore\n",
|
|
"import torch\n",
|
|
"import torch.nn.functional as F\n",
|
|
"from loguru import logger\n",
|
|
"from langchain_ollama import OllamaEmbeddings\n",
|
|
"from transformers import AutoTokenizer, AutoModel, AutoConfig\n",
|
|
"from elasticsearch import Elasticsearch\n",
|
|
"import nltk\n",
|
|
"from nltk.tokenize import sent_tokenize\n",
|
|
"nltk.download(\"punkt\", quiet=True)\n",
|
|
"\n",
|
|
"ELASTICSEARCH_URL = os.getenv(\"ELASTICSEARCH_LOCAL_URL\")\n",
|
|
"ELASTICSEARCH_INDEX = os.getenv(\"ELASTICSEARCH_INDEX\")\n",
|
|
"HF_EMB_MODEL_NAME = os.getenv(\"HF_EMB_MODEL_NAME\")\n",
|
|
"OLLAMA_URL = os.getenv(\"OLLAMA_URL\")\n",
|
|
"OLLAMA_LOCAL_URL = os.getenv(\"OLLAMA_LOCAL_URL\")\n",
|
|
"OLLAMA_MODEL_NAME = os.getenv(\"OLLAMA_MODEL_NAME\")\n",
|
|
"OLLAMA_EMB_MODEL_NAME = os.getenv(\"OLLAMA_EMB_MODEL_NAME\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "baa779f3",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Functions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "148a4bb5",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Utilities"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "3c1e4649",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def clean_text(text: str) -> str:\n",
|
|
" text = text.replace(\"\\u00a0\", \" \")\n",
|
|
" text = re.sub(r\"\\s+\", \" \", text).strip()\n",
|
|
" return text\n",
|
|
"\n",
|
|
"def markdown_to_text(md_text: str) -> str:\n",
|
|
" html = markdown.markdown(md_text)\n",
|
|
" soup = BeautifulSoup(html, \"html.parser\")\n",
|
|
" return soup.get_text()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "acecbf08",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Chunking Strategies"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "8360441b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class ChunkingStrategy(Protocol):\n",
|
|
" def __call__(self, text: str, **kwargs) -> list[str]:\n",
|
|
" ..."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "bcb8862f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def fixed_size_token_chunking(\n",
|
|
" text: str,\n",
|
|
" embedding_model_name: str = HF_EMB_MODEL_NAME,\n",
|
|
" chunk_size: int = 1200,\n",
|
|
" overlap: int = 200,\n",
|
|
") -> list[str]:\n",
|
|
"\n",
|
|
" if chunk_size <= overlap:\n",
|
|
" raise ValueError(\"chunk_size must be greater than overlap\")\n",
|
|
"\n",
|
|
" tokenizer = AutoTokenizer.from_pretrained(embedding_model_name, use_fast=True)\n",
|
|
" token_ids = tokenizer.encode(text, add_special_tokens=False)\n",
|
|
"\n",
|
|
" chunks: list[str] = []\n",
|
|
" start = 0\n",
|
|
" n = len(token_ids)\n",
|
|
"\n",
|
|
" while start < n:\n",
|
|
" end = min(start + chunk_size, n)\n",
|
|
" chunk_ids = token_ids[start:end]\n",
|
|
" chunks.append(tokenizer.decode(chunk_ids, skip_special_tokens=True))\n",
|
|
"\n",
|
|
" if end == n:\n",
|
|
" break\n",
|
|
"\n",
|
|
" start = end - overlap\n",
|
|
"\n",
|
|
" return chunks\n",
|
|
"\n",
|
|
"\n",
|
|
"def semantic_chunking(\n",
|
|
" text: str,\n",
|
|
" embedding_model_name: str = HF_EMB_MODEL_NAME,\n",
|
|
" similarity_threshold: float = 0.6,\n",
|
|
" max_sentences_per_chunk: int = 12,\n",
|
|
") -> list[str]:\n",
|
|
" sentences = [s.strip() for s in sent_tokenize(text) if s.strip()]\n",
|
|
" if not sentences:\n",
|
|
" return []\n",
|
|
" logger.info(f\"Semantic chunking: {len(sentences)} sentences found\")\n",
|
|
" device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
|
"\n",
|
|
" tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)\n",
|
|
" model = AutoModel.from_pretrained(embedding_model_name).to(device)\n",
|
|
" model.eval()\n",
|
|
"\n",
|
|
" with torch.no_grad():\n",
|
|
" enc = tokenizer(sentences, padding=True, truncation=True, return_tensors=\"pt\").to(device)\n",
|
|
" out = model(**enc)\n",
|
|
" mask = enc[\"attention_mask\"].unsqueeze(-1)\n",
|
|
" vecs = (out.last_hidden_state * mask).sum(1) / mask.sum(1).clamp(min=1e-9)\n",
|
|
" vecs = F.normalize(vecs, p=2, dim=1)\n",
|
|
"\n",
|
|
" chunks: list[list[str]] = [[sentences[0]]]\n",
|
|
"\n",
|
|
" for i in range(1, len(sentences)):\n",
|
|
" sim = float((vecs[i - 1] * vecs[i]).sum())\n",
|
|
" logger.info(f\"Similarity between sentence {i-1} and {i}: {sim:.4f}\")\n",
|
|
" if sim < similarity_threshold or len(chunks[-1]) >= max_sentences_per_chunk:\n",
|
|
" chunks.append([])\n",
|
|
" chunks[-1].append(sentences[i])\n",
|
|
"\n",
|
|
" return [\" \".join(chunk) for chunk in chunks if chunk]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "e2a856fe",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"CHUNKING_REGISTRY: dict[str, ChunkingStrategy] = {\n",
|
|
" \"fixed\": fixed_size_token_chunking,\n",
|
|
" \"semantic\": semantic_chunking,\n",
|
|
"}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "35a937ac",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def build_chunks(\n",
|
|
" doc_text: str,\n",
|
|
" metadata: dict[str, Any],\n",
|
|
" chunking_strategy: str = \"fixed\",\n",
|
|
" **chunking_kwargs,\n",
|
|
") -> list[Document]:\n",
|
|
"\n",
|
|
" if chunking_strategy not in CHUNKING_REGISTRY:\n",
|
|
" raise ValueError(\n",
|
|
" f\"Unknown chunking strategy '{chunking_strategy}'. \"\n",
|
|
" f\"Available: {list(CHUNKING_REGISTRY.keys())}\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" chunking_fn = CHUNKING_REGISTRY[chunking_strategy]\n",
|
|
" parts = chunking_fn(doc_text, **chunking_kwargs)\n",
|
|
"\n",
|
|
" return [\n",
|
|
" Document(\n",
|
|
" id=str(uuid.uuid4()),\n",
|
|
" page_content=part,\n",
|
|
" metadata={**metadata,}\n",
|
|
" )\n",
|
|
" for i, part in enumerate(parts)\n",
|
|
" if part.strip()\n",
|
|
" ]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "eb3f44f0",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\u001b[32m2026-03-09 14:45:22.477\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36m__main__\u001b[0m:\u001b[36msemantic_chunking\u001b[0m:\u001b[36m40\u001b[0m - \u001b[1mSemantic chunking: 1089 sentences found\u001b[0m\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"a = build_chunks(\n",
|
|
" doc_text=md_content,\n",
|
|
" metadata={\"source\": \"test_doc\"},\n",
|
|
" chunking_strategy=\"semantic\",\n",
|
|
" similarity_threshold=0.8,\n",
|
|
" max_sentences_per_chunk=20\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "39a10e99",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Build Chunks"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8e214f79",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def build_chunks_from_folder(\n",
|
|
" folder_path: str,\n",
|
|
") -> list[Document]:\n",
|
|
"\n",
|
|
" folder = Path(folder_path)\n",
|
|
"\n",
|
|
" if not folder.exists() or not folder.is_dir():\n",
|
|
" raise ValueError(f\"Invalid folder path: {folder_path}\")\n",
|
|
"\n",
|
|
" all_chunks: list[Document] = []\n",
|
|
"\n",
|
|
" for file_path in folder.glob(\"*.txt\"):\n",
|
|
"\n",
|
|
" doc_text = file_path.read_text(encoding=\"utf-8\")\n",
|
|
"\n",
|
|
" if not doc_text.strip():\n",
|
|
" continue\n",
|
|
"\n",
|
|
" metadata: dict[str, Any] = {\n",
|
|
" \"source\": file_path.name,\n",
|
|
" }\n",
|
|
"\n",
|
|
" doc_text = clean_text(doc_text)\n",
|
|
"\n",
|
|
" chunk = Document(\n",
|
|
" id=str(uuid.uuid4()),\n",
|
|
" page_content=doc_text,\n",
|
|
" metadata={**metadata,}\n",
|
|
" )\n",
|
|
"\n",
|
|
" all_chunks.append(chunk)\n",
|
|
"\n",
|
|
" return all_chunks\n",
|
|
"\n",
|
|
"\n",
|
|
"chunks = build_chunks_from_folder(\n",
|
|
" folder_path=\"/home/acano/PycharmProjects/assistance-engine/ingestion/docs\"\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2a5dc98e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"chunks"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "77f6c552",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Elastic Search"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "09ce3e29",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"es = Elasticsearch(\n",
|
|
" ELASTICSEARCH_URL,\n",
|
|
" request_timeout=120,\n",
|
|
" max_retries=5,\n",
|
|
" retry_on_timeout=True,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d575c386",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if es.indices.exists(index=ELASTICSEARCH_INDEX):\n",
|
|
" es.indices.delete(index=ELASTICSEARCH_INDEX)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "40ea0af8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"for index in es.indices.get(index=\"*\"):\n",
|
|
" print(index)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4e091b39",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"embeddings = OllamaEmbeddings(base_url=OLLAMA_LOCAL_URL, model=OLLAMA_EMB_MODEL_NAME)\n",
|
|
"embeddings"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1ed4c817",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"db = ElasticsearchStore.from_documents(\n",
|
|
" chunks,\n",
|
|
" embeddings,\n",
|
|
" client=es,\n",
|
|
" index_name=ELASTICSEARCH_INDEX,\n",
|
|
" distance_strategy=\"COSINE\",\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "74c0a377",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"response = es.search(\n",
|
|
" index=ELASTICSEARCH_INDEX,\n",
|
|
" body={\n",
|
|
" \"query\": {\"match_all\": {}},\n",
|
|
" \"size\": 10 \n",
|
|
" }\n",
|
|
")\n",
|
|
"\n",
|
|
"for hit in response[\"hits\"][\"hits\"]:\n",
|
|
" print(\"ID:\", hit[\"_id\"])\n",
|
|
" print(\"Source:\", hit[\"_source\"])\n",
|
|
" print(\"-\" * 40)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "assistance-engine",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.11.13"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|