clan-lib: Remove injected "op_key" argument from all functions and do it over the threadcontext instead. Remove double threading in http server
This commit is contained in:
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
|
|||||||
|
|
||||||
from clan_lib.api import ApiResponse
|
from clan_lib.api import ApiResponse
|
||||||
from clan_lib.api.tasks import WebThread
|
from clan_lib.api.tasks import WebThread
|
||||||
from clan_lib.async_run import set_should_cancel
|
from clan_lib.async_run import set_current_thread_opkey, set_should_cancel
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .middleware import Middleware
|
from .middleware import Middleware
|
||||||
@@ -98,7 +98,7 @@ class ApiBridge(ABC):
|
|||||||
*,
|
*,
|
||||||
thread_name: str = "ApiBridgeThread",
|
thread_name: str = "ApiBridgeThread",
|
||||||
wait_for_completion: bool = False,
|
wait_for_completion: bool = False,
|
||||||
timeout: float = 60.0,
|
timeout: float = 60.0 * 60, # 1 hour default timeout
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process an API request in a separate thread with cancellation support.
|
"""Process an API request in a separate thread with cancellation support.
|
||||||
|
|
||||||
@@ -112,6 +112,7 @@ class ApiBridge(ABC):
|
|||||||
|
|
||||||
def thread_task(stop_event: threading.Event) -> None:
|
def thread_task(stop_event: threading.Event) -> None:
|
||||||
set_should_cancel(lambda: stop_event.is_set())
|
set_should_cancel(lambda: stop_event.is_set())
|
||||||
|
set_current_thread_opkey(op_key)
|
||||||
try:
|
try:
|
||||||
log.debug(
|
log.debug(
|
||||||
f"Processing {request.method_name} with args {request.args} "
|
f"Processing {request.method_name} with args {request.args} "
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ gi.require_version("Gtk", "4.0")
|
|||||||
|
|
||||||
from clan_lib.api import ApiError, ErrorDataClass, SuccessDataClass
|
from clan_lib.api import ApiError, ErrorDataClass, SuccessDataClass
|
||||||
from clan_lib.api.directory import FileRequest
|
from clan_lib.api.directory import FileRequest
|
||||||
|
from clan_lib.async_run import get_current_thread_opkey
|
||||||
from clan_lib.clan.check import check_clan_valid
|
from clan_lib.clan.check import check_clan_valid
|
||||||
from clan_lib.flake import Flake
|
from clan_lib.flake import Flake
|
||||||
from gi.repository import Gio, GLib, Gtk
|
from gi.repository import Gio, GLib, Gtk
|
||||||
@@ -24,7 +25,7 @@ def remove_none(_list: list) -> list:
|
|||||||
RESULT: dict[str, SuccessDataClass[list[str] | None] | ErrorDataClass] = {}
|
RESULT: dict[str, SuccessDataClass[list[str] | None] | ErrorDataClass] = {}
|
||||||
|
|
||||||
|
|
||||||
def get_clan_folder(*, op_key: str) -> SuccessDataClass[Flake] | ErrorDataClass:
|
def get_clan_folder() -> SuccessDataClass[Flake] | ErrorDataClass:
|
||||||
"""
|
"""
|
||||||
Opens the clan folder using the GTK file dialog.
|
Opens the clan folder using the GTK file dialog.
|
||||||
Returns the path to the clan folder or an error if it fails.
|
Returns the path to the clan folder or an error if it fails.
|
||||||
@@ -34,7 +35,10 @@ def get_clan_folder(*, op_key: str) -> SuccessDataClass[Flake] | ErrorDataClass:
|
|||||||
title="Select Clan Folder",
|
title="Select Clan Folder",
|
||||||
initial_folder=str(Path.home()),
|
initial_folder=str(Path.home()),
|
||||||
)
|
)
|
||||||
response = get_system_file(file_request, op_key=op_key)
|
|
||||||
|
response = get_system_file(file_request)
|
||||||
|
|
||||||
|
op_key = response.op_key
|
||||||
|
|
||||||
if isinstance(response, ErrorDataClass):
|
if isinstance(response, ErrorDataClass):
|
||||||
return response
|
return response
|
||||||
@@ -70,8 +74,13 @@ def get_clan_folder(*, op_key: str) -> SuccessDataClass[Flake] | ErrorDataClass:
|
|||||||
|
|
||||||
|
|
||||||
def get_system_file(
|
def get_system_file(
|
||||||
file_request: FileRequest, *, op_key: str
|
file_request: FileRequest,
|
||||||
) -> SuccessDataClass[list[str] | None] | ErrorDataClass:
|
) -> SuccessDataClass[list[str] | None] | ErrorDataClass:
|
||||||
|
op_key = get_current_thread_opkey()
|
||||||
|
|
||||||
|
if not op_key:
|
||||||
|
msg = "No operation key found in the current thread context."
|
||||||
|
raise RuntimeError(msg)
|
||||||
GLib.idle_add(gtk_open_file, file_request, op_key)
|
GLib.idle_add(gtk_open_file, file_request, op_key)
|
||||||
|
|
||||||
while RESULT.get(op_key) is None:
|
while RESULT.get(op_key) is None:
|
||||||
|
|||||||
@@ -21,18 +21,12 @@ class ArgumentParsingMiddleware(Middleware):
|
|||||||
# Convert dictionary arguments to dataclass instances
|
# Convert dictionary arguments to dataclass instances
|
||||||
reconciled_arguments = {}
|
reconciled_arguments = {}
|
||||||
for k, v in context.request.args.items():
|
for k, v in context.request.args.items():
|
||||||
if k == "op_key":
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get the expected argument type from the API
|
# Get the expected argument type from the API
|
||||||
arg_class = self.api.get_method_argtype(context.request.method_name, k)
|
arg_class = self.api.get_method_argtype(context.request.method_name, k)
|
||||||
|
|
||||||
# Convert dictionary to dataclass instance
|
# Convert dictionary to dataclass instance
|
||||||
reconciled_arguments[k] = from_dict(arg_class, v)
|
reconciled_arguments[k] = from_dict(arg_class, v)
|
||||||
|
|
||||||
# Add op_key to arguments
|
|
||||||
reconciled_arguments["op_key"] = context.request.op_key
|
|
||||||
|
|
||||||
# Create a new request with reconciled arguments
|
# Create a new request with reconciled arguments
|
||||||
|
|
||||||
updated_request = BackendRequest(
|
updated_request = BackendRequest(
|
||||||
|
|||||||
@@ -1,13 +1,22 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from http.server import BaseHTTPRequestHandler
|
from http.server import BaseHTTPRequestHandler
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from clan_lib.api import MethodRegistry, SuccessDataClass, dataclass_to_dict
|
from clan_lib.api import (
|
||||||
|
MethodRegistry,
|
||||||
|
SuccessDataClass,
|
||||||
|
dataclass_to_dict,
|
||||||
|
)
|
||||||
from clan_lib.api.tasks import WebThread
|
from clan_lib.api.tasks import WebThread
|
||||||
|
from clan_lib.async_run import (
|
||||||
|
set_current_thread_opkey,
|
||||||
|
set_should_cancel,
|
||||||
|
)
|
||||||
|
|
||||||
from clan_app.api.api_bridge import ApiBridge, BackendRequest, BackendResponse
|
from clan_app.api.api_bridge import ApiBridge, BackendRequest, BackendResponse
|
||||||
|
|
||||||
@@ -324,17 +333,34 @@ class HttpBridge(ApiBridge, BaseHTTPRequestHandler):
|
|||||||
msg = f"Operation key '{op_key}' is already in use. Please try again."
|
msg = f"Operation key '{op_key}' is already in use. Please try again."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
def process_request_in_thread(
|
||||||
|
self,
|
||||||
|
request: BackendRequest,
|
||||||
|
*,
|
||||||
|
thread_name: str = "ApiBridgeThread",
|
||||||
|
wait_for_completion: bool = False,
|
||||||
|
timeout: float = 60.0 * 60, # 1 hour default timeout
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
def _process_api_request_in_thread(
|
def _process_api_request_in_thread(
|
||||||
self, api_request: BackendRequest, method_name: str
|
self, api_request: BackendRequest, method_name: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process the API request in a separate thread."""
|
"""Process the API request in a separate thread."""
|
||||||
# Use the inherited thread processing method
|
stop_event = threading.Event()
|
||||||
self.process_request_in_thread(
|
request = api_request
|
||||||
api_request,
|
op_key = request.op_key or "unknown"
|
||||||
thread_name="HttpThread",
|
set_should_cancel(lambda: stop_event.is_set())
|
||||||
wait_for_completion=True,
|
set_current_thread_opkey(op_key)
|
||||||
timeout=60.0,
|
|
||||||
|
curr_thread = threading.current_thread()
|
||||||
|
self.threads[op_key] = WebThread(thread=curr_thread, stop_event=stop_event)
|
||||||
|
|
||||||
|
log.debug(
|
||||||
|
f"Processing {request.method_name} with args {request.args} "
|
||||||
|
f"and header {request.header}"
|
||||||
)
|
)
|
||||||
|
self.process_request(request)
|
||||||
|
|
||||||
def log_message(self, format: str, *args: Any) -> None: # noqa: A002
|
def log_message(self, format: str, *args: Any) -> None: # noqa: A002
|
||||||
"""Override default logging to use our logger."""
|
"""Override default logging to use our logger."""
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ export type SuccessData<T extends OperationNames> = SuccessQuery<T>["data"];
|
|||||||
|
|
||||||
interface SendHeaderType {
|
interface SendHeaderType {
|
||||||
logging?: { group_path: string[] };
|
logging?: { group_path: string[] };
|
||||||
|
op_key?: string;
|
||||||
}
|
}
|
||||||
interface BackendSendType<K extends OperationNames> {
|
interface BackendSendType<K extends OperationNames> {
|
||||||
body: OperationArgs<K>;
|
body: OperationArgs<K>;
|
||||||
@@ -64,9 +65,14 @@ export const callApi = <K extends OperationNames>(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
const req: BackendSendType<OperationNames> = {
|
const op_key = backendOpts?.op_key ?? crypto.randomUUID();
|
||||||
|
|
||||||
|
let req: BackendSendType<OperationNames> = {
|
||||||
body: args,
|
body: args,
|
||||||
header: backendOpts,
|
header: {
|
||||||
|
...backendOpts,
|
||||||
|
op_key,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
const result = (
|
const result = (
|
||||||
@@ -78,9 +84,6 @@ export const callApi = <K extends OperationNames>(
|
|||||||
>
|
>
|
||||||
)[method](req) as Promise<BackendReturnType<K>>;
|
)[method](req) as Promise<BackendReturnType<K>>;
|
||||||
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
||||||
const op_key = (result as any)._webviewMessageId as string;
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
uuid: op_key,
|
uuid: op_key,
|
||||||
result: result.then(({ body }) => body),
|
result: result.then(({ body }) => body),
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from clan_lib.api.util import JSchemaTypeError
|
from clan_lib.api.util import JSchemaTypeError
|
||||||
|
from clan_lib.async_run import get_current_thread_opkey
|
||||||
from clan_lib.errors import ClanError
|
from clan_lib.errors import ClanError
|
||||||
|
|
||||||
from .serde import dataclass_to_dict, from_dict, sanitize_string
|
from .serde import dataclass_to_dict, from_dict, sanitize_string
|
||||||
@@ -54,26 +55,6 @@ class ErrorDataClass:
|
|||||||
ApiResponse = SuccessDataClass[ResponseDataType] | ErrorDataClass
|
ApiResponse = SuccessDataClass[ResponseDataType] | ErrorDataClass
|
||||||
|
|
||||||
|
|
||||||
def update_wrapper_signature(wrapper: Callable, wrapped: Callable) -> None:
|
|
||||||
sig = signature(wrapped)
|
|
||||||
params = list(sig.parameters.values())
|
|
||||||
|
|
||||||
# Add 'op_key' parameter
|
|
||||||
op_key_param = Parameter(
|
|
||||||
"op_key",
|
|
||||||
Parameter.KEYWORD_ONLY,
|
|
||||||
# we add a None default value so that typescript code gen drops the parameter
|
|
||||||
# FIXME: this is a hack, we should filter out op_key in the typescript code gen
|
|
||||||
default=None,
|
|
||||||
annotation=str,
|
|
||||||
)
|
|
||||||
params.append(op_key_param)
|
|
||||||
|
|
||||||
# Create a new signature
|
|
||||||
new_sig = sig.replace(parameters=params)
|
|
||||||
wrapper.__signature__ = new_sig # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class MethodRegistry:
|
class MethodRegistry:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._orig_signature: dict[str, Signature] = {}
|
self._orig_signature: dict[str, Signature] = {}
|
||||||
@@ -130,18 +111,8 @@ API.register(get_system_file)
|
|||||||
fn_signature = signature(fn)
|
fn_signature = signature(fn)
|
||||||
abstract_signature = signature(self._registry[fn_name])
|
abstract_signature = signature(self._registry[fn_name])
|
||||||
|
|
||||||
# Remove the default argument of op_key from abstract_signature
|
|
||||||
# FIXME: This is a hack to make the signature comparison work
|
|
||||||
# because the other hack above where default value of op_key is None in the wrapper
|
|
||||||
abstract_params = list(abstract_signature.parameters.values())
|
|
||||||
for i, param in enumerate(abstract_params):
|
|
||||||
if param.name == "op_key":
|
|
||||||
abstract_params[i] = param.replace(default=Parameter.empty)
|
|
||||||
break
|
|
||||||
abstract_signature = abstract_signature.replace(parameters=abstract_params)
|
|
||||||
|
|
||||||
if fn_signature != abstract_signature:
|
if fn_signature != abstract_signature:
|
||||||
msg = f"Expected signature: {abstract_signature}\nActual signature: {fn_signature}"
|
msg = f"For function: {fn_name}. Expected signature: {abstract_signature}\nActual signature: {fn_signature}"
|
||||||
raise ClanError(msg)
|
raise ClanError(msg)
|
||||||
|
|
||||||
self._registry[fn_name] = fn
|
self._registry[fn_name] = fn
|
||||||
@@ -159,7 +130,11 @@ API.register(get_system_file)
|
|||||||
self._orig_signature[fn.__name__] = signature(fn)
|
self._orig_signature[fn.__name__] = signature(fn)
|
||||||
|
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
def wrapper(*args: Any, op_key: str, **kwargs: Any) -> ApiResponse[T]:
|
def wrapper(*args: Any, **kwargs: Any) -> ApiResponse[T]:
|
||||||
|
op_key = get_current_thread_opkey()
|
||||||
|
if op_key is None:
|
||||||
|
msg = f"While executing {fn.__name__}. Middleware forgot to set_current_thread_opkey()"
|
||||||
|
raise RuntimeError(msg)
|
||||||
try:
|
try:
|
||||||
data: T = fn(*args, **kwargs)
|
data: T = fn(*args, **kwargs)
|
||||||
return SuccessDataClass(status="success", data=data, op_key=op_key)
|
return SuccessDataClass(status="success", data=data, op_key=op_key)
|
||||||
@@ -196,11 +171,6 @@ API.register(get_system_file)
|
|||||||
orig_return_type = get_type_hints(fn).get("return")
|
orig_return_type = get_type_hints(fn).get("return")
|
||||||
wrapper.__annotations__["return"] = ApiResponse[orig_return_type] # type: ignore
|
wrapper.__annotations__["return"] = ApiResponse[orig_return_type] # type: ignore
|
||||||
|
|
||||||
# Add additional argument for the operation key
|
|
||||||
wrapper.__annotations__["op_key"] = str # type: ignore
|
|
||||||
|
|
||||||
update_wrapper_signature(wrapper, fn)
|
|
||||||
|
|
||||||
self._registry[fn.__name__] = wrapper
|
self._registry[fn.__name__] = wrapper
|
||||||
|
|
||||||
return fn
|
return fn
|
||||||
|
|||||||
@@ -74,6 +74,7 @@ class AsyncContext:
|
|||||||
should_cancel: Callable[[], bool] = (
|
should_cancel: Callable[[], bool] = (
|
||||||
lambda: False
|
lambda: False
|
||||||
) # Used to signal cancellation of task
|
) # Used to signal cancellation of task
|
||||||
|
op_key: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -90,6 +91,22 @@ class AsyncOpts:
|
|||||||
ASYNC_CTX_THREAD_LOCAL = threading.local()
|
ASYNC_CTX_THREAD_LOCAL = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
def set_current_thread_opkey(op_key: str) -> None:
|
||||||
|
"""
|
||||||
|
Set the current thread's operation key.
|
||||||
|
"""
|
||||||
|
ctx = get_async_ctx()
|
||||||
|
ctx.op_key = op_key
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_thread_opkey() -> str | None:
|
||||||
|
"""
|
||||||
|
Get the current thread's operation key.
|
||||||
|
"""
|
||||||
|
ctx = get_async_ctx()
|
||||||
|
return ctx.op_key
|
||||||
|
|
||||||
|
|
||||||
def is_async_cancelled() -> bool:
|
def is_async_cancelled() -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the current task has been cancelled.
|
Check if the current task has been cancelled.
|
||||||
|
|||||||
Reference in New Issue
Block a user