updated pipeline to only download files when missing

This commit is contained in:
pseco 2026-03-25 10:06:07 +01:00
parent 2a33f8eb06
commit 9b1a0e54d5
1 changed files with 44 additions and 13 deletions

View File

@ -19,7 +19,6 @@ from beir.retrieval.search.dense import DenseRetrievalExactSearch
from beir import util
from datasets import load_dataset
from src.config import settings
from src.utils.emb_factory import create_embedding_model
# Import embedding factory
project_root = settings.proj_root
DATASETS_ROOT = project_root / "research" / "embeddings" / "datasets"
@ -27,6 +26,21 @@ DATASETS_ROOT = project_root / "research" / "embeddings" / "datasets"
app = typer.Typer()
def _has_local_beir_files(data_path: Path) -> bool:
"""Return True when a dataset folder already has the required BEIR files."""
required_files = [
data_path / "corpus.jsonl",
data_path / "queries.jsonl",
data_path / "qrels" / "test.tsv",
]
return all(path.exists() and path.stat().st_size > 0 for path in required_files)
def _load_local_beir_dataset(data_path: Path) -> tuple[Dict, Dict, Dict]:
"""Load a BEIR-formatted dataset from local disk."""
return GenericDataLoader(str(data_path)).load(split="test")
class BEIROllamaEmbeddings:
"""
Adapter that makes LangChain's OllamaEmbeddings compatible with BEIR.
@ -171,17 +185,30 @@ class BEIRHuggingFaceEmbeddings:
def load_scifact_dataset() -> tuple[Dict, Dict, Dict]:
"""Load SciFact benchmark."""
DATASETS_ROOT.mkdir(parents=True, exist_ok=True)
scifact_path = DATASETS_ROOT / "scifact"
if _has_local_beir_files(scifact_path):
print(" Using local SciFact dataset cache")
return _load_local_beir_dataset(scifact_path)
print(" SciFact dataset not found locally. Downloading...")
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scifact.zip"
data_path = util.download_and_unzip(url, out_dir=str(DATASETS_ROOT))
scifact_path = Path(data_path)
if scifact_path.name != "scifact":
scifact_path = DATASETS_ROOT / "scifact"
return GenericDataLoader(str(scifact_path)).load(split="test")
downloaded_path = Path(data_path)
if downloaded_path.name == "scifact" and _has_local_beir_files(downloaded_path):
return _load_local_beir_dataset(downloaded_path)
return _load_local_beir_dataset(scifact_path)
def load_cosqa_dataset() -> tuple[Dict, Dict, Dict]:
"""Load CoSQA benchmark."""
data_path = DATASETS_ROOT / "cosqa"
if _has_local_beir_files(data_path):
print(" Using local CoSQA dataset cache")
return _load_local_beir_dataset(data_path)
print(" CoSQA dataset not found locally. Downloading and preparing...")
(data_path / "qrels").mkdir(parents=True, exist_ok=True)
# Load from HuggingFace
@ -208,12 +235,17 @@ def load_cosqa_dataset() -> tuple[Dict, Dict, Dict]:
for item in hf_qrels:
f.write(f"{item['query-id']}\t{item['corpus-id']}\t{item['score']}\n")
return GenericDataLoader(str(data_path)).load(split="test")
return _load_local_beir_dataset(data_path)
def load_codexglue_dataset() -> tuple[Dict, Dict, Dict]:
"""Load CodexGlue benchmark."""
data_path = DATASETS_ROOT / "codexglue"
if _has_local_beir_files(data_path):
print(" Using local CodexGlue dataset cache")
return _load_local_beir_dataset(data_path)
print(" CodexGlue dataset not found locally. Downloading and preparing...")
(data_path / "qrels").mkdir(parents=True, exist_ok=True)
raw_dataset = load_dataset("google/code_x_glue_tc_nl_code_search_adv", split="test")
@ -243,7 +275,7 @@ def load_codexglue_dataset() -> tuple[Dict, Dict, Dict]:
for i, _ in enumerate(raw_dataset):
qrels_file.write(f"q_{i}\tdoc_{i}\t1\n")
return GenericDataLoader(str(data_path)).load(split="test")
return _load_local_beir_dataset(data_path)
BENCHMARK_LOADERS = {
@ -280,10 +312,10 @@ def evaluate_model_on_benchmark(
retriever = DenseRetrievalExactSearch(adapter, batch_size=64)
evaluator = EvaluateRetrieval(retriever, score_function="cos_sim")
print(f" Running retrieval...")
print(" Running retrieval...")
results = evaluator.retrieve(corpus, queries)
print(f" Computing metrics...")
print(" Computing metrics...")
ndcg, _map, recall, precision = evaluator.evaluate(qrels, results, k_values)
return {"NDCG": ndcg, "MAP": _map, "Recall": recall, "Precision": precision}
@ -324,7 +356,7 @@ def evaluate_models(
benchmark, provider, model_name, k_values=k_values
)
model_results[benchmark] = metrics
print(f"✓ Complete")
print("✓ Complete")
except Exception as e:
print(f"✗ Error: {e}")
import traceback
@ -332,12 +364,11 @@ def evaluate_models(
all_results[model_spec] = model_results
# Save results
output_file = output_folder / "results.json"
output_file = output_folder / f"results_{'_'.join(models)}_{'_'.join(benchmarks)}.json"
print(f"\n{'='*60}\nSaving to {output_file}")
with open(output_file, "w") as f:
json.dump(all_results, f, indent=2)
print(f"✓ Done")
print("✓ Done")
@app.command()