evaluation on acano
This commit is contained in:
parent
f6a907911d
commit
b01a76e71d
|
|
@ -479,6 +479,139 @@
|
||||||
" \n",
|
" \n",
|
||||||
"print(results)"
|
"print(results)"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "07f9f5e5",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Evaluate"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ec2362c4",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from dataclasses import dataclass\n",
|
||||||
|
"from typing import Any, Iterable\n",
|
||||||
|
" \n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
" \n",
|
||||||
|
"import mteb\n",
|
||||||
|
"from mteb.types import Array\n",
|
||||||
|
"from mteb.models import SearchEncoderWrapper\n",
|
||||||
|
" \n",
|
||||||
|
" \n",
|
||||||
|
"def _l2_normalize(x: np.ndarray, eps: float = 1e-12) -> np.ndarray:\n",
|
||||||
|
" norms = np.linalg.norm(x, axis=1, keepdims=True)\n",
|
||||||
|
" return x / np.clip(norms, eps, None)\n",
|
||||||
|
" \n",
|
||||||
|
" \n",
|
||||||
|
"def _to_text_list(batch: dict[str, Any]) -> list[str]:\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" MTEB batched inputs can be:\n",
|
||||||
|
" - TextInput: {\"text\": [..]}\n",
|
||||||
|
" - CorpusInput: {\"title\": [..], \"body\": [..], \"text\": [..]}\n",
|
||||||
|
" - QueryInput: {\"query\": [..], \"instruction\": [..], \"text\": [..]}\n",
|
||||||
|
" We prefer \"text\" if present; otherwise compose from title/body or query/instruction.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" if \"text\" in batch and batch[\"text\"] is not None:\n",
|
||||||
|
" return list(batch[\"text\"])\n",
|
||||||
|
" \n",
|
||||||
|
" if \"title\" in batch and \"body\" in batch:\n",
|
||||||
|
" titles = batch[\"title\"] or [\"\"] * len(batch[\"body\"])\n",
|
||||||
|
" bodies = batch[\"body\"] or [\"\"] * len(batch[\"title\"])\n",
|
||||||
|
" return [f\"{t} {b}\".strip() for t, b in zip(titles, bodies)]\n",
|
||||||
|
" \n",
|
||||||
|
" if \"query\" in batch:\n",
|
||||||
|
" queries = list(batch[\"query\"])\n",
|
||||||
|
" instructions = batch.get(\"instruction\")\n",
|
||||||
|
" if instructions:\n",
|
||||||
|
" return [f\"{i} {q}\".strip() for q, i in zip(queries, instructions)]\n",
|
||||||
|
" return queries\n",
|
||||||
|
" \n",
|
||||||
|
" raise ValueError(f\"Unsupported batch keys: {sorted(batch.keys())}\")\n",
|
||||||
|
" \n",
|
||||||
|
" \n",
|
||||||
|
"@dataclass\n",
|
||||||
|
"class OllamaLangChainEncoder:\n",
|
||||||
|
" lc_embeddings: Any # OllamaEmbeddings implements embed_documents()\n",
|
||||||
|
" normalize: bool = True\n",
|
||||||
|
" \n",
|
||||||
|
" # Optional metadata hook used by some wrappers; safe to keep as None for local runs\n",
|
||||||
|
" mteb_model_meta: Any = None\n",
|
||||||
|
" \n",
|
||||||
|
" def encode(\n",
|
||||||
|
" self,\n",
|
||||||
|
" inputs: Iterable[dict[str, Any]],\n",
|
||||||
|
" *,\n",
|
||||||
|
" task_metadata: Any,\n",
|
||||||
|
" hf_split: str,\n",
|
||||||
|
" hf_subset: str,\n",
|
||||||
|
" prompt_type: Any = None,\n",
|
||||||
|
" **kwargs: Any,\n",
|
||||||
|
" ) -> Array:\n",
|
||||||
|
" all_vecs: list[np.ndarray] = []\n",
|
||||||
|
" \n",
|
||||||
|
" for batch in inputs:\n",
|
||||||
|
" texts = _to_text_list(batch)\n",
|
||||||
|
" vecs = self.lc_embeddings.embed_documents(texts)\n",
|
||||||
|
" arr = np.asarray(vecs, dtype=np.float32)\n",
|
||||||
|
" if self.normalize:\n",
|
||||||
|
" arr = _l2_normalize(arr)\n",
|
||||||
|
" all_vecs.append(arr)\n",
|
||||||
|
" \n",
|
||||||
|
" if not all_vecs:\n",
|
||||||
|
" return np.zeros((0, 0), dtype=np.float32)\n",
|
||||||
|
" \n",
|
||||||
|
" return np.vstack(all_vecs)\n",
|
||||||
|
" \n",
|
||||||
|
" def similarity(self, embeddings1: Array, embeddings2: Array) -> Array:\n",
|
||||||
|
" a = np.asarray(embeddings1, dtype=np.float32)\n",
|
||||||
|
" b = np.asarray(embeddings2, dtype=np.float32)\n",
|
||||||
|
" if self.normalize:\n",
|
||||||
|
" # dot == cosine if already normalized\n",
|
||||||
|
" return a @ b.T\n",
|
||||||
|
" a = _l2_normalize(a)\n",
|
||||||
|
" b = _l2_normalize(b)\n",
|
||||||
|
" return a @ b.T\n",
|
||||||
|
" \n",
|
||||||
|
" def similarity_pairwise(self, embeddings1: Array, embeddings2: Array) -> Array:\n",
|
||||||
|
" a = np.asarray(embeddings1, dtype=np.float32)\n",
|
||||||
|
" b = np.asarray(embeddings2, dtype=np.float32)\n",
|
||||||
|
" if not self.normalize:\n",
|
||||||
|
" a = _l2_normalize(a)\n",
|
||||||
|
" b = _l2_normalize(b)\n",
|
||||||
|
" return np.sum(a * b, axis=1)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "db6fa201",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"encoder = OllamaLangChainEncoder(lc_embeddings=embeddings, normalize=True)\n",
|
||||||
|
"search_model = SearchEncoderWrapper(encoder)\n",
|
||||||
|
" \n",
|
||||||
|
"tasks = mteb.get_tasks([\n",
|
||||||
|
" \"CodeSearchNetRetrieval\",\n",
|
||||||
|
" \"CodeSearchNetCCRetrieval\",\n",
|
||||||
|
" \"AppsRetrieval\",\n",
|
||||||
|
" \"StackOverflowDupQuestions\",\n",
|
||||||
|
"])\n",
|
||||||
|
"results = mteb.evaluate(\n",
|
||||||
|
" model=search_model,\n",
|
||||||
|
" tasks=tasks,\n",
|
||||||
|
" encode_kwargs={\"batch_size\": 32, \"show_progress_bar\": True}\n",
|
||||||
|
")\n",
|
||||||
|
" \n",
|
||||||
|
"print(results)"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue