#!/usr/bin/env python3 """ Phase 2 — Train Layer 2 embedding classifier (ADR-0008). Reads labeled (query, type) pairs from a JSONL file, embeds them with bge-m3 via Ollama, trains a LogisticRegression classifier, and serializes the model with joblib. Usage: python train_classifier.py python train_classifier.py --data path/to/dataset.jsonl --output /data/classifier_model.pkl The output file is loaded at engine startup by graph.py (_load_layer2_model). Any JSONL file produced by classifier_export.py is compatible as additional training data — merge with the seed dataset before retraining. Requirements (add to requirements.txt if not present): scikit-learn joblib numpy """ import argparse import json import os import sys from collections import Counter from pathlib import Path import joblib import numpy as np from sklearn.linear_model import LogisticRegression from sklearn.metrics import classification_report from sklearn.model_selection import StratifiedKFold, cross_val_score, cross_val_predict def load_data(path: str) -> tuple[list[str], list[str]]: queries, labels = [], [] skipped = 0 with open(path, encoding="utf-8") as f: for i, line in enumerate(f, 1): line = line.strip() if not line: continue try: rec = json.loads(line) except json.JSONDecodeError as e: print(f" [WARN] line {i} skipped — JSON error: {e}", file=sys.stderr) skipped += 1 continue q = rec.get("query", "").strip() t = rec.get("type", "").strip() if not q or not t: skipped += 1 continue queries.append(q) labels.append(t) if skipped: print(f" [WARN] {skipped} records skipped (missing query/type or invalid JSON)") return queries, labels def embed_queries(queries: list[str], labels: list[str], base_url: str) -> tuple[np.ndarray, list[str]]: """Embed queries with bge-m3 via Ollama, one at a time. bge-m3 occasionally produces NaN vectors for certain inputs (known Ollama bug). Embedding one by one lets us detect and skip those queries instead of failing the entire batch. Returns (vectors, kept_labels) — labels aligned to the returned vectors. """ try: from langchain_ollama import OllamaEmbeddings except ImportError: print("[ERROR] langchain-ollama not installed. Run: pip install langchain-ollama", file=sys.stderr) sys.exit(1) emb = OllamaEmbeddings(model="bge-m3", base_url=base_url) vectors: list[list[float]] = [] kept_labels: list[str] = [] skipped = 0 for i, (query, label) in enumerate(zip(queries, labels)): try: vec = emb.embed_query(query) if any(v != v for v in vec): # NaN check (NaN != NaN) print(f" [WARN] query {i+1} produced NaN vector, skipped: '{query[:60]}'", file=sys.stderr) skipped += 1 continue vectors.append(vec) kept_labels.append(label) except Exception as e: print(f" [WARN] query {i+1} embedding failed ({e}), skipped: '{query[:60]}'", file=sys.stderr) skipped += 1 continue if (i + 1) % 10 == 0 or (i + 1) == len(queries): print(f" Embedded {i+1}/{len(queries)}...", end="\r") print() if skipped: print(f" [WARN] {skipped} queries skipped due to NaN or embedding errors") return np.array(vectors), kept_labels def main(): parser = argparse.ArgumentParser(description="Train Layer 2 classifier for ADR-0008") parser.add_argument( "--data", default=str(Path(__file__).parent / "seed_classifier_dataset.jsonl"), help="Path to labeled JSONL dataset", ) parser.add_argument( "--output", default=os.getenv("CLASSIFIER_MODEL_PATH", "/data/classifier_model.pkl"), help="Output path for serialized model", ) parser.add_argument( "--ollama", default=os.getenv("OLLAMA_LOCAL_URL", "http://localhost:11434"), help="Ollama base URL", ) parser.add_argument( "--min-cv-accuracy", type=float, default=0.90, help="Minimum cross-validation accuracy to proceed with saving (default: 0.90)", ) args = parser.parse_args() # ── Load data ────────────────────────────────────────────────────────────── print(f"\n[1/4] Loading data from {args.data}") queries, labels = load_data(args.data) print(f" {len(queries)} examples loaded") dist = Counter(labels) print(f" Distribution: {dict(dist)}") if len(queries) < 20: print("[ERROR] Too few examples (< 20). Add more to the dataset.", file=sys.stderr) sys.exit(1) min_class_count = min(dist.values()) n_splits = min(5, min_class_count) if n_splits < 2: print(f"[ERROR] At least one class has fewer than 2 examples. Add more data.", file=sys.stderr) sys.exit(1) # ── Embed ────────────────────────────────────────────────────────────────── print(f"\n[2/4] Embedding with bge-m3 via {args.ollama}") vectors, labels = embed_queries(queries, labels, args.ollama) print(f" Embedding matrix: {vectors.shape} ({len(labels)} examples kept)") # ── Train + cross-validate ───────────────────────────────────────────────── print(f"\n[3/4] Training LogisticRegression (C=1.0) with {n_splits}-fold CV") clf = LogisticRegression(max_iter=1000, C=1.0, random_state=42) cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42) scores = cross_val_score(clf, vectors, labels, cv=cv, scoring="accuracy") cv_mean = scores.mean() cv_std = scores.std() print(f" CV accuracy: {cv_mean:.3f} ± {cv_std:.3f} (folds: {scores.round(3).tolist()})") # Per-class report via cross-validated predictions y_pred = cross_val_predict(clf, vectors, labels, cv=cv) print("\n Per-class report:") report = classification_report(labels, y_pred, zero_division=0) for line in report.splitlines(): print(f" {line}") if cv_mean < args.min_cv_accuracy: print( f"\n[FAIL] CV accuracy {cv_mean:.3f} is below threshold {args.min_cv_accuracy:.2f}. " "Add more examples to the dataset before deploying.", file=sys.stderr, ) sys.exit(1) print(f"\n CV accuracy {cv_mean:.3f} ≥ {args.min_cv_accuracy:.2f} — proceeding to save.") # ── Fit final model on full dataset ──────────────────────────────────────── clf.fit(vectors, labels) # ── Save ─────────────────────────────────────────────────────────────────── print(f"\n[4/4] Saving model to {args.output}") out_path = Path(args.output) out_path.parent.mkdir(parents=True, exist_ok=True) joblib.dump( { "clf": clf, "classes": clf.classes_.tolist(), "n_train": len(queries), "cv_mean": round(cv_mean, 4), "cv_std": round(cv_std, 4), }, out_path, ) print(f" Model saved → {out_path}") print(f"\nDone. Classes: {clf.classes_.tolist()}") if __name__ == "__main__": main()