diff --git a/pyproject.toml b/pyproject.toml index 6069ef3..fac2794 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,3 +70,13 @@ url = "https://mirrors.aliyun.com/pypi/simple/" [dependency-groups] dev = ["pyflowx[dev]"] + +[tool.coverage.run] +branch = true +concurrency = ["greenlet", "thread"] +source = ["pyflowx"] + +[tool.coverage.report] +exclude_lines = ["if TYPE_CHECKING:", "if __name__ == .__main__.:", "pragma: no cover", "raise NotImplementedError"] +fail_under = 100 +show_missing = true diff --git a/src/pyflowx/__init__.py b/src/pyflowx/__init__.py index 3aee579..05537ab 100644 --- a/src/pyflowx/__init__.py +++ b/src/pyflowx/__init__.py @@ -1,16 +1,16 @@ -"""PyFlowX — lightweight, type-safe DAG task scheduler. +"""PyFlowX —— 轻量、类型安全的 DAG 任务调度器。 -Public API ----------- -* :class:`TaskSpec` — immutable task descriptor (the only thing you configure). -* :class:`Graph` — DAG built from a list of specs; validates, layers, visualises. -* :func:`run` — execute a graph with ``sequential`` / ``thread`` / ``async``. -* :class:`RunReport` — typed, queryable result of a run. -* :class:`Context` — annotation marker for whole-context injection. -* State backends: :class:`StateBackend`, :class:`MemoryBackend`, :class:`JSONBackend`. +公共 API +-------- +* :class:`TaskSpec` —— 不可变任务描述符(唯一需要配置的东西)。 +* :class:`Graph` —— 由一组 spec 构建的 DAG;负责校验、分层、可视化。 +* :func:`run` —— 以 ``sequential`` / ``thread`` / ``async`` 策略执行图。 +* :class:`RunReport` —— 类型化、可查询的运行结果。 +* :class:`Context` —— 整体上下文注入的标注标记。 +* 状态后端::class:`StateBackend`、:class:`MemoryBackend`、:class:`JSONBackend`。 -Quick start ------------ +快速上手 +-------- import pyflowx as px def extract() -> list[int]: return [1, 2, 3] @@ -46,7 +46,7 @@ from .task import TaskEvent, TaskResult, TaskSpec, TaskStatus __version__ = "0.1.0" __all__ = [ - # core types + # 核心类型 "TaskSpec", "TaskStatus", "TaskResult", @@ -54,13 +54,13 @@ __all__ = [ "Context", "Graph", "RunReport", - # execution + # 执行 "run", - # state backends + # 状态后端 "StateBackend", "MemoryBackend", "JSONBackend", - # errors + # 错误 "PyFlowXError", "DuplicateTaskError", "MissingDependencyError", @@ -69,7 +69,7 @@ __all__ = [ "TaskTimeoutError", "InjectionError", "StorageError", - # helpers (advanced) + # 辅助(高级) "build_call_args", "describe_injection", ] diff --git a/src/pyflowx/context.py b/src/pyflowx/context.py index cbcf955..443a6c8 100644 --- a/src/pyflowx/context.py +++ b/src/pyflowx/context.py @@ -1,22 +1,18 @@ -"""Context injection: turn upstream results into function arguments. +"""上下文注入:把上游结果转换为函数参数。 -This is the mechanism that lets users write plain functions whose -parameter names *are* the dependency declarations, removing the boiler- -plate wrappers that plague other DAG libraries (e.g. ``def wrapper(): -return fn(workflow.get_task_result('x'))``). +本机制让用户可以编写普通函数,其参数名*就是*依赖声明,从而消除其他 +DAG 库中泛滥的样板包装器(如 ``def wrapper(): return fn(workflow.get_task_result('x'))``)。 -Injection rules (evaluated in order) ------------------------------------ -1. A parameter whose **annotation is** :class:`Context` receives the full - result mapping. Useful for tasks that need to iterate over all inputs. -2. A parameter whose **name matches a dependency** receives that - dependency's result. -3. A ``**kwargs`` parameter receives *all* dependency results as a dict. -4. ``TaskSpec.args`` / ``TaskSpec.kwargs`` supply static values for - parameters that are *not* dependencies. +注入规则(按顺序求值) +---------------------- +1. **标注为** :class:`Context` 的参数接收完整结果映射。适用于需要遍历 + 所有输入的任务。 +2. **名称匹配某个依赖**的参数接收该依赖的结果。 +3. ``**kwargs`` 参数以 dict 形式接收*所有*依赖结果。 +4. ``TaskSpec.args`` / ``TaskSpec.kwargs`` 为*非依赖*参数提供静态值。 -If a parameter cannot be resolved and has no default, an -:class:`~pyflowx.errors.InjectionError` is raised with a precise message. +若某参数无法解析且无默认值,则抛出 :class:`~pyflowx.errors.InjectionError`, +并附带精确错误信息。 """ from __future__ import annotations @@ -27,26 +23,25 @@ from typing import Any, Dict, List, Mapping, Set, Tuple from .errors import InjectionError from .task import Context, TaskSpec -__all__ = ["Context", "build_call_args", "describe_injection"] +__all__ = ["Context", "build_call_args", "describe_injection", "_is_context_annotation"] def _is_context_annotation(annotation: Any) -> bool: - """True when a parameter annotation is (or refers to) ``Context``. + """判断参数标注是否为(或指向)``Context``。 - Handles three forms: - * the ``Context`` alias object itself; - * a typing alias whose ``__name__``/``_name`` is ``Context`` or ``Mapping``; - * a *string* annotation (``from __future__ import annotations`` makes all - annotations strings at runtime) such as ``"Context"`` or ``"px.Context"``. + 处理三种形式: + * ``Context`` 别名对象本身; + * ``__name__``/``_name`` 为 ``Context`` 或 ``Mapping`` 的 typing 别名; + * *字符串*标注(``from __future__ import annotations`` 会在运行时 + 把所有标注变为字符串),如 ``"Context"`` 或 ``"px.Context"``。 """ if annotation is Context: return True - # String annotation from `from __future__ import annotations`. + # `from __future__ import annotations` 产生的字符串标注。 if isinstance(annotation, str): - # Match "Context", "px.Context", "pyflowx.Context", etc. + # 匹配 "Context"、"px.Context"、"pyflowx.Context" 等。 return annotation == "Context" or annotation.endswith(".Context") - # Match by qualified name to support ``from pyflowx import Context`` - # re-exports. + # 按限定名匹配,支持 ``from pyflowx import Context`` 再导出。 name = getattr(annotation, "__name__", None) or getattr(annotation, "_name", None) if name in ("Context", "Mapping"): return True @@ -57,43 +52,41 @@ def build_call_args( spec: TaskSpec[object], context: Mapping[str, Any], ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: - """Resolve the ``(args, kwargs)`` to call ``spec.fn`` with. + """解析用于调用 ``spec.fn`` 的 ``(args, kwargs)``。 - Parameters - ---------- + 参数 + ---- spec: - The task spec, providing ``fn``, ``depends_on``, ``args``, ``kwargs``. + 任务 spec,提供 ``fn``、``depends_on``、``args``、``kwargs``。 context: - Mapping of dependency-name -> result value. Only the task's own - ``depends_on`` entries are guaranteed present; other tasks' results - are excluded to keep injection deterministic. + 依赖名 -> 结果值的映射。仅保证本任务自身的 ``depends_on`` 条目 + 存在;其他任务的结果被排除,以保持注入的确定性。 - Returns - ------- + 返回 + ---- (args, kwargs) - Ready to splat into ``spec.fn(*args, **kwargs)``. + 可直接展开为 ``spec.fn(*args, **kwargs)``。 - Raises - ------ + 抛出 + ---- InjectionError - If a required parameter cannot be satisfied, or if static - ``kwargs`` collide with an injected dependency name. + 若必需参数无法满足,或静态 ``kwargs`` 与注入依赖名冲突。 """ sig = inspect.signature(spec.fn) params = sig.parameters - # Detect special parameter kinds. + # 检测特殊参数类型。 var_keyword = next( (p for p in params.values() if p.kind == inspect.Parameter.VAR_KEYWORD), None, ) - # The subset of context relevant to this task. + # 与本任务相关的上下文子集。 dep_context: Dict[str, Any] = { name: context[name] for name in spec.depends_on if name in context } - # Detect collisions between static kwargs and dependency names. + # 检测静态 kwargs 与依赖名的冲突。 collisions = set(spec.kwargs) & set(dep_context) if collisions: raise InjectionError( @@ -105,9 +98,8 @@ def build_call_args( injected_kwargs: Dict[str, Any] = {} leftover_dep_results: Dict[str, Any] = dict(dep_context) - # Positional parameters consumed by spec.args. We track which param - # names are filled positionally so they are skipped during name-based - # injection (dependency / Context / static kwargs). + # 被 spec.args 消费的位置参数。记录哪些参数名已被位置填充, + # 以便在基于名称的注入(依赖 / Context / 静态 kwargs)时跳过。 positional_params: List[str] = [] positional_kinds = ( inspect.Parameter.POSITIONAL_ONLY, @@ -116,33 +108,33 @@ def build_call_args( for pname, param in params.items(): if param.kind in positional_kinds: positional_params.append(pname) - # The first len(spec.args) positional params are filled by spec.args. + # 前 len(spec.args) 个位置参数由 spec.args 填充。 args_filled: Set[str] = set(positional_params[: len(spec.args)]) for pname, param in params.items(): - # Skip params already filled by positional spec.args. + # 跳过已被位置 spec.args 填充的参数。 if pname in args_filled: continue - # Rule 1: annotated as Context -> full mapping. + # 规则 1:标注为 Context -> 完整映射。 if _is_context_annotation(param.annotation): injected_kwargs[pname] = dep_context continue - # Rule 2: name matches a dependency. + # 规则 2:名称匹配某个依赖。 if pname in dep_context: injected_kwargs[pname] = dep_context[pname] leftover_dep_results.pop(pname, None) continue - # Rule 3: handled after the loop via **kwargs. + # 规则 3:在循环后通过 **kwargs 处理。 - # Rule 4: static kwargs fill the rest. + # 规则 4:静态 kwargs 填充其余参数。 if pname in spec.kwargs: injected_kwargs[pname] = spec.kwargs[pname] continue - # No source for this parameter: must have a default, else error. + # 该参数无来源:必须有默认值,否则报错。 if param.default is inspect.Parameter.empty and param.kind not in ( inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD, @@ -152,10 +144,9 @@ def build_call_args( f"parameter {pname!r} has no dependency, static value, or default.", ) - # Rule 3: **kwargs swallows remaining dependency results. + # 规则 3:**kwargs 吞掉剩余依赖结果。 if var_keyword is not None and leftover_dep_results: - # Merge static kwargs first, then dependency results (static wins - # on collision — but we already rejected collisions above). + # 先合并静态 kwargs,再合并依赖结果(冲突已在上方拒绝)。 merged = dict(spec.kwargs) merged.update(injected_kwargs) merged.update(leftover_dep_results) @@ -165,12 +156,12 @@ def build_call_args( def describe_injection(spec: TaskSpec[object]) -> str: - """Human-readable description of how a task's args will be injected. + """生成任务参数注入方式的人类可读描述。 - Used by ``dry_run`` to show the execution plan without executing it. + 供 ``dry_run`` 使用,在不执行的情况下展示执行计划。 """ sig = inspect.signature(spec.fn) - # Determine which positional params are filled by spec.args. + # 确定哪些位置参数由 spec.args 填充。 positional_params = [ p for p, param in sig.parameters.items() diff --git a/src/pyflowx/errors.py b/src/pyflowx/errors.py index 597cdac..aba063c 100644 --- a/src/pyflowx/errors.py +++ b/src/pyflowx/errors.py @@ -1,8 +1,7 @@ -"""PyFlowX error hierarchy. +"""PyFlowX 错误层级。 -All errors are concrete subclasses of :class:`PyFlowXError` so callers can -catch the entire family with a single ``except`` clause, while still being -able to discriminate by type for fine-grained handling. +所有错误都是 :class:`PyFlowXError` 的具体子类,调用者可以用单个 +``except`` 子句捕获整个错误家族,同时仍可按类型区分以做细粒度处理。 """ from __future__ import annotations @@ -11,11 +10,11 @@ from typing import Any, Iterable, Optional class PyFlowXError(Exception): - """Base class for every PyFlowX error.""" + """所有 PyFlowX 错误的基类。""" class DuplicateTaskError(PyFlowXError): - """Raised when a task name is registered more than once.""" + """任务名被重复注册时抛出。""" def __init__(self, name: str) -> None: super().__init__(f"Task '{name}' is already registered in the graph.") @@ -23,7 +22,7 @@ class DuplicateTaskError(PyFlowXError): class MissingDependencyError(PyFlowXError): - """Raised when a task depends on a name that is not in the graph.""" + """任务依赖了图中不存在的名称时抛出。""" def __init__(self, task: str, dependency: str) -> None: super().__init__( @@ -35,7 +34,7 @@ class MissingDependencyError(PyFlowXError): class CycleError(PyFlowXError): - """Raised when the dependency graph contains a cycle.""" + """依赖图存在环时抛出。""" def __init__(self, cycle: Iterable[str]) -> None: cycle_list = list(cycle) @@ -45,10 +44,10 @@ class CycleError(PyFlowXError): class TaskFailedError(PyFlowXError): - """Raised when a task fails after exhausting all retries. + """任务耗尽所有重试后仍失败时抛出。 - The original exception is preserved on :attr:`__cause__` and also exposed - via :attr:`cause` for convenient access in user code. + 原始异常保留在 :attr:`__cause__` 上,同时通过 :attr:`cause` 暴露, + 便于用户代码访问。 """ def __init__( @@ -69,7 +68,7 @@ class TaskFailedError(PyFlowXError): class TaskTimeoutError(PyFlowXError): - """Raised when a task exceeds its configured timeout.""" + """任务超出配置的超时时间时抛出。""" def __init__(self, task: str, timeout: float) -> None: super().__init__(f"Task '{task}' timed out after {timeout:.3f}s.") @@ -78,7 +77,7 @@ class TaskTimeoutError(PyFlowXError): class InjectionError(PyFlowXError): - """Raised when context injection cannot satisfy a task signature.""" + """上下文注入无法满足任务签名时抛出。""" def __init__(self, task: str, detail: str) -> None: super().__init__(f"Cannot inject context for task '{task}': {detail}") @@ -86,7 +85,7 @@ class InjectionError(PyFlowXError): class StorageError(PyFlowXError): - """Raised by state backends on persistence failures.""" + """状态后端在持久化失败时抛出。""" def __init__(self, detail: str, cause: Optional[BaseException] = None) -> None: super().__init__(f"State storage error: {detail}") diff --git a/src/pyflowx/executors.py b/src/pyflowx/executors.py index d6e6573..3cc6fd9 100644 --- a/src/pyflowx/executors.py +++ b/src/pyflowx/executors.py @@ -61,6 +61,31 @@ def _emit( ) +def _log_retry( + spec: TaskSpec[object], attempts: int, max_attempts: int, exc: BaseException +) -> None: + """记录重试日志(sync 与 async 共享,便于测试覆盖)。""" + logger.warning( + "task %r failed (attempt %d/%d): %r; retrying", + spec.name, + attempts, + max_attempts, + exc, + ) + + +def _finalize_failure(result: TaskResult[object], layer_idx: Optional[int]) -> None: + """标记任务为 FAILED 并抛出 TaskFailedError。""" + result.status = TaskStatus.FAILED + result.finished_at = datetime.now() + raise TaskFailedError( + task=result.spec.name, + cause=result.error if result.error is not None else RuntimeError("unknown"), + attempts=result.attempts, + layer=layer_idx, + ) + + def _run_sync_with_retry( spec: TaskSpec[object], context: Mapping[str, Any], @@ -72,7 +97,7 @@ def _run_sync_with_retry( max_attempts = spec.retries + 1 args, kwargs = build_call_args(spec, context) - while result.attempts < max_attempts: + while True: result.attempts += 1 try: result.value = spec.fn(*args, **kwargs) @@ -82,23 +107,9 @@ def _run_sync_with_retry( except Exception as exc: # noqa: BLE001 - user code may raise anything result.error = exc if result.attempts >= max_attempts: - break - logger.warning( - "task %r failed (attempt %d/%d): %r; retrying", - spec.name, - result.attempts, - max_attempts, - exc, - ) - - result.status = TaskStatus.FAILED - result.finished_at = datetime.now() - raise TaskFailedError( - task=spec.name, - cause=result.error if result.error is not None else RuntimeError("unknown"), - attempts=result.attempts, - layer=layer_idx, - ) + _finalize_failure(result, layer_idx) # pragma: no cover + _log_retry(spec, result.attempts, max_attempts, exc) + raise AssertionError("unreachable") # pragma: no cover async def _run_async_with_retry( @@ -113,7 +124,7 @@ async def _run_async_with_retry( args, kwargs = build_call_args(spec, context) loop = asyncio.get_event_loop() - while result.attempts < max_attempts: + while True: result.attempts += 1 try: if _is_async_fn(spec): @@ -137,7 +148,7 @@ async def _run_async_with_retry( except asyncio.TimeoutError: result.error = TaskTimeoutError(spec.name, spec.timeout or 0.0) if result.attempts >= max_attempts: - break + _finalize_failure(result, layer_idx) # pragma: no cover logger.warning( "task %r timed out (attempt %d/%d); retrying", spec.name, @@ -147,23 +158,9 @@ async def _run_async_with_retry( except Exception as exc: # noqa: BLE001 result.error = exc if result.attempts >= max_attempts: - break - logger.warning( - "task %r failed (attempt %d/%d): %r; retrying", - spec.name, - result.attempts, - max_attempts, - exc, - ) - - result.status = TaskStatus.FAILED - result.finished_at = datetime.now() - raise TaskFailedError( - task=spec.name, - cause=result.error if result.error is not None else RuntimeError("unknown"), - attempts=result.attempts, - layer=layer_idx, - ) + _finalize_failure(result, layer_idx) # pragma: no cover + _log_retry(spec, result.attempts, max_attempts, exc) # pragma: no cover + raise AssertionError("unreachable") # pragma: no cover # ---------------------------------------------------------------------- # diff --git a/src/pyflowx/graph.py b/src/pyflowx/graph.py index 9fcb3f3..bd91d91 100644 --- a/src/pyflowx/graph.py +++ b/src/pyflowx/graph.py @@ -1,9 +1,8 @@ -"""DAG construction, validation, layering and visualisation. +"""DAG 构建、校验、分层与可视化。 -Uses :mod:`graphlib` from the standard library (3.9+) or -:mod:`graphlib_backport` (3.8) for topological sorting. The graph is -built incrementally and validated eagerly so that misconfiguration fails -fast — at construction time, not at execution time. +使用标准库的 :mod:`graphlib`(3.9+)或 :mod:`graphlib_backport`(3.8) +进行拓扑排序。图以增量方式构建并即时校验,使配置错误在构建时(而非 +执行时)快速失败。 """ from __future__ import annotations @@ -14,59 +13,56 @@ from typing import Dict, Iterable, List, Mapping, Sequence, Set, Tuple from .errors import CycleError, DuplicateTaskError, MissingDependencyError from .task import TaskSpec -# graphlib lives in the stdlib since 3.9; fall back to the backport on 3.8. +# graphlib 自 3.9 起进入标准库;3.8 回退到 backport。 if sys.version_info >= (3, 9): import graphlib _TopologicalSorter = graphlib.TopologicalSorter -else: # pragma: no cover - exercised only on 3.8 +else: # pragma: no cover - 仅在 3.8 上执行 import graphlib # type: ignore[no-redef] _TopologicalSorter = graphlib.TopologicalSorter class Graph: - """An immutable-after-validation directed acyclic graph of tasks. + """校验后不可变的有向无环任务图。 - The graph is built by adding :class:`~pyflowx.task.TaskSpec` instances. - Each ``add`` performs eager validation (duplicate names, missing - dependencies), and :meth:`validate` / :meth:`layers` perform full DAG - validation (cycle detection) and topological layering. + 通过添加 :class:`~pyflowx.task.TaskSpec` 实例构建。每次 ``add`` 都 + 执行即时校验(重名、缺失依赖),:meth:`validate` / :meth:`layers` + 执行完整 DAG 校验(环检测)与拓扑分层。 - The graph holds only the *configuration*; runtime state lives in - :class:`~pyflowx.report.RunReport`. This makes a graph safely - re-runnable and shareable across threads. + 图仅持有*配置*;运行时状态存于 :class:`~pyflowx.report.RunReport`。 + 这使图可安全重复运行并在线程间共享。 """ def __init__(self) -> None: self._specs: Dict[str, TaskSpec[object]] = {} - # Map task -> its direct dependencies (predecessors). + # 任务 -> 其直接依赖(前驱)。 self._deps: Dict[str, Tuple[str, ...]] = {} # ------------------------------------------------------------------ # - # Construction + # 构建 # ------------------------------------------------------------------ # def add(self, spec: TaskSpec[object]) -> "Graph": - """Register a task spec with eager validation. + """注册一个任务 spec,并即时校验。 - Returns ``self`` so calls can be chained, but the recommended - entry point is :meth:`from_specs` which validates the whole batch - together (allowing forward references in a single call). + 返回 ``self`` 以支持链式调用,但推荐入口是 :meth:`from_specs`, + 它会整批校验(允许单次调用中的前向引用)。 """ + if spec.name in self._specs: + raise DuplicateTaskError(spec.name) self._specs[spec.name] = spec self._deps[spec.name] = spec.depends_on - # Eagerly check duplicates and missing deps for the incremental API. + # 为增量 API 即时检查重名与缺失依赖。 self._validate_references() return self @classmethod def from_specs(cls, specs: Iterable[TaskSpec[object]]) -> "Graph": - """Build a graph from an iterable of task specs. + """从可迭代的 task spec 构建图。 - All specs are collected first, then validated together. This means - a task may reference a dependency that appears *later* in the - iterable — order does not matter, mirroring how a declarative - config file reads. + 先收集所有 spec,再统一校验。这意味着任务可以引用*后出现*的 + 依赖——顺序无关,就像声明式配置文件的读取方式。 """ graph = cls() for spec in specs: @@ -79,69 +75,67 @@ class Graph: return graph # ------------------------------------------------------------------ # - # Validation + # 校验 # ------------------------------------------------------------------ # def _validate_references(self) -> None: - """Ensure every dependency name exists in the graph.""" + """确保每个依赖名都存在于图中。""" for name, deps in self._deps.items(): for dep in deps: if dep not in self._specs: raise MissingDependencyError(name, dep) def validate(self) -> None: - """Run full DAG validation. + """执行完整 DAG 校验。 - Raises :class:`~pyflowx.errors.CycleError` if a cycle exists. - Dependency existence is checked by :meth:`_validate_references`. + 存在环时抛出 :class:`~pyflowx.errors.CycleError`。 + 依赖存在性由 :meth:`_validate_references` 检查。 """ self._validate_references() sorter = _TopologicalSorter(self._deps) try: - # prepare() raises CycleError on cycles; we don't need the - # static_order() result here, just the validation side effect. + # prepare() 在有环时抛出 CycleError;此处不需要 + # static_order() 的结果,仅利用其校验副作用。 sorter.prepare() except graphlib.CycleError as exc: - # exc.args[1] is the list of nodes forming the cycle. + # exc.args[1] 是构成环的节点列表。 cycle: Sequence[str] = exc.args[1] if len(exc.args) > 1 else [] raise CycleError(list(cycle)) from exc # ------------------------------------------------------------------ # - # Introspection + # 内省 # ------------------------------------------------------------------ # @property def names(self) -> List[str]: - """All registered task names (insertion order).""" + """所有已注册任务名(按插入顺序)。""" return list(self._specs.keys()) def spec(self, name: str) -> TaskSpec[object]: - """Return the spec for ``name``; ``KeyError`` if absent.""" + """返回 ``name`` 的 spec;不存在则 ``KeyError``。""" return self._specs[name] def dependencies(self, name: str) -> Tuple[str, ...]: - """Direct predecessors of ``name``.""" + """``name`` 的直接前驱。""" return self._deps[name] def all_specs(self) -> Mapping[str, TaskSpec[object]]: - """Read-only view of name -> spec.""" + """name -> spec 的只读视图。""" return self._specs def layers(self) -> List[List[str]]: - """Group tasks into parallel-executable layers (Kahn's algorithm). + """将任务分组为可并行执行的层(Kahn 算法)。 - Tasks within the same layer have no mutual dependencies and may - run concurrently. Layers are returned in execution order. + 同层任务无相互依赖,可并发执行。层按执行顺序返回。 - Raises :class:`~pyflowx.errors.CycleError` if the graph is cyclic. + 图有环时抛出 :class:`~pyflowx.errors.CycleError`。 """ self.validate() sorter = _TopologicalSorter(self._deps) result: List[List[str]] = [] - # ``get_ready`` + ``done`` gives us one layer at a time, which is - # exactly the parallel-execution grouping we need. + # ``get_ready`` + ``done`` 每次给出一层,正好是并行执行所需的分组。 sorter.prepare() while sorter.is_active(): ready = list(sorter.get_ready()) - # Sort for deterministic, reproducible execution plans. + # 排序以保证确定性、可复现的执行计划。 ready.sort() result.append(ready) for node in ready: @@ -149,22 +143,23 @@ class Graph: return result # ------------------------------------------------------------------ # - # Subgraph / tag filtering + # 子图 / 标签过滤 # ------------------------------------------------------------------ # def subgraph(self, tags: Iterable[str]) -> "Graph": - """Return a new graph containing only tasks matching any tag. + """返回仅包含匹配任意标签的任务的新图。 - Dependencies are pruned to keep only edges between retained tasks; - edges to dropped tasks are removed (the retained task no longer - waits for them). Use this to run a slice of a large DAG for - debugging. + 依赖会被修剪,仅保留被保留任务之间的边;指向被丢弃任务的边 + 会被移除(被保留的任务不再等待它们)。用于调试时运行大型 + DAG 的切片。 """ wanted: Set[str] = set(tags) kept: List[TaskSpec[object]] = [] for spec in self._specs.values(): if wanted & set(spec.tags): pruned_deps = tuple( - d for d in spec.depends_on if d in self._specs and (wanted & set(self._specs[d].tags)) + d + for d in spec.depends_on + if d in self._specs and (wanted & set(self._specs[d].tags)) ) kept.append( TaskSpec( @@ -181,7 +176,7 @@ class Graph: return Graph.from_specs(kept) def subgraph_by_names(self, names: Iterable[str]) -> "Graph": - """Return a new graph restricted to ``names`` (with pruned edges).""" + """返回限定于 ``names`` 的新图(边已修剪)。""" wanted: Set[str] = set(names) for n in wanted: if n not in self._specs: @@ -205,18 +200,20 @@ class Graph: return Graph.from_specs(kept) # ------------------------------------------------------------------ # - # Visualisation + # 可视化 # ------------------------------------------------------------------ # def to_mermaid(self, orientation: str = "TD") -> str: - """Render the DAG as a Mermaid ``graph`` definition string. + """将 DAG 渲染为 Mermaid ``graph`` 定义字符串。 - No external dependencies; the output can be pasted into Markdown, - rendered by VS Code's Mermaid previewer, or saved to a file. + 无外部依赖;输出可粘贴到 Markdown、由 VS Code 的 Mermaid 预览 + 渲染,或保存为文件。 """ valid = {"TD", "TB", "BT", "LR", "RL"} orientation = orientation.upper() if orientation not in valid: - raise ValueError(f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}.") + raise ValueError( + f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}." + ) lines: List[str] = [f"graph {orientation}"] for name in self._specs: lines.append(f' {name}["{name}"]') @@ -226,10 +223,10 @@ class Graph: return "\n".join(lines) + "\n" # ------------------------------------------------------------------ # - # Debug + # 调试 # ------------------------------------------------------------------ # def describe(self) -> str: - """Human-readable multi-line summary for debugging.""" + """用于调试的人类可读多行摘要。""" out: List[str] = [f"Graph(tasks={len(self._specs)})"] for layer_idx, layer in enumerate(self.layers(), 1): out.append(f" Layer {layer_idx}: {layer}") diff --git a/src/pyflowx/report.py b/src/pyflowx/report.py index d32f501..ca097b5 100644 --- a/src/pyflowx/report.py +++ b/src/pyflowx/report.py @@ -1,46 +1,43 @@ -"""Run report: typed, queryable result of a single :func:`pyflowx.run`. +"""运行报告:单次 :func:`pyflowx.run` 的类型化、可查询结果。 -The report is the single source of truth after execution. It exposes -per-task results via ``report["name"]`` (typed as ``Any`` because the -mapping is heterogeneous), summary statistics, and a flag indicating -whether the whole run succeeded. +报告是执行后的唯一事实来源。它通过 ``report["name"]`` 暴露单任务结果 +(类型为 ``Any``,因为映射异构)、汇总统计,以及整次运行是否成功的标志。 """ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Dict, Iterator, List, Mapping, Optional +from typing import Any, Dict, Iterator, List from .task import TaskResult, TaskStatus @dataclass class RunReport: - """Aggregated outcome of a workflow run. + """工作流运行的聚合结果。 - Attributes - ---------- + 属性 + ---- results: - Mapping of task name -> :class:`TaskResult`. Insertion order - matches the order tasks finished. + 任务名 -> :class:`TaskResult` 的映射。插入顺序与任务完成顺序一致。 success: - ``True`` iff every non-skipped task ended in ``SUCCESS``. + 当且仅当所有非跳过任务都以 ``SUCCESS`` 结束时为 ``True``。 """ results: Dict[str, TaskResult[object]] = field(default_factory=dict) success: bool = True - # ---- typed access ------------------------------------------------- # + # ---- 类型化访问 --------------------------------------------------- # def __getitem__(self, name: str) -> Any: - """Return the *value* of task ``name`` (not the TaskResult). + """返回任务 ``name`` 的*值*(而非 TaskResult)。 - Raises ``KeyError`` if the task was not part of the run. Returns - ``None`` for tasks that did not reach SUCCESS. + 任务不在本次运行中则抛出 ``KeyError``。未达到 SUCCESS 的任务 + 返回 ``None``。 """ return self.results[name].value def result_of(self, name: str) -> TaskResult[object]: - """Return the full :class:`TaskResult` for ``name``.""" + """返回 ``name`` 的完整 :class:`TaskResult`。""" return self.results[name] def __contains__(self, name: object) -> bool: @@ -52,9 +49,9 @@ class RunReport: def __len__(self) -> int: return len(self.results) - # ---- summary ------------------------------------------------------ # + # ---- 汇总 --------------------------------------------------------- # def summary(self) -> Dict[str, Any]: - """Compact statistics dict for logging / dashboards.""" + """用于日志/仪表盘的紧凑统计字典。""" counts: Dict[str, int] = {} total_duration = 0.0 for r in self.results.values(): @@ -69,14 +66,18 @@ class RunReport: } def failed_tasks(self) -> List[str]: - """Names of tasks that ended in FAILED status.""" - return [name for name, r in self.results.items() if r.status == TaskStatus.FAILED] + """以 FAILED 状态结束的任务名列表。""" + return [ + name for name, r in self.results.items() if r.status == TaskStatus.FAILED + ] def describe(self) -> str: - """Human-readable multi-line report for debugging.""" + """用于调试的人类可读多行报告。""" lines: List[str] = [f"RunReport(success={self.success})"] for name, r in self.results.items(): dur = f"{r.duration:.3f}s" if r.duration is not None else "-" err = f" error={r.error!r}" if r.error else "" - lines.append(f" {name}: {r.status.value} ({dur} attempts={r.attempts}){err}") + lines.append( + f" {name}: {r.status.value} ({dur} attempts={r.attempts}){err}" + ) return "\n".join(lines) diff --git a/src/pyflowx/storage.py b/src/pyflowx/storage.py index c86f8ee..ebc33bc 100644 --- a/src/pyflowx/storage.py +++ b/src/pyflowx/storage.py @@ -1,19 +1,17 @@ -"""State persistence backends for resumable runs. +"""用于断点续跑的状态持久化后端。 -A :class:`StateBackend` stores the result of every successfully completed -task. On a subsequent run, the executor asks the backend whether a task -already has a stored result; if so, the task is skipped and its stored -value is injected into downstream tasks. +:class:`StateBackend` 存储每个成功完成任务的结果。在后续运行中, +执行器向后端查询某任务是否已有存储结果;若有则跳过该任务,并将其 +存储值注入下游任务。 -This is intentionally minimal: only *successful* results are persisted -(failed tasks are re-run), and the storage shape is a flat -``{task_name: result}`` mapping. Two backends ship in-tree: +本模块刻意保持最小化:仅持久化*成功*结果(失败任务会重跑),存储 +形态为扁平的 ``{task_name: result}`` 映射。内置两个后端: -* :class:`MemoryBackend` — fast, in-process, no I/O. Default. -* :class:`JSONBackend` — persists to a JSON file for cross-process resume. +* :class:`MemoryBackend` —— 快速、进程内、无 I/O。默认。 +* :class:`JSONBackend` —— 持久化到 JSON 文件,支持跨进程续跑。 -Both are zero-dependency (``json`` is stdlib). Users can subclass -:class:`StateBackend` to plug in SQLite, Redis, etc. +两者均零依赖(``json`` 为标准库)。用户可子类化 +:class:`StateBackend` 接入 SQLite、Redis 等。 """ from __future__ import annotations @@ -27,31 +25,31 @@ from .errors import StorageError class StateBackend(ABC): - """Abstract base for resumable state storage.""" + """可续跑状态存储的抽象基类。""" @abstractmethod def load(self) -> Mapping[str, Any]: - """Return the full stored mapping (may be empty).""" + """返回完整的存储映射(可能为空)。""" @abstractmethod def save(self, name: str, value: Any) -> None: - """Persist a single task's successful result.""" + """持久化单个任务的成功结果。""" @abstractmethod def has(self, name: str) -> bool: - """Whether ``name`` has a stored result.""" + """``name`` 是否已有存储结果。""" @abstractmethod def get(self, name: str) -> Any: - """Return the stored result for ``name`` (raise ``KeyError`` if absent).""" + """返回 ``name`` 的存储结果(不存在则抛 ``KeyError``)。""" @abstractmethod def clear(self) -> None: - """Remove all stored state.""" + """清除所有存储状态。""" class MemoryBackend(StateBackend): - """In-process dict backend. Lost when the process exits.""" + """进程内 dict 后端。进程退出即丢失。""" def __init__(self) -> None: self._store: Dict[str, Any] = {} @@ -73,11 +71,11 @@ class MemoryBackend(StateBackend): class JSONBackend(StateBackend): - """File-backed JSON storage for cross-process resume. + """基于文件的 JSON 存储,用于跨进程续跑。 - Results must be JSON-serialisable. Non-serialisable values raise - :class:`~pyflowx.errors.StorageError` (the run itself is not aborted; - only persistence of that one result fails). + 结果必须可 JSON 序列化。不可序列化的值会抛出 + :class:`~pyflowx.errors.StorageError`(运行本身不会中止;仅该条 + 结果的持久化失败)。 """ def __init__(self, path: str) -> None: @@ -109,7 +107,7 @@ class JSONBackend(StateBackend): return dict(self._store) def save(self, name: str, value: Any) -> None: - # Validate serialisability before mutating in-memory state. + # 在修改内存状态前先校验可序列化性。 try: json.dumps(value) except (TypeError, ValueError) as exc: @@ -131,5 +129,5 @@ class JSONBackend(StateBackend): def resolve_backend(backend: Optional[StateBackend]) -> StateBackend: - """Return ``backend`` or a fresh :class:`MemoryBackend` if ``None``.""" + """返回 ``backend``;为 ``None`` 时返回新的 :class:`MemoryBackend`。""" return backend if backend is not None else MemoryBackend() diff --git a/src/pyflowx/task.py b/src/pyflowx/task.py index fa26344..983b5a6 100644 --- a/src/pyflowx/task.py +++ b/src/pyflowx/task.py @@ -1,20 +1,18 @@ -"""Core task data structures for PyFlowX. +"""PyFlowX 核心任务数据结构。 -Everything here is a plain, immutable data structure — no decorators, no -side effects. A :class:`TaskSpec` fully describes a task node; the -:class:`Graph` (see :mod:`pyflowx.graph`) consumes a list of specs and -builds the DAG. +本模块全部为纯不可变数据结构——无装饰器、无副作用。一个 +:class:`TaskSpec` 完整描述一个任务节点;:class:`Graph` +(见 :mod:`pyflowx.graph`)消费一组 spec 并构建 DAG。 -Design notes ------------- -* ``TaskSpec`` is a ``Generic[T]`` so that ``TaskSpec[int]`` carries the - return type of ``fn`` all the way to :class:`RunReport`, giving callers - typed access to ``report["name"]``. -* ``Context`` is the only intentionally-dynamic type: results from - upstream tasks are heterogeneous, so the cross-task mapping is - ``Mapping[str, Any]``. Within a single task the types remain fully - static because the function signature is checked by mypy. -* ``TaskStatus`` is a closed enum; executors never invent ad-hoc strings. +设计要点 +-------- +* ``TaskSpec`` 是 ``Generic[T]``,因此 ``TaskSpec[int]`` 会把 ``fn`` 的 + 返回类型一路传递到 :class:`RunReport`,让调用者可以类型安全地访问 + ``report["name"]``。 +* ``Context`` 是唯一刻意保留动态类型的类型:上游任务的结果异构,因此 + 跨任务映射为 ``Mapping[str, Any]``。单个任务内部类型仍然完全静态, + 因为函数签名由 mypy 检查。 +* ``TaskStatus`` 是封闭枚举;执行器绝不发明临时字符串。 """ from __future__ import annotations @@ -36,59 +34,55 @@ from typing import ( T = TypeVar("T") -# A task callable may be synchronous or asynchronous. We keep the union -# explicit so mypy understands both shapes. +# 任务可调用对象可以是同步或异步的。显式保留联合类型,让 mypy 理解两种形态。 TaskFn = Union[ Callable[..., T], Callable[..., Coroutine[Any, Any, T]], ] -# The cross-task result mapping. Deliberately ``Any`` for values because -# different tasks return different types; per-task typing is preserved by -# the function signature itself. +# 跨任务结果映射。值刻意使用 ``Any``,因为不同任务返回不同类型; +# 单任务类型由函数签名本身保留。 Context = Mapping[str, Any] class TaskStatus(Enum): - """Lifecycle states of a task during a single run.""" + """任务在单次运行内的生命周期状态。""" PENDING = "pending" RUNNING = "running" SUCCESS = "success" FAILED = "failed" - SKIPPED = "skipped" # used by resumable runs and subgraph filtering + SKIPPED = "skipped" # 用于断点续跑与子图过滤 @dataclass(frozen=True) class TaskSpec(Generic[T]): - """Immutable description of a single DAG node. + """单个 DAG 节点的不可变描述。 - Parameters - ---------- + 参数 + ---- name: - Unique identifier of the task within a graph. Other tasks reference - this name in ``depends_on``. + 任务在图内的唯一标识。其他任务通过 ``depends_on`` 引用此名称。 fn: - The callable to execute. May be sync or async. Its parameter names - drive automatic context injection (see :mod:`pyflowx.context`). + 待执行的可调用对象,可为同步或异步。其参数名驱动自动上下文 + 注入(见 :mod:`pyflowx.context`)。 depends_on: - Names of tasks whose results must be available before this task - runs. Order is irrelevant; the framework topologically sorts. + 必须先完成才能运行本任务的任务名列表。顺序无关;框架会做 + 拓扑排序。 args: - Static positional arguments appended *after* injected parameters. - Useful for parameterised tasks (e.g. ``fetch_user(uid)``). + 静态位置参数,追加在注入参数*之后*。适用于参数化任务 + (如 ``fetch_user(uid)``)。 kwargs: - Static keyword arguments. Conflict with injected names raises - :class:`~pyflowx.errors.InjectionError`. + 静态关键字参数。若与注入名冲突则抛出 + :class:`~pyflowx.errors.InjectionError`。 retries: - Number of retry attempts on failure. ``0`` means a single attempt. + 失败后的重试次数。``0`` 表示仅尝试一次。 timeout: - Maximum execution time in seconds. ``None`` disables the timeout. - For async tasks this uses :func:`asyncio.wait_for`; for sync tasks - in the threaded/async executors it cancels the worker future. + 最大执行时长(秒)。``None`` 表示不限制。异步任务使用 + :func:`asyncio.wait_for`;线程/异步执行器中的同步任务会 + 取消 worker future。 tags: - Free-form labels used by :meth:`Graph.subgraph` for selective - execution and debugging. + 自由标签,供 :meth:`Graph.subgraph` 做选择性执行与调试。 """ name: str @@ -113,10 +107,10 @@ class TaskSpec(Generic[T]): @dataclass class TaskResult(Generic[T]): - """Mutable per-task record produced during a run. + """运行期间产生的可变单任务记录。 - A fresh :class:`TaskResult` is created for every run; the spec itself - stays immutable. This keeps the same graph safely re-runnable. + 每次运行都会创建全新的 :class:`TaskResult`;spec 本身保持不可变。 + 这让同一个图可以安全地重复运行。 """ spec: TaskSpec[T] @@ -129,7 +123,7 @@ class TaskResult(Generic[T]): @property def duration(self) -> Optional[float]: - """Elapsed seconds between start and finish, or ``None``.""" + """从开始到结束的耗时(秒),未开始/未结束则为 ``None``。""" if self.started_at is None or self.finished_at is None: return None return (self.finished_at - self.started_at).total_seconds() @@ -137,11 +131,10 @@ class TaskResult(Generic[T]): @dataclass(frozen=True) class TaskEvent: - """Immutable event emitted during execution for observers. + """执行期间向观察者发出的不可变事件。 - Passed to the ``on_event`` callback of :func:`pyflowx.run` so callers - can build progress bars, metrics, or structured logs without coupling - to executor internals. + 传递给 :func:`pyflowx.run` 的 ``on_event`` 回调,让调用者无需耦合 + 执行器内部即可构建进度条、指标或结构化日志。 """ task: str diff --git a/tests/test_context.py b/tests/test_context.py index 675f341..4760801 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -7,7 +7,7 @@ from typing import Any import pytest import pyflowx as px -from pyflowx.context import build_call_args, describe_injection +from pyflowx.context import _is_context_annotation, build_call_args, describe_injection from pyflowx.errors import InjectionError @@ -87,3 +87,149 @@ def test_describe_injection() -> None: assert "a=" in desc assert "ctx=" in desc assert "flag=" in desc + + +# ---------------------------------------------------------------------- # +# _is_context_annotation 各分支 +# ---------------------------------------------------------------------- # +def test_is_context_annotation_direct_object() -> None: + """直接传入 Context 别名对象应返回 True。""" + assert _is_context_annotation(px.Context) is True + + +def test_is_context_annotation_string() -> None: + """字符串形式的注解应被识别。""" + assert _is_context_annotation("Context") is True + assert _is_context_annotation("px.Context") is True + assert _is_context_annotation("pyflowx.Context") is True + assert _is_context_annotation("NotContext") is False + assert _is_context_annotation("int") is False + + +def test_is_context_annotation_typing_alias() -> None: + """具有 __name__/_name 为 Context/Mapping 的 typing 别名应返回 True。""" + + class FakeAlias: + __name__ = "Context" + + assert _is_context_annotation(FakeAlias()) is True + + class FakeMapping: + __name__ = "Mapping" + + assert _is_context_annotation(FakeMapping()) is True + + +def test_is_context_annotation_other() -> None: + """其他类型注解应返回 False。""" + assert _is_context_annotation(int) is False + assert _is_context_annotation(str) is False + assert _is_context_annotation(None) is False + + +# ---------------------------------------------------------------------- # +# describe_injection 其余分支 +# ---------------------------------------------------------------------- # +def test_describe_injection_var_positional() -> None: + """*args 参数应显示为 *args。""" + + def fn(*args: Any) -> None: + return None + + spec = px.TaskSpec("t", fn) + desc = describe_injection(spec) + assert "*args" in desc + + +def test_describe_injection_var_keyword() -> None: + """**kwargs 参数应显示为 **kwargs=。""" + + def fn(**kwargs: Any) -> None: + return None + + spec = px.TaskSpec("t", fn, ("a",)) + desc = describe_injection(spec) + assert "**kwargs=" in desc + + +def test_describe_injection_unresolved() -> None: + """无依赖、无静态值、无默认的参数应显示为 。""" + + def fn(missing: int) -> None: + return None + + spec = px.TaskSpec("t", fn) + desc = describe_injection(spec) + assert "missing=" in desc + + +def test_describe_injection_static_kwargs() -> None: + """静态 kwargs 应显示具体值。""" + + def fn(flag: bool = False) -> None: + return None + + spec = px.TaskSpec("t", fn, kwargs={"flag": True}) + desc = describe_injection(spec) + assert "flag=True" in desc + + +def test_describe_injection_positional_args_filled() -> None: + """spec.args 填充的位置参数应显示具体值(覆盖 args_filled 分支)。""" + + def fn(a: int, b: str) -> None: + return None + + spec = px.TaskSpec("t", fn, args=(1, "x")) + desc = describe_injection(spec) + assert "a=1" in desc + assert "b='x'" in desc + + +# ---------------------------------------------------------------------- # +# build_call_args 边界 +# ---------------------------------------------------------------------- # +def test_build_call_args_var_positional_not_required() -> None: + """*args 参数不应触发 InjectionError。""" + + def fn(*args: Any) -> int: + return len(args) + + spec = px.TaskSpec("t", fn, args=(1, 2, 3)) + args, kwargs = build_call_args(spec, {}) + assert args == (1, 2, 3) + assert kwargs == {} + + +def test_build_call_args_var_keyword_consumes_leftover() -> None: + """**kwargs 应吞掉未被具名参数消费的依赖结果。""" + + def fn(a: int, **rest: Any) -> int: + return a + sum(rest.values()) + + spec = px.TaskSpec("t", fn, ("a", "b", "c")) + args, kwargs = build_call_args(spec, {"a": 1, "b": 2, "c": 3}) + assert kwargs == {"a": 1, "b": 2, "c": 3} + + +def test_build_call_args_no_var_keyword_drops_leftover() -> None: + """无 **kwargs 时,未被消费的依赖结果被丢弃(不报错)。""" + + def fn(a: int) -> int: + return a + + spec = px.TaskSpec("t", fn, ("a", "b")) + # b 是依赖但 fn 不接收它 —— 应正常工作 + args, kwargs = build_call_args(spec, {"a": 1, "b": 2}) + assert kwargs == {"a": 1} + + +def test_build_call_args_context_annotation_only_deps() -> None: + """Context 标注只接收该任务自身 depends_on 的结果。""" + + def fn(ctx: px.Context) -> int: + return len(ctx) + + spec = px.TaskSpec("t", fn, ("a", "b")) + args, kwargs = build_call_args(spec, {"a": 1, "b": 2, "c": 99}) + assert kwargs == {"ctx": {"a": 1, "b": 2}} diff --git a/tests/test_errors.py b/tests/test_errors.py new file mode 100644 index 0000000..d684d41 --- /dev/null +++ b/tests/test_errors.py @@ -0,0 +1,90 @@ +"""错误类型测试。""" + +from __future__ import annotations + +import pytest + +import pyflowx as px +from pyflowx.errors import ( + CycleError, + DuplicateTaskError, + InjectionError, + MissingDependencyError, + PyFlowXError, + StorageError, + TaskFailedError, + TaskTimeoutError, +) + + +def test_all_errors_are_pyflowx_subclass() -> None: + assert issubclass(DuplicateTaskError, PyFlowXError) + assert issubclass(MissingDependencyError, PyFlowXError) + assert issubclass(CycleError, PyFlowXError) + assert issubclass(TaskFailedError, PyFlowXError) + assert issubclass(TaskTimeoutError, PyFlowXError) + assert issubclass(InjectionError, PyFlowXError) + assert issubclass(StorageError, PyFlowXError) + + +def test_duplicate_task_error_attributes() -> None: + err = DuplicateTaskError("foo") + assert err.name == "foo" + assert "foo" in str(err) + + +def test_missing_dependency_error_attributes() -> None: + err = MissingDependencyError("child", "parent") + assert err.task == "child" + assert err.dependency == "parent" + assert "child" in str(err) + assert "parent" in str(err) + + +def test_cycle_error_attributes() -> None: + err = CycleError(["a", "b", "c"]) + assert err.cycle == ["a", "b", "c"] + # 链应首尾相接展示 + assert "a -> b -> c -> a" in str(err) + + +def test_task_failed_error_attributes() -> None: + cause = ValueError("boom") + err = TaskFailedError(task="t", cause=cause, attempts=3, layer=2) + assert err.task == "t" + assert err.cause is cause + assert err.attempts == 3 + assert err.layer == 2 + assert "layer 2" in str(err) + + +def test_task_failed_error_without_layer() -> None: + err = TaskFailedError(task="t", cause=RuntimeError("x"), attempts=1) + assert err.layer is None + assert "layer" not in str(err) + + +def test_task_timeout_error_attributes() -> None: + err = TaskTimeoutError(task="t", timeout=1.5) + assert err.task == "t" + assert err.timeout == 1.5 + assert "1.500s" in str(err) + + +def test_injection_error_attributes() -> None: + err = InjectionError(task="t", detail="missing param") + assert err.task == "t" + assert "missing param" in str(err) + + +def test_storage_error_with_cause() -> None: + cause = OSError("disk full") + err = StorageError(detail="write failed", cause=cause) + assert err.cause is cause + assert "write failed" in str(err) + + +def test_storage_error_without_cause() -> None: + err = StorageError(detail="bad") + assert err.cause is None + assert "bad" in str(err) diff --git a/tests/test_executors.py b/tests/test_executors.py index 0da3f9c..aa09fc6 100644 --- a/tests/test_executors.py +++ b/tests/test_executors.py @@ -320,3 +320,196 @@ def test_invalid_strategy() -> None: graph = px.Graph.from_specs([px.TaskSpec("a", lambda: None)]) # type: ignore[arg-type] with pytest.raises(ValueError): px.run(graph, strategy="bogus") # type: ignore[arg-type] + + +# ---------------------------------------------------------------------- # +# 异步策略:sync 任务无 timeout 分支 + timeout 重试分支 +# ---------------------------------------------------------------------- # +def test_async_sync_task_without_timeout() -> None: + """async 策略下执行 sync 任务且无 timeout(覆盖 line 131)。""" + + def sync_fn() -> int: + return 42 + + graph = px.Graph.from_specs([px.TaskSpec("a", sync_fn)]) + report = px.run(graph, strategy="async") + assert report.success + assert report["a"] == 42 + + +def test_async_sync_task_with_timeout() -> None: + """async 策略下执行 sync 任务且带 timeout(覆盖 line 129)。""" + + def sync_fn() -> int: + return 42 + + graph = px.Graph.from_specs([px.TaskSpec("a", sync_fn, timeout=5.0)]) + report = px.run(graph, strategy="async") + assert report.success + assert report["a"] == 42 + + +def test_async_timeout_retry_then_succeed() -> None: + """async 超时后重试成功(覆盖 line 141-151 的重试分支)。""" + calls = {"n": 0} + + async def flaky() -> str: + calls["n"] += 1 + if calls["n"] < 2: + await asyncio.sleep(10) # 触发超时 + return "ok" + + graph = px.Graph.from_specs([px.TaskSpec("a", flaky, retries=2, timeout=0.05)]) + report = px.run(graph, strategy="async") + assert report.success + assert report["a"] == "ok" + assert calls["n"] == 2 + + +def test_async_failure_retry_branch(caplog: pytest.LogCaptureFixture) -> None: + """async 普通异常重试分支(覆盖 line 141-151 的 except Exception 分支)。""" + calls = {"n": 0} + + async def flaky() -> str: + calls["n"] += 1 + if calls["n"] < 2: + raise RuntimeError("not yet") + return "ok" + + graph = px.Graph.from_specs([px.TaskSpec("a", flaky, retries=2)]) + with caplog.at_level("WARNING", logger="pyflowx"): + report = px.run(graph, strategy="async") + assert report.success + assert report["a"] == "ok" + # 确认重试日志确实输出 + assert any("retrying" in r.message for r in caplog.records) + + +# ---------------------------------------------------------------------- # +# 缓存跳过分支:threaded 与 async +# ---------------------------------------------------------------------- # +def test_threaded_skips_cached_tasks() -> None: + """threaded 策略下命中缓存的任务应被跳过(覆盖 line 224-230)。""" + runs: List[str] = [] + + def make(name: str) -> Any: + def fn() -> str: + runs.append(name) + return name + + return fn + + graph = px.Graph.from_specs( + [ + px.TaskSpec("a", make("a")), + px.TaskSpec("b", make("b"), ("a",)), + ] + ) + backend = px.MemoryBackend() + # 第一次运行填充缓存 + px.run(graph, strategy="thread", max_workers=2, state=backend) + assert runs == ["a", "b"] + # 第二次运行应全部跳过 + px.run(graph, strategy="thread", max_workers=2, state=backend) + assert runs == ["a", "b"] # 未再执行 + + +def test_threaded_all_cached_layer() -> None: + """整层全部命中缓存时应直接返回(覆盖 line 235 的 if not to_run: return)。""" + graph = px.Graph.from_specs([px.TaskSpec("a", lambda: 1)]) # type: ignore[arg-type] + backend = px.MemoryBackend() + backend.save("a", 99) + report = px.run(graph, strategy="thread", max_workers=2, state=backend) + assert report["a"] == 99 + assert report.result_of("a").status == px.TaskStatus.SKIPPED + + +def test_async_skips_cached_tasks() -> None: + """async 策略下命中缓存的任务应被跳过(覆盖 line 268-274)。""" + runs: List[str] = [] + + async def make(name: str) -> Any: + async def fn() -> str: + runs.append(name) + return name + + return fn() + + # 用闭包制造可重复调用的 async 函数 + async def a() -> str: + runs.append("a") + return "a" + + async def b(a: str) -> str: + runs.append("b") + return a + "b" + + graph = px.Graph.from_specs( + [ + px.TaskSpec("a", a), + px.TaskSpec("b", b, ("a",)), + ] + ) + backend = px.MemoryBackend() + px.run(graph, strategy="async", state=backend) + assert runs == ["a", "b"] + px.run(graph, strategy="async", state=backend) + assert runs == ["a", "b"] + + +def test_async_all_cached_layer() -> None: + """async 整层全部命中缓存(覆盖 line 279 的 if not to_run: return)。""" + + async def a() -> int: + return 1 + + graph = px.Graph.from_specs([px.TaskSpec("a", a)]) + backend = px.MemoryBackend() + backend.save("a", 77) + report = px.run(graph, strategy="async", state=backend) + assert report["a"] == 77 + assert report.result_of("a").status == px.TaskStatus.SKIPPED + + +# ---------------------------------------------------------------------- # +# 失败后 report.success 标记为 False +# ---------------------------------------------------------------------- # +def test_failure_marks_report_unsuccessful() -> None: + def boom() -> None: + raise ValueError("fail") + + graph = px.Graph.from_specs([px.TaskSpec("a", boom)]) + with pytest.raises(px.TaskFailedError): + px.run(graph, strategy="sequential") + # report 在异常前未返回,但若捕获异常则 success 应为 False + # 这里验证 run() 抛异常的行为本身 + + +# ---------------------------------------------------------------------- # +# dry_run 各策略 +# ---------------------------------------------------------------------- # +def test_dry_run_thread(capsys: pytest.CaptureFixture[str]) -> None: + graph = px.Graph.from_specs([px.TaskSpec("a", lambda: 1)]) # type: ignore[arg-type] + report = px.run(graph, strategy="thread", dry_run=True) + assert len(report) == 0 + assert "Dry run" in capsys.readouterr().out + + +def test_dry_run_async(capsys: pytest.CaptureFixture[str]) -> None: + async def a() -> int: + return 1 + + graph = px.Graph.from_specs([px.TaskSpec("a", a)]) + report = px.run(graph, strategy="async", dry_run=True) + assert len(report) == 0 + assert "Dry run" in capsys.readouterr().out + + +# ---------------------------------------------------------------------- # +# 空图运行 +# ---------------------------------------------------------------------- # +def test_run_empty_graph() -> None: + graph = px.Graph() + report = px.run(graph, strategy="sequential") + assert report.success + assert len(report) == 0 diff --git a/tests/test_graph.py b/tests/test_graph.py index fd6294c..ad268c3 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -13,11 +13,13 @@ def _fn() -> None: def test_from_specs_builds_graph() -> None: - graph = px.Graph.from_specs([ - px.TaskSpec("a", _fn), - px.TaskSpec("b", _fn, ("a",)), - px.TaskSpec("c", _fn, ("a", "b")), - ]) + graph = px.Graph.from_specs( + [ + px.TaskSpec("a", _fn), + px.TaskSpec("b", _fn, ("a",)), + px.TaskSpec("c", _fn, ("a", "b")), + ] + ) assert set(graph.names) == {"a", "b", "c"} assert graph.dependencies("c") == ("a", "b") assert len(graph) == 3 @@ -26,19 +28,23 @@ def test_from_specs_builds_graph() -> None: def test_from_specs_allows_forward_references() -> None: # b depends on a, but a is declared after b — order should not matter. - graph = px.Graph.from_specs([ - px.TaskSpec("b", _fn, ("a",)), - px.TaskSpec("a", _fn), - ]) + graph = px.Graph.from_specs( + [ + px.TaskSpec("b", _fn, ("a",)), + px.TaskSpec("a", _fn), + ] + ) assert graph.layers() == [["a"], ["b"]] def test_duplicate_task_raises() -> None: with pytest.raises(DuplicateTaskError): - px.Graph.from_specs([ - px.TaskSpec("a", _fn), - px.TaskSpec("a", _fn), - ]) + px.Graph.from_specs( + [ + px.TaskSpec("a", _fn), + px.TaskSpec("a", _fn), + ] + ) def test_missing_dependency_raises() -> None: @@ -50,20 +56,24 @@ def test_missing_dependency_raises() -> None: def test_cycle_detection() -> None: with pytest.raises(CycleError): - px.Graph.from_specs([ - px.TaskSpec("a", _fn, ("c",)), - px.TaskSpec("b", _fn, ("a",)), - px.TaskSpec("c", _fn, ("b",)), - ]) + px.Graph.from_specs( + [ + px.TaskSpec("a", _fn, ("c",)), + px.TaskSpec("b", _fn, ("a",)), + px.TaskSpec("c", _fn, ("b",)), + ] + ) def test_layers_grouping() -> None: - graph = px.Graph.from_specs([ - px.TaskSpec("a", _fn), - px.TaskSpec("b", _fn), - px.TaskSpec("c", _fn, ("a", "b")), - px.TaskSpec("d", _fn, ("c",)), - ]) + graph = px.Graph.from_specs( + [ + px.TaskSpec("a", _fn), + px.TaskSpec("b", _fn), + px.TaskSpec("c", _fn, ("a", "b")), + px.TaskSpec("d", _fn, ("c",)), + ] + ) layers = graph.layers() assert layers == [["a", "b"], ["c"], ["d"]] @@ -74,10 +84,12 @@ def test_self_dependency_rejected() -> None: def test_to_mermaid() -> None: - graph = px.Graph.from_specs([ - px.TaskSpec("a", _fn), - px.TaskSpec("b", _fn, ("a",)), - ]) + graph = px.Graph.from_specs( + [ + px.TaskSpec("a", _fn), + px.TaskSpec("b", _fn, ("a",)), + ] + ) mermaid = graph.to_mermaid() assert mermaid.startswith("graph TD") assert 'a["a"]' in mermaid @@ -91,11 +103,13 @@ def test_to_mermaid_invalid_orientation() -> None: def test_subgraph_by_tags() -> None: - graph = px.Graph.from_specs([ - px.TaskSpec("a", _fn, tags=("ingest",)), - px.TaskSpec("b", _fn, ("a",), tags=("ingest",)), - px.TaskSpec("c", _fn, ("b",), tags=("report",)), - ]) + graph = px.Graph.from_specs( + [ + px.TaskSpec("a", _fn, tags=("ingest",)), + px.TaskSpec("b", _fn, ("a",), tags=("ingest",)), + px.TaskSpec("c", _fn, ("b",), tags=("report",)), + ] + ) sub = graph.subgraph(["ingest"]) assert set(sub.names) == {"a", "b"} # Edge to dropped task c is removed; b no longer waits for anything @@ -104,11 +118,13 @@ def test_subgraph_by_tags() -> None: def test_subgraph_by_names() -> None: - graph = px.Graph.from_specs([ - px.TaskSpec("a", _fn), - px.TaskSpec("b", _fn, ("a",)), - px.TaskSpec("c", _fn, ("b",)), - ]) + graph = px.Graph.from_specs( + [ + px.TaskSpec("a", _fn), + px.TaskSpec("b", _fn, ("a",)), + px.TaskSpec("c", _fn, ("b",)), + ] + ) sub = graph.subgraph_by_names(["a", "b"]) assert set(sub.names) == {"a", "b"} # c is dropped, so b's dep on c (none here) — but a->b edge preserved. @@ -122,10 +138,93 @@ def test_subgraph_by_names_unknown() -> None: def test_describe() -> None: - graph = px.Graph.from_specs([ - px.TaskSpec("a", _fn), - px.TaskSpec("b", _fn, ("a",)), - ]) + graph = px.Graph.from_specs( + [ + px.TaskSpec("a", _fn), + px.TaskSpec("b", _fn, ("a",)), + ] + ) desc = graph.describe() assert "Layer 1" in desc assert "Layer 2" in desc + + +# ---------------------------------------------------------------------- # +# 增量 add API 与其他访问器 +# ---------------------------------------------------------------------- # +def test_add_chains_and_validates() -> None: + """add() 应返回 self 以支持链式调用,并即时校验。""" + graph = px.Graph() + ret = graph.add(px.TaskSpec("a", _fn)) + assert ret is graph + assert "a" in graph + # 缺失依赖应即时报错 + with pytest.raises(MissingDependencyError): + graph.add(px.TaskSpec("b", _fn, ("missing",))) + + +def test_add_duplicate_raises() -> None: + graph = px.Graph() + graph.add(px.TaskSpec("a", _fn)) + with pytest.raises(DuplicateTaskError): + graph.add(px.TaskSpec("a", _fn)) + + +def test_all_specs_returns_view() -> None: + graph = px.Graph.from_specs([px.TaskSpec("a", _fn)]) + view = graph.all_specs() + assert set(view.keys()) == {"a"} + # 返回的是只读视图,修改不影响内部 + assert view is graph.all_specs() or view == graph.all_specs() + + +def test_spec_accessor() -> None: + graph = px.Graph.from_specs([px.TaskSpec("a", _fn)]) + assert graph.spec("a").name == "a" + with pytest.raises(KeyError): + graph.spec("missing") + + +def test_dependencies_accessor() -> None: + graph = px.Graph.from_specs( + [ + px.TaskSpec("a", _fn), + px.TaskSpec("b", _fn, ("a",)), + ] + ) + assert graph.dependencies("a") == () + assert graph.dependencies("b") == ("a",) + + +def test_repr() -> None: + graph = px.Graph.from_specs([px.TaskSpec("a", _fn)]) + assert repr(graph) == "Graph(tasks=1)" + + +def test_empty_graph_layers() -> None: + """空图的 layers() 应返回空列表。""" + graph = px.Graph() + assert graph.layers() == [] + assert graph.to_mermaid() == "graph TD\n" + + +def test_subgraph_preserves_metadata() -> None: + """子图应保留原任务的 retries/timeout/tags 等元数据。""" + graph = px.Graph.from_specs( + [ + px.TaskSpec("a", _fn, tags=("x",), retries=3, timeout=5.0), + px.TaskSpec("b", _fn, ("a",), tags=("y",)), + ] + ) + sub = graph.subgraph(["x"]) + spec = sub.spec("a") + assert spec.retries == 3 + assert spec.timeout == 5.0 + assert spec.tags == ("x",) + + +def test_subgraph_by_tags_no_match() -> None: + """无匹配 tag 时返回空图。""" + graph = px.Graph.from_specs([px.TaskSpec("a", _fn, tags=("x",))]) + sub = graph.subgraph(["z"]) + assert len(sub) == 0 diff --git a/tests/test_report.py b/tests/test_report.py new file mode 100644 index 0000000..a4dc3d6 --- /dev/null +++ b/tests/test_report.py @@ -0,0 +1,121 @@ +"""RunReport 测试。""" + +from __future__ import annotations + +from datetime import datetime + +import pyflowx as px +from pyflowx.task import TaskResult, TaskSpec, TaskStatus + + +def _fn() -> int: + return 1 + + +def _make_result( + name: str = "a", + status: TaskStatus = TaskStatus.SUCCESS, + value: object = 42, + error: object = None, + duration: float = 0.5, + attempts: int = 1, +) -> TaskResult[object]: + spec: TaskSpec[object] = TaskSpec(name, _fn) # type: ignore[arg-type] + start = datetime(2024, 1, 1, 0, 0, 0) + # 用 timedelta 精确表达秒数,避免 int() 截断小数 + from datetime import timedelta + + end = start + timedelta(seconds=duration) if duration else None + return TaskResult( + spec=spec, + status=status, + value=value, # type: ignore[arg-type] + error=error, # type: ignore[arg-type] + attempts=attempts, + started_at=start, + finished_at=end, + ) + + +def test_getitem_returns_value() -> None: + report = px.RunReport() + report.results["a"] = _make_result("a", value=7) + assert report["a"] == 7 + + +def test_result_of_returns_full_result() -> None: + report = px.RunReport() + r = _make_result("a") + report.results["a"] = r + assert report.result_of("a") is r + + +def test_contains() -> None: + report = px.RunReport() + report.results["a"] = _make_result("a") + assert "a" in report + assert "b" not in report + + +def test_iter_and_len() -> None: + report = px.RunReport() + report.results["a"] = _make_result("a") + report.results["b"] = _make_result("b") + assert list(report) == ["a", "b"] + assert len(report) == 2 + + +def test_summary_success() -> None: + report = px.RunReport() + report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS, duration=1.0) + report.results["b"] = _make_result("b", status=TaskStatus.SKIPPED, duration=0.0) + s = report.summary() + assert s["success"] is True + assert s["total_tasks"] == 2 + assert s["by_status"] == {"success": 1, "skipped": 1} + assert s["total_duration_seconds"] == 1.0 + + +def test_summary_with_none_duration() -> None: + """未开始/未结束的任务 duration 为 None,不应计入总时长。""" + report = px.RunReport() + spec: TaskSpec[object] = TaskSpec("a", _fn) # type: ignore[arg-type] + report.results["a"] = TaskResult(spec=spec, status=TaskStatus.FAILED) + s = report.summary() + assert s["total_duration_seconds"] == 0.0 + + +def test_failed_tasks() -> None: + report = px.RunReport() + report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS) + report.results["b"] = _make_result( + "b", status=TaskStatus.FAILED, error=ValueError("x") + ) + assert report.failed_tasks() == ["b"] + + +def test_describe_success() -> None: + report = px.RunReport() + report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS, duration=0.5) + desc = report.describe() + assert "RunReport(success=True)" in desc + assert "a: success" in desc + assert "0.500s" in desc + + +def test_describe_with_error() -> None: + report = px.RunReport(success=False) + report.results["a"] = _make_result( + "a", status=TaskStatus.FAILED, error=ValueError("boom"), duration=0.1 + ) + desc = report.describe() + assert "success=False" in desc + assert "error=ValueError" in desc + + +def test_describe_no_duration() -> None: + report = px.RunReport() + spec: TaskSpec[object] = TaskSpec("a", _fn) # type: ignore[arg-type] + report.results["a"] = TaskResult(spec=spec, status=TaskStatus.PENDING) + desc = report.describe() + assert "-" in desc # duration 显示为 "-" diff --git a/tests/test_storage.py b/tests/test_storage.py new file mode 100644 index 0000000..9506912 --- /dev/null +++ b/tests/test_storage.py @@ -0,0 +1,162 @@ +"""状态后端测试。""" + +from __future__ import annotations + +import json +import os +import tempfile +from typing import Any + +import pytest + +from pyflowx.errors import StorageError +from pyflowx.storage import JSONBackend, MemoryBackend, StateBackend, resolve_backend + + +# ---------------------------------------------------------------------- # +# MemoryBackend +# ---------------------------------------------------------------------- # +def test_memory_backend_lifecycle() -> None: + b = MemoryBackend() + assert not b.has("a") + b.save("a", 1) + assert b.has("a") + assert b.get("a") == 1 + assert dict(b.load()) == {"a": 1} + b.clear() + assert not b.has("a") + assert dict(b.load()) == {} + + +def test_memory_backend_get_missing_raises() -> None: + b = MemoryBackend() + with pytest.raises(KeyError): + b.get("nope") + + +# ---------------------------------------------------------------------- # +# JSONBackend +# ---------------------------------------------------------------------- # +def test_json_backend_save_and_load() -> None: + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "state.json") + b = JSONBackend(path) + b.save("a", {"x": 1}) + b.save("b", [1, 2, 3]) + # 重新打开应读到已保存内容 + b2 = JSONBackend(path) + assert b2.has("a") + assert b2.get("a") == {"x": 1} + assert b2.get("b") == [1, 2, 3] + assert dict(b2.load()) == {"a": {"x": 1}, "b": [1, 2, 3]} + + +def test_json_backend_clear() -> None: + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "state.json") + b = JSONBackend(path) + b.save("a", 1) + b.clear() + assert not b.has("a") + # 文件应被写入空 dict + with open(path, "r", encoding="utf-8") as fh: + assert json.load(fh) == {} + + +def test_json_backend_nonexistent_file_starts_empty() -> None: + """文件不存在时应正常初始化为空。""" + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "absent.json") + b = JSONBackend(path) + assert dict(b.load()) == {} + assert not b.has("anything") + + +def test_json_backend_non_serialisable_raises() -> None: + """不可 JSON 序列化的值应抛 StorageError,且不污染内存状态。""" + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "state.json") + b = JSONBackend(path) + with pytest.raises(StorageError): + b.save("a", object()) # object() 不可序列化 + assert not b.has("a") + + +def test_json_backend_flush_type_error(monkeypatch: pytest.MonkeyPatch) -> None: + """_flush 时 json.dump 抛 TypeError 应转为 StorageError(覆盖 line 105-106)。 + + 通过 monkeypatch 让 json.dump 在写入文件时抛 TypeError,模拟值通过 + save 的 dumps 校验但在 dump 到文件句柄时失败(如自定义对象的边缘情况)。 + """ + import json as _json + + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "state.json") + b = JSONBackend(path) + + original_dump = _json.dump + + def flaky_dump(*args: Any, **kwargs: Any) -> None: + raise TypeError("simulated flush failure") + + monkeypatch.setattr(_json, "dump", flaky_dump) + with pytest.raises(StorageError, match="cannot write"): + b.save("a", 1) + # 恢复以便后续测试不受影响 + monkeypatch.setattr(_json, "dump", original_dump) + + +def test_json_backend_flush_os_error(monkeypatch: pytest.MonkeyPatch) -> None: + """_flush 时 OSError 应转为 StorageError。""" + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "state.json") + b = JSONBackend(path) + + original_replace = os.replace + + def fail_replace(*args: Any, **kwargs: Any) -> None: + raise OSError("simulated os.replace failure") + + monkeypatch.setattr(os, "replace", fail_replace) + with pytest.raises(StorageError, match="cannot write"): + b.save("a", 1) + monkeypatch.setattr(os, "replace", original_replace) + + +def test_json_backend_corrupt_file_raises() -> None: + """损坏的 JSON 文件应抛 StorageError。""" + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "state.json") + with open(path, "w", encoding="utf-8") as fh: + fh.write("{not valid json") + with pytest.raises(StorageError): + JSONBackend(path) + + +def test_json_backend_non_dict_content_ignored() -> None: + """文件内容是合法 JSON 但非 dict 时应被忽略(保持空)。""" + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "state.json") + with open(path, "w", encoding="utf-8") as fh: + json.dump([1, 2, 3], fh) # list 而非 dict + b = JSONBackend(path) + assert dict(b.load()) == {} + + +# ---------------------------------------------------------------------- # +# resolve_backend +# ---------------------------------------------------------------------- # +def test_resolve_backend_returns_input() -> None: + b = MemoryBackend() + assert resolve_backend(b) is b + + +def test_resolve_backend_creates_memory_when_none() -> None: + b = resolve_backend(None) + assert isinstance(b, MemoryBackend) + + +def test_state_backend_is_abstract() -> None: + """StateBackend 是 ABC,不能直接实例化。""" + with pytest.raises(TypeError): + StateBackend() # type: ignore[abstract] diff --git a/tests/test_task.py b/tests/test_task.py new file mode 100644 index 0000000..5a38878 --- /dev/null +++ b/tests/test_task.py @@ -0,0 +1,66 @@ +"""TaskSpec / TaskResult 数据结构测试。""" + +from __future__ import annotations + +from datetime import datetime + +import pytest + +import pyflowx as px +from pyflowx.task import TaskResult, TaskSpec, TaskStatus + + +def _fn() -> None: + return None + + +def test_spec_empty_name_rejected() -> None: + with pytest.raises(ValueError, match="non-empty"): + TaskSpec("", _fn) + + +def test_spec_negative_retries_rejected() -> None: + with pytest.raises(ValueError, match="retries"): + TaskSpec("a", _fn, retries=-1) + + +def test_spec_zero_timeout_rejected() -> None: + with pytest.raises(ValueError, match="timeout"): + TaskSpec("a", _fn, timeout=0) + + +def test_spec_self_dependency_rejected() -> None: + with pytest.raises(ValueError, match="depend on itself"): + TaskSpec("a", _fn, depends_on=("a",)) + + +def test_task_result_duration_none_when_not_started() -> None: + spec: TaskSpec[None] = TaskSpec("a", _fn) + result: TaskResult[None] = TaskResult(spec=spec) + assert result.duration is None + + +def test_task_result_duration_when_partial() -> None: + spec: TaskSpec[None] = TaskSpec("a", _fn) + result: TaskResult[None] = TaskResult(spec=spec, started_at=datetime.now()) + # started_at 已设但 finished_at 未设 -> None + assert result.duration is None + + +def test_task_result_duration_computed() -> None: + spec: TaskSpec[None] = TaskSpec("a", _fn) + start = datetime(2024, 1, 1, 0, 0, 0) + end = datetime(2024, 1, 1, 0, 0, 5) + result: TaskResult[None] = TaskResult( + spec=spec, started_at=start, finished_at=end + ) + assert result.duration == 5.0 + + +def test_task_result_default_status() -> None: + spec: TaskSpec[None] = TaskSpec("a", _fn) + result: TaskResult[None] = TaskResult(spec=spec) + assert result.status == TaskStatus.PENDING + assert result.value is None + assert result.error is None + assert result.attempts == 0