Source code for ewokscore.tests.test_task

import gc
import json
import re
import warnings
from glob import glob
from pathlib import Path

import pytest
from ewoksutils.exceptions import TaskInputError
from ewoksutils.exceptions import TaskInputWarning

from ..task import Task
from .examples.tasks.sumtask import SumTask


[docs] def find_files(tmp_path: Path, extension): return glob(str(tmp_path / "**" / f"*{extension}"), recursive=True)
[docs] def expected_task_output_storage(task): expected = [var.serialize() for var in task.output_variables.values()] expected.append(task.output_variables.serialize()) return expected
[docs] def assert_storage(tmp_path, expected): lst = [] for filename in find_files(tmp_path, ".json"): with open(filename, "r") as fileobj: lst.append(json.load(fileobj)["data"]) for v in lst: if isinstance(v, dict): v.pop("__traceback__", None) assert len(lst) == len(expected) for v in expected: lst.pop(lst.index(v)) assert not lst, "Unexpected data saved"
[docs] def test_no_public_reserved_names(): assert not [s for s in Task._reserved_variable_names() if not s.startswith("_")]
[docs] def test_task_missing_input(): with pytest.raises(TaskInputError): SumTask()
[docs] def test_task_unexpected_input_warning(): with pytest.warns( TaskInputWarning, match=re.escape("Unexpected inputs for task SumTask: ['unknown']"), ): task = SumTask(inputs={"a": 10, "unknown": 1}) assert "unknown" in task.input_variables
[docs] def test_task_readonly_input(): task = SumTask(inputs={"a": 10}) with pytest.raises(RuntimeError): task.inputs.a = 10
[docs] def test_task_optional_input(tmp_path, varinfo): task = SumTask(inputs={"a": 10}, varinfo=varinfo) assert not task.done task.execute() assert task.done assert task.outputs.result == 10 expected = expected_task_output_storage(task) assert_storage(tmp_path, expected)
[docs] def test_task_done(varinfo): task = SumTask(inputs={"a": 10}, varinfo=varinfo) assert not task.done task.execute() assert task.done task = SumTask(inputs={"a": 10}, varinfo=varinfo) assert task.done task = SumTask(inputs={"a": 10}) assert not task.done task.execute() assert task.done task = SumTask(inputs={"a": 10}) assert not task.done
[docs] def test_task_uhash(varinfo): task = SumTask(inputs={"a": 10}, varinfo=varinfo) uhash = task.uhash assert task.uhash == task.output_variables.uhash assert task.uhash != task.input_variables.uhash task.input_variables["a"].value += 1 assert task.uhash != uhash assert task.uhash == task.output_variables.uhash assert task.uhash != task.input_variables.uhash
[docs] def test_task_storage(tmp_path, varinfo): task = SumTask(inputs={"a": 10, "b": 2}, varinfo=varinfo) assert not task.done task.execute() assert task.done assert task.outputs.result == 12 expected = expected_task_output_storage(task) assert_storage(tmp_path, expected) task = SumTask(inputs={"a": 10, "b": 2}, varinfo=varinfo) assert task.done assert task.outputs.result == 12 assert_storage(tmp_path, expected) task = SumTask({"a": 2, "b": 10}, varinfo=varinfo) assert not task.done task.execute() assert task.done assert task.outputs.result == 12 expected += expected_task_output_storage(task) assert_storage(tmp_path, expected) task = SumTask({"a": task.output_variables["result"], "b": 0}, varinfo=varinfo) assert not task.done task.execute() assert task.done assert task.outputs.result == 12 expected += expected_task_output_storage(task) assert_storage(tmp_path, expected) task = SumTask( {"a": 1, "b": task.output_variables["result"].data_proxy}, varinfo=varinfo ) assert not task.done task.execute() assert task.done assert task.outputs.result == 13 expected += expected_task_output_storage(task) assert_storage(tmp_path, expected) task = SumTask( {"a": 1, "b": task.output_variables["result"].data_proxy.uri}, varinfo=varinfo ) assert not task.done task.execute() assert task.done assert task.outputs.result == 14 expected += expected_task_output_storage(task) assert_storage(tmp_path, expected)
[docs] def test_task_required_positional_inputs(): class MyTask(Task, n_required_positional_inputs=1): pass with pytest.raises(TaskInputError): MyTask()
[docs] def test_task_unexpected_input_warning_ignores_positional_inputs(): class MyTask(Task, n_required_positional_inputs=1): pass with warnings.catch_warnings(): warnings.simplefilter("error", UserWarning) MyTask(inputs={0: "value"})
[docs] def test_init_subclass_rejects_input_name_typo(): # test that the wrong name appears in the error statement with pytest.raises(TaskInputError, match=re.escape("input")): class Bad1(Task, input=["reason"]): pass # test that the correct name appears somewhere in the error statement with pytest.raises(TaskInputError, match=re.escape("input_names")): class Bad2(Task, input=["reason"]): pass
[docs] def test_init_subclass_rejects_optional_input_names_typo(): # test that the wrong name appears in the error statement with pytest.raises(TaskInputError, match=re.escape("optinal_input_names")): class Bad1(Task, optinal_input_names=["reason"]): pass # test that the correct name appears somewhere in the error statement with pytest.raises(TaskInputError, match=re.escape("optinal_input_names")): class Bad2(Task, optinal_input_names=["reason"]): pass
[docs] def test_init_subclass_rejects_output_name_typo(): # test that the wrong name appears in the error statement with pytest.raises(TaskInputError, match="output_name"): class Bad1(Task, output_name=["reason"]): pass # test that the correct name appears in the error statement with pytest.raises(TaskInputError, match="output_names"): class Bad2(Task, output_name=["reason"]): pass
[docs] def test_init_subclass_accepts_valid_params(): # Should not raise an error class Good(Task, input_names=["reason"], output_names=["result", "reason"]): pass
[docs] def test_task_cleanup_references(): class MyTask(Task, input_names=["mylist"], output_names=["mylist"]): def run(self): self.outputs.mylist = self.inputs.mylist + [len(self.inputs.mylist)] obj = [0, 1, 2] nref_start = len(gc.get_referrers(obj)) task1 = MyTask(inputs={"mylist": obj}) task2 = MyTask(inputs=task1.output_variables) task1.execute() task2.execute() assert len(gc.get_referrers(obj)) > nref_start uhash1 = task1.uhash uhashes1 = task1.get_output_uhashes() uhash2 = task2.uhash uhashes2 = task2.get_output_uhashes() task1.cleanup_references() while gc.collect(): pass assert len(gc.get_referrers(obj)) == nref_start assert uhash1 == task1.uhash assert uhashes1 == task1.get_output_uhashes() assert uhash2 == task2.uhash assert uhashes2 == task2.get_output_uhashes()
[docs] def test_task_cancel(varinfo): task = SumTask(inputs={"a": 10}, varinfo=varinfo) with pytest.raises(NotImplementedError): task.cancel()