Source code for ewokscore.tests.test_task_output_model

from dataclasses import dataclass

import pytest
from pydantic import BaseModel
from pydantic import field_validator

from ..model import BaseOutputModel
from ..task import Task
from .examples.tasks.sumtask import SumTask


[docs] class User(BaseOutputModel): id: int name: str = "Jane Doe"
[docs] class PassThroughTask( Task, input_names=["id"], optional_input_names=["name"], output_model=User ):
[docs] def run(self): self.outputs.id = self.inputs.id if not self.missing_inputs.name: self.outputs.name = self.inputs.name
[docs] def test_error_if_output_model_does_not_derive_from_base_model(): class WrongBaseModelUser(BaseModel): id: int name: str = "Jane Doe" with pytest.raises( TypeError, match=r"output_model should be a subclass of ewokscore.model.BaseOutputModel", ): class WrongPassThroughTask(Task, output_model=WrongBaseModelUser): pass
[docs] def test_error_if_output_model_used_with_output_names(): with pytest.raises( TypeError, match="output_model cannot be used with output_names" ): class WrongPassThroughTask(Task, output_model=User, output_names=["user"]): pass
[docs] def test_validation(): task = PassThroughTask(inputs={"id": "wrong type"}) with pytest.raises(RuntimeError, match=r"id(\s*)Input should be a valid integer"): task.execute()
[docs] def test_default_value(): task = PassThroughTask(inputs={"id": 5}) task.execute() assert task.get_output_values() == {"id": 5, "name": "Jane Doe"}
[docs] def test_run(): task = PassThroughTask(inputs={"id": 5, "name": "Smith"}) task.execute() assert task.get_output_values() == {"id": 5, "name": "Smith"}
[docs] def test_error_on_subclass_with_wrong_submodel(): class Car(BaseOutputModel): wheels: int with pytest.raises( TypeError, match="Output model (.*) from task subclass must be a subclass of the original task output model", ): class PassThroughCarTask(PassThroughTask, output_model=Car): pass
[docs] def test_error_on_subclass_with_output_names(): with pytest.raises( TypeError, match="Cannot use output_names", ): class ChildPassThroughTask(PassThroughTask, output_names=["age"]): pass
[docs] def test_error_on_subclass_with_output_model_if_output_names(): with pytest.raises( TypeError, match="Cannot use output_model", ): class ChildPassThroughTask(SumTask, output_model=User): pass
[docs] def test_subclass_with_no_change(): class ChildPassThroughTask(PassThroughTask): pass task = ChildPassThroughTask(inputs={"id": 5, "name": "Smith"}) task.execute() assert task.get_output_values() == {"id": 5, "name": "Smith"}
[docs] class SuperUser(User): age: int
[docs] class PassThroughSubTask( PassThroughTask, optional_input_names=["age"], output_model=SuperUser ):
[docs] def run(self): super().run() if not self.missing_inputs.age: self.outputs.age = self.inputs.age
[docs] def test_subclass_validation(): task = PassThroughSubTask(inputs={"id": 5}) with pytest.raises(RuntimeError, match=r"1 validation error for SuperUser\nage"): task.execute()
[docs] def test_subclass(): task = PassThroughSubTask(inputs={"id": 5, "age": 18}) task.execute() assert task.get_output_values() == {"id": 5, "name": "Jane Doe", "age": 18}
[docs] class UserWithTypeCoercion(User): age: int
[docs] @field_validator("age", mode="before") def coerce_age(cls, value): if isinstance(value, float): return int(value + 0.5) if not isinstance(value, int): return -1 return value
[docs] class TaskWithTypeCoercion(PassThroughTask, output_model=UserWithTypeCoercion):
[docs] def run(self): super().run() if not self.missing_inputs.age: self.outputs.age = self.inputs.age
[docs] def test_output_type_coercion(): task = TaskWithTypeCoercion(inputs={"id": 5, "age": 18}) task.execute() assert task.get_output_values() == {"id": 5, "name": "Jane Doe", "age": 18} task = TaskWithTypeCoercion(inputs={"id": 5, "age": 18.1}) task.execute() assert task.get_output_values() == {"id": 5, "name": "Jane Doe", "age": 18} task = TaskWithTypeCoercion(inputs={"id": 5, "age": "wrong type"}) task.execute() assert task.get_output_values() == {"id": 5, "name": "Jane Doe", "age": -1}
[docs] def test_wrapped_type_coercion(tmp_path): varinfo = {"root_uri": str(tmp_path / "task_results")} task = TaskWithTypeCoercion(inputs={"id": 5, "age": 18}, varinfo=varinfo) task.execute() coerced_output_values = {"id": 5, "name": "Jane Doe", "age": 18} assert task.get_output_values() == coerced_output_values
[docs] def test_dataclass_field_stays_dataclass(): @dataclass class Address: city: str street: str class Outputs(BaseOutputModel): address: Address class CheckAddress(Task, input_names=["address"], output_model=Outputs): def run(self): self.outputs.address = self.inputs.address address = Address(city="Grenoble", street="Jean-Jaurès") task = CheckAddress(inputs={"address": address}) task.execute() assert isinstance(task.get_output_value("address"), Address)