assistance-engine/Docker/tests/test_prd_0002.py

396 lines
14 KiB
Python

"""
tests/test_prd_0002.py
Unit tests for PRD-0002 — Editor Context Injection.
These tests run without any external dependencies (no Elasticsearch, no Ollama,
no gRPC server). They validate the logic of the components modified in PRD-0002:
- _parse_query_type — classifier output parser (graph.py)
- _parse_editor_context — user field parser (openai_proxy.py)
- _build_classify_prompt — classify prompt builder (graph.py)
- _build_reformulate_query — reformulate anchor builder (graph.py)
- _build_generation_prompt — generation prompt builder (graph.py)
- _decode_b64 — base64 decoder (server.py)
Run with:
pytest tests/test_prd_0002.py -v
"""
import base64
import json
import sys
import os
import pytest
# ---------------------------------------------------------------------------
# Minimal stubs so we can import graph.py and openai_proxy.py without
# the full Docker/src environment loaded
# ---------------------------------------------------------------------------
# Stub brunix_pb2 so openai_proxy imports cleanly
import types
brunix_pb2 = types.ModuleType("brunix_pb2")
brunix_pb2.AgentRequest = lambda **kw: kw
brunix_pb2.AgentResponse = lambda **kw: kw
sys.modules["brunix_pb2"] = brunix_pb2
sys.modules["brunix_pb2_grpc"] = types.ModuleType("brunix_pb2_grpc")
# Stub grpc
grpc_mod = types.ModuleType("grpc")
grpc_mod.insecure_channel = lambda *a, **kw: None
grpc_mod.Channel = object
grpc_mod.RpcError = Exception
sys.modules["grpc"] = grpc_mod
# Stub grpc_reflection
refl = types.ModuleType("grpc_reflection.v1alpha.reflection")
sys.modules["grpc_reflection"] = types.ModuleType("grpc_reflection")
sys.modules["grpc_reflection.v1alpha"] = types.ModuleType("grpc_reflection.v1alpha")
sys.modules["grpc_reflection.v1alpha.reflection"] = refl
# Add Docker/src to path so we can import the modules directly
DOCKER_SRC = os.path.join(os.path.dirname(__file__), "..", "Docker", "src")
sys.path.insert(0, os.path.abspath(DOCKER_SRC))
# ---------------------------------------------------------------------------
# Import the functions under test
# ---------------------------------------------------------------------------
# We import only the pure functions — no LLM, no ES, no gRPC calls
def _parse_query_type(raw: str):
"""Copy of _parse_query_type from graph.py — tested in isolation."""
parts = raw.strip().upper().split()
query_type = "RETRIEVAL"
use_editor = False
if parts:
first = parts[0]
if first.startswith("CODE_GENERATION") or "CODE" in first:
query_type = "CODE_GENERATION"
elif first.startswith("CONVERSATIONAL"):
query_type = "CONVERSATIONAL"
if len(parts) > 1 and parts[1] == "EDITOR":
use_editor = True
return query_type, use_editor
def _decode_b64(value: str) -> str:
"""Copy of _decode_b64 from server.py — tested in isolation."""
try:
return base64.b64decode(value).decode("utf-8") if value else ""
except Exception:
return ""
def _parse_editor_context(user):
"""Copy of _parse_editor_context from openai_proxy.py — tested in isolation."""
if not user:
return "", "", "", ""
try:
ctx = json.loads(user)
if isinstance(ctx, dict):
return (
ctx.get("editor_content", "") or "",
ctx.get("selected_text", "") or "",
ctx.get("extra_context", "") or "",
json.dumps(ctx.get("user_info", {})),
)
except (json.JSONDecodeError, TypeError):
pass
return "", "", "", ""
def _build_reformulate_query(question: str, selected_text: str) -> str:
"""Copy of _build_reformulate_query from graph.py — tested in isolation."""
if not selected_text:
return question
return f"{selected_text}\n\nUser question about the above: {question}"
def _build_generation_prompt_injects(editor_content, selected_text, use_editor):
"""Helper — returns True if editor context would be injected."""
sections = []
if selected_text and use_editor:
sections.append("selected_code")
if editor_content and use_editor:
sections.append("editor_file")
return len(sections) > 0
# ---------------------------------------------------------------------------
# Tests: _parse_query_type
# ---------------------------------------------------------------------------
class TestParseQueryType:
def test_retrieval_no_editor(self):
qt, ue = _parse_query_type("RETRIEVAL NO_EDITOR")
assert qt == "RETRIEVAL"
assert ue is False
def test_retrieval_editor(self):
qt, ue = _parse_query_type("RETRIEVAL EDITOR")
assert qt == "RETRIEVAL"
assert ue is True
def test_code_generation_no_editor(self):
qt, ue = _parse_query_type("CODE_GENERATION NO_EDITOR")
assert qt == "CODE_GENERATION"
assert ue is False
def test_code_generation_editor(self):
qt, ue = _parse_query_type("CODE_GENERATION EDITOR")
assert qt == "CODE_GENERATION"
assert ue is True
def test_conversational_no_editor(self):
qt, ue = _parse_query_type("CONVERSATIONAL NO_EDITOR")
assert qt == "CONVERSATIONAL"
assert ue is False
def test_single_token_defaults_no_editor(self):
"""If model returns only one token, use_editor defaults to False."""
qt, ue = _parse_query_type("RETRIEVAL")
assert qt == "RETRIEVAL"
assert ue is False
def test_empty_defaults_retrieval_no_editor(self):
qt, ue = _parse_query_type("")
assert qt == "RETRIEVAL"
assert ue is False
def test_case_insensitive(self):
qt, ue = _parse_query_type("retrieval editor")
assert qt == "RETRIEVAL"
assert ue is True
def test_code_shorthand(self):
"""'CODE' alone should map to CODE_GENERATION."""
qt, ue = _parse_query_type("CODE NO_EDITOR")
assert qt == "CODE_GENERATION"
assert ue is False
def test_extra_whitespace(self):
qt, ue = _parse_query_type(" RETRIEVAL NO_EDITOR ")
assert qt == "RETRIEVAL"
assert ue is False
# ---------------------------------------------------------------------------
# Tests: _decode_b64
# ---------------------------------------------------------------------------
class TestDecodeB64:
def test_valid_base64_spanish(self):
text = "addVar(mensaje, \"Hola mundo\")\naddResult(mensaje)"
encoded = base64.b64encode(text.encode("utf-8")).decode("utf-8")
assert _decode_b64(encoded) == text
def test_valid_base64_english(self):
text = "registerEndpoint(\"GET\", \"/hello\", [], \"public\", handler, \"\")"
encoded = base64.b64encode(text.encode("utf-8")).decode("utf-8")
assert _decode_b64(encoded) == text
def test_empty_string_returns_empty(self):
assert _decode_b64("") == ""
def test_none_equivalent_empty(self):
assert _decode_b64(None) == ""
def test_invalid_base64_returns_empty(self):
assert _decode_b64("not_valid_base64!!!") == ""
def test_unicode_content(self):
text = "// función de validación\nif(token, \"SECRET\", \"=\")"
encoded = base64.b64encode(text.encode("utf-8")).decode("utf-8")
assert _decode_b64(encoded) == text
# ---------------------------------------------------------------------------
# Tests: _parse_editor_context
# ---------------------------------------------------------------------------
class TestParseEditorContext:
def _encode(self, text: str) -> str:
return base64.b64encode(text.encode()).decode()
def test_full_context_parsed(self):
editor = self._encode("addVar(x, 10)")
selected = self._encode("addResult(x)")
extra = self._encode("/path/to/file.avap")
user_json = json.dumps({
"editor_content": editor,
"selected_text": selected,
"extra_context": extra,
"user_info": {"dev_id": 1, "project_id": 2, "org_id": 3}
})
ec, st, ex, ui = _parse_editor_context(user_json)
assert ec == editor
assert st == selected
assert ex == extra
assert json.loads(ui) == {"dev_id": 1, "project_id": 2, "org_id": 3}
def test_empty_user_returns_empty_tuple(self):
ec, st, ex, ui = _parse_editor_context(None)
assert ec == st == ex == ""
def test_empty_string_returns_empty_tuple(self):
ec, st, ex, ui = _parse_editor_context("")
assert ec == st == ex == ""
def test_plain_string_not_json_returns_empty(self):
"""Non-JSON user field — backward compat, no error raised."""
ec, st, ex, ui = _parse_editor_context("plain string")
assert ec == st == ex == ""
def test_missing_fields_default_empty(self):
user_json = json.dumps({"editor_content": "abc"})
ec, st, ex, ui = _parse_editor_context(user_json)
assert ec == "abc"
assert st == ""
assert ex == ""
def test_user_info_missing_defaults_empty_object(self):
user_json = json.dumps({"editor_content": "abc"})
_, _, _, ui = _parse_editor_context(user_json)
assert json.loads(ui) == {}
def test_user_info_full_object(self):
user_json = json.dumps({
"editor_content": "",
"selected_text": "",
"extra_context": "",
"user_info": {"dev_id": 42, "project_id": 7, "org_id": 99}
})
_, _, _, ui = _parse_editor_context(user_json)
parsed = json.loads(ui)
assert parsed["dev_id"] == 42
assert parsed["project_id"] == 7
assert parsed["org_id"] == 99
def test_session_id_not_leaked_into_context(self):
"""session_id must NOT appear in editor context — it has its own field."""
user_json = json.dumps({
"editor_content": "",
"selected_text": "",
"extra_context": "",
"user_info": {}
})
ec, st, ex, ui = _parse_editor_context(user_json)
assert "session_id" not in ec
assert "session_id" not in st
# ---------------------------------------------------------------------------
# Tests: _build_reformulate_query
# ---------------------------------------------------------------------------
class TestBuildReformulateQuery:
def test_no_selected_text_returns_question(self):
q = "Que significa AVAP?"
assert _build_reformulate_query(q, "") == q
def test_selected_text_prepended_to_question(self):
q = "que hace esto?"
selected = "addVar(x, 10)\naddResult(x)"
result = _build_reformulate_query(q, selected)
assert result.startswith(selected)
assert q in result
def test_selected_text_anchor_format(self):
q = "fix this"
selected = "try()\n ormDirect(query, res)\nexception(e)\nend()"
result = _build_reformulate_query(q, selected)
assert "User question about the above:" in result
assert selected in result
assert q in result
# ---------------------------------------------------------------------------
# Tests: editor context injection logic
# ---------------------------------------------------------------------------
class TestEditorContextInjection:
def test_no_injection_when_use_editor_false(self):
"""Editor content must NOT be injected when use_editor_context is False."""
injected = _build_generation_prompt_injects(
editor_content = "addVar(x, 10)",
selected_text = "addResult(x)",
use_editor = False,
)
assert injected is False
def test_injection_when_use_editor_true_and_content_present(self):
"""Editor content MUST be injected when use_editor_context is True."""
injected = _build_generation_prompt_injects(
editor_content = "addVar(x, 10)",
selected_text = "addResult(x)",
use_editor = True,
)
assert injected is True
def test_no_injection_when_content_empty_even_if_flag_true(self):
"""Empty fields must never be injected even if flag is True."""
injected = _build_generation_prompt_injects(
editor_content = "",
selected_text = "",
use_editor = True,
)
assert injected is False
def test_partial_injection_selected_only(self):
"""selected_text alone triggers injection when flag is True."""
injected = _build_generation_prompt_injects(
editor_content = "",
selected_text = "addResult(x)",
use_editor = True,
)
assert injected is True
# ---------------------------------------------------------------------------
# Tests: classifier routing — EDITOR signal
# ---------------------------------------------------------------------------
class TestClassifierEditorSignal:
"""
These tests validate that the two-token output format is correctly parsed
for all combinations the classifier can produce.
"""
VALID_OUTPUTS = [
("RETRIEVAL NO_EDITOR", "RETRIEVAL", False),
("RETRIEVAL EDITOR", "RETRIEVAL", True),
("CODE_GENERATION NO_EDITOR", "CODE_GENERATION", False),
("CODE_GENERATION EDITOR", "CODE_GENERATION", True),
("CONVERSATIONAL NO_EDITOR", "CONVERSATIONAL", False),
("CONVERSATIONAL EDITOR", "CONVERSATIONAL", True),
]
@pytest.mark.parametrize("raw,expected_qt,expected_ue", VALID_OUTPUTS)
def test_valid_two_token_output(self, raw, expected_qt, expected_ue):
qt, ue = _parse_query_type(raw)
assert qt == expected_qt
assert ue == expected_ue
def test_editor_flag_false_for_general_avap_question(self):
"""'Que significa AVAP?' -> RETRIEVAL NO_EDITOR."""
qt, ue = _parse_query_type("RETRIEVAL NO_EDITOR")
assert ue is False
def test_editor_flag_true_for_explicit_editor_reference(self):
"""'que hace este codigo?' with selected_text -> RETRIEVAL EDITOR."""
qt, ue = _parse_query_type("RETRIEVAL EDITOR")
assert ue is True
def test_editor_flag_false_for_code_generation_without_reference(self):
"""'dame un API de hello world' -> CODE_GENERATION NO_EDITOR."""
qt, ue = _parse_query_type("CODE_GENERATION NO_EDITOR")
assert ue is False