clan_lib/llm: get_llm_turn uses state transitions instead of callback function

This commit is contained in:
Qubasa
2025-10-24 15:57:26 +02:00
parent 183de9209f
commit c3456c1f0c
10 changed files with 529 additions and 859 deletions

View File

@@ -76,7 +76,6 @@ in
cmd = "su - text-user -c 'pytest -s -n0 -m service_runner -p no:cacheprovider -o addopts="" ${cli.passthru.sourceWithTests}/clan_lib/llm'" cmd = "su - text-user -c 'pytest -s -n0 -m service_runner -p no:cacheprovider -o addopts="" ${cli.passthru.sourceWithTests}/clan_lib/llm'"
print("Running tests with command: " + cmd) print("Running tests with command: " + cmd)
# Run tests as text-user (environment variables are set automatically) # Run tests as text-user (environment variables are set automatically)
peer1.succeed(cmd) peer1.succeed(cmd)
''; '';

View File

@@ -1,200 +0,0 @@
"""High-level API functions for LLM interactions, suitable for HTTP APIs and web UIs.
This module provides a clean, stateless API for integrating LLM functionality into
web applications and HTTP services. It wraps the complex multi-stage workflow into
simple function calls with serializable inputs and outputs.
"""
from pathlib import Path
from typing import Any, Literal, TypedDict, cast
from clan_lib.api import API
from clan_lib.flake.flake import Flake
from .llm import (
DEFAULT_MODELS,
ChatResult,
DiscoveryProgressEvent,
FinalDecisionProgressEvent,
ModelConfig,
ProgressCallback,
ProgressEvent,
ReadmeFetchProgressEvent,
get_model_config,
process_chat_turn,
)
from .schemas import ChatMessage, ConversationHistory, SessionState
class ChatTurnRequest(TypedDict, total=False):
"""Request payload for a chat turn.
Attributes:
user_message: The user's message/request
conversation_history: Optional list of prior messages in the conversation
provider: The LLM provider to use (default: "claude")
trace_file: Optional path to write LLM interaction traces for debugging
session_state: Opaque state returned from the previous turn
"""
user_message: str
conversation_history: ConversationHistory | None
provider: Literal["openai", "ollama", "claude"]
trace_file: Path | None
session_state: SessionState | None
class ChatTurnResponse(TypedDict):
"""Response payload for a chat turn.
Attributes:
proposed_instances: List of inventory instances suggested by the LLM
conversation_history: Updated conversation history after this turn
assistant_message: Message from the assistant
requires_user_response: Whether the assistant is waiting for user input
error: Error message if something went wrong (None on success)
session_state: State blob to pass into the next turn when continuing the workflow
"""
proposed_instances: list[dict[str, Any]]
conversation_history: list[ChatMessage]
assistant_message: str
requires_user_response: bool
error: str | None
session_state: SessionState
class ProgressEventResponse(TypedDict):
"""Progress event for streaming updates.
Attributes:
stage: The current stage of processing
status: The status within that stage (if applicable)
count: Count of items (for readme_fetch stage)
message: Message content (for conversation stage)
"""
stage: str
status: str | None
count: int | None
message: str | None
@API.register
def get_llm_turn(
flake: Flake,
request: ChatTurnRequest,
progress_callback: ProgressCallback | None = None,
) -> ChatTurnResponse:
"""Process a single chat turn through the LLM workflow.
This is the main entry point for HTTP APIs and web UIs to interact with
the LLM functionality. It handles:
- Service discovery
- Documentation fetching
- Final decision making
- Conversation management
Args:
flake: The Flake object representing the clan configuration
request: The chat turn request containing user message and optional history
progress_callback: Optional callback for progress updates
Returns:
ChatTurnResponse with proposed instances and conversation state
Example:
>>> from clan_lib.flake.flake import Flake
>>> flake = Flake("/path/to/clan")
>>> request: ChatTurnRequest = {
... "user_message": "Set up a web server",
... "provider": "claude"
... }
>>> response = chat_turn(flake, request)
>>> if response["proposed_instances"]:
... print("LLM suggests:", response["proposed_instances"])
>>> if response["requires_user_response"]:
... print("Assistant asks:", response["assistant_message"])
"""
result: ChatResult = process_chat_turn(
user_request=request["user_message"],
flake=flake,
conversation_history=request.get("conversation_history"),
provider=request.get("provider", "claude"),
progress_callback=progress_callback,
trace_file=request.get("trace_file"),
session_state=request.get("session_state"),
)
# Convert frozen tuples to lists for JSON serialization
return ChatTurnResponse(
proposed_instances=[dict(inst) for inst in result.proposed_instances],
conversation_history=list(result.conversation_history),
assistant_message=result.assistant_message,
requires_user_response=result.requires_user_response,
error=result.error,
session_state=cast("SessionState", dict(result.session_state)),
)
def progress_event_to_dict(event: ProgressEvent) -> ProgressEventResponse:
"""Convert a ProgressEvent to a dictionary suitable for JSON serialization.
This helper function is useful for streaming progress updates over HTTP
(e.g., Server-Sent Events or WebSockets).
Args:
event: The progress event to convert
Returns:
Dictionary representation of the event
Example:
>>> from clan_lib.llm.llm import DiscoveryProgressEvent
>>> event = DiscoveryProgressEvent(status="analyzing")
>>> progress_event_to_dict(event)
{'stage': 'discovery', 'status': 'analyzing', 'count': None, 'message': None}
"""
base_response: ProgressEventResponse = {
"stage": event.stage,
"status": None,
"count": None,
"message": None,
}
if isinstance(event, (DiscoveryProgressEvent, FinalDecisionProgressEvent)):
base_response["status"] = event.status
elif isinstance(event, ReadmeFetchProgressEvent):
base_response["status"] = event.status
base_response["count"] = event.count
# ConversationProgressEvent has message field
elif hasattr(event, "message"):
base_response["message"] = event.message # type: ignore[attr-defined]
if hasattr(event, "awaiting_response"):
base_response["status"] = (
"awaiting_response"
if event.awaiting_response # type: ignore[attr-defined]
else "complete"
)
return base_response
# Re-export types for convenience
__all__ = [
"DEFAULT_MODELS",
"ChatTurnRequest",
"ChatTurnResponse",
"ModelConfig",
"ProgressCallback",
"ProgressEvent",
"ProgressEventResponse",
"get_llm_turn",
"get_model_config",
"progress_event_to_dict",
]

View File

