clan-app: Make http server non blocking, add tests for the http server and for cancelling tasks
This commit is contained in:
@@ -126,6 +126,7 @@ class ApiBridge(ABC):
|
||||
target=thread_task, args=(stop_event,), name=thread_name
|
||||
)
|
||||
thread.start()
|
||||
|
||||
self.threads[op_key] = WebThread(thread=thread, stop_event=stop_event)
|
||||
|
||||
if wait_for_completion:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import threading
|
||||
from http.server import HTTPServer
|
||||
from http.server import HTTPServer, ThreadingHTTPServer
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -36,7 +36,7 @@ class HttpApiServer:
|
||||
self._server_thread: threading.Thread | None = None
|
||||
# Bridge is now the request handler itself, no separate instance needed
|
||||
self._middleware: list[Middleware] = []
|
||||
self.shared_threads = shared_threads or {}
|
||||
self.shared_threads = shared_threads if shared_threads is not None else {}
|
||||
|
||||
def add_middleware(self, middleware: "Middleware") -> None:
|
||||
"""Add middleware to the middleware chain."""
|
||||
@@ -84,9 +84,9 @@ class HttpApiServer:
|
||||
log.warning("HTTP server is already running")
|
||||
return
|
||||
|
||||
# Create the server
|
||||
# Create the server using ThreadingHTTPServer for concurrent request handling
|
||||
handler_class = self._create_request_handler()
|
||||
self._server = HTTPServer((self.host, self.port), handler_class)
|
||||
self._server = ThreadingHTTPServer((self.host, self.port), handler_class)
|
||||
|
||||
def run_server() -> None:
|
||||
if self._server:
|
||||
|
||||
@@ -2,19 +2,18 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import Mock
|
||||
from urllib.request import Request, urlopen
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
from clan_lib.api import MethodRegistry
|
||||
from clan_lib.api import MethodRegistry, tasks
|
||||
from clan_lib.async_run import is_async_cancelled
|
||||
from clan_lib.log_manager import LogManager
|
||||
from clan_lib.api import tasks
|
||||
|
||||
from clan_app.api.middleware import (
|
||||
ArgumentParsingMiddleware,
|
||||
LoggingMiddleware,
|
||||
MethodExecutionMiddleware,
|
||||
)
|
||||
from clan_app.deps.http.http_server import HttpApiServer
|
||||
@@ -27,6 +26,8 @@ def mock_api() -> MethodRegistry:
|
||||
"""Create a mock API with test methods."""
|
||||
api = MethodRegistry()
|
||||
|
||||
api.register(tasks.delete_task)
|
||||
|
||||
@api.register
|
||||
def test_method(message: str) -> dict[str, str]:
|
||||
return {"response": f"Hello {message}!"}
|
||||
@@ -45,9 +46,7 @@ def mock_api() -> MethodRegistry:
|
||||
if is_async_cancelled():
|
||||
log.debug("Task was cancelled")
|
||||
return "Task was cancelled"
|
||||
log.debug(
|
||||
f"Processing {i} for {wtime}"
|
||||
)
|
||||
log.debug(f"Processing {i} for {wtime}")
|
||||
time.sleep(1)
|
||||
return f"Task completed with wtime: {wtime}"
|
||||
|
||||
@@ -70,7 +69,7 @@ def http_bridge(
|
||||
"""Create HTTP bridge dependencies for testing."""
|
||||
middleware_chain = (
|
||||
ArgumentParsingMiddleware(api=mock_api),
|
||||
LoggingMiddleware(log_manager=mock_log_manager),
|
||||
# LoggingMiddleware(log_manager=mock_log_manager),
|
||||
MethodExecutionMiddleware(api=mock_api),
|
||||
)
|
||||
return mock_api, middleware_chain
|
||||
@@ -87,7 +86,7 @@ def http_server(mock_api: MethodRegistry, mock_log_manager: Mock) -> HttpApiServ
|
||||
|
||||
# Add middleware
|
||||
server.add_middleware(ArgumentParsingMiddleware(api=mock_api))
|
||||
server.add_middleware(LoggingMiddleware(log_manager=mock_log_manager))
|
||||
# server.add_middleware(LoggingMiddleware(log_manager=mock_log_manager))
|
||||
server.add_middleware(MethodExecutionMiddleware(api=mock_api))
|
||||
|
||||
# Bridge will be created automatically when accessed
|
||||
@@ -104,7 +103,7 @@ class TestHttpBridge:
|
||||
# We'll test initialization through the server
|
||||
api, middleware_chain = http_bridge
|
||||
assert api is not None
|
||||
assert len(middleware_chain) == 3
|
||||
assert len(middleware_chain) == 2
|
||||
|
||||
def test_http_bridge_middleware_setup(self, http_bridge: tuple) -> None:
|
||||
"""Test that middleware is properly set up."""
|
||||
@@ -112,10 +111,10 @@ class TestHttpBridge:
|
||||
|
||||
# Test that we can create the bridge with middleware
|
||||
# The actual HTTP handling will be tested through the server integration tests
|
||||
assert len(middleware_chain) == 3
|
||||
assert len(middleware_chain) == 2
|
||||
assert isinstance(middleware_chain[0], ArgumentParsingMiddleware)
|
||||
assert isinstance(middleware_chain[1], LoggingMiddleware)
|
||||
assert isinstance(middleware_chain[2], MethodExecutionMiddleware)
|
||||
# assert isinstance(middleware_chain[1], LoggingMiddleware)
|
||||
assert isinstance(middleware_chain[1], MethodExecutionMiddleware)
|
||||
|
||||
|
||||
class TestHttpApiServer:
|
||||
@@ -268,7 +267,7 @@ class TestIntegration:
|
||||
|
||||
# Add middleware
|
||||
server.add_middleware(ArgumentParsingMiddleware(api=mock_api))
|
||||
server.add_middleware(LoggingMiddleware(log_manager=mock_log_manager))
|
||||
# server.add_middleware(LoggingMiddleware(log_manager=mock_log_manager))
|
||||
server.add_middleware(MethodExecutionMiddleware(api=mock_api))
|
||||
|
||||
# Bridge will be created automatically when accessed
|
||||
@@ -301,52 +300,49 @@ class TestIntegration:
|
||||
# Always stop server
|
||||
server.stop()
|
||||
|
||||
|
||||
def test_blocking_task(
|
||||
self, mock_api: MethodRegistry, mock_log_manager: Mock
|
||||
) -> None:
|
||||
|
||||
shared_threads: dict[str, tasks.WebThread] = {}
|
||||
tasks.BAKEND_THREADS = shared_threads
|
||||
|
||||
"""Test a long-running blocking task."""
|
||||
server: HttpApiServer = HttpApiServer(
|
||||
api=mock_api,
|
||||
host="127.0.0.1",
|
||||
port=8083,
|
||||
shared_threads=shared_threads,
|
||||
)
|
||||
api=mock_api,
|
||||
host="127.0.0.1",
|
||||
port=8083,
|
||||
shared_threads=shared_threads,
|
||||
)
|
||||
|
||||
# Add middleware
|
||||
server.add_middleware(ArgumentParsingMiddleware(api=mock_api))
|
||||
server.add_middleware(LoggingMiddleware(log_manager=mock_log_manager))
|
||||
# server.add_middleware(LoggingMiddleware(log_manager=mock_log_manager))
|
||||
server.add_middleware(MethodExecutionMiddleware(api=mock_api))
|
||||
|
||||
# Start server
|
||||
server.start()
|
||||
time.sleep(0.1) # Give server time to start
|
||||
|
||||
blocking_op_key = "b37f920f-ce8c-4c8d-b595-28ca983d265e" # str(uuid.uuid4())
|
||||
|
||||
sucess = threading.Event()
|
||||
def parallel_task() -> None:
|
||||
|
||||
time.sleep(1)
|
||||
# Make API call
|
||||
request_data: dict = {
|
||||
"body": {"message": "Integration"},
|
||||
"body": {"wtime": 60},
|
||||
"header": {"op_key": blocking_op_key},
|
||||
}
|
||||
req: Request = Request(
|
||||
"http://127.0.0.1:8083/api/v1/test_method",
|
||||
"http://127.0.0.1:8083/api/v1/run_task_blocking",
|
||||
data=json.dumps(request_data).encode(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response = urlopen(req)
|
||||
data: dict = json.loads(response.read().decode())
|
||||
|
||||
# thread.join()
|
||||
assert "body" in data
|
||||
assert "header" in data
|
||||
assert data["body"]["status"] == "success"
|
||||
assert data["body"]["data"] == {"response": "Hello Integration!"}
|
||||
sucess.set()
|
||||
assert data["body"]["data"] == "Task was cancelled"
|
||||
|
||||
thread = threading.Thread(
|
||||
target=parallel_task,
|
||||
@@ -355,23 +351,21 @@ class TestIntegration:
|
||||
)
|
||||
thread.start()
|
||||
|
||||
# Make API call
|
||||
time.sleep(1)
|
||||
request_data: dict = {
|
||||
"body": {"wtime": 3},
|
||||
"body": {"task_id": blocking_op_key},
|
||||
}
|
||||
req: Request = Request(
|
||||
"http://127.0.0.1:8083/api/v1/run_task_blocking",
|
||||
"http://127.0.0.1:8083/api/v1/delete_task",
|
||||
data=json.dumps(request_data).encode(),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response = urlopen(req)
|
||||
data: dict = json.loads(response.read().decode())
|
||||
|
||||
# thread.join()
|
||||
assert "body" in data
|
||||
assert "header" in data
|
||||
assert data["body"]["status"] == "success"
|
||||
assert data["body"]["data"] == "Task completed with wtime: 3"
|
||||
assert sucess.is_set(), "Parallel task did not complete successfully"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -21,7 +21,7 @@ exclude = ["result", "**/__pycache__"]
|
||||
clan_app = ["**/assets/*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = "tests"
|
||||
testpaths = [ "tests", "clan_app" ]
|
||||
faulthandler_timeout = 60
|
||||
log_level = "DEBUG"
|
||||
log_format = "%(levelname)s: %(message)s\n %(pathname)s:%(lineno)d::%(funcName)s"
|
||||
|
||||
@@ -58,6 +58,7 @@ mkShell {
|
||||
with ps;
|
||||
[
|
||||
mypy
|
||||
pytest-cov
|
||||
]
|
||||
++ (clan-app.devshellPyDeps ps)
|
||||
))
|
||||
|
||||
@@ -23,6 +23,7 @@ def delete_task(task_id: str) -> None:
|
||||
"""Cancel a task by its op_key."""
|
||||
assert BAKEND_THREADS is not None, "Backend threads not initialized"
|
||||
future = BAKEND_THREADS.get(task_id)
|
||||
|
||||
log.debug(f"Thread ID: {threading.get_ident()}")
|
||||
if future:
|
||||
future.stop_event.set()
|
||||
|
||||
Reference in New Issue
Block a user