132 lines
4.4 KiB
Python
132 lines
4.4 KiB
Python
import json
|
|
from enum import Enum
|
|
from datasets import load_dataset
|
|
import boto3
|
|
import typer
|
|
from loguru import logger
|
|
from botocore.config import Config
|
|
from pathlib import Path
|
|
from langchain_core.messages import SystemMessage, HumanMessage
|
|
from src.utils.llm_factory import create_chat_model
|
|
from scripts.pipelines.tasks.prompts import get_prompt_human_eval
|
|
|
|
|
|
app = typer.Typer()
|
|
|
|
|
|
class Provider(str, Enum):
|
|
bedrock = "bedrock"
|
|
openai = "openai"
|
|
ollama = "ollama"
|
|
|
|
|
|
@app.command()
|
|
def generate_synthetic_dataset(
|
|
provider: Provider = Provider.bedrock,
|
|
model: str = "global.anthropic.claude-sonnet-4-6",
|
|
temperature: float = 0.0,
|
|
num_samples: int = 10,
|
|
seed: int = 42,
|
|
context_docs_path: str = "data/avap.txt",
|
|
synthetic_output_path: str = "synthetic_datasets",
|
|
) -> None:
|
|
"""Generate synthetic dataset using the specified LLM with Human Eval dataset."""
|
|
logger.info("🚀 Starting synthetic dataset generation pipeline (Human Eval)")
|
|
logger.info(
|
|
f"Configuration - Provider: {provider}, Model: {model}, Temperature: {temperature}, Samples: {num_samples}, Seed: {seed}"
|
|
)
|
|
|
|
config = Config(
|
|
connect_timeout=10,
|
|
read_timeout=600,
|
|
)
|
|
client = boto3.client("bedrock-runtime", config=config)
|
|
logger.info("✓ Bedrock client initialized successfully")
|
|
|
|
# Create LLM instance with specified parameters
|
|
logger.debug(f"Creating LLM instance with provider: {provider}")
|
|
llm = create_chat_model(
|
|
provider=provider,
|
|
client=client,
|
|
model=model,
|
|
temperature=temperature,
|
|
)
|
|
logger.info(f"✓ LLM initialized: {model}")
|
|
|
|
# Load Human Eval dataset
|
|
logger.debug("Loading OpenAI Human Eval dataset")
|
|
dataset_full = load_dataset("openai_humaneval")
|
|
logger.info("✓ Human Eval dataset loaded successfully")
|
|
|
|
# Select random test samples for synthetic generation
|
|
logger.debug(f"Selecting {num_samples} random test samples from Human Eval dataset")
|
|
random_test_samples = (
|
|
dataset_full["test"].shuffle(seed=seed).select(range(min(num_samples, len(dataset_full["test"]))))
|
|
)
|
|
logger.info(f"✓ Selected {len(random_test_samples)} test samples")
|
|
|
|
# Prepare test samples dictionary
|
|
logger.debug("Preparing test samples dictionary")
|
|
test_samples_dict = {
|
|
str(sample["task_id"]): {
|
|
"task_id": sample["task_id"],
|
|
"prompt": sample["prompt"],
|
|
"canonical_solution": sample["canonical_solution"],
|
|
"test": sample["test"],
|
|
"entry_point": sample["entry_point"],
|
|
}
|
|
for i, sample in enumerate(random_test_samples)
|
|
}
|
|
logger.info(f"✓ Prepared {len(test_samples_dict)} samples for processing")
|
|
|
|
# Load AVAP documentation
|
|
logger.debug(f"Loading AVAP documentation from {context_docs_path}")
|
|
with open(context_docs_path, "r") as f:
|
|
avap_docs = f.read()
|
|
logger.info(f"✓ AVAP documentation loaded ({len(avap_docs)} characters)")
|
|
|
|
# Generate prompt with AVAP context
|
|
logger.debug("Generating prompt with AVAP context")
|
|
get_prompt_human_eval_func = get_prompt_human_eval(avap_docs=avap_docs)
|
|
logger.debug("✓ Prompt generated successfully")
|
|
|
|
# Invoke LLM to generate synthetic data
|
|
logger.info("Invoking LLM to generate synthetic dataset...")
|
|
llm_response = llm.invoke(
|
|
[get_prompt_human_eval_func, HumanMessage(content=str(test_samples_dict))]
|
|
)
|
|
logger.info("✓ LLM response received")
|
|
logger.info(f"LLM Response: {llm_response.content}")
|
|
|
|
# Parse JSON response
|
|
logger.debug("Parsing LLM response as JSON")
|
|
json_str = (
|
|
llm_response.content.removeprefix("```json").removesuffix("```").strip()
|
|
)
|
|
logger.debug(f"JSON string: {json_str}")
|
|
synthetic_data = json.loads(json_str)
|
|
logger.info(
|
|
f"✓ Successfully parsed synthetic data with {len(synthetic_data)} samples"
|
|
)
|
|
|
|
logger.info(
|
|
f"Pipeline completed successfully! Generated {len(synthetic_data)} synthetic samples"
|
|
)
|
|
|
|
output_dir = Path(synthetic_output_path)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
output_file = output_dir / f"synthetic_data_human_eval_{provider.value}.json"
|
|
with output_file.open("w", encoding="utf-8") as f:
|
|
json.dump(synthetic_data, f, ensure_ascii=False, indent=2)
|
|
|
|
return synthetic_data
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logger.info("=" * 50)
|
|
logger.info("Human Eval Synthetic Generation Pipeline")
|
|
logger.info("=" * 50)
|
|
|
|
app()
|