432 lines
14 KiB
Python
432 lines
14 KiB
Python
"""
|
|
Embedding Evaluation Pipeline
|
|
|
|
Evaluate embedding models across CodexGlue, CoSQA, and SciFact benchmarks.
|
|
Supports multiple embedding providers via factory methods.
|
|
"""
|
|
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Union
|
|
|
|
import numpy as np
|
|
import typer
|
|
from langchain_ollama import OllamaEmbeddings
|
|
from langchain_huggingface import HuggingFaceEmbeddings
|
|
from beir.datasets.data_loader import GenericDataLoader
|
|
from beir.retrieval.evaluation import EvaluateRetrieval
|
|
from beir.retrieval.search.dense import DenseRetrievalExactSearch
|
|
from beir import util
|
|
from datasets import load_dataset
|
|
from src.config import settings
|
|
# Import embedding factory
|
|
project_root = settings.proj_root
|
|
DATASETS_ROOT = project_root / "research" / "embeddings" / "datasets"
|
|
|
|
app = typer.Typer()
|
|
|
|
|
|
def _has_local_beir_files(data_path: Path) -> bool:
|
|
"""Return True when a dataset folder already has the required BEIR files."""
|
|
required_files = [
|
|
data_path / "corpus.jsonl",
|
|
data_path / "queries.jsonl",
|
|
data_path / "qrels" / "test.tsv",
|
|
]
|
|
return all(path.exists() and path.stat().st_size > 0 for path in required_files)
|
|
|
|
|
|
def _load_local_beir_dataset(data_path: Path) -> tuple[Dict, Dict, Dict]:
|
|
"""Load a BEIR-formatted dataset from local disk."""
|
|
return GenericDataLoader(str(data_path)).load(split="test")
|
|
|
|
|
|
class BEIROllamaEmbeddings:
|
|
"""
|
|
Adapter that makes LangChain's OllamaEmbeddings compatible with BEIR.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_url: str,
|
|
model: str,
|
|
batch_size: int = 64,
|
|
) -> None:
|
|
self.batch_size = batch_size
|
|
self.embeddings = OllamaEmbeddings(
|
|
base_url=base_url,
|
|
model=model,
|
|
)
|
|
|
|
def _batch_embed(self, texts: List[str]) -> np.ndarray:
|
|
vectors = []
|
|
|
|
for i in range(0, len(texts), self.batch_size):
|
|
batch = texts[i : i + self.batch_size]
|
|
batch_vectors = self.embeddings.embed_documents(batch)
|
|
|
|
# Handle NaN values by replacing with zeros
|
|
for vec in batch_vectors:
|
|
if isinstance(vec, (list, np.ndarray)):
|
|
vec_array = np.asarray(vec, dtype=np.float32)
|
|
# Replace NaN with zeros
|
|
vec_array = np.nan_to_num(vec_array, nan=0.0, posinf=0.0, neginf=0.0)
|
|
vectors.append(vec_array)
|
|
else:
|
|
vectors.append(vec)
|
|
|
|
return np.asarray(vectors, dtype=np.float32)
|
|
|
|
def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
|
|
"""
|
|
BEIR query encoder
|
|
"""
|
|
# Filter and clean queries - replace empty ones with placeholder
|
|
cleaned_queries = []
|
|
for q in queries:
|
|
if isinstance(q, str):
|
|
cleaned = q.strip()
|
|
if not cleaned:
|
|
cleaned = "[EMPTY]"
|
|
else:
|
|
cleaned = "[INVALID]"
|
|
cleaned_queries.append(cleaned)
|
|
|
|
return self._batch_embed(cleaned_queries)
|
|
|
|
def encode_corpus(
|
|
self,
|
|
corpus: Union[List[Dict[str, str]], Dict[str, Dict[str, str]]],
|
|
**kwargs,
|
|
) -> np.ndarray:
|
|
"""
|
|
BEIR corpus encoder
|
|
"""
|
|
if isinstance(corpus, dict):
|
|
corpus = list(corpus.values())
|
|
|
|
texts = []
|
|
for doc in corpus:
|
|
title = (doc.get("title") or "").strip()
|
|
text = (doc.get("text") or "").strip()
|
|
|
|
# Combine title and text, filtering out empty strings
|
|
combined = " ".join(filter(None, [title, text]))
|
|
|
|
# Use placeholder if both are empty to avoid NaN embeddings
|
|
if not combined:
|
|
combined = "[EMPTY]"
|
|
|
|
texts.append(combined)
|
|
|
|
return self._batch_embed(texts)
|
|
|
|
|
|
class BEIRHuggingFaceEmbeddings:
|
|
"""
|
|
Adapter that makes LangChain's HuggingFaceEmbeddings compatible with BEIR.
|
|
"""
|
|
|
|
def __init__(self, model: str, batch_size: int = 64) -> None:
|
|
self.batch_size = batch_size
|
|
self.embeddings = HuggingFaceEmbeddings(model_name=model)
|
|
|
|
def _batch_embed(self, texts: List[str]) -> np.ndarray:
|
|
vectors = []
|
|
for i in range(0, len(texts), self.batch_size):
|
|
batch = texts[i : i + self.batch_size]
|
|
batch_vectors = self.embeddings.embed_documents(batch)
|
|
|
|
# Handle NaN values
|
|
for vec in batch_vectors:
|
|
if isinstance(vec, (list, np.ndarray)):
|
|
vec_array = np.asarray(vec, dtype=np.float32)
|
|
vec_array = np.nan_to_num(vec_array, nan=0.0, posinf=0.0, neginf=0.0)
|
|
vectors.append(vec_array)
|
|
else:
|
|
vectors.append(vec)
|
|
|
|
return np.asarray(vectors, dtype=np.float32)
|
|
|
|
def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
|
|
"""BEIR query encoder"""
|
|
cleaned_queries = []
|
|
for q in queries:
|
|
if isinstance(q, str):
|
|
cleaned = q.strip()
|
|
if not cleaned:
|
|
cleaned = "[EMPTY]"
|
|
else:
|
|
cleaned = "[INVALID]"
|
|
cleaned_queries.append(cleaned)
|
|
return self._batch_embed(cleaned_queries)
|
|
|
|
def encode_corpus(
|
|
self,
|
|
corpus: Union[List[Dict[str, str]], Dict[str, Dict[str, str]]],
|
|
**kwargs,
|
|
) -> np.ndarray:
|
|
"""BEIR corpus encoder"""
|
|
if isinstance(corpus, dict):
|
|
corpus = list(corpus.values())
|
|
|
|
texts = []
|
|
for doc in corpus:
|
|
title = (doc.get("title") or "").strip()
|
|
text = (doc.get("text") or "").strip()
|
|
combined = " ".join(filter(None, [title, text]))
|
|
if not combined:
|
|
combined = "[EMPTY]"
|
|
texts.append(combined)
|
|
|
|
return self._batch_embed(texts)
|
|
|
|
|
|
def load_scifact_dataset() -> tuple[Dict, Dict, Dict]:
|
|
"""Load SciFact benchmark."""
|
|
DATASETS_ROOT.mkdir(parents=True, exist_ok=True)
|
|
scifact_path = DATASETS_ROOT / "scifact"
|
|
|
|
if _has_local_beir_files(scifact_path):
|
|
print(" Using local SciFact dataset cache")
|
|
return _load_local_beir_dataset(scifact_path)
|
|
|
|
print(" SciFact dataset not found locally. Downloading...")
|
|
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scifact.zip"
|
|
data_path = util.download_and_unzip(url, out_dir=str(DATASETS_ROOT))
|
|
downloaded_path = Path(data_path)
|
|
if downloaded_path.name == "scifact" and _has_local_beir_files(downloaded_path):
|
|
return _load_local_beir_dataset(downloaded_path)
|
|
|
|
return _load_local_beir_dataset(scifact_path)
|
|
|
|
|
|
def load_cosqa_dataset() -> tuple[Dict, Dict, Dict]:
|
|
"""Load CoSQA benchmark."""
|
|
data_path = DATASETS_ROOT / "cosqa"
|
|
if _has_local_beir_files(data_path):
|
|
print(" Using local CoSQA dataset cache")
|
|
return _load_local_beir_dataset(data_path)
|
|
|
|
print(" CoSQA dataset not found locally. Downloading and preparing...")
|
|
(data_path / "qrels").mkdir(parents=True, exist_ok=True)
|
|
|
|
# Load from HuggingFace
|
|
hf_corpus = load_dataset("CoIR-Retrieval/cosqa", "corpus", split="corpus")
|
|
hf_queries = load_dataset("CoIR-Retrieval/cosqa", "queries", split="queries")
|
|
hf_qrels = load_dataset("CoIR-Retrieval/cosqa", "default", split="test")
|
|
|
|
# Save in BEIR format
|
|
with open(data_path / "corpus.jsonl", "w") as f:
|
|
for item in hf_corpus:
|
|
f.write(
|
|
json.dumps(
|
|
{"_id": str(item["_id"]), "text": item["text"], "title": ""}
|
|
)
|
|
+ "\n"
|
|
)
|
|
|
|
with open(data_path / "queries.jsonl", "w") as f:
|
|
for item in hf_queries:
|
|
f.write(json.dumps({"_id": str(item["_id"]), "text": item["text"]}) + "\n")
|
|
|
|
with open(data_path / "qrels" / "test.tsv", "w") as f:
|
|
f.write("query-id\tcorpus-id\tscore\n")
|
|
for item in hf_qrels:
|
|
f.write(f"{item['query-id']}\t{item['corpus-id']}\t{item['score']}\n")
|
|
|
|
return _load_local_beir_dataset(data_path)
|
|
|
|
|
|
def load_codexglue_dataset() -> tuple[Dict, Dict, Dict]:
|
|
"""Load CodexGlue benchmark."""
|
|
data_path = DATASETS_ROOT / "codexglue"
|
|
if _has_local_beir_files(data_path):
|
|
print(" Using local CodexGlue dataset cache")
|
|
return _load_local_beir_dataset(data_path)
|
|
|
|
print(" CodexGlue dataset not found locally. Downloading and preparing...")
|
|
(data_path / "qrels").mkdir(parents=True, exist_ok=True)
|
|
|
|
raw_dataset = load_dataset("google/code_x_glue_tc_nl_code_search_adv", split="test")
|
|
with open(data_path / "corpus.jsonl", "w") as corpus_file:
|
|
for i, data in enumerate(raw_dataset):
|
|
docid = f"doc_{i}"
|
|
corpus_file.write(
|
|
json.dumps(
|
|
{
|
|
"_id": docid,
|
|
"title": data.get("func_name", ""),
|
|
"text": data["code"],
|
|
}
|
|
)
|
|
+ "\n"
|
|
)
|
|
|
|
with open(data_path / "queries.jsonl", "w") as query_file:
|
|
for i, data in enumerate(raw_dataset):
|
|
queryid = f"q_{i}"
|
|
query_file.write(
|
|
json.dumps({"_id": queryid, "text": data["docstring"]}) + "\n"
|
|
)
|
|
|
|
with open(data_path / "qrels" / "test.tsv", "w") as qrels_file:
|
|
qrels_file.write("query-id\tcorpus-id\tscore\n")
|
|
for i, _ in enumerate(raw_dataset):
|
|
qrels_file.write(f"q_{i}\tdoc_{i}\t1\n")
|
|
|
|
return _load_local_beir_dataset(data_path)
|
|
|
|
|
|
BENCHMARK_LOADERS = {
|
|
"scifact": load_scifact_dataset,
|
|
"cosqa": load_cosqa_dataset,
|
|
"codexglue": load_codexglue_dataset,
|
|
}
|
|
|
|
|
|
def evaluate_model_on_benchmark(
|
|
benchmark: str, provider: str, model: str, k_values: List[int] = None
|
|
) -> Dict[str, Any]:
|
|
"""Evaluate a model on a benchmark."""
|
|
if k_values is None:
|
|
k_values = [1, 5, 10, 100]
|
|
|
|
print(f" Loading {benchmark.upper()} dataset...")
|
|
corpus, queries, qrels = BENCHMARK_LOADERS[benchmark]()
|
|
|
|
print(f" Corpus: {len(corpus)}, Queries: {len(queries)}")
|
|
|
|
# Select adapter based on provider
|
|
if provider == "ollama":
|
|
adapter = BEIROllamaEmbeddings(
|
|
base_url=settings.ollama_local_url,
|
|
model=model,
|
|
batch_size=64
|
|
)
|
|
elif provider == "huggingface":
|
|
adapter = BEIRHuggingFaceEmbeddings(model=model, batch_size=64)
|
|
else:
|
|
raise ValueError(f"Unknown provider: {provider}")
|
|
|
|
retriever = DenseRetrievalExactSearch(adapter, batch_size=64)
|
|
evaluator = EvaluateRetrieval(retriever, score_function="cos_sim")
|
|
|
|
print(" Running retrieval...")
|
|
results = evaluator.retrieve(corpus, queries)
|
|
|
|
print(" Computing metrics...")
|
|
ndcg, _map, recall, precision = evaluator.evaluate(qrels, results, k_values)
|
|
|
|
return {"NDCG": ndcg, "MAP": _map, "Recall": recall, "Precision": precision}
|
|
|
|
|
|
def parse_model_spec(model_spec: str) -> tuple[str, str]:
|
|
"""
|
|
Parse model spec. Format: "provider:model_name" (default provider: ollama).
|
|
Examples: "ollama:qwen3", "openai:text-embedding-3-small", "bge-me3:latest"
|
|
"""
|
|
if ":" in model_spec:
|
|
parts = model_spec.split(":", 1)
|
|
if parts[0].lower() in ["ollama", "openai", "huggingface", "bedrock"]:
|
|
return parts[0].lower(), parts[1]
|
|
return "ollama", model_spec
|
|
|
|
|
|
def evaluate_models(
|
|
models: List[str], benchmarks: List[str], output_folder: Path, k_values: List[int]
|
|
) -> None:
|
|
"""Evaluate multiple models on multiple benchmarks."""
|
|
output_folder.mkdir(parents=True, exist_ok=True)
|
|
all_results = {}
|
|
|
|
for model_spec in models:
|
|
provider, model_name = parse_model_spec(model_spec)
|
|
print(f"\n{'='*60}\nModel: {model_name} ({provider})\n{'='*60}")
|
|
|
|
model_results = {}
|
|
for benchmark in benchmarks:
|
|
if benchmark not in BENCHMARK_LOADERS:
|
|
print(f"✗ Unknown benchmark: {benchmark}")
|
|
continue
|
|
|
|
print(f"\nEvaluating on {benchmark}...")
|
|
try:
|
|
metrics = evaluate_model_on_benchmark(
|
|
benchmark, provider, model_name, k_values=k_values
|
|
)
|
|
model_results[benchmark] = metrics
|
|
print("✓ Complete")
|
|
except Exception as e:
|
|
print(f"✗ Error: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
all_results[model_spec] = model_results
|
|
|
|
output_file = output_folder / f"results_{'_'.join(models)}_{'_'.join(benchmarks)}.json"
|
|
print(f"\n{'='*60}\nSaving to {output_file}")
|
|
with open(output_file, "w") as f:
|
|
json.dump(all_results, f, indent=2)
|
|
print("✓ Done")
|
|
|
|
|
|
@app.command()
|
|
def main(
|
|
models: List[str] = typer.Option(
|
|
None,
|
|
"--model",
|
|
"-m",
|
|
help="Model spec (format: 'provider:model' or just 'model' for Ollama). "
|
|
"Providers: ollama, huggingface. Can specify multiple times. "
|
|
"Default: huggingface:sentence-transformers/all-MiniLM-L6-v2",
|
|
),
|
|
benchmarks: List[str] = typer.Option(
|
|
None,
|
|
"--benchmark",
|
|
"-b",
|
|
help="Benchmark name (scifact, cosqa, codexglue). Default: all three",
|
|
),
|
|
output_folder: Path = typer.Option(
|
|
Path("research/embedding_eval_results"),
|
|
"--output",
|
|
"-o",
|
|
help="Output folder for results.",
|
|
),
|
|
k_values: str = typer.Option(
|
|
"1,5,10,100",
|
|
"--k-values",
|
|
"-k",
|
|
help="Comma-separated k values for metrics.",
|
|
),
|
|
) -> None:
|
|
"""
|
|
Evaluate embedding models on CodexGlue, CoSQA, and SciFact benchmarks.
|
|
|
|
Examples:
|
|
# HuggingFace model (no Ollama required)
|
|
python evaluate_embeddings_pipeline.py
|
|
|
|
# Different HuggingFace model
|
|
python evaluate_embeddings_pipeline.py -m huggingface:sentence-transformers/bge-small-en-v1.5
|
|
|
|
# Ollama model
|
|
python evaluate_embeddings_pipeline.py -m ollama:qwen:embeddings
|
|
|
|
# Multiple models and single benchmark
|
|
python evaluate_embeddings_pipeline.py -m huggingface:all-MiniLM-L6-v2 -m ollama:bge-m3 -b scifact -o ./results
|
|
"""
|
|
if not models:
|
|
models = ["bge-m3:latest", "qwen3-0.6B-emb:latest"]
|
|
|
|
if not benchmarks:
|
|
benchmarks = ["scifact", "cosqa", "codexglue"]
|
|
|
|
k_list = [int(k.strip()) for k in k_values.split(",")]
|
|
|
|
evaluate_models(models=models, benchmarks=benchmarks, output_folder=output_folder, k_values=k_list)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app()
|