clan_lib/llm: get_llm_turn uses state transitions instead of callback function

This commit is contained in:
Qubasa
2025-10-24 15:57:26 +02:00
parent 183de9209f
commit c3456c1f0c
10 changed files with 529 additions and 859 deletions

View File

@@ -76,7 +76,6 @@ in
cmd = "su - text-user -c 'pytest -s -n0 -m service_runner -p no:cacheprovider -o addopts="" ${cli.passthru.sourceWithTests}/clan_lib/llm'"
print("Running tests with command: " + cmd)
# Run tests as text-user (environment variables are set automatically)
peer1.succeed(cmd)
'';

View File

@@ -1,200 +0,0 @@
"""High-level API functions for LLM interactions, suitable for HTTP APIs and web UIs.
This module provides a clean, stateless API for integrating LLM functionality into
web applications and HTTP services. It wraps the complex multi-stage workflow into
simple function calls with serializable inputs and outputs.
"""
from pathlib import Path
from typing import Any, Literal, TypedDict, cast
from clan_lib.api import API
from clan_lib.flake.flake import Flake
from .llm import (
DEFAULT_MODELS,
ChatResult,
DiscoveryProgressEvent,
FinalDecisionProgressEvent,
ModelConfig,
ProgressCallback,
ProgressEvent,
ReadmeFetchProgressEvent,
get_model_config,
process_chat_turn,
)
from .schemas import ChatMessage, ConversationHistory, SessionState
class ChatTurnRequest(TypedDict, total=False):
"""Request payload for a chat turn.
Attributes:
user_message: The user's message/request
conversation_history: Optional list of prior messages in the conversation
provider: The LLM provider to use (default: "claude")
trace_file: Optional path to write LLM interaction traces for debugging
session_state: Opaque state returned from the previous turn
"""
user_message: str
conversation_history: ConversationHistory | None
provider: Literal["openai", "ollama", "claude"]
trace_file: Path | None
session_state: SessionState | None
class ChatTurnResponse(TypedDict):
"""Response payload for a chat turn.
Attributes:
proposed_instances: List of inventory instances suggested by the LLM
conversation_history: Updated conversation history after this turn
assistant_message: Message from the assistant
requires_user_response: Whether the assistant is waiting for user input
error: Error message if something went wrong (None on success)
session_state: State blob to pass into the next turn when continuing the workflow
"""
proposed_instances: list[dict[str, Any]]
conversation_history: list[ChatMessage]
assistant_message: str
requires_user_response: bool
error: str | None
session_state: SessionState
class ProgressEventResponse(TypedDict):
"""Progress event for streaming updates.
Attributes:
stage: The current stage of processing
status: The status within that stage (if applicable)
count: Count of items (for readme_fetch stage)
message: Message content (for conversation stage)
"""
stage: str
status: str | None
count: int | None
message: str | None
@API.register
def get_llm_turn(
flake: Flake,
request: ChatTurnRequest,
progress_callback: ProgressCallback | None = None,
) -> ChatTurnResponse:
"""Process a single chat turn through the LLM workflow.
This is the main entry point for HTTP APIs and web UIs to interact with
the LLM functionality. It handles:
- Service discovery
- Documentation fetching
- Final decision making
- Conversation management
Args:
flake: The Flake object representing the clan configuration
request: The chat turn request containing user message and optional history
progress_callback: Optional callback for progress updates
Returns:
ChatTurnResponse with proposed instances and conversation state
Example:
>>> from clan_lib.flake.flake import Flake
>>> flake = Flake("/path/to/clan")
>>> request: ChatTurnRequest = {
... "user_message": "Set up a web server",
... "provider": "claude"
... }
>>> response = chat_turn(flake, request)
>>> if response["proposed_instances"]:
... print("LLM suggests:", response["proposed_instances"])
>>> if response["requires_user_response"]:
... print("Assistant asks:", response["assistant_message"])
"""
result: ChatResult = process_chat_turn(
user_request=request["user_message"],
flake=flake,
conversation_history=request.get("conversation_history"),
provider=request.get("provider", "claude"),
progress_callback=progress_callback,
trace_file=request.get("trace_file"),
session_state=request.get("session_state"),
)
# Convert frozen tuples to lists for JSON serialization
return ChatTurnResponse(
proposed_instances=[dict(inst) for inst in result.proposed_instances],
conversation_history=list(result.conversation_history),
assistant_message=result.assistant_message,
requires_user_response=result.requires_user_response,
error=result.error,
session_state=cast("SessionState", dict(result.session_state)),
)
def progress_event_to_dict(event: ProgressEvent) -> ProgressEventResponse:
"""Convert a ProgressEvent to a dictionary suitable for JSON serialization.
This helper function is useful for streaming progress updates over HTTP
(e.g., Server-Sent Events or WebSockets).
Args:
event: The progress event to convert
Returns:
Dictionary representation of the event
Example:
>>> from clan_lib.llm.llm import DiscoveryProgressEvent
>>> event = DiscoveryProgressEvent(status="analyzing")
>>> progress_event_to_dict(event)
{'stage': 'discovery', 'status': 'analyzing', 'count': None, 'message': None}
"""
base_response: ProgressEventResponse = {
"stage": event.stage,
"status": None,
"count": None,
"message": None,
}
if isinstance(event, (DiscoveryProgressEvent, FinalDecisionProgressEvent)):
base_response["status"] = event.status
elif isinstance(event, ReadmeFetchProgressEvent):
base_response["status"] = event.status
base_response["count"] = event.count
# ConversationProgressEvent has message field
elif hasattr(event, "message"):
base_response["message"] = event.message # type: ignore[attr-defined]
if hasattr(event, "awaiting_response"):
base_response["status"] = (
"awaiting_response"
if event.awaiting_response # type: ignore[attr-defined]
else "complete"
)
return base_response
# Re-export types for convenience
__all__ = [
"DEFAULT_MODELS",
"ChatTurnRequest",
"ChatTurnResponse",
"ModelConfig",
"ProgressCallback",
"ProgressEvent",
"ProgressEventResponse",
"get_llm_turn",
"get_model_config",
"progress_event_to_dict",
]

View File

