vars: improve generator execution pipeline

- ensure all dependents are re-generated as well
- refactor: separate out computation of generator update closure
This commit is contained in:
DavHau
2024-09-06 14:30:23 +02:00
parent f4fb9ea96c
commit 5cd9960ed4
5 changed files with 104 additions and 38 deletions

View File

@@ -18,6 +18,7 @@ class Prompt:
previous_value: str | None = None
# TODO: add flag 'pending' generator needs to be executed
@dataclass
class Generator:
name: str

View File

@@ -1,5 +1,4 @@
import argparse
import importlib
import logging
import os
import sys
@@ -197,9 +196,9 @@ def execute_generator(
return True
def _get_subgraph(graph: dict[str, set], vertex: str) -> dict[str, set]:
def _get_subgraph(graph: dict[str, set], vertices: list[str]) -> dict[str, set]:
visited = set()
queue = [vertex]
queue = vertices
while queue:
vertex = queue.pop(0)
if vertex not in visited:
@@ -208,50 +207,90 @@ def _get_subgraph(graph: dict[str, set], vertex: str) -> dict[str, set]:
return {k: v for k, v in graph.items() if k in visited}
def _dependency_graph(
machine: Machine, entry_nodes: None | list[str] = None
) -> dict[str, set]:
graph = {
gen_name: set(generator["dependencies"])
for gen_name, generator in machine.vars_generators.items()
}
if entry_nodes:
return _get_subgraph(graph, entry_nodes)
return graph
def _reverse_dependency_graph(
machine: Machine, entry_nodes: None | list[str] = None
) -> dict[str, set]:
graph = _dependency_graph(machine)
reverse_graph: dict[str, set] = {gen_name: set() for gen_name in graph}
for gen_name, dependencies in graph.items():
for dep in dependencies:
reverse_graph[dep].add(gen_name)
if entry_nodes:
return _get_subgraph(reverse_graph, entry_nodes)
return reverse_graph
def _required_generators(
machine: Machine,
desired_generators: list[str],
) -> list[str]:
"""
Receives list fo desired generators to update and returns list of required generators to update.
This is needed because some generators might depend on others, so we need to update them first.
The returned list is sorted topologically.
"""
dependency_graph = _dependency_graph(machine)
# extract sub-graph if specific generators selected
dependency_graph = _get_subgraph(dependency_graph, desired_generators)
# check if all dependencies actually exist
for gen_name, dependencies in dependency_graph.items():
for dep in dependencies:
if dep not in dependency_graph:
msg = f"Generator {gen_name} has a dependency on {dep}, which does not exist"
raise ClanError(msg)
# ensure that all dependents are regenerated as well as their vars might depend on the current generator
reverse_dependency_graph = _reverse_dependency_graph(machine, desired_generators)
final_graph = _dependency_graph(
machine, entry_nodes=list(reverse_dependency_graph.keys())
)
# process generators in topological order (dependencies first)
sorter = TopologicalSorter(final_graph)
return list(sorter.static_order())
def _generate_vars_for_machine(
machine: Machine,
generator_name: str | None,
regenerate: bool,
) -> bool:
secret_vars_module = importlib.import_module(machine.secret_vars_module)
secret_vars_store = secret_vars_module.SecretStore(machine=machine)
return _generate_vars_for_machine_multi(
machine, [generator_name] if generator_name else [], regenerate
)
public_vars_module = importlib.import_module(machine.public_vars_module)
public_vars_store = public_vars_module.FactStore(machine=machine)
def _generate_vars_for_machine_multi(
machine: Machine,
generator_names: list[str],
regenerate: bool,
) -> bool:
machine_updated = False
if generator_name and generator_name not in machine.vars_generators:
generators = list(machine.vars_generators.keys())
msg = f"Could not find generator with name: {generator_name}. The following generators are available: {generators}"
raise ClanError(msg)
graph = {
gen_name: set(generator["dependencies"])
for gen_name, generator in machine.vars_generators.items()
}
# extract sub-graph if specific generator selected
if generator_name:
graph = _get_subgraph(graph, generator_name)
# check if all dependencies actually exist
for gen_name, dependencies in graph.items():
for dep in dependencies:
if dep not in graph:
msg = f"Generator {gen_name} has a dependency on {dep}, which does not exist"
raise ClanError(msg)
# process generators in topological order
sorter = TopologicalSorter(graph)
for generator_name in sorter.static_order():
generators_to_update = _required_generators(machine, generator_names)
for generator_name in generators_to_update:
assert generator_name is not None
machine_updated |= execute_generator(
machine=machine,
generator_name=generator_name,
regenerate=regenerate,
secret_vars_store=secret_vars_store,
public_vars_store=public_vars_store,
secret_vars_store=machine.secret_vars_store,
public_vars_store=machine.public_vars_store,
)
if machine_updated:
# flush caches to make sure the new secrets are available in evaluation

View File

@@ -30,8 +30,10 @@ def get_vars(machine: Machine) -> list[Var]:
return pub_store.get_all() + sec_store.get_all()
def _get_prompt_value(
machine: Machine, generator: Generator, prompt: Prompt
def _get_previous_value(
machine: Machine,
generator: Generator,
prompt: Prompt,
) -> str | None:
if not prompt.has_file:
return None
@@ -40,10 +42,16 @@ def _get_prompt_value(
return pub_store.get(
generator.name, prompt.name, shared=generator.share
).decode()
sec_store = secret_store(machine)
if sec_store.exists(generator.name, prompt.name, shared=generator.share):
return sec_store.get(
generator.name, prompt.name, shared=generator.share
).decode()
return None
@API.register
# TODO: use machine_name
def get_prompts(machine: Machine) -> list[Generator]:
generators = []
for gen_name, generator in machine.vars_generators.items():
@@ -61,7 +69,7 @@ def get_prompts(machine: Machine) -> list[Generator]:
has_file=prompt["createFile"],
generator=gen_name,
)
prompt.previous_value = _get_prompt_value(machine, gen, prompt)
prompt.previous_value = _get_previous_value(machine, gen, prompt)
prompts.append(prompt)
generators.append(gen)
@@ -69,6 +77,8 @@ def get_prompts(machine: Machine) -> list[Generator]:
# TODO: Ensure generator dependencies are met (executed in correct order etc.)
# TODO: for missing prompts, default to existing values
# TODO: raise error if mandatory prompt not provided
@API.register
def set_prompts(machine: Machine, updates: list[GeneratorUpdate]) -> None:
for update in updates: