updated pipeline to only download files when missing
This commit is contained in:
parent
2a33f8eb06
commit
9b1a0e54d5
|
|
@ -19,7 +19,6 @@ from beir.retrieval.search.dense import DenseRetrievalExactSearch
|
||||||
from beir import util
|
from beir import util
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from src.config import settings
|
from src.config import settings
|
||||||
from src.utils.emb_factory import create_embedding_model
|
|
||||||
# Import embedding factory
|
# Import embedding factory
|
||||||
project_root = settings.proj_root
|
project_root = settings.proj_root
|
||||||
DATASETS_ROOT = project_root / "research" / "embeddings" / "datasets"
|
DATASETS_ROOT = project_root / "research" / "embeddings" / "datasets"
|
||||||
|
|
@ -27,6 +26,21 @@ DATASETS_ROOT = project_root / "research" / "embeddings" / "datasets"
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
|
|
||||||
|
def _has_local_beir_files(data_path: Path) -> bool:
|
||||||
|
"""Return True when a dataset folder already has the required BEIR files."""
|
||||||
|
required_files = [
|
||||||
|
data_path / "corpus.jsonl",
|
||||||
|
data_path / "queries.jsonl",
|
||||||
|
data_path / "qrels" / "test.tsv",
|
||||||
|
]
|
||||||
|
return all(path.exists() and path.stat().st_size > 0 for path in required_files)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_local_beir_dataset(data_path: Path) -> tuple[Dict, Dict, Dict]:
|
||||||
|
"""Load a BEIR-formatted dataset from local disk."""
|
||||||
|
return GenericDataLoader(str(data_path)).load(split="test")
|
||||||
|
|
||||||
|
|
||||||
class BEIROllamaEmbeddings:
|
class BEIROllamaEmbeddings:
|
||||||
"""
|
"""
|
||||||
Adapter that makes LangChain's OllamaEmbeddings compatible with BEIR.
|
Adapter that makes LangChain's OllamaEmbeddings compatible with BEIR.
|
||||||
|
|
@ -171,17 +185,30 @@ class BEIRHuggingFaceEmbeddings:
|
||||||
def load_scifact_dataset() -> tuple[Dict, Dict, Dict]:
|
def load_scifact_dataset() -> tuple[Dict, Dict, Dict]:
|
||||||
"""Load SciFact benchmark."""
|
"""Load SciFact benchmark."""
|
||||||
DATASETS_ROOT.mkdir(parents=True, exist_ok=True)
|
DATASETS_ROOT.mkdir(parents=True, exist_ok=True)
|
||||||
|
scifact_path = DATASETS_ROOT / "scifact"
|
||||||
|
|
||||||
|
if _has_local_beir_files(scifact_path):
|
||||||
|
print(" Using local SciFact dataset cache")
|
||||||
|
return _load_local_beir_dataset(scifact_path)
|
||||||
|
|
||||||
|
print(" SciFact dataset not found locally. Downloading...")
|
||||||
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scifact.zip"
|
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scifact.zip"
|
||||||
data_path = util.download_and_unzip(url, out_dir=str(DATASETS_ROOT))
|
data_path = util.download_and_unzip(url, out_dir=str(DATASETS_ROOT))
|
||||||
scifact_path = Path(data_path)
|
downloaded_path = Path(data_path)
|
||||||
if scifact_path.name != "scifact":
|
if downloaded_path.name == "scifact" and _has_local_beir_files(downloaded_path):
|
||||||
scifact_path = DATASETS_ROOT / "scifact"
|
return _load_local_beir_dataset(downloaded_path)
|
||||||
return GenericDataLoader(str(scifact_path)).load(split="test")
|
|
||||||
|
return _load_local_beir_dataset(scifact_path)
|
||||||
|
|
||||||
|
|
||||||
def load_cosqa_dataset() -> tuple[Dict, Dict, Dict]:
|
def load_cosqa_dataset() -> tuple[Dict, Dict, Dict]:
|
||||||
"""Load CoSQA benchmark."""
|
"""Load CoSQA benchmark."""
|
||||||
data_path = DATASETS_ROOT / "cosqa"
|
data_path = DATASETS_ROOT / "cosqa"
|
||||||
|
if _has_local_beir_files(data_path):
|
||||||
|
print(" Using local CoSQA dataset cache")
|
||||||
|
return _load_local_beir_dataset(data_path)
|
||||||
|
|
||||||
|
print(" CoSQA dataset not found locally. Downloading and preparing...")
|
||||||
(data_path / "qrels").mkdir(parents=True, exist_ok=True)
|
(data_path / "qrels").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Load from HuggingFace
|
# Load from HuggingFace
|
||||||
|
|
@ -208,12 +235,17 @@ def load_cosqa_dataset() -> tuple[Dict, Dict, Dict]:
|
||||||
for item in hf_qrels:
|
for item in hf_qrels:
|
||||||
f.write(f"{item['query-id']}\t{item['corpus-id']}\t{item['score']}\n")
|
f.write(f"{item['query-id']}\t{item['corpus-id']}\t{item['score']}\n")
|
||||||
|
|
||||||
return GenericDataLoader(str(data_path)).load(split="test")
|
return _load_local_beir_dataset(data_path)
|
||||||
|
|
||||||
|
|
||||||
def load_codexglue_dataset() -> tuple[Dict, Dict, Dict]:
|
def load_codexglue_dataset() -> tuple[Dict, Dict, Dict]:
|
||||||
"""Load CodexGlue benchmark."""
|
"""Load CodexGlue benchmark."""
|
||||||
data_path = DATASETS_ROOT / "codexglue"
|
data_path = DATASETS_ROOT / "codexglue"
|
||||||
|
if _has_local_beir_files(data_path):
|
||||||
|
print(" Using local CodexGlue dataset cache")
|
||||||
|
return _load_local_beir_dataset(data_path)
|
||||||
|
|
||||||
|
print(" CodexGlue dataset not found locally. Downloading and preparing...")
|
||||||
(data_path / "qrels").mkdir(parents=True, exist_ok=True)
|
(data_path / "qrels").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
raw_dataset = load_dataset("google/code_x_glue_tc_nl_code_search_adv", split="test")
|
raw_dataset = load_dataset("google/code_x_glue_tc_nl_code_search_adv", split="test")
|
||||||
|
|
@ -243,7 +275,7 @@ def load_codexglue_dataset() -> tuple[Dict, Dict, Dict]:
|
||||||
for i, _ in enumerate(raw_dataset):
|
for i, _ in enumerate(raw_dataset):
|
||||||
qrels_file.write(f"q_{i}\tdoc_{i}\t1\n")
|
qrels_file.write(f"q_{i}\tdoc_{i}\t1\n")
|
||||||
|
|
||||||
return GenericDataLoader(str(data_path)).load(split="test")
|
return _load_local_beir_dataset(data_path)
|
||||||
|
|
||||||
|
|
||||||
BENCHMARK_LOADERS = {
|
BENCHMARK_LOADERS = {
|
||||||
|
|
@ -280,10 +312,10 @@ def evaluate_model_on_benchmark(
|
||||||
retriever = DenseRetrievalExactSearch(adapter, batch_size=64)
|
retriever = DenseRetrievalExactSearch(adapter, batch_size=64)
|
||||||
evaluator = EvaluateRetrieval(retriever, score_function="cos_sim")
|
evaluator = EvaluateRetrieval(retriever, score_function="cos_sim")
|
||||||
|
|
||||||
print(f" Running retrieval...")
|
print(" Running retrieval...")
|
||||||
results = evaluator.retrieve(corpus, queries)
|
results = evaluator.retrieve(corpus, queries)
|
||||||
|
|
||||||
print(f" Computing metrics...")
|
print(" Computing metrics...")
|
||||||
ndcg, _map, recall, precision = evaluator.evaluate(qrels, results, k_values)
|
ndcg, _map, recall, precision = evaluator.evaluate(qrels, results, k_values)
|
||||||
|
|
||||||
return {"NDCG": ndcg, "MAP": _map, "Recall": recall, "Precision": precision}
|
return {"NDCG": ndcg, "MAP": _map, "Recall": recall, "Precision": precision}
|
||||||
|
|
@ -324,7 +356,7 @@ def evaluate_models(
|
||||||
benchmark, provider, model_name, k_values=k_values
|
benchmark, provider, model_name, k_values=k_values
|
||||||
)
|
)
|
||||||
model_results[benchmark] = metrics
|
model_results[benchmark] = metrics
|
||||||
print(f"✓ Complete")
|
print("✓ Complete")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Error: {e}")
|
print(f"✗ Error: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
@ -332,12 +364,11 @@ def evaluate_models(
|
||||||
|
|
||||||
all_results[model_spec] = model_results
|
all_results[model_spec] = model_results
|
||||||
|
|
||||||
# Save results
|
output_file = output_folder / f"results_{'_'.join(models)}_{'_'.join(benchmarks)}.json"
|
||||||
output_file = output_folder / "results.json"
|
|
||||||
print(f"\n{'='*60}\nSaving to {output_file}")
|
print(f"\n{'='*60}\nSaving to {output_file}")
|
||||||
with open(output_file, "w") as f:
|
with open(output_file, "w") as f:
|
||||||
json.dump(all_results, f, indent=2)
|
json.dump(all_results, f, indent=2)
|
||||||
print(f"✓ Done")
|
print("✓ Done")
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue