Merge pull request 'clan-app: WebExecutor now mirrors jsonschema api types generically' (#1768) from Qubasa/clan-core:Qubasa-main into main

This commit is contained in:
clan-bot
2024-07-16 14:38:17 +00:00
5 changed files with 78 additions and 55 deletions

View File

@@ -1,6 +1,14 @@
import inspect
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from typing import Any, ClassVar, Generic, ParamSpec, TypeVar, cast from typing import (
Any,
ClassVar,
Generic,
ParamSpec,
TypeVar,
cast,
)
from gi.repository import GLib, GObject from gi.repository import GLib, GObject
@@ -12,9 +20,8 @@ class GResult(GObject.Object):
op_key: str op_key: str
method_name: str method_name: str
def __init__(self, result: Any, method_name: str, op_key: str) -> None: def __init__(self, result: Any, method_name: str) -> None:
super().__init__() super().__init__()
self.op_key = op_key
self.result = result self.result = result
self.method_name = method_name self.method_name = method_name
@@ -32,9 +39,13 @@ class ImplFunc(GObject.Object, Generic[P, B]):
def returns(self, result: B, *, method_name: str | None = None) -> None: def returns(self, result: B, *, method_name: str | None = None) -> None:
if method_name is None: if method_name is None:
method_name = self.__class__.__name__ method_name = self.__class__.__name__
if self.op_key is None:
raise ValueError(f"op_key is not set for the function {method_name}") self.emit("returns", GResult(result, method_name))
self.emit("returns", GResult(result, method_name, self.op_key))
def _signature_check(self, *args: P.args, **kwargs: P.kwargs) -> B:
raise RuntimeError(
"This method is only for typechecking and should never be called"
)
def await_result(self, fn: Callable[["ImplFunc[..., Any]", B], None]) -> None: def await_result(self, fn: Callable[["ImplFunc[..., Any]", B], None]) -> None:
self.connect("returns", fn) self.connect("returns", fn)
@@ -42,8 +53,7 @@ class ImplFunc(GObject.Object, Generic[P, B]):
def async_run(self, *args: P.args, **kwargs: P.kwargs) -> bool: def async_run(self, *args: P.args, **kwargs: P.kwargs) -> bool:
raise NotImplementedError("Method 'async_run' must be implemented") raise NotImplementedError("Method 'async_run' must be implemented")
def _async_run(self, data: Any, op_key: str) -> bool: def _async_run(self, data: Any) -> bool:
self.op_key = op_key
result = GLib.SOURCE_REMOVE result = GLib.SOURCE_REMOVE
try: try:
result = self.async_run(**data) result = self.async_run(**data)
@@ -62,33 +72,33 @@ class GObjApi:
def overwrite_fn(self, obj: type[ImplFunc]) -> None: def overwrite_fn(self, obj: type[ImplFunc]) -> None:
fn_name = obj.__name__ fn_name = obj.__name__
if not isinstance(obj, type(ImplFunc)):
raise ValueError(f"Object '{fn_name}' is not an instance of ImplFunc")
if fn_name in self._obj_registry: if fn_name in self._obj_registry:
raise ValueError(f"Function '{fn_name}' already registered") raise ValueError(f"Function '{fn_name}' already registered")
self._obj_registry[fn_name] = obj self._obj_registry[fn_name] = obj
def check_signature(self, method_annotations: dict[str, dict[str, Any]]) -> None: def check_signature(self, fn_signatures: dict[str, inspect.Signature]) -> None:
overwrite_fns = self._obj_registry overwrite_fns = self._obj_registry
# iterate over the methods and check if all are implemented # iterate over the methods and check if all are implemented
for m_name, m_annotations in method_annotations.items(): for m_name, m_signature in fn_signatures.items():
if m_name not in overwrite_fns: if m_name not in overwrite_fns:
continue continue
else: else:
# check if the signature of the abstract method matches the implementation # check if the signature of the overriden method matches
# abstract signature # the implementation signature
values = list(m_annotations.values()) exp_args = []
expected_signature = (tuple(values[:-1]), values[-1:][0]) exp_return = m_signature.return_annotation
for param in dict(m_signature.parameters).values():
exp_args.append(param.annotation)
exp_signature = (tuple(exp_args), exp_return)
# implementation signature # implementation signature
obj = dict(overwrite_fns[m_name].__dict__) obj = dict(overwrite_fns[m_name].__dict__)
obj_type = obj["__orig_bases__"][0] obj_type = obj["__orig_bases__"][0]
got_signature = obj_type.__args__ got_signature = obj_type.__args__
if expected_signature != got_signature: if exp_signature != got_signature:
log.error(f"Expected signature: {expected_signature}") log.error(f"Expected signature: {exp_signature}")
log.error(f"Actual signature: {got_signature}") log.error(f"Actual signature: {got_signature}")
raise ValueError( raise ValueError(
f"Overwritten method '{m_name}' has different signature than the implementation" f"Overwritten method '{m_name}' has different signature than the implementation"

View File

@@ -5,6 +5,7 @@ gi.require_version("Gtk", "4.0")
import logging import logging
from clan_cli.api import ErrorDataClass, SuccessDataClass
from clan_cli.api.directory import FileRequest from clan_cli.api.directory import FileRequest
from gi.repository import Gio, GLib, Gtk from gi.repository import Gio, GLib, Gtk
@@ -15,17 +16,23 @@ log = logging.getLogger(__name__)
# This implements the abstract function open_file with one argument, file_request, # This implements the abstract function open_file with one argument, file_request,
# which is a FileRequest object and returns a string or None. # which is a FileRequest object and returns a string or None.
class open_file(ImplFunc[[FileRequest], str | None]): class open_file(
ImplFunc[[FileRequest, str], SuccessDataClass[str | None] | ErrorDataClass]
):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def async_run(self, file_request: FileRequest) -> bool: def async_run(self, file_request: FileRequest, op_key: str) -> bool:
def on_file_select(file_dialog: Gtk.FileDialog, task: Gio.Task) -> None: def on_file_select(file_dialog: Gtk.FileDialog, task: Gio.Task) -> None:
try: try:
gfile = file_dialog.open_finish(task) gfile = file_dialog.open_finish(task)
if gfile: if gfile:
selected_path = gfile.get_path() selected_path = gfile.get_path()
self.returns(selected_path) self.returns(
SuccessDataClass(
op_key=op_key, data=selected_path, status="success"
)
)
except Exception as e: except Exception as e:
print(f"Error getting selected file or directory: {e}") print(f"Error getting selected file or directory: {e}")
@@ -34,7 +41,11 @@ class open_file(ImplFunc[[FileRequest], str | None]):
gfile = file_dialog.select_folder_finish(task) gfile = file_dialog.select_folder_finish(task)
if gfile: if gfile:
selected_path = gfile.get_path() selected_path = gfile.get_path()
self.returns(selected_path) self.returns(
SuccessDataClass(
op_key=op_key, data=selected_path, status="success"
)
)
except Exception as e: except Exception as e:
print(f"Error getting selected directory: {e}") print(f"Error getting selected directory: {e}")
@@ -43,7 +54,11 @@ class open_file(ImplFunc[[FileRequest], str | None]):
gfile = file_dialog.save_finish(task) gfile = file_dialog.save_finish(task)
if gfile: if gfile:
selected_path = gfile.get_path() selected_path = gfile.get_path()
self.returns(selected_path) self.returns(
SuccessDataClass(
op_key=op_key, data=selected_path, status="success"
)
)
except Exception as e: except Exception as e:
print(f"Error getting selected file: {e}") print(f"Error getting selected file: {e}")

View File

@@ -17,9 +17,9 @@ log = logging.getLogger(__name__)
class WebExecutor(GObject.Object): class WebExecutor(GObject.Object):
def __init__(self, content_uri: str, plain_api: MethodRegistry) -> None: def __init__(self, content_uri: str, jschema_api: MethodRegistry) -> None:
super().__init__() super().__init__()
self.plain_api: MethodRegistry = plain_api self.jschema_api: MethodRegistry = jschema_api
self.webview: WebKit.WebView = WebKit.WebView() self.webview: WebKit.WebView = WebKit.WebView()
settings: WebKit.Settings = self.webview.get_settings() settings: WebKit.Settings = self.webview.get_settings()
@@ -40,10 +40,10 @@ class WebExecutor(GObject.Object):
self.webview.load_uri(content_uri) self.webview.load_uri(content_uri)
self.content_uri = content_uri self.content_uri = content_uri
self.api: GObjApi = GObjApi(self.plain_api.functions) self.api: GObjApi = GObjApi(self.jschema_api.functions)
self.api.overwrite_fn(open_file) self.api.overwrite_fn(open_file)
self.api.check_signature(self.plain_api.annotations) self.api.check_signature(self.jschema_api.signatures)
def on_decide_policy( def on_decide_policy(
self, self,
@@ -94,30 +94,25 @@ class WebExecutor(GObject.Object):
# Initialize dataclasses from the payload # Initialize dataclasses from the payload
reconciled_arguments = {} reconciled_arguments = {}
op_key = data.pop("op_key", None)
for k, v in data.items(): for k, v in data.items():
# Some functions expect to be called with dataclass instances # Some functions expect to be called with dataclass instances
# But the js api returns dictionaries. # But the js api returns dictionaries.
# Introspect the function and create the expected dataclass from dict dynamically # Introspect the function and create the expected dataclass from dict dynamically
# Depending on the introspected argument_type # Depending on the introspected argument_type
arg_class = self.plain_api.get_method_argtype(method_name, k) arg_class = self.jschema_api.get_method_argtype(method_name, k)
if dataclasses.is_dataclass(arg_class): if dataclasses.is_dataclass(arg_class):
reconciled_arguments[k] = from_dict(arg_class, v) reconciled_arguments[k] = from_dict(arg_class, v)
else: else:
reconciled_arguments[k] = v reconciled_arguments[k] = v
GLib.idle_add( GLib.idle_add(fn_instance._async_run, reconciled_arguments)
fn_instance._async_run,
reconciled_arguments,
op_key,
)
def on_result(self, source: ImplFunc, data: GResult) -> None: def on_result(self, source: ImplFunc, data: GResult) -> None:
result = dict() result = dataclass_to_dict(data.result)
result["result"] = dataclass_to_dict(data.result)
result["op_key"] = data.op_key
serialized = json.dumps(result, indent=4) serialized = json.dumps(result, indent=4)
log.debug(f"Result for {data.method_name}: {serialized}") log.debug(f"Result for {data.method_name}: {serialized}")
# Use idle_add to queue the response call to js on the main GTK thread # Use idle_add to queue the response call to js on the main GTK thread
self.return_data_to_js(data.method_name, serialized) self.return_data_to_js(data.method_name, serialized)

View File

@@ -36,7 +36,7 @@ class MainWindow(Adw.ApplicationWindow):
stack_view = ViewStack.use().view stack_view = ViewStack.use().view
webexec = WebExecutor(plain_api=API, content_uri=config.content_uri) webexec = WebExecutor(jschema_api=API, content_uri=config.content_uri)
stack_view.add_named(webexec.get_webview(), "webview") stack_view.add_named(webexec.get_webview(), "webview")
stack_view.set_visible_child_name(config.initial_view) stack_view.set_visible_child_name(config.initial_view)

View File

@@ -1,7 +1,7 @@
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from functools import wraps from functools import wraps
from inspect import Parameter, signature from inspect import Parameter, Signature, signature
from typing import Annotated, Any, Generic, Literal, TypeVar, get_type_hints from typing import Annotated, Any, Generic, Literal, TypeVar, get_type_hints
from clan_cli.errors import ClanError from clan_cli.errors import ClanError
@@ -20,12 +20,14 @@ class ApiError:
@dataclass @dataclass
class SuccessDataClass(Generic[ResponseDataType]): class SuccessDataClass(Generic[ResponseDataType]):
op_key: str
status: Annotated[Literal["success"], "The status of the response."] status: Annotated[Literal["success"], "The status of the response."]
data: ResponseDataType data: ResponseDataType
@dataclass @dataclass
class ErrorDataClass: class ErrorDataClass:
op_key: str
status: Literal["error"] status: Literal["error"]
errors: list[ApiError] errors: list[ApiError]
@@ -39,7 +41,7 @@ def update_wrapper_signature(wrapper: Callable, wrapped: Callable) -> None:
# Add 'op_key' parameter # Add 'op_key' parameter
op_key_param = Parameter( op_key_param = Parameter(
"op_key", Parameter.KEYWORD_ONLY, default=None, annotation=str | None "op_key", Parameter.KEYWORD_ONLY, default=None, annotation=str
) )
params.append(op_key_param) params.append(op_key_param)
@@ -50,26 +52,28 @@ def update_wrapper_signature(wrapper: Callable, wrapped: Callable) -> None:
class MethodRegistry: class MethodRegistry:
def __init__(self) -> None: def __init__(self) -> None:
self._orig_annotations: dict[str, dict[str, Any]] = {} self._orig_signature: dict[str, Signature] = {}
self._registry: dict[str, Callable[..., Any]] = {} self._registry: dict[str, Callable[..., Any]] = {}
@property @property
def annotations(self) -> dict[str, dict[str, Any]]: def orig_signatures(self) -> dict[str, Signature]:
return self._orig_annotations return self._orig_signature
@property
def signatures(self) -> dict[str, Signature]:
return {name: signature(fn) for name, fn in self.functions.items()}
@property @property
def functions(self) -> dict[str, Callable[..., Any]]: def functions(self) -> dict[str, Callable[..., Any]]:
return self._registry return self._registry
def reset(self) -> None: def reset(self) -> None:
self._orig_annotations.clear() self._orig_signature.clear()
self._registry.clear() self._registry.clear()
def register_abstract(self, fn: Callable[..., T]) -> Callable[..., T]: def register_abstract(self, fn: Callable[..., T]) -> Callable[..., T]:
@wraps(fn) @wraps(fn)
def wrapper( def wrapper(*args: Any, op_key: str, **kwargs: Any) -> ApiResponse[T]:
*args: Any, op_key: str | None = None, **kwargs: Any
) -> ApiResponse[T]:
raise NotImplementedError( raise NotImplementedError(
f"""{fn.__name__} - The platform didn't implement this function. f"""{fn.__name__} - The platform didn't implement this function.
@@ -96,20 +100,19 @@ API.register(open_file)
def register(self, fn: Callable[..., T]) -> Callable[..., T]: def register(self, fn: Callable[..., T]) -> Callable[..., T]:
if fn.__name__ in self._registry: if fn.__name__ in self._registry:
raise ValueError(f"Function {fn.__name__} already registered") raise ValueError(f"Function {fn.__name__} already registered")
if fn.__name__ in self._orig_annotations: if fn.__name__ in self._orig_signature:
raise ValueError(f"Function {fn.__name__} already registered") raise ValueError(f"Function {fn.__name__} already registered")
# make copy of original function # make copy of original function
self._orig_annotations[fn.__name__] = fn.__annotations__.copy() self._orig_signature[fn.__name__] = signature(fn)
@wraps(fn) @wraps(fn)
def wrapper( def wrapper(*args: Any, op_key: str, **kwargs: Any) -> ApiResponse[T]:
*args: Any, op_key: str | None = None, **kwargs: Any
) -> ApiResponse[T]:
try: try:
data: T = fn(*args, **kwargs) data: T = fn(*args, **kwargs)
return SuccessDataClass(status="success", data=data) return SuccessDataClass(status="success", data=data, op_key=op_key)
except ClanError as e: except ClanError as e:
return ErrorDataClass( return ErrorDataClass(
op_key=op_key,
status="error", status="error",
errors=[ errors=[
ApiError( ApiError(
@@ -127,7 +130,7 @@ API.register(open_file)
wrapper.__annotations__["return"] = ApiResponse[orig_return_type] # type: ignore wrapper.__annotations__["return"] = ApiResponse[orig_return_type] # type: ignore
# Add additional argument for the operation key # Add additional argument for the operation key
wrapper.__annotations__["op_key"] = str | None # type: ignore wrapper.__annotations__["op_key"] = str # type: ignore
update_wrapper_signature(wrapper, fn) update_wrapper_signature(wrapper, fn)