111 lines
3.5 KiB
Python
111 lines
3.5 KiB
Python
import io
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import requests
|
|
from loguru import logger
|
|
|
|
|
|
def load_tasks(dataset_path: Path) -> list[dict]:
|
|
"""Load tasks from a synthetic dataset JSON file."""
|
|
with dataset_path.open("r", encoding="utf-8") as f:
|
|
tasks: list[dict] = json.load(f)
|
|
logger.info(f"Loaded {len(tasks)} tasks from {dataset_path}")
|
|
return tasks
|
|
|
|
|
|
def _post_single_task(task: dict, api_url: str, timeout: int) -> dict:
|
|
"""Post a single task to the validation API and return the result."""
|
|
payload = json.dumps([task]).encode("utf-8")
|
|
file_obj = io.BytesIO(payload)
|
|
response = requests.post(
|
|
api_url,
|
|
files={"file": ("task.json", file_obj, "application/json")},
|
|
timeout=timeout,
|
|
)
|
|
return _parse_task_response(response.text)
|
|
|
|
|
|
def _parse_task_response(raw: str) -> dict:
|
|
"""Parse the API response for a single task."""
|
|
raw = raw.strip()
|
|
if not raw:
|
|
return {"success": False, "error": "Empty response from API"}
|
|
|
|
decoder = json.JSONDecoder()
|
|
objects: list[dict] = []
|
|
idx = 0
|
|
while idx < len(raw):
|
|
try:
|
|
obj, end_idx = decoder.raw_decode(raw, idx)
|
|
objects.append(obj)
|
|
idx = end_idx
|
|
except json.JSONDecodeError:
|
|
idx += 1
|
|
while idx < len(raw) and raw[idx] in " \t\n\r":
|
|
idx += 1
|
|
|
|
if not objects:
|
|
return {"success": False, "error": f"Could not parse response: {raw[:200]}"}
|
|
|
|
for obj in objects:
|
|
if not obj.get("success"):
|
|
return obj
|
|
if "result_sequence" in obj and obj["result_sequence"]:
|
|
return obj["result_sequence"][0]
|
|
|
|
return objects[0]
|
|
|
|
|
|
def validate_all_tasks(tasks: list[dict], api_url: str, timeout: int) -> list[dict]:
|
|
"""Validate each task individually against the API.
|
|
|
|
Posts tasks one by one so that a failure in one task does not
|
|
prevent the rest from being validated.
|
|
|
|
Args:
|
|
tasks: List of task dicts to validate.
|
|
api_url: URL of the validation API endpoint.
|
|
timeout: Timeout in seconds for each API request.
|
|
|
|
Returns:
|
|
List of tasks that passed validation.
|
|
"""
|
|
validated: list[dict] = []
|
|
errors: list[str] = []
|
|
|
|
for idx, task in enumerate(tasks):
|
|
task_id = task.get("task_id", idx)
|
|
try:
|
|
result = _post_single_task(task, api_url, timeout)
|
|
if result.get("success") and result.get("assertion_result", True):
|
|
validated.append(task)
|
|
logger.debug(f"Task {task_id}: passed")
|
|
else:
|
|
msg = f"Task {task_id}: {result}"
|
|
errors.append(msg)
|
|
logger.warning(msg)
|
|
except requests.RequestException as exc:
|
|
msg = f"Task {task_id}: Request failed — {exc}"
|
|
errors.append(msg)
|
|
logger.error(msg)
|
|
|
|
if errors:
|
|
logger.error(
|
|
f"\n{'=' * 60}\n"
|
|
f"VALIDATION ERROR SUMMARY — {len(errors)} task(s) failed:\n"
|
|
+ "\n".join(f" - {e}" for e in errors)
|
|
+ f"\n{'=' * 60}"
|
|
)
|
|
|
|
logger.info(f"Validation complete: {len(validated)}/{len(tasks)} tasks passed")
|
|
return validated
|
|
|
|
|
|
def save_validated_tasks(tasks: list[dict], output_path: Path) -> None:
|
|
"""Write the validated task list to a JSON file."""
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
with output_path.open("w", encoding="utf-8") as f:
|
|
json.dump(tasks, f, ensure_ascii=False, indent=2)
|
|
logger.info(f"Saved validated dataset to {output_path}")
|