{ "cells": [ { "cell_type": "markdown", "id": "7cf31ec4", "metadata": {}, "source": [ "# Libraries" ] }, { "cell_type": "code", "execution_count": 47, "id": "5fe736bf", "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "from pathlib import Path\n", "from typing import TypedDict, List, Optional, Annotated, Literal\n", "from IPython.display import Image, display\n", "from pydantic import BaseModel, Field\n", "\n", "# Ensure the project root is on the path so `src` is importable\n", "_project_root = str(Path(__file__).resolve().parents[2]) if \"__file__\" in dir() else str(Path.cwd().parents[1])\n", "if _project_root not in sys.path:\n", " sys.path.insert(0, _project_root)\n", "\n", "from langchain_core.documents import Document\n", "from langchain_core.messages import BaseMessage, SystemMessage, AIMessage\n", "from langchain_core.tools import tool\n", "from langgraph.checkpoint.memory import InMemorySaver\n", "from langgraph.graph.message import add_messages\n", "from langchain_ollama import ChatOllama, OllamaEmbeddings\n", "from langchain_elasticsearch import ElasticsearchStore\n", "from langgraph.graph import StateGraph, END\n", "from langgraph.prebuilt import ToolNode, tools_condition\n", "from langfuse import Langfuse\n", "\n", "from src.utils.llm_factory import create_chat_model\n", "from src.utils.emb_factory import create_embedding_model\n", "from src.config import (\n", " ELASTICSEARCH_LOCAL_URL,\n", " ELASTICSEARCH_INDEX,\n", " OLLAMA_MODEL_NAME,\n", " OLLAMA_EMB_MODEL_NAME\n", ")" ] }, { "cell_type": "markdown", "id": "a73629e1", "metadata": {}, "source": [ "# State" ] }, { "cell_type": "code", "execution_count": 48, "id": "5097b8a5", "metadata": {}, "outputs": [], "source": [ "class AgentState(TypedDict):\n", " messages: Annotated[list, add_messages]\n", " reformulated_query: str\n", " context: str" ] }, { "cell_type": "code", "execution_count": 49, "id": "c5328a0d", "metadata": {}, "outputs": [], "source": [ "class AgenticAgentState(TypedDict):\n", " messages: Annotated[list, add_messages]" ] }, { "cell_type": "code", "execution_count": 50, "id": "e8f0a0bc", "metadata": {}, "outputs": [], "source": [ "REFORMULATE_PROMPT = SystemMessage(\n", " content=(\n", " \"You are a deterministic query rewriting function.\\n\"\n", " \"You convert natural language questions into keyword search queries.\\n\\n\"\n", " \"Strict constraints:\\n\"\n", " \"1. Keep function names and technical tokens unchanged.\\n\"\n", " \"2. Remove filler phrases.\\n\"\n", " \"3. Do not answer.\\n\"\n", " \"4. Do not explain.\\n\"\n", " \"5. Do not generate code.\\n\"\n", " \"6. Return a single-line query only.\\n\"\n", " \"7. If already optimal, return unchanged.\\n\"\n", " )\n", ")\n", "\n", "GENERATE_PROMPT = SystemMessage(\n", " content=\"\"\"You are an agent designed to assist users with AVAP (Advanced Virtual API Programming) language.\n", " It's a new language, so you should know nothing about it.\n", " Use ONLY the provided context to answer AVAP-related questions.\n", " If the context does not contain enough information, say so honestly.\n", " If the question is not related to AVAP, answer based on your general knowledge.\n", "\n", " Context:\n", " {context}\"\"\"\n", ")\n", "\n", "AGENTIC_PROMPT = SystemMessage(\n", " content=\"\"\"You are an agent designed to assist users with AVAP (Advanced Virtual API Programming) language.\n", " It's a new language, so you should know nothing about it.\n", " Use ONLY the provided 'context_retrieve' tool to answer AVAP-related questions.\n", " The 'context_retrieve' tool receives a user query (as a string) and returns relevant context from a vector store.\n", " If the context does not contain enough information, say so honestly.\n", " If the question is not related to AVAP, answer based on your general knowledge.\n", " \"\"\"\n", ")" ] }, { "cell_type": "markdown", "id": "ebd003d9", "metadata": {}, "source": [ "# Function" ] }, { "cell_type": "code", "execution_count": 51, "id": "b4e5d981", "metadata": {}, "outputs": [], "source": [ "def format_context(docs: List[Document]) -> str:\n", " chunks: List[str] = []\n", " for i, doc in enumerate(docs, 1):\n", " source = (doc.metadata or {}).get(\"source\", \"Untitled\")\n", " source_id = (doc.metadata or {}).get(\"id\", f\"chunk-{i}\")\n", " text = doc.page_content or \"\"\n", " chunks.append(f\"[{i}] id={source_id} source={source}\\n{text}\")\n", " return \"\\n\\n\".join(chunks)" ] }, { "cell_type": "code", "execution_count": 52, "id": "8a38359a", "metadata": {}, "outputs": [], "source": [ "def reformulate(state: AgentState, llm=None) -> AgentState:\n", " \"\"\"Use the LLM to rewrite the user query for better retrieval.\"\"\"\n", " # The graph runner passes the state only. Accept an optional `llm` so\n", " # this function can also be called directly with an explicit model.\n", " user_msg = state[\"messages\"][-1]\n", " if llm is None:\n", " llm = globals().get('llm')\n", " if llm is None:\n", " raise RuntimeError('No LLM available for reformulate')\n", " resp = llm.invoke([REFORMULATE_PROMPT, user_msg])\n", " reformulated = resp.content.strip()\n", " print(f\"[reformulate] '{user_msg.content}' → '{reformulated}'\")\n", " return {\"reformulated_query\": reformulated}\n", "\n", "\n", "def retrieve(state: AgentState, vector_store=None, retrieve_kwargs=None) -> AgentState:\n", " \"\"\"Retrieve context using the reformulated query.\"\"\"\n", " # Graph runner passes state first. Accept optional `vector_store` and\n", " # `retrieve_kwargs` for direct calls; otherwise fall back to globals.\n", " if vector_store is None:\n", " vector_store = globals().get('vector_store')\n", " if vector_store is None:\n", " raise RuntimeError('No vector_store available for retrieve')\n", " if retrieve_kwargs is None:\n", " retrieve_kwargs = {}\n", " query = state[\"reformulated_query\"]\n", " docs = vector_store.as_retriever(\n", " search_type=\"similarity\",\n", " search_kwargs=retrieve_kwargs,\n", " ).invoke(query)\n", " context = format_context(docs)\n", " print(f\"[retrieve] {len(docs)} docs fetched\")\n", " print(context)\n", " return {\"context\": context}\n", "\n", "\n", "def generate(llm, state: AgentState) -> AgentState:\n", " \"\"\"Generate the final answer using retrieved context.\"\"\"\n", " prompt = SystemMessage(\n", " content=GENERATE_PROMPT.content.format(context=state[\"context\"])\n", " )\n", " resp = llm.invoke([prompt] + state[\"messages\"])\n", " return {\"messages\": [resp]}" ] }, { "cell_type": "code", "execution_count": 53, "id": "d5001041", "metadata": {}, "outputs": [], "source": [ "def agent(llm, tools, state: AgentState) -> AgentState:\n", " llm_with_tools = llm.bind_tools(tools)\n", " return {\"messages\": [llm_with_tools.invoke([SystemMessage(content=AGENTIC_PROMPT.content)] + state[\"messages\"])]}" ] }, { "cell_type": "markdown", "id": "538f812d", "metadata": {}, "source": [ "# Code" ] }, { "cell_type": "code", "execution_count": 54, "id": "89786481", "metadata": {}, "outputs": [], "source": [ "langfuse = Langfuse()\n", "\n", "llm = create_chat_model(\n", " provider=\"ollama\",\n", " model=OLLAMA_MODEL_NAME,\n", " temperature=0,\n", " validate_model_on_init=True,\n", ")\n", "embeddings = create_embedding_model(\n", " provider=\"ollama\",\n", " model=OLLAMA_EMB_MODEL_NAME,\n", ")\n", "vector_store = ElasticsearchStore(\n", " es_url=ELASTICSEARCH_LOCAL_URL,\n", " index_name=ELASTICSEARCH_INDEX,\n", " embedding=embeddings,\n", " query_field=\"text\",\n", " vector_query_field=\"vector\",\n", " # strategy=ElasticsearchStore.ApproxRetrievalStrategy(\n", " # hybrid=True,\n", " # rrf={\"rank_constant\": 60, \"window_size\": 100}\n", " # )\n", ")" ] }, { "cell_type": "markdown", "id": "22115cdc", "metadata": {}, "source": [ "## Tool" ] }, { "cell_type": "code", "execution_count": 55, "id": "42df0acc", "metadata": {}, "outputs": [], "source": [ "retrieve_kwargs = {\"k\": 3}" ] }, { "cell_type": "code", "execution_count": 56, "id": "a132e0e7", "metadata": {}, "outputs": [], "source": [ "@tool\n", "def context_retrieve(query: str) -> str:\n", " \"\"\"Consults vector store to respond AVAP related questions\n", " Args:\n", " query (str): The input query for which to retrieve relevant documents.\n", " \"\"\"\n", " retriever = vector_store.as_retriever(\n", " search_type=\"similarity\",\n", " search_kwargs=retrieve_kwargs,\n", " )\n", " docs = retriever.invoke(query)\n", " return format_context(docs)" ] }, { "cell_type": "markdown", "id": "0165d8d7", "metadata": {}, "source": [ "## Graph" ] }, { "cell_type": "code", "execution_count": 57, "id": "70fdd80f", "metadata": {}, "outputs": [], "source": [ "memory = InMemorySaver()\n", "\n", "graph_builder = StateGraph(AgentState)\n", "\n", "graph_builder.add_node(\"reformulate\", reformulate)\n", "graph_builder.add_node(\"retrieve\", retrieve)\n", "graph_builder.add_node(\"generate\", generate)\n", "\n", "graph_builder.set_entry_point(\"reformulate\")\n", "graph_builder.add_edge(\"reformulate\", \"retrieve\")\n", "graph_builder.add_edge(\"retrieve\", \"generate\")\n", "graph_builder.add_edge(\"generate\", END)\n", "\n", "guided_graph = graph_builder.compile()" ] }, { "cell_type": "code", "execution_count": 58, "id": "be526413", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Failed to export span batch code: 404, reason: