assistance-engine/scripts/pipelines/classifier/train_classifier.py

205 lines
7.7 KiB
Python

#!/usr/bin/env python3
"""
Phase 2 — Train Layer 2 embedding classifier (ADR-0008).
Reads labeled (query, type) pairs from a JSONL file, embeds them with bge-m3
via Ollama, trains a LogisticRegression classifier, and serializes the model
with joblib.
Usage:
python train_classifier.py
python train_classifier.py --data path/to/dataset.jsonl --output /data/classifier_model.pkl
The output file is loaded at engine startup by graph.py (_load_layer2_model).
Any JSONL file produced by classifier_export.py is compatible as additional
training data — merge with the seed dataset before retraining.
Requirements (add to requirements.txt if not present):
scikit-learn
joblib
numpy
"""
import argparse
import json
import os
import sys
from collections import Counter
from pathlib import Path
import joblib
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.model_selection import StratifiedKFold, cross_val_score, cross_val_predict
def load_data(path: str) -> tuple[list[str], list[str]]:
queries, labels = [], []
skipped = 0
with open(path, encoding="utf-8") as f:
for i, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
rec = json.loads(line)
except json.JSONDecodeError as e:
print(f" [WARN] line {i} skipped — JSON error: {e}", file=sys.stderr)
skipped += 1
continue
q = rec.get("query", "").strip()
t = rec.get("type", "").strip()
if not q or not t:
skipped += 1
continue
queries.append(q)
labels.append(t)
if skipped:
print(f" [WARN] {skipped} records skipped (missing query/type or invalid JSON)")
return queries, labels
def embed_queries(queries: list[str], labels: list[str], base_url: str) -> tuple[np.ndarray, list[str]]:
"""Embed queries with bge-m3 via Ollama, one at a time.
bge-m3 occasionally produces NaN vectors for certain inputs (known Ollama bug).
Embedding one by one lets us detect and skip those queries instead of failing
the entire batch.
Returns (vectors, kept_labels) — labels aligned to the returned vectors.
"""
try:
from langchain_ollama import OllamaEmbeddings
except ImportError:
print("[ERROR] langchain-ollama not installed. Run: pip install langchain-ollama", file=sys.stderr)
sys.exit(1)
emb = OllamaEmbeddings(model="bge-m3", base_url=base_url)
vectors: list[list[float]] = []
kept_labels: list[str] = []
skipped = 0
for i, (query, label) in enumerate(zip(queries, labels)):
try:
vec = emb.embed_query(query)
if any(v != v for v in vec): # NaN check (NaN != NaN)
print(f" [WARN] query {i+1} produced NaN vector, skipped: '{query[:60]}'", file=sys.stderr)
skipped += 1
continue
vectors.append(vec)
kept_labels.append(label)
except Exception as e:
print(f" [WARN] query {i+1} embedding failed ({e}), skipped: '{query[:60]}'", file=sys.stderr)
skipped += 1
continue
if (i + 1) % 10 == 0 or (i + 1) == len(queries):
print(f" Embedded {i+1}/{len(queries)}...", end="\r")
print()
if skipped:
print(f" [WARN] {skipped} queries skipped due to NaN or embedding errors")
return np.array(vectors), kept_labels
def main():
parser = argparse.ArgumentParser(description="Train Layer 2 classifier for ADR-0008")
parser.add_argument(
"--data",
default=str(Path(__file__).parent / "seed_classifier_dataset.jsonl"),
help="Path to labeled JSONL dataset",
)
parser.add_argument(
"--output",
default=os.getenv("CLASSIFIER_MODEL_PATH", "/data/classifier_model.pkl"),
help="Output path for serialized model",
)
parser.add_argument(
"--ollama",
default=os.getenv("OLLAMA_LOCAL_URL", "http://localhost:11434"),
help="Ollama base URL",
)
parser.add_argument(
"--min-cv-accuracy",
type=float,
default=0.90,
help="Minimum cross-validation accuracy to proceed with saving (default: 0.90)",
)
args = parser.parse_args()
# ── Load data ──────────────────────────────────────────────────────────────
print(f"\n[1/4] Loading data from {args.data}")
queries, labels = load_data(args.data)
print(f" {len(queries)} examples loaded")
dist = Counter(labels)
print(f" Distribution: {dict(dist)}")
if len(queries) < 20:
print("[ERROR] Too few examples (< 20). Add more to the dataset.", file=sys.stderr)
sys.exit(1)
min_class_count = min(dist.values())
n_splits = min(5, min_class_count)
if n_splits < 2:
print(f"[ERROR] At least one class has fewer than 2 examples. Add more data.", file=sys.stderr)
sys.exit(1)
# ── Embed ──────────────────────────────────────────────────────────────────
print(f"\n[2/4] Embedding with bge-m3 via {args.ollama}")
vectors, labels = embed_queries(queries, labels, args.ollama)
print(f" Embedding matrix: {vectors.shape} ({len(labels)} examples kept)")
# ── Train + cross-validate ─────────────────────────────────────────────────
print(f"\n[3/4] Training LogisticRegression (C=1.0) with {n_splits}-fold CV")
clf = LogisticRegression(max_iter=1000, C=1.0, random_state=42)
cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
scores = cross_val_score(clf, vectors, labels, cv=cv, scoring="accuracy")
cv_mean = scores.mean()
cv_std = scores.std()
print(f" CV accuracy: {cv_mean:.3f} ± {cv_std:.3f} (folds: {scores.round(3).tolist()})")
# Per-class report via cross-validated predictions
y_pred = cross_val_predict(clf, vectors, labels, cv=cv)
print("\n Per-class report:")
report = classification_report(labels, y_pred, zero_division=0)
for line in report.splitlines():
print(f" {line}")
if cv_mean < args.min_cv_accuracy:
print(
f"\n[FAIL] CV accuracy {cv_mean:.3f} is below threshold {args.min_cv_accuracy:.2f}. "
"Add more examples to the dataset before deploying.",
file=sys.stderr,
)
sys.exit(1)
print(f"\n CV accuracy {cv_mean:.3f}{args.min_cv_accuracy:.2f} — proceeding to save.")
# ── Fit final model on full dataset ────────────────────────────────────────
clf.fit(vectors, labels)
# ── Save ───────────────────────────────────────────────────────────────────
print(f"\n[4/4] Saving model to {args.output}")
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(
{
"clf": clf,
"classes": clf.classes_.tolist(),
"n_train": len(queries),
"cv_mean": round(cv_mean, 4),
"cv_std": round(cv_std, 4),
},
out_path,
)
print(f" Model saved → {out_path}")
print(f"\nDone. Classes: {clf.classes_.tolist()}")
if __name__ == "__main__":
main()