59 lines
1.4 KiB
Python
59 lines
1.4 KiB
Python
import json
|
|
from datasets import load_dataset
|
|
import boto3
|
|
import typer
|
|
import logging
|
|
from botocore.config import Config
|
|
from langchain_core.messages import SystemMessage, HumanMessage
|
|
from src.utils.llm_factory import create_chat_model
|
|
from src.config import RAW_DIR, INTERIM_DIR
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
app = typer.Typer()
|
|
|
|
config = Config(
|
|
region_name="us-east-1",
|
|
connect_timeout=10,
|
|
read_timeout=600,
|
|
)
|
|
|
|
client = boto3.client("bedrock-runtime", config=config)
|
|
|
|
llm = create_chat_model(
|
|
provider="bedrock",
|
|
client=client,
|
|
model="global.anthropic.claude-sonnet-4-6",
|
|
temperature=0,
|
|
)
|
|
|
|
dataset_full = load_dataset("mbpp")
|
|
|
|
|
|
random_test_samples = dataset_full["test"].shuffle(seed=42).select(range(50))
|
|
|
|
test_samples_dict = {
|
|
str(i): {
|
|
"text": sample["text"],
|
|
"code": sample["code"],
|
|
}
|
|
for i, sample in enumerate(random_test_samples)
|
|
}
|
|
|
|
llm_response = llm.invoke([PROMPT_MBPP, HumanMessage(content=str(test_samples_dict))])
|
|
|
|
json_str = llm_response.content.removeprefix("```json").removesuffix("```").strip()
|
|
synthetic_data = json.loads(json_str)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
|
)
|
|
try:
|
|
app()
|
|
except Exception as exc:
|
|
logger.exception(exc)
|
|
raise
|