Merge branch 'mrh-online-dev' of github.com:BRUNIX-AI/assistance-engine into mrh-online-dev
This commit is contained in:
commit
04fa15ff1e
|
|
@ -261,8 +261,6 @@
|
||||||
"print(\"Recall:\", recall_qwen_2)\n",
|
"print(\"Recall:\", recall_qwen_2)\n",
|
||||||
"print(\"Precision:\", precision_qwen_2)"
|
"print(\"Precision:\", precision_qwen_2)"
|
||||||
]
|
]
|
||||||
<<<<<<< HEAD
|
|
||||||
=======
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
|
|
@ -308,7 +306,6 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
>>>>>>> 4b5352d93cf89b7562895b550fb5bd62160586c5
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,8 @@ import argparse
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import io
|
||||||
|
import random
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
@ -19,8 +21,9 @@ from pathlib import Path
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
import requests
|
import requests
|
||||||
|
from construct_prior import AVAP_NODE_NAMES, ConstructPrior
|
||||||
|
|
||||||
from construct_prior import ConstructPrior, AVAP_NODE_NAMES
|
from src.config import settings
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
@ -68,14 +71,13 @@ AVAP_NODE_TYPES = {
|
||||||
NODE_TYPE_NAMES = AVAP_NODE_NAMES
|
NODE_TYPE_NAMES = AVAP_NODE_NAMES
|
||||||
_PRIOR_EPSILON = 0.05
|
_PRIOR_EPSILON = 0.05
|
||||||
|
|
||||||
class CellValidator:
|
|
||||||
|
|
||||||
|
class CellValidator:
|
||||||
def __init__(self, parser_url: str, parser_timeout: int = 5):
|
def __init__(self, parser_url: str, parser_timeout: int = 5):
|
||||||
self.parser_url = parser_url.rstrip("/")
|
self.parser_url = parser_url.rstrip("/")
|
||||||
self.parser_timeout = parser_timeout
|
self.parser_timeout = parser_timeout
|
||||||
self._parser_available = True
|
self._parser_available = True
|
||||||
|
|
||||||
|
|
||||||
def parse(self, code: str) -> tuple[bool, dict, str]:
|
def parse(self, code: str) -> tuple[bool, dict, str]:
|
||||||
|
|
||||||
if not self._parser_available:
|
if not self._parser_available:
|
||||||
|
|
@ -83,7 +85,9 @@ class CellValidator:
|
||||||
try:
|
try:
|
||||||
resp = requests.post(
|
resp = requests.post(
|
||||||
f"{self.parser_url}/parse",
|
f"{self.parser_url}/parse",
|
||||||
|
#f"{settings.parser_url}/api/v1/upload",
|
||||||
json={"code": code},
|
json={"code": code},
|
||||||
|
#files={"file": ("task.json", io.BytesIO(json.dumps([code]).encode("utf-8")), "application/json")},
|
||||||
timeout=self.parser_timeout,
|
timeout=self.parser_timeout,
|
||||||
)
|
)
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
|
|
@ -95,6 +99,7 @@ class CellValidator:
|
||||||
return None, {}, "parser_unavailable"
|
return None, {}, "parser_unavailable"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return False, {}, str(e)
|
return False, {}, str(e)
|
||||||
|
|
||||||
def detect_constructs(self, code: str, ast: dict) -> set:
|
def detect_constructs(self, code: str, ast: dict) -> set:
|
||||||
if ast:
|
if ast:
|
||||||
return self._from_ast(ast)
|
return self._from_ast(ast)
|
||||||
|
|
@ -149,7 +154,8 @@ class CellValidator:
|
||||||
bonus_ratio = len(extra) / max(len(all_types) - len(cell_constructs), 1)
|
bonus_ratio = len(extra) / max(len(all_types) - len(cell_constructs), 1)
|
||||||
|
|
||||||
tq = sum(
|
tq = sum(
|
||||||
1 for t in test_list
|
1
|
||||||
|
for t in test_list
|
||||||
if isinstance(t, str) and "re.match(" in t and len(t.strip()) > 10
|
if isinstance(t, str) and "re.match(" in t and len(t.strip()) > 10
|
||||||
) / max(len(test_list), 1)
|
) / max(len(test_list), 1)
|
||||||
|
|
||||||
|
|
@ -171,8 +177,6 @@ class CellValidator:
|
||||||
|
|
||||||
|
|
||||||
class CoverageMap:
|
class CoverageMap:
|
||||||
|
|
||||||
|
|
||||||
def __init__(self, cell_size: int = 3):
|
def __init__(self, cell_size: int = 3):
|
||||||
|
|
||||||
self.cell_size = cell_size
|
self.cell_size = cell_size
|
||||||
|
|
@ -217,10 +221,7 @@ class CoverageMap:
|
||||||
return [c for c in self._all_cells if c not in self._map]
|
return [c for c in self._all_cells if c not in self._map]
|
||||||
|
|
||||||
def get_low_quality_cells(self, threshold: float = 0.7) -> list[frozenset]:
|
def get_low_quality_cells(self, threshold: float = 0.7) -> list[frozenset]:
|
||||||
return [
|
return [c for c, (_, q, _) in self._map.items() if q < threshold]
|
||||||
c for c, (_, q, _) in self._map.items()
|
|
||||||
if q < threshold
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_example(self, cell: frozenset) -> dict | None:
|
def get_example(self, cell: frozenset) -> dict | None:
|
||||||
entry = self._map.get(cell)
|
entry = self._map.get(cell)
|
||||||
|
|
@ -262,9 +263,8 @@ class CoverageMap:
|
||||||
f"Entropy: {entropy:.2f} bits"
|
f"Entropy: {entropy:.2f} bits"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CellSelector:
|
class CellSelector:
|
||||||
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
coverage_map: CoverageMap,
|
coverage_map: CoverageMap,
|
||||||
|
|
@ -276,6 +276,7 @@ class CellSelector:
|
||||||
self.ucb_c = ucb_c
|
self.ucb_c = ucb_c
|
||||||
self._total_calls = 0
|
self._total_calls = 0
|
||||||
import random
|
import random
|
||||||
|
|
||||||
self._rng = random.Random(42)
|
self._rng = random.Random(42)
|
||||||
|
|
||||||
def select(self) -> frozenset:
|
def select(self) -> frozenset:
|
||||||
|
|
@ -306,8 +307,8 @@ class CellSelector:
|
||||||
|
|
||||||
return best_cell
|
return best_cell
|
||||||
|
|
||||||
class CellSelectorPrior(CellSelector):
|
|
||||||
|
|
||||||
|
class CellSelectorPrior(CellSelector):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
coverage_map: CoverageMap,
|
coverage_map: CoverageMap,
|
||||||
|
|
@ -328,8 +329,7 @@ class CellSelectorPrior(CellSelector):
|
||||||
|
|
||||||
if empty:
|
if empty:
|
||||||
high_prior_empty = [
|
high_prior_empty = [
|
||||||
c for c in empty
|
c for c in empty if self.prior.cell_weight(c) > self.prior.epsilon * 1.5
|
||||||
if self.prior.cell_weight(c) > self.prior.epsilon * 1.5
|
|
||||||
]
|
]
|
||||||
if high_prior_empty:
|
if high_prior_empty:
|
||||||
return self._weighted_sample(high_prior_empty)
|
return self._weighted_sample(high_prior_empty)
|
||||||
|
|
@ -373,6 +373,107 @@ class CellSelectorPrior(CellSelector):
|
||||||
|
|
||||||
return best_cell
|
return best_cell
|
||||||
|
|
||||||
|
|
||||||
|
class GoldPool:
|
||||||
|
"""Top-K pool of high-reward examples for Candidate A (CW-Reward)."""
|
||||||
|
|
||||||
|
def __init__(self, max_size: int = 50, seed: int = 42):
|
||||||
|
self._pool: list[tuple[dict, float]] = []
|
||||||
|
self.max_size = max_size
|
||||||
|
self._rng = random.Random(seed)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def size(self) -> int:
|
||||||
|
return len(self._pool)
|
||||||
|
|
||||||
|
def add(self, example: dict, reward: float) -> bool:
|
||||||
|
"""Add example to pool. Returns True if it entered the pool."""
|
||||||
|
if len(self._pool) < self.max_size:
|
||||||
|
self._pool.append((example, reward))
|
||||||
|
self._pool.sort(key=lambda x: -x[1])
|
||||||
|
return True
|
||||||
|
if reward > self._pool[-1][1]:
|
||||||
|
self._pool[-1] = (example, reward)
|
||||||
|
self._pool.sort(key=lambda x: -x[1])
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def few_shot_examples(self, n: int = 3) -> list[dict]:
|
||||||
|
"""Return n diverse examples from the pool for few-shot prompting."""
|
||||||
|
if not self._pool:
|
||||||
|
return []
|
||||||
|
if len(self._pool) <= n:
|
||||||
|
return [ex for ex, _ in self._pool]
|
||||||
|
top_half = self._pool[: max(len(self._pool) // 2, n)]
|
||||||
|
sampled = self._rng.sample(top_half, n)
|
||||||
|
return [ex for ex, _ in sampled]
|
||||||
|
|
||||||
|
def construct_sets(self) -> list[set[str]]:
|
||||||
|
"""Return the construct sets of all pool examples."""
|
||||||
|
return [set(ex.get("_detected", [])) for ex, _ in self._pool]
|
||||||
|
|
||||||
|
def pool_summary(self) -> str:
|
||||||
|
if not self._pool:
|
||||||
|
return "GoldPool: empty"
|
||||||
|
rewards = [r for _, r in self._pool]
|
||||||
|
return (
|
||||||
|
f"GoldPool: {len(self._pool)}/{self.max_size} | "
|
||||||
|
f"reward: min={min(rewards):.3f} max={max(rewards):.3f} "
|
||||||
|
f"mean={sum(rewards) / len(rewards):.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_ecs(detected: set[str]) -> float:
|
||||||
|
"""Execution Coverage Score — fraction of all AVAP node types covered."""
|
||||||
|
return len(detected) / max(len(NODE_TYPE_NAMES), 1)
|
||||||
|
|
||||||
|
|
||||||
|
def jaccard_similarity(a: set, b: set) -> float:
|
||||||
|
"""Jaccard similarity between two sets."""
|
||||||
|
union = a | b
|
||||||
|
if not union:
|
||||||
|
return 1.0
|
||||||
|
return len(a & b) / len(union)
|
||||||
|
|
||||||
|
|
||||||
|
def jaccard_novelty(detected: set[str], pool: GoldPool) -> float:
|
||||||
|
"""Novelty = 1 - max Jaccard similarity with any pool example."""
|
||||||
|
pool_sets = pool.construct_sets()
|
||||||
|
if not pool_sets:
|
||||||
|
return 1.0
|
||||||
|
max_sim = max(jaccard_similarity(detected, ps) for ps in pool_sets)
|
||||||
|
return 1.0 - max_sim
|
||||||
|
|
||||||
|
|
||||||
|
def compute_reward(
|
||||||
|
detected: set[str],
|
||||||
|
test_list: list,
|
||||||
|
pool: GoldPool,
|
||||||
|
w_ecs: float,
|
||||||
|
w_novelty: float,
|
||||||
|
w_tests: float,
|
||||||
|
) -> tuple[float, dict]:
|
||||||
|
"""
|
||||||
|
Candidate A composite reward:
|
||||||
|
reward(e) = w_ecs * ECS(e) + w_novelty * Jaccard_novelty(e, Pool) + w_tests * test_quality(e)
|
||||||
|
"""
|
||||||
|
ecs = compute_ecs(detected)
|
||||||
|
novelty = jaccard_novelty(detected, pool)
|
||||||
|
tq = sum(
|
||||||
|
1
|
||||||
|
for t in test_list
|
||||||
|
if isinstance(t, str) and "re.match(" in t and len(t.strip()) > 10
|
||||||
|
) / max(len(test_list), 1)
|
||||||
|
reward = w_ecs * ecs + w_novelty * novelty + w_tests * tq
|
||||||
|
return reward, {
|
||||||
|
"ecs": round(ecs, 3),
|
||||||
|
"novelty": round(novelty, 3),
|
||||||
|
"test_quality": round(tq, 3),
|
||||||
|
"reward": round(reward, 3),
|
||||||
|
"detected": sorted(detected),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
SYSTEM_PROMPT = """Eres un experto en el lenguaje AVAP.
|
SYSTEM_PROMPT = """Eres un experto en el lenguaje AVAP.
|
||||||
Se te proporciona el Language Reference Manual (LRM) completo de AVAP.
|
Se te proporciona el Language Reference Manual (LRM) completo de AVAP.
|
||||||
Tu tarea es generar UN problema de benchmark estilo MBPP para evaluar
|
Tu tarea es generar UN problema de benchmark estilo MBPP para evaluar
|
||||||
|
|
@ -422,7 +523,7 @@ El siguiente ejemplo YA existe para esta combinación con calidad mejorable.
|
||||||
Genera algo DISTINTO y MÁS COMPLEJO que lo supere:
|
Genera algo DISTINTO y MÁS COMPLEJO que lo supere:
|
||||||
|
|
||||||
```
|
```
|
||||||
{existing_example.get('code', '')}
|
{existing_example.get("code", "")}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -457,6 +558,80 @@ Responde ÚNICAMENTE con el objeto JSON. Sin texto antes ni después.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _reward_coverage_text(dataset: list) -> tuple[str, list[str]]:
|
||||||
|
"""Build coverage summary and list of underrepresented constructs."""
|
||||||
|
freq: dict[str, int] = defaultdict(int)
|
||||||
|
for ex in dataset:
|
||||||
|
for nt in ex.get("_detected", []):
|
||||||
|
freq[nt] += 1
|
||||||
|
if not freq:
|
||||||
|
return (
|
||||||
|
"No examples generated yet. Cover as many AVAP constructs as possible.",
|
||||||
|
list(NODE_TYPE_NAMES),
|
||||||
|
)
|
||||||
|
covered = set(freq.keys())
|
||||||
|
uncovered = sorted(set(NODE_TYPE_NAMES) - covered)
|
||||||
|
least_covered = sorted(freq, key=freq.get)[:10]
|
||||||
|
underrepresented = uncovered + least_covered
|
||||||
|
lines = [
|
||||||
|
f"Total examples: {len(dataset)}",
|
||||||
|
f"Covered constructs: {len(covered)}/{len(NODE_TYPE_NAMES)}",
|
||||||
|
]
|
||||||
|
if uncovered:
|
||||||
|
lines.append(f"UNCOVERED (prioritize): {', '.join(uncovered)}")
|
||||||
|
if least_covered:
|
||||||
|
lines.append(f"Least common: {', '.join(least_covered)}")
|
||||||
|
return "\n".join(lines), underrepresented
|
||||||
|
|
||||||
|
|
||||||
|
def build_reward_prompt(
|
||||||
|
lrm: str,
|
||||||
|
few_shots: list[dict],
|
||||||
|
coverage_text: str,
|
||||||
|
underrepresented: list[str],
|
||||||
|
) -> str:
|
||||||
|
"""Build the user prompt for Candidate A reward mode."""
|
||||||
|
few_shot_block = ""
|
||||||
|
if few_shots:
|
||||||
|
few_shot_block = "\n# EJEMPLOS DE ALTA CALIDAD (referencia)\n\n"
|
||||||
|
for i, ex in enumerate(few_shots, 1):
|
||||||
|
few_shot_block += f"## Ejemplo {i}\n"
|
||||||
|
few_shot_block += f"Enunciado: {ex.get('text', '')}\n"
|
||||||
|
few_shot_block += f"```\n{ex.get('code', '')}\n```\n\n"
|
||||||
|
steer_text = ""
|
||||||
|
if underrepresented:
|
||||||
|
constructs_str = ", ".join(f"`{c}`" for c in underrepresented[:8])
|
||||||
|
steer_text = (
|
||||||
|
f"\nPRIORIDAD: Constructs subrepresentados — intenta incluir "
|
||||||
|
f"algunos de ellos: {constructs_str}\n"
|
||||||
|
)
|
||||||
|
return f"""# LRM AVAP — Language Reference Manual
|
||||||
|
|
||||||
|
{lrm}
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# ESTADO DE COBERTURA DEL DATASET
|
||||||
|
|
||||||
|
{coverage_text}
|
||||||
|
{few_shot_block}
|
||||||
|
---
|
||||||
|
|
||||||
|
# TAREA
|
||||||
|
|
||||||
|
Genera UN ejemplo AVAP original y complejo como problema de benchmark.
|
||||||
|
|
||||||
|
Requisitos:
|
||||||
|
- Escenario realista de microservicio HTTP en AVAP
|
||||||
|
- Usa la mayor cantidad de constructs AVAP distintos posible (aumenta la puntuación)
|
||||||
|
- El ejemplo debe ser DIFERENTE a los ejemplos de referencia mostrados arriba
|
||||||
|
- Código complejo y rico — no ejemplos triviales de 3 líneas
|
||||||
|
- 2-3 aserciones re.match() en test_list
|
||||||
|
{steer_text}
|
||||||
|
Responde ÚNICAMENTE con el objeto JSON. Sin texto antes ni después.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def call_api(
|
def call_api(
|
||||||
client: anthropic.Anthropic,
|
client: anthropic.Anthropic,
|
||||||
lrm: str,
|
lrm: str,
|
||||||
|
|
@ -473,16 +648,22 @@ def call_api(
|
||||||
model="claude-sonnet-4-20250514",
|
model="claude-sonnet-4-20250514",
|
||||||
max_tokens=4000,
|
max_tokens=4000,
|
||||||
system=SYSTEM_PROMPT,
|
system=SYSTEM_PROMPT,
|
||||||
messages=[{
|
messages=[
|
||||||
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": build_cell_prompt(lrm, cell, existing_example, map_summary),
|
"content": build_cell_prompt(
|
||||||
}],
|
lrm, cell, existing_example, map_summary
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
)
|
)
|
||||||
raw = message.content[0].text.strip()
|
raw = message.content[0].text.strip()
|
||||||
|
|
||||||
if raw.startswith("```"):
|
if raw.startswith("```"):
|
||||||
lines = raw.splitlines()
|
lines = raw.splitlines()
|
||||||
raw = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:])
|
raw = "\n".join(
|
||||||
|
lines[1:-1] if lines[-1].strip() == "```" else lines[1:]
|
||||||
|
)
|
||||||
|
|
||||||
problem = json.loads(raw)
|
problem = json.loads(raw)
|
||||||
if not isinstance(problem, dict):
|
if not isinstance(problem, dict):
|
||||||
|
|
@ -511,6 +692,65 @@ def call_api(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def call_api_reward(
|
||||||
|
client: anthropic.Anthropic,
|
||||||
|
lrm: str,
|
||||||
|
few_shots: list[dict],
|
||||||
|
coverage_text: str,
|
||||||
|
underrepresented: list[str],
|
||||||
|
task_id: int,
|
||||||
|
retries: int = 3,
|
||||||
|
) -> dict | None:
|
||||||
|
"""API call for Candidate A reward mode."""
|
||||||
|
for attempt in range(1, retries + 1):
|
||||||
|
try:
|
||||||
|
message = client.messages.create(
|
||||||
|
model="claude-sonnet-4-20250514",
|
||||||
|
max_tokens=4000,
|
||||||
|
system=SYSTEM_PROMPT,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": build_reward_prompt(
|
||||||
|
lrm,
|
||||||
|
few_shots,
|
||||||
|
coverage_text,
|
||||||
|
underrepresented,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
raw = message.content[0].text.strip()
|
||||||
|
if raw.startswith("```"):
|
||||||
|
lines = raw.splitlines()
|
||||||
|
raw = "\n".join(
|
||||||
|
lines[1:-1] if lines[-1].strip() == "```" else lines[1:]
|
||||||
|
)
|
||||||
|
problem = json.loads(raw)
|
||||||
|
if not isinstance(problem, dict):
|
||||||
|
raise ValueError("Response is not a JSON object")
|
||||||
|
for field in ("text", "code", "test_list"):
|
||||||
|
if field not in problem:
|
||||||
|
raise ValueError(f"Missing field '{field}'")
|
||||||
|
if "test_inputs" not in problem:
|
||||||
|
problem["test_inputs"] = {}
|
||||||
|
problem["task_id"] = task_id
|
||||||
|
return problem
|
||||||
|
except (json.JSONDecodeError, ValueError) as e:
|
||||||
|
print(f"\n Attempt {attempt}/{retries} — parse error: {e}")
|
||||||
|
if attempt < retries:
|
||||||
|
time.sleep(2**attempt)
|
||||||
|
except anthropic.RateLimitError:
|
||||||
|
wait = 30 * attempt
|
||||||
|
print(f"\n Rate limit — waiting {wait}s...")
|
||||||
|
time.sleep(wait)
|
||||||
|
except anthropic.APIError as e:
|
||||||
|
print(f"\n API error at attempt {attempt}: {e}")
|
||||||
|
if attempt < retries:
|
||||||
|
time.sleep(5)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def run_map_elites(args, client, lrm, output_path):
|
def run_map_elites(args, client, lrm, output_path):
|
||||||
|
|
||||||
validator = CellValidator(parser_url=args.parser)
|
validator = CellValidator(parser_url=args.parser)
|
||||||
|
|
@ -522,14 +762,17 @@ def run_map_elites(args, client, lrm, output_path):
|
||||||
valid_count = 0
|
valid_count = 0
|
||||||
cell_updates = 0
|
cell_updates = 0
|
||||||
|
|
||||||
print(f"\n MAP-Elites mode | cells: {cmap.total_cells} | target: {args.problems} examples")
|
print(
|
||||||
print(f" Cell size: {args.cell_size} | Quality threshold: {args.quality_threshold}")
|
f"\n MAP-Elites mode | cells: {cmap.total_cells} | target: {args.problems} examples"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Cell size: {args.cell_size} | Quality threshold: {args.quality_threshold}"
|
||||||
|
)
|
||||||
print("─" * 65)
|
print("─" * 65)
|
||||||
|
|
||||||
max_calls = args.problems * 4
|
max_calls = args.problems * 4
|
||||||
|
|
||||||
while len(dataset) < args.problems and call_count < max_calls:
|
while len(dataset) < args.problems and call_count < max_calls:
|
||||||
|
|
||||||
cell = selector.select()
|
cell = selector.select()
|
||||||
existing = cmap.get_example(cell)
|
existing = cmap.get_example(cell)
|
||||||
call_count += 1
|
call_count += 1
|
||||||
|
|
@ -538,11 +781,15 @@ def run_map_elites(args, client, lrm, output_path):
|
||||||
f" [{call_count:04d}] Cell {sorted(cell)} "
|
f" [{call_count:04d}] Cell {sorted(cell)} "
|
||||||
f"| filled={cmap.filled_cells}/{cmap.total_cells} "
|
f"| filled={cmap.filled_cells}/{cmap.total_cells} "
|
||||||
f"| dataset={len(dataset)} ... ",
|
f"| dataset={len(dataset)} ... ",
|
||||||
end="", flush=True,
|
end="",
|
||||||
|
flush=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
problem = call_api(
|
problem = call_api(
|
||||||
client, lrm, cell, task_id,
|
client,
|
||||||
|
lrm,
|
||||||
|
cell,
|
||||||
|
task_id,
|
||||||
existing_example=existing,
|
existing_example=existing,
|
||||||
map_summary=cmap.fill_summary(),
|
map_summary=cmap.fill_summary(),
|
||||||
)
|
)
|
||||||
|
|
@ -559,7 +806,7 @@ def run_map_elites(args, client, lrm, output_path):
|
||||||
if is_valid is None:
|
if is_valid is None:
|
||||||
is_valid, ast = True, {}
|
is_valid, ast = True, {}
|
||||||
if call_count == 1:
|
if call_count == 1:
|
||||||
print(f"\n Parser unavailable — using keyword fallback", flush=True)
|
print("\n Parser unavailable — using keyword fallback", flush=True)
|
||||||
|
|
||||||
if is_valid is False:
|
if is_valid is False:
|
||||||
print(f"INVALID ({error_msg[:40]})")
|
print(f"INVALID ({error_msg[:40]})")
|
||||||
|
|
@ -570,8 +817,13 @@ def run_map_elites(args, client, lrm, output_path):
|
||||||
|
|
||||||
# Compute cell quality
|
# Compute cell quality
|
||||||
quality, components = validator.cell_quality(
|
quality, components = validator.cell_quality(
|
||||||
code, ast, test_list, cell,
|
code,
|
||||||
alpha=args.alpha, beta=args.beta, gamma=args.gamma,
|
ast,
|
||||||
|
test_list,
|
||||||
|
cell,
|
||||||
|
alpha=args.alpha,
|
||||||
|
beta=args.beta,
|
||||||
|
gamma=args.gamma,
|
||||||
)
|
)
|
||||||
problem["_cell"] = sorted(cell)
|
problem["_cell"] = sorted(cell)
|
||||||
problem["_quality"] = components
|
problem["_quality"] = components
|
||||||
|
|
@ -598,18 +850,21 @@ def run_map_elites(args, client, lrm, output_path):
|
||||||
_save(dataset, output_path, cmap)
|
_save(dataset, output_path, cmap)
|
||||||
freq = cmap.node_type_frequency()
|
freq = cmap.node_type_frequency()
|
||||||
entropy = cmap.distribution_entropy()
|
entropy = cmap.distribution_entropy()
|
||||||
print(f"\n ── Checkpoint ──────────────────────────────────")
|
print("\n ── Checkpoint ──────────────────────────────────")
|
||||||
print(f" Dataset: {len(dataset)} | Valid: {valid_count}/{call_count}")
|
print(f" Dataset: {len(dataset)} | Valid: {valid_count}/{call_count}")
|
||||||
print(f" {cmap.fill_summary()}")
|
print(f" {cmap.fill_summary()}")
|
||||||
print(f" Top-5 most frequent: {sorted(freq, key=freq.get, reverse=True)[:5]}")
|
print(
|
||||||
|
f" Top-5 most frequent: {sorted(freq, key=freq.get, reverse=True)[:5]}"
|
||||||
|
)
|
||||||
print(f" Top-5 least frequent: {sorted(freq, key=freq.get)[:5]}")
|
print(f" Top-5 least frequent: {sorted(freq, key=freq.get)[:5]}")
|
||||||
print(f" ────────────────────────────────────────────────\n")
|
print(" ────────────────────────────────────────────────\n")
|
||||||
|
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
|
||||||
_save(dataset, output_path, cmap)
|
_save(dataset, output_path, cmap)
|
||||||
return dataset, cmap, valid_count, call_count
|
return dataset, cmap, valid_count, call_count
|
||||||
|
|
||||||
|
|
||||||
def run_map_elites_prior(args, client, lrm, output_path):
|
def run_map_elites_prior(args, client, lrm, output_path):
|
||||||
|
|
||||||
print("\n Loading ConstructPrior...", flush=True)
|
print("\n Loading ConstructPrior...", flush=True)
|
||||||
|
|
@ -622,8 +877,10 @@ def run_map_elites_prior(args, client, lrm, output_path):
|
||||||
else:
|
else:
|
||||||
# Fallback: yaml not found — use static prior and warn
|
# Fallback: yaml not found — use static prior and warn
|
||||||
print(f" [WARN] construct_map.yaml not found at '{yaml_path}'.")
|
print(f" [WARN] construct_map.yaml not found at '{yaml_path}'.")
|
||||||
print(f" [WARN] Using static fallback prior. Generate the real prior with:")
|
print(" [WARN] Using static fallback prior. Generate the real prior with:")
|
||||||
print(f" [WARN] python construct_prior.py --generate-map --github-token TOKEN")
|
print(
|
||||||
|
" [WARN] python construct_prior.py --generate-map --github-token TOKEN"
|
||||||
|
)
|
||||||
prior = ConstructPrior.from_static_fallback(epsilon=epsilon)
|
prior = ConstructPrior.from_static_fallback(epsilon=epsilon)
|
||||||
|
|
||||||
print(f" {prior.coverage_summary()}")
|
print(f" {prior.coverage_summary()}")
|
||||||
|
|
@ -631,7 +888,8 @@ def run_map_elites_prior(args, client, lrm, output_path):
|
||||||
validator = CellValidator(parser_url=args.parser)
|
validator = CellValidator(parser_url=args.parser)
|
||||||
cmap = CoverageMap(cell_size=args.cell_size)
|
cmap = CoverageMap(cell_size=args.cell_size)
|
||||||
selector = CellSelectorPrior(
|
selector = CellSelectorPrior(
|
||||||
cmap, prior,
|
cmap,
|
||||||
|
prior,
|
||||||
quality_threshold=args.quality_threshold,
|
quality_threshold=args.quality_threshold,
|
||||||
phase3_threshold=getattr(args, "prior_phase3_threshold", 0.70),
|
phase3_threshold=getattr(args, "prior_phase3_threshold", 0.70),
|
||||||
)
|
)
|
||||||
|
|
@ -641,14 +899,17 @@ def run_map_elites_prior(args, client, lrm, output_path):
|
||||||
valid_count = 0
|
valid_count = 0
|
||||||
cell_updates = 0
|
cell_updates = 0
|
||||||
|
|
||||||
print(f"\n MAP-Elites+Prior mode | cells: {cmap.total_cells} | target: {args.problems} examples")
|
print(
|
||||||
print(f" Cell size: {args.cell_size} | Quality threshold: {args.quality_threshold}")
|
f"\n MAP-Elites+Prior mode | cells: {cmap.total_cells} | target: {args.problems} examples"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Cell size: {args.cell_size} | Quality threshold: {args.quality_threshold}"
|
||||||
|
)
|
||||||
print("─" * 65)
|
print("─" * 65)
|
||||||
|
|
||||||
max_calls = args.problems * 4
|
max_calls = args.problems * 4
|
||||||
|
|
||||||
while len(dataset) < args.problems and call_count < max_calls:
|
while len(dataset) < args.problems and call_count < max_calls:
|
||||||
|
|
||||||
cell = selector.select()
|
cell = selector.select()
|
||||||
existing = cmap.get_example(cell)
|
existing = cmap.get_example(cell)
|
||||||
prior_w = prior.cell_weight(cell)
|
prior_w = prior.cell_weight(cell)
|
||||||
|
|
@ -659,11 +920,15 @@ def run_map_elites_prior(args, client, lrm, output_path):
|
||||||
f"| prior={prior_w:.3f} "
|
f"| prior={prior_w:.3f} "
|
||||||
f"| filled={cmap.filled_cells}/{cmap.total_cells} "
|
f"| filled={cmap.filled_cells}/{cmap.total_cells} "
|
||||||
f"| dataset={len(dataset)} ... ",
|
f"| dataset={len(dataset)} ... ",
|
||||||
end="", flush=True,
|
end="",
|
||||||
|
flush=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
problem = call_api(
|
problem = call_api(
|
||||||
client, lrm, cell, task_id,
|
client,
|
||||||
|
lrm,
|
||||||
|
cell,
|
||||||
|
task_id,
|
||||||
existing_example=existing,
|
existing_example=existing,
|
||||||
map_summary=cmap.fill_summary(),
|
map_summary=cmap.fill_summary(),
|
||||||
)
|
)
|
||||||
|
|
@ -680,7 +945,7 @@ def run_map_elites_prior(args, client, lrm, output_path):
|
||||||
if is_valid is None:
|
if is_valid is None:
|
||||||
is_valid, ast = True, {}
|
is_valid, ast = True, {}
|
||||||
if call_count == 1:
|
if call_count == 1:
|
||||||
print(f"\n Parser unavailable — using keyword fallback", flush=True)
|
print("\n Parser unavailable — using keyword fallback", flush=True)
|
||||||
|
|
||||||
if is_valid is False:
|
if is_valid is False:
|
||||||
print(f"INVALID ({error_msg[:40]})")
|
print(f"INVALID ({error_msg[:40]})")
|
||||||
|
|
@ -690,8 +955,13 @@ def run_map_elites_prior(args, client, lrm, output_path):
|
||||||
valid_count += 1
|
valid_count += 1
|
||||||
|
|
||||||
quality, components = validator.cell_quality(
|
quality, components = validator.cell_quality(
|
||||||
code, ast, test_list, cell,
|
code,
|
||||||
alpha=args.alpha, beta=args.beta, gamma=args.gamma,
|
ast,
|
||||||
|
test_list,
|
||||||
|
cell,
|
||||||
|
alpha=args.alpha,
|
||||||
|
beta=args.beta,
|
||||||
|
gamma=args.gamma,
|
||||||
)
|
)
|
||||||
problem["_cell"] = sorted(cell)
|
problem["_cell"] = sorted(cell)
|
||||||
problem["_prior_weight"] = round(prior_w, 4)
|
problem["_prior_weight"] = round(prior_w, 4)
|
||||||
|
|
@ -721,13 +991,17 @@ def run_map_elites_prior(args, client, lrm, output_path):
|
||||||
freq = cmap.node_type_frequency()
|
freq = cmap.node_type_frequency()
|
||||||
entropy = cmap.distribution_entropy()
|
entropy = cmap.distribution_entropy()
|
||||||
kl = prior.kl_divergence(freq)
|
kl = prior.kl_divergence(freq)
|
||||||
print(f"\n ── Checkpoint ──────────────────────────────────")
|
print("\n ── Checkpoint ──────────────────────────────────")
|
||||||
print(f" Dataset: {len(dataset)} | Valid: {valid_count}/{call_count}")
|
print(f" Dataset: {len(dataset)} | Valid: {valid_count}/{call_count}")
|
||||||
print(f" {cmap.fill_summary()}")
|
print(f" {cmap.fill_summary()}")
|
||||||
print(f" KL(dataset ‖ prior): {kl:.4f} (lower = closer to production patterns)")
|
print(
|
||||||
print(f" Top-5 most frequent: {sorted(freq, key=freq.get, reverse=True)[:5]}")
|
f" KL(dataset ‖ prior): {kl:.4f} (lower = closer to production patterns)"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Top-5 most frequent: {sorted(freq, key=freq.get, reverse=True)[:5]}"
|
||||||
|
)
|
||||||
print(f" Top-5 least frequent: {sorted(freq, key=freq.get)[:5]}")
|
print(f" Top-5 least frequent: {sorted(freq, key=freq.get)[:5]}")
|
||||||
print(f" ────────────────────────────────────────────────\n")
|
print(" ────────────────────────────────────────────────\n")
|
||||||
|
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
|
@ -757,6 +1031,170 @@ def _save(dataset: list, path: Path, cmap: CoverageMap, prior: ConstructPrior =
|
||||||
with open(stats_path, "w", encoding="utf-8") as f:
|
with open(stats_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(stats, f, ensure_ascii=False, indent=2)
|
json.dump(stats, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def run_reward(args, client, lrm, output_path):
|
||||||
|
"""Candidate A — CW-Reward with GoldPool feedback loop."""
|
||||||
|
validator = CellValidator(parser_url=args.parser)
|
||||||
|
pool = GoldPool(max_size=args.pool_size)
|
||||||
|
dataset = []
|
||||||
|
task_id = 1
|
||||||
|
call_count = 0
|
||||||
|
valid_count = 0
|
||||||
|
pool_entries = 0
|
||||||
|
|
||||||
|
w_ecs = args.w_ecs
|
||||||
|
w_novelty = args.w_novelty
|
||||||
|
w_tests = args.w_tests
|
||||||
|
|
||||||
|
print(f"\n CW-Reward mode | target: {args.problems} examples")
|
||||||
|
print(f" Weights: ECS={w_ecs} Novelty={w_novelty} Tests={w_tests}")
|
||||||
|
print(f" Pool size: {args.pool_size}")
|
||||||
|
print("─" * 65)
|
||||||
|
|
||||||
|
max_calls = args.problems * 4
|
||||||
|
|
||||||
|
while len(dataset) < args.problems and call_count < max_calls:
|
||||||
|
coverage_text, underrepresented = _reward_coverage_text(dataset)
|
||||||
|
few_shots = pool.few_shot_examples(n=3)
|
||||||
|
call_count += 1
|
||||||
|
|
||||||
|
print(
|
||||||
|
f" [{call_count:04d}] "
|
||||||
|
f"| {pool.pool_summary()} "
|
||||||
|
f"| dataset={len(dataset)} ... ",
|
||||||
|
end="",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
problem = call_api_reward(
|
||||||
|
client,
|
||||||
|
lrm,
|
||||||
|
few_shots,
|
||||||
|
coverage_text,
|
||||||
|
underrepresented,
|
||||||
|
task_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if problem is None:
|
||||||
|
print("SKIP (generation failed)")
|
||||||
|
continue
|
||||||
|
|
||||||
|
code = problem["code"]
|
||||||
|
test_list = problem.get("test_list", [])
|
||||||
|
|
||||||
|
is_valid, ast, error_msg = validator.parse(code)
|
||||||
|
|
||||||
|
if is_valid is None:
|
||||||
|
is_valid, ast = True, {}
|
||||||
|
if call_count == 1:
|
||||||
|
print(
|
||||||
|
"\n Parser unavailable — using keyword fallback",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_valid is False:
|
||||||
|
print(f"INVALID ({error_msg[:40]})")
|
||||||
|
continue
|
||||||
|
|
||||||
|
valid_count += 1
|
||||||
|
detected = validator.detect_constructs(code, ast)
|
||||||
|
|
||||||
|
reward, components = compute_reward(
|
||||||
|
detected,
|
||||||
|
test_list,
|
||||||
|
pool,
|
||||||
|
w_ecs=w_ecs,
|
||||||
|
w_novelty=w_novelty,
|
||||||
|
w_tests=w_tests,
|
||||||
|
)
|
||||||
|
|
||||||
|
problem["_detected"] = sorted(detected)
|
||||||
|
problem["_reward"] = components
|
||||||
|
|
||||||
|
entered = pool.add(problem, reward)
|
||||||
|
if entered:
|
||||||
|
pool_entries += 1
|
||||||
|
|
||||||
|
dataset.append(problem)
|
||||||
|
task_id += 1
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"OK reward={reward:.3f} "
|
||||||
|
f"ecs={components['ecs']:.2f} "
|
||||||
|
f"novelty={components['novelty']:.2f} "
|
||||||
|
f"{'→ POOL' if entered else ''}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(dataset) % 50 == 0:
|
||||||
|
_save_reward(dataset, output_path, pool, w_ecs, w_novelty, w_tests)
|
||||||
|
freq: dict[str, int] = defaultdict(int)
|
||||||
|
for ex in dataset:
|
||||||
|
for nt in ex.get("_detected", []):
|
||||||
|
freq[nt] += 1
|
||||||
|
total_f = sum(freq.values())
|
||||||
|
entropy = 0.0
|
||||||
|
if total_f > 0:
|
||||||
|
for count in freq.values():
|
||||||
|
p = count / total_f
|
||||||
|
if p > 0:
|
||||||
|
entropy -= p * math.log2(p)
|
||||||
|
print("\n ── Checkpoint ──────────────────────────────────")
|
||||||
|
print(f" Dataset: {len(dataset)} | Valid: {valid_count}/{call_count}")
|
||||||
|
print(f" {pool.pool_summary()}")
|
||||||
|
print(
|
||||||
|
f" Coverage: {len(freq)}/{len(NODE_TYPE_NAMES)} | Entropy: {entropy:.2f} bits"
|
||||||
|
)
|
||||||
|
uncov = sorted(set(NODE_TYPE_NAMES) - set(freq.keys()))
|
||||||
|
print(f" Uncovered: {uncov[:10]}")
|
||||||
|
print(" ────────────────────────────────────────────────\n")
|
||||||
|
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
_save_reward(dataset, output_path, pool, w_ecs, w_novelty, w_tests)
|
||||||
|
return dataset, pool, valid_count, call_count
|
||||||
|
|
||||||
|
|
||||||
|
def _save_reward(
|
||||||
|
dataset: list,
|
||||||
|
path: Path,
|
||||||
|
pool: GoldPool,
|
||||||
|
w_ecs: float,
|
||||||
|
w_novelty: float,
|
||||||
|
w_tests: float,
|
||||||
|
):
|
||||||
|
"""Save dataset and stats for Candidate A reward mode."""
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(dataset, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
stats_path = path.with_name(path.stem + "_reward_stats.json")
|
||||||
|
freq: dict[str, int] = defaultdict(int)
|
||||||
|
for ex in dataset:
|
||||||
|
for nt in ex.get("_detected", []):
|
||||||
|
freq[nt] += 1
|
||||||
|
total_f = sum(freq.values())
|
||||||
|
entropy = 0.0
|
||||||
|
if total_f > 0:
|
||||||
|
for count in freq.values():
|
||||||
|
p = count / total_f
|
||||||
|
if p > 0:
|
||||||
|
entropy -= p * math.log2(p)
|
||||||
|
rewards = [ex.get("_reward", {}).get("reward", 0) for ex in dataset]
|
||||||
|
stats = {
|
||||||
|
"mode": "reward",
|
||||||
|
"weights": {"w_ecs": w_ecs, "w_novelty": w_novelty, "w_tests": w_tests},
|
||||||
|
"dataset_size": len(dataset),
|
||||||
|
"pool_size": pool.size,
|
||||||
|
"pool_summary": pool.pool_summary(),
|
||||||
|
"distribution_entropy": round(entropy, 3),
|
||||||
|
"node_type_frequency": dict(freq),
|
||||||
|
"covered_constructs": len(freq),
|
||||||
|
"total_constructs": len(NODE_TYPE_NAMES),
|
||||||
|
"mean_reward": round(sum(rewards) / max(len(rewards), 1), 4),
|
||||||
|
}
|
||||||
|
with open(stats_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(stats, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="AVAP Dataset Generator v2 — MAP-Elites Quality-Diversity Pipeline"
|
description="AVAP Dataset Generator v2 — MAP-Elites Quality-Diversity Pipeline"
|
||||||
|
|
@ -764,18 +1202,39 @@ def main():
|
||||||
parser.add_argument("--lrm", default="avap.md")
|
parser.add_argument("--lrm", default="avap.md")
|
||||||
parser.add_argument("--output", default="output/mbpp_avap_v2.json")
|
parser.add_argument("--output", default="output/mbpp_avap_v2.json")
|
||||||
parser.add_argument("--problems", type=int, default=5000)
|
parser.add_argument("--problems", type=int, default=5000)
|
||||||
parser.add_argument("--parser", default="http://localhost:8080",
|
parser.add_argument(
|
||||||
help="AVAP parser URL")
|
"--parser", default="http://localhost:8080", help="AVAP parser URL"
|
||||||
parser.add_argument("--cell-size", type=int, default=3,
|
)
|
||||||
help="Max constructs per cell: 2=pairs, 3=pairs+trios (default: 3)")
|
parser.add_argument(
|
||||||
parser.add_argument("--quality-threshold", type=float, default=0.80,
|
"--cell-size",
|
||||||
help="Min quality to consider a cell 'good' (default: 0.80)")
|
type=int,
|
||||||
parser.add_argument("--alpha", type=float, default=0.30,
|
default=3,
|
||||||
help="Weight for bonus constructs in cell quality (default: 0.30)")
|
help="Max constructs per cell: 2=pairs, 3=pairs+trios (default: 3)",
|
||||||
parser.add_argument("--beta", type=float, default=0.20,
|
)
|
||||||
help="Weight for test quality in cell quality (default: 0.20)")
|
parser.add_argument(
|
||||||
parser.add_argument("--gamma", type=float, default=0.10,
|
"--quality-threshold",
|
||||||
help="Weight for code richness in cell quality (default: 0.10)")
|
type=float,
|
||||||
|
default=0.80,
|
||||||
|
help="Min quality to consider a cell 'good' (default: 0.80)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--alpha",
|
||||||
|
type=float,
|
||||||
|
default=0.30,
|
||||||
|
help="Weight for bonus constructs in cell quality (default: 0.30)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--beta",
|
||||||
|
type=float,
|
||||||
|
default=0.20,
|
||||||
|
help="Weight for test quality in cell quality (default: 0.20)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gamma",
|
||||||
|
type=float,
|
||||||
|
default=0.10,
|
||||||
|
help="Weight for code richness in cell quality (default: 0.10)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mode",
|
"--mode",
|
||||||
choices=["map-elites-prior", "map-elites", "reward"],
|
choices=["map-elites-prior", "map-elites", "reward"],
|
||||||
|
|
@ -811,6 +1270,30 @@ def main():
|
||||||
"cells become the focus. Default: 0.70"
|
"cells become the focus. Default: 0.70"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--w-ecs",
|
||||||
|
type=float,
|
||||||
|
default=0.50,
|
||||||
|
help="Candidate A: weight for ECS in reward formula (default: 0.50)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--w-novelty",
|
||||||
|
type=float,
|
||||||
|
default=0.35,
|
||||||
|
help="Candidate A: weight for Jaccard novelty in reward formula (default: 0.35)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--w-tests",
|
||||||
|
type=float,
|
||||||
|
default=0.15,
|
||||||
|
help="Candidate A: weight for test quality in reward formula (default: 0.15)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pool-size",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
help="Candidate A: max examples in GoldPool (default: 50)",
|
||||||
|
)
|
||||||
parser.add_argument("--api-key", default=None)
|
parser.add_argument("--api-key", default=None)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
@ -846,36 +1329,72 @@ def main():
|
||||||
print(f" Quality thresh : {args.quality_threshold}")
|
print(f" Quality thresh : {args.quality_threshold}")
|
||||||
if args.mode == "map-elites-prior":
|
if args.mode == "map-elites-prior":
|
||||||
yaml_exists = Path(args.prior_map).exists()
|
yaml_exists = Path(args.prior_map).exists()
|
||||||
print(f" Prior map : {args.prior_map} ({'✓ found' if yaml_exists else '✗ not found — will use static fallback'})")
|
print(
|
||||||
|
f" Prior map : {args.prior_map} ({'✓ found' if yaml_exists else '✗ not found — will use static fallback'})"
|
||||||
|
)
|
||||||
print(f" Prior epsilon : {args.prior_epsilon}")
|
print(f" Prior epsilon : {args.prior_epsilon}")
|
||||||
|
elif args.mode == "reward":
|
||||||
|
print(
|
||||||
|
f" Reward weights : ECS={args.w_ecs} Novelty={args.w_novelty} Tests={args.w_tests}"
|
||||||
|
)
|
||||||
|
print(f" Pool size : {args.pool_size}")
|
||||||
print("=" * 65)
|
print("=" * 65)
|
||||||
|
|
||||||
prior = None
|
prior = None
|
||||||
|
pool = None
|
||||||
|
cmap = None
|
||||||
|
|
||||||
if args.mode == "map-elites-prior":
|
if args.mode == "map-elites-prior":
|
||||||
result = run_map_elites_prior(args, client, lrm, output_path)
|
result = run_map_elites_prior(args, client, lrm, output_path)
|
||||||
dataset, cmap, valid_count, call_count, prior = result
|
dataset, cmap, valid_count, call_count, prior = result
|
||||||
elif args.mode == "map-elites":
|
elif args.mode == "map-elites":
|
||||||
dataset, cmap, valid_count, call_count = run_map_elites(args, client, lrm, output_path)
|
dataset, cmap, valid_count, call_count = run_map_elites(
|
||||||
|
args, client, lrm, output_path
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
sys.exit("ERROR: --mode reward (Candidate A) is not yet implemented in v2. "
|
dataset, pool, valid_count, call_count = run_reward(
|
||||||
"Use generate_mbap.py for the v1 reward baseline.")
|
args, client, lrm, output_path
|
||||||
|
)
|
||||||
|
|
||||||
# Final report
|
# Final report
|
||||||
|
if cmap is not None:
|
||||||
freq = cmap.node_type_frequency()
|
freq = cmap.node_type_frequency()
|
||||||
entropy = cmap.distribution_entropy()
|
entropy = cmap.distribution_entropy()
|
||||||
|
else:
|
||||||
|
freq = defaultdict(int)
|
||||||
|
for ex in dataset:
|
||||||
|
for nt in ex.get("_detected", []):
|
||||||
|
freq[nt] += 1
|
||||||
|
freq = dict(freq)
|
||||||
|
total_f = sum(freq.values())
|
||||||
|
entropy = 0.0
|
||||||
|
if total_f > 0:
|
||||||
|
for count in freq.values():
|
||||||
|
p = count / total_f
|
||||||
|
if p > 0:
|
||||||
|
entropy -= p * math.log2(p)
|
||||||
|
entropy = round(entropy, 3)
|
||||||
|
|
||||||
print("\n" + "=" * 65)
|
print("\n" + "=" * 65)
|
||||||
print(" Pipeline complete")
|
print(" Pipeline complete")
|
||||||
print(f" Mode : {mode_label}")
|
print(f" Mode : {mode_label}")
|
||||||
print(f" Total API calls : {call_count}")
|
print(f" Total API calls : {call_count}")
|
||||||
print(f" Valid examples : {valid_count} ({100*valid_count/max(call_count,1):.1f}%)")
|
print(
|
||||||
|
f" Valid examples : {valid_count} ({100 * valid_count / max(call_count, 1):.1f}%)"
|
||||||
|
)
|
||||||
print(f" Dataset size : {len(dataset)}")
|
print(f" Dataset size : {len(dataset)}")
|
||||||
|
if cmap is not None:
|
||||||
print(f" {cmap.fill_summary()}")
|
print(f" {cmap.fill_summary()}")
|
||||||
print(f" Distribution entropy : {entropy:.3f} bits (max={math.log2(len(NODE_TYPE_NAMES)):.2f})")
|
if pool is not None:
|
||||||
|
print(f" {pool.pool_summary()}")
|
||||||
|
print(
|
||||||
|
f" Distribution entropy : {entropy:.3f} bits (max={math.log2(len(NODE_TYPE_NAMES)):.2f})"
|
||||||
|
)
|
||||||
if prior is not None:
|
if prior is not None:
|
||||||
kl = prior.kl_divergence(freq)
|
kl = prior.kl_divergence(freq)
|
||||||
print(f" KL(dataset ‖ prior) : {kl:.4f} (0 = perfect alignment with production code)")
|
print(
|
||||||
|
f" KL(dataset ‖ prior) : {kl:.4f} (0 = perfect alignment with production code)"
|
||||||
|
)
|
||||||
print(f" Most covered : {sorted(freq, key=freq.get, reverse=True)[:5]}")
|
print(f" Most covered : {sorted(freq, key=freq.get, reverse=True)[:5]}")
|
||||||
print(f" Least covered : {sorted(freq, key=freq.get)[:5]}")
|
print(f" Least covered : {sorted(freq, key=freq.get)[:5]}")
|
||||||
print(f" Output : {output_path}")
|
print(f" Output : {output_path}")
|
||||||
|
|
@ -884,3 +1403,4 @@ def main():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue