Merge pull request 'clan_lib/llm: get_llm_turn uses state transitions instead of callback function' (#5659) from Qubasa/clan-core:llm_no_callback2 into main
Reviewed-on: https://git.clan.lol/clan/clan-core/pulls/5659
This commit is contained in:
@@ -76,7 +76,6 @@ in
|
||||
cmd = "su - text-user -c 'pytest -s -n0 -m service_runner -p no:cacheprovider -o addopts="" ${cli.passthru.sourceWithTests}/clan_lib/llm'"
|
||||
print("Running tests with command: " + cmd)
|
||||
|
||||
|
||||
# Run tests as text-user (environment variables are set automatically)
|
||||
peer1.succeed(cmd)
|
||||
'';
|
||||
|
||||
@@ -1,200 +0,0 @@
|
||||
"""High-level API functions for LLM interactions, suitable for HTTP APIs and web UIs.
|
||||
|
||||
This module provides a clean, stateless API for integrating LLM functionality into
|
||||
web applications and HTTP services. It wraps the complex multi-stage workflow into
|
||||
simple function calls with serializable inputs and outputs.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
|
||||
from clan_lib.api import API
|
||||
from clan_lib.flake.flake import Flake
|
||||
|
||||
from .llm import (
|
||||
DEFAULT_MODELS,
|
||||
ChatResult,
|
||||
DiscoveryProgressEvent,
|
||||
FinalDecisionProgressEvent,
|
||||
ModelConfig,
|
||||
ProgressCallback,
|
||||
ProgressEvent,
|
||||
ReadmeFetchProgressEvent,
|
||||
get_model_config,
|
||||
process_chat_turn,
|
||||
)
|
||||
from .schemas import ChatMessage, ConversationHistory, SessionState
|
||||
|
||||
|
||||
class ChatTurnRequest(TypedDict, total=False):
|
||||
"""Request payload for a chat turn.
|
||||
|
||||
Attributes:
|
||||
user_message: The user's message/request
|
||||
conversation_history: Optional list of prior messages in the conversation
|
||||
provider: The LLM provider to use (default: "claude")
|
||||
trace_file: Optional path to write LLM interaction traces for debugging
|
||||
session_state: Opaque state returned from the previous turn
|
||||
|
||||
"""
|
||||
|
||||
user_message: str
|
||||
conversation_history: ConversationHistory | None
|
||||
provider: Literal["openai", "ollama", "claude"]
|
||||
trace_file: Path | None
|
||||
session_state: SessionState | None
|
||||
|
||||
|
||||
class ChatTurnResponse(TypedDict):
|
||||
"""Response payload for a chat turn.
|
||||
|
||||
Attributes:
|
||||
proposed_instances: List of inventory instances suggested by the LLM
|
||||
conversation_history: Updated conversation history after this turn
|
||||
assistant_message: Message from the assistant
|
||||
requires_user_response: Whether the assistant is waiting for user input
|
||||
error: Error message if something went wrong (None on success)
|
||||
session_state: State blob to pass into the next turn when continuing the workflow
|
||||
|
||||
"""
|
||||
|
||||
proposed_instances: list[dict[str, Any]]
|
||||
conversation_history: list[ChatMessage]
|
||||
assistant_message: str
|
||||
requires_user_response: bool
|
||||
error: str | None
|
||||
session_state: SessionState
|
||||
|
||||
|
||||
class ProgressEventResponse(TypedDict):
|
||||
"""Progress event for streaming updates.
|
||||
|
||||
Attributes:
|
||||
stage: The current stage of processing
|
||||
status: The status within that stage (if applicable)
|
||||
count: Count of items (for readme_fetch stage)
|
||||
message: Message content (for conversation stage)
|
||||
|
||||
"""
|
||||
|
||||
stage: str
|
||||
status: str | None
|
||||
count: int | None
|
||||
message: str | None
|
||||
|
||||
|
||||
@API.register
|
||||
def get_llm_turn(
|
||||
flake: Flake,
|
||||
request: ChatTurnRequest,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
) -> ChatTurnResponse:
|
||||
"""Process a single chat turn through the LLM workflow.
|
||||
|
||||
This is the main entry point for HTTP APIs and web UIs to interact with
|
||||
the LLM functionality. It handles:
|
||||
- Service discovery
|
||||
- Documentation fetching
|
||||
- Final decision making
|
||||
- Conversation management
|
||||
|
||||
Args:
|
||||
flake: The Flake object representing the clan configuration
|
||||
request: The chat turn request containing user message and optional history
|
||||
progress_callback: Optional callback for progress updates
|
||||
|
||||
Returns:
|
||||
ChatTurnResponse with proposed instances and conversation state
|
||||
|
||||
Example:
|
||||
>>> from clan_lib.flake.flake import Flake
|
||||
>>> flake = Flake("/path/to/clan")
|
||||
>>> request: ChatTurnRequest = {
|
||||
... "user_message": "Set up a web server",
|
||||
... "provider": "claude"
|
||||
... }
|
||||
>>> response = chat_turn(flake, request)
|
||||
>>> if response["proposed_instances"]:
|
||||
... print("LLM suggests:", response["proposed_instances"])
|
||||
>>> if response["requires_user_response"]:
|
||||
... print("Assistant asks:", response["assistant_message"])
|
||||
|
||||
"""
|
||||
result: ChatResult = process_chat_turn(
|
||||
user_request=request["user_message"],
|
||||
flake=flake,
|
||||
conversation_history=request.get("conversation_history"),
|
||||
provider=request.get("provider", "claude"),
|
||||
progress_callback=progress_callback,
|
||||
trace_file=request.get("trace_file"),
|
||||
session_state=request.get("session_state"),
|
||||
)
|
||||
|
||||
# Convert frozen tuples to lists for JSON serialization
|
||||
return ChatTurnResponse(
|
||||
proposed_instances=[dict(inst) for inst in result.proposed_instances],
|
||||
conversation_history=list(result.conversation_history),
|
||||
assistant_message=result.assistant_message,
|
||||
requires_user_response=result.requires_user_response,
|
||||
error=result.error,
|
||||
session_state=cast("SessionState", dict(result.session_state)),
|
||||
)
|
||||
|
||||
|
||||
def progress_event_to_dict(event: ProgressEvent) -> ProgressEventResponse:
|
||||
"""Convert a ProgressEvent to a dictionary suitable for JSON serialization.
|
||||
|
||||
This helper function is useful for streaming progress updates over HTTP
|
||||
(e.g., Server-Sent Events or WebSockets).
|
||||
|
||||
Args:
|
||||
event: The progress event to convert
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the event
|
||||
|
||||
Example:
|
||||
>>> from clan_lib.llm.llm import DiscoveryProgressEvent
|
||||
>>> event = DiscoveryProgressEvent(status="analyzing")
|
||||
>>> progress_event_to_dict(event)
|
||||
{'stage': 'discovery', 'status': 'analyzing', 'count': None, 'message': None}
|
||||
|
||||
"""
|
||||
base_response: ProgressEventResponse = {
|
||||
"stage": event.stage,
|
||||
"status": None,
|
||||
"count": None,
|
||||
"message": None,
|
||||
}
|
||||
|
||||
if isinstance(event, (DiscoveryProgressEvent, FinalDecisionProgressEvent)):
|
||||
base_response["status"] = event.status
|
||||
elif isinstance(event, ReadmeFetchProgressEvent):
|
||||
base_response["status"] = event.status
|
||||
base_response["count"] = event.count
|
||||
# ConversationProgressEvent has message field
|
||||
elif hasattr(event, "message"):
|
||||
base_response["message"] = event.message # type: ignore[attr-defined]
|
||||
if hasattr(event, "awaiting_response"):
|
||||
base_response["status"] = (
|
||||
"awaiting_response"
|
||||
if event.awaiting_response # type: ignore[attr-defined]
|
||||
else "complete"
|
||||
)
|
||||
|
||||
return base_response
|
||||
|
||||
|
||||
# Re-export types for convenience
|
||||
__all__ = [
|
||||
"DEFAULT_MODELS",
|
||||
"ChatTurnRequest",
|
||||
"ChatTurnResponse",
|
||||
"ModelConfig",
|
||||
"ProgressCallback",
|
||||
"ProgressEvent",
|
||||
"ProgressEventResponse",
|
||||
"get_llm_turn",
|
||||
"get_model_config",
|
||||
"progress_event_to_dict",
|
||||
]
|
||||
@@ -2,16 +2,86 @@ import contextlib
|
||||
import json
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from clan_lib.flake.flake import Flake
|
||||
from clan_lib.llm.llm import (
|
||||
process_chat_turn,
|
||||
)
|
||||
from clan_lib.llm.orchestrator import get_llm_turn
|
||||
from clan_lib.llm.service import create_llm_model, run_llm_service
|
||||
from clan_lib.service_runner import create_service_manager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from clan_lib.llm.llm_types import ChatResult
|
||||
from clan_lib.llm.schemas import ChatMessage, SessionState
|
||||
|
||||
|
||||
def get_current_mode(session_state: "SessionState") -> str:
|
||||
"""Extract the current mode from session state."""
|
||||
if "pending_service_selection" in session_state:
|
||||
return "SERVICE_SELECTION"
|
||||
if "pending_final_decision" in session_state:
|
||||
return "FINAL_DECISION"
|
||||
return "DISCOVERY"
|
||||
|
||||
|
||||
def print_separator(
|
||||
title: str = "", char: str = "=", width: int = 80, double: bool = True
|
||||
) -> None:
|
||||
"""Print a separator line with optional title."""
|
||||
if double:
|
||||
print(f"\n{char * width}")
|
||||
if title:
|
||||
padding = (width - len(title) - 2) // 2
|
||||
print(f"{char * padding} {title} {char * padding}")
|
||||
if double or title:
|
||||
print(f"{char * width}")
|
||||
|
||||
|
||||
def print_meta_info(result: "ChatResult", turn: int, phase: str) -> None:
|
||||
"""Print meta information section in a structured format."""
|
||||
mode = get_current_mode(result.session_state)
|
||||
print_separator("META INFORMATION", char="-", width=80, double=False)
|
||||
print(f" Turn Number: {turn}")
|
||||
print(f" Phase: {phase}")
|
||||
print(f" Current Mode: {mode}")
|
||||
print(f" Requires User Input: {result.requires_user_response}")
|
||||
print(f" Conversation Length: {len(result.conversation_history)} messages")
|
||||
print(f" Proposed Instances: {len(result.proposed_instances)}")
|
||||
print(f" Has Next Action: {result.next_action is not None}")
|
||||
print(f" Session State Keys: {list(result.session_state.keys())}")
|
||||
if result.error:
|
||||
print(f" Error: {result.error}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
def print_chat_exchange(
|
||||
user_msg: str | None, assistant_msg: str, session_state: "SessionState"
|
||||
) -> None:
|
||||
"""Print a chat exchange with role labels and current mode."""
|
||||
mode = get_current_mode(session_state)
|
||||
print_separator("CHAT EXCHANGE", char="-", width=80, double=False)
|
||||
|
||||
if user_msg:
|
||||
print("\n USER:")
|
||||
print(f" {user_msg}")
|
||||
|
||||
print(f"\n ASSISTANT [{mode}]:")
|
||||
# Wrap long messages
|
||||
max_line_length = 76
|
||||
words = assistant_msg.split()
|
||||
current_line = " "
|
||||
for word in words:
|
||||
if len(current_line) + len(word) + 1 > max_line_length:
|
||||
print(current_line)
|
||||
current_line = " " + word
|
||||
else:
|
||||
current_line += (" " if len(current_line) > 2 else "") + word
|
||||
if current_line.strip():
|
||||
print(current_line)
|
||||
|
||||
print("\n" + "-" * 80)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flake() -> MagicMock:
|
||||
@@ -47,9 +117,9 @@ def mock_flake() -> MagicMock:
|
||||
}
|
||||
|
||||
match arg:
|
||||
case "clanInternals.inventoryClass.inventory.{instances,machines,meta}":
|
||||
case "clanInternals.inventoryClass.inventorySerialization.{instances,machines,meta}":
|
||||
return load_json("inventory_instances_machines_meta.json")
|
||||
case "clanInternals.inventoryClass.inventory.{tags}":
|
||||
case "clanInternals.inventoryClass.inventorySerialization.{tags}":
|
||||
return load_json("inventory_tags.json")
|
||||
case "clanInternals.inventoryClass.modulesPerSource":
|
||||
return load_json("modules_per_source.json")
|
||||
@@ -98,6 +168,40 @@ def llm_service() -> Iterator[None]:
|
||||
service_manager.stop_service("ollama")
|
||||
|
||||
|
||||
def execute_multi_turn_workflow(
|
||||
user_request: str,
|
||||
flake: Flake | MagicMock,
|
||||
conversation_history: list["ChatMessage"] | None = None,
|
||||
provider: str = "ollama",
|
||||
session_state: "SessionState | None" = None,
|
||||
) -> "ChatResult":
|
||||
"""Execute the multi-turn workflow, auto-executing all pending operations.
|
||||
|
||||
This simulates the behavior of the CLI auto-execute loop in workflow.py.
|
||||
"""
|
||||
result = get_llm_turn(
|
||||
user_request=user_request,
|
||||
flake=flake,
|
||||
conversation_history=conversation_history,
|
||||
provider=provider, # type: ignore[arg-type]
|
||||
session_state=session_state,
|
||||
execute_next_action=False,
|
||||
)
|
||||
|
||||
# Auto-execute any pending operations
|
||||
while result.next_action:
|
||||
result = get_llm_turn(
|
||||
user_request="",
|
||||
flake=flake,
|
||||
conversation_history=list(result.conversation_history),
|
||||
provider=provider, # type: ignore[arg-type]
|
||||
session_state=result.session_state,
|
||||
execute_next_action=True,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@pytest.mark.service_runner
|
||||
@pytest.mark.usefixtures("mock_nix_shell", "llm_service")
|
||||
def test_full_conversation_flow(mock_flake: MagicMock) -> None:
|
||||
@@ -112,10 +216,9 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None:
|
||||
- Error handling and edge cases
|
||||
"""
|
||||
flake = mock_flake
|
||||
return
|
||||
# ========== TURN 1: Discovery Phase - Initial vague request ==========
|
||||
print("\n=== TURN 1: Initial discovery request ===")
|
||||
result = process_chat_turn(
|
||||
print_separator("TURN 1: Discovery Phase", char="=", width=80)
|
||||
result = execute_multi_turn_workflow(
|
||||
user_request="What VPN options do I have?",
|
||||
flake=flake,
|
||||
provider="ollama",
|
||||
@@ -133,24 +236,25 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None:
|
||||
assert result.conversation_history[-1]["role"] == "assistant"
|
||||
assert len(result.assistant_message) > 0, "Assistant should provide a response"
|
||||
|
||||
# Should transition to service selection phase with pending state
|
||||
assert "pending_service_selection" in result.session_state, (
|
||||
"Should have pending service selection"
|
||||
)
|
||||
assert "readme_results" in result.session_state["pending_service_selection"]
|
||||
# After multi-turn execution, we may have either:
|
||||
# - pending_service_selection (if LLM provided options and is waiting for choice)
|
||||
# - pending_final_decision (if LLM directly selected a service)
|
||||
# - no pending state (if LLM asked a clarifying question)
|
||||
|
||||
# No instances yet
|
||||
assert len(result.proposed_instances) == 0
|
||||
assert result.error is None
|
||||
|
||||
print(f"Assistant: {result.assistant_message[:200]}...")
|
||||
print(f"State: {list(result.session_state.keys())}")
|
||||
print(f"History length: {len(result.conversation_history)}")
|
||||
print_chat_exchange(
|
||||
"What VPN options do I have?", result.assistant_message, result.session_state
|
||||
)
|
||||
print_meta_info(result, turn=1, phase="Discovery")
|
||||
|
||||
# ========== TURN 2: Service Selection Phase - User makes a choice ==========
|
||||
print("\n=== TURN 2: User selects ZeroTier ===")
|
||||
result = process_chat_turn(
|
||||
user_request="I'll use ZeroTier please",
|
||||
print_separator("TURN 2: Service Selection", char="=", width=80)
|
||||
user_msg_2 = "I'll use ZeroTier please"
|
||||
result = execute_multi_turn_workflow(
|
||||
user_request=user_msg_2,
|
||||
flake=flake,
|
||||
conversation_history=list(result.conversation_history),
|
||||
provider="ollama",
|
||||
@@ -176,11 +280,8 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None:
|
||||
assert len(result.proposed_instances) > 0
|
||||
assert result.proposed_instances[0]["module"]["name"] == "zerotier"
|
||||
|
||||
print(
|
||||
f"Assistant: {result.assistant_message[:200] if result.assistant_message else 'No message'}..."
|
||||
)
|
||||
print(f"State: {list(result.session_state.keys())}")
|
||||
print(f"Requires response: {result.requires_user_response}")
|
||||
print_chat_exchange(user_msg_2, result.assistant_message, result.session_state)
|
||||
print_meta_info(result, turn=2, phase="Service Selection")
|
||||
|
||||
# ========== Continue conversation until we reach final decision or completion ==========
|
||||
max_turns = 10
|
||||
@@ -188,22 +289,24 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None:
|
||||
|
||||
while result.requires_user_response and turn_count < max_turns:
|
||||
turn_count += 1
|
||||
print(f"\n=== TURN {turn_count}: Continuing conversation ===")
|
||||
|
||||
# Determine appropriate response based on current state
|
||||
if "pending_service_selection" in result.session_state:
|
||||
# Still selecting service
|
||||
user_request = "Yes, ZeroTier"
|
||||
phase = "Service Selection (continued)"
|
||||
elif "pending_final_decision" in result.session_state:
|
||||
# Configuring the service
|
||||
user_request = "Set up gchq-local as controller, qube-email as moon, and wintux as peer"
|
||||
phase = "Final Configuration"
|
||||
else:
|
||||
# Generic continuation
|
||||
user_request = "Yes, that sounds good. Use gchq-local as controller."
|
||||
phase = "Continuing Conversation"
|
||||
|
||||
print(f"User: {user_request}")
|
||||
print_separator(f"TURN {turn_count}: {phase}", char="=", width=80)
|
||||
|
||||
result = process_chat_turn(
|
||||
result = execute_multi_turn_workflow(
|
||||
user_request=user_request,
|
||||
flake=flake,
|
||||
conversation_history=list(result.conversation_history),
|
||||
@@ -221,19 +324,18 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None:
|
||||
result.conversation_history[0]["content"] == "What VPN options do I have?"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Assistant: {result.assistant_message[:200] if result.assistant_message else 'No message'}..."
|
||||
print_chat_exchange(
|
||||
user_request, result.assistant_message, result.session_state
|
||||
)
|
||||
print(f"State: {list(result.session_state.keys())}")
|
||||
print(f"Requires response: {result.requires_user_response}")
|
||||
print(f"Proposed instances: {len(result.proposed_instances)}")
|
||||
print_meta_info(result, turn=turn_count, phase=phase)
|
||||
|
||||
# Check for completion
|
||||
if not result.requires_user_response:
|
||||
print("\n=== Conversation completed! ===")
|
||||
print_separator("CONVERSATION COMPLETED", char="=", width=80)
|
||||
break
|
||||
|
||||
# ========== Final Verification ==========
|
||||
print_separator("FINAL VERIFICATION", char="=", width=80)
|
||||
assert turn_count < max_turns, f"Conversation took too many turns ({turn_count})"
|
||||
|
||||
# If conversation completed, verify we have valid configuration
|
||||
@@ -253,22 +355,29 @@ def test_full_conversation_flow(mock_flake: MagicMock) -> None:
|
||||
"mycelium",
|
||||
]
|
||||
|
||||
# Should have roles configuration
|
||||
if "roles" in instance:
|
||||
print(f"\nConfiguration roles: {list(instance['roles'].keys())}")
|
||||
|
||||
# Should not be in pending state anymore
|
||||
assert "pending_service_selection" not in result.session_state
|
||||
assert "pending_final_decision" not in result.session_state
|
||||
|
||||
assert result.error is None, f"Should not have error: {result.error}"
|
||||
|
||||
print(f"\nFinal instance: {instance['module']['name']}")
|
||||
print(f"Total conversation turns: {turn_count}")
|
||||
print(f"Final history length: {len(result.conversation_history)}")
|
||||
print_separator("FINAL SUMMARY", char="-", width=80, double=False)
|
||||
print(" Status: SUCCESS")
|
||||
print(f" Module Name: {instance['module']['name']}")
|
||||
print(f" Total Turns: {turn_count}")
|
||||
print(f" Final History Length: {len(result.conversation_history)} messages")
|
||||
if "roles" in instance:
|
||||
roles_list = ", ".join(instance["roles"].keys())
|
||||
print(f" Configuration Roles: {roles_list}")
|
||||
print(" Errors: None")
|
||||
print("-" * 80)
|
||||
else:
|
||||
# Conversation didn't complete but should have made progress
|
||||
assert len(result.conversation_history) > 2
|
||||
assert result.error is None
|
||||
print(f"\nConversation in progress after {turn_count} turns")
|
||||
print(f"Current state: {list(result.session_state.keys())}")
|
||||
print_separator("FINAL SUMMARY", char="-", width=80, double=False)
|
||||
print(" Status: IN PROGRESS")
|
||||
print(f" Total Turns: {turn_count}")
|
||||
print(f" Current State: {list(result.session_state.keys())}")
|
||||
print(f" History Length: {len(result.conversation_history)} messages")
|
||||
print("-" * 80)
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
"""High-level LLM orchestration functions.
|
||||
|
||||
This module re-exports the LLM orchestration API from submodules.
|
||||
"""
|
||||
|
||||
# Re-export types and dataclasses
|
||||
from .llm_types import ( # noqa: F401
|
||||
DEFAULT_MODELS,
|
||||
ChatResult,
|
||||
ConversationProgressEvent,
|
||||
DiscoveryProgressEvent,
|
||||
FinalDecisionProgressEvent,
|
||||
ModelConfig,
|
||||
ProgressCallback,
|
||||
ProgressEvent,
|
||||
ReadmeFetchProgressEvent,
|
||||
ServiceSelectionProgressEvent,
|
||||
ServiceSelectionResult,
|
||||
get_model_config,
|
||||
)
|
||||
|
||||
# Re-export high-level orchestrator
|
||||
from .orchestrator import process_chat_turn # noqa: F401
|
||||
|
||||
# Re-export low-level phase functions
|
||||
from .phases import ( # noqa: F401
|
||||
execute_readme_requests,
|
||||
get_llm_discovery_phase,
|
||||
get_llm_final_decision,
|
||||
get_llm_service_selection,
|
||||
llm_final_decision_to_inventory_instances,
|
||||
)
|
||||
|
||||
# Re-export commonly used functions and types from schemas
|
||||
from .schemas import ( # noqa: F401
|
||||
AiAggregate,
|
||||
ChatMessage,
|
||||
ConversationHistory,
|
||||
FunctionCallType,
|
||||
JSONValue,
|
||||
MachineDescription,
|
||||
OllamaFunctionSchema,
|
||||
OpenAIFunctionSchema,
|
||||
PendingFinalDecisionState,
|
||||
PendingServiceSelectionState,
|
||||
ReadmeRequest,
|
||||
SessionState,
|
||||
SimplifiedServiceSchema,
|
||||
TagDescription,
|
||||
aggregate_ollama_function_schemas,
|
||||
aggregate_openai_function_schemas,
|
||||
create_get_readme_tool,
|
||||
create_select_service_tool,
|
||||
create_simplified_service_schemas,
|
||||
)
|
||||
|
||||
# Re-export service functions
|
||||
from .service import create_llm_model, run_llm_service # noqa: F401
|
||||
|
||||
# Re-export utility functions and constants
|
||||
from .utils import ( # noqa: F401
|
||||
ASSISTANT_MODE_DISCOVERY,
|
||||
ASSISTANT_MODE_FINAL,
|
||||
ASSISTANT_MODE_SELECTION,
|
||||
)
|
||||
@@ -3,12 +3,13 @@ from collections.abc import Callable
|
||||
import pytest
|
||||
from clan_cli.tests.fixtures_flakes import nested_dict
|
||||
from clan_lib.flake.flake import Flake
|
||||
from clan_lib.llm.llm import (
|
||||
from clan_lib.llm.phases import llm_final_decision_to_inventory_instances
|
||||
from clan_lib.llm.schemas import (
|
||||
FunctionCallType,
|
||||
OpenAIFunctionSchema,
|
||||
aggregate_openai_function_schemas,
|
||||
llm_final_decision_to_inventory_instances,
|
||||
clan_module_to_openai_spec,
|
||||
)
|
||||
from clan_lib.llm.schemas import FunctionCallType, clan_module_to_openai_spec
|
||||
from clan_lib.services.modules import list_service_modules
|
||||
|
||||
|
||||
|
||||
@@ -1,57 +1,28 @@
|
||||
"""Type definitions and dataclasses for LLM orchestration."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from clan_lib.nix_models.clan import InventoryInstance
|
||||
|
||||
from .schemas import ChatMessage, SessionState
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DiscoveryProgressEvent:
|
||||
"""Progress event during discovery phase."""
|
||||
class NextAction(TypedDict):
|
||||
"""Describes the next expensive operation that will be performed.
|
||||
|
||||
service_names: list[str] | None = None
|
||||
stage: Literal["discovery"] = "discovery"
|
||||
status: Literal["analyzing", "complete"] = "analyzing"
|
||||
Attributes:
|
||||
type: The type of operation (discovery, fetch_readmes, service_selection, final_decision)
|
||||
description: Human-readable description of what will happen
|
||||
estimated_duration_seconds: Rough estimate of operation duration
|
||||
details: Phase-specific information (e.g., service names, count)
|
||||
|
||||
"""
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReadmeFetchProgressEvent:
|
||||
"""Progress event during readme fetching."""
|
||||
|
||||
count: int
|
||||
service_names: list[str]
|
||||
stage: Literal["readme_fetch"] = "readme_fetch"
|
||||
status: Literal["fetching", "complete"] = "fetching"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ServiceSelectionProgressEvent:
|
||||
"""Progress event during service selection phase."""
|
||||
|
||||
service_names: list[str]
|
||||
stage: Literal["service_selection"] = "service_selection"
|
||||
status: Literal["selecting", "complete"] = "selecting"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FinalDecisionProgressEvent:
|
||||
"""Progress event during final decision phase."""
|
||||
|
||||
stage: Literal["final_decision"] = "final_decision"
|
||||
status: Literal["reviewing", "complete"] = "reviewing"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConversationProgressEvent:
|
||||
"""Progress event for conversation continuation."""
|
||||
|
||||
message: str
|
||||
stage: Literal["conversation"] = "conversation"
|
||||
awaiting_response: bool = True
|
||||
type: Literal["discovery", "fetch_readmes", "service_selection", "final_decision"]
|
||||
description: str
|
||||
estimated_duration_seconds: int
|
||||
details: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -70,17 +41,6 @@ class ServiceSelectionResult:
|
||||
clarifying_message: str
|
||||
|
||||
|
||||
ProgressEvent = (
|
||||
DiscoveryProgressEvent
|
||||
| ReadmeFetchProgressEvent
|
||||
| ServiceSelectionProgressEvent
|
||||
| FinalDecisionProgressEvent
|
||||
| ConversationProgressEvent
|
||||
)
|
||||
|
||||
ProgressCallback = Callable[[ProgressEvent], None]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChatResult:
|
||||
"""Result of a complete chat turn through the multi-stage workflow.
|
||||
@@ -92,6 +52,7 @@ class ChatResult:
|
||||
requires_user_response: True if the assistant asked a question and needs a response
|
||||
error: Error message if something went wrong (None on success)
|
||||
session_state: Serializable state to pass into the next turn when continuing a workflow
|
||||
next_action: Description of the next operation to be performed (None if workflow complete)
|
||||
|
||||
"""
|
||||
|
||||
@@ -100,6 +61,7 @@ class ChatResult:
|
||||
assistant_message: str
|
||||
requires_user_response: bool
|
||||
session_state: SessionState
|
||||
next_action: NextAction | None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
|
||||
@@ -4,18 +4,12 @@ import json
|
||||
from pathlib import Path
|
||||
from typing import Literal, cast
|
||||
|
||||
from clan_lib.api import API
|
||||
from clan_lib.errors import ClanAiError
|
||||
from clan_lib.flake.flake import Flake
|
||||
from clan_lib.services.modules import InputName, ServiceReadmeCollection
|
||||
|
||||
from .llm_types import (
|
||||
ChatResult,
|
||||
DiscoveryProgressEvent,
|
||||
FinalDecisionProgressEvent,
|
||||
ProgressCallback,
|
||||
ReadmeFetchProgressEvent,
|
||||
ServiceSelectionProgressEvent,
|
||||
)
|
||||
from .llm_types import ChatResult, NextAction
|
||||
from .phases import (
|
||||
execute_readme_requests,
|
||||
get_llm_discovery_phase,
|
||||
@@ -26,8 +20,11 @@ from .phases import (
|
||||
from .schemas import (
|
||||
ConversationHistory,
|
||||
JSONValue,
|
||||
PendingDiscoveryState,
|
||||
PendingFinalDecisionState,
|
||||
PendingReadmeFetchState,
|
||||
PendingServiceSelectionState,
|
||||
ReadmeRequest,
|
||||
SessionState,
|
||||
)
|
||||
from .utils import (
|
||||
@@ -41,45 +38,50 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
def process_chat_turn(
|
||||
@API.register
|
||||
def get_llm_turn(
|
||||
user_request: str,
|
||||
flake: Flake,
|
||||
conversation_history: ConversationHistory | None = None,
|
||||
provider: Literal["openai", "ollama", "claude"] = "ollama",
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
trace_file: Path | None = None,
|
||||
session_state: SessionState | None = None,
|
||||
execute_next_action: bool = False,
|
||||
) -> ChatResult:
|
||||
"""High-level API that orchestrates the entire multi-stage chat workflow.
|
||||
|
||||
This function handles the complete flow:
|
||||
This function handles the complete flow using a multi-turn approach:
|
||||
1. Discovery phase - LLM selects relevant services
|
||||
2. Readme fetching - Retrieves detailed documentation
|
||||
3. Final decision - LLM makes informed suggestions
|
||||
4. Conversion - Transforms suggestions to inventory instances
|
||||
|
||||
Before each expensive operation, the function returns with next_action describing
|
||||
what will happen. The caller must call again with execute_next_action=True.
|
||||
|
||||
Args:
|
||||
user_request: The user's message/request
|
||||
flake: The Flake object to get services from
|
||||
conversation_history: Optional list of prior messages in the conversation
|
||||
provider: The LLM provider to use
|
||||
progress_callback: Optional callback for progress updates
|
||||
trace_file: Optional path to write LLM interaction traces for debugging
|
||||
session_state: Optional cross-turn state to resume pending workflows
|
||||
execute_next_action: If True, execute the pending operation in session_state
|
||||
|
||||
Returns:
|
||||
ChatResult containing proposed instances, updated history, and assistant message
|
||||
ChatResult containing proposed instances, updated history, next_action, and assistant message
|
||||
|
||||
Example:
|
||||
>>> result = process_chat_turn(
|
||||
... "Set up a web server",
|
||||
... flake,
|
||||
... progress_callback=lambda event: print(f"Stage: {event.stage}")
|
||||
... )
|
||||
>>> if result.proposed_instances:
|
||||
... print("LLM suggested:", result.proposed_instances)
|
||||
>>> if result.requires_user_response:
|
||||
... print("Assistant asks:", result.assistant_message)
|
||||
>>> result = process_chat_turn("Set up a web server", flake)
|
||||
>>> while result.next_action:
|
||||
... # Show user what will happen
|
||||
... print(result.next_action["description"])
|
||||
... result = process_chat_turn(
|
||||
... user_request="",
|
||||
... flake=flake,
|
||||
... session_state=result.session_state,
|
||||
... execute_next_action=True
|
||||
... )
|
||||
|
||||
"""
|
||||
history = list(conversation_history) if conversation_history else []
|
||||
@@ -87,6 +89,11 @@ def process_chat_turn(
|
||||
"SessionState", dict(session_state) if session_state else {}
|
||||
)
|
||||
|
||||
# Add non-empty user message to history for conversation tracking
|
||||
# Phases will also add it to their messages array for LLM calls
|
||||
if user_request:
|
||||
history.append(_user_message(user_request))
|
||||
|
||||
def _state_snapshot() -> dict[str, JSONValue]:
|
||||
try:
|
||||
return json.loads(json.dumps(state))
|
||||
@@ -102,6 +109,17 @@ def process_chat_turn(
|
||||
def _state_copy() -> SessionState:
|
||||
return cast("SessionState", dict(state))
|
||||
|
||||
# Check pending states in workflow order (earliest to latest)
|
||||
pending_discovery_raw = state.get("pending_discovery")
|
||||
pending_discovery: PendingDiscoveryState | None = (
|
||||
pending_discovery_raw if isinstance(pending_discovery_raw, dict) else None
|
||||
)
|
||||
|
||||
pending_readme_fetch_raw = state.get("pending_readme_fetch")
|
||||
pending_readme_fetch: PendingReadmeFetchState | None = (
|
||||
pending_readme_fetch_raw if isinstance(pending_readme_fetch_raw, dict) else None
|
||||
)
|
||||
|
||||
pending_final_raw = state.get("pending_final_decision")
|
||||
pending_final: PendingFinalDecisionState | None = (
|
||||
pending_final_raw if isinstance(pending_final_raw, dict) else None
|
||||
@@ -117,87 +135,139 @@ def process_chat_turn(
|
||||
if serialized_results is not None:
|
||||
resume_readme_results = _deserialize_readme_results(serialized_results)
|
||||
|
||||
# Only pop if we can't deserialize (invalid state)
|
||||
if resume_readme_results is None:
|
||||
state.pop("pending_service_selection", None)
|
||||
else:
|
||||
state.pop("pending_service_selection", None)
|
||||
|
||||
# Handle pending_discovery state: execute discovery if execute_next_action=True
|
||||
if pending_discovery is not None and execute_next_action:
|
||||
state.pop("pending_discovery", None)
|
||||
# Continue to execute discovery below (after pending state checks)
|
||||
|
||||
# Handle pending_readme_fetch state: execute readme fetch if execute_next_action=True
|
||||
if pending_readme_fetch is not None and execute_next_action:
|
||||
readme_requests_raw = pending_readme_fetch.get("readme_requests", [])
|
||||
readme_requests = cast("list[ReadmeRequest]", readme_requests_raw)
|
||||
|
||||
if readme_requests:
|
||||
state.pop("pending_readme_fetch", None)
|
||||
readme_results = execute_readme_requests(readme_requests, flake)
|
||||
|
||||
# Save readme results and return next_action for service selection
|
||||
state["pending_service_selection"] = cast(
|
||||
"PendingServiceSelectionState",
|
||||
{"readme_results": _serialize_readme_results(readme_results)},
|
||||
)
|
||||
service_count = len(readme_results)
|
||||
next_action_selection: NextAction = {
|
||||
"type": "service_selection",
|
||||
"description": f"Analyzing {service_count} service(s) to find the best match",
|
||||
"estimated_duration_seconds": 15,
|
||||
"details": {"service_count": service_count},
|
||||
}
|
||||
return ChatResult(
|
||||
next_action=next_action_selection,
|
||||
proposed_instances=(),
|
||||
conversation_history=tuple(history),
|
||||
assistant_message="",
|
||||
requires_user_response=False,
|
||||
error=None,
|
||||
session_state=_state_copy(),
|
||||
)
|
||||
|
||||
state.pop("pending_readme_fetch", None)
|
||||
|
||||
if pending_final is not None:
|
||||
service_name = pending_final.get("service_name")
|
||||
service_summary = pending_final.get("service_summary")
|
||||
|
||||
if isinstance(service_name, str) and isinstance(service_summary, str):
|
||||
if progress_callback:
|
||||
progress_callback(FinalDecisionProgressEvent(status="reviewing"))
|
||||
|
||||
function_calls, final_message = get_llm_final_decision(
|
||||
user_request,
|
||||
flake,
|
||||
service_name,
|
||||
service_summary,
|
||||
conversation_history,
|
||||
provider=provider,
|
||||
trace_file=trace_file,
|
||||
trace_metadata=_metadata(
|
||||
{
|
||||
"selected_service": service_name,
|
||||
"resume": True,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(FinalDecisionProgressEvent(status="complete"))
|
||||
|
||||
history.append(_user_message(user_request))
|
||||
|
||||
if function_calls:
|
||||
proposed_instances = llm_final_decision_to_inventory_instances(
|
||||
function_calls
|
||||
)
|
||||
instance_names = [inst["module"]["name"] for inst in proposed_instances]
|
||||
summary = (
|
||||
f"I suggest configuring these services: {', '.join(instance_names)}"
|
||||
)
|
||||
history.append(_assistant_message(summary, mode=ASSISTANT_MODE_FINAL))
|
||||
if execute_next_action:
|
||||
state.pop("pending_final_decision", None)
|
||||
|
||||
return ChatResult(
|
||||
proposed_instances=tuple(proposed_instances),
|
||||
conversation_history=tuple(history),
|
||||
assistant_message=summary,
|
||||
requires_user_response=False,
|
||||
error=None,
|
||||
session_state=_state_copy(),
|
||||
function_calls, final_message = get_llm_final_decision(
|
||||
user_request,
|
||||
flake,
|
||||
service_name,
|
||||
service_summary,
|
||||
conversation_history,
|
||||
provider=provider,
|
||||
trace_file=trace_file,
|
||||
trace_metadata=_metadata(
|
||||
{
|
||||
"selected_service": service_name,
|
||||
"resume": True,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
if final_message:
|
||||
history.append(
|
||||
_assistant_message(final_message, mode=ASSISTANT_MODE_FINAL)
|
||||
)
|
||||
state["pending_final_decision"] = cast(
|
||||
"PendingFinalDecisionState",
|
||||
{
|
||||
"service_name": service_name,
|
||||
"service_summary": service_summary,
|
||||
},
|
||||
if function_calls:
|
||||
proposed_instances = llm_final_decision_to_inventory_instances(
|
||||
function_calls
|
||||
)
|
||||
instance_names = [
|
||||
inst["module"]["name"] for inst in proposed_instances
|
||||
]
|
||||
summary = f"I suggest configuring these services: {', '.join(instance_names)}"
|
||||
history.append(
|
||||
_assistant_message(summary, mode=ASSISTANT_MODE_FINAL)
|
||||
)
|
||||
|
||||
return ChatResult(
|
||||
next_action=None,
|
||||
proposed_instances=tuple(proposed_instances),
|
||||
conversation_history=tuple(history),
|
||||
assistant_message=summary,
|
||||
requires_user_response=False,
|
||||
error=None,
|
||||
session_state=_state_copy(),
|
||||
)
|
||||
|
||||
if final_message:
|
||||
history.append(
|
||||
_assistant_message(final_message, mode=ASSISTANT_MODE_FINAL)
|
||||
)
|
||||
state["pending_final_decision"] = cast(
|
||||
"PendingFinalDecisionState",
|
||||
{
|
||||
"service_name": service_name,
|
||||
"service_summary": service_summary,
|
||||
},
|
||||
)
|
||||
|
||||
return ChatResult(
|
||||
next_action=None,
|
||||
proposed_instances=(),
|
||||
conversation_history=tuple(history),
|
||||
assistant_message=final_message,
|
||||
requires_user_response=True,
|
||||
error=None,
|
||||
session_state=_state_copy(),
|
||||
)
|
||||
|
||||
state.pop("pending_final_decision", None)
|
||||
msg = "LLM did not provide any response or recommendations"
|
||||
raise ClanAiError(
|
||||
msg,
|
||||
description="Expected either function calls (configuration) or a clarifying message",
|
||||
location="Final Decision Phase (pending)",
|
||||
)
|
||||
|
||||
return ChatResult(
|
||||
proposed_instances=(),
|
||||
conversation_history=tuple(history),
|
||||
assistant_message=final_message,
|
||||
requires_user_response=True,
|
||||
error=None,
|
||||
session_state=_state_copy(),
|
||||
)
|
||||
|
||||
state.pop("pending_final_decision", None)
|
||||
msg = "LLM did not provide any response or recommendations"
|
||||
raise ClanAiError(
|
||||
msg,
|
||||
description="Expected either function calls (configuration) or a clarifying message",
|
||||
location="Final Decision Phase (pending)",
|
||||
# If not executing, return next_action for final decision
|
||||
next_action_final_pending: NextAction = {
|
||||
"type": "final_decision",
|
||||
"description": f"Generating configuration for {service_name}",
|
||||
"estimated_duration_seconds": 20,
|
||||
"details": {"service_name": service_name},
|
||||
}
|
||||
return ChatResult(
|
||||
next_action=next_action_final_pending,
|
||||
proposed_instances=(),
|
||||
conversation_history=tuple(history),
|
||||
assistant_message="",
|
||||
requires_user_response=False,
|
||||
error=None,
|
||||
session_state=_state_copy(),
|
||||
)
|
||||
|
||||
state.pop("pending_final_decision", None)
|
||||
@@ -206,19 +276,12 @@ def process_chat_turn(
|
||||
readme_results: dict[InputName, ServiceReadmeCollection],
|
||||
) -> ChatResult:
|
||||
# Extract all service names from readme results
|
||||
all_service_names = [
|
||||
[
|
||||
service_name
|
||||
for collection in readme_results.values()
|
||||
for service_name in collection.readmes
|
||||
]
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
ServiceSelectionProgressEvent(
|
||||
service_names=all_service_names, status="selecting"
|
||||
)
|
||||
)
|
||||
|
||||
selection_result = get_llm_service_selection(
|
||||
user_request,
|
||||
readme_results,
|
||||
@@ -232,7 +295,6 @@ def process_chat_turn(
|
||||
selection_result.clarifying_message
|
||||
and not selection_result.selected_service
|
||||
):
|
||||
history.append(_user_message(user_request))
|
||||
history.append(
|
||||
_assistant_message(
|
||||
selection_result.clarifying_message,
|
||||
@@ -247,6 +309,7 @@ def process_chat_turn(
|
||||
)
|
||||
|
||||
return ChatResult(
|
||||
next_action=None,
|
||||
proposed_instances=(),
|
||||
conversation_history=tuple(history),
|
||||
assistant_message=selection_result.clarifying_message,
|
||||
@@ -267,81 +330,79 @@ def process_chat_turn(
|
||||
location="Service Selection Phase",
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(FinalDecisionProgressEvent(status="reviewing"))
|
||||
|
||||
function_calls, final_message = get_llm_final_decision(
|
||||
user_request,
|
||||
flake,
|
||||
selection_result.selected_service,
|
||||
selection_result.service_summary,
|
||||
conversation_history,
|
||||
provider=provider,
|
||||
trace_file=trace_file,
|
||||
trace_metadata=_metadata(
|
||||
{"selected_service": selection_result.selected_service}
|
||||
),
|
||||
# After service selection, always return next_action for final decision
|
||||
state["pending_final_decision"] = cast(
|
||||
"PendingFinalDecisionState",
|
||||
{
|
||||
"service_name": selection_result.selected_service,
|
||||
"service_summary": selection_result.service_summary,
|
||||
},
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(FinalDecisionProgressEvent(status="complete"))
|
||||
|
||||
if function_calls:
|
||||
history.append(_user_message(user_request))
|
||||
|
||||
proposed_instances = llm_final_decision_to_inventory_instances(
|
||||
function_calls
|
||||
)
|
||||
|
||||
instance_names = [inst["module"]["name"] for inst in proposed_instances]
|
||||
summary = (
|
||||
f"I suggest configuring these services: {', '.join(instance_names)}"
|
||||
)
|
||||
history.append(_assistant_message(summary, mode=ASSISTANT_MODE_FINAL))
|
||||
state.pop("pending_final_decision", None)
|
||||
|
||||
return ChatResult(
|
||||
proposed_instances=tuple(proposed_instances),
|
||||
conversation_history=tuple(history),
|
||||
assistant_message=summary,
|
||||
requires_user_response=False,
|
||||
error=None,
|
||||
session_state=_state_copy(),
|
||||
)
|
||||
|
||||
if final_message:
|
||||
history.append(_user_message(user_request))
|
||||
history.append(_assistant_message(final_message, mode=ASSISTANT_MODE_FINAL))
|
||||
state["pending_final_decision"] = cast(
|
||||
"PendingFinalDecisionState",
|
||||
{
|
||||
"service_name": selection_result.selected_service,
|
||||
"service_summary": selection_result.service_summary,
|
||||
},
|
||||
)
|
||||
|
||||
return ChatResult(
|
||||
proposed_instances=(),
|
||||
conversation_history=tuple(history),
|
||||
assistant_message=final_message,
|
||||
requires_user_response=True,
|
||||
error=None,
|
||||
session_state=_state_copy(),
|
||||
)
|
||||
|
||||
msg = "LLM did not provide any response or recommendations"
|
||||
raise ClanAiError(
|
||||
msg,
|
||||
description="Expected either function calls (configuration) or a clarifying message after service selection",
|
||||
location="Final Decision Phase",
|
||||
next_action_final: NextAction = {
|
||||
"type": "final_decision",
|
||||
"description": f"Generating configuration for {selection_result.selected_service}",
|
||||
"estimated_duration_seconds": 20,
|
||||
"details": {"service_name": selection_result.selected_service},
|
||||
}
|
||||
return ChatResult(
|
||||
next_action=next_action_final,
|
||||
proposed_instances=(),
|
||||
conversation_history=tuple(history),
|
||||
assistant_message="",
|
||||
requires_user_response=False,
|
||||
error=None,
|
||||
session_state=_state_copy(),
|
||||
)
|
||||
|
||||
if resume_readme_results is not None:
|
||||
return _continue_with_service_selection(resume_readme_results)
|
||||
if execute_next_action:
|
||||
# Pop the pending state now that we're executing
|
||||
state.pop("pending_service_selection", None)
|
||||
return _continue_with_service_selection(resume_readme_results)
|
||||
|
||||
# If not executing, return next_action for service selection
|
||||
service_count = len(resume_readme_results)
|
||||
next_action_sel_resume: NextAction = {
|
||||
"type": "service_selection",
|
||||
"description": f"Analyzing {service_count} service(s) to find the best match",
|
||||
"estimated_duration_seconds": 15,
|
||||
"details": {"service_count": service_count},
|
||||
}
|
||||
return ChatResult(
|
||||
next_action=next_action_sel_resume,
|
||||
proposed_instances=(),
|
||||
conversation_history=tuple(history),
|
||||
assistant_message="",
|
||||
requires_user_response=False,
|
||||
error=None,
|
||||
session_state=_state_copy(),
|
||||
)
|
||||
|
||||
# Stage 1: Discovery phase
|
||||
if progress_callback:
|
||||
progress_callback(DiscoveryProgressEvent(status="analyzing"))
|
||||
# If we're not executing and have no pending states, return next_action for discovery
|
||||
has_pending_states = (
|
||||
pending_discovery is not None
|
||||
or pending_readme_fetch is not None
|
||||
or pending_selection is not None
|
||||
or pending_final is not None
|
||||
)
|
||||
if not execute_next_action and not has_pending_states:
|
||||
state["pending_discovery"] = cast("PendingDiscoveryState", {})
|
||||
next_action: NextAction = {
|
||||
"type": "discovery",
|
||||
"description": "Analyzing your request and discovering relevant services",
|
||||
"estimated_duration_seconds": 10,
|
||||
"details": {"phase": "discovery"},
|
||||
}
|
||||
return ChatResult(
|
||||
next_action=next_action,
|
||||
proposed_instances=(),
|
||||
conversation_history=tuple(history),
|
||||
assistant_message="",
|
||||
requires_user_response=False,
|
||||
error=None,
|
||||
session_state=_state_copy(),
|
||||
)
|
||||
|
||||
readme_requests, discovery_message = get_llm_discovery_phase(
|
||||
user_request,
|
||||
@@ -352,23 +413,14 @@ def process_chat_turn(
|
||||
trace_metadata=_metadata(),
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
selected_services = [req["function_name"] for req in readme_requests]
|
||||
progress_callback(
|
||||
DiscoveryProgressEvent(
|
||||
service_names=selected_services if selected_services else None,
|
||||
status="complete",
|
||||
)
|
||||
)
|
||||
|
||||
# If LLM asked a question or made a recommendation without readme requests
|
||||
if discovery_message and not readme_requests:
|
||||
history.append(_user_message(user_request))
|
||||
history.append(
|
||||
_assistant_message(discovery_message, mode=ASSISTANT_MODE_DISCOVERY)
|
||||
)
|
||||
|
||||
return ChatResult(
|
||||
next_action=None,
|
||||
proposed_instances=(),
|
||||
conversation_history=tuple(history),
|
||||
assistant_message=discovery_message,
|
||||
@@ -377,34 +429,28 @@ def process_chat_turn(
|
||||
session_state=_state_copy(),
|
||||
)
|
||||
|
||||
# If we got readme requests, continue to selecting services
|
||||
# If we got readme requests, save them and return next_action for readme fetch
|
||||
if readme_requests:
|
||||
# Stage 2: Fetch readmes
|
||||
service_names = [
|
||||
f"{req['function_name']} (from {req['input_name'] or 'built-in'})"
|
||||
for req in readme_requests
|
||||
]
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
ReadmeFetchProgressEvent(
|
||||
count=len(readme_requests),
|
||||
service_names=service_names,
|
||||
status="fetching",
|
||||
)
|
||||
)
|
||||
|
||||
readme_results = execute_readme_requests(readme_requests, flake)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
ReadmeFetchProgressEvent(
|
||||
count=len(readme_requests),
|
||||
service_names=service_names,
|
||||
status="complete",
|
||||
)
|
||||
)
|
||||
|
||||
return _continue_with_service_selection(readme_results)
|
||||
state["pending_readme_fetch"] = cast(
|
||||
"PendingReadmeFetchState",
|
||||
{"readme_requests": cast("list[dict[str, JSONValue]]", readme_requests)},
|
||||
)
|
||||
service_count = len(readme_requests)
|
||||
next_action_fetch: NextAction = {
|
||||
"type": "fetch_readmes",
|
||||
"description": f"Fetching documentation for {service_count} service(s)",
|
||||
"estimated_duration_seconds": 5,
|
||||
"details": {"service_count": service_count},
|
||||
}
|
||||
return ChatResult(
|
||||
next_action=next_action_fetch,
|
||||
proposed_instances=(),
|
||||
conversation_history=tuple(history),
|
||||
assistant_message="",
|
||||
requires_user_response=False,
|
||||
error=None,
|
||||
session_state=_state_copy(),
|
||||
)
|
||||
|
||||
# No readme requests and no message - unexpected
|
||||
msg = "LLM did not provide any response or recommendations"
|
||||
|
||||
@@ -85,7 +85,8 @@ def get_llm_discovery_phase(
|
||||
{"role": "assistant", "content": assistant_context},
|
||||
]
|
||||
messages.extend(_strip_conversation_metadata(conversation_history))
|
||||
messages.append(_user_message(user_request))
|
||||
if user_request:
|
||||
messages.append(_user_message(user_request))
|
||||
|
||||
# Call LLM with only get_readme tool
|
||||
model_config = get_model_config(provider)
|
||||
@@ -233,7 +234,8 @@ def get_llm_service_selection(
|
||||
{"role": "assistant", "content": combined_assistant_context},
|
||||
]
|
||||
messages.extend(_strip_conversation_metadata(conversation_history))
|
||||
messages.append(_user_message(user_request))
|
||||
if user_request:
|
||||
messages.append(_user_message(user_request))
|
||||
|
||||
model_config = get_model_config(provider)
|
||||
|
||||
@@ -430,8 +432,8 @@ def get_llm_final_decision(
|
||||
{"role": "assistant", "content": combined_assistant_context},
|
||||
]
|
||||
messages.extend(_strip_conversation_metadata(conversation_history))
|
||||
|
||||
messages.append(_user_message(user_request))
|
||||
if user_request:
|
||||
messages.append(_user_message(user_request))
|
||||
|
||||
# Get full schemas
|
||||
model_config = get_model_config(provider)
|
||||
|
||||
@@ -66,18 +66,28 @@ class ChatMessage(TypedDict):
|
||||
ConversationHistory = list[ChatMessage]
|
||||
|
||||
|
||||
class PendingFinalDecisionState(TypedDict, total=False):
|
||||
service_name: NotRequired[str]
|
||||
service_summary: NotRequired[str]
|
||||
class PendingDiscoveryState(TypedDict, total=False):
|
||||
user_request: NotRequired[str]
|
||||
|
||||
|
||||
class PendingReadmeFetchState(TypedDict, total=False):
|
||||
readme_requests: NotRequired[list[dict[str, JSONValue]]]
|
||||
|
||||
|
||||
class PendingServiceSelectionState(TypedDict, total=False):
|
||||
readme_results: NotRequired[list[dict[str, JSONValue]]]
|
||||
|
||||
|
||||
class PendingFinalDecisionState(TypedDict, total=False):
|
||||
service_name: NotRequired[str]
|
||||
service_summary: NotRequired[str]
|
||||
|
||||
|
||||
class SessionState(TypedDict, total=False):
|
||||
pending_final_decision: NotRequired[PendingFinalDecisionState]
|
||||
pending_discovery: NotRequired[PendingDiscoveryState]
|
||||
pending_readme_fetch: NotRequired[PendingReadmeFetchState]
|
||||
pending_service_selection: NotRequired[PendingServiceSelectionState]
|
||||
pending_final_decision: NotRequired[PendingFinalDecisionState]
|
||||
|
||||
|
||||
class JSONSchemaProperty(TypedDict, total=False):
|
||||
|
||||
@@ -16,16 +16,12 @@ from clan_lib.llm.endpoints import (
|
||||
parse_ollama_response,
|
||||
parse_openai_response,
|
||||
)
|
||||
from clan_lib.llm.llm import (
|
||||
DiscoveryProgressEvent,
|
||||
FinalDecisionProgressEvent,
|
||||
ReadmeFetchProgressEvent,
|
||||
ServiceSelectionProgressEvent,
|
||||
ServiceSelectionResult,
|
||||
from clan_lib.llm.llm_types import ServiceSelectionResult
|
||||
from clan_lib.llm.orchestrator import get_llm_turn
|
||||
from clan_lib.llm.phases import (
|
||||
execute_readme_requests,
|
||||
get_llm_final_decision,
|
||||
get_llm_service_selection,
|
||||
process_chat_turn,
|
||||
)
|
||||
from clan_lib.llm.schemas import (
|
||||
AiAggregate,
|
||||
@@ -37,9 +33,44 @@ from clan_lib.llm.schemas import (
|
||||
from clan_lib.services.modules import ServiceReadmeCollection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from clan_lib.llm.llm_types import ChatResult
|
||||
from clan_lib.llm.schemas import ChatMessage
|
||||
|
||||
|
||||
def execute_multi_turn_workflow(
|
||||
user_request: str,
|
||||
flake: Flake,
|
||||
conversation_history: list["ChatMessage"] | None = None,
|
||||
provider: str = "claude",
|
||||
session_state: SessionState | None = None,
|
||||
) -> "ChatResult":
|
||||
"""Execute the multi-turn workflow, auto-executing all pending operations.
|
||||
|
||||
This simulates the behavior of the CLI auto-execute loop in workflow.py.
|
||||
"""
|
||||
result = get_llm_turn(
|
||||
user_request=user_request,
|
||||
flake=flake,
|
||||
conversation_history=conversation_history,
|
||||
provider=provider, # type: ignore[arg-type]
|
||||
session_state=session_state,
|
||||
execute_next_action=False,
|
||||
)
|
||||
|
||||
# Auto-execute any pending operations
|
||||
while result.next_action:
|
||||
result = get_llm_turn(
|
||||
user_request="",
|
||||
flake=flake,
|
||||
conversation_history=list(result.conversation_history),
|
||||
provider=provider, # type: ignore[arg-type]
|
||||
session_state=result.session_state,
|
||||
execute_next_action=True,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trace_data() -> list[dict[str, Any]]:
|
||||
"""Load trace data from mytrace.json."""
|
||||
@@ -181,17 +212,23 @@ class TestProcessChatTurn:
|
||||
# Mock final decision (shouldn't be called, but mock it anyway for safety)
|
||||
mock_final.return_value = ([], "")
|
||||
|
||||
# Run process_chat_turn
|
||||
result = process_chat_turn(
|
||||
# Run multi-turn workflow
|
||||
result = execute_multi_turn_workflow(
|
||||
user_request="What VPNs are available?",
|
||||
flake=mock_flake,
|
||||
conversation_history=None,
|
||||
provider="claude",
|
||||
)
|
||||
|
||||
# Verify the call was made
|
||||
# Verify the discovery call was made
|
||||
assert mock_call.called
|
||||
|
||||
# Verify readme execution was called
|
||||
assert mock_execute.called
|
||||
|
||||
# Verify service selection was called
|
||||
assert mock_selection.called
|
||||
|
||||
# Final decision should NOT be called since we return early with clarifying message
|
||||
assert not mock_final.called
|
||||
|
||||
@@ -262,8 +299,8 @@ class TestProcessChatTurn:
|
||||
final_trace["response"]["message"],
|
||||
)
|
||||
|
||||
# Run process_chat_turn with session state
|
||||
result = process_chat_turn(
|
||||
# Run multi-turn workflow with session state
|
||||
result = execute_multi_turn_workflow(
|
||||
user_request="Hmm zerotier please",
|
||||
flake=mock_flake,
|
||||
conversation_history=conversation_history,
|
||||
@@ -271,6 +308,12 @@ class TestProcessChatTurn:
|
||||
session_state=session_state,
|
||||
)
|
||||
|
||||
# Verify service selection was called
|
||||
assert mock_selection.called
|
||||
|
||||
# Verify final decision was called
|
||||
assert mock_final.called
|
||||
|
||||
# Verify the result
|
||||
assert result.requires_user_response is True
|
||||
assert "controller" in result.assistant_message.lower()
|
||||
@@ -335,8 +378,8 @@ class TestProcessChatTurn:
|
||||
"",
|
||||
)
|
||||
|
||||
# Run process_chat_turn
|
||||
result = process_chat_turn(
|
||||
# Run multi-turn workflow
|
||||
result = execute_multi_turn_workflow(
|
||||
user_request="okay then gchq-local as controller and qube-email as moon please everything else as peer",
|
||||
flake=mock_flake,
|
||||
conversation_history=conversation_history,
|
||||
@@ -344,6 +387,9 @@ class TestProcessChatTurn:
|
||||
session_state=session_state,
|
||||
)
|
||||
|
||||
# Verify final decision was called
|
||||
assert mock_final.called
|
||||
|
||||
# Verify the result
|
||||
assert result.requires_user_response is False
|
||||
assert len(result.proposed_instances) == 1
|
||||
@@ -394,19 +440,19 @@ class TestProcessChatTurn:
|
||||
)
|
||||
mock_final.return_value = ([], "")
|
||||
|
||||
result1 = process_chat_turn(
|
||||
result1 = execute_multi_turn_workflow(
|
||||
user_request="What VPNs are available?",
|
||||
flake=mock_flake,
|
||||
provider="claude",
|
||||
)
|
||||
|
||||
# Verify final decision was not called
|
||||
# Verify final decision was not called (since we get clarifying message)
|
||||
assert not mock_final.called
|
||||
|
||||
# Verify discovery completed and moved to service selection
|
||||
# Verify discovery completed and service selection asked clarifying question
|
||||
assert result1.requires_user_response is True
|
||||
assert "VPN" in result1.assistant_message
|
||||
# Session state should have pending_service_selection
|
||||
# Session state should have pending_service_selection (with readme results)
|
||||
assert "pending_service_selection" in result1.session_state
|
||||
|
||||
# Test Turn 2: Continue with session state
|
||||
@@ -425,7 +471,7 @@ class TestProcessChatTurn:
|
||||
)
|
||||
mock_final.return_value = ([], trace_data[3]["response"]["message"])
|
||||
|
||||
result2 = process_chat_turn(
|
||||
result2 = execute_multi_turn_workflow(
|
||||
user_request="Hmm zerotier please",
|
||||
flake=mock_flake,
|
||||
conversation_history=list(result1.conversation_history),
|
||||
@@ -485,7 +531,7 @@ class TestProcessChatTurn:
|
||||
# Return empty function_calls but with a clarifying message
|
||||
mock_final.return_value = ([], clarify_trace["response"]["message"])
|
||||
|
||||
result = process_chat_turn(
|
||||
result = execute_multi_turn_workflow(
|
||||
user_request="Set up zerotier with gchq-local as controller",
|
||||
flake=mock_flake,
|
||||
conversation_history=conversation_history,
|
||||
@@ -535,7 +581,7 @@ class TestProcessChatTurn:
|
||||
]
|
||||
mock_final.return_value = ([], "")
|
||||
|
||||
result = process_chat_turn(
|
||||
result = execute_multi_turn_workflow(
|
||||
user_request="I want to set up a VPN",
|
||||
flake=mock_flake,
|
||||
provider="claude",
|
||||
@@ -624,7 +670,7 @@ class TestProcessChatTurn:
|
||||
]
|
||||
)
|
||||
|
||||
result = process_chat_turn(
|
||||
result = execute_multi_turn_workflow(
|
||||
user_request="Use zerotier with gchq-local as controller, qube-email as moon, rest as peers",
|
||||
flake=mock_flake,
|
||||
conversation_history=conversation_history,
|
||||
@@ -632,6 +678,12 @@ class TestProcessChatTurn:
|
||||
session_state=session_state,
|
||||
)
|
||||
|
||||
# Verify service selection was called
|
||||
assert mock_selection.called
|
||||
|
||||
# Verify final decision was called
|
||||
assert mock_final.called
|
||||
|
||||
# Verify the function_calls branch in _continue_with_service_selection
|
||||
assert result.requires_user_response is False
|
||||
assert len(result.proposed_instances) == 1
|
||||
@@ -1004,15 +1056,16 @@ class TestProcessChatTurnPendingFinalDecision:
|
||||
response = create_openai_response([], clarify_trace["response"]["message"])
|
||||
mock_call.return_value = response
|
||||
|
||||
result = process_chat_turn(
|
||||
result = get_llm_turn(
|
||||
user_request="gchq-local as controller",
|
||||
flake=mock_flake,
|
||||
conversation_history=conversation_history,
|
||||
provider="claude",
|
||||
session_state=session_state,
|
||||
execute_next_action=True, # Execute the pending final decision
|
||||
)
|
||||
|
||||
# Verify the if final_message branch at line 425 was taken
|
||||
# Verify the if final_message branch was taken
|
||||
assert result.requires_user_response is True
|
||||
assert result.assistant_message == clarify_trace["response"]["message"]
|
||||
|
||||
@@ -1074,12 +1127,13 @@ class TestProcessChatTurnPendingFinalDecision:
|
||||
response = create_openai_response(function_calls, "")
|
||||
mock_call.return_value = response
|
||||
|
||||
result = process_chat_turn(
|
||||
result = get_llm_turn(
|
||||
user_request="gchq-local as controller, qube-email as moon, rest as peers",
|
||||
flake=mock_flake,
|
||||
conversation_history=conversation_history,
|
||||
provider="claude",
|
||||
session_state=session_state,
|
||||
execute_next_action=True, # Execute the pending final decision
|
||||
)
|
||||
|
||||
# Verify configuration completed
|
||||
@@ -1094,186 +1148,6 @@ class TestProcessChatTurnPendingFinalDecision:
|
||||
assert result.error is None
|
||||
|
||||
|
||||
class TestProgressCallbacks:
|
||||
"""Test progress_callback functionality in process_chat_turn."""
|
||||
|
||||
def test_progress_callback_during_readme_fetch(
|
||||
self, trace_data: list[dict[str, Any]], mock_flake: MagicMock
|
||||
) -> None:
|
||||
"""Test that progress_callback is called during README fetching."""
|
||||
# Use trace entry with README requests
|
||||
discovery_trace = trace_data[0]
|
||||
function_calls = discovery_trace["response"]["function_calls"]
|
||||
assert len(function_calls) > 0
|
||||
|
||||
# Track progress events
|
||||
progress_events: list[Any] = []
|
||||
|
||||
def track_progress(event: Any) -> None:
|
||||
progress_events.append(event)
|
||||
|
||||
# Create response with get_readme calls
|
||||
response = create_openai_response(function_calls, "")
|
||||
|
||||
with (
|
||||
patch("clan_lib.llm.phases.call_claude_api", return_value=response),
|
||||
patch("clan_lib.llm.orchestrator.execute_readme_requests") as mock_execute,
|
||||
patch(
|
||||
"clan_lib.llm.orchestrator.get_llm_service_selection"
|
||||
) as mock_selection,
|
||||
patch("clan_lib.llm.orchestrator.get_llm_final_decision") as mock_final,
|
||||
):
|
||||
mock_execute.return_value = {
|
||||
None: ServiceReadmeCollection(
|
||||
input_name=None,
|
||||
readmes={
|
||||
"wireguard": "# WireGuard README",
|
||||
"zerotier": "# ZeroTier README",
|
||||
"mycelium": "# Mycelium README",
|
||||
"yggdrasil": "# Yggdrasil README",
|
||||
},
|
||||
)
|
||||
}
|
||||
mock_selection.return_value = ServiceSelectionResult(
|
||||
selected_service=None,
|
||||
service_summary=None,
|
||||
clarifying_message=trace_data[1]["response"]["message"],
|
||||
)
|
||||
mock_final.return_value = ([], "")
|
||||
|
||||
result = process_chat_turn(
|
||||
user_request="What VPNs are available?",
|
||||
flake=mock_flake,
|
||||
provider="claude",
|
||||
progress_callback=track_progress,
|
||||
)
|
||||
|
||||
# Verify final decision was not called
|
||||
assert not mock_final.called
|
||||
|
||||
# Verify progress events were sent
|
||||
assert len(progress_events) > 0
|
||||
|
||||
# Check for discovery progress events
|
||||
discovery_events = [
|
||||
e for e in progress_events if isinstance(e, DiscoveryProgressEvent)
|
||||
]
|
||||
assert len(discovery_events) >= 2 # At least start and complete
|
||||
|
||||
# Check for readme fetch progress events
|
||||
fetch_events = [
|
||||
e for e in progress_events if isinstance(e, ReadmeFetchProgressEvent)
|
||||
]
|
||||
assert len(fetch_events) >= 2 # fetching and complete
|
||||
|
||||
# Verify the fetching event has correct data
|
||||
fetching_event = next(e for e in fetch_events if e.status == "fetching")
|
||||
assert fetching_event.count == len(function_calls)
|
||||
# Service names include "(from built-in)" or "(from <input>)" suffix
|
||||
assert any("wireguard" in name for name in fetching_event.service_names)
|
||||
|
||||
# Verify the complete event
|
||||
complete_event = next(e for e in fetch_events if e.status == "complete")
|
||||
assert complete_event.count == len(function_calls)
|
||||
|
||||
# Result should still be successful
|
||||
assert result.requires_user_response is True
|
||||
|
||||
def test_progress_callback_through_full_workflow(
|
||||
self, trace_data: list[dict[str, Any]], mock_flake: MagicMock
|
||||
) -> None:
|
||||
"""Test progress_callback through entire workflow from discovery to config."""
|
||||
progress_events: list[Any] = []
|
||||
|
||||
def track_progress(event: Any) -> None:
|
||||
progress_events.append(event)
|
||||
|
||||
# Setup for full workflow
|
||||
discovery_response = create_openai_response(
|
||||
trace_data[0]["response"]["function_calls"],
|
||||
trace_data[0]["response"]["message"],
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"clan_lib.llm.phases.call_claude_api", return_value=discovery_response
|
||||
),
|
||||
patch("clan_lib.llm.orchestrator.execute_readme_requests") as mock_execute,
|
||||
patch(
|
||||
"clan_lib.llm.orchestrator.get_llm_service_selection"
|
||||
) as mock_selection,
|
||||
patch("clan_lib.llm.orchestrator.get_llm_final_decision") as mock_final,
|
||||
patch("clan_lib.llm.phases.aggregate_ollama_function_schemas") as mock_agg,
|
||||
):
|
||||
mock_execute.return_value = {
|
||||
None: ServiceReadmeCollection(
|
||||
input_name=None, readmes={"zerotier": "# ZeroTier README"}
|
||||
)
|
||||
}
|
||||
mock_selection.return_value = ServiceSelectionResult(
|
||||
selected_service="zerotier",
|
||||
service_summary="ZeroTier mesh VPN",
|
||||
clarifying_message="",
|
||||
)
|
||||
# Return configuration
|
||||
final_trace = trace_data[-1]
|
||||
mock_final.return_value = (
|
||||
[
|
||||
FunctionCallType(
|
||||
id="call_0",
|
||||
call_id="call_0",
|
||||
type="function_call",
|
||||
name="zerotier",
|
||||
arguments=json.dumps(
|
||||
final_trace["response"]["function_calls"][0]["arguments"]
|
||||
),
|
||||
)
|
||||
],
|
||||
"",
|
||||
)
|
||||
mock_agg.return_value = MagicMock(
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "zerotier", "description": "ZeroTier VPN"},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
result = process_chat_turn(
|
||||
user_request="Setup zerotier with gchq-local as controller",
|
||||
flake=mock_flake,
|
||||
provider="claude",
|
||||
progress_callback=track_progress,
|
||||
)
|
||||
|
||||
# Verify we got progress events for all phases
|
||||
discovery_events = [
|
||||
e for e in progress_events if isinstance(e, DiscoveryProgressEvent)
|
||||
]
|
||||
fetch_events = [
|
||||
e for e in progress_events if isinstance(e, ReadmeFetchProgressEvent)
|
||||
]
|
||||
selection_events = [
|
||||
e
|
||||
for e in progress_events
|
||||
if isinstance(e, ServiceSelectionProgressEvent)
|
||||
]
|
||||
final_events = [
|
||||
e for e in progress_events if isinstance(e, FinalDecisionProgressEvent)
|
||||
]
|
||||
|
||||
# Should have events from all phases
|
||||
assert len(discovery_events) > 0
|
||||
assert len(fetch_events) > 0
|
||||
assert len(selection_events) > 0
|
||||
assert len(final_events) > 0
|
||||
|
||||
# Result should be successful with config
|
||||
assert result.requires_user_response is False
|
||||
assert len(result.proposed_instances) == 1
|
||||
|
||||
|
||||
class TestErrorCases:
|
||||
"""Test error handling in process_chat_turn."""
|
||||
|
||||
@@ -1288,7 +1162,8 @@ class TestErrorCases:
|
||||
patch("clan_lib.llm.phases.call_claude_api", return_value=response),
|
||||
pytest.raises(ClanAiError, match="did not provide any response"),
|
||||
):
|
||||
process_chat_turn(
|
||||
# Use multi-turn workflow to execute through discovery
|
||||
execute_multi_turn_workflow(
|
||||
user_request="Setup a VPN",
|
||||
flake=mock_flake,
|
||||
provider="claude",
|
||||
@@ -1304,7 +1179,8 @@ class TestErrorCases:
|
||||
),
|
||||
pytest.raises(ValueError, match="Test error"),
|
||||
):
|
||||
process_chat_turn(
|
||||
# Use multi-turn workflow to execute through discovery
|
||||
execute_multi_turn_workflow(
|
||||
user_request="Setup a VPN",
|
||||
flake=mock_flake,
|
||||
provider="claude",
|
||||
@@ -1326,86 +1202,14 @@ class TestErrorCases:
|
||||
),
|
||||
pytest.raises(RuntimeError, match="Network error"),
|
||||
):
|
||||
process_chat_turn(
|
||||
# Use multi-turn workflow to execute through discovery
|
||||
execute_multi_turn_workflow(
|
||||
user_request="Setup zerotier",
|
||||
flake=mock_flake,
|
||||
conversation_history=conversation_history,
|
||||
provider="claude",
|
||||
)
|
||||
|
||||
def test_progress_callback_final_decision_reviewing_and_complete(
|
||||
self, trace_data: list[dict[str, Any]], mock_flake: MagicMock
|
||||
) -> None:
|
||||
"""Test FinalDecisionProgressEvent with reviewing and complete statuses."""
|
||||
progress_events: list[Any] = []
|
||||
|
||||
def track_progress(event: Any) -> None:
|
||||
progress_events.append(event)
|
||||
|
||||
# Build conversation history and session state for pending_final_decision
|
||||
conversation_history: list[ChatMessage] = [
|
||||
{"role": "user", "content": "Setup VPN"},
|
||||
{"role": "assistant", "content": "Which service?"},
|
||||
{"role": "user", "content": "Use zerotier"},
|
||||
{"role": "assistant", "content": "Which machine as controller?"},
|
||||
]
|
||||
|
||||
session_state: SessionState = cast(
|
||||
"SessionState",
|
||||
{
|
||||
"pending_final_decision": {
|
||||
"service_name": "zerotier",
|
||||
"service_summary": "ZeroTier mesh VPN",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Use final trace with configuration
|
||||
final_trace = trace_data[-1]
|
||||
function_calls = final_trace["response"]["function_calls"]
|
||||
|
||||
with (
|
||||
patch("clan_lib.llm.phases.aggregate_ollama_function_schemas") as mock_agg,
|
||||
patch("clan_lib.llm.phases.call_claude_api") as mock_call,
|
||||
):
|
||||
mock_agg.return_value = MagicMock(
|
||||
tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {"name": "zerotier", "description": "ZeroTier VPN"},
|
||||
}
|
||||
]
|
||||
)
|
||||
response = create_openai_response(function_calls, "")
|
||||
mock_call.return_value = response
|
||||
|
||||
result = process_chat_turn(
|
||||
user_request="gchq-local as controller, qube-email as moon, rest as peers",
|
||||
flake=mock_flake,
|
||||
conversation_history=conversation_history,
|
||||
provider="claude",
|
||||
session_state=session_state,
|
||||
progress_callback=track_progress,
|
||||
)
|
||||
|
||||
# Verify we got FinalDecisionProgressEvent with both statuses
|
||||
final_events = [
|
||||
e for e in progress_events if isinstance(e, FinalDecisionProgressEvent)
|
||||
]
|
||||
assert len(final_events) >= 2
|
||||
|
||||
# Check for "reviewing" status
|
||||
reviewing_events = [e for e in final_events if e.status == "reviewing"]
|
||||
assert len(reviewing_events) >= 1
|
||||
|
||||
# Check for "complete" status
|
||||
complete_events = [e for e in final_events if e.status == "complete"]
|
||||
assert len(complete_events) >= 1
|
||||
|
||||
# Result should be successful
|
||||
assert result.requires_user_response is False
|
||||
assert len(result.proposed_instances) == 1
|
||||
|
||||
def test_service_selection_fails_no_service_selected(
|
||||
self, mock_flake: MagicMock
|
||||
) -> None:
|
||||
@@ -1443,7 +1247,8 @@ class TestErrorCases:
|
||||
|
||||
# Should raise ClanAiError
|
||||
with pytest.raises(ClanAiError, match="Failed to select service"):
|
||||
process_chat_turn(
|
||||
# Use multi-turn workflow to execute through service selection
|
||||
execute_multi_turn_workflow(
|
||||
user_request="Setup VPN",
|
||||
flake=mock_flake,
|
||||
provider="claude",
|
||||
@@ -1680,7 +1485,8 @@ class TestGetLlmFinalDecisionErrors:
|
||||
|
||||
# Should raise ClanAiError
|
||||
with pytest.raises(ClanAiError, match="LLM did not provide any response"):
|
||||
process_chat_turn(
|
||||
# Use multi-turn workflow to execute through final decision
|
||||
execute_multi_turn_workflow(
|
||||
user_request="gchq-local as controller",
|
||||
flake=mock_flake,
|
||||
conversation_history=conversation_history,
|
||||
|
||||
Reference in New Issue
Block a user