checks: Fix flakey llm test, improve performance

This commit is contained in:
Qubasa
2025-10-27 16:56:40 +01:00
parent 9a442c15e9
commit bdd5de5628
6 changed files with 162 additions and 195 deletions

View File

@@ -5,15 +5,17 @@ from pathlib import Path
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
import clan_lib.llm.llm_types
import pytest
from clan_lib.flake.flake import Flake
from clan_lib.llm.llm_types import ModelConfig
from clan_lib.llm.orchestrator import get_llm_turn
from clan_lib.llm.service import create_llm_model, run_llm_service
from clan_lib.service_runner import create_service_manager
if TYPE_CHECKING:
from clan_lib.llm.llm_types import ChatResult
from clan_lib.llm.schemas import ChatMessage, SessionState
from clan_lib.llm.schemas import SessionState
def get_current_mode(session_state: "SessionState") -> str:
@@ -168,28 +170,54 @@ def llm_service() -> Iterator[None]:
service_manager.stop_service("ollama")
def execute_multi_turn_workflow(
user_request: str,
flake: Flake | MagicMock,
conversation_history: list["ChatMessage"] | None = None,
provider: str = "ollama",
session_state: "SessionState | None" = None,
) -> "ChatResult":
"""Execute the multi-turn workflow, auto-executing all pending operations.
@pytest.mark.service_runner
@pytest.mark.usefixtures("mock_nix_shell", "llm_service")
def test_full_conversation_flow(mock_flake: MagicMock) -> None:
"""Test the complete conversation flow by manually calling get_llm_turn at each step.
This simulates the behavior of the CLI auto-execute loop in workflow.py.
This test verifies:
- State transitions through discovery -> readme_fetch -> service_selection -> final_decision
- Each step returns the correct next_action
- Conversation history is preserved across turns
- Session state is correctly maintained
"""
flake = mock_flake
trace_file = Path("~/.ollama/container_test_llm_trace.json").expanduser()
trace_file.unlink(missing_ok=True) # Start fresh
provider = "ollama"
# Override DEFAULT_MODELS with 4-minute timeouts for container tests
clan_lib.llm.llm_types.DEFAULT_MODELS = {
"ollama": ModelConfig(
name="qwen3:4b-instruct",
provider="ollama",
timeout=300, # set inference timeout to 5 minutes as CI may be slow
temperature=0, # set randomness to 0 for consistent test results
),
}
# ========== STEP 1: Initial request (should return next_action for discovery) ==========
print_separator("STEP 1: Initial Request", char="=", width=80)
result = get_llm_turn(
user_request=user_request,
user_request="What VPN options do I have?",
flake=flake,
conversation_history=conversation_history,
provider=provider, # type: ignore[arg-type]
session_state=session_state,
execute_next_action=False,
trace_file=trace_file,
)
# Auto-execute any pending operations
while result.next_action:
# Should have next_action for discovery phase
assert result.next_action is not None, "Should have next_action for discovery"
assert result.next_action["type"] == "discovery"
assert result.requires_user_response is False
assert len(result.proposed_instances) == 0
assert "pending_discovery" in result.session_state
print(f" Next Action: {result.next_action['type']}")
print(f" Description: {result.next_action['description']}")
print_meta_info(result, turn=1, phase="Initial Request")
# ========== STEP 2: Execute discovery (should return next_action for readme_fetch) ==========
print_separator("STEP 2: Execute Discovery", char="=", width=80)
result = get_llm_turn(
user_request="",
flake=flake,
@@ -197,187 +225,95 @@ def execute_multi_turn_workflow(
provider=provider, # type: ignore[arg-type]
session_state=result.session_state,
execute_next_action=True,
trace_file=trace_file,
)
return result
@pytest.mark.service_runner
@pytest.mark.usefixtures("mock_nix_shell", "llm_service")
def test_full_conversation_flow(mock_flake: MagicMock) -> None:
"""Comprehensive test that exercises the complete conversation flow with the actual LLM service.
This test simulates a realistic multi-turn conversation that covers:
- Discovery phase: Initial request and LLM gathering information
- Service selection phase: User choosing from available options
- Final decision phase: Configuring the selected service with specific parameters
- State transitions: pending_service_selection -> pending_final_decision -> completion
- Conversation history preservation across all turns
- Error handling and edge cases
"""
flake = mock_flake
# ========== TURN 1: Discovery Phase - Initial vague request ==========
print_separator("TURN 1: Discovery Phase", char="=", width=80)
result = execute_multi_turn_workflow(
user_request="What VPN options do I have?",
flake=flake,
provider="ollama",
)
# Verify discovery phase behavior
assert result.requires_user_response is True, (
"Should require user response in discovery"
)
assert len(result.conversation_history) >= 2, (
"Should have user + assistant messages"
)
assert result.conversation_history[0]["role"] == "user"
assert result.conversation_history[0]["content"] == "What VPN options do I have?"
assert result.conversation_history[-1]["role"] == "assistant"
assert len(result.assistant_message) > 0, "Assistant should provide a response"
# After multi-turn execution, we may have either:
# - pending_service_selection (if LLM provided options and is waiting for choice)
# - pending_final_decision (if LLM directly selected a service)
# - no pending state (if LLM asked a clarifying question)
# No instances yet
assert len(result.proposed_instances) == 0
assert result.error is None
print_chat_exchange(
"What VPN options do I have?", result.assistant_message, result.session_state
)
print_meta_info(result, turn=1, phase="Discovery")
# ========== TURN 2: Service Selection Phase - User makes a choice ==========
print_separator("TURN 2: Service Selection", char="=", width=80)
user_msg_2 = "I'll use ZeroTier please"
result = execute_multi_turn_workflow(
user_request=user_msg_2,
flake=flake,
conversation_history=list(result.conversation_history),
provider="ollama",
session_state=result.session_state,
)
# Verify conversation history growth and preservation
assert len(result.conversation_history) > 2, "History should grow"
assert result.conversation_history[0]["content"] == "What VPN options do I have?"
assert result.conversation_history[2]["content"] == "I'll use ZeroTier please"
# Should either ask for configuration details or provide direct config
# Most likely will ask for more details (pending_final_decision)
if result.requires_user_response:
# LLM is asking for configuration details
# Should have next_action for readme fetch OR a clarifying question
if result.next_action:
assert result.next_action["type"] == "fetch_readmes"
assert "pending_readme_fetch" in result.session_state
print(f" Next Action: {result.next_action['type']}")
print(f" Description: {result.next_action['description']}")
else:
# LLM asked a clarifying question
assert result.requires_user_response is True
assert len(result.assistant_message) > 0
# Should transition to final decision phase
if "pending_final_decision" not in result.session_state:
# Might still be in service selection asking clarifications
assert "pending_service_selection" in result.session_state
else:
# LLM provided configuration immediately (less likely)
assert len(result.proposed_instances) > 0
assert result.proposed_instances[0]["module"]["name"] == "zerotier"
print(f" Assistant Message: {result.assistant_message[:100]}...")
print_meta_info(result, turn=2, phase="Discovery Executed")
print_chat_exchange(user_msg_2, result.assistant_message, result.session_state)
print_meta_info(result, turn=2, phase="Service Selection")
# ========== Continue conversation until we reach final decision or completion ==========
max_turns = 10
turn_count = 2
while result.requires_user_response and turn_count < max_turns:
turn_count += 1
# Determine appropriate response based on current state
if "pending_service_selection" in result.session_state:
# Still selecting service
user_request = "Yes, ZeroTier"
phase = "Service Selection (continued)"
elif "pending_final_decision" in result.session_state:
# Configuring the service
user_request = "Set up gchq-local as controller, qube-email as moon, and wintux as peer"
phase = "Final Configuration"
else:
# Generic continuation
user_request = "Yes, that sounds good. Use gchq-local as controller."
phase = "Continuing Conversation"
print_separator(f"TURN {turn_count}: {phase}", char="=", width=80)
result = execute_multi_turn_workflow(
user_request=user_request,
# ========== STEP 3: Execute readme fetch (if applicable) ==========
if result.next_action and result.next_action["type"] == "fetch_readmes":
print_separator("STEP 3: Execute Readme Fetch", char="=", width=80)
result = get_llm_turn(
user_request="",
flake=flake,
conversation_history=list(result.conversation_history),
provider="ollama",
provider=provider, # type: ignore[arg-type]
session_state=result.session_state,
execute_next_action=True,
trace_file=trace_file,
)
# Verify conversation history continues to grow
assert len(result.conversation_history) == (turn_count * 2), (
f"History should have {turn_count * 2} messages (turn {turn_count})"
# Should have next_action for service selection
assert result.next_action is not None
assert result.next_action["type"] == "service_selection"
assert "pending_service_selection" in result.session_state
print(f" Next Action: {result.next_action['type']}")
print(f" Description: {result.next_action['description']}")
print_meta_info(result, turn=3, phase="Readme Fetch Executed")
# ========== STEP 4: Execute service selection ==========
print_separator("STEP 4: Execute Service Selection", char="=", width=80)
result = get_llm_turn(
user_request="I want ZeroTier.",
flake=flake,
conversation_history=list(result.conversation_history),
provider=provider, # type: ignore[arg-type]
session_state=result.session_state,
execute_next_action=True,
trace_file=trace_file,
)
# Verify history preservation
assert (
result.conversation_history[0]["content"] == "What VPN options do I have?"
)
print_chat_exchange(
user_request, result.assistant_message, result.session_state
)
print_meta_info(result, turn=turn_count, phase=phase)
# Check for completion
if not result.requires_user_response:
print_separator("CONVERSATION COMPLETED", char="=", width=80)
break
# ========== Final Verification ==========
print_separator("FINAL VERIFICATION", char="=", width=80)
assert turn_count < max_turns, f"Conversation took too many turns ({turn_count})"
# If conversation completed, verify we have valid configuration
if not result.requires_user_response:
assert len(result.proposed_instances) > 0, (
"Should have at least one proposed instance"
)
instance = result.proposed_instances[0]
# Verify instance structure
assert "module" in instance
assert "name" in instance["module"]
assert instance["module"]["name"] in [
"zerotier",
"wireguard",
"yggdrasil",
"mycelium",
]
# Should not be in pending state anymore
assert "pending_service_selection" not in result.session_state
assert "pending_final_decision" not in result.session_state
assert result.error is None, f"Should not have error: {result.error}"
print_separator("FINAL SUMMARY", char="-", width=80, double=False)
print(" Status: SUCCESS")
print(f" Module Name: {instance['module']['name']}")
print(f" Total Turns: {turn_count}")
print(f" Final History Length: {len(result.conversation_history)} messages")
if "roles" in instance:
roles_list = ", ".join(instance["roles"].keys())
print(f" Configuration Roles: {roles_list}")
print(" Errors: None")
print("-" * 80)
# Should either have next_action for final_decision OR a clarifying question
if result.next_action:
assert result.next_action["type"] == "final_decision"
assert "pending_final_decision" in result.session_state
print(f" Next Action: {result.next_action['type']}")
print(f" Description: {result.next_action['description']}")
else:
# Conversation didn't complete but should have made progress
assert len(result.conversation_history) > 2
assert result.error is None
print_separator("FINAL SUMMARY", char="-", width=80, double=False)
print(" Status: IN PROGRESS")
print(f" Total Turns: {turn_count}")
print(f" Current State: {list(result.session_state.keys())}")
print(f" History Length: {len(result.conversation_history)} messages")
print("-" * 80)
# LLM asked a clarifying question during service selection
assert result.requires_user_response is True
assert len(result.assistant_message) > 0
print(f" Assistant Message: {result.assistant_message[:100]}...")
print_meta_info(result, turn=4, phase="Service Selection Executed")
# ========== STEP 5: Execute final decision (if applicable) ==========
if result.next_action and result.next_action["type"] == "final_decision":
print_separator("STEP 5: Execute Final Decision", char="=", width=80)
result = get_llm_turn(
user_request="",
flake=flake,
conversation_history=list(result.conversation_history),
provider=provider, # type: ignore[arg-type]
session_state=result.session_state,
execute_next_action=True,
trace_file=trace_file,
)
# Should either have proposed_instances OR ask a clarifying question
if result.proposed_instances:
assert len(result.proposed_instances) > 0
assert result.next_action is None
print(f" Proposed Instances: {len(result.proposed_instances)}")
for inst in result.proposed_instances:
print(f" - {inst['module']['name']}")
else:
# LLM asked a clarifying question
assert result.requires_user_response is True
assert len(result.assistant_message) > 0
print(f" Assistant Message: {result.assistant_message[:100]}...")
print_meta_info(result, turn=5, phase="Final Decision Executed")
# Verify conversation history has grown
assert len(result.conversation_history) > 0
assert result.conversation_history[0]["content"] == "What VPN options do I have?"

View File

@@ -149,6 +149,7 @@ def call_openai_api(
trace_file: Path | None = None,
stage: str = "unknown",
trace_metadata: dict[str, Any] | None = None,
temperature: float | None = None,
) -> OpenAIChatCompletionResponse:
"""Call the OpenAI API for chat completion.
@@ -160,6 +161,7 @@ def call_openai_api(
trace_file: Optional path to write trace entries for debugging
stage: Stage name for trace entries (default: "unknown")
trace_metadata: Optional metadata to include in trace entries
temperature: Sampling temperature (default: None = use API default)
Returns:
The parsed JSON response from the API
@@ -178,6 +180,8 @@ def call_openai_api(
"messages": messages,
"tools": list(tools),
}
if temperature is not None:
payload["temperature"] = temperature
_debug_log_request("openai", messages, tools)
url = "https://api.openai.com/v1/chat/completions"
headers = {
@@ -256,6 +260,7 @@ def call_claude_api(
trace_file: Path | None = None,
stage: str = "unknown",
trace_metadata: dict[str, Any] | None = None,
temperature: float | None = None,
) -> OpenAIChatCompletionResponse:
"""Call the Claude API (via OpenAI-compatible endpoint) for chat completion.
@@ -268,6 +273,7 @@ def call_claude_api(
trace_file: Optional path to write trace entries for debugging
stage: Stage name for trace entries (default: "unknown")
trace_metadata: Optional metadata to include in trace entries
temperature: Sampling temperature (default: None = use API default)
Returns:
The parsed JSON response from the API
@@ -293,6 +299,8 @@ def call_claude_api(
"messages": messages,
"tools": list(tools),
}
if temperature is not None:
payload["temperature"] = temperature
_debug_log_request("claude", messages, tools)
url = f"{base_url}chat/completions"
@@ -372,6 +380,7 @@ def call_ollama_api(
stage: str = "unknown",
max_tokens: int | None = None,
trace_metadata: dict[str, Any] | None = None,
temperature: float | None = None,
) -> OllamaChatResponse:
"""Call the Ollama API for chat completion.
@@ -384,6 +393,7 @@ def call_ollama_api(
stage: Stage name for trace entries (default: "unknown")
max_tokens: Maximum number of tokens to generate (default: None = unlimited)
trace_metadata: Optional metadata to include in trace entries
temperature: Sampling temperature (default: None = use API default)
Returns:
The parsed JSON response from the API
@@ -399,9 +409,14 @@ def call_ollama_api(
"tools": list(tools),
}
# Add max_tokens limit if specified
# Add options for max_tokens and temperature if specified
options: dict[str, int | float] = {}
if max_tokens is not None:
payload["options"] = {"num_predict": max_tokens} # type: ignore[typeddict-item]
options["num_predict"] = max_tokens
if temperature is not None:
options["temperature"] = temperature
if options:
payload["options"] = options # type: ignore[typeddict-item]
_debug_log_request("ollama", messages, tools)
url = "http://localhost:11434/api/chat"

View File

@@ -73,19 +73,21 @@ class ModelConfig:
name: The model identifier/name
provider: The LLM provider
timeout: Request timeout in seconds (default: 120)
temperature: Sampling temperature for the model (default: None = use API default)
"""
name: str
provider: Literal["openai", "ollama", "claude"]
timeout: int = 120
temperature: float | None = None
# Default model configurations for each provider
DEFAULT_MODELS: dict[Literal["openai", "ollama", "claude"], ModelConfig] = {
"openai": ModelConfig(name="gpt-4o", provider="openai", timeout=60),
"claude": ModelConfig(name="claude-sonnet-4-5", provider="claude", timeout=60),
"ollama": ModelConfig(name="qwen3:4b-instruct", provider="ollama", timeout=120),
"ollama": ModelConfig(name="qwen3:4b-instruct", provider="ollama", timeout=180),
}

View File

@@ -100,6 +100,7 @@ def get_llm_discovery_phase(
trace_file=trace_file,
stage="discovery",
trace_metadata=trace_metadata,
temperature=model_config.temperature,
)
function_calls, message_content = parse_openai_response(
openai_response, provider="openai"
@@ -113,6 +114,7 @@ def get_llm_discovery_phase(
trace_file=trace_file,
stage="discovery",
trace_metadata=trace_metadata,
temperature=model_config.temperature,
)
function_calls, message_content = parse_openai_response(
claude_response, provider="claude"
@@ -127,6 +129,7 @@ def get_llm_discovery_phase(
stage="discovery",
max_tokens=300, # Limit output for discovery phase (get_readme calls or short question)
trace_metadata=trace_metadata,
temperature=model_config.temperature,
)
function_calls, message_content = parse_ollama_response(
ollama_response, provider="ollama"
@@ -249,6 +252,7 @@ def get_llm_service_selection(
trace_file=trace_file,
stage="select_service",
trace_metadata=trace_metadata,
temperature=model_config.temperature,
)
function_calls, message_content = parse_openai_response(
openai_response, provider="openai"
@@ -262,6 +266,7 @@ def get_llm_service_selection(
trace_file=trace_file,
stage="select_service",
trace_metadata=trace_metadata,
temperature=model_config.temperature,
)
function_calls, message_content = parse_openai_response(
claude_response, provider="claude"
@@ -276,6 +281,7 @@ def get_llm_service_selection(
stage="select_service",
max_tokens=600, # Allow space for summary
trace_metadata=trace_metadata,
temperature=model_config.temperature,
)
function_calls, message_content = parse_ollama_response(
ollama_response, provider="ollama"
@@ -447,6 +453,7 @@ def get_llm_final_decision(
trace_file=trace_file,
stage="final_decision",
trace_metadata=trace_metadata,
temperature=model_config.temperature,
)
function_calls, message_content = parse_openai_response(
openai_response, provider="openai"
@@ -462,6 +469,7 @@ def get_llm_final_decision(
trace_file=trace_file,
stage="final_decision",
trace_metadata=trace_metadata,
temperature=model_config.temperature,
)
function_calls, message_content = parse_openai_response(
claude_response, provider="claude"
@@ -477,6 +485,7 @@ def get_llm_final_decision(
stage="final_decision",
max_tokens=500, # Limit output to prevent excessive verbosity
trace_metadata=trace_metadata,
temperature=model_config.temperature,
)
function_calls, message_content = parse_ollama_response(
ollama_response, provider="ollama"

View File

@@ -231,6 +231,7 @@ class ChatCompletionRequestPayload(TypedDict, total=False):
messages: list[ChatMessage]
tools: list[ToolDefinition]
stream: NotRequired[bool]
temperature: NotRequired[float]
@dataclass(frozen=True)

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Literal, TypedDict
from clan_lib.cmd import RunOpts, run
from clan_lib.cmd import Log, RunOpts, run
from clan_lib.errors import ClanError
if TYPE_CHECKING:
@@ -70,7 +70,7 @@ class SystemdUserService:
"""Run systemctl command with --user flag."""
return run(
["systemctl", "--user", action, f"{service_name}.service"],
RunOpts(check=False),
RunOpts(check=False, log=Log.NONE),
)
def _get_property(self, service_name: str, prop: str) -> str:
@@ -240,11 +240,15 @@ class SystemdUserService:
service_name = self._service_name(name)
result = self._systemctl("stop", service_name)
if result.returncode != 0 and "not loaded" not in result.stderr.lower():
if (
result.returncode != 0
and "not loaded" not in result.stderr.lower()
and "does not exist" not in result.stderr.lower()
):
msg = f"Failed to stop service: {result.stderr}"
raise ClanError(msg)
self._systemctl("disable", service_name) # Ignore errors for transient units
result = self._systemctl("disable", service_name)
unit_file = self._unit_file_path(name)
if unit_file.exists():