assistance-engine/scripts/pipelines/tasks/validate.py

111 lines
3.5 KiB
Python

import io
import json
from pathlib import Path
import requests
from loguru import logger
def load_tasks(dataset_path: Path) -> list[dict]:
"""Load tasks from a synthetic dataset JSON file."""
with dataset_path.open("r", encoding="utf-8") as f:
tasks: list[dict] = json.load(f)
logger.info(f"Loaded {len(tasks)} tasks from {dataset_path}")
return tasks
def _post_single_task(task: dict, api_url: str, timeout: int) -> dict:
"""Post a single task to the validation API and return the result."""
payload = json.dumps([task]).encode("utf-8")
file_obj = io.BytesIO(payload)
response = requests.post(
api_url,
files={"file": ("task.json", file_obj, "application/json")},
timeout=timeout,
)
return _parse_task_response(response.text)
def _parse_task_response(raw: str) -> dict:
"""Parse the API response for a single task."""
raw = raw.strip()
if not raw:
return {"success": False, "error": "Empty response from API"}
decoder = json.JSONDecoder()
objects: list[dict] = []
idx = 0
while idx < len(raw):
try:
obj, end_idx = decoder.raw_decode(raw, idx)
objects.append(obj)
idx = end_idx
except json.JSONDecodeError:
idx += 1
while idx < len(raw) and raw[idx] in " \t\n\r":
idx += 1
if not objects:
return {"success": False, "error": f"Could not parse response: {raw[:200]}"}
for obj in objects:
if not obj.get("success"):
return obj
if "result_sequence" in obj and obj["result_sequence"]:
return obj["result_sequence"][0]
return objects[0]
def validate_all_tasks(tasks: list[dict], api_url: str, timeout: int) -> list[dict]:
"""Validate each task individually against the API.
Posts tasks one by one so that a failure in one task does not
prevent the rest from being validated.
Args:
tasks: List of task dicts to validate.
api_url: URL of the validation API endpoint.
timeout: Timeout in seconds for each API request.
Returns:
List of tasks that passed validation.
"""
validated: list[dict] = []
errors: list[str] = []
for idx, task in enumerate(tasks):
task_id = task.get("task_id", idx)
try:
result = _post_single_task(task, api_url, timeout)
if result.get("success") and result.get("assertion_result", True):
validated.append(task)
logger.debug(f"Task {task_id}: passed")
else:
msg = f"Task {task_id}: {result}"
errors.append(msg)
logger.warning(msg)
except requests.RequestException as exc:
msg = f"Task {task_id}: Request failed — {exc}"
errors.append(msg)
logger.error(msg)
if errors:
logger.error(
f"\n{'=' * 60}\n"
f"VALIDATION ERROR SUMMARY — {len(errors)} task(s) failed:\n"
+ "\n".join(f" - {e}" for e in errors)
+ f"\n{'=' * 60}"
)
logger.info(f"Validation complete: {len(validated)}/{len(tasks)} tasks passed")
return validated
def save_validated_tasks(tasks: list[dict], output_path: Path) -> None:
"""Write the validated task list to a JSON file."""
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w", encoding="utf-8") as f:
json.dump(tasks, f, ensure_ascii=False, indent=2)
logger.info(f"Saved validated dataset to {output_path}")