evaluation on acano
This commit is contained in:
parent
f6a907911d
commit
b01a76e71d
|
|
@ -479,6 +479,139 @@
|
|||
" \n",
|
||||
"print(results)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "07f9f5e5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Evaluate"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ec2362c4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"from typing import Any, Iterable\n",
|
||||
" \n",
|
||||
"import numpy as np\n",
|
||||
" \n",
|
||||
"import mteb\n",
|
||||
"from mteb.types import Array\n",
|
||||
"from mteb.models import SearchEncoderWrapper\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"def _l2_normalize(x: np.ndarray, eps: float = 1e-12) -> np.ndarray:\n",
|
||||
" norms = np.linalg.norm(x, axis=1, keepdims=True)\n",
|
||||
" return x / np.clip(norms, eps, None)\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"def _to_text_list(batch: dict[str, Any]) -> list[str]:\n",
|
||||
" \"\"\"\n",
|
||||
" MTEB batched inputs can be:\n",
|
||||
" - TextInput: {\"text\": [..]}\n",
|
||||
" - CorpusInput: {\"title\": [..], \"body\": [..], \"text\": [..]}\n",
|
||||
" - QueryInput: {\"query\": [..], \"instruction\": [..], \"text\": [..]}\n",
|
||||
" We prefer \"text\" if present; otherwise compose from title/body or query/instruction.\n",
|
||||
" \"\"\"\n",
|
||||
" if \"text\" in batch and batch[\"text\"] is not None:\n",
|
||||
" return list(batch[\"text\"])\n",
|
||||
" \n",
|
||||
" if \"title\" in batch and \"body\" in batch:\n",
|
||||
" titles = batch[\"title\"] or [\"\"] * len(batch[\"body\"])\n",
|
||||
" bodies = batch[\"body\"] or [\"\"] * len(batch[\"title\"])\n",
|
||||
" return [f\"{t} {b}\".strip() for t, b in zip(titles, bodies)]\n",
|
||||
" \n",
|
||||
" if \"query\" in batch:\n",
|
||||
" queries = list(batch[\"query\"])\n",
|
||||
" instructions = batch.get(\"instruction\")\n",
|
||||
" if instructions:\n",
|
||||
" return [f\"{i} {q}\".strip() for q, i in zip(queries, instructions)]\n",
|
||||
" return queries\n",
|
||||
" \n",
|
||||
" raise ValueError(f\"Unsupported batch keys: {sorted(batch.keys())}\")\n",
|
||||
" \n",
|
||||
" \n",
|
||||
"@dataclass\n",
|
||||
"class OllamaLangChainEncoder:\n",
|
||||
" lc_embeddings: Any # OllamaEmbeddings implements embed_documents()\n",
|
||||
" normalize: bool = True\n",
|
||||
" \n",
|
||||
" # Optional metadata hook used by some wrappers; safe to keep as None for local runs\n",
|
||||
" mteb_model_meta: Any = None\n",
|
||||
" \n",
|
||||
" def encode(\n",
|
||||
" self,\n",
|
||||
" inputs: Iterable[dict[str, Any]],\n",
|
||||
" *,\n",
|
||||
" task_metadata: Any,\n",
|
||||
" hf_split: str,\n",
|
||||
" hf_subset: str,\n",
|
||||
" prompt_type: Any = None,\n",
|
||||
" **kwargs: Any,\n",
|
||||
" ) -> Array:\n",
|
||||
" all_vecs: list[np.ndarray] = []\n",
|
||||
" \n",
|
||||
" for batch in inputs:\n",
|
||||
" texts = _to_text_list(batch)\n",
|
||||
" vecs = self.lc_embeddings.embed_documents(texts)\n",
|
||||
" arr = np.asarray(vecs, dtype=np.float32)\n",
|
||||
" if self.normalize:\n",
|
||||
" arr = _l2_normalize(arr)\n",
|
||||
" all_vecs.append(arr)\n",
|
||||
" \n",
|
||||
" if not all_vecs:\n",
|
||||
" return np.zeros((0, 0), dtype=np.float32)\n",
|
||||
" \n",
|
||||
" return np.vstack(all_vecs)\n",
|
||||
" \n",
|
||||
" def similarity(self, embeddings1: Array, embeddings2: Array) -> Array:\n",
|
||||
" a = np.asarray(embeddings1, dtype=np.float32)\n",
|
||||
" b = np.asarray(embeddings2, dtype=np.float32)\n",
|
||||
" if self.normalize:\n",
|
||||
" # dot == cosine if already normalized\n",
|
||||
" return a @ b.T\n",
|
||||
" a = _l2_normalize(a)\n",
|
||||
" b = _l2_normalize(b)\n",
|
||||
" return a @ b.T\n",
|
||||
" \n",
|
||||
" def similarity_pairwise(self, embeddings1: Array, embeddings2: Array) -> Array:\n",
|
||||
" a = np.asarray(embeddings1, dtype=np.float32)\n",
|
||||
" b = np.asarray(embeddings2, dtype=np.float32)\n",
|
||||
" if not self.normalize:\n",
|
||||
" a = _l2_normalize(a)\n",
|
||||
" b = _l2_normalize(b)\n",
|
||||
" return np.sum(a * b, axis=1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "db6fa201",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"encoder = OllamaLangChainEncoder(lc_embeddings=embeddings, normalize=True)\n",
|
||||
"search_model = SearchEncoderWrapper(encoder)\n",
|
||||
" \n",
|
||||
"tasks = mteb.get_tasks([\n",
|
||||
" \"CodeSearchNetRetrieval\",\n",
|
||||
" \"CodeSearchNetCCRetrieval\",\n",
|
||||
" \"AppsRetrieval\",\n",
|
||||
" \"StackOverflowDupQuestions\",\n",
|
||||
"])\n",
|
||||
"results = mteb.evaluate(\n",
|
||||
" model=search_model,\n",
|
||||
" tasks=tasks,\n",
|
||||
" encode_kwargs={\"batch_size\": 32, \"show_progress_bar\": True}\n",
|
||||
")\n",
|
||||
" \n",
|
||||
"print(results)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
|
|||
Loading…
Reference in New Issue