Refactor code structure for improved readability and maintainability
This commit is contained in:
parent
d4d7d9d2a1
commit
4b5352d93c
|
|
@ -28,5 +28,7 @@ dev = [
|
||||||
"beir>=2.2.0",
|
"beir>=2.2.0",
|
||||||
"jupyter>=1.1.1",
|
"jupyter>=1.1.1",
|
||||||
"langfuse>=3.14.4",
|
"langfuse>=3.14.4",
|
||||||
|
"mteb>=2.8.8",
|
||||||
|
"polars>=1.38.1",
|
||||||
"ruff>=0.15.1",
|
"ruff>=0.15.1",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -369,7 +369,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": null,
|
||||||
"id": "74c0a377",
|
"id": "74c0a377",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
|
@ -412,7 +412,7 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"response = es.search(\n",
|
"response = es.search(\n",
|
||||||
" index=ES_INDEX_NAME,\n",
|
" index=ES_INDEX_NAME,################dfdfdfasdad\n",
|
||||||
" body={\n",
|
" body={\n",
|
||||||
" \"query\": {\"match_all\": {}},\n",
|
" \"query\": {\"match_all\": {}},\n",
|
||||||
" \"size\": 10 \n",
|
" \"size\": 10 \n",
|
||||||
|
|
|
||||||
|
|
@ -28,13 +28,13 @@
|
||||||
"from langchain_elasticsearch import ElasticsearchStore\n",
|
"from langchain_elasticsearch import ElasticsearchStore\n",
|
||||||
"from langgraph.graph import StateGraph, END\n",
|
"from langgraph.graph import StateGraph, END\n",
|
||||||
"from langgraph.prebuilt import ToolNode\n",
|
"from langgraph.prebuilt import ToolNode\n",
|
||||||
"from langfuse import get_client\n",
|
"from langfuse import get_client, Langfuse\n",
|
||||||
"from langfuse.langchain import CallbackHandler"
|
"from langfuse.langchain import CallbackHandler"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": null,
|
||||||
"id": "30edcecc",
|
"id": "30edcecc",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
|
@ -48,8 +48,8 @@
|
||||||
"LANGFUSE_SECRET_KEY = os.getenv(\"LANGFUSE_SECRET_KEY\")\n",
|
"LANGFUSE_SECRET_KEY = os.getenv(\"LANGFUSE_SECRET_KEY\")\n",
|
||||||
"LANGFUSE_HOST = os.getenv(\"LANGFUSE_HOST\")\n",
|
"LANGFUSE_HOST = os.getenv(\"LANGFUSE_HOST\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# langfuse = get_client()\n",
|
"langfuse = get_client()\n",
|
||||||
"# langfuse_handler = CallbackHandler()\n",
|
"langfuse_handler = CallbackHandler()\n",
|
||||||
"\n",
|
"\n",
|
||||||
"embeddings = OllamaEmbeddings(base_url=BASE_URL, model=EMB_MODEL_NAME)\n",
|
"embeddings = OllamaEmbeddings(base_url=BASE_URL, model=EMB_MODEL_NAME)\n",
|
||||||
"llm = ChatOllama(base_url=BASE_URL, model=MODEL_NAME)\n",
|
"llm = ChatOllama(base_url=BASE_URL, model=MODEL_NAME)\n",
|
||||||
|
|
@ -65,15 +65,31 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 3,
|
||||||
"id": "ad98841b",
|
"id": "ad98841b",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"ename": "ValidationError",
|
||||||
|
"evalue": "2 validation errors for ParsingModel[Projects]\n__root__ -> data -> 0 -> organization\n field required (type=value_error.missing)\n__root__ -> data -> 0 -> metadata\n field required (type=value_error.missing)",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
||||||
|
"\u001b[31mValidationError\u001b[39m Traceback (most recent call last)",
|
||||||
|
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[3]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[43mlangfuse\u001b[49m\u001b[43m.\u001b[49m\u001b[43mauth_check\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[32m 2\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mLangfuse client is authenticated and ready!\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n",
|
||||||
|
"\u001b[36mFile \u001b[39m\u001b[32m~/PycharmProjects/assistance-engine/.venv/lib/python3.11/site-packages/langfuse/_client/client.py:3385\u001b[39m, in \u001b[36mLangfuse.auth_check\u001b[39m\u001b[34m(self)\u001b[39m\n\u001b[32m 3376\u001b[39m \u001b[38;5;250m\u001b[39m\u001b[33;03m\"\"\"Check if the provided credentials (public and secret key) are valid.\u001b[39;00m\n\u001b[32m 3377\u001b[39m \n\u001b[32m 3378\u001b[39m \u001b[33;03mRaises:\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 3382\u001b[39m \u001b[33;03m This method is blocking. It is discouraged to use it in production code.\u001b[39;00m\n\u001b[32m 3383\u001b[39m \u001b[33;03m\"\"\"\u001b[39;00m\n\u001b[32m 3384\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m3385\u001b[39m projects = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mapi\u001b[49m\u001b[43m.\u001b[49m\u001b[43mprojects\u001b[49m\u001b[43m.\u001b[49m\u001b[43mget\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 3386\u001b[39m langfuse_logger.debug(\n\u001b[32m 3387\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mAuth check successful, found \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(projects.data)\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m projects\u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 3388\u001b[39m )\n\u001b[32m 3389\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(projects.data) == \u001b[32m0\u001b[39m:\n",
|
||||||
|
"\u001b[36mFile \u001b[39m\u001b[32m~/PycharmProjects/assistance-engine/.venv/lib/python3.11/site-packages/langfuse/api/resources/projects/client.py:65\u001b[39m, in \u001b[36mProjectsClient.get\u001b[39m\u001b[34m(self, request_options)\u001b[39m\n\u001b[32m 63\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 64\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[32m200\u001b[39m <= _response.status_code < \u001b[32m300\u001b[39m:\n\u001b[32m---> \u001b[39m\u001b[32m65\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mpydantic_v1\u001b[49m\u001b[43m.\u001b[49m\u001b[43mparse_obj_as\u001b[49m\u001b[43m(\u001b[49m\u001b[43mProjects\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m_response\u001b[49m\u001b[43m.\u001b[49m\u001b[43mjson\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[32m 66\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m _response.status_code == \u001b[32m400\u001b[39m:\n\u001b[32m 67\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m Error(pydantic_v1.parse_obj_as(typing.Any, _response.json())) \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n",
|
||||||
|
"\u001b[36mFile \u001b[39m\u001b[32m~/PycharmProjects/assistance-engine/.venv/lib/python3.11/site-packages/pydantic/v1/tools.py:38\u001b[39m, in \u001b[36mparse_obj_as\u001b[39m\u001b[34m(type_, obj, type_name)\u001b[39m\n\u001b[32m 36\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mparse_obj_as\u001b[39m(type_: Type[T], obj: Any, *, type_name: Optional[NameFactory] = \u001b[38;5;28;01mNone\u001b[39;00m) -> T:\n\u001b[32m 37\u001b[39m model_type = _get_parsing_type(type_, type_name=type_name) \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m38\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mmodel_type\u001b[49m\u001b[43m(\u001b[49m\u001b[43m__root__\u001b[49m\u001b[43m=\u001b[49m\u001b[43mobj\u001b[49m\u001b[43m)\u001b[49m.__root__\n",
|
||||||
|
"\u001b[36mFile \u001b[39m\u001b[32m~/PycharmProjects/assistance-engine/.venv/lib/python3.11/site-packages/pydantic/v1/main.py:347\u001b[39m, in \u001b[36mBaseModel.__init__\u001b[39m\u001b[34m(__pydantic_self__, **data)\u001b[39m\n\u001b[32m 345\u001b[39m values, fields_set, validation_error = validate_model(__pydantic_self__.\u001b[34m__class__\u001b[39m, data)\n\u001b[32m 346\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m validation_error:\n\u001b[32m--> \u001b[39m\u001b[32m347\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m validation_error\n\u001b[32m 348\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 349\u001b[39m object_setattr(__pydantic_self__, \u001b[33m'\u001b[39m\u001b[33m__dict__\u001b[39m\u001b[33m'\u001b[39m, values)\n",
|
||||||
|
"\u001b[31mValidationError\u001b[39m: 2 validation errors for ParsingModel[Projects]\n__root__ -> data -> 0 -> organization\n field required (type=value_error.missing)\n__root__ -> data -> 0 -> metadata\n field required (type=value_error.missing)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"# if langfuse.auth_check():\n",
|
"if langfuse.auth_check():\n",
|
||||||
"# print(\"Langfuse client is authenticated and ready!\")\n",
|
" print(\"Langfuse client is authenticated and ready!\")\n",
|
||||||
"# else:\n",
|
"else:\n",
|
||||||
"# print(\"Authentication failed. Please check your credentials and host.\")"
|
" print(\"Authentication failed. Please check your credentials and host.\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -86,7 +102,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 7,
|
||||||
"id": "5f8c88cf",
|
"id": "5f8c88cf",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
|
@ -105,7 +121,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 8,
|
||||||
"id": "f0a21230",
|
"id": "f0a21230",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
|
@ -115,7 +131,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 9,
|
||||||
"id": "f9359747",
|
"id": "f9359747",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
|
@ -146,7 +162,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 10,
|
||||||
"id": "e5247ab9",
|
"id": "e5247ab9",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
|
@ -161,7 +177,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 11,
|
||||||
"id": "a644f6fa",
|
"id": "a644f6fa",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
|
@ -181,7 +197,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 17,
|
"execution_count": 12,
|
||||||
"id": "36d0f54e",
|
"id": "36d0f54e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
|
@ -215,7 +231,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 18,
|
"execution_count": 13,
|
||||||
"id": "fae46a58",
|
"id": "fae46a58",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
|
@ -233,7 +249,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 19,
|
"execution_count": 14,
|
||||||
"id": "2fec3fdb",
|
"id": "2fec3fdb",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
|
|
@ -265,7 +281,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": 15,
|
||||||
"id": "8569cf39",
|
"id": "8569cf39",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
|
@ -286,7 +302,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 24,
|
"execution_count": 16,
|
||||||
"id": "a1a1f3cf",
|
"id": "a1a1f3cf",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
|
|
@ -296,12 +312,172 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 17,
|
||||||
"id": "53b89690",
|
"id": "53b89690",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"================================\u001b[1m Human Message \u001b[0m=================================\n",
|
||||||
|
"\n",
|
||||||
|
"Create a small snippet that adds 2 numbers using AVAP language. Use the provided tool.\n",
|
||||||
|
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
||||||
|
"\n",
|
||||||
|
"<think>\n",
|
||||||
|
"Okay, the user wants a small snippet that adds two numbers using the AVAP language. They mentioned using the provided tool. Let me check the tools available.\n",
|
||||||
|
"\n",
|
||||||
|
"The only tool provided is retrieve, which gets documents about AVAP. But the user is asking for a code snippet, not a document retrieval. Since the function retrieve isn't for generating code, maybe I can't use it here. Wait, maybe I should inform the user that the tool isn't suitable for this task. Alternatively, perhaps the vector store has some code examples, but I don't have access to it. So, I need to respond that I can't create the code snippet with the available tool. But the user wants the answer. Hmm, maybe I should proceed by writing the code snippet using AVAP syntax, even though the tool isn't applicable. Let me think of AVAP syntax. AVAP uses a specific syntax for APIs, so perhaps the code would look something like adding two numbers using a function. Let me draft that.\n",
|
||||||
|
"</think>\n",
|
||||||
|
"\n",
|
||||||
|
"The AVAP language allows creating API endpoints using simple syntax. Here's a snippet that adds two numbers:\n",
|
||||||
|
"\n",
|
||||||
|
"```avap\n",
|
||||||
|
"function add(a, b) {\n",
|
||||||
|
" return a + b;\n",
|
||||||
|
"}\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"This snippet uses the `add` function to sum two numbers. Let me know if you need further assistance!\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"a = stream_graph_updates(user_input)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "3707574b",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### MTEB"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "d9657ec4",
|
||||||
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"a = stream_graph_updates(user_input)"
|
"from dataclasses import dataclass\n",
|
||||||
|
"from typing import Any, Iterable\n",
|
||||||
|
" \n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
" \n",
|
||||||
|
"import mteb\n",
|
||||||
|
"from mteb.types import Array\n",
|
||||||
|
"from mteb.models import SearchEncoderWrapper\n",
|
||||||
|
" \n",
|
||||||
|
" \n",
|
||||||
|
"def _l2_normalize(x: np.ndarray, eps: float = 1e-12) -> np.ndarray:\n",
|
||||||
|
" norms = np.linalg.norm(x, axis=1, keepdims=True)\n",
|
||||||
|
" return x / np.clip(norms, eps, None)\n",
|
||||||
|
" \n",
|
||||||
|
" \n",
|
||||||
|
"def _to_text_list(batch: dict[str, Any]) -> list[str]:\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" MTEB batched inputs can be:\n",
|
||||||
|
" - TextInput: {\"text\": [..]}\n",
|
||||||
|
" - CorpusInput: {\"title\": [..], \"body\": [..], \"text\": [..]}\n",
|
||||||
|
" - QueryInput: {\"query\": [..], \"instruction\": [..], \"text\": [..]}\n",
|
||||||
|
" We prefer \"text\" if present; otherwise compose from title/body or query/instruction.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" if \"text\" in batch and batch[\"text\"] is not None:\n",
|
||||||
|
" return list(batch[\"text\"])\n",
|
||||||
|
" \n",
|
||||||
|
" if \"title\" in batch and \"body\" in batch:\n",
|
||||||
|
" titles = batch[\"title\"] or [\"\"] * len(batch[\"body\"])\n",
|
||||||
|
" bodies = batch[\"body\"] or [\"\"] * len(batch[\"title\"])\n",
|
||||||
|
" return [f\"{t} {b}\".strip() for t, b in zip(titles, bodies)]\n",
|
||||||
|
" \n",
|
||||||
|
" if \"query\" in batch:\n",
|
||||||
|
" queries = list(batch[\"query\"])\n",
|
||||||
|
" instructions = batch.get(\"instruction\")\n",
|
||||||
|
" if instructions:\n",
|
||||||
|
" return [f\"{i} {q}\".strip() for q, i in zip(queries, instructions)]\n",
|
||||||
|
" return queries\n",
|
||||||
|
" \n",
|
||||||
|
" raise ValueError(f\"Unsupported batch keys: {sorted(batch.keys())}\")\n",
|
||||||
|
" \n",
|
||||||
|
" \n",
|
||||||
|
"@dataclass\n",
|
||||||
|
"class OllamaLangChainEncoder:\n",
|
||||||
|
" lc_embeddings: Any # OllamaEmbeddings implements embed_documents()\n",
|
||||||
|
" normalize: bool = True\n",
|
||||||
|
" \n",
|
||||||
|
" # Optional metadata hook used by some wrappers; safe to keep as None for local runs\n",
|
||||||
|
" mteb_model_meta: Any = None\n",
|
||||||
|
" \n",
|
||||||
|
" def encode(\n",
|
||||||
|
" self,\n",
|
||||||
|
" inputs: Iterable[dict[str, Any]],\n",
|
||||||
|
" *,\n",
|
||||||
|
" task_metadata: Any,\n",
|
||||||
|
" hf_split: str,\n",
|
||||||
|
" hf_subset: str,\n",
|
||||||
|
" prompt_type: Any = None,\n",
|
||||||
|
" **kwargs: Any,\n",
|
||||||
|
" ) -> Array:\n",
|
||||||
|
" all_vecs: list[np.ndarray] = []\n",
|
||||||
|
" \n",
|
||||||
|
" for batch in inputs:\n",
|
||||||
|
" texts = _to_text_list(batch)\n",
|
||||||
|
" vecs = self.lc_embeddings.embed_documents(texts)\n",
|
||||||
|
" arr = np.asarray(vecs, dtype=np.float32)\n",
|
||||||
|
" if self.normalize:\n",
|
||||||
|
" arr = _l2_normalize(arr)\n",
|
||||||
|
" all_vecs.append(arr)\n",
|
||||||
|
" \n",
|
||||||
|
" if not all_vecs:\n",
|
||||||
|
" return np.zeros((0, 0), dtype=np.float32)\n",
|
||||||
|
" \n",
|
||||||
|
" return np.vstack(all_vecs)\n",
|
||||||
|
" \n",
|
||||||
|
" def similarity(self, embeddings1: Array, embeddings2: Array) -> Array:\n",
|
||||||
|
" a = np.asarray(embeddings1, dtype=np.float32)\n",
|
||||||
|
" b = np.asarray(embeddings2, dtype=np.float32)\n",
|
||||||
|
" if self.normalize:\n",
|
||||||
|
" # dot == cosine if already normalized\n",
|
||||||
|
" return a @ b.T\n",
|
||||||
|
" a = _l2_normalize(a)\n",
|
||||||
|
" b = _l2_normalize(b)\n",
|
||||||
|
" return a @ b.T\n",
|
||||||
|
" \n",
|
||||||
|
" def similarity_pairwise(self, embeddings1: Array, embeddings2: Array) -> Array:\n",
|
||||||
|
" a = np.asarray(embeddings1, dtype=np.float32)\n",
|
||||||
|
" b = np.asarray(embeddings2, dtype=np.float32)\n",
|
||||||
|
" if not self.normalize:\n",
|
||||||
|
" a = _l2_normalize(a)\n",
|
||||||
|
" b = _l2_normalize(b)\n",
|
||||||
|
" return np.sum(a * b, axis=1)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "85727a68",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"encoder = OllamaLangChainEncoder(lc_embeddings=embeddings, normalize=True)\n",
|
||||||
|
"search_model = SearchEncoderWrapper(encoder)\n",
|
||||||
|
" \n",
|
||||||
|
"tasks = mteb.get_tasks([\n",
|
||||||
|
" \"CodeSearchNetRetrieval\",\n",
|
||||||
|
" \"CodeSearchNetCCRetrieval\",\n",
|
||||||
|
" \"AppsRetrieval\",\n",
|
||||||
|
" \"StackOverflowDupQuestions\",\n",
|
||||||
|
"])\n",
|
||||||
|
"results = mteb.evaluate(\n",
|
||||||
|
" model=search_model,\n",
|
||||||
|
" tasks=tasks,\n",
|
||||||
|
" encode_kwargs={\"batch_size\": 32, \"show_progress_bar\": True}\n",
|
||||||
|
")\n",
|
||||||
|
" \n",
|
||||||
|
"print(results)"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -264,9 +264,46 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": 14,
|
||||||
"id": "1db7d110",
|
"id": "1db7d110",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Resultados guardados en /home/acano/PycharmProjects/assistance-engine/data/interim/beir_cosqa_results.json\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"results_data = {\n",
|
||||||
|
" \"qwen3-0.6B-emb:latest\": {\n",
|
||||||
|
" \"NDCG\": ndcg,\n",
|
||||||
|
" \"MAP\": _map,\n",
|
||||||
|
" \"Recall\": recall,\n",
|
||||||
|
" \"Precision\": precision,\n",
|
||||||
|
" },\n",
|
||||||
|
" \"qwen2.5:1.5b\": {\n",
|
||||||
|
" \"NDCG\": ndcg_qwen_2,\n",
|
||||||
|
" \"MAP\": _map_qwen_2,\n",
|
||||||
|
" \"Recall\": recall_qwen_2,\n",
|
||||||
|
" \"Precision\": precision_qwen_2,\n",
|
||||||
|
" }\n",
|
||||||
|
"}\n",
|
||||||
|
" \n",
|
||||||
|
"output_file = \"/home/acano/PycharmProjects/assistance-engine/data/interim/beir_cosqa_results.json\"\n",
|
||||||
|
"with open(output_file, \"w\") as f:\n",
|
||||||
|
" json.dump(results_data, f, indent=2)\n",
|
||||||
|
" \n",
|
||||||
|
"print(f\"Resultados guardados en {output_file}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "e4f8d78b",
|
||||||
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue