Enhance generate_mbap_v2.py with new reward mechanism and GoldPool integration

- Added GoldPool class to manage a top-K pool of high-reward examples.
- Implemented compute_reward function to calculate composite rewards based on execution coverage, novelty, and test quality.
- Introduced call_api_reward function for API calls in the new reward mode.
- Updated main function to support new reward mode with adjustable weights for ECS, novelty, and test quality.
- Enhanced dataset saving functionality to include reward statistics.
- Refactored existing code for improved readability and consistency.
This commit is contained in:
acano 2026-03-27 14:04:21 +01:00
parent c6b57849cd
commit f747c140c8
1 changed files with 654 additions and 133 deletions

View File

@ -11,6 +11,8 @@ import argparse
import json import json
import math import math
import os import os
import io
import random
import sys import sys
import time import time
from collections import defaultdict from collections import defaultdict
@ -19,8 +21,11 @@ from pathlib import Path
import anthropic import anthropic
import requests import requests
from construct_prior import AVAP_NODE_NAMES, ConstructPrior
from construct_prior import ConstructPrior, AVAP_NODE_NAMES from src.config import settings
from dotenv import load_dotenv
load_dotenv()
AVAP_NODE_TYPES = { AVAP_NODE_TYPES = {
"addParam": ["addParam("], "addParam": ["addParam("],
@ -66,14 +71,13 @@ AVAP_NODE_TYPES = {
NODE_TYPE_NAMES = AVAP_NODE_NAMES NODE_TYPE_NAMES = AVAP_NODE_NAMES
_PRIOR_EPSILON = 0.05 _PRIOR_EPSILON = 0.05
class CellValidator:
class CellValidator:
def __init__(self, parser_url: str, parser_timeout: int = 5): def __init__(self, parser_url: str, parser_timeout: int = 5):
self.parser_url = parser_url.rstrip("/") self.parser_url = parser_url.rstrip("/")
self.parser_timeout = parser_timeout self.parser_timeout = parser_timeout
self._parser_available = True self._parser_available = True
def parse(self, code: str) -> tuple[bool, dict, str]: def parse(self, code: str) -> tuple[bool, dict, str]:
if not self._parser_available: if not self._parser_available:
@ -81,7 +85,9 @@ class CellValidator:
try: try:
resp = requests.post( resp = requests.post(
f"{self.parser_url}/parse", f"{self.parser_url}/parse",
#f"{settings.parser_url}/api/v1/upload",
json={"code": code}, json={"code": code},
#files={"file": ("task.json", io.BytesIO(json.dumps([code]).encode("utf-8")), "application/json")},
timeout=self.parser_timeout, timeout=self.parser_timeout,
) )
data = resp.json() data = resp.json()
@ -93,6 +99,7 @@ class CellValidator:
return None, {}, "parser_unavailable" return None, {}, "parser_unavailable"
except Exception as e: except Exception as e:
return False, {}, str(e) return False, {}, str(e)
def detect_constructs(self, code: str, ast: dict) -> set: def detect_constructs(self, code: str, ast: dict) -> set:
if ast: if ast:
return self._from_ast(ast) return self._from_ast(ast)
@ -147,7 +154,8 @@ class CellValidator:
bonus_ratio = len(extra) / max(len(all_types) - len(cell_constructs), 1) bonus_ratio = len(extra) / max(len(all_types) - len(cell_constructs), 1)
tq = sum( tq = sum(
1 for t in test_list 1
for t in test_list
if isinstance(t, str) and "re.match(" in t and len(t.strip()) > 10 if isinstance(t, str) and "re.match(" in t and len(t.strip()) > 10
) / max(len(test_list), 1) ) / max(len(test_list), 1)
@ -169,8 +177,6 @@ class CellValidator:
class CoverageMap: class CoverageMap:
def __init__(self, cell_size: int = 3): def __init__(self, cell_size: int = 3):
self.cell_size = cell_size self.cell_size = cell_size
@ -215,10 +221,7 @@ class CoverageMap:
return [c for c in self._all_cells if c not in self._map] return [c for c in self._all_cells if c not in self._map]
def get_low_quality_cells(self, threshold: float = 0.7) -> list[frozenset]: def get_low_quality_cells(self, threshold: float = 0.7) -> list[frozenset]:
return [ return [c for c, (_, q, _) in self._map.items() if q < threshold]
c for c, (_, q, _) in self._map.items()
if q < threshold
]
def get_example(self, cell: frozenset) -> dict | None: def get_example(self, cell: frozenset) -> dict | None:
entry = self._map.get(cell) entry = self._map.get(cell)
@ -260,9 +263,8 @@ class CoverageMap:
f"Entropy: {entropy:.2f} bits" f"Entropy: {entropy:.2f} bits"
) )
class CellSelector: class CellSelector:
def __init__( def __init__(
self, self,
coverage_map: CoverageMap, coverage_map: CoverageMap,
@ -274,6 +276,7 @@ class CellSelector:
self.ucb_c = ucb_c self.ucb_c = ucb_c
self._total_calls = 0 self._total_calls = 0
import random import random
self._rng = random.Random(42) self._rng = random.Random(42)
def select(self) -> frozenset: def select(self) -> frozenset:
@ -304,8 +307,8 @@ class CellSelector:
return best_cell return best_cell
class CellSelectorPrior(CellSelector):
class CellSelectorPrior(CellSelector):
def __init__( def __init__(
self, self,
coverage_map: CoverageMap, coverage_map: CoverageMap,
@ -326,8 +329,7 @@ class CellSelectorPrior(CellSelector):
if empty: if empty:
high_prior_empty = [ high_prior_empty = [
c for c in empty c for c in empty if self.prior.cell_weight(c) > self.prior.epsilon * 1.5
if self.prior.cell_weight(c) > self.prior.epsilon * 1.5
] ]
if high_prior_empty: if high_prior_empty:
return self._weighted_sample(high_prior_empty) return self._weighted_sample(high_prior_empty)
@ -371,6 +373,107 @@ class CellSelectorPrior(CellSelector):
return best_cell return best_cell
class GoldPool:
"""Top-K pool of high-reward examples for Candidate A (CW-Reward)."""
def __init__(self, max_size: int = 50, seed: int = 42):
self._pool: list[tuple[dict, float]] = []
self.max_size = max_size
self._rng = random.Random(seed)
@property
def size(self) -> int:
return len(self._pool)
def add(self, example: dict, reward: float) -> bool:
"""Add example to pool. Returns True if it entered the pool."""
if len(self._pool) < self.max_size:
self._pool.append((example, reward))
self._pool.sort(key=lambda x: -x[1])
return True
if reward > self._pool[-1][1]:
self._pool[-1] = (example, reward)
self._pool.sort(key=lambda x: -x[1])
return True
return False
def few_shot_examples(self, n: int = 3) -> list[dict]:
"""Return n diverse examples from the pool for few-shot prompting."""
if not self._pool:
return []
if len(self._pool) <= n:
return [ex for ex, _ in self._pool]
top_half = self._pool[: max(len(self._pool) // 2, n)]
sampled = self._rng.sample(top_half, n)
return [ex for ex, _ in sampled]
def construct_sets(self) -> list[set[str]]:
"""Return the construct sets of all pool examples."""
return [set(ex.get("_detected", [])) for ex, _ in self._pool]
def pool_summary(self) -> str:
if not self._pool:
return "GoldPool: empty"
rewards = [r for _, r in self._pool]
return (
f"GoldPool: {len(self._pool)}/{self.max_size} | "
f"reward: min={min(rewards):.3f} max={max(rewards):.3f} "
f"mean={sum(rewards) / len(rewards):.3f}"
)
def compute_ecs(detected: set[str]) -> float:
"""Execution Coverage Score — fraction of all AVAP node types covered."""
return len(detected) / max(len(NODE_TYPE_NAMES), 1)
def jaccard_similarity(a: set, b: set) -> float:
"""Jaccard similarity between two sets."""
union = a | b
if not union:
return 1.0
return len(a & b) / len(union)
def jaccard_novelty(detected: set[str], pool: GoldPool) -> float:
"""Novelty = 1 - max Jaccard similarity with any pool example."""
pool_sets = pool.construct_sets()
if not pool_sets:
return 1.0
max_sim = max(jaccard_similarity(detected, ps) for ps in pool_sets)
return 1.0 - max_sim
def compute_reward(
detected: set[str],
test_list: list,
pool: GoldPool,
w_ecs: float,
w_novelty: float,
w_tests: float,
) -> tuple[float, dict]:
"""
Candidate A composite reward:
reward(e) = w_ecs * ECS(e) + w_novelty * Jaccard_novelty(e, Pool) + w_tests * test_quality(e)
"""
ecs = compute_ecs(detected)
novelty = jaccard_novelty(detected, pool)
tq = sum(
1
for t in test_list
if isinstance(t, str) and "re.match(" in t and len(t.strip()) > 10
) / max(len(test_list), 1)
reward = w_ecs * ecs + w_novelty * novelty + w_tests * tq
return reward, {
"ecs": round(ecs, 3),
"novelty": round(novelty, 3),
"test_quality": round(tq, 3),
"reward": round(reward, 3),
"detected": sorted(detected),
}
SYSTEM_PROMPT = """Eres un experto en el lenguaje AVAP. SYSTEM_PROMPT = """Eres un experto en el lenguaje AVAP.
Se te proporciona el Language Reference Manual (LRM) completo de AVAP. Se te proporciona el Language Reference Manual (LRM) completo de AVAP.
Tu tarea es generar UN problema de benchmark estilo MBPP para evaluar Tu tarea es generar UN problema de benchmark estilo MBPP para evaluar
@ -420,7 +523,7 @@ El siguiente ejemplo YA existe para esta combinación con calidad mejorable.
Genera algo DISTINTO y MÁS COMPLEJO que lo supere: Genera algo DISTINTO y MÁS COMPLEJO que lo supere:
``` ```
{existing_example.get('code', '')} {existing_example.get("code", "")}
``` ```
""" """
@ -455,6 +558,80 @@ Responde ÚNICAMENTE con el objeto JSON. Sin texto antes ni después.
""" """
def _reward_coverage_text(dataset: list) -> tuple[str, list[str]]:
"""Build coverage summary and list of underrepresented constructs."""
freq: dict[str, int] = defaultdict(int)
for ex in dataset:
for nt in ex.get("_detected", []):
freq[nt] += 1
if not freq:
return (
"No examples generated yet. Cover as many AVAP constructs as possible.",
list(NODE_TYPE_NAMES),
)
covered = set(freq.keys())
uncovered = sorted(set(NODE_TYPE_NAMES) - covered)
least_covered = sorted(freq, key=freq.get)[:10]
underrepresented = uncovered + least_covered
lines = [
f"Total examples: {len(dataset)}",
f"Covered constructs: {len(covered)}/{len(NODE_TYPE_NAMES)}",
]
if uncovered:
lines.append(f"UNCOVERED (prioritize): {', '.join(uncovered)}")
if least_covered:
lines.append(f"Least common: {', '.join(least_covered)}")
return "\n".join(lines), underrepresented
def build_reward_prompt(
lrm: str,
few_shots: list[dict],
coverage_text: str,
underrepresented: list[str],
) -> str:
"""Build the user prompt for Candidate A reward mode."""
few_shot_block = ""
if few_shots:
few_shot_block = "\n# EJEMPLOS DE ALTA CALIDAD (referencia)\n\n"
for i, ex in enumerate(few_shots, 1):
few_shot_block += f"## Ejemplo {i}\n"
few_shot_block += f"Enunciado: {ex.get('text', '')}\n"
few_shot_block += f"```\n{ex.get('code', '')}\n```\n\n"
steer_text = ""
if underrepresented:
constructs_str = ", ".join(f"`{c}`" for c in underrepresented[:8])
steer_text = (
f"\nPRIORIDAD: Constructs subrepresentados — intenta incluir "
f"algunos de ellos: {constructs_str}\n"
)
return f"""# LRM AVAP — Language Reference Manual
{lrm}
---
# ESTADO DE COBERTURA DEL DATASET
{coverage_text}
{few_shot_block}
---
# TAREA
Genera UN ejemplo AVAP original y complejo como problema de benchmark.
Requisitos:
- Escenario realista de microservicio HTTP en AVAP
- Usa la mayor cantidad de constructs AVAP distintos posible (aumenta la puntuación)
- El ejemplo debe ser DIFERENTE a los ejemplos de referencia mostrados arriba
- Código complejo y rico no ejemplos triviales de 3 líneas
- 2-3 aserciones re.match() en test_list
{steer_text}
Responde ÚNICAMENTE con el objeto JSON. Sin texto antes ni después.
"""
def call_api( def call_api(
client: anthropic.Anthropic, client: anthropic.Anthropic,
lrm: str, lrm: str,
@ -471,16 +648,22 @@ def call_api(
model="claude-sonnet-4-20250514", model="claude-sonnet-4-20250514",
max_tokens=4000, max_tokens=4000,
system=SYSTEM_PROMPT, system=SYSTEM_PROMPT,
messages=[{ messages=[
{
"role": "user", "role": "user",
"content": build_cell_prompt(lrm, cell, existing_example, map_summary), "content": build_cell_prompt(
}], lrm, cell, existing_example, map_summary
),
}
],
) )
raw = message.content[0].text.strip() raw = message.content[0].text.strip()
if raw.startswith("```"): if raw.startswith("```"):
lines = raw.splitlines() lines = raw.splitlines()
raw = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]) raw = "\n".join(
lines[1:-1] if lines[-1].strip() == "```" else lines[1:]
)
problem = json.loads(raw) problem = json.loads(raw)
if not isinstance(problem, dict): if not isinstance(problem, dict):
@ -509,6 +692,65 @@ def call_api(
return None return None
def call_api_reward(
client: anthropic.Anthropic,
lrm: str,
few_shots: list[dict],
coverage_text: str,
underrepresented: list[str],
task_id: int,
retries: int = 3,
) -> dict | None:
"""API call for Candidate A reward mode."""
for attempt in range(1, retries + 1):
try:
message = client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=4000,
system=SYSTEM_PROMPT,
messages=[
{
"role": "user",
"content": build_reward_prompt(
lrm,
few_shots,
coverage_text,
underrepresented,
),
}
],
)
raw = message.content[0].text.strip()
if raw.startswith("```"):
lines = raw.splitlines()
raw = "\n".join(
lines[1:-1] if lines[-1].strip() == "```" else lines[1:]
)
problem = json.loads(raw)
if not isinstance(problem, dict):
raise ValueError("Response is not a JSON object")
for field in ("text", "code", "test_list"):
if field not in problem:
raise ValueError(f"Missing field '{field}'")
if "test_inputs" not in problem:
problem["test_inputs"] = {}
problem["task_id"] = task_id
return problem
except (json.JSONDecodeError, ValueError) as e:
print(f"\n Attempt {attempt}/{retries} — parse error: {e}")
if attempt < retries:
time.sleep(2**attempt)
except anthropic.RateLimitError:
wait = 30 * attempt
print(f"\n Rate limit — waiting {wait}s...")
time.sleep(wait)
except anthropic.APIError as e:
print(f"\n API error at attempt {attempt}: {e}")
if attempt < retries:
time.sleep(5)
return None
def run_map_elites(args, client, lrm, output_path): def run_map_elites(args, client, lrm, output_path):
validator = CellValidator(parser_url=args.parser) validator = CellValidator(parser_url=args.parser)
@ -520,14 +762,17 @@ def run_map_elites(args, client, lrm, output_path):
valid_count = 0 valid_count = 0
cell_updates = 0 cell_updates = 0
print(f"\n MAP-Elites mode | cells: {cmap.total_cells} | target: {args.problems} examples") print(
print(f" Cell size: {args.cell_size} | Quality threshold: {args.quality_threshold}") f"\n MAP-Elites mode | cells: {cmap.total_cells} | target: {args.problems} examples"
)
print(
f" Cell size: {args.cell_size} | Quality threshold: {args.quality_threshold}"
)
print("" * 65) print("" * 65)
max_calls = args.problems * 4 max_calls = args.problems * 4
while len(dataset) < args.problems and call_count < max_calls: while len(dataset) < args.problems and call_count < max_calls:
cell = selector.select() cell = selector.select()
existing = cmap.get_example(cell) existing = cmap.get_example(cell)
call_count += 1 call_count += 1
@ -536,11 +781,15 @@ def run_map_elites(args, client, lrm, output_path):
f" [{call_count:04d}] Cell {sorted(cell)} " f" [{call_count:04d}] Cell {sorted(cell)} "
f"| filled={cmap.filled_cells}/{cmap.total_cells} " f"| filled={cmap.filled_cells}/{cmap.total_cells} "
f"| dataset={len(dataset)} ... ", f"| dataset={len(dataset)} ... ",
end="", flush=True, end="",
flush=True,
) )
problem = call_api( problem = call_api(
client, lrm, cell, task_id, client,
lrm,
cell,
task_id,
existing_example=existing, existing_example=existing,
map_summary=cmap.fill_summary(), map_summary=cmap.fill_summary(),
) )
@ -557,7 +806,7 @@ def run_map_elites(args, client, lrm, output_path):
if is_valid is None: if is_valid is None:
is_valid, ast = True, {} is_valid, ast = True, {}
if call_count == 1: if call_count == 1:
print(f"\n Parser unavailable — using keyword fallback", flush=True) print("\n Parser unavailable — using keyword fallback", flush=True)
if is_valid is False: if is_valid is False:
print(f"INVALID ({error_msg[:40]})") print(f"INVALID ({error_msg[:40]})")
@ -568,8 +817,13 @@ def run_map_elites(args, client, lrm, output_path):
# Compute cell quality # Compute cell quality
quality, components = validator.cell_quality( quality, components = validator.cell_quality(
code, ast, test_list, cell, code,
alpha=args.alpha, beta=args.beta, gamma=args.gamma, ast,
test_list,
cell,
alpha=args.alpha,
beta=args.beta,
gamma=args.gamma,
) )
problem["_cell"] = sorted(cell) problem["_cell"] = sorted(cell)
problem["_quality"] = components problem["_quality"] = components
@ -596,18 +850,21 @@ def run_map_elites(args, client, lrm, output_path):
_save(dataset, output_path, cmap) _save(dataset, output_path, cmap)
freq = cmap.node_type_frequency() freq = cmap.node_type_frequency()
entropy = cmap.distribution_entropy() entropy = cmap.distribution_entropy()
print(f"\n ── Checkpoint ──────────────────────────────────") print("\n ── Checkpoint ──────────────────────────────────")
print(f" Dataset: {len(dataset)} | Valid: {valid_count}/{call_count}") print(f" Dataset: {len(dataset)} | Valid: {valid_count}/{call_count}")
print(f" {cmap.fill_summary()}") print(f" {cmap.fill_summary()}")
print(f" Top-5 most frequent: {sorted(freq, key=freq.get, reverse=True)[:5]}") print(
f" Top-5 most frequent: {sorted(freq, key=freq.get, reverse=True)[:5]}"
)
print(f" Top-5 least frequent: {sorted(freq, key=freq.get)[:5]}") print(f" Top-5 least frequent: {sorted(freq, key=freq.get)[:5]}")
print(f" ────────────────────────────────────────────────\n") print(" ────────────────────────────────────────────────\n")
time.sleep(0.5) time.sleep(0.5)
_save(dataset, output_path, cmap) _save(dataset, output_path, cmap)
return dataset, cmap, valid_count, call_count return dataset, cmap, valid_count, call_count
def run_map_elites_prior(args, client, lrm, output_path): def run_map_elites_prior(args, client, lrm, output_path):
print("\n Loading ConstructPrior...", flush=True) print("\n Loading ConstructPrior...", flush=True)
@ -620,8 +877,10 @@ def run_map_elites_prior(args, client, lrm, output_path):
else: else:
# Fallback: yaml not found — use static prior and warn # Fallback: yaml not found — use static prior and warn
print(f" [WARN] construct_map.yaml not found at '{yaml_path}'.") print(f" [WARN] construct_map.yaml not found at '{yaml_path}'.")
print(f" [WARN] Using static fallback prior. Generate the real prior with:") print(" [WARN] Using static fallback prior. Generate the real prior with:")
print(f" [WARN] python construct_prior.py --generate-map --github-token TOKEN") print(
" [WARN] python construct_prior.py --generate-map --github-token TOKEN"
)
prior = ConstructPrior.from_static_fallback(epsilon=epsilon) prior = ConstructPrior.from_static_fallback(epsilon=epsilon)
print(f" {prior.coverage_summary()}") print(f" {prior.coverage_summary()}")
@ -629,7 +888,8 @@ def run_map_elites_prior(args, client, lrm, output_path):
validator = CellValidator(parser_url=args.parser) validator = CellValidator(parser_url=args.parser)
cmap = CoverageMap(cell_size=args.cell_size) cmap = CoverageMap(cell_size=args.cell_size)
selector = CellSelectorPrior( selector = CellSelectorPrior(
cmap, prior, cmap,
prior,
quality_threshold=args.quality_threshold, quality_threshold=args.quality_threshold,
phase3_threshold=getattr(args, "prior_phase3_threshold", 0.70), phase3_threshold=getattr(args, "prior_phase3_threshold", 0.70),
) )
@ -639,14 +899,17 @@ def run_map_elites_prior(args, client, lrm, output_path):
valid_count = 0 valid_count = 0
cell_updates = 0 cell_updates = 0
print(f"\n MAP-Elites+Prior mode | cells: {cmap.total_cells} | target: {args.problems} examples") print(
print(f" Cell size: {args.cell_size} | Quality threshold: {args.quality_threshold}") f"\n MAP-Elites+Prior mode | cells: {cmap.total_cells} | target: {args.problems} examples"
)
print(
f" Cell size: {args.cell_size} | Quality threshold: {args.quality_threshold}"
)
print("" * 65) print("" * 65)
max_calls = args.problems * 4 max_calls = args.problems * 4
while len(dataset) < args.problems and call_count < max_calls: while len(dataset) < args.problems and call_count < max_calls:
cell = selector.select() cell = selector.select()
existing = cmap.get_example(cell) existing = cmap.get_example(cell)
prior_w = prior.cell_weight(cell) prior_w = prior.cell_weight(cell)
@ -657,11 +920,15 @@ def run_map_elites_prior(args, client, lrm, output_path):
f"| prior={prior_w:.3f} " f"| prior={prior_w:.3f} "
f"| filled={cmap.filled_cells}/{cmap.total_cells} " f"| filled={cmap.filled_cells}/{cmap.total_cells} "
f"| dataset={len(dataset)} ... ", f"| dataset={len(dataset)} ... ",
end="", flush=True, end="",
flush=True,
) )
problem = call_api( problem = call_api(
client, lrm, cell, task_id, client,
lrm,
cell,
task_id,
existing_example=existing, existing_example=existing,
map_summary=cmap.fill_summary(), map_summary=cmap.fill_summary(),
) )
@ -678,7 +945,7 @@ def run_map_elites_prior(args, client, lrm, output_path):
if is_valid is None: if is_valid is None:
is_valid, ast = True, {} is_valid, ast = True, {}
if call_count == 1: if call_count == 1:
print(f"\n Parser unavailable — using keyword fallback", flush=True) print("\n Parser unavailable — using keyword fallback", flush=True)
if is_valid is False: if is_valid is False:
print(f"INVALID ({error_msg[:40]})") print(f"INVALID ({error_msg[:40]})")
@ -688,8 +955,13 @@ def run_map_elites_prior(args, client, lrm, output_path):
valid_count += 1 valid_count += 1
quality, components = validator.cell_quality( quality, components = validator.cell_quality(
code, ast, test_list, cell, code,
alpha=args.alpha, beta=args.beta, gamma=args.gamma, ast,
test_list,
cell,
alpha=args.alpha,
beta=args.beta,
gamma=args.gamma,
) )
problem["_cell"] = sorted(cell) problem["_cell"] = sorted(cell)
problem["_prior_weight"] = round(prior_w, 4) problem["_prior_weight"] = round(prior_w, 4)
@ -719,13 +991,17 @@ def run_map_elites_prior(args, client, lrm, output_path):
freq = cmap.node_type_frequency() freq = cmap.node_type_frequency()
entropy = cmap.distribution_entropy() entropy = cmap.distribution_entropy()
kl = prior.kl_divergence(freq) kl = prior.kl_divergence(freq)
print(f"\n ── Checkpoint ──────────────────────────────────") print("\n ── Checkpoint ──────────────────────────────────")
print(f" Dataset: {len(dataset)} | Valid: {valid_count}/{call_count}") print(f" Dataset: {len(dataset)} | Valid: {valid_count}/{call_count}")
print(f" {cmap.fill_summary()}") print(f" {cmap.fill_summary()}")
print(f" KL(dataset ‖ prior): {kl:.4f} (lower = closer to production patterns)") print(
print(f" Top-5 most frequent: {sorted(freq, key=freq.get, reverse=True)[:5]}") f" KL(dataset ‖ prior): {kl:.4f} (lower = closer to production patterns)"
)
print(
f" Top-5 most frequent: {sorted(freq, key=freq.get, reverse=True)[:5]}"
)
print(f" Top-5 least frequent: {sorted(freq, key=freq.get)[:5]}") print(f" Top-5 least frequent: {sorted(freq, key=freq.get)[:5]}")
print(f" ────────────────────────────────────────────────\n") print(" ────────────────────────────────────────────────\n")
time.sleep(0.5) time.sleep(0.5)
@ -755,6 +1031,170 @@ def _save(dataset: list, path: Path, cmap: CoverageMap, prior: ConstructPrior =
with open(stats_path, "w", encoding="utf-8") as f: with open(stats_path, "w", encoding="utf-8") as f:
json.dump(stats, f, ensure_ascii=False, indent=2) json.dump(stats, f, ensure_ascii=False, indent=2)
def run_reward(args, client, lrm, output_path):
"""Candidate A — CW-Reward with GoldPool feedback loop."""
validator = CellValidator(parser_url=args.parser)
pool = GoldPool(max_size=args.pool_size)
dataset = []
task_id = 1
call_count = 0
valid_count = 0
pool_entries = 0
w_ecs = args.w_ecs
w_novelty = args.w_novelty
w_tests = args.w_tests
print(f"\n CW-Reward mode | target: {args.problems} examples")
print(f" Weights: ECS={w_ecs} Novelty={w_novelty} Tests={w_tests}")
print(f" Pool size: {args.pool_size}")
print("" * 65)
max_calls = args.problems * 4
while len(dataset) < args.problems and call_count < max_calls:
coverage_text, underrepresented = _reward_coverage_text(dataset)
few_shots = pool.few_shot_examples(n=3)
call_count += 1
print(
f" [{call_count:04d}] "
f"| {pool.pool_summary()} "
f"| dataset={len(dataset)} ... ",
end="",
flush=True,
)
problem = call_api_reward(
client,
lrm,
few_shots,
coverage_text,
underrepresented,
task_id,
)
if problem is None:
print("SKIP (generation failed)")
continue
code = problem["code"]
test_list = problem.get("test_list", [])
is_valid, ast, error_msg = validator.parse(code)
if is_valid is None:
is_valid, ast = True, {}
if call_count == 1:
print(
"\n Parser unavailable — using keyword fallback",
flush=True,
)
if is_valid is False:
print(f"INVALID ({error_msg[:40]})")
continue
valid_count += 1
detected = validator.detect_constructs(code, ast)
reward, components = compute_reward(
detected,
test_list,
pool,
w_ecs=w_ecs,
w_novelty=w_novelty,
w_tests=w_tests,
)
problem["_detected"] = sorted(detected)
problem["_reward"] = components
entered = pool.add(problem, reward)
if entered:
pool_entries += 1
dataset.append(problem)
task_id += 1
print(
f"OK reward={reward:.3f} "
f"ecs={components['ecs']:.2f} "
f"novelty={components['novelty']:.2f} "
f"{'→ POOL' if entered else ''}"
)
if len(dataset) % 50 == 0:
_save_reward(dataset, output_path, pool, w_ecs, w_novelty, w_tests)
freq: dict[str, int] = defaultdict(int)
for ex in dataset:
for nt in ex.get("_detected", []):
freq[nt] += 1
total_f = sum(freq.values())
entropy = 0.0
if total_f > 0:
for count in freq.values():
p = count / total_f
if p > 0:
entropy -= p * math.log2(p)
print("\n ── Checkpoint ──────────────────────────────────")
print(f" Dataset: {len(dataset)} | Valid: {valid_count}/{call_count}")
print(f" {pool.pool_summary()}")
print(
f" Coverage: {len(freq)}/{len(NODE_TYPE_NAMES)} | Entropy: {entropy:.2f} bits"
)
uncov = sorted(set(NODE_TYPE_NAMES) - set(freq.keys()))
print(f" Uncovered: {uncov[:10]}")
print(" ────────────────────────────────────────────────\n")
time.sleep(0.5)
_save_reward(dataset, output_path, pool, w_ecs, w_novelty, w_tests)
return dataset, pool, valid_count, call_count
def _save_reward(
dataset: list,
path: Path,
pool: GoldPool,
w_ecs: float,
w_novelty: float,
w_tests: float,
):
"""Save dataset and stats for Candidate A reward mode."""
with open(path, "w", encoding="utf-8") as f:
json.dump(dataset, f, ensure_ascii=False, indent=2)
stats_path = path.with_name(path.stem + "_reward_stats.json")
freq: dict[str, int] = defaultdict(int)
for ex in dataset:
for nt in ex.get("_detected", []):
freq[nt] += 1
total_f = sum(freq.values())
entropy = 0.0
if total_f > 0:
for count in freq.values():
p = count / total_f
if p > 0:
entropy -= p * math.log2(p)
rewards = [ex.get("_reward", {}).get("reward", 0) for ex in dataset]
stats = {
"mode": "reward",
"weights": {"w_ecs": w_ecs, "w_novelty": w_novelty, "w_tests": w_tests},
"dataset_size": len(dataset),
"pool_size": pool.size,
"pool_summary": pool.pool_summary(),
"distribution_entropy": round(entropy, 3),
"node_type_frequency": dict(freq),
"covered_constructs": len(freq),
"total_constructs": len(NODE_TYPE_NAMES),
"mean_reward": round(sum(rewards) / max(len(rewards), 1), 4),
}
with open(stats_path, "w", encoding="utf-8") as f:
json.dump(stats, f, ensure_ascii=False, indent=2)
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="AVAP Dataset Generator v2 — MAP-Elites Quality-Diversity Pipeline" description="AVAP Dataset Generator v2 — MAP-Elites Quality-Diversity Pipeline"
@ -762,18 +1202,39 @@ def main():
parser.add_argument("--lrm", default="avap.md") parser.add_argument("--lrm", default="avap.md")
parser.add_argument("--output", default="output/mbpp_avap_v2.json") parser.add_argument("--output", default="output/mbpp_avap_v2.json")
parser.add_argument("--problems", type=int, default=5000) parser.add_argument("--problems", type=int, default=5000)
parser.add_argument("--parser", default="http://localhost:8080", parser.add_argument(
help="AVAP parser URL") "--parser", default="http://localhost:8080", help="AVAP parser URL"
parser.add_argument("--cell-size", type=int, default=3, )
help="Max constructs per cell: 2=pairs, 3=pairs+trios (default: 3)") parser.add_argument(
parser.add_argument("--quality-threshold", type=float, default=0.80, "--cell-size",
help="Min quality to consider a cell 'good' (default: 0.80)") type=int,
parser.add_argument("--alpha", type=float, default=0.30, default=3,
help="Weight for bonus constructs in cell quality (default: 0.30)") help="Max constructs per cell: 2=pairs, 3=pairs+trios (default: 3)",
parser.add_argument("--beta", type=float, default=0.20, )
help="Weight for test quality in cell quality (default: 0.20)") parser.add_argument(
parser.add_argument("--gamma", type=float, default=0.10, "--quality-threshold",
help="Weight for code richness in cell quality (default: 0.10)") type=float,
default=0.80,
help="Min quality to consider a cell 'good' (default: 0.80)",
)
parser.add_argument(
"--alpha",
type=float,
default=0.30,
help="Weight for bonus constructs in cell quality (default: 0.30)",
)
parser.add_argument(
"--beta",
type=float,
default=0.20,
help="Weight for test quality in cell quality (default: 0.20)",
)
parser.add_argument(
"--gamma",
type=float,
default=0.10,
help="Weight for code richness in cell quality (default: 0.10)",
)
parser.add_argument( parser.add_argument(
"--mode", "--mode",
choices=["map-elites-prior", "map-elites", "reward"], choices=["map-elites-prior", "map-elites", "reward"],
@ -809,6 +1270,30 @@ def main():
"cells become the focus. Default: 0.70" "cells become the focus. Default: 0.70"
), ),
) )
parser.add_argument(
"--w-ecs",
type=float,
default=0.50,
help="Candidate A: weight for ECS in reward formula (default: 0.50)",
)
parser.add_argument(
"--w-novelty",
type=float,
default=0.35,
help="Candidate A: weight for Jaccard novelty in reward formula (default: 0.35)",
)
parser.add_argument(
"--w-tests",
type=float,
default=0.15,
help="Candidate A: weight for test quality in reward formula (default: 0.15)",
)
parser.add_argument(
"--pool-size",
type=int,
default=5,
help="Candidate A: max examples in GoldPool (default: 50)",
)
parser.add_argument("--api-key", default=None) parser.add_argument("--api-key", default=None)
args = parser.parse_args() args = parser.parse_args()
@ -844,36 +1329,72 @@ def main():
print(f" Quality thresh : {args.quality_threshold}") print(f" Quality thresh : {args.quality_threshold}")
if args.mode == "map-elites-prior": if args.mode == "map-elites-prior":
yaml_exists = Path(args.prior_map).exists() yaml_exists = Path(args.prior_map).exists()
print(f" Prior map : {args.prior_map} ({'✓ found' if yaml_exists else '✗ not found — will use static fallback'})") print(
f" Prior map : {args.prior_map} ({'✓ found' if yaml_exists else '✗ not found — will use static fallback'})"
)
print(f" Prior epsilon : {args.prior_epsilon}") print(f" Prior epsilon : {args.prior_epsilon}")
elif args.mode == "reward":
print(
f" Reward weights : ECS={args.w_ecs} Novelty={args.w_novelty} Tests={args.w_tests}"
)
print(f" Pool size : {args.pool_size}")
print("=" * 65) print("=" * 65)
prior = None prior = None
pool = None
cmap = None
if args.mode == "map-elites-prior": if args.mode == "map-elites-prior":
result = run_map_elites_prior(args, client, lrm, output_path) result = run_map_elites_prior(args, client, lrm, output_path)
dataset, cmap, valid_count, call_count, prior = result dataset, cmap, valid_count, call_count, prior = result
elif args.mode == "map-elites": elif args.mode == "map-elites":
dataset, cmap, valid_count, call_count = run_map_elites(args, client, lrm, output_path) dataset, cmap, valid_count, call_count = run_map_elites(
args, client, lrm, output_path
)
else: else:
sys.exit("ERROR: --mode reward (Candidate A) is not yet implemented in v2. " dataset, pool, valid_count, call_count = run_reward(
"Use generate_mbap.py for the v1 reward baseline.") args, client, lrm, output_path
)
# Final report # Final report
if cmap is not None:
freq = cmap.node_type_frequency() freq = cmap.node_type_frequency()
entropy = cmap.distribution_entropy() entropy = cmap.distribution_entropy()
else:
freq = defaultdict(int)
for ex in dataset:
for nt in ex.get("_detected", []):
freq[nt] += 1
freq = dict(freq)
total_f = sum(freq.values())
entropy = 0.0
if total_f > 0:
for count in freq.values():
p = count / total_f
if p > 0:
entropy -= p * math.log2(p)
entropy = round(entropy, 3)
print("\n" + "=" * 65) print("\n" + "=" * 65)
print(" Pipeline complete") print(" Pipeline complete")
print(f" Mode : {mode_label}") print(f" Mode : {mode_label}")
print(f" Total API calls : {call_count}") print(f" Total API calls : {call_count}")
print(f" Valid examples : {valid_count} ({100*valid_count/max(call_count,1):.1f}%)") print(
f" Valid examples : {valid_count} ({100 * valid_count / max(call_count, 1):.1f}%)"
)
print(f" Dataset size : {len(dataset)}") print(f" Dataset size : {len(dataset)}")
if cmap is not None:
print(f" {cmap.fill_summary()}") print(f" {cmap.fill_summary()}")
print(f" Distribution entropy : {entropy:.3f} bits (max={math.log2(len(NODE_TYPE_NAMES)):.2f})") if pool is not None:
print(f" {pool.pool_summary()}")
print(
f" Distribution entropy : {entropy:.3f} bits (max={math.log2(len(NODE_TYPE_NAMES)):.2f})"
)
if prior is not None: if prior is not None:
kl = prior.kl_divergence(freq) kl = prior.kl_divergence(freq)
print(f" KL(dataset ‖ prior) : {kl:.4f} (0 = perfect alignment with production code)") print(
f" KL(dataset ‖ prior) : {kl:.4f} (0 = perfect alignment with production code)"
)
print(f" Most covered : {sorted(freq, key=freq.get, reverse=True)[:5]}") print(f" Most covered : {sorted(freq, key=freq.get, reverse=True)[:5]}")
print(f" Least covered : {sorted(freq, key=freq.get)[:5]}") print(f" Least covered : {sorted(freq, key=freq.get)[:5]}")
print(f" Output : {output_path}") print(f" Output : {output_path}")