Source code for ewokscore.graph.graph_io

from typing import Dict
from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional
from typing import Union

import networkx

from .. import missing_data
from ..node import NodeIdType
from ..node import get_node_label
from ..task import Task
from .analysis import end_nodes
from .analysis import start_nodes


[docs] def update_default_inputs( graph: networkx.DiGraph, inputs: Optional[List[dict]] = None ) -> None: """Input items have the following keys: - name: input variable name - value: input variable value - id (optional): node id - label (optional): used when `id` is missing - task_identifier (optional): used when `id` is missing - all (optional): used when `id`, `label` and `task_identifier` are missing (`True`: all nodes, `False`: start nodes) """ inputs = parse_inputs(graph, inputs) keys_to_update = "name", "value" for input_item in inputs: node_id = input_item.get("id") if node_id is None: continue node_attrs = graph.nodes[node_id] default_inputs = node_attrs.get("default_inputs") input_item = {k: input_item[k] for k in keys_to_update} if default_inputs: for existing_input_item in default_inputs: if existing_input_item["name"] == input_item["name"]: existing_input_item.update(input_item) break else: default_inputs.append(input_item) else: node_attrs["default_inputs"] = [input_item]
[docs] def parse_inputs( graph: networkx.DiGraph, inputs: Optional[List[dict]] = None ) -> List[dict]: """Input items have the following keys: - name: input variable name - value: input variable value - id (optional): node id - label (optional): used when `id` is missing - task_identifier (optional): used when `id` is missing - all (optional): used when `id`, `label` and `task_identifier` are missing (`True`: all nodes, `False`: start nodes) """ if not inputs: return list() required = {"name", "value"} returned = {"id", "name", "value"} parsed = list() for input_item in list(inputs): missing = required - input_item.keys() if missing: raise ValueError(f"missing keys in one of the graph inputs: {missing}") if "id" in input_item: parsed.append({k: v for k, v in input_item.items() if k in returned}) continue node_filters = dict() for k in ("label", "task_identifier"): if k in input_item: node_filters[k] = input_item[k] if node_filters: node_ids = iter_node_ids(graph, **node_filters) elif input_item.get("all"): node_ids = graph.nodes else: node_ids = start_nodes(graph) for node_id in node_ids: input_item = {k: v for k, v in input_item.items() if k in returned} input_item["id"] = node_id parsed.append(input_item) return parsed
[docs] def parse_outputs( graph: networkx.DiGraph, outputs: Optional[List[dict]] = None ) -> List[dict]: """Output items have the following keys: - name (optional): output variable name (all outputs when missing) - new_name (optional): optional renaming when `name` is defined - id (optional): node id - label (optional): used when `id` is missing - task_identifier (optional): used when `id` is missing - all (optional): used when `id`, `label` and `task_identifier` are missing (`True`: all nodes, `False`: end nodes) """ if outputs is None: outputs = [{"all": False}] parsed = list() returned = {"id", "name", "new_name"} for output_item in outputs: if "id" in output_item: parsed.append({k: v for k, v in output_item.items() if k in returned}) continue node_filters = dict() for k in ("label", "task_identifier"): if k in output_item: node_filters[k] = output_item[k] if node_filters: node_ids = iter_node_ids(graph, **node_filters) elif output_item.get("all"): node_ids = graph.nodes else: node_ids = end_nodes(graph) for node_id in node_ids: output_item = {k: v for k, v in output_item.items() if k in returned} output_item["id"] = node_id parsed.append(output_item) return parsed
[docs] def iter_node_ids( graph: networkx.DiGraph, label: Optional[str] = None, task_identifier: Optional[str] = None, ) -> Iterator[NodeIdType]: """Yield nodes with matching `label` AND `task_identifier`""" for node_id, node_attrs in graph.nodes.items(): return_id = False if label is not None: node_label = get_node_label(node_id, node_attrs) if label != node_label: continue return_id = True if task_identifier is not None: s = node_attrs.get("task_identifier") if not s or not s.endswith(task_identifier): continue return_id = True if return_id: yield node_id
[docs] def extract_output_values( node_id: NodeIdType, task_or_outputs: Union[Task, Mapping], outputs: List[dict] ) -> Optional[dict]: """Output items have the following keys: - id: node id - label (optional): used when `id` is missing - name (optional): output variable name (all outputs when missing) - new_name (optional): optional renaming when name is defined """ output_values = None if isinstance(task_or_outputs, Task): task_output_values = None else: task_output_values = task_or_outputs for output_item in outputs: if output_item.get("id") != node_id: continue if task_output_values is None: task_output_values = task_or_outputs.get_output_values() if output_values is None: output_values = dict() name = output_item.get("name") if name: new_name = output_item.get("new_name", name) output_values[new_name] = task_output_values.get( name, missing_data.MISSING_DATA ) else: output_values.update(task_output_values) return output_values
[docs] def add_output_values( output_values: dict, node_id: NodeIdType, task_or_outputs: Union[Task, Dict], outputs: List[dict], merge_outputs: Optional[bool] = True, ) -> None: """Output items have the following keys: - id: node id - label (optional): used when `id` is missing - name (optional): output variable name (all outputs when missing) - new_name (optional): optional renaming when name is defined """ task_output_values = extract_output_values(node_id, task_or_outputs, outputs) if task_output_values is not None: if merge_outputs: output_values.update(task_output_values) else: output_values[node_id] = task_output_values