From b5745173402eaddd43d22c98baa87c95687c275c Mon Sep 17 00:00:00 2001 From: pseco Date: Mon, 23 Mar 2026 13:17:50 +0100 Subject: [PATCH] working on ADR0005 --- .../evaluate_embeddings_pipeline.py | 400 ++++++++++++++++++ src/config.py | 3 +- 2 files changed, 402 insertions(+), 1 deletion(-) create mode 100644 research/embeddings/evaluate_embeddings_pipeline.py diff --git a/research/embeddings/evaluate_embeddings_pipeline.py b/research/embeddings/evaluate_embeddings_pipeline.py new file mode 100644 index 0000000..36b2a32 --- /dev/null +++ b/research/embeddings/evaluate_embeddings_pipeline.py @@ -0,0 +1,400 @@ +""" +Embedding Evaluation Pipeline + +Evaluate embedding models across CodexGlue, CoSQA, and SciFact benchmarks. +Supports multiple embedding providers via factory methods. +""" + +import json +from pathlib import Path +from typing import Any, Dict, List, Union + +import numpy as np +import typer +from langchain_ollama import OllamaEmbeddings +from langchain_huggingface import HuggingFaceEmbeddings +from beir.datasets.data_loader import GenericDataLoader +from beir.retrieval.evaluation import EvaluateRetrieval +from beir.retrieval.search.dense import DenseRetrievalExactSearch +from beir import util +from datasets import load_dataset +from src.config import settings +from src.utils.emb_factory import create_embedding_model +# Import embedding factory +project_root = settings.proj_root +DATASETS_ROOT = project_root / "research" / "embeddings" / "datasets" + +app = typer.Typer() + + +class BEIROllamaEmbeddings: + """ + Adapter that makes LangChain's OllamaEmbeddings compatible with BEIR. + """ + + def __init__( + self, + base_url: str, + model: str, + batch_size: int = 64, + ) -> None: + self.batch_size = batch_size + self.embeddings = OllamaEmbeddings( + base_url=base_url, + model=model, + ) + + def _batch_embed(self, texts: List[str]) -> np.ndarray: + vectors = [] + + for i in range(0, len(texts), self.batch_size): + batch = texts[i : i + self.batch_size] + batch_vectors = self.embeddings.embed_documents(batch) + + # Handle NaN values by replacing with zeros + for vec in batch_vectors: + if isinstance(vec, (list, np.ndarray)): + vec_array = np.asarray(vec, dtype=np.float32) + # Replace NaN with zeros + vec_array = np.nan_to_num(vec_array, nan=0.0, posinf=0.0, neginf=0.0) + vectors.append(vec_array) + else: + vectors.append(vec) + + return np.asarray(vectors, dtype=np.float32) + + def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray: + """ + BEIR query encoder + """ + # Filter and clean queries - replace empty ones with placeholder + cleaned_queries = [] + for q in queries: + if isinstance(q, str): + cleaned = q.strip() + if not cleaned: + cleaned = "[EMPTY]" + else: + cleaned = "[INVALID]" + cleaned_queries.append(cleaned) + + return self._batch_embed(cleaned_queries) + + def encode_corpus( + self, + corpus: Union[List[Dict[str, str]], Dict[str, Dict[str, str]]], + **kwargs, + ) -> np.ndarray: + """ + BEIR corpus encoder + """ + if isinstance(corpus, dict): + corpus = list(corpus.values()) + + texts = [] + for doc in corpus: + title = (doc.get("title") or "").strip() + text = (doc.get("text") or "").strip() + + # Combine title and text, filtering out empty strings + combined = " ".join(filter(None, [title, text])) + + # Use placeholder if both are empty to avoid NaN embeddings + if not combined: + combined = "[EMPTY]" + + texts.append(combined) + + return self._batch_embed(texts) + + +class BEIRHuggingFaceEmbeddings: + """ + Adapter that makes LangChain's HuggingFaceEmbeddings compatible with BEIR. + """ + + def __init__(self, model: str, batch_size: int = 64) -> None: + self.batch_size = batch_size + self.embeddings = HuggingFaceEmbeddings(model_name=model) + + def _batch_embed(self, texts: List[str]) -> np.ndarray: + vectors = [] + for i in range(0, len(texts), self.batch_size): + batch = texts[i : i + self.batch_size] + batch_vectors = self.embeddings.embed_documents(batch) + + # Handle NaN values + for vec in batch_vectors: + if isinstance(vec, (list, np.ndarray)): + vec_array = np.asarray(vec, dtype=np.float32) + vec_array = np.nan_to_num(vec_array, nan=0.0, posinf=0.0, neginf=0.0) + vectors.append(vec_array) + else: + vectors.append(vec) + + return np.asarray(vectors, dtype=np.float32) + + def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray: + """BEIR query encoder""" + cleaned_queries = [] + for q in queries: + if isinstance(q, str): + cleaned = q.strip() + if not cleaned: + cleaned = "[EMPTY]" + else: + cleaned = "[INVALID]" + cleaned_queries.append(cleaned) + return self._batch_embed(cleaned_queries) + + def encode_corpus( + self, + corpus: Union[List[Dict[str, str]], Dict[str, Dict[str, str]]], + **kwargs, + ) -> np.ndarray: + """BEIR corpus encoder""" + if isinstance(corpus, dict): + corpus = list(corpus.values()) + + texts = [] + for doc in corpus: + title = (doc.get("title") or "").strip() + text = (doc.get("text") or "").strip() + combined = " ".join(filter(None, [title, text])) + if not combined: + combined = "[EMPTY]" + texts.append(combined) + + return self._batch_embed(texts) + + +def load_scifact_dataset() -> tuple[Dict, Dict, Dict]: + """Load SciFact benchmark.""" + DATASETS_ROOT.mkdir(parents=True, exist_ok=True) + url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scifact.zip" + data_path = util.download_and_unzip(url, out_dir=str(DATASETS_ROOT)) + scifact_path = Path(data_path) + if scifact_path.name != "scifact": + scifact_path = DATASETS_ROOT / "scifact" + return GenericDataLoader(str(scifact_path)).load(split="test") + + +def load_cosqa_dataset() -> tuple[Dict, Dict, Dict]: + """Load CoSQA benchmark.""" + data_path = DATASETS_ROOT / "cosqa" + (data_path / "qrels").mkdir(parents=True, exist_ok=True) + + # Load from HuggingFace + hf_corpus = load_dataset("CoIR-Retrieval/cosqa", "corpus", split="corpus") + hf_queries = load_dataset("CoIR-Retrieval/cosqa", "queries", split="queries") + hf_qrels = load_dataset("CoIR-Retrieval/cosqa", "default", split="test") + + # Save in BEIR format + with open(data_path / "corpus.jsonl", "w") as f: + for item in hf_corpus: + f.write( + json.dumps( + {"_id": str(item["_id"]), "text": item["text"], "title": ""} + ) + + "\n" + ) + + with open(data_path / "queries.jsonl", "w") as f: + for item in hf_queries: + f.write(json.dumps({"_id": str(item["_id"]), "text": item["text"]}) + "\n") + + with open(data_path / "qrels" / "test.tsv", "w") as f: + f.write("query-id\tcorpus-id\tscore\n") + for item in hf_qrels: + f.write(f"{item['query-id']}\t{item['corpus-id']}\t{item['score']}\n") + + return GenericDataLoader(str(data_path)).load(split="test") + + +def load_codexglue_dataset() -> tuple[Dict, Dict, Dict]: + """Load CodexGlue benchmark.""" + data_path = DATASETS_ROOT / "codexglue" + (data_path / "qrels").mkdir(parents=True, exist_ok=True) + + raw_dataset = load_dataset("google/code_x_glue_tc_nl_code_search_adv", split="test") + with open(data_path / "corpus.jsonl", "w") as corpus_file: + for i, data in enumerate(raw_dataset): + docid = f"doc_{i}" + corpus_file.write( + json.dumps( + { + "_id": docid, + "title": data.get("func_name", ""), + "text": data["code"], + } + ) + + "\n" + ) + + with open(data_path / "queries.jsonl", "w") as query_file: + for i, data in enumerate(raw_dataset): + queryid = f"q_{i}" + query_file.write( + json.dumps({"_id": queryid, "text": data["docstring"]}) + "\n" + ) + + with open(data_path / "qrels" / "test.tsv", "w") as qrels_file: + qrels_file.write("query-id\tcorpus-id\tscore\n") + for i, _ in enumerate(raw_dataset): + qrels_file.write(f"q_{i}\tdoc_{i}\t1\n") + + return GenericDataLoader(str(data_path)).load(split="test") + + +BENCHMARK_LOADERS = { + "scifact": load_scifact_dataset, + "cosqa": load_cosqa_dataset, + "codexglue": load_codexglue_dataset, +} + + +def evaluate_model_on_benchmark( + benchmark: str, provider: str, model: str, k_values: List[int] = None +) -> Dict[str, Any]: + """Evaluate a model on a benchmark.""" + if k_values is None: + k_values = [1, 5, 10, 100] + + print(f" Loading {benchmark.upper()} dataset...") + corpus, queries, qrels = BENCHMARK_LOADERS[benchmark]() + + print(f" Corpus: {len(corpus)}, Queries: {len(queries)}") + + # Select adapter based on provider + if provider == "ollama": + adapter = BEIROllamaEmbeddings( + base_url=settings.ollama_local_url, + model=model, + batch_size=64 + ) + elif provider == "huggingface": + adapter = BEIRHuggingFaceEmbeddings(model=model, batch_size=64) + else: + raise ValueError(f"Unknown provider: {provider}") + + retriever = DenseRetrievalExactSearch(adapter, batch_size=64) + evaluator = EvaluateRetrieval(retriever, score_function="cos_sim") + + print(f" Running retrieval...") + results = evaluator.retrieve(corpus, queries) + + print(f" Computing metrics...") + ndcg, _map, recall, precision = evaluator.evaluate(qrels, results, k_values) + + return {"NDCG": ndcg, "MAP": _map, "Recall": recall, "Precision": precision} + + +def parse_model_spec(model_spec: str) -> tuple[str, str]: + """ + Parse model spec. Format: "provider:model_name" (default provider: ollama). + Examples: "ollama:qwen3", "openai:text-embedding-3-small", "bge-me3:latest" + """ + if ":" in model_spec: + parts = model_spec.split(":", 1) + if parts[0].lower() in ["ollama", "openai", "huggingface", "bedrock"]: + return parts[0].lower(), parts[1] + return "ollama", model_spec + + +def evaluate_models( + models: List[str], benchmarks: List[str], output_folder: Path, k_values: List[int] +) -> None: + """Evaluate multiple models on multiple benchmarks.""" + output_folder.mkdir(parents=True, exist_ok=True) + all_results = {} + + for model_spec in models: + provider, model_name = parse_model_spec(model_spec) + print(f"\n{'='*60}\nModel: {model_name} ({provider})\n{'='*60}") + + model_results = {} + for benchmark in benchmarks: + if benchmark not in BENCHMARK_LOADERS: + print(f"✗ Unknown benchmark: {benchmark}") + continue + + print(f"\nEvaluating on {benchmark}...") + try: + metrics = evaluate_model_on_benchmark( + benchmark, provider, model_name, k_values=k_values + ) + model_results[benchmark] = metrics + print(f"✓ Complete") + except Exception as e: + print(f"✗ Error: {e}") + import traceback + traceback.print_exc() + + all_results[model_spec] = model_results + + # Save results + output_file = output_folder / "results.json" + print(f"\n{'='*60}\nSaving to {output_file}") + with open(output_file, "w") as f: + json.dump(all_results, f, indent=2) + print(f"✓ Done") + + +@app.command() +def main( + models: List[str] = typer.Option( + None, + "--model", + "-m", + help="Model spec (format: 'provider:model' or just 'model' for Ollama). " + "Providers: ollama, huggingface. Can specify multiple times. " + "Default: huggingface:sentence-transformers/all-MiniLM-L6-v2", + ), + benchmarks: List[str] = typer.Option( + None, + "--benchmark", + "-b", + help="Benchmark name (scifact, cosqa, codexglue). Default: all three", + ), + output_folder: Path = typer.Option( + Path("research/embedding_eval_results"), + "--output", + "-o", + help="Output folder for results.", + ), + k_values: str = typer.Option( + "1,5,10,100", + "--k-values", + "-k", + help="Comma-separated k values for metrics.", + ), +) -> None: + """ + Evaluate embedding models on CodexGlue, CoSQA, and SciFact benchmarks. + + Examples: + # HuggingFace model (no Ollama required) + python evaluate_embeddings_pipeline.py + + # Different HuggingFace model + python evaluate_embeddings_pipeline.py -m huggingface:sentence-transformers/bge-small-en-v1.5 + + # Ollama model + python evaluate_embeddings_pipeline.py -m ollama:qwen:embeddings + + # Multiple models and single benchmark + python evaluate_embeddings_pipeline.py -m huggingface:all-MiniLM-L6-v2 -m ollama:bge-m3 -b scifact -o ./results + """ + if not models: + models = ["huggingface:sentence-transformers/all-MiniLM-L6-v2"] + + if not benchmarks: + benchmarks = ["scifact", "cosqa", "codexglue"] + + k_list = [int(k.strip()) for k in k_values.split(",")] + + evaluate_models(models=models, benchmarks=benchmarks, output_folder=output_folder, k_values=k_list) + + +if __name__ == "__main__": + app() diff --git a/src/config.py b/src/config.py index ec686b9..4ee639f 100644 --- a/src/config.py +++ b/src/config.py @@ -2,7 +2,8 @@ from pathlib import Path from typing import Optional from pydantic_settings import BaseSettings, SettingsConfigDict - +from dotenv import load_dotenv +load_dotenv() class Settings(BaseSettings): data_path_: Optional[str] = None