338 lines
9.4 KiB
Plaintext
338 lines
9.4 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "9f97dd1e",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Libraries"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9e974df6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"from typing import TypedDict, List, Optional, Annotated\n",
|
|
"from IPython.display import Image, display\n",
|
|
"\n",
|
|
"from langchain_core.documents import Document\n",
|
|
"from langchain_core.messages import BaseMessage, SystemMessage\n",
|
|
"from langchain_core.tools import tool\n",
|
|
"from langgraph.checkpoint.memory import InMemorySaver\n",
|
|
"from langgraph.graph.message import add_messages\n",
|
|
"from langchain_ollama import ChatOllama, OllamaEmbeddings\n",
|
|
"from langchain_elasticsearch import ElasticsearchStore\n",
|
|
"from langgraph.graph import StateGraph, END\n",
|
|
"from langgraph.prebuilt import ToolNode\n",
|
|
"from langfuse import observe, get_client\n",
|
|
"from langfuse.langchain import CallbackHandler"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "30edcecc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"ES_URL = os.getenv(\"ELASTICSEARCH_LOCAL_URL\")\n",
|
|
"INDEX_NAME = os.getenv(\"ELASTICSEARCH_INDEX\")\n",
|
|
"BASE_URL = os.getenv(\"LLM_BASE_LOCAL_URL\")\n",
|
|
"MODEL_NAME = os.getenv(\"OLLAMA_MODEL_NAME\")\n",
|
|
"LANGFUSE_PUBLIC_KEY = os.getenv(\"LANGFUSE_PUBLIC_KEY\")\n",
|
|
"LANGFUSE_SECRET_KEY = os.getenv(\"LANGFUSE_SECRET_KEY\")\n",
|
|
"LANGFUSE_HOST = os.getenv(\"LANGFUSE_HOST\")\n",
|
|
"\n",
|
|
"print(f\"DEBUG: LANGFUSE_HOST from env = {LANGFUSE_HOST}\")\n",
|
|
"\n",
|
|
"# Initialize Langfuse client\n",
|
|
"langfuse = get_client()\n",
|
|
"\n",
|
|
"# Print actual client configuration\n",
|
|
"print(f\"DEBUG: Langfuse client base_url = {langfuse.base_url if hasattr(langfuse, 'base_url') else 'NOT SET'}\")\n",
|
|
"\n",
|
|
"# Create CallbackHandler - it will use the client from environment variables\n",
|
|
"langfuse_handler = CallbackHandler()\n",
|
|
"\n",
|
|
"embeddings = OllamaEmbeddings(base_url=BASE_URL, model=MODEL_NAME)\n",
|
|
"llm = ChatOllama(base_url=BASE_URL, model=MODEL_NAME)\n",
|
|
"\n",
|
|
"vector_store = ElasticsearchStore(\n",
|
|
" es_url=ES_URL,\n",
|
|
" index_name=INDEX_NAME,\n",
|
|
" embedding=embeddings,\n",
|
|
" query_field=\"text\",\n",
|
|
" vector_query_field=\"vector\",\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "ad98841b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(f\"Langfuse Configuration:\")\n",
|
|
"print(f\" Host: {LANGFUSE_HOST}\")\n",
|
|
"print(f\" Public Key: {'*' * 10 if LANGFUSE_PUBLIC_KEY else 'NOT SET'}\")\n",
|
|
"print(f\" Secret Key: {'*' * 10 if LANGFUSE_SECRET_KEY else 'NOT SET'}\")\n",
|
|
"\n",
|
|
"if all([LANGFUSE_HOST, LANGFUSE_PUBLIC_KEY, LANGFUSE_SECRET_KEY]):\n",
|
|
" print(\"\\n✓ All Langfuse environment variables are set\")\n",
|
|
" print(\" Tracing will be sent to Langfuse when you run the agent\")\n",
|
|
"else:\n",
|
|
" print(\"\\n⚠ Some Langfuse variables are missing - tracing may not work\")\n",
|
|
" print(f\" Set LANGFUSE_HOST, LANGFUSE_PUBLIC_KEY, and LANGFUSE_SECRET_KEY\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "873ea2f6",
|
|
"metadata": {},
|
|
"source": [
|
|
"### State"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5f8c88cf",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class AgentState(TypedDict):\n",
|
|
" messages: Annotated[list, add_messages]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1d60c120",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Tools"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f9359747",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def format_context(docs: List[Document]) -> str:\n",
|
|
" chunks: List[str] = []\n",
|
|
" for i, doc in enumerate(docs, 1):\n",
|
|
" title = (doc.metadata or {}).get(\"title\", \"Untitled\")\n",
|
|
" source_id = (doc.metadata or {}).get(\"id\", f\"chunk-{i}\")\n",
|
|
" text = doc.page_content or \"\"\n",
|
|
" chunks.append(f\"[{i}] id={source_id} title={title}\\n{text}\")\n",
|
|
" return \"\\n\\n\".join(chunks)\n",
|
|
"\n",
|
|
"\n",
|
|
"@tool(\"retrieve\", return_direct=False)\n",
|
|
"def retrieve(query: str, k: int = 4, title_filter: Optional[str] = None) -> str:\n",
|
|
" \"\"\"Retrieve relevant context from Elasticsearch for a given query.\"\"\"\n",
|
|
" search_kwargs = {\"k\": k}\n",
|
|
" if title_filter:\n",
|
|
" search_kwargs[\"filter\"] = {\"term\": {\"metadata.title.keyword\": title_filter}}\n",
|
|
"\n",
|
|
" retriever = vector_store.as_retriever(\n",
|
|
" search_type=\"similarity\",\n",
|
|
" search_kwargs=search_kwargs,\n",
|
|
" )\n",
|
|
" docs = retriever.invoke(query)\n",
|
|
" return format_context(docs)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e5247ab9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def should_continue(state: AgentState) -> str:\n",
|
|
" last = state[\"messages\"][-1]\n",
|
|
" # If the model requested tool calls, go execute them\n",
|
|
" if getattr(last, \"tool_calls\", None):\n",
|
|
" return \"tools\"\n",
|
|
" return \"end\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a644f6fa",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"tools = [retrieve]\n",
|
|
"tool_node = ToolNode(tools)\n",
|
|
"memory = InMemorySaver()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "395966e2",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Agent"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "36d0f54e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def agent(state: AgentState) -> AgentState:\n",
|
|
" messages: List[BaseMessage] = state[\"messages\"]\n",
|
|
"\n",
|
|
" system = SystemMessage(\n",
|
|
" content=(\n",
|
|
" \"You are a helpful assistant. You must use the tools provided to respond.\\n\"\n",
|
|
" \"If you don't have enough info, ask a precise follow-up question.\"\n",
|
|
" )\n",
|
|
" )\n",
|
|
"\n",
|
|
" # IMPORTANT: bind tools so the model can emit tool calls\n",
|
|
" # Also bind the langfuse handler for tracing\n",
|
|
" model = llm.bind_tools(tools)\n",
|
|
"\n",
|
|
" resp = model.invoke([system, *messages], config={\"callbacks\": [langfuse_handler]})\n",
|
|
" return {\"messages\": [*messages, resp]}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ef55bca3",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Graph"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "fae46a58",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"graph = StateGraph(AgentState)\n",
|
|
"graph.add_node(\"agent\", agent)\n",
|
|
"graph.add_node(\"tools\", tool_node)\n",
|
|
"\n",
|
|
"graph.set_entry_point(\"agent\")\n",
|
|
"graph.add_conditional_edges(\"agent\", should_continue, {\"tools\": \"tools\", \"end\": END})\n",
|
|
"graph.add_edge(\"tools\", \"agent\")\n",
|
|
"\n",
|
|
"agent_graph = graph.compile(checkpointer=memory)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2fec3fdb",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"try:\n",
|
|
" display(Image(agent_graph.get_graph().draw_mermaid_png()))\n",
|
|
"except Exception:\n",
|
|
" pass"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "1e9aff05",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Test"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a7f4fbf6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"user_input = \"What does vertical farming proposes?\""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8569cf39",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"config = {\"configurable\": {\"thread_id\": \"1\"}, \n",
|
|
" \"callbacks\": [langfuse_handler],\n",
|
|
" \"run_name\": \"rag-local-test\"}\n",
|
|
"\n",
|
|
"\n",
|
|
"@observe(name=\"agent-graph-stream\")\n",
|
|
"def stream_graph_updates(user_input: str):\n",
|
|
" for event in agent_graph.stream(\n",
|
|
" {\"messages\": [{\"role\": \"user\", \"content\": user_input}]},\n",
|
|
" config=config,\n",
|
|
" stream_mode=\"values\",\n",
|
|
" ):\n",
|
|
" event[\"messages\"][-1].pretty_print()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "53b89690",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(\"Starting agent...\")\n",
|
|
"stream_graph_updates(user_input)\n",
|
|
"\n",
|
|
"# Ensure all spans are flushed to Langfuse\n",
|
|
"print(\"\\nFlushing traces to Langfuse...\")\n",
|
|
"langfuse.flush()\n",
|
|
"print(\"✓ Traces flushed (check your Langfuse dashboard at \" + LANGFUSE_HOST + \")\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a5bfbf18",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "assistance-engine",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.11"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|