import gc
import json
import re
import warnings
from glob import glob
from pathlib import Path
import pytest
from ewoksutils.exceptions import TaskInputError
from ewoksutils.exceptions import TaskInputWarning
from ..task import Task
from .examples.tasks.sumtask import SumTask
[docs]
def find_files(tmp_path: Path, extension):
return glob(str(tmp_path / "**" / f"*{extension}"), recursive=True)
[docs]
def expected_task_output_storage(task):
expected = [var.serialize() for var in task.output_variables.values()]
expected.append(task.output_variables.serialize())
return expected
[docs]
def assert_storage(tmp_path, expected):
lst = []
for filename in find_files(tmp_path, ".json"):
with open(filename, "r") as fileobj:
lst.append(json.load(fileobj)["data"])
for v in lst:
if isinstance(v, dict):
v.pop("__traceback__", None)
assert len(lst) == len(expected)
for v in expected:
lst.pop(lst.index(v))
assert not lst, "Unexpected data saved"
[docs]
def test_no_public_reserved_names():
assert not [s for s in Task._reserved_variable_names() if not s.startswith("_")]
[docs]
def test_task_done(varinfo):
task = SumTask(inputs={"a": 10}, varinfo=varinfo)
assert not task.done
task.execute()
assert task.done
task = SumTask(inputs={"a": 10}, varinfo=varinfo)
assert task.done
task = SumTask(inputs={"a": 10})
assert not task.done
task.execute()
assert task.done
task = SumTask(inputs={"a": 10})
assert not task.done
[docs]
def test_task_uhash(varinfo):
task = SumTask(inputs={"a": 10}, varinfo=varinfo)
uhash = task.uhash
assert task.uhash == task.output_variables.uhash
assert task.uhash != task.input_variables.uhash
task.input_variables["a"].value += 1
assert task.uhash != uhash
assert task.uhash == task.output_variables.uhash
assert task.uhash != task.input_variables.uhash
[docs]
def test_task_storage(tmp_path, varinfo):
task = SumTask(inputs={"a": 10, "b": 2}, varinfo=varinfo)
assert not task.done
task.execute()
assert task.done
assert task.outputs.result == 12
expected = expected_task_output_storage(task)
assert_storage(tmp_path, expected)
task = SumTask(inputs={"a": 10, "b": 2}, varinfo=varinfo)
assert task.done
assert task.outputs.result == 12
assert_storage(tmp_path, expected)
task = SumTask({"a": 2, "b": 10}, varinfo=varinfo)
assert not task.done
task.execute()
assert task.done
assert task.outputs.result == 12
expected += expected_task_output_storage(task)
assert_storage(tmp_path, expected)
task = SumTask({"a": task.output_variables["result"], "b": 0}, varinfo=varinfo)
assert not task.done
task.execute()
assert task.done
assert task.outputs.result == 12
expected += expected_task_output_storage(task)
assert_storage(tmp_path, expected)
task = SumTask(
{"a": 1, "b": task.output_variables["result"].data_proxy}, varinfo=varinfo
)
assert not task.done
task.execute()
assert task.done
assert task.outputs.result == 13
expected += expected_task_output_storage(task)
assert_storage(tmp_path, expected)
task = SumTask(
{"a": 1, "b": task.output_variables["result"].data_proxy.uri}, varinfo=varinfo
)
assert not task.done
task.execute()
assert task.done
assert task.outputs.result == 14
expected += expected_task_output_storage(task)
assert_storage(tmp_path, expected)
[docs]
def test_init_subclass_rejects_output_name_typo():
# test that the wrong name appears in the error statement
with pytest.raises(TaskInputError, match="output_name"):
class Bad1(Task, output_name=["reason"]):
pass
# test that the correct name appears in the error statement
with pytest.raises(TaskInputError, match="output_names"):
class Bad2(Task, output_name=["reason"]):
pass
[docs]
def test_init_subclass_accepts_valid_params():
# Should not raise an error
class Good(Task, input_names=["reason"], output_names=["result", "reason"]):
pass
[docs]
def test_task_cleanup_references():
class MyTask(Task, input_names=["mylist"], output_names=["mylist"]):
def run(self):
self.outputs.mylist = self.inputs.mylist + [len(self.inputs.mylist)]
obj = [0, 1, 2]
nref_start = len(gc.get_referrers(obj))
task1 = MyTask(inputs={"mylist": obj})
task2 = MyTask(inputs=task1.output_variables)
task1.execute()
task2.execute()
assert len(gc.get_referrers(obj)) > nref_start
uhash1 = task1.uhash
uhashes1 = task1.get_output_uhashes()
uhash2 = task2.uhash
uhashes2 = task2.get_output_uhashes()
task1.cleanup_references()
while gc.collect():
pass
assert len(gc.get_referrers(obj)) == nref_start
assert uhash1 == task1.uhash
assert uhashes1 == task1.get_output_uhashes()
assert uhash2 == task2.uhash
assert uhashes2 == task2.get_output_uhashes()
[docs]
def test_task_cancel(varinfo):
task = SumTask(inputs={"a": 10}, varinfo=varinfo)
with pytest.raises(NotImplementedError):
task.cancel()