evaluation on acano

This commit is contained in:
pseco 2026-02-24 17:03:33 +01:00
parent f6a907911d
commit b01a76e71d
1 changed files with 133 additions and 0 deletions

View File

@ -479,6 +479,139 @@
" \n",
"print(results)"
]
},
{
"cell_type": "markdown",
"id": "07f9f5e5",
"metadata": {},
"source": [
"# Evaluate"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec2362c4",
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"from typing import Any, Iterable\n",
" \n",
"import numpy as np\n",
" \n",
"import mteb\n",
"from mteb.types import Array\n",
"from mteb.models import SearchEncoderWrapper\n",
" \n",
" \n",
"def _l2_normalize(x: np.ndarray, eps: float = 1e-12) -> np.ndarray:\n",
" norms = np.linalg.norm(x, axis=1, keepdims=True)\n",
" return x / np.clip(norms, eps, None)\n",
" \n",
" \n",
"def _to_text_list(batch: dict[str, Any]) -> list[str]:\n",
" \"\"\"\n",
" MTEB batched inputs can be:\n",
" - TextInput: {\"text\": [..]}\n",
" - CorpusInput: {\"title\": [..], \"body\": [..], \"text\": [..]}\n",
" - QueryInput: {\"query\": [..], \"instruction\": [..], \"text\": [..]}\n",
" We prefer \"text\" if present; otherwise compose from title/body or query/instruction.\n",
" \"\"\"\n",
" if \"text\" in batch and batch[\"text\"] is not None:\n",
" return list(batch[\"text\"])\n",
" \n",
" if \"title\" in batch and \"body\" in batch:\n",
" titles = batch[\"title\"] or [\"\"] * len(batch[\"body\"])\n",
" bodies = batch[\"body\"] or [\"\"] * len(batch[\"title\"])\n",
" return [f\"{t} {b}\".strip() for t, b in zip(titles, bodies)]\n",
" \n",
" if \"query\" in batch:\n",
" queries = list(batch[\"query\"])\n",
" instructions = batch.get(\"instruction\")\n",
" if instructions:\n",
" return [f\"{i} {q}\".strip() for q, i in zip(queries, instructions)]\n",
" return queries\n",
" \n",
" raise ValueError(f\"Unsupported batch keys: {sorted(batch.keys())}\")\n",
" \n",
" \n",
"@dataclass\n",
"class OllamaLangChainEncoder:\n",
" lc_embeddings: Any # OllamaEmbeddings implements embed_documents()\n",
" normalize: bool = True\n",
" \n",
" # Optional metadata hook used by some wrappers; safe to keep as None for local runs\n",
" mteb_model_meta: Any = None\n",
" \n",
" def encode(\n",
" self,\n",
" inputs: Iterable[dict[str, Any]],\n",
" *,\n",
" task_metadata: Any,\n",
" hf_split: str,\n",
" hf_subset: str,\n",
" prompt_type: Any = None,\n",
" **kwargs: Any,\n",
" ) -> Array:\n",
" all_vecs: list[np.ndarray] = []\n",
" \n",
" for batch in inputs:\n",
" texts = _to_text_list(batch)\n",
" vecs = self.lc_embeddings.embed_documents(texts)\n",
" arr = np.asarray(vecs, dtype=np.float32)\n",
" if self.normalize:\n",
" arr = _l2_normalize(arr)\n",
" all_vecs.append(arr)\n",
" \n",
" if not all_vecs:\n",
" return np.zeros((0, 0), dtype=np.float32)\n",
" \n",
" return np.vstack(all_vecs)\n",
" \n",
" def similarity(self, embeddings1: Array, embeddings2: Array) -> Array:\n",
" a = np.asarray(embeddings1, dtype=np.float32)\n",
" b = np.asarray(embeddings2, dtype=np.float32)\n",
" if self.normalize:\n",
" # dot == cosine if already normalized\n",
" return a @ b.T\n",
" a = _l2_normalize(a)\n",
" b = _l2_normalize(b)\n",
" return a @ b.T\n",
" \n",
" def similarity_pairwise(self, embeddings1: Array, embeddings2: Array) -> Array:\n",
" a = np.asarray(embeddings1, dtype=np.float32)\n",
" b = np.asarray(embeddings2, dtype=np.float32)\n",
" if not self.normalize:\n",
" a = _l2_normalize(a)\n",
" b = _l2_normalize(b)\n",
" return np.sum(a * b, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "db6fa201",
"metadata": {},
"outputs": [],
"source": [
"encoder = OllamaLangChainEncoder(lc_embeddings=embeddings, normalize=True)\n",
"search_model = SearchEncoderWrapper(encoder)\n",
" \n",
"tasks = mteb.get_tasks([\n",
" \"CodeSearchNetRetrieval\",\n",
" \"CodeSearchNetCCRetrieval\",\n",
" \"AppsRetrieval\",\n",
" \"StackOverflowDupQuestions\",\n",
"])\n",
"results = mteb.evaluate(\n",
" model=search_model,\n",
" tasks=tasks,\n",
" encode_kwargs={\"batch_size\": 32, \"show_progress_bar\": True}\n",
")\n",
" \n",
"print(results)"
]
}
],
"metadata": {