Refactor(clan_lib): move clan_cli.api into clan_lib.api
This commit is contained in:
286
pkgs/clan-cli/clan_lib/api/__init__.py
Normal file
286
pkgs/clan-cli/clan_lib/api/__init__.py
Normal file
@@ -0,0 +1,286 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from functools import wraps
|
||||
from inspect import Parameter, Signature, signature
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
TypeVar,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from .serde import dataclass_to_dict, from_dict, sanitize_string
|
||||
|
||||
__all__ = ["dataclass_to_dict", "from_dict", "sanitize_string"]
|
||||
|
||||
from clan_cli.errors import ClanError
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
ResponseDataType = TypeVar("ResponseDataType")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiError:
|
||||
message: str
|
||||
description: str | None
|
||||
location: list[str] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SuccessDataClass(Generic[ResponseDataType]):
|
||||
op_key: str
|
||||
status: Annotated[Literal["success"], "The status of the response."]
|
||||
data: ResponseDataType
|
||||
|
||||
|
||||
@dataclass
|
||||
class ErrorDataClass:
|
||||
op_key: str
|
||||
status: Literal["error"]
|
||||
errors: list[ApiError]
|
||||
|
||||
|
||||
ApiResponse = SuccessDataClass[ResponseDataType] | ErrorDataClass
|
||||
|
||||
|
||||
def update_wrapper_signature(wrapper: Callable, wrapped: Callable) -> None:
|
||||
sig = signature(wrapped)
|
||||
params = list(sig.parameters.values())
|
||||
|
||||
# Add 'op_key' parameter
|
||||
op_key_param = Parameter(
|
||||
"op_key",
|
||||
Parameter.KEYWORD_ONLY,
|
||||
# we add a None default value so that typescript code gen drops the parameter
|
||||
# FIXME: this is a hack, we should filter out op_key in the typescript code gen
|
||||
default=None,
|
||||
annotation=str,
|
||||
)
|
||||
params.append(op_key_param)
|
||||
|
||||
# Create a new signature
|
||||
new_sig = sig.replace(parameters=params)
|
||||
wrapper.__signature__ = new_sig # type: ignore
|
||||
|
||||
|
||||
class MethodRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._orig_signature: dict[str, Signature] = {}
|
||||
self._registry: dict[str, Callable[..., Any]] = {}
|
||||
|
||||
@property
|
||||
def orig_signatures(self) -> dict[str, Signature]:
|
||||
return self._orig_signature
|
||||
|
||||
@property
|
||||
def signatures(self) -> dict[str, Signature]:
|
||||
return {name: signature(fn) for name, fn in self.functions.items()}
|
||||
|
||||
@property
|
||||
def functions(self) -> dict[str, Callable[..., Any]]:
|
||||
return self._registry
|
||||
|
||||
def reset(self) -> None:
|
||||
self._orig_signature.clear()
|
||||
self._registry.clear()
|
||||
|
||||
def register_abstract(self, fn: Callable[..., T]) -> Callable[..., T]:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: Any, op_key: str, **kwargs: Any) -> ApiResponse[T]:
|
||||
msg = f"""{fn.__name__} - The platform didn't implement this function.
|
||||
|
||||
---
|
||||
# Example
|
||||
|
||||
The function 'open_file()' depends on the platform.
|
||||
|
||||
def open_file(file_request: FileRequest) -> str | None:
|
||||
# In GTK we open a file dialog window
|
||||
# In Android we open a file picker dialog
|
||||
# and so on.
|
||||
pass
|
||||
|
||||
# At runtime the clan-app must override platform specific functions
|
||||
API.register(open_file)
|
||||
---
|
||||
"""
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
self.register(wrapper)
|
||||
return fn
|
||||
|
||||
def overwrite_fn(self, fn: Callable[..., Any]) -> None:
|
||||
fn_name = fn.__name__
|
||||
|
||||
if fn_name not in self._registry:
|
||||
msg = f"Function '{fn_name}' is not registered as an API method"
|
||||
raise ClanError(msg)
|
||||
|
||||
fn_signature = signature(fn)
|
||||
abstract_signature = signature(self._registry[fn_name])
|
||||
|
||||
# Remove the default argument of op_key from abstract_signature
|
||||
# FIXME: This is a hack to make the signature comparison work
|
||||
# because the other hack above where default value of op_key is None in the wrapper
|
||||
abstract_params = list(abstract_signature.parameters.values())
|
||||
for i, param in enumerate(abstract_params):
|
||||
if param.name == "op_key":
|
||||
abstract_params[i] = param.replace(default=Parameter.empty)
|
||||
break
|
||||
abstract_signature = abstract_signature.replace(parameters=abstract_params)
|
||||
|
||||
if fn_signature != abstract_signature:
|
||||
msg = f"Expected signature: {abstract_signature}\nActual signature: {fn_signature}"
|
||||
raise ClanError(msg)
|
||||
|
||||
self._registry[fn_name] = fn
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
def register(self, fn: F) -> F:
|
||||
if fn.__name__ in self._registry:
|
||||
msg = f"Function {fn.__name__} already registered"
|
||||
raise ClanError(msg)
|
||||
if fn.__name__ in self._orig_signature:
|
||||
msg = f"Function {fn.__name__} already registered"
|
||||
raise ClanError(msg)
|
||||
# make copy of original function
|
||||
self._orig_signature[fn.__name__] = signature(fn)
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args: Any, op_key: str, **kwargs: Any) -> ApiResponse[T]:
|
||||
try:
|
||||
data: T = fn(*args, **kwargs)
|
||||
return SuccessDataClass(status="success", data=data, op_key=op_key)
|
||||
except ClanError as e:
|
||||
log.exception(f"Error calling wrapped {fn.__name__}")
|
||||
return ErrorDataClass(
|
||||
op_key=op_key,
|
||||
status="error",
|
||||
errors=[
|
||||
ApiError(
|
||||
message=e.msg,
|
||||
description=e.description,
|
||||
location=[fn.__name__, e.location],
|
||||
)
|
||||
],
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Error calling wrapped {fn.__name__}")
|
||||
return ErrorDataClass(
|
||||
op_key=op_key,
|
||||
status="error",
|
||||
errors=[
|
||||
ApiError(
|
||||
message=str(e),
|
||||
description="An unexpected error occurred",
|
||||
location=[fn.__name__],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# @wraps preserves all metadata of fn
|
||||
# we need to update the annotation, because our wrapper changes the return type
|
||||
# This overrides the new return type annotation with the generic typeVar filled in
|
||||
orig_return_type = get_type_hints(fn).get("return")
|
||||
wrapper.__annotations__["return"] = ApiResponse[orig_return_type] # type: ignore
|
||||
|
||||
# Add additional argument for the operation key
|
||||
wrapper.__annotations__["op_key"] = str # type: ignore
|
||||
|
||||
update_wrapper_signature(wrapper, fn)
|
||||
|
||||
self._registry[fn.__name__] = wrapper
|
||||
|
||||
return fn
|
||||
|
||||
def to_json_schema(self) -> dict[str, Any]:
|
||||
from typing import get_type_hints
|
||||
|
||||
from .util import type_to_dict
|
||||
|
||||
api_schema: dict[str, Any] = {
|
||||
"$comment": "An object containing API methods. ",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
"required": list(self._registry.keys()),
|
||||
"properties": {},
|
||||
}
|
||||
|
||||
err_type = None
|
||||
for name, func in self._registry.items():
|
||||
hints = get_type_hints(func)
|
||||
|
||||
serialized_hints = {
|
||||
key: type_to_dict(
|
||||
value, scope=name + " argument" if key != "return" else "return"
|
||||
)
|
||||
for key, value in hints.items()
|
||||
}
|
||||
|
||||
return_type = serialized_hints.pop("return")
|
||||
|
||||
if err_type is None:
|
||||
err_type = next(
|
||||
t
|
||||
for t in return_type["oneOf"]
|
||||
if ("error" in t["properties"]["status"]["enum"])
|
||||
)
|
||||
|
||||
return_type["oneOf"][1] = {"$ref": "#/$defs/error"}
|
||||
|
||||
sig = signature(func)
|
||||
required_args = []
|
||||
for n, param in sig.parameters.items():
|
||||
if param.default == Parameter.empty:
|
||||
required_args.append(n)
|
||||
|
||||
api_schema["properties"][name] = {
|
||||
"type": "object",
|
||||
"required": ["arguments", "return"],
|
||||
"additionalProperties": False,
|
||||
"properties": {
|
||||
"return": return_type,
|
||||
"arguments": {
|
||||
"type": "object",
|
||||
"required": required_args,
|
||||
"additionalProperties": False,
|
||||
"properties": serialized_hints,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
api_schema["$defs"] = {"error": err_type}
|
||||
|
||||
return api_schema
|
||||
|
||||
def get_method_argtype(self, method_name: str, arg_name: str) -> Any:
|
||||
from inspect import signature
|
||||
|
||||
func = self._registry.get(method_name, None)
|
||||
if not func:
|
||||
msg = f"API Method {method_name} not found in registry. Available methods: {list(self._registry.keys())}"
|
||||
raise ClanError(msg)
|
||||
|
||||
sig = signature(func)
|
||||
|
||||
# seems direct 'key in dict' doesnt work here
|
||||
if arg_name not in sig.parameters.keys(): # noqa: SIM118
|
||||
msg = f"Argument {arg_name} not found in api method '{method_name}'. Available arguments: {list(sig.parameters.keys())}"
|
||||
raise ClanError(msg)
|
||||
|
||||
param = sig.parameters.get(arg_name)
|
||||
if param:
|
||||
param_class = param.annotation
|
||||
return param_class
|
||||
|
||||
return None
|
||||
|
||||
|
||||
API = MethodRegistry()
|
||||
37
pkgs/clan-cli/clan_lib/api/admin.py
Normal file
37
pkgs/clan-cli/clan_lib/api/admin.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# @API.register
|
||||
# def set_admin_service(
|
||||
# base_url: str,
|
||||
# allowed_keys: dict[str, str],
|
||||
# instance_name: str = "admin",
|
||||
# extra_machines: list[str] | None = None,
|
||||
# ) -> None:
|
||||
# """
|
||||
# Set the admin service of a clan
|
||||
# Every machine is by default part of the admin service via the 'all' tag
|
||||
# """
|
||||
# if extra_machines is None:
|
||||
# extra_machines = []
|
||||
# inventory = load_inventory_eval(base_url)
|
||||
|
||||
# if not allowed_keys:
|
||||
# msg = "At least one key must be provided to ensure access"
|
||||
# raise ClanError(msg)
|
||||
|
||||
# instance = ServiceAdmin(
|
||||
# meta=ServiceMeta(name=instance_name),
|
||||
# roles=ServiceAdminRole(
|
||||
# default=ServiceAdminRoleDefault(
|
||||
# machines=extra_machines,
|
||||
# tags=["all"],
|
||||
# )
|
||||
# ),
|
||||
# config=AdminConfig(allowedKeys=allowed_keys),
|
||||
# )
|
||||
|
||||
# inventory.services.admin[instance_name] = instance
|
||||
|
||||
# save_inventory(
|
||||
# inventory,
|
||||
# base_url,
|
||||
# f"Set admin service: '{instance_name}'",
|
||||
# )
|
||||
13
pkgs/clan-cli/clan_lib/api/cli.py
Executable file
13
pkgs/clan-cli/clan_lib/api/cli.py
Executable file
@@ -0,0 +1,13 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from . import API
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Debug the API.")
|
||||
args = parser.parse_args()
|
||||
|
||||
schema = API.to_json_schema()
|
||||
print(json.dumps(schema, indent=4))
|
||||
145
pkgs/clan-cli/clan_lib/api/directory.py
Normal file
145
pkgs/clan-cli/clan_lib/api/directory.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from clan_cli.cmd import RunOpts
|
||||
from clan_cli.errors import ClanError
|
||||
from clan_cli.nix import nix_shell, run_no_stdout
|
||||
|
||||
from . import API
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileFilter:
|
||||
title: str | None = field(default=None)
|
||||
mime_types: list[str] | None = field(default=None)
|
||||
patterns: list[str] | None = field(default=None)
|
||||
suffixes: list[str] | None = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileRequest:
|
||||
# Mode of the os dialog window
|
||||
mode: Literal["open_file", "select_folder", "save", "open_multiple_files"]
|
||||
# Title of the os dialog window
|
||||
title: str | None = field(default=None)
|
||||
# Pre-applied filters for the file dialog
|
||||
filters: FileFilter | None = field(default=None)
|
||||
initial_file: str | None = field(default=None)
|
||||
initial_folder: str | None = field(default=None)
|
||||
|
||||
|
||||
@API.register_abstract
|
||||
def open_file(file_request: FileRequest) -> list[str] | None:
|
||||
"""
|
||||
Abstract api method to open a file dialog window.
|
||||
It must return the name of the selected file or None if no file was selected.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class File:
|
||||
path: str
|
||||
file_type: Literal["file", "directory", "symlink"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Directory:
|
||||
path: str
|
||||
files: list[File] = field(default_factory=list)
|
||||
|
||||
|
||||
@API.register
|
||||
def get_directory(current_path: str) -> Directory:
|
||||
curr_dir = Path(current_path)
|
||||
directory = Directory(path=str(curr_dir))
|
||||
|
||||
if not curr_dir.is_dir():
|
||||
msg = f"Path {curr_dir} is not a directory"
|
||||
raise ClanError(msg)
|
||||
|
||||
with os.scandir(curr_dir.resolve()) as it:
|
||||
for entry in it:
|
||||
if entry.is_symlink():
|
||||
directory.files.append(
|
||||
File(
|
||||
path=str(curr_dir / Path(entry.name)),
|
||||
file_type="symlink",
|
||||
)
|
||||
)
|
||||
elif entry.is_file():
|
||||
directory.files.append(
|
||||
File(
|
||||
path=str(curr_dir / Path(entry.name)),
|
||||
file_type="file",
|
||||
)
|
||||
)
|
||||
|
||||
elif entry.is_dir():
|
||||
directory.files.append(
|
||||
File(
|
||||
path=str(curr_dir / Path(entry.name)),
|
||||
file_type="directory",
|
||||
)
|
||||
)
|
||||
|
||||
return directory
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlkInfo:
|
||||
name: str
|
||||
id_link: str
|
||||
path: str
|
||||
rm: str
|
||||
size: str
|
||||
ro: bool
|
||||
mountpoints: list[str]
|
||||
type_: Literal["disk"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Blockdevices:
|
||||
blockdevices: list[BlkInfo]
|
||||
|
||||
|
||||
def blk_from_dict(data: dict) -> BlkInfo:
|
||||
return BlkInfo(
|
||||
name=data["name"],
|
||||
path=data["path"],
|
||||
rm=data["rm"],
|
||||
size=data["size"],
|
||||
ro=data["ro"],
|
||||
mountpoints=data["mountpoints"],
|
||||
type_=data["type"], # renamed
|
||||
id_link=data["id-link"], # renamed
|
||||
)
|
||||
|
||||
|
||||
@API.register
|
||||
def show_block_devices() -> Blockdevices:
|
||||
"""
|
||||
Api method to show local block devices.
|
||||
|
||||
It must return a list of block devices.
|
||||
"""
|
||||
|
||||
cmd = nix_shell(
|
||||
["util-linux"],
|
||||
[
|
||||
"lsblk",
|
||||
"--json",
|
||||
"--output",
|
||||
"PATH,NAME,RM,SIZE,RO,MOUNTPOINTS,TYPE,ID-LINK",
|
||||
],
|
||||
)
|
||||
proc = run_no_stdout(cmd, RunOpts(needs_user_terminal=True))
|
||||
res = proc.stdout.strip()
|
||||
|
||||
blk_info: dict[str, Any] = json.loads(res)
|
||||
|
||||
return Blockdevices(
|
||||
blockdevices=[blk_from_dict(device) for device in blk_info["blockdevices"]]
|
||||
)
|
||||
228
pkgs/clan-cli/clan_lib/api/disk.py
Normal file
228
pkgs/clan-cli/clan_lib/api/disk.py
Normal file
@@ -0,0 +1,228 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, TypedDict
|
||||
from uuid import uuid4
|
||||
|
||||
from clan_cli.dirs import TemplateType, clan_templates
|
||||
from clan_cli.errors import ClanError
|
||||
from clan_cli.git import commit_file
|
||||
from clan_cli.machines.hardware import HardwareConfig, show_machine_hardware_config
|
||||
|
||||
from clan_lib.api import API
|
||||
from clan_lib.api.modules import Frontmatter, extract_frontmatter
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def disk_in_facter_report(hw_report: dict) -> bool:
|
||||
return "hardware" in hw_report and "disk" in hw_report["hardware"]
|
||||
|
||||
|
||||
def get_best_unix_device_name(unix_device_names: list[str]) -> str:
|
||||
# find the first device name that is disk/by-id
|
||||
for device_name in unix_device_names:
|
||||
if "disk/by-id" in device_name:
|
||||
return device_name
|
||||
else:
|
||||
# if no by-id found, use the first device name
|
||||
return unix_device_names[0]
|
||||
|
||||
|
||||
def hw_main_disk_options(hw_report: dict) -> list[str] | None:
|
||||
options: list[str] = []
|
||||
if not disk_in_facter_report(hw_report):
|
||||
return None
|
||||
|
||||
disks = hw_report["hardware"]["disk"]
|
||||
|
||||
for disk in disks:
|
||||
unix_device_names = disk["unix_device_names"]
|
||||
device_name = get_best_unix_device_name(unix_device_names)
|
||||
options += [device_name]
|
||||
|
||||
return options
|
||||
|
||||
|
||||
@dataclass
|
||||
class Placeholder:
|
||||
# Input name for the user
|
||||
label: str
|
||||
options: list[str] | None
|
||||
required: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiskSchema:
|
||||
name: str
|
||||
readme: str
|
||||
frontmatter: Frontmatter
|
||||
placeholders: dict[str, Placeholder]
|
||||
|
||||
|
||||
# must be manually kept in sync with the ${clancore}/templates/disks directory
|
||||
templates: dict[str, dict[str, Callable[[dict[str, Any]], Placeholder]]] = {
|
||||
"single-disk": {
|
||||
# Placeholders
|
||||
"mainDisk": lambda hw_report: Placeholder(
|
||||
label="Main disk", options=hw_main_disk_options(hw_report), required=True
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@API.register
|
||||
def get_disk_schemas(
|
||||
base_path: Path, machine_name: str | None = None
|
||||
) -> dict[str, DiskSchema]:
|
||||
"""
|
||||
Get the available disk schemas
|
||||
"""
|
||||
disk_templates = clan_templates(TemplateType.DISK)
|
||||
disk_schemas = {}
|
||||
hw_report = {}
|
||||
|
||||
if machine_name is not None:
|
||||
hw_report_path = HardwareConfig.NIXOS_FACTER.config_path(
|
||||
base_path, machine_name
|
||||
)
|
||||
if not hw_report_path.exists():
|
||||
msg = "Hardware configuration missing"
|
||||
raise ClanError(msg)
|
||||
with hw_report_path.open("r") as hw_report_file:
|
||||
hw_report = json.load(hw_report_file)
|
||||
|
||||
for disk_template in disk_templates.iterdir():
|
||||
if disk_template.is_dir():
|
||||
schema_name = disk_template.stem
|
||||
if schema_name not in templates:
|
||||
msg = f"Disk schema {schema_name} not found in templates {templates.keys()}"
|
||||
raise ClanError(
|
||||
msg,
|
||||
description="This is an internal architecture problem. Because disk schemas dont define their own interface",
|
||||
)
|
||||
|
||||
placeholder_getters = templates.get(schema_name)
|
||||
placeholders = {}
|
||||
|
||||
if placeholder_getters:
|
||||
placeholders = {k: v(hw_report) for k, v in placeholder_getters.items()}
|
||||
|
||||
raw_readme = (disk_template / "README.md").read_text()
|
||||
frontmatter, readme = extract_frontmatter(
|
||||
raw_readme, f"{disk_template}/README.md"
|
||||
)
|
||||
|
||||
disk_schemas[schema_name] = DiskSchema(
|
||||
name=schema_name,
|
||||
placeholders=placeholders,
|
||||
readme=readme,
|
||||
frontmatter=frontmatter,
|
||||
)
|
||||
|
||||
return disk_schemas
|
||||
|
||||
|
||||
class MachineDiskMatter(TypedDict):
|
||||
schema_name: str
|
||||
placeholders: dict[str, str]
|
||||
|
||||
|
||||
@API.register
|
||||
def set_machine_disk_schema(
|
||||
base_path: Path,
|
||||
machine_name: str,
|
||||
schema_name: str,
|
||||
# Placeholders are used to fill in the disk schema
|
||||
# Use get disk schemas to get the placeholders and their options
|
||||
placeholders: dict[str, str],
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Set the disk placeholders of the template
|
||||
"""
|
||||
# Assert the hw-config must exist before setting the disk
|
||||
hw_config = show_machine_hardware_config(base_path, machine_name)
|
||||
hw_config_path = hw_config.config_path(base_path, machine_name)
|
||||
|
||||
if not hw_config_path.exists():
|
||||
msg = "Hardware configuration must exist before applying disk schema"
|
||||
raise ClanError(msg)
|
||||
|
||||
if hw_config != HardwareConfig.NIXOS_FACTER:
|
||||
msg = "Hardware configuration must use type FACTER for applying disk schema automatically"
|
||||
raise ClanError(msg)
|
||||
|
||||
disk_schema_path = clan_templates(TemplateType.DISK) / f"{schema_name}/default.nix"
|
||||
|
||||
if not disk_schema_path.exists():
|
||||
msg = f"Disk schema not found at {disk_schema_path}"
|
||||
raise ClanError(msg)
|
||||
|
||||
# Check that the placeholders are valid
|
||||
disk_schema = get_disk_schemas(base_path, machine_name)[schema_name]
|
||||
# check that all required placeholders are present
|
||||
for placeholder_name, schema_placeholder in disk_schema.placeholders.items():
|
||||
if schema_placeholder.required and placeholder_name not in placeholders:
|
||||
msg = f"Required placeholder {placeholder_name} - {schema_placeholder} missing"
|
||||
raise ClanError(msg)
|
||||
|
||||
# For every placeholder check that the value is valid
|
||||
for placeholder_name, placeholder_value in placeholders.items():
|
||||
ph = disk_schema.placeholders.get(placeholder_name)
|
||||
|
||||
# Unknown placeholder
|
||||
if not ph:
|
||||
msg = (
|
||||
f"Placeholder {placeholder_name} not found in disk schema {schema_name}"
|
||||
)
|
||||
raise ClanError(
|
||||
msg,
|
||||
description=f"Available placeholders: {disk_schema.placeholders.keys()}",
|
||||
)
|
||||
|
||||
# Invalid value. Check if the value is one of the provided options
|
||||
if ph.options and placeholder_value not in ph.options:
|
||||
msg = (
|
||||
f"Invalid value {placeholder_value} for placeholder {placeholder_name}"
|
||||
)
|
||||
raise ClanError(msg, description=f"Valid options: {ph.options}")
|
||||
|
||||
placeholders_toml = "\n".join(
|
||||
[f"""# {k} = "{v}" """ for k, v in placeholders.items() if v is not None]
|
||||
)
|
||||
header = f"""# ---
|
||||
# schema = "{schema_name}"
|
||||
# [placeholders]
|
||||
{placeholders_toml}
|
||||
# ---
|
||||
# This file was automatically generated!
|
||||
# CHANGING this configuration requires wiping and reinstalling the machine
|
||||
"""
|
||||
with disk_schema_path.open("r") as disk_template:
|
||||
config_str = disk_template.read()
|
||||
for placeholder_name, placeholder_value in placeholders.items():
|
||||
config_str = config_str.replace(
|
||||
r"{{" + placeholder_name + r"}}", placeholder_value
|
||||
)
|
||||
|
||||
# Custom replacements
|
||||
config_str = config_str.replace(r"{{uuid}}", str(uuid4()).replace("-", ""))
|
||||
|
||||
# place disko.nix alongside the hw-config
|
||||
disko_file_path = hw_config_path.parent.joinpath("disko.nix")
|
||||
if disko_file_path.exists() and not force:
|
||||
msg = f"Disk schema already exists at {disko_file_path}"
|
||||
raise ClanError(msg, description="Use 'force' to overwrite")
|
||||
|
||||
with disko_file_path.open("w") as disk_config:
|
||||
disk_config.write(header)
|
||||
disk_config.write(config_str)
|
||||
|
||||
commit_file(
|
||||
disko_file_path,
|
||||
base_path,
|
||||
commit_message=f"Set disk schema of machine: {machine_name} to {schema_name}",
|
||||
)
|
||||
67
pkgs/clan-cli/clan_lib/api/iwd.py
Normal file
67
pkgs/clan-cli/clan_lib/api/iwd.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def instance_name(machine_name: str) -> str:
|
||||
return f"{machine_name}_wifi_0_"
|
||||
|
||||
|
||||
@dataclass
|
||||
class NetworkConfig:
|
||||
ssid: str
|
||||
password: str
|
||||
|
||||
|
||||
# @API.register
|
||||
# def set_iwd_service_for_machine(
|
||||
# base_url: str, machine_name: str, networks: dict[str, NetworkConfig]
|
||||
# ) -> None:
|
||||
# """
|
||||
# Set the admin service of a clan
|
||||
# Every machine is by default part of the admin service via the 'all' tag
|
||||
# """
|
||||
# _instance_name = instance_name(machine_name)
|
||||
|
||||
# inventory = load_inventory_eval(base_url)
|
||||
|
||||
# instance = ServiceIwd(
|
||||
# meta=ServiceMeta(name="wifi_0"),
|
||||
# roles=ServiceIwdRole(
|
||||
# default=ServiceIwdRoleDefault(
|
||||
# machines=[machine_name],
|
||||
# )
|
||||
# ),
|
||||
# config=IwdConfig(
|
||||
# networks={k: IwdConfigNetwork(v.ssid) for k, v in networks.items()}
|
||||
# ),
|
||||
# )
|
||||
|
||||
# inventory.services.iwd[_instance_name] = instance
|
||||
|
||||
# save_inventory(
|
||||
# inventory,
|
||||
# base_url,
|
||||
# f"Set iwd service: '{_instance_name}'",
|
||||
# )
|
||||
|
||||
# pubkey = maybe_get_public_key()
|
||||
# if not pubkey:
|
||||
# # TODO: do this automatically
|
||||
# # pubkey = generate_key()
|
||||
# raise ClanError(msg="No public key found. Please initialize your key.")
|
||||
|
||||
# registered_key = maybe_get_user_or_machine(Path(base_url), pubkey)
|
||||
# if not registered_key:
|
||||
# # TODO: do this automatically
|
||||
# # username = os.getlogin()
|
||||
# # add_user(Path(base_url), username, pubkey, force=False)
|
||||
# raise ClanError(msg="Your public key is not registered for use with this clan.")
|
||||
|
||||
# password_dict = {f"iwd.{net.ssid}": net.password for net in networks.values()}
|
||||
# for net in networks.values():
|
||||
# generate_facts(
|
||||
# service=f"iwd.{net.ssid}",
|
||||
# machines=[Machine(machine_name, FlakeId(base_url))],
|
||||
# regenerate=True,
|
||||
# # Just returns the password
|
||||
# prompt=lambda service, _msg: password_dict[service],
|
||||
# )
|
||||
116
pkgs/clan-cli/clan_lib/api/mdns_discovery.py
Normal file
116
pkgs/clan-cli/clan_lib/api/mdns_discovery.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import argparse
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
|
||||
from clan_cli.cmd import run_no_stdout
|
||||
from clan_cli.nix import nix_shell
|
||||
|
||||
from . import API
|
||||
|
||||
|
||||
@dataclass
|
||||
class Host:
|
||||
# Part of the discovery
|
||||
interface: str
|
||||
protocol: str
|
||||
name: str
|
||||
type_: str
|
||||
domain: str
|
||||
# Optional, only if more data is available
|
||||
host: str | None
|
||||
ip: str | None
|
||||
port: str | None
|
||||
txt: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DNSInfo:
|
||||
""" "
|
||||
mDNS/DNS-SD services discovered on the network
|
||||
"""
|
||||
|
||||
services: dict[str, Host]
|
||||
|
||||
|
||||
def decode_escapes(s: str) -> str:
|
||||
return re.sub(r"\\(\d{3})", lambda x: chr(int(x.group(1))), s)
|
||||
|
||||
|
||||
def parse_avahi_output(output: str) -> DNSInfo:
|
||||
dns_info = DNSInfo(services={})
|
||||
for line in output.splitlines():
|
||||
parts = line.split(";")
|
||||
# New service discovered
|
||||
# print(parts)
|
||||
if parts[0] == "+" and len(parts) >= 6:
|
||||
interface, protocol, name, type_, domain = parts[1:6]
|
||||
|
||||
name = decode_escapes(name)
|
||||
|
||||
dns_info.services[name] = Host(
|
||||
interface=interface,
|
||||
protocol=protocol,
|
||||
name=name,
|
||||
type_=decode_escapes(type_),
|
||||
domain=domain,
|
||||
host=None,
|
||||
ip=None,
|
||||
port=None,
|
||||
txt=None,
|
||||
)
|
||||
|
||||
# Resolved more data for already discovered services
|
||||
elif parts[0] == "=" and len(parts) >= 9:
|
||||
interface, protocol, name, type_, domain, host, ip, port = parts[1:9]
|
||||
|
||||
name = decode_escapes(name)
|
||||
|
||||
if name in dns_info.services:
|
||||
dns_info.services[name].host = decode_escapes(host)
|
||||
dns_info.services[name].ip = ip
|
||||
dns_info.services[name].port = port
|
||||
if len(parts) > 9:
|
||||
dns_info.services[name].txt = decode_escapes(parts[9])
|
||||
else:
|
||||
dns_info.services[name] = Host(
|
||||
interface=parts[1],
|
||||
protocol=parts[2],
|
||||
name=name,
|
||||
type_=decode_escapes(parts[4]),
|
||||
domain=parts[5],
|
||||
host=decode_escapes(parts[6]),
|
||||
ip=parts[7],
|
||||
port=parts[8],
|
||||
txt=decode_escapes(parts[9]) if len(parts) > 9 else None,
|
||||
)
|
||||
|
||||
return dns_info
|
||||
|
||||
|
||||
@API.register
|
||||
def show_mdns() -> DNSInfo:
|
||||
cmd = nix_shell(
|
||||
["avahi"],
|
||||
[
|
||||
"avahi-browse",
|
||||
"--all",
|
||||
"--resolve",
|
||||
"--parsable",
|
||||
"-l", # Ignore local services
|
||||
"--terminate",
|
||||
],
|
||||
)
|
||||
proc = run_no_stdout(cmd)
|
||||
data = parse_avahi_output(proc.stdout)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def mdns_command(args: argparse.Namespace) -> None:
|
||||
dns_info = show_mdns()
|
||||
for name, info in dns_info.services.items():
|
||||
print(f"Hostname: {name} - ip: {info.ip}")
|
||||
|
||||
|
||||
def register_mdns(parser: argparse.ArgumentParser) -> None:
|
||||
parser.set_defaults(func=mdns_command)
|
||||
246
pkgs/clan-cli/clan_lib/api/modules.py
Normal file
246
pkgs/clan-cli/clan_lib/api/modules.py
Normal file
@@ -0,0 +1,246 @@
|
||||
import json
|
||||
import re
|
||||
import tomllib
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from clan_cli.cmd import run_no_stdout
|
||||
from clan_cli.errors import ClanCmdError, ClanError
|
||||
from clan_cli.nix import nix_eval
|
||||
|
||||
from . import API
|
||||
|
||||
|
||||
class CategoryInfo(TypedDict):
|
||||
color: str
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Frontmatter:
|
||||
description: str
|
||||
categories: list[str] = field(default_factory=lambda: ["Uncategorized"])
|
||||
features: list[str] = field(default_factory=list)
|
||||
constraints: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def categories_info(self) -> dict[str, CategoryInfo]:
|
||||
category_map: dict[str, CategoryInfo] = {
|
||||
"AudioVideo": {
|
||||
"color": "#AEC6CF",
|
||||
"description": "Applications for presenting, creating, or processing multimedia (audio/video)",
|
||||
},
|
||||
"Audio": {"color": "#CFCFC4", "description": "Audio"},
|
||||
"Video": {"color": "#FFD1DC", "description": "Video"},
|
||||
"Development": {"color": "#F49AC2", "description": "Development"},
|
||||
"Education": {"color": "#B39EB5", "description": "Education"},
|
||||
"Game": {"color": "#FFB347", "description": "Game"},
|
||||
"Graphics": {"color": "#FF6961", "description": "Graphics"},
|
||||
"Social": {"color": "#76D7C4", "description": "Social"},
|
||||
"Network": {"color": "#77DD77", "description": "Network"},
|
||||
"Office": {"color": "#85C1E9", "description": "Office"},
|
||||
"Science": {"color": "#779ECB", "description": "Science"},
|
||||
"System": {"color": "#F5C3C0", "description": "System"},
|
||||
"Settings": {"color": "#03C03C", "description": "Settings"},
|
||||
"Utility": {"color": "#B19CD9", "description": "Utility"},
|
||||
"Uncategorized": {"color": "#C23B22", "description": "Uncategorized"},
|
||||
}
|
||||
|
||||
return category_map
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
for category in self.categories:
|
||||
if category not in self.categories_info:
|
||||
msg = f"Invalid category: {category}"
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def parse_frontmatter(readme_content: str) -> tuple[dict[str, Any] | None, str]:
|
||||
"""
|
||||
Extracts TOML frontmatter from a string
|
||||
|
||||
Raises:
|
||||
- ClanError: If the toml frontmatter is invalid
|
||||
"""
|
||||
# Pattern to match YAML frontmatter enclosed by triple-dashed lines
|
||||
frontmatter_pattern = r"^---\s+(.*?)\s+---\s?+(.*)$"
|
||||
|
||||
# Search for the frontmatter using the pattern
|
||||
match = re.search(frontmatter_pattern, readme_content, re.DOTALL)
|
||||
|
||||
# If a match is found, return the frontmatter content
|
||||
match = re.search(frontmatter_pattern, readme_content, re.DOTALL)
|
||||
|
||||
# If a match is found, parse the TOML frontmatter and return both parts
|
||||
if match:
|
||||
frontmatter_raw, remaining_content = match.groups()
|
||||
try:
|
||||
# Parse the TOML frontmatter
|
||||
frontmatter_parsed = tomllib.loads(frontmatter_raw)
|
||||
except tomllib.TOMLDecodeError as e:
|
||||
msg = f"Error parsing TOML frontmatter: {e}"
|
||||
raise ClanError(
|
||||
msg,
|
||||
description="Invalid TOML frontmatter",
|
||||
location="extract_frontmatter",
|
||||
) from e
|
||||
|
||||
return frontmatter_parsed, remaining_content
|
||||
return None, readme_content
|
||||
|
||||
|
||||
def extract_frontmatter(readme_content: str, err_scope: str) -> tuple[Frontmatter, str]:
|
||||
"""
|
||||
Extracts TOML frontmatter from a README file content.
|
||||
|
||||
Parameters:
|
||||
- readme_content (str): The content of the README file as a string.
|
||||
|
||||
Returns:
|
||||
- str: The extracted frontmatter as a string.
|
||||
- str: The content of the README file without the frontmatter.
|
||||
|
||||
Raises:
|
||||
- ValueError: If the README does not contain valid frontmatter.
|
||||
"""
|
||||
frontmatter_raw, remaining_content = parse_frontmatter(readme_content)
|
||||
|
||||
if frontmatter_raw:
|
||||
return Frontmatter(**frontmatter_raw), remaining_content
|
||||
|
||||
# If no frontmatter is found, raise an error
|
||||
msg = "Invalid README: Frontmatter not found."
|
||||
raise ClanError(
|
||||
msg,
|
||||
location="extract_frontmatter",
|
||||
description=f"{err_scope} does not contain valid frontmatter.",
|
||||
)
|
||||
|
||||
|
||||
def has_inventory_feature(module_path: Path) -> bool:
|
||||
readme_file = module_path / "README.md"
|
||||
if not readme_file.exists():
|
||||
return False
|
||||
with readme_file.open() as f:
|
||||
readme = f.read()
|
||||
frontmatter, _ = extract_frontmatter(readme, f"{module_path}")
|
||||
return "inventory" in frontmatter.features
|
||||
|
||||
|
||||
def get_roles(module_path: Path) -> None | list[str]:
|
||||
if not has_inventory_feature(module_path):
|
||||
return None
|
||||
|
||||
roles_dir = module_path / "roles"
|
||||
if not roles_dir.exists() or not roles_dir.is_dir():
|
||||
return []
|
||||
|
||||
return [
|
||||
role.stem # filename without .nix extension
|
||||
for role in roles_dir.iterdir()
|
||||
if role.is_file() and role.suffix == ".nix"
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleInfo:
|
||||
description: str
|
||||
readme: str
|
||||
categories: list[str]
|
||||
roles: list[str] | None
|
||||
features: list[str] = field(default_factory=list)
|
||||
constraints: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def get_modules(base_path: str) -> dict[str, str]:
|
||||
cmd = nix_eval(
|
||||
[
|
||||
f"{base_path}#clanInternals.inventory.modules",
|
||||
"--json",
|
||||
]
|
||||
)
|
||||
try:
|
||||
proc = run_no_stdout(cmd)
|
||||
res = proc.stdout.strip()
|
||||
except ClanCmdError as e:
|
||||
msg = "clanInternals might not have inventory.modules attributes"
|
||||
raise ClanError(
|
||||
msg,
|
||||
location=f"list_modules {base_path}",
|
||||
description="Evaluation failed on clanInternals.inventory.modules attribute",
|
||||
) from e
|
||||
modules: dict[str, str] = json.loads(res)
|
||||
return modules
|
||||
|
||||
|
||||
@API.register
|
||||
def get_module_interface(base_path: str, module_name: str) -> dict[str, Any]:
|
||||
"""
|
||||
Check if a module exists and returns the interface schema
|
||||
Returns an error if the module does not exist or has an incorrect interface
|
||||
"""
|
||||
cmd = nix_eval([f"{base_path}#clanInternals.moduleSchemas.{module_name}", "--json"])
|
||||
try:
|
||||
proc = run_no_stdout(cmd)
|
||||
res = proc.stdout.strip()
|
||||
except ClanCmdError as e:
|
||||
msg = "clanInternals might not have moduleSchemas attributes"
|
||||
raise ClanError(
|
||||
msg,
|
||||
location=f"list_modules {base_path}",
|
||||
description="Evaluation failed on clanInternals.moduleSchemas attribute",
|
||||
) from e
|
||||
modules_schema: dict[str, Any] = json.loads(res)
|
||||
|
||||
return modules_schema
|
||||
|
||||
|
||||
@API.register
|
||||
def list_modules(base_path: str) -> dict[str, ModuleInfo]:
|
||||
"""
|
||||
Show information about a module
|
||||
"""
|
||||
modules = get_modules(base_path)
|
||||
return {
|
||||
module_name: get_module_info(module_name, Path(module_path))
|
||||
for module_name, module_path in modules.items()
|
||||
}
|
||||
|
||||
|
||||
def get_module_info(
|
||||
module_name: str,
|
||||
module_path: Path,
|
||||
) -> ModuleInfo:
|
||||
"""
|
||||
Retrieves information about a module
|
||||
"""
|
||||
if not module_path.exists():
|
||||
msg = "Module not found"
|
||||
raise ClanError(
|
||||
msg,
|
||||
location=f"show_module_info {module_name}",
|
||||
description="Module does not exist",
|
||||
)
|
||||
module_readme = module_path / "README.md"
|
||||
if not module_readme.exists():
|
||||
msg = "Module not found"
|
||||
raise ClanError(
|
||||
msg,
|
||||
location=f"show_module_info {module_name}",
|
||||
description="Module does not exist or doesn't have any README.md file",
|
||||
)
|
||||
with module_readme.open() as f:
|
||||
readme = f.read()
|
||||
frontmatter, readme_content = extract_frontmatter(
|
||||
readme, f"{module_path}/README.md"
|
||||
)
|
||||
|
||||
return ModuleInfo(
|
||||
description=frontmatter.description,
|
||||
categories=frontmatter.categories,
|
||||
roles=get_roles(module_path),
|
||||
readme=readme_content,
|
||||
features=["inventory"] if has_inventory_feature(module_path) else [],
|
||||
constraints=frontmatter.constraints,
|
||||
)
|
||||
331
pkgs/clan-cli/clan_lib/api/serde.py
Normal file
331
pkgs/clan-cli/clan_lib/api/serde.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
This module provides utility functions for serialization and deserialization of data classes.
|
||||
|
||||
Functions:
|
||||
- sanitize_string(s: str) -> str: Ensures a string is properly escaped for json serializing.
|
||||
- dataclass_to_dict(obj: Any) -> Any: Converts a data class and its nested data classes, lists, tuples, and dictionaries to dictionaries.
|
||||
- from_dict(t: type[T], data: Any) -> T: Dynamically instantiates a data class from a dictionary, constructing nested data classes, validates all required fields exist and have the expected type.
|
||||
|
||||
Classes:
|
||||
- TypeAdapter: A Pydantic type adapter for data classes.
|
||||
|
||||
Exceptions:
|
||||
- ValidationError: Raised when there is a validation error during deserialization.
|
||||
- ClanError: Raised when there is an error during serialization or deserialization.
|
||||
|
||||
Dependencies:
|
||||
- dataclasses: Provides the @dataclass decorator and related functions for creating data classes.
|
||||
- json: Provides functions for working with JSON data.
|
||||
- collections.abc: Provides abstract base classes for collections.
|
||||
- functools: Provides functions for working with higher-order functions and decorators.
|
||||
- inspect: Provides functions for inspecting live objects.
|
||||
- operator: Provides functions for working with operators.
|
||||
- pathlib: Provides classes for working with filesystem paths.
|
||||
- types: Provides functions for working with types.
|
||||
- typing: Provides support for type hints.
|
||||
- pydantic: A library for data validation and settings management.
|
||||
- pydantic_core: Core functionality for Pydantic.
|
||||
|
||||
Note: This module assumes the presence of other modules and classes such as `ClanError` and `ErrorDetails` from the `clan_cli.errors` module.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from dataclasses import dataclass, fields, is_dataclass
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from types import UnionType
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
is_typeddict,
|
||||
)
|
||||
|
||||
from clan_cli.errors import ClanError
|
||||
|
||||
|
||||
def sanitize_string(s: str) -> str:
|
||||
# Using the native string sanitizer to handle all edge cases
|
||||
# Remove the outer quotes '"string"'
|
||||
# return json.dumps(s)[1:-1]
|
||||
return s
|
||||
|
||||
|
||||
def is_enum(obj: Any) -> bool:
|
||||
"""
|
||||
Safely checks if the object or one of its attributes is an Enum.
|
||||
"""
|
||||
# Check if the object itself is an Enum
|
||||
if isinstance(obj, Enum):
|
||||
return True
|
||||
|
||||
# Check if the object has an 'enum' attribute and if it's an Enum
|
||||
enum_attr = getattr(obj, "enum", None)
|
||||
return isinstance(enum_attr, Enum)
|
||||
|
||||
|
||||
def get_enum_value(obj: Any) -> Any:
|
||||
"""
|
||||
Safely checks if the object or one of its attributes is an Enum.
|
||||
"""
|
||||
# Check if the object itself is an Enum
|
||||
value = getattr(obj, "value", None)
|
||||
if value is None and obj.enum:
|
||||
value = getattr(obj.enum, "value", None)
|
||||
|
||||
if value is None:
|
||||
error_msg = f"Cannot determine enum value for {obj}"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return dataclass_to_dict(value)
|
||||
|
||||
|
||||
def dataclass_to_dict(obj: Any, *, use_alias: bool = True) -> Any:
|
||||
def _to_dict(obj: Any) -> Any:
|
||||
"""
|
||||
Utility function to convert dataclasses to dictionaries
|
||||
It converts all nested dataclasses, lists, tuples, and dictionaries to dictionaries
|
||||
|
||||
It does NOT convert member functions.
|
||||
"""
|
||||
if is_enum(obj):
|
||||
return get_enum_value(obj)
|
||||
if is_dataclass(obj):
|
||||
return {
|
||||
# Use either the original name or name
|
||||
sanitize_string(
|
||||
field.metadata.get("alias", field.name) if use_alias else field.name
|
||||
): _to_dict(getattr(obj, field.name))
|
||||
for field in fields(obj)
|
||||
if not field.name.startswith("_")
|
||||
and getattr(obj, field.name) is not None # type: ignore
|
||||
}
|
||||
if isinstance(obj, list | tuple | set):
|
||||
return [_to_dict(item) for item in obj]
|
||||
if isinstance(obj, dict):
|
||||
return {sanitize_string(k): _to_dict(v) for k, v in obj.items()}
|
||||
if isinstance(obj, Path):
|
||||
return sanitize_string(str(obj))
|
||||
if isinstance(obj, str):
|
||||
return sanitize_string(obj)
|
||||
return obj
|
||||
|
||||
return _to_dict(obj)
|
||||
|
||||
|
||||
T = TypeVar("T", bound=dataclass) # type: ignore
|
||||
|
||||
|
||||
def is_union_type(type_hint: type | UnionType) -> bool:
|
||||
return (
|
||||
type(type_hint) is UnionType
|
||||
or isinstance(type_hint, UnionType)
|
||||
or get_origin(type_hint) is Union
|
||||
)
|
||||
|
||||
|
||||
def is_type_in_union(union_type: type | UnionType, target_type: type) -> bool:
|
||||
if get_origin(union_type) is UnionType:
|
||||
return any(issubclass(arg, target_type) for arg in get_args(union_type))
|
||||
return union_type == target_type
|
||||
|
||||
|
||||
def unwrap_none_type(type_hint: type | UnionType) -> type:
|
||||
"""
|
||||
Takes a type union and returns the first non-None type.
|
||||
None | str
|
||||
=>
|
||||
str
|
||||
"""
|
||||
|
||||
if is_union_type(type_hint):
|
||||
# Return the first non-None type
|
||||
return next(t for t in get_args(type_hint) if t is not type(None))
|
||||
|
||||
return type_hint # type: ignore
|
||||
|
||||
|
||||
JsonValue = str | float | dict[str, Any] | list[Any] | None
|
||||
|
||||
|
||||
def construct_value(
|
||||
t: type | UnionType, field_value: JsonValue, loc: list[str] | None = None
|
||||
) -> Any:
|
||||
"""
|
||||
Construct a field value from a type hint and a field value.
|
||||
"""
|
||||
if loc is None:
|
||||
loc = []
|
||||
if t is None and field_value:
|
||||
msg = f"Trying to construct field of type None. But got: {field_value}. loc: {loc}"
|
||||
raise ClanError(msg, location=f"{loc}")
|
||||
|
||||
if is_type_in_union(t, type(None)) and field_value is None:
|
||||
# Sometimes the field value is None, which is valid if the type hint allows None
|
||||
return None
|
||||
|
||||
# If the field is another dataclass
|
||||
# Field_value must be a dictionary
|
||||
if is_dataclass(t) and isinstance(field_value, dict):
|
||||
assert isinstance(t, type)
|
||||
return construct_dataclass(t, field_value)
|
||||
|
||||
# If the field expects a path
|
||||
# Field_value must be a string
|
||||
if is_type_in_union(t, Path):
|
||||
if not isinstance(field_value, str):
|
||||
msg = (
|
||||
f"Expected string, cannot construct pathlib.Path() from: {field_value} "
|
||||
)
|
||||
raise ClanError(
|
||||
msg,
|
||||
location=f"{loc}",
|
||||
)
|
||||
|
||||
return Path(field_value)
|
||||
|
||||
if t is str:
|
||||
if not isinstance(field_value, str):
|
||||
msg = f"Expected string, got {field_value}"
|
||||
raise ClanError(msg, location=f"{loc}")
|
||||
|
||||
return field_value
|
||||
|
||||
if t is int and not isinstance(field_value, str):
|
||||
return int(field_value) # type: ignore
|
||||
if t is float and not isinstance(field_value, str):
|
||||
return float(field_value) # type: ignore
|
||||
if t is bool and isinstance(field_value, bool):
|
||||
return field_value # type: ignore
|
||||
|
||||
# Union types construct the first non-None type
|
||||
if is_union_type(t):
|
||||
# Unwrap the union type
|
||||
inner = unwrap_none_type(t)
|
||||
# Construct the field value
|
||||
return construct_value(inner, field_value)
|
||||
|
||||
# Nested types
|
||||
# list
|
||||
# dict
|
||||
origin = get_origin(t)
|
||||
if origin is list:
|
||||
if not isinstance(field_value, list):
|
||||
msg = f"Expected list, got {field_value}"
|
||||
raise ClanError(msg, location=f"{loc}")
|
||||
|
||||
return [construct_value(get_args(t)[0], item) for item in field_value]
|
||||
if origin is dict and isinstance(field_value, dict):
|
||||
return {
|
||||
key: construct_value(get_args(t)[1], value)
|
||||
for key, value in field_value.items()
|
||||
}
|
||||
if origin is Literal:
|
||||
valid_values = get_args(t)
|
||||
if field_value not in valid_values:
|
||||
msg = f"Expected one of {', '.join(valid_values)}, got {field_value}"
|
||||
raise ClanError(msg, location=f"{loc}")
|
||||
return field_value
|
||||
|
||||
# Enums
|
||||
if origin is Enum:
|
||||
try:
|
||||
return t(field_value) # type: ignore
|
||||
except ValueError:
|
||||
msg = f"Expected one of {', '.join(str(origin))}, got {field_value}"
|
||||
raise ClanError(msg, location=f"{loc}") from ValueError
|
||||
|
||||
if isinstance(t, type) and issubclass(t, Enum):
|
||||
try:
|
||||
return t(field_value) # type: ignore
|
||||
except ValueError:
|
||||
msg = f"Expected one of {', '.join(t.__members__)}, got {field_value}"
|
||||
raise ClanError(msg, location=f"{loc}") from ValueError
|
||||
|
||||
if origin is Annotated:
|
||||
(base_type,) = get_args(t)
|
||||
return construct_value(base_type, field_value)
|
||||
|
||||
# elif get_origin(t) is Union:
|
||||
if t is Any:
|
||||
return field_value
|
||||
|
||||
if is_typeddict(t):
|
||||
if not isinstance(field_value, dict):
|
||||
msg = f"Expected TypedDict {t}, got {field_value}"
|
||||
raise ClanError(msg, location=f"{loc}")
|
||||
|
||||
return t(field_value) # type: ignore
|
||||
|
||||
msg = f"Unhandled field type {t} with value {field_value}"
|
||||
raise ClanError(msg)
|
||||
|
||||
|
||||
def construct_dataclass(
|
||||
t: type[T], data: dict[str, Any], path: list[str] | None = None
|
||||
) -> T:
|
||||
"""
|
||||
type t MUST be a dataclass
|
||||
Dynamically instantiate a data class from a dictionary, handling nested data classes.
|
||||
"""
|
||||
if path is None:
|
||||
path = []
|
||||
if not is_dataclass(t):
|
||||
msg = f"{t.__name__} is not a dataclass"
|
||||
raise ClanError(msg)
|
||||
|
||||
# Attempt to create an instance of the data_class#
|
||||
field_values: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
|
||||
for field in fields(t):
|
||||
if field.name.startswith("_"):
|
||||
continue
|
||||
# The first type in a Union
|
||||
# str <- None | str | Path
|
||||
field_type: type[Any] = unwrap_none_type(field.type) # type: ignore
|
||||
data_field_name = field.metadata.get("alias", field.name)
|
||||
|
||||
if (
|
||||
field.default is dataclasses.MISSING
|
||||
and field.default_factory is dataclasses.MISSING
|
||||
):
|
||||
required.append(field.name)
|
||||
|
||||
# Populate the field_values dictionary with the field value
|
||||
# if present in the data
|
||||
if data_field_name in data:
|
||||
field_value = data.get(data_field_name)
|
||||
|
||||
if field_value is None and (
|
||||
field.type is None or is_type_in_union(field.type, type(None)) # type: ignore
|
||||
):
|
||||
field_values[field.name] = None
|
||||
else:
|
||||
field_values[field.name] = construct_value(field_type, field_value)
|
||||
|
||||
# Check that all required field are present.
|
||||
for field_name in required:
|
||||
if field_name not in field_values:
|
||||
formatted_path = " ".join(path)
|
||||
msg = f"Default value missing for: '{field_name}' in {t} {formatted_path}, got Value: {data}"
|
||||
raise ClanError(msg)
|
||||
|
||||
return t(**field_values) # type: ignore
|
||||
|
||||
|
||||
def from_dict(
|
||||
t: type | UnionType, data: dict[str, Any] | Any, path: list[str] | None = None
|
||||
) -> Any:
|
||||
if path is None:
|
||||
path = []
|
||||
if is_dataclass(t):
|
||||
if not isinstance(data, dict):
|
||||
msg = f"{data} is not a dict. Expected {t}"
|
||||
raise ClanError(msg)
|
||||
return construct_dataclass(t, data, path) # type: ignore
|
||||
return construct_value(t, data, path)
|
||||
291
pkgs/clan-cli/clan_lib/api/util.py
Normal file
291
pkgs/clan-cli/clan_lib/api/util.py
Normal file
@@ -0,0 +1,291 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
import pathlib
|
||||
from dataclasses import MISSING
|
||||
from enum import EnumType
|
||||
from inspect import get_annotations
|
||||
from types import NoneType, UnionType
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
NewType,
|
||||
NotRequired,
|
||||
Required,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
is_typeddict,
|
||||
)
|
||||
|
||||
from clan_lib.api.serde import dataclass_to_dict
|
||||
|
||||
|
||||
class JSchemaTypeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# Inspect the fields of the parameterized type
|
||||
def inspect_dataclass_fields(t: type) -> dict[TypeVar, type]:
|
||||
"""
|
||||
Returns a map of type variables to actual types for a parameterized type.
|
||||
"""
|
||||
origin = get_origin(t)
|
||||
type_args = get_args(t)
|
||||
if origin is None:
|
||||
return {}
|
||||
|
||||
type_params = origin.__parameters__
|
||||
# Create a map from type parameters to actual type arguments
|
||||
type_map = dict(zip(type_params, type_args, strict=False))
|
||||
|
||||
return type_map
|
||||
|
||||
|
||||
def apply_annotations(schema: dict[str, Any], annotations: list[Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Add metadata from typing.annotations to the json Schema.
|
||||
The annotations can be a dict, a tuple, or a string and is directly applied to the schema as shown below.
|
||||
No further validation is done, the caller is responsible for following json-schema.
|
||||
|
||||
Examples
|
||||
|
||||
```python
|
||||
# String annotation
|
||||
Annotated[int, "This is an int"] -> {"type": "integer", "description": "This is an int"}
|
||||
|
||||
# Dict annotation
|
||||
Annotated[int, {"minimum": 0, "maximum": 10}] -> {"type": "integer", "minimum": 0, "maximum": 10}
|
||||
|
||||
# Tuple annotation
|
||||
Annotated[int, ("minimum", 0)] -> {"type": "integer", "minimum": 0}
|
||||
```
|
||||
"""
|
||||
for annotation in annotations:
|
||||
if isinstance(annotation, dict):
|
||||
# Assuming annotation is a dict that can directly apply to the schema
|
||||
schema.update(annotation)
|
||||
elif isinstance(annotation, tuple) and len(annotation) == 2:
|
||||
# Assuming a tuple where first element is a keyword (like 'minLength') and the second is the value
|
||||
schema[annotation[0]] = annotation[1]
|
||||
elif isinstance(annotation, str):
|
||||
# String annotations can be used for description
|
||||
schema.update({"description": f"{annotation}"})
|
||||
return schema
|
||||
|
||||
|
||||
def is_typed_dict(t: type) -> bool:
|
||||
return is_typeddict(t)
|
||||
|
||||
|
||||
# Function to get member names and their types
|
||||
def get_typed_dict_fields(typed_dict_class: type, scope: str) -> dict[str, type]:
|
||||
"""Retrieve member names and their types from a TypedDict."""
|
||||
if not hasattr(typed_dict_class, "__annotations__"):
|
||||
msg = f"{typed_dict_class} is not a TypedDict."
|
||||
raise JSchemaTypeError(msg, scope)
|
||||
return get_annotations(typed_dict_class)
|
||||
|
||||
|
||||
def is_type_in_union(union_type: type | UnionType, target_type: type) -> bool:
|
||||
if get_origin(union_type) is UnionType:
|
||||
return any(issubclass(arg, target_type) for arg in get_args(union_type))
|
||||
return union_type == target_type
|
||||
|
||||
|
||||
def is_total(typed_dict_class: type) -> bool:
|
||||
"""
|
||||
Check if a TypedDict has total=true
|
||||
https://typing.readthedocs.io/en/latest/spec/typeddict.html#interaction-with-total-false
|
||||
"""
|
||||
return getattr(typed_dict_class, "__total__", True) # Default to True if not set
|
||||
|
||||
|
||||
def type_to_dict(
|
||||
t: Any, scope: str = "", type_map: dict[TypeVar, type] | None = None
|
||||
) -> dict:
|
||||
if type_map is None:
|
||||
type_map = {}
|
||||
if t is None:
|
||||
return {"type": "null"}
|
||||
|
||||
if dataclasses.is_dataclass(t):
|
||||
fields = dataclasses.fields(t)
|
||||
properties = {}
|
||||
for f in fields:
|
||||
if f.name.startswith("_"):
|
||||
continue
|
||||
assert not isinstance(f.type, str), (
|
||||
f"Expected field type to be a type, got {f.type}, Have you imported `from __future__ import annotations`?"
|
||||
)
|
||||
properties[f.metadata.get("alias", f.name)] = type_to_dict(
|
||||
f.type,
|
||||
f"{scope} {t.__name__}.{f.name}", # type: ignore
|
||||
type_map, # type: ignore
|
||||
)
|
||||
|
||||
required = set()
|
||||
for pn, pv in properties.items():
|
||||
if pv.get("type") is not None:
|
||||
if "null" not in pv["type"]:
|
||||
required.add(pn)
|
||||
|
||||
elif pv.get("oneOf") is not None:
|
||||
if "null" not in [i.get("type") for i in pv.get("oneOf", [])]:
|
||||
required.add(pn)
|
||||
|
||||
required_fields = {
|
||||
f.name
|
||||
for f in fields
|
||||
if f.default is MISSING and f.default_factory is MISSING
|
||||
}
|
||||
|
||||
# Find intersection
|
||||
intersection = required & required_fields
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": list(intersection),
|
||||
# Dataclasses can only have the specified properties
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
if is_typed_dict(t):
|
||||
dict_fields = get_typed_dict_fields(t, scope)
|
||||
dict_properties: dict = {}
|
||||
dict_required: list[str] = []
|
||||
for field_name, field_type in dict_fields.items():
|
||||
if (
|
||||
not is_type_in_union(field_type, type(None))
|
||||
and get_origin(field_type) is not NotRequired
|
||||
) or get_origin(field_type) is Required:
|
||||
dict_required.append(field_name)
|
||||
|
||||
dict_properties[field_name] = type_to_dict(
|
||||
field_type, f"{scope} {t.__name__}.{field_name}", type_map
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": dict_properties,
|
||||
"required": dict_required if is_total(t) else [],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
if type(t) is UnionType:
|
||||
return {
|
||||
"oneOf": [type_to_dict(arg, scope, type_map) for arg in t.__args__],
|
||||
}
|
||||
|
||||
if isinstance(t, TypeVar):
|
||||
# if t is a TypeVar, look up the type in the type_map
|
||||
# And return the resolved type instead of the TypeVar
|
||||
resolved = type_map.get(t)
|
||||
if not resolved:
|
||||
msg = f"{scope} - TypeVar {t} not found in type_map, map: {type_map}"
|
||||
raise JSchemaTypeError(msg)
|
||||
return type_to_dict(type_map.get(t), scope, type_map)
|
||||
|
||||
if isinstance(t, NewType):
|
||||
origtype = t.__supertype__
|
||||
return type_to_dict(origtype, scope, type_map)
|
||||
|
||||
if hasattr(t, "__origin__"): # Check if it's a generic type
|
||||
origin = get_origin(t)
|
||||
args = get_args(t)
|
||||
|
||||
if origin is None:
|
||||
# Non-generic user-defined or built-in type
|
||||
# TODO: handle custom types
|
||||
msg = f"{scope} Unhandled Type: "
|
||||
raise JSchemaTypeError(msg, origin)
|
||||
|
||||
if origin is Literal:
|
||||
# Handle Literal values for enums in JSON Schema
|
||||
return {
|
||||
"type": "string",
|
||||
"enum": list(args), # assumes all args are strings
|
||||
}
|
||||
|
||||
if origin is Annotated:
|
||||
base_type, *metadata = get_args(t)
|
||||
schema = type_to_dict(base_type, scope) # Generate schema for the base type
|
||||
return apply_annotations(schema, metadata)
|
||||
|
||||
if origin is Union:
|
||||
union_types = [type_to_dict(arg, scope, type_map) for arg in t.__args__]
|
||||
return {
|
||||
"oneOf": union_types,
|
||||
}
|
||||
|
||||
if origin in {list, set, frozenset, tuple}:
|
||||
return {
|
||||
"type": "array",
|
||||
"items": type_to_dict(t.__args__[0], scope, type_map),
|
||||
}
|
||||
|
||||
# Used to mark optional fields in TypedDict
|
||||
# Here we just unwrap the type and return the schema for the inner type
|
||||
if origin is NotRequired or origin is Required:
|
||||
return type_to_dict(t.__args__[0], scope, type_map)
|
||||
|
||||
if issubclass(origin, dict):
|
||||
value_type = t.__args__[1]
|
||||
if value_type is Any:
|
||||
return {"type": "object", "additionalProperties": True}
|
||||
return {
|
||||
"type": "object",
|
||||
"additionalProperties": type_to_dict(value_type, scope, type_map),
|
||||
}
|
||||
# Generic dataclass with type parameters
|
||||
if dataclasses.is_dataclass(origin):
|
||||
# This behavior should mimic the scoping of typeVars in dataclasses
|
||||
# Once type_to_dict() encounters a TypeVar, it will look up the type in the type_map
|
||||
# When type_to_dict() returns the map goes out of scope.
|
||||
# This behaves like a stack, where the type_map is pushed and popped as we traverse the dataclass fields
|
||||
new_map = copy.deepcopy(type_map)
|
||||
new_map.update(inspect_dataclass_fields(t))
|
||||
return type_to_dict(origin, scope, new_map)
|
||||
|
||||
msg = f"{scope} - Error api type not yet supported {t!s}"
|
||||
raise JSchemaTypeError(msg)
|
||||
|
||||
if isinstance(t, type):
|
||||
if t is str:
|
||||
return {"type": "string"}
|
||||
if t is int:
|
||||
return {"type": "integer"}
|
||||
if t is float:
|
||||
return {"type": "number"}
|
||||
if t is bool:
|
||||
return {"type": "boolean"}
|
||||
if t is object:
|
||||
return {"type": "object"}
|
||||
if type(t) is EnumType:
|
||||
return {
|
||||
"type": "string",
|
||||
# Construct every enum value and use the same method as the serde module for converting it into the same literal string
|
||||
"enum": [dataclass_to_dict(t(value)) for value in t], # type: ignore
|
||||
}
|
||||
if t is Any:
|
||||
msg = f"{scope} - Usage of the Any type is not supported for API functions. In: {scope}"
|
||||
raise JSchemaTypeError(msg)
|
||||
if t is pathlib.Path:
|
||||
return {
|
||||
# TODO: maybe give it a pattern for URI
|
||||
"type": "string",
|
||||
}
|
||||
if t is dict:
|
||||
msg = f"{scope} - Generic 'dict' type not supported. Use dict[str, Any] or any more expressive type."
|
||||
raise JSchemaTypeError(msg)
|
||||
|
||||
# Optional[T] gets internally transformed Union[T,NoneType]
|
||||
if t is NoneType:
|
||||
return {"type": "null"}
|
||||
|
||||
msg = f"{scope} - Basic type '{t!s}' is not supported"
|
||||
raise JSchemaTypeError(msg)
|
||||
msg = f"{scope} - Type '{t!s}' is not supported"
|
||||
raise JSchemaTypeError(msg)
|
||||
Reference in New Issue
Block a user