245 lines
7.7 KiB
Python
245 lines
7.7 KiB
Python
import os
|
|
import time
|
|
import json
|
|
import logging
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
from ragas import evaluate as ragas_evaluate
|
|
from ragas.metrics import ( faithfulness, answer_relevancy, context_recall, context_precision,)
|
|
from ragas.llms import LangchainLLMWrapper
|
|
from ragas.embeddings import LangchainEmbeddingsWrapper
|
|
from datasets import Dataset
|
|
from langchain_anthropic import ChatAnthropic
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
GOLDEN_DATASET_PATH = Path(__file__).parent / "golden_dataset.json"
|
|
CLAUDE_MODEL = os.getenv("ANTHROPIC_MODEL", "claude-sonnet-4-20250514")
|
|
ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY", "")
|
|
K_RETRIEVE = 5
|
|
|
|
|
|
|
|
ANTHROPIC_AVAILABLE = True
|
|
|
|
|
|
from elasticsearch import Elasticsearch
|
|
from langchain_core.messages import SystemMessage, HumanMessage
|
|
|
|
def retrieve_context( es_client, embeddings, question, index, k = K_RETRIEVE,):
|
|
|
|
query_vector = None
|
|
try:
|
|
query_vector = embeddings.embed_query(question)
|
|
except Exception as e:
|
|
logger.warning(f"[eval] embed_query fails: {e}")
|
|
|
|
bm25_hits = []
|
|
try:
|
|
resp = es_client.search(
|
|
index=index,
|
|
body={
|
|
"size": k,
|
|
"query": {
|
|
"multi_match": {
|
|
"query": question,
|
|
"fields": ["content^2", "text^2"],
|
|
"type": "best_fields",
|
|
"fuzziness": "AUTO",
|
|
}
|
|
},
|
|
"_source": {"excludes": ["embedding"]},
|
|
}
|
|
)
|
|
bm25_hits = resp["hits"]["hits"]
|
|
except Exception as e:
|
|
logger.warning(f"[eval] BM25 fails: {e}")
|
|
|
|
knn_hits = []
|
|
if query_vector:
|
|
try:
|
|
resp = es_client.search(
|
|
index=index,
|
|
body={
|
|
"size": k,
|
|
"knn": {
|
|
"field": "embedding",
|
|
"query_vector": query_vector,
|
|
"k": k,
|
|
"num_candidates": k * 5,
|
|
},
|
|
"_source": {"excludes": ["embedding"]},
|
|
}
|
|
)
|
|
knn_hits = resp["hits"]["hits"]
|
|
except Exception as e:
|
|
logger.warning(f"[eval] kNN falló: {e}")
|
|
|
|
rrf_scores: dict[str, float] = defaultdict(float)
|
|
hit_by_id: dict[str, dict] = {}
|
|
|
|
for rank, hit in enumerate(bm25_hits):
|
|
doc_id = hit["_id"]
|
|
rrf_scores[doc_id] += 1.0 / (rank + 60)
|
|
hit_by_id[doc_id] = hit
|
|
|
|
for rank, hit in enumerate(knn_hits):
|
|
doc_id = hit["_id"]
|
|
rrf_scores[doc_id] += 1.0 / (rank + 60)
|
|
if doc_id not in hit_by_id:
|
|
hit_by_id[doc_id] = hit
|
|
|
|
ranked = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)[:k]
|
|
|
|
return [
|
|
hit_by_id[doc_id]["_source"].get("content")
|
|
or hit_by_id[doc_id]["_source"].get("text", "")
|
|
for doc_id, _ in ranked
|
|
if (
|
|
hit_by_id[doc_id]["_source"].get("content")
|
|
or hit_by_id[doc_id]["_source"].get("text", "")
|
|
).strip()
|
|
]
|
|
|
|
|
|
def generate_answer(llm, question: str, contexts: list[str]) -> str:
|
|
try:
|
|
from prompts import GENERATE_PROMPT
|
|
context_text = "\n\n".join(
|
|
f"[{i+1}] {ctx}" for i, ctx in enumerate(contexts)
|
|
)
|
|
prompt = SystemMessage(
|
|
content=GENERATE_PROMPT.content.format(context=context_text)
|
|
)
|
|
resp = llm.invoke([prompt, HumanMessage(content=question)])
|
|
return resp.content.strip()
|
|
except Exception as e:
|
|
logger.warning(f"[eval] generate_answer fails: {e}")
|
|
return ""
|
|
|
|
def run_evaluation( es_client, llm, embeddings, index_name, category = None, limit = None,):
|
|
|
|
if not ANTHROPIC_AVAILABLE:
|
|
return {"error": "langchain-anthropic no instalado. pip install langchain-anthropic"}
|
|
if not ANTHROPIC_API_KEY:
|
|
return {"error": "ANTHROPIC_API_KEY no configurada en .env"}
|
|
if not GOLDEN_DATASET_PATH.exists():
|
|
return {"error": f"Golden dataset no encontrado en {GOLDEN_DATASET_PATH}"}
|
|
|
|
|
|
questions = json.loads(GOLDEN_DATASET_PATH.read_text(encoding="utf-8"))
|
|
if category:
|
|
questions = [q for q in questions if q.get("category") == category]
|
|
if limit:
|
|
questions = questions[:limit]
|
|
if not questions:
|
|
return {"error": "NO QUESTIONS WITH THIS FILTERS"}
|
|
|
|
logger.info(f"[eval] makind: {len(questions)} questions, index={index_name}")
|
|
|
|
claude_judge = ChatAnthropic(
|
|
model=CLAUDE_MODEL,
|
|
api_key=ANTHROPIC_API_KEY,
|
|
temperature=0,
|
|
max_tokens=2048,
|
|
)
|
|
|
|
rows = {"question": [], "answer": [], "contexts": [], "ground_truth": []}
|
|
details = []
|
|
t_start = time.time()
|
|
|
|
for item in questions:
|
|
q_id = item["id"]
|
|
question = item["question"]
|
|
gt = item["ground_truth"]
|
|
|
|
logger.info(f"[eval] {q_id}: {question[:60]}")
|
|
|
|
contexts = retrieve_context(es_client, embeddings, question, index_name)
|
|
if not contexts:
|
|
logger.warning(f"[eval] No context for {q_id} — skipping")
|
|
continue
|
|
|
|
answer = generate_answer(llm, question, contexts)
|
|
if not answer:
|
|
logger.warning(f"[eval] No answers for {q_id} — skipping")
|
|
continue
|
|
|
|
rows["question"].append(question)
|
|
rows["answer"].append(answer)
|
|
rows["contexts"].append(contexts)
|
|
rows["ground_truth"].append(gt)
|
|
|
|
details.append({
|
|
"id": q_id,
|
|
"category": item.get("category", ""),
|
|
"question": question,
|
|
"answer_preview": answer[:300],
|
|
"n_chunks": len(contexts),
|
|
})
|
|
|
|
if not rows["question"]:
|
|
return {"error": "NO SAMPLES GENETARED"}
|
|
|
|
dataset = Dataset.from_dict(rows)
|
|
ragas_llm = LangchainLLMWrapper(claude_judge)
|
|
ragas_emb = LangchainEmbeddingsWrapper(embeddings)
|
|
|
|
metrics = [faithfulness, answer_relevancy, context_recall, context_precision]
|
|
for metric in metrics:
|
|
metric.llm = ragas_llm
|
|
if hasattr(metric, "embeddings"):
|
|
metric.embeddings = ragas_emb
|
|
|
|
logger.info("[eval] JUDGING BY CLAUDE...")
|
|
result = ragas_evaluate(dataset, metrics=metrics)
|
|
|
|
elapsed = time.time() - t_start
|
|
|
|
# RAGAS >= 0.2 returns an EvaluationResult object, not a dict.
|
|
# Extract per-metric means from the underlying DataFrame.
|
|
try:
|
|
df = result.to_pandas()
|
|
def _mean(col):
|
|
return round(float(df[col].dropna().mean()), 4) if col in df.columns else 0.0
|
|
except Exception:
|
|
# Fallback: try legacy dict-style access
|
|
df = None
|
|
def _mean(col):
|
|
try:
|
|
return round(float(result[col]), 4)
|
|
except Exception:
|
|
return 0.0
|
|
|
|
scores = {
|
|
"faithfulness": _mean("faithfulness"),
|
|
"answer_relevancy": _mean("answer_relevancy"),
|
|
"context_recall": _mean("context_recall"),
|
|
"context_precision": _mean("context_precision"),
|
|
}
|
|
|
|
valid_scores = [v for v in scores.values() if v > 0]
|
|
global_score = round(sum(valid_scores) / len(valid_scores), 4) if valid_scores else 0.0
|
|
|
|
verdict = (
|
|
"EXCELLENT" if global_score >= 0.8 else
|
|
"ACCEPTABLE" if global_score >= 0.6 else
|
|
"INSUFFICIENT"
|
|
)
|
|
|
|
logger.info(f"[eval] FINISHED — global={global_score} verdict={verdict} "
|
|
f"elapsed={elapsed:.0f}s")
|
|
|
|
return {
|
|
"status": "ok",
|
|
"questions_evaluated": len(rows["question"]),
|
|
"elapsed_seconds": round(elapsed, 1),
|
|
"judge_model": CLAUDE_MODEL,
|
|
"index": index_name,
|
|
"category_filter": category or "all",
|
|
"scores": scores,
|
|
"global_score": global_score,
|
|
"verdict": verdict,
|
|
"details": details,
|
|
} |