diff --git a/research/embeddings/evaluate_embeddings_pipeline.py b/research/embeddings/evaluate_embeddings_pipeline.py index ce1eded..b479e67 100644 --- a/research/embeddings/evaluate_embeddings_pipeline.py +++ b/research/embeddings/evaluate_embeddings_pipeline.py @@ -19,7 +19,6 @@ from beir.retrieval.search.dense import DenseRetrievalExactSearch from beir import util from datasets import load_dataset from src.config import settings -from src.utils.emb_factory import create_embedding_model # Import embedding factory project_root = settings.proj_root DATASETS_ROOT = project_root / "research" / "embeddings" / "datasets" @@ -27,6 +26,21 @@ DATASETS_ROOT = project_root / "research" / "embeddings" / "datasets" app = typer.Typer() +def _has_local_beir_files(data_path: Path) -> bool: + """Return True when a dataset folder already has the required BEIR files.""" + required_files = [ + data_path / "corpus.jsonl", + data_path / "queries.jsonl", + data_path / "qrels" / "test.tsv", + ] + return all(path.exists() and path.stat().st_size > 0 for path in required_files) + + +def _load_local_beir_dataset(data_path: Path) -> tuple[Dict, Dict, Dict]: + """Load a BEIR-formatted dataset from local disk.""" + return GenericDataLoader(str(data_path)).load(split="test") + + class BEIROllamaEmbeddings: """ Adapter that makes LangChain's OllamaEmbeddings compatible with BEIR. @@ -171,17 +185,30 @@ class BEIRHuggingFaceEmbeddings: def load_scifact_dataset() -> tuple[Dict, Dict, Dict]: """Load SciFact benchmark.""" DATASETS_ROOT.mkdir(parents=True, exist_ok=True) + scifact_path = DATASETS_ROOT / "scifact" + + if _has_local_beir_files(scifact_path): + print(" Using local SciFact dataset cache") + return _load_local_beir_dataset(scifact_path) + + print(" SciFact dataset not found locally. Downloading...") url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scifact.zip" data_path = util.download_and_unzip(url, out_dir=str(DATASETS_ROOT)) - scifact_path = Path(data_path) - if scifact_path.name != "scifact": - scifact_path = DATASETS_ROOT / "scifact" - return GenericDataLoader(str(scifact_path)).load(split="test") + downloaded_path = Path(data_path) + if downloaded_path.name == "scifact" and _has_local_beir_files(downloaded_path): + return _load_local_beir_dataset(downloaded_path) + + return _load_local_beir_dataset(scifact_path) def load_cosqa_dataset() -> tuple[Dict, Dict, Dict]: """Load CoSQA benchmark.""" data_path = DATASETS_ROOT / "cosqa" + if _has_local_beir_files(data_path): + print(" Using local CoSQA dataset cache") + return _load_local_beir_dataset(data_path) + + print(" CoSQA dataset not found locally. Downloading and preparing...") (data_path / "qrels").mkdir(parents=True, exist_ok=True) # Load from HuggingFace @@ -208,12 +235,17 @@ def load_cosqa_dataset() -> tuple[Dict, Dict, Dict]: for item in hf_qrels: f.write(f"{item['query-id']}\t{item['corpus-id']}\t{item['score']}\n") - return GenericDataLoader(str(data_path)).load(split="test") + return _load_local_beir_dataset(data_path) def load_codexglue_dataset() -> tuple[Dict, Dict, Dict]: """Load CodexGlue benchmark.""" data_path = DATASETS_ROOT / "codexglue" + if _has_local_beir_files(data_path): + print(" Using local CodexGlue dataset cache") + return _load_local_beir_dataset(data_path) + + print(" CodexGlue dataset not found locally. Downloading and preparing...") (data_path / "qrels").mkdir(parents=True, exist_ok=True) raw_dataset = load_dataset("google/code_x_glue_tc_nl_code_search_adv", split="test") @@ -243,7 +275,7 @@ def load_codexglue_dataset() -> tuple[Dict, Dict, Dict]: for i, _ in enumerate(raw_dataset): qrels_file.write(f"q_{i}\tdoc_{i}\t1\n") - return GenericDataLoader(str(data_path)).load(split="test") + return _load_local_beir_dataset(data_path) BENCHMARK_LOADERS = { @@ -280,10 +312,10 @@ def evaluate_model_on_benchmark( retriever = DenseRetrievalExactSearch(adapter, batch_size=64) evaluator = EvaluateRetrieval(retriever, score_function="cos_sim") - print(f" Running retrieval...") + print(" Running retrieval...") results = evaluator.retrieve(corpus, queries) - print(f" Computing metrics...") + print(" Computing metrics...") ndcg, _map, recall, precision = evaluator.evaluate(qrels, results, k_values) return {"NDCG": ndcg, "MAP": _map, "Recall": recall, "Precision": precision} @@ -324,7 +356,7 @@ def evaluate_models( benchmark, provider, model_name, k_values=k_values ) model_results[benchmark] = metrics - print(f"✓ Complete") + print("✓ Complete") except Exception as e: print(f"✗ Error: {e}") import traceback @@ -332,12 +364,11 @@ def evaluate_models( all_results[model_spec] = model_results - # Save results - output_file = output_folder / "results.json" + output_file = output_folder / f"results_{'_'.join(models)}_{'_'.join(benchmarks)}.json" print(f"\n{'='*60}\nSaving to {output_file}") with open(output_file, "w") as f: json.dump(all_results, f, indent=2) - print(f"✓ Done") + print("✓ Done") @app.command()