assistance-engine/scripts/pipelines/flows/dataset_synthetic_generatio...

59 lines
1.4 KiB
Python

import json
from datasets import load_dataset
import boto3
import typer
import logging
from botocore.config import Config
from langchain_core.messages import SystemMessage, HumanMessage
from src.utils.llm_factory import create_chat_model
from src.config import RAW_DIR, INTERIM_DIR
logger = logging.getLogger(__name__)
app = typer.Typer()
config = Config(
region_name="us-east-1",
connect_timeout=10,
read_timeout=600,
)
client = boto3.client("bedrock-runtime", config=config)
llm = create_chat_model(
provider="bedrock",
client=client,
model="global.anthropic.claude-sonnet-4-6",
temperature=0,
)
dataset_full = load_dataset("mbpp")
random_test_samples = dataset_full["test"].shuffle(seed=42).select(range(50))
test_samples_dict = {
str(i): {
"text": sample["text"],
"code": sample["code"],
}
for i, sample in enumerate(random_test_samples)
}
llm_response = llm.invoke([PROMPT_MBPP, HumanMessage(content=str(test_samples_dict))])
json_str = llm_response.content.removeprefix("```json").removesuffix("```").strip()
synthetic_data = json.loads(json_str)
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
)
try:
app()
except Exception as exc:
logger.exception(exc)
raise