Merge branch 'mrh-online-dev' of github.com:BRUNIX-AI/assistance-engine into mrh-online-dev
This commit is contained in:
commit
b2e5d06d96
|
|
@ -19,7 +19,6 @@ from beir.retrieval.search.dense import DenseRetrievalExactSearch
|
|||
from beir import util
|
||||
from datasets import load_dataset
|
||||
from src.config import settings
|
||||
from src.utils.emb_factory import create_embedding_model
|
||||
# Import embedding factory
|
||||
project_root = settings.proj_root
|
||||
DATASETS_ROOT = project_root / "research" / "embeddings" / "datasets"
|
||||
|
|
@ -27,6 +26,21 @@ DATASETS_ROOT = project_root / "research" / "embeddings" / "datasets"
|
|||
app = typer.Typer()
|
||||
|
||||
|
||||
def _has_local_beir_files(data_path: Path) -> bool:
|
||||
"""Return True when a dataset folder already has the required BEIR files."""
|
||||
required_files = [
|
||||
data_path / "corpus.jsonl",
|
||||
data_path / "queries.jsonl",
|
||||
data_path / "qrels" / "test.tsv",
|
||||
]
|
||||
return all(path.exists() and path.stat().st_size > 0 for path in required_files)
|
||||
|
||||
|
||||
def _load_local_beir_dataset(data_path: Path) -> tuple[Dict, Dict, Dict]:
|
||||
"""Load a BEIR-formatted dataset from local disk."""
|
||||
return GenericDataLoader(str(data_path)).load(split="test")
|
||||
|
||||
|
||||
class BEIROllamaEmbeddings:
|
||||
"""
|
||||
Adapter that makes LangChain's OllamaEmbeddings compatible with BEIR.
|
||||
|
|
@ -171,17 +185,30 @@ class BEIRHuggingFaceEmbeddings:
|
|||
def load_scifact_dataset() -> tuple[Dict, Dict, Dict]:
|
||||
"""Load SciFact benchmark."""
|
||||
DATASETS_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
scifact_path = DATASETS_ROOT / "scifact"
|
||||
|
||||
if _has_local_beir_files(scifact_path):
|
||||
print(" Using local SciFact dataset cache")
|
||||
return _load_local_beir_dataset(scifact_path)
|
||||
|
||||
print(" SciFact dataset not found locally. Downloading...")
|
||||
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scifact.zip"
|
||||
data_path = util.download_and_unzip(url, out_dir=str(DATASETS_ROOT))
|
||||
scifact_path = Path(data_path)
|
||||
if scifact_path.name != "scifact":
|
||||
scifact_path = DATASETS_ROOT / "scifact"
|
||||
return GenericDataLoader(str(scifact_path)).load(split="test")
|
||||
downloaded_path = Path(data_path)
|
||||
if downloaded_path.name == "scifact" and _has_local_beir_files(downloaded_path):
|
||||
return _load_local_beir_dataset(downloaded_path)
|
||||
|
||||
return _load_local_beir_dataset(scifact_path)
|
||||
|
||||
|
||||
def load_cosqa_dataset() -> tuple[Dict, Dict, Dict]:
|
||||
"""Load CoSQA benchmark."""
|
||||
data_path = DATASETS_ROOT / "cosqa"
|
||||
if _has_local_beir_files(data_path):
|
||||
print(" Using local CoSQA dataset cache")
|
||||
return _load_local_beir_dataset(data_path)
|
||||
|
||||
print(" CoSQA dataset not found locally. Downloading and preparing...")
|
||||
(data_path / "qrels").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load from HuggingFace
|
||||
|
|
@ -208,12 +235,17 @@ def load_cosqa_dataset() -> tuple[Dict, Dict, Dict]:
|
|||
for item in hf_qrels:
|
||||
f.write(f"{item['query-id']}\t{item['corpus-id']}\t{item['score']}\n")
|
||||
|
||||
return GenericDataLoader(str(data_path)).load(split="test")
|
||||
return _load_local_beir_dataset(data_path)
|
||||
|
||||
|
||||
def load_codexglue_dataset() -> tuple[Dict, Dict, Dict]:
|
||||
"""Load CodexGlue benchmark."""
|
||||
data_path = DATASETS_ROOT / "codexglue"
|
||||
if _has_local_beir_files(data_path):
|
||||
print(" Using local CodexGlue dataset cache")
|
||||
return _load_local_beir_dataset(data_path)
|
||||
|
||||
print(" CodexGlue dataset not found locally. Downloading and preparing...")
|
||||
(data_path / "qrels").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
raw_dataset = load_dataset("google/code_x_glue_tc_nl_code_search_adv", split="test")
|
||||
|
|
@ -243,7 +275,7 @@ def load_codexglue_dataset() -> tuple[Dict, Dict, Dict]:
|
|||
for i, _ in enumerate(raw_dataset):
|
||||
qrels_file.write(f"q_{i}\tdoc_{i}\t1\n")
|
||||
|
||||
return GenericDataLoader(str(data_path)).load(split="test")
|
||||
return _load_local_beir_dataset(data_path)
|
||||
|
||||
|
||||
BENCHMARK_LOADERS = {
|
||||
|
|
@ -280,10 +312,10 @@ def evaluate_model_on_benchmark(
|
|||
retriever = DenseRetrievalExactSearch(adapter, batch_size=64)
|
||||
evaluator = EvaluateRetrieval(retriever, score_function="cos_sim")
|
||||
|
||||
print(f" Running retrieval...")
|
||||
print(" Running retrieval...")
|
||||
results = evaluator.retrieve(corpus, queries)
|
||||
|
||||
print(f" Computing metrics...")
|
||||
print(" Computing metrics...")
|
||||
ndcg, _map, recall, precision = evaluator.evaluate(qrels, results, k_values)
|
||||
|
||||
return {"NDCG": ndcg, "MAP": _map, "Recall": recall, "Precision": precision}
|
||||
|
|
@ -324,7 +356,7 @@ def evaluate_models(
|
|||
benchmark, provider, model_name, k_values=k_values
|
||||
)
|
||||
model_results[benchmark] = metrics
|
||||
print(f"✓ Complete")
|
||||
print("✓ Complete")
|
||||
except Exception as e:
|
||||
print(f"✗ Error: {e}")
|
||||
import traceback
|
||||
|
|
@ -332,12 +364,11 @@ def evaluate_models(
|
|||
|
||||
all_results[model_spec] = model_results
|
||||
|
||||
# Save results
|
||||
output_file = output_folder / "results.json"
|
||||
output_file = output_folder / f"results_{'_'.join(models)}_{'_'.join(benchmarks)}.json"
|
||||
print(f"\n{'='*60}\nSaving to {output_file}")
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(all_results, f, indent=2)
|
||||
print(f"✓ Done")
|
||||
print("✓ Done")
|
||||
|
||||
|
||||
@app.command()
|
||||
|
|
|
|||
Loading…
Reference in New Issue