@@ -2,16 +2,86 @@ import contextlib
import json import json
from collections.abc import Iterator from collections.abc import Iterator
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
from clan_lib.flake.flake import Flake from clan_lib.flake.flake import Flake
from clan_lib.llm.llm import ( from clan_lib.llm.orchestrator import get_llm_turn
process_chat_turn,
)
from clan_lib.llm.service import create_llm_model, run_llm_service from clan_lib.llm.service import create_llm_model, run_llm_service
from clan_lib.service_runner import create_service_manager from clan_lib.service_runner import create_service_manager
if TYPE_CHECKING:
from clan_lib.llm.llm_types import ChatResult
from clan_lib.llm.schemas import ChatMessage, SessionState
def get_current_mode(session_state: "SessionState") -> str:
"""Extract the current mode from session state."""
if "pending_service_selection" in session_state:
return "SERVICE_SELECTION"
if "pending_final_decision" in session_state:
return "FINAL_DECISION"
return "DISCOVERY"
def print_separator(
title: str = "", char: str = "=", width: int = 80, double: bool = True
) -> None:
"""Print a separator line with optional title."""
if double:
print(f"\n{char * width}")
if title:
padding = (width - len(title) - 2) // 2
print(f"{char * padding} {title} {char * padding}")
if double or title:
print(f"{char * width}")
def print_meta_info(result: "ChatResult", turn: int, phase: str) -> None:
"""Print meta information section in a structured format."""
mode = get_current_mode(result.session_state)
print_separator("META INFORMATION", char="-", width=80, double=False)
print(f" Turn Number: {turn}")
print(f" Phase: {phase}")
print(f" Current Mode: {mode}")
print(f" Requires User Input: {result.requires_user_response}")
print(f" Conversation Length: {len(result.conversation_history)} messages")
print(f" Proposed Instances: {len(result.proposed_instances)}")
print(f" Has Next Action: {result.next_action is not None}")
print(f" Session State Keys: {list(result.session_state.keys())}")
if result.error:
print(f" Error: {result.error}")
print("-" * 80)
def print_chat_exchange(
user_msg: str | None, assistant_msg: str, session_state: "SessionState"
) -> None:
"""Print a chat exchange with role labels and current mode."""
mode = get_current_mode(session_state)
print_separator("CHAT EXCHANGE", char="-", width=80, double=False)
if user_msg:
print("\n USER:")
print(f" {user_msg}")
print(f"\n ASSISTANT [{mode}]:")
# Wrap long messages
max_line_length = 76
words = assistant_msg.split()
current_line = " "
for word in words:
if len(current_line) + len(word) + 1 > max_line_length:
print(current_line)
current_line = " " + word
else:
current_line += (" " if len(current_line) > 2 else "") + word
if current_line.strip():
print(current_line)
print("\n" + "-" * 80)
@pytest.fixture @pytest.fixture
def mock_flake() -> MagicMock: def mock_flake() -> MagicMock:
@@ -47,9 +117,9 @@ def mock_flake() -> MagicMock:
} }
match arg: match arg:
case "clanInternals.inventoryClass.inventory.{instances,machines,meta}": case "clanInternals.inventoryClass.inventorySerialization.{instances,machines,meta}":
return load_json("inventory_instances_machines_meta.json") return load_json("inventory_instances_machines_meta.json")
case "clanInternals.inventoryClass.inventory.{tags}": case "clanInternals.inventoryClass.inventorySerialization.{tags}":
return load_json("inventory_tags.json") return load_json("inventory_tags.json")
case "clanInternals.inventoryClass.modulesPerSource": case "clanInternals.inventoryClass.modulesPerSource":
return load_json("modules_per_source.json") return load_json("modules_per_source.json")
@@ -98,6 +168,40 @@ def llm_service() -> Iterator[None]:
service_manager.stop_service("ollama") service_manager.stop_service("ollama")
def execute_multi_turn_workflow(
user_request: str,
flake: Flake | MagicMock,
conversation_history: list["ChatMessage"] | None = None,
provider: str = "ollama",
session_state: "SessionState | None" = None,
) -> "ChatResult":
"""Execute the multi-turn workflow, auto-executing all pending operations.
This simulates the behavior of the CLI auto-execute loop in workflow.py.
"""
result = get_llm_turn(
user_request=user_request,
flake=flake,
conversation_history=conversation_history,
provider=provider, # type: ignore[arg-type]
session_state=session_state,
execute_next_action=False,
)
# Auto-execute any pending operations
while result.next_action:
result = get_llm_turn(
user_request="",
flake=flake,
conversation_history=list(result.conversation_history),
provider=provider, # type: ignore[arg-type]
session_state=result.session_state,
execute_next_action=True,
)
return result
@pytest.mark.service_runner @pytest.mark.service_runner
@pytest.mark.usefixtures("mock_nix_shell", "llm_service") @pytest.mark.usefixtures("mock_nix_shell", "llm_service")
def test_full_conversation_flow(mock_flake: MagicMock) -> None: def test_full_conversation_flow(mock_flake: MagicMock) -> None:
@@ -112,10 +216,9 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None:
- Error handling and edge cases - Error handling and edge cases
""" """
flake = mock_flake flake = mock_flake
return
# ========== TURN 1: Discovery Phase - Initial vague request ========== # ========== TURN 1: Discovery Phase - Initial vague request ==========
print("\n=== TURN 1: Initial discovery request ===") print_separator("TURN 1: Discovery Phase", char="=", width=80)
result = process_chat_turn( result = execute_multi_turn_workflow(
user_request="What VPN options do I have?", user_request="What VPN options do I have?",
flake=flake, flake=flake,
provider="ollama", provider="ollama",
@@ -133,24 +236,25 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None:
assert result.conversation_history[-1]["role"] == "assistant" assert result.conversation_history[-1]["role"] == "assistant"
assert len(result.assistant_message) > 0, "Assistant should provide a response" assert len(result.assistant_message) > 0, "Assistant should provide a response"
# Should transition to service selection phase with pending state # After multi-turn execution, we may have either:
assert "pending_service_selection" in result.session_state, ( # - pending_service_selection (if LLM provided options and is waiting for choice)
"Should have pending service selection" # - pending_final_decision (if LLM directly selected a service)
) # - no pending state (if LLM asked a clarifying question)
assert "readme_results" in result.session_state["pending_service_selection"]
# No instances yet # No instances yet
assert len(result.proposed_instances) == 0 assert len(result.proposed_instances) == 0
assert result.error is None assert result.error is None
print(f"Assistant: {result.assistant_message[:200]}...") print_chat_exchange(
print(f"State: {list(result.session_state.keys())}") "What VPN options do I have?", result.assistant_message, result.session_state
print(f"History length: {len(result.conversation_history)}") )
print_meta_info(result, turn=1, phase="Discovery")
# ========== TURN 2: Service Selection Phase - User makes a choice ========== # ========== TURN 2: Service Selection Phase - User makes a choice ==========
print("\n=== TURN 2: User selects ZeroTier ===") print_separator("TURN 2: Service Selection", char="=", width=80)
result = process_chat_turn( user_msg_2 = "I'll use ZeroTier please"
user_request="I'll use ZeroTier please", result = execute_multi_turn_workflow(
user_request=user_msg_2,
flake=flake, flake=flake,
conversation_history=list(result.conversation_history), conversation_history=list(result.conversation_history),
provider="ollama", provider="ollama",
@@ -176,11 +280,8 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None:
assert len(result.proposed_instances) > 0 assert len(result.proposed_instances) > 0
assert result.proposed_instances[0]["module"]["name"] == "zerotier" assert result.proposed_instances[0]["module"]["name"] == "zerotier"
print( print_chat_exchange(user_msg_2, result.assistant_message, result.session_state)
f"Assistant: {result.assistant_message[:200] if result.assistant_message else 'No message'}..." print_meta_info(result, turn=2, phase="Service Selection")
)
print(f"State: {list(result.session_state.keys())}")
print(f"Requires response: {result.requires_user_response}")
# ========== Continue conversation until we reach final decision or completion ========== # ========== Continue conversation until we reach final decision or completion ==========
max_turns = 10 max_turns = 10
@@ -188,22 +289,24 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None:
while result.requires_user_response and turn_count < max_turns: while result.requires_user_response and turn_count < max_turns:
turn_count += 1 turn_count += 1
print(f"\n=== TURN {turn_count}: Continuing conversation ===")
# Determine appropriate response based on current state # Determine appropriate response based on current state
if "pending_service_selection" in result.session_state: if "pending_service_selection" in result.session_state:
# Still selecting service # Still selecting service
user_request = "Yes, ZeroTier" user_request = "Yes, ZeroTier"
phase = "Service Selection (continued)"
elif "pending_final_decision" in result.session_state: elif "pending_final_decision" in result.session_state:
# Configuring the service # Configuring the service
user_request = "Set up gchq-local as controller, qube-email as moon, and wintux as peer" user_request = "Set up gchq-local as controller, qube-email as moon, and wintux as peer"
phase = "Final Configuration"
else: else:
# Generic continuation # Generic continuation
user_request = "Yes, that sounds good. Use gchq-local as controller." user_request = "Yes, that sounds good. Use gchq-local as controller."
phase = "Continuing Conversation"
print(f"User: {user_request}") print_separator(f"TURN {turn_count}: {phase}", char="=", width=80)
result = process_chat_turn( result = execute_multi_turn_workflow(
user_request=user_request, user_request=user_request,
flake=flake, flake=flake,
conversation_history=list(result.conversation_history), conversation_history=list(result.conversation_history),
@@ -221,19 +324,18 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None:
result.conversation_history[0]["content"] == "What VPN options do I have?" result.conversation_history[0]["content"] == "What VPN options do I have?"
) )
print( print_chat_exchange(
f"Assistant: {result.assistant_message[:200] if result.assistant_message else 'No message'}..." user_request, result.assistant_message, result.session_state
) )
print(f"State: {list(result.session_state.keys())}") print_meta_info(result, turn=turn_count, phase=phase)
print(f"Requires response: {result.requires_user_response}")
print(f"Proposed instances: {len(result.proposed_instances)}")
# Check for completion # Check for completion
if not result.requires_user_response: if not result.requires_user_response:
print("\n=== Conversation completed! ===") print_separator("CONVERSATION COMPLETED", char="=", width=80)
break break
# ========== Final Verification ========== # ========== Final Verification ==========
print_separator("FINAL VERIFICATION", char="=", width=80)
assert turn_count < max_turns, f"Conversation took too many turns ({turn_count})" assert turn_count < max_turns, f"Conversation took too many turns ({turn_count})"
# If conversation completed, verify we have valid configuration # If conversation completed, verify we have valid configuration
@@ -253,22 +355,29 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None:
"mycelium", "mycelium",
] ]
# Should have roles configuration
if "roles" in instance:
print(f"\nConfiguration roles: {list(instance['roles'].keys())}")
# Should not be in pending state anymore # Should not be in pending state anymore
assert "pending_service_selection" not in result.session_state assert "pending_service_selection" not in result.session_state
assert "pending_final_decision" not in result.session_state assert "pending_final_decision" not in result.session_state
assert result.error is None, f"Should not have error: {result.error}" assert result.error is None, f"Should not have error: {result.error}"
print(f"\nFinal instance: {instance['module']['name']}") print_separator("FINAL SUMMARY", char="-", width=80, double=False)
print(f"Total conversation turns: {turn_count}") print(" Status: SUCCESS")
print(f"Final history length: {len(result.conversation_history)}") print(f" Module Name: {instance['module']['name']}")
print(f" Total Turns: {turn_count}")
print(f" Final History Length: {len(result.conversation_history)} messages")
if "roles" in instance:
roles_list = ", ".join(instance["roles"].keys())
print(f" Configuration Roles: {roles_list}")
print(" Errors: None")
print("-" * 80)
else: else:
# Conversation didn't complete but should have made progress # Conversation didn't complete but should have made progress
assert len(result.conversation_history) > 2 assert len(result.conversation_history) > 2
assert result.error is None assert result.error is None
print(f"\nConversation in progress after {turn_count} turns") print_separator("FINAL SUMMARY", char="-", width=80, double=False)
print(f"Current state: {list(result.session_state.keys())}") print(" Status: IN PROGRESS")
print(f" Total Turns: {turn_count}")
print(f" Current State: {list(result.session_state.keys())}")
print(f" History Length: {len(result.conversation_history)} messages")
print("-" * 80)

View File

@@ -1,65 +0,0 @@
"""High-level LLM orchestration functions.
This module re-exports the LLM orchestration API from submodules.
"""
# Re-export types and dataclasses
from .llm_types import ( # noqa: F401
DEFAULT_MODELS,
ChatResult,
ConversationProgressEvent,
DiscoveryProgressEvent,
FinalDecisionProgressEvent,
ModelConfig,
ProgressCallback,
ProgressEvent,
ReadmeFetchProgressEvent,
ServiceSelectionProgressEvent,
ServiceSelectionResult,
get_model_config,
)
# Re-export high-level orchestrator
from .orchestrator import process_chat_turn # noqa: F401
# Re-export low-level phase functions
from .phases import ( # noqa: F401
execute_readme_requests,
get_llm_discovery_phase,
get_llm_final_decision,
get_llm_service_selection,
llm_final_decision_to_inventory_instances,
)
# Re-export commonly used functions and types from schemas
from .schemas import ( # noqa: F401
AiAggregate,
ChatMessage,
ConversationHistory,
FunctionCallType,
JSONValue,
MachineDescription,
OllamaFunctionSchema,
OpenAIFunctionSchema,
PendingFinalDecisionState,
PendingServiceSelectionState,
ReadmeRequest,
SessionState,
SimplifiedServiceSchema,
TagDescription,
aggregate_ollama_function_schemas,
aggregate_openai_function_schemas,
create_get_readme_tool,
create_select_service_tool,
create_simplified_service_schemas,
)
# Re-export service functions
from .service import create_llm_model, run_llm_service # noqa: F401
# Re-export utility functions and constants
from .utils import ( # noqa: F401
ASSISTANT_MODE_DISCOVERY,
ASSISTANT_MODE_FINAL,
ASSISTANT_MODE_SELECTION,
)

View File

@@ -3,12 +3,13 @@ from collections.abc import Callable
import pytest import pytest
from clan_cli.tests.fixtures_flakes import nested_dict from clan_cli.tests.fixtures_flakes import nested_dict
from clan_lib.flake.flake import Flake from clan_lib.flake.flake import Flake
from clan_lib.llm.llm import ( from clan_lib.llm.phases import llm_final_decision_to_inventory_instances
from clan_lib.llm.schemas import (
FunctionCallType,
OpenAIFunctionSchema, OpenAIFunctionSchema,
aggregate_openai_function_schemas, aggregate_openai_function_schemas,
llm_final_decision_to_inventory_instances, clan_module_to_openai_spec,
) )
from clan_lib.llm.schemas import FunctionCallType, clan_module_to_openai_spec
from clan_lib.services.modules import list_service_modules from clan_lib.services.modules import list_service_modules

View File

@@ -1,57 +1,28 @@
"""Type definitions and dataclasses for LLM orchestration.""" """Type definitions and dataclasses for LLM orchestration."""
from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Literal from typing import Any, Literal, TypedDict
from clan_lib.nix_models.clan import InventoryInstance from clan_lib.nix_models.clan import InventoryInstance
from .schemas import ChatMessage, SessionState from .schemas import ChatMessage, SessionState
@dataclass(frozen=True) class NextAction(TypedDict):
class DiscoveryProgressEvent: """Describes the next expensive operation that will be performed.
"""Progress event during discovery phase."""
service_names: list[str] | None = None Attributes:
stage: Literal["discovery"] = "discovery" type: The type of operation (discovery, fetch_readmes, service_selection, final_decision)
status: Literal["analyzing", "complete"] = "analyzing" description: Human-readable description of what will happen
estimated_duration_seconds: Rough estimate of operation duration
details: Phase-specific information (e.g., service names, count)
"""
@dataclass(frozen=True) type: Literal["discovery", "fetch_readmes", "service_selection", "final_decision"]
class ReadmeFetchProgressEvent: description: str
"""Progress event during readme fetching.""" estimated_duration_seconds: int
details: dict[str, Any]
count: int
service_names: list[str]
stage: Literal["readme_fetch"] = "readme_fetch"
status: Literal["fetching", "complete"] = "fetching"
@dataclass(frozen=True)
class ServiceSelectionProgressEvent:
"""Progress event during service selection phase."""
service_names: list[str]
stage: Literal["service_selection"] = "service_selection"
status: Literal["selecting", "complete"] = "selecting"
@dataclass(frozen=True)
class FinalDecisionProgressEvent:
"""Progress event during final decision phase."""
stage: Literal["final_decision"] = "final_decision"
status: Literal["reviewing", "complete"] = "reviewing"
@dataclass(frozen=True)
class ConversationProgressEvent:
"""Progress event for conversation continuation."""
message: str
stage: Literal["conversation"] = "conversation"
awaiting_response: bool = True
@dataclass(frozen=True) @dataclass(frozen=True)
@@ -70,17 +41,6 @@ class ServiceSelectionResult:
clarifying_message: str clarifying_message: str
ProgressEvent = (
DiscoveryProgressEvent
| ReadmeFetchProgressEvent
| ServiceSelectionProgressEvent
| FinalDecisionProgressEvent
| ConversationProgressEvent
)
ProgressCallback = Callable[[ProgressEvent], None]
@dataclass(frozen=True) @dataclass(frozen=True)
class ChatResult: class ChatResult:
"""Result of a complete chat turn through the multi-stage workflow. """Result of a complete chat turn through the multi-stage workflow.
@@ -92,6 +52,7 @@ class ChatResult:
requires_user_response: True if the assistant asked a question and needs a response requires_user_response: True if the assistant asked a question and needs a response
error: Error message if something went wrong (None on success) error: Error message if something went wrong (None on success)
session_state: Serializable state to pass into the next turn when continuing a workflow session_state: Serializable state to pass into the next turn when continuing a workflow
next_action: Description of the next operation to be performed (None if workflow complete)
""" """
@@ -100,6 +61,7 @@ class ChatResult:
assistant_message: str assistant_message: str
requires_user_response: bool requires_user_response: bool
session_state: SessionState session_state: SessionState
next_action: NextAction | None
error: str | None = None error: str | None = None

View File

@@ -4,18 +4,12 @@ import json
from pathlib import Path from pathlib import Path
from typing import Literal, cast from typing import Literal, cast
from clan_lib.api import API
from clan_lib.errors import ClanAiError from clan_lib.errors import ClanAiError
from clan_lib.flake.flake import Flake from clan_lib.flake.flake import Flake
from clan_lib.services.modules import InputName, ServiceReadmeCollection from clan_lib.services.modules import InputName, ServiceReadmeCollection
from .llm_types import ( from .llm_types import ChatResult, NextAction
ChatResult,
DiscoveryProgressEvent,
FinalDecisionProgressEvent,
ProgressCallback,
ReadmeFetchProgressEvent,
ServiceSelectionProgressEvent,
)
from .phases import ( from .phases import (
execute_readme_requests, execute_readme_requests,
get_llm_discovery_phase, get_llm_discovery_phase,
@@ -26,8 +20,11 @@ from .phases import (
from .schemas import ( from .schemas import (
ConversationHistory, ConversationHistory,
JSONValue, JSONValue,
PendingDiscoveryState,
PendingFinalDecisionState, PendingFinalDecisionState,
PendingReadmeFetchState,
PendingServiceSelectionState, PendingServiceSelectionState,
ReadmeRequest,
SessionState, SessionState,
) )
from .utils import ( from .utils import (
@@ -41,45 +38,50 @@ from .utils import (
) )
def process_chat_turn( @API.register
def get_llm_turn(
user_request: str, user_request: str,
flake: Flake, flake: Flake,
conversation_history: ConversationHistory | None = None, conversation_history: ConversationHistory | None = None,
provider: Literal["openai", "ollama", "claude"] = "ollama", provider: Literal["openai", "ollama", "claude"] = "ollama",
progress_callback: ProgressCallback | None = None,
trace_file: Path | None = None, trace_file: Path | None = None,
session_state: SessionState | None = None, session_state: SessionState | None = None,
execute_next_action: bool = False,
) -> ChatResult: ) -> ChatResult:
"""High-level API that orchestrates the entire multi-stage chat workflow. """High-level API that orchestrates the entire multi-stage chat workflow.
This function handles the complete flow: This function handles the complete flow using a multi-turn approach:
1. Discovery phase - LLM selects relevant services 1. Discovery phase - LLM selects relevant services
2. Readme fetching - Retrieves detailed documentation 2. Readme fetching - Retrieves detailed documentation
3. Final decision - LLM makes informed suggestions 3. Final decision - LLM makes informed suggestions
4. Conversion - Transforms suggestions to inventory instances 4. Conversion - Transforms suggestions to inventory instances
Before each expensive operation, the function returns with next_action describing
what will happen. The caller must call again with execute_next_action=True.
Args: Args:
user_request: The user's message/request user_request: The user's message/request
flake: The Flake object to get services from flake: The Flake object to get services from
conversation_history: Optional list of prior messages in the conversation conversation_history: Optional list of prior messages in the conversation
provider: The LLM provider to use provider: The LLM provider to use
progress_callback: Optional callback for progress updates
trace_file: Optional path to write LLM interaction traces for debugging trace_file: Optional path to write LLM interaction traces for debugging
session_state: Optional cross-turn state to resume pending workflows session_state: Optional cross-turn state to resume pending workflows
execute_next_action: If True, execute the pending operation in session_state
Returns: Returns:
ChatResult containing proposed instances, updated history, and assistant message ChatResult containing proposed instances, updated history, next_action, and assistant message
Example: Example:
>>> result = process_chat_turn( >>> result = process_chat_turn("Set up a web server", flake)
... "Set up a web server", >>> while result.next_action:
... flake, ... # Show user what will happen
... progress_callback=lambda event: print(f"Stage: {event.stage}") ... print(result.next_action["description"])
... ) ... result = process_chat_turn(
>>> if result.proposed_instances: ... user_request="",
... print("LLM suggested:", result.proposed_instances) ... flake=flake,
>>> if result.requires_user_response: ... session_state=result.session_state,
... print("Assistant asks:", result.assistant_message) ... execute_next_action=True
... )
""" """
history = list(conversation_history) if conversation_history else [] history = list(conversation_history) if conversation_history else []
@@ -87,6 +89,11 @@ def process_chat_turn(
"SessionState", dict(session_state) if session_state else {} "SessionState", dict(session_state) if session_state else {}
) )
# Add non-empty user message to history for conversation tracking
# Phases will also add it to their messages array for LLM calls
if user_request:
history.append(_user_message(user_request))
def _state_snapshot() -> dict[str, JSONValue]: def _state_snapshot() -> dict[str, JSONValue]:
try: try:
return json.loads(json.dumps(state)) return json.loads(json.dumps(state))
@@ -102,6 +109,17 @@ def process_chat_turn(
def _state_copy() -> SessionState: def _state_copy() -> SessionState:
return cast("SessionState", dict(state)) return cast("SessionState", dict(state))
# Check pending states in workflow order (earliest to latest)
pending_discovery_raw = state.get("pending_discovery")
pending_discovery: PendingDiscoveryState | None = (
pending_discovery_raw if isinstance(pending_discovery_raw, dict) else None
)
pending_readme_fetch_raw = state.get("pending_readme_fetch")
pending_readme_fetch: PendingReadmeFetchState | None = (
pending_readme_fetch_raw if isinstance(pending_readme_fetch_raw, dict) else None
)
pending_final_raw = state.get("pending_final_decision") pending_final_raw = state.get("pending_final_decision")
pending_final: PendingFinalDecisionState | None = ( pending_final: PendingFinalDecisionState | None = (
pending_final_raw if isinstance(pending_final_raw, dict) else None pending_final_raw if isinstance(pending_final_raw, dict) else None
@@ -117,87 +135,139 @@ def process_chat_turn(
if serialized_results is not None: if serialized_results is not None:
resume_readme_results = _deserialize_readme_results(serialized_results) resume_readme_results = _deserialize_readme_results(serialized_results)
# Only pop if we can't deserialize (invalid state)
if resume_readme_results is None: if resume_readme_results is None:
state.pop("pending_service_selection", None) state.pop("pending_service_selection", None)
else:
state.pop("pending_service_selection", None) # Handle pending_discovery state: execute discovery if execute_next_action=True
if pending_discovery is not None and execute_next_action:
state.pop("pending_discovery", None)
# Continue to execute discovery below (after pending state checks)
# Handle pending_readme_fetch state: execute readme fetch if execute_next_action=True
if pending_readme_fetch is not None and execute_next_action:
readme_requests_raw = pending_readme_fetch.get("readme_requests", [])
readme_requests = cast("list[ReadmeRequest]", readme_requests_raw)
if readme_requests:
state.pop("pending_readme_fetch", None)
readme_results = execute_readme_requests(readme_requests, flake)
# Save readme results and return next_action for service selection
state["pending_service_selection"] = cast(
"PendingServiceSelectionState",
{"readme_results": _serialize_readme_results(readme_results)},
)
service_count = len(readme_results)
next_action_selection: NextAction = {
"type": "service_selection",
"description": f"Analyzing {service_count} service(s) to find the best match",
"estimated_duration_seconds": 15,
"details": {"service_count": service_count},
}
return ChatResult(
next_action=next_action_selection,
proposed_instances=(),
conversation_history=tuple(history),
assistant_message="",
requires_user_response=False,
error=None,
session_state=_state_copy(),
)
state.pop("pending_readme_fetch", None)
if pending_final is not None: if pending_final is not None:
service_name = pending_final.get("service_name") service_name = pending_final.get("service_name")
service_summary = pending_final.get("service_summary") service_summary = pending_final.get("service_summary")
if isinstance(service_name, str) and isinstance(service_summary, str): if isinstance(service_name, str) and isinstance(service_summary, str):
if progress_callback: if execute_next_action:
progress_callback(FinalDecisionProgressEvent(status="reviewing"))
function_calls, final_message = get_llm_final_decision(
user_request,
flake,
service_name,
service_summary,
conversation_history,
provider=provider,
trace_file=trace_file,
trace_metadata=_metadata(
{
"selected_service": service_name,
"resume": True,
}
),
)
if progress_callback:
progress_callback(FinalDecisionProgressEvent(status="complete"))
history.append(_user_message(user_request))
if function_calls:
proposed_instances = llm_final_decision_to_inventory_instances(
function_calls
)
instance_names = [inst["module"]["name"] for inst in proposed_instances]
summary = (
f"I suggest configuring these services: {', '.join(instance_names)}"
)
history.append(_assistant_message(summary, mode=ASSISTANT_MODE_FINAL))
state.pop("pending_final_decision", None) state.pop("pending_final_decision", None)
return ChatResult( function_calls, final_message = get_llm_final_decision(
proposed_instances=tuple(proposed_instances), user_request,
conversation_history=tuple(history), flake,
assistant_message=summary, service_name,
requires_user_response=False, service_summary,
error=None, conversation_history,
session_state=_state_copy(), provider=provider,
trace_file=trace_file,
trace_metadata=_metadata(
{
"selected_service": service_name,
"resume": True,
}
),
) )
if final_message: if function_calls:
history.append( proposed_instances = llm_final_decision_to_inventory_instances(
_assistant_message(final_message, mode=ASSISTANT_MODE_FINAL) function_calls
) )
state["pending_final_decision"] = cast( instance_names = [
"PendingFinalDecisionState", inst["module"]["name"] for inst in proposed_instances
{ ]
"service_name": service_name, summary = f"I suggest configuring these services: {', '.join(instance_names)}"
"service_summary": service_summary, history.append(
}, _assistant_message(summary, mode=ASSISTANT_MODE_FINAL)
)
return ChatResult(
next_action=None,
proposed_instances=tuple(proposed_instances),
conversation_history=tuple(history),
assistant_message=summary,
requires_user_response=False,
error=None,
session_state=_state_copy(),
)
if final_message:
history.append(
_assistant_message(final_message, mode=ASSISTANT_MODE_FINAL)
)
state["pending_final_decision"] = cast(
"PendingFinalDecisionState",
{
"service_name": service_name,
"service_summary": service_summary,
},
)
return ChatResult(
next_action=None,
proposed_instances=(),
conversation_history=tuple(history),
assistant_message=final_message,
requires_user_response=True,
error=None,
session_state=_state_copy(),
)
state.pop("pending_final_decision", None)
msg = "LLM did not provide any response or recommendations"
raise ClanAiError(
msg,
description="Expected either function calls (configuration) or a clarifying message",
location="Final Decision Phase (pending)",
) )
return ChatResult( # If not executing, return next_action for final decision
proposed_instances=(), next_action_final_pending: NextAction = {
conversation_history=tuple(history), "type": "final_decision",
assistant_message=final_message, "description": f"Generating configuration for {service_name}",
requires_user_response=True, "estimated_duration_seconds": 20,
error=None, "details": {"service_name": service_name},
session_state=_state_copy(), }
) return ChatResult(
next_action=next_action_final_pending,
state.pop("pending_final_decision", None) proposed_instances=(),
msg = "LLM did not provide any response or recommendations" conversation_history=tuple(history),
raise ClanAiError( assistant_message="",
msg, requires_user_response=False,
description="Expected either function calls (configuration) or a clarifying message", error=None,
location="Final Decision Phase (pending)", session_state=_state_copy(),
) )
state.pop("pending_final_decision", None) state.pop("pending_final_decision", None)
@@ -206,19 +276,12 @@ def process_chat_turn(
readme_results: dict[InputName, ServiceReadmeCollection], readme_results: dict[InputName, ServiceReadmeCollection],
) -> ChatResult: ) -> ChatResult:
# Extract all service names from readme results # Extract all service names from readme results
all_service_names = [ [
service_name service_name
for collection in readme_results.values() for collection in readme_results.values()
for service_name in collection.readmes for service_name in collection.readmes
] ]
if progress_callback:
progress_callback(
ServiceSelectionProgressEvent(
service_names=all_service_names, status="selecting"
)
)
selection_result = get_llm_service_selection( selection_result = get_llm_service_selection(
user_request, user_request,
readme_results, readme_results,
@@ -232,7 +295,6 @@ def process_chat_turn(
selection_result.clarifying_message selection_result.clarifying_message
and not selection_result.selected_service and not selection_result.selected_service
): ):
history.append(_user_message(user_request))
history.append( history.append(
_assistant_message( _assistant_message(
selection_result.clarifying_message, selection_result.clarifying_message,
@@ -247,6 +309,7 @@ def process_chat_turn(
) )
return ChatResult( return ChatResult(
next_action=None,
proposed_instances=(), proposed_instances=(),
conversation_history=tuple(history), conversation_history=tuple(history),
assistant_message=selection_result.clarifying_message, assistant_message=selection_result.clarifying_message,
@@ -267,81 +330,79 @@ def process_chat_turn(
location="Service Selection Phase", location="Service Selection Phase",
) )
if progress_callback: # After service selection, always return next_action for final decision
progress_callback(FinalDecisionProgressEvent(status="reviewing")) state["pending_final_decision"] = cast(
"PendingFinalDecisionState",
function_calls, final_message = get_llm_final_decision( {
user_request, "service_name": selection_result.selected_service,
flake, "service_summary": selection_result.service_summary,
selection_result.selected_service, },
selection_result.service_summary,
conversation_history,
provider=provider,
trace_file=trace_file,
trace_metadata=_metadata(
{"selected_service": selection_result.selected_service}
),
) )
next_action_final: NextAction = {
if progress_callback: "type": "final_decision",
progress_callback(FinalDecisionProgressEvent(status="complete")) "description": f"Generating configuration for {selection_result.selected_service}",
"estimated_duration_seconds": 20,
if function_calls: "details": {"service_name": selection_result.selected_service},
history.append(_user_message(user_request)) }
return ChatResult(
proposed_instances = llm_final_decision_to_inventory_instances( next_action=next_action_final,
function_calls proposed_instances=(),
) conversation_history=tuple(history),
assistant_message="",
instance_names = [inst["module"]["name"] for inst in proposed_instances] requires_user_response=False,
summary = ( error=None,
f"I suggest configuring these services: {', '.join(instance_names)}" session_state=_state_copy(),
)
history.append(_assistant_message(summary, mode=ASSISTANT_MODE_FINAL))
state.pop("pending_final_decision", None)
return ChatResult(
proposed_instances=tuple(proposed_instances),
conversation_history=tuple(history),
assistant_message=summary,
requires_user_response=False,
error=None,
session_state=_state_copy(),
)
if final_message:
history.append(_user_message(user_request))
history.append(_assistant_message(final_message, mode=ASSISTANT_MODE_FINAL))
state["pending_final_decision"] = cast(
"PendingFinalDecisionState",
{
"service_name": selection_result.selected_service,
"service_summary": selection_result.service_summary,
},
)
return ChatResult(
proposed_instances=(),
conversation_history=tuple(history),
assistant_message=final_message,
requires_user_response=True,
error=None,
session_state=_state_copy(),
)
msg = "LLM did not provide any response or recommendations"
raise ClanAiError(
msg,
description="Expected either function calls (configuration) or a clarifying message after service selection",
location="Final Decision Phase",
) )
if resume_readme_results is not None: if resume_readme_results is not None:
return _continue_with_service_selection(resume_readme_results) if execute_next_action:
# Pop the pending state now that we're executing
state.pop("pending_service_selection", None)
return _continue_with_service_selection(resume_readme_results)
# If not executing, return next_action for service selection
service_count = len(resume_readme_results)
next_action_sel_resume: NextAction = {
"type": "service_selection",
"description": f"Analyzing {service_count} service(s) to find the best match",
"estimated_duration_seconds": 15,
"details": {"service_count": service_count},
}
return ChatResult(
next_action=next_action_sel_resume,
proposed_instances=(),
conversation_history=tuple(history),
assistant_message="",
requires_user_response=False,
error=None,
session_state=_state_copy(),
)
# Stage 1: Discovery phase # Stage 1: Discovery phase
if progress_callback: # If we're not executing and have no pending states, return next_action for discovery
progress_callback(DiscoveryProgressEvent(status="analyzing")) has_pending_states = (
pending_discovery is not None
or pending_readme_fetch is not None
or pending_selection is not None
or pending_final is not None
)
if not execute_next_action and not has_pending_states:
state["pending_discovery"] = cast("PendingDiscoveryState", {})
next_action: NextAction = {
"type": "discovery",
"description": "Analyzing your request and discovering relevant services",
"estimated_duration_seconds": 10,
"details": {"phase": "discovery"},
}
return ChatResult(
next_action=next_action,
proposed_instances=(),
conversation_history=tuple(history),
assistant_message="",
requires_user_response=False,
error=None,
session_state=_state_copy(),
)
readme_requests, discovery_message = get_llm_discovery_phase( readme_requests, discovery_message = get_llm_discovery_phase(
user_request, user_request,
@@ -352,23 +413,14 @@ def process_chat_turn(
trace_metadata=_metadata(), trace_metadata=_metadata(),
) )
if progress_callback:
selected_services = [req["function_name"] for req in readme_requests]
progress_callback(
DiscoveryProgressEvent(
service_names=selected_services if selected_services else None,
status="complete",
)
)
# If LLM asked a question or made a recommendation without readme requests # If LLM asked a question or made a recommendation without readme requests
if discovery_message and not readme_requests: if discovery_message and not readme_requests:
history.append(_user_message(user_request))
history.append( history.append(
_assistant_message(discovery_message, mode=ASSISTANT_MODE_DISCOVERY) _assistant_message(discovery_message, mode=ASSISTANT_MODE_DISCOVERY)
) )
return ChatResult( return ChatResult(
next_action=None,
proposed_instances=(), proposed_instances=(),
conversation_history=tuple(history), conversation_history=tuple(history),
assistant_message=discovery_message, assistant_message=discovery_message,
@@ -377,34 +429,28 @@ def process_chat_turn(
session_state=_state_copy(), session_state=_state_copy(),
) )
# If we got readme requests, continue to selecting services # If we got readme requests, save them and return next_action for readme fetch
if readme_requests: if readme_requests:
# Stage 2: Fetch readmes state["pending_readme_fetch"] = cast(
service_names = [ "PendingReadmeFetchState",
f"{req['function_name']} (from {req['input_name'] or 'built-in'})" {"readme_requests": cast("list[dict[str, JSONValue]]", readme_requests)},
for req in readme_requests )
] service_count = len(readme_requests)
if progress_callback: next_action_fetch: NextAction = {
progress_callback( "type": "fetch_readmes",
ReadmeFetchProgressEvent( "description": f"Fetching documentation for {service_count} service(s)",
count=len(readme_requests), "estimated_duration_seconds": 5,
service_names=service_names, "details": {"service_count": service_count},
status="fetching", }
) return ChatResult(
) next_action=next_action_fetch,
proposed_instances=(),
readme_results = execute_readme_requests(readme_requests, flake) conversation_history=tuple(history),
assistant_message="",
if progress_callback: requires_user_response=False,
progress_callback( error=None,
ReadmeFetchProgressEvent( session_state=_state_copy(),
count=len(readme_requests), )
service_names=service_names,
status="complete",
)
)
return _continue_with_service_selection(readme_results)
# No readme requests and no message - unexpected # No readme requests and no message - unexpected
msg = "LLM did not provide any response or recommendations" msg = "LLM did not provide any response or recommendations"

View File

@@ -85,7 +85,8 @@ def get_llm_discovery_phase(
{"role": "assistant", "content": assistant_context}, {"role": "assistant", "content": assistant_context},
] ]
messages.extend(_strip_conversation_metadata(conversation_history)) messages.extend(_strip_conversation_metadata(conversation_history))
messages.append(_user_message(user_request)) if user_request:
messages.append(_user_message(user_request))
# Call LLM with only get_readme tool # Call LLM with only get_readme tool
model_config = get_model_config(provider) model_config = get_model_config(provider)
@@ -233,7 +234,8 @@ def get_llm_service_selection(
{"role": "assistant", "content": combined_assistant_context}, {"role": "assistant", "content": combined_assistant_context},
] ]
messages.extend(_strip_conversation_metadata(conversation_history)) messages.extend(_strip_conversation_metadata(conversation_history))
messages.append(_user_message(user_request)) if user_request:
messages.append(_user_message(user_request))
model_config = get_model_config(provider) model_config = get_model_config(provider)
@@ -430,8 +432,8 @@ def get_llm_final_decision(
{"role": "assistant", "content": combined_assistant_context}, {"role": "assistant", "content": combined_assistant_context},
] ]
messages.extend(_strip_conversation_metadata(conversation_history)) messages.extend(_strip_conversation_metadata(conversation_history))
if user_request:
messages.append(_user_message(user_request)) messages.append(_user_message(user_request))
# Get full schemas # Get full schemas
model_config = get_model_config(provider) model_config = get_model_config(provider)

View File

@@ -66,18 +66,28 @@ class ChatMessage(TypedDict):
ConversationHistory = list[ChatMessage] ConversationHistory = list[ChatMessage]
class PendingFinalDecisionState(TypedDict, total=False): class PendingDiscoveryState(TypedDict, total=False):
service_name: NotRequired[str] user_request: NotRequired[str]
service_summary: NotRequired[str]
class PendingReadmeFetchState(TypedDict, total=False):
readme_requests: NotRequired[list[dict[str, JSONValue]]]
class PendingServiceSelectionState(TypedDict, total=False): class PendingServiceSelectionState(TypedDict, total=False):
readme_results: NotRequired[list[dict[str, JSONValue]]] readme_results: NotRequired[list[dict[str, JSONValue]]]
class PendingFinalDecisionState(TypedDict, total=False):
service_name: NotRequired[str]
service_summary: NotRequired[str]
class SessionState(TypedDict, total=False): class SessionState(TypedDict, total=False):
pending_final_decision: NotRequired[PendingFinalDecisionState] pending_discovery: NotRequired[PendingDiscoveryState]
pending_readme_fetch: NotRequired[PendingReadmeFetchState]
pending_service_selection: NotRequired[PendingServiceSelectionState] pending_service_selection: NotRequired[PendingServiceSelectionState]
pending_final_decision: NotRequired[PendingFinalDecisionState]
class JSONSchemaProperty(TypedDict, total=False): class JSONSchemaProperty(TypedDict, total=False):

View File

@@ -16,16 +16,12 @@ from clan_lib.llm.endpoints import (
parse_ollama_response, parse_ollama_response,
parse_openai_response, parse_openai_response,
) )
from clan_lib.llm.llm import ( from clan_lib.llm.llm_types import ServiceSelectionResult
DiscoveryProgressEvent, from clan_lib.llm.orchestrator import get_llm_turn
FinalDecisionProgressEvent, from clan_lib.llm.phases import (
ReadmeFetchProgressEvent,
ServiceSelectionProgressEvent,
ServiceSelectionResult,
execute_readme_requests, execute_readme_requests,
get_llm_final_decision, get_llm_final_decision,
get_llm_service_selection, get_llm_service_selection,
process_chat_turn,
) )
from clan_lib.llm.schemas import ( from clan_lib.llm.schemas import (
AiAggregate, AiAggregate,
@@ -37,9 +33,44 @@ from clan_lib.llm.schemas import (
from clan_lib.services.modules import ServiceReadmeCollection from clan_lib.services.modules import ServiceReadmeCollection
if TYPE_CHECKING: if TYPE_CHECKING:
from clan_lib.llm.llm_types import ChatResult
from clan_lib.llm.schemas import ChatMessage from clan_lib.llm.schemas import ChatMessage
def execute_multi_turn_workflow(
user_request: str,
flake: Flake,
conversation_history: list["ChatMessage"] | None = None,
provider: str = "claude",
session_state: SessionState | None = None,
) -> "ChatResult":
"""Execute the multi-turn workflow, auto-executing all pending operations.
This simulates the behavior of the CLI auto-execute loop in workflow.py.
"""
result = get_llm_turn(
user_request=user_request,
flake=flake,
conversation_history=conversation_history,
provider=provider, # type: ignore[arg-type]
session_state=session_state,
execute_next_action=False,
)
# Auto-execute any pending operations
while result.next_action:
result = get_llm_turn(
user_request="",
flake=flake,
conversation_history=list(result.conversation_history),
provider=provider, # type: ignore[arg-type]
session_state=result.session_state,
execute_next_action=True,
)
return result
@pytest.fixture @pytest.fixture
def trace_data() -> list[dict[str, Any]]: def trace_data() -> list[dict[str, Any]]:
"""Load trace data from mytrace.json.""" """Load trace data from mytrace.json."""
@@ -181,17 +212,23 @@ class TestProcessChatTurn:
# Mock final decision (shouldn't be called, but mock it anyway for safety) # Mock final decision (shouldn't be called, but mock it anyway for safety)
mock_final.return_value = ([], "") mock_final.return_value = ([], "")
# Run process_chat_turn # Run multi-turn workflow
result = process_chat_turn( result = execute_multi_turn_workflow(
user_request="What VPNs are available?", user_request="What VPNs are available?",
flake=mock_flake, flake=mock_flake,
conversation_history=None, conversation_history=None,
provider="claude", provider="claude",
) )
# Verify the call was made # Verify the discovery call was made
assert mock_call.called assert mock_call.called
# Verify readme execution was called
assert mock_execute.called
# Verify service selection was called
assert mock_selection.called
# Final decision should NOT be called since we return early with clarifying message # Final decision should NOT be called since we return early with clarifying message
assert not mock_final.called assert not mock_final.called
@@ -262,8 +299,8 @@ class TestProcessChatTurn:
final_trace["response"]["message"], final_trace["response"]["message"],
) )
# Run process_chat_turn with session state # Run multi-turn workflow with session state
result = process_chat_turn( result = execute_multi_turn_workflow(
user_request="Hmm zerotier please", user_request="Hmm zerotier please",
flake=mock_flake, flake=mock_flake,
conversation_history=conversation_history, conversation_history=conversation_history,
@@ -271,6 +308,12 @@ class TestProcessChatTurn:
session_state=session_state, session_state=session_state,
) )
# Verify service selection was called
assert mock_selection.called
# Verify final decision was called
assert mock_final.called
# Verify the result # Verify the result
assert result.requires_user_response is True assert result.requires_user_response is True
assert "controller" in result.assistant_message.lower() assert "controller" in result.assistant_message.lower()
@@ -335,8 +378,8 @@ class TestProcessChatTurn:
"", "",
) )
# Run process_chat_turn # Run multi-turn workflow
result = process_chat_turn( result = execute_multi_turn_workflow(
user_request="okay then gchq-local as controller and qube-email as moon please everything else as peer", user_request="okay then gchq-local as controller and qube-email as moon please everything else as peer",
flake=mock_flake, flake=mock_flake,
conversation_history=conversation_history, conversation_history=conversation_history,
@@ -344,6 +387,9 @@ class TestProcessChatTurn:
session_state=session_state, session_state=session_state,
) )
# Verify final decision was called
assert mock_final.called
# Verify the result # Verify the result
assert result.requires_user_response is False assert result.requires_user_response is False
assert len(result.proposed_instances) == 1 assert len(result.proposed_instances) == 1
@@ -394,19 +440,19 @@ class TestProcessChatTurn:
) )
mock_final.return_value = ([], "") mock_final.return_value = ([], "")
result1 = process_chat_turn( result1 = execute_multi_turn_workflow(
user_request="What VPNs are available?", user_request="What VPNs are available?",
flake=mock_flake, flake=mock_flake,
provider="claude", provider="claude",
) )
# Verify final decision was not called # Verify final decision was not called (since we get clarifying message)
assert not mock_final.called assert not mock_final.called
# Verify discovery completed and moved to service selection # Verify discovery completed and service selection asked clarifying question
assert result1.requires_user_response is True assert result1.requires_user_response is True
assert "VPN" in result1.assistant_message assert "VPN" in result1.assistant_message
# Session state should have pending_service_selection # Session state should have pending_service_selection (with readme results)
assert "pending_service_selection" in result1.session_state assert "pending_service_selection" in result1.session_state
# Test Turn 2: Continue with session state # Test Turn 2: Continue with session state
@@ -425,7 +471,7 @@ class TestProcessChatTurn:
) )
mock_final.return_value = ([], trace_data[3]["response"]["message"]) mock_final.return_value = ([], trace_data[3]["response"]["message"])
result2 = process_chat_turn( result2 = execute_multi_turn_workflow(
user_request="Hmm zerotier please", user_request="Hmm zerotier please",
flake=mock_flake, flake=mock_flake,
conversation_history=list(result1.conversation_history), conversation_history=list(result1.conversation_history),
@@ -485,7 +531,7 @@ class TestProcessChatTurn:
# Return empty function_calls but with a clarifying message # Return empty function_calls but with a clarifying message
mock_final.return_value = ([], clarify_trace["response"]["message"]) mock_final.return_value = ([], clarify_trace["response"]["message"])
result = process_chat_turn( result = execute_multi_turn_workflow(
user_request="Set up zerotier with gchq-local as controller", user_request="Set up zerotier with gchq-local as controller",
flake=mock_flake, flake=mock_flake,
conversation_history=conversation_history, conversation_history=conversation_history,
@@ -535,7 +581,7 @@ class TestProcessChatTurn:
] ]
mock_final.return_value = ([], "") mock_final.return_value = ([], "")
result = process_chat_turn( result = execute_multi_turn_workflow(
user_request="I want to set up a VPN", user_request="I want to set up a VPN",
flake=mock_flake, flake=mock_flake,
provider="claude", provider="claude",
@@ -624,7 +670,7 @@ class TestProcessChatTurn:
] ]
) )
result = process_chat_turn( result = execute_multi_turn_workflow(
user_request="Use zerotier with gchq-local as controller, qube-email as moon, rest as peers", user_request="Use zerotier with gchq-local as controller, qube-email as moon, rest as peers",
flake=mock_flake, flake=mock_flake,
conversation_history=conversation_history, conversation_history=conversation_history,
@@ -632,6 +678,12 @@ class TestProcessChatTurn:
session_state=session_state, session_state=session_state,
) )
# Verify service selection was called
assert mock_selection.called
# Verify final decision was called
assert mock_final.called
# Verify the function_calls branch in _continue_with_service_selection # Verify the function_calls branch in _continue_with_service_selection
assert result.requires_user_response is False assert result.requires_user_response is False
assert len(result.proposed_instances) == 1 assert len(result.proposed_instances) == 1
@@ -1004,15 +1056,16 @@ class TestProcessChatTurnPendingFinalDecision:
response = create_openai_response([], clarify_trace["response"]["message"]) response = create_openai_response([], clarify_trace["response"]["message"])
mock_call.return_value = response mock_call.return_value = response
result = process_chat_turn( result = get_llm_turn(
user_request="gchq-local as controller", user_request="gchq-local as controller",
flake=mock_flake, flake=mock_flake,
conversation_history=conversation_history, conversation_history=conversation_history,
provider="claude", provider="claude",
session_state=session_state, session_state=session_state,
execute_next_action=True, # Execute the pending final decision
) )
# Verify the if final_message branch at line 425 was taken # Verify the if final_message branch was taken
assert result.requires_user_response is True assert result.requires_user_response is True
assert result.assistant_message == clarify_trace["response"]["message"] assert result.assistant_message == clarify_trace["response"]["message"]
@@ -1074,12 +1127,13 @@ class TestProcessChatTurnPendingFinalDecision:
response = create_openai_response(function_calls, "") response = create_openai_response(function_calls, "")
mock_call.return_value = response mock_call.return_value = response
result = process_chat_turn( result = get_llm_turn(
user_request="gchq-local as controller, qube-email as moon, rest as peers", user_request="gchq-local as controller, qube-email as moon, rest as peers",
flake=mock_flake, flake=mock_flake,
conversation_history=conversation_history, conversation_history=conversation_history,
provider="claude", provider="claude",
session_state=session_state, session_state=session_state,
execute_next_action=True, # Execute the pending final decision
) )
# Verify configuration completed # Verify configuration completed
@@ -1094,186 +1148,6 @@ class TestProcessChatTurnPendingFinalDecision:
assert result.error is None assert result.error is None
class TestProgressCallbacks:
"""Test progress_callback functionality in process_chat_turn."""
def test_progress_callback_during_readme_fetch(
self, trace_data: list[dict[str, Any]], mock_flake: MagicMock
) -> None:
"""Test that progress_callback is called during README fetching."""
# Use trace entry with README requests
discovery_trace = trace_data[0]
function_calls = discovery_trace["response"]["function_calls"]
assert len(function_calls) > 0
# Track progress events
progress_events: list[Any] = []
def track_progress(event: Any) -> None:
progress_events.append(event)
# Create response with get_readme calls
response = create_openai_response(function_calls, "")
with (
patch("clan_lib.llm.phases.call_claude_api", return_value=response),
patch("clan_lib.llm.orchestrator.execute_readme_requests") as mock_execute,
patch(
"clan_lib.llm.orchestrator.get_llm_service_selection"
) as mock_selection,
patch("clan_lib.llm.orchestrator.get_llm_final_decision") as mock_final,
):
mock_execute.return_value = {
None: ServiceReadmeCollection(
input_name=None,
readmes={
"wireguard": "# WireGuard README",
"zerotier": "# ZeroTier README",
"mycelium": "# Mycelium README",
"yggdrasil": "# Yggdrasil README",
},
)
}
mock_selection.return_value = ServiceSelectionResult(
selected_service=None,
service_summary=None,
clarifying_message=trace_data[1]["response"]["message"],
)
mock_final.return_value = ([], "")
result = process_chat_turn(
user_request="What VPNs are available?",
flake=mock_flake,
provider="claude",
progress_callback=track_progress,
)
# Verify final decision was not called
assert not mock_final.called
# Verify progress events were sent
assert len(progress_events) > 0
# Check for discovery progress events
discovery_events = [
e for e in progress_events if isinstance(e, DiscoveryProgressEvent)
]
assert len(discovery_events) >= 2 # At least start and complete
# Check for readme fetch progress events
fetch_events = [
e for e in progress_events if isinstance(e, ReadmeFetchProgressEvent)
]
assert len(fetch_events) >= 2 # fetching and complete
# Verify the fetching event has correct data
fetching_event = next(e for e in fetch_events if e.status == "fetching")
assert fetching_event.count == len(function_calls)
# Service names include "(from built-in)" or "(from <input>)" suffix
assert any("wireguard" in name for name in fetching_event.service_names)
# Verify the complete event
complete_event = next(e for e in fetch_events if e.status == "complete")
assert complete_event.count == len(function_calls)
# Result should still be successful
assert result.requires_user_response is True
def test_progress_callback_through_full_workflow(
self, trace_data: list[dict[str, Any]], mock_flake: MagicMock
) -> None:
"""Test progress_callback through entire workflow from discovery to config."""
progress_events: list[Any] = []
def track_progress(event: Any) -> None:
progress_events.append(event)
# Setup for full workflow
discovery_response = create_openai_response(
trace_data[0]["response"]["function_calls"],
trace_data[0]["response"]["message"],
)
with (
patch(
"clan_lib.llm.phases.call_claude_api", return_value=discovery_response
),
patch("clan_lib.llm.orchestrator.execute_readme_requests") as mock_execute,
patch(
"clan_lib.llm.orchestrator.get_llm_service_selection"
) as mock_selection,
patch("clan_lib.llm.orchestrator.get_llm_final_decision") as mock_final,
patch("clan_lib.llm.phases.aggregate_ollama_function_schemas") as mock_agg,
):
mock_execute.return_value = {
None: ServiceReadmeCollection(
input_name=None, readmes={"zerotier": "# ZeroTier README"}
)
}
mock_selection.return_value = ServiceSelectionResult(
selected_service="zerotier",
service_summary="ZeroTier mesh VPN",
clarifying_message="",
)
# Return configuration
final_trace = trace_data[-1]
mock_final.return_value = (
[
FunctionCallType(
id="call_0",
call_id="call_0",
type="function_call",
name="zerotier",
arguments=json.dumps(
final_trace["response"]["function_calls"][0]["arguments"]
),
)
],
"",
)
mock_agg.return_value = MagicMock(
tools=[
{
"type": "function",
"function": {"name": "zerotier", "description": "ZeroTier VPN"},
}
]
)
result = process_chat_turn(
user_request="Setup zerotier with gchq-local as controller",
flake=mock_flake,
provider="claude",
progress_callback=track_progress,
)
# Verify we got progress events for all phases
discovery_events = [
e for e in progress_events if isinstance(e, DiscoveryProgressEvent)
]
fetch_events = [
e for e in progress_events if isinstance(e, ReadmeFetchProgressEvent)
]
selection_events = [
e
for e in progress_events
if isinstance(e, ServiceSelectionProgressEvent)
]
final_events = [
e for e in progress_events if isinstance(e, FinalDecisionProgressEvent)
]
# Should have events from all phases
assert len(discovery_events) > 0
assert len(fetch_events) > 0
assert len(selection_events) > 0
assert len(final_events) > 0
# Result should be successful with config
assert result.requires_user_response is False
assert len(result.proposed_instances) == 1
class TestErrorCases: class TestErrorCases:
"""Test error handling in process_chat_turn.""" """Test error handling in process_chat_turn."""
@@ -1288,7 +1162,8 @@ class TestErrorCases:
patch("clan_lib.llm.phases.call_claude_api", return_value=response), patch("clan_lib.llm.phases.call_claude_api", return_value=response),
pytest.raises(ClanAiError, match="did not provide any response"), pytest.raises(ClanAiError, match="did not provide any response"),
): ):
process_chat_turn( # Use multi-turn workflow to execute through discovery
execute_multi_turn_workflow(
user_request="Setup a VPN", user_request="Setup a VPN",
flake=mock_flake, flake=mock_flake,
provider="claude", provider="claude",
@@ -1304,7 +1179,8 @@ class TestErrorCases:
), ),
pytest.raises(ValueError, match="Test error"), pytest.raises(ValueError, match="Test error"),
): ):
process_chat_turn( # Use multi-turn workflow to execute through discovery
execute_multi_turn_workflow(
user_request="Setup a VPN", user_request="Setup a VPN",
flake=mock_flake, flake=mock_flake,
provider="claude", provider="claude",
@@ -1326,86 +1202,14 @@ class TestErrorCases:
), ),
pytest.raises(RuntimeError, match="Network error"), pytest.raises(RuntimeError, match="Network error"),
): ):
process_chat_turn( # Use multi-turn workflow to execute through discovery
execute_multi_turn_workflow(
user_request="Setup zerotier", user_request="Setup zerotier",
flake=mock_flake, flake=mock_flake,
conversation_history=conversation_history, conversation_history=conversation_history,
provider="claude", provider="claude",
) )
def test_progress_callback_final_decision_reviewing_and_complete(
self, trace_data: list[dict[str, Any]], mock_flake: MagicMock
) -> None:
"""Test FinalDecisionProgressEvent with reviewing and complete statuses."""
progress_events: list[Any] = []
def track_progress(event: Any) -> None:
progress_events.append(event)
# Build conversation history and session state for pending_final_decision
conversation_history: list[ChatMessage] = [
{"role": "user", "content": "Setup VPN"},
{"role": "assistant", "content": "Which service?"},
{"role": "user", "content": "Use zerotier"},
{"role": "assistant", "content": "Which machine as controller?"},
]
session_state: SessionState = cast(
"SessionState",
{
"pending_final_decision": {
"service_name": "zerotier",
"service_summary": "ZeroTier mesh VPN",
}
},
)
# Use final trace with configuration
final_trace = trace_data[-1]
function_calls = final_trace["response"]["function_calls"]
with (
patch("clan_lib.llm.phases.aggregate_ollama_function_schemas") as mock_agg,
patch("clan_lib.llm.phases.call_claude_api") as mock_call,
):
mock_agg.return_value = MagicMock(
tools=[
{
"type": "function",
"function": {"name": "zerotier", "description": "ZeroTier VPN"},
}
]
)
response = create_openai_response(function_calls, "")
mock_call.return_value = response
result = process_chat_turn(
user_request="gchq-local as controller, qube-email as moon, rest as peers",
flake=mock_flake,
conversation_history=conversation_history,
provider="claude",
session_state=session_state,
progress_callback=track_progress,
)
# Verify we got FinalDecisionProgressEvent with both statuses
final_events = [
e for e in progress_events if isinstance(e, FinalDecisionProgressEvent)
]
assert len(final_events) >= 2
# Check for "reviewing" status
reviewing_events = [e for e in final_events if e.status == "reviewing"]
assert len(reviewing_events) >= 1
# Check for "complete" status
complete_events = [e for e in final_events if e.status == "complete"]
assert len(complete_events) >= 1
# Result should be successful
assert result.requires_user_response is False
assert len(result.proposed_instances) == 1
def test_service_selection_fails_no_service_selected( def test_service_selection_fails_no_service_selected(
self, mock_flake: MagicMock self, mock_flake: MagicMock
) -> None: ) -> None:
@@ -1443,7 +1247,8 @@ class TestErrorCases:
# Should raise ClanAiError # Should raise ClanAiError
with pytest.raises(ClanAiError, match="Failed to select service"): with pytest.raises(ClanAiError, match="Failed to select service"):
process_chat_turn( # Use multi-turn workflow to execute through service selection
execute_multi_turn_workflow(
user_request="Setup VPN", user_request="Setup VPN",
flake=mock_flake, flake=mock_flake,
provider="claude", provider="claude",
@@ -1680,7 +1485,8 @@ class TestGetLlmFinalDecisionErrors:
# Should raise ClanAiError # Should raise ClanAiError
with pytest.raises(ClanAiError, match="LLM did not provide any response"): with pytest.raises(ClanAiError, match="LLM did not provide any response"):
process_chat_turn( # Use multi-turn workflow to execute through final decision
execute_multi_turn_workflow(
user_request="gchq-local as controller", user_request="gchq-local as controller",
flake=mock_flake, flake=mock_flake,
conversation_history=conversation_history, conversation_history=conversation_history,