assistance-engine/scripts/pipelines/flows/translate_human_eval.py

132 lines
4.4 KiB
Python

import json
from enum import Enum
from datasets import load_dataset
import boto3
import typer
from loguru import logger
from botocore.config import Config
from pathlib import Path
from langchain_core.messages import SystemMessage, HumanMessage
from src.utils.llm_factory import create_chat_model
from scripts.pipelines.tasks.prompts import get_prompt_human_eval
app = typer.Typer()
class Provider(str, Enum):
bedrock = "bedrock"
openai = "openai"
ollama = "ollama"
@app.command()
def generate_synthetic_dataset(
provider: Provider = Provider.bedrock,
model: str = "global.anthropic.claude-sonnet-4-6",
temperature: float = 0.0,
num_samples: int = 10,
seed: int = 42,
context_docs_path: str = "data/avap.txt",
synthetic_output_path: str = "synthetic_datasets",
) -> None:
"""Generate synthetic dataset using the specified LLM with Human Eval dataset."""
logger.info("🚀 Starting synthetic dataset generation pipeline (Human Eval)")
logger.info(
f"Configuration - Provider: {provider}, Model: {model}, Temperature: {temperature}, Samples: {num_samples}, Seed: {seed}"
)
config = Config(
connect_timeout=10,
read_timeout=600,
)
client = boto3.client("bedrock-runtime", config=config)
logger.info("✓ Bedrock client initialized successfully")
# Create LLM instance with specified parameters
logger.debug(f"Creating LLM instance with provider: {provider}")
llm = create_chat_model(
provider=provider,
client=client,
model=model,
temperature=temperature,
)
logger.info(f"✓ LLM initialized: {model}")
# Load Human Eval dataset
logger.debug("Loading OpenAI Human Eval dataset")
dataset_full = load_dataset("openai_humaneval")
logger.info("✓ Human Eval dataset loaded successfully")
# Select random test samples for synthetic generation
logger.debug(f"Selecting {num_samples} random test samples from Human Eval dataset")
random_test_samples = (
dataset_full["test"].shuffle(seed=seed).select(range(min(num_samples, len(dataset_full["test"]))))
)
logger.info(f"✓ Selected {len(random_test_samples)} test samples")
# Prepare test samples dictionary
logger.debug("Preparing test samples dictionary")
test_samples_dict = {
str(sample["task_id"]): {
"task_id": sample["task_id"],
"prompt": sample["prompt"],
"canonical_solution": sample["canonical_solution"],
"test": sample["test"],
"entry_point": sample["entry_point"],
}
for i, sample in enumerate(random_test_samples)
}
logger.info(f"✓ Prepared {len(test_samples_dict)} samples for processing")
# Load AVAP documentation
logger.debug(f"Loading AVAP documentation from {context_docs_path}")
with open(context_docs_path, "r") as f:
avap_docs = f.read()
logger.info(f"✓ AVAP documentation loaded ({len(avap_docs)} characters)")
# Generate prompt with AVAP context
logger.debug("Generating prompt with AVAP context")
get_prompt_human_eval_func = get_prompt_human_eval(avap_docs=avap_docs)
logger.debug("✓ Prompt generated successfully")
# Invoke LLM to generate synthetic data
logger.info("Invoking LLM to generate synthetic dataset...")
llm_response = llm.invoke(
[get_prompt_human_eval_func, HumanMessage(content=str(test_samples_dict))]
)
logger.info("✓ LLM response received")
logger.info(f"LLM Response: {llm_response.content}")
# Parse JSON response
logger.debug("Parsing LLM response as JSON")
json_str = (
llm_response.content.removeprefix("```json").removesuffix("```").strip()
)
logger.debug(f"JSON string: {json_str}")
synthetic_data = json.loads(json_str)
logger.info(
f"✓ Successfully parsed synthetic data with {len(synthetic_data)} samples"
)
logger.info(
f"Pipeline completed successfully! Generated {len(synthetic_data)} synthetic samples"
)
output_dir = Path(synthetic_output_path)
output_dir.mkdir(parents=True, exist_ok=True)
output_file = output_dir / f"synthetic_data_human_eval_{provider.value}.json"
with output_file.open("w", encoding="utf-8") as f:
json.dump(synthetic_data, f, ensure_ascii=False, indent=2)
return synthetic_data
if __name__ == "__main__":
logger.info("=" * 50)
logger.info("Human Eval Synthetic Generation Pipeline")
logger.info("=" * 50)
app()