assistance-engine/scripts/pipelines/classifier/retrain_pipeline.py

272 lines
13 KiB
Python

#!/usr/bin/env python3
"""
retrain_pipeline.py — Automatic Champion/Challenger retraining (ADR-0010).
Triggered automatically by classifier_export.py when RETRAIN_THRESHOLD new
sessions accumulate. Can also be run manually.
Flow:
1. Merge all JSONL exports in EXPORT_DIR with the seed dataset
2. Split into train (80%) and held-out (20%) — stratified
3. Train challenger model
4. Load champion model (current production model at CLASSIFIER_MODEL_PATH)
5. Evaluate both on held-out set
6. If challenger accuracy >= champion accuracy: deploy challenger → champion
7. If challenger < champion: discard challenger, keep champion, log alert
8. Archive processed export files to EXPORT_ARCHIVE_DIR
Usage:
python retrain_pipeline.py # auto mode (uses env vars)
python retrain_pipeline.py --force # skip champion comparison, always deploy
python retrain_pipeline.py --dry-run # evaluate only, do not deploy
"""
import argparse
import json
import logging
import os
import shutil
from collections import Counter
from datetime import datetime, timezone
from pathlib import Path
import joblib
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.model_selection import StratifiedShuffleSplit, cross_val_score, StratifiedKFold
logging.basicConfig(level=logging.INFO, format="%(asctime)s [retrain] %(message)s")
logger = logging.getLogger("retrain_pipeline")
# ── Configuration ──────────────────────────────────────────────────────────────
EXPORT_DIR = Path(os.getenv("CLASSIFIER_EXPORT_DIR", "/data/classifier_labels"))
EXPORT_ARCHIVE = Path(os.getenv("CLASSIFIER_ARCHIVE_DIR", "/data/classifier_labels/archived"))
CHAMPION_PATH = Path(os.getenv("CLASSIFIER_MODEL_PATH", "/data/classifier_model.pkl"))
CHALLENGER_PATH = CHAMPION_PATH.parent / "classifier_model_challenger.pkl"
SEED_DATASET = Path(os.getenv("CLASSIFIER_SEED_DATASET",
str(Path(__file__).parent / "seed_classifier_dataset.jsonl")))
OLLAMA_URL = os.getenv("OLLAMA_LOCAL_URL", "http://localhost:11434")
MIN_CV_ACCURACY = float(os.getenv("CLASSIFIER_MIN_CV_ACCURACY", "0.90"))
HELD_OUT_RATIO = float(os.getenv("CLASSIFIER_HELD_OUT_RATIO", "0.20"))
# ── Data loading ───────────────────────────────────────────────────────────────
def load_jsonl(path: Path) -> tuple[list[str], list[str]]:
queries, labels = [], []
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
rec = json.loads(line)
q = rec.get("query", "").strip()
t = rec.get("type", "").strip()
if q and t:
queries.append(q)
labels.append(t)
except json.JSONDecodeError:
pass
return queries, labels
def load_all_data(export_dir: Path, seed_path: Path) -> tuple[list[str], list[str], list[Path]]:
"""Load seed dataset + all unarchived export files. Returns (queries, labels, export_files)."""
queries, labels = [], []
if seed_path.exists():
q, l = load_jsonl(seed_path)
queries.extend(q)
labels.extend(l)
logger.info(f"Seed dataset: {len(q)} examples from {seed_path}")
else:
logger.warning(f"Seed dataset not found at {seed_path}")
export_files = sorted(export_dir.glob("classifier_labels_*.jsonl"))
for f in export_files:
q, l = load_jsonl(f)
queries.extend(q)
labels.extend(l)
logger.info(f"Export file: {len(q)} examples from {f.name}")
logger.info(f"Total dataset: {len(queries)} examples — {dict(Counter(labels))}")
return queries, labels, export_files
# ── Embedding ──────────────────────────────────────────────────────────────────
def embed_queries(queries: list[str], labels: list[str], base_url: str) -> tuple[np.ndarray, list[str]]:
"""Embed with bge-m3, one at a time to handle NaN vectors gracefully."""
try:
from langchain_ollama import OllamaEmbeddings
except ImportError:
logger.error("langchain-ollama not installed")
raise
emb = OllamaEmbeddings(model="bge-m3", base_url=base_url)
vectors, kept_labels = [], []
skipped = 0
for i, (query, label) in enumerate(zip(queries, labels)):
try:
vec = emb.embed_query(query)
if any(v != v for v in vec):
skipped += 1
continue
vectors.append(vec)
kept_labels.append(label)
except Exception as e:
logger.warning(f"Embedding failed for query {i}: {e}")
skipped += 1
if (i + 1) % 20 == 0:
logger.info(f" Embedded {i+1}/{len(queries)}...")
if skipped:
logger.warning(f"Skipped {skipped} queries due to NaN or embedding errors")
return np.array(vectors), kept_labels
# ── Training ───────────────────────────────────────────────────────────────────
def train_model(X: np.ndarray, y: list[str]) -> LogisticRegression:
clf = LogisticRegression(max_iter=1000, C=1.0, random_state=42)
clf.fit(X, y)
return clf
def cross_validate(X: np.ndarray, y: list[str]) -> float:
dist = Counter(y)
n_splits = min(5, min(dist.values()))
if n_splits < 2:
logger.warning("Too few examples per class for cross-validation")
return 0.0
clf = LogisticRegression(max_iter=1000, C=1.0, random_state=42)
cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
scores = cross_val_score(clf, X, y, cv=cv, scoring="accuracy")
logger.info(f"CV accuracy: {scores.mean():.3f} ± {scores.std():.3f} (folds: {scores.round(3).tolist()})")
return float(scores.mean())
# ── Champion evaluation ────────────────────────────────────────────────────────
def evaluate_champion(champion_path: Path, X_held: np.ndarray, y_held: list[str]) -> float:
"""Evaluate the current champion on the held-out set. Returns accuracy."""
if not champion_path.exists():
logger.info("No champion model found — challenger will be deployed unconditionally")
return 0.0
model = joblib.load(champion_path)
preds = model["clf"].predict(X_held)
acc = accuracy_score(y_held, preds)
logger.info(f"Champion accuracy on held-out set: {acc:.3f} (trained on {model.get('n_train', '?')} examples)")
return acc
# ── Archive ────────────────────────────────────────────────────────────────────
def archive_exports(export_files: list[Path], archive_dir: Path) -> None:
archive_dir.mkdir(parents=True, exist_ok=True)
for f in export_files:
dest = archive_dir / f.name
shutil.move(str(f), str(dest))
logger.info(f"Archived {f.name}{archive_dir}")
# ── Main ───────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="Champion/Challenger classifier retraining (ADR-0010)")
parser.add_argument("--ollama", default=OLLAMA_URL)
parser.add_argument("--force", action="store_true", help="Deploy challenger without comparing to champion")
parser.add_argument("--dry-run", action="store_true", help="Evaluate only — do not deploy or archive")
parser.add_argument("--min-cv", type=float, default=MIN_CV_ACCURACY)
args = parser.parse_args()
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
logger.info(f"=== Retraining pipeline started {ts} ===")
# ── 1. Load data ───────────────────────────────────────────────────────────
queries, labels, export_files = load_all_data(EXPORT_DIR, SEED_DATASET)
if len(queries) < 20:
logger.error("Too few examples (< 20). Aborting.")
return
# ── 2. Embed ───────────────────────────────────────────────────────────────
logger.info(f"Embedding {len(queries)} queries via bge-m3 at {args.ollama}...")
X, labels = embed_queries(queries, labels, args.ollama)
logger.info(f"Embedding matrix: {X.shape}")
# ── 3. Stratified train / held-out split ───────────────────────────────────
sss = StratifiedShuffleSplit(n_splits=1, test_size=HELD_OUT_RATIO, random_state=42)
train_idx, held_idx = next(sss.split(X, labels))
X_train, X_held = X[train_idx], X[held_idx]
y_train = [labels[i] for i in train_idx]
y_held = [labels[i] for i in held_idx]
logger.info(f"Train: {len(y_train)} | Held-out: {len(y_held)}")
# ── 4. Cross-validate challenger ───────────────────────────────────────────
logger.info("Cross-validating challenger...")
cv_acc = cross_validate(X_train, y_train)
if cv_acc < args.min_cv:
logger.error(f"Challenger CV accuracy {cv_acc:.3f} below minimum {args.min_cv:.2f}. Aborting.")
return
# ── 5. Train challenger on full train split ────────────────────────────────
challenger_clf = train_model(X_train, y_train)
challenger_acc = accuracy_score(y_held, challenger_clf.predict(X_held))
logger.info(f"Challenger accuracy on held-out: {challenger_acc:.3f}")
# ── 6. Evaluate champion on same held-out set ──────────────────────────────
champion_acc = evaluate_champion(CHAMPION_PATH, X_held, y_held)
# ── 7. Promotion decision ──────────────────────────────────────────────────
promote = args.force or (challenger_acc >= champion_acc)
if promote:
reason = "forced" if args.force else f"challenger ({challenger_acc:.3f}) >= champion ({champion_acc:.3f})"
logger.info(f"PROMOTING challenger — {reason}")
if not args.dry_run:
# Back up champion before overwriting
if CHAMPION_PATH.exists():
backup = CHAMPION_PATH.parent / f"classifier_model_backup_{ts}.pkl"
shutil.copy(str(CHAMPION_PATH), str(backup))
logger.info(f"Champion backed up → {backup.name}")
joblib.dump(
{
"clf": challenger_clf,
"classes": challenger_clf.classes_.tolist(),
"n_train": len(y_train),
"cv_mean": round(cv_acc, 4),
"held_acc": round(challenger_acc, 4),
"champion_acc": round(champion_acc, 4),
"retrained_at": ts,
},
CHAMPION_PATH,
)
logger.info(f"New champion saved → {CHAMPION_PATH}")
logger.info("Restart brunix-assistance-engine to load the new model.")
else:
logger.info("[dry-run] Promotion skipped")
else:
logger.warning(
f"KEEPING champion — challenger ({challenger_acc:.3f}) < champion ({champion_acc:.3f}). "
"Consider adding more labeled data."
)
# ── 8. Archive processed exports ──────────────────────────────────────────
if not args.dry_run and export_files:
archive_exports(export_files, EXPORT_ARCHIVE)
logger.info(f"=== Retraining pipeline finished {datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%SZ')} ===")
if __name__ == "__main__":
main()