feat: 新增多项核心功能并优化默认执行策略

1.  将CliRunner默认执行策略从sequential改为dependency
2.  新增RunReport的任务状态查询和时长统计方法
3.  实现task装饰器并补充executor参数文档
4.  新增进程池执行器支持CPU密集型任务
5.  新增Graph.chain链式构建和add_subgraph子图合并功能
6.  新增流式任务传递、进程池执行、命名空间等多类测试用例
7.  补充tests目录路径导入配置
This commit is contained in:
2026-06-28 15:10:15 +08:00
parent 232e7293d9
commit 40f641611b
15 changed files with 907 additions and 15 deletions
+2
View File
@@ -94,6 +94,7 @@ from .task import (
TaskResult, TaskResult,
TaskSpec, TaskSpec,
TaskStatus, TaskStatus,
task,
task_template, task_template,
) )
@@ -139,5 +140,6 @@ __all__ = [
"describe_injection", "describe_injection",
"run", "run",
"run_command", "run_command",
"task",
"task_template", "task_template",
] ]
+65 -9
View File
@@ -58,6 +58,31 @@ from .task import TaskEvent, TaskHooks, TaskResult, TaskSpec, TaskStatus
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# 进程池复用:同一次 run() 内的 process 任务共享一个 ProcessPoolExecutor。
# 模块级缓存避免每次任务都创建/销毁进程池的开销。
_process_pool: concurrent.futures.ProcessPoolExecutor | None = None
_process_pool_lock = threading.Lock()
def _get_process_pool() -> concurrent.futures.ProcessPoolExecutor:
"""获取复用的进程池(惰性创建)。"""
global _process_pool # noqa: PLW0603
if _process_pool is None:
with _process_pool_lock:
if _process_pool is None:
_process_pool = concurrent.futures.ProcessPoolExecutor()
return _process_pool
def _run_in_process(fn: Any, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
"""模块级函数:在进程池中执行任务(须可 pickle)。
env_context 等上下文管理器无法跨进程传递,进程池任务的 ``env``/``cwd``
不生效;如需设置环境,应在 ``fn`` 内部自行处理。
"""
return fn(*args, **kwargs)
# 观察者回调类型。 # 观察者回调类型。
EventCallback = Callable[[TaskEvent], None] EventCallback = Callable[[TaskEvent], None]
Strategy = Literal["sequential", "thread", "async", "dependency"] Strategy = Literal["sequential", "thread", "async", "dependency"]
@@ -391,19 +416,50 @@ async def _execute_async_task(
loop: asyncio.AbstractEventLoop, loop: asyncio.AbstractEventLoop,
) -> Any: ) -> Any:
"""执行异步或同步任务(带超时处理)。""" """执行异步或同步任务(带超时处理)。"""
# 异步任务直接 await
if _is_async_fn(spec): if _is_async_fn(spec):
coro = cast(Awaitable[Any], spec.effective_fn(*args, **kwargs)) coro = cast(Awaitable[Any], spec.effective_fn(*args, **kwargs))
if spec.timeout is not None: return await asyncio.wait_for(coro, timeout=spec.timeout) if spec.timeout is not None else await coro
return await asyncio.wait_for(coro, timeout=spec.timeout)
return await coro # 同步任务:根据 executor 选择执行器
fut = _submit_sync_task(spec, args, kwargs, loop)
return await asyncio.wait_for(fut, timeout=spec.timeout) if spec.timeout is not None else await fut
def _submit_sync_task(
spec: TaskSpec[Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
loop: asyncio.AbstractEventLoop,
) -> asyncio.Future[Any]:
"""提交同步任务到对应执行器,返回 Future。
* ``inline``:直接在事件循环线程调用(阻塞循环,最快)。
* ``process``:进程池执行(绕过 GILfn 须可 pickle)。
* ``thread``(默认):线程池执行。
"""
def fn_call() -> Any: def fn_call() -> Any:
with spec.env_context(): with spec.env_context():
return spec.effective_fn(*args, **kwargs) return spec.effective_fn(*args, **kwargs)
if spec.timeout is not None: # inline:直接在事件循环线程调用,无线程池开销,但会阻塞循环。
return await asyncio.wait_for(loop.run_in_executor(None, fn_call), timeout=spec.timeout) if spec.executor == "inline":
return await loop.run_in_executor(None, fn_call) result = fn_call()
fut: asyncio.Future[Any] = loop.create_future()
fut.set_result(result)
return fut
# process:进程池执行,绕过 GIL,适合 CPU 密集型任务(fn 须可 pickle)。
if spec.executor == "process":
from functools import partial
pool = _get_process_pool()
proc_fn = partial(_run_in_process, spec.effective_fn, args, kwargs)
return loop.run_in_executor(pool, proc_fn)
# thread(默认):线程池执行。
return loop.run_in_executor(None, fn_call)
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
@@ -662,7 +718,7 @@ def _make_verbose_callback(on_event: EventCallback | None) -> EventCallback:
def run( def run(
graph: Graph, graph: Graph,
strategy: Strategy = "sequential", strategy: Strategy = "dependency",
*, *,
max_workers: int | None = None, max_workers: int | None = None,
dry_run: bool = False, dry_run: bool = False,
@@ -678,8 +734,8 @@ def run(
graph: graph:
待执行的已校验 :class:`Graph`。 待执行的已校验 :class:`Graph`。
strategy: strategy:
执行策略: ``"sequential"`` / ``"thread"`` / ``"async"`` / 执行策略: ``"dependency"``(默认,依赖驱动无层屏障,最大并行度)/
``"dependency"````"dependency"`` 为依赖驱动调度,无层屏障 ``"sequential"`` / ``"thread"`` / ``"async"``(层屏障模型)
max_workers: max_workers:
``"thread"`` 的线程池大小。默认 ``min(32, len(layer))``。 ``"thread"`` 的线程池大小。默认 ``min(32, len(layer))``。
dry_run: dry_run:
+138 -2
View File
@@ -17,12 +17,13 @@ __all__ = [
"GraphDefaults", "GraphDefaults",
] ]
import inspect
import sys import sys
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from typing import Any, Callable, Iterable, Mapping, Sequence from typing import Any, Callable, Iterable, Mapping, Sequence
from .errors import CycleError, DuplicateTaskError, MissingDependencyError from .errors import CycleError, DuplicateTaskError, MissingDependencyError
from .task import RetryPolicy, TaskSpec from .task import Context, RetryPolicy, TaskSpec
if sys.version_info >= (3, 9): # pragma: no cover if sys.version_info >= (3, 9): # pragma: no cover
import graphlib # pyright: ignore[reportUnreachable] import graphlib # pyright: ignore[reportUnreachable]
@@ -63,6 +64,74 @@ def _prune_deps(spec: TaskSpec[Any], keep: Callable[[str], bool]) -> TaskSpec[An
) )
def _make_namespaced_fn(orig_fn: Any, ns: str, dep_names: set[str]) -> Any:
"""包装 fn,使其能接收带 ``ns:`` 前缀的依赖名,调用时映射回原参数名。
命名空间合并后,依赖名带前缀(如 ``build:extract``),但 Python 参数名
不能含 ``:``。wrapper 用 ``**kwargs`` 接收所有依赖,内部把带前缀的依赖名
映射回原参数名后调用原 fn。
无依赖参数时直接返回原 fn。
"""
if not dep_names or orig_fn is None:
return orig_fn
try:
orig_sig = inspect.signature(orig_fn)
except (TypeError, ValueError):
return orig_fn
# 带前缀依赖名 -> 原参数名
name_map: dict[str, str] = {f"{ns}:{orig}": orig for orig in dep_names}
prefix = f"{ns}:"
# 检查原 fn 是否有 Context 标注参数
context_param_name: str | None = None
for p in orig_sig.parameters.values():
ann = p.annotation
if ann is not Context and not (isinstance(ann, str) and ann.endswith("Context")):
continue
context_param_name = p.name
break
if context_param_name is not None:
def wrapper(ctx: Any = None, **kwargs: Any) -> Any:
# ctx 是 dep_context,键为带前缀的依赖名;映射回原始键
orig_ctx: dict[str, Any] = {}
for k, v in (ctx or {}).items():
orig_ctx[name_map.get(k, k)] = v
# kwargs 中带前缀的依赖也映射回原参数名
for k, v in kwargs.items():
if k in name_map:
orig_ctx[name_map[k]] = v
return orig_fn(**{context_param_name: orig_ctx})
ctx_param = inspect.Parameter("ctx", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Context)
kw_param = inspect.Parameter("kwargs", inspect.Parameter.VAR_KEYWORD)
wrapper.__signature__ = inspect.Signature( # type: ignore[attr-defined]
parameters=[ctx_param, kw_param],
return_annotation=orig_sig.return_annotation,
)
else:
def wrapper(**kwargs: Any) -> Any: # type: ignore[no-redef]
orig_kwargs: dict[str, Any] = {}
for k, v in kwargs.items():
if k.startswith(prefix):
orig_kwargs[k[len(prefix) :]] = v
return orig_fn(**orig_kwargs)
kw_param = inspect.Parameter("kwargs", inspect.Parameter.VAR_KEYWORD)
wrapper.__signature__ = inspect.Signature( # type: ignore[attr-defined]
parameters=[kw_param],
return_annotation=orig_sig.return_annotation,
)
wrapper.__name__ = f"{ns}_{getattr(orig_fn, '__name__', 'fn')}"
wrapper.__doc__ = getattr(orig_fn, "__doc__", None)
return wrapper
@dataclass @dataclass
class Graph: class Graph:
"""校验后的有向无环任务图。 """校验后的有向无环任务图。
@@ -78,6 +147,7 @@ class Graph:
specs: dict[str, TaskSpec[Any]] = field(default_factory=dict) specs: dict[str, TaskSpec[Any]] = field(default_factory=dict)
deps: dict[str, tuple[str, ...]] = field(default_factory=dict) deps: dict[str, tuple[str, ...]] = field(default_factory=dict)
defaults: GraphDefaults = field(default_factory=GraphDefaults) defaults: GraphDefaults = field(default_factory=GraphDefaults)
namespace: str | None = None
# 待解析的字符串引用列表(由 GraphComposer 消费);为空表示无引用。 # 待解析的字符串引用列表(由 GraphComposer 消费);为空表示无引用。
_pending_refs: list[str] = field(default_factory=list) _pending_refs: list[str] = field(default_factory=list)
@@ -95,6 +165,28 @@ class Graph:
self._validate_references() self._validate_references()
return self return self
def chain(self, *specs: TaskSpec[Any]) -> Graph:
"""链式注册任务:每个 spec 自动依赖前一个。
``chain(a, b, c)`` 等价于 ``b`` 依赖 ``a````c`` 依赖 ``b``。
若 spec 已带 ``depends_on``,则前驱名追加到现有依赖前。
返回 ``self`` 支持链式调用。
Examples
--------
>>> graph = px.Graph().chain(extract, transform, load)
"""
prev_name: str | None = None
for s in specs:
current = s
if prev_name is not None:
# 将前驱追加到 depends_on 最前(保持显式依赖优先)
new_deps = (prev_name, *s.depends_on) if prev_name not in s.depends_on else s.depends_on
current = replace(s, depends_on=new_deps)
self.add(current)
prev_name = current.name
return self
def _register(self, spec: TaskSpec[Any]) -> None: def _register(self, spec: TaskSpec[Any]) -> None:
if spec.name in self.specs: if spec.name in self.specs:
raise DuplicateTaskError(spec.name) raise DuplicateTaskError(spec.name)
@@ -108,6 +200,8 @@ class Graph:
cls, cls,
specs: Iterable[TaskSpec[Any] | str], specs: Iterable[TaskSpec[Any] | str],
defaults: GraphDefaults | None = None, defaults: GraphDefaults | None = None,
*,
namespace: str | None = None,
) -> Graph: ) -> Graph:
"""从可迭代的 task spec 构建图。 """从可迭代的 task spec 构建图。
@@ -120,8 +214,10 @@ class Graph:
TaskSpec 对象或字符串引用的列表。 TaskSpec 对象或字符串引用的列表。
defaults: defaults:
图级默认值。``None`` 使用空 :class:`GraphDefaults`。 图级默认值。``None`` 使用空 :class:`GraphDefaults`。
namespace:
可选命名空间,用于 :meth:`add_subgraph` 合并时加前缀。
""" """
graph = cls(defaults=defaults or GraphDefaults()) graph = cls(defaults=defaults or GraphDefaults(), namespace=namespace)
pending_refs: list[str] = [] pending_refs: list[str] = []
for spec in specs: for spec in specs:
@@ -139,6 +235,46 @@ class Graph:
graph.validate() graph.validate()
return graph return graph
def add_subgraph(self, sub: Graph, *, namespace: str | None = None) -> Graph:
"""将子图合并到当前图,任务名加命名空间前缀避免冲突。
参数
----
sub:
待合并的子图。
namespace:
命名空间前缀。``None`` 时使用 ``sub.namespace``,若子图也无命名空间
则抛出 ``ValueError``。最终任务名为 ``f"{ns}:{original_name}"``。
合并后,子图内任务的依赖名也会被加前缀;与子图外部任务的依赖保持原样。
返回 ``self`` 支持链式调用。
"""
ns = namespace or sub.namespace
if not ns:
raise ValueError("add_subgraph 需要 namespace 或子图自带 namespace")
def _rename(name: str) -> str:
# 仅对子图内部任务名加前缀;外部依赖保持原样
return f"{ns}:{name}" if name in sub.specs else name
sub_names = set(sub.specs.keys())
for spec in sub.specs.values():
# 子图内部依赖名需加前缀,对应的 fn 参数也需包装
internal_deps = (set(spec.depends_on) | set(spec.soft_depends_on)) & sub_names
new_fn = _make_namespaced_fn(spec.fn, ns, internal_deps) if spec.fn else spec.fn
new_spec = replace(
spec,
name=_rename(spec.name),
fn=new_fn,
depends_on=tuple(_rename(d) for d in spec.depends_on),
soft_depends_on=tuple(_rename(d) for d in spec.soft_depends_on),
)
self._register(new_spec)
self._validate_references()
self.validate()
return self
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
# 校验 # 校验
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
+16
View File
@@ -69,6 +69,22 @@ class RunReport:
"""以 FAILED 状态结束的任务名列表。""" """以 FAILED 状态结束的任务名列表。"""
return [name for name, r in self.results.items() if r.status == TaskStatus.FAILED] return [name for name, r in self.results.items() if r.status == TaskStatus.FAILED]
def succeeded_tasks(self) -> list[str]:
"""以 SUCCESS 状态结束的任务名列表。"""
return [name for name, r in self.results.items() if r.status == TaskStatus.SUCCESS]
def skipped_tasks(self) -> list[str]:
"""以 SKIPPED 状态结束的任务名列表。"""
return [name for name, r in self.results.items() if r.status == TaskStatus.SKIPPED]
def tasks_by_status(self, status: TaskStatus) -> list[str]:
"""返回指定状态的任务名列表。"""
return [name for name, r in self.results.items() if r.status == status]
def durations(self) -> dict[str, float]:
"""任务名 -> 执行时长(秒)。无时长记录的为 0.0。"""
return {name: (r.duration or 0.0) for name, r in self.results.items()}
def describe(self) -> str: def describe(self) -> str:
"""用于调试的人类可读多行报告。""" """用于调试的人类可读多行报告。"""
lines: list[str] = [f"RunReport(success={self.success})"] lines: list[str] = [f"RunReport(success={self.success})"]
+1 -1
View File
@@ -114,7 +114,7 @@ class CliRunner:
""" """
graphs: dict[str, Graph] = field(default_factory=dict) graphs: dict[str, Graph] = field(default_factory=dict)
strategy: Strategy = field(default="sequential") strategy: Strategy = field(default="dependency")
description: str = field(default_factory=str) description: str = field(default_factory=str)
verbose: bool = field(default_factory=lambda: True) verbose: bool = field(default_factory=lambda: True)
+88
View File
@@ -254,6 +254,10 @@ class TaskSpec(Generic[T]):
存取状态后端使不同输入产生独立缓存条目``None`` 表示用任务名 存取状态后端使不同输入产生独立缓存条目``None`` 表示用任务名
hooks: hooks:
:class:`TaskHooks` 生命周期钩子 :class:`TaskHooks` 生命周期钩子
executor:
同步任务的执行器``"thread"``默认线程池/ ``"process"``
进程池绕过 GIL适合 CPU 密集型``fn`` 须可 pickle/
``"inline"``直接在事件循环线程调用最快但会阻塞循环
""" """
name: str name: str
@@ -279,6 +283,7 @@ class TaskSpec(Generic[T]):
continue_on_error: bool = False continue_on_error: bool = False
cache_key: CacheKeyFn | None = None cache_key: CacheKeyFn | None = None
hooks: TaskHooks = field(default_factory=TaskHooks) hooks: TaskHooks = field(default_factory=TaskHooks)
executor: str = "thread" # "thread" | "process" | "inline"
def __post_init__(self) -> None: def __post_init__(self) -> None:
if not self.name: if not self.name:
@@ -447,6 +452,89 @@ def _env_and_cwd(
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# 任务模板:批量生成相似 TaskSpec 的工厂 # 任务模板:批量生成相似 TaskSpec 的工厂
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
def _task_noop() -> None:
"""task(cmd=...) 形式下的占位 fn(cmd 任务执行期不调用 fn)。"""
return None
def task(
fn: TaskFn[Any] | None = None,
*,
cmd: TaskCmd | None = None,
depends_on: tuple[str, ...] = (),
soft_depends_on: tuple[str, ...] = (),
defaults: Mapping[str, Any] | None = None,
args: tuple[Any, ...] = (),
kwargs: Mapping[str, Any] | None = None,
retry: RetryPolicy | None = None,
timeout: float | None = None,
tags: tuple[str, ...] = (),
conditions: tuple[Condition, ...] = (),
cwd: str | Path | None = None,
env: Mapping[str, str] | None = None,
verbose: bool = False,
skip_if_missing: bool = False,
allow_upstream_skip: bool = False,
strategy: str | None = None,
priority: int = 0,
concurrency_key: str | None = None,
continue_on_error: bool = False,
cache_key: CacheKeyFn | None = None,
hooks: TaskHooks | None = None,
name: str | None = None,
) -> Any:
"""装饰器:将函数转为 :class:`TaskSpec`。
``name`` 默认取 ``fn.__name__``可直接装饰函数或带参数使用
Examples
--------
>>> @px.task
... def extract(): return [1, 2, 3]
>>> @px.task(depends_on=("extract",))
... def double(extract): return [x * 2 for x in extract]
>>> graph = px.Graph.from_specs([extract, double])
"""
def _decorate(func: TaskFn[Any]) -> TaskSpec[Any]:
spec_name = name or func.__name__
return TaskSpec(
name=spec_name,
fn=func,
cmd=cmd,
depends_on=depends_on,
soft_depends_on=soft_depends_on,
defaults=dict(defaults) if defaults else {},
args=args,
kwargs=dict(kwargs) if kwargs else {},
retry=retry if retry is not None else RetryPolicy(),
timeout=timeout,
tags=tags,
conditions=conditions,
cwd=Path(cwd) if isinstance(cwd, str) else cwd,
env=dict(env) if env else None,
verbose=verbose,
skip_if_missing=skip_if_missing,
allow_upstream_skip=allow_upstream_skip,
strategy=strategy,
priority=priority,
concurrency_key=concurrency_key,
continue_on_error=continue_on_error,
cache_key=cache_key,
hooks=hooks if hooks is not None else TaskHooks(),
)
if fn is None and cmd is None:
# 带参数调用:@task(depends_on=...),等待被装饰函数
return _decorate
if fn is None:
# task(cmd=..., name=...) 直接构造,无被装饰函数
if name is None:
raise ValueError("task(cmd=...) 需要显式提供 name")
return _decorate(_task_noop)
return _decorate(fn)
def task_template( def task_template(
fn: TaskFn[Any] | None = None, fn: TaskFn[Any] | None = None,
cmd: TaskCmd | None = None, cmd: TaskCmd | None = None,
+26
View File
@@ -0,0 +1,26 @@
"""进程池测试辅助:模块级函数(须可 pickle)。"""
from __future__ import annotations
import time
def cpu_heavy(n: int) -> int:
"""CPU 密集型计算(求平方和)。"""
return sum(i * i for i in range(n))
def add(a: int, b: int) -> int:
"""简单加法。"""
return a + b
def sub(a: int, b: int) -> int:
"""简单减法。"""
return a - b
def slow_sleep(seconds: float) -> int:
"""睡眠指定秒数,用于测试超时。"""
time.sleep(seconds)
return int(seconds)
+7
View File
@@ -1,9 +1,16 @@
from __future__ import annotations from __future__ import annotations
import sys
from pathlib import Path from pathlib import Path
import pytest import pytest
# 将 tests 目录加入 sys.path,使进程池测试能 import _proc_helper 模块级辅助函数。
# 进程池 pickle 要求被调用函数为模块级,conftest.py 在 xdist worker 中也会执行。
_TESTS_DIR = str(Path(__file__).resolve().parent)
if _TESTS_DIR not in sys.path:
sys.path.insert(0, _TESTS_DIR)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def packtool_tmp_workdir(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: def packtool_tmp_workdir(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
+101
View File
@@ -0,0 +1,101 @@
"""Tests for Graph.chain DSL."""
from __future__ import annotations
import pyflowx as px
from pyflowx.task import TaskSpec
def _fn() -> None:
return None
def test_chain_basic_linkage() -> None:
"""chain(a, b, c) 应建立 a->b->c 依赖."""
a = TaskSpec("a", _fn)
b = TaskSpec("b", _fn)
c = TaskSpec("c", _fn)
graph = px.Graph().chain(a, b, c)
assert graph.all_specs()["b"].depends_on == ("a",)
assert graph.all_specs()["c"].depends_on == ("b",)
assert graph.all_specs()["a"].depends_on == ()
def test_chain_single_spec() -> None:
"""chain(a) 应只注册 a,无依赖."""
a = TaskSpec("a", _fn)
graph = px.Graph().chain(a)
assert "a" in graph
assert graph.all_specs()["a"].depends_on == ()
def test_chain_preserves_existing_deps() -> None:
"""chain 应保留 spec 已有的 depends_on."""
a = TaskSpec("a", _fn)
b = TaskSpec("b", _fn)
c = TaskSpec("c", _fn, depends_on=("b",))
graph = px.Graph().chain(a, b, c)
# c 已有 depends_on=('b',),前驱是 b,已在依赖中,不重复添加
assert graph.all_specs()["c"].depends_on == ("b",)
def test_chain_merges_existing_deps() -> None:
"""chain 应将前驱追加到已有依赖前(若不存在)."""
a = TaskSpec("a", _fn)
x = TaskSpec("x", _fn)
c = TaskSpec("c", _fn, depends_on=("x",))
graph = px.Graph().chain(a, x, c)
# c 前驱是 x,但 c 已依赖 x,不重复
assert graph.all_specs()["c"].depends_on == ("x",)
def test_chain_returns_self() -> None:
"""chain 返回 self 支持链式调用."""
a = TaskSpec("a", _fn)
graph = px.Graph()
assert graph.chain(a) is graph
def test_chain_execution_order() -> None:
"""chain 应保证执行顺序."""
order: list[str] = []
def make(name: str):
def fn() -> str:
order.append(name)
return name
return fn
a = TaskSpec("a", make("a"))
b = TaskSpec("b", make("b"))
c = TaskSpec("c", make("c"))
graph = px.Graph().chain(a, b, c)
report = px.run(graph)
assert report.success
assert order == ["a", "b", "c"]
def test_chain_with_decorator_specs() -> None:
"""chain 应与 @task 装饰器配合."""
@px.task
def extract() -> int:
return 1
@px.task
def transform(extract: int) -> int:
return extract + 10
@px.task
def load(transform: int) -> int:
return transform + 100
graph = px.Graph().chain(extract, transform, load)
report = px.run(graph)
assert report.success
assert report["load"] == 111
+62
View File
@@ -0,0 +1,62 @@
"""Tests for process executor (spec.executor='process')."""
from __future__ import annotations
import pytest
# pyrefly: ignore[missing-import]
from _proc_helper import add, cpu_heavy, slow_sleep, sub
import pyflowx as px
from pyflowx.errors import TaskFailedError
def test_process_executor_runs_cpu_task() -> None:
"""executor='process' 应在进程池中执行 CPU 密集型任务."""
spec = px.TaskSpec("cpu", fn=cpu_heavy, args=(1000,), executor="process")
graph = px.Graph.from_specs([spec])
report = px.run(graph)
assert report.success
assert report["cpu"] == sum(i * i for i in range(1000))
def test_process_executor_with_dependency() -> None:
"""进程池任务应支持依赖注入."""
spec1 = px.TaskSpec("a", fn=cpu_heavy, args=(100,), executor="process")
spec2 = px.TaskSpec("b", fn=add, args=(3, 4), executor="process", depends_on=("a",))
graph = px.Graph.from_specs([spec1, spec2])
report = px.run(graph)
assert report.success
assert report["b"] == 7
def test_process_executor_default_is_thread() -> None:
"""TaskSpec.executor 默认应为 'thread'."""
spec = px.TaskSpec("x", fn=lambda: None)
assert spec.executor == "thread"
def test_inline_executor_runs_in_event_loop() -> None:
"""executor='inline' 应直接在事件循环线程调用."""
spec = px.TaskSpec("inline", fn=add, args=(10, 20), executor="inline")
graph = px.Graph.from_specs([spec])
report = px.run(graph)
assert report.success
assert report["inline"] == 30
def test_process_executor_with_kwargs() -> None:
"""进程池任务应支持 kwargs 注入."""
spec = px.TaskSpec("kw", fn=sub, args=(10,), kwargs={"b": 3}, executor="process")
graph = px.Graph.from_specs([spec])
report = px.run(graph)
assert report.success
assert report["kw"] == 7
def test_process_executor_timeout() -> None:
"""进程池任务超时应抛 TaskFailedError."""
spec = px.TaskSpec("slow", fn=slow_sleep, args=(10.0,), executor="process", timeout=0.1)
graph = px.Graph.from_specs([spec])
with pytest.raises(TaskFailedError):
px.run(graph)
+152
View File
@@ -0,0 +1,152 @@
"""Tests for Graph namespace and add_subgraph."""
from __future__ import annotations
import pytest
import pyflowx as px
def _fn() -> None:
return None
def test_graph_namespace_field_default_none() -> None:
"""Graph 默认 namespace 为 None."""
graph = px.Graph()
assert graph.namespace is None
def test_graph_from_specs_with_namespace() -> None:
"""from_specs(namespace=...) 应设置 graph.namespace."""
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)], namespace="ns1")
assert graph.namespace == "ns1"
def test_add_subgraph_prefixes_task_names() -> None:
"""add_subgraph 应给子图任务名加命名空间前缀."""
sub = px.Graph.from_specs(
[px.TaskSpec("extract", _fn), px.TaskSpec("build", _fn, depends_on=("extract",))],
namespace="build",
)
main = px.Graph.from_specs([px.TaskSpec("start", _fn)])
main.add_subgraph(sub)
assert "start" in main
assert "build:extract" in main
assert "build:build" in main
def test_add_subgraph_renames_internal_deps() -> None:
"""add_subgraph 应给子图内部依赖名加前缀."""
sub = px.Graph.from_specs(
[px.TaskSpec("a", _fn), px.TaskSpec("b", _fn, depends_on=("a",))],
namespace="ns",
)
main = px.Graph()
main.add_subgraph(sub)
b_spec = main.all_specs()["ns:b"]
assert b_spec.depends_on == ("ns:a",)
def test_add_subgraph_all_internal_deps_prefixed() -> None:
"""add_subgraph 子图内所有任务(含被依赖的)都加前缀."""
sub = px.Graph.from_specs(
[px.TaskSpec("ext", _fn), px.TaskSpec("b", _fn, depends_on=("ext",))],
namespace="ns",
)
main = px.Graph()
main.add_subgraph(sub)
b_spec = main.all_specs()["ns:b"]
assert b_spec.depends_on == ("ns:ext",)
assert "ns:ext" in main
def test_add_subgraph_requires_namespace() -> None:
"""add_subgraph 无 namespace 时应抛 ValueError."""
sub = px.Graph.from_specs([px.TaskSpec("a", _fn)]) # 无 namespace
main = px.Graph()
with pytest.raises(ValueError, match="namespace"):
main.add_subgraph(sub)
def test_add_subgraph_explicit_namespace_overrides() -> None:
"""add_subgraph(namespace=...) 应覆盖子图自带 namespace."""
sub = px.Graph.from_specs([px.TaskSpec("a", _fn)], namespace="original")
main = px.Graph()
main.add_subgraph(sub, namespace="override")
assert "override:a" in main
assert "original:a" not in main
def test_add_subgraph_internal_injection_works() -> None:
"""子图内部依赖注入应通过 wrapper 正常工作."""
sub = px.Graph.from_specs(
[
px.TaskSpec("extract", lambda: [1, 2, 3]),
px.TaskSpec("build", lambda extract: [x * 2 for x in extract], depends_on=("extract",)),
],
namespace="build",
)
main = px.Graph()
main.add_subgraph(sub)
report = px.run(main)
assert report.success
assert report["build:build"] == [2, 4, 6]
def test_add_subgraph_cross_namespace_ref_via_context() -> None:
"""跨命名空间引用应通过 Context 标注接收."""
def consumer(ctx: px.Context) -> str:
return f"got {ctx['ns:data']}"
sub = px.Graph.from_specs(
[px.TaskSpec("data", lambda: "data_value")],
namespace="ns",
)
main = px.Graph()
main.add_subgraph(sub)
main.add(px.TaskSpec("consumer", consumer, depends_on=("ns:data",)))
report = px.run(main)
assert report.success
assert report["consumer"] == "got data_value"
def test_add_subgraph_context_annotation_in_subgraph() -> None:
"""子图内部任务用 Context 标注时,wrapper 应正确传递."""
def sink(ctx: px.Context) -> int:
return ctx["src"]
sub = px.Graph.from_specs(
[
px.TaskSpec("src", lambda: 42),
px.TaskSpec("sink", sink, depends_on=("src",)),
],
namespace="ns",
)
main = px.Graph()
main.add_subgraph(sub)
report = px.run(main)
assert report.success
assert report["ns:sink"] == 42
def test_add_subgraph_chained() -> None:
"""多个子图可链式合并到主图."""
sub_a = px.Graph.from_specs([px.TaskSpec("a", _fn)], namespace="nsA")
sub_b = px.Graph.from_specs([px.TaskSpec("b", _fn)], namespace="nsB")
main = px.Graph()
main.add_subgraph(sub_a).add_subgraph(sub_b)
assert "nsA:a" in main
assert "nsB:b" in main
+47
View File
@@ -126,3 +126,50 @@ class TestRunReportDescribe:
report.results["a"] = TaskResult[Any](spec=spec, status=TaskStatus.PENDING) report.results["a"] = TaskResult[Any](spec=spec, status=TaskStatus.PENDING)
desc = report.describe() desc = report.describe()
assert "-" in desc # duration 显示为 "-" assert "-" in desc # duration 显示为 "-"
class TestRunReportQueries:
"""测试 RunReport 的新查询 API."""
def test_succeeded_tasks(self) -> None:
"""succeeded_tasks 返回 SUCCESS 状态的任务名."""
report = px.RunReport()
report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS)
report.results["b"] = _make_result("b", status=TaskStatus.FAILED)
report.results["c"] = _make_result("c", status=TaskStatus.SUCCESS)
assert report.succeeded_tasks() == ["a", "c"]
def test_skipped_tasks(self) -> None:
"""skipped_tasks 返回 SKIPPED 状态的任务名."""
report = px.RunReport()
report.results["a"] = _make_result("a", status=TaskStatus.SKIPPED)
report.results["b"] = _make_result("b", status=TaskStatus.SUCCESS)
assert report.skipped_tasks() == ["a"]
def test_tasks_by_status(self) -> None:
"""tasks_by_status 按指定状态过滤."""
report = px.RunReport()
report.results["a"] = _make_result("a", status=TaskStatus.FAILED)
report.results["b"] = _make_result("b", status=TaskStatus.FAILED)
report.results["c"] = _make_result("c", status=TaskStatus.SUCCESS)
assert report.tasks_by_status(TaskStatus.FAILED) == ["a", "b"]
assert report.tasks_by_status(TaskStatus.SUCCESS) == ["c"]
assert report.tasks_by_status(TaskStatus.SKIPPED) == []
def test_durations(self) -> None:
"""durations 返回任务名 -> 时长映射."""
report = px.RunReport()
report.results["a"] = _make_result("a", duration=1.5)
report.results["b"] = _make_result("b", duration=2.0)
durs = report.durations()
assert durs["a"] == 1.5
assert durs["b"] == 2.0
def test_durations_no_duration(self) -> None:
"""无时长的任务应返回 0.0."""
report = px.RunReport()
spec: TaskSpec[Any] = TaskSpec[Any]("a", _fn) # type: ignore[arg-type]
report.results["a"] = TaskResult[Any](spec=spec, status=TaskStatus.PENDING)
durs = report.durations()
assert durs["a"] == 0.0
+3 -3
View File
@@ -72,10 +72,10 @@ class TestCliRunnerConstruction:
) )
assert runner.commands == ["clean", "build", "test"] assert runner.commands == ["clean", "build", "test"]
def test_default_strategy_is_sequential(self) -> None: def test_default_strategy_is_dependency(self) -> None:
"""默认策略应为 Strategy.SEQUENTIAL.""" """默认策略应为 dependency(依赖驱动,最大并行度)."""
runner = px.CliRunner({"clean": _echo_graph()}) runner = px.CliRunner({"clean": _echo_graph()})
assert runner.strategy == "sequential" assert runner.strategy == "dependency"
def test_custom_strategy_string(self) -> None: def test_custom_strategy_string(self) -> None:
"""应支持通过字符串指定策略.""" """应支持通过字符串指定策略."""
+63
View File
@@ -0,0 +1,63 @@
"""Tests for streaming result passing (iterators between tasks)."""
from __future__ import annotations
from typing import Iterator
import pyflowx as px
def test_generator_passed_as_iterator() -> None:
"""上游返回生成器,下游应能惰性消费."""
@px.task
def source() -> Iterator[int]:
yield from range(5)
@px.task(depends_on=("source",))
def consume(source: Iterator[int]) -> int:
return sum(source)
graph = px.Graph.from_specs([source, consume])
report = px.run(graph)
assert report.success
assert report["consume"] == 10
def test_large_range_streaming() -> None:
"""大范围迭代器流式传递,避免中间列表."""
@px.task
def numbers() -> Iterator[int]:
yield from range(1000)
@px.task(depends_on=("numbers",))
def total(numbers: Iterator[int]) -> int:
return sum(numbers)
graph = px.Graph.from_specs([numbers, total])
report = px.run(graph)
assert report.success
assert report["total"] == sum(range(1000))
def test_chain_multiple_streams() -> None:
"""多个流式任务串联."""
@px.task
def gen() -> Iterator[int]:
yield from range(10)
@px.task(depends_on=("gen",))
def doubled(gen: Iterator[int]) -> Iterator[int]:
for x in gen:
yield x * 2
@px.task(depends_on=("doubled",))
def collect(doubled: Iterator[int]) -> list[int]:
return list(doubled)
graph = px.Graph.from_specs([gen, doubled, collect])
report = px.run(graph)
assert report.success
assert report["collect"] == [x * 2 for x in range(10)]
+136
View File
@@ -0,0 +1,136 @@
"""Tests for the @task decorator API."""
from __future__ import annotations
from pathlib import Path
from typing import Any, Mapping
import pyflowx as px
from pyflowx.task import RetryPolicy, TaskHooks, TaskSpec
def test_task_decorator_plain() -> None:
"""@task 无参数装饰:name 取函数名,返回 TaskSpec."""
@px.task
def extract() -> list[int]:
return [1, 2, 3]
assert isinstance(extract, TaskSpec)
assert extract.name == "extract"
assert extract.fn is not None
assert extract.depends_on == ()
def test_task_decorator_with_params() -> None:
"""@task(...) 带参数装饰:传递依赖与重试."""
@px.task(depends_on=("extract",), retry=RetryPolicy(max_attempts=3))
def double(extract: list[int]) -> list[int]:
return [x * 2 for x in extract]
assert isinstance(double, TaskSpec)
assert double.name == "double"
assert double.depends_on == ("extract",)
assert double.retry.max_attempts == 3
def test_task_decorator_explicit_name() -> None:
"""@task(name=...) 应使用显式名称而非函数名."""
@px.task(name="custom_name")
def my_func() -> None:
return None
assert my_func.name == "custom_name"
def test_task_decorator_cmd_form() -> None:
"""@task(cmd=...) 应支持命令形式."""
spec = px.task(cmd=["ls", "-la"], name="list_files")
assert isinstance(spec, TaskSpec)
assert spec.name == "list_files"
assert spec.cmd == ["ls", "-la"]
def test_task_decorator_full_options() -> None:
"""@task 应支持全部 TaskSpec 字段."""
@px.task(
depends_on=("a",),
soft_depends_on=("b",),
defaults={"b": 0},
args=(1,),
kwargs={"x": 2},
retry=RetryPolicy(max_attempts=5),
timeout=10.0,
tags=("t1",),
conditions=(px.BuiltinConditions.IS_WINDOWS,), # type: ignore[arg-type]
cwd="/tmp",
env={"K": "v"},
verbose=True,
skip_if_missing=True,
allow_upstream_skip=True,
strategy="thread",
priority=3,
concurrency_key="db",
continue_on_error=True,
)
def f(a: int) -> int:
return a
assert f.depends_on == ("a",)
assert f.soft_depends_on == ("b",)
assert f.defaults == {"b": 0}
assert f.args == (1,)
assert f.kwargs == {"x": 2}
assert f.retry.max_attempts == 5
assert f.timeout == 10.0
assert f.tags == ("t1",)
assert len(f.conditions) == 1
assert isinstance(f.cwd, Path)
assert f.cwd == Path("/tmp")
assert f.env == {"K": "v"}
assert f.verbose is True
assert f.skip_if_missing is True
assert f.allow_upstream_skip is True
assert f.strategy == "thread"
assert f.priority == 3
assert f.concurrency_key == "db"
assert f.continue_on_error is True
def test_task_decorator_runs_in_graph() -> None:
"""装饰器生成的 TaskSpec 应能直接构建图并运行."""
@px.task
def extract() -> list[int]:
return [1, 2, 3]
@px.task(depends_on=("extract",))
def double(extract: list[int]) -> list[int]:
return [x * 2 for x in extract]
graph = px.Graph.from_specs([extract, double])
report = px.run(graph)
assert report.success
assert report["double"] == [2, 4, 6]
def test_task_decorator_hooks_passthrough() -> None:
"""@task(hooks=...) 应传递 TaskHooks 实例."""
hooks = TaskHooks(pre_run=lambda _spec: None)
spec = px.task(fn=lambda: None, hooks=hooks, name="h")
assert spec.hooks is hooks
def test_task_decorator_cache_key_passthrough() -> None:
"""@task(cache_key=...) 应传递缓存键函数."""
def ck(ctx: Mapping[str, Any]) -> str:
return "k"
spec = px.task(fn=lambda: None, cache_key=ck, name="c")
assert spec.cache_key is ck