Source code for ewokscore.tests.test_task

import gc
import json
from glob import glob
from pathlib import Path

import pytest

from ..task import Task
from ..task import TaskInputError
from .examples.tasks.sumtask import SumTask


[docs] def find_files(tmp_path: Path, extension): return glob(str(tmp_path / "**" / f"*{extension}"), recursive=True)
[docs] def expected_task_output_storage(task): expected = [var.serialize() for var in task.output_variables.values()] expected.append(task.output_variables.serialize()) return expected
[docs] def assert_storage(tmp_path, expected): lst = [] for filename in find_files(tmp_path, ".json"): with open(filename, "r") as fileobj: lst.append(json.load(fileobj)) for v in lst: if isinstance(v, dict): v.pop("__traceback__", None) assert len(lst) == len(expected) for v in expected: lst.pop(lst.index(v)) assert not lst, "Unexpected data saved"
[docs] def test_no_public_reserved_names(): assert not [s for s in Task._reserved_variable_names() if not s.startswith("_")]
[docs] def test_task_missing_input(): with pytest.raises(TaskInputError): SumTask()
[docs] def test_task_readonly_input(): task = SumTask(inputs={"a": 10}) with pytest.raises(RuntimeError): task.inputs.a = 10
[docs] def test_task_optional_input(tmp_path, varinfo): task = SumTask(inputs={"a": 10}, varinfo=varinfo) assert not task.done task.execute() assert task.done assert task.outputs.result == 10 expected = expected_task_output_storage(task) assert_storage(tmp_path, expected)
[docs] def test_task_done(varinfo): task = SumTask(inputs={"a": 10}, varinfo=varinfo) assert not task.done task.execute() assert task.done task = SumTask(inputs={"a": 10}, varinfo=varinfo) assert task.done task = SumTask(inputs={"a": 10}) assert not task.done task.execute() assert task.done task = SumTask(inputs={"a": 10}) assert not task.done
[docs] def test_task_uhash(varinfo): task = SumTask(inputs={"a": 10}, varinfo=varinfo) uhash = task.uhash assert task.uhash == task.output_variables.uhash assert task.uhash != task.input_variables.uhash task.input_variables["a"].value += 1 assert task.uhash != uhash assert task.uhash == task.output_variables.uhash assert task.uhash != task.input_variables.uhash
[docs] def test_task_storage(tmp_path, varinfo): task = SumTask(inputs={"a": 10, "b": 2}, varinfo=varinfo) assert not task.done task.execute() assert task.done assert task.outputs.result == 12 expected = expected_task_output_storage(task) assert_storage(tmp_path, expected) task = SumTask(inputs={"a": 10, "b": 2}, varinfo=varinfo) assert task.done assert task.outputs.result == 12 assert_storage(tmp_path, expected) task = SumTask({"a": 2, "b": 10}, varinfo=varinfo) assert not task.done task.execute() assert task.done assert task.outputs.result == 12 expected += expected_task_output_storage(task) assert_storage(tmp_path, expected) task = SumTask({"a": task.output_variables["result"], "b": 0}, varinfo=varinfo) assert not task.done task.execute() assert task.done assert task.outputs.result == 12 expected += expected_task_output_storage(task) assert_storage(tmp_path, expected) task = SumTask( {"a": 1, "b": task.output_variables["result"].data_proxy}, varinfo=varinfo ) assert not task.done task.execute() assert task.done assert task.outputs.result == 13 expected += expected_task_output_storage(task) assert_storage(tmp_path, expected) task = SumTask( {"a": 1, "b": task.output_variables["result"].data_proxy.uri}, varinfo=varinfo ) assert not task.done task.execute() assert task.done assert task.outputs.result == 14 expected += expected_task_output_storage(task) assert_storage(tmp_path, expected)
[docs] def test_task_required_positional_inputs(): class MyTask(Task, n_required_positional_inputs=1): pass with pytest.raises(TaskInputError): MyTask()
[docs] def test_task_cleanup_references(): class MyTask(Task, input_names=["mylist"], output_names=["mylist"]): def run(self): self.outputs.mylist = self.inputs.mylist + [len(self.inputs.mylist)] obj = [0, 1, 2] nref_start = len(gc.get_referrers(obj)) task1 = MyTask(inputs={"mylist": obj}) task2 = MyTask(inputs=task1.output_variables) task1.execute() task2.execute() assert len(gc.get_referrers(obj)) > nref_start uhash1 = task1.uhash uhashes1 = task1.get_output_uhashes() uhash2 = task2.uhash uhashes2 = task2.get_output_uhashes() task1.cleanup_references() while gc.collect(): pass assert len(gc.get_referrers(obj)) == nref_start assert uhash1 == task1.uhash assert uhashes1 == task1.get_output_uhashes() assert uhash2 == task2.uhash assert uhashes2 == task2.get_output_uhashes()
[docs] def test_task_cancel(varinfo): task = SumTask(inputs={"a": 10}, varinfo=varinfo) with pytest.raises(NotImplementedError): task.cancel()