571 lines
19 KiB
Python
571 lines
19 KiB
Python
"""API client code for LLM providers (OpenAI and Ollama)."""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
import urllib.request
|
|
from collections.abc import Sequence
|
|
from http import HTTPStatus
|
|
from pathlib import Path
|
|
from typing import Any, cast
|
|
from urllib.error import HTTPError, URLError
|
|
|
|
from clan_lib.errors import ClanError
|
|
|
|
from .schemas import (
|
|
ChatCompletionRequestPayload,
|
|
ChatMessage,
|
|
FunctionCallType,
|
|
MessageContent,
|
|
OllamaChatResponse,
|
|
OpenAIChatCompletionResponse,
|
|
ToolDefinition,
|
|
)
|
|
from .trace import (
|
|
format_messages_for_trace,
|
|
format_tools_for_trace,
|
|
write_trace_entry,
|
|
)
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def _stringify_message_content(content: MessageContent | None) -> str:
|
|
"""Convert message content payloads to human-readable text for logging."""
|
|
if content is None:
|
|
return ""
|
|
if isinstance(content, str):
|
|
return content
|
|
if isinstance(content, list):
|
|
parts: list[str] = []
|
|
for item in content:
|
|
if isinstance(item, dict) and "text" in item:
|
|
text_part = item.get("text")
|
|
if isinstance(text_part, str):
|
|
parts.append(text_part)
|
|
continue
|
|
parts.append(json.dumps(item, ensure_ascii=False))
|
|
return "\n".join(parts)
|
|
return json.dumps(content, ensure_ascii=False)
|
|
|
|
|
|
def _summarize_tools(
|
|
tools: Sequence[ToolDefinition],
|
|
) -> str:
|
|
"""Create a concise comma-separated list of tool names for logging."""
|
|
names: list[str] = []
|
|
for tool in tools:
|
|
if not isinstance(tool, dict):
|
|
continue
|
|
function_block = tool.get("function")
|
|
if isinstance(function_block, dict) and "name" in function_block:
|
|
name = function_block.get("name")
|
|
else:
|
|
name = tool.get("name")
|
|
if isinstance(name, str):
|
|
names.append(name)
|
|
return ", ".join(names)
|
|
|
|
|
|
def _debug_log_request(
|
|
provider: str,
|
|
messages: list[ChatMessage],
|
|
tools: Sequence[ToolDefinition],
|
|
) -> None:
|
|
"""Emit structured debug logs for outbound LLM requests."""
|
|
if not log.isEnabledFor(logging.DEBUG):
|
|
return
|
|
|
|
log.debug("[%s] >>> sending %d message(s)", provider, len(messages))
|
|
for idx, message in enumerate(messages):
|
|
role = message.get("role", "unknown")
|
|
content_str = _stringify_message_content(message.get("content"))
|
|
log.debug(
|
|
"[%s] >>> message[%02d] role=%s len=%d",
|
|
provider,
|
|
idx,
|
|
role,
|
|
len(content_str),
|
|
)
|
|
if content_str:
|
|
log.debug("[%s] >>> message[%02d] content:\n%s", provider, idx, content_str)
|
|
|
|
if tools:
|
|
log.debug("[%s] >>> tool summary: %s", provider, _summarize_tools(tools))
|
|
log.debug(
|
|
"[%s] >>> tool payload:\n%s",
|
|
provider,
|
|
json.dumps(list(tools), indent=2, ensure_ascii=False),
|
|
)
|
|
|
|
|
|
def _debug_log_response(
|
|
provider: str,
|
|
text: str,
|
|
function_calls: list[FunctionCallType],
|
|
) -> None:
|
|
"""Emit structured debug logs for inbound LLM responses."""
|
|
if not log.isEnabledFor(logging.DEBUG):
|
|
return
|
|
|
|
if text:
|
|
log.debug(
|
|
"[%s] <<< response text len=%d\n%s",
|
|
provider,
|
|
len(text),
|
|
text,
|
|
)
|
|
else:
|
|
log.debug("[%s] <<< no textual response", provider)
|
|
|
|
if not function_calls:
|
|
log.debug("[%s] <<< no function calls", provider)
|
|
return
|
|
|
|
for idx, call in enumerate(function_calls):
|
|
args_repr = call.get("arguments", "")
|
|
formatted_args = args_repr
|
|
if isinstance(args_repr, str):
|
|
try:
|
|
parsed_args = json.loads(args_repr)
|
|
formatted_args = json.dumps(parsed_args, indent=2, ensure_ascii=False)
|
|
except json.JSONDecodeError:
|
|
formatted_args = args_repr
|
|
log.debug(
|
|
"[%s] <<< call[%02d] name=%s\n%s",
|
|
provider,
|
|
idx,
|
|
call.get("name"),
|
|
formatted_args,
|
|
)
|
|
|
|
|
|
def call_openai_api(
|
|
model: str,
|
|
messages: list[ChatMessage],
|
|
tools: Sequence[ToolDefinition],
|
|
timeout: int = 60,
|
|
trace_file: Path | None = None,
|
|
stage: str = "unknown",
|
|
trace_metadata: dict[str, Any] | None = None,
|
|
temperature: float | None = None,
|
|
) -> OpenAIChatCompletionResponse:
|
|
"""Call the OpenAI API for chat completion.
|
|
|
|
Args:
|
|
model: The OpenAI model to use
|
|
messages: List of message dictionaries
|
|
tools: List of OpenAI function schemas
|
|
timeout: Request timeout in seconds (default: 60)
|
|
trace_file: Optional path to write trace entries for debugging
|
|
stage: Stage name for trace entries (default: "unknown")
|
|
trace_metadata: Optional metadata to include in trace entries
|
|
temperature: Sampling temperature (default: None = use API default)
|
|
|
|
Returns:
|
|
The parsed JSON response from the API
|
|
|
|
Raises:
|
|
ClanError: If the API call fails
|
|
|
|
"""
|
|
api_key = os.environ.get("OPENAI_API_KEY")
|
|
if not api_key:
|
|
msg = "OPENAI_API_KEY environment variable is required for OpenAI provider"
|
|
raise ClanError(msg)
|
|
|
|
payload: ChatCompletionRequestPayload = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"tools": list(tools),
|
|
}
|
|
if temperature is not None:
|
|
payload["temperature"] = temperature
|
|
_debug_log_request("openai", messages, tools)
|
|
url = "https://api.openai.com/v1/chat/completions"
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {api_key}",
|
|
}
|
|
|
|
start_time = time.time()
|
|
try:
|
|
req = urllib.request.Request( # noqa: S310
|
|
url,
|
|
data=json.dumps(payload).encode("utf-8"),
|
|
headers=headers,
|
|
)
|
|
with urllib.request.urlopen(req, timeout=timeout) as resp: # noqa: S310
|
|
if resp.getcode() != HTTPStatus.OK.value:
|
|
msg = f"OpenAI API returned status {resp.getcode()}"
|
|
raise ClanError(msg)
|
|
|
|
raw = resp.read().decode("utf-8")
|
|
response = cast("OpenAIChatCompletionResponse", json.loads(raw))
|
|
|
|
# Write trace if requested
|
|
if trace_file:
|
|
duration_ms = (time.time() - start_time) * 1000
|
|
function_calls, message_content = parse_openai_response(
|
|
response, provider="openai"
|
|
)
|
|
write_trace_entry(
|
|
trace_file=trace_file,
|
|
provider="openai",
|
|
model=model,
|
|
stage=stage,
|
|
request={
|
|
"messages": format_messages_for_trace(messages),
|
|
"tools": format_tools_for_trace(
|
|
cast("list[dict[str, Any]]", list(tools))
|
|
),
|
|
},
|
|
response={
|
|
"function_calls": [
|
|
{
|
|
"name": call["name"],
|
|
"arguments": json.loads(call["arguments"])
|
|
if isinstance(call["arguments"], str)
|
|
else call["arguments"],
|
|
}
|
|
for call in function_calls
|
|
],
|
|
"message": message_content,
|
|
},
|
|
duration_ms=duration_ms,
|
|
metadata=trace_metadata,
|
|
)
|
|
|
|
return response
|
|
|
|
except HTTPError as e:
|
|
error_body = e.read().decode("utf-8") if e.fp else ""
|
|
msg = f"OpenAI returned HTTP {e.code}: {error_body}"
|
|
raise ClanError(msg) from e
|
|
except URLError as e:
|
|
msg = "OpenAI API not reachable"
|
|
raise ClanError(msg) from e
|
|
except json.JSONDecodeError as e:
|
|
msg = "Failed to parse OpenAI API response"
|
|
raise ClanError(msg) from e
|
|
|
|
|
|
def call_claude_api(
|
|
model: str,
|
|
messages: list[ChatMessage],
|
|
tools: Sequence[ToolDefinition],
|
|
base_url: str | None = None,
|
|
timeout: int = 60,
|
|
trace_file: Path | None = None,
|
|
stage: str = "unknown",
|
|
trace_metadata: dict[str, Any] | None = None,
|
|
temperature: float | None = None,
|
|
) -> OpenAIChatCompletionResponse:
|
|
"""Call the Claude API (via OpenAI-compatible endpoint) for chat completion.
|
|
|
|
Args:
|
|
model: The Claude model to use
|
|
messages: List of message dictionaries
|
|
tools: List of function schemas (OpenAI format)
|
|
base_url: Optional base URL for the API (defaults to https://api.anthropic.com/v1/)
|
|
timeout: Request timeout in seconds (default: 60)
|
|
trace_file: Optional path to write trace entries for debugging
|
|
stage: Stage name for trace entries (default: "unknown")
|
|
trace_metadata: Optional metadata to include in trace entries
|
|
temperature: Sampling temperature (default: None = use API default)
|
|
|
|
Returns:
|
|
The parsed JSON response from the API
|
|
|
|
Raises:
|
|
ClanError: If the API call fails
|
|
|
|
"""
|
|
api_key = os.environ.get("ANTHROPIC_API_KEY")
|
|
if not api_key:
|
|
msg = "ANTHROPIC_API_KEY environment variable is required for Claude provider"
|
|
raise ClanError(msg)
|
|
|
|
if base_url is None:
|
|
base_url = os.environ.get("ANTHROPIC_BASE_URL", "https://api.anthropic.com/v1/")
|
|
|
|
# Ensure base_url ends with /
|
|
if not base_url.endswith("/"):
|
|
base_url += "/"
|
|
|
|
payload: ChatCompletionRequestPayload = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"tools": list(tools),
|
|
}
|
|
if temperature is not None:
|
|
payload["temperature"] = temperature
|
|
_debug_log_request("claude", messages, tools)
|
|
|
|
url = f"{base_url}chat/completions"
|
|
headers = {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {api_key}",
|
|
}
|
|
|
|
start_time = time.time()
|
|
try:
|
|
req = urllib.request.Request( # noqa: S310
|
|
url,
|
|
data=json.dumps(payload).encode("utf-8"),
|
|
headers=headers,
|
|
)
|
|
with urllib.request.urlopen(req, timeout=timeout) as resp: # noqa: S310
|
|
if resp.getcode() != HTTPStatus.OK.value:
|
|
msg = f"Claude API returned status {resp.getcode()}"
|
|
raise ClanError(msg)
|
|
|
|
raw = resp.read().decode("utf-8")
|
|
response = cast("OpenAIChatCompletionResponse", json.loads(raw))
|
|
|
|
# Write trace if requested
|
|
if trace_file:
|
|
duration_ms = (time.time() - start_time) * 1000
|
|
function_calls, message_content = parse_openai_response(
|
|
response, provider="claude"
|
|
)
|
|
write_trace_entry(
|
|
trace_file=trace_file,
|
|
provider="claude",
|
|
model=model,
|
|
stage=stage,
|
|
request={
|
|
"messages": format_messages_for_trace(messages),
|
|
"tools": format_tools_for_trace(
|
|
cast("list[dict[str, Any]]", list(tools))
|
|
),
|
|
},
|
|
response={
|
|
"function_calls": [
|
|
{
|
|
"name": call["name"],
|
|
"arguments": json.loads(call["arguments"])
|
|
if isinstance(call["arguments"], str)
|
|
else call["arguments"],
|
|
}
|
|
for call in function_calls
|
|
],
|
|
"message": message_content,
|
|
},
|
|
duration_ms=duration_ms,
|
|
metadata=trace_metadata,
|
|
)
|
|
|
|
return response
|
|
|
|
except HTTPError as e:
|
|
error_body = e.read().decode("utf-8") if e.fp else ""
|
|
msg = f"Claude returned HTTP {e.code}: {error_body}"
|
|
raise ClanError(msg) from e
|
|
except URLError as e:
|
|
msg = f"Claude API not reachable at {url}"
|
|
raise ClanError(msg) from e
|
|
except json.JSONDecodeError as e:
|
|
msg = "Failed to parse Claude API response"
|
|
raise ClanError(msg) from e
|
|
|
|
|
|
def call_ollama_api(
|
|
model: str,
|
|
messages: list[ChatMessage],
|
|
tools: Sequence[ToolDefinition],
|
|
timeout: int = 120,
|
|
trace_file: Path | None = None,
|
|
stage: str = "unknown",
|
|
max_tokens: int | None = None,
|
|
trace_metadata: dict[str, Any] | None = None,
|
|
temperature: float | None = None,
|
|
) -> OllamaChatResponse:
|
|
"""Call the Ollama API for chat completion.
|
|
|
|
Args:
|
|
model: The Ollama model to use
|
|
messages: List of message dictionaries
|
|
tools: List of Ollama function schemas
|
|
timeout: Request timeout in seconds (default: 120)
|
|
trace_file: Optional path to write trace entries for debugging
|
|
stage: Stage name for trace entries (default: "unknown")
|
|
max_tokens: Maximum number of tokens to generate (default: None = unlimited)
|
|
trace_metadata: Optional metadata to include in trace entries
|
|
temperature: Sampling temperature (default: None = use API default)
|
|
|
|
Returns:
|
|
The parsed JSON response from the API
|
|
|
|
Raises:
|
|
ClanError: If the API call fails
|
|
|
|
"""
|
|
payload: ChatCompletionRequestPayload = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"stream": False,
|
|
"tools": list(tools),
|
|
}
|
|
|
|
# Add options for max_tokens and temperature if specified
|
|
options: dict[str, int | float] = {}
|
|
if max_tokens is not None:
|
|
options["num_predict"] = max_tokens
|
|
if temperature is not None:
|
|
options["temperature"] = temperature
|
|
if options:
|
|
payload["options"] = options # type: ignore[typeddict-item]
|
|
_debug_log_request("ollama", messages, tools)
|
|
url = "http://localhost:11434/api/chat"
|
|
|
|
start_time = time.time()
|
|
try:
|
|
req = urllib.request.Request( # noqa: S310
|
|
url,
|
|
data=json.dumps(payload).encode("utf-8"),
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
with urllib.request.urlopen(req, timeout=timeout) as resp: # noqa: S310
|
|
if resp.getcode() != HTTPStatus.OK.value:
|
|
msg = f"Ollama API returned status {resp.getcode()}"
|
|
raise ClanError(msg)
|
|
|
|
raw = resp.read().decode("utf-8")
|
|
response = cast("OllamaChatResponse", json.loads(raw))
|
|
|
|
# Write trace if requested
|
|
if trace_file:
|
|
duration_ms = (time.time() - start_time) * 1000
|
|
function_calls, message_content = parse_ollama_response(
|
|
response, provider="ollama"
|
|
)
|
|
write_trace_entry(
|
|
trace_file=trace_file,
|
|
provider="ollama",
|
|
model=model,
|
|
stage=stage,
|
|
request={
|
|
"messages": format_messages_for_trace(messages),
|
|
"tools": format_tools_for_trace(
|
|
cast("list[dict[str, Any]]", list(tools))
|
|
),
|
|
},
|
|
response={
|
|
"function_calls": [
|
|
{
|
|
"name": call["name"],
|
|
"arguments": json.loads(call["arguments"])
|
|
if isinstance(call["arguments"], str)
|
|
else call["arguments"],
|
|
}
|
|
for call in function_calls
|
|
],
|
|
"message": message_content,
|
|
},
|
|
duration_ms=duration_ms,
|
|
metadata=trace_metadata,
|
|
)
|
|
|
|
return response
|
|
|
|
except HTTPError as e:
|
|
msg = f"Ollama returned HTTP {e.code} when requesting chat completion."
|
|
raise ClanError(msg) from e
|
|
except URLError as e:
|
|
msg = "Ollama API not reachable at http://localhost:11434"
|
|
raise ClanError(msg) from e
|
|
except json.JSONDecodeError as e:
|
|
msg = "Failed to parse Ollama API response"
|
|
raise ClanError(msg) from e
|
|
|
|
|
|
def parse_openai_response(
|
|
response_data: OpenAIChatCompletionResponse,
|
|
provider: str = "openai",
|
|
) -> tuple[list[FunctionCallType], str]:
|
|
"""Parse OpenAI API response to extract function calls.
|
|
|
|
Args:
|
|
response_data: The raw response from OpenAI API
|
|
provider: The provider name for logging purposes (default: "openai")
|
|
|
|
Returns:
|
|
Tuple of (function_calls, message_content)
|
|
|
|
"""
|
|
choices = response_data.get("choices") or []
|
|
if not choices:
|
|
return [], ""
|
|
|
|
message = choices[0].get("message") or {}
|
|
tool_calls = message.get("tool_calls") or []
|
|
raw_content = message.get("content") or ""
|
|
model_content = _stringify_message_content(raw_content)
|
|
|
|
result: list[FunctionCallType] = []
|
|
for tool_call in tool_calls:
|
|
tc_id = tool_call.get("id") or f"call_{int(time.time() * 1000)}"
|
|
function = tool_call.get("function") or {}
|
|
function_name = function.get("name") or ""
|
|
function_args = function.get("arguments") or "{}"
|
|
|
|
result.append(
|
|
FunctionCallType(
|
|
id=tc_id,
|
|
call_id=tc_id,
|
|
type="function_call",
|
|
name=function_name,
|
|
arguments=function_args,
|
|
)
|
|
)
|
|
|
|
_debug_log_response(provider, model_content, result)
|
|
|
|
return result, model_content
|
|
|
|
|
|
def parse_ollama_response(
|
|
response_data: OllamaChatResponse,
|
|
provider: str = "ollama",
|
|
) -> tuple[list[FunctionCallType], str]:
|
|
"""Parse Ollama API response to extract function calls.
|
|
|
|
Args:
|
|
response_data: The raw response from Ollama API
|
|
provider: The provider name for logging purposes (default: "ollama")
|
|
|
|
Returns:
|
|
Tuple of (function_calls, message_content)
|
|
|
|
"""
|
|
message = response_data.get("message") or {}
|
|
tool_calls = message.get("tool_calls") or []
|
|
raw_content = message.get("content") or ""
|
|
model_content = _stringify_message_content(raw_content)
|
|
|
|
result: list[FunctionCallType] = []
|
|
for idx, tool_call in enumerate(tool_calls):
|
|
function = tool_call.get("function") or {}
|
|
function_name = function.get("name") or ""
|
|
function_args = function.get("arguments") or {}
|
|
|
|
# Generate unique IDs (similar to OpenAI format)
|
|
call_id = f"call_{idx}_{int(time.time() * 1000)}"
|
|
fc_id = f"fc_{idx}_{int(time.time() * 1000)}"
|
|
|
|
result.append(
|
|
FunctionCallType(
|
|
id=fc_id,
|
|
call_id=call_id,
|
|
type="function_call",
|
|
name=function_name,
|
|
arguments=json.dumps(function_args),
|
|
)
|
|
)
|
|
|
|
_debug_log_response(provider, model_content, result)
|
|
|
|
return result, model_content
|