import logging
import threading
import time
from pprint import pformat
from typing import Dict
from typing import List
from typing import Optional
import pytest
from ewoksutils.event_utils import FIELD_TYPES
from ewoksutils.import_utils import qualname
from ewoksutils.sqlite3_utils import connect
from ewoksutils.sqlite3_utils import select
from ..bindings import execute_graph
from ..events import cleanup as cleanup_events
from ..task import Task
logger = logging.getLogger(__name__)
[docs]
@pytest.fixture
def sqlite_path(tmp_path):
try:
yield tmp_path
finally:
cleanup_events()
[docs]
def test_succesfull_workfow(sqlite_path):
database = sqlite_path / "ewoks_events.db"
_run_succesfull_workfow(database, execute_graph)
events = _fetch_events(database, 10)
_assert_succesfull_workfow_events(events)
[docs]
def test_failed_workfow(sqlite_path):
database = sqlite_path / "ewoks_events.db"
_run_failed_workfow(database, execute_graph)
events = _fetch_events(database, 8)
_assert_failed_workfow_events(events)
[docs]
def test_changing_handlers(sqlite_path):
database1 = sqlite_path / "ewoks_events1.db"
_run_succesfull_workfow(database1, execute_graph)
events = _fetch_events(database1, 10)
_assert_sleep_workfow_events(events)
size_before = database1.stat().st_size
database2 = sqlite_path / "ewoks_events2.db"
_run_succesfull_workfow(database2, execute_graph)
events = _fetch_events(database2, 10)
_assert_sleep_workfow_events(events)
size_after = database1.stat().st_size
assert size_before == size_after
[docs]
def test_changing_handlers_parallel(sqlite_path, n_concurrent=4):
databases = [sqlite_path / f"ewoks_events_{i}.db" for i in range(n_concurrent)]
threads = [
threading.Thread(target=_run_sleep_workfow, args=(db, execute_graph))
for db in databases
]
_run_threads(threads)
sizes = []
for db in databases:
events = _fetch_events(db, 10)
_assert_sleep_workfow_events(events)
sizes.append(db.stat().st_size)
assert len(set(sizes)) == 1
def _run_threads(threads):
for t in threads:
t.start()
deadline = time.time() + 20
for t in threads:
remaining = deadline - time.time()
if remaining <= 0:
pytest.fail("Timeout while waiting for workflow threads to finish")
t.join(timeout=remaining)
if t.is_alive():
pytest.fail(f"Workflow thread {t.name} did not finish within {deadline}s")
class _MyTask(
Task, input_names=["ctr"], optional_input_names=["error_msg"], output_names=["ctr"]
):
def run(self):
if self.inputs.error_msg:
raise ValueError(self.inputs.error_msg)
else:
self.outputs.ctr = self.inputs.ctr + 1
class _MySleepTask(
Task, input_names=["ctr"], optional_input_names=["error_msg"], output_names=["ctr"]
):
def run(self):
time.sleep(0.2)
self.outputs.ctr = self.inputs.ctr + 1
def _run_succesfull_workfow(database, execute_graph, **execute_options):
graph = {"id": "test_graph", "schema_version": "1.1"}
nodes = [
{
"id": "node1",
"task_type": "class",
"task_identifier": qualname(_MyTask),
"default_inputs": [{"name": "ctr", "value": 0}],
},
{
"id": "node2",
"task_type": "class",
"task_identifier": qualname(_MyTask),
"default_inputs": [{"name": "ctr", "value": 0}],
},
{
"id": "node3",
"task_type": "class",
"task_identifier": qualname(_MyTask),
"default_inputs": [{"name": "ctr", "value": 0}],
},
]
links = [
{
"source": "node1",
"target": "node2",
"data_mapping": [{"source_output": "ctr", "target_input": "ctr"}],
},
{
"source": "node2",
"target": "node3",
"data_mapping": [{"source_output": "ctr", "target_input": "ctr"}],
},
]
taskgraph = {"graph": graph, "nodes": nodes, "links": links}
_execute_graph(database, taskgraph, execute_graph, **execute_options)
def _assert_succesfull_workfow_events(events):
expected = [
{"context": "job", "node_id": None, "type": "start"},
{"context": "workflow", "node_id": None, "type": "start"},
{"context": "node", "node_id": "node1", "type": "start"},
{"context": "node", "node_id": "node1", "type": "end"},
{"context": "node", "node_id": "node2", "type": "start"},
{"context": "node", "node_id": "node2", "type": "end"},
{"context": "node", "node_id": "node3", "type": "start"},
{"context": "node", "node_id": "node3", "type": "end"},
{"context": "workflow", "node_id": None, "type": "end"},
{"context": "job", "node_id": None, "type": "end"},
]
captured = [
{k: event[k] for k in ("context", "node_id", "type")} for event in events
]
_assert_events(expected, captured)
def _run_failed_workfow(database, execute_graph, **execute_options):
graph = {"id": "test_graph", "schema_version": "1.1"}
nodes = [
{
"id": "node1",
"task_type": "class",
"task_identifier": qualname(_MyTask),
"default_inputs": [{"name": "ctr", "value": 0}],
},
{
"id": "node2",
"task_type": "class",
"task_identifier": qualname(_MyTask),
"default_inputs": [
{"name": "ctr", "value": 0},
{"name": "error_msg", "value": "abc"},
],
},
{
"id": "node3",
"task_type": "class",
"task_identifier": qualname(_MyTask),
"default_inputs": [{"name": "ctr", "value": 0}],
},
]
links = [
{
"source": "node1",
"target": "node2",
"data_mapping": [{"source_output": "ctr", "target_input": "ctr"}],
},
{
"source": "node2",
"target": "node3",
"data_mapping": [{"source_output": "ctr", "target_input": "ctr"}],
},
]
graph = {"graph": graph, "nodes": nodes, "links": links}
_execute_graph(database, graph, execute_graph, **execute_options)
def _assert_failed_workfow_events(events):
err_msg = "Execution failed for ewoks task 'node2' (id: 'node2', task: 'ewokscore.tests.test_workflow_events._MyTask'): abc"
expected = [
{
"context": "job",
"node_id": None,
"type": "start",
"error_message": None,
},
{
"context": "workflow",
"node_id": None,
"type": "start",
"error_message": None,
},
{
"context": "node",
"node_id": "node1",
"type": "start",
"error_message": None,
},
{
"context": "node",
"node_id": "node1",
"type": "end",
"error_message": None,
},
{
"context": "node",
"node_id": "node2",
"type": "start",
"error_message": None,
},
{
"context": "node",
"node_id": "node2",
"type": "end",
"error_message": "abc",
},
{
"context": "workflow",
"node_id": None,
"type": "end",
"error_message": err_msg,
},
{
"context": "job",
"node_id": None,
"type": "end",
"error_message": err_msg,
},
]
captured = [
{k: event[k] for k in ("context", "node_id", "type", "error_message")}
for event in events
]
_assert_events(expected, captured)
def _run_sleep_workfow(database, execute_graph, **execute_options):
graph = {"id": "test_graph", "schema_version": "1.1"}
nodes = [
{
"id": "node1",
"task_type": "class",
"task_identifier": qualname(_MySleepTask),
"default_inputs": [{"name": "ctr", "value": 0}],
},
{
"id": "node2",
"task_type": "class",
"task_identifier": qualname(_MySleepTask),
"default_inputs": [{"name": "ctr", "value": 0}],
},
{
"id": "node3",
"task_type": "class",
"task_identifier": qualname(_MySleepTask),
"default_inputs": [{"name": "ctr", "value": 0}],
},
]
links = [
{
"source": "node1",
"target": "node2",
"data_mapping": [{"source_output": "ctr", "target_input": "ctr"}],
},
{
"source": "node2",
"target": "node3",
"data_mapping": [{"source_output": "ctr", "target_input": "ctr"}],
},
]
taskgraph = {"graph": graph, "nodes": nodes, "links": links}
_execute_graph(database, taskgraph, execute_graph, **execute_options)
def _assert_sleep_workfow_events(events):
expected = [
{"context": "job", "node_id": None, "type": "start"},
{"context": "workflow", "node_id": None, "type": "start"},
{"context": "node", "node_id": "node1", "type": "start"},
{"context": "node", "node_id": "node1", "type": "end"},
{"context": "node", "node_id": "node2", "type": "start"},
{"context": "node", "node_id": "node2", "type": "end"},
{"context": "node", "node_id": "node3", "type": "start"},
{"context": "node", "node_id": "node3", "type": "end"},
{"context": "workflow", "node_id": None, "type": "end"},
{"context": "job", "node_id": None, "type": "end"},
]
captured = [
{k: event[k] for k in ("context", "node_id", "type")} for event in events
]
_assert_events(expected, captured)
def _execute_graph(database, graph, execute_graph, **execute_options):
execinfo = execute_options.setdefault("execinfo", dict())
handlers = execinfo.setdefault("handlers", list())
handlers.append(
{
"class": "ewokscore.events.handlers.Sqlite3EwoksEventHandler",
"arguments": [{"name": "uri", "value": database}],
}
)
try:
execute_graph(graph, **execute_options)
except RuntimeError:
pass
def _assert_events(expected, captured):
missing = list()
unexpected = list(captured)
for event in expected:
try:
unexpected.remove(event)
except ValueError:
missing.append(event)
if missing or unexpected:
raise AssertionError(
f"ewoks events not as expected\nmissing:\n{pformat(missing)}\nunexpected:\n{unexpected}"
)
def _fetch_events(database: str, nevents: int) -> List[Dict[str, Optional[str]]]:
"""Events are handled asynchronously so wait until we have the required `nevents`
up to 3 seconds.
"""
exception = None
events = list()
for _ in range(30):
try:
with connect(database) as conn:
events = list(select(conn, "ewoks_events", field_types=FIELD_TYPES))
if len(events) != nevents:
raise RuntimeError(f"{len(events)} ewoks events instead of {nevents}")
return events
except Exception as e:
exception = e
time.sleep(0.1)
if exception:
logger.error(exception)
return events