Source code for ewokscore.graph.graph_io

import logging
from typing import Dict
from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional
from typing import Union

import networkx

from .. import missing_data
from ..node import NodeIdType
from ..node import get_node_label
from ..task import Task
from ..utils import build_close_match_report
from .analysis import end_nodes
from .analysis import start_nodes

logger = logging.getLogger(__name__)


[docs] def update_default_inputs( graph: networkx.DiGraph, inputs: Optional[List[dict]] = None ) -> None: """Input items have the following keys: - name: input variable name - value: input variable value - id (optional): node id - label (optional): used when `id` is missing - task_identifier (optional): used when `id` is missing - all (optional): used when `id`, `label` and `task_identifier` are missing (`True`: all nodes, `False`: start nodes) """ inputs = parse_inputs(graph, inputs) keys_to_update = "name", "value" for input_item in inputs: node_id = input_item.get("id") if node_id is None: continue node_attrs = graph.nodes[node_id] default_inputs = node_attrs.get("default_inputs") input_item = {k: input_item[k] for k in keys_to_update} if default_inputs: for existing_input_item in default_inputs: if existing_input_item["name"] == input_item["name"]: existing_input_item.update(input_item) break else: default_inputs.append(input_item) else: node_attrs["default_inputs"] = [input_item]
[docs] def parse_inputs( graph: networkx.DiGraph, inputs: Optional[List[dict]] = None ) -> List[dict]: """Input items have the following keys: - name: input variable name - value: input variable value - id (optional): node id - label (optional): used when `id` is missing - task_identifier (optional): used when `id` is missing - all (optional): used when `id`, `label` and `task_identifier` are missing (`True`: all nodes, `False`: start nodes) """ if not inputs: return list() required = {"name", "value"} returned = {"id", "name", "value"} parsed = list() for input_item in list(inputs): missing = required - input_item.keys() if missing: error_report = ( f"Missing required keys {missing} in input item: {input_item}." + build_close_match_report(missing, input_item.keys()) ) raise ValueError(error_report) if "id" in input_item: parsed.append({k: v for k, v in input_item.items() if k in returned}) continue node_filters = dict() for k in ("label", "task_identifier"): if k in input_item: node_filters[k] = input_item[k] if node_filters: node_ids = iter_node_ids(graph, **node_filters) elif input_item.get("all"): node_ids = graph.nodes else: node_ids = start_nodes(graph) if not node_ids: error_report = ( f"\n The input item {input_item} could not be associated to any node." + build_close_match_report( input_item.keys(), ["id", "label", "task_identifier"] ) ) logger.warning(error_report) for node_id in node_ids: input_item = {k: v for k, v in input_item.items() if k in returned} input_item["id"] = node_id parsed.append(input_item) return parsed
[docs] def parse_outputs( graph: networkx.DiGraph, outputs: Optional[List[dict]] = None ) -> List[dict]: """Output items have the following keys: - name (optional): output variable name (all outputs when missing) - new_name (optional): optional renaming when `name` is defined - id (optional): node id - label (optional): used when `id` is missing - task_identifier (optional): used when `id` is missing - all (optional): used when `id`, `label` and `task_identifier` are missing (`True`: all nodes, `False`: end nodes) """ if outputs is None: outputs = [{"all": False}] parsed = list() returned = {"id", "name", "new_name"} for output_item in outputs: if "id" in output_item: parsed.append({k: v for k, v in output_item.items() if k in returned}) continue node_filters = dict() for k in ("label", "task_identifier"): if k in output_item: node_filters[k] = output_item[k] if node_filters: node_ids = iter_node_ids(graph, **node_filters) elif output_item.get("all"): node_ids = graph.nodes else: node_ids = end_nodes(graph) if not node_ids: error_report = ( f"\n The output item {output_item} could not be associated to any node." + build_close_match_report( output_item.keys(), ["id", "label", "task_identifier"] ) ) logger.warning(error_report) for node_id in node_ids: output_item = {k: v for k, v in output_item.items() if k in returned} output_item["id"] = node_id parsed.append(output_item) return parsed
[docs] def iter_node_ids( graph: networkx.DiGraph, label: Optional[str] = None, task_identifier: Optional[str] = None, ) -> Iterator[NodeIdType]: """Yield nodes with matching `label` AND `task_identifier`""" for node_id, node_attrs in graph.nodes.items(): return_id = False if label is not None: node_label = get_node_label(node_id, node_attrs) if label != node_label: continue return_id = True if task_identifier is not None: s = node_attrs.get("task_identifier") if not s or not s.endswith(task_identifier): continue return_id = True if return_id: yield node_id
[docs] def extract_output_values( node_id: NodeIdType, task_or_outputs: Union[Task, Mapping], outputs: List[dict] ) -> Optional[dict]: """Output items have the following keys: - id: node id - label (optional): used when `id` is missing - name (optional): output variable name (all outputs when missing) - new_name (optional): optional renaming when name is defined """ output_values = None if isinstance(task_or_outputs, Task): task_output_values = None else: task_output_values = task_or_outputs for output_item in outputs: if output_item.get("id") != node_id: continue if task_output_values is None: task_output_values = task_or_outputs.get_output_values() if output_values is None: output_values = dict() name = output_item.get("name") if name: new_name = output_item.get("new_name", name) output_values[new_name] = task_output_values.get( name, missing_data.MISSING_DATA ) else: output_values.update(task_output_values) return output_values
[docs] def add_output_values( output_values: dict, node_id: NodeIdType, task_or_outputs: Union[Task, Dict], outputs: List[dict], merge_outputs: Optional[bool] = True, ) -> None: """Output items have the following keys: - id: node id - label (optional): used when `id` is missing - name (optional): output variable name (all outputs when missing) - new_name (optional): optional renaming when name is defined """ task_output_values = extract_output_values(node_id, task_or_outputs, outputs) if task_output_values is not None: if merge_outputs: output_values.update(task_output_values) else: output_values[node_id] = task_output_values