Source code for ewokscore.tests.test_task_model
from typing import Union
import pytest
from ewokscore.missing_data import MISSING_DATA, MissingData, is_missing_data
from ewokscore.model import BaseInputModel
from ewokscore.task import Task, TaskInputError
from .examples.tasks.sumtask import SumTask
[docs]
class User(BaseInputModel):
id: int
name: str = "Jane Doe"
[docs]
class PassThroughTask(Task, input_model=User, output_names=["result"]):
[docs]
def run(self):
self.outputs.result = self.get_input_values()
[docs]
def test_validation():
with pytest.raises(TaskInputError, match=r"id(\s*)Field required"):
PassThroughTask(inputs={})
with pytest.raises(TaskInputError, match=r"id(\s*)Input should be a valid integer"):
PassThroughTask(inputs={"id": "ff"})
[docs]
def test_default_value():
task = PassThroughTask(inputs={"id": 5})
assert task.get_input_values() == {"id": 5, "name": "Jane Doe"}
[docs]
def test_run():
task = PassThroughTask(inputs={"id": 5, "name": "Smith"})
task.execute()
assert task.outputs["result"] == {"id": 5, "name": "Smith"}
[docs]
def test_error_on_subclass_with_wrong_submodel():
class Car(BaseInputModel):
wheels: int
with pytest.raises(
TypeError,
match="Input model (.*) from task subclass must be a subclass of the original task input model",
):
class PassThroughCarTask(PassThroughTask, input_model=Car):
pass
[docs]
def test_subclass_with_no_change():
class ChildPassThroughTask(PassThroughTask):
pass
task = ChildPassThroughTask(inputs={"id": 5, "name": "Smith"})
task.execute()
assert task.outputs["result"] == {"id": 5, "name": "Smith"}
[docs]
class SuperUser(User):
age: int
[docs]
class PassThroughSubTask(PassThroughTask, input_model=SuperUser):
pass
[docs]
def test_subclass_validation():
with pytest.raises(TaskInputError, match=r"age(\s*)Field required"):
PassThroughSubTask(inputs={"id": 5})
[docs]
def test_subclass():
task = PassThroughSubTask(inputs={"id": 5, "age": 18})
task.execute()
assert task.outputs["result"] == {"id": 5, "name": "Jane Doe", "age": 18}
[docs]
def test_missing_data():
class RegularTask(Task, input_names=["one"], optional_input_names=["two"]):
pass
class Model(BaseInputModel):
one: int
two: Union[int, MissingData] = MISSING_DATA
class ModelTask(Task, input_model=Model):
pass
regular_task = RegularTask(inputs={"one": 1})
model_task = ModelTask(inputs={"one": 1})
assert (
model_task.get_input_values() == regular_task.get_input_values() == {"one": 1}
)
assert (
is_missing_data(model_task.get_input_value("two"))
== is_missing_data(regular_task.get_input_value("two"))
== True # noqa: E712
)
assert (
model_task.missing_inputs["two"]
== regular_task.missing_inputs["two"]
== True # noqa: E712
)