205 lines
7.7 KiB
Python
205 lines
7.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Phase 2 — Train Layer 2 embedding classifier (ADR-0008).
|
|
|
|
Reads labeled (query, type) pairs from a JSONL file, embeds them with bge-m3
|
|
via Ollama, trains a LogisticRegression classifier, and serializes the model
|
|
with joblib.
|
|
|
|
Usage:
|
|
python train_classifier.py
|
|
python train_classifier.py --data path/to/dataset.jsonl --output /data/classifier_model.pkl
|
|
|
|
The output file is loaded at engine startup by graph.py (_load_layer2_model).
|
|
Any JSONL file produced by classifier_export.py is compatible as additional
|
|
training data — merge with the seed dataset before retraining.
|
|
|
|
Requirements (add to requirements.txt if not present):
|
|
scikit-learn
|
|
joblib
|
|
numpy
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
from collections import Counter
|
|
from pathlib import Path
|
|
|
|
import joblib
|
|
import numpy as np
|
|
from sklearn.linear_model import LogisticRegression
|
|
from sklearn.metrics import classification_report
|
|
from sklearn.model_selection import StratifiedKFold, cross_val_score, cross_val_predict
|
|
|
|
|
|
def load_data(path: str) -> tuple[list[str], list[str]]:
|
|
queries, labels = [], []
|
|
skipped = 0
|
|
with open(path, encoding="utf-8") as f:
|
|
for i, line in enumerate(f, 1):
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
try:
|
|
rec = json.loads(line)
|
|
except json.JSONDecodeError as e:
|
|
print(f" [WARN] line {i} skipped — JSON error: {e}", file=sys.stderr)
|
|
skipped += 1
|
|
continue
|
|
q = rec.get("query", "").strip()
|
|
t = rec.get("type", "").strip()
|
|
if not q or not t:
|
|
skipped += 1
|
|
continue
|
|
queries.append(q)
|
|
labels.append(t)
|
|
if skipped:
|
|
print(f" [WARN] {skipped} records skipped (missing query/type or invalid JSON)")
|
|
return queries, labels
|
|
|
|
|
|
def embed_queries(queries: list[str], labels: list[str], base_url: str) -> tuple[np.ndarray, list[str]]:
|
|
"""Embed queries with bge-m3 via Ollama, one at a time.
|
|
|
|
bge-m3 occasionally produces NaN vectors for certain inputs (known Ollama bug).
|
|
Embedding one by one lets us detect and skip those queries instead of failing
|
|
the entire batch.
|
|
|
|
Returns (vectors, kept_labels) — labels aligned to the returned vectors.
|
|
"""
|
|
try:
|
|
from langchain_ollama import OllamaEmbeddings
|
|
except ImportError:
|
|
print("[ERROR] langchain-ollama not installed. Run: pip install langchain-ollama", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
emb = OllamaEmbeddings(model="bge-m3", base_url=base_url)
|
|
vectors: list[list[float]] = []
|
|
kept_labels: list[str] = []
|
|
skipped = 0
|
|
|
|
for i, (query, label) in enumerate(zip(queries, labels)):
|
|
try:
|
|
vec = emb.embed_query(query)
|
|
if any(v != v for v in vec): # NaN check (NaN != NaN)
|
|
print(f" [WARN] query {i+1} produced NaN vector, skipped: '{query[:60]}'", file=sys.stderr)
|
|
skipped += 1
|
|
continue
|
|
vectors.append(vec)
|
|
kept_labels.append(label)
|
|
except Exception as e:
|
|
print(f" [WARN] query {i+1} embedding failed ({e}), skipped: '{query[:60]}'", file=sys.stderr)
|
|
skipped += 1
|
|
continue
|
|
|
|
if (i + 1) % 10 == 0 or (i + 1) == len(queries):
|
|
print(f" Embedded {i+1}/{len(queries)}...", end="\r")
|
|
|
|
print()
|
|
if skipped:
|
|
print(f" [WARN] {skipped} queries skipped due to NaN or embedding errors")
|
|
|
|
return np.array(vectors), kept_labels
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Train Layer 2 classifier for ADR-0008")
|
|
parser.add_argument(
|
|
"--data",
|
|
default=str(Path(__file__).parent / "seed_classifier_dataset.jsonl"),
|
|
help="Path to labeled JSONL dataset",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
default=os.getenv("CLASSIFIER_MODEL_PATH", "/data/classifier_model.pkl"),
|
|
help="Output path for serialized model",
|
|
)
|
|
parser.add_argument(
|
|
"--ollama",
|
|
default=os.getenv("OLLAMA_LOCAL_URL", "http://localhost:11434"),
|
|
help="Ollama base URL",
|
|
)
|
|
parser.add_argument(
|
|
"--min-cv-accuracy",
|
|
type=float,
|
|
default=0.90,
|
|
help="Minimum cross-validation accuracy to proceed with saving (default: 0.90)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# ── Load data ──────────────────────────────────────────────────────────────
|
|
print(f"\n[1/4] Loading data from {args.data}")
|
|
queries, labels = load_data(args.data)
|
|
print(f" {len(queries)} examples loaded")
|
|
|
|
dist = Counter(labels)
|
|
print(f" Distribution: {dict(dist)}")
|
|
|
|
if len(queries) < 20:
|
|
print("[ERROR] Too few examples (< 20). Add more to the dataset.", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
min_class_count = min(dist.values())
|
|
n_splits = min(5, min_class_count)
|
|
if n_splits < 2:
|
|
print(f"[ERROR] At least one class has fewer than 2 examples. Add more data.", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# ── Embed ──────────────────────────────────────────────────────────────────
|
|
print(f"\n[2/4] Embedding with bge-m3 via {args.ollama}")
|
|
vectors, labels = embed_queries(queries, labels, args.ollama)
|
|
print(f" Embedding matrix: {vectors.shape} ({len(labels)} examples kept)")
|
|
|
|
# ── Train + cross-validate ─────────────────────────────────────────────────
|
|
print(f"\n[3/4] Training LogisticRegression (C=1.0) with {n_splits}-fold CV")
|
|
clf = LogisticRegression(max_iter=1000, C=1.0, random_state=42)
|
|
cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
|
|
|
|
scores = cross_val_score(clf, vectors, labels, cv=cv, scoring="accuracy")
|
|
cv_mean = scores.mean()
|
|
cv_std = scores.std()
|
|
print(f" CV accuracy: {cv_mean:.3f} ± {cv_std:.3f} (folds: {scores.round(3).tolist()})")
|
|
|
|
# Per-class report via cross-validated predictions
|
|
y_pred = cross_val_predict(clf, vectors, labels, cv=cv)
|
|
print("\n Per-class report:")
|
|
report = classification_report(labels, y_pred, zero_division=0)
|
|
for line in report.splitlines():
|
|
print(f" {line}")
|
|
|
|
if cv_mean < args.min_cv_accuracy:
|
|
print(
|
|
f"\n[FAIL] CV accuracy {cv_mean:.3f} is below threshold {args.min_cv_accuracy:.2f}. "
|
|
"Add more examples to the dataset before deploying.",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(1)
|
|
|
|
print(f"\n CV accuracy {cv_mean:.3f} ≥ {args.min_cv_accuracy:.2f} — proceeding to save.")
|
|
|
|
# ── Fit final model on full dataset ────────────────────────────────────────
|
|
clf.fit(vectors, labels)
|
|
|
|
# ── Save ───────────────────────────────────────────────────────────────────
|
|
print(f"\n[4/4] Saving model to {args.output}")
|
|
out_path = Path(args.output)
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
joblib.dump(
|
|
{
|
|
"clf": clf,
|
|
"classes": clf.classes_.tolist(),
|
|
"n_train": len(queries),
|
|
"cv_mean": round(cv_mean, 4),
|
|
"cv_std": round(cv_std, 4),
|
|
},
|
|
out_path,
|
|
)
|
|
print(f" Model saved → {out_path}")
|
|
print(f"\nDone. Classes: {clf.classes_.tolist()}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|