Merge pull request 'checks: Fix flakey llm test, improve performance' (#5678) from Qubasa/clan-core:fix_slow_llm into main
Reviewed-on: https://git.clan.lol/clan/clan-core/pulls/5678
This commit is contained in:
@@ -5,15 +5,18 @@ from pathlib import Path
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import clan_lib.llm.llm_types
|
||||||
import pytest
|
import pytest
|
||||||
from clan_lib.flake.flake import Flake
|
from clan_lib.flake.flake import Flake
|
||||||
|
from clan_lib.llm.llm_types import ModelConfig
|
||||||
from clan_lib.llm.orchestrator import get_llm_turn
|
from clan_lib.llm.orchestrator import get_llm_turn
|
||||||
from clan_lib.llm.service import create_llm_model, run_llm_service
|
from clan_lib.llm.service import create_llm_model, run_llm_service
|
||||||
from clan_lib.service_runner import create_service_manager
|
from clan_lib.service_runner import create_service_manager
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from clan_lib.llm.llm_types import ChatResult
|
from clan_lib.llm.llm_types import ChatResult
|
||||||
from clan_lib.llm.schemas import ChatMessage, SessionState
|
from clan_lib.llm.schemas import SessionState
|
||||||
|
import platform
|
||||||
|
|
||||||
|
|
||||||
def get_current_mode(session_state: "SessionState") -> str:
|
def get_current_mode(session_state: "SessionState") -> str:
|
||||||
@@ -168,28 +171,80 @@ def llm_service() -> Iterator[None]:
|
|||||||
service_manager.stop_service("ollama")
|
service_manager.stop_service("ollama")
|
||||||
|
|
||||||
|
|
||||||
def execute_multi_turn_workflow(
|
@pytest.mark.service_runner
|
||||||
user_request: str,
|
@pytest.mark.usefixtures("mock_nix_shell", "llm_service")
|
||||||
flake: Flake | MagicMock,
|
def test_full_conversation_flow(mock_flake: MagicMock) -> None:
|
||||||
conversation_history: list["ChatMessage"] | None = None,
|
"""Test the complete conversation flow by manually calling get_llm_turn at each step.
|
||||||
provider: str = "ollama",
|
|
||||||
session_state: "SessionState | None" = None,
|
|
||||||
) -> "ChatResult":
|
|
||||||
"""Execute the multi-turn workflow, auto-executing all pending operations.
|
|
||||||
|
|
||||||
This simulates the behavior of the CLI auto-execute loop in workflow.py.
|
This test verifies:
|
||||||
|
- State transitions through discovery -> readme_fetch -> service_selection -> final_decision
|
||||||
|
- Each step returns the correct next_action
|
||||||
|
- Conversation history is preserved across turns
|
||||||
|
- Session state is correctly maintained
|
||||||
"""
|
"""
|
||||||
|
flake = mock_flake
|
||||||
|
trace_file = Path("~/.ollama/container_test_llm_trace.json").expanduser()
|
||||||
|
trace_file.unlink(missing_ok=True) # Start fresh
|
||||||
|
provider = "ollama"
|
||||||
|
|
||||||
|
# Override DEFAULT_MODELS with 4-minute timeouts for container tests
|
||||||
|
clan_lib.llm.llm_types.DEFAULT_MODELS = {
|
||||||
|
"ollama": ModelConfig(
|
||||||
|
name="qwen3:4b-instruct",
|
||||||
|
provider="ollama",
|
||||||
|
timeout=300, # set inference timeout to 5 minutes as CI may be slow
|
||||||
|
temperature=0, # set randomness to 0 for consistent test results
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# ========== STEP 1: Initial request (should return next_action for discovery) ==========
|
||||||
|
print_separator("STEP 1: Initial Request", char="=", width=80)
|
||||||
result = get_llm_turn(
|
result = get_llm_turn(
|
||||||
user_request=user_request,
|
user_request="What VPN options do I have?",
|
||||||
flake=flake,
|
flake=flake,
|
||||||
conversation_history=conversation_history,
|
|
||||||
provider=provider, # type: ignore[arg-type]
|
provider=provider, # type: ignore[arg-type]
|
||||||
session_state=session_state,
|
|
||||||
execute_next_action=False,
|
execute_next_action=False,
|
||||||
|
trace_file=trace_file,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Auto-execute any pending operations
|
# Should have next_action for discovery phase
|
||||||
while result.next_action:
|
assert result.next_action is not None, "Should have next_action for discovery"
|
||||||
|
assert result.next_action["type"] == "discovery"
|
||||||
|
assert result.requires_user_response is False
|
||||||
|
assert len(result.proposed_instances) == 0
|
||||||
|
assert "pending_discovery" in result.session_state
|
||||||
|
print(f" Next Action: {result.next_action['type']}")
|
||||||
|
print(f" Description: {result.next_action['description']}")
|
||||||
|
print_meta_info(result, turn=1, phase="Initial Request")
|
||||||
|
|
||||||
|
# ========== STEP 2: Execute discovery (should return next_action for readme_fetch) ==========
|
||||||
|
print_separator("STEP 2: Execute Discovery", char="=", width=80)
|
||||||
|
result = get_llm_turn(
|
||||||
|
user_request="",
|
||||||
|
flake=flake,
|
||||||
|
conversation_history=list(result.conversation_history),
|
||||||
|
provider=provider, # type: ignore[arg-type]
|
||||||
|
session_state=result.session_state,
|
||||||
|
execute_next_action=True,
|
||||||
|
trace_file=trace_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have next_action for readme fetch OR a clarifying question
|
||||||
|
if result.next_action:
|
||||||
|
assert result.next_action["type"] == "fetch_readmes"
|
||||||
|
assert "pending_readme_fetch" in result.session_state
|
||||||
|
print(f" Next Action: {result.next_action['type']}")
|
||||||
|
print(f" Description: {result.next_action['description']}")
|
||||||
|
else:
|
||||||
|
# LLM asked a clarifying question
|
||||||
|
assert result.requires_user_response is True
|
||||||
|
assert len(result.assistant_message) > 0
|
||||||
|
print(f" Assistant Message: {result.assistant_message[:100]}...")
|
||||||
|
print_meta_info(result, turn=2, phase="Discovery Executed")
|
||||||
|
|
||||||
|
# ========== STEP 3: Execute readme fetch (if applicable) ==========
|
||||||
|
if result.next_action and result.next_action["type"] == "fetch_readmes":
|
||||||
|
print_separator("STEP 3: Execute Readme Fetch", char="=", width=80)
|
||||||
result = get_llm_turn(
|
result = get_llm_turn(
|
||||||
user_request="",
|
user_request="",
|
||||||
flake=flake,
|
flake=flake,
|
||||||
@@ -197,187 +252,74 @@ def execute_multi_turn_workflow(
|
|||||||
provider=provider, # type: ignore[arg-type]
|
provider=provider, # type: ignore[arg-type]
|
||||||
session_state=result.session_state,
|
session_state=result.session_state,
|
||||||
execute_next_action=True,
|
execute_next_action=True,
|
||||||
|
trace_file=trace_file,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
# Should have next_action for service selection
|
||||||
|
assert result.next_action is not None
|
||||||
|
assert result.next_action["type"] == "service_selection"
|
||||||
|
assert "pending_service_selection" in result.session_state
|
||||||
|
print(f" Next Action: {result.next_action['type']}")
|
||||||
|
print(f" Description: {result.next_action['description']}")
|
||||||
|
print_meta_info(result, turn=3, phase="Readme Fetch Executed")
|
||||||
|
|
||||||
|
if platform.machine() == "aarch64":
|
||||||
|
pytest.skip(
|
||||||
|
"aarch64 detected: skipping readme/service-selection and final step for performance reasons"
|
||||||
|
)
|
||||||
|
|
||||||
@pytest.mark.service_runner
|
# ========== STEP 4: Execute service selection ==========
|
||||||
@pytest.mark.usefixtures("mock_nix_shell", "llm_service")
|
print_separator("STEP 4: Execute Service Selection", char="=", width=80)
|
||||||
def test_full_conversation_flow(mock_flake: MagicMock) -> None:
|
result = get_llm_turn(
|
||||||
"""Comprehensive test that exercises the complete conversation flow with the actual LLM service.
|
user_request="I want ZeroTier.",
|
||||||
|
|
||||||
This test simulates a realistic multi-turn conversation that covers:
|
|
||||||
- Discovery phase: Initial request and LLM gathering information
|
|
||||||
- Service selection phase: User choosing from available options
|
|
||||||
- Final decision phase: Configuring the selected service with specific parameters
|
|
||||||
- State transitions: pending_service_selection -> pending_final_decision -> completion
|
|
||||||
- Conversation history preservation across all turns
|
|
||||||
- Error handling and edge cases
|
|
||||||
"""
|
|
||||||
flake = mock_flake
|
|
||||||
# ========== TURN 1: Discovery Phase - Initial vague request ==========
|
|
||||||
print_separator("TURN 1: Discovery Phase", char="=", width=80)
|
|
||||||
result = execute_multi_turn_workflow(
|
|
||||||
user_request="What VPN options do I have?",
|
|
||||||
flake=flake,
|
|
||||||
provider="ollama",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify discovery phase behavior
|
|
||||||
assert result.requires_user_response is True, (
|
|
||||||
"Should require user response in discovery"
|
|
||||||
)
|
|
||||||
assert len(result.conversation_history) >= 2, (
|
|
||||||
"Should have user + assistant messages"
|
|
||||||
)
|
|
||||||
assert result.conversation_history[0]["role"] == "user"
|
|
||||||
assert result.conversation_history[0]["content"] == "What VPN options do I have?"
|
|
||||||
assert result.conversation_history[-1]["role"] == "assistant"
|
|
||||||
assert len(result.assistant_message) > 0, "Assistant should provide a response"
|
|
||||||
|
|
||||||
# After multi-turn execution, we may have either:
|
|
||||||
# - pending_service_selection (if LLM provided options and is waiting for choice)
|
|
||||||
# - pending_final_decision (if LLM directly selected a service)
|
|
||||||
# - no pending state (if LLM asked a clarifying question)
|
|
||||||
|
|
||||||
# No instances yet
|
|
||||||
assert len(result.proposed_instances) == 0
|
|
||||||
assert result.error is None
|
|
||||||
|
|
||||||
print_chat_exchange(
|
|
||||||
"What VPN options do I have?", result.assistant_message, result.session_state
|
|
||||||
)
|
|
||||||
print_meta_info(result, turn=1, phase="Discovery")
|
|
||||||
|
|
||||||
# ========== TURN 2: Service Selection Phase - User makes a choice ==========
|
|
||||||
print_separator("TURN 2: Service Selection", char="=", width=80)
|
|
||||||
user_msg_2 = "I'll use ZeroTier please"
|
|
||||||
result = execute_multi_turn_workflow(
|
|
||||||
user_request=user_msg_2,
|
|
||||||
flake=flake,
|
|
||||||
conversation_history=list(result.conversation_history),
|
|
||||||
provider="ollama",
|
|
||||||
session_state=result.session_state,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify conversation history growth and preservation
|
|
||||||
assert len(result.conversation_history) > 2, "History should grow"
|
|
||||||
assert result.conversation_history[0]["content"] == "What VPN options do I have?"
|
|
||||||
assert result.conversation_history[2]["content"] == "I'll use ZeroTier please"
|
|
||||||
|
|
||||||
# Should either ask for configuration details or provide direct config
|
|
||||||
# Most likely will ask for more details (pending_final_decision)
|
|
||||||
if result.requires_user_response:
|
|
||||||
# LLM is asking for configuration details
|
|
||||||
assert len(result.assistant_message) > 0
|
|
||||||
# Should transition to final decision phase
|
|
||||||
if "pending_final_decision" not in result.session_state:
|
|
||||||
# Might still be in service selection asking clarifications
|
|
||||||
assert "pending_service_selection" in result.session_state
|
|
||||||
else:
|
|
||||||
# LLM provided configuration immediately (less likely)
|
|
||||||
assert len(result.proposed_instances) > 0
|
|
||||||
assert result.proposed_instances[0]["module"]["name"] == "zerotier"
|
|
||||||
|
|
||||||
print_chat_exchange(user_msg_2, result.assistant_message, result.session_state)
|
|
||||||
print_meta_info(result, turn=2, phase="Service Selection")
|
|
||||||
|
|
||||||
# ========== Continue conversation until we reach final decision or completion ==========
|
|
||||||
max_turns = 10
|
|
||||||
turn_count = 2
|
|
||||||
|
|
||||||
while result.requires_user_response and turn_count < max_turns:
|
|
||||||
turn_count += 1
|
|
||||||
|
|
||||||
# Determine appropriate response based on current state
|
|
||||||
if "pending_service_selection" in result.session_state:
|
|
||||||
# Still selecting service
|
|
||||||
user_request = "Yes, ZeroTier"
|
|
||||||
phase = "Service Selection (continued)"
|
|
||||||
elif "pending_final_decision" in result.session_state:
|
|
||||||
# Configuring the service
|
|
||||||
user_request = "Set up gchq-local as controller, qube-email as moon, and wintux as peer"
|
|
||||||
phase = "Final Configuration"
|
|
||||||
else:
|
|
||||||
# Generic continuation
|
|
||||||
user_request = "Yes, that sounds good. Use gchq-local as controller."
|
|
||||||
phase = "Continuing Conversation"
|
|
||||||
|
|
||||||
print_separator(f"TURN {turn_count}: {phase}", char="=", width=80)
|
|
||||||
|
|
||||||
result = execute_multi_turn_workflow(
|
|
||||||
user_request=user_request,
|
|
||||||
flake=flake,
|
flake=flake,
|
||||||
conversation_history=list(result.conversation_history),
|
conversation_history=list(result.conversation_history),
|
||||||
provider="ollama",
|
provider=provider, # type: ignore[arg-type]
|
||||||
session_state=result.session_state,
|
session_state=result.session_state,
|
||||||
|
execute_next_action=True,
|
||||||
|
trace_file=trace_file,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify conversation history continues to grow
|
# Should either have next_action for final_decision OR a clarifying question
|
||||||
assert len(result.conversation_history) == (turn_count * 2), (
|
if result.next_action:
|
||||||
f"History should have {turn_count * 2} messages (turn {turn_count})"
|
assert result.next_action["type"] == "final_decision"
|
||||||
)
|
assert "pending_final_decision" in result.session_state
|
||||||
|
print(f" Next Action: {result.next_action['type']}")
|
||||||
|
print(f" Description: {result.next_action['description']}")
|
||||||
|
else:
|
||||||
|
# LLM asked a clarifying question during service selection
|
||||||
|
assert result.requires_user_response is True
|
||||||
|
assert len(result.assistant_message) > 0
|
||||||
|
print(f" Assistant Message: {result.assistant_message[:100]}...")
|
||||||
|
print_meta_info(result, turn=4, phase="Service Selection Executed")
|
||||||
|
|
||||||
# Verify history preservation
|
# ========== STEP 5: Execute final decision (if applicable) ==========
|
||||||
assert (
|
if result.next_action and result.next_action["type"] == "final_decision":
|
||||||
result.conversation_history[0]["content"] == "What VPN options do I have?"
|
print_separator("STEP 5: Execute Final Decision", char="=", width=80)
|
||||||
)
|
result = get_llm_turn(
|
||||||
|
user_request="",
|
||||||
|
flake=flake,
|
||||||
|
conversation_history=list(result.conversation_history),
|
||||||
|
provider=provider, # type: ignore[arg-type]
|
||||||
|
session_state=result.session_state,
|
||||||
|
execute_next_action=True,
|
||||||
|
trace_file=trace_file,
|
||||||
|
)
|
||||||
|
|
||||||
print_chat_exchange(
|
# Should either have proposed_instances OR ask a clarifying question
|
||||||
user_request, result.assistant_message, result.session_state
|
if result.proposed_instances:
|
||||||
)
|
assert len(result.proposed_instances) > 0
|
||||||
print_meta_info(result, turn=turn_count, phase=phase)
|
assert result.next_action is None
|
||||||
|
print(f" Proposed Instances: {len(result.proposed_instances)}")
|
||||||
|
for inst in result.proposed_instances:
|
||||||
|
print(f" - {inst['module']['name']}")
|
||||||
|
else:
|
||||||
|
# LLM asked a clarifying question
|
||||||
|
assert result.requires_user_response is True
|
||||||
|
assert len(result.assistant_message) > 0
|
||||||
|
print(f" Assistant Message: {result.assistant_message[:100]}...")
|
||||||
|
print_meta_info(result, turn=5, phase="Final Decision Executed")
|
||||||
|
|
||||||
# Check for completion
|
# Verify conversation history has grown
|
||||||
if not result.requires_user_response:
|
assert len(result.conversation_history) > 0
|
||||||
print_separator("CONVERSATION COMPLETED", char="=", width=80)
|
assert result.conversation_history[0]["content"] == "What VPN options do I have?"
|
||||||
break
|
|
||||||
|
|
||||||
# ========== Final Verification ==========
|
|
||||||
print_separator("FINAL VERIFICATION", char="=", width=80)
|
|
||||||
assert turn_count < max_turns, f"Conversation took too many turns ({turn_count})"
|
|
||||||
|
|
||||||
# If conversation completed, verify we have valid configuration
|
|
||||||
if not result.requires_user_response:
|
|
||||||
assert len(result.proposed_instances) > 0, (
|
|
||||||
"Should have at least one proposed instance"
|
|
||||||
)
|
|
||||||
instance = result.proposed_instances[0]
|
|
||||||
|
|
||||||
# Verify instance structure
|
|
||||||
assert "module" in instance
|
|
||||||
assert "name" in instance["module"]
|
|
||||||
assert instance["module"]["name"] in [
|
|
||||||
"zerotier",
|
|
||||||
"wireguard",
|
|
||||||
"yggdrasil",
|
|
||||||
"mycelium",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Should not be in pending state anymore
|
|
||||||
assert "pending_service_selection" not in result.session_state
|
|
||||||
assert "pending_final_decision" not in result.session_state
|
|
||||||
|
|
||||||
assert result.error is None, f"Should not have error: {result.error}"
|
|
||||||
|
|
||||||
print_separator("FINAL SUMMARY", char="-", width=80, double=False)
|
|
||||||
print(" Status: SUCCESS")
|
|
||||||
print(f" Module Name: {instance['module']['name']}")
|
|
||||||
print(f" Total Turns: {turn_count}")
|
|
||||||
print(f" Final History Length: {len(result.conversation_history)} messages")
|
|
||||||
if "roles" in instance:
|
|
||||||
roles_list = ", ".join(instance["roles"].keys())
|
|
||||||
print(f" Configuration Roles: {roles_list}")
|
|
||||||
print(" Errors: None")
|
|
||||||
print("-" * 80)
|
|
||||||
else:
|
|
||||||
# Conversation didn't complete but should have made progress
|
|
||||||
assert len(result.conversation_history) > 2
|
|
||||||
assert result.error is None
|
|
||||||
print_separator("FINAL SUMMARY", char="-", width=80, double=False)
|
|
||||||
print(" Status: IN PROGRESS")
|
|
||||||
print(f" Total Turns: {turn_count}")
|
|
||||||
print(f" Current State: {list(result.session_state.keys())}")
|
|
||||||
print(f" History Length: {len(result.conversation_history)} messages")
|
|
||||||
print("-" * 80)
|
|
||||||
|
|||||||
@@ -149,6 +149,7 @@ def call_openai_api(
|
|||||||
trace_file: Path | None = None,
|
trace_file: Path | None = None,
|
||||||
stage: str = "unknown",
|
stage: str = "unknown",
|
||||||
trace_metadata: dict[str, Any] | None = None,
|
trace_metadata: dict[str, Any] | None = None,
|
||||||
|
temperature: float | None = None,
|
||||||
) -> OpenAIChatCompletionResponse:
|
) -> OpenAIChatCompletionResponse:
|
||||||
"""Call the OpenAI API for chat completion.
|
"""Call the OpenAI API for chat completion.
|
||||||
|
|
||||||
@@ -160,6 +161,7 @@ def call_openai_api(
|
|||||||
trace_file: Optional path to write trace entries for debugging
|
trace_file: Optional path to write trace entries for debugging
|
||||||
stage: Stage name for trace entries (default: "unknown")
|
stage: Stage name for trace entries (default: "unknown")
|
||||||
trace_metadata: Optional metadata to include in trace entries
|
trace_metadata: Optional metadata to include in trace entries
|
||||||
|
temperature: Sampling temperature (default: None = use API default)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The parsed JSON response from the API
|
The parsed JSON response from the API
|
||||||
@@ -178,6 +180,8 @@ def call_openai_api(
|
|||||||
"messages": messages,
|
"messages": messages,
|
||||||
"tools": list(tools),
|
"tools": list(tools),
|
||||||
}
|
}
|
||||||
|
if temperature is not None:
|
||||||
|
payload["temperature"] = temperature
|
||||||
_debug_log_request("openai", messages, tools)
|
_debug_log_request("openai", messages, tools)
|
||||||
url = "https://api.openai.com/v1/chat/completions"
|
url = "https://api.openai.com/v1/chat/completions"
|
||||||
headers = {
|
headers = {
|
||||||
@@ -256,6 +260,7 @@ def call_claude_api(
|
|||||||
trace_file: Path | None = None,
|
trace_file: Path | None = None,
|
||||||
stage: str = "unknown",
|
stage: str = "unknown",
|
||||||
trace_metadata: dict[str, Any] | None = None,
|
trace_metadata: dict[str, Any] | None = None,
|
||||||
|
temperature: float | None = None,
|
||||||
) -> OpenAIChatCompletionResponse:
|
) -> OpenAIChatCompletionResponse:
|
||||||
"""Call the Claude API (via OpenAI-compatible endpoint) for chat completion.
|
"""Call the Claude API (via OpenAI-compatible endpoint) for chat completion.
|
||||||
|
|
||||||
@@ -268,6 +273,7 @@ def call_claude_api(
|
|||||||
trace_file: Optional path to write trace entries for debugging
|
trace_file: Optional path to write trace entries for debugging
|
||||||
stage: Stage name for trace entries (default: "unknown")
|
stage: Stage name for trace entries (default: "unknown")
|
||||||
trace_metadata: Optional metadata to include in trace entries
|
trace_metadata: Optional metadata to include in trace entries
|
||||||
|
temperature: Sampling temperature (default: None = use API default)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The parsed JSON response from the API
|
The parsed JSON response from the API
|
||||||
@@ -293,6 +299,8 @@ def call_claude_api(
|
|||||||
"messages": messages,
|
"messages": messages,
|
||||||
"tools": list(tools),
|
"tools": list(tools),
|
||||||
}
|
}
|
||||||
|
if temperature is not None:
|
||||||
|
payload["temperature"] = temperature
|
||||||
_debug_log_request("claude", messages, tools)
|
_debug_log_request("claude", messages, tools)
|
||||||
|
|
||||||
url = f"{base_url}chat/completions"
|
url = f"{base_url}chat/completions"
|
||||||
@@ -372,6 +380,7 @@ def call_ollama_api(
|
|||||||
stage: str = "unknown",
|
stage: str = "unknown",
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
trace_metadata: dict[str, Any] | None = None,
|
trace_metadata: dict[str, Any] | None = None,
|
||||||
|
temperature: float | None = None,
|
||||||
) -> OllamaChatResponse:
|
) -> OllamaChatResponse:
|
||||||
"""Call the Ollama API for chat completion.
|
"""Call the Ollama API for chat completion.
|
||||||
|
|
||||||
@@ -384,6 +393,7 @@ def call_ollama_api(
|
|||||||
stage: Stage name for trace entries (default: "unknown")
|
stage: Stage name for trace entries (default: "unknown")
|
||||||
max_tokens: Maximum number of tokens to generate (default: None = unlimited)
|
max_tokens: Maximum number of tokens to generate (default: None = unlimited)
|
||||||
trace_metadata: Optional metadata to include in trace entries
|
trace_metadata: Optional metadata to include in trace entries
|
||||||
|
temperature: Sampling temperature (default: None = use API default)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The parsed JSON response from the API
|
The parsed JSON response from the API
|
||||||
@@ -399,9 +409,14 @@ def call_ollama_api(
|
|||||||
"tools": list(tools),
|
"tools": list(tools),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add max_tokens limit if specified
|
# Add options for max_tokens and temperature if specified
|
||||||
|
options: dict[str, int | float] = {}
|
||||||
if max_tokens is not None:
|
if max_tokens is not None:
|
||||||
payload["options"] = {"num_predict": max_tokens} # type: ignore[typeddict-item]
|
options["num_predict"] = max_tokens
|
||||||
|
if temperature is not None:
|
||||||
|
options["temperature"] = temperature
|
||||||
|
if options:
|
||||||
|
payload["options"] = options # type: ignore[typeddict-item]
|
||||||
_debug_log_request("ollama", messages, tools)
|
_debug_log_request("ollama", messages, tools)
|
||||||
url = "http://localhost:11434/api/chat"
|
url = "http://localhost:11434/api/chat"
|
||||||
|
|
||||||
|
|||||||
@@ -73,19 +73,21 @@ class ModelConfig:
|
|||||||
name: The model identifier/name
|
name: The model identifier/name
|
||||||
provider: The LLM provider
|
provider: The LLM provider
|
||||||
timeout: Request timeout in seconds (default: 120)
|
timeout: Request timeout in seconds (default: 120)
|
||||||
|
temperature: Sampling temperature for the model (default: None = use API default)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
provider: Literal["openai", "ollama", "claude"]
|
provider: Literal["openai", "ollama", "claude"]
|
||||||
timeout: int = 120
|
timeout: int = 120
|
||||||
|
temperature: float | None = None
|
||||||
|
|
||||||
|
|
||||||
# Default model configurations for each provider
|
# Default model configurations for each provider
|
||||||
DEFAULT_MODELS: dict[Literal["openai", "ollama", "claude"], ModelConfig] = {
|
DEFAULT_MODELS: dict[Literal["openai", "ollama", "claude"], ModelConfig] = {
|
||||||
"openai": ModelConfig(name="gpt-4o", provider="openai", timeout=60),
|
"openai": ModelConfig(name="gpt-4o", provider="openai", timeout=60),
|
||||||
"claude": ModelConfig(name="claude-sonnet-4-5", provider="claude", timeout=60),
|
"claude": ModelConfig(name="claude-sonnet-4-5", provider="claude", timeout=60),
|
||||||
"ollama": ModelConfig(name="qwen3:4b-instruct", provider="ollama", timeout=120),
|
"ollama": ModelConfig(name="qwen3:4b-instruct", provider="ollama", timeout=180),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -100,6 +100,7 @@ def get_llm_discovery_phase(
|
|||||||
trace_file=trace_file,
|
trace_file=trace_file,
|
||||||
stage="discovery",
|
stage="discovery",
|
||||||
trace_metadata=trace_metadata,
|
trace_metadata=trace_metadata,
|
||||||
|
temperature=model_config.temperature,
|
||||||
)
|
)
|
||||||
function_calls, message_content = parse_openai_response(
|
function_calls, message_content = parse_openai_response(
|
||||||
openai_response, provider="openai"
|
openai_response, provider="openai"
|
||||||
@@ -113,6 +114,7 @@ def get_llm_discovery_phase(
|
|||||||
trace_file=trace_file,
|
trace_file=trace_file,
|
||||||
stage="discovery",
|
stage="discovery",
|
||||||
trace_metadata=trace_metadata,
|
trace_metadata=trace_metadata,
|
||||||
|
temperature=model_config.temperature,
|
||||||
)
|
)
|
||||||
function_calls, message_content = parse_openai_response(
|
function_calls, message_content = parse_openai_response(
|
||||||
claude_response, provider="claude"
|
claude_response, provider="claude"
|
||||||
@@ -127,6 +129,7 @@ def get_llm_discovery_phase(
|
|||||||
stage="discovery",
|
stage="discovery",
|
||||||
max_tokens=300, # Limit output for discovery phase (get_readme calls or short question)
|
max_tokens=300, # Limit output for discovery phase (get_readme calls or short question)
|
||||||
trace_metadata=trace_metadata,
|
trace_metadata=trace_metadata,
|
||||||
|
temperature=model_config.temperature,
|
||||||
)
|
)
|
||||||
function_calls, message_content = parse_ollama_response(
|
function_calls, message_content = parse_ollama_response(
|
||||||
ollama_response, provider="ollama"
|
ollama_response, provider="ollama"
|
||||||
@@ -249,6 +252,7 @@ def get_llm_service_selection(
|
|||||||
trace_file=trace_file,
|
trace_file=trace_file,
|
||||||
stage="select_service",
|
stage="select_service",
|
||||||
trace_metadata=trace_metadata,
|
trace_metadata=trace_metadata,
|
||||||
|
temperature=model_config.temperature,
|
||||||
)
|
)
|
||||||
function_calls, message_content = parse_openai_response(
|
function_calls, message_content = parse_openai_response(
|
||||||
openai_response, provider="openai"
|
openai_response, provider="openai"
|
||||||
@@ -262,6 +266,7 @@ def get_llm_service_selection(
|
|||||||
trace_file=trace_file,
|
trace_file=trace_file,
|
||||||
stage="select_service",
|
stage="select_service",
|
||||||
trace_metadata=trace_metadata,
|
trace_metadata=trace_metadata,
|
||||||
|
temperature=model_config.temperature,
|
||||||
)
|
)
|
||||||
function_calls, message_content = parse_openai_response(
|
function_calls, message_content = parse_openai_response(
|
||||||
claude_response, provider="claude"
|
claude_response, provider="claude"
|
||||||
@@ -276,6 +281,7 @@ def get_llm_service_selection(
|
|||||||
stage="select_service",
|
stage="select_service",
|
||||||
max_tokens=600, # Allow space for summary
|
max_tokens=600, # Allow space for summary
|
||||||
trace_metadata=trace_metadata,
|
trace_metadata=trace_metadata,
|
||||||
|
temperature=model_config.temperature,
|
||||||
)
|
)
|
||||||
function_calls, message_content = parse_ollama_response(
|
function_calls, message_content = parse_ollama_response(
|
||||||
ollama_response, provider="ollama"
|
ollama_response, provider="ollama"
|
||||||
@@ -447,6 +453,7 @@ def get_llm_final_decision(
|
|||||||
trace_file=trace_file,
|
trace_file=trace_file,
|
||||||
stage="final_decision",
|
stage="final_decision",
|
||||||
trace_metadata=trace_metadata,
|
trace_metadata=trace_metadata,
|
||||||
|
temperature=model_config.temperature,
|
||||||
)
|
)
|
||||||
function_calls, message_content = parse_openai_response(
|
function_calls, message_content = parse_openai_response(
|
||||||
openai_response, provider="openai"
|
openai_response, provider="openai"
|
||||||
@@ -462,6 +469,7 @@ def get_llm_final_decision(
|
|||||||
trace_file=trace_file,
|
trace_file=trace_file,
|
||||||
stage="final_decision",
|
stage="final_decision",
|
||||||
trace_metadata=trace_metadata,
|
trace_metadata=trace_metadata,
|
||||||
|
temperature=model_config.temperature,
|
||||||
)
|
)
|
||||||
function_calls, message_content = parse_openai_response(
|
function_calls, message_content = parse_openai_response(
|
||||||
claude_response, provider="claude"
|
claude_response, provider="claude"
|
||||||
@@ -477,6 +485,7 @@ def get_llm_final_decision(
|
|||||||
stage="final_decision",
|
stage="final_decision",
|
||||||
max_tokens=500, # Limit output to prevent excessive verbosity
|
max_tokens=500, # Limit output to prevent excessive verbosity
|
||||||
trace_metadata=trace_metadata,
|
trace_metadata=trace_metadata,
|
||||||
|
temperature=model_config.temperature,
|
||||||
)
|
)
|
||||||
function_calls, message_content = parse_ollama_response(
|
function_calls, message_content = parse_ollama_response(
|
||||||
ollama_response, provider="ollama"
|
ollama_response, provider="ollama"
|
||||||
|
|||||||
@@ -231,6 +231,7 @@ class ChatCompletionRequestPayload(TypedDict, total=False):
|
|||||||
messages: list[ChatMessage]
|
messages: list[ChatMessage]
|
||||||
tools: list[ToolDefinition]
|
tools: list[ToolDefinition]
|
||||||
stream: NotRequired[bool]
|
stream: NotRequired[bool]
|
||||||
|
temperature: NotRequired[float]
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from dataclasses import dataclass
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||||
|
|
||||||
from clan_lib.cmd import RunOpts, run
|
from clan_lib.cmd import Log, RunOpts, run
|
||||||
from clan_lib.errors import ClanError
|
from clan_lib.errors import ClanError
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -70,7 +70,7 @@ class SystemdUserService:
|
|||||||
"""Run systemctl command with --user flag."""
|
"""Run systemctl command with --user flag."""
|
||||||
return run(
|
return run(
|
||||||
["systemctl", "--user", action, f"{service_name}.service"],
|
["systemctl", "--user", action, f"{service_name}.service"],
|
||||||
RunOpts(check=False),
|
RunOpts(check=False, log=Log.NONE),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_property(self, service_name: str, prop: str) -> str:
|
def _get_property(self, service_name: str, prop: str) -> str:
|
||||||
@@ -240,11 +240,15 @@ class SystemdUserService:
|
|||||||
service_name = self._service_name(name)
|
service_name = self._service_name(name)
|
||||||
|
|
||||||
result = self._systemctl("stop", service_name)
|
result = self._systemctl("stop", service_name)
|
||||||
if result.returncode != 0 and "not loaded" not in result.stderr.lower():
|
if (
|
||||||
|
result.returncode != 0
|
||||||
|
and "not loaded" not in result.stderr.lower()
|
||||||
|
and "does not exist" not in result.stderr.lower()
|
||||||
|
):
|
||||||
msg = f"Failed to stop service: {result.stderr}"
|
msg = f"Failed to stop service: {result.stderr}"
|
||||||
raise ClanError(msg)
|
raise ClanError(msg)
|
||||||
|
|
||||||
self._systemctl("disable", service_name) # Ignore errors for transient units
|
result = self._systemctl("disable", service_name)
|
||||||
|
|
||||||
unit_file = self._unit_file_path(name)
|
unit_file = self._unit_file_path(name)
|
||||||
if unit_file.exists():
|
if unit_file.exists():
|
||||||
|
|||||||
Reference in New Issue
Block a user