vars: refactor - unify get_generators and _get_closure
This commit is contained in:
@@ -318,8 +318,13 @@ export const useMachineGenerators = (
|
|||||||
],
|
],
|
||||||
queryFn: async () => {
|
queryFn: async () => {
|
||||||
const call = client.fetch("get_generators", {
|
const call = client.fetch("get_generators", {
|
||||||
base_dir: clanUri,
|
machine: {
|
||||||
machine_name: machineName,
|
name: machineName,
|
||||||
|
flake: {
|
||||||
|
identifier: clanUri,
|
||||||
|
},
|
||||||
|
full_closure: true, // TODO: Make this configurable
|
||||||
|
},
|
||||||
// TODO: Make this configurable
|
// TODO: Make this configurable
|
||||||
include_previous_values: true,
|
include_previous_values: true,
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -725,8 +725,9 @@ def test_api_set_prompts(
|
|||||||
)
|
)
|
||||||
assert store.get(my_generator, "prompt1").decode() == "input2"
|
assert store.get(my_generator, "prompt1").decode() == "input2"
|
||||||
|
|
||||||
|
machine = Machine(name="my_machine", flake=Flake(str(flake.path)))
|
||||||
generators = get_generators(
|
generators = get_generators(
|
||||||
machine_name="my_machine", base_dir=flake.path, include_previous_values=True
|
machine=machine, full_closure=True, include_previous_values=True
|
||||||
)
|
)
|
||||||
# get_generators should bind the store
|
# get_generators should bind the store
|
||||||
assert generators[0].files[0]._store is not None
|
assert generators[0].files[0]._store is not None
|
||||||
|
|||||||
@@ -427,12 +427,25 @@ def _get_previous_value(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _get_closure(
|
@API.register
|
||||||
|
def get_generators(
|
||||||
machine: "Machine",
|
machine: "Machine",
|
||||||
generator_name: str | None,
|
|
||||||
full_closure: bool,
|
full_closure: bool,
|
||||||
|
generator_name: str | None = None,
|
||||||
include_previous_values: bool = False,
|
include_previous_values: bool = False,
|
||||||
) -> list[Generator]:
|
) -> list[Generator]:
|
||||||
|
"""
|
||||||
|
Get generators for a machine, with optional closure computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
machine: The machine to get generators for.
|
||||||
|
full_closure: If True, include all dependency generators. If False, only include missing ones.
|
||||||
|
generator_name: Name of a specific generator to get, or None for all generators.
|
||||||
|
include_previous_values: If True, populate prompts with their previous values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of generators based on the specified selection and closure mode.
|
||||||
|
"""
|
||||||
from . import graph
|
from . import graph
|
||||||
|
|
||||||
vars_generators = Generator.get_machine_generators(machine.name, machine.flake)
|
vars_generators = Generator.get_machine_generators(machine.name, machine.flake)
|
||||||
@@ -510,31 +523,6 @@ def _generate_vars_for_machine(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@API.register
|
|
||||||
def get_generators(
|
|
||||||
machine_name: str,
|
|
||||||
base_dir: Path,
|
|
||||||
include_previous_values: bool = False,
|
|
||||||
) -> list[Generator]:
|
|
||||||
"""
|
|
||||||
Get the list of generators for a machine, optionally with previous values.
|
|
||||||
If `full_closure` is True, it returns the full closure of generators.
|
|
||||||
If `include_previous_values` is True, it includes the previous values for prompts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
machine_name (str): The name of the machine.
|
|
||||||
base_dir (Path): The base directory of the flake.
|
|
||||||
Returns:
|
|
||||||
list[Generator]: A list of generators for the machine.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return Generator.get_machine_generators(
|
|
||||||
machine_name,
|
|
||||||
Flake(str(base_dir)),
|
|
||||||
include_previous_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@API.register
|
@API.register
|
||||||
def run_generators(
|
def run_generators(
|
||||||
machine_name: str,
|
machine_name: str,
|
||||||
@@ -585,7 +573,7 @@ def create_machine_vars_interactive(
|
|||||||
regenerate: bool,
|
regenerate: bool,
|
||||||
no_sandbox: bool = False,
|
no_sandbox: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
generators = _get_closure(machine, generator_name, regenerate)
|
generators = get_generators(machine, regenerate, generator_name)
|
||||||
if len(generators) == 0:
|
if len(generators) == 0:
|
||||||
return
|
return
|
||||||
all_prompt_values = {}
|
all_prompt_values = {}
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from clan_cli.completions import add_dynamic_completer, complete_machines
|
from clan_cli.completions import add_dynamic_completer, complete_machines
|
||||||
from clan_lib.flake import Flake, require_flake
|
from clan_lib.flake import Flake, require_flake
|
||||||
@@ -20,7 +19,7 @@ def get_machine_vars(base_dir: str, machine_name: str) -> list[Var]:
|
|||||||
|
|
||||||
all_vars = []
|
all_vars = []
|
||||||
|
|
||||||
generators = get_generators(base_dir=Path(base_dir), machine_name=machine_name)
|
generators = get_generators(machine=machine, full_closure=True)
|
||||||
for generator in generators:
|
for generator in generators:
|
||||||
for var in generator.files:
|
for var in generator.files:
|
||||||
if var.secret:
|
if var.secret:
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ def test_clan_create_api(
|
|||||||
# Invalidate cache because of new inventory
|
# Invalidate cache because of new inventory
|
||||||
clan_dir_flake.invalidate_cache()
|
clan_dir_flake.invalidate_cache()
|
||||||
|
|
||||||
generators = get_generators(machine.name, machine.flake.path)
|
generators = get_generators(machine=machine, full_closure=True)
|
||||||
all_prompt_values = {}
|
all_prompt_values = {}
|
||||||
for generator in generators:
|
for generator in generators:
|
||||||
prompt_values = {}
|
prompt_values = {}
|
||||||
|
|||||||
Reference in New Issue
Block a user