250 lines
7.1 KiB
Plaintext
250 lines
7.1 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "66cbbaf8",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Libraries"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "c01c19dc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from typing import Dict, List, Union\n",
|
|
"import numpy as np\n",
|
|
"from langchain_ollama import OllamaEmbeddings\n",
|
|
"from beir.datasets.data_loader import GenericDataLoader\n",
|
|
"from beir.retrieval.search.dense import DenseRetrievalExactSearch\n",
|
|
"from beir.retrieval.evaluation import EvaluateRetrieval\n",
|
|
"from beir import util"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ac011c1c",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Utils"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "b83e7900",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class BEIROllamaEmbeddings:\n",
|
|
" \"\"\"\n",
|
|
" Adapter that makes LangChain's OllamaEmbeddings compatible with BEIR.\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" def __init__(\n",
|
|
" self,\n",
|
|
" base_url: str,\n",
|
|
" model: str,\n",
|
|
" batch_size: int = 64,\n",
|
|
" ) -> None:\n",
|
|
" self.batch_size = batch_size\n",
|
|
" self.embeddings = OllamaEmbeddings(\n",
|
|
" base_url=base_url,\n",
|
|
" model=model,\n",
|
|
" )\n",
|
|
"\n",
|
|
" def _batch_embed(self, texts: List[str]) -> np.ndarray:\n",
|
|
" vectors = []\n",
|
|
"\n",
|
|
" for i in range(0, len(texts), self.batch_size):\n",
|
|
" batch = texts[i : i + self.batch_size]\n",
|
|
" batch_vectors = self.embeddings.embed_documents(batch)\n",
|
|
" vectors.extend(batch_vectors)\n",
|
|
"\n",
|
|
" return np.asarray(vectors, dtype=np.float32)\n",
|
|
"\n",
|
|
" def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:\n",
|
|
" \"\"\"\n",
|
|
" BEIR query encoder\n",
|
|
" \"\"\"\n",
|
|
" return self._batch_embed(queries)\n",
|
|
"\n",
|
|
" def encode_corpus(\n",
|
|
" self,\n",
|
|
" corpus: Union[List[Dict[str, str]], Dict[str, Dict[str, str]]],\n",
|
|
" **kwargs,\n",
|
|
" ) -> np.ndarray:\n",
|
|
" \"\"\"\n",
|
|
" BEIR corpus encoder\n",
|
|
" \"\"\"\n",
|
|
" if isinstance(corpus, dict):\n",
|
|
" corpus = list(corpus.values())\n",
|
|
"\n",
|
|
" texts = []\n",
|
|
" for doc in corpus:\n",
|
|
" title = (doc.get(\"title\") or \"\").strip()\n",
|
|
" text = (doc.get(\"text\") or \"\").strip()\n",
|
|
"\n",
|
|
" if title:\n",
|
|
" texts.append(f\"{title}\\n{text}\")\n",
|
|
" else:\n",
|
|
" texts.append(text)\n",
|
|
"\n",
|
|
" return self._batch_embed(texts)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c9528fb6",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "230aae25",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "1915c67ec20f4806b30b48eff9a132e2",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
" 0%| | 0/5183 [00:00<?, ?it/s]"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"dataset=\"scifact\"\n",
|
|
"url=f\"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip\"\n",
|
|
"data_path=util.download_and_unzip(url, out_dir=\"datasets\")\n",
|
|
"corpus, queries, qrels=GenericDataLoader(data_path).load(split=\"test\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "13050d31",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Test qwen3-0.6B-emb:latest"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "514540af",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"NDCG: {'NDCG@1': 0.56333, 'NDCG@3': 0.64367, 'NDCG@5': 0.66577, 'NDCG@10': 0.68551, 'NDCG@100': 0.71285}\n",
|
|
"MAP: {'MAP@1': 0.52994, 'MAP@3': 0.6117, 'MAP@5': 0.62815, 'MAP@10': 0.6383, 'MAP@100': 0.64466}\n",
|
|
"Recall: {'Recall@1': 0.52994, 'Recall@3': 0.7035, 'Recall@5': 0.75967, 'Recall@10': 0.81611, 'Recall@100': 0.94}\n",
|
|
"Precision: {'P@1': 0.56333, 'P@3': 0.25889, 'P@5': 0.17067, 'P@10': 0.093, 'P@100': 0.0107}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model = BEIROllamaEmbeddings(\n",
|
|
" base_url=\"http://localhost:11434\",\n",
|
|
" model=\"qwen3-0.6B-emb:latest\",\n",
|
|
" batch_size=64,\n",
|
|
")\n",
|
|
"\n",
|
|
"retriever = DenseRetrievalExactSearch(model, batch_size=64)\n",
|
|
"evaluator = EvaluateRetrieval(retriever, score_function=\"cos_sim\")\n",
|
|
"\n",
|
|
"results = evaluator.retrieve(corpus, queries)\n",
|
|
"ndcg, _map, recall, precision = evaluator.evaluate(\n",
|
|
" qrels, results, [1, 3, 5, 10, 100]\n",
|
|
")\n",
|
|
"\n",
|
|
"print(\"NDCG:\", ndcg)\n",
|
|
"print(\"MAP:\", _map)\n",
|
|
"print(\"Recall:\", recall)\n",
|
|
"print(\"Precision:\", precision)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "c4e643ca",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Test qwen2.5:1.5b"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "5ced1c25",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"NDCG: {'NDCG@1': 0.02333, 'NDCG@3': 0.03498, 'NDCG@5': 0.0404, 'NDCG@10': 0.04619, 'NDCG@100': 0.07768}\n",
|
|
"MAP: {'MAP@1': 0.02083, 'MAP@3': 0.03083, 'MAP@5': 0.03375, 'MAP@10': 0.03632, 'MAP@100': 0.04123}\n",
|
|
"Recall: {'Recall@1': 0.02083, 'Recall@3': 0.04417, 'Recall@5': 0.0575, 'Recall@10': 0.07417, 'Recall@100': 0.23144}\n",
|
|
"Precision: {'P@1': 0.02333, 'P@3': 0.01556, 'P@5': 0.01267, 'P@10': 0.00833, 'P@100': 0.00277}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"model_qwen2 = BEIROllamaEmbeddings(\n",
|
|
" base_url=\"http://localhost:11434\",\n",
|
|
" model=\"qwen2.5:1.5b\",\n",
|
|
" batch_size=64,\n",
|
|
")\n",
|
|
"\n",
|
|
"retriever_qwen_2 = DenseRetrievalExactSearch(model_qwen2, batch_size=64)\n",
|
|
"evaluator_qwen_2 = EvaluateRetrieval(retriever_qwen_2, score_function=\"cos_sim\")\n",
|
|
"\n",
|
|
"results_qwen_2 = evaluator_qwen_2.retrieve(corpus, queries)\n",
|
|
"ndcg_qwen_2, _map_qwen_2, recall_qwen_2, precision_qwen_2 = evaluator_qwen_2.evaluate(\n",
|
|
" qrels, results_qwen_2, [1, 3, 5, 10, 100]\n",
|
|
")\n",
|
|
"\n",
|
|
"print(\"NDCG:\", ndcg_qwen_2)\n",
|
|
"print(\"MAP:\", _map_qwen_2)\n",
|
|
"print(\"Recall:\", recall_qwen_2)\n",
|
|
"print(\"Precision:\", precision_qwen_2)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "assistance-engine",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.11"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|