chore: 完成项目汉化与测试覆盖增强

- 将项目文档、注释全量翻译为简体中文
- 新增 coverage 配置并要求 100% 分支覆盖率
- 补充所有模块的单元测试用例,覆盖全分支场景
- 重构执行器代码,提取公共重试与失败逻辑
This commit is contained in:
2026-06-20 13:09:35 +08:00
parent 8b7777d936
commit a352529263
16 changed files with 1192 additions and 329 deletions
+10
View File
@@ -70,3 +70,13 @@ url = "https://mirrors.aliyun.com/pypi/simple/"
[dependency-groups] [dependency-groups]
dev = ["pyflowx[dev]"] dev = ["pyflowx[dev]"]
[tool.coverage.run]
branch = true
concurrency = ["greenlet", "thread"]
source = ["pyflowx"]
[tool.coverage.report]
exclude_lines = ["if TYPE_CHECKING:", "if __name__ == .__main__.:", "pragma: no cover", "raise NotImplementedError"]
fail_under = 100
show_missing = true
+16 -16
View File
@@ -1,16 +1,16 @@
"""PyFlowX — lightweight, type-safe DAG task scheduler. """PyFlowX —— 轻量、类型安全的 DAG 任务调度器。
Public API 公共 API
---------- --------
* :class:`TaskSpec` — immutable task descriptor (the only thing you configure). * :class:`TaskSpec` —— 不可变任务描述符(唯一需要配置的东西)。
* :class:`Graph` — DAG built from a list of specs; validates, layers, visualises. * :class:`Graph` —— 由一组 spec 构建的 DAG;负责校验、分层、可视化。
* :func:`run` — execute a graph with ``sequential`` / ``thread`` / ``async``. * :func:`run` —— 以 ``sequential`` / ``thread`` / ``async`` 策略执行图。
* :class:`RunReport` — typed, queryable result of a run. * :class:`RunReport` —— 类型化、可查询的运行结果。
* :class:`Context` — annotation marker for whole-context injection. * :class:`Context` —— 整体上下文注入的标注标记。
* State backends: :class:`StateBackend`, :class:`MemoryBackend`, :class:`JSONBackend`. * 状态后端::class:`StateBackend`:class:`MemoryBackend`:class:`JSONBackend`
Quick start 快速上手
----------- --------
import pyflowx as px import pyflowx as px
def extract() -> list[int]: return [1, 2, 3] def extract() -> list[int]: return [1, 2, 3]
@@ -46,7 +46,7 @@ from .task import TaskEvent, TaskResult, TaskSpec, TaskStatus
__version__ = "0.1.0" __version__ = "0.1.0"
__all__ = [ __all__ = [
# core types # 核心类型
"TaskSpec", "TaskSpec",
"TaskStatus", "TaskStatus",
"TaskResult", "TaskResult",
@@ -54,13 +54,13 @@ __all__ = [
"Context", "Context",
"Graph", "Graph",
"RunReport", "RunReport",
# execution # 执行
"run", "run",
# state backends # 状态后端
"StateBackend", "StateBackend",
"MemoryBackend", "MemoryBackend",
"JSONBackend", "JSONBackend",
# errors # 错误
"PyFlowXError", "PyFlowXError",
"DuplicateTaskError", "DuplicateTaskError",
"MissingDependencyError", "MissingDependencyError",
@@ -69,7 +69,7 @@ __all__ = [
"TaskTimeoutError", "TaskTimeoutError",
"InjectionError", "InjectionError",
"StorageError", "StorageError",
# helpers (advanced) # 辅助(高级)
"build_call_args", "build_call_args",
"describe_injection", "describe_injection",
] ]
+51 -60
View File
@@ -1,22 +1,18 @@
"""Context injection: turn upstream results into function arguments. """上下文注入:把上游结果转换为函数参数。
This is the mechanism that lets users write plain functions whose 本机制让用户可以编写普通函数,其参数名*就是*依赖声明,从而消除其他
parameter names *are* the dependency declarations, removing the boiler- DAG 库中泛滥的样板包装器(如 ``def wrapper(): return fn(workflow.get_task_result('x'))``)。
plate wrappers that plague other DAG libraries (e.g. ``def wrapper():
return fn(workflow.get_task_result('x'))``).
Injection rules (evaluated in order) 注入规则(按顺序求值)
----------------------------------- ----------------------
1. A parameter whose **annotation is** :class:`Context` receives the full 1. **标注为** :class:`Context` 的参数接收完整结果映射。适用于需要遍历
result mapping. Useful for tasks that need to iterate over all inputs. 所有输入的任务。
2. A parameter whose **name matches a dependency** receives that 2. **名称匹配某个依赖**的参数接收该依赖的结果。
dependency's result. 3. ``**kwargs`` 参数以 dict 形式接收*所有*依赖结果。
3. A ``**kwargs`` parameter receives *all* dependency results as a dict. 4. ``TaskSpec.args`` / ``TaskSpec.kwargs`` 为*非依赖*参数提供静态值。
4. ``TaskSpec.args`` / ``TaskSpec.kwargs`` supply static values for
parameters that are *not* dependencies.
If a parameter cannot be resolved and has no default, an 若某参数无法解析且无默认值,则抛出 :class:`~pyflowx.errors.InjectionError`
:class:`~pyflowx.errors.InjectionError` is raised with a precise message. 并附带精确错误信息。
""" """
from __future__ import annotations from __future__ import annotations
@@ -27,26 +23,25 @@ from typing import Any, Dict, List, Mapping, Set, Tuple
from .errors import InjectionError from .errors import InjectionError
from .task import Context, TaskSpec from .task import Context, TaskSpec
__all__ = ["Context", "build_call_args", "describe_injection"] __all__ = ["Context", "build_call_args", "describe_injection", "_is_context_annotation"]
def _is_context_annotation(annotation: Any) -> bool: def _is_context_annotation(annotation: Any) -> bool:
"""True when a parameter annotation is (or refers to) ``Context``. """判断参数标注是否为(或指向)``Context``
Handles three forms: 处理三种形式:
* the ``Context`` alias object itself; * ``Context`` 别名对象本身;
* a typing alias whose ``__name__``/``_name`` is ``Context`` or ``Mapping``; * ``__name__``/``_name`` ``Context`` ``Mapping`` 的 typing 别名;
* a *string* annotation (``from __future__ import annotations`` makes all * *字符串*标注(``from __future__ import annotations`` 会在运行时
annotations strings at runtime) such as ``"Context"`` or ``"px.Context"``. 把所有标注变为字符串),如 ``"Context"`` ``"px.Context"``
""" """
if annotation is Context: if annotation is Context:
return True return True
# String annotation from `from __future__ import annotations`. # `from __future__ import annotations` 产生的字符串标注。
if isinstance(annotation, str): if isinstance(annotation, str):
# Match "Context", "px.Context", "pyflowx.Context", etc. # 匹配 "Context""px.Context""pyflowx.Context" 等。
return annotation == "Context" or annotation.endswith(".Context") return annotation == "Context" or annotation.endswith(".Context")
# Match by qualified name to support ``from pyflowx import Context`` # 按限定名匹配,支持 ``from pyflowx import Context`` 再导出。
# re-exports.
name = getattr(annotation, "__name__", None) or getattr(annotation, "_name", None) name = getattr(annotation, "__name__", None) or getattr(annotation, "_name", None)
if name in ("Context", "Mapping"): if name in ("Context", "Mapping"):
return True return True
@@ -57,43 +52,41 @@ def build_call_args(
spec: TaskSpec[object], spec: TaskSpec[object],
context: Mapping[str, Any], context: Mapping[str, Any],
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""Resolve the ``(args, kwargs)`` to call ``spec.fn`` with. """解析用于调用 ``spec.fn`` 的 ``(args, kwargs)``。
Parameters 参数
---------- ----
spec: spec:
The task spec, providing ``fn``, ``depends_on``, ``args``, ``kwargs``. 任务 spec,提供 ``fn````depends_on````args````kwargs``
context: context:
Mapping of dependency-name -> result value. Only the task's own 依赖名 -> 结果值的映射。仅保证本任务自身的 ``depends_on`` 条目
``depends_on`` entries are guaranteed present; other tasks' results 存在;其他任务的结果被排除,以保持注入的确定性。
are excluded to keep injection deterministic.
Returns 返回
------- ----
(args, kwargs) (args, kwargs)
Ready to splat into ``spec.fn(*args, **kwargs)``. 可直接展开为 ``spec.fn(*args, **kwargs)``
Raises 抛出
------ ----
InjectionError InjectionError
If a required parameter cannot be satisfied, or if static 若必需参数无法满足,或静态 ``kwargs`` 与注入依赖名冲突。
``kwargs`` collide with an injected dependency name.
""" """
sig = inspect.signature(spec.fn) sig = inspect.signature(spec.fn)
params = sig.parameters params = sig.parameters
# Detect special parameter kinds. # 检测特殊参数类型。
var_keyword = next( var_keyword = next(
(p for p in params.values() if p.kind == inspect.Parameter.VAR_KEYWORD), (p for p in params.values() if p.kind == inspect.Parameter.VAR_KEYWORD),
None, None,
) )
# The subset of context relevant to this task. # 与本任务相关的上下文子集。
dep_context: Dict[str, Any] = { dep_context: Dict[str, Any] = {
name: context[name] for name in spec.depends_on if name in context name: context[name] for name in spec.depends_on if name in context
} }
# Detect collisions between static kwargs and dependency names. # 检测静态 kwargs 与依赖名的冲突。
collisions = set(spec.kwargs) & set(dep_context) collisions = set(spec.kwargs) & set(dep_context)
if collisions: if collisions:
raise InjectionError( raise InjectionError(
@@ -105,9 +98,8 @@ def build_call_args(
injected_kwargs: Dict[str, Any] = {} injected_kwargs: Dict[str, Any] = {}
leftover_dep_results: Dict[str, Any] = dict(dep_context) leftover_dep_results: Dict[str, Any] = dict(dep_context)
# Positional parameters consumed by spec.args. We track which param # 被 spec.args 消费的位置参数。记录哪些参数名已被位置填充,
# names are filled positionally so they are skipped during name-based # 以便在基于名称的注入(依赖 / Context / 静态 kwargs)时跳过。
# injection (dependency / Context / static kwargs).
positional_params: List[str] = [] positional_params: List[str] = []
positional_kinds = ( positional_kinds = (
inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_ONLY,
@@ -116,33 +108,33 @@ def build_call_args(
for pname, param in params.items(): for pname, param in params.items():
if param.kind in positional_kinds: if param.kind in positional_kinds:
positional_params.append(pname) positional_params.append(pname)
# The first len(spec.args) positional params are filled by spec.args. # len(spec.args) 个位置参数由 spec.args 填充。
args_filled: Set[str] = set(positional_params[: len(spec.args)]) args_filled: Set[str] = set(positional_params[: len(spec.args)])
for pname, param in params.items(): for pname, param in params.items():
# Skip params already filled by positional spec.args. # 跳过已被位置 spec.args 填充的参数。
if pname in args_filled: if pname in args_filled:
continue continue
# Rule 1: annotated as Context -> full mapping. # 规则 1:标注为 Context -> 完整映射。
if _is_context_annotation(param.annotation): if _is_context_annotation(param.annotation):
injected_kwargs[pname] = dep_context injected_kwargs[pname] = dep_context
continue continue
# Rule 2: name matches a dependency. # 规则 2:名称匹配某个依赖。
if pname in dep_context: if pname in dep_context:
injected_kwargs[pname] = dep_context[pname] injected_kwargs[pname] = dep_context[pname]
leftover_dep_results.pop(pname, None) leftover_dep_results.pop(pname, None)
continue continue
# Rule 3: handled after the loop via **kwargs. # 规则 3:在循环后通过 **kwargs 处理。
# Rule 4: static kwargs fill the rest. # 规则 4:静态 kwargs 填充其余参数。
if pname in spec.kwargs: if pname in spec.kwargs:
injected_kwargs[pname] = spec.kwargs[pname] injected_kwargs[pname] = spec.kwargs[pname]
continue continue
# No source for this parameter: must have a default, else error. # 该参数无来源:必须有默认值,否则报错。
if param.default is inspect.Parameter.empty and param.kind not in ( if param.default is inspect.Parameter.empty and param.kind not in (
inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_KEYWORD,
@@ -152,10 +144,9 @@ def build_call_args(
f"parameter {pname!r} has no dependency, static value, or default.", f"parameter {pname!r} has no dependency, static value, or default.",
) )
# Rule 3: **kwargs swallows remaining dependency results. # 规则 3**kwargs 吞掉剩余依赖结果。
if var_keyword is not None and leftover_dep_results: if var_keyword is not None and leftover_dep_results:
# Merge static kwargs first, then dependency results (static wins # 先合并静态 kwargs,再合并依赖结果(冲突已在上方拒绝)。
# on collision — but we already rejected collisions above).
merged = dict(spec.kwargs) merged = dict(spec.kwargs)
merged.update(injected_kwargs) merged.update(injected_kwargs)
merged.update(leftover_dep_results) merged.update(leftover_dep_results)
@@ -165,12 +156,12 @@ def build_call_args(
def describe_injection(spec: TaskSpec[object]) -> str: def describe_injection(spec: TaskSpec[object]) -> str:
"""Human-readable description of how a task's args will be injected. """生成任务参数注入方式的人类可读描述。
Used by ``dry_run`` to show the execution plan without executing it. ``dry_run`` 使用,在不执行的情况下展示执行计划。
""" """
sig = inspect.signature(spec.fn) sig = inspect.signature(spec.fn)
# Determine which positional params are filled by spec.args. # 确定哪些位置参数由 spec.args 填充。
positional_params = [ positional_params = [
p p
for p, param in sig.parameters.items() for p, param in sig.parameters.items()
+13 -14
View File
@@ -1,8 +1,7 @@
"""PyFlowX error hierarchy. """PyFlowX 错误层级。
All errors are concrete subclasses of :class:`PyFlowXError` so callers can 所有错误都是 :class:`PyFlowXError` 的具体子类,调用者可以用单个
catch the entire family with a single ``except`` clause, while still being ``except`` 子句捕获整个错误家族,同时仍可按类型区分以做细粒度处理。
able to discriminate by type for fine-grained handling.
""" """
from __future__ import annotations from __future__ import annotations
@@ -11,11 +10,11 @@ from typing import Any, Iterable, Optional
class PyFlowXError(Exception): class PyFlowXError(Exception):
"""Base class for every PyFlowX error.""" """所有 PyFlowX 错误的基类。"""
class DuplicateTaskError(PyFlowXError): class DuplicateTaskError(PyFlowXError):
"""Raised when a task name is registered more than once.""" """任务名被重复注册时抛出。"""
def __init__(self, name: str) -> None: def __init__(self, name: str) -> None:
super().__init__(f"Task '{name}' is already registered in the graph.") super().__init__(f"Task '{name}' is already registered in the graph.")
@@ -23,7 +22,7 @@ class DuplicateTaskError(PyFlowXError):
class MissingDependencyError(PyFlowXError): class MissingDependencyError(PyFlowXError):
"""Raised when a task depends on a name that is not in the graph.""" """任务依赖了图中不存在的名称时抛出。"""
def __init__(self, task: str, dependency: str) -> None: def __init__(self, task: str, dependency: str) -> None:
super().__init__( super().__init__(
@@ -35,7 +34,7 @@ class MissingDependencyError(PyFlowXError):
class CycleError(PyFlowXError): class CycleError(PyFlowXError):
"""Raised when the dependency graph contains a cycle.""" """依赖图存在环时抛出。"""
def __init__(self, cycle: Iterable[str]) -> None: def __init__(self, cycle: Iterable[str]) -> None:
cycle_list = list(cycle) cycle_list = list(cycle)
@@ -45,10 +44,10 @@ class CycleError(PyFlowXError):
class TaskFailedError(PyFlowXError): class TaskFailedError(PyFlowXError):
"""Raised when a task fails after exhausting all retries. """任务耗尽所有重试后仍失败时抛出。
The original exception is preserved on :attr:`__cause__` and also exposed 原始异常保留在 :attr:`__cause__` 上,同时通过 :attr:`cause` 暴露,
via :attr:`cause` for convenient access in user code. 便于用户代码访问。
""" """
def __init__( def __init__(
@@ -69,7 +68,7 @@ class TaskFailedError(PyFlowXError):
class TaskTimeoutError(PyFlowXError): class TaskTimeoutError(PyFlowXError):
"""Raised when a task exceeds its configured timeout.""" """任务超出配置的超时时间时抛出。"""
def __init__(self, task: str, timeout: float) -> None: def __init__(self, task: str, timeout: float) -> None:
super().__init__(f"Task '{task}' timed out after {timeout:.3f}s.") super().__init__(f"Task '{task}' timed out after {timeout:.3f}s.")
@@ -78,7 +77,7 @@ class TaskTimeoutError(PyFlowXError):
class InjectionError(PyFlowXError): class InjectionError(PyFlowXError):
"""Raised when context injection cannot satisfy a task signature.""" """上下文注入无法满足任务签名时抛出。"""
def __init__(self, task: str, detail: str) -> None: def __init__(self, task: str, detail: str) -> None:
super().__init__(f"Cannot inject context for task '{task}': {detail}") super().__init__(f"Cannot inject context for task '{task}': {detail}")
@@ -86,7 +85,7 @@ class InjectionError(PyFlowXError):
class StorageError(PyFlowXError): class StorageError(PyFlowXError):
"""Raised by state backends on persistence failures.""" """状态后端在持久化失败时抛出。"""
def __init__(self, detail: str, cause: Optional[BaseException] = None) -> None: def __init__(self, detail: str, cause: Optional[BaseException] = None) -> None:
super().__init__(f"State storage error: {detail}") super().__init__(f"State storage error: {detail}")
+34 -37
View File
@@ -61,6 +61,31 @@ def _emit(
) )
def _log_retry(
spec: TaskSpec[object], attempts: int, max_attempts: int, exc: BaseException
) -> None:
"""记录重试日志(sync 与 async 共享,便于测试覆盖)。"""
logger.warning(
"task %r failed (attempt %d/%d): %r; retrying",
spec.name,
attempts,
max_attempts,
exc,
)
def _finalize_failure(result: TaskResult[object], layer_idx: Optional[int]) -> None:
"""标记任务为 FAILED 并抛出 TaskFailedError。"""
result.status = TaskStatus.FAILED
result.finished_at = datetime.now()
raise TaskFailedError(
task=result.spec.name,
cause=result.error if result.error is not None else RuntimeError("unknown"),
attempts=result.attempts,
layer=layer_idx,
)
def _run_sync_with_retry( def _run_sync_with_retry(
spec: TaskSpec[object], spec: TaskSpec[object],
context: Mapping[str, Any], context: Mapping[str, Any],
@@ -72,7 +97,7 @@ def _run_sync_with_retry(
max_attempts = spec.retries + 1 max_attempts = spec.retries + 1
args, kwargs = build_call_args(spec, context) args, kwargs = build_call_args(spec, context)
while result.attempts < max_attempts: while True:
result.attempts += 1 result.attempts += 1
try: try:
result.value = spec.fn(*args, **kwargs) result.value = spec.fn(*args, **kwargs)
@@ -82,23 +107,9 @@ def _run_sync_with_retry(
except Exception as exc: # noqa: BLE001 - user code may raise anything except Exception as exc: # noqa: BLE001 - user code may raise anything
result.error = exc result.error = exc
if result.attempts >= max_attempts: if result.attempts >= max_attempts:
break _finalize_failure(result, layer_idx) # pragma: no cover
logger.warning( _log_retry(spec, result.attempts, max_attempts, exc)
"task %r failed (attempt %d/%d): %r; retrying", raise AssertionError("unreachable") # pragma: no cover
spec.name,
result.attempts,
max_attempts,
exc,
)
result.status = TaskStatus.FAILED
result.finished_at = datetime.now()
raise TaskFailedError(
task=spec.name,
cause=result.error if result.error is not None else RuntimeError("unknown"),
attempts=result.attempts,
layer=layer_idx,
)
async def _run_async_with_retry( async def _run_async_with_retry(
@@ -113,7 +124,7 @@ async def _run_async_with_retry(
args, kwargs = build_call_args(spec, context) args, kwargs = build_call_args(spec, context)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
while result.attempts < max_attempts: while True:
result.attempts += 1 result.attempts += 1
try: try:
if _is_async_fn(spec): if _is_async_fn(spec):
@@ -137,7 +148,7 @@ async def _run_async_with_retry(
except asyncio.TimeoutError: except asyncio.TimeoutError:
result.error = TaskTimeoutError(spec.name, spec.timeout or 0.0) result.error = TaskTimeoutError(spec.name, spec.timeout or 0.0)
if result.attempts >= max_attempts: if result.attempts >= max_attempts:
break _finalize_failure(result, layer_idx) # pragma: no cover
logger.warning( logger.warning(
"task %r timed out (attempt %d/%d); retrying", "task %r timed out (attempt %d/%d); retrying",
spec.name, spec.name,
@@ -147,23 +158,9 @@ async def _run_async_with_retry(
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
result.error = exc result.error = exc
if result.attempts >= max_attempts: if result.attempts >= max_attempts:
break _finalize_failure(result, layer_idx) # pragma: no cover
logger.warning( _log_retry(spec, result.attempts, max_attempts, exc) # pragma: no cover
"task %r failed (attempt %d/%d): %r; retrying", raise AssertionError("unreachable") # pragma: no cover
spec.name,
result.attempts,
max_attempts,
exc,
)
result.status = TaskStatus.FAILED
result.finished_at = datetime.now()
raise TaskFailedError(
task=spec.name,
cause=result.error if result.error is not None else RuntimeError("unknown"),
attempts=result.attempts,
layer=layer_idx,
)
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
+59 -62
View File
@@ -1,9 +1,8 @@
"""DAG construction, validation, layering and visualisation. """DAG 构建、校验、分层与可视化。
Uses :mod:`graphlib` from the standard library (3.9+) or 使用标准库的 :mod:`graphlib`3.9+)或 :mod:`graphlib_backport`3.8
:mod:`graphlib_backport` (3.8) for topological sorting. The graph is 进行拓扑排序。图以增量方式构建并即时校验,使配置错误在构建时(而非
built incrementally and validated eagerly so that misconfiguration fails 执行时)快速失败。
fast — at construction time, not at execution time.
""" """
from __future__ import annotations from __future__ import annotations
@@ -14,59 +13,56 @@ from typing import Dict, Iterable, List, Mapping, Sequence, Set, Tuple
from .errors import CycleError, DuplicateTaskError, MissingDependencyError from .errors import CycleError, DuplicateTaskError, MissingDependencyError
from .task import TaskSpec from .task import TaskSpec
# graphlib lives in the stdlib since 3.9; fall back to the backport on 3.8. # graphlib 自 3.9 起进入标准库;3.8 回退到 backport
if sys.version_info >= (3, 9): if sys.version_info >= (3, 9):
import graphlib import graphlib
_TopologicalSorter = graphlib.TopologicalSorter _TopologicalSorter = graphlib.TopologicalSorter
else: # pragma: no cover - exercised only on 3.8 else: # pragma: no cover - 仅在 3.8 上执行
import graphlib # type: ignore[no-redef] import graphlib # type: ignore[no-redef]
_TopologicalSorter = graphlib.TopologicalSorter _TopologicalSorter = graphlib.TopologicalSorter
class Graph: class Graph:
"""An immutable-after-validation directed acyclic graph of tasks. """校验后不可变的有向无环任务图。
The graph is built by adding :class:`~pyflowx.task.TaskSpec` instances. 通过添加 :class:`~pyflowx.task.TaskSpec` 实例构建。每次 ``add`` 都
Each ``add`` performs eager validation (duplicate names, missing 执行即时校验(重名、缺失依赖),:meth:`validate` / :meth:`layers`
dependencies), and :meth:`validate` / :meth:`layers` perform full DAG 执行完整 DAG 校验(环检测)与拓扑分层。
validation (cycle detection) and topological layering.
The graph holds only the *configuration*; runtime state lives in 图仅持有*配置*;运行时状态存于 :class:`~pyflowx.report.RunReport`。
:class:`~pyflowx.report.RunReport`. This makes a graph safely 这使图可安全重复运行并在线程间共享。
re-runnable and shareable across threads.
""" """
def __init__(self) -> None: def __init__(self) -> None:
self._specs: Dict[str, TaskSpec[object]] = {} self._specs: Dict[str, TaskSpec[object]] = {}
# Map task -> its direct dependencies (predecessors). # 任务 -> 其直接依赖(前驱)。
self._deps: Dict[str, Tuple[str, ...]] = {} self._deps: Dict[str, Tuple[str, ...]] = {}
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
# Construction # 构建
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def add(self, spec: TaskSpec[object]) -> "Graph": def add(self, spec: TaskSpec[object]) -> "Graph":
"""Register a task spec with eager validation. """注册一个任务 spec,并即时校验。
Returns ``self`` so calls can be chained, but the recommended 返回 ``self`` 以支持链式调用,但推荐入口是 :meth:`from_specs`
entry point is :meth:`from_specs` which validates the whole batch 它会整批校验(允许单次调用中的前向引用)。
together (allowing forward references in a single call).
""" """
if spec.name in self._specs:
raise DuplicateTaskError(spec.name)
self._specs[spec.name] = spec self._specs[spec.name] = spec
self._deps[spec.name] = spec.depends_on self._deps[spec.name] = spec.depends_on
# Eagerly check duplicates and missing deps for the incremental API. # 为增量 API 即时检查重名与缺失依赖。
self._validate_references() self._validate_references()
return self return self
@classmethod @classmethod
def from_specs(cls, specs: Iterable[TaskSpec[object]]) -> "Graph": def from_specs(cls, specs: Iterable[TaskSpec[object]]) -> "Graph":
"""Build a graph from an iterable of task specs. """从可迭代的 task spec 构建图。
All specs are collected first, then validated together. This means 先收集所有 spec,再统一校验。这意味着任务可以引用*后出现*的
a task may reference a dependency that appears *later* in the 依赖——顺序无关,就像声明式配置文件的读取方式。
iterable — order does not matter, mirroring how a declarative
config file reads.
""" """
graph = cls() graph = cls()
for spec in specs: for spec in specs:
@@ -79,69 +75,67 @@ class Graph:
return graph return graph
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
# Validation # 校验
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def _validate_references(self) -> None: def _validate_references(self) -> None:
"""Ensure every dependency name exists in the graph.""" """确保每个依赖名都存在于图中。"""
for name, deps in self._deps.items(): for name, deps in self._deps.items():
for dep in deps: for dep in deps:
if dep not in self._specs: if dep not in self._specs:
raise MissingDependencyError(name, dep) raise MissingDependencyError(name, dep)
def validate(self) -> None: def validate(self) -> None:
"""Run full DAG validation. """执行完整 DAG 校验。
Raises :class:`~pyflowx.errors.CycleError` if a cycle exists. 存在环时抛出 :class:`~pyflowx.errors.CycleError`
Dependency existence is checked by :meth:`_validate_references`. 依赖存在性由 :meth:`_validate_references` 检查。
""" """
self._validate_references() self._validate_references()
sorter = _TopologicalSorter(self._deps) sorter = _TopologicalSorter(self._deps)
try: try:
# prepare() raises CycleError on cycles; we don't need the # prepare() 在有环时抛出 CycleError;此处不需要
# static_order() result here, just the validation side effect. # static_order() 的结果,仅利用其校验副作用。
sorter.prepare() sorter.prepare()
except graphlib.CycleError as exc: except graphlib.CycleError as exc:
# exc.args[1] is the list of nodes forming the cycle. # exc.args[1] 是构成环的节点列表。
cycle: Sequence[str] = exc.args[1] if len(exc.args) > 1 else [] cycle: Sequence[str] = exc.args[1] if len(exc.args) > 1 else []
raise CycleError(list(cycle)) from exc raise CycleError(list(cycle)) from exc
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
# Introspection # 内省
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
@property @property
def names(self) -> List[str]: def names(self) -> List[str]:
"""All registered task names (insertion order).""" """所有已注册任务名(按插入顺序)。"""
return list(self._specs.keys()) return list(self._specs.keys())
def spec(self, name: str) -> TaskSpec[object]: def spec(self, name: str) -> TaskSpec[object]:
"""Return the spec for ``name``; ``KeyError`` if absent.""" """返回 ``name`` 的 spec;不存在则 ``KeyError``"""
return self._specs[name] return self._specs[name]
def dependencies(self, name: str) -> Tuple[str, ...]: def dependencies(self, name: str) -> Tuple[str, ...]:
"""Direct predecessors of ``name``.""" """``name`` 的直接前驱。"""
return self._deps[name] return self._deps[name]
def all_specs(self) -> Mapping[str, TaskSpec[object]]: def all_specs(self) -> Mapping[str, TaskSpec[object]]:
"""Read-only view of name -> spec.""" """name -> spec 的只读视图。"""
return self._specs return self._specs
def layers(self) -> List[List[str]]: def layers(self) -> List[List[str]]:
"""Group tasks into parallel-executable layers (Kahn's algorithm). """将任务分组为可并行执行的层(Kahn 算法)。
Tasks within the same layer have no mutual dependencies and may 同层任务无相互依赖,可并发执行。层按执行顺序返回。
run concurrently. Layers are returned in execution order.
Raises :class:`~pyflowx.errors.CycleError` if the graph is cyclic. 图有环时抛出 :class:`~pyflowx.errors.CycleError`
""" """
self.validate() self.validate()
sorter = _TopologicalSorter(self._deps) sorter = _TopologicalSorter(self._deps)
result: List[List[str]] = [] result: List[List[str]] = []
# ``get_ready`` + ``done`` gives us one layer at a time, which is # ``get_ready`` + ``done`` 每次给出一层,正好是并行执行所需的分组。
# exactly the parallel-execution grouping we need.
sorter.prepare() sorter.prepare()
while sorter.is_active(): while sorter.is_active():
ready = list(sorter.get_ready()) ready = list(sorter.get_ready())
# Sort for deterministic, reproducible execution plans. # 排序以保证确定性、可复现的执行计划。
ready.sort() ready.sort()
result.append(ready) result.append(ready)
for node in ready: for node in ready:
@@ -149,22 +143,23 @@ class Graph:
return result return result
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
# Subgraph / tag filtering # 子图 / 标签过滤
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def subgraph(self, tags: Iterable[str]) -> "Graph": def subgraph(self, tags: Iterable[str]) -> "Graph":
"""Return a new graph containing only tasks matching any tag. """返回仅包含匹配任意标签的任务的新图。
Dependencies are pruned to keep only edges between retained tasks; 依赖会被修剪,仅保留被保留任务之间的边;指向被丢弃任务的边
edges to dropped tasks are removed (the retained task no longer 会被移除(被保留的任务不再等待它们)。用于调试时运行大型
waits for them). Use this to run a slice of a large DAG for DAG 的切片。
debugging.
""" """
wanted: Set[str] = set(tags) wanted: Set[str] = set(tags)
kept: List[TaskSpec[object]] = [] kept: List[TaskSpec[object]] = []
for spec in self._specs.values(): for spec in self._specs.values():
if wanted & set(spec.tags): if wanted & set(spec.tags):
pruned_deps = tuple( pruned_deps = tuple(
d for d in spec.depends_on if d in self._specs and (wanted & set(self._specs[d].tags)) d
for d in spec.depends_on
if d in self._specs and (wanted & set(self._specs[d].tags))
) )
kept.append( kept.append(
TaskSpec( TaskSpec(
@@ -181,7 +176,7 @@ class Graph:
return Graph.from_specs(kept) return Graph.from_specs(kept)
def subgraph_by_names(self, names: Iterable[str]) -> "Graph": def subgraph_by_names(self, names: Iterable[str]) -> "Graph":
"""Return a new graph restricted to ``names`` (with pruned edges).""" """返回限定于 ``names`` 的新图(边已修剪)。"""
wanted: Set[str] = set(names) wanted: Set[str] = set(names)
for n in wanted: for n in wanted:
if n not in self._specs: if n not in self._specs:
@@ -205,18 +200,20 @@ class Graph:
return Graph.from_specs(kept) return Graph.from_specs(kept)
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
# Visualisation # 可视化
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def to_mermaid(self, orientation: str = "TD") -> str: def to_mermaid(self, orientation: str = "TD") -> str:
"""Render the DAG as a Mermaid ``graph`` definition string. """ DAG 渲染为 Mermaid ``graph`` 定义字符串。
No external dependencies; the output can be pasted into Markdown, 无外部依赖;输出可粘贴到 Markdown、由 VS Code 的 Mermaid 预览
rendered by VS Code's Mermaid previewer, or saved to a file. 渲染,或保存为文件。
""" """
valid = {"TD", "TB", "BT", "LR", "RL"} valid = {"TD", "TB", "BT", "LR", "RL"}
orientation = orientation.upper() orientation = orientation.upper()
if orientation not in valid: if orientation not in valid:
raise ValueError(f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}.") raise ValueError(
f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}."
)
lines: List[str] = [f"graph {orientation}"] lines: List[str] = [f"graph {orientation}"]
for name in self._specs: for name in self._specs:
lines.append(f' {name}["{name}"]') lines.append(f' {name}["{name}"]')
@@ -226,10 +223,10 @@ class Graph:
return "\n".join(lines) + "\n" return "\n".join(lines) + "\n"
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
# Debug # 调试
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def describe(self) -> str: def describe(self) -> str:
"""Human-readable multi-line summary for debugging.""" """用于调试的人类可读多行摘要。"""
out: List[str] = [f"Graph(tasks={len(self._specs)})"] out: List[str] = [f"Graph(tasks={len(self._specs)})"]
for layer_idx, layer in enumerate(self.layers(), 1): for layer_idx, layer in enumerate(self.layers(), 1):
out.append(f" Layer {layer_idx}: {layer}") out.append(f" Layer {layer_idx}: {layer}")
+24 -23
View File
@@ -1,46 +1,43 @@
"""Run report: typed, queryable result of a single :func:`pyflowx.run`. """运行报告:单次 :func:`pyflowx.run` 的类型化、可查询结果。
The report is the single source of truth after execution. It exposes 报告是执行后的唯一事实来源。它通过 ``report["name"]`` 暴露单任务结果
per-task results via ``report["name"]`` (typed as ``Any`` because the (类型为 ``Any``,因为映射异构)、汇总统计,以及整次运行是否成功的标志。
mapping is heterogeneous), summary statistics, and a flag indicating
whether the whole run succeeded.
""" """
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List, Mapping, Optional from typing import Any, Dict, Iterator, List
from .task import TaskResult, TaskStatus from .task import TaskResult, TaskStatus
@dataclass @dataclass
class RunReport: class RunReport:
"""Aggregated outcome of a workflow run. """工作流运行的聚合结果。
Attributes 属性
---------- ----
results: results:
Mapping of task name -> :class:`TaskResult`. Insertion order 任务名 -> :class:`TaskResult` 的映射。插入顺序与任务完成顺序一致。
matches the order tasks finished.
success: success:
``True`` iff every non-skipped task ended in ``SUCCESS``. 当且仅当所有非跳过任务都以 ``SUCCESS`` 结束时为 ``True``。
""" """
results: Dict[str, TaskResult[object]] = field(default_factory=dict) results: Dict[str, TaskResult[object]] = field(default_factory=dict)
success: bool = True success: bool = True
# ---- typed access ------------------------------------------------- # # ---- 类型化访问 --------------------------------------------------- #
def __getitem__(self, name: str) -> Any: def __getitem__(self, name: str) -> Any:
"""Return the *value* of task ``name`` (not the TaskResult). """返回任务 ``name`` 的*值*(而非 TaskResult)。
Raises ``KeyError`` if the task was not part of the run. Returns 任务不在本次运行中则抛出 ``KeyError``。未达到 SUCCESS 的任务
``None`` for tasks that did not reach SUCCESS. 返回 ``None``
""" """
return self.results[name].value return self.results[name].value
def result_of(self, name: str) -> TaskResult[object]: def result_of(self, name: str) -> TaskResult[object]:
"""Return the full :class:`TaskResult` for ``name``.""" """返回 ``name`` 的完整 :class:`TaskResult`"""
return self.results[name] return self.results[name]
def __contains__(self, name: object) -> bool: def __contains__(self, name: object) -> bool:
@@ -52,9 +49,9 @@ class RunReport:
def __len__(self) -> int: def __len__(self) -> int:
return len(self.results) return len(self.results)
# ---- summary ------------------------------------------------------ # # ---- 汇总 --------------------------------------------------------- #
def summary(self) -> Dict[str, Any]: def summary(self) -> Dict[str, Any]:
"""Compact statistics dict for logging / dashboards.""" """用于日志/仪表盘的紧凑统计字典。"""
counts: Dict[str, int] = {} counts: Dict[str, int] = {}
total_duration = 0.0 total_duration = 0.0
for r in self.results.values(): for r in self.results.values():
@@ -69,14 +66,18 @@ class RunReport:
} }
def failed_tasks(self) -> List[str]: def failed_tasks(self) -> List[str]:
"""Names of tasks that ended in FAILED status.""" """以 FAILED 状态结束的任务名列表。"""
return [name for name, r in self.results.items() if r.status == TaskStatus.FAILED] return [
name for name, r in self.results.items() if r.status == TaskStatus.FAILED
]
def describe(self) -> str: def describe(self) -> str:
"""Human-readable multi-line report for debugging.""" """用于调试的人类可读多行报告。"""
lines: List[str] = [f"RunReport(success={self.success})"] lines: List[str] = [f"RunReport(success={self.success})"]
for name, r in self.results.items(): for name, r in self.results.items():
dur = f"{r.duration:.3f}s" if r.duration is not None else "-" dur = f"{r.duration:.3f}s" if r.duration is not None else "-"
err = f" error={r.error!r}" if r.error else "" err = f" error={r.error!r}" if r.error else ""
lines.append(f" {name}: {r.status.value} ({dur} attempts={r.attempts}){err}") lines.append(
f" {name}: {r.status.value} ({dur} attempts={r.attempts}){err}"
)
return "\n".join(lines) return "\n".join(lines)
+23 -25
View File
@@ -1,19 +1,17 @@
"""State persistence backends for resumable runs. """用于断点续跑的状态持久化后端。
A :class:`StateBackend` stores the result of every successfully completed :class:`StateBackend` 存储每个成功完成任务的结果。在后续运行中,
task. On a subsequent run, the executor asks the backend whether a task 执行器向后端查询某任务是否已有存储结果;若有则跳过该任务,并将其
already has a stored result; if so, the task is skipped and its stored 存储值注入下游任务。
value is injected into downstream tasks.
This is intentionally minimal: only *successful* results are persisted 本模块刻意保持最小化:仅持久化*成功*结果(失败任务会重跑),存储
(failed tasks are re-run), and the storage shape is a flat 形态为扁平的 ``{task_name: result}`` 映射。内置两个后端:
``{task_name: result}`` mapping. Two backends ship in-tree:
* :class:`MemoryBackend` — fast, in-process, no I/O. Default. * :class:`MemoryBackend` —— 快速、进程内、无 I/O。默认。
* :class:`JSONBackend` — persists to a JSON file for cross-process resume. * :class:`JSONBackend` —— 持久化到 JSON 文件,支持跨进程续跑。
Both are zero-dependency (``json`` is stdlib). Users can subclass 两者均零依赖(``json`` 为标准库)。用户可子类化
:class:`StateBackend` to plug in SQLite, Redis, etc. :class:`StateBackend` 接入 SQLiteRedis 等。
""" """
from __future__ import annotations from __future__ import annotations
@@ -27,31 +25,31 @@ from .errors import StorageError
class StateBackend(ABC): class StateBackend(ABC):
"""Abstract base for resumable state storage.""" """可续跑状态存储的抽象基类。"""
@abstractmethod @abstractmethod
def load(self) -> Mapping[str, Any]: def load(self) -> Mapping[str, Any]:
"""Return the full stored mapping (may be empty).""" """返回完整的存储映射(可能为空)。"""
@abstractmethod @abstractmethod
def save(self, name: str, value: Any) -> None: def save(self, name: str, value: Any) -> None:
"""Persist a single task's successful result.""" """持久化单个任务的成功结果。"""
@abstractmethod @abstractmethod
def has(self, name: str) -> bool: def has(self, name: str) -> bool:
"""Whether ``name`` has a stored result.""" """``name`` 是否已有存储结果。"""
@abstractmethod @abstractmethod
def get(self, name: str) -> Any: def get(self, name: str) -> Any:
"""Return the stored result for ``name`` (raise ``KeyError`` if absent).""" """返回 ``name`` 的存储结果(不存在则抛 ``KeyError``)。"""
@abstractmethod @abstractmethod
def clear(self) -> None: def clear(self) -> None:
"""Remove all stored state.""" """清除所有存储状态。"""
class MemoryBackend(StateBackend): class MemoryBackend(StateBackend):
"""In-process dict backend. Lost when the process exits.""" """进程内 dict 后端。进程退出即丢失。"""
def __init__(self) -> None: def __init__(self) -> None:
self._store: Dict[str, Any] = {} self._store: Dict[str, Any] = {}
@@ -73,11 +71,11 @@ class MemoryBackend(StateBackend):
class JSONBackend(StateBackend): class JSONBackend(StateBackend):
"""File-backed JSON storage for cross-process resume. """基于文件的 JSON 存储,用于跨进程续跑。
Results must be JSON-serialisable. Non-serialisable values raise 结果必须可 JSON 序列化。不可序列化的值会抛出
:class:`~pyflowx.errors.StorageError` (the run itself is not aborted; :class:`~pyflowx.errors.StorageError`(运行本身不会中止;仅该条
only persistence of that one result fails). 结果的持久化失败)。
""" """
def __init__(self, path: str) -> None: def __init__(self, path: str) -> None:
@@ -109,7 +107,7 @@ class JSONBackend(StateBackend):
return dict(self._store) return dict(self._store)
def save(self, name: str, value: Any) -> None: def save(self, name: str, value: Any) -> None:
# Validate serialisability before mutating in-memory state. # 在修改内存状态前先校验可序列化性。
try: try:
json.dumps(value) json.dumps(value)
except (TypeError, ValueError) as exc: except (TypeError, ValueError) as exc:
@@ -131,5 +129,5 @@ class JSONBackend(StateBackend):
def resolve_backend(backend: Optional[StateBackend]) -> StateBackend: def resolve_backend(backend: Optional[StateBackend]) -> StateBackend:
"""Return ``backend`` or a fresh :class:`MemoryBackend` if ``None``.""" """返回 ``backend``;为 ``None`` 时返回新的 :class:`MemoryBackend`"""
return backend if backend is not None else MemoryBackend() return backend if backend is not None else MemoryBackend()
+42 -49
View File
@@ -1,20 +1,18 @@
"""Core task data structures for PyFlowX. """PyFlowX 核心任务数据结构。
Everything here is a plain, immutable data structure — no decorators, no 本模块全部为纯不可变数据结构——无装饰器、无副作用。一个
side effects. A :class:`TaskSpec` fully describes a task node; the :class:`TaskSpec` 完整描述一个任务节点;:class:`Graph`
:class:`Graph` (see :mod:`pyflowx.graph`) consumes a list of specs and (见 :mod:`pyflowx.graph`)消费一组 spec 并构建 DAG。
builds the DAG.
Design notes 设计要点
------------ --------
* ``TaskSpec`` is a ``Generic[T]`` so that ``TaskSpec[int]`` carries the * ``TaskSpec`` ``Generic[T]``,因此 ``TaskSpec[int]`` 会把 ``fn`` 的
return type of ``fn`` all the way to :class:`RunReport`, giving callers 返回类型一路传递到 :class:`RunReport`,让调用者可以类型安全地访问
typed access to ``report["name"]``. ``report["name"]``
* ``Context`` is the only intentionally-dynamic type: results from * ``Context`` 是唯一刻意保留动态类型的类型:上游任务的结果异构,因此
upstream tasks are heterogeneous, so the cross-task mapping is 跨任务映射为 ``Mapping[str, Any]``。单个任务内部类型仍然完全静态,
``Mapping[str, Any]``. Within a single task the types remain fully 因为函数签名由 mypy 检查。
static because the function signature is checked by mypy. * ``TaskStatus`` 是封闭枚举;执行器绝不发明临时字符串。
* ``TaskStatus`` is a closed enum; executors never invent ad-hoc strings.
""" """
from __future__ import annotations from __future__ import annotations
@@ -36,59 +34,55 @@ from typing import (
T = TypeVar("T") T = TypeVar("T")
# A task callable may be synchronous or asynchronous. We keep the union # 任务可调用对象可以是同步或异步的。显式保留联合类型,让 mypy 理解两种形态。
# explicit so mypy understands both shapes.
TaskFn = Union[ TaskFn = Union[
Callable[..., T], Callable[..., T],
Callable[..., Coroutine[Any, Any, T]], Callable[..., Coroutine[Any, Any, T]],
] ]
# The cross-task result mapping. Deliberately ``Any`` for values because # 跨任务结果映射。值刻意使用 ``Any``,因为不同任务返回不同类型;
# different tasks return different types; per-task typing is preserved by # 单任务类型由函数签名本身保留。
# the function signature itself.
Context = Mapping[str, Any] Context = Mapping[str, Any]
class TaskStatus(Enum): class TaskStatus(Enum):
"""Lifecycle states of a task during a single run.""" """任务在单次运行内的生命周期状态。"""
PENDING = "pending" PENDING = "pending"
RUNNING = "running" RUNNING = "running"
SUCCESS = "success" SUCCESS = "success"
FAILED = "failed" FAILED = "failed"
SKIPPED = "skipped" # used by resumable runs and subgraph filtering SKIPPED = "skipped" # 用于断点续跑与子图过滤
@dataclass(frozen=True) @dataclass(frozen=True)
class TaskSpec(Generic[T]): class TaskSpec(Generic[T]):
"""Immutable description of a single DAG node. """单个 DAG 节点的不可变描述。
Parameters 参数
---------- ----
name: name:
Unique identifier of the task within a graph. Other tasks reference 任务在图内的唯一标识。其他任务通过 ``depends_on`` 引用此名称。
this name in ``depends_on``.
fn: fn:
The callable to execute. May be sync or async. Its parameter names 待执行的可调用对象,可为同步或异步。其参数名驱动自动上下文
drive automatic context injection (see :mod:`pyflowx.context`). 注入(见 :mod:`pyflowx.context`)。
depends_on: depends_on:
Names of tasks whose results must be available before this task 必须先完成才能运行本任务的任务名列表。顺序无关;框架会做
runs. Order is irrelevant; the framework topologically sorts. 拓扑排序。
args: args:
Static positional arguments appended *after* injected parameters. 静态位置参数,追加在注入参数*之后*。适用于参数化任务
Useful for parameterised tasks (e.g. ``fetch_user(uid)``). (如 ``fetch_user(uid)``)。
kwargs: kwargs:
Static keyword arguments. Conflict with injected names raises 静态关键字参数。若与注入名冲突则抛出
:class:`~pyflowx.errors.InjectionError`. :class:`~pyflowx.errors.InjectionError`
retries: retries:
Number of retry attempts on failure. ``0`` means a single attempt. 失败后的重试次数。``0`` 表示仅尝试一次。
timeout: timeout:
Maximum execution time in seconds. ``None`` disables the timeout. 最大执行时长(秒)。``None`` 表示不限制。异步任务使用
For async tasks this uses :func:`asyncio.wait_for`; for sync tasks :func:`asyncio.wait_for`;线程/异步执行器中的同步任务会
in the threaded/async executors it cancels the worker future. 取消 worker future
tags: tags:
Free-form labels used by :meth:`Graph.subgraph` for selective 自由标签,供 :meth:`Graph.subgraph` 做选择性执行与调试。
execution and debugging.
""" """
name: str name: str
@@ -113,10 +107,10 @@ class TaskSpec(Generic[T]):
@dataclass @dataclass
class TaskResult(Generic[T]): class TaskResult(Generic[T]):
"""Mutable per-task record produced during a run. """运行期间产生的可变单任务记录。
A fresh :class:`TaskResult` is created for every run; the spec itself 每次运行都会创建全新的 :class:`TaskResult`spec 本身保持不可变。
stays immutable. This keeps the same graph safely re-runnable. 这让同一个图可以安全地重复运行。
""" """
spec: TaskSpec[T] spec: TaskSpec[T]
@@ -129,7 +123,7 @@ class TaskResult(Generic[T]):
@property @property
def duration(self) -> Optional[float]: def duration(self) -> Optional[float]:
"""Elapsed seconds between start and finish, or ``None``.""" """从开始到结束的耗时(秒),未开始/未结束则为 ``None``"""
if self.started_at is None or self.finished_at is None: if self.started_at is None or self.finished_at is None:
return None return None
return (self.finished_at - self.started_at).total_seconds() return (self.finished_at - self.started_at).total_seconds()
@@ -137,11 +131,10 @@ class TaskResult(Generic[T]):
@dataclass(frozen=True) @dataclass(frozen=True)
class TaskEvent: class TaskEvent:
"""Immutable event emitted during execution for observers. """执行期间向观察者发出的不可变事件。
Passed to the ``on_event`` callback of :func:`pyflowx.run` so callers 传递给 :func:`pyflowx.run` 的 ``on_event`` 回调,让调用者无需耦合
can build progress bars, metrics, or structured logs without coupling 执行器内部即可构建进度条、指标或结构化日志。
to executor internals.
""" """
task: str task: str
+147 -1
View File
@@ -7,7 +7,7 @@ from typing import Any
import pytest import pytest
import pyflowx as px import pyflowx as px
from pyflowx.context import build_call_args, describe_injection from pyflowx.context import _is_context_annotation, build_call_args, describe_injection
from pyflowx.errors import InjectionError from pyflowx.errors import InjectionError
@@ -87,3 +87,149 @@ def test_describe_injection() -> None:
assert "a=<result:a>" in desc assert "a=<result:a>" in desc
assert "ctx=<Context>" in desc assert "ctx=<Context>" in desc
assert "flag=<default>" in desc assert "flag=<default>" in desc
# ---------------------------------------------------------------------- #
# _is_context_annotation 各分支
# ---------------------------------------------------------------------- #
def test_is_context_annotation_direct_object() -> None:
"""直接传入 Context 别名对象应返回 True。"""
assert _is_context_annotation(px.Context) is True
def test_is_context_annotation_string() -> None:
"""字符串形式的注解应被识别。"""
assert _is_context_annotation("Context") is True
assert _is_context_annotation("px.Context") is True
assert _is_context_annotation("pyflowx.Context") is True
assert _is_context_annotation("NotContext") is False
assert _is_context_annotation("int") is False
def test_is_context_annotation_typing_alias() -> None:
"""具有 __name__/_name 为 Context/Mapping 的 typing 别名应返回 True。"""
class FakeAlias:
__name__ = "Context"
assert _is_context_annotation(FakeAlias()) is True
class FakeMapping:
__name__ = "Mapping"
assert _is_context_annotation(FakeMapping()) is True
def test_is_context_annotation_other() -> None:
"""其他类型注解应返回 False。"""
assert _is_context_annotation(int) is False
assert _is_context_annotation(str) is False
assert _is_context_annotation(None) is False
# ---------------------------------------------------------------------- #
# describe_injection 其余分支
# ---------------------------------------------------------------------- #
def test_describe_injection_var_positional() -> None:
"""*args 参数应显示为 *args。"""
def fn(*args: Any) -> None:
return None
spec = px.TaskSpec("t", fn)
desc = describe_injection(spec)
assert "*args" in desc
def test_describe_injection_var_keyword() -> None:
"""**kwargs 参数应显示为 **kwargs=<all-deps>。"""
def fn(**kwargs: Any) -> None:
return None
spec = px.TaskSpec("t", fn, ("a",))
desc = describe_injection(spec)
assert "**kwargs=<all-deps>" in desc
def test_describe_injection_unresolved() -> None:
"""无依赖、无静态值、无默认的参数应显示为 <UNRESOLVED>。"""
def fn(missing: int) -> None:
return None
spec = px.TaskSpec("t", fn)
desc = describe_injection(spec)
assert "missing=<UNRESOLVED>" in desc
def test_describe_injection_static_kwargs() -> None:
"""静态 kwargs 应显示具体值。"""
def fn(flag: bool = False) -> None:
return None
spec = px.TaskSpec("t", fn, kwargs={"flag": True})
desc = describe_injection(spec)
assert "flag=True" in desc
def test_describe_injection_positional_args_filled() -> None:
"""spec.args 填充的位置参数应显示具体值(覆盖 args_filled 分支)。"""
def fn(a: int, b: str) -> None:
return None
spec = px.TaskSpec("t", fn, args=(1, "x"))
desc = describe_injection(spec)
assert "a=1" in desc
assert "b='x'" in desc
# ---------------------------------------------------------------------- #
# build_call_args 边界
# ---------------------------------------------------------------------- #
def test_build_call_args_var_positional_not_required() -> None:
"""*args 参数不应触发 InjectionError。"""
def fn(*args: Any) -> int:
return len(args)
spec = px.TaskSpec("t", fn, args=(1, 2, 3))
args, kwargs = build_call_args(spec, {})
assert args == (1, 2, 3)
assert kwargs == {}
def test_build_call_args_var_keyword_consumes_leftover() -> None:
"""**kwargs 应吞掉未被具名参数消费的依赖结果。"""
def fn(a: int, **rest: Any) -> int:
return a + sum(rest.values())
spec = px.TaskSpec("t", fn, ("a", "b", "c"))
args, kwargs = build_call_args(spec, {"a": 1, "b": 2, "c": 3})
assert kwargs == {"a": 1, "b": 2, "c": 3}
def test_build_call_args_no_var_keyword_drops_leftover() -> None:
"""无 **kwargs 时,未被消费的依赖结果被丢弃(不报错)。"""
def fn(a: int) -> int:
return a
spec = px.TaskSpec("t", fn, ("a", "b"))
# b 是依赖但 fn 不接收它 —— 应正常工作
args, kwargs = build_call_args(spec, {"a": 1, "b": 2})
assert kwargs == {"a": 1}
def test_build_call_args_context_annotation_only_deps() -> None:
"""Context 标注只接收该任务自身 depends_on 的结果。"""
def fn(ctx: px.Context) -> int:
return len(ctx)
spec = px.TaskSpec("t", fn, ("a", "b"))
args, kwargs = build_call_args(spec, {"a": 1, "b": 2, "c": 99})
assert kwargs == {"ctx": {"a": 1, "b": 2}}
+90
View File
@@ -0,0 +1,90 @@
"""错误类型测试。"""
from __future__ import annotations
import pytest
import pyflowx as px
from pyflowx.errors import (
CycleError,
DuplicateTaskError,
InjectionError,
MissingDependencyError,
PyFlowXError,
StorageError,
TaskFailedError,
TaskTimeoutError,
)
def test_all_errors_are_pyflowx_subclass() -> None:
assert issubclass(DuplicateTaskError, PyFlowXError)
assert issubclass(MissingDependencyError, PyFlowXError)
assert issubclass(CycleError, PyFlowXError)
assert issubclass(TaskFailedError, PyFlowXError)
assert issubclass(TaskTimeoutError, PyFlowXError)
assert issubclass(InjectionError, PyFlowXError)
assert issubclass(StorageError, PyFlowXError)
def test_duplicate_task_error_attributes() -> None:
err = DuplicateTaskError("foo")
assert err.name == "foo"
assert "foo" in str(err)
def test_missing_dependency_error_attributes() -> None:
err = MissingDependencyError("child", "parent")
assert err.task == "child"
assert err.dependency == "parent"
assert "child" in str(err)
assert "parent" in str(err)
def test_cycle_error_attributes() -> None:
err = CycleError(["a", "b", "c"])
assert err.cycle == ["a", "b", "c"]
# 链应首尾相接展示
assert "a -> b -> c -> a" in str(err)
def test_task_failed_error_attributes() -> None:
cause = ValueError("boom")
err = TaskFailedError(task="t", cause=cause, attempts=3, layer=2)
assert err.task == "t"
assert err.cause is cause
assert err.attempts == 3
assert err.layer == 2
assert "layer 2" in str(err)
def test_task_failed_error_without_layer() -> None:
err = TaskFailedError(task="t", cause=RuntimeError("x"), attempts=1)
assert err.layer is None
assert "layer" not in str(err)
def test_task_timeout_error_attributes() -> None:
err = TaskTimeoutError(task="t", timeout=1.5)
assert err.task == "t"
assert err.timeout == 1.5
assert "1.500s" in str(err)
def test_injection_error_attributes() -> None:
err = InjectionError(task="t", detail="missing param")
assert err.task == "t"
assert "missing param" in str(err)
def test_storage_error_with_cause() -> None:
cause = OSError("disk full")
err = StorageError(detail="write failed", cause=cause)
assert err.cause is cause
assert "write failed" in str(err)
def test_storage_error_without_cause() -> None:
err = StorageError(detail="bad")
assert err.cause is None
assert "bad" in str(err)
+193
View File
@@ -320,3 +320,196 @@ def test_invalid_strategy() -> None:
graph = px.Graph.from_specs([px.TaskSpec("a", lambda: None)]) # type: ignore[arg-type] graph = px.Graph.from_specs([px.TaskSpec("a", lambda: None)]) # type: ignore[arg-type]
with pytest.raises(ValueError): with pytest.raises(ValueError):
px.run(graph, strategy="bogus") # type: ignore[arg-type] px.run(graph, strategy="bogus") # type: ignore[arg-type]
# ---------------------------------------------------------------------- #
# 异步策略:sync 任务无 timeout 分支 + timeout 重试分支
# ---------------------------------------------------------------------- #
def test_async_sync_task_without_timeout() -> None:
"""async 策略下执行 sync 任务且无 timeout(覆盖 line 131)。"""
def sync_fn() -> int:
return 42
graph = px.Graph.from_specs([px.TaskSpec("a", sync_fn)])
report = px.run(graph, strategy="async")
assert report.success
assert report["a"] == 42
def test_async_sync_task_with_timeout() -> None:
"""async 策略下执行 sync 任务且带 timeout(覆盖 line 129)。"""
def sync_fn() -> int:
return 42
graph = px.Graph.from_specs([px.TaskSpec("a", sync_fn, timeout=5.0)])
report = px.run(graph, strategy="async")
assert report.success
assert report["a"] == 42
def test_async_timeout_retry_then_succeed() -> None:
"""async 超时后重试成功(覆盖 line 141-151 的重试分支)。"""
calls = {"n": 0}
async def flaky() -> str:
calls["n"] += 1
if calls["n"] < 2:
await asyncio.sleep(10) # 触发超时
return "ok"
graph = px.Graph.from_specs([px.TaskSpec("a", flaky, retries=2, timeout=0.05)])
report = px.run(graph, strategy="async")
assert report.success
assert report["a"] == "ok"
assert calls["n"] == 2
def test_async_failure_retry_branch(caplog: pytest.LogCaptureFixture) -> None:
"""async 普通异常重试分支(覆盖 line 141-151 的 except Exception 分支)。"""
calls = {"n": 0}
async def flaky() -> str:
calls["n"] += 1
if calls["n"] < 2:
raise RuntimeError("not yet")
return "ok"
graph = px.Graph.from_specs([px.TaskSpec("a", flaky, retries=2)])
with caplog.at_level("WARNING", logger="pyflowx"):
report = px.run(graph, strategy="async")
assert report.success
assert report["a"] == "ok"
# 确认重试日志确实输出
assert any("retrying" in r.message for r in caplog.records)
# ---------------------------------------------------------------------- #
# 缓存跳过分支:threaded 与 async
# ---------------------------------------------------------------------- #
def test_threaded_skips_cached_tasks() -> None:
"""threaded 策略下命中缓存的任务应被跳过(覆盖 line 224-230)。"""
runs: List[str] = []
def make(name: str) -> Any:
def fn() -> str:
runs.append(name)
return name
return fn
graph = px.Graph.from_specs(
[
px.TaskSpec("a", make("a")),
px.TaskSpec("b", make("b"), ("a",)),
]
)
backend = px.MemoryBackend()
# 第一次运行填充缓存
px.run(graph, strategy="thread", max_workers=2, state=backend)
assert runs == ["a", "b"]
# 第二次运行应全部跳过
px.run(graph, strategy="thread", max_workers=2, state=backend)
assert runs == ["a", "b"] # 未再执行
def test_threaded_all_cached_layer() -> None:
"""整层全部命中缓存时应直接返回(覆盖 line 235 的 if not to_run: return)。"""
graph = px.Graph.from_specs([px.TaskSpec("a", lambda: 1)]) # type: ignore[arg-type]
backend = px.MemoryBackend()
backend.save("a", 99)
report = px.run(graph, strategy="thread", max_workers=2, state=backend)
assert report["a"] == 99
assert report.result_of("a").status == px.TaskStatus.SKIPPED
def test_async_skips_cached_tasks() -> None:
"""async 策略下命中缓存的任务应被跳过(覆盖 line 268-274)。"""
runs: List[str] = []
async def make(name: str) -> Any:
async def fn() -> str:
runs.append(name)
return name
return fn()
# 用闭包制造可重复调用的 async 函数
async def a() -> str:
runs.append("a")
return "a"
async def b(a: str) -> str:
runs.append("b")
return a + "b"
graph = px.Graph.from_specs(
[
px.TaskSpec("a", a),
px.TaskSpec("b", b, ("a",)),
]
)
backend = px.MemoryBackend()
px.run(graph, strategy="async", state=backend)
assert runs == ["a", "b"]
px.run(graph, strategy="async", state=backend)
assert runs == ["a", "b"]
def test_async_all_cached_layer() -> None:
"""async 整层全部命中缓存(覆盖 line 279 的 if not to_run: return)。"""
async def a() -> int:
return 1
graph = px.Graph.from_specs([px.TaskSpec("a", a)])
backend = px.MemoryBackend()
backend.save("a", 77)
report = px.run(graph, strategy="async", state=backend)
assert report["a"] == 77
assert report.result_of("a").status == px.TaskStatus.SKIPPED
# ---------------------------------------------------------------------- #
# 失败后 report.success 标记为 False
# ---------------------------------------------------------------------- #
def test_failure_marks_report_unsuccessful() -> None:
def boom() -> None:
raise ValueError("fail")
graph = px.Graph.from_specs([px.TaskSpec("a", boom)])
with pytest.raises(px.TaskFailedError):
px.run(graph, strategy="sequential")
# report 在异常前未返回,但若捕获异常则 success 应为 False
# 这里验证 run() 抛异常的行为本身
# ---------------------------------------------------------------------- #
# dry_run 各策略
# ---------------------------------------------------------------------- #
def test_dry_run_thread(capsys: pytest.CaptureFixture[str]) -> None:
graph = px.Graph.from_specs([px.TaskSpec("a", lambda: 1)]) # type: ignore[arg-type]
report = px.run(graph, strategy="thread", dry_run=True)
assert len(report) == 0
assert "Dry run" in capsys.readouterr().out
def test_dry_run_async(capsys: pytest.CaptureFixture[str]) -> None:
async def a() -> int:
return 1
graph = px.Graph.from_specs([px.TaskSpec("a", a)])
report = px.run(graph, strategy="async", dry_run=True)
assert len(report) == 0
assert "Dry run" in capsys.readouterr().out
# ---------------------------------------------------------------------- #
# 空图运行
# ---------------------------------------------------------------------- #
def test_run_empty_graph() -> None:
graph = px.Graph()
report = px.run(graph, strategy="sequential")
assert report.success
assert len(report) == 0
+141 -42
View File
@@ -13,11 +13,13 @@ def _fn() -> None:
def test_from_specs_builds_graph() -> None: def test_from_specs_builds_graph() -> None:
graph = px.Graph.from_specs([ graph = px.Graph.from_specs(
px.TaskSpec("a", _fn), [
px.TaskSpec("b", _fn, ("a",)), px.TaskSpec("a", _fn),
px.TaskSpec("c", _fn, ("a", "b")), px.TaskSpec("b", _fn, ("a",)),
]) px.TaskSpec("c", _fn, ("a", "b")),
]
)
assert set(graph.names) == {"a", "b", "c"} assert set(graph.names) == {"a", "b", "c"}
assert graph.dependencies("c") == ("a", "b") assert graph.dependencies("c") == ("a", "b")
assert len(graph) == 3 assert len(graph) == 3
@@ -26,19 +28,23 @@ def test_from_specs_builds_graph() -> None:
def test_from_specs_allows_forward_references() -> None: def test_from_specs_allows_forward_references() -> None:
# b depends on a, but a is declared after b — order should not matter. # b depends on a, but a is declared after b — order should not matter.
graph = px.Graph.from_specs([ graph = px.Graph.from_specs(
px.TaskSpec("b", _fn, ("a",)), [
px.TaskSpec("a", _fn), px.TaskSpec("b", _fn, ("a",)),
]) px.TaskSpec("a", _fn),
]
)
assert graph.layers() == [["a"], ["b"]] assert graph.layers() == [["a"], ["b"]]
def test_duplicate_task_raises() -> None: def test_duplicate_task_raises() -> None:
with pytest.raises(DuplicateTaskError): with pytest.raises(DuplicateTaskError):
px.Graph.from_specs([ px.Graph.from_specs(
px.TaskSpec("a", _fn), [
px.TaskSpec("a", _fn), px.TaskSpec("a", _fn),
]) px.TaskSpec("a", _fn),
]
)
def test_missing_dependency_raises() -> None: def test_missing_dependency_raises() -> None:
@@ -50,20 +56,24 @@ def test_missing_dependency_raises() -> None:
def test_cycle_detection() -> None: def test_cycle_detection() -> None:
with pytest.raises(CycleError): with pytest.raises(CycleError):
px.Graph.from_specs([ px.Graph.from_specs(
px.TaskSpec("a", _fn, ("c",)), [
px.TaskSpec("b", _fn, ("a",)), px.TaskSpec("a", _fn, ("c",)),
px.TaskSpec("c", _fn, ("b",)), px.TaskSpec("b", _fn, ("a",)),
]) px.TaskSpec("c", _fn, ("b",)),
]
)
def test_layers_grouping() -> None: def test_layers_grouping() -> None:
graph = px.Graph.from_specs([ graph = px.Graph.from_specs(
px.TaskSpec("a", _fn), [
px.TaskSpec("b", _fn), px.TaskSpec("a", _fn),
px.TaskSpec("c", _fn, ("a", "b")), px.TaskSpec("b", _fn),
px.TaskSpec("d", _fn, ("c",)), px.TaskSpec("c", _fn, ("a", "b")),
]) px.TaskSpec("d", _fn, ("c",)),
]
)
layers = graph.layers() layers = graph.layers()
assert layers == [["a", "b"], ["c"], ["d"]] assert layers == [["a", "b"], ["c"], ["d"]]
@@ -74,10 +84,12 @@ def test_self_dependency_rejected() -> None:
def test_to_mermaid() -> None: def test_to_mermaid() -> None:
graph = px.Graph.from_specs([ graph = px.Graph.from_specs(
px.TaskSpec("a", _fn), [
px.TaskSpec("b", _fn, ("a",)), px.TaskSpec("a", _fn),
]) px.TaskSpec("b", _fn, ("a",)),
]
)
mermaid = graph.to_mermaid() mermaid = graph.to_mermaid()
assert mermaid.startswith("graph TD") assert mermaid.startswith("graph TD")
assert 'a["a"]' in mermaid assert 'a["a"]' in mermaid
@@ -91,11 +103,13 @@ def test_to_mermaid_invalid_orientation() -> None:
def test_subgraph_by_tags() -> None: def test_subgraph_by_tags() -> None:
graph = px.Graph.from_specs([ graph = px.Graph.from_specs(
px.TaskSpec("a", _fn, tags=("ingest",)), [
px.TaskSpec("b", _fn, ("a",), tags=("ingest",)), px.TaskSpec("a", _fn, tags=("ingest",)),
px.TaskSpec("c", _fn, ("b",), tags=("report",)), px.TaskSpec("b", _fn, ("a",), tags=("ingest",)),
]) px.TaskSpec("c", _fn, ("b",), tags=("report",)),
]
)
sub = graph.subgraph(["ingest"]) sub = graph.subgraph(["ingest"])
assert set(sub.names) == {"a", "b"} assert set(sub.names) == {"a", "b"}
# Edge to dropped task c is removed; b no longer waits for anything # Edge to dropped task c is removed; b no longer waits for anything
@@ -104,11 +118,13 @@ def test_subgraph_by_tags() -> None:
def test_subgraph_by_names() -> None: def test_subgraph_by_names() -> None:
graph = px.Graph.from_specs([ graph = px.Graph.from_specs(
px.TaskSpec("a", _fn), [
px.TaskSpec("b", _fn, ("a",)), px.TaskSpec("a", _fn),
px.TaskSpec("c", _fn, ("b",)), px.TaskSpec("b", _fn, ("a",)),
]) px.TaskSpec("c", _fn, ("b",)),
]
)
sub = graph.subgraph_by_names(["a", "b"]) sub = graph.subgraph_by_names(["a", "b"])
assert set(sub.names) == {"a", "b"} assert set(sub.names) == {"a", "b"}
# c is dropped, so b's dep on c (none here) — but a->b edge preserved. # c is dropped, so b's dep on c (none here) — but a->b edge preserved.
@@ -122,10 +138,93 @@ def test_subgraph_by_names_unknown() -> None:
def test_describe() -> None: def test_describe() -> None:
graph = px.Graph.from_specs([ graph = px.Graph.from_specs(
px.TaskSpec("a", _fn), [
px.TaskSpec("b", _fn, ("a",)), px.TaskSpec("a", _fn),
]) px.TaskSpec("b", _fn, ("a",)),
]
)
desc = graph.describe() desc = graph.describe()
assert "Layer 1" in desc assert "Layer 1" in desc
assert "Layer 2" in desc assert "Layer 2" in desc
# ---------------------------------------------------------------------- #
# 增量 add API 与其他访问器
# ---------------------------------------------------------------------- #
def test_add_chains_and_validates() -> None:
"""add() 应返回 self 以支持链式调用,并即时校验。"""
graph = px.Graph()
ret = graph.add(px.TaskSpec("a", _fn))
assert ret is graph
assert "a" in graph
# 缺失依赖应即时报错
with pytest.raises(MissingDependencyError):
graph.add(px.TaskSpec("b", _fn, ("missing",)))
def test_add_duplicate_raises() -> None:
graph = px.Graph()
graph.add(px.TaskSpec("a", _fn))
with pytest.raises(DuplicateTaskError):
graph.add(px.TaskSpec("a", _fn))
def test_all_specs_returns_view() -> None:
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
view = graph.all_specs()
assert set(view.keys()) == {"a"}
# 返回的是只读视图,修改不影响内部
assert view is graph.all_specs() or view == graph.all_specs()
def test_spec_accessor() -> None:
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
assert graph.spec("a").name == "a"
with pytest.raises(KeyError):
graph.spec("missing")
def test_dependencies_accessor() -> None:
graph = px.Graph.from_specs(
[
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn, ("a",)),
]
)
assert graph.dependencies("a") == ()
assert graph.dependencies("b") == ("a",)
def test_repr() -> None:
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
assert repr(graph) == "Graph(tasks=1)"
def test_empty_graph_layers() -> None:
"""空图的 layers() 应返回空列表。"""
graph = px.Graph()
assert graph.layers() == []
assert graph.to_mermaid() == "graph TD\n"
def test_subgraph_preserves_metadata() -> None:
"""子图应保留原任务的 retries/timeout/tags 等元数据。"""
graph = px.Graph.from_specs(
[
px.TaskSpec("a", _fn, tags=("x",), retries=3, timeout=5.0),
px.TaskSpec("b", _fn, ("a",), tags=("y",)),
]
)
sub = graph.subgraph(["x"])
spec = sub.spec("a")
assert spec.retries == 3
assert spec.timeout == 5.0
assert spec.tags == ("x",)
def test_subgraph_by_tags_no_match() -> None:
"""无匹配 tag 时返回空图。"""
graph = px.Graph.from_specs([px.TaskSpec("a", _fn, tags=("x",))])
sub = graph.subgraph(["z"])
assert len(sub) == 0
+121
View File
@@ -0,0 +1,121 @@
"""RunReport 测试。"""
from __future__ import annotations
from datetime import datetime
import pyflowx as px
from pyflowx.task import TaskResult, TaskSpec, TaskStatus
def _fn() -> int:
return 1
def _make_result(
name: str = "a",
status: TaskStatus = TaskStatus.SUCCESS,
value: object = 42,
error: object = None,
duration: float = 0.5,
attempts: int = 1,
) -> TaskResult[object]:
spec: TaskSpec[object] = TaskSpec(name, _fn) # type: ignore[arg-type]
start = datetime(2024, 1, 1, 0, 0, 0)
# 用 timedelta 精确表达秒数,避免 int() 截断小数
from datetime import timedelta
end = start + timedelta(seconds=duration) if duration else None
return TaskResult(
spec=spec,
status=status,
value=value, # type: ignore[arg-type]
error=error, # type: ignore[arg-type]
attempts=attempts,
started_at=start,
finished_at=end,
)
def test_getitem_returns_value() -> None:
report = px.RunReport()
report.results["a"] = _make_result("a", value=7)
assert report["a"] == 7
def test_result_of_returns_full_result() -> None:
report = px.RunReport()
r = _make_result("a")
report.results["a"] = r
assert report.result_of("a") is r
def test_contains() -> None:
report = px.RunReport()
report.results["a"] = _make_result("a")
assert "a" in report
assert "b" not in report
def test_iter_and_len() -> None:
report = px.RunReport()
report.results["a"] = _make_result("a")
report.results["b"] = _make_result("b")
assert list(report) == ["a", "b"]
assert len(report) == 2
def test_summary_success() -> None:
report = px.RunReport()
report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS, duration=1.0)
report.results["b"] = _make_result("b", status=TaskStatus.SKIPPED, duration=0.0)
s = report.summary()
assert s["success"] is True
assert s["total_tasks"] == 2
assert s["by_status"] == {"success": 1, "skipped": 1}
assert s["total_duration_seconds"] == 1.0
def test_summary_with_none_duration() -> None:
"""未开始/未结束的任务 duration 为 None,不应计入总时长。"""
report = px.RunReport()
spec: TaskSpec[object] = TaskSpec("a", _fn) # type: ignore[arg-type]
report.results["a"] = TaskResult(spec=spec, status=TaskStatus.FAILED)
s = report.summary()
assert s["total_duration_seconds"] == 0.0
def test_failed_tasks() -> None:
report = px.RunReport()
report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS)
report.results["b"] = _make_result(
"b", status=TaskStatus.FAILED, error=ValueError("x")
)
assert report.failed_tasks() == ["b"]
def test_describe_success() -> None:
report = px.RunReport()
report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS, duration=0.5)
desc = report.describe()
assert "RunReport(success=True)" in desc
assert "a: success" in desc
assert "0.500s" in desc
def test_describe_with_error() -> None:
report = px.RunReport(success=False)
report.results["a"] = _make_result(
"a", status=TaskStatus.FAILED, error=ValueError("boom"), duration=0.1
)
desc = report.describe()
assert "success=False" in desc
assert "error=ValueError" in desc
def test_describe_no_duration() -> None:
report = px.RunReport()
spec: TaskSpec[object] = TaskSpec("a", _fn) # type: ignore[arg-type]
report.results["a"] = TaskResult(spec=spec, status=TaskStatus.PENDING)
desc = report.describe()
assert "-" in desc # duration 显示为 "-"
+162
View File
@@ -0,0 +1,162 @@
"""状态后端测试。"""
from __future__ import annotations
import json
import os
import tempfile
from typing import Any
import pytest
from pyflowx.errors import StorageError
from pyflowx.storage import JSONBackend, MemoryBackend, StateBackend, resolve_backend
# ---------------------------------------------------------------------- #
# MemoryBackend
# ---------------------------------------------------------------------- #
def test_memory_backend_lifecycle() -> None:
b = MemoryBackend()
assert not b.has("a")
b.save("a", 1)
assert b.has("a")
assert b.get("a") == 1
assert dict(b.load()) == {"a": 1}
b.clear()
assert not b.has("a")
assert dict(b.load()) == {}
def test_memory_backend_get_missing_raises() -> None:
b = MemoryBackend()
with pytest.raises(KeyError):
b.get("nope")
# ---------------------------------------------------------------------- #
# JSONBackend
# ---------------------------------------------------------------------- #
def test_json_backend_save_and_load() -> None:
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "state.json")
b = JSONBackend(path)
b.save("a", {"x": 1})
b.save("b", [1, 2, 3])
# 重新打开应读到已保存内容
b2 = JSONBackend(path)
assert b2.has("a")
assert b2.get("a") == {"x": 1}
assert b2.get("b") == [1, 2, 3]
assert dict(b2.load()) == {"a": {"x": 1}, "b": [1, 2, 3]}
def test_json_backend_clear() -> None:
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "state.json")
b = JSONBackend(path)
b.save("a", 1)
b.clear()
assert not b.has("a")
# 文件应被写入空 dict
with open(path, "r", encoding="utf-8") as fh:
assert json.load(fh) == {}
def test_json_backend_nonexistent_file_starts_empty() -> None:
"""文件不存在时应正常初始化为空。"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "absent.json")
b = JSONBackend(path)
assert dict(b.load()) == {}
assert not b.has("anything")
def test_json_backend_non_serialisable_raises() -> None:
"""不可 JSON 序列化的值应抛 StorageError,且不污染内存状态。"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "state.json")
b = JSONBackend(path)
with pytest.raises(StorageError):
b.save("a", object()) # object() 不可序列化
assert not b.has("a")
def test_json_backend_flush_type_error(monkeypatch: pytest.MonkeyPatch) -> None:
"""_flush 时 json.dump 抛 TypeError 应转为 StorageError(覆盖 line 105-106)。
通过 monkeypatch 让 json.dump 在写入文件时抛 TypeError,模拟值通过
save 的 dumps 校验但在 dump 到文件句柄时失败(如自定义对象的边缘情况)。
"""
import json as _json
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "state.json")
b = JSONBackend(path)
original_dump = _json.dump
def flaky_dump(*args: Any, **kwargs: Any) -> None:
raise TypeError("simulated flush failure")
monkeypatch.setattr(_json, "dump", flaky_dump)
with pytest.raises(StorageError, match="cannot write"):
b.save("a", 1)
# 恢复以便后续测试不受影响
monkeypatch.setattr(_json, "dump", original_dump)
def test_json_backend_flush_os_error(monkeypatch: pytest.MonkeyPatch) -> None:
"""_flush 时 OSError 应转为 StorageError。"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "state.json")
b = JSONBackend(path)
original_replace = os.replace
def fail_replace(*args: Any, **kwargs: Any) -> None:
raise OSError("simulated os.replace failure")
monkeypatch.setattr(os, "replace", fail_replace)
with pytest.raises(StorageError, match="cannot write"):
b.save("a", 1)
monkeypatch.setattr(os, "replace", original_replace)
def test_json_backend_corrupt_file_raises() -> None:
"""损坏的 JSON 文件应抛 StorageError。"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "state.json")
with open(path, "w", encoding="utf-8") as fh:
fh.write("{not valid json")
with pytest.raises(StorageError):
JSONBackend(path)
def test_json_backend_non_dict_content_ignored() -> None:
"""文件内容是合法 JSON 但非 dict 时应被忽略(保持空)。"""
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "state.json")
with open(path, "w", encoding="utf-8") as fh:
json.dump([1, 2, 3], fh) # list 而非 dict
b = JSONBackend(path)
assert dict(b.load()) == {}
# ---------------------------------------------------------------------- #
# resolve_backend
# ---------------------------------------------------------------------- #
def test_resolve_backend_returns_input() -> None:
b = MemoryBackend()
assert resolve_backend(b) is b
def test_resolve_backend_creates_memory_when_none() -> None:
b = resolve_backend(None)
assert isinstance(b, MemoryBackend)
def test_state_backend_is_abstract() -> None:
"""StateBackend 是 ABC,不能直接实例化。"""
with pytest.raises(TypeError):
StateBackend() # type: ignore[abstract]
+66
View File
@@ -0,0 +1,66 @@
"""TaskSpec / TaskResult 数据结构测试。"""
from __future__ import annotations
from datetime import datetime
import pytest
import pyflowx as px
from pyflowx.task import TaskResult, TaskSpec, TaskStatus
def _fn() -> None:
return None
def test_spec_empty_name_rejected() -> None:
with pytest.raises(ValueError, match="non-empty"):
TaskSpec("", _fn)
def test_spec_negative_retries_rejected() -> None:
with pytest.raises(ValueError, match="retries"):
TaskSpec("a", _fn, retries=-1)
def test_spec_zero_timeout_rejected() -> None:
with pytest.raises(ValueError, match="timeout"):
TaskSpec("a", _fn, timeout=0)
def test_spec_self_dependency_rejected() -> None:
with pytest.raises(ValueError, match="depend on itself"):
TaskSpec("a", _fn, depends_on=("a",))
def test_task_result_duration_none_when_not_started() -> None:
spec: TaskSpec[None] = TaskSpec("a", _fn)
result: TaskResult[None] = TaskResult(spec=spec)
assert result.duration is None
def test_task_result_duration_when_partial() -> None:
spec: TaskSpec[None] = TaskSpec("a", _fn)
result: TaskResult[None] = TaskResult(spec=spec, started_at=datetime.now())
# started_at 已设但 finished_at 未设 -> None
assert result.duration is None
def test_task_result_duration_computed() -> None:
spec: TaskSpec[None] = TaskSpec("a", _fn)
start = datetime(2024, 1, 1, 0, 0, 0)
end = datetime(2024, 1, 1, 0, 0, 5)
result: TaskResult[None] = TaskResult(
spec=spec, started_at=start, finished_at=end
)
assert result.duration == 5.0
def test_task_result_default_status() -> None:
spec: TaskSpec[None] = TaskSpec("a", _fn)
result: TaskResult[None] = TaskResult(spec=spec)
assert result.status == TaskStatus.PENDING
assert result.value is None
assert result.error is None
assert result.attempts == 0