assistance-engine/scripts/pipelines/ingestion/avap_chunker.py

795 lines
27 KiB
Python

"""
chunker.py v1.0
Uso:
python chunker.py --lang-config avap_config.json --docs-path ./docs/samples
python chunker.py --lang-config avap_config.json --docs-path ./docs/samples --workers 8
python chunker.py --lang-config avap_config.json --docs-path ./docs/samples --redis-url redis://localhost:6379
python chunker.py --lang-config avap_config.json --docs-path ./docs/samples --no-dedup
"""
import re
import os
import json
import hashlib
import argparse
import tempfile
import warnings as py_warnings
from pathlib import Path
from dataclasses import dataclass, asdict, field
from typing import Optional, Generator, IO
from concurrent.futures import ProcessPoolExecutor, as_completed
try:
import tiktoken
_ENC = tiktoken.get_encoding("cl100k_base")
def count_tokens(text: str) -> int:
return len(_ENC.encode(text))
TOKEN_BACKEND = "tiktoken/cl100k_base"
except ImportError:
py_warnings.warn("tiktoken no instalado — usando word-count. pip install tiktoken",
stacklevel=2)
def count_tokens(text: str) -> int: # type: ignore[misc]
return len(text.split())
TOKEN_BACKEND = "word-count (estimación)"
try:
from datasketch import MinHash, MinHashLSH
MINHASH_AVAILABLE = True
except ImportError:
MINHASH_AVAILABLE = False
py_warnings.warn("datasketch no instalado — dedup desactivada. pip install datasketch",
stacklevel=2)
try:
from tqdm import tqdm
except ImportError:
def tqdm(x, **kwargs): return x # type: ignore[misc]
MAX_NARRATIVE_TOKENS = 400
OVERLAP_LINES = 3
DEDUP_THRESHOLD = 0.85
MINHASH_NUM_PERM = 128
MINHASH_SHINGLE_SIZE = 3
DEFAULT_WORKERS = max(1, (os.cpu_count() or 4) - 1)
@dataclass
class BlockDef:
name: str
doc_type: str
opener_re: re.Pattern
closer_re: re.Pattern
extract_signature:bool = False
signature_template:str = ""
def extract_sig(self, clean_line):
if not self.extract_signature:
return None
m = self.opener_re.match(clean_line)
if not m:
return None
tpl = self.signature_template
for i, g in enumerate(m.groups(), start=1):
tpl = tpl.replace(f"{{group{i}}}", (g or "").strip())
return tpl
@dataclass
class StatementDef:
name: str
re: re.Pattern
@dataclass
class SemanticTag:
tag: str
re: re.Pattern
class LanguageConfig:
def __init__(self, config_path: str):
raw = json.loads(Path(config_path).read_text(encoding="utf-8"))
self.language = raw.get("language", "unknown")
self.version = raw.get("version", "1.0")
self.extensions = set(raw.get("file_extensions", []))
lex = raw.get("lexer", {})
self.string_delimiters = lex.get("string_delimiters", ['"', "'"])
self.escape_char = lex.get("escape_char", "\\")
self.comment_line = sorted(lex.get("comment_line", ["#"]), key=len, reverse=True)
cb = lex.get("comment_block", {})
self.comment_block_open = cb.get("open", "")
self.comment_block_close = cb.get("close", "")
self.line_oriented = lex.get("line_oriented", True)
self.blocks: list[BlockDef] = []
for b in raw.get("blocks", []):
self.blocks.append(BlockDef(
name = b["name"],
doc_type = b.get("doc_type", "code"),
opener_re = re.compile(b["opener_pattern"]),
closer_re = re.compile(b["closer_pattern"]),
extract_signature = b.get("extract_signature", False),
signature_template = b.get("signature_template", ""),
))
self.statements: list[StatementDef] = [
StatementDef(name=s["name"], re=re.compile(s["pattern"]))
for s in raw.get("statements", [])
]
self.semantic_tags: list[SemanticTag] = [
SemanticTag(tag=t["tag"], re=re.compile(t["pattern"]))
for t in raw.get("semantic_tags", [])
]
def match_opener(self, clean_line):
for block in self.blocks:
if block.opener_re.match(clean_line):
return block
return None
def match_closer(self, clean_line):
for block in self.blocks:
if block.closer_re.match(clean_line):
return True
return False
def classify_statement(self, clean_line):
for stmt in self.statements:
if stmt.re.match(clean_line):
return stmt.name
return "statement"
def enrich_metadata(self, content):
meta: dict = {}
for tag in self.semantic_tags:
if tag.re.search(content):
meta[tag.tag] = True
meta["complexity"] = sum(1 for v in meta.values() if v is True)
return meta
@dataclass
class Chunk:
chunk_id: str
source_file: str
doc_type: str
block_type: str
section: str
start_line: int
end_line: int
content: str
metadata: dict = field(default_factory=dict)
def token_count(self):
return count_tokens(self.content)
def to_dict(self):
d = asdict(self)
d["token_estimate"] = self.token_count()
return d
def make_chunk_id(filepath, start, end, content):
return hashlib.sha1(
f"{filepath.name}:{start}:{end}:{content[:60]}".encode()
).hexdigest()[:16]
def make_chunk(filepath: Path, doc_type, block_type,section, start, end, content, cfg, extra_meta = None):
content = content.strip()
meta = cfg.enrich_metadata(content)
if extra_meta:
meta.update(extra_meta)
return Chunk(
chunk_id=make_chunk_id(filepath, start, end, content),
source_file=str(filepath),
doc_type=doc_type, block_type=block_type,
section=section, start_line=start, end_line=end,
content=content, metadata=meta,
)
class GenericLexer:
def __init__(self, cfg: LanguageConfig):
self.cfg = cfg
self.in_block_comment = False
def process_line(self, raw):
if self.in_block_comment:
if self.cfg.comment_block_close and \
self.cfg.comment_block_close in raw:
self.in_block_comment = False
return False, ""
cb_open = self.cfg.comment_block_open
cb_close = self.cfg.comment_block_close
if cb_open and cb_open in raw:
idx_open = raw.index(cb_open)
rest = raw[idx_open + len(cb_open):]
if cb_close and cb_close in rest:
idx_close = raw.index(cb_close, idx_open)
code_part = raw[:idx_open] + raw[idx_close + len(cb_close):]
return self._strip_line_comments(code_part)
else:
self.in_block_comment = True
return self._strip_line_comments(raw[:idx_open])
return self._strip_line_comments(raw)
def _strip_line_comments(self, raw):
in_str: Optional[str] = None
result = []
i = 0
while i < len(raw):
ch = raw[i]
if in_str and ch == self.cfg.escape_char:
result.append(ch)
if i + 1 < len(raw):
result.append(raw[i + 1])
i += 2
else:
i += 1
continue
if in_str and ch == in_str:
in_str = None
result.append(ch); i += 1; continue
if not in_str and ch in self.cfg.string_delimiters:
in_str = ch
result.append(ch); i += 1; continue
if not in_str:
matched = False
for prefix in self.cfg.comment_line:
if raw[i:].startswith(prefix):
matched = True
break
if matched:
break
result.append(ch); i += 1
code = "".join(result).strip()
return bool(code), code
class SemanticOverlapBuffer:
def __init__(self, overlap_lines = OVERLAP_LINES):
self.overlap_lines = overlap_lines
self._prev = None
self._current_fn_sig = None
self._current_fn_file = None
def notify_function(self, sig, source_file):
self._current_fn_sig = sig
self._current_fn_file = source_file
def notify_file_change(self, source_file):
if self._current_fn_file != source_file:
self._current_fn_sig = None
self._current_fn_file = source_file
self._prev = None
def apply(self, chunk):
if self.overlap_lines <= 0:
self._prev = chunk
return chunk
if self._prev and self._prev.source_file != chunk.source_file:
self.notify_file_change(chunk.source_file)
context_header = None
if (self._current_fn_sig
and self._current_fn_file == chunk.source_file
and chunk.block_type not in ("function", "function_signature")):
context_header = f"// contexto: {self._current_fn_sig}"
overlap_type = "function_sig"
elif (self._prev
and self._prev.source_file == chunk.source_file
and self._prev.doc_type == chunk.doc_type):
context_header = "\n".join(
self._prev.content.splitlines()[-self.overlap_lines:])
overlap_type = "line_tail"
else:
overlap_type = "none"
self._prev = chunk
if context_header:
new_content = (context_header + "\n" + chunk.content).strip()
return Chunk(
chunk_id=chunk.chunk_id, source_file=chunk.source_file,
doc_type=chunk.doc_type, block_type=chunk.block_type,
section=chunk.section, start_line=chunk.start_line,
end_line=chunk.end_line, content=new_content,
metadata={**chunk.metadata,
"has_overlap": True,
"overlap_type": overlap_type},
)
return chunk
def _shingles(text, k = MINHASH_SHINGLE_SIZE):
words = text.lower().split()
if len(words) < k:
return [" ".join(words).encode()]
return [" ".join(words[i:i+k]).encode() for i in range(len(words) - k + 1)]
def _build_minhash(text):
m = MinHash(num_perm=MINHASH_NUM_PERM)
for s in _shingles(text):
m.update(s)
return m
class StreamingDeduplicator:
def __init__(self, threshold: float = DEDUP_THRESHOLD ):
self.threshold = threshold
self._lsh: dict[str, "MinHashLSH"] = {}
self.removed = 0
def _get_lsh(self, doc_type):
if doc_type not in self._lsh:
self._lsh[doc_type] = MinHashLSH(
threshold=self.threshold, num_perm=MINHASH_NUM_PERM)
return self._lsh[doc_type]
def is_duplicate(self, chunk):
if not MINHASH_AVAILABLE:
return False
lsh = self._get_lsh(chunk.doc_type)
m = _build_minhash(chunk.content)
try:
if lsh.query(m):
self.removed += 1
return True
except Exception:
pass
try:
lsh.insert(chunk.chunk_id, m)
except Exception as e:
print(e)
pass
return False
class JsonlWriter:
def __init__(self, path):
out = Path(path)
if out.suffix.lower() == ".json":
out = out.with_suffix(".jsonl")
out.parent.mkdir(parents=True, exist_ok=True)
self.path = out
self._handle: IO = open(out, "w", encoding="utf-8")
self.written = 0
def write(self, chunk):
self._handle.write(json.dumps(chunk.to_dict(), ensure_ascii=False) + "\n")
self.written += 1
def close(self):
if self._handle:
self._handle.close()
def validate_syntax(lines, filepath, cfg ):
warnings_out = []
stack = []
lexer = GenericLexer(cfg)
for i, raw in enumerate(lines):
line_no = i + 1
is_code, clean = lexer.process_line(raw)
if not is_code or not clean:
continue
block = cfg.match_opener(clean)
if block:
stack.append((block.name, line_no))
continue
if cfg.match_closer(clean):
if stack:
stack.pop()
else:
warnings_out.append(
f"{filepath.name}:{line_no} — close without open")
for bt, ln in stack:
warnings_out.append(
f"{filepath.name}:{ln} — not closed block '{bt}'")
return warnings_out
def iter_code_chunks(filepath, cfg, overlap_buf):
lines = filepath.read_text(encoding="utf-8").splitlines()
warnings = validate_syntax(lines, filepath, cfg)
overlap_buf.notify_file_change(str(filepath))
lexer = GenericLexer(cfg)
i = 0
pending_raw = []
loose_buffer = []
loose_type = None
def flush_loose():
nonlocal loose_buffer, loose_type
if not loose_buffer:
return
start = loose_buffer[0][0]
end = loose_buffer[-1][0]
content = "\n".join(t for _, t in loose_buffer)
chunk = make_chunk(filepath, "code", loose_type or "statement",
"", start, end, content, cfg)
chunk = overlap_buf.apply(chunk)
loose_buffer.clear(); loose_type = None
yield chunk
while i < len(lines):
raw = lines[i]
line_no = i + 1
is_code, clean = lexer.process_line(raw)
if not is_code or not clean:
pending_raw.append(raw); i += 1; continue
block_def = cfg.match_opener(clean)
if block_def:
yield from flush_loose()
block_start = line_no
block_lines = list(pending_raw) + [raw]
pending_raw.clear()
sig = block_def.extract_sig(clean)
if sig:
overlap_buf.notify_function(sig, str(filepath))
depth = 1; i += 1
while i < len(lines) and depth > 0:
inner_raw = lines[i]
_, inner_clean = lexer.process_line(inner_raw)
block_lines.append(inner_raw)
if inner_clean:
if cfg.match_opener(inner_clean):
depth += 1
elif cfg.match_closer(inner_clean):
depth -= 1
i += 1
chunk = make_chunk(filepath, block_def.doc_type, block_def.name, "", block_start, i, "\n".join(block_lines), cfg)
chunk = overlap_buf.apply(chunk)
yield chunk
if sig:
yield make_chunk(
filepath, "function_signature", "function_signature", "", block_start, block_start, sig, cfg,
extra_meta={"full_block_start": block_start,
"full_block_end": i}
)
continue
stmt_type = cfg.classify_statement(clean)
if loose_type and stmt_type != loose_type:
yield from flush_loose()
if pending_raw and not loose_buffer:
for pc in pending_raw:
loose_buffer.append((line_no, pc))
pending_raw.clear()
loose_type = stmt_type
loose_buffer.append((line_no, raw))
i += 1
yield from flush_loose()
if warnings:
yield (None, warnings)
RE_MD_H1 = re.compile(r"^# (.+)")
RE_MD_H2 = re.compile(r"^## (.+)")
RE_MD_H3 = re.compile(r"^### (.+)")
RE_FENCE_OPEN = re.compile(r"^```(\w*)")
RE_FENCE_CLOSE = re.compile(r"^```\s*$")
RE_TABLE_ROW = re.compile(r"^\|")
def split_narrative_by_tokens(text, max_tokens):
paragraphs = re.split(r"\n\s*\n", text)
result = []; current = []; current_tokens = 0
for para in paragraphs:
pt = count_tokens(para)
if current_tokens + pt > max_tokens and current:
result.append("\n\n".join(current))
current = [para]; current_tokens = pt
else:
current.append(para); current_tokens += pt
if current:
result.append("\n\n".join(current))
return [t for t in result if t.strip()]
def iter_markdown_chunks(filepath, cfg, max_tokens = MAX_NARRATIVE_TOKENS):
lines = filepath.read_text(encoding="utf-8").splitlines()
current_h1 = current_h2 = current_h3 = ""
def section_label() -> str:
return " > ".join(p for p in [current_h1, current_h2, current_h3] if p)
def make_md_chunk(doc_type, block_type, start, end, content) -> Chunk:
return make_chunk(filepath, doc_type, block_type,
section_label(), start, end, content, cfg)
i = 0
narrative_start = 1; narrative_lines: list[str] = []
def flush_narrative() -> Generator:
nonlocal narrative_lines, narrative_start
text = "\n".join(narrative_lines).strip()
if not text:
narrative_lines.clear(); return
for sub in split_narrative_by_tokens(text, max_tokens):
sl = sub.count("\n") + 1
yield make_md_chunk("spec", "narrative",
narrative_start, narrative_start + sl - 1, sub)
narrative_lines.clear()
while i < len(lines):
raw = lines[i]; line_no = i + 1
m1 = RE_MD_H1.match(raw); m2 = RE_MD_H2.match(raw); m3 = RE_MD_H3.match(raw)
if m1:
yield from flush_narrative()
current_h1 = m1.group(1).strip(); current_h2 = current_h3 = ""
narrative_start = line_no + 1; i += 1; continue
if m2:
yield from flush_narrative()
current_h2 = m2.group(1).strip(); current_h3 = ""
narrative_start = line_no + 1; i += 1; continue
if m3:
yield from flush_narrative()
current_h3 = m3.group(1).strip()
narrative_start = line_no + 1; i += 1; continue
fm = RE_FENCE_OPEN.match(raw)
if fm and not RE_FENCE_CLOSE.match(raw):
yield from flush_narrative()
lang = fm.group(1).lower() or "code"
doc_type = "bnf" if lang == "bnf" else "code_example"
fence_start = line_no
fence_lines = [raw]; i += 1
while i < len(lines):
fence_lines.append(lines[i])
if RE_FENCE_CLOSE.match(lines[i]) and len(fence_lines) > 1:
i += 1; break
i += 1
yield make_md_chunk(doc_type, lang,
fence_start, fence_start + len(fence_lines) - 1,
"\n".join(fence_lines))
narrative_start = i + 1
continue
if RE_TABLE_ROW.match(raw):
yield from flush_narrative()
ts = line_no; tl = []
while i < len(lines) and RE_TABLE_ROW.match(lines[i]):
tl.append(lines[i]); i += 1
yield make_md_chunk("spec", "table", ts, ts + len(tl) - 1, "\n".join(tl))
narrative_start = i + 1
continue
if not narrative_lines:
narrative_start = line_no
narrative_lines.append(raw)
i += 1
yield from flush_narrative()
def _worker(args):
paths, config_path, overlap_lines, max_tokens = args
cfg = LanguageConfig(config_path)
overlap_buf = SemanticOverlapBuffer(overlap_lines)
stats = {t: 0 for t in ["code", "function_signature", "spec", "bnf", "code_example", "unknown", "total"]}
all_warnings = []
fd, tmp_path = tempfile.mkstemp(suffix=".jsonl", prefix="worker_")
os.close(fd)
with open(tmp_path, "w", encoding="utf-8") as f:
for path in paths:
ext = path.suffix.lower()
if ext in cfg.extensions:
for item in iter_code_chunks(path, cfg, overlap_buf):
if isinstance(item, tuple) and item[0] is None:
all_warnings.extend(item[1])
continue
chunk = item
f.write(json.dumps(chunk.to_dict(), ensure_ascii=False) + "\n")
stats[chunk.doc_type] = stats.get(chunk.doc_type, 0) + 1
stats["total"] += 1
elif ext == ".md":
for chunk in iter_markdown_chunks(path, cfg, max_tokens):
f.write(json.dumps(chunk.to_dict(), ensure_ascii=False) + "\n")
stats[chunk.doc_type] = stats.get(chunk.doc_type, 0) + 1
stats["total"] += 1
else:
content = path.read_text(encoding="utf-8")
chunk = make_chunk(path, "unknown", "raw", "", 1,
content.count("\n") + 1, content, cfg)
f.write(json.dumps(chunk.to_dict(), ensure_ascii=False) + "\n")
stats["unknown"] += 1; stats["total"] += 1
return tmp_path, stats, all_warnings
def fetch_documents(docs_path, cfg, extra_extensions):
root = Path(docs_path)
if not root.exists():
raise FileNotFoundError(f"PATH not found: {root}")
all_exts = cfg.extensions | set(extra_extensions)
return sorted(p for p in root.rglob("*")
if p.is_file() and p.suffix.lower() in all_exts)
def _partition(paths, n):
k = max(1, len(paths) // n)
return [paths[i:i+k] for i in range(0, len(paths), k)]
def run_pipeline(paths,
config_path,
writer,
deduplicator,
overlap_lines,
max_tokens,
workers):
total_stats = {t: 0 for t in ["code", "function_signature", "spec", "bnf", "code_example", "unknown", "total", "dedup_removed"]}
all_warnings = []
tmp_files = []
partitions = _partition(paths, workers)
worker_args = [(part, config_path, overlap_lines, max_tokens) for part in partitions]
print(f"{len(paths)} Files in {len(partitions)} workers...\n")
with ProcessPoolExecutor(max_workers=workers) as executor:
futures = {executor.submit(_worker, arg): i
for i, arg in enumerate(worker_args)}
for future in tqdm(as_completed(futures), total=len(futures),
desc=" Workers", unit="worker"):
tmp_path, stats, warns = future.result()
tmp_files.append(tmp_path)
all_warnings.extend(warns)
for k, v in stats.items():
total_stats[k] = total_stats.get(k, 0) + v
print(f"\n Mergin {len(tmp_files)} partial files...")
for tmp_path in tqdm(tmp_files, desc=" Merge + dedup", unit="file"):
with open(tmp_path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
cd = json.loads(line)
if deduplicator:
c = Chunk(
chunk_id=cd["chunk_id"], source_file=cd["source_file"],
doc_type=cd["doc_type"], block_type=cd["block_type"],
section=cd["section"], start_line=cd["start_line"],
end_line=cd["end_line"], content=cd["content"],
metadata=cd.get("metadata", {}),
)
if deduplicator.is_duplicate(c):
total_stats["dedup_removed"] = \
total_stats.get("dedup_removed", 0) + 1
continue
writer._handle.write(line + "\n")
writer.written += 1
except json.JSONDecodeError as e:
print(e)
pass
Path(tmp_path).unlink(missing_ok=True)
return total_stats, all_warnings
def print_report(stats, warnings, output_path, token_backend, workers, language):
print(f" RESULT — [{language}]")
print(f" Tokenizer : {token_backend}")
dedup_be = "MinHash LSH (RAM)" if MINHASH_AVAILABLE else "desactivada"
print(f" Dedup backend : {dedup_be}")
print(f" Workers : {workers}")
print()
for t in ["code", "function_signature", "spec", "bnf", "code_example", "unknown"]:
n = stats.get(t, 0)
if n:
print(f" {t:<25}: {n:>6} chunks")
print(f"\n Total written : {stats.get('total', 0)}")
print(f" Erased (dedup) : {stats.get('dedup_removed', 0)}")
if warnings:
print(f"\n Warnings ({len(warnings)}):")
for w in warnings[:20]:
print(w)
if len(warnings) > 20:
print(f" ... and {len(warnings) - 20} more")
else:
print("\n Ok")
print(f"\n OUTPUT File {output_path}")
def main():
parser = argparse.ArgumentParser(
description="GEneric chunker"
)
parser.add_argument("--lang-config", required=True,
help="(ej: avap_config.json)")
parser.add_argument("--docs-path", default="docs/samples")
parser.add_argument("--output", default="ingestion/chunks.jsonl")
parser.add_argument("--overlap", type=int, default=OVERLAP_LINES)
parser.add_argument("--max-tokens", type=int, default=MAX_NARRATIVE_TOKENS)
parser.add_argument("--dedup-threshold", type=float, default=DEDUP_THRESHOLD)
parser.add_argument("--no-dedup", action="store_true")
parser.add_argument("--no-overlap", action="store_true")
parser.add_argument("--workers", type=int, default=DEFAULT_WORKERS)
args = parser.parse_args()
cfg = LanguageConfig(args.lang_config)
overlap = 0 if args.no_overlap else args.overlap
print(f" Lenguaje : {cfg.language} v{cfg.version}")
print(f" Config : {args.lang_config}")
print(f" Extensiones : {cfg.extensions | {'.md'}}")
print(f" Docs path : {args.docs_path}")
print(f" Output : {args.output}")
print(f" Workers : {args.workers}")
print(f" Tokenizador : {TOKEN_BACKEND}")
print(f" Overlap : {overlap} líneas (semántico)")
print(f" Max tokens : {args.max_tokens}")
dedup_info = "deactive" if args.no_dedup else \
f"MinHash LSH threshold={args.dedup_threshold}" + \
(f" RAM")
print(f" Dedup : {dedup_info}")
print()
paths = fetch_documents(args.docs_path, cfg, [".md"])
if not paths:
print("No files found.")
return
print(f"{len(paths)} files found\n")
deduplicator = None
if not args.no_dedup and MINHASH_AVAILABLE:
deduplicator = StreamingDeduplicator(
threshold=args.dedup_threshold,
)
writer = JsonlWriter(args.output)
try:
stats, warnings = run_pipeline(
paths, args.lang_config, writer, deduplicator,
overlap, args.max_tokens, args.workers
)
finally:
writer.close()
print_report(stats, warnings, str(writer.path),
TOKEN_BACKEND, args.workers, cfg.language)
if __name__ == "__main__":
main()