cli: fix remaining typing errors
This commit is contained in:
@@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import shlex
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Iterator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Body, HTTPException, Request, status
|
||||
@@ -34,12 +34,10 @@ def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]:
|
||||
|
||||
|
||||
class NixBuildException(HTTPException):
|
||||
def __init__(self, uuid: UUID, msg: str, loc: list = ["body", "flake_attr"]):
|
||||
self.uuid = uuid
|
||||
def __init__(self, msg: str, loc: list = ["body", "flake_attr"]):
|
||||
detail = [
|
||||
{
|
||||
"loc": loc,
|
||||
"uuid": str(uuid),
|
||||
"msg": msg,
|
||||
"type": "value_error",
|
||||
}
|
||||
@@ -65,7 +63,7 @@ class BuildVmTask(BaseTask):
|
||||
vm_path = f"{''.join(proc.stdout[0])}/bin/run-nixos-vm"
|
||||
self.log.debug(f"vm_path: {vm_path}")
|
||||
|
||||
self.run_cmd(vm_path)
|
||||
self.run_cmd([vm_path])
|
||||
self.finished = True
|
||||
except Exception as e:
|
||||
self.failed = True
|
||||
@@ -103,7 +101,6 @@ async def inspect_vm(
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise NixBuildException(
|
||||
""
|
||||
f"""
|
||||
Failed to evaluate vm from '{flake_url}#{flake_attr}'.
|
||||
command: {shlex.join(cmd)}
|
||||
@@ -127,7 +124,7 @@ async def get_status(uuid: UUID) -> VmStatusResponse:
|
||||
@router.get("/api/vms/{uuid}/logs")
|
||||
async def get_logs(uuid: UUID) -> StreamingResponse:
|
||||
# Generator function that yields log lines as they are available
|
||||
def stream_logs():
|
||||
def stream_logs() -> Iterator[str]:
|
||||
task = get_task(uuid)
|
||||
|
||||
for proc in task.procs:
|
||||
|
||||
@@ -5,6 +5,7 @@ import select
|
||||
import shlex
|
||||
import subprocess
|
||||
import threading
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
|
||||
@@ -105,14 +106,14 @@ def get_task(uuid: UUID) -> BaseTask:
|
||||
return POOL[uuid]
|
||||
|
||||
|
||||
def register_task(task: BaseTask, *kwargs) -> UUID:
|
||||
def register_task(task: type, *args: Any) -> UUID:
|
||||
global POOL
|
||||
if not issubclass(task, BaseTask):
|
||||
raise TypeError("task must be a subclass of BaseTask")
|
||||
|
||||
uuid = uuid4()
|
||||
|
||||
inst_task = task(uuid, *kwargs)
|
||||
inst_task = task(uuid, *args)
|
||||
POOL[uuid] = inst_task
|
||||
inst_task.start()
|
||||
return uuid
|
||||
|
||||
@@ -2,6 +2,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from api import TestClient
|
||||
from httpx import SyncByteStream
|
||||
|
||||
# @pytest.mark.impure
|
||||
# def test_inspect(api: TestClient, test_flake_with_core: Path) -> None:
|
||||
@@ -41,6 +42,7 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None:
|
||||
|
||||
response = api.get(f"/api/vms/{uuid}/logs")
|
||||
print("=========FLAKE LOGS==========")
|
||||
assert isinstance(response.stream, SyncByteStream)
|
||||
for line in response.stream:
|
||||
assert line != b"", "Failed to get vm logs"
|
||||
print(line.decode("utf-8"), end="")
|
||||
@@ -48,6 +50,7 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None:
|
||||
assert response.status_code == 200, "Failed to get vm logs"
|
||||
|
||||
response = api.get(f"/api/vms/{uuid}/logs")
|
||||
assert isinstance(response.stream, SyncByteStream)
|
||||
print("=========VM LOGS==========")
|
||||
for line in response.stream:
|
||||
assert line != b"", "Failed to get vm logs"
|
||||
|
||||
Reference in New Issue
Block a user