From c3456c1f0c80bf8162ba88cc20feb6f346df5fc4 Mon Sep 17 00:00:00 2001 From: Qubasa Date: Fri, 24 Oct 2025 15:57:26 +0200 Subject: [PATCH] clan_lib/llm: get_llm_turn uses state transitions instead of callback function --- checks/llm/default.nix | 1 - pkgs/clan-cli/clan_lib/llm/api.py | 200 -------- pkgs/clan-cli/clan_lib/llm/container_test.py | 193 ++++++-- pkgs/clan-cli/clan_lib/llm/llm.py | 65 --- pkgs/clan-cli/clan_lib/llm/llm_test.py | 7 +- pkgs/clan-cli/clan_lib/llm/llm_types.py | 68 +-- pkgs/clan-cli/clan_lib/llm/orchestrator.py | 454 ++++++++++-------- pkgs/clan-cli/clan_lib/llm/phases.py | 10 +- pkgs/clan-cli/clan_lib/llm/schemas.py | 18 +- .../clan_lib/llm/test_process_chat_turn.py | 372 ++++---------- 10 files changed, 529 insertions(+), 859 deletions(-) delete mode 100644 pkgs/clan-cli/clan_lib/llm/api.py delete mode 100644 pkgs/clan-cli/clan_lib/llm/llm.py diff --git a/checks/llm/default.nix b/checks/llm/default.nix index c79beb0db..0c0e143fc 100644 --- a/checks/llm/default.nix +++ b/checks/llm/default.nix @@ -76,7 +76,6 @@ in cmd = "su - text-user -c 'pytest -s -n0 -m service_runner -p no:cacheprovider -o addopts="" ${cli.passthru.sourceWithTests}/clan_lib/llm'" print("Running tests with command: " + cmd) - # Run tests as text-user (environment variables are set automatically) peer1.succeed(cmd) ''; diff --git a/pkgs/clan-cli/clan_lib/llm/api.py b/pkgs/clan-cli/clan_lib/llm/api.py deleted file mode 100644 index a0c150adc..000000000 --- a/pkgs/clan-cli/clan_lib/llm/api.py +++ /dev/null @@ -1,200 +0,0 @@ -"""High-level API functions for LLM interactions, suitable for HTTP APIs and web UIs. - -This module provides a clean, stateless API for integrating LLM functionality into -web applications and HTTP services. It wraps the complex multi-stage workflow into -simple function calls with serializable inputs and outputs. -""" - -from pathlib import Path -from typing import Any, Literal, TypedDict, cast - -from clan_lib.api import API -from clan_lib.flake.flake import Flake - -from .llm import ( - DEFAULT_MODELS, - ChatResult, - DiscoveryProgressEvent, - FinalDecisionProgressEvent, - ModelConfig, - ProgressCallback, - ProgressEvent, - ReadmeFetchProgressEvent, - get_model_config, - process_chat_turn, -) -from .schemas import ChatMessage, ConversationHistory, SessionState - - -class ChatTurnRequest(TypedDict, total=False): - """Request payload for a chat turn. - - Attributes: - user_message: The user's message/request - conversation_history: Optional list of prior messages in the conversation - provider: The LLM provider to use (default: "claude") - trace_file: Optional path to write LLM interaction traces for debugging - session_state: Opaque state returned from the previous turn - - """ - - user_message: str - conversation_history: ConversationHistory | None - provider: Literal["openai", "ollama", "claude"] - trace_file: Path | None - session_state: SessionState | None - - -class ChatTurnResponse(TypedDict): - """Response payload for a chat turn. - - Attributes: - proposed_instances: List of inventory instances suggested by the LLM - conversation_history: Updated conversation history after this turn - assistant_message: Message from the assistant - requires_user_response: Whether the assistant is waiting for user input - error: Error message if something went wrong (None on success) - session_state: State blob to pass into the next turn when continuing the workflow - - """ - - proposed_instances: list[dict[str, Any]] - conversation_history: list[ChatMessage] - assistant_message: str - requires_user_response: bool - error: str | None - session_state: SessionState - - -class ProgressEventResponse(TypedDict): - """Progress event for streaming updates. - - Attributes: - stage: The current stage of processing - status: The status within that stage (if applicable) - count: Count of items (for readme_fetch stage) - message: Message content (for conversation stage) - - """ - - stage: str - status: str | None - count: int | None - message: str | None - - -@API.register -def get_llm_turn( - flake: Flake, - request: ChatTurnRequest, - progress_callback: ProgressCallback | None = None, -) -> ChatTurnResponse: - """Process a single chat turn through the LLM workflow. - - This is the main entry point for HTTP APIs and web UIs to interact with - the LLM functionality. It handles: - - Service discovery - - Documentation fetching - - Final decision making - - Conversation management - - Args: - flake: The Flake object representing the clan configuration - request: The chat turn request containing user message and optional history - progress_callback: Optional callback for progress updates - - Returns: - ChatTurnResponse with proposed instances and conversation state - - Example: - >>> from clan_lib.flake.flake import Flake - >>> flake = Flake("/path/to/clan") - >>> request: ChatTurnRequest = { - ... "user_message": "Set up a web server", - ... "provider": "claude" - ... } - >>> response = chat_turn(flake, request) - >>> if response["proposed_instances"]: - ... print("LLM suggests:", response["proposed_instances"]) - >>> if response["requires_user_response"]: - ... print("Assistant asks:", response["assistant_message"]) - - """ - result: ChatResult = process_chat_turn( - user_request=request["user_message"], - flake=flake, - conversation_history=request.get("conversation_history"), - provider=request.get("provider", "claude"), - progress_callback=progress_callback, - trace_file=request.get("trace_file"), - session_state=request.get("session_state"), - ) - - # Convert frozen tuples to lists for JSON serialization - return ChatTurnResponse( - proposed_instances=[dict(inst) for inst in result.proposed_instances], - conversation_history=list(result.conversation_history), - assistant_message=result.assistant_message, - requires_user_response=result.requires_user_response, - error=result.error, - session_state=cast("SessionState", dict(result.session_state)), - ) - - -def progress_event_to_dict(event: ProgressEvent) -> ProgressEventResponse: - """Convert a ProgressEvent to a dictionary suitable for JSON serialization. - - This helper function is useful for streaming progress updates over HTTP - (e.g., Server-Sent Events or WebSockets). - - Args: - event: The progress event to convert - - Returns: - Dictionary representation of the event - - Example: - >>> from clan_lib.llm.llm import DiscoveryProgressEvent - >>> event = DiscoveryProgressEvent(status="analyzing") - >>> progress_event_to_dict(event) - {'stage': 'discovery', 'status': 'analyzing', 'count': None, 'message': None} - - """ - base_response: ProgressEventResponse = { - "stage": event.stage, - "status": None, - "count": None, - "message": None, - } - - if isinstance(event, (DiscoveryProgressEvent, FinalDecisionProgressEvent)): - base_response["status"] = event.status - elif isinstance(event, ReadmeFetchProgressEvent): - base_response["status"] = event.status - base_response["count"] = event.count - # ConversationProgressEvent has message field - elif hasattr(event, "message"): - base_response["message"] = event.message # type: ignore[attr-defined] - if hasattr(event, "awaiting_response"): - base_response["status"] = ( - "awaiting_response" - if event.awaiting_response # type: ignore[attr-defined] - else "complete" - ) - - return base_response - - -# Re-export types for convenience -__all__ = [ - "DEFAULT_MODELS", - "ChatTurnRequest", - "ChatTurnResponse", - "ModelConfig", - "ProgressCallback", - "ProgressEvent", - "ProgressEventResponse", - "get_llm_turn", - "get_model_config", - "progress_event_to_dict", -] diff --git a/pkgs/clan-cli/clan_lib/llm/container_test.py b/pkgs/clan-cli/clan_lib/llm/container_test.py index c9f9be420..2ce39695d 100644 --- a/pkgs/clan-cli/clan_lib/llm/container_test.py +++ b/pkgs/clan-cli/clan_lib/llm/container_test.py @@ -2,16 +2,86 @@ import contextlib import json from collections.abc import Iterator from pathlib import Path +from typing import TYPE_CHECKING from unittest.mock import MagicMock, patch import pytest from clan_lib.flake.flake import Flake -from clan_lib.llm.llm import ( - process_chat_turn, -) +from clan_lib.llm.orchestrator import get_llm_turn from clan_lib.llm.service import create_llm_model, run_llm_service from clan_lib.service_runner import create_service_manager +if TYPE_CHECKING: + from clan_lib.llm.llm_types import ChatResult + from clan_lib.llm.schemas import ChatMessage, SessionState + + +def get_current_mode(session_state: "SessionState") -> str: + """Extract the current mode from session state.""" + if "pending_service_selection" in session_state: + return "SERVICE_SELECTION" + if "pending_final_decision" in session_state: + return "FINAL_DECISION" + return "DISCOVERY" + + +def print_separator( + title: str = "", char: str = "=", width: int = 80, double: bool = True +) -> None: + """Print a separator line with optional title.""" + if double: + print(f"\n{char * width}") + if title: + padding = (width - len(title) - 2) // 2 + print(f"{char * padding} {title} {char * padding}") + if double or title: + print(f"{char * width}") + + +def print_meta_info(result: "ChatResult", turn: int, phase: str) -> None: + """Print meta information section in a structured format.""" + mode = get_current_mode(result.session_state) + print_separator("META INFORMATION", char="-", width=80, double=False) + print(f" Turn Number: {turn}") + print(f" Phase: {phase}") + print(f" Current Mode: {mode}") + print(f" Requires User Input: {result.requires_user_response}") + print(f" Conversation Length: {len(result.conversation_history)} messages") + print(f" Proposed Instances: {len(result.proposed_instances)}") + print(f" Has Next Action: {result.next_action is not None}") + print(f" Session State Keys: {list(result.session_state.keys())}") + if result.error: + print(f" Error: {result.error}") + print("-" * 80) + + +def print_chat_exchange( + user_msg: str | None, assistant_msg: str, session_state: "SessionState" +) -> None: + """Print a chat exchange with role labels and current mode.""" + mode = get_current_mode(session_state) + print_separator("CHAT EXCHANGE", char="-", width=80, double=False) + + if user_msg: + print("\n USER:") + print(f" {user_msg}") + + print(f"\n ASSISTANT [{mode}]:") + # Wrap long messages + max_line_length = 76 + words = assistant_msg.split() + current_line = " " + for word in words: + if len(current_line) + len(word) + 1 > max_line_length: + print(current_line) + current_line = " " + word + else: + current_line += (" " if len(current_line) > 2 else "") + word + if current_line.strip(): + print(current_line) + + print("\n" + "-" * 80) + @pytest.fixture def mock_flake() -> MagicMock: @@ -47,9 +117,9 @@ def mock_flake() -> MagicMock: } match arg: - case "clanInternals.inventoryClass.inventory.{instances,machines,meta}": + case "clanInternals.inventoryClass.inventorySerialization.{instances,machines,meta}": return load_json("inventory_instances_machines_meta.json") - case "clanInternals.inventoryClass.inventory.{tags}": + case "clanInternals.inventoryClass.inventorySerialization.{tags}": return load_json("inventory_tags.json") case "clanInternals.inventoryClass.modulesPerSource": return load_json("modules_per_source.json") @@ -98,6 +168,40 @@ def llm_service() -> Iterator[None]: service_manager.stop_service("ollama") +def execute_multi_turn_workflow( + user_request: str, + flake: Flake | MagicMock, + conversation_history: list["ChatMessage"] | None = None, + provider: str = "ollama", + session_state: "SessionState | None" = None, +) -> "ChatResult": + """Execute the multi-turn workflow, auto-executing all pending operations. + + This simulates the behavior of the CLI auto-execute loop in workflow.py. + """ + result = get_llm_turn( + user_request=user_request, + flake=flake, + conversation_history=conversation_history, + provider=provider, # type: ignore[arg-type] + session_state=session_state, + execute_next_action=False, + ) + + # Auto-execute any pending operations + while result.next_action: + result = get_llm_turn( + user_request="", + flake=flake, + conversation_history=list(result.conversation_history), + provider=provider, # type: ignore[arg-type] + session_state=result.session_state, + execute_next_action=True, + ) + + return result + + @pytest.mark.service_runner @pytest.mark.usefixtures("mock_nix_shell", "llm_service") def test_full_conversation_flow(mock_flake: MagicMock) -> None: @@ -112,10 +216,9 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None: - Error handling and edge cases """ flake = mock_flake - return # ========== TURN 1: Discovery Phase - Initial vague request ========== - print("\n=== TURN 1: Initial discovery request ===") - result = process_chat_turn( + print_separator("TURN 1: Discovery Phase", char="=", width=80) + result = execute_multi_turn_workflow( user_request="What VPN options do I have?", flake=flake, provider="ollama", @@ -133,24 +236,25 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None: assert result.conversation_history[-1]["role"] == "assistant" assert len(result.assistant_message) > 0, "Assistant should provide a response" - # Should transition to service selection phase with pending state - assert "pending_service_selection" in result.session_state, ( - "Should have pending service selection" - ) - assert "readme_results" in result.session_state["pending_service_selection"] + # After multi-turn execution, we may have either: + # - pending_service_selection (if LLM provided options and is waiting for choice) + # - pending_final_decision (if LLM directly selected a service) + # - no pending state (if LLM asked a clarifying question) # No instances yet assert len(result.proposed_instances) == 0 assert result.error is None - print(f"Assistant: {result.assistant_message[:200]}...") - print(f"State: {list(result.session_state.keys())}") - print(f"History length: {len(result.conversation_history)}") + print_chat_exchange( + "What VPN options do I have?", result.assistant_message, result.session_state + ) + print_meta_info(result, turn=1, phase="Discovery") # ========== TURN 2: Service Selection Phase - User makes a choice ========== - print("\n=== TURN 2: User selects ZeroTier ===") - result = process_chat_turn( - user_request="I'll use ZeroTier please", + print_separator("TURN 2: Service Selection", char="=", width=80) + user_msg_2 = "I'll use ZeroTier please" + result = execute_multi_turn_workflow( + user_request=user_msg_2, flake=flake, conversation_history=list(result.conversation_history), provider="ollama", @@ -176,11 +280,8 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None: assert len(result.proposed_instances) > 0 assert result.proposed_instances[0]["module"]["name"] == "zerotier" - print( - f"Assistant: {result.assistant_message[:200] if result.assistant_message else 'No message'}..." - ) - print(f"State: {list(result.session_state.keys())}") - print(f"Requires response: {result.requires_user_response}") + print_chat_exchange(user_msg_2, result.assistant_message, result.session_state) + print_meta_info(result, turn=2, phase="Service Selection") # ========== Continue conversation until we reach final decision or completion ========== max_turns = 10 @@ -188,22 +289,24 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None: while result.requires_user_response and turn_count < max_turns: turn_count += 1 - print(f"\n=== TURN {turn_count}: Continuing conversation ===") # Determine appropriate response based on current state if "pending_service_selection" in result.session_state: # Still selecting service user_request = "Yes, ZeroTier" + phase = "Service Selection (continued)" elif "pending_final_decision" in result.session_state: # Configuring the service user_request = "Set up gchq-local as controller, qube-email as moon, and wintux as peer" + phase = "Final Configuration" else: # Generic continuation user_request = "Yes, that sounds good. Use gchq-local as controller." + phase = "Continuing Conversation" - print(f"User: {user_request}") + print_separator(f"TURN {turn_count}: {phase}", char="=", width=80) - result = process_chat_turn( + result = execute_multi_turn_workflow( user_request=user_request, flake=flake, conversation_history=list(result.conversation_history), @@ -221,19 +324,18 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None: result.conversation_history[0]["content"] == "What VPN options do I have?" ) - print( - f"Assistant: {result.assistant_message[:200] if result.assistant_message else 'No message'}..." + print_chat_exchange( + user_request, result.assistant_message, result.session_state ) - print(f"State: {list(result.session_state.keys())}") - print(f"Requires response: {result.requires_user_response}") - print(f"Proposed instances: {len(result.proposed_instances)}") + print_meta_info(result, turn=turn_count, phase=phase) # Check for completion if not result.requires_user_response: - print("\n=== Conversation completed! ===") + print_separator("CONVERSATION COMPLETED", char="=", width=80) break # ========== Final Verification ========== + print_separator("FINAL VERIFICATION", char="=", width=80) assert turn_count < max_turns, f"Conversation took too many turns ({turn_count})" # If conversation completed, verify we have valid configuration @@ -253,22 +355,29 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None: "mycelium", ] - # Should have roles configuration - if "roles" in instance: - print(f"\nConfiguration roles: {list(instance['roles'].keys())}") - # Should not be in pending state anymore assert "pending_service_selection" not in result.session_state assert "pending_final_decision" not in result.session_state assert result.error is None, f"Should not have error: {result.error}" - print(f"\nFinal instance: {instance['module']['name']}") - print(f"Total conversation turns: {turn_count}") - print(f"Final history length: {len(result.conversation_history)}") + print_separator("FINAL SUMMARY", char="-", width=80, double=False) + print(" Status: SUCCESS") + print(f" Module Name: {instance['module']['name']}") + print(f" Total Turns: {turn_count}") + print(f" Final History Length: {len(result.conversation_history)} messages") + if "roles" in instance: + roles_list = ", ".join(instance["roles"].keys()) + print(f" Configuration Roles: {roles_list}") + print(" Errors: None") + print("-" * 80) else: # Conversation didn't complete but should have made progress assert len(result.conversation_history) > 2 assert result.error is None - print(f"\nConversation in progress after {turn_count} turns") - print(f"Current state: {list(result.session_state.keys())}") + print_separator("FINAL SUMMARY", char="-", width=80, double=False) + print(" Status: IN PROGRESS") + print(f" Total Turns: {turn_count}") + print(f" Current State: {list(result.session_state.keys())}") + print(f" History Length: {len(result.conversation_history)} messages") + print("-" * 80) diff --git a/pkgs/clan-cli/clan_lib/llm/llm.py b/pkgs/clan-cli/clan_lib/llm/llm.py deleted file mode 100644 index 1e0904ab7..000000000 --- a/pkgs/clan-cli/clan_lib/llm/llm.py +++ /dev/null @@ -1,65 +0,0 @@ -"""High-level LLM orchestration functions. - -This module re-exports the LLM orchestration API from submodules. -""" - -# Re-export types and dataclasses -from .llm_types import ( # noqa: F401 - DEFAULT_MODELS, - ChatResult, - ConversationProgressEvent, - DiscoveryProgressEvent, - FinalDecisionProgressEvent, - ModelConfig, - ProgressCallback, - ProgressEvent, - ReadmeFetchProgressEvent, - ServiceSelectionProgressEvent, - ServiceSelectionResult, - get_model_config, -) - -# Re-export high-level orchestrator -from .orchestrator import process_chat_turn # noqa: F401 - -# Re-export low-level phase functions -from .phases import ( # noqa: F401 - execute_readme_requests, - get_llm_discovery_phase, - get_llm_final_decision, - get_llm_service_selection, - llm_final_decision_to_inventory_instances, -) - -# Re-export commonly used functions and types from schemas -from .schemas import ( # noqa: F401 - AiAggregate, - ChatMessage, - ConversationHistory, - FunctionCallType, - JSONValue, - MachineDescription, - OllamaFunctionSchema, - OpenAIFunctionSchema, - PendingFinalDecisionState, - PendingServiceSelectionState, - ReadmeRequest, - SessionState, - SimplifiedServiceSchema, - TagDescription, - aggregate_ollama_function_schemas, - aggregate_openai_function_schemas, - create_get_readme_tool, - create_select_service_tool, - create_simplified_service_schemas, -) - -# Re-export service functions -from .service import create_llm_model, run_llm_service # noqa: F401 - -# Re-export utility functions and constants -from .utils import ( # noqa: F401 - ASSISTANT_MODE_DISCOVERY, - ASSISTANT_MODE_FINAL, - ASSISTANT_MODE_SELECTION, -) diff --git a/pkgs/clan-cli/clan_lib/llm/llm_test.py b/pkgs/clan-cli/clan_lib/llm/llm_test.py index 6b2acb73e..b0f72ec0b 100644 --- a/pkgs/clan-cli/clan_lib/llm/llm_test.py +++ b/pkgs/clan-cli/clan_lib/llm/llm_test.py @@ -3,12 +3,13 @@ from collections.abc import Callable import pytest from clan_cli.tests.fixtures_flakes import nested_dict from clan_lib.flake.flake import Flake -from clan_lib.llm.llm import ( +from clan_lib.llm.phases import llm_final_decision_to_inventory_instances +from clan_lib.llm.schemas import ( + FunctionCallType, OpenAIFunctionSchema, aggregate_openai_function_schemas, - llm_final_decision_to_inventory_instances, + clan_module_to_openai_spec, ) -from clan_lib.llm.schemas import FunctionCallType, clan_module_to_openai_spec from clan_lib.services.modules import list_service_modules diff --git a/pkgs/clan-cli/clan_lib/llm/llm_types.py b/pkgs/clan-cli/clan_lib/llm/llm_types.py index 44b1e0352..7e19c05bd 100644 --- a/pkgs/clan-cli/clan_lib/llm/llm_types.py +++ b/pkgs/clan-cli/clan_lib/llm/llm_types.py @@ -1,57 +1,28 @@ """Type definitions and dataclasses for LLM orchestration.""" -from collections.abc import Callable from dataclasses import dataclass -from typing import Literal +from typing import Any, Literal, TypedDict from clan_lib.nix_models.clan import InventoryInstance from .schemas import ChatMessage, SessionState -@dataclass(frozen=True) -class DiscoveryProgressEvent: - """Progress event during discovery phase.""" +class NextAction(TypedDict): + """Describes the next expensive operation that will be performed. - service_names: list[str] | None = None - stage: Literal["discovery"] = "discovery" - status: Literal["analyzing", "complete"] = "analyzing" + Attributes: + type: The type of operation (discovery, fetch_readmes, service_selection, final_decision) + description: Human-readable description of what will happen + estimated_duration_seconds: Rough estimate of operation duration + details: Phase-specific information (e.g., service names, count) + """ -@dataclass(frozen=True) -class ReadmeFetchProgressEvent: - """Progress event during readme fetching.""" - - count: int - service_names: list[str] - stage: Literal["readme_fetch"] = "readme_fetch" - status: Literal["fetching", "complete"] = "fetching" - - -@dataclass(frozen=True) -class ServiceSelectionProgressEvent: - """Progress event during service selection phase.""" - - service_names: list[str] - stage: Literal["service_selection"] = "service_selection" - status: Literal["selecting", "complete"] = "selecting" - - -@dataclass(frozen=True) -class FinalDecisionProgressEvent: - """Progress event during final decision phase.""" - - stage: Literal["final_decision"] = "final_decision" - status: Literal["reviewing", "complete"] = "reviewing" - - -@dataclass(frozen=True) -class ConversationProgressEvent: - """Progress event for conversation continuation.""" - - message: str - stage: Literal["conversation"] = "conversation" - awaiting_response: bool = True + type: Literal["discovery", "fetch_readmes", "service_selection", "final_decision"] + description: str + estimated_duration_seconds: int + details: dict[str, Any] @dataclass(frozen=True) @@ -70,17 +41,6 @@ class ServiceSelectionResult: clarifying_message: str -ProgressEvent = ( - DiscoveryProgressEvent - | ReadmeFetchProgressEvent - | ServiceSelectionProgressEvent - | FinalDecisionProgressEvent - | ConversationProgressEvent -) - -ProgressCallback = Callable[[ProgressEvent], None] - - @dataclass(frozen=True) class ChatResult: """Result of a complete chat turn through the multi-stage workflow. @@ -92,6 +52,7 @@ class ChatResult: requires_user_response: True if the assistant asked a question and needs a response error: Error message if something went wrong (None on success) session_state: Serializable state to pass into the next turn when continuing a workflow + next_action: Description of the next operation to be performed (None if workflow complete) """ @@ -100,6 +61,7 @@ class ChatResult: assistant_message: str requires_user_response: bool session_state: SessionState + next_action: NextAction | None error: str | None = None diff --git a/pkgs/clan-cli/clan_lib/llm/orchestrator.py b/pkgs/clan-cli/clan_lib/llm/orchestrator.py index b37adbdad..baccca437 100644 --- a/pkgs/clan-cli/clan_lib/llm/orchestrator.py +++ b/pkgs/clan-cli/clan_lib/llm/orchestrator.py @@ -4,18 +4,12 @@ import json from pathlib import Path from typing import Literal, cast +from clan_lib.api import API from clan_lib.errors import ClanAiError from clan_lib.flake.flake import Flake from clan_lib.services.modules import InputName, ServiceReadmeCollection -from .llm_types import ( - ChatResult, - DiscoveryProgressEvent, - FinalDecisionProgressEvent, - ProgressCallback, - ReadmeFetchProgressEvent, - ServiceSelectionProgressEvent, -) +from .llm_types import ChatResult, NextAction from .phases import ( execute_readme_requests, get_llm_discovery_phase, @@ -26,8 +20,11 @@ from .phases import ( from .schemas import ( ConversationHistory, JSONValue, + PendingDiscoveryState, PendingFinalDecisionState, + PendingReadmeFetchState, PendingServiceSelectionState, + ReadmeRequest, SessionState, ) from .utils import ( @@ -41,45 +38,50 @@ from .utils import ( ) -def process_chat_turn( +@API.register +def get_llm_turn( user_request: str, flake: Flake, conversation_history: ConversationHistory | None = None, provider: Literal["openai", "ollama", "claude"] = "ollama", - progress_callback: ProgressCallback | None = None, trace_file: Path | None = None, session_state: SessionState | None = None, + execute_next_action: bool = False, ) -> ChatResult: """High-level API that orchestrates the entire multi-stage chat workflow. - This function handles the complete flow: + This function handles the complete flow using a multi-turn approach: 1. Discovery phase - LLM selects relevant services 2. Readme fetching - Retrieves detailed documentation 3. Final decision - LLM makes informed suggestions 4. Conversion - Transforms suggestions to inventory instances + Before each expensive operation, the function returns with next_action describing + what will happen. The caller must call again with execute_next_action=True. + Args: user_request: The user's message/request flake: The Flake object to get services from conversation_history: Optional list of prior messages in the conversation provider: The LLM provider to use - progress_callback: Optional callback for progress updates trace_file: Optional path to write LLM interaction traces for debugging session_state: Optional cross-turn state to resume pending workflows + execute_next_action: If True, execute the pending operation in session_state Returns: - ChatResult containing proposed instances, updated history, and assistant message + ChatResult containing proposed instances, updated history, next_action, and assistant message Example: - >>> result = process_chat_turn( - ... "Set up a web server", - ... flake, - ... progress_callback=lambda event: print(f"Stage: {event.stage}") - ... ) - >>> if result.proposed_instances: - ... print("LLM suggested:", result.proposed_instances) - >>> if result.requires_user_response: - ... print("Assistant asks:", result.assistant_message) + >>> result = process_chat_turn("Set up a web server", flake) + >>> while result.next_action: + ... # Show user what will happen + ... print(result.next_action["description"]) + ... result = process_chat_turn( + ... user_request="", + ... flake=flake, + ... session_state=result.session_state, + ... execute_next_action=True + ... ) """ history = list(conversation_history) if conversation_history else [] @@ -87,6 +89,11 @@ def process_chat_turn( "SessionState", dict(session_state) if session_state else {} ) + # Add non-empty user message to history for conversation tracking + # Phases will also add it to their messages array for LLM calls + if user_request: + history.append(_user_message(user_request)) + def _state_snapshot() -> dict[str, JSONValue]: try: return json.loads(json.dumps(state)) @@ -102,6 +109,17 @@ def process_chat_turn( def _state_copy() -> SessionState: return cast("SessionState", dict(state)) + # Check pending states in workflow order (earliest to latest) + pending_discovery_raw = state.get("pending_discovery") + pending_discovery: PendingDiscoveryState | None = ( + pending_discovery_raw if isinstance(pending_discovery_raw, dict) else None + ) + + pending_readme_fetch_raw = state.get("pending_readme_fetch") + pending_readme_fetch: PendingReadmeFetchState | None = ( + pending_readme_fetch_raw if isinstance(pending_readme_fetch_raw, dict) else None + ) + pending_final_raw = state.get("pending_final_decision") pending_final: PendingFinalDecisionState | None = ( pending_final_raw if isinstance(pending_final_raw, dict) else None @@ -117,87 +135,139 @@ def process_chat_turn( if serialized_results is not None: resume_readme_results = _deserialize_readme_results(serialized_results) + # Only pop if we can't deserialize (invalid state) if resume_readme_results is None: state.pop("pending_service_selection", None) - else: - state.pop("pending_service_selection", None) + + # Handle pending_discovery state: execute discovery if execute_next_action=True + if pending_discovery is not None and execute_next_action: + state.pop("pending_discovery", None) + # Continue to execute discovery below (after pending state checks) + + # Handle pending_readme_fetch state: execute readme fetch if execute_next_action=True + if pending_readme_fetch is not None and execute_next_action: + readme_requests_raw = pending_readme_fetch.get("readme_requests", []) + readme_requests = cast("list[ReadmeRequest]", readme_requests_raw) + + if readme_requests: + state.pop("pending_readme_fetch", None) + readme_results = execute_readme_requests(readme_requests, flake) + + # Save readme results and return next_action for service selection + state["pending_service_selection"] = cast( + "PendingServiceSelectionState", + {"readme_results": _serialize_readme_results(readme_results)}, + ) + service_count = len(readme_results) + next_action_selection: NextAction = { + "type": "service_selection", + "description": f"Analyzing {service_count} service(s) to find the best match", + "estimated_duration_seconds": 15, + "details": {"service_count": service_count}, + } + return ChatResult( + next_action=next_action_selection, + proposed_instances=(), + conversation_history=tuple(history), + assistant_message="", + requires_user_response=False, + error=None, + session_state=_state_copy(), + ) + + state.pop("pending_readme_fetch", None) if pending_final is not None: service_name = pending_final.get("service_name") service_summary = pending_final.get("service_summary") if isinstance(service_name, str) and isinstance(service_summary, str): - if progress_callback: - progress_callback(FinalDecisionProgressEvent(status="reviewing")) - - function_calls, final_message = get_llm_final_decision( - user_request, - flake, - service_name, - service_summary, - conversation_history, - provider=provider, - trace_file=trace_file, - trace_metadata=_metadata( - { - "selected_service": service_name, - "resume": True, - } - ), - ) - - if progress_callback: - progress_callback(FinalDecisionProgressEvent(status="complete")) - - history.append(_user_message(user_request)) - - if function_calls: - proposed_instances = llm_final_decision_to_inventory_instances( - function_calls - ) - instance_names = [inst["module"]["name"] for inst in proposed_instances] - summary = ( - f"I suggest configuring these services: {', '.join(instance_names)}" - ) - history.append(_assistant_message(summary, mode=ASSISTANT_MODE_FINAL)) + if execute_next_action: state.pop("pending_final_decision", None) - return ChatResult( - proposed_instances=tuple(proposed_instances), - conversation_history=tuple(history), - assistant_message=summary, - requires_user_response=False, - error=None, - session_state=_state_copy(), + function_calls, final_message = get_llm_final_decision( + user_request, + flake, + service_name, + service_summary, + conversation_history, + provider=provider, + trace_file=trace_file, + trace_metadata=_metadata( + { + "selected_service": service_name, + "resume": True, + } + ), ) - if final_message: - history.append( - _assistant_message(final_message, mode=ASSISTANT_MODE_FINAL) - ) - state["pending_final_decision"] = cast( - "PendingFinalDecisionState", - { - "service_name": service_name, - "service_summary": service_summary, - }, + if function_calls: + proposed_instances = llm_final_decision_to_inventory_instances( + function_calls + ) + instance_names = [ + inst["module"]["name"] for inst in proposed_instances + ] + summary = f"I suggest configuring these services: {', '.join(instance_names)}" + history.append( + _assistant_message(summary, mode=ASSISTANT_MODE_FINAL) + ) + + return ChatResult( + next_action=None, + proposed_instances=tuple(proposed_instances), + conversation_history=tuple(history), + assistant_message=summary, + requires_user_response=False, + error=None, + session_state=_state_copy(), + ) + + if final_message: + history.append( + _assistant_message(final_message, mode=ASSISTANT_MODE_FINAL) + ) + state["pending_final_decision"] = cast( + "PendingFinalDecisionState", + { + "service_name": service_name, + "service_summary": service_summary, + }, + ) + + return ChatResult( + next_action=None, + proposed_instances=(), + conversation_history=tuple(history), + assistant_message=final_message, + requires_user_response=True, + error=None, + session_state=_state_copy(), + ) + + state.pop("pending_final_decision", None) + msg = "LLM did not provide any response or recommendations" + raise ClanAiError( + msg, + description="Expected either function calls (configuration) or a clarifying message", + location="Final Decision Phase (pending)", ) - return ChatResult( - proposed_instances=(), - conversation_history=tuple(history), - assistant_message=final_message, - requires_user_response=True, - error=None, - session_state=_state_copy(), - ) - - state.pop("pending_final_decision", None) - msg = "LLM did not provide any response or recommendations" - raise ClanAiError( - msg, - description="Expected either function calls (configuration) or a clarifying message", - location="Final Decision Phase (pending)", + # If not executing, return next_action for final decision + next_action_final_pending: NextAction = { + "type": "final_decision", + "description": f"Generating configuration for {service_name}", + "estimated_duration_seconds": 20, + "details": {"service_name": service_name}, + } + return ChatResult( + next_action=next_action_final_pending, + proposed_instances=(), + conversation_history=tuple(history), + assistant_message="", + requires_user_response=False, + error=None, + session_state=_state_copy(), ) state.pop("pending_final_decision", None) @@ -206,19 +276,12 @@ def process_chat_turn( readme_results: dict[InputName, ServiceReadmeCollection], ) -> ChatResult: # Extract all service names from readme results - all_service_names = [ + [ service_name for collection in readme_results.values() for service_name in collection.readmes ] - if progress_callback: - progress_callback( - ServiceSelectionProgressEvent( - service_names=all_service_names, status="selecting" - ) - ) - selection_result = get_llm_service_selection( user_request, readme_results, @@ -232,7 +295,6 @@ def process_chat_turn( selection_result.clarifying_message and not selection_result.selected_service ): - history.append(_user_message(user_request)) history.append( _assistant_message( selection_result.clarifying_message, @@ -247,6 +309,7 @@ def process_chat_turn( ) return ChatResult( + next_action=None, proposed_instances=(), conversation_history=tuple(history), assistant_message=selection_result.clarifying_message, @@ -267,81 +330,79 @@ def process_chat_turn( location="Service Selection Phase", ) - if progress_callback: - progress_callback(FinalDecisionProgressEvent(status="reviewing")) - - function_calls, final_message = get_llm_final_decision( - user_request, - flake, - selection_result.selected_service, - selection_result.service_summary, - conversation_history, - provider=provider, - trace_file=trace_file, - trace_metadata=_metadata( - {"selected_service": selection_result.selected_service} - ), + # After service selection, always return next_action for final decision + state["pending_final_decision"] = cast( + "PendingFinalDecisionState", + { + "service_name": selection_result.selected_service, + "service_summary": selection_result.service_summary, + }, ) - - if progress_callback: - progress_callback(FinalDecisionProgressEvent(status="complete")) - - if function_calls: - history.append(_user_message(user_request)) - - proposed_instances = llm_final_decision_to_inventory_instances( - function_calls - ) - - instance_names = [inst["module"]["name"] for inst in proposed_instances] - summary = ( - f"I suggest configuring these services: {', '.join(instance_names)}" - ) - history.append(_assistant_message(summary, mode=ASSISTANT_MODE_FINAL)) - state.pop("pending_final_decision", None) - - return ChatResult( - proposed_instances=tuple(proposed_instances), - conversation_history=tuple(history), - assistant_message=summary, - requires_user_response=False, - error=None, - session_state=_state_copy(), - ) - - if final_message: - history.append(_user_message(user_request)) - history.append(_assistant_message(final_message, mode=ASSISTANT_MODE_FINAL)) - state["pending_final_decision"] = cast( - "PendingFinalDecisionState", - { - "service_name": selection_result.selected_service, - "service_summary": selection_result.service_summary, - }, - ) - - return ChatResult( - proposed_instances=(), - conversation_history=tuple(history), - assistant_message=final_message, - requires_user_response=True, - error=None, - session_state=_state_copy(), - ) - - msg = "LLM did not provide any response or recommendations" - raise ClanAiError( - msg, - description="Expected either function calls (configuration) or a clarifying message after service selection", - location="Final Decision Phase", + next_action_final: NextAction = { + "type": "final_decision", + "description": f"Generating configuration for {selection_result.selected_service}", + "estimated_duration_seconds": 20, + "details": {"service_name": selection_result.selected_service}, + } + return ChatResult( + next_action=next_action_final, + proposed_instances=(), + conversation_history=tuple(history), + assistant_message="", + requires_user_response=False, + error=None, + session_state=_state_copy(), ) if resume_readme_results is not None: - return _continue_with_service_selection(resume_readme_results) + if execute_next_action: + # Pop the pending state now that we're executing + state.pop("pending_service_selection", None) + return _continue_with_service_selection(resume_readme_results) + + # If not executing, return next_action for service selection + service_count = len(resume_readme_results) + next_action_sel_resume: NextAction = { + "type": "service_selection", + "description": f"Analyzing {service_count} service(s) to find the best match", + "estimated_duration_seconds": 15, + "details": {"service_count": service_count}, + } + return ChatResult( + next_action=next_action_sel_resume, + proposed_instances=(), + conversation_history=tuple(history), + assistant_message="", + requires_user_response=False, + error=None, + session_state=_state_copy(), + ) # Stage 1: Discovery phase - if progress_callback: - progress_callback(DiscoveryProgressEvent(status="analyzing")) + # If we're not executing and have no pending states, return next_action for discovery + has_pending_states = ( + pending_discovery is not None + or pending_readme_fetch is not None + or pending_selection is not None + or pending_final is not None + ) + if not execute_next_action and not has_pending_states: + state["pending_discovery"] = cast("PendingDiscoveryState", {}) + next_action: NextAction = { + "type": "discovery", + "description": "Analyzing your request and discovering relevant services", + "estimated_duration_seconds": 10, + "details": {"phase": "discovery"}, + } + return ChatResult( + next_action=next_action, + proposed_instances=(), + conversation_history=tuple(history), + assistant_message="", + requires_user_response=False, + error=None, + session_state=_state_copy(), + ) readme_requests, discovery_message = get_llm_discovery_phase( user_request, @@ -352,23 +413,14 @@ def process_chat_turn( trace_metadata=_metadata(), ) - if progress_callback: - selected_services = [req["function_name"] for req in readme_requests] - progress_callback( - DiscoveryProgressEvent( - service_names=selected_services if selected_services else None, - status="complete", - ) - ) - # If LLM asked a question or made a recommendation without readme requests if discovery_message and not readme_requests: - history.append(_user_message(user_request)) history.append( _assistant_message(discovery_message, mode=ASSISTANT_MODE_DISCOVERY) ) return ChatResult( + next_action=None, proposed_instances=(), conversation_history=tuple(history), assistant_message=discovery_message, @@ -377,34 +429,28 @@ def process_chat_turn( session_state=_state_copy(), ) - # If we got readme requests, continue to selecting services + # If we got readme requests, save them and return next_action for readme fetch if readme_requests: - # Stage 2: Fetch readmes - service_names = [ - f"{req['function_name']} (from {req['input_name'] or 'built-in'})" - for req in readme_requests - ] - if progress_callback: - progress_callback( - ReadmeFetchProgressEvent( - count=len(readme_requests), - service_names=service_names, - status="fetching", - ) - ) - - readme_results = execute_readme_requests(readme_requests, flake) - - if progress_callback: - progress_callback( - ReadmeFetchProgressEvent( - count=len(readme_requests), - service_names=service_names, - status="complete", - ) - ) - - return _continue_with_service_selection(readme_results) + state["pending_readme_fetch"] = cast( + "PendingReadmeFetchState", + {"readme_requests": cast("list[dict[str, JSONValue]]", readme_requests)}, + ) + service_count = len(readme_requests) + next_action_fetch: NextAction = { + "type": "fetch_readmes", + "description": f"Fetching documentation for {service_count} service(s)", + "estimated_duration_seconds": 5, + "details": {"service_count": service_count}, + } + return ChatResult( + next_action=next_action_fetch, + proposed_instances=(), + conversation_history=tuple(history), + assistant_message="", + requires_user_response=False, + error=None, + session_state=_state_copy(), + ) # No readme requests and no message - unexpected msg = "LLM did not provide any response or recommendations" diff --git a/pkgs/clan-cli/clan_lib/llm/phases.py b/pkgs/clan-cli/clan_lib/llm/phases.py index d8960e487..80d8e74f6 100644 --- a/pkgs/clan-cli/clan_lib/llm/phases.py +++ b/pkgs/clan-cli/clan_lib/llm/phases.py @@ -85,7 +85,8 @@ def get_llm_discovery_phase( {"role": "assistant", "content": assistant_context}, ] messages.extend(_strip_conversation_metadata(conversation_history)) - messages.append(_user_message(user_request)) + if user_request: + messages.append(_user_message(user_request)) # Call LLM with only get_readme tool model_config = get_model_config(provider) @@ -233,7 +234,8 @@ def get_llm_service_selection( {"role": "assistant", "content": combined_assistant_context}, ] messages.extend(_strip_conversation_metadata(conversation_history)) - messages.append(_user_message(user_request)) + if user_request: + messages.append(_user_message(user_request)) model_config = get_model_config(provider) @@ -430,8 +432,8 @@ def get_llm_final_decision( {"role": "assistant", "content": combined_assistant_context}, ] messages.extend(_strip_conversation_metadata(conversation_history)) - - messages.append(_user_message(user_request)) + if user_request: + messages.append(_user_message(user_request)) # Get full schemas model_config = get_model_config(provider) diff --git a/pkgs/clan-cli/clan_lib/llm/schemas.py b/pkgs/clan-cli/clan_lib/llm/schemas.py index edb4d9687..0955245ff 100644 --- a/pkgs/clan-cli/clan_lib/llm/schemas.py +++ b/pkgs/clan-cli/clan_lib/llm/schemas.py @@ -66,18 +66,28 @@ class ChatMessage(TypedDict): ConversationHistory = list[ChatMessage] -class PendingFinalDecisionState(TypedDict, total=False): - service_name: NotRequired[str] - service_summary: NotRequired[str] +class PendingDiscoveryState(TypedDict, total=False): + user_request: NotRequired[str] + + +class PendingReadmeFetchState(TypedDict, total=False): + readme_requests: NotRequired[list[dict[str, JSONValue]]] class PendingServiceSelectionState(TypedDict, total=False): readme_results: NotRequired[list[dict[str, JSONValue]]] +class PendingFinalDecisionState(TypedDict, total=False): + service_name: NotRequired[str] + service_summary: NotRequired[str] + + class SessionState(TypedDict, total=False): - pending_final_decision: NotRequired[PendingFinalDecisionState] + pending_discovery: NotRequired[PendingDiscoveryState] + pending_readme_fetch: NotRequired[PendingReadmeFetchState] pending_service_selection: NotRequired[PendingServiceSelectionState] + pending_final_decision: NotRequired[PendingFinalDecisionState] class JSONSchemaProperty(TypedDict, total=False): diff --git a/pkgs/clan-cli/clan_lib/llm/test_process_chat_turn.py b/pkgs/clan-cli/clan_lib/llm/test_process_chat_turn.py index 0c11df2c3..20977dc3a 100644 --- a/pkgs/clan-cli/clan_lib/llm/test_process_chat_turn.py +++ b/pkgs/clan-cli/clan_lib/llm/test_process_chat_turn.py @@ -16,16 +16,12 @@ from clan_lib.llm.endpoints import ( parse_ollama_response, parse_openai_response, ) -from clan_lib.llm.llm import ( - DiscoveryProgressEvent, - FinalDecisionProgressEvent, - ReadmeFetchProgressEvent, - ServiceSelectionProgressEvent, - ServiceSelectionResult, +from clan_lib.llm.llm_types import ServiceSelectionResult +from clan_lib.llm.orchestrator import get_llm_turn +from clan_lib.llm.phases import ( execute_readme_requests, get_llm_final_decision, get_llm_service_selection, - process_chat_turn, ) from clan_lib.llm.schemas import ( AiAggregate, @@ -37,9 +33,44 @@ from clan_lib.llm.schemas import ( from clan_lib.services.modules import ServiceReadmeCollection if TYPE_CHECKING: + from clan_lib.llm.llm_types import ChatResult from clan_lib.llm.schemas import ChatMessage +def execute_multi_turn_workflow( + user_request: str, + flake: Flake, + conversation_history: list["ChatMessage"] | None = None, + provider: str = "claude", + session_state: SessionState | None = None, +) -> "ChatResult": + """Execute the multi-turn workflow, auto-executing all pending operations. + + This simulates the behavior of the CLI auto-execute loop in workflow.py. + """ + result = get_llm_turn( + user_request=user_request, + flake=flake, + conversation_history=conversation_history, + provider=provider, # type: ignore[arg-type] + session_state=session_state, + execute_next_action=False, + ) + + # Auto-execute any pending operations + while result.next_action: + result = get_llm_turn( + user_request="", + flake=flake, + conversation_history=list(result.conversation_history), + provider=provider, # type: ignore[arg-type] + session_state=result.session_state, + execute_next_action=True, + ) + + return result + + @pytest.fixture def trace_data() -> list[dict[str, Any]]: """Load trace data from mytrace.json.""" @@ -181,17 +212,23 @@ class TestProcessChatTurn: # Mock final decision (shouldn't be called, but mock it anyway for safety) mock_final.return_value = ([], "") - # Run process_chat_turn - result = process_chat_turn( + # Run multi-turn workflow + result = execute_multi_turn_workflow( user_request="What VPNs are available?", flake=mock_flake, conversation_history=None, provider="claude", ) - # Verify the call was made + # Verify the discovery call was made assert mock_call.called + # Verify readme execution was called + assert mock_execute.called + + # Verify service selection was called + assert mock_selection.called + # Final decision should NOT be called since we return early with clarifying message assert not mock_final.called @@ -262,8 +299,8 @@ class TestProcessChatTurn: final_trace["response"]["message"], ) - # Run process_chat_turn with session state - result = process_chat_turn( + # Run multi-turn workflow with session state + result = execute_multi_turn_workflow( user_request="Hmm zerotier please", flake=mock_flake, conversation_history=conversation_history, @@ -271,6 +308,12 @@ class TestProcessChatTurn: session_state=session_state, ) + # Verify service selection was called + assert mock_selection.called + + # Verify final decision was called + assert mock_final.called + # Verify the result assert result.requires_user_response is True assert "controller" in result.assistant_message.lower() @@ -335,8 +378,8 @@ class TestProcessChatTurn: "", ) - # Run process_chat_turn - result = process_chat_turn( + # Run multi-turn workflow + result = execute_multi_turn_workflow( user_request="okay then gchq-local as controller and qube-email as moon please everything else as peer", flake=mock_flake, conversation_history=conversation_history, @@ -344,6 +387,9 @@ class TestProcessChatTurn: session_state=session_state, ) + # Verify final decision was called + assert mock_final.called + # Verify the result assert result.requires_user_response is False assert len(result.proposed_instances) == 1 @@ -394,19 +440,19 @@ class TestProcessChatTurn: ) mock_final.return_value = ([], "") - result1 = process_chat_turn( + result1 = execute_multi_turn_workflow( user_request="What VPNs are available?", flake=mock_flake, provider="claude", ) - # Verify final decision was not called + # Verify final decision was not called (since we get clarifying message) assert not mock_final.called - # Verify discovery completed and moved to service selection + # Verify discovery completed and service selection asked clarifying question assert result1.requires_user_response is True assert "VPN" in result1.assistant_message - # Session state should have pending_service_selection + # Session state should have pending_service_selection (with readme results) assert "pending_service_selection" in result1.session_state # Test Turn 2: Continue with session state @@ -425,7 +471,7 @@ class TestProcessChatTurn: ) mock_final.return_value = ([], trace_data[3]["response"]["message"]) - result2 = process_chat_turn( + result2 = execute_multi_turn_workflow( user_request="Hmm zerotier please", flake=mock_flake, conversation_history=list(result1.conversation_history), @@ -485,7 +531,7 @@ class TestProcessChatTurn: # Return empty function_calls but with a clarifying message mock_final.return_value = ([], clarify_trace["response"]["message"]) - result = process_chat_turn( + result = execute_multi_turn_workflow( user_request="Set up zerotier with gchq-local as controller", flake=mock_flake, conversation_history=conversation_history, @@ -535,7 +581,7 @@ class TestProcessChatTurn: ] mock_final.return_value = ([], "") - result = process_chat_turn( + result = execute_multi_turn_workflow( user_request="I want to set up a VPN", flake=mock_flake, provider="claude", @@ -624,7 +670,7 @@ class TestProcessChatTurn: ] ) - result = process_chat_turn( + result = execute_multi_turn_workflow( user_request="Use zerotier with gchq-local as controller, qube-email as moon, rest as peers", flake=mock_flake, conversation_history=conversation_history, @@ -632,6 +678,12 @@ class TestProcessChatTurn: session_state=session_state, ) + # Verify service selection was called + assert mock_selection.called + + # Verify final decision was called + assert mock_final.called + # Verify the function_calls branch in _continue_with_service_selection assert result.requires_user_response is False assert len(result.proposed_instances) == 1 @@ -1004,15 +1056,16 @@ class TestProcessChatTurnPendingFinalDecision: response = create_openai_response([], clarify_trace["response"]["message"]) mock_call.return_value = response - result = process_chat_turn( + result = get_llm_turn( user_request="gchq-local as controller", flake=mock_flake, conversation_history=conversation_history, provider="claude", session_state=session_state, + execute_next_action=True, # Execute the pending final decision ) - # Verify the if final_message branch at line 425 was taken + # Verify the if final_message branch was taken assert result.requires_user_response is True assert result.assistant_message == clarify_trace["response"]["message"] @@ -1074,12 +1127,13 @@ class TestProcessChatTurnPendingFinalDecision: response = create_openai_response(function_calls, "") mock_call.return_value = response - result = process_chat_turn( + result = get_llm_turn( user_request="gchq-local as controller, qube-email as moon, rest as peers", flake=mock_flake, conversation_history=conversation_history, provider="claude", session_state=session_state, + execute_next_action=True, # Execute the pending final decision ) # Verify configuration completed @@ -1094,186 +1148,6 @@ class TestProcessChatTurnPendingFinalDecision: assert result.error is None -class TestProgressCallbacks: - """Test progress_callback functionality in process_chat_turn.""" - - def test_progress_callback_during_readme_fetch( - self, trace_data: list[dict[str, Any]], mock_flake: MagicMock - ) -> None: - """Test that progress_callback is called during README fetching.""" - # Use trace entry with README requests - discovery_trace = trace_data[0] - function_calls = discovery_trace["response"]["function_calls"] - assert len(function_calls) > 0 - - # Track progress events - progress_events: list[Any] = [] - - def track_progress(event: Any) -> None: - progress_events.append(event) - - # Create response with get_readme calls - response = create_openai_response(function_calls, "") - - with ( - patch("clan_lib.llm.phases.call_claude_api", return_value=response), - patch("clan_lib.llm.orchestrator.execute_readme_requests") as mock_execute, - patch( - "clan_lib.llm.orchestrator.get_llm_service_selection" - ) as mock_selection, - patch("clan_lib.llm.orchestrator.get_llm_final_decision") as mock_final, - ): - mock_execute.return_value = { - None: ServiceReadmeCollection( - input_name=None, - readmes={ - "wireguard": "# WireGuard README", - "zerotier": "# ZeroTier README", - "mycelium": "# Mycelium README", - "yggdrasil": "# Yggdrasil README", - }, - ) - } - mock_selection.return_value = ServiceSelectionResult( - selected_service=None, - service_summary=None, - clarifying_message=trace_data[1]["response"]["message"], - ) - mock_final.return_value = ([], "") - - result = process_chat_turn( - user_request="What VPNs are available?", - flake=mock_flake, - provider="claude", - progress_callback=track_progress, - ) - - # Verify final decision was not called - assert not mock_final.called - - # Verify progress events were sent - assert len(progress_events) > 0 - - # Check for discovery progress events - discovery_events = [ - e for e in progress_events if isinstance(e, DiscoveryProgressEvent) - ] - assert len(discovery_events) >= 2 # At least start and complete - - # Check for readme fetch progress events - fetch_events = [ - e for e in progress_events if isinstance(e, ReadmeFetchProgressEvent) - ] - assert len(fetch_events) >= 2 # fetching and complete - - # Verify the fetching event has correct data - fetching_event = next(e for e in fetch_events if e.status == "fetching") - assert fetching_event.count == len(function_calls) - # Service names include "(from built-in)" or "(from )" suffix - assert any("wireguard" in name for name in fetching_event.service_names) - - # Verify the complete event - complete_event = next(e for e in fetch_events if e.status == "complete") - assert complete_event.count == len(function_calls) - - # Result should still be successful - assert result.requires_user_response is True - - def test_progress_callback_through_full_workflow( - self, trace_data: list[dict[str, Any]], mock_flake: MagicMock - ) -> None: - """Test progress_callback through entire workflow from discovery to config.""" - progress_events: list[Any] = [] - - def track_progress(event: Any) -> None: - progress_events.append(event) - - # Setup for full workflow - discovery_response = create_openai_response( - trace_data[0]["response"]["function_calls"], - trace_data[0]["response"]["message"], - ) - - with ( - patch( - "clan_lib.llm.phases.call_claude_api", return_value=discovery_response - ), - patch("clan_lib.llm.orchestrator.execute_readme_requests") as mock_execute, - patch( - "clan_lib.llm.orchestrator.get_llm_service_selection" - ) as mock_selection, - patch("clan_lib.llm.orchestrator.get_llm_final_decision") as mock_final, - patch("clan_lib.llm.phases.aggregate_ollama_function_schemas") as mock_agg, - ): - mock_execute.return_value = { - None: ServiceReadmeCollection( - input_name=None, readmes={"zerotier": "# ZeroTier README"} - ) - } - mock_selection.return_value = ServiceSelectionResult( - selected_service="zerotier", - service_summary="ZeroTier mesh VPN", - clarifying_message="", - ) - # Return configuration - final_trace = trace_data[-1] - mock_final.return_value = ( - [ - FunctionCallType( - id="call_0", - call_id="call_0", - type="function_call", - name="zerotier", - arguments=json.dumps( - final_trace["response"]["function_calls"][0]["arguments"] - ), - ) - ], - "", - ) - mock_agg.return_value = MagicMock( - tools=[ - { - "type": "function", - "function": {"name": "zerotier", "description": "ZeroTier VPN"}, - } - ] - ) - - result = process_chat_turn( - user_request="Setup zerotier with gchq-local as controller", - flake=mock_flake, - provider="claude", - progress_callback=track_progress, - ) - - # Verify we got progress events for all phases - discovery_events = [ - e for e in progress_events if isinstance(e, DiscoveryProgressEvent) - ] - fetch_events = [ - e for e in progress_events if isinstance(e, ReadmeFetchProgressEvent) - ] - selection_events = [ - e - for e in progress_events - if isinstance(e, ServiceSelectionProgressEvent) - ] - final_events = [ - e for e in progress_events if isinstance(e, FinalDecisionProgressEvent) - ] - - # Should have events from all phases - assert len(discovery_events) > 0 - assert len(fetch_events) > 0 - assert len(selection_events) > 0 - assert len(final_events) > 0 - - # Result should be successful with config - assert result.requires_user_response is False - assert len(result.proposed_instances) == 1 - - class TestErrorCases: """Test error handling in process_chat_turn.""" @@ -1288,7 +1162,8 @@ class TestErrorCases: patch("clan_lib.llm.phases.call_claude_api", return_value=response), pytest.raises(ClanAiError, match="did not provide any response"), ): - process_chat_turn( + # Use multi-turn workflow to execute through discovery + execute_multi_turn_workflow( user_request="Setup a VPN", flake=mock_flake, provider="claude", @@ -1304,7 +1179,8 @@ class TestErrorCases: ), pytest.raises(ValueError, match="Test error"), ): - process_chat_turn( + # Use multi-turn workflow to execute through discovery + execute_multi_turn_workflow( user_request="Setup a VPN", flake=mock_flake, provider="claude", @@ -1326,86 +1202,14 @@ class TestErrorCases: ), pytest.raises(RuntimeError, match="Network error"), ): - process_chat_turn( + # Use multi-turn workflow to execute through discovery + execute_multi_turn_workflow( user_request="Setup zerotier", flake=mock_flake, conversation_history=conversation_history, provider="claude", ) - def test_progress_callback_final_decision_reviewing_and_complete( - self, trace_data: list[dict[str, Any]], mock_flake: MagicMock - ) -> None: - """Test FinalDecisionProgressEvent with reviewing and complete statuses.""" - progress_events: list[Any] = [] - - def track_progress(event: Any) -> None: - progress_events.append(event) - - # Build conversation history and session state for pending_final_decision - conversation_history: list[ChatMessage] = [ - {"role": "user", "content": "Setup VPN"}, - {"role": "assistant", "content": "Which service?"}, - {"role": "user", "content": "Use zerotier"}, - {"role": "assistant", "content": "Which machine as controller?"}, - ] - - session_state: SessionState = cast( - "SessionState", - { - "pending_final_decision": { - "service_name": "zerotier", - "service_summary": "ZeroTier mesh VPN", - } - }, - ) - - # Use final trace with configuration - final_trace = trace_data[-1] - function_calls = final_trace["response"]["function_calls"] - - with ( - patch("clan_lib.llm.phases.aggregate_ollama_function_schemas") as mock_agg, - patch("clan_lib.llm.phases.call_claude_api") as mock_call, - ): - mock_agg.return_value = MagicMock( - tools=[ - { - "type": "function", - "function": {"name": "zerotier", "description": "ZeroTier VPN"}, - } - ] - ) - response = create_openai_response(function_calls, "") - mock_call.return_value = response - - result = process_chat_turn( - user_request="gchq-local as controller, qube-email as moon, rest as peers", - flake=mock_flake, - conversation_history=conversation_history, - provider="claude", - session_state=session_state, - progress_callback=track_progress, - ) - - # Verify we got FinalDecisionProgressEvent with both statuses - final_events = [ - e for e in progress_events if isinstance(e, FinalDecisionProgressEvent) - ] - assert len(final_events) >= 2 - - # Check for "reviewing" status - reviewing_events = [e for e in final_events if e.status == "reviewing"] - assert len(reviewing_events) >= 1 - - # Check for "complete" status - complete_events = [e for e in final_events if e.status == "complete"] - assert len(complete_events) >= 1 - - # Result should be successful - assert result.requires_user_response is False - assert len(result.proposed_instances) == 1 - def test_service_selection_fails_no_service_selected( self, mock_flake: MagicMock ) -> None: @@ -1443,7 +1247,8 @@ class TestErrorCases: # Should raise ClanAiError with pytest.raises(ClanAiError, match="Failed to select service"): - process_chat_turn( + # Use multi-turn workflow to execute through service selection + execute_multi_turn_workflow( user_request="Setup VPN", flake=mock_flake, provider="claude", @@ -1680,7 +1485,8 @@ class TestGetLlmFinalDecisionErrors: # Should raise ClanAiError with pytest.raises(ClanAiError, match="LLM did not provide any response"): - process_chat_turn( + # Use multi-turn workflow to execute through final decision + execute_multi_turn_workflow( user_request="gchq-local as controller", flake=mock_flake, conversation_history=conversation_history,