API: handle functions with multiple arguments

This commit is contained in:
Johannes Kirschbauer
2024-05-26 18:04:49 +02:00
parent ed171f0264
commit ab656d5655
9 changed files with 85 additions and 24 deletions

View File

@@ -1,10 +1,12 @@
import json
from clan_cli.api import API
def main() -> None:
schema = API.to_json_schema()
print(
f"""export const schema = {schema} as const;
f"""export const schema = {json.dumps(schema, indent=2)} as const;
"""
)

View File

@@ -24,15 +24,14 @@ class ApiResponse(Generic[ResponseDataType]):
class _MethodRegistry:
def __init__(self) -> None:
self._registry: dict[str, Callable] = {}
self._registry: dict[str, Callable[[Any], Any]] = {}
def register(self, fn: Callable[..., T]) -> Callable[..., T]:
self._registry[fn.__name__] = fn
return fn
def to_json_schema(self) -> str:
def to_json_schema(self) -> dict[str, Any]:
# Import only when needed
import json
from typing import get_type_hints
from clan_cli.api.util import type_to_dict
@@ -41,25 +40,51 @@ class _MethodRegistry:
"$comment": "An object containing API methods. ",
"type": "object",
"additionalProperties": False,
"required": ["list_machines"],
"required": [func_name for func_name in self._registry.keys()],
"properties": {},
}
for name, func in self._registry.items():
hints = get_type_hints(func)
serialized_hints = {
"argument" if key != "return" else "return": type_to_dict(
key: type_to_dict(
value, scope=name + " argument" if key != "return" else "return"
)
for key, value in hints.items()
}
return_type = serialized_hints.pop("return")
api_schema["properties"][name] = {
"type": "object",
"required": [k for k in serialized_hints.keys()],
"required": ["arguments", "return"],
"additionalProperties": False,
"properties": {**serialized_hints},
"properties": {
"return": return_type,
"arguments": {
"type": "object",
"required": [k for k in serialized_hints.keys()],
"additionalProperties": False,
"properties": serialized_hints,
},
},
}
return json.dumps(api_schema, indent=2)
return api_schema
def get_method_argtype(self, method_name: str, arg_name: str) -> Any:
from inspect import signature
func = self._registry.get(method_name, None)
if func:
sig = signature(func)
param = sig.parameters.get(arg_name)
if param:
param_class = param.annotation
return param_class
return None
API = _MethodRegistry()

View File

@@ -42,10 +42,14 @@ def type_to_dict(t: Any, scope: str = "") -> dict:
return {"type": "array", "items": type_to_dict(t.__args__[0], scope)}
elif issubclass(origin, dict):
return {
"type": "object",
"additionalProperties": type_to_dict(t.__args__[1], scope),
}
value_type = t.__args__[1]
if value_type is Any:
return {"type": "object", "additionalProperties": True}
else:
return {
"type": "object",
"additionalProperties": type_to_dict(value_type, scope),
}
raise BaseException(f"Error api type not yet supported {t!s}")

View File

@@ -39,7 +39,7 @@ def inspect_flake(flake_url: str | Path, machine_name: str) -> FlakeConfig:
system = config["system"]
# Check if the machine exists
machines = list_machines(False, flake_url)
machines = list_machines(flake_url, False)
if machine_name not in machines:
raise ClanError(
f"Machine {machine_name} not found in {flake_url}. Available machines: {', '.join(machines)}"

View File

@@ -12,7 +12,7 @@ log = logging.getLogger(__name__)
@dataclass
class MachineCreateRequest:
name: str
config: dict
config: dict[str, int]
@API.register

View File

@@ -20,7 +20,7 @@ class MachineInfo:
@API.register
def list_machines(debug: bool, flake_url: Path | str) -> dict[str, MachineInfo]:
def list_machines(flake_url: str | Path, debug: bool) -> dict[str, MachineInfo]:
config = nix_config()
system = config["system"]
cmd = nix_eval(
@@ -57,7 +57,7 @@ def list_command(args: argparse.Namespace) -> None:
print("Listing all machines:\n")
print("Source: ", flake_path)
print("-" * 40)
for name, machine in list_machines(args.debug, flake_path).items():
for name, machine in list_machines(flake_path, args.debug).items():
description = machine.machine_description or "[no description]"
print(f"{name}\n: {description}\n")
print("-" * 40)