@@ -2,16 +2,86 @@ import contextlib
import json
from collections.abc import Iterator
from pathlib import Path
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch
import pytest
from clan_lib.flake.flake import Flake
from clan_lib.llm.llm import (
process_chat_turn,
)
from clan_lib.llm.orchestrator import get_llm_turn
from clan_lib.llm.service import create_llm_model, run_llm_service
from clan_lib.service_runner import create_service_manager
if TYPE_CHECKING:
from clan_lib.llm.llm_types import ChatResult
from clan_lib.llm.schemas import ChatMessage, SessionState
def get_current_mode(session_state: "SessionState") -> str:
"""Extract the current mode from session state."""
if "pending_service_selection" in session_state:
return "SERVICE_SELECTION"
if "pending_final_decision" in session_state:
return "FINAL_DECISION"
return "DISCOVERY"
def print_separator(
title: str = "", char: str = "=", width: int = 80, double: bool = True
) -> None:
"""Print a separator line with optional title."""
if double:
print(f"\n{char * width}")
if title:
padding = (width - len(title) - 2) // 2
print(f"{char * padding} {title} {char * padding}")
if double or title:
print(f"{char * width}")
def print_meta_info(result: "ChatResult", turn: int, phase: str) -> None:
"""Print meta information section in a structured format."""
mode = get_current_mode(result.session_state)
print_separator("META INFORMATION", char="-", width=80, double=False)
print(f" Turn Number: {turn}")
print(f" Phase: {phase}")
print(f" Current Mode: {mode}")
print(f" Requires User Input: {result.requires_user_response}")
print(f" Conversation Length: {len(result.conversation_history)} messages")
print(f" Proposed Instances: {len(result.proposed_instances)}")
print(f" Has Next Action: {result.next_action is not None}")
print(f" Session State Keys: {list(result.session_state.keys())}")
if result.error:
print(f" Error: {result.error}")
print("-" * 80)
def print_chat_exchange(
user_msg: str | None, assistant_msg: str, session_state: "SessionState"
) -> None:
"""Print a chat exchange with role labels and current mode."""
mode = get_current_mode(session_state)
print_separator("CHAT EXCHANGE", char="-", width=80, double=False)
if user_msg:
print("\n USER:")
print(f" {user_msg}")
print(f"\n ASSISTANT [{mode}]:")
# Wrap long messages
max_line_length = 76
words = assistant_msg.split()
current_line = " "
for word in words:
if len(current_line) + len(word) + 1 > max_line_length:
print(current_line)
current_line = " " + word
else:
current_line += (" " if len(current_line) > 2 else "") + word
if current_line.strip():
print(current_line)
print("\n" + "-" * 80)
@pytest.fixture
def mock_flake() -> MagicMock:
@@ -47,9 +117,9 @@ def mock_flake() -> MagicMock:
}
match arg:
case "clanInternals.inventoryClass.inventory.{instances,machines,meta}":
case "clanInternals.inventoryClass.inventorySerialization.{instances,machines,meta}":
return load_json("inventory_instances_machines_meta.json")
case "clanInternals.inventoryClass.inventory.{tags}":
case "clanInternals.inventoryClass.inventorySerialization.{tags}":
return load_json("inventory_tags.json")
case "clanInternals.inventoryClass.modulesPerSource":
return load_json("modules_per_source.json")
@@ -98,6 +168,40 @@ def llm_service() -> Iterator[None]:
service_manager.stop_service("ollama")
def execute_multi_turn_workflow(
user_request: str,
flake: Flake | MagicMock,
conversation_history: list["ChatMessage"] | None = None,
provider: str = "ollama",
session_state: "SessionState | None" = None,
) -> "ChatResult":
"""Execute the multi-turn workflow, auto-executing all pending operations.
This simulates the behavior of the CLI auto-execute loop in workflow.py.
"""
result = get_llm_turn(
user_request=user_request,
flake=flake,
conversation_history=conversation_history,
provider=provider, # type: ignore[arg-type]
session_state=session_state,
execute_next_action=False,
)
# Auto-execute any pending operations
while result.next_action:
result = get_llm_turn(
user_request="",
flake=flake,
conversation_history=list(result.conversation_history),
provider=provider, # type: ignore[arg-type]
session_state=result.session_state,
execute_next_action=True,
)
return result
@pytest.mark.service_runner
@pytest.mark.usefixtures("mock_nix_shell", "llm_service")
def test_full_conversation_flow(mock_flake: MagicMock) -> None:
@@ -112,10 +216,9 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None:
- Error handling and edge cases
"""
flake = mock_flake
return
# ========== TURN 1: Discovery Phase - Initial vague request ==========
print("\n=== TURN 1: Initial discovery request ===")
result = process_chat_turn(
print_separator("TURN 1: Discovery Phase", char="=", width=80)
result = execute_multi_turn_workflow(
user_request="What VPN options do I have?",
flake=flake,
provider="ollama",
@@ -133,24 +236,25 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None:
assert result.conversation_history[-1]["role"] == "assistant"
assert len(result.assistant_message) > 0, "Assistant should provide a response"
# Should transition to service selection phase with pending state
assert "pending_service_selection" in result.session_state, (
"Should have pending service selection"
)
assert "readme_results" in result.session_state["pending_service_selection"]
# After multi-turn execution, we may have either:
# - pending_service_selection (if LLM provided options and is waiting for choice)
# - pending_final_decision (if LLM directly selected a service)
# - no pending state (if LLM asked a clarifying question)
# No instances yet
assert len(result.proposed_instances) == 0
assert result.error is None
print(f"Assistant: {result.assistant_message[:200]}...")
print(f"State: {list(result.session_state.keys())}")
print(f"History length: {len(result.conversation_history)}")
print_chat_exchange(
"What VPN options do I have?", result.assistant_message, result.session_state
)
print_meta_info(result, turn=1, phase="Discovery")
# ========== TURN 2: Service Selection Phase - User makes a choice ==========
print("\n=== TURN 2: User selects ZeroTier ===")
result = process_chat_turn(
user_request="I'll use ZeroTier please",
print_separator("TURN 2: Service Selection", char="=", width=80)
user_msg_2 = "I'll use ZeroTier please"
result = execute_multi_turn_workflow(
user_request=user_msg_2,
flake=flake,
conversation_history=list(result.conversation_history),
provider="ollama",
@@ -176,11 +280,8 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None:
assert len(result.proposed_instances) > 0
assert result.proposed_instances[0]["module"]["name"] == "zerotier"
print(
f"Assistant: {result.assistant_message[:200] if result.assistant_message else 'No message'}..."
)
print(f"State: {list(result.session_state.keys())}")
print(f"Requires response: {result.requires_user_response}")
print_chat_exchange(user_msg_2, result.assistant_message, result.session_state)
print_meta_info(result, turn=2, phase="Service Selection")
# ========== Continue conversation until we reach final decision or completion ==========
max_turns = 10
@@ -188,22 +289,24 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None:
while result.requires_user_response and turn_count < max_turns:
turn_count += 1
print(f"\n=== TURN {turn_count}: Continuing conversation ===")
# Determine appropriate response based on current state
if "pending_service_selection" in result.session_state:
# Still selecting service
user_request = "Yes, ZeroTier"
phase = "Service Selection (continued)"
elif "pending_final_decision" in result.session_state:
# Configuring the service
user_request = "Set up gchq-local as controller, qube-email as moon, and wintux as peer"
phase = "Final Configuration"
else:
# Generic continuation
user_request = "Yes, that sounds good. Use gchq-local as controller."
phase = "Continuing Conversation"
print(f"User: {user_request}")
print_separator(f"TURN {turn_count}: {phase}", char="=", width=80)
result = process_chat_turn(
result = execute_multi_turn_workflow(
user_request=user_request,
flake=flake,
conversation_history=list(result.conversation_history),
@@ -221,19 +324,18 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None:
result.conversation_history[0]["content"] == "What VPN options do I have?"
)
print(
f"Assistant: {result.assistant_message[:200] if result.assistant_message else 'No message'}..."
print_chat_exchange(
user_request, result.assistant_message, result.session_state
)
print(f"State: {list(result.session_state.keys())}")
print(f"Requires response: {result.requires_user_response}")
print(f"Proposed instances: {len(result.proposed_instances)}")
print_meta_info(result, turn=turn_count, phase=phase)
# Check for completion
if not result.requires_user_response:
print("\n=== Conversation completed! ===")
print_separator("CONVERSATION COMPLETED", char="=", width=80)
break
# ========== Final Verification ==========
print_separator("FINAL VERIFICATION", char="=", width=80)
assert turn_count < max_turns, f"Conversation took too many turns ({turn_count})"
# If conversation completed, verify we have valid configuration
@@ -253,22 +355,29 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None:
"mycelium",
]
# Should have roles configuration
if "roles" in instance:
print(f"\nConfiguration roles: {list(instance['roles'].keys())}")
# Should not be in pending state anymore
assert "pending_service_selection" not in result.session_state
assert "pending_final_decision" not in result.session_state
assert result.error is None, f"Should not have error: {result.error}"
print(f"\nFinal instance: {instance['module']['name']}")
print(f"Total conversation turns: {turn_count}")
print(f"Final history length: {len(result.conversation_history)}")
print_separator("FINAL SUMMARY", char="-", width=80, double=False)
print(" Status: SUCCESS")
print(f" Module Name: {instance['module']['name']}")
print(f" Total Turns: {turn_count}")
print(f" Final History Length: {len(result.conversation_history)} messages")
if "roles" in instance:
roles_list = ", ".join(instance["roles"].keys())
print(f" Configuration Roles: {roles_list}")
print(" Errors: None")
print("-" * 80)
else:
# Conversation didn't complete but should have made progress
assert len(result.conversation_history) > 2
assert result.error is None
print(f"\nConversation in progress after {turn_count} turns")
print(f"Current state: {list(result.session_state.keys())}")
print_separator("FINAL SUMMARY", char="-", width=80, double=False)
print(" Status: IN PROGRESS")
print(f" Total Turns: {turn_count}")
print(f" Current State: {list(result.session_state.keys())}")
print(f" History Length: {len(result.conversation_history)} messages")
print("-" * 80)

View File

@@ -1,65 +0,0 @@
"""High-level LLM orchestration functions.
This module re-exports the LLM orchestration API from submodules.
"""
# Re-export types and dataclasses
from .llm_types import ( # noqa: F401
DEFAULT_MODELS,
ChatResult,
ConversationProgressEvent,
DiscoveryProgressEvent,
FinalDecisionProgressEvent,
ModelConfig,
ProgressCallback,
ProgressEvent,
ReadmeFetchProgressEvent,
ServiceSelectionProgressEvent,
ServiceSelectionResult,
get_model_config,
)
# Re-export high-level orchestrator
from .orchestrator import process_chat_turn # noqa: F401
# Re-export low-level phase functions
from .phases import ( # noqa: F401
execute_readme_requests,
get_llm_discovery_phase,
get_llm_final_decision,
get_llm_service_selection,
llm_final_decision_to_inventory_instances,
)
# Re-export commonly used functions and types from schemas
from .schemas import ( # noqa: F401
AiAggregate,
ChatMessage,
ConversationHistory,
FunctionCallType,
JSONValue,
MachineDescription,
OllamaFunctionSchema,
OpenAIFunctionSchema,
PendingFinalDecisionState,
PendingServiceSelectionState,
ReadmeRequest,
SessionState,
SimplifiedServiceSchema,
TagDescription,
aggregate_ollama_function_schemas,
aggregate_openai_function_schemas,
create_get_readme_tool,
create_select_service_tool,
create_simplified_service_schemas,
)
# Re-export service functions
from .service import create_llm_model, run_llm_service # noqa: F401
# Re-export utility functions and constants
from .utils import ( # noqa: F401
ASSISTANT_MODE_DISCOVERY,
ASSISTANT_MODE_FINAL,
ASSISTANT_MODE_SELECTION,
)

View File

@@ -3,12 +3,13 @@ from collections.abc import Callable
import pytest
from clan_cli.tests.fixtures_flakes import nested_dict
from clan_lib.flake.flake import Flake
from clan_lib.llm.llm import (
from clan_lib.llm.phases import llm_final_decision_to_inventory_instances
from clan_lib.llm.schemas import (
FunctionCallType,
OpenAIFunctionSchema,
aggregate_openai_function_schemas,
llm_final_decision_to_inventory_instances,
clan_module_to_openai_spec,
)
from clan_lib.llm.schemas import FunctionCallType, clan_module_to_openai_spec
from clan_lib.services.modules import list_service_modules

View File

@@ -1,57 +1,28 @@
"""Type definitions and dataclasses for LLM orchestration."""
from collections.abc import Callable
from dataclasses import dataclass
from typing import Literal
from typing import Any, Literal, TypedDict
from clan_lib.nix_models.clan import InventoryInstance
from .schemas import ChatMessage, SessionState
@dataclass(frozen=True)
class DiscoveryProgressEvent:
"""Progress event during discovery phase."""
class NextAction(TypedDict):
"""Describes the next expensive operation that will be performed.
service_names: list[str] | None = None
stage: Literal["discovery"] = "discovery"
status: Literal["analyzing", "complete"] = "analyzing"
Attributes:
type: The type of operation (discovery, fetch_readmes, service_selection, final_decision)
description: Human-readable description of what will happen
estimated_duration_seconds: Rough estimate of operation duration
details: Phase-specific information (e.g., service names, count)
"""
@dataclass(frozen=True)
class ReadmeFetchProgressEvent:
"""Progress event during readme fetching."""
count: int
service_names: list[str]
stage: Literal["readme_fetch"] = "readme_fetch"
status: Literal["fetching", "complete"] = "fetching"
@dataclass(frozen=True)
class ServiceSelectionProgressEvent:
"""Progress event during service selection phase."""
service_names: list[str]
stage: Literal["service_selection"] = "service_selection"
status: Literal["selecting", "complete"] = "selecting"
@dataclass(frozen=True)
class FinalDecisionProgressEvent:
"""Progress event during final decision phase."""
stage: Literal["final_decision"] = "final_decision"
status: Literal["reviewing", "complete"] = "reviewing"
@dataclass(frozen=True)
class ConversationProgressEvent:
"""Progress event for conversation continuation."""
message: str
stage: Literal["conversation"] = "conversation"
awaiting_response: bool = True
type: Literal["discovery", "fetch_readmes", "service_selection", "final_decision"]
description: str
estimated_duration_seconds: int
details: dict[str, Any]
@dataclass(frozen=True)
@@ -70,17 +41,6 @@ class ServiceSelectionResult:
clarifying_message: str
ProgressEvent = (
DiscoveryProgressEvent
| ReadmeFetchProgressEvent
| ServiceSelectionProgressEvent
| FinalDecisionProgressEvent
| ConversationProgressEvent
)
ProgressCallback = Callable[[ProgressEvent], None]
@dataclass(frozen=True)
class ChatResult:
"""Result of a complete chat turn through the multi-stage workflow.
@@ -92,6 +52,7 @@ class ChatResult:
requires_user_response: True if the assistant asked a question and needs a response
error: Error message if something went wrong (None on success)
session_state: Serializable state to pass into the next turn when continuing a workflow
next_action: Description of the next operation to be performed (None if workflow complete)
"""
@@ -100,6 +61,7 @@ class ChatResult:
assistant_message: str
requires_user_response: bool
session_state: SessionState
next_action: NextAction | None
error: str | None = None

View File

@@ -4,18 +4,12 @@ import json
from pathlib import Path
from typing import Literal, cast
from clan_lib.api import API
from clan_lib.errors import ClanAiError
from clan_lib.flake.flake import Flake
from clan_lib.services.modules import InputName, ServiceReadmeCollection
from .llm_types import (
ChatResult,
DiscoveryProgressEvent,
FinalDecisionProgressEvent,
ProgressCallback,
ReadmeFetchProgressEvent,
ServiceSelectionProgressEvent,
)
from .llm_types import ChatResult, NextAction
from .phases import (
execute_readme_requests,
get_llm_discovery_phase,
@@ -26,8 +20,11 @@ from .phases import (
from .schemas import (
ConversationHistory,
JSONValue,
PendingDiscoveryState,
PendingFinalDecisionState,
PendingReadmeFetchState,
PendingServiceSelectionState,
ReadmeRequest,
SessionState,
)
from .utils import (
@@ -41,45 +38,50 @@ from .utils import (
)
def process_chat_turn(
@API.register
def get_llm_turn(
user_request: str,
flake: Flake,
conversation_history: ConversationHistory | None = None,
provider: Literal["openai", "ollama", "claude"] = "ollama",
progress_callback: ProgressCallback | None = None,
trace_file: Path | None = None,
session_state: SessionState | None = None,
execute_next_action: bool = False,
) -> ChatResult:
"""High-level API that orchestrates the entire multi-stage chat workflow.
This function handles the complete flow:
This function handles the complete flow using a multi-turn approach:
1. Discovery phase - LLM selects relevant services
2. Readme fetching - Retrieves detailed documentation
3. Final decision - LLM makes informed suggestions
4. Conversion - Transforms suggestions to inventory instances
Before each expensive operation, the function returns with next_action describing
what will happen. The caller must call again with execute_next_action=True.
Args:
user_request: The user's message/request
flake: The Flake object to get services from
conversation_history: Optional list of prior messages in the conversation
provider: The LLM provider to use
progress_callback: Optional callback for progress updates
trace_file: Optional path to write LLM interaction traces for debugging
session_state: Optional cross-turn state to resume pending workflows
execute_next_action: If True, execute the pending operation in session_state
Returns:
ChatResult containing proposed instances, updated history, and assistant message
ChatResult containing proposed instances, updated history, next_action, and assistant message
Example:
>>> result = process_chat_turn(
... "Set up a web server",
... flake,
... progress_callback=lambda event: print(f"Stage: {event.stage}")
... )
>>> if result.proposed_instances:
... print("LLM suggested:", result.proposed_instances)
>>> if result.requires_user_response:
... print("Assistant asks:", result.assistant_message)
>>> result = process_chat_turn("Set up a web server", flake)
>>> while result.next_action:
... # Show user what will happen
... print(result.next_action["description"])
... result = process_chat_turn(
... user_request="",
... flake=flake,
... session_state=result.session_state,
... execute_next_action=True
... )
"""
history = list(conversation_history) if conversation_history else []
@@ -87,6 +89,11 @@ def process_chat_turn(
"SessionState", dict(session_state) if session_state else {}
)
# Add non-empty user message to history for conversation tracking
# Phases will also add it to their messages array for LLM calls
if user_request:
history.append(_user_message(user_request))
def _state_snapshot() -> dict[str, JSONValue]:
try:
return json.loads(json.dumps(state))
@@ -102,6 +109,17 @@ def process_chat_turn(
def _state_copy() -> SessionState:
return cast("SessionState", dict(state))
# Check pending states in workflow order (earliest to latest)
pending_discovery_raw = state.get("pending_discovery")
pending_discovery: PendingDiscoveryState | None = (
pending_discovery_raw if isinstance(pending_discovery_raw, dict) else None
)
pending_readme_fetch_raw = state.get("pending_readme_fetch")
pending_readme_fetch: PendingReadmeFetchState | None = (
pending_readme_fetch_raw if isinstance(pending_readme_fetch_raw, dict) else None
)
pending_final_raw = state.get("pending_final_decision")
pending_final: PendingFinalDecisionState | None = (
pending_final_raw if isinstance(pending_final_raw, dict) else None
@@ -117,87 +135,139 @@ def process_chat_turn(
if serialized_results is not None:
resume_readme_results = _deserialize_readme_results(serialized_results)
# Only pop if we can't deserialize (invalid state)
if resume_readme_results is None:
state.pop("pending_service_selection", None)
else:
state.pop("pending_service_selection", None)
# Handle pending_discovery state: execute discovery if execute_next_action=True
if pending_discovery is not None and execute_next_action:
state.pop("pending_discovery", None)
# Continue to execute discovery below (after pending state checks)
# Handle pending_readme_fetch state: execute readme fetch if execute_next_action=True
if pending_readme_fetch is not None and execute_next_action:
readme_requests_raw = pending_readme_fetch.get("readme_requests", [])
readme_requests = cast("list[ReadmeRequest]", readme_requests_raw)
if readme_requests:
state.pop("pending_readme_fetch", None)
readme_results = execute_readme_requests(readme_requests, flake)
# Save readme results and return next_action for service selection
state["pending_service_selection"] = cast(
"PendingServiceSelectionState",
{"readme_results": _serialize_readme_results(readme_results)},
)
service_count = len(readme_results)
next_action_selection: NextAction = {
"type": "service_selection",
"description": f"Analyzing {service_count} service(s) to find the best match",
"estimated_duration_seconds": 15,
"details": {"service_count": service_count},
}
return ChatResult(
next_action=next_action_selection,
proposed_instances=(),
conversation_history=tuple(history),
assistant_message="",
requires_user_response=False,
error=None,
session_state=_state_copy(),
)
state.pop("pending_readme_fetch", None)
if pending_final is not None:
service_name = pending_final.get("service_name")
service_summary = pending_final.get("service_summary")
if isinstance(service_name, str) and isinstance(service_summary, str):
if progress_callback:
progress_callback(FinalDecisionProgressEvent(status="reviewing"))
function_calls, final_message = get_llm_final_decision(
user_request,
flake,
service_name,
service_summary,
conversation_history,
provider=provider,
trace_file=trace_file,
trace_metadata=_metadata(
{
"selected_service": service_name,
"resume": True,
}
),
)
if progress_callback:
progress_callback(FinalDecisionProgressEvent(status="complete"))
history.append(_user_message(user_request))
if function_calls:
proposed_instances = llm_final_decision_to_inventory_instances(
function_calls
)
instance_names = [inst["module"]["name"] for inst in proposed_instances]
summary = (
f"I suggest configuring these services: {', '.join(instance_names)}"
)
history.append(_assistant_message(summary, mode=ASSISTANT_MODE_FINAL))
if execute_next_action:
state.pop("pending_final_decision", None)
return ChatResult(
proposed_instances=tuple(proposed_instances),
conversation_history=tuple(history),
assistant_message=summary,
requires_user_response=False,
error=None,
session_state=_state_copy(),
function_calls, final_message = get_llm_final_decision(
user_request,
flake,
service_name,
service_summary,
conversation_history,
provider=provider,
trace_file=trace_file,
trace_metadata=_metadata(
{
"selected_service": service_name,
"resume": True,
}
),
)
if final_message:
history.append(
_assistant_message(final_message, mode=ASSISTANT_MODE_FINAL)
)
state["pending_final_decision"] = cast(
"PendingFinalDecisionState",
{
"service_name": service_name,
"service_summary": service_summary,
},
if function_calls:
proposed_instances = llm_final_decision_to_inventory_instances(
function_calls
)
instance_names = [
inst["module"]["name"] for inst in proposed_instances
]
summary = f"I suggest configuring these services: {', '.join(instance_names)}"
history.append(
_assistant_message(summary, mode=ASSISTANT_MODE_FINAL)
)
return ChatResult(
next_action=None,
proposed_instances=tuple(proposed_instances),
conversation_history=tuple(history),
assistant_message=summary,
requires_user_response=False,
error=None,
session_state=_state_copy(),
)
if final_message:
history.append(
_assistant_message(final_message, mode=ASSISTANT_MODE_FINAL)
)
state["pending_final_decision"] = cast(
"PendingFinalDecisionState",
{
"service_name": service_name,
"service_summary": service_summary,
},
)
return ChatResult(
next_action=None,
proposed_instances=(),
conversation_history=tuple(history),
assistant_message=final_message,
requires_user_response=True,
error=None,
session_state=_state_copy(),
)
state.pop("pending_final_decision", None)
msg = "LLM did not provide any response or recommendations"
raise ClanAiError(
msg,
description="Expected either function calls (configuration) or a clarifying message",
location="Final Decision Phase (pending)",
)
return ChatResult(
proposed_instances=(),
conversation_history=tuple(history),
assistant_message=final_message,
requires_user_response=True,
error=None,
session_state=_state_copy(),
)
state.pop("pending_final_decision", None)
msg = "LLM did not provide any response or recommendations"
raise ClanAiError(
msg,
description="Expected either function calls (configuration) or a clarifying message",
location="Final Decision Phase (pending)",
# If not executing, return next_action for final decision
next_action_final_pending: NextAction = {
"type": "final_decision",
"description": f"Generating configuration for {service_name}",
"estimated_duration_seconds": 20,
"details": {"service_name": service_name},
}
return ChatResult(
next_action=next_action_final_pending,
proposed_instances=(),
conversation_history=tuple(history),
assistant_message="",
requires_user_response=False,
error=None,
session_state=_state_copy(),
)
state.pop("pending_final_decision", None)
@@ -206,19 +276,12 @@ def process_chat_turn(
readme_results: dict[InputName, ServiceReadmeCollection],
) -> ChatResult:
# Extract all service names from readme results
all_service_names = [
[
service_name
for collection in readme_results.values()
for service_name in collection.readmes
]
if progress_callback:
progress_callback(
ServiceSelectionProgressEvent(
service_names=all_service_names, status="selecting"
)
)
selection_result = get_llm_service_selection(
user_request,
readme_results,
@@ -232,7 +295,6 @@ def process_chat_turn(
selection_result.clarifying_message
and not selection_result.selected_service
):
history.append(_user_message(user_request))
history.append(
_assistant_message(
selection_result.clarifying_message,
@@ -247,6 +309,7 @@ def process_chat_turn(
)
return ChatResult(
next_action=None,
proposed_instances=(),
conversation_history=tuple(history),
assistant_message=selection_result.clarifying_message,
@@ -267,81 +330,79 @@ def process_chat_turn(
location="Service Selection Phase",
)
if progress_callback:
progress_callback(FinalDecisionProgressEvent(status="reviewing"))
function_calls, final_message = get_llm_final_decision(
user_request,
flake,
selection_result.selected_service,
selection_result.service_summary,
conversation_history,
provider=provider,
trace_file=trace_file,
trace_metadata=_metadata(
{"selected_service": selection_result.selected_service}
),
# After service selection, always return next_action for final decision
state["pending_final_decision"] = cast(
"PendingFinalDecisionState",
{
"service_name": selection_result.selected_service,
"service_summary": selection_result.service_summary,
},
)
if progress_callback:
progress_callback(FinalDecisionProgressEvent(status="complete"))
if function_calls:
history.append(_user_message(user_request))
proposed_instances = llm_final_decision_to_inventory_instances(
function_calls
)
instance_names = [inst["module"]["name"] for inst in proposed_instances]
summary = (
f"I suggest configuring these services: {', '.join(instance_names)}"
)
history.append(_assistant_message(summary, mode=ASSISTANT_MODE_FINAL))
state.pop("pending_final_decision", None)
return ChatResult(
proposed_instances=tuple(proposed_instances),
conversation_history=tuple(history),
assistant_message=summary,
requires_user_response=False,
error=None,
session_state=_state_copy(),
)
if final_message:
history.append(_user_message(user_request))
history.append(_assistant_message(final_message, mode=ASSISTANT_MODE_FINAL))
state["pending_final_decision"] = cast(
"PendingFinalDecisionState",
{
"service_name": selection_result.selected_service,
"service_summary": selection_result.service_summary,
},
)
return ChatResult(
proposed_instances=(),
conversation_history=tuple(history),
assistant_message=final_message,
requires_user_response=True,
error=None,
session_state=_state_copy(),
)
msg = "LLM did not provide any response or recommendations"
raise ClanAiError(
msg,
description="Expected either function calls (configuration) or a clarifying message after service selection",
location="Final Decision Phase",
next_action_final: NextAction = {
"type": "final_decision",
"description": f"Generating configuration for {selection_result.selected_service}",
"estimated_duration_seconds": 20,
"details": {"service_name": selection_result.selected_service},
}
return ChatResult(
next_action=next_action_final,
proposed_instances=(),
conversation_history=tuple(history),
assistant_message="",
requires_user_response=False,
error=None,
session_state=_state_copy(),
)
if resume_readme_results is not None:
return _continue_with_service_selection(resume_readme_results)
if execute_next_action:
# Pop the pending state now that we're executing
state.pop("pending_service_selection", None)
return _continue_with_service_selection(resume_readme_results)
# If not executing, return next_action for service selection
service_count = len(resume_readme_results)
next_action_sel_resume: NextAction = {
"type": "service_selection",
"description": f"Analyzing {service_count} service(s) to find the best match",
"estimated_duration_seconds": 15,
"details": {"service_count": service_count},
}
return ChatResult(
next_action=next_action_sel_resume,
proposed_instances=(),
conversation_history=tuple(history),
assistant_message="",
requires_user_response=False,
error=None,
session_state=_state_copy(),
)
# Stage 1: Discovery phase
if progress_callback:
progress_callback(DiscoveryProgressEvent(status="analyzing"))
# If we're not executing and have no pending states, return next_action for discovery
has_pending_states = (
pending_discovery is not None
or pending_readme_fetch is not None
or pending_selection is not None
or pending_final is not None
)
if not execute_next_action and not has_pending_states:
state["pending_discovery"] = cast("PendingDiscoveryState", {})
next_action: NextAction = {
"type": "discovery",
"description": "Analyzing your request and discovering relevant services",
"estimated_duration_seconds": 10,
"details": {"phase": "discovery"},
}
return ChatResult(
next_action=next_action,
proposed_instances=(),
conversation_history=tuple(history),
assistant_message="",
requires_user_response=False,
error=None,
session_state=_state_copy(),
)
readme_requests, discovery_message = get_llm_discovery_phase(
user_request,
@@ -352,23 +413,14 @@ def process_chat_turn(
trace_metadata=_metadata(),
)
if progress_callback:
selected_services = [req["function_name"] for req in readme_requests]
progress_callback(
DiscoveryProgressEvent(
service_names=selected_services if selected_services else None,
status="complete",
)
)
# If LLM asked a question or made a recommendation without readme requests
if discovery_message and not readme_requests:
history.append(_user_message(user_request))
history.append(
_assistant_message(discovery_message, mode=ASSISTANT_MODE_DISCOVERY)
)
return ChatResult(
next_action=None,
proposed_instances=(),
conversation_history=tuple(history),
assistant_message=discovery_message,
@@ -377,34 +429,28 @@ def process_chat_turn(
session_state=_state_copy(),
)
# If we got readme requests, continue to selecting services
# If we got readme requests, save them and return next_action for readme fetch
if readme_requests:
# Stage 2: Fetch readmes
service_names = [
f"{req['function_name']} (from {req['input_name'] or 'built-in'})"
for req in readme_requests
]
if progress_callback:
progress_callback(
ReadmeFetchProgressEvent(
count=len(readme_requests),
service_names=service_names,
status="fetching",
)
)
readme_results = execute_readme_requests(readme_requests, flake)
if progress_callback:
progress_callback(
ReadmeFetchProgressEvent(
count=len(readme_requests),
service_names=service_names,
status="complete",
)
)
return _continue_with_service_selection(readme_results)
state["pending_readme_fetch"] = cast(
"PendingReadmeFetchState",
{"readme_requests": cast("list[dict[str, JSONValue]]", readme_requests)},
)
service_count = len(readme_requests)
next_action_fetch: NextAction = {
"type": "fetch_readmes",
"description": f"Fetching documentation for {service_count} service(s)",
"estimated_duration_seconds": 5,
"details": {"service_count": service_count},
}
return ChatResult(
next_action=next_action_fetch,
proposed_instances=(),
conversation_history=tuple(history),
assistant_message="",
requires_user_response=False,
error=None,
session_state=_state_copy(),
)
# No readme requests and no message - unexpected
msg = "LLM did not provide any response or recommendations"

View File

@@ -85,7 +85,8 @@ def get_llm_discovery_phase(
{"role": "assistant", "content": assistant_context},
]
messages.extend(_strip_conversation_metadata(conversation_history))
messages.append(_user_message(user_request))
if user_request:
messages.append(_user_message(user_request))
# Call LLM with only get_readme tool
model_config = get_model_config(provider)
@@ -233,7 +234,8 @@ def get_llm_service_selection(
{"role": "assistant", "content": combined_assistant_context},
]
messages.extend(_strip_conversation_metadata(conversation_history))
messages.append(_user_message(user_request))
if user_request:
messages.append(_user_message(user_request))
model_config = get_model_config(provider)
@@ -430,8 +432,8 @@ def get_llm_final_decision(
{"role": "assistant", "content": combined_assistant_context},
]
messages.extend(_strip_conversation_metadata(conversation_history))
messages.append(_user_message(user_request))
if user_request:
messages.append(_user_message(user_request))
# Get full schemas
model_config = get_model_config(provider)

View File

@@ -66,18 +66,28 @@ class ChatMessage(TypedDict):
ConversationHistory = list[ChatMessage]
class PendingFinalDecisionState(TypedDict, total=False):
service_name: NotRequired[str]
service_summary: NotRequired[str]
class PendingDiscoveryState(TypedDict, total=False):
user_request: NotRequired[str]
class PendingReadmeFetchState(TypedDict, total=False):
readme_requests: NotRequired[list[dict[str, JSONValue]]]
class PendingServiceSelectionState(TypedDict, total=False):
readme_results: NotRequired[list[dict[str, JSONValue]]]
class PendingFinalDecisionState(TypedDict, total=False):
service_name: NotRequired[str]
service_summary: NotRequired[str]
class SessionState(TypedDict, total=False):
pending_final_decision: NotRequired[PendingFinalDecisionState]
pending_discovery: NotRequired[PendingDiscoveryState]
pending_readme_fetch: NotRequired[PendingReadmeFetchState]
pending_service_selection: NotRequired[PendingServiceSelectionState]
pending_final_decision: NotRequired[PendingFinalDecisionState]
class JSONSchemaProperty(TypedDict, total=False):

View File

@@ -16,16 +16,12 @@ from clan_lib.llm.endpoints import (
parse_ollama_response,
parse_openai_response,
)
from clan_lib.llm.llm import (
DiscoveryProgressEvent,
FinalDecisionProgressEvent,
ReadmeFetchProgressEvent,
ServiceSelectionProgressEvent,
ServiceSelectionResult,
from clan_lib.llm.llm_types import ServiceSelectionResult
from clan_lib.llm.orchestrator import get_llm_turn
from clan_lib.llm.phases import (
execute_readme_requests,
get_llm_final_decision,
get_llm_service_selection,
process_chat_turn,
)
from clan_lib.llm.schemas import (
AiAggregate,
@@ -37,9 +33,44 @@ from clan_lib.llm.schemas import (
from clan_lib.services.modules import ServiceReadmeCollection
if TYPE_CHECKING:
from clan_lib.llm.llm_types import ChatResult
from clan_lib.llm.schemas import ChatMessage
def execute_multi_turn_workflow(
user_request: str,
flake: Flake,
conversation_history: list["ChatMessage"] | None = None,
provider: str = "claude",
session_state: SessionState | None = None,
) -> "ChatResult":
"""Execute the multi-turn workflow, auto-executing all pending operations.
This simulates the behavior of the CLI auto-execute loop in workflow.py.
"""
result = get_llm_turn(
user_request=user_request,
flake=flake,
conversation_history=conversation_history,
provider=provider, # type: ignore[arg-type]
session_state=session_state,
execute_next_action=False,
)
# Auto-execute any pending operations
while result.next_action:
result = get_llm_turn(
user_request="",
flake=flake,
conversation_history=list(result.conversation_history),
provider=provider, # type: ignore[arg-type]
session_state=result.session_state,
execute_next_action=True,
)
return result
@pytest.fixture
def trace_data() -> list[dict[str, Any]]:
"""Load trace data from mytrace.json."""
@@ -181,17 +212,23 @@ class TestProcessChatTurn:
# Mock final decision (shouldn't be called, but mock it anyway for safety)
mock_final.return_value = ([], "")
# Run process_chat_turn
result = process_chat_turn(
# Run multi-turn workflow
result = execute_multi_turn_workflow(
user_request="What VPNs are available?",
flake=mock_flake,
conversation_history=None,
provider="claude",
)
# Verify the call was made
# Verify the discovery call was made
assert mock_call.called
# Verify readme execution was called
assert mock_execute.called
# Verify service selection was called
assert mock_selection.called
# Final decision should NOT be called since we return early with clarifying message
assert not mock_final.called
@@ -262,8 +299,8 @@ class TestProcessChatTurn:
final_trace["response"]["message"],
)
# Run process_chat_turn with session state
result = process_chat_turn(
# Run multi-turn workflow with session state
result = execute_multi_turn_workflow(
user_request="Hmm zerotier please",
flake=mock_flake,
conversation_history=conversation_history,
@@ -271,6 +308,12 @@ class TestProcessChatTurn:
session_state=session_state,
)
# Verify service selection was called
assert mock_selection.called
# Verify final decision was called
assert mock_final.called
# Verify the result
assert result.requires_user_response is True
assert "controller" in result.assistant_message.lower()
@@ -335,8 +378,8 @@ class TestProcessChatTurn:
"",
)
# Run process_chat_turn
result = process_chat_turn(
# Run multi-turn workflow
result = execute_multi_turn_workflow(
user_request="okay then gchq-local as controller and qube-email as moon please everything else as peer",
flake=mock_flake,
conversation_history=conversation_history,
@@ -344,6 +387,9 @@ class TestProcessChatTurn:
session_state=session_state,
)
# Verify final decision was called
assert mock_final.called
# Verify the result
assert result.requires_user_response is False
assert len(result.proposed_instances) == 1
@@ -394,19 +440,19 @@ class TestProcessChatTurn:
)
mock_final.return_value = ([], "")
result1 = process_chat_turn(
result1 = execute_multi_turn_workflow(
user_request="What VPNs are available?",
flake=mock_flake,
provider="claude",
)
# Verify final decision was not called
# Verify final decision was not called (since we get clarifying message)
assert not mock_final.called
# Verify discovery completed and moved to service selection
# Verify discovery completed and service selection asked clarifying question
assert result1.requires_user_response is True
assert "VPN" in result1.assistant_message
# Session state should have pending_service_selection
# Session state should have pending_service_selection (with readme results)
assert "pending_service_selection" in result1.session_state
# Test Turn 2: Continue with session state
@@ -425,7 +471,7 @@ class TestProcessChatTurn:
)
mock_final.return_value = ([], trace_data[3]["response"]["message"])
result2 = process_chat_turn(
result2 = execute_multi_turn_workflow(
user_request="Hmm zerotier please",
flake=mock_flake,
conversation_history=list(result1.conversation_history),
@@ -485,7 +531,7 @@ class TestProcessChatTurn:
# Return empty function_calls but with a clarifying message
mock_final.return_value = ([], clarify_trace["response"]["message"])
result = process_chat_turn(
result = execute_multi_turn_workflow(
user_request="Set up zerotier with gchq-local as controller",
flake=mock_flake,
conversation_history=conversation_history,
@@ -535,7 +581,7 @@ class TestProcessChatTurn:
]
mock_final.return_value = ([], "")
result = process_chat_turn(
result = execute_multi_turn_workflow(
user_request="I want to set up a VPN",
flake=mock_flake,
provider="claude",
@@ -624,7 +670,7 @@ class TestProcessChatTurn:
]
)
result = process_chat_turn(
result = execute_multi_turn_workflow(
user_request="Use zerotier with gchq-local as controller, qube-email as moon, rest as peers",
flake=mock_flake,
conversation_history=conversation_history,
@@ -632,6 +678,12 @@ class TestProcessChatTurn:
session_state=session_state,
)
# Verify service selection was called
assert mock_selection.called
# Verify final decision was called
assert mock_final.called
# Verify the function_calls branch in _continue_with_service_selection
assert result.requires_user_response is False
assert len(result.proposed_instances) == 1
@@ -1004,15 +1056,16 @@ class TestProcessChatTurnPendingFinalDecision:
response = create_openai_response([], clarify_trace["response"]["message"])
mock_call.return_value = response
result = process_chat_turn(
result = get_llm_turn(
user_request="gchq-local as controller",
flake=mock_flake,
conversation_history=conversation_history,
provider="claude",
session_state=session_state,
execute_next_action=True, # Execute the pending final decision
)
# Verify the if final_message branch at line 425 was taken
# Verify the if final_message branch was taken
assert result.requires_user_response is True
assert result.assistant_message == clarify_trace["response"]["message"]
@@ -1074,12 +1127,13 @@ class TestProcessChatTurnPendingFinalDecision:
response = create_openai_response(function_calls, "")
mock_call.return_value = response
result = process_chat_turn(
result = get_llm_turn(
user_request="gchq-local as controller, qube-email as moon, rest as peers",
flake=mock_flake,
conversation_history=conversation_history,
provider="claude",
session_state=session_state,
execute_next_action=True, # Execute the pending final decision
)
# Verify configuration completed
@@ -1094,186 +1148,6 @@ class TestProcessChatTurnPendingFinalDecision:
assert result.error is None
class TestProgressCallbacks:
"""Test progress_callback functionality in process_chat_turn."""
def test_progress_callback_during_readme_fetch(
self, trace_data: list[dict[str, Any]], mock_flake: MagicMock
) -> None:
"""Test that progress_callback is called during README fetching."""
# Use trace entry with README requests
discovery_trace = trace_data[0]
function_calls = discovery_trace["response"]["function_calls"]
assert len(function_calls) > 0
# Track progress events
progress_events: list[Any] = []
def track_progress(event: Any) -> None:
progress_events.append(event)
# Create response with get_readme calls
response = create_openai_response(function_calls, "")
with (
patch("clan_lib.llm.phases.call_claude_api", return_value=response),
patch("clan_lib.llm.orchestrator.execute_readme_requests") as mock_execute,
patch(
"clan_lib.llm.orchestrator.get_llm_service_selection"
) as mock_selection,
patch("clan_lib.llm.orchestrator.get_llm_final_decision") as mock_final,
):
mock_execute.return_value = {
None: ServiceReadmeCollection(
input_name=None,
readmes={
"wireguard": "# WireGuard README",
"zerotier": "# ZeroTier README",
"mycelium": "# Mycelium README",
"yggdrasil": "# Yggdrasil README",
},
)
}
mock_selection.return_value = ServiceSelectionResult(
selected_service=None,
service_summary=None,
clarifying_message=trace_data[1]["response"]["message"],
)
mock_final.return_value = ([], "")
result = process_chat_turn(
user_request="What VPNs are available?",
flake=mock_flake,
provider="claude",
progress_callback=track_progress,
)
# Verify final decision was not called
assert not mock_final.called
# Verify progress events were sent
assert len(progress_events) > 0
# Check for discovery progress events
discovery_events = [
e for e in progress_events if isinstance(e, DiscoveryProgressEvent)
]
assert len(discovery_events) >= 2 # At least start and complete
# Check for readme fetch progress events
fetch_events = [
e for e in progress_events if isinstance(e, ReadmeFetchProgressEvent)
]
assert len(fetch_events) >= 2 # fetching and complete
# Verify the fetching event has correct data
fetching_event = next(e for e in fetch_events if e.status == "fetching")
assert fetching_event.count == len(function_calls)
# Service names include "(from built-in)" or "(from <input>)" suffix
assert any("wireguard" in name for name in fetching_event.service_names)
# Verify the complete event
complete_event = next(e for e in fetch_events if e.status == "complete")
assert complete_event.count == len(function_calls)
# Result should still be successful
assert result.requires_user_response is True
def test_progress_callback_through_full_workflow(
self, trace_data: list[dict[str, Any]], mock_flake: MagicMock
) -> None:
"""Test progress_callback through entire workflow from discovery to config."""
progress_events: list[Any] = []
def track_progress(event: Any) -> None:
progress_events.append(event)
# Setup for full workflow
discovery_response = create_openai_response(
trace_data[0]["response"]["function_calls"],
trace_data[0]["response"]["message"],
)
with (
patch(
"clan_lib.llm.phases.call_claude_api", return_value=discovery_response
),
patch("clan_lib.llm.orchestrator.execute_readme_requests") as mock_execute,
patch(
"clan_lib.llm.orchestrator.get_llm_service_selection"
) as mock_selection,
patch("clan_lib.llm.orchestrator.get_llm_final_decision") as mock_final,
patch("clan_lib.llm.phases.aggregate_ollama_function_schemas") as mock_agg,
):
mock_execute.return_value = {
None: ServiceReadmeCollection(
input_name=None, readmes={"zerotier": "# ZeroTier README"}
)
}
mock_selection.return_value = ServiceSelectionResult(
selected_service="zerotier",
service_summary="ZeroTier mesh VPN",
clarifying_message="",
)
# Return configuration
final_trace = trace_data[-1]
mock_final.return_value = (
[
FunctionCallType(
id="call_0",
call_id="call_0",
type="function_call",
name="zerotier",
arguments=json.dumps(
final_trace["response"]["function_calls"][0]["arguments"]
),
)
],
"",
)
mock_agg.return_value = MagicMock(
tools=[
{
"type": "function",
"function": {"name": "zerotier", "description": "ZeroTier VPN"},
}
]
)
result = process_chat_turn(
user_request="Setup zerotier with gchq-local as controller",
flake=mock_flake,
provider="claude",
progress_callback=track_progress,
)
# Verify we got progress events for all phases
discovery_events = [
e for e in progress_events if isinstance(e, DiscoveryProgressEvent)
]
fetch_events = [
e for e in progress_events if isinstance(e, ReadmeFetchProgressEvent)
]
selection_events = [
e
for e in progress_events
if isinstance(e, ServiceSelectionProgressEvent)
]
final_events = [
e for e in progress_events if isinstance(e, FinalDecisionProgressEvent)
]
# Should have events from all phases
assert len(discovery_events) > 0
assert len(fetch_events) > 0
assert len(selection_events) > 0
assert len(final_events) > 0
# Result should be successful with config
assert result.requires_user_response is False
assert len(result.proposed_instances) == 1
class TestErrorCases:
"""Test error handling in process_chat_turn."""
@@ -1288,7 +1162,8 @@ class TestErrorCases:
patch("clan_lib.llm.phases.call_claude_api", return_value=response),
pytest.raises(ClanAiError, match="did not provide any response"),
):
process_chat_turn(
# Use multi-turn workflow to execute through discovery
execute_multi_turn_workflow(
user_request="Setup a VPN",
flake=mock_flake,
provider="claude",
@@ -1304,7 +1179,8 @@ class TestErrorCases:
),
pytest.raises(ValueError, match="Test error"),
):
process_chat_turn(
# Use multi-turn workflow to execute through discovery
execute_multi_turn_workflow(
user_request="Setup a VPN",
flake=mock_flake,
provider="claude",
@@ -1326,86 +1202,14 @@ class TestErrorCases:
),
pytest.raises(RuntimeError, match="Network error"),
):
process_chat_turn(
# Use multi-turn workflow to execute through discovery
execute_multi_turn_workflow(
user_request="Setup zerotier",
flake=mock_flake,
conversation_history=conversation_history,
provider="claude",
)
def test_progress_callback_final_decision_reviewing_and_complete(
self, trace_data: list[dict[str, Any]], mock_flake: MagicMock
) -> None:
"""Test FinalDecisionProgressEvent with reviewing and complete statuses."""
progress_events: list[Any] = []
def track_progress(event: Any) -> None:
progress_events.append(event)
# Build conversation history and session state for pending_final_decision
conversation_history: list[ChatMessage] = [
{"role": "user", "content": "Setup VPN"},
{"role": "assistant", "content": "Which service?"},
{"role": "user", "content": "Use zerotier"},
{"role": "assistant", "content": "Which machine as controller?"},
]
session_state: SessionState = cast(
"SessionState",
{
"pending_final_decision": {
"service_name": "zerotier",
"service_summary": "ZeroTier mesh VPN",
}
},
)
# Use final trace with configuration
final_trace = trace_data[-1]
function_calls = final_trace["response"]["function_calls"]
with (
patch("clan_lib.llm.phases.aggregate_ollama_function_schemas") as mock_agg,
patch("clan_lib.llm.phases.call_claude_api") as mock_call,
):
mock_agg.return_value = MagicMock(
tools=[
{
"type": "function",
"function": {"name": "zerotier", "description": "ZeroTier VPN"},
}
]
)
response = create_openai_response(function_calls, "")
mock_call.return_value = response
result = process_chat_turn(
user_request="gchq-local as controller, qube-email as moon, rest as peers",
flake=mock_flake,
conversation_history=conversation_history,
provider="claude",
session_state=session_state,
progress_callback=track_progress,
)
# Verify we got FinalDecisionProgressEvent with both statuses
final_events = [
e for e in progress_events if isinstance(e, FinalDecisionProgressEvent)
]
assert len(final_events) >= 2
# Check for "reviewing" status
reviewing_events = [e for e in final_events if e.status == "reviewing"]
assert len(reviewing_events) >= 1
# Check for "complete" status
complete_events = [e for e in final_events if e.status == "complete"]
assert len(complete_events) >= 1
# Result should be successful
assert result.requires_user_response is False
assert len(result.proposed_instances) == 1
def test_service_selection_fails_no_service_selected(
self, mock_flake: MagicMock
) -> None:
@@ -1443,7 +1247,8 @@ class TestErrorCases:
# Should raise ClanAiError
with pytest.raises(ClanAiError, match="Failed to select service"):
process_chat_turn(
# Use multi-turn workflow to execute through service selection
execute_multi_turn_workflow(
user_request="Setup VPN",
flake=mock_flake,
provider="claude",
@@ -1680,7 +1485,8 @@ class TestGetLlmFinalDecisionErrors:
# Should raise ClanAiError
with pytest.raises(ClanAiError, match="LLM did not provide any response"):
process_chat_turn(
# Use multi-turn workflow to execute through final decision
execute_multi_turn_workflow(
user_request="gchq-local as controller",
flake=mock_flake,
conversation_history=conversation_history,