Source code for ewokscore.graph.subgraph

import itertools
from copy import deepcopy
from typing import Any
from typing import Iterator
from typing import Optional
from typing import Tuple
from typing import Union

import networkx

from ..node import flatten_node_id
from ..utils import dict_merge

NodeIdType = Union[str, Tuple[str, Any]]  # Any is NodeIdType


def _append_subnode_id(node_id: NodeIdType, sub_node_id: str) -> NodeIdType:
    if isinstance(node_id, tuple):
        parent, child = node_id
        return parent, _append_subnode_id(child, sub_node_id)
    else:
        return node_id, sub_node_id


def _get_subgraph(node_id: NodeIdType, subgraphs: dict):
    if isinstance(node_id, str):
        return subgraphs.get(node_id)
    subgraph_id, subnode_id = node_id
    try:
        subgraph = subgraphs[subgraph_id]
    except KeyError:
        raise ValueError(node_id, f"{repr(subgraph_id)} is not a subgraph")
    flat_subnode_id = flatten_node_id(subnode_id)
    n = len(flat_subnode_id)
    for node_id in subgraph.graph.nodes:
        flat_node_id = flatten_node_id(node_id)
        nflatid = len(flat_node_id)
        if flat_node_id == flat_subnode_id:
            return None  # a task node
        if nflatid > n and flat_node_id[:n] == flat_subnode_id:
            return subgraph  # a graph node
    raise ValueError(
        f"{subnode_id} is not a node or subgraph of subgraph {repr(subgraph_id)}",
    )


def _alias_to_node_id(alias_attrs: dict) -> NodeIdType:
    sub_node = alias_attrs.get("sub_node", None)
    if sub_node:
        return alias_attrs["node"], sub_node
    else:
        return alias_attrs["node"]


def _resolve_node_aliases(
    node_id: NodeIdType, graph_attrs: dict, input_nodes: bool
) -> Iterator[Tuple[NodeIdType, dict]]:
    if input_nodes:
        aliases = graph_attrs.get("input_nodes", list())
    else:
        aliases = graph_attrs.get("output_nodes", list())
    aliases = [alias_attrs for alias_attrs in aliases if alias_attrs["id"] == node_id]
    if aliases:
        for alias_attrs in aliases:
            sub_node_id = _alias_to_node_id(alias_attrs)
            link_attributes = alias_attrs.get("link_attributes", dict())
            yield sub_node_id, link_attributes
    else:
        yield node_id, dict()


def _resolve_all_node_aliases(
    graph_attrs: dict, input_nodes: bool
) -> Iterator[NodeIdType]:
    if input_nodes:
        aliases = graph_attrs.get("input_nodes", list())
    else:
        aliases = graph_attrs.get("output_nodes", list())
    for alias_attrs in aliases:
        yield _alias_to_node_id(alias_attrs)


def _get_subnode_ids(
    node_id: NodeIdType, link_attrs_subgraph_keys: dict, subgraphs: dict, source: bool
) -> Iterator[Tuple[NodeIdType, Optional[dict]]]:
    if source:
        key = "sub_source"
    else:
        key = "sub_target"
    subgraph = _get_subgraph(node_id, subgraphs)
    if subgraph is None:
        # node_id is not a subgraph
        if key in link_attrs_subgraph_keys:
            raise ValueError(
                f"'{node_id}' is not a graph so 'sub_source' should not be specified"
            )
        yield node_id, None
    else:
        # node_id is a subgraph
        try:
            sub_node_id = link_attrs_subgraph_keys[key]
        except KeyError:
            raise ValueError(
                f"the '{key}' attribute to specify a node in subgraph '{node_id}' is missing"
            ) from None
        for sub_node_id, link_attributes in _resolve_node_aliases(
            sub_node_id, subgraph.graph.graph, input_nodes=not source
        ):
            new_node_id = _append_subnode_id(node_id, sub_node_id)
            yield new_node_id, link_attributes


def _get_subnode_attributes(
    node_id: NodeIdType, subgraphs: dict, graph_node_attrs: dict
) -> Iterator[Tuple[NodeIdType, dict]]:
    """Update all input node attributes of the subgraph with the graph node attributes from the super graph"""
    transfer_attributes = {
        "default_inputs",
        "force_start_node",
        "conditions_else_value",
        "default_error_node",
    }
    node_attrs = {k: v for k, v in graph_node_attrs.items() if k in transfer_attributes}
    if not node_attrs:
        return
    subgraph = _get_subgraph(node_id, subgraphs)
    if subgraph is None:
        # node_id is not a subgraph
        return
    # node_id is a subgraph
    for sub_node_id in _resolve_all_node_aliases(
        subgraph.graph.graph, input_nodes=True
    ):
        new_node_id = _append_subnode_id(node_id, sub_node_id)
        yield new_node_id, node_attrs


