workin on llm_factory
This commit is contained in:
parent
77751ee8ac
commit
e01e424fac
|
|
@ -10,7 +10,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 83,
|
||||
"execution_count": 145,
|
||||
"id": "9e974df6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
|
|
@ -30,12 +30,16 @@
|
|||
"from langgraph.graph import StateGraph, END\n",
|
||||
"from langgraph.prebuilt import ToolNode\n",
|
||||
"from langfuse import get_client, Langfuse\n",
|
||||
"from langfuse.langchain import CallbackHandler"
|
||||
"from langfuse.langchain import CallbackHandler\n",
|
||||
"\n",
|
||||
"from typing import TypedDict, List, Optional, Annotated, Literal\n",
|
||||
"from pydantic import BaseModel, Field\n",
|
||||
"from langchain_core.messages import BaseMessage, SystemMessage, AIMessage"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 84,
|
||||
"execution_count": 146,
|
||||
"id": "30edcecc",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
|
|
@ -72,15 +76,32 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 147,
|
||||
"id": "5f8c88cf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class AgentState(TypedDict):\n",
|
||||
" messages: Annotated[list, add_messages]\n",
|
||||
" language_ok: bool\n",
|
||||
" language_retries: int"
|
||||
" messages: Annotated[list, add_messages]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 148,
|
||||
"id": "30473bce",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class AgentResponse(BaseModel):\n",
|
||||
" \"\"\"\n",
|
||||
" Structured output contract for final assistant responses.\n",
|
||||
" \"\"\"\n",
|
||||
" language: Literal[\"en\"] = Field(\n",
|
||||
" description=\"ISO code. Must always be 'en'.\"\n",
|
||||
" )\n",
|
||||
" content: str = Field(\n",
|
||||
" description=\"Final answer in English.\"\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -93,7 +114,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 86,
|
||||
"execution_count": 149,
|
||||
"id": "f0a21230",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
|
|
@ -103,7 +124,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 87,
|
||||
"execution_count": 150,
|
||||
"id": "f9359747",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
|
|
@ -134,7 +155,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 88,
|
||||
"execution_count": 151,
|
||||
"id": "e5247ab9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
|
|
@ -149,7 +170,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 89,
|
||||
"execution_count": 152,
|
||||
"id": "a644f6fa",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
|
|
@ -169,27 +190,11 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 153,
|
||||
"id": "36d0f54e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def _safe_detect_language(text: str) -> str:\n",
|
||||
" stripped_text = (text or \"\").strip()\n",
|
||||
" if not stripped_text:\n",
|
||||
" return \"unknown\"\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" from langdetect import LangDetectException, detect\n",
|
||||
" return detect(stripped_text)\n",
|
||||
" except Exception:\n",
|
||||
" cjk_pattern = r\"[\\u3400-\\u4dbf\\u4e00-\\u9fff\\uf900-\\ufaff\\u3040-\\u30ff\\uac00-\\ud7af]\"\n",
|
||||
" if re.search(cjk_pattern, stripped_text):\n",
|
||||
" return \"non-en\"\n",
|
||||
" ascii_ratio = sum(1 for char in stripped_text if ord(char) < 128) / max(len(stripped_text), 1)\n",
|
||||
" return \"en\" if ascii_ratio > 0.9 else \"unknown\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def _message_text(message: BaseMessage) -> str:\n",
|
||||
" content = getattr(message, \"content\", \"\")\n",
|
||||
" if isinstance(content, str):\n",
|
||||
|
|
@ -215,7 +220,7 @@
|
|||
" return \"\"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"MAX_LANGUAGE_RETRIES = 2\n",
|
||||
"MAX_LANGUAGE_RETRIES = 5\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def agent(state: AgentState) -> AgentState:\n",
|
||||
|
|
@ -242,45 +247,21 @@
|
|||
" )\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" model = llm.bind_tools(tools)\n",
|
||||
" response = model.invoke([system, *messages])\n",
|
||||
"\n",
|
||||
" return {\"messages\": [*messages, response]}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def language_guard(state: AgentState) -> AgentState:\n",
|
||||
" messages: List[BaseMessage] = state[\"messages\"]\n",
|
||||
" retries = state.get(\"language_retries\", 0)\n",
|
||||
" assistant_text = _last_assistant_text(messages)\n",
|
||||
" detected_language = _safe_detect_language(assistant_text)\n",
|
||||
" is_english = detected_language == \"en\"\n",
|
||||
"\n",
|
||||
" if is_english or retries >= MAX_LANGUAGE_RETRIES:\n",
|
||||
" return {\n",
|
||||
" \"messages\": messages,\n",
|
||||
" \"language_ok\": is_english,\n",
|
||||
" \"language_retries\": retries,\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" correction_instruction = SystemMessage(\n",
|
||||
" content=(\n",
|
||||
" \"Your previous answer was not in English. Regenerate the final answer in English only. \"\n",
|
||||
" \"Keep the same meaning and format requirements.\"\n",
|
||||
" structured_model = llm.with_structured_output(AgentResponse)\n",
|
||||
" structured_response: AgentResponse = structured_model.invoke(\n",
|
||||
" [system, *messages]\n",
|
||||
" )\n",
|
||||
" )\n",
|
||||
" return {\n",
|
||||
" \"messages\": [*messages, correction_instruction],\n",
|
||||
" \"language_ok\": False,\n",
|
||||
" \"language_retries\": retries + 1,\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" # Keep LangGraph message flow unchanged\n",
|
||||
" final_message = AIMessage(content=structured_response.content)\n",
|
||||
" return {\"messages\": [*messages, final_message]}\n",
|
||||
"\n",
|
||||
"def route_after_language_guard(state: AgentState) -> str:\n",
|
||||
" if state.get(\"language_ok\", False):\n",
|
||||
" return \"end\"\n",
|
||||
" if state.get(\"language_retries\", 0) >= MAX_LANGUAGE_RETRIES:\n",
|
||||
" return \"end\"\n",
|
||||
" return \"retry\""
|
||||
" # model = llm.bind_tools(tools)\n",
|
||||
" # response = model.invoke([system, *messages])\n",
|
||||
"\n",
|
||||
" # return {\"messages\": [*messages, response]}"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
@ -293,22 +274,17 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 154,
|
||||
"id": "fae46a58",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"graph = StateGraph(AgentState)\n",
|
||||
"graph.add_node(\"agent\", agent)\n",
|
||||
"graph.add_node(\"language_guard\", language_guard)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"graph.set_entry_point(\"agent\")\n",
|
||||
"graph.add_edge(\"agent\", \"language_guard\")\n",
|
||||
"graph.add_conditional_edges(\n",
|
||||
" \"language_guard\",\n",
|
||||
" route_after_language_guard,\n",
|
||||
" {\"retry\": \"agent\", \"end\": END},\n",
|
||||
")\n",
|
||||
"graph.add_edge(\"agent\", END)\n",
|
||||
"\n",
|
||||
"# Alternative mode (single pass) - kept commented for quick rollback.\n",
|
||||
"# graph.set_entry_point(\"agent\")\n",
|
||||
|
|
@ -319,7 +295,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 92,
|
||||
"execution_count": 155,
|
||||
"id": "2fec3fdb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
|
|
@ -351,7 +327,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 93,
|
||||
"execution_count": 156,
|
||||
"id": "8569cf39",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
|
|
@ -372,23 +348,19 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 94,
|
||||
"execution_count": 162,
|
||||
"id": "a1a1f3cf",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"user_input = (\n",
|
||||
" \"Create exactly two variables in AVAP. \"\n",
|
||||
" \"Each variable must be assigned a single integer number using addVar(). \"\n",
|
||||
" \"For example: first variable named x with value 10, second variable named y with value 20. \"\n",
|
||||
" \"Return only the AVAP code snippet with the two addVar() statements in one fenced ```avap``` block. \"\n",
|
||||
" \"Do not add explanations, just the code.\"\n",
|
||||
" \"Generate two variables, asigning them one number to each, do it in AVAP language.\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 95,
|
||||
"execution_count": 163,
|
||||
"id": "53b89690",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
|
|
@ -398,23 +370,10 @@
|
|||
"text": [
|
||||
"================================\u001b[1m Human Message \u001b[0m=================================\n",
|
||||
"\n",
|
||||
"Create exactly two variables in AVAP. Each variable must be assigned a single integer number using addVar(). For example: first variable named x with value 10, second variable named y with value 20. Return only the AVAP code snippet with the two addVar() statements in one fenced ```avap``` block. Do not add explanations, just the code.\n",
|
||||
"Generate two variables, asigning them one number to each, do it in AVAP language.\n",
|
||||
"==================================\u001b[1m Ai Message \u001b[0m==================================\n",
|
||||
"\n",
|
||||
"<think>\n",
|
||||
"Okay, the user wants me to create exactly two variables in the AVAP language. Each variable needs to be an integer, and they should be assigned using addVar(). The example given is first variable x with value 10, second y with 20. I need to return the code in a single ```avap``` block without any explanations.\n",
|
||||
"\n",
|
||||
"First, I should make sure I understand the task correctly. The user specified two variables, each assigned an integer. The function to use is addVar(), and the variables should be named x and y. The example uses x and y, so I'll follow that format.\n",
|
||||
"\n",
|
||||
"I need to create two addVar() statements. The first one would be addVar(\"x\", 10), and the second would be addVar(\"y\", 20). Then, I have to put all this into a single ```avap``` block. Let me check the syntax again to make sure. The function call should be in JSON within the tool_call tags. Each addVar is a separate function call. \n",
|
||||
"\n",
|
||||
"Wait, the user said \"return only the AVAP code snippet with the two addVar() statements\". So the code should be two lines: first the addVar for x, then another addVar for y. No explanations, just the code. Alright, that's clear. I'll structure the tool calls correctly, ensuring that each function call is properly formatted.\n",
|
||||
"</think>\n",
|
||||
"\n",
|
||||
"```avap\n",
|
||||
"addVar(x, 10)\n",
|
||||
"addVar(y, 20)\n",
|
||||
"```\n"
|
||||
"user: generate two variables, assigning them one number to each, you assign themize in, avap language.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
load_dotenv()
|
||||
|
||||
OPENAI_API_KEY=os.getenv("OPENAI_API_KEY", "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")
|
||||
|
||||
OLLAMA_URL=os.getenv("OLLAMA_URL", "http://host.docker.internal:11434")
|
||||
OLLAMA_LOCAL_URL=os.getenv("OLLAMA_LOCAL_URL", "http://localhost:11434")
|
||||
OLLAMA_MODEL_NAME=os.getenv("OLLAMA_MODEL_NAME", "qwen3-0.6B:latest")
|
||||
OLLAMA_EMB_MODEL_NAME=os.getenv("OLLAMA_EMB_MODEL_NAME", "qwen3-0.6B-emb:latest")
|
||||
|
||||
LANGFUSE_HOST=os.getenv("LANGFUSE_HOST", "http://45.77.119.180")
|
||||
LANGFUSE_PUBLIC_KEY=os.getenv("LANGFUSE_PUBLIC_KEY", "pk-lf-0e6db694-3e95-4dd4-aedf-5a2694267058")
|
||||
LANGFUSE_SECRET_KEY=os.getenv("LANGFUSE_SECRET_KEY", "sk-lf-dbf28bb9-15bb-4d03-a8c3-05caa3e3905f")
|
||||
|
||||
ELASTICSEARCH_URL=os.getenv("ELASTICSEARCH_URL", "http://host.docker.internal:9200")
|
||||
ELASTICSEARCH_LOCAL_URL=os.getenv("ELASTICSEARCH_LOCAL_URL", "http://localhost:9200")
|
||||
ELASTICSEARCH_INDEX=os.getenv("ELASTICSEARCH_INDEX", "avap-docs-test")
|
||||
|
||||
DATABASE_URL=os.getenv("DATABASE_URL", "postgresql://postgres:brunix_pass@host.docker.internal:5432/postgres")
|
||||
|
||||
KUBECONFIG_PATH=os.getenv("KUBECONFIG_PATH", "kubernetes/kubeconfig.yaml")
|
||||
|
||||
HF_TOKEN=os.getenv("HF_TOKEN", "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")
|
||||
HF_EMB_MODEL_NAME=os.getenv("HF_EMB_MODEL_NAME", "Qwen/Qwen3-Embedding-0.6B")
|
||||
|
||||
PROJ_ROOT = Path(__file__).resolve().parents[1]
|
||||
|
||||
DATA_DIR=PROJ_ROOT / "data"
|
||||
MODELS_DIR=DATA_DIR / "models"
|
||||
RAW_DIR=DATA_DIR / "raw"
|
||||
PROCESSED_DIR=DATA_DIR / "processed"
|
||||
INTERIM_DIR=DATA_DIR / "interim"
|
||||
EXTERNAL_DIR=DATA_DIR / "external"
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
|
||||
from langchain_ollama import ChatOllama, OllamaEmbeddings
|
||||
|
||||
|
||||
class Provider(StrEnum):
|
||||
OLLAMA = "ollama"
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
AWS_BEDROCK = "aws_bedrock"
|
||||
HUGGINGFACE = "huggingface"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChatModelConfig:
|
||||
provider: Provider
|
||||
model: str
|
||||
temperature: float = 0.0
|
||||
|
||||
# Ollama
|
||||
ollama_base_url: Optional[str] = None
|
||||
validate_model_on_init: bool = True
|
||||
|
||||
# OpenAI / Anthropic / Azure
|
||||
api_key: Optional[str] = None
|
||||
azure_endpoint: Optional[str] = None
|
||||
azure_deployment: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EmbeddingsConfig:
|
||||
provider: Provider
|
||||
model: str
|
||||
|
||||
# Ollama
|
||||
ollama_base_url: Optional[str] = None
|
||||
|
||||
# OpenAI / Azure
|
||||
api_key: Optional[str] = None
|
||||
azure_endpoint: Optional[str] = None
|
||||
azure_deployment: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
|
||||
|
||||
def build_chat_model(cfg: ChatModelConfig):
|
||||
match cfg.provider:
|
||||
case Provider.OLLAMA:
|
||||
return ChatOllama(
|
||||
model=cfg.model,
|
||||
temperature=cfg.temperature,
|
||||
validate_model_on_init=cfg.validate_model_on_init,
|
||||
base_url=cfg.ollama_base_url,
|
||||
)
|
||||
|
||||
case Provider.OPENAI:
|
||||
from langchain_openai import ChatOpenAI # pip install langchain-openai
|
||||
|
||||
if not cfg.api_key:
|
||||
raise ValueError("Missing api_key for OpenAI provider.")
|
||||
return ChatOpenAI(
|
||||
model=cfg.model,
|
||||
temperature=cfg.temperature,
|
||||
api_key=cfg.api_key,
|
||||
)
|
||||
|
||||
case Provider.ANTHROPIC:
|
||||
from langchain_anthropic import ChatAnthropic # pip install langchain-anthropic
|
||||
|
||||
if not cfg.api_key:
|
||||
raise ValueError("Missing api_key for Anthropic provider.")
|
||||
return ChatAnthropic(
|
||||
model=cfg.model,
|
||||
temperature=cfg.temperature,
|
||||
api_key=cfg.api_key,
|
||||
)
|
||||
|
||||
case Provider.AZURE_OPENAI:
|
||||
from langchain_openai import AzureChatOpenAI # pip install langchain-openai
|
||||
|
||||
missing = [
|
||||
name
|
||||
for name, value in {
|
||||
"api_key": cfg.api_key,
|
||||
"azure_endpoint": cfg.azure_endpoint,
|
||||
"azure_deployment": cfg.azure_deployment,
|
||||
"api_version": cfg.api_version,
|
||||
}.items()
|
||||
if not value
|
||||
]
|
||||
if missing:
|
||||
raise ValueError(f"Missing Azure settings: {', '.join(missing)}")
|
||||
|
||||
return AzureChatOpenAI(
|
||||
api_key=cfg.api_key,
|
||||
azure_endpoint=cfg.azure_endpoint,
|
||||
azure_deployment=cfg.azure_deployment,
|
||||
api_version=cfg.api_version,
|
||||
temperature=cfg.temperature,
|
||||
)
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Unsupported provider: {cfg.provider}")
|
||||
|
||||
|
||||
def build_embeddings(cfg: EmbeddingsConfig):
|
||||
match cfg.provider:
|
||||
case Provider.OLLAMA:
|
||||
return OllamaEmbeddings(
|
||||
model=cfg.model,
|
||||
base_url=cfg.ollama_base_url,
|
||||
)
|
||||
|
||||
case Provider.OPENAI:
|
||||
from langchain_openai import OpenAIEmbeddings # pip install langchain-openai
|
||||
|
||||
if not cfg.api_key:
|
||||
raise ValueError("Missing api_key for OpenAI embeddings provider.")
|
||||
return OpenAIEmbeddings(
|
||||
model=cfg.model,
|
||||
api_key=cfg.api_key,
|
||||
)
|
||||
|
||||
case Provider.AZURE_OPENAI:
|
||||
from langchain_openai import AzureOpenAIEmbeddings # pip install langchain-openai
|
||||
|
||||
missing = [
|
||||
name
|
||||
for name, value in {
|
||||
"api_key": cfg.api_key,
|
||||
"azure_endpoint": cfg.azure_endpoint,
|
||||
"azure_deployment": cfg.azure_deployment,
|
||||
"api_version": cfg.api_version,
|
||||
}.items()
|
||||
if not value
|
||||
]
|
||||
if missing:
|
||||
raise ValueError(f"Missing Azure settings: {', '.join(missing)}")
|
||||
|
||||
return AzureOpenAIEmbeddings(
|
||||
api_key=cfg.api_key,
|
||||
azure_endpoint=cfg.azure_endpoint,
|
||||
azure_deployment=cfg.azure_deployment,
|
||||
api_version=cfg.api_version,
|
||||
)
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Unsupported embeddings provider: {cfg.provider}")
|
||||
|
|
@ -0,0 +1,179 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# ---------- Providers ----------
|
||||
class Provider(StrEnum):
|
||||
OLLAMA = "ollama"
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
AWS_BEDROCK = "aws_bedrock"
|
||||
HUGGINGFACE = "huggingface"
|
||||
|
||||
|
||||
# ---------- Provider-specific configs ----------
|
||||
@dataclass(frozen=True)
|
||||
class OllamaCfg:
|
||||
base_url: Optional[str] = None
|
||||
validate_model_on_init: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OpenAICfg:
|
||||
api_key: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AnthropicCfg:
|
||||
api_key: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BedrockCfg:
|
||||
# depende de cómo autentiques: env vars, perfil AWS, role, etc.
|
||||
region_name: Optional[str] = None
|
||||
# model_kwargs típicos: temperature, max_tokens, etc. (según wrapper)
|
||||
# lo dejamos mínimo para no acoplar
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HuggingFaceCfg:
|
||||
# puede ser token HF o endpoint, según uses Inference API o local
|
||||
api_key: Optional[str] = None
|
||||
endpoint_url: Optional[str] = None
|
||||
|
||||
|
||||
# ---------- Base configs ----------
|
||||
@dataclass(frozen=True)
|
||||
class ChatModelConfig:
|
||||
provider: Provider
|
||||
model: str
|
||||
temperature: float = 0.0
|
||||
|
||||
# EXACTAMENTE una de estas debería venir informada según provider:
|
||||
ollama: Optional[OllamaCfg] = None
|
||||
openai: Optional[OpenAICfg] = None
|
||||
anthropic: Optional[AnthropicCfg] = None
|
||||
bedrock: Optional[BedrockCfg] = None
|
||||
huggingface: Optional[HuggingFaceCfg] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EmbeddingsConfig:
|
||||
provider: Provider
|
||||
model: str
|
||||
|
||||
ollama: Optional[OllamaCfg] = None
|
||||
openai: Optional[OpenAICfg] = None
|
||||
bedrock: Optional[BedrockCfg] = None
|
||||
huggingface: Optional[HuggingFaceCfg] = None
|
||||
|
||||
|
||||
# ---------- Helpers ----------
|
||||
def _require(value, msg: str):
|
||||
if value is None:
|
||||
raise ValueError(msg)
|
||||
return value
|
||||
|
||||
|
||||
def _require_cfg(cfg_obj, msg: str):
|
||||
if cfg_obj is None:
|
||||
raise ValueError(msg)
|
||||
return cfg_obj
|
||||
|
||||
|
||||
# ---------- Builders ----------
|
||||
def build_chat_model(cfg: ChatModelConfig):
|
||||
match cfg.provider:
|
||||
case Provider.OLLAMA:
|
||||
from langchain_ollama import ChatOllama
|
||||
|
||||
ocfg = cfg.ollama or OllamaCfg()
|
||||
return ChatOllama(
|
||||
model=cfg.model,
|
||||
temperature=cfg.temperature,
|
||||
validate_model_on_init=ocfg.validate_model_on_init,
|
||||
base_url=ocfg.base_url,
|
||||
)
|
||||
|
||||
case Provider.OPENAI:
|
||||
from langchain_openai import ChatOpenAI # pip install langchain-openai
|
||||
|
||||
ocfg = _require_cfg(cfg.openai, "Missing cfg.openai for OpenAI provider.")
|
||||
return ChatOpenAI(
|
||||
model=cfg.model,
|
||||
temperature=cfg.temperature,
|
||||
api_key=ocfg.api_key,
|
||||
)
|
||||
|
||||
case Provider.ANTHROPIC:
|
||||
from langchain_anthropic import ChatAnthropic # pip install langchain-anthropic
|
||||
|
||||
acfg = _require_cfg(cfg.anthropic, "Missing cfg.anthropic for Anthropic provider.")
|
||||
return ChatAnthropic(
|
||||
model=cfg.model,
|
||||
temperature=cfg.temperature,
|
||||
api_key=acfg.api_key,
|
||||
)
|
||||
|
||||
case Provider.AWS_BEDROCK:
|
||||
# wrapper típico: langchain-aws (según versión) o langchain-community en algunos setups
|
||||
# aquí lo dejo como ejemplo con guardrail claro
|
||||
try:
|
||||
from langchain_aws import ChatBedrock # pip install langchain-aws
|
||||
except Exception as e:
|
||||
raise ImportError(
|
||||
"To use AWS Bedrock, install `langchain-aws` and configure AWS credentials."
|
||||
) from e
|
||||
|
||||
bcfg = cfg.bedrock or BedrockCfg()
|
||||
# OJO: ChatBedrock suele usar model_id en vez de model, depende del wrapper/versión.
|
||||
return ChatBedrock(
|
||||
model_id=cfg.model,
|
||||
region_name=bcfg.region_name,
|
||||
model_kwargs={"temperature": cfg.temperature},
|
||||
)
|
||||
|
||||
case Provider.HUGGINGFACE:
|
||||
# depende MUCHO: endpoint, local pipeline, inference API...
|
||||
raise NotImplementedError(
|
||||
"HUGGINGFACE provider not implemented here (depends on whether you use Inference API, TGI, or local pipeline)."
|
||||
)
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Unsupported provider: {cfg.provider}")
|
||||
|
||||
|
||||
def build_embeddings(cfg: EmbeddingsConfig):
|
||||
match cfg.provider:
|
||||
case Provider.OLLAMA:
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
|
||||
ocfg = cfg.ollama or OllamaCfg()
|
||||
return OllamaEmbeddings(
|
||||
model=cfg.model,
|
||||
base_url=ocfg.base_url,
|
||||
)
|
||||
|
||||
case Provider.OPENAI:
|
||||
from langchain_openai import OpenAIEmbeddings # pip install langchain-openai
|
||||
|
||||
ocfg = _require_cfg(cfg.openai, "Missing cfg.openai for OpenAI embeddings provider.")
|
||||
return OpenAIEmbeddings(
|
||||
model=cfg.model,
|
||||
api_key=ocfg.api_key,
|
||||
)
|
||||
|
||||
case Provider.AWS_BEDROCK:
|
||||
# Igual: depende del wrapper
|
||||
raise NotImplementedError("Bedrock embeddings: añade el wrapper que uses y mapea aquí.")
|
||||
|
||||
case Provider.HUGGINGFACE:
|
||||
raise NotImplementedError("HuggingFace embeddings: depende del wrapper (endpoint/local).")
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Unsupported embeddings provider: {cfg.provider}")
|
||||
Loading…
Reference in New Issue