import itertools
from copy import deepcopy
from typing import Any
from typing import Iterator
from typing import Optional
from typing import Tuple
from typing import Union
import networkx
from ..node import flatten_node_id
from ..utils import dict_merge
NodeIdType = Union[str, Tuple[str, Any]] # Any is NodeIdType
def _append_subnode_id(node_id: NodeIdType, sub_node_id: str) -> NodeIdType:
if isinstance(node_id, tuple):
parent, child = node_id
return parent, _append_subnode_id(child, sub_node_id)
else:
return node_id, sub_node_id
def _get_subgraph(node_id: NodeIdType, subgraphs: dict):
if isinstance(node_id, str):
return subgraphs.get(node_id)
subgraph_id, subnode_id = node_id
try:
subgraph = subgraphs[subgraph_id]
except KeyError:
raise ValueError(node_id, f"{repr(subgraph_id)} is not a subgraph")
flat_subnode_id = flatten_node_id(subnode_id)
n = len(flat_subnode_id)
for node_id in subgraph.graph.nodes:
flat_node_id = flatten_node_id(node_id)
nflatid = len(flat_node_id)
if flat_node_id == flat_subnode_id:
return None # a task node
if nflatid > n and flat_node_id[:n] == flat_subnode_id:
return subgraph # a graph node
raise ValueError(
f"{subnode_id} is not a node or subgraph of subgraph {repr(subgraph_id)}",
)
def _alias_to_node_id(alias_attrs: dict) -> NodeIdType:
sub_node = alias_attrs.get("sub_node", None)
if sub_node:
return alias_attrs["node"], sub_node
else:
return alias_attrs["node"]
def _resolve_node_aliases(
node_id: NodeIdType, graph_attrs: dict, input_nodes: bool
) -> Iterator[Tuple[NodeIdType, dict]]:
if input_nodes:
aliases = graph_attrs.get("input_nodes", list())
else:
aliases = graph_attrs.get("output_nodes", list())
aliases = [alias_attrs for alias_attrs in aliases if alias_attrs["id"] == node_id]
if aliases:
for alias_attrs in aliases:
sub_node_id = _alias_to_node_id(alias_attrs)
link_attributes = alias_attrs.get("link_attributes", dict())
yield sub_node_id, link_attributes
else:
yield node_id, dict()
def _resolve_all_node_aliases(
graph_attrs: dict, input_nodes: bool
) -> Iterator[NodeIdType]:
if input_nodes:
aliases = graph_attrs.get("input_nodes", list())
else:
aliases = graph_attrs.get("output_nodes", list())
for alias_attrs in aliases:
yield _alias_to_node_id(alias_attrs)
def _get_subnode_ids(
node_id: NodeIdType, link_attrs_subgraph_keys: dict, subgraphs: dict, source: bool
) -> Iterator[Tuple[NodeIdType, Optional[dict]]]:
if source:
key = "sub_source"
else:
key = "sub_target"
subgraph = _get_subgraph(node_id, subgraphs)
if subgraph is None:
# node_id is not a subgraph
if key in link_attrs_subgraph_keys:
raise ValueError(
f"'{node_id}' is not a graph so 'sub_source' should not be specified"
)
yield node_id, None
else:
# node_id is a subgraph
try:
sub_node_id = link_attrs_subgraph_keys[key]
except KeyError:
raise ValueError(
f"the '{key}' attribute to specify a node in subgraph '{node_id}' is missing"
) from None
for sub_node_id, link_attributes in _resolve_node_aliases(
sub_node_id, subgraph.graph.graph, input_nodes=not source
):
new_node_id = _append_subnode_id(node_id, sub_node_id)
yield new_node_id, link_attributes
def _get_subnode_attributes(
node_id: NodeIdType, subgraphs: dict, graph_node_attrs: dict
) -> Iterator[Tuple[NodeIdType, dict]]:
"""Update all input node attributes of the subgraph with the graph node attributes from the super graph"""
transfer_attributes = {
"default_inputs",
"force_start_node",
"conditions_else_value",
"default_error_node",
}
node_attrs = {k: v for k, v in graph_node_attrs.items() if k in transfer_attributes}
if not node_attrs:
return
subgraph = _get_subgraph(node_id, subgraphs)
if subgraph is None:
# node_id is not a subgraph
return
# node_id is a subgraph
for sub_node_id in _resolve_all_node_aliases(
subgraph.graph.graph, input_nodes=True
):
new_node_id = _append_subnode_id(node_id, sub_node_id)
yield new_node_id, node_attrs
def _get_subnode_links(
source_id: NodeIdType,
target_id: NodeIdType,
link_attrs_subgraph_keys: dict,
subgraphs: dict,
) -> Iterator[Tuple[NodeIdType, NodeIdType, dict, bool]]:
sources = list(
_get_subnode_ids(source_id, link_attrs_subgraph_keys, subgraphs, source=True)
)
targets = list(
_get_subnode_ids(target_id, link_attrs_subgraph_keys, subgraphs, source=False)
)
for sub_source, source_link_attrs in sources:
for sub_target, target_link_attrs in targets:
if source_link_attrs:
link_attrs = source_link_attrs
else:
link_attrs = dict()
if target_link_attrs:
link_attrs.update(target_link_attrs)
target_is_graph = target_link_attrs is not None
yield sub_source, sub_target, link_attrs, target_is_graph
def _replace_aliases(
graph: networkx.DiGraph, subgraphs: dict, input_nodes: bool
) -> dict:
if input_nodes:
aliases = graph.graph.get("input_nodes")
if not aliases:
return
source = False
key = "sub_target"
else:
aliases = graph.graph.get("output_nodes")
if not aliases:
return
source = True
key = "sub_source"
new_aliases = list()
for alias_attrs in aliases:
node_id = alias_attrs.get("node")
if node_id is None:
continue
sub_node = alias_attrs.pop("sub_node", None)
if sub_node:
node_id = node_id, sub_node
if not isinstance(node_id, tuple):
new_aliases.append(alias_attrs)
continue
parent, child = node_id
original_alias_attrs = alias_attrs
for node_id, link_attrs in _get_subnode_ids(
parent, {key: child}, subgraphs=subgraphs, source=source
):
alias_attrs = deepcopy(original_alias_attrs)
if link_attrs:
link_attrs.update(alias_attrs.get("link_attributes", dict()))
alias_attrs["link_attributes"] = link_attrs
alias_attrs["node"] = node_id
new_aliases.append(alias_attrs)
if input_nodes:
graph.graph["input_nodes"] = new_aliases
else:
graph.graph["output_nodes"] = new_aliases
[docs]
def add_subgraph_links(graph: networkx.DiGraph, edges: list, update_attrs: dict):
# Output from extract_graph_nodes
for source, target, _ in edges:
if source not in graph.nodes:
raise ValueError(
f"Source node {repr(source)} of link |{repr(source)} -> {repr(target)}| does not exist"
)
if target not in graph.nodes:
raise ValueError(
f"Target node {repr(target)} of link |{repr(source)} -> {repr(target)}| does not exist"
)
graph.add_edges_from(edges) # This adds missing nodes
for node, attrs in update_attrs.items():
if attrs:
node_attrs = graph.nodes[node]
dict_merge(node_attrs, attrs, overwrite=True)