API: handle functions with multiple arguments
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import json
|
||||
|
||||
from clan_cli.api import API
|
||||
|
||||
|
||||
def main() -> None:
|
||||
schema = API.to_json_schema()
|
||||
print(
|
||||
f"""export const schema = {schema} as const;
|
||||
f"""export const schema = {json.dumps(schema, indent=2)} as const;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@@ -24,15 +24,14 @@ class ApiResponse(Generic[ResponseDataType]):
|
||||
|
||||
class _MethodRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._registry: dict[str, Callable] = {}
|
||||
self._registry: dict[str, Callable[[Any], Any]] = {}
|
||||
|
||||
def register(self, fn: Callable[..., T]) -> Callable[..., T]:
|
||||
self._registry[fn.__name__] = fn
|
||||
return fn
|
||||
|
||||
def to_json_schema(self) -> str:
|
||||
def to_json_schema(self) -> dict[str, Any]:
|
||||
# Import only when needed
|
||||
import json
|
||||
from typing import get_type_hints
|
||||
|
||||
from clan_cli.api.util import type_to_dict
|
||||
@@ -41,25 +40,51 @@ class _MethodRegistry:
|
||||
"$comment": "An object containing API methods. ",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
"required": ["list_machines"],
|
||||
"required": [func_name for func_name in self._registry.keys()],
|
||||
"properties": {},
|
||||
}
|
||||
|
||||
for name, func in self._registry.items():
|
||||
hints = get_type_hints(func)
|
||||
|
||||
serialized_hints = {
|
||||
"argument" if key != "return" else "return": type_to_dict(
|
||||
key: type_to_dict(
|
||||
value, scope=name + " argument" if key != "return" else "return"
|
||||
)
|
||||
for key, value in hints.items()
|
||||
}
|
||||
|
||||
return_type = serialized_hints.pop("return")
|
||||
|
||||
api_schema["properties"][name] = {
|
||||
"type": "object",
|
||||
"required": [k for k in serialized_hints.keys()],
|
||||
"required": ["arguments", "return"],
|
||||
"additionalProperties": False,
|
||||
"properties": {**serialized_hints},
|
||||
"properties": {
|
||||
"return": return_type,
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"required": [k for k in serialized_hints.keys()],
|
||||
"additionalProperties": False,
|
||||
"properties": serialized_hints,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return json.dumps(api_schema, indent=2)
|
||||
return api_schema
|
||||
|
||||
def get_method_argtype(self, method_name: str, arg_name: str) -> Any:
|
||||
from inspect import signature
|
||||
|
||||
func = self._registry.get(method_name, None)
|
||||
if func:
|
||||
sig = signature(func)
|
||||
param = sig.parameters.get(arg_name)
|
||||
if param:
|
||||
param_class = param.annotation
|
||||
return param_class
|
||||
|
||||
return None
|
||||
|
||||
|
||||
API = _MethodRegistry()
|
||||
|
||||
@@ -42,10 +42,14 @@ def type_to_dict(t: Any, scope: str = "") -> dict:
|
||||
return {"type": "array", "items": type_to_dict(t.__args__[0], scope)}
|
||||
|
||||
elif issubclass(origin, dict):
|
||||
return {
|
||||
"type": "object",
|
||||
"additionalProperties": type_to_dict(t.__args__[1], scope),
|
||||
}
|
||||
value_type = t.__args__[1]
|
||||
if value_type is Any:
|
||||
return {"type": "object", "additionalProperties": True}
|
||||
else:
|
||||
return {
|
||||
"type": "object",
|
||||
"additionalProperties": type_to_dict(value_type, scope),
|
||||
}
|
||||
|
||||
raise BaseException(f"Error api type not yet supported {t!s}")
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ def inspect_flake(flake_url: str | Path, machine_name: str) -> FlakeConfig:
|
||||
system = config["system"]
|
||||
|
||||
# Check if the machine exists
|
||||
machines = list_machines(False, flake_url)
|
||||
machines = list_machines(flake_url, False)
|
||||
if machine_name not in machines:
|
||||
raise ClanError(
|
||||
f"Machine {machine_name} not found in {flake_url}. Available machines: {', '.join(machines)}"
|
||||
|
||||
@@ -12,7 +12,7 @@ log = logging.getLogger(__name__)
|
||||
@dataclass
|
||||
class MachineCreateRequest:
|
||||
name: str
|
||||
config: dict
|
||||
config: dict[str, int]
|
||||
|
||||
|
||||
@API.register
|
||||
|
||||
@@ -20,7 +20,7 @@ class MachineInfo:
|
||||
|
||||
|
||||
@API.register
|
||||
def list_machines(debug: bool, flake_url: Path | str) -> dict[str, MachineInfo]:
|
||||
def list_machines(flake_url: str | Path, debug: bool) -> dict[str, MachineInfo]:
|
||||
config = nix_config()
|
||||
system = config["system"]
|
||||
cmd = nix_eval(
|
||||
@@ -57,7 +57,7 @@ def list_command(args: argparse.Namespace) -> None:
|
||||
print("Listing all machines:\n")
|
||||
print("Source: ", flake_path)
|
||||
print("-" * 40)
|
||||
for name, machine in list_machines(args.debug, flake_path).items():
|
||||
for name, machine in list_machines(flake_path, args.debug).items():
|
||||
description = machine.machine_description or "[no description]"
|
||||
print(f"{name}\n: {description}\n")
|
||||
print("-" * 40)
|
||||
|
||||
Reference in New Issue
Block a user