Merge pull request 'Restructuring CLI to use API' (#387) from Qubasa-main into main
Reviewed-on: https://git.clan.lol/clan/clan-core/pulls/387
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,4 +1,5 @@
|
||||
.direnv
|
||||
democlan
|
||||
result*
|
||||
/pkgs/clan-cli/clan_cli/nixpkgs
|
||||
/pkgs/clan-cli/clan_cli/webui/assets
|
||||
|
||||
9
pkgs/clan-cli/.vscode/launch.json
vendored
9
pkgs/clan-cli/.vscode/launch.json
vendored
@@ -12,6 +12,15 @@
|
||||
"justMyCode": false,
|
||||
"args": [ "--reload", "--no-open", "--log-level", "debug" ],
|
||||
|
||||
},
|
||||
{
|
||||
"name": "Clan Cli VMs",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "clan_cli",
|
||||
"justMyCode": false,
|
||||
"args": [ "vms" ],
|
||||
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,12 +1,15 @@
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Optional
|
||||
|
||||
from . import config, create, machines, secrets, vms, webui
|
||||
from . import config, create, custom_logger, machines, secrets, vms, webui
|
||||
from .errors import ClanError
|
||||
from .ssh import cli as ssh_cli
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
argcomplete: Optional[ModuleType] = None
|
||||
try:
|
||||
import argcomplete # type: ignore[no-redef]
|
||||
@@ -62,14 +65,20 @@ def create_parser(prog: Optional[str] = None) -> argparse.ArgumentParser:
|
||||
def main() -> None:
|
||||
parser = create_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.debug:
|
||||
custom_logger.register(logging.DEBUG)
|
||||
log.debug("Debug logging enabled")
|
||||
else:
|
||||
custom_logger.register(logging.INFO)
|
||||
|
||||
if not hasattr(args, "func"):
|
||||
log.error("No argparse function registered")
|
||||
return
|
||||
try:
|
||||
args.func(args)
|
||||
except ClanError as e:
|
||||
if args.debug:
|
||||
raise
|
||||
print(f"{sys.argv[0]}: {e}")
|
||||
log.exception(e)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
|
||||
4
pkgs/clan-cli/clan_cli/__main__.py
Normal file
4
pkgs/clan-cli/clan_cli/__main__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from . import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
28
pkgs/clan-cli/clan_cli/async_cmd.py
Normal file
28
pkgs/clan-cli/clan_cli/async_cmd.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import shlex
|
||||
|
||||
from .errors import ClanError
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run(cmd: list[str]) -> bytes:
|
||||
log.debug(f"$: {shlex.join(cmd)}")
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise ClanError(
|
||||
f"""
|
||||
command: {shlex.join(cmd)}
|
||||
exit code: {proc.returncode}
|
||||
command output:
|
||||
{stderr.decode("utf-8")}
|
||||
"""
|
||||
)
|
||||
return stdout
|
||||
@@ -1,12 +1,16 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
|
||||
from .folders import machines_folder
|
||||
from .types import validate_hostname
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def list_machines() -> list[str]:
|
||||
path = machines_folder()
|
||||
log.debug(f"Listing machines in {path}")
|
||||
if not path.exists():
|
||||
return []
|
||||
objs: list[str] = []
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from clan_cli.errors import ClanError
|
||||
|
||||
from ..dirs import get_clan_flake_toplevel, module_root
|
||||
from ..dirs import get_clan_flake_toplevel
|
||||
from ..nix import nix_build, nix_config
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_generate_script(machine: str, clan_dir: Path) -> str:
|
||||
config = nix_config()
|
||||
@@ -31,7 +35,8 @@ def build_generate_script(machine: str, clan_dir: Path) -> str:
|
||||
def run_generate_secrets(secret_generator_script: str, clan_dir: Path) -> None:
|
||||
env = os.environ.copy()
|
||||
env["CLAN_DIR"] = str(clan_dir)
|
||||
env["PYTHONPATH"] = str(module_root().parent) # TODO do this in the clanCore module
|
||||
env["PYTHONPATH"] = ":".join(sys.path) # TODO do this in the clanCore module
|
||||
|
||||
print(f"generating secrets... {secret_generator_script}")
|
||||
proc = subprocess.run(
|
||||
[secret_generator_script],
|
||||
@@ -39,6 +44,8 @@ def run_generate_secrets(secret_generator_script: str, clan_dir: Path) -> None:
|
||||
)
|
||||
|
||||
if proc.returncode != 0:
|
||||
log.error("stdout: %s", proc.stdout)
|
||||
log.error("stderr: %s", proc.stderr)
|
||||
raise ClanError("failed to generate secrets")
|
||||
else:
|
||||
print("successfully generated secrets")
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
from ..dirs import get_clan_flake_toplevel, module_root
|
||||
from ..dirs import get_clan_flake_toplevel
|
||||
from ..errors import ClanError
|
||||
from ..nix import nix_build, nix_config, nix_shell
|
||||
from ..ssh import parse_deployment_address
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_upload_script(machine: str, clan_dir: Path) -> str:
|
||||
config = nix_config()
|
||||
@@ -53,7 +57,7 @@ def run_upload_secrets(
|
||||
) -> None:
|
||||
env = os.environ.copy()
|
||||
env["CLAN_DIR"] = str(clan_dir)
|
||||
env["PYTHONPATH"] = str(module_root().parent) # TODO do this in the clanCore module
|
||||
env["PYTHONPATH"] = ":".join(sys.path) # TODO do this in the clanCore module
|
||||
print(f"uploading secrets... {flake_attr}")
|
||||
with TemporaryDirectory() as tempdir_:
|
||||
tempdir = Path(tempdir_)
|
||||
@@ -67,6 +71,8 @@ def run_upload_secrets(
|
||||
)
|
||||
|
||||
if proc.returncode != 0:
|
||||
log.error("Stdout: %s", proc.stdout)
|
||||
log.error("Stderr: %s", proc.stderr)
|
||||
raise ClanError("failed to upload secrets")
|
||||
|
||||
h = parse_deployment_address(flake_attr, target)
|
||||
|
||||
169
pkgs/clan-cli/clan_cli/task_manager.py
Normal file
169
pkgs/clan-cli/clan_cli/task_manager.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import select
|
||||
import shlex
|
||||
import subprocess
|
||||
import threading
|
||||
from typing import Any, Iterator
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
|
||||
class CmdState:
|
||||
def __init__(self, log: logging.Logger) -> None:
|
||||
self.log: logging.Logger = log
|
||||
self.p: subprocess.Popen | None = None
|
||||
self.stdout: list[str] = []
|
||||
self.stderr: list[str] = []
|
||||
self._output: queue.SimpleQueue = queue.SimpleQueue()
|
||||
self.returncode: int | None = None
|
||||
self.done: bool = False
|
||||
self.running: bool = False
|
||||
self.cmd_str: str | None = None
|
||||
self.workdir: str | None = None
|
||||
|
||||
def close_queue(self) -> None:
|
||||
if self.p is not None:
|
||||
self.returncode = self.p.returncode
|
||||
self._output.put(None)
|
||||
self.running = False
|
||||
self.done = True
|
||||
|
||||
def run(self, cmd: list[str]) -> None:
|
||||
self.running = True
|
||||
try:
|
||||
self.cmd_str = shlex.join(cmd)
|
||||
self.workdir = os.getcwd()
|
||||
self.log.debug(f"Working directory: {self.workdir}")
|
||||
self.log.debug(f"Running command: {shlex.join(cmd)}")
|
||||
self.p = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
encoding="utf-8",
|
||||
cwd=self.workdir,
|
||||
)
|
||||
|
||||
while self.p.poll() is None:
|
||||
# Check if stderr is ready to be read from
|
||||
rlist, _, _ = select.select([self.p.stderr, self.p.stdout], [], [], 0)
|
||||
if self.p.stderr in rlist:
|
||||
assert self.p.stderr is not None
|
||||
line = self.p.stderr.readline()
|
||||
if line != "":
|
||||
line = line.strip("\n")
|
||||
self.stderr.append(line)
|
||||
self.log.debug("stderr: %s", line)
|
||||
self._output.put(line + "\n")
|
||||
|
||||
if self.p.stdout in rlist:
|
||||
assert self.p.stdout is not None
|
||||
line = self.p.stdout.readline()
|
||||
if line != "":
|
||||
line = line.strip("\n")
|
||||
self.stdout.append(line)
|
||||
self.log.debug("stdout: %s", line)
|
||||
self._output.put(line + "\n")
|
||||
|
||||
if self.p.returncode != 0:
|
||||
raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}")
|
||||
|
||||
self.log.debug("Successfully ran command")
|
||||
finally:
|
||||
self.close_queue()
|
||||
|
||||
|
||||
class BaseTask(threading.Thread):
|
||||
def __init__(self, uuid: UUID) -> None:
|
||||
# calling parent class constructor
|
||||
threading.Thread.__init__(self)
|
||||
|
||||
# constructor
|
||||
self.uuid: UUID = uuid
|
||||
self.log = logging.getLogger(__name__)
|
||||
self.procs: list[CmdState] = []
|
||||
self.failed: bool = False
|
||||
self.finished: bool = False
|
||||
self.logs_lock = threading.Lock()
|
||||
|
||||
def run(self) -> None:
|
||||
try:
|
||||
self.task_run()
|
||||
except Exception as e:
|
||||
for proc in self.procs:
|
||||
proc.close_queue()
|
||||
self.failed = True
|
||||
self.log.exception(e)
|
||||
finally:
|
||||
self.finished = True
|
||||
|
||||
def task_run(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
## TODO: If two clients are connected to the same task,
|
||||
def logs_iter(self) -> Iterator[str]:
|
||||
with self.logs_lock:
|
||||
for proc in self.procs:
|
||||
if self.finished:
|
||||
self.log.debug("log iter: Task is finished")
|
||||
break
|
||||
if proc.done:
|
||||
for line in proc.stderr:
|
||||
yield line + "\n"
|
||||
for line in proc.stdout:
|
||||
yield line + "\n"
|
||||
continue
|
||||
while True:
|
||||
out = proc._output
|
||||
line = out.get()
|
||||
if line is None:
|
||||
break
|
||||
yield line
|
||||
|
||||
def register_cmds(self, num_cmds: int) -> Iterator[CmdState]:
|
||||
for i in range(num_cmds):
|
||||
cmd = CmdState(self.log)
|
||||
self.procs.append(cmd)
|
||||
|
||||
for cmd in self.procs:
|
||||
yield cmd
|
||||
|
||||
|
||||
# TODO: We need to test concurrency
|
||||
class TaskPool:
|
||||
def __init__(self) -> None:
|
||||
self.lock: threading.RLock = threading.RLock()
|
||||
self.pool: dict[UUID, BaseTask] = {}
|
||||
|
||||
def __getitem__(self, uuid: UUID) -> BaseTask:
|
||||
with self.lock:
|
||||
return self.pool[uuid]
|
||||
|
||||
def __setitem__(self, uuid: UUID, task: BaseTask) -> None:
|
||||
with self.lock:
|
||||
if uuid in self.pool:
|
||||
raise KeyError(f"Task with uuid {uuid} already exists")
|
||||
if type(uuid) is not UUID:
|
||||
raise TypeError("uuid must be of type UUID")
|
||||
self.pool[uuid] = task
|
||||
|
||||
|
||||
POOL: TaskPool = TaskPool()
|
||||
|
||||
|
||||
def get_task(uuid: UUID) -> BaseTask:
|
||||
global POOL
|
||||
return POOL[uuid]
|
||||
|
||||
|
||||
def register_task(task: type, *args: Any) -> UUID:
|
||||
global POOL
|
||||
if not issubclass(task, BaseTask):
|
||||
raise TypeError("task must be a subclass of BaseTask")
|
||||
|
||||
uuid = uuid4()
|
||||
|
||||
inst_task = task(uuid, *args)
|
||||
POOL[uuid] = inst_task
|
||||
inst_task.start()
|
||||
return uuid
|
||||
@@ -1,101 +1,129 @@
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
from uuid import UUID
|
||||
|
||||
from ..dirs import get_clan_flake_toplevel
|
||||
from ..nix import nix_build, nix_shell
|
||||
from ..task_manager import BaseTask, CmdState, get_task, register_task
|
||||
from .inspect import VmConfig
|
||||
|
||||
|
||||
def get_vm_create_info(machine: str) -> dict:
|
||||
class BuildVmTask(BaseTask):
|
||||
def __init__(self, uuid: UUID, vm: VmConfig) -> None:
|
||||
super().__init__(uuid)
|
||||
self.vm = vm
|
||||
|
||||
def get_vm_create_info(self, cmds: Iterator[CmdState]) -> dict:
|
||||
clan_dir = self.vm.flake_url
|
||||
machine = self.vm.flake_attr
|
||||
cmd = next(cmds)
|
||||
cmd.run(
|
||||
nix_build(
|
||||
[
|
||||
# f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.clan.virtualisation.createJSON' # TODO use this
|
||||
f'{clan_dir}#nixosConfigurations."{machine}".config.system.clan.vm.create'
|
||||
]
|
||||
)
|
||||
)
|
||||
vm_json = "".join(cmd.stdout)
|
||||
self.log.debug(f"VM JSON path: {vm_json}")
|
||||
with open(vm_json) as f:
|
||||
return json.load(f)
|
||||
|
||||
def task_run(self) -> None:
|
||||
cmds = self.register_cmds(4)
|
||||
|
||||
machine = self.vm.flake_attr
|
||||
self.log.debug(f"Creating VM for {machine}")
|
||||
|
||||
# TODO: We should get this from the vm argument
|
||||
vm_config = self.get_vm_create_info(cmds)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir_:
|
||||
xchg_dir = Path(tmpdir_) / "xchg"
|
||||
xchg_dir.mkdir()
|
||||
disk_img = f"{tmpdir_}/disk.img"
|
||||
|
||||
cmd = next(cmds)
|
||||
cmd.run(
|
||||
nix_shell(
|
||||
["qemu"],
|
||||
[
|
||||
"qemu-img",
|
||||
"create",
|
||||
"-f",
|
||||
"raw",
|
||||
disk_img,
|
||||
"1024M",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
cmd = next(cmds)
|
||||
cmd.run(
|
||||
nix_shell(
|
||||
["e2fsprogs"],
|
||||
[
|
||||
"mkfs.ext4",
|
||||
"-L",
|
||||
"nixos",
|
||||
disk_img,
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
cmd = next(cmds)
|
||||
cmd.run(
|
||||
nix_shell(
|
||||
["qemu"],
|
||||
[
|
||||
# fmt: off
|
||||
"qemu-kvm",
|
||||
"-name", machine,
|
||||
"-m", f'{vm_config["memorySize"]}M',
|
||||
"-smp", str(vm_config["cores"]),
|
||||
"-device", "virtio-rng-pci",
|
||||
"-net", "nic,netdev=user.0,model=virtio", "-netdev", "user,id=user.0",
|
||||
"-virtfs", "local,path=/nix/store,security_model=none,mount_tag=nix-store",
|
||||
"-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=shared",
|
||||
"-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=xchg",
|
||||
"-drive", f'cache=writeback,file={disk_img},format=raw,id=drive1,if=none,index=1,werror=report',
|
||||
"-device", "virtio-blk-pci,bootindex=1,drive=drive1,serial=root",
|
||||
"-device", "virtio-keyboard",
|
||||
"-usb",
|
||||
"-device", "usb-tablet,bus=usb-bus.0",
|
||||
"-kernel", f'{vm_config["toplevel"]}/kernel',
|
||||
"-initrd", vm_config["initrd"],
|
||||
"-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0',
|
||||
# fmt: on
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def create_vm(vm: VmConfig) -> UUID:
|
||||
return register_task(BuildVmTask, vm)
|
||||
|
||||
|
||||
def create_command(args: argparse.Namespace) -> None:
|
||||
clan_dir = get_clan_flake_toplevel().as_posix()
|
||||
vm = VmConfig(
|
||||
flake_url=clan_dir,
|
||||
flake_attr=args.machine,
|
||||
cores=0,
|
||||
graphics=False,
|
||||
memory_size=0,
|
||||
)
|
||||
|
||||
# config = nix_config()
|
||||
# system = config["system"]
|
||||
|
||||
vm_json = subprocess.run(
|
||||
nix_build(
|
||||
[
|
||||
# f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.clan.virtualisation.createJSON' # TODO use this
|
||||
f'{clan_dir}#nixosConfigurations."{machine}".config.system.clan.vm.create'
|
||||
]
|
||||
),
|
||||
stdout=subprocess.PIPE,
|
||||
check=True,
|
||||
text=True,
|
||||
).stdout.strip()
|
||||
with open(vm_json) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def create(args: argparse.Namespace) -> None:
|
||||
print(f"Creating VM for {args.machine}")
|
||||
machine = args.machine
|
||||
vm_config = get_vm_create_info(machine)
|
||||
with tempfile.TemporaryDirectory() as tmpdir_:
|
||||
xchg_dir = Path(tmpdir_) / "xchg"
|
||||
xchg_dir.mkdir()
|
||||
disk_img = f"{tmpdir_}/disk.img"
|
||||
subprocess.run(
|
||||
nix_shell(
|
||||
["qemu"],
|
||||
[
|
||||
"qemu-img",
|
||||
"create",
|
||||
"-f",
|
||||
"raw",
|
||||
disk_img,
|
||||
"1024M",
|
||||
],
|
||||
),
|
||||
stdout=subprocess.PIPE,
|
||||
check=True,
|
||||
text=True,
|
||||
)
|
||||
subprocess.run(
|
||||
[
|
||||
"mkfs.ext4",
|
||||
"-L",
|
||||
"nixos",
|
||||
disk_img,
|
||||
],
|
||||
stdout=subprocess.PIPE,
|
||||
check=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
subprocess.run(
|
||||
nix_shell(
|
||||
["qemu"],
|
||||
[
|
||||
# fmt: off
|
||||
"qemu-kvm",
|
||||
"-name", machine,
|
||||
"-m", f'{vm_config["memorySize"]}M',
|
||||
"-smp", str(vm_config["cores"]),
|
||||
"-device", "virtio-rng-pci",
|
||||
"-net", "nic,netdev=user.0,model=virtio", "-netdev", "user,id=user.0",
|
||||
"-virtfs", "local,path=/nix/store,security_model=none,mount_tag=nix-store",
|
||||
"-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=shared",
|
||||
"-virtfs", f"local,path={xchg_dir},security_model=none,mount_tag=xchg",
|
||||
"-drive", f'cache=writeback,file={disk_img},format=raw,id=drive1,if=none,index=1,werror=report',
|
||||
"-device", "virtio-blk-pci,bootindex=1,drive=drive1,serial=root",
|
||||
"-device", "virtio-keyboard",
|
||||
"-usb",
|
||||
"-device", "usb-tablet,bus=usb-bus.0",
|
||||
"-kernel", f'{vm_config["toplevel"]}/kernel',
|
||||
"-initrd", vm_config["initrd"],
|
||||
"-append", f'{(Path(vm_config["toplevel"]) / "kernel-params").read_text()} init={vm_config["toplevel"]}/init regInfo={vm_config["regInfo"]}/registration console=ttyS0,115200n8 console=tty0',
|
||||
# fmt: on
|
||||
],
|
||||
),
|
||||
stdout=subprocess.PIPE,
|
||||
check=True,
|
||||
text=True,
|
||||
)
|
||||
uuid = create_vm(vm)
|
||||
task = get_task(uuid)
|
||||
for line in task.logs_iter():
|
||||
print(line, end="")
|
||||
|
||||
|
||||
def register_create_parser(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument("machine", type=str)
|
||||
parser.set_defaults(func=create)
|
||||
parser.set_defaults(func=create_command)
|
||||
|
||||
@@ -1,38 +1,42 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import subprocess
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ..async_cmd import run
|
||||
from ..dirs import get_clan_flake_toplevel
|
||||
from ..nix import nix_eval
|
||||
|
||||
|
||||
def get_vm_inspect_info(machine: str) -> dict:
|
||||
clan_dir = get_clan_flake_toplevel().as_posix()
|
||||
class VmConfig(BaseModel):
|
||||
flake_url: str
|
||||
flake_attr: str
|
||||
|
||||
# config = nix_config()
|
||||
# system = config["system"]
|
||||
cores: int
|
||||
memory_size: int
|
||||
graphics: bool
|
||||
|
||||
return json.loads(
|
||||
subprocess.run(
|
||||
nix_eval(
|
||||
[
|
||||
# f'{clan_dir}#clanInternals.machines."{system}"."{machine}".config.clan.virtualisation' # TODO use this
|
||||
f'{clan_dir}#nixosConfigurations."{machine}".config.system.clan.vm.config'
|
||||
]
|
||||
),
|
||||
stdout=subprocess.PIPE,
|
||||
check=True,
|
||||
text=True,
|
||||
).stdout
|
||||
|
||||
async def inspect_vm(flake_url: str, flake_attr: str) -> VmConfig:
|
||||
cmd = nix_eval(
|
||||
[
|
||||
f"{flake_url}#nixosConfigurations.{json.dumps(flake_attr)}.config.system.clan.vm.config"
|
||||
]
|
||||
)
|
||||
stdout = await run(cmd)
|
||||
data = json.loads(stdout)
|
||||
return VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data)
|
||||
|
||||
|
||||
def inspect(args: argparse.Namespace) -> None:
|
||||
print(f"Creating VM for {args.machine}")
|
||||
machine = args.machine
|
||||
print(get_vm_inspect_info(machine))
|
||||
def inspect_command(args: argparse.Namespace) -> None:
|
||||
clan_dir = get_clan_flake_toplevel().as_posix()
|
||||
res = asyncio.run(inspect_vm(flake_url=clan_dir, flake_attr=args.machine))
|
||||
print("Cores:", res.cores)
|
||||
print("Memory size:", res.memory_size)
|
||||
print("Graphics:", res.graphics)
|
||||
|
||||
|
||||
def register_inspect_parser(parser: argparse.ArgumentParser) -> None:
|
||||
parser.add_argument("machine", type=str)
|
||||
parser.set_defaults(func=inspect)
|
||||
parser.set_defaults(func=inspect_command)
|
||||
|
||||
@@ -45,6 +45,8 @@ def register_parser(parser: argparse.ArgumentParser) -> None:
|
||||
help="Log level",
|
||||
choices=["critical", "error", "warning", "info", "debug", "trace"],
|
||||
)
|
||||
|
||||
# Set the args.func variable in args
|
||||
if start_server is None:
|
||||
parser.set_defaults(func=fastapi_is_not_installed)
|
||||
else:
|
||||
|
||||
@@ -5,6 +5,11 @@ from . import register_parser
|
||||
if __name__ == "__main__":
|
||||
# this is use in our integration test
|
||||
parser = argparse.ArgumentParser()
|
||||
# call the register_parser function, which adds arguments to the parser
|
||||
register_parser(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
# call the function that is stored
|
||||
# in the func attribute of args, and pass args as the argument
|
||||
# look into register_parser to see how this is done
|
||||
args.func(args)
|
||||
|
||||
@@ -5,9 +5,10 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from .. import custom_logger
|
||||
from ..errors import ClanError
|
||||
from .assets import asset_path
|
||||
from .routers import flake, health, machines, root, utils, vms
|
||||
from .error_handlers import clan_error_handler
|
||||
from .routers import flake, health, machines, root, vms
|
||||
|
||||
origins = [
|
||||
"http://localhost:3000",
|
||||
@@ -33,9 +34,7 @@ def setup_app() -> FastAPI:
|
||||
# Needs to be last in register. Because of wildcard route
|
||||
app.include_router(root.router)
|
||||
|
||||
app.add_exception_handler(
|
||||
utils.NixBuildException, utils.nix_build_exception_handler
|
||||
)
|
||||
app.add_exception_handler(ClanError, clan_error_handler)
|
||||
|
||||
app.mount("/static", StaticFiles(directory=asset_path()), name="static")
|
||||
|
||||
@@ -43,15 +42,11 @@ def setup_app() -> FastAPI:
|
||||
if isinstance(route, APIRoute):
|
||||
route.operation_id = route.name # in this case, 'read_items'
|
||||
log.debug(f"Registered route: {route}")
|
||||
|
||||
for i in app.exception_handlers.items():
|
||||
log.debug(f"Registered exception handler: {i}")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# TODO: How do I get the log level from the command line in here?
|
||||
custom_logger.register(logging.DEBUG)
|
||||
app = setup_app()
|
||||
|
||||
for i in app.exception_handlers.items():
|
||||
log.info(f"Registered exception handler: {i}")
|
||||
|
||||
log.warning("log warn")
|
||||
log.debug("log debug")
|
||||
|
||||
@@ -1,7 +1,39 @@
|
||||
import functools
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_hash(string: str) -> str:
|
||||
"""
|
||||
This function takes a string like '/nix/store/kkvk20b8zh8aafdnfjp6dnf062x19732-source'
|
||||
and returns the hash part 'kkvk20b8zh8aafdnfjp6dnf062x19732' after '/nix/store/' and before '-source'.
|
||||
"""
|
||||
# Split the string by '/' and get the last element
|
||||
last_element = string.split("/")[-1]
|
||||
# Split the last element by '-' and get the first element
|
||||
hash_part = last_element.split("-")[0]
|
||||
# Return the hash part
|
||||
return hash_part
|
||||
|
||||
|
||||
def check_divergence(path: Path) -> None:
|
||||
p = path.resolve()
|
||||
|
||||
log.info("Absolute web asset path: %s", p)
|
||||
if not p.is_dir():
|
||||
raise FileNotFoundError(p)
|
||||
|
||||
# Get the hash part of the path
|
||||
gh = get_hash(str(p))
|
||||
|
||||
log.debug(f"Serving webui asset with hash {gh}")
|
||||
|
||||
|
||||
@functools.cache
|
||||
def asset_path() -> Path:
|
||||
return Path(__file__).parent / "assets"
|
||||
path = Path(__file__).parent / "assets"
|
||||
log.debug("Serving assets from: %s", path)
|
||||
check_divergence(path)
|
||||
return path
|
||||
|
||||
23
pkgs/clan-cli/clan_cli/webui/error_handlers.py
Normal file
23
pkgs/clan-cli/clan_cli/webui/error_handlers.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import logging
|
||||
|
||||
from fastapi import Request, status
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from ..errors import ClanError
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def clan_error_handler(request: Request, exc: ClanError) -> JSONResponse:
|
||||
log.error("ClanError: %s", exc)
|
||||
detail = [
|
||||
{
|
||||
"loc": [],
|
||||
"msg": str(exc),
|
||||
}
|
||||
]
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
content=jsonable_encoder(dict(detail=detail)),
|
||||
)
|
||||
@@ -6,15 +6,15 @@ from fastapi import APIRouter, HTTPException
|
||||
|
||||
from clan_cli.webui.schemas import FlakeAction, FlakeAttrResponse, FlakeResponse
|
||||
|
||||
from ...async_cmd import run
|
||||
from ...nix import nix_command, nix_flake_show
|
||||
from .utils import run_cmd
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def get_attrs(url: str) -> list[str]:
|
||||
cmd = nix_flake_show(url)
|
||||
stdout = await run_cmd(cmd)
|
||||
stdout = await run(cmd)
|
||||
|
||||
data: dict[str, dict] = {}
|
||||
try:
|
||||
@@ -45,7 +45,7 @@ async def inspect_flake(
|
||||
# Extract the flake from the given URL
|
||||
# We do this by running 'nix flake prefetch {url} --json'
|
||||
cmd = nix_command(["flake", "prefetch", url, "--json", "--refresh"])
|
||||
stdout = await run_cmd(cmd)
|
||||
stdout = await run(cmd)
|
||||
data: dict[str, str] = json.loads(stdout)
|
||||
|
||||
if data.get("storePath") is None:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
from mimetypes import guess_type
|
||||
from pathlib import Path
|
||||
@@ -8,6 +9,8 @@ from ..assets import asset_path
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("/{path_name:path}")
|
||||
async def root(path_name: str) -> Response:
|
||||
@@ -16,6 +19,7 @@ async def root(path_name: str) -> Response:
|
||||
filename = Path(os.path.normpath(asset_path() / path_name))
|
||||
|
||||
if not filename.is_relative_to(asset_path()):
|
||||
log.error("Prevented directory traversal: %s", filename)
|
||||
# prevent directory traversal
|
||||
return Response(status_code=403)
|
||||
|
||||
@@ -23,6 +27,7 @@ async def root(path_name: str) -> Response:
|
||||
if filename.suffix == "":
|
||||
filename = filename.with_suffix(".html")
|
||||
if not filename.is_file():
|
||||
log.error("File not found: %s", filename)
|
||||
return Response(status_code=404)
|
||||
|
||||
content_type, _ = guess_type(filename)
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import shlex
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NixBuildException(HTTPException):
|
||||
def __init__(self, msg: str, loc: list = ["body", "flake_attr"]):
|
||||
detail = [
|
||||
{
|
||||
"loc": loc,
|
||||
"msg": msg,
|
||||
"type": "value_error",
|
||||
}
|
||||
]
|
||||
super().__init__(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=detail
|
||||
)
|
||||
|
||||
|
||||
def nix_build_exception_handler(
|
||||
request: Request, exc: NixBuildException
|
||||
) -> JSONResponse:
|
||||
log.error("NixBuildException: %s", exc)
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=jsonable_encoder(dict(detail=exc.detail)),
|
||||
)
|
||||
|
||||
|
||||
async def run_cmd(cmd: list[str]) -> bytes:
|
||||
log.debug(f"Running command: {shlex.join(cmd)}")
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise NixBuildException(
|
||||
f"""
|
||||
command: {shlex.join(cmd)}
|
||||
exit code: {proc.returncode}
|
||||
command output:
|
||||
{stderr.decode("utf-8")}
|
||||
"""
|
||||
)
|
||||
return stdout
|
||||
@@ -1,73 +1,27 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Annotated, Iterator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Body, status
|
||||
from fastapi import APIRouter, Body, status
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from clan_cli.webui.routers.flake import get_attrs
|
||||
|
||||
from ...nix import nix_build, nix_eval
|
||||
from ...task_manager import get_task
|
||||
from ...vms import create, inspect
|
||||
from ..schemas import VmConfig, VmCreateResponse, VmInspectResponse, VmStatusResponse
|
||||
from ..task_manager import BaseTask, get_task, register_task
|
||||
from .utils import run_cmd
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def nix_inspect_vm_cmd(machine: str, flake_url: str) -> list[str]:
|
||||
return nix_eval(
|
||||
[
|
||||
f"{flake_url}#nixosConfigurations.{json.dumps(machine)}.config.system.clan.vm.config"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def nix_build_vm_cmd(machine: str, flake_url: str) -> list[str]:
|
||||
return nix_build(
|
||||
[
|
||||
f"{flake_url}#nixosConfigurations.{json.dumps(machine)}.config.system.build.vm"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class BuildVmTask(BaseTask):
|
||||
def __init__(self, uuid: UUID, vm: VmConfig) -> None:
|
||||
super().__init__(uuid)
|
||||
self.vm = vm
|
||||
|
||||
def run(self) -> None:
|
||||
try:
|
||||
self.log.debug(f"BuildVM with uuid {self.uuid} started")
|
||||
cmd = nix_build_vm_cmd(self.vm.flake_attr, flake_url=self.vm.flake_url)
|
||||
|
||||
proc = self.run_cmd(cmd)
|
||||
self.log.debug(f"stdout: {proc.stdout}")
|
||||
|
||||
vm_path = f"{''.join(proc.stdout[0])}/bin/run-nixos-vm"
|
||||
self.log.debug(f"vm_path: {vm_path}")
|
||||
|
||||
self.run_cmd([vm_path])
|
||||
self.finished = True
|
||||
except Exception as e:
|
||||
self.failed = True
|
||||
self.finished = True
|
||||
log.exception(e)
|
||||
|
||||
|
||||
@router.post("/api/vms/inspect")
|
||||
async def inspect_vm(
|
||||
flake_url: Annotated[str, Body()], flake_attr: Annotated[str, Body()]
|
||||
) -> VmInspectResponse:
|
||||
cmd = nix_inspect_vm_cmd(flake_attr, flake_url=flake_url)
|
||||
stdout = await run_cmd(cmd)
|
||||
data = json.loads(stdout)
|
||||
return VmInspectResponse(
|
||||
config=VmConfig(flake_url=flake_url, flake_attr=flake_attr, **data)
|
||||
)
|
||||
config = await inspect.inspect_vm(flake_url, flake_attr)
|
||||
return VmInspectResponse(config=config)
|
||||
|
||||
|
||||
@router.get("/api/vms/{uuid}/status")
|
||||
@@ -84,21 +38,7 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse:
|
||||
def stream_logs() -> Iterator[str]:
|
||||
task = get_task(uuid)
|
||||
|
||||
for proc in task.procs:
|
||||
if proc.done:
|
||||
log.debug("stream logs and proc is done")
|
||||
for line in proc.stderr:
|
||||
yield line + "\n"
|
||||
for line in proc.stdout:
|
||||
yield line + "\n"
|
||||
continue
|
||||
while True:
|
||||
out = proc.output
|
||||
line = out.get()
|
||||
if line is None:
|
||||
log.debug("stream logs and line is None")
|
||||
break
|
||||
yield line
|
||||
yield from task.logs_iter()
|
||||
|
||||
return StreamingResponse(
|
||||
content=stream_logs(),
|
||||
@@ -107,14 +47,12 @@ async def get_vm_logs(uuid: UUID) -> StreamingResponse:
|
||||
|
||||
|
||||
@router.post("/api/vms/create")
|
||||
async def create_vm(
|
||||
vm: Annotated[VmConfig, Body()], background_tasks: BackgroundTasks
|
||||
) -> VmCreateResponse:
|
||||
async def create_vm(vm: Annotated[VmConfig, Body()]) -> VmCreateResponse:
|
||||
flake_attrs = await get_attrs(vm.flake_url)
|
||||
if vm.flake_attr not in flake_attrs:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Provided attribute '{vm.flake_attr}' does not exist.",
|
||||
)
|
||||
uuid = register_task(BuildVmTask, vm)
|
||||
uuid = create.create_vm(vm)
|
||||
return VmCreateResponse(uuid=str(uuid))
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..vms.inspect import VmConfig
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
ONLINE = "online"
|
||||
@@ -35,15 +37,6 @@ class SchemaResponse(BaseModel):
|
||||
schema_: dict = Field(alias="schema")
|
||||
|
||||
|
||||
class VmConfig(BaseModel):
|
||||
flake_url: str
|
||||
flake_attr: str
|
||||
|
||||
cores: int
|
||||
memory_size: int
|
||||
graphics: bool
|
||||
|
||||
|
||||
class VmStatusResponse(BaseModel):
|
||||
returncode: list[int | None]
|
||||
running: bool
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import argparse
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import syslog
|
||||
import time
|
||||
import urllib.request
|
||||
import webbrowser
|
||||
@@ -90,3 +95,98 @@ def start_server(args: argparse.Namespace) -> None:
|
||||
access_log=args.log_level == "debug",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
# Define a function that takes the path of the file socket as input and returns True if it is served, False otherwise
|
||||
def is_served(file_socket: Path) -> bool:
|
||||
# Create a Unix stream socket
|
||||
client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
# Try to connect to the file socket
|
||||
try:
|
||||
client.connect(str(file_socket))
|
||||
# Connection succeeded, return True
|
||||
return True
|
||||
except OSError:
|
||||
# Connection failed, return False
|
||||
return False
|
||||
finally:
|
||||
# Close the client socket
|
||||
client.close()
|
||||
|
||||
|
||||
def set_out_to_syslog() -> None: # type: ignore
|
||||
# Define some constants for convenience
|
||||
log_levels = {
|
||||
"emerg": syslog.LOG_EMERG,
|
||||
"alert": syslog.LOG_ALERT,
|
||||
"crit": syslog.LOG_CRIT,
|
||||
"err": syslog.LOG_ERR,
|
||||
"warning": syslog.LOG_WARNING,
|
||||
"notice": syslog.LOG_NOTICE,
|
||||
"info": syslog.LOG_INFO,
|
||||
"debug": syslog.LOG_DEBUG,
|
||||
}
|
||||
facility = syslog.LOG_USER # Use user facility for custom applications
|
||||
|
||||
# Open a connection to the system logger
|
||||
syslog.openlog("clan-cli", 0, facility) # Use "myapp" as the prefix for messages
|
||||
|
||||
# Define a custom write function that sends messages to syslog
|
||||
def write(message: str) -> int:
|
||||
# Strip the newline character from the message
|
||||
message = message.rstrip("\n")
|
||||
# Check if the message is not empty
|
||||
if message:
|
||||
# Send the message to syslog with the appropriate level
|
||||
if message.startswith("ERROR:"):
|
||||
# Use error level for messages that start with "ERROR:"
|
||||
syslog.syslog(log_levels["err"], message)
|
||||
else:
|
||||
# Use info level for other messages
|
||||
syslog.syslog(log_levels["info"], message)
|
||||
return 0
|
||||
|
||||
# Assign the custom write function to sys.stdout and sys.stderr
|
||||
setattr(sys.stdout, "write", write)
|
||||
setattr(sys.stderr, "write", write)
|
||||
|
||||
# Define a dummy flush function to prevent errors
|
||||
def flush() -> None:
|
||||
pass
|
||||
|
||||
# Assign the dummy flush function to sys.stdout and sys.stderr
|
||||
setattr(sys.stdout, "flush", flush)
|
||||
setattr(sys.stderr, "flush", flush)
|
||||
|
||||
|
||||
def _run_socketfile(socket_file: Path, debug: bool) -> None:
|
||||
set_out_to_syslog()
|
||||
run(
|
||||
"clan_cli.webui.app:app",
|
||||
uds=str(socket_file),
|
||||
access_log=debug,
|
||||
reload=False,
|
||||
log_level="debug" if debug else "info",
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def api_server(debug: bool) -> Iterator[Path]:
|
||||
runtime_dir = os.getenv("XDG_RUNTIME_DIR")
|
||||
if runtime_dir is None:
|
||||
raise RuntimeError("XDG_RUNTIME_DIR not set")
|
||||
socket_path = Path(runtime_dir) / "clan.sock"
|
||||
socket_path = socket_path.resolve()
|
||||
|
||||
log.debug("Socketfile lies at %s", socket_path)
|
||||
|
||||
if not is_served(socket_path):
|
||||
log.debug("Starting api server...")
|
||||
mp.set_start_method(method="spawn")
|
||||
proc = mp.Process(target=_run_socketfile, args=(socket_path, debug))
|
||||
proc.start()
|
||||
else:
|
||||
log.info("Api server is already running on %s", socket_path)
|
||||
|
||||
yield socket_path
|
||||
proc.terminate()
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import select
|
||||
import shlex
|
||||
import subprocess
|
||||
import threading
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
|
||||
class CmdState:
|
||||
def __init__(self, proc: subprocess.Popen) -> None:
|
||||
global LOOP
|
||||
self.proc: subprocess.Popen = proc
|
||||
self.stdout: list[str] = []
|
||||
self.stderr: list[str] = []
|
||||
self.output: queue.SimpleQueue = queue.SimpleQueue()
|
||||
self.returncode: int | None = None
|
||||
self.done: bool = False
|
||||
|
||||
|
||||
class BaseTask(threading.Thread):
|
||||
def __init__(self, uuid: UUID) -> None:
|
||||
# calling parent class constructor
|
||||
threading.Thread.__init__(self)
|
||||
|
||||
# constructor
|
||||
self.uuid: UUID = uuid
|
||||
self.log = logging.getLogger(__name__)
|
||||
self.procs: list[CmdState] = []
|
||||
self.failed: bool = False
|
||||
self.finished: bool = False
|
||||
|
||||
def run(self) -> None:
|
||||
self.finished = True
|
||||
|
||||
def run_cmd(self, cmd: list[str]) -> CmdState:
|
||||
cwd = os.getcwd()
|
||||
self.log.debug(f"Working directory: {cwd}")
|
||||
self.log.debug(f"Running command: {shlex.join(cmd)}")
|
||||
p = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
encoding="utf-8",
|
||||
# shell=True,
|
||||
cwd=cwd,
|
||||
)
|
||||
self.procs.append(CmdState(p))
|
||||
p_state = self.procs[-1]
|
||||
|
||||
while p.poll() is None:
|
||||
# Check if stderr is ready to be read from
|
||||
rlist, _, _ = select.select([p.stderr, p.stdout], [], [], 0)
|
||||
if p.stderr in rlist:
|
||||
assert p.stderr is not None
|
||||
line = p.stderr.readline()
|
||||
if line != "":
|
||||
p_state.stderr.append(line.strip("\n"))
|
||||
self.log.debug(f"stderr: {line}")
|
||||
p_state.output.put(line)
|
||||
|
||||
if p.stdout in rlist:
|
||||
assert p.stdout is not None
|
||||
line = p.stdout.readline()
|
||||
if line != "":
|
||||
p_state.stdout.append(line.strip("\n"))
|
||||
self.log.debug(f"stdout: {line}")
|
||||
p_state.output.put(line)
|
||||
|
||||
p_state.returncode = p.returncode
|
||||
p_state.output.put(None)
|
||||
p_state.done = True
|
||||
|
||||
if p.returncode != 0:
|
||||
raise RuntimeError(f"Failed to run command: {shlex.join(cmd)}")
|
||||
|
||||
self.log.debug("Successfully ran command")
|
||||
return p_state
|
||||
|
||||
|
||||
class TaskPool:
|
||||
def __init__(self) -> None:
|
||||
self.lock: threading.RLock = threading.RLock()
|
||||
self.pool: dict[UUID, BaseTask] = {}
|
||||
|
||||
def __getitem__(self, uuid: UUID) -> BaseTask:
|
||||
with self.lock:
|
||||
return self.pool[uuid]
|
||||
|
||||
def __setitem__(self, uuid: UUID, task: BaseTask) -> None:
|
||||
with self.lock:
|
||||
if uuid in self.pool:
|
||||
raise KeyError(f"Task with uuid {uuid} already exists")
|
||||
if type(uuid) is not UUID:
|
||||
raise TypeError("uuid must be of type UUID")
|
||||
self.pool[uuid] = task
|
||||
|
||||
|
||||
POOL: TaskPool = TaskPool()
|
||||
|
||||
|
||||
def get_task(uuid: UUID) -> BaseTask:
|
||||
global POOL
|
||||
return POOL[uuid]
|
||||
|
||||
|
||||
def register_task(task: type, *args: Any) -> UUID:
|
||||
global POOL
|
||||
if not issubclass(task, BaseTask):
|
||||
raise TypeError("task must be a subclass of BaseTask")
|
||||
|
||||
uuid = uuid4()
|
||||
|
||||
inst_task = task(uuid, *args)
|
||||
POOL[uuid] = inst_task
|
||||
inst_task.start()
|
||||
return uuid
|
||||
@@ -29,6 +29,7 @@
|
||||
, copyDesktopItems
|
||||
, qemu
|
||||
, gnupg
|
||||
, e2fsprogs
|
||||
}:
|
||||
let
|
||||
|
||||
@@ -63,6 +64,7 @@ let
|
||||
sops
|
||||
git
|
||||
qemu
|
||||
e2fsprogs
|
||||
];
|
||||
|
||||
runtimeDependenciesAsSet = builtins.listToAttrs (builtins.map (p: lib.nameValuePair (lib.getName p.name) p) runtimeDependencies);
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -28,3 +29,23 @@ def test_inspect_err(api: TestClient) -> None:
|
||||
data = response.json()
|
||||
print("Data: ", data)
|
||||
assert data.get("detail")
|
||||
|
||||
|
||||
@pytest.mark.impure
|
||||
def test_inspect_flake(api: TestClient, test_flake_with_core: Path) -> None:
|
||||
params = {"url": str(test_flake_with_core)}
|
||||
response = api.get(
|
||||
"/api/flake",
|
||||
params=params,
|
||||
)
|
||||
assert response.status_code == 200, "Failed to inspect vm"
|
||||
data = response.json()
|
||||
print("Data: ", json.dumps(data, indent=2))
|
||||
assert data.get("content") is not None
|
||||
actions = data.get("actions")
|
||||
assert actions is not None
|
||||
assert len(actions) == 2
|
||||
assert actions[0].get("id") == "vms/inspect"
|
||||
assert actions[0].get("uri") == "api/vms/inspect"
|
||||
assert actions[1].get("id") == "vms/create"
|
||||
assert actions[1].get("uri") == "api/vms/create"
|
||||
|
||||
@@ -6,24 +6,6 @@ from api import TestClient
|
||||
from httpx import SyncByteStream
|
||||
|
||||
|
||||
def is_running_in_ci() -> bool:
|
||||
# Check if pytest is running in GitHub Actions
|
||||
if os.getenv("GITHUB_ACTIONS") == "true":
|
||||
print("Running on GitHub Actions")
|
||||
return True
|
||||
|
||||
# Check if pytest is running in Travis CI
|
||||
if os.getenv("TRAVIS") == "true":
|
||||
print("Running on Travis CI")
|
||||
return True
|
||||
|
||||
# Check if pytest is running in Circle CI
|
||||
if os.getenv("CIRCLECI") == "true":
|
||||
print("Running on Circle CI")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.impure
|
||||
def test_inspect(api: TestClient, test_flake_with_core: Path) -> None:
|
||||
response = api.post(
|
||||
@@ -49,10 +31,9 @@ def test_incorrect_uuid(api: TestClient) -> None:
|
||||
assert response.status_code == 422, "Failed to get vm status"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.path.exists("/dev/kvm"), reason="Requires KVM")
|
||||
@pytest.mark.impure
|
||||
def test_create(api: TestClient, test_flake_with_core: Path) -> None:
|
||||
if is_running_in_ci():
|
||||
pytest.skip("Skipping test in CI. As it requires KVM")
|
||||
print(f"flake_url: {test_flake_with_core} ")
|
||||
response = api.post(
|
||||
"/api/vms/create",
|
||||
@@ -74,20 +55,11 @@ def test_create(api: TestClient, test_flake_with_core: Path) -> None:
|
||||
assert response.status_code == 200, "Failed to get vm status"
|
||||
|
||||
response = api.get(f"/api/vms/{uuid}/logs")
|
||||
print("=========FLAKE LOGS==========")
|
||||
assert isinstance(response.stream, SyncByteStream)
|
||||
for line in response.stream:
|
||||
assert line != b"", "Failed to get vm logs"
|
||||
print(line.decode("utf-8"), end="")
|
||||
print("=========END LOGS==========")
|
||||
assert response.status_code == 200, "Failed to get vm logs"
|
||||
|
||||
response = api.get(f"/api/vms/{uuid}/logs")
|
||||
assert isinstance(response.stream, SyncByteStream)
|
||||
print("=========VM LOGS==========")
|
||||
assert isinstance(response.stream, SyncByteStream)
|
||||
for line in response.stream:
|
||||
assert line != b"", "Failed to get vm logs"
|
||||
print(line.decode("utf-8"), end="")
|
||||
print(line.decode("utf-8"))
|
||||
print("=========END LOGS==========")
|
||||
assert response.status_code == 200, "Failed to get vm logs"
|
||||
|
||||
|
||||
22
pkgs/clan-cli/tests/test_vms_cli.py
Normal file
22
pkgs/clan-cli/tests/test_vms_cli.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from cli import Cli
|
||||
|
||||
no_kvm = not os.path.exists("/dev/kvm")
|
||||
|
||||
|
||||
@pytest.mark.impure
|
||||
def test_inspect(test_flake_with_core: Path, capsys: pytest.CaptureFixture) -> None:
|
||||
cli = Cli()
|
||||
cli.run(["vms", "inspect", "vm1"])
|
||||
out = capsys.readouterr() # empty the buffer
|
||||
assert "Cores" in out.out
|
||||
|
||||
|
||||
@pytest.mark.skipif(no_kvm, reason="Requires KVM")
|
||||
@pytest.mark.impure
|
||||
def test_create(test_flake_with_core: Path) -> None:
|
||||
cli = Cli()
|
||||
cli.run(["vms", "create", "vm1"])
|
||||
Reference in New Issue
Block a user