import inspect
import logging
import pkgutil
import sys
from fnmatch import fnmatch
from types import FunctionType
from types import ModuleType
from typing import Generator
from typing import List
from typing import Optional
from typing import TypedDict
from ewoksutils.import_utils import import_module
from ewoksutils.import_utils import qualname
from .entry_points import entry_points
from .task import Task
class _TaskInputs(TypedDict):
required_input_names: List[str]
optional_input_names: List[str]
n_required_positional_inputs: int
class _CommonTaskFields(_TaskInputs):
task_identifier: str
output_names: List[str]
category: str
description: Optional[str]
input_model: Optional[str]
output_model: Optional[str]
[docs]
class TaskDict(_CommonTaskFields):
task_type: str
logger = logging.getLogger(__name__)
[docs]
def discover_tasks_from_modules(
*module_names_or_patterns: str,
task_type: Optional[str] = None,
reload: bool = False,
raise_import_failure: bool = True,
) -> List[TaskDict]:
if task_type is None:
task_types = ("class", "ppfmethod", "method")
else:
task_types = (task_type,)
result = list()
for task_type in task_types:
# Module names can contain patterns
for module_name_or_pattern in module_names_or_patterns:
for module_name in _iter_modules_from_pattern(
module_name_or_pattern,
reload=reload,
raise_import_failure=raise_import_failure,
):
result.extend(
_iter_discover_tasks_from_modules(
module_name,
task_type=task_type,
reload=reload,
raise_import_failure=raise_import_failure,
)
)
return result
def _iter_discover_tasks_from_modules(
*module_names: str,
task_type: str,
reload: bool = False,
raise_import_failure: bool = True,
) -> Generator[TaskDict, None, None]:
if "" not in sys.path:
# This happens when the python process was launched
# through a python console script
sys.path.append("")
if task_type == "method":
yield from _iter_method_tasks(
*module_names, reload=reload, raise_import_failure=raise_import_failure
)
elif task_type == "ppfmethod":
yield from _iter_ppfmethod_tasks(
*module_names, reload=reload, raise_import_failure=raise_import_failure
)
elif task_type == "class":
for module_name in module_names:
_safe_import_module(
module_name, reload=reload, raise_import_failure=raise_import_failure
)
yield from _iter_registered_tasks(*module_names)
else:
raise ValueError(f"Task type {task_type} does not support discovery")
def _iter_registered_tasks(*filter_modules: str) -> Generator[TaskDict, None, None]:
"""Yields all task classes registered in the current process."""
for cls in Task.get_subclasses():
assert issubclass(cls, Task)
module = cls.__module__
if filter_modules and not any(
module.startswith(prefix) for prefix in filter_modules
):
continue
task_identifier = cls.class_registry_name()
if task_identifier is None:
# Exclude unregistered tasks
continue
category = task_identifier.split(".")[0]
name = task_identifier.split(".")[-1]
if name.startswith("_"):
# Exclude hidden tasks
continue
input_model = cls.input_model()
output_model = cls.output_model()
yield {
"task_type": "class",
"task_identifier": task_identifier,
"required_input_names": sorted(cls.required_input_names()),
"optional_input_names": sorted(cls.optional_input_names()),
"output_names": sorted(cls.output_names()),
"category": category,
"description": cls.__doc__,
"input_model": qualname(input_model) if input_model else None,
"output_model": qualname(output_model) if output_model else None,
"n_required_positional_inputs": cls.n_required_positional_inputs(),
}
def _iter_method_tasks(
*module_names: str,
reload: bool = False,
raise_import_failure: bool = False,
) -> Generator[TaskDict, None, None]:
"""Yields all task methods from the provided module_names. The module_names will be will
imported for discovery.
"""
for module_name in module_names:
mod = _safe_import_module(
module_name, reload=reload, raise_import_failure=raise_import_failure
)
if mod is None:
continue
for method_name, method_qn in inspect.getmembers(mod, inspect.isfunction):
if method_name.startswith("_"):
continue
yield {
"task_type": "method",
**_common_method_task_fields(method_name, method_qn, mod),
}
def _iter_ppfmethod_tasks(
*module_names: str,
reload: bool = False,
raise_import_failure: bool = False,
) -> Generator[TaskDict, None, None]:
"""Yields all task ppfmethods from the provided module_names. The module_names will be will
imported for discovery.
The difference with regular methods is that ppfmethods are expected to be called `run`. Other method names will be ignored.
"""
for module_name in module_names:
mod = _safe_import_module(
module_name, reload=reload, raise_import_failure=raise_import_failure
)
if mod is None:
continue
for method_name, method_qn in inspect.getmembers(mod, inspect.isfunction):
if method_name != "run":
continue
yield {
"task_type": "ppfmethod",
**_common_method_task_fields(method_name, method_qn, mod),
}
def _iter_discover_all_tasks(
reload: bool = False,
task_type: Optional[str] = None,
raise_import_failure: bool = False,
) -> Generator[TaskDict, None, None]:
if task_type is None:
task_types = ("class", "ppfmethod", "method")
else:
task_types = (task_type,)
for task_type in task_types:
group = "ewoks.tasks." + task_type
visited = set()
for entrypoint in entry_points(group):
module_pattern = entrypoint.name
if module_pattern in visited:
continue
visited.add(module_pattern)
yield from discover_tasks_from_modules(
module_pattern,
task_type=task_type,
reload=reload,
raise_import_failure=raise_import_failure,
)
[docs]
def discover_all_tasks(
reload: bool = False,
task_type: Optional[str] = None,
raise_import_failure: bool = False,
) -> List[TaskDict]:
return list(
_iter_discover_all_tasks(
reload=reload,
task_type=task_type,
raise_import_failure=raise_import_failure,
)
)
def _iter_modules_from_pattern(
module_name_or_pattern: str,
reload: bool = False,
raise_import_failure: bool = False,
) -> Generator[str, None, None]:
if "*" not in module_name_or_pattern:
yield module_name_or_pattern
return
module_pattern = module_name_or_pattern
ndots = module_pattern.count(".")
parts = module_pattern.split(".")
pkg = _safe_import_module(
parts[0], reload=reload, raise_import_failure=raise_import_failure
)
if pkg is None:
return
if raise_import_failure:
def onerror(module_name):
raise
else:
onerror = _onerror
for pkginfo in pkgutil.walk_packages(
pkg.__path__, pkg.__name__ + ".", onerror=onerror
):
if pkginfo.name.count(".") == ndots and fnmatch(pkginfo.name, module_pattern):
yield pkginfo.name
def _safe_import_module(
module_name: str, reload: bool = False, raise_import_failure: bool = False
) -> Optional[ModuleType]:
try:
return import_module(module_name, reload=reload)
except Exception as e:
if raise_import_failure:
raise
_onerror(module_name, exception=e)
def _onerror(module_name, exception: Optional[Exception] = None):
if exception is None:
exception = sys.exc_info()[1]
logger.error(f"Module '{module_name}' cannot be imported: {exception}")
def _method_arguments(method) -> _TaskInputs:
sig = inspect.signature(method)
required_input_names: List[str] = list()
optional_input_names: List[str] = list()
n_required_positional_inputs = 0
for name, param in sig.parameters.items():
if param.kind == param.POSITIONAL_ONLY:
n_required_positional_inputs += 1
continue
if param.kind == param.VAR_POSITIONAL:
continue
if param.kind == param.VAR_KEYWORD:
continue
required = param.default is inspect._empty
if required:
required_input_names.append(name)
else:
optional_input_names.append(name)
return {
"required_input_names": required_input_names,
"optional_input_names": optional_input_names,
"n_required_positional_inputs": n_required_positional_inputs,
}
def _common_method_task_fields(
method_name: str, method_qn: FunctionType, mod: ModuleType
) -> _CommonTaskFields:
task_identifier = qualname(method_qn)
method = getattr(mod, method_name)
return {
**_method_arguments(method),
"task_identifier": qualname(method_qn),
"output_names": ["return_value"],
"category": task_identifier.split(".")[0],
"description": method.__doc__,
"input_model": None,
"output_model": None,
}