Enhance generate_mbap_v2.py with new reward mechanism and GoldPool integration
- Added GoldPool class to manage a top-K pool of high-reward examples. - Implemented compute_reward function to calculate composite rewards based on execution coverage, novelty, and test quality. - Introduced call_api_reward function for API calls in the new reward mode. - Updated main function to support new reward mode with adjustable weights for ECS, novelty, and test quality. - Enhanced dataset saving functionality to include reward statistics. - Refactored existing code for improved readability and consistency.
This commit is contained in:
parent
c6b57849cd
commit
f747c140c8
|
|
@ -11,6 +11,8 @@ import argparse
|
|||
import json
|
||||
import math
|
||||
import os
|
||||
import io
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
|
@ -19,8 +21,11 @@ from pathlib import Path
|
|||
|
||||
import anthropic
|
||||
import requests
|
||||
from construct_prior import AVAP_NODE_NAMES, ConstructPrior
|
||||
|
||||
from construct_prior import ConstructPrior, AVAP_NODE_NAMES
|
||||
from src.config import settings
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
AVAP_NODE_TYPES = {
|
||||
"addParam": ["addParam("],
|
||||
|
|
@ -66,14 +71,13 @@ AVAP_NODE_TYPES = {
|
|||
NODE_TYPE_NAMES = AVAP_NODE_NAMES
|
||||
_PRIOR_EPSILON = 0.05
|
||||
|
||||
class CellValidator:
|
||||
|
||||
class CellValidator:
|
||||
def __init__(self, parser_url: str, parser_timeout: int = 5):
|
||||
self.parser_url = parser_url.rstrip("/")
|
||||
self.parser_timeout = parser_timeout
|
||||
self._parser_available = True
|
||||
|
||||
|
||||
def parse(self, code: str) -> tuple[bool, dict, str]:
|
||||
|
||||
if not self._parser_available:
|
||||
|
|
@ -81,7 +85,9 @@ class CellValidator:
|
|||
try:
|
||||
resp = requests.post(
|
||||
f"{self.parser_url}/parse",
|
||||
#f"{settings.parser_url}/api/v1/upload",
|
||||
json={"code": code},
|
||||
#files={"file": ("task.json", io.BytesIO(json.dumps([code]).encode("utf-8")), "application/json")},
|
||||
timeout=self.parser_timeout,
|
||||
)
|
||||
data = resp.json()
|
||||
|
|
@ -93,6 +99,7 @@ class CellValidator:
|
|||
return None, {}, "parser_unavailable"
|
||||
except Exception as e:
|
||||
return False, {}, str(e)
|
||||
|
||||
def detect_constructs(self, code: str, ast: dict) -> set:
|
||||
if ast:
|
||||
return self._from_ast(ast)
|
||||
|
|
@ -147,7 +154,8 @@ class CellValidator:
|
|||
bonus_ratio = len(extra) / max(len(all_types) - len(cell_constructs), 1)
|
||||
|
||||
tq = sum(
|
||||
1 for t in test_list
|
||||
1
|
||||
for t in test_list
|
||||
if isinstance(t, str) and "re.match(" in t and len(t.strip()) > 10
|
||||
) / max(len(test_list), 1)
|
||||
|
||||
|
|
@ -169,8 +177,6 @@ class CellValidator:
|
|||
|
||||
|
||||
class CoverageMap:
|
||||
|
||||
|
||||
def __init__(self, cell_size: int = 3):
|
||||
|
||||
self.cell_size = cell_size
|
||||
|
|
@ -215,10 +221,7 @@ class CoverageMap:
|
|||
return [c for c in self._all_cells if c not in self._map]
|
||||
|
||||
def get_low_quality_cells(self, threshold: float = 0.7) -> list[frozenset]:
|
||||
return [
|
||||
c for c, (_, q, _) in self._map.items()
|
||||
if q < threshold
|
||||
]
|
||||
return [c for c, (_, q, _) in self._map.items() if q < threshold]
|
||||
|
||||
def get_example(self, cell: frozenset) -> dict | None:
|
||||
entry = self._map.get(cell)
|
||||
|
|
@ -254,15 +257,14 @@ class CoverageMap:
|
|||
entropy = self.distribution_entropy()
|
||||
return (
|
||||
f"Cells: {self.filled_cells}/{self.total_cells} filled "
|
||||
f"({100*self.fill_rate:.1f}%) | "
|
||||
f"({100 * self.fill_rate:.1f}%) | "
|
||||
f"Low quality: {low} | "
|
||||
f"Empty: {empty} | "
|
||||
f"Entropy: {entropy:.2f} bits"
|
||||
)
|
||||
|
||||
|
||||
class CellSelector:
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
coverage_map: CoverageMap,
|
||||
|
|
@ -274,6 +276,7 @@ class CellSelector:
|
|||
self.ucb_c = ucb_c
|
||||
self._total_calls = 0
|
||||
import random
|
||||
|
||||
self._rng = random.Random(42)
|
||||
|
||||
def select(self) -> frozenset:
|
||||
|
|
@ -304,8 +307,8 @@ class CellSelector:
|
|||
|
||||
return best_cell
|
||||
|
||||
class CellSelectorPrior(CellSelector):
|
||||
|
||||
class CellSelectorPrior(CellSelector):
|
||||
def __init__(
|
||||
self,
|
||||
coverage_map: CoverageMap,
|
||||
|
|
@ -326,8 +329,7 @@ class CellSelectorPrior(CellSelector):
|
|||
|
||||
if empty:
|
||||
high_prior_empty = [
|
||||
c for c in empty
|
||||
if self.prior.cell_weight(c) > self.prior.epsilon * 1.5
|
||||
c for c in empty if self.prior.cell_weight(c) > self.prior.epsilon * 1.5
|
||||
]
|
||||
if high_prior_empty:
|
||||
return self._weighted_sample(high_prior_empty)
|
||||
|
|
@ -371,6 +373,107 @@ class CellSelectorPrior(CellSelector):
|
|||
|
||||
return best_cell
|
||||
|
||||
|
||||
class GoldPool:
|
||||
"""Top-K pool of high-reward examples for Candidate A (CW-Reward)."""
|
||||
|
||||
def __init__(self, max_size: int = 50, seed: int = 42):
|
||||
self._pool: list[tuple[dict, float]] = []
|
||||
self.max_size = max_size
|
||||
self._rng = random.Random(seed)
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return len(self._pool)
|
||||
|
||||
def add(self, example: dict, reward: float) -> bool:
|
||||
"""Add example to pool. Returns True if it entered the pool."""
|
||||
if len(self._pool) < self.max_size:
|
||||
self._pool.append((example, reward))
|
||||
self._pool.sort(key=lambda x: -x[1])
|
||||
return True
|
||||
if reward > self._pool[-1][1]:
|
||||
self._pool[-1] = (example, reward)
|
||||
self._pool.sort(key=lambda x: -x[1])
|
||||
return True
|
||||
return False
|
||||
|
||||
def few_shot_examples(self, n: int = 3) -> list[dict]:
|
||||
"""Return n diverse examples from the pool for few-shot prompting."""
|
||||
if not self._pool:
|
||||
return []
|
||||
if len(self._pool) <= n:
|
||||
return [ex for ex, _ in self._pool]
|
||||
top_half = self._pool[: max(len(self._pool) // 2, n)]
|
||||
sampled = self._rng.sample(top_half, n)
|
||||
return [ex for ex, _ in sampled]
|
||||
|
||||
def construct_sets(self) -> list[set[str]]:
|
||||
"""Return the construct sets of all pool examples."""
|
||||
return [set(ex.get("_detected", [])) for ex, _ in self._pool]
|
||||
|
||||
def pool_summary(self) -> str:
|
||||
if not self._pool:
|
||||
return "GoldPool: empty"
|
||||
rewards = [r for _, r in self._pool]
|
||||
return (
|
||||
f"GoldPool: {len(self._pool)}/{self.max_size} | "
|
||||
f"reward: min={min(rewards):.3f} max={max(rewards):.3f} "
|
||||
f"mean={sum(rewards) / len(rewards):.3f}"
|
||||
)
|
||||
|
||||
|
||||
def compute_ecs(detected: set[str]) -> float:
|
||||
"""Execution Coverage Score — fraction of all AVAP node types covered."""
|
||||
return len(detected) / max(len(NODE_TYPE_NAMES), 1)
|
||||
|
||||
|
||||
def jaccard_similarity(a: set, b: set) -> float:
|
||||
"""Jaccard similarity between two sets."""
|
||||
union = a | b
|
||||
if not union:
|
||||
return 1.0
|
||||
return len(a & b) / len(union)
|
||||
|
||||
|
||||
def jaccard_novelty(detected: set[str], pool: GoldPool) -> float:
|
||||
"""Novelty = 1 - max Jaccard similarity with any pool example."""
|
||||
pool_sets = pool.construct_sets()
|
||||
if not pool_sets:
|
||||
return 1.0
|
||||
max_sim = max(jaccard_similarity(detected, ps) for ps in pool_sets)
|
||||
return 1.0 - max_sim
|
||||
|
||||
|
||||
def compute_reward(
|
||||
detected: set[str],
|
||||
test_list: list,
|
||||
pool: GoldPool,
|
||||
w_ecs: float,
|
||||
w_novelty: float,
|
||||
w_tests: float,
|
||||
) -> tuple[float, dict]:
|
||||
"""
|
||||
Candidate A composite reward:
|
||||
reward(e) = w_ecs * ECS(e) + w_novelty * Jaccard_novelty(e, Pool) + w_tests * test_quality(e)
|
||||
"""
|
||||
ecs = compute_ecs(detected)
|
||||
novelty = jaccard_novelty(detected, pool)
|
||||
tq = sum(
|
||||
1
|
||||
for t in test_list
|
||||
if isinstance(t, str) and "re.match(" in t and len(t.strip()) > 10
|
||||
) / max(len(test_list), 1)
|
||||
reward = w_ecs * ecs + w_novelty * novelty + w_tests * tq
|
||||
return reward, {
|
||||
"ecs": round(ecs, 3),
|
||||
"novelty": round(novelty, 3),
|
||||
"test_quality": round(tq, 3),
|
||||
"reward": round(reward, 3),
|
||||
"detected": sorted(detected),
|
||||
}
|
||||
|
||||
|
||||
SYSTEM_PROMPT = """Eres un experto en el lenguaje AVAP.
|
||||
Se te proporciona el Language Reference Manual (LRM) completo de AVAP.
|
||||
Tu tarea es generar UN problema de benchmark estilo MBPP para evaluar
|
||||
|
|
@ -420,7 +523,7 @@ El siguiente ejemplo YA existe para esta combinación con calidad mejorable.
|
|||
Genera algo DISTINTO y MÁS COMPLEJO que lo supere:
|
||||
|
||||
```
|
||||
{existing_example.get('code', '')}
|
||||
{existing_example.get("code", "")}
|
||||
```
|
||||
"""
|
||||
|
||||
|
|
@ -455,6 +558,80 @@ Responde ÚNICAMENTE con el objeto JSON. Sin texto antes ni después.
|
|||
"""
|
||||
|
||||
|
||||
def _reward_coverage_text(dataset: list) -> tuple[str, list[str]]:
|
||||
"""Build coverage summary and list of underrepresented constructs."""
|
||||
freq: dict[str, int] = defaultdict(int)
|
||||
for ex in dataset:
|
||||
for nt in ex.get("_detected", []):
|
||||
freq[nt] += 1
|
||||
if not freq:
|
||||
return (
|
||||
"No examples generated yet. Cover as many AVAP constructs as possible.",
|
||||
list(NODE_TYPE_NAMES),
|
||||
)
|
||||
covered = set(freq.keys())
|
||||
uncovered = sorted(set(NODE_TYPE_NAMES) - covered)
|
||||
least_covered = sorted(freq, key=freq.get)[:10]
|
||||
underrepresented = uncovered + least_covered
|
||||
lines = [
|
||||
f"Total examples: {len(dataset)}",
|
||||
f"Covered constructs: {len(covered)}/{len(NODE_TYPE_NAMES)}",
|
||||
]
|
||||
if uncovered:
|
||||
lines.append(f"UNCOVERED (prioritize): {', '.join(uncovered)}")
|
||||
if least_covered:
|
||||
lines.append(f"Least common: {', '.join(least_covered)}")
|
||||
return "\n".join(lines), underrepresented
|
||||
|
||||
|
||||
def build_reward_prompt(
|
||||
lrm: str,
|
||||
few_shots: list[dict],
|
||||
coverage_text: str,
|
||||
underrepresented: list[str],
|
||||
) -> str:
|
||||
"""Build the user prompt for Candidate A reward mode."""
|
||||
few_shot_block = ""
|
||||
if few_shots:
|
||||
few_shot_block = "\n# EJEMPLOS DE ALTA CALIDAD (referencia)\n\n"
|
||||
for i, ex in enumerate(few_shots, 1):
|
||||
few_shot_block += f"## Ejemplo {i}\n"
|
||||
few_shot_block += f"Enunciado: {ex.get('text', '')}\n"
|
||||
few_shot_block += f"```\n{ex.get('code', '')}\n```\n\n"
|
||||
steer_text = ""
|
||||
if underrepresented:
|
||||
constructs_str = ", ".join(f"`{c}`" for c in underrepresented[:8])
|
||||
steer_text = (
|
||||
f"\nPRIORIDAD: Constructs subrepresentados — intenta incluir "
|
||||
f"algunos de ellos: {constructs_str}\n"
|
||||
)
|
||||
return f"""# LRM AVAP — Language Reference Manual
|
||||
|
||||
{lrm}
|
||||
|
||||
---
|
||||
|
||||
# ESTADO DE COBERTURA DEL DATASET
|
||||
|
||||
{coverage_text}
|
||||
{few_shot_block}
|
||||
---
|
||||
|
||||
# TAREA
|
||||
|
||||
Genera UN ejemplo AVAP original y complejo como problema de benchmark.
|
||||
|
||||
Requisitos:
|
||||
- Escenario realista de microservicio HTTP en AVAP
|
||||
- Usa la mayor cantidad de constructs AVAP distintos posible (aumenta la puntuación)
|
||||
- El ejemplo debe ser DIFERENTE a los ejemplos de referencia mostrados arriba
|
||||
- Código complejo y rico — no ejemplos triviales de 3 líneas
|
||||
- 2-3 aserciones re.match() en test_list
|
||||
{steer_text}
|
||||
Responde ÚNICAMENTE con el objeto JSON. Sin texto antes ni después.
|
||||
"""
|
||||
|
||||
|
||||
def call_api(
|
||||
client: anthropic.Anthropic,
|
||||
lrm: str,
|
||||
|
|
@ -471,16 +648,22 @@ def call_api(
|
|||
model="claude-sonnet-4-20250514",
|
||||
max_tokens=4000,
|
||||
system=SYSTEM_PROMPT,
|
||||
messages=[{
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": build_cell_prompt(lrm, cell, existing_example, map_summary),
|
||||
}],
|
||||
"content": build_cell_prompt(
|
||||
lrm, cell, existing_example, map_summary
|
||||
),
|
||||
}
|
||||
],
|
||||
)
|
||||
raw = message.content[0].text.strip()
|
||||
|
||||
if raw.startswith("```"):
|
||||
lines = raw.splitlines()
|
||||
raw = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:])
|
||||
raw = "\n".join(
|
||||
lines[1:-1] if lines[-1].strip() == "```" else lines[1:]
|
||||
)
|
||||
|
||||
problem = json.loads(raw)
|
||||
if not isinstance(problem, dict):
|
||||
|
|
@ -496,7 +679,7 @@ def call_api(
|
|||
except (json.JSONDecodeError, ValueError) as e:
|
||||
print(f"\n Attempt {attempt}/{retries} — parse error: {e}")
|
||||
if attempt < retries:
|
||||
time.sleep(2 ** attempt)
|
||||
time.sleep(2**attempt)
|
||||
except anthropic.RateLimitError:
|
||||
wait = 30 * attempt
|
||||
print(f"\n Rate limit — waiting {wait}s...")
|
||||
|
|
@ -509,6 +692,65 @@ def call_api(
|
|||
return None
|
||||
|
||||
|
||||
def call_api_reward(
|
||||
client: anthropic.Anthropic,
|
||||
lrm: str,
|
||||
few_shots: list[dict],
|
||||
coverage_text: str,
|
||||
underrepresented: list[str],
|
||||
task_id: int,
|
||||
retries: int = 3,
|
||||
) -> dict | None:
|
||||
"""API call for Candidate A reward mode."""
|
||||
for attempt in range(1, retries + 1):
|
||||
try:
|
||||
message = client.messages.create(
|
||||
model="claude-sonnet-4-20250514",
|
||||
max_tokens=4000,
|
||||
system=SYSTEM_PROMPT,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": build_reward_prompt(
|
||||
lrm,
|
||||
few_shots,
|
||||
coverage_text,
|
||||
underrepresented,
|
||||
),
|
||||
}
|
||||
],
|
||||
)
|
||||
raw = message.content[0].text.strip()
|
||||
if raw.startswith("```"):
|
||||
lines = raw.splitlines()
|
||||
raw = "\n".join(
|
||||
lines[1:-1] if lines[-1].strip() == "```" else lines[1:]
|
||||
)
|
||||
problem = json.loads(raw)
|
||||
if not isinstance(problem, dict):
|
||||
raise ValueError("Response is not a JSON object")
|
||||
for field in ("text", "code", "test_list"):
|
||||
if field not in problem:
|
||||
raise ValueError(f"Missing field '{field}'")
|
||||
if "test_inputs" not in problem:
|
||||
problem["test_inputs"] = {}
|
||||
problem["task_id"] = task_id
|
||||
return problem
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
print(f"\n Attempt {attempt}/{retries} — parse error: {e}")
|
||||
if attempt < retries:
|
||||
time.sleep(2**attempt)
|
||||
except anthropic.RateLimitError:
|
||||
wait = 30 * attempt
|
||||
print(f"\n Rate limit — waiting {wait}s...")
|
||||
time.sleep(wait)
|
||||
except anthropic.APIError as e:
|
||||
print(f"\n API error at attempt {attempt}: {e}")
|
||||
if attempt < retries:
|
||||
time.sleep(5)
|
||||
return None
|
||||
|
||||
|
||||
def run_map_elites(args, client, lrm, output_path):
|
||||
|
||||
validator = CellValidator(parser_url=args.parser)
|
||||
|
|
@ -520,14 +762,17 @@ def run_map_elites(args, client, lrm, output_path):
|
|||
valid_count = 0
|
||||
cell_updates = 0
|
||||
|
||||
print(f"\n MAP-Elites mode | cells: {cmap.total_cells} | target: {args.problems} examples")
|
||||
print(f" Cell size: {args.cell_size} | Quality threshold: {args.quality_threshold}")
|
||||
print(
|
||||
f"\n MAP-Elites mode | cells: {cmap.total_cells} | target: {args.problems} examples"
|
||||
)
|
||||
print(
|
||||
f" Cell size: {args.cell_size} | Quality threshold: {args.quality_threshold}"
|
||||
)
|
||||
print("─" * 65)
|
||||
|
||||
max_calls = args.problems * 4
|
||||
|
||||
while len(dataset) < args.problems and call_count < max_calls:
|
||||
|
||||
cell = selector.select()
|
||||
existing = cmap.get_example(cell)
|
||||
call_count += 1
|
||||
|
|
@ -536,11 +781,15 @@ def run_map_elites(args, client, lrm, output_path):
|
|||
f" [{call_count:04d}] Cell {sorted(cell)} "
|
||||
f"| filled={cmap.filled_cells}/{cmap.total_cells} "
|
||||
f"| dataset={len(dataset)} ... ",
|
||||
end="", flush=True,
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
problem = call_api(
|
||||
client, lrm, cell, task_id,
|
||||
client,
|
||||
lrm,
|
||||
cell,
|
||||
task_id,
|
||||
existing_example=existing,
|
||||
map_summary=cmap.fill_summary(),
|
||||
)
|
||||
|
|
@ -557,7 +806,7 @@ def run_map_elites(args, client, lrm, output_path):
|
|||
if is_valid is None:
|
||||
is_valid, ast = True, {}
|
||||
if call_count == 1:
|
||||
print(f"\n Parser unavailable — using keyword fallback", flush=True)
|
||||
print("\n Parser unavailable — using keyword fallback", flush=True)
|
||||
|
||||
if is_valid is False:
|
||||
print(f"INVALID ({error_msg[:40]})")
|
||||
|
|
@ -568,8 +817,13 @@ def run_map_elites(args, client, lrm, output_path):
|
|||
|
||||
# Compute cell quality
|
||||
quality, components = validator.cell_quality(
|
||||
code, ast, test_list, cell,
|
||||
alpha=args.alpha, beta=args.beta, gamma=args.gamma,
|
||||
code,
|
||||
ast,
|
||||
test_list,
|
||||
cell,
|
||||
alpha=args.alpha,
|
||||
beta=args.beta,
|
||||
gamma=args.gamma,
|
||||
)
|
||||
problem["_cell"] = sorted(cell)
|
||||
problem["_quality"] = components
|
||||
|
|
@ -596,22 +850,25 @@ def run_map_elites(args, client, lrm, output_path):
|
|||
_save(dataset, output_path, cmap)
|
||||
freq = cmap.node_type_frequency()
|
||||
entropy = cmap.distribution_entropy()
|
||||
print(f"\n ── Checkpoint ──────────────────────────────────")
|
||||
print("\n ── Checkpoint ──────────────────────────────────")
|
||||
print(f" Dataset: {len(dataset)} | Valid: {valid_count}/{call_count}")
|
||||
print(f" {cmap.fill_summary()}")
|
||||
print(f" Top-5 most frequent: {sorted(freq, key=freq.get, reverse=True)[:5]}")
|
||||
print(
|
||||
f" Top-5 most frequent: {sorted(freq, key=freq.get, reverse=True)[:5]}"
|
||||
)
|
||||
print(f" Top-5 least frequent: {sorted(freq, key=freq.get)[:5]}")
|
||||
print(f" ────────────────────────────────────────────────\n")
|
||||
print(" ────────────────────────────────────────────────\n")
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
_save(dataset, output_path, cmap)
|
||||
return dataset, cmap, valid_count, call_count
|
||||
|
||||
|
||||
def run_map_elites_prior(args, client, lrm, output_path):
|
||||
|
||||
print("\n Loading ConstructPrior...", flush=True)
|
||||
prior_map = getattr(args, "prior_map","construct_map.yaml")
|
||||
prior_map = getattr(args, "prior_map", "construct_map.yaml")
|
||||
epsilon = getattr(args, "prior_epsilon", _PRIOR_EPSILON)
|
||||
yaml_path = Path(prior_map)
|
||||
|
||||
|
|
@ -620,8 +877,10 @@ def run_map_elites_prior(args, client, lrm, output_path):
|
|||
else:
|
||||
# Fallback: yaml not found — use static prior and warn
|
||||
print(f" [WARN] construct_map.yaml not found at '{yaml_path}'.")
|
||||
print(f" [WARN] Using static fallback prior. Generate the real prior with:")
|
||||
print(f" [WARN] python construct_prior.py --generate-map --github-token TOKEN")
|
||||
print(" [WARN] Using static fallback prior. Generate the real prior with:")
|
||||
print(
|
||||
" [WARN] python construct_prior.py --generate-map --github-token TOKEN"
|
||||
)
|
||||
prior = ConstructPrior.from_static_fallback(epsilon=epsilon)
|
||||
|
||||
print(f" {prior.coverage_summary()}")
|
||||
|
|
@ -629,7 +888,8 @@ def run_map_elites_prior(args, client, lrm, output_path):
|
|||
validator = CellValidator(parser_url=args.parser)
|
||||
cmap = CoverageMap(cell_size=args.cell_size)
|
||||
selector = CellSelectorPrior(
|
||||
cmap, prior,
|
||||
cmap,
|
||||
prior,
|
||||
quality_threshold=args.quality_threshold,
|
||||
phase3_threshold=getattr(args, "prior_phase3_threshold", 0.70),
|
||||
)
|
||||
|
|
@ -639,14 +899,17 @@ def run_map_elites_prior(args, client, lrm, output_path):
|
|||
valid_count = 0
|
||||
cell_updates = 0
|
||||
|
||||
print(f"\n MAP-Elites+Prior mode | cells: {cmap.total_cells} | target: {args.problems} examples")
|
||||
print(f" Cell size: {args.cell_size} | Quality threshold: {args.quality_threshold}")
|
||||
print(
|
||||
f"\n MAP-Elites+Prior mode | cells: {cmap.total_cells} | target: {args.problems} examples"
|
||||
)
|
||||
print(
|
||||
f" Cell size: {args.cell_size} | Quality threshold: {args.quality_threshold}"
|
||||
)
|
||||
print("─" * 65)
|
||||
|
||||
max_calls = args.problems * 4
|
||||
|
||||
while len(dataset) < args.problems and call_count < max_calls:
|
||||
|
||||
cell = selector.select()
|
||||
existing = cmap.get_example(cell)
|
||||
prior_w = prior.cell_weight(cell)
|
||||
|
|
@ -657,11 +920,15 @@ def run_map_elites_prior(args, client, lrm, output_path):
|
|||
f"| prior={prior_w:.3f} "
|
||||
f"| filled={cmap.filled_cells}/{cmap.total_cells} "
|
||||
f"| dataset={len(dataset)} ... ",
|
||||
end="", flush=True,
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
problem = call_api(
|
||||
client, lrm, cell, task_id,
|
||||
client,
|
||||
lrm,
|
||||
cell,
|
||||
task_id,
|
||||
existing_example=existing,
|
||||
map_summary=cmap.fill_summary(),
|
||||
)
|
||||
|
|
@ -678,7 +945,7 @@ def run_map_elites_prior(args, client, lrm, output_path):
|
|||
if is_valid is None:
|
||||
is_valid, ast = True, {}
|
||||
if call_count == 1:
|
||||
print(f"\n Parser unavailable — using keyword fallback", flush=True)
|
||||
print("\n Parser unavailable — using keyword fallback", flush=True)
|
||||
|
||||
if is_valid is False:
|
||||
print(f"INVALID ({error_msg[:40]})")
|
||||
|
|
@ -688,8 +955,13 @@ def run_map_elites_prior(args, client, lrm, output_path):
|
|||
valid_count += 1
|
||||
|
||||
quality, components = validator.cell_quality(
|
||||
code, ast, test_list, cell,
|
||||
alpha=args.alpha, beta=args.beta, gamma=args.gamma,
|
||||
code,
|
||||
ast,
|
||||
test_list,
|
||||
cell,
|
||||
alpha=args.alpha,
|
||||
beta=args.beta,
|
||||
gamma=args.gamma,
|
||||
)
|
||||
problem["_cell"] = sorted(cell)
|
||||
problem["_prior_weight"] = round(prior_w, 4)
|
||||
|
|
@ -719,13 +991,17 @@ def run_map_elites_prior(args, client, lrm, output_path):
|
|||
freq = cmap.node_type_frequency()
|
||||
entropy = cmap.distribution_entropy()
|
||||
kl = prior.kl_divergence(freq)
|
||||
print(f"\n ── Checkpoint ──────────────────────────────────")
|
||||
print("\n ── Checkpoint ──────────────────────────────────")
|
||||
print(f" Dataset: {len(dataset)} | Valid: {valid_count}/{call_count}")
|
||||
print(f" {cmap.fill_summary()}")
|
||||
print(f" KL(dataset ‖ prior): {kl:.4f} (lower = closer to production patterns)")
|
||||
print(f" Top-5 most frequent: {sorted(freq, key=freq.get, reverse=True)[:5]}")
|
||||
print(
|
||||
f" KL(dataset ‖ prior): {kl:.4f} (lower = closer to production patterns)"
|
||||
)
|
||||
print(
|
||||
f" Top-5 most frequent: {sorted(freq, key=freq.get, reverse=True)[:5]}"
|
||||
)
|
||||
print(f" Top-5 least frequent: {sorted(freq, key=freq.get)[:5]}")
|
||||
print(f" ────────────────────────────────────────────────\n")
|
||||
print(" ────────────────────────────────────────────────\n")
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
|
|
@ -755,6 +1031,170 @@ def _save(dataset: list, path: Path, cmap: CoverageMap, prior: ConstructPrior =
|
|||
with open(stats_path, "w", encoding="utf-8") as f:
|
||||
json.dump(stats, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def run_reward(args, client, lrm, output_path):
|
||||
"""Candidate A — CW-Reward with GoldPool feedback loop."""
|
||||
validator = CellValidator(parser_url=args.parser)
|
||||
pool = GoldPool(max_size=args.pool_size)
|
||||
dataset = []
|
||||
task_id = 1
|
||||
call_count = 0
|
||||
valid_count = 0
|
||||
pool_entries = 0
|
||||
|
||||
w_ecs = args.w_ecs
|
||||
w_novelty = args.w_novelty
|
||||
w_tests = args.w_tests
|
||||
|
||||
print(f"\n CW-Reward mode | target: {args.problems} examples")
|
||||
print(f" Weights: ECS={w_ecs} Novelty={w_novelty} Tests={w_tests}")
|
||||
print(f" Pool size: {args.pool_size}")
|
||||
print("─" * 65)
|
||||
|
||||
max_calls = args.problems * 4
|
||||
|
||||
while len(dataset) < args.problems and call_count < max_calls:
|
||||
coverage_text, underrepresented = _reward_coverage_text(dataset)
|
||||
few_shots = pool.few_shot_examples(n=3)
|
||||
call_count += 1
|
||||
|
||||
print(
|
||||
f" [{call_count:04d}] "
|
||||
f"| {pool.pool_summary()} "
|
||||
f"| dataset={len(dataset)} ... ",
|
||||
end="",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
problem = call_api_reward(
|
||||
client,
|
||||
lrm,
|
||||
few_shots,
|
||||
coverage_text,
|
||||
underrepresented,
|
||||
task_id,
|
||||
)
|
||||
|
||||
if problem is None:
|
||||
print("SKIP (generation failed)")
|
||||
continue
|
||||
|
||||
code = problem["code"]
|
||||
test_list = problem.get("test_list", [])
|
||||
|
||||
is_valid, ast, error_msg = validator.parse(code)
|
||||
|
||||
if is_valid is None:
|
||||
is_valid, ast = True, {}
|
||||
if call_count == 1:
|
||||
print(
|
||||
"\n Parser unavailable — using keyword fallback",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if is_valid is False:
|
||||
print(f"INVALID ({error_msg[:40]})")
|
||||
continue
|
||||
|
||||
valid_count += 1
|
||||
detected = validator.detect_constructs(code, ast)
|
||||
|
||||
reward, components = compute_reward(
|
||||
detected,
|
||||
test_list,
|
||||
pool,
|
||||
w_ecs=w_ecs,
|
||||
w_novelty=w_novelty,
|
||||
w_tests=w_tests,
|
||||
)
|
||||
|
||||
problem["_detected"] = sorted(detected)
|
||||
problem["_reward"] = components
|
||||
|
||||
entered = pool.add(problem, reward)
|
||||
if entered:
|
||||
pool_entries += 1
|
||||
|
||||
dataset.append(problem)
|
||||
task_id += 1
|
||||
|
||||
print(
|
||||
f"OK reward={reward:.3f} "
|
||||
f"ecs={components['ecs']:.2f} "
|
||||
f"novelty={components['novelty']:.2f} "
|
||||
f"{'→ POOL' if entered else ''}"
|
||||
)
|
||||
|
||||
if len(dataset) % 50 == 0:
|
||||
_save_reward(dataset, output_path, pool, w_ecs, w_novelty, w_tests)
|
||||
freq: dict[str, int] = defaultdict(int)
|
||||
for ex in dataset:
|
||||
for nt in ex.get("_detected", []):
|
||||
freq[nt] += 1
|
||||
total_f = sum(freq.values())
|
||||
entropy = 0.0
|
||||
if total_f > 0:
|
||||
for count in freq.values():
|
||||
p = count / total_f
|
||||
if p > 0:
|
||||
entropy -= p * math.log2(p)
|
||||
print("\n ── Checkpoint ──────────────────────────────────")
|
||||
print(f" Dataset: {len(dataset)} | Valid: {valid_count}/{call_count}")
|
||||
print(f" {pool.pool_summary()}")
|
||||
print(
|
||||
f" Coverage: {len(freq)}/{len(NODE_TYPE_NAMES)} | Entropy: {entropy:.2f} bits"
|
||||
)
|
||||
uncov = sorted(set(NODE_TYPE_NAMES) - set(freq.keys()))
|
||||
print(f" Uncovered: {uncov[:10]}")
|
||||
print(" ────────────────────────────────────────────────\n")
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
_save_reward(dataset, output_path, pool, w_ecs, w_novelty, w_tests)
|
||||
return dataset, pool, valid_count, call_count
|
||||
|
||||
|
||||
def _save_reward(
|
||||
dataset: list,
|
||||
path: Path,
|
||||
pool: GoldPool,
|
||||
w_ecs: float,
|
||||
w_novelty: float,
|
||||
w_tests: float,
|
||||
):
|
||||
"""Save dataset and stats for Candidate A reward mode."""
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(dataset, f, ensure_ascii=False, indent=2)
|
||||
|
||||
stats_path = path.with_name(path.stem + "_reward_stats.json")
|
||||
freq: dict[str, int] = defaultdict(int)
|
||||
for ex in dataset:
|
||||
for nt in ex.get("_detected", []):
|
||||
freq[nt] += 1
|
||||
total_f = sum(freq.values())
|
||||
entropy = 0.0
|
||||
if total_f > 0:
|
||||
for count in freq.values():
|
||||
p = count / total_f
|
||||
if p > 0:
|
||||
entropy -= p * math.log2(p)
|
||||
rewards = [ex.get("_reward", {}).get("reward", 0) for ex in dataset]
|
||||
stats = {
|
||||
"mode": "reward",
|
||||
"weights": {"w_ecs": w_ecs, "w_novelty": w_novelty, "w_tests": w_tests},
|
||||
"dataset_size": len(dataset),
|
||||
"pool_size": pool.size,
|
||||
"pool_summary": pool.pool_summary(),
|
||||
"distribution_entropy": round(entropy, 3),
|
||||
"node_type_frequency": dict(freq),
|
||||
"covered_constructs": len(freq),
|
||||
"total_constructs": len(NODE_TYPE_NAMES),
|
||||
"mean_reward": round(sum(rewards) / max(len(rewards), 1), 4),
|
||||
}
|
||||
with open(stats_path, "w", encoding="utf-8") as f:
|
||||
json.dump(stats, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="AVAP Dataset Generator v2 — MAP-Elites Quality-Diversity Pipeline"
|
||||
|
|
@ -762,18 +1202,39 @@ def main():
|
|||
parser.add_argument("--lrm", default="avap.md")
|
||||
parser.add_argument("--output", default="output/mbpp_avap_v2.json")
|
||||
parser.add_argument("--problems", type=int, default=5000)
|
||||
parser.add_argument("--parser", default="http://localhost:8080",
|
||||
help="AVAP parser URL")
|
||||
parser.add_argument("--cell-size", type=int, default=3,
|
||||
help="Max constructs per cell: 2=pairs, 3=pairs+trios (default: 3)")
|
||||
parser.add_argument("--quality-threshold", type=float, default=0.80,
|
||||
help="Min quality to consider a cell 'good' (default: 0.80)")
|
||||
parser.add_argument("--alpha", type=float, default=0.30,
|
||||
help="Weight for bonus constructs in cell quality (default: 0.30)")
|
||||
parser.add_argument("--beta", type=float, default=0.20,
|
||||
help="Weight for test quality in cell quality (default: 0.20)")
|
||||
parser.add_argument("--gamma", type=float, default=0.10,
|
||||
help="Weight for code richness in cell quality (default: 0.10)")
|
||||
parser.add_argument(
|
||||
"--parser", default="http://localhost:8080", help="AVAP parser URL"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cell-size",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Max constructs per cell: 2=pairs, 3=pairs+trios (default: 3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quality-threshold",
|
||||
type=float,
|
||||
default=0.80,
|
||||
help="Min quality to consider a cell 'good' (default: 0.80)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alpha",
|
||||
type=float,
|
||||
default=0.30,
|
||||
help="Weight for bonus constructs in cell quality (default: 0.30)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--beta",
|
||||
type=float,
|
||||
default=0.20,
|
||||
help="Weight for test quality in cell quality (default: 0.20)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gamma",
|
||||
type=float,
|
||||
default=0.10,
|
||||
help="Weight for code richness in cell quality (default: 0.10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
choices=["map-elites-prior", "map-elites", "reward"],
|
||||
|
|
@ -809,6 +1270,30 @@ def main():
|
|||
"cells become the focus. Default: 0.70"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--w-ecs",
|
||||
type=float,
|
||||
default=0.50,
|
||||
help="Candidate A: weight for ECS in reward formula (default: 0.50)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--w-novelty",
|
||||
type=float,
|
||||
default=0.35,
|
||||
help="Candidate A: weight for Jaccard novelty in reward formula (default: 0.35)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--w-tests",
|
||||
type=float,
|
||||
default=0.15,
|
||||
help="Candidate A: weight for test quality in reward formula (default: 0.15)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pool-size",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Candidate A: max examples in GoldPool (default: 50)",
|
||||
)
|
||||
parser.add_argument("--api-key", default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -844,36 +1329,72 @@ def main():
|
|||
print(f" Quality thresh : {args.quality_threshold}")
|
||||
if args.mode == "map-elites-prior":
|
||||
yaml_exists = Path(args.prior_map).exists()
|
||||
print(f" Prior map : {args.prior_map} ({'✓ found' if yaml_exists else '✗ not found — will use static fallback'})")
|
||||
print(
|
||||
f" Prior map : {args.prior_map} ({'✓ found' if yaml_exists else '✗ not found — will use static fallback'})"
|
||||
)
|
||||
print(f" Prior epsilon : {args.prior_epsilon}")
|
||||
elif args.mode == "reward":
|
||||
print(
|
||||
f" Reward weights : ECS={args.w_ecs} Novelty={args.w_novelty} Tests={args.w_tests}"
|
||||
)
|
||||
print(f" Pool size : {args.pool_size}")
|
||||
print("=" * 65)
|
||||
|
||||
prior = None
|
||||
pool = None
|
||||
cmap = None
|
||||
|
||||
if args.mode == "map-elites-prior":
|
||||
result = run_map_elites_prior(args, client, lrm, output_path)
|
||||
dataset, cmap, valid_count, call_count, prior = result
|
||||
elif args.mode == "map-elites":
|
||||
dataset, cmap, valid_count, call_count = run_map_elites(args, client, lrm, output_path)
|
||||
dataset, cmap, valid_count, call_count = run_map_elites(
|
||||
args, client, lrm, output_path
|
||||
)
|
||||
else:
|
||||
sys.exit("ERROR: --mode reward (Candidate A) is not yet implemented in v2. "
|
||||
"Use generate_mbap.py for the v1 reward baseline.")
|
||||
dataset, pool, valid_count, call_count = run_reward(
|
||||
args, client, lrm, output_path
|
||||
)
|
||||
|
||||
# Final report
|
||||
if cmap is not None:
|
||||
freq = cmap.node_type_frequency()
|
||||
entropy = cmap.distribution_entropy()
|
||||
else:
|
||||
freq = defaultdict(int)
|
||||
for ex in dataset:
|
||||
for nt in ex.get("_detected", []):
|
||||
freq[nt] += 1
|
||||
freq = dict(freq)
|
||||
total_f = sum(freq.values())
|
||||
entropy = 0.0
|
||||
if total_f > 0:
|
||||
for count in freq.values():
|
||||
p = count / total_f
|
||||
if p > 0:
|
||||
entropy -= p * math.log2(p)
|
||||
entropy = round(entropy, 3)
|
||||
|
||||
print("\n" + "=" * 65)
|
||||
print(" Pipeline complete")
|
||||
print(f" Mode : {mode_label}")
|
||||
print(f" Total API calls : {call_count}")
|
||||
print(f" Valid examples : {valid_count} ({100*valid_count/max(call_count,1):.1f}%)")
|
||||
print(
|
||||
f" Valid examples : {valid_count} ({100 * valid_count / max(call_count, 1):.1f}%)"
|
||||
)
|
||||
print(f" Dataset size : {len(dataset)}")
|
||||
if cmap is not None:
|
||||
print(f" {cmap.fill_summary()}")
|
||||
print(f" Distribution entropy : {entropy:.3f} bits (max={math.log2(len(NODE_TYPE_NAMES)):.2f})")
|
||||
if pool is not None:
|
||||
print(f" {pool.pool_summary()}")
|
||||
print(
|
||||
f" Distribution entropy : {entropy:.3f} bits (max={math.log2(len(NODE_TYPE_NAMES)):.2f})"
|
||||
)
|
||||
if prior is not None:
|
||||
kl = prior.kl_divergence(freq)
|
||||
print(f" KL(dataset ‖ prior) : {kl:.4f} (0 = perfect alignment with production code)")
|
||||
print(
|
||||
f" KL(dataset ‖ prior) : {kl:.4f} (0 = perfect alignment with production code)"
|
||||
)
|
||||
print(f" Most covered : {sorted(freq, key=freq.get, reverse=True)[:5]}")
|
||||
print(f" Least covered : {sorted(freq, key=freq.get)[:5]}")
|
||||
print(f" Output : {output_path}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue