assistance-engine/scripts/pipelines/flows/synthetic_dataset_generatio...

252 lines
7.9 KiB
Python

import json
from enum import Enum
from typing import Optional
from datasets import load_dataset
import boto3
import typer
from loguru import logger
from botocore.config import Config
from pathlib import Path
from langchain_core.messages import SystemMessage, HumanMessage
from src.utils.llm_factory import create_chat_model
from scripts.pipelines.tasks.prompts import (
get_prompt_mbpp,
get_prompt_human_eval,
get_prompt_generation,
)
app = typer.Typer()
class Provider(str, Enum):
bedrock = "bedrock"
openai = "openai"
ollama = "ollama"
class Dataset(str, Enum):
mbpp = "mbpp"
human_eval = "openai_humaneval"
@app.command()
def generate_synthetic_dataset(
provider: Provider = Provider.bedrock,
model: str = "global.anthropic.claude-sonnet-4-6",
temperature: float = 0.0,
num_samples: int = 10,
seed: int = 42,
context_docs_path: str = "docs/LRM/avap.md",
synthetic_output_path: str = "synthetic_datasets",
dataset: Optional[Dataset] = None,
problems_per_category: int = 10,
) -> None:
"""
Generate synthetic AVAP dataset.
Modes:
- With --dataset {mbpp|human_eval}: Translate existing dataset to AVAP
- Without --dataset: Generate new problems from scratch using the prompt
"""
logger.info("🚀 Starting synthetic dataset generation pipeline")
logger.info(
f"Configuration - Provider: {provider}, Model: {model}, Temperature: {temperature}, "
f"Samples: {num_samples}, Seed: {seed}, Dataset: {dataset or 'generation mode'}"
)
config = Config(
connect_timeout=10,
read_timeout=600,
)
client = boto3.client("bedrock-runtime", config=config)
logger.info("✓ Bedrock client initialized successfully")
# Create LLM instance with specified parameters
logger.debug(f"Creating LLM instance with provider: {provider}")
llm = create_chat_model(
provider=provider,
client=client,
model=model,
temperature=temperature,
)
logger.info(f"✓ LLM initialized: {model}")
# Load AVAP documentation
logger.debug(f"Loading AVAP documentation from {context_docs_path}")
with open(context_docs_path, "r") as f:
avap_docs = f.read()
logger.info(f"✓ AVAP documentation loaded ({len(avap_docs)} characters)")
# Choose mode: translation or generation
if dataset:
logger.info(f"🔄 Translation mode: Converting {dataset.value} dataset to AVAP")
_generate_from_dataset(
llm=llm,
avap_docs=avap_docs,
dataset_name=dataset.value,
num_samples=num_samples,
seed=seed,
output_path=synthetic_output_path,
provider=provider.value,
)
else:
logger.info("✨ Generation mode: Creating new problems from prompt")
_generate_from_prompt(
llm=llm,
avap_docs=avap_docs,
num_samples=num_samples,
output_path=synthetic_output_path,
provider=provider.value,
problems_per_category=problems_per_category,
)
def _generate_from_dataset(
llm,
avap_docs: str,
dataset_name: str,
num_samples: int,
seed: int,
output_path: str,
provider: str,
) -> None:
"""Generate by translating an existing dataset to AVAP."""
# Load dataset
logger.debug(f"Loading {dataset_name} dataset")
dataset_full = load_dataset(dataset_name)
logger.info(f"{dataset_name} dataset loaded successfully")
# Select random test samples
logger.debug(f"Selecting {num_samples} random samples from {dataset_name}")
random_test_samples = dataset_full["test"].shuffle(seed=seed).select(
range(min(num_samples, len(dataset_full["test"])))
)
logger.info(f"✓ Selected {len(random_test_samples)} samples")
# Prepare samples dictionary
logger.debug("Preparing samples dictionary")
if dataset_name == "mbpp":
test_samples_dict = {
str(i): {"text": sample["text"], "code": sample["code"]}
for i, sample in enumerate(random_test_samples)
}
prompt_func = get_prompt_mbpp(avap_docs=avap_docs)
output_suffix = "mbpp"
elif dataset_name == "openai_humaneval":
test_samples_dict = {
str(sample["task_id"]): {
"task_id": sample["task_id"],
"prompt": sample["prompt"],
"canonical_solution": sample["canonical_solution"],
"test": sample["test"],
"entry_point": sample["entry_point"],
}
for i, sample in enumerate(random_test_samples)
}
prompt_func = get_prompt_human_eval(avap_docs=avap_docs)
output_suffix = "human_eval"
else:
raise ValueError(f"Unsupported dataset: {dataset_name}")
logger.info(f"✓ Prepared {len(test_samples_dict)} samples for processing")
# Generate prompt
logger.debug("Generating prompt with AVAP context")
logger.debug("✓ Prompt generated successfully")
# Invoke LLM
logger.info("Invoking LLM to generate synthetic dataset...")
llm_response = llm.invoke(
[prompt_func, HumanMessage(content=str(test_samples_dict))]
)
logger.info("✓ LLM response received")
logger.debug(f"LLM Response: {llm_response.content}")
# Parse JSON response
logger.debug("Parsing LLM response as JSON")
json_str = (
llm_response.content.removeprefix("```json")
.removesuffix("```")
.strip()
)
logger.debug(f"JSON string: {json_str}")
synthetic_data = json.loads(json_str)
logger.info(
f"✓ Successfully parsed synthetic data with {len(synthetic_data)} samples"
)
# Save output
output_dir = Path(output_path)
output_dir.mkdir(parents=True, exist_ok=True)
output_file = output_dir / f"synthetic_data_{output_suffix}_{provider}.json"
with output_file.open("w", encoding="utf-8") as f:
json.dump(synthetic_data, f, ensure_ascii=False, indent=2)
logger.info(f"✓ Output saved to {output_file}")
logger.info(
f"Pipeline completed successfully! Generated {len(synthetic_data)} synthetic samples"
)
def _generate_from_prompt(
llm,
avap_docs: str,
num_samples: int,
output_path: str,
provider: str,
problems_per_category: int = 10,
) -> None:
"""Generate new problems from scratch using the generation prompt."""
logger.debug("Generating prompt for problem generation")
prompt_func = get_prompt_generation(
avap_docs=avap_docs,
num_problems=num_samples,
problems_per_category=problems_per_category,
)
logger.debug("✓ Prompt generated successfully")
# Invoke LLM
logger.info("Invoking LLM to generate new problems...")
llm_response = llm.invoke(
[prompt_func, HumanMessage(content="Generate the synthetic dataset now.")]
)
logger.info("✓ LLM response received")
logger.debug(f"LLM Response: {llm_response.content}")
# Parse JSON response
logger.debug("Parsing LLM response as JSON")
json_str = (
llm_response.content.removeprefix("```json")
.removeprefix("```")
.removesuffix("```")
.strip()
)
logger.debug(f"JSON string: {json_str}")
synthetic_data = json.loads(json_str)
logger.info(
f"✓ Successfully parsed synthetic data with {len(synthetic_data)} samples"
)
# Save output
output_dir = Path(output_path)
output_dir.mkdir(parents=True, exist_ok=True)
output_file = output_dir / f"synthetic_data_generated_{provider}.json"
with output_file.open("w", encoding="utf-8") as f:
json.dump(synthetic_data, f, ensure_ascii=False, indent=2)
logger.info(f"✓ Output saved to {output_file}")
logger.info(
f"Pipeline completed successfully! Generated {len(synthetic_data)} synthetic samples"
)
if __name__ == "__main__":
logger.info("=" * 50)
logger.info("Synthetic Dataset Generation Pipeline")
logger.info("=" * 50)
app()