API: handle functions with multiple arguments
This commit is contained in:
@@ -24,15 +24,14 @@ class ApiResponse(Generic[ResponseDataType]):
|
||||
|
||||
class _MethodRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._registry: dict[str, Callable] = {}
|
||||
self._registry: dict[str, Callable[[Any], Any]] = {}
|
||||
|
||||
def register(self, fn: Callable[..., T]) -> Callable[..., T]:
|
||||
self._registry[fn.__name__] = fn
|
||||
return fn
|
||||
|
||||
def to_json_schema(self) -> str:
|
||||
def to_json_schema(self) -> dict[str, Any]:
|
||||
# Import only when needed
|
||||
import json
|
||||
from typing import get_type_hints
|
||||
|
||||
from clan_cli.api.util import type_to_dict
|
||||
@@ -41,25 +40,51 @@ class _MethodRegistry:
|
||||
"$comment": "An object containing API methods. ",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
"required": ["list_machines"],
|
||||
"required": [func_name for func_name in self._registry.keys()],
|
||||
"properties": {},
|
||||
}
|
||||
|
||||
for name, func in self._registry.items():
|
||||
hints = get_type_hints(func)
|
||||
|
||||
serialized_hints = {
|
||||
"argument" if key != "return" else "return": type_to_dict(
|
||||
key: type_to_dict(
|
||||
value, scope=name + " argument" if key != "return" else "return"
|
||||
)
|
||||
for key, value in hints.items()
|
||||
}
|
||||
|
||||
return_type = serialized_hints.pop("return")
|
||||
|
||||
api_schema["properties"][name] = {
|
||||
"type": "object",
|
||||
"required": [k for k in serialized_hints.keys()],
|
||||
"required": ["arguments", "return"],
|
||||
"additionalProperties": False,
|
||||
"properties": {**serialized_hints},
|
||||
"properties": {
|
||||
"return": return_type,
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"required": [k for k in serialized_hints.keys()],
|
||||
"additionalProperties": False,
|
||||
"properties": serialized_hints,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return json.dumps(api_schema, indent=2)
|
||||
return api_schema
|
||||
|
||||
def get_method_argtype(self, method_name: str, arg_name: str) -> Any:
|
||||
from inspect import signature
|
||||
|
||||
func = self._registry.get(method_name, None)
|
||||
if func:
|
||||
sig = signature(func)
|
||||
param = sig.parameters.get(arg_name)
|
||||
if param:
|
||||
param_class = param.annotation
|
||||
return param_class
|
||||
|
||||
return None
|
||||
|
||||
|
||||
API = _MethodRegistry()
|
||||
|
||||
@@ -42,10 +42,14 @@ def type_to_dict(t: Any, scope: str = "") -> dict:
|
||||
return {"type": "array", "items": type_to_dict(t.__args__[0], scope)}
|
||||
|
||||
elif issubclass(origin, dict):
|
||||
return {
|
||||
"type": "object",
|
||||
"additionalProperties": type_to_dict(t.__args__[1], scope),
|
||||
}
|
||||
value_type = t.__args__[1]
|
||||
if value_type is Any:
|
||||
return {"type": "object", "additionalProperties": True}
|
||||
else:
|
||||
return {
|
||||
"type": "object",
|
||||
"additionalProperties": type_to_dict(value_type, scope),
|
||||
}
|
||||
|
||||
raise BaseException(f"Error api type not yet supported {t!s}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user