def _get_subnode_links(
    source_id: NodeIdType,
    target_id: NodeIdType,
    link_attrs_subgraph_keys: dict,
    subgraphs: dict,
) -> Iterator[Tuple[NodeIdType, NodeIdType, dict, bool]]:
    sources = list(
        _get_subnode_ids(source_id, link_attrs_subgraph_keys, subgraphs, source=True)
    )
    targets = list(
        _get_subnode_ids(target_id, link_attrs_subgraph_keys, subgraphs, source=False)
    )
    for sub_source, source_link_attrs in sources:
        for sub_target, target_link_attrs in targets:
            if source_link_attrs:
                link_attrs = source_link_attrs
            else:
                link_attrs = dict()
            if target_link_attrs:
                link_attrs.update(target_link_attrs)
            target_is_graph = target_link_attrs is not None
            yield sub_source, sub_target, link_attrs, target_is_graph


def _replace_aliases(
    graph: networkx.DiGraph, subgraphs: dict, input_nodes: bool
) -> dict:
    if input_nodes:
        aliases = graph.graph.get("input_nodes")
        if not aliases:
            return
        source = False
        key = "sub_target"
    else:
        aliases = graph.graph.get("output_nodes")
        if not aliases:
            return
        source = True
        key = "sub_source"

    new_aliases = list()
    for alias_attrs in aliases:
        node_id = alias_attrs.get("node")
        if node_id is None:
            continue
        sub_node = alias_attrs.pop("sub_node", None)
        if sub_node:
            node_id = node_id, sub_node
        if not isinstance(node_id, tuple):
            new_aliases.append(alias_attrs)
            continue
        parent, child = node_id
        original_alias_attrs = alias_attrs
        for node_id, link_attrs in _get_subnode_ids(
            parent, {key: child}, subgraphs=subgraphs, source=source
        ):
            alias_attrs = deepcopy(original_alias_attrs)
            if link_attrs:
                link_attrs.update(alias_attrs.get("link_attributes", dict()))
                alias_attrs["link_attributes"] = link_attrs
            alias_attrs["node"] = node_id
            new_aliases.append(alias_attrs)

    if input_nodes:
        graph.graph["input_nodes"] = new_aliases
    else:
        graph.graph["output_nodes"] = new_aliases


[docs] def extract_graph_nodes(graph: networkx.DiGraph, subgraphs) -> Tuple[list, dict]: """Removes all graph nodes from `graph` and returns a list of edges between the nodes from `graph` and `subgraphs`. Nodes in sub-graphs are defines in the `link_attrs_subgraph_keys` link attribute. For example: .. code: python link_attrs = { "source": "subgraph1", "target": "subgraph2", "data_mapping": [{"source_output":"return_value", "target_input": 0}], "link_attrs_subgraph_keys": { "sub_source": ("subsubgraph", ("subsubsubgraph", "task2")), "sub_target": "task1", }, } """ edges = list() update_attrs = dict() graph_is_multi = graph.is_multigraph() for subgraph_id in subgraphs: # Nodes to be updated after flattening the graph for node_id, graph_node_attrs in _get_subnode_attributes( subgraph_id, subgraphs, graph.nodes[subgraph_id] ): if node_id in update_attrs: update_attrs[node_id].update(graph_node_attrs) else: update_attrs[node_id] = graph_node_attrs # Links to be made after flattening the graph it1 = ( (source_id, subgraph_id) for source_id in graph.predecessors(subgraph_id) ) it2 = ((subgraph_id, target_id) for target_id in graph.successors(subgraph_id)) for source_id, target_id in itertools.chain(it1, it2): all_link_attrs = graph[source_id][target_id] if graph_is_multi: all_link_attrs = all_link_attrs.values() else: all_link_attrs = [all_link_attrs] for link_attrs in all_link_attrs: link_attrs_subgraph_keys = { key: link_attrs.pop(key) for key in ["sub_source", "sub_target", "sub_target_attributes"] if key in link_attrs } if not link_attrs_subgraph_keys: continue original_link_attrs = link_attrs links = _get_subnode_links( source_id, target_id, link_attrs_subgraph_keys, subgraphs ) for source, target, default_link_attrs, target_is_graph in links: link_attrs = deepcopy(original_link_attrs) if default_link_attrs: link_attrs = {**default_link_attrs, **link_attrs} sub_target_attributes = link_attrs_subgraph_keys.get( "sub_target_attributes", None ) if sub_target_attributes: if not target_is_graph: raise ValueError( f"'{target_id}' is not a graph so 'sub_target_attributes' should not be specified" ) if target in update_attrs: update_attrs[target].update(sub_target_attributes) else: update_attrs[target] = sub_target_attributes edges.append((source, target, link_attrs)) _replace_aliases(graph, subgraphs, input_nodes=True) _replace_aliases(graph, subgraphs, input_nodes=False) graph.remove_nodes_from(subgraphs.keys()) return edges, update_attrs