assistance-engine/scripts/pipelines/samples_generator/construct_prior.py

1049 lines
41 KiB
Python

#!/usr/bin/env python3
import argparse
import ast as pyast
import base64
import json
import math
import os
import sys
import time
from collections import defaultdict
from datetime import datetime, timezone
from itertools import combinations
from pathlib import Path
try:
import yaml
except ImportError:
print("ERROR: pyyaml not installed. Run: pip install pyyaml")
sys.exit(1)
try:
import requests
except ImportError:
print("ERROR: requests not installed. Run: pip install requests")
sys.exit(1)
AVAP_NODE_NAMES: list[str] = [
"addParam", "addResult", "addVar", "_status",
"getListLen", "getQueryParamList", "itemFromList",
"replace", "randomString",
"if_mode1", "if_mode2", "else", "end",
"startLoop", "endLoop",
"try", "exception",
"return",
"go", "gather",
"avapConnector",
"ormCheckTable", "ormDirect",
"ormAccessSelect", "ormAccessInsert", "ormAccessUpdate",
"variableFromJSON", "AddVariableToJSON",
"encodeSHA256", "encodeMD5",
"getTimeStamp", "getDateTime", "stampToDatetime",
"RequestGet", "RequestPost",
"function",
"import", "include",
]
LANGUAGE_MAPPINGS: dict[str, dict] = {
"ormAccessSelect": {
"description": "ORM read/query operation — SELECT semantics",
"python_ast_calls": [
".fetchall", ".fetchone", ".fetchmany",
".query", ".filter", ".filter_by", ".all", ".first",
".execute", ".select",
],
"go_keywords": ["db.Query(", "rows.Scan("],
"sql_keywords": ["SELECT ", "JOIN ", "WHERE "],
},
"ormAccessInsert": {
"description": "ORM write operation — INSERT semantics",
"python_ast_calls": [".add", ".insert", ".bulk_insert_mappings", ".create"],
"go_keywords": ['db.Exec("INSERT'],
"sql_keywords": ["INSERT INTO"],
},
"ormAccessUpdate": {
"description": "ORM update/delete operation — UPDATE/DELETE semantics",
"python_ast_calls": [".update", ".merge", ".save", ".commit"],
"go_keywords": ['db.Exec("UPDATE'],
"sql_keywords": ["UPDATE ", "DELETE FROM"],
},
"ormCheckTable": {
"description": "Check if a table exists before operating on it",
"python_ast_calls": [
"inspect.has_table", "engine.dialect.has_table",
"inspector.get_table_names",
],
"go_keywords": ["db.QueryRow(\"SELECT EXISTS"],
"sql_keywords": ["SHOW TABLES", "information_schema.tables"],
},
"ormDirect": {
"description": "Raw/direct SQL execution — bypasses ORM abstraction",
"python_ast_calls": [
"cursor.execute", "connection.execute",
"db.execute", "session.execute",
],
"go_keywords": ["db.Exec(", "db.QueryRow("],
"sql_keywords": ["EXECUTE ", "CALL "],
},
"RequestGet": {
"description": "HTTP GET request",
"python_ast_calls": [
"requests.get", "httpx.get", "session.get", "client.get",
"aiohttp.ClientSession.get",
],
"go_keywords": ["http.Get(", "client.Get("],
"sql_keywords": [],
},
"RequestPost": {
"description": "HTTP POST request",
"python_ast_calls": [
"requests.post", "httpx.post", "session.post", "client.post",
"aiohttp.ClientSession.post",
],
"go_keywords": ["http.Post(", "client.Post("],
"sql_keywords": [],
},
"try": {
"description": "Exception handling block — try",
"python_ast_node": "ast.Try",
"go_keywords": ["if err != nil"],
"sql_keywords": ["BEGIN TRY"],
},
"exception": {
"description": "Exception handler — except/catch clause",
"python_ast_node": "ast.ExceptHandler",
"go_keywords": ["if err != nil"],
"sql_keywords": ["BEGIN CATCH"],
},
"startLoop": {
"description": "Loop / iteration construct",
"python_ast_node": "ast.For / ast.AsyncFor",
"go_keywords": ["for _, v := range", "for i, "],
"sql_keywords": ["CURSOR LOOP"],
},
"endLoop": {
"description": "End of loop block (AVAP explicit close)",
"python_ast_node": "end of ast.For scope",
"go_keywords": [],
"sql_keywords": [],
},
"function": {
"description": "Function definition",
"python_ast_node": "ast.FunctionDef / ast.AsyncFunctionDef",
"go_keywords": ["func "],
"sql_keywords": ["CREATE FUNCTION", "CREATE PROCEDURE"],
},
"return": {
"description": "Return statement",
"python_ast_node": "ast.Return",
"go_keywords": ["return "],
"sql_keywords": ["RETURN "],
},
"if_mode1": {
"description": "Conditional — if(var, comparison, operator) form",
"python_ast_node": "ast.If",
"go_keywords": ["if "],
"sql_keywords": ["IF ", "CASE WHEN"],
},
"if_mode2": {
"description": "Conditional — if(None, None, expression) form",
"python_ast_node": "ast.If (complex condition)",
"go_keywords": [],
"sql_keywords": [],
},
"else": {
"description": "Else branch of a conditional",
"python_ast_node": "ast.If.orelse",
"go_keywords": ["} else {"],
"sql_keywords": ["ELSE"],
},
"end": {
"description": "Block terminator (AVAP explicit close)",
"python_ast_node": "end of ast.If scope",
"go_keywords": [],
"sql_keywords": ["END IF", "END"],
},
"go": {
"description": "Async/concurrent task launch",
"python_ast_calls": [
"asyncio.create_task", "asyncio.ensure_future",
"ThreadPoolExecutor", "executor.submit",
],
"go_keywords": ["go func(", "go "],
"sql_keywords": [],
},
"gather": {
"description": "Wait for concurrent tasks to complete",
"python_ast_calls": [
"asyncio.gather", "asyncio.wait",
"executor.map", "wg.Wait",
],
"go_keywords": ["sync.WaitGroup", "wg.Wait()"],
"sql_keywords": [],
},
"avapConnector": {
"description": "AVAP connector — external service integration point",
"python_ast_calls": [],
"go_keywords": [],
"sql_keywords": [],
"note": "No direct mainstream language equivalent. Rare in co-occurrence.",
},
"encodeSHA256": {
"description": "SHA-256 hashing",
"python_ast_calls": ["hashlib.sha256", "sha256", ".hexdigest"],
"go_keywords": ["sha256.New()", "sha256.Sum256("],
"sql_keywords": ["SHA2(", "HASHBYTES('SHA2_256'"],
},
"encodeMD5": {
"description": "MD5 hashing",
"python_ast_calls": ["hashlib.md5", "md5", ".hexdigest"],
"go_keywords": ["md5.New()", "md5.Sum("],
"sql_keywords": ["MD5(", "HASHBYTES('MD5'"],
},
"variableFromJSON": {
"description": "Parse JSON string into variable",
"python_ast_calls": ["json.loads", "json.load", "orjson.loads", "ujson.loads"],
"go_keywords": ["json.Unmarshal("],
"sql_keywords": ["JSON_VALUE(", "JSON_EXTRACT("],
},
"AddVariableToJSON": {
"description": "Serialize variable to JSON string",
"python_ast_calls": ["json.dumps", "json.dump", "orjson.dumps", "ujson.dumps"],
"go_keywords": ["json.Marshal("],
"sql_keywords": ["JSON_OBJECT(", "FOR JSON"],
},
"getDateTime": {
"description": "Get current date and time",
"python_ast_calls": [
"datetime.now", "datetime.utcnow", "datetime.today", "date.today",
],
"go_keywords": ["time.Now()"],
"sql_keywords": ["NOW()", "GETDATE()", "CURRENT_TIMESTAMP"],
},
"getTimeStamp": {
"description": "Get current Unix timestamp",
"python_ast_calls": ["time.time", "time.monotonic"],
"go_keywords": ["time.Now().Unix()"],
"sql_keywords": ["UNIX_TIMESTAMP()", "EXTRACT(EPOCH"],
},
"stampToDatetime": {
"description": "Convert Unix timestamp to datetime",
"python_ast_calls": ["datetime.fromtimestamp", "datetime.utcfromtimestamp"],
"go_keywords": ["time.Unix("],
"sql_keywords": ["FROM_UNIXTIME(", "DATEADD"],
},
"randomString": {
"description": "Generate random string or token",
"python_ast_calls": [
"secrets.token_hex", "secrets.token_urlsafe",
"uuid.uuid4", "uuid.uuid1",
"random.choices", "random.randbytes",
],
"go_keywords": ["rand.Read(", "uuid.New("],
"sql_keywords": ["NEWID()", "UUID()"],
},
"replace": {
"description": "String replacement operation",
"python_ast_calls": [".replace", "re.sub", "str.replace"],
"go_keywords": ["strings.Replace(", "strings.ReplaceAll("],
"sql_keywords": ["REPLACE("],
},
"addParam": {
"description": "Declare/receive input parameter",
"python_ast_node": "function argument / request.args.get",
"go_keywords": ["r.URL.Query().Get("],
"sql_keywords": [],
},
"addResult": {
"description": "Declare output/result variable",
"python_ast_node": "variable assignment for return value",
"go_keywords": [],
"sql_keywords": [],
},
"addVar": {
"description": "Declare intermediate variable",
"python_ast_node": "ast.Assign",
"go_keywords": [":=", "var "],
"sql_keywords": ["DECLARE "],
},
"_status": {
"description": "HTTP status code variable",
"python_ast_calls": ["response.status_code", ".status"],
"go_keywords": ["resp.StatusCode", "w.WriteHeader("],
"sql_keywords": [],
},
"getListLen": {
"description": "Get length of a list",
"python_ast_calls": ["len("],
"go_keywords": ["len("],
"sql_keywords": ["COUNT("],
},
"getQueryParamList": {
"description": "Get multiple values for a query parameter (list form)",
"python_ast_calls": ["request.args.getlist", "request.GET.getlist"],
"go_keywords": ["r.URL.Query()"],
"sql_keywords": [],
},
"itemFromList": {
"description": "Get item by index from a list",
"python_ast_calls": [],
"python_ast_node": "ast.Subscript on list",
"go_keywords": [],
"sql_keywords": [],
},
"import": {
"description": "Module/package import",
"python_ast_node": "ast.Import / ast.ImportFrom",
"go_keywords": ["import (", 'import "'],
"sql_keywords": [],
},
"include": {
"description": "Include another AVAP file/module",
"python_ast_calls": ["importlib.import_module"],
"go_keywords": [],
"sql_keywords": [],
},
}
# ─────────────────────────────────────────────────────────────────────────────
# GITHUB SEARCH QUERIES
# Each query targets a specific capability domain.
# Queries use GitHub Code Search syntax:
# language:python — only Python files
# path:api — files in directories named "api" (common microservice layout)
# The coverage across queries ensures we sample diverse API patterns.
# ─────────────────────────────────────────────────────────────────────────────
_GITHUB_QUERIES: list[str] = [
# ORM + error handling (most common microservice pattern)
"language:python path:api session.query try except",
"language:python path:routes db.query fetchall",
"language:python path:handlers cursor.execute SELECT",
# HTTP clients in API endpoints
"language:python path:api requests.get requests.post",
"language:python path:api httpx.get httpx.post",
# Authentication / crypto
"language:python path:api hashlib.sha256",
"language:python path:api hashlib.md5 token",
# Async / concurrency patterns
"language:python path:api asyncio.gather asyncio.create_task",
"language:python path:api ThreadPoolExecutor executor.submit",
# JSON handling
"language:python path:api json.loads json.dumps",
# DateTime
"language:python path:api datetime.now time.time",
# UUID / token generation
"language:python path:api uuid.uuid4 secrets.token",
# ORM insert/update patterns
"language:python path:api session.add session.commit",
"language:python path:routes db.execute INSERT UPDATE",
# Flask / FastAPI endpoint patterns (ensure we get realistic API structure)
"language:python path:api @app.route @router",
"language:python path:api @app.get @app.post fastapi",
]
# ─────────────────────────────────────────────────────────────────────────────
# PYTHON AST DETECTOR
# Parses Python files with the standard ast module and maps AST nodes to
# AVAP semantic equivalents. This is AST-level, not keyword scanning —
# no false positives from variable names, strings, or comments.
# ─────────────────────────────────────────────────────────────────────────────
class PythonASTDetector:
"""Detects AVAP semantic equivalents in a Python file using the ast module."""
# Suffix patterns for ORM call detection
_ORM_SELECT_SUFFIXES = frozenset([
".fetchall", ".fetchone", ".fetchmany",
".query", ".filter", ".filter_by", ".all", ".first",
".execute", ".select",
])
_ORM_INSERT_SUFFIXES = frozenset([".add", ".insert", ".bulk_insert_mappings", ".create"])
_ORM_UPDATE_SUFFIXES = frozenset([".update", ".merge", ".save", ".commit"])
def detect(self, code: str) -> set[str]:
"""
Parse Python source and return set of AVAP command names detected.
Falls back to keyword scanning if file has syntax errors.
"""
try:
tree = pyast.parse(code)
except SyntaxError:
return self._keyword_fallback(code)
detected: set[str] = set()
for node in pyast.walk(tree):
# ── Structural nodes ──────────────────────────────────────────
if isinstance(node, (pyast.FunctionDef, pyast.AsyncFunctionDef)):
detected.add("function")
# Count function parameters as addParam proxies
if node.args.args:
detected.add("addParam")
elif isinstance(node, pyast.Return) and node.value is not None:
detected.add("return")
elif isinstance(node, (pyast.For, pyast.AsyncFor)):
detected.add("startLoop")
elif isinstance(node, pyast.If):
detected.add("if_mode1")
if node.orelse:
detected.add("else")
elif isinstance(node, (pyast.Import, pyast.ImportFrom)):
detected.add("import")
elif isinstance(node, pyast.Try):
detected.add("try")
if node.handlers:
detected.add("exception")
elif isinstance(node, pyast.Assign):
detected.add("addVar")
elif isinstance(node, pyast.Subscript):
# list[index] → itemFromList proxy
if isinstance(node.ctx, pyast.Load):
detected.add("itemFromList")
# ── Call nodes ────────────────────────────────────────────────
elif isinstance(node, pyast.Call):
self._analyse_call(node, detected)
return detected
def _analyse_call(self, node: pyast.Call, detected: set[str]):
"""Analyse a Call AST node and add matching AVAP constructs."""
try:
callee = pyast.unparse(node.func)
except Exception:
return
# ORM
if any(callee.endswith(s) for s in self._ORM_SELECT_SUFFIXES):
detected.add("ormAccessSelect")
if any(callee.endswith(s) for s in self._ORM_INSERT_SUFFIXES):
detected.add("ormAccessInsert")
if any(callee.endswith(s) for s in self._ORM_UPDATE_SUFFIXES):
detected.add("ormAccessUpdate")
# ORM — raw SQL (cursor.execute / db.execute)
if any(callee.endswith(p) for p in ("cursor.execute", "connection.execute",
"db.execute", "session.execute")):
detected.add("ormDirect")
# ORM — table inspection
if any(p in callee for p in ("has_table", "get_table_names", "inspector")):
detected.add("ormCheckTable")
# HTTP GET
if callee in ("requests.get", "httpx.get") or (
callee.endswith(".get") and self._has_url_arg(node)
):
detected.add("RequestGet")
# HTTP POST
if callee in ("requests.post", "httpx.post") or (
callee.endswith(".post") and self._has_url_arg(node)
):
detected.add("RequestPost")
# Async concurrency
if callee in ("asyncio.gather", "asyncio.wait"):
detected.add("go")
detected.add("gather")
if callee in ("asyncio.create_task", "asyncio.ensure_future"):
detected.add("go")
if any(callee.endswith(p) for p in ("executor.submit", "executor.map",
"ThreadPoolExecutor")):
detected.add("go")
# Crypto
if any(p in callee for p in ("sha256", "sha_256", "SHA256")):
detected.add("encodeSHA256")
if any(p in callee for p in ("md5", "MD5")) and "hmac" not in callee:
detected.add("encodeMD5")
# JSON
if callee in ("json.loads", "json.load", "orjson.loads", "ujson.loads"):
detected.add("variableFromJSON")
if callee in ("json.dumps", "json.dump", "orjson.dumps", "ujson.dumps"):
detected.add("AddVariableToJSON")
# DateTime
if any(p in callee for p in ("datetime.now", "datetime.utcnow",
"datetime.today", "date.today")):
detected.add("getDateTime")
if callee in ("time.time", "time.monotonic", "time.time_ns"):
detected.add("getTimeStamp")
if any(p in callee for p in ("fromtimestamp", "utcfromtimestamp")):
detected.add("stampToDatetime")
# Random / UUID
if any(p in callee for p in ("secrets.token", "uuid.uuid", "random.choice",
"random.randbytes")):
detected.add("randomString")
# String replace
if callee.endswith(".replace"):
detected.add("replace")
# len() → getListLen
if callee == "len":
detected.add("getListLen")
# Query param list (Flask/FastAPI specific)
if any(p in callee for p in ("getlist", "get_list")):
detected.add("getQueryParamList")
# Status code
if any(p in callee for p in ("status_code", "WriteHeader", "status")):
detected.add("_status")
def _has_url_arg(self, node: pyast.Call) -> bool:
"""Heuristic: first argument looks like a URL or URL variable."""
if not node.args:
return False
first = node.args[0]
if isinstance(first, pyast.Constant) and isinstance(first.value, str):
v = first.value
return v.startswith(("http", "/", "https")) or "{" in v
if isinstance(first, pyast.Name):
return first.id.lower() in ("url", "endpoint", "uri", "base_url")
if isinstance(first, pyast.Attribute):
try:
return "url" in pyast.unparse(first).lower()
except Exception:
return False
return False
def _keyword_fallback(self, code: str) -> set[str]:
"""Keyword-based fallback for files that fail Python parsing."""
detected: set[str] = set()
checks = {
"def ": "function", "return ": "return",
"for ": "startLoop", "if ": "if_mode1",
"try:": "try", "except ": "exception",
"import ": "import",
"requests.get(": "RequestGet", "requests.post(": "RequestPost",
"httpx.get(": "RequestGet", "httpx.post(": "RequestPost",
".fetchall(": "ormAccessSelect", ".query(": "ormAccessSelect",
".add(": "ormAccessInsert", ".execute(": "ormDirect",
"json.loads(": "variableFromJSON", "json.dumps(": "AddVariableToJSON",
"hashlib.sha256": "encodeSHA256", "hashlib.md5": "encodeMD5",
"asyncio.gather": "gather", "asyncio.create_task": "go",
"datetime.now": "getDateTime", "time.time()": "getTimeStamp",
"uuid.uuid4": "randomString",
}
for pattern, avap in checks.items():
if pattern in code:
detected.add(avap)
return detected
# ─────────────────────────────────────────────────────────────────────────────
# GITHUB CODEBASE FETCHER
# Queries GitHub Code Search API and downloads Python files.
# Logs every file fetched so the user can verify real codebases are queried.
# ─────────────────────────────────────────────────────────────────────────────
class GitHubFetcher:
SEARCH_URL = "https://api.github.com/search/code"
def __init__(self, token: str = None, max_files: int = 100, verbose: bool = True):
self.token = token or os.environ.get("GITHUB_TOKEN")
self.max_files = max_files
self.verbose = verbose
self._files_fetched: list[dict] = [] # [{repo, path, url, avap_constructs}]
@property
def headers(self) -> dict:
h = {"Accept": "application/vnd.github.v3+json"}
if self.token:
h["Authorization"] = f"Bearer {self.token}"
return h
def fetch_all(self) -> list[dict]:
"""
Execute all GitHub search queries, download files, return list of
{repo, path, url, code} dicts. Logs every file to stdout.
"""
print(f"\n{''*60}")
print(f" GitHub Codebase Extraction")
print(f" Token: {'✓ authenticated (30 req/min)' if self.token else '✗ anonymous (10 req/min)'}")
print(f" Max files: {self.max_files}")
print(f" Queries: {len(_GITHUB_QUERIES)}")
print(f"{''*60}\n")
fetched: list[dict] = []
urls_seen: set[str] = set()
rate_sleep = 2.0 if self.token else 6.5 # seconds between queries
for q_idx, query in enumerate(_GITHUB_QUERIES):
if len(fetched) >= self.max_files:
break
print(f" [{q_idx+1:02d}/{len(_GITHUB_QUERIES)}] Query: {query[:70]}")
try:
resp = requests.get(
self.SEARCH_URL,
params={"q": query, "per_page": 10},
headers=self.headers,
timeout=15,
)
except requests.exceptions.RequestException as e:
print(f" ⚠ Network error: {e}")
time.sleep(5)
continue
if resp.status_code == 403:
reset_ts = int(resp.headers.get("X-RateLimit-Reset", time.time() + 60))
wait_sec = max(reset_ts - int(time.time()), 10)
print(f" ⚠ Rate limit — waiting {wait_sec}s...")
time.sleep(wait_sec)
# Retry once
try:
resp = requests.get(
self.SEARCH_URL,
params={"q": query, "per_page": 10},
headers=self.headers,
timeout=15,
)
except Exception:
continue
if resp.status_code != 200:
print(f" ⚠ HTTP {resp.status_code}")
time.sleep(2)
continue
items = resp.json().get("items", [])
print(f"{len(items)} results from GitHub")
for item in items:
if len(fetched) >= self.max_files:
break
raw_url = (
item.get("html_url", "")
.replace("https://github.com/", "https://raw.githubusercontent.com/")
.replace("/blob/", "/")
)
if not raw_url or raw_url in urls_seen:
continue
urls_seen.add(raw_url)
repo = item.get("repository", {}).get("full_name", "?")
path = item.get("path", "?")
try:
content_resp = requests.get(raw_url, timeout=10)
if content_resp.status_code != 200:
continue
code = content_resp.text
except Exception as e:
print(f" ⚠ Download error ({repo}/{path}): {e}")
continue
fetched.append({"repo": repo, "path": path, "url": raw_url, "code": code})
print(f"{repo} / {path} ({len(code):,} chars)")
time.sleep(rate_sleep)
print(f"\n {''*40}")
print(f" Total files fetched from GitHub: {len(fetched)}")
print(f" {''*40}\n")
self._files_fetched = fetched
return fetched
@property
def fetch_log(self) -> list[dict]:
"""Returns log of all fetched files (without code content)."""
return [
{"repo": f["repo"], "path": f["path"], "url": f["url"]}
for f in self._files_fetched
]
# ─────────────────────────────────────────────────────────────────────────────
# COOCCURRENCE EXTRACTOR
# Processes fetched files through the AST detector and builds
# pair/trio co-occurrence counts and normalized weights.
# ─────────────────────────────────────────────────────────────────────────────
class CooccurrenceExtractor:
def __init__(self):
self.detector = PythonASTDetector()
self._pair_counts: dict[tuple, int] = defaultdict(int)
self._trio_counts: dict[tuple, int] = defaultdict(int)
self._file_results: list[dict] = [] # per-file detection results
def process_files(self, files: list[dict]) -> None:
"""Process a list of {repo, path, code} dicts and accumulate counts."""
print(f" Processing {len(files)} files through Python AST detector...\n")
for i, f in enumerate(files):
code = f.get("code", "")
detected = self.detector.detect(code)
if len(detected) < 2:
print(f" [{i+1:03d}] {f['repo']}/{f['path']}"
f"{len(detected)} constructs (skipped, need ≥2)")
continue
sorted_d = sorted(detected)
pairs = list(combinations(sorted_d, 2))
trios = list(combinations(sorted_d, 3))
for p in pairs:
self._pair_counts[p] += 1
for t in trios:
self._trio_counts[t] += 1
self._file_results.append({
"repo": f["repo"],
"path": f["path"],
"constructs": sorted_d,
"pairs": len(pairs),
"trios": len(trios),
})
if self.verbose_log(i):
print(f" [{i+1:03d}] {f['repo']}/{f['path']}")
print(f" Constructs ({len(detected)}): {', '.join(sorted_d)}")
print(f" Pairs: {len(pairs)} Trios: {len(trios)}")
print(f"\n ─────────────────────────────────────────────────────")
print(f" Files with ≥2 constructs: {len(self._file_results)}")
print(f" Unique pair co-occurrences: {len(self._pair_counts)}")
print(f" Unique trio co-occurrences: {len(self._trio_counts)}")
print(f" ─────────────────────────────────────────────────────\n")
def verbose_log(self, idx: int) -> bool:
"""Log every file for full traceability."""
return True # Always show — the user needs to verify real files are analyzed
def normalized_pair_weights(self) -> dict[str, float]:
"""Normalize pair counts to [0, 1]. Key = 'a+b' sorted alphabetically."""
if not self._pair_counts:
return {}
max_c = max(self._pair_counts.values())
return {
"+".join(k): round(v / max_c, 6)
for k, v in sorted(
self._pair_counts.items(), key=lambda x: -x[1]
)
}
def normalized_trio_weights(self) -> dict[str, float]:
"""Normalize trio counts to [0, 1]. Key = 'a+b+c' sorted alphabetically."""
if not self._trio_counts:
return {}
max_c = max(self._trio_counts.values())
return {
"+".join(k): round(v / max_c, 6)
for k, v in sorted(
self._trio_counts.items(), key=lambda x: -x[1]
)
}
def top_pairs(self, n: int = 20) -> list[tuple]:
"""Return top-n pairs by count."""
return sorted(self._pair_counts.items(), key=lambda x: -x[1])[:n]
def top_trios(self, n: int = 20) -> list[tuple]:
"""Return top-n trios by count."""
return sorted(self._trio_counts.items(), key=lambda x: -x[1])[:n]
def per_file_results(self) -> list[dict]:
return self._file_results
def generate_construct_map(
extractor: CooccurrenceExtractor,
fetcher: GitHubFetcher,
output_path: Path,
) -> None:
#print(extractor)
pair_weights = extractor.normalized_pair_weights()
trio_weights = extractor.normalized_trio_weights()
doc = {
"meta": {
"description": (
"Auto-generated by construct_prior.py. "
"Weights derived from real GitHub codebases via Python AST analysis. "
"DO NOT EDIT MANUALLY — regenerate with: python construct_prior.py --generate-map"
),
"generated_at": datetime.now(timezone.utc).isoformat(),
"generator_version": "2.0",
"avap_node_count": len(AVAP_NODE_NAMES),
"avap_node_names": AVAP_NODE_NAMES,
"source_stats": {
"github_files_analyzed": len(extractor.per_file_results()),
"github_files_fetched": len(fetcher.fetch_log),
"total_pair_cooccurrences": len(pair_weights),
"total_trio_cooccurrences": len(trio_weights),
},
},
"language_mappings": LANGUAGE_MAPPINGS,
"fetch_log": fetcher.fetch_log,
"pair_weights": pair_weights,
"trio_weights": trio_weights,
}
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
yaml.dump(doc, f, default_flow_style=False, allow_unicode=True,
sort_keys=False, width=120)
print(f"construct_map.yaml written to: {output_path}")
print(f"Pair weights: {len(pair_weights)}")
print(f"Trio weights: {len(trio_weights)}")
print()
print("Top-20 construct pairs by co-occurrence frequency:")
for (a, b), count in extractor.top_pairs(20):
w = pair_weights.get(f"{a}+{b}", pair_weights.get(f"{b}+{a}", 0))
print(f" {w:.4f} {a} + {b} (n={count})")
print()
print("Top-10 construct trios by co-occurrence frequency:")
for trio, count in extractor.top_trios(10):
w = trio_weights.get("+".join(trio), 0)
print(f" {w:.4f} {' + '.join(trio)} (n={count})")
_DEFAULT_EPSILON = 0.05
class ConstructPrior:
def __init__(self, weights: dict[frozenset, float], epsilon: float = _DEFAULT_EPSILON,
source_stats: dict = None):
self._weights = weights
self.epsilon = epsilon
self._source_stats = source_stats or {}
self._propagate_subset_weights()
#print(self._source_stats )
@classmethod
def from_yaml(cls, path: Path, epsilon: float = _DEFAULT_EPSILON) -> "ConstructPrior":
if not path.exists():
raise FileNotFoundError(
f"construct_map.yaml not found at {path}.\n"
f"Generate it first: python construct_prior.py --generate-map"
)
with open(path, encoding="utf-8") as f:
doc = yaml.safe_load(f)
weights: dict[frozenset, float] = {}
for key_str, w in doc.get("pair_weights", {}).items():
parts = key_str.split("+")
if len(parts) == 2:
weights[frozenset(parts)] = float(w)
for key_str, w in doc.get("trio_weights", {}).items():
parts = key_str.split("+")
if len(parts) == 3:
weights[frozenset(parts)] = float(w)
source_stats = doc.get("meta", {}).get("source_stats", {})
print(f"ConstructPrior loaded from {path.name}")
print(f"Files analyzed: {source_stats.get('github_files_analyzed', '?')}")
print(f"-Pair weights: {len([k for k in weights if len(k)==2])}")
print(f"-Trio weights: {len([k for k in weights if len(k)==3])}")
return cls(weights=weights, epsilon=epsilon, source_stats=source_stats)
@classmethod
def from_static_fallback(cls, epsilon: float = _DEFAULT_EPSILON) -> "ConstructPrior":
print("[WARN] Using static fallback prior (no construct_map.yaml).")
print("Run: python construct_prior.py --generate-map")
static: list[tuple[tuple, float]] = [
(("try", "exception"),1.00),
(("function", "return"),0.98),
(("function", "try", "return"),0.95),
(("ormAccessSelect", "try"),0.90),
(("ormAccessSelect", "try", "exception"),0.88),
(("RequestGet", "try"),0.85),
(("RequestPost", "try"),0.84),
(("if_mode1", "return"),0.82),
(("function", "if_mode1", "return"),0.80),
(("ormAccessSelect", "return"), 0.78),
(("ormAccessInsert", "try"),0.75),
(("ormAccessUpdate", "try"),0.72),
(("RequestGet", "variableFromJSON"),0.70),
(("RequestPost", "variableFromJSON"),0.68),
(("variableFromJSON", "return"),0.65),
(("startLoop", "return"),0.62),
(("startLoop", "ormAccessSelect"),0.60),
(("function", "import"),0.58),
(("if_mode1", "ormAccessSelect"),0.55),
(("ormAccessSelect", "ormAccessInsert"),0.52),
(("go", "gather"),0.48),
(("go", "RequestGet"),0.45),
(("go", "RequestPost"),0.43),
(("go", "gather", "return"),0.42),
(("encodeSHA256", "return"),0.40),
(("encodeSHA256", "if_mode1"),0.38),
(("encodeMD5", "return"),0.36),
(("AddVariableToJSON", "return"),0.35),
(("variableFromJSON", "AddVariableToJSON"),0.33),
(("getDateTime", "ormAccessInsert"),0.30),
(("getTimeStamp", "return"),0.28),
(("startLoop", "if_mode1"),0.27),
(("startLoop", "if_mode1", "return"),0.25),
(("randomString", "return"),0.22),
(("randomString", "encodeSHA256"),0.20),
(("replace", "return"),0.16),
(("avapConnector", "try"),0.14),
(("go", "ormAccessSelect"),0.10),
(("go", "gather", "ormAccessSelect"),0.09),
]
weights = {frozenset(k): v for k, v in static}
return cls(weights=weights, epsilon=epsilon,
source_stats={"mode": "static_fallback"})
def cell_weight(self, cell: frozenset) -> float:
#print(cell)
#www = self._weights.get(cell, self.epsilon)
#print(www)
return max(self._weights.get(cell, self.epsilon), self.epsilon)
def kl_divergence(self, dataset_freq: dict[str, int]) -> float:
total_d = sum(dataset_freq.values())
total_p = sum(self._weights.values())
if total_d == 0 or total_p == 0:
return float("inf")
kl = 0.0
for nt in AVAP_NODE_NAMES:
p = dataset_freq.get(nt, 0) / total_d
q_raw = sum(w for cell, w in self._weights.items() if nt in cell)
q = max(q_raw / total_p, 1e-9)
if p > 0:
kl += p * math.log2(p / q)
return round(kl, 4)
def coverage_summary(self) -> str:
n = len(self._weights)
avg = sum(self._weights.values()) / max(n, 1)
stats = " ".join(f"{k}={v}" for k, v in self._source_stats.items())
return (
f"ConstructPrior: {n} cells | mean={avg:.3f} | epsilon={self.epsilon}"
+ (f" | {stats}" if stats else "")
)
def _propagate_subset_weights(self):
pairs = [(cell, w) for cell, w in self._weights.items() if len(cell) == 2]
for pair_cell, pair_w in pairs:
inherited = pair_w * 0.60
for trio_cell in list(self._weights):
if len(trio_cell) == 3 and pair_cell.issubset(trio_cell):
#print(trio_cell)
if inherited > self._weights.get(trio_cell, 0):
self._weights[trio_cell] = inherited
def main():
parser = argparse.ArgumentParser(
description="Generate construct_map.yaml from real GitHub codebases",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Fetch 200 files for a richer prior:
python construct_prior.py --generate-map --max-files 200 --github-token ghp_...
""",
)
parser.add_argument(
"--generate-map", action="store_true",
help="Fetch real codebases from GitHub and generate construct_map.yaml",
)
parser.add_argument(
"--verify", action="store_true",
help="Load and print stats for an existing construct_map.yaml",
)
parser.add_argument(
"--github-token", default=None,
help="GitHub personal access token (or set GITHUB_TOKEN env var)",
)
parser.add_argument(
"--max-files", type=int, default=100,
help="Maximum number of GitHub files to analyze (default: 100)",
)
parser.add_argument(
"--output", default="construct_map.yaml",
help="Output path for construct_map.yaml (default: construct_map.yaml)",
)
parser.add_argument(
"--map", default="construct_map.yaml",
help="Path to existing construct_map.yaml (for --verify)",
)
args = parser.parse_args()
if args.generate_map:
print("\n====================================================")
print(" ConstructPrior — Codebase Extraction")
print(" Querying REAL GitHub codebases. No hardcoded data.")
print("=====================================================")
fetcher= GitHubFetcher(
token=args.github_token,
max_files=args.max_files,
verbose=True,
)
files = fetcher.fetch_all()
if not files:
print("\nERROR: No files fetched from GitHub.")
print("Check your internet connection and GitHub token.")
sys.exit(1)
extractor = CooccurrenceExtractor()
extractor.process_files(files)
output_path = Path(args.output)
generate_construct_map(extractor, fetcher, output_path)
print("\n================================================")
print(f" construct_map.yaml generated successfully.")
print(f" Files analyzed from real GitHub codebases: {len(extractor.per_file_results())}")
print(f" Use in generator: --prior-map {output_path}")
print("==================================================\n")
elif args.verify:
map_path = Path(args.map)
try:
prior = ConstructPrior.from_yaml(map_path)
print(f"\n{prior.coverage_summary()}")
with open(map_path, encoding="utf-8") as f:
doc = yaml.safe_load(f)
meta = doc.get("meta", {})
stats = meta.get("source_stats", {})
print(f"\n Generated: {meta.get('generated_at', '?')}")
print(f" AVAP commands: {meta.get('avap_node_count', '?')}")
print(f" Files analyzed: {stats.get('github_files_analyzed', '?')}")
print(f" Pair weights: {stats.get('total_pair_cooccurrences', '?')}")
print(f" Trio weights: {stats.get('total_trio_cooccurrences', '?')}")
except FileNotFoundError as e:
print(f" ERROR: {e}")
sys.exit(1)
else:
parser.print_help()
if __name__ == "__main__":
main()