272 lines
13 KiB
Python
272 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
retrain_pipeline.py — Automatic Champion/Challenger retraining (ADR-0010).
|
|
|
|
Triggered automatically by classifier_export.py when RETRAIN_THRESHOLD new
|
|
sessions accumulate. Can also be run manually.
|
|
|
|
Flow:
|
|
1. Merge all JSONL exports in EXPORT_DIR with the seed dataset
|
|
2. Split into train (80%) and held-out (20%) — stratified
|
|
3. Train challenger model
|
|
4. Load champion model (current production model at CLASSIFIER_MODEL_PATH)
|
|
5. Evaluate both on held-out set
|
|
6. If challenger accuracy >= champion accuracy: deploy challenger → champion
|
|
7. If challenger < champion: discard challenger, keep champion, log alert
|
|
8. Archive processed export files to EXPORT_ARCHIVE_DIR
|
|
|
|
Usage:
|
|
python retrain_pipeline.py # auto mode (uses env vars)
|
|
python retrain_pipeline.py --force # skip champion comparison, always deploy
|
|
python retrain_pipeline.py --dry-run # evaluate only, do not deploy
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
import shutil
|
|
from collections import Counter
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
import joblib
|
|
import numpy as np
|
|
from sklearn.linear_model import LogisticRegression
|
|
from sklearn.metrics import accuracy_score
|
|
from sklearn.model_selection import StratifiedShuffleSplit, cross_val_score, StratifiedKFold
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [retrain] %(message)s")
|
|
logger = logging.getLogger("retrain_pipeline")
|
|
|
|
# ── Configuration ──────────────────────────────────────────────────────────────
|
|
|
|
EXPORT_DIR = Path(os.getenv("CLASSIFIER_EXPORT_DIR", "/data/classifier_labels"))
|
|
EXPORT_ARCHIVE = Path(os.getenv("CLASSIFIER_ARCHIVE_DIR", "/data/classifier_labels/archived"))
|
|
CHAMPION_PATH = Path(os.getenv("CLASSIFIER_MODEL_PATH", "/data/classifier_model.pkl"))
|
|
CHALLENGER_PATH = CHAMPION_PATH.parent / "classifier_model_challenger.pkl"
|
|
SEED_DATASET = Path(os.getenv("CLASSIFIER_SEED_DATASET",
|
|
str(Path(__file__).parent / "seed_classifier_dataset.jsonl")))
|
|
OLLAMA_URL = os.getenv("OLLAMA_LOCAL_URL", "http://localhost:11434")
|
|
MIN_CV_ACCURACY = float(os.getenv("CLASSIFIER_MIN_CV_ACCURACY", "0.90"))
|
|
HELD_OUT_RATIO = float(os.getenv("CLASSIFIER_HELD_OUT_RATIO", "0.20"))
|
|
|
|
|
|
# ── Data loading ───────────────────────────────────────────────────────────────
|
|
|
|
def load_jsonl(path: Path) -> tuple[list[str], list[str]]:
|
|
queries, labels = [], []
|
|
with open(path, encoding="utf-8") as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
try:
|
|
rec = json.loads(line)
|
|
q = rec.get("query", "").strip()
|
|
t = rec.get("type", "").strip()
|
|
if q and t:
|
|
queries.append(q)
|
|
labels.append(t)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
return queries, labels
|
|
|
|
|
|
def load_all_data(export_dir: Path, seed_path: Path) -> tuple[list[str], list[str], list[Path]]:
|
|
"""Load seed dataset + all unarchived export files. Returns (queries, labels, export_files)."""
|
|
queries, labels = [], []
|
|
|
|
if seed_path.exists():
|
|
q, l = load_jsonl(seed_path)
|
|
queries.extend(q)
|
|
labels.extend(l)
|
|
logger.info(f"Seed dataset: {len(q)} examples from {seed_path}")
|
|
else:
|
|
logger.warning(f"Seed dataset not found at {seed_path}")
|
|
|
|
export_files = sorted(export_dir.glob("classifier_labels_*.jsonl"))
|
|
for f in export_files:
|
|
q, l = load_jsonl(f)
|
|
queries.extend(q)
|
|
labels.extend(l)
|
|
logger.info(f"Export file: {len(q)} examples from {f.name}")
|
|
|
|
logger.info(f"Total dataset: {len(queries)} examples — {dict(Counter(labels))}")
|
|
return queries, labels, export_files
|
|
|
|
|
|
# ── Embedding ──────────────────────────────────────────────────────────────────
|
|
|
|
def embed_queries(queries: list[str], labels: list[str], base_url: str) -> tuple[np.ndarray, list[str]]:
|
|
"""Embed with bge-m3, one at a time to handle NaN vectors gracefully."""
|
|
try:
|
|
from langchain_ollama import OllamaEmbeddings
|
|
except ImportError:
|
|
logger.error("langchain-ollama not installed")
|
|
raise
|
|
|
|
emb = OllamaEmbeddings(model="bge-m3", base_url=base_url)
|
|
vectors, kept_labels = [], []
|
|
skipped = 0
|
|
|
|
for i, (query, label) in enumerate(zip(queries, labels)):
|
|
try:
|
|
vec = emb.embed_query(query)
|
|
if any(v != v for v in vec):
|
|
skipped += 1
|
|
continue
|
|
vectors.append(vec)
|
|
kept_labels.append(label)
|
|
except Exception as e:
|
|
logger.warning(f"Embedding failed for query {i}: {e}")
|
|
skipped += 1
|
|
if (i + 1) % 20 == 0:
|
|
logger.info(f" Embedded {i+1}/{len(queries)}...")
|
|
|
|
if skipped:
|
|
logger.warning(f"Skipped {skipped} queries due to NaN or embedding errors")
|
|
|
|
return np.array(vectors), kept_labels
|
|
|
|
|
|
# ── Training ───────────────────────────────────────────────────────────────────
|
|
|
|
def train_model(X: np.ndarray, y: list[str]) -> LogisticRegression:
|
|
clf = LogisticRegression(max_iter=1000, C=1.0, random_state=42)
|
|
clf.fit(X, y)
|
|
return clf
|
|
|
|
|
|
def cross_validate(X: np.ndarray, y: list[str]) -> float:
|
|
dist = Counter(y)
|
|
n_splits = min(5, min(dist.values()))
|
|
if n_splits < 2:
|
|
logger.warning("Too few examples per class for cross-validation")
|
|
return 0.0
|
|
clf = LogisticRegression(max_iter=1000, C=1.0, random_state=42)
|
|
cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
|
scores = cross_val_score(clf, X, y, cv=cv, scoring="accuracy")
|
|
logger.info(f"CV accuracy: {scores.mean():.3f} ± {scores.std():.3f} (folds: {scores.round(3).tolist()})")
|
|
return float(scores.mean())
|
|
|
|
|
|
# ── Champion evaluation ────────────────────────────────────────────────────────
|
|
|
|
def evaluate_champion(champion_path: Path, X_held: np.ndarray, y_held: list[str]) -> float:
|
|
"""Evaluate the current champion on the held-out set. Returns accuracy."""
|
|
if not champion_path.exists():
|
|
logger.info("No champion model found — challenger will be deployed unconditionally")
|
|
return 0.0
|
|
model = joblib.load(champion_path)
|
|
preds = model["clf"].predict(X_held)
|
|
acc = accuracy_score(y_held, preds)
|
|
logger.info(f"Champion accuracy on held-out set: {acc:.3f} (trained on {model.get('n_train', '?')} examples)")
|
|
return acc
|
|
|
|
|
|
# ── Archive ────────────────────────────────────────────────────────────────────
|
|
|
|
def archive_exports(export_files: list[Path], archive_dir: Path) -> None:
|
|
archive_dir.mkdir(parents=True, exist_ok=True)
|
|
for f in export_files:
|
|
dest = archive_dir / f.name
|
|
shutil.move(str(f), str(dest))
|
|
logger.info(f"Archived {f.name} → {archive_dir}")
|
|
|
|
|
|
# ── Main ───────────────────────────────────────────────────────────────────────
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Champion/Challenger classifier retraining (ADR-0010)")
|
|
parser.add_argument("--ollama", default=OLLAMA_URL)
|
|
parser.add_argument("--force", action="store_true", help="Deploy challenger without comparing to champion")
|
|
parser.add_argument("--dry-run", action="store_true", help="Evaluate only — do not deploy or archive")
|
|
parser.add_argument("--min-cv", type=float, default=MIN_CV_ACCURACY)
|
|
args = parser.parse_args()
|
|
|
|
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
|
logger.info(f"=== Retraining pipeline started {ts} ===")
|
|
|
|
# ── 1. Load data ───────────────────────────────────────────────────────────
|
|
queries, labels, export_files = load_all_data(EXPORT_DIR, SEED_DATASET)
|
|
|
|
if len(queries) < 20:
|
|
logger.error("Too few examples (< 20). Aborting.")
|
|
return
|
|
|
|
# ── 2. Embed ───────────────────────────────────────────────────────────────
|
|
logger.info(f"Embedding {len(queries)} queries via bge-m3 at {args.ollama}...")
|
|
X, labels = embed_queries(queries, labels, args.ollama)
|
|
logger.info(f"Embedding matrix: {X.shape}")
|
|
|
|
# ── 3. Stratified train / held-out split ───────────────────────────────────
|
|
sss = StratifiedShuffleSplit(n_splits=1, test_size=HELD_OUT_RATIO, random_state=42)
|
|
train_idx, held_idx = next(sss.split(X, labels))
|
|
X_train, X_held = X[train_idx], X[held_idx]
|
|
y_train = [labels[i] for i in train_idx]
|
|
y_held = [labels[i] for i in held_idx]
|
|
logger.info(f"Train: {len(y_train)} | Held-out: {len(y_held)}")
|
|
|
|
# ── 4. Cross-validate challenger ───────────────────────────────────────────
|
|
logger.info("Cross-validating challenger...")
|
|
cv_acc = cross_validate(X_train, y_train)
|
|
|
|
if cv_acc < args.min_cv:
|
|
logger.error(f"Challenger CV accuracy {cv_acc:.3f} below minimum {args.min_cv:.2f}. Aborting.")
|
|
return
|
|
|
|
# ── 5. Train challenger on full train split ────────────────────────────────
|
|
challenger_clf = train_model(X_train, y_train)
|
|
challenger_acc = accuracy_score(y_held, challenger_clf.predict(X_held))
|
|
logger.info(f"Challenger accuracy on held-out: {challenger_acc:.3f}")
|
|
|
|
# ── 6. Evaluate champion on same held-out set ──────────────────────────────
|
|
champion_acc = evaluate_champion(CHAMPION_PATH, X_held, y_held)
|
|
|
|
# ── 7. Promotion decision ──────────────────────────────────────────────────
|
|
promote = args.force or (challenger_acc >= champion_acc)
|
|
|
|
if promote:
|
|
reason = "forced" if args.force else f"challenger ({challenger_acc:.3f}) >= champion ({champion_acc:.3f})"
|
|
logger.info(f"PROMOTING challenger — {reason}")
|
|
|
|
if not args.dry_run:
|
|
# Back up champion before overwriting
|
|
if CHAMPION_PATH.exists():
|
|
backup = CHAMPION_PATH.parent / f"classifier_model_backup_{ts}.pkl"
|
|
shutil.copy(str(CHAMPION_PATH), str(backup))
|
|
logger.info(f"Champion backed up → {backup.name}")
|
|
|
|
joblib.dump(
|
|
{
|
|
"clf": challenger_clf,
|
|
"classes": challenger_clf.classes_.tolist(),
|
|
"n_train": len(y_train),
|
|
"cv_mean": round(cv_acc, 4),
|
|
"held_acc": round(challenger_acc, 4),
|
|
"champion_acc": round(champion_acc, 4),
|
|
"retrained_at": ts,
|
|
},
|
|
CHAMPION_PATH,
|
|
)
|
|
logger.info(f"New champion saved → {CHAMPION_PATH}")
|
|
logger.info("Restart brunix-assistance-engine to load the new model.")
|
|
else:
|
|
logger.info("[dry-run] Promotion skipped")
|
|
else:
|
|
logger.warning(
|
|
f"KEEPING champion — challenger ({challenger_acc:.3f}) < champion ({champion_acc:.3f}). "
|
|
"Consider adding more labeled data."
|
|
)
|
|
|
|
# ── 8. Archive processed exports ──────────────────────────────────────────
|
|
if not args.dry_run and export_files:
|
|
archive_exports(export_files, EXPORT_ARCHIVE)
|
|
|
|
logger.info(f"=== Retraining pipeline finished {datetime.now(timezone.utc).strftime('%Y%m%dT%H%M%SZ')} ===")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|