diff --git a/scripts/pipelines/samples_generator/generate_mbap_v2.py b/scripts/pipelines/samples_generator/generate_mbap_v2.py index bb77055..980accc 100644 --- a/scripts/pipelines/samples_generator/generate_mbap_v2.py +++ b/scripts/pipelines/samples_generator/generate_mbap_v2.py @@ -11,6 +11,8 @@ import argparse import json import math import os +import io +import random import sys import time from collections import defaultdict @@ -19,61 +21,63 @@ from pathlib import Path import anthropic import requests +from construct_prior import AVAP_NODE_NAMES, ConstructPrior -from construct_prior import ConstructPrior, AVAP_NODE_NAMES +from src.config import settings +from dotenv import load_dotenv +load_dotenv() AVAP_NODE_TYPES = { - "addParam": ["addParam("], - "addResult": ["addResult("], - "_status": ["_status"], - "addVar": ["addVar("], - "getListLen": ["getListLen("], - "getQueryParamList": ["getQueryParamList("], - "itemFromList": ["itemFromList("], - "replace": ["replace("], - "randomString": ["randomString("], - "if_mode1": ["if("], - "if_mode2": ["if(None, None,"], - "else": ["else()"], - "end": ["end()"], - "startLoop": ["startLoop("], - "endLoop": ["endLoop()"], - "try": ["try()"], - "exception": ["exception()"], - "return": ["return("], - "go": ["go("], - "gather": ["gather("], - "avapConnector": ["avapConnector("], - "ormCheckTable": ["ormCheckTable("], - "ormDirect": ["ormDirect("], - "ormAccessSelect": ["ormAccessSelect("], - "ormAccessInsert": ["ormAccessInsert("], - "ormAccessUpdate": ["ormAccessUpdate("], - "variableFromJSON": ["variableFromJSON("], - "AddVariableToJSON": ["AddVariableToJSON("], - "encodeSHA256": ["encodeSHA256("], - "encodeMD5": ["encodeMD5("], - "getTimeStamp": ["getTimeStamp("], - "getDateTime": ["getDateTime("], - "stampToDatetime": ["stampToDatetime("], - "RequestGet": ["RequestGet("], - "RequestPost": ["RequestPost("], - "function": ["function "], - "import": ["import "], - "include": ["include("], + "addParam": ["addParam("], + "addResult": ["addResult("], + "_status": ["_status"], + "addVar": ["addVar("], + "getListLen": ["getListLen("], + "getQueryParamList": ["getQueryParamList("], + "itemFromList": ["itemFromList("], + "replace": ["replace("], + "randomString": ["randomString("], + "if_mode1": ["if("], + "if_mode2": ["if(None, None,"], + "else": ["else()"], + "end": ["end()"], + "startLoop": ["startLoop("], + "endLoop": ["endLoop()"], + "try": ["try()"], + "exception": ["exception()"], + "return": ["return("], + "go": ["go("], + "gather": ["gather("], + "avapConnector": ["avapConnector("], + "ormCheckTable": ["ormCheckTable("], + "ormDirect": ["ormDirect("], + "ormAccessSelect": ["ormAccessSelect("], + "ormAccessInsert": ["ormAccessInsert("], + "ormAccessUpdate": ["ormAccessUpdate("], + "variableFromJSON": ["variableFromJSON("], + "AddVariableToJSON": ["AddVariableToJSON("], + "encodeSHA256": ["encodeSHA256("], + "encodeMD5": ["encodeMD5("], + "getTimeStamp": ["getTimeStamp("], + "getDateTime": ["getDateTime("], + "stampToDatetime": ["stampToDatetime("], + "RequestGet": ["RequestGet("], + "RequestPost": ["RequestPost("], + "function": ["function "], + "import": ["import "], + "include": ["include("], } NODE_TYPE_NAMES = AVAP_NODE_NAMES -_PRIOR_EPSILON = 0.05 +_PRIOR_EPSILON = 0.05 + class CellValidator: - def __init__(self, parser_url: str, parser_timeout: int = 5): - self.parser_url = parser_url.rstrip("/") + self.parser_url = parser_url.rstrip("/") self.parser_timeout = parser_timeout self._parser_available = True - def parse(self, code: str) -> tuple[bool, dict, str]: if not self._parser_available: @@ -81,7 +85,9 @@ class CellValidator: try: resp = requests.post( f"{self.parser_url}/parse", + #f"{settings.parser_url}/api/v1/upload", json={"code": code}, + #files={"file": ("task.json", io.BytesIO(json.dumps([code]).encode("utf-8")), "application/json")}, timeout=self.parser_timeout, ) data = resp.json() @@ -93,6 +99,7 @@ class CellValidator: return None, {}, "parser_unavailable" except Exception as e: return False, {}, str(e) + def detect_constructs(self, code: str, ast: dict) -> set: if ast: return self._from_ast(ast) @@ -124,7 +131,7 @@ class CellValidator: found.add(name) break return found - + def cell_quality( self, code: str, @@ -132,7 +139,7 @@ class CellValidator: test_list: list, cell: frozenset, alpha: float = 0.3, - beta: float = 0.2, + beta: float = 0.2, gamma: float = 0.1, ) -> tuple[float, dict]: @@ -143,11 +150,12 @@ class CellValidator: present_required = cell_constructs & detected fidelity = len(present_required) / max(len(cell_constructs), 1) - extra = detected - cell_constructs + extra = detected - cell_constructs bonus_ratio = len(extra) / max(len(all_types) - len(cell_constructs), 1) tq = sum( - 1 for t in test_list + 1 + for t in test_list if isinstance(t, str) and "re.match(" in t and len(t.strip()) > 10 ) / max(len(test_list), 1) @@ -169,8 +177,6 @@ class CellValidator: class CoverageMap: - - def __init__(self, cell_size: int = 3): self.cell_size = cell_size @@ -215,10 +221,7 @@ class CoverageMap: return [c for c in self._all_cells if c not in self._map] def get_low_quality_cells(self, threshold: float = 0.7) -> list[frozenset]: - return [ - c for c, (_, q, _) in self._map.items() - if q < threshold - ] + return [c for c, (_, q, _) in self._map.items() if q < threshold] def get_example(self, cell: frozenset) -> dict | None: entry = self._map.get(cell) @@ -228,7 +231,7 @@ class CoverageMap: return [ex for ex, _, _ in self._map.values()] def node_type_frequency(self) -> dict[str, int]: - + freq = defaultdict(int) for cell in self._map: for nt in cell: @@ -236,7 +239,7 @@ class CoverageMap: return dict(freq) def distribution_entropy(self) -> float: - + freq = self.node_type_frequency() total = sum(freq.values()) if total == 0: @@ -254,15 +257,14 @@ class CoverageMap: entropy = self.distribution_entropy() return ( f"Cells: {self.filled_cells}/{self.total_cells} filled " - f"({100*self.fill_rate:.1f}%) | " + f"({100 * self.fill_rate:.1f}%) | " f"Low quality: {low} | " f"Empty: {empty} | " f"Entropy: {entropy:.2f} bits" ) + class CellSelector: - - def __init__( self, coverage_map: CoverageMap, @@ -274,6 +276,7 @@ class CellSelector: self.ucb_c = ucb_c self._total_calls = 0 import random + self._rng = random.Random(42) def select(self) -> frozenset: @@ -300,12 +303,12 @@ class CellSelector: score = quality + self.ucb_c * math.sqrt(math.log(total) / attempts) if score > best_score: best_score = score - best_cell = cell + best_cell = cell return best_cell + class CellSelectorPrior(CellSelector): - def __init__( self, coverage_map: CoverageMap, @@ -326,8 +329,7 @@ class CellSelectorPrior(CellSelector): if empty: high_prior_empty = [ - c for c in empty - if self.prior.cell_weight(c) > self.prior.epsilon * 1.5 + c for c in empty if self.prior.cell_weight(c) > self.prior.epsilon * 1.5 ] if high_prior_empty: return self._weighted_sample(high_prior_empty) @@ -353,7 +355,7 @@ class CellSelectorPrior(CellSelector): return cells[-1] def _ucb_prior_select(self, cells) -> frozenset: - + best_cell = None best_score = -float("inf") total = max(self._total_calls, 1) @@ -367,10 +369,111 @@ class CellSelectorPrior(CellSelector): score = prior_w * (quality + ucb_term) if score > best_score: best_score = score - best_cell = cell + best_cell = cell return best_cell + +class GoldPool: + """Top-K pool of high-reward examples for Candidate A (CW-Reward).""" + + def __init__(self, max_size: int = 50, seed: int = 42): + self._pool: list[tuple[dict, float]] = [] + self.max_size = max_size + self._rng = random.Random(seed) + + @property + def size(self) -> int: + return len(self._pool) + + def add(self, example: dict, reward: float) -> bool: + """Add example to pool. Returns True if it entered the pool.""" + if len(self._pool) < self.max_size: + self._pool.append((example, reward)) + self._pool.sort(key=lambda x: -x[1]) + return True + if reward > self._pool[-1][1]: + self._pool[-1] = (example, reward) + self._pool.sort(key=lambda x: -x[1]) + return True + return False + + def few_shot_examples(self, n: int = 3) -> list[dict]: + """Return n diverse examples from the pool for few-shot prompting.""" + if not self._pool: + return [] + if len(self._pool) <= n: + return [ex for ex, _ in self._pool] + top_half = self._pool[: max(len(self._pool) // 2, n)] + sampled = self._rng.sample(top_half, n) + return [ex for ex, _ in sampled] + + def construct_sets(self) -> list[set[str]]: + """Return the construct sets of all pool examples.""" + return [set(ex.get("_detected", [])) for ex, _ in self._pool] + + def pool_summary(self) -> str: + if not self._pool: + return "GoldPool: empty" + rewards = [r for _, r in self._pool] + return ( + f"GoldPool: {len(self._pool)}/{self.max_size} | " + f"reward: min={min(rewards):.3f} max={max(rewards):.3f} " + f"mean={sum(rewards) / len(rewards):.3f}" + ) + + +def compute_ecs(detected: set[str]) -> float: + """Execution Coverage Score — fraction of all AVAP node types covered.""" + return len(detected) / max(len(NODE_TYPE_NAMES), 1) + + +def jaccard_similarity(a: set, b: set) -> float: + """Jaccard similarity between two sets.""" + union = a | b + if not union: + return 1.0 + return len(a & b) / len(union) + + +def jaccard_novelty(detected: set[str], pool: GoldPool) -> float: + """Novelty = 1 - max Jaccard similarity with any pool example.""" + pool_sets = pool.construct_sets() + if not pool_sets: + return 1.0 + max_sim = max(jaccard_similarity(detected, ps) for ps in pool_sets) + return 1.0 - max_sim + + +def compute_reward( + detected: set[str], + test_list: list, + pool: GoldPool, + w_ecs: float, + w_novelty: float, + w_tests: float, +) -> tuple[float, dict]: + """ + Candidate A composite reward: + reward(e) = w_ecs * ECS(e) + w_novelty * Jaccard_novelty(e, Pool) + w_tests * test_quality(e) + """ + ecs = compute_ecs(detected) + novelty = jaccard_novelty(detected, pool) + tq = sum( + 1 + for t in test_list + if isinstance(t, str) and "re.match(" in t and len(t.strip()) > 10 + ) / max(len(test_list), 1) + reward = w_ecs * ecs + w_novelty * novelty + w_tests * tq + return reward, { + "ecs": round(ecs, 3), + "novelty": round(novelty, 3), + "test_quality": round(tq, 3), + "reward": round(reward, 3), + "detected": sorted(detected), + } + + SYSTEM_PROMPT = """Eres un experto en el lenguaje AVAP. Se te proporciona el Language Reference Manual (LRM) completo de AVAP. Tu tarea es generar UN problema de benchmark estilo MBPP para evaluar @@ -420,7 +523,7 @@ El siguiente ejemplo YA existe para esta combinación con calidad mejorable. Genera algo DISTINTO y MÁS COMPLEJO que lo supere: ``` -{existing_example.get('code', '')} +{existing_example.get("code", "")} ``` """ @@ -455,6 +558,80 @@ Responde ÚNICAMENTE con el objeto JSON. Sin texto antes ni después. """ +def _reward_coverage_text(dataset: list) -> tuple[str, list[str]]: + """Build coverage summary and list of underrepresented constructs.""" + freq: dict[str, int] = defaultdict(int) + for ex in dataset: + for nt in ex.get("_detected", []): + freq[nt] += 1 + if not freq: + return ( + "No examples generated yet. Cover as many AVAP constructs as possible.", + list(NODE_TYPE_NAMES), + ) + covered = set(freq.keys()) + uncovered = sorted(set(NODE_TYPE_NAMES) - covered) + least_covered = sorted(freq, key=freq.get)[:10] + underrepresented = uncovered + least_covered + lines = [ + f"Total examples: {len(dataset)}", + f"Covered constructs: {len(covered)}/{len(NODE_TYPE_NAMES)}", + ] + if uncovered: + lines.append(f"UNCOVERED (prioritize): {', '.join(uncovered)}") + if least_covered: + lines.append(f"Least common: {', '.join(least_covered)}") + return "\n".join(lines), underrepresented + + +def build_reward_prompt( + lrm: str, + few_shots: list[dict], + coverage_text: str, + underrepresented: list[str], +) -> str: + """Build the user prompt for Candidate A reward mode.""" + few_shot_block = "" + if few_shots: + few_shot_block = "\n# EJEMPLOS DE ALTA CALIDAD (referencia)\n\n" + for i, ex in enumerate(few_shots, 1): + few_shot_block += f"## Ejemplo {i}\n" + few_shot_block += f"Enunciado: {ex.get('text', '')}\n" + few_shot_block += f"```\n{ex.get('code', '')}\n```\n\n" + steer_text = "" + if underrepresented: + constructs_str = ", ".join(f"`{c}`" for c in underrepresented[:8]) + steer_text = ( + f"\nPRIORIDAD: Constructs subrepresentados — intenta incluir " + f"algunos de ellos: {constructs_str}\n" + ) + return f"""# LRM AVAP — Language Reference Manual + +{lrm} + +--- + +# ESTADO DE COBERTURA DEL DATASET + +{coverage_text} +{few_shot_block} +--- + +# TAREA + +Genera UN ejemplo AVAP original y complejo como problema de benchmark. + +Requisitos: +- Escenario realista de microservicio HTTP en AVAP +- Usa la mayor cantidad de constructs AVAP distintos posible (aumenta la puntuación) +- El ejemplo debe ser DIFERENTE a los ejemplos de referencia mostrados arriba +- Código complejo y rico — no ejemplos triviales de 3 líneas +- 2-3 aserciones re.match() en test_list +{steer_text} +Responde ÚNICAMENTE con el objeto JSON. Sin texto antes ni después. +""" + + def call_api( client: anthropic.Anthropic, lrm: str, @@ -471,16 +648,22 @@ def call_api( model="claude-sonnet-4-20250514", max_tokens=4000, system=SYSTEM_PROMPT, - messages=[{ - "role": "user", - "content": build_cell_prompt(lrm, cell, existing_example, map_summary), - }], + messages=[ + { + "role": "user", + "content": build_cell_prompt( + lrm, cell, existing_example, map_summary + ), + } + ], ) raw = message.content[0].text.strip() if raw.startswith("```"): lines = raw.splitlines() - raw = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:]) + raw = "\n".join( + lines[1:-1] if lines[-1].strip() == "```" else lines[1:] + ) problem = json.loads(raw) if not isinstance(problem, dict): @@ -496,7 +679,7 @@ def call_api( except (json.JSONDecodeError, ValueError) as e: print(f"\n Attempt {attempt}/{retries} — parse error: {e}") if attempt < retries: - time.sleep(2 ** attempt) + time.sleep(2**attempt) except anthropic.RateLimitError: wait = 30 * attempt print(f"\n Rate limit — waiting {wait}s...") @@ -509,6 +692,65 @@ def call_api( return None +def call_api_reward( + client: anthropic.Anthropic, + lrm: str, + few_shots: list[dict], + coverage_text: str, + underrepresented: list[str], + task_id: int, + retries: int = 3, +) -> dict | None: + """API call for Candidate A reward mode.""" + for attempt in range(1, retries + 1): + try: + message = client.messages.create( + model="claude-sonnet-4-20250514", + max_tokens=4000, + system=SYSTEM_PROMPT, + messages=[ + { + "role": "user", + "content": build_reward_prompt( + lrm, + few_shots, + coverage_text, + underrepresented, + ), + } + ], + ) + raw = message.content[0].text.strip() + if raw.startswith("```"): + lines = raw.splitlines() + raw = "\n".join( + lines[1:-1] if lines[-1].strip() == "```" else lines[1:] + ) + problem = json.loads(raw) + if not isinstance(problem, dict): + raise ValueError("Response is not a JSON object") + for field in ("text", "code", "test_list"): + if field not in problem: + raise ValueError(f"Missing field '{field}'") + if "test_inputs" not in problem: + problem["test_inputs"] = {} + problem["task_id"] = task_id + return problem + except (json.JSONDecodeError, ValueError) as e: + print(f"\n Attempt {attempt}/{retries} — parse error: {e}") + if attempt < retries: + time.sleep(2**attempt) + except anthropic.RateLimitError: + wait = 30 * attempt + print(f"\n Rate limit — waiting {wait}s...") + time.sleep(wait) + except anthropic.APIError as e: + print(f"\n API error at attempt {attempt}: {e}") + if attempt < retries: + time.sleep(5) + return None + + def run_map_elites(args, client, lrm, output_path): validator = CellValidator(parser_url=args.parser) @@ -520,14 +762,17 @@ def run_map_elites(args, client, lrm, output_path): valid_count = 0 cell_updates = 0 - print(f"\n MAP-Elites mode | cells: {cmap.total_cells} | target: {args.problems} examples") - print(f" Cell size: {args.cell_size} | Quality threshold: {args.quality_threshold}") + print( + f"\n MAP-Elites mode | cells: {cmap.total_cells} | target: {args.problems} examples" + ) + print( + f" Cell size: {args.cell_size} | Quality threshold: {args.quality_threshold}" + ) print("─" * 65) - max_calls = args.problems * 4 + max_calls = args.problems * 4 while len(dataset) < args.problems and call_count < max_calls: - cell = selector.select() existing = cmap.get_example(cell) call_count += 1 @@ -536,11 +781,15 @@ def run_map_elites(args, client, lrm, output_path): f" [{call_count:04d}] Cell {sorted(cell)} " f"| filled={cmap.filled_cells}/{cmap.total_cells} " f"| dataset={len(dataset)} ... ", - end="", flush=True, + end="", + flush=True, ) problem = call_api( - client, lrm, cell, task_id, + client, + lrm, + cell, + task_id, existing_example=existing, map_summary=cmap.fill_summary(), ) @@ -549,7 +798,7 @@ def run_map_elites(args, client, lrm, output_path): print("SKIP (generation failed)") continue - code = problem["code"] + code = problem["code"] test_list = problem.get("test_list", []) is_valid, ast, error_msg = validator.parse(code) @@ -557,7 +806,7 @@ def run_map_elites(args, client, lrm, output_path): if is_valid is None: is_valid, ast = True, {} if call_count == 1: - print(f"\n Parser unavailable — using keyword fallback", flush=True) + print("\n Parser unavailable — using keyword fallback", flush=True) if is_valid is False: print(f"INVALID ({error_msg[:40]})") @@ -568,8 +817,13 @@ def run_map_elites(args, client, lrm, output_path): # Compute cell quality quality, components = validator.cell_quality( - code, ast, test_list, cell, - alpha=args.alpha, beta=args.beta, gamma=args.gamma, + code, + ast, + test_list, + cell, + alpha=args.alpha, + beta=args.beta, + gamma=args.gamma, ) problem["_cell"] = sorted(cell) problem["_quality"] = components @@ -596,23 +850,26 @@ def run_map_elites(args, client, lrm, output_path): _save(dataset, output_path, cmap) freq = cmap.node_type_frequency() entropy = cmap.distribution_entropy() - print(f"\n ── Checkpoint ──────────────────────────────────") + print("\n ── Checkpoint ──────────────────────────────────") print(f" Dataset: {len(dataset)} | Valid: {valid_count}/{call_count}") print(f" {cmap.fill_summary()}") - print(f" Top-5 most frequent: {sorted(freq, key=freq.get, reverse=True)[:5]}") + print( + f" Top-5 most frequent: {sorted(freq, key=freq.get, reverse=True)[:5]}" + ) print(f" Top-5 least frequent: {sorted(freq, key=freq.get)[:5]}") - print(f" ────────────────────────────────────────────────\n") + print(" ────────────────────────────────────────────────\n") time.sleep(0.5) _save(dataset, output_path, cmap) return dataset, cmap, valid_count, call_count + def run_map_elites_prior(args, client, lrm, output_path): print("\n Loading ConstructPrior...", flush=True) - prior_map = getattr(args, "prior_map","construct_map.yaml") - epsilon = getattr(args, "prior_epsilon", _PRIOR_EPSILON) + prior_map = getattr(args, "prior_map", "construct_map.yaml") + epsilon = getattr(args, "prior_epsilon", _PRIOR_EPSILON) yaml_path = Path(prior_map) if yaml_path.exists(): @@ -620,16 +877,19 @@ def run_map_elites_prior(args, client, lrm, output_path): else: # Fallback: yaml not found — use static prior and warn print(f" [WARN] construct_map.yaml not found at '{yaml_path}'.") - print(f" [WARN] Using static fallback prior. Generate the real prior with:") - print(f" [WARN] python construct_prior.py --generate-map --github-token TOKEN") + print(" [WARN] Using static fallback prior. Generate the real prior with:") + print( + " [WARN] python construct_prior.py --generate-map --github-token TOKEN" + ) prior = ConstructPrior.from_static_fallback(epsilon=epsilon) print(f" {prior.coverage_summary()}") - validator = CellValidator(parser_url=args.parser) - cmap = CoverageMap(cell_size=args.cell_size) - selector = CellSelectorPrior( - cmap, prior, + validator = CellValidator(parser_url=args.parser) + cmap = CoverageMap(cell_size=args.cell_size) + selector = CellSelectorPrior( + cmap, + prior, quality_threshold=args.quality_threshold, phase3_threshold=getattr(args, "prior_phase3_threshold", 0.70), ) @@ -639,14 +899,17 @@ def run_map_elites_prior(args, client, lrm, output_path): valid_count = 0 cell_updates = 0 - print(f"\n MAP-Elites+Prior mode | cells: {cmap.total_cells} | target: {args.problems} examples") - print(f" Cell size: {args.cell_size} | Quality threshold: {args.quality_threshold}") + print( + f"\n MAP-Elites+Prior mode | cells: {cmap.total_cells} | target: {args.problems} examples" + ) + print( + f" Cell size: {args.cell_size} | Quality threshold: {args.quality_threshold}" + ) print("─" * 65) max_calls = args.problems * 4 while len(dataset) < args.problems and call_count < max_calls: - cell = selector.select() existing = cmap.get_example(cell) prior_w = prior.cell_weight(cell) @@ -657,11 +920,15 @@ def run_map_elites_prior(args, client, lrm, output_path): f"| prior={prior_w:.3f} " f"| filled={cmap.filled_cells}/{cmap.total_cells} " f"| dataset={len(dataset)} ... ", - end="", flush=True, + end="", + flush=True, ) problem = call_api( - client, lrm, cell, task_id, + client, + lrm, + cell, + task_id, existing_example=existing, map_summary=cmap.fill_summary(), ) @@ -670,7 +937,7 @@ def run_map_elites_prior(args, client, lrm, output_path): print("SKIP (generation failed)") continue - code = problem["code"] + code = problem["code"] test_list = problem.get("test_list", []) is_valid, ast, error_msg = validator.parse(code) @@ -678,7 +945,7 @@ def run_map_elites_prior(args, client, lrm, output_path): if is_valid is None: is_valid, ast = True, {} if call_count == 1: - print(f"\n Parser unavailable — using keyword fallback", flush=True) + print("\n Parser unavailable — using keyword fallback", flush=True) if is_valid is False: print(f"INVALID ({error_msg[:40]})") @@ -688,8 +955,13 @@ def run_map_elites_prior(args, client, lrm, output_path): valid_count += 1 quality, components = validator.cell_quality( - code, ast, test_list, cell, - alpha=args.alpha, beta=args.beta, gamma=args.gamma, + code, + ast, + test_list, + cell, + alpha=args.alpha, + beta=args.beta, + gamma=args.gamma, ) problem["_cell"] = sorted(cell) problem["_prior_weight"] = round(prior_w, 4) @@ -716,16 +988,20 @@ def run_map_elites_prior(args, client, lrm, output_path): if len(dataset) % 50 == 0: _save(dataset, output_path, cmap, prior=prior) - freq = cmap.node_type_frequency() + freq = cmap.node_type_frequency() entropy = cmap.distribution_entropy() - kl = prior.kl_divergence(freq) - print(f"\n ── Checkpoint ──────────────────────────────────") + kl = prior.kl_divergence(freq) + print("\n ── Checkpoint ──────────────────────────────────") print(f" Dataset: {len(dataset)} | Valid: {valid_count}/{call_count}") print(f" {cmap.fill_summary()}") - print(f" KL(dataset ‖ prior): {kl:.4f} (lower = closer to production patterns)") - print(f" Top-5 most frequent: {sorted(freq, key=freq.get, reverse=True)[:5]}") + print( + f" KL(dataset ‖ prior): {kl:.4f} (lower = closer to production patterns)" + ) + print( + f" Top-5 most frequent: {sorted(freq, key=freq.get, reverse=True)[:5]}" + ) print(f" Top-5 least frequent: {sorted(freq, key=freq.get)[:5]}") - print(f" ────────────────────────────────────────────────\n") + print(" ────────────────────────────────────────────────\n") time.sleep(0.5) @@ -739,7 +1015,7 @@ def _save(dataset: list, path: Path, cmap: CoverageMap, prior: ConstructPrior = # Save coverage map statistics alongside dataset stats_path = path.with_name(path.stem + "_coverage_stats.json") - freq = cmap.node_type_frequency() + freq = cmap.node_type_frequency() stats = { "total_cells": cmap.total_cells, "filled_cells": cmap.filled_cells, @@ -755,25 +1031,210 @@ def _save(dataset: list, path: Path, cmap: CoverageMap, prior: ConstructPrior = with open(stats_path, "w", encoding="utf-8") as f: json.dump(stats, f, ensure_ascii=False, indent=2) + +def run_reward(args, client, lrm, output_path): + """Candidate A — CW-Reward with GoldPool feedback loop.""" + validator = CellValidator(parser_url=args.parser) + pool = GoldPool(max_size=args.pool_size) + dataset = [] + task_id = 1 + call_count = 0 + valid_count = 0 + pool_entries = 0 + + w_ecs = args.w_ecs + w_novelty = args.w_novelty + w_tests = args.w_tests + + print(f"\n CW-Reward mode | target: {args.problems} examples") + print(f" Weights: ECS={w_ecs} Novelty={w_novelty} Tests={w_tests}") + print(f" Pool size: {args.pool_size}") + print("─" * 65) + + max_calls = args.problems * 4 + + while len(dataset) < args.problems and call_count < max_calls: + coverage_text, underrepresented = _reward_coverage_text(dataset) + few_shots = pool.few_shot_examples(n=3) + call_count += 1 + + print( + f" [{call_count:04d}] " + f"| {pool.pool_summary()} " + f"| dataset={len(dataset)} ... ", + end="", + flush=True, + ) + + problem = call_api_reward( + client, + lrm, + few_shots, + coverage_text, + underrepresented, + task_id, + ) + + if problem is None: + print("SKIP (generation failed)") + continue + + code = problem["code"] + test_list = problem.get("test_list", []) + + is_valid, ast, error_msg = validator.parse(code) + + if is_valid is None: + is_valid, ast = True, {} + if call_count == 1: + print( + "\n Parser unavailable — using keyword fallback", + flush=True, + ) + + if is_valid is False: + print(f"INVALID ({error_msg[:40]})") + continue + + valid_count += 1 + detected = validator.detect_constructs(code, ast) + + reward, components = compute_reward( + detected, + test_list, + pool, + w_ecs=w_ecs, + w_novelty=w_novelty, + w_tests=w_tests, + ) + + problem["_detected"] = sorted(detected) + problem["_reward"] = components + + entered = pool.add(problem, reward) + if entered: + pool_entries += 1 + + dataset.append(problem) + task_id += 1 + + print( + f"OK reward={reward:.3f} " + f"ecs={components['ecs']:.2f} " + f"novelty={components['novelty']:.2f} " + f"{'→ POOL' if entered else ''}" + ) + + if len(dataset) % 50 == 0: + _save_reward(dataset, output_path, pool, w_ecs, w_novelty, w_tests) + freq: dict[str, int] = defaultdict(int) + for ex in dataset: + for nt in ex.get("_detected", []): + freq[nt] += 1 + total_f = sum(freq.values()) + entropy = 0.0 + if total_f > 0: + for count in freq.values(): + p = count / total_f + if p > 0: + entropy -= p * math.log2(p) + print("\n ── Checkpoint ──────────────────────────────────") + print(f" Dataset: {len(dataset)} | Valid: {valid_count}/{call_count}") + print(f" {pool.pool_summary()}") + print( + f" Coverage: {len(freq)}/{len(NODE_TYPE_NAMES)} | Entropy: {entropy:.2f} bits" + ) + uncov = sorted(set(NODE_TYPE_NAMES) - set(freq.keys())) + print(f" Uncovered: {uncov[:10]}") + print(" ────────────────────────────────────────────────\n") + + time.sleep(0.5) + + _save_reward(dataset, output_path, pool, w_ecs, w_novelty, w_tests) + return dataset, pool, valid_count, call_count + + +def _save_reward( + dataset: list, + path: Path, + pool: GoldPool, + w_ecs: float, + w_novelty: float, + w_tests: float, +): + """Save dataset and stats for Candidate A reward mode.""" + with open(path, "w", encoding="utf-8") as f: + json.dump(dataset, f, ensure_ascii=False, indent=2) + + stats_path = path.with_name(path.stem + "_reward_stats.json") + freq: dict[str, int] = defaultdict(int) + for ex in dataset: + for nt in ex.get("_detected", []): + freq[nt] += 1 + total_f = sum(freq.values()) + entropy = 0.0 + if total_f > 0: + for count in freq.values(): + p = count / total_f + if p > 0: + entropy -= p * math.log2(p) + rewards = [ex.get("_reward", {}).get("reward", 0) for ex in dataset] + stats = { + "mode": "reward", + "weights": {"w_ecs": w_ecs, "w_novelty": w_novelty, "w_tests": w_tests}, + "dataset_size": len(dataset), + "pool_size": pool.size, + "pool_summary": pool.pool_summary(), + "distribution_entropy": round(entropy, 3), + "node_type_frequency": dict(freq), + "covered_constructs": len(freq), + "total_constructs": len(NODE_TYPE_NAMES), + "mean_reward": round(sum(rewards) / max(len(rewards), 1), 4), + } + with open(stats_path, "w", encoding="utf-8") as f: + json.dump(stats, f, ensure_ascii=False, indent=2) + + def main(): parser = argparse.ArgumentParser( description="AVAP Dataset Generator v2 — MAP-Elites Quality-Diversity Pipeline" ) - parser.add_argument("--lrm", default="avap.md") - parser.add_argument("--output", default="output/mbpp_avap_v2.json") - parser.add_argument("--problems", type=int, default=5000) - parser.add_argument("--parser", default="http://localhost:8080", - help="AVAP parser URL") - parser.add_argument("--cell-size", type=int, default=3, - help="Max constructs per cell: 2=pairs, 3=pairs+trios (default: 3)") - parser.add_argument("--quality-threshold", type=float, default=0.80, - help="Min quality to consider a cell 'good' (default: 0.80)") - parser.add_argument("--alpha", type=float, default=0.30, - help="Weight for bonus constructs in cell quality (default: 0.30)") - parser.add_argument("--beta", type=float, default=0.20, - help="Weight for test quality in cell quality (default: 0.20)") - parser.add_argument("--gamma", type=float, default=0.10, - help="Weight for code richness in cell quality (default: 0.10)") + parser.add_argument("--lrm", default="avap.md") + parser.add_argument("--output", default="output/mbpp_avap_v2.json") + parser.add_argument("--problems", type=int, default=5000) + parser.add_argument( + "--parser", default="http://localhost:8080", help="AVAP parser URL" + ) + parser.add_argument( + "--cell-size", + type=int, + default=3, + help="Max constructs per cell: 2=pairs, 3=pairs+trios (default: 3)", + ) + parser.add_argument( + "--quality-threshold", + type=float, + default=0.80, + help="Min quality to consider a cell 'good' (default: 0.80)", + ) + parser.add_argument( + "--alpha", + type=float, + default=0.30, + help="Weight for bonus constructs in cell quality (default: 0.30)", + ) + parser.add_argument( + "--beta", + type=float, + default=0.20, + help="Weight for test quality in cell quality (default: 0.20)", + ) + parser.add_argument( + "--gamma", + type=float, + default=0.10, + help="Weight for code richness in cell quality (default: 0.10)", + ) parser.add_argument( "--mode", choices=["map-elites-prior", "map-elites", "reward"], @@ -809,6 +1270,30 @@ def main(): "cells become the focus. Default: 0.70" ), ) + parser.add_argument( + "--w-ecs", + type=float, + default=0.50, + help="Candidate A: weight for ECS in reward formula (default: 0.50)", + ) + parser.add_argument( + "--w-novelty", + type=float, + default=0.35, + help="Candidate A: weight for Jaccard novelty in reward formula (default: 0.35)", + ) + parser.add_argument( + "--w-tests", + type=float, + default=0.15, + help="Candidate A: weight for test quality in reward formula (default: 0.15)", + ) + parser.add_argument( + "--pool-size", + type=int, + default=5, + help="Candidate A: max examples in GoldPool (default: 50)", + ) parser.add_argument("--api-key", default=None) args = parser.parse_args() @@ -844,36 +1329,72 @@ def main(): print(f" Quality thresh : {args.quality_threshold}") if args.mode == "map-elites-prior": yaml_exists = Path(args.prior_map).exists() - print(f" Prior map : {args.prior_map} ({'✓ found' if yaml_exists else '✗ not found — will use static fallback'})") + print( + f" Prior map : {args.prior_map} ({'✓ found' if yaml_exists else '✗ not found — will use static fallback'})" + ) print(f" Prior epsilon : {args.prior_epsilon}") + elif args.mode == "reward": + print( + f" Reward weights : ECS={args.w_ecs} Novelty={args.w_novelty} Tests={args.w_tests}" + ) + print(f" Pool size : {args.pool_size}") print("=" * 65) prior = None + pool = None + cmap = None if args.mode == "map-elites-prior": result = run_map_elites_prior(args, client, lrm, output_path) dataset, cmap, valid_count, call_count, prior = result elif args.mode == "map-elites": - dataset, cmap, valid_count, call_count = run_map_elites(args, client, lrm, output_path) + dataset, cmap, valid_count, call_count = run_map_elites( + args, client, lrm, output_path + ) else: - sys.exit("ERROR: --mode reward (Candidate A) is not yet implemented in v2. " - "Use generate_mbap.py for the v1 reward baseline.") + dataset, pool, valid_count, call_count = run_reward( + args, client, lrm, output_path + ) # Final report - freq = cmap.node_type_frequency() - entropy = cmap.distribution_entropy() + if cmap is not None: + freq = cmap.node_type_frequency() + entropy = cmap.distribution_entropy() + else: + freq = defaultdict(int) + for ex in dataset: + for nt in ex.get("_detected", []): + freq[nt] += 1 + freq = dict(freq) + total_f = sum(freq.values()) + entropy = 0.0 + if total_f > 0: + for count in freq.values(): + p = count / total_f + if p > 0: + entropy -= p * math.log2(p) + entropy = round(entropy, 3) print("\n" + "=" * 65) print(" Pipeline complete") print(f" Mode : {mode_label}") print(f" Total API calls : {call_count}") - print(f" Valid examples : {valid_count} ({100*valid_count/max(call_count,1):.1f}%)") + print( + f" Valid examples : {valid_count} ({100 * valid_count / max(call_count, 1):.1f}%)" + ) print(f" Dataset size : {len(dataset)}") - print(f" {cmap.fill_summary()}") - print(f" Distribution entropy : {entropy:.3f} bits (max={math.log2(len(NODE_TYPE_NAMES)):.2f})") + if cmap is not None: + print(f" {cmap.fill_summary()}") + if pool is not None: + print(f" {pool.pool_summary()}") + print( + f" Distribution entropy : {entropy:.3f} bits (max={math.log2(len(NODE_TYPE_NAMES)):.2f})" + ) if prior is not None: kl = prior.kl_divergence(freq) - print(f" KL(dataset ‖ prior) : {kl:.4f} (0 = perfect alignment with production code)") + print( + f" KL(dataset ‖ prior) : {kl:.4f} (0 = perfect alignment with production code)" + ) print(f" Most covered : {sorted(freq, key=freq.get, reverse=True)[:5]}") print(f" Least covered : {sorted(freq, key=freq.get)[:5]}") print(f" Output : {output_path}")