From 40f641611b5c237cb1976e3d0c817f3d3267a958 Mon Sep 17 00:00:00 2001 From: gooker_young Date: Sun, 28 Jun 2026 15:10:15 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E5=A4=9A=E9=A1=B9?= =?UTF-8?q?=E6=A0=B8=E5=BF=83=E5=8A=9F=E8=83=BD=E5=B9=B6=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E9=BB=98=E8=AE=A4=E6=89=A7=E8=A1=8C=E7=AD=96=E7=95=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 将CliRunner默认执行策略从sequential改为dependency 2. 新增RunReport的任务状态查询和时长统计方法 3. 实现task装饰器并补充executor参数文档 4. 新增进程池执行器支持CPU密集型任务 5. 新增Graph.chain链式构建和add_subgraph子图合并功能 6. 新增流式任务传递、进程池执行、命名空间等多类测试用例 7. 补充tests目录路径导入配置 --- src/pyflowx/__init__.py | 2 + src/pyflowx/executors.py | 74 ++++++++++++++-- src/pyflowx/graph.py | 140 +++++++++++++++++++++++++++++- src/pyflowx/report.py | 16 ++++ src/pyflowx/runner.py | 2 +- src/pyflowx/task.py | 88 +++++++++++++++++++ tests/_proc_helper.py | 26 ++++++ tests/conftest.py | 7 ++ tests/test_chain_dsl.py | 101 ++++++++++++++++++++++ tests/test_executor_process.py | 62 ++++++++++++++ tests/test_namespace.py | 152 +++++++++++++++++++++++++++++++++ tests/test_report.py | 47 ++++++++++ tests/test_runner.py | 6 +- tests/test_streaming.py | 63 ++++++++++++++ tests/test_task_decorator.py | 136 +++++++++++++++++++++++++++++ 15 files changed, 907 insertions(+), 15 deletions(-) create mode 100644 tests/_proc_helper.py create mode 100644 tests/test_chain_dsl.py create mode 100644 tests/test_executor_process.py create mode 100644 tests/test_namespace.py create mode 100644 tests/test_streaming.py create mode 100644 tests/test_task_decorator.py diff --git a/src/pyflowx/__init__.py b/src/pyflowx/__init__.py index da1a10a..7a3c38b 100644 --- a/src/pyflowx/__init__.py +++ b/src/pyflowx/__init__.py @@ -94,6 +94,7 @@ from .task import ( TaskResult, TaskSpec, TaskStatus, + task, task_template, ) @@ -139,5 +140,6 @@ __all__ = [ "describe_injection", "run", "run_command", + "task", "task_template", ] diff --git a/src/pyflowx/executors.py b/src/pyflowx/executors.py index 083b5e6..4a67468 100644 --- a/src/pyflowx/executors.py +++ b/src/pyflowx/executors.py @@ -58,6 +58,31 @@ from .task import TaskEvent, TaskHooks, TaskResult, TaskSpec, TaskStatus logger = logging.getLogger(__name__) +# 进程池复用:同一次 run() 内的 process 任务共享一个 ProcessPoolExecutor。 +# 模块级缓存避免每次任务都创建/销毁进程池的开销。 +_process_pool: concurrent.futures.ProcessPoolExecutor | None = None +_process_pool_lock = threading.Lock() + + +def _get_process_pool() -> concurrent.futures.ProcessPoolExecutor: + """获取复用的进程池(惰性创建)。""" + global _process_pool # noqa: PLW0603 + if _process_pool is None: + with _process_pool_lock: + if _process_pool is None: + _process_pool = concurrent.futures.ProcessPoolExecutor() + return _process_pool + + +def _run_in_process(fn: Any, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + """模块级函数:在进程池中执行任务(须可 pickle)。 + + env_context 等上下文管理器无法跨进程传递,进程池任务的 ``env``/``cwd`` + 不生效;如需设置环境,应在 ``fn`` 内部自行处理。 + """ + return fn(*args, **kwargs) + + # 观察者回调类型。 EventCallback = Callable[[TaskEvent], None] Strategy = Literal["sequential", "thread", "async", "dependency"] @@ -391,19 +416,50 @@ async def _execute_async_task( loop: asyncio.AbstractEventLoop, ) -> Any: """执行异步或同步任务(带超时处理)。""" + # 异步任务直接 await if _is_async_fn(spec): coro = cast(Awaitable[Any], spec.effective_fn(*args, **kwargs)) - if spec.timeout is not None: - return await asyncio.wait_for(coro, timeout=spec.timeout) - return await coro + return await asyncio.wait_for(coro, timeout=spec.timeout) if spec.timeout is not None else await coro + + # 同步任务:根据 executor 选择执行器 + fut = _submit_sync_task(spec, args, kwargs, loop) + return await asyncio.wait_for(fut, timeout=spec.timeout) if spec.timeout is not None else await fut + + +def _submit_sync_task( + spec: TaskSpec[Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], + loop: asyncio.AbstractEventLoop, +) -> asyncio.Future[Any]: + """提交同步任务到对应执行器,返回 Future。 + + * ``inline``:直接在事件循环线程调用(阻塞循环,最快)。 + * ``process``:进程池执行(绕过 GIL,fn 须可 pickle)。 + * ``thread``(默认):线程池执行。 + """ def fn_call() -> Any: with spec.env_context(): return spec.effective_fn(*args, **kwargs) - if spec.timeout is not None: - return await asyncio.wait_for(loop.run_in_executor(None, fn_call), timeout=spec.timeout) - return await loop.run_in_executor(None, fn_call) + # inline:直接在事件循环线程调用,无线程池开销,但会阻塞循环。 + if spec.executor == "inline": + result = fn_call() + fut: asyncio.Future[Any] = loop.create_future() + fut.set_result(result) + return fut + + # process:进程池执行,绕过 GIL,适合 CPU 密集型任务(fn 须可 pickle)。 + if spec.executor == "process": + from functools import partial + + pool = _get_process_pool() + proc_fn = partial(_run_in_process, spec.effective_fn, args, kwargs) + return loop.run_in_executor(pool, proc_fn) + + # thread(默认):线程池执行。 + return loop.run_in_executor(None, fn_call) # ---------------------------------------------------------------------- # @@ -662,7 +718,7 @@ def _make_verbose_callback(on_event: EventCallback | None) -> EventCallback: def run( graph: Graph, - strategy: Strategy = "sequential", + strategy: Strategy = "dependency", *, max_workers: int | None = None, dry_run: bool = False, @@ -678,8 +734,8 @@ def run( graph: 待执行的已校验 :class:`Graph`。 strategy: - 执行策略: ``"sequential"`` / ``"thread"`` / ``"async"`` / - ``"dependency"``。``"dependency"`` 为依赖驱动调度,无层屏障。 + 执行策略: ``"dependency"``(默认,依赖驱动无层屏障,最大并行度)/ + ``"sequential"`` / ``"thread"`` / ``"async"``(层屏障模型)。 max_workers: ``"thread"`` 的线程池大小。默认 ``min(32, len(layer))``。 dry_run: diff --git a/src/pyflowx/graph.py b/src/pyflowx/graph.py index 42058a7..fd725e3 100644 --- a/src/pyflowx/graph.py +++ b/src/pyflowx/graph.py @@ -17,12 +17,13 @@ __all__ = [ "GraphDefaults", ] +import inspect import sys from dataclasses import dataclass, field, replace from typing import Any, Callable, Iterable, Mapping, Sequence from .errors import CycleError, DuplicateTaskError, MissingDependencyError -from .task import RetryPolicy, TaskSpec +from .task import Context, RetryPolicy, TaskSpec if sys.version_info >= (3, 9): # pragma: no cover import graphlib # pyright: ignore[reportUnreachable] @@ -63,6 +64,74 @@ def _prune_deps(spec: TaskSpec[Any], keep: Callable[[str], bool]) -> TaskSpec[An ) +def _make_namespaced_fn(orig_fn: Any, ns: str, dep_names: set[str]) -> Any: + """包装 fn,使其能接收带 ``ns:`` 前缀的依赖名,调用时映射回原参数名。 + + 命名空间合并后,依赖名带前缀(如 ``build:extract``),但 Python 参数名 + 不能含 ``:``。wrapper 用 ``**kwargs`` 接收所有依赖,内部把带前缀的依赖名 + 映射回原参数名后调用原 fn。 + + 无依赖参数时直接返回原 fn。 + """ + if not dep_names or orig_fn is None: + return orig_fn + try: + orig_sig = inspect.signature(orig_fn) + except (TypeError, ValueError): + return orig_fn + + # 带前缀依赖名 -> 原参数名 + name_map: dict[str, str] = {f"{ns}:{orig}": orig for orig in dep_names} + prefix = f"{ns}:" + + # 检查原 fn 是否有 Context 标注参数 + context_param_name: str | None = None + for p in orig_sig.parameters.values(): + ann = p.annotation + if ann is not Context and not (isinstance(ann, str) and ann.endswith("Context")): + continue + context_param_name = p.name + break + + if context_param_name is not None: + + def wrapper(ctx: Any = None, **kwargs: Any) -> Any: + # ctx 是 dep_context,键为带前缀的依赖名;映射回原始键 + orig_ctx: dict[str, Any] = {} + for k, v in (ctx or {}).items(): + orig_ctx[name_map.get(k, k)] = v + # kwargs 中带前缀的依赖也映射回原参数名 + for k, v in kwargs.items(): + if k in name_map: + orig_ctx[name_map[k]] = v + return orig_fn(**{context_param_name: orig_ctx}) + + ctx_param = inspect.Parameter("ctx", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Context) + kw_param = inspect.Parameter("kwargs", inspect.Parameter.VAR_KEYWORD) + wrapper.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=[ctx_param, kw_param], + return_annotation=orig_sig.return_annotation, + ) + else: + + def wrapper(**kwargs: Any) -> Any: # type: ignore[no-redef] + orig_kwargs: dict[str, Any] = {} + for k, v in kwargs.items(): + if k.startswith(prefix): + orig_kwargs[k[len(prefix) :]] = v + return orig_fn(**orig_kwargs) + + kw_param = inspect.Parameter("kwargs", inspect.Parameter.VAR_KEYWORD) + wrapper.__signature__ = inspect.Signature( # type: ignore[attr-defined] + parameters=[kw_param], + return_annotation=orig_sig.return_annotation, + ) + + wrapper.__name__ = f"{ns}_{getattr(orig_fn, '__name__', 'fn')}" + wrapper.__doc__ = getattr(orig_fn, "__doc__", None) + return wrapper + + @dataclass class Graph: """校验后的有向无环任务图。 @@ -78,6 +147,7 @@ class Graph: specs: dict[str, TaskSpec[Any]] = field(default_factory=dict) deps: dict[str, tuple[str, ...]] = field(default_factory=dict) defaults: GraphDefaults = field(default_factory=GraphDefaults) + namespace: str | None = None # 待解析的字符串引用列表(由 GraphComposer 消费);为空表示无引用。 _pending_refs: list[str] = field(default_factory=list) @@ -95,6 +165,28 @@ class Graph: self._validate_references() return self + def chain(self, *specs: TaskSpec[Any]) -> Graph: + """链式注册任务:每个 spec 自动依赖前一个。 + + ``chain(a, b, c)`` 等价于 ``b`` 依赖 ``a``,``c`` 依赖 ``b``。 + 若 spec 已带 ``depends_on``,则前驱名追加到现有依赖前。 + 返回 ``self`` 支持链式调用。 + + Examples + -------- + >>> graph = px.Graph().chain(extract, transform, load) + """ + prev_name: str | None = None + for s in specs: + current = s + if prev_name is not None: + # 将前驱追加到 depends_on 最前(保持显式依赖优先) + new_deps = (prev_name, *s.depends_on) if prev_name not in s.depends_on else s.depends_on + current = replace(s, depends_on=new_deps) + self.add(current) + prev_name = current.name + return self + def _register(self, spec: TaskSpec[Any]) -> None: if spec.name in self.specs: raise DuplicateTaskError(spec.name) @@ -108,6 +200,8 @@ class Graph: cls, specs: Iterable[TaskSpec[Any] | str], defaults: GraphDefaults | None = None, + *, + namespace: str | None = None, ) -> Graph: """从可迭代的 task spec 构建图。 @@ -120,8 +214,10 @@ class Graph: TaskSpec 对象或字符串引用的列表。 defaults: 图级默认值。``None`` 使用空 :class:`GraphDefaults`。 + namespace: + 可选命名空间,用于 :meth:`add_subgraph` 合并时加前缀。 """ - graph = cls(defaults=defaults or GraphDefaults()) + graph = cls(defaults=defaults or GraphDefaults(), namespace=namespace) pending_refs: list[str] = [] for spec in specs: @@ -139,6 +235,46 @@ class Graph: graph.validate() return graph + def add_subgraph(self, sub: Graph, *, namespace: str | None = None) -> Graph: + """将子图合并到当前图,任务名加命名空间前缀避免冲突。 + + 参数 + ---- + sub: + 待合并的子图。 + namespace: + 命名空间前缀。``None`` 时使用 ``sub.namespace``,若子图也无命名空间 + 则抛出 ``ValueError``。最终任务名为 ``f"{ns}:{original_name}"``。 + + 合并后,子图内任务的依赖名也会被加前缀;与子图外部任务的依赖保持原样。 + + 返回 ``self`` 支持链式调用。 + """ + ns = namespace or sub.namespace + if not ns: + raise ValueError("add_subgraph 需要 namespace 或子图自带 namespace") + + def _rename(name: str) -> str: + # 仅对子图内部任务名加前缀;外部依赖保持原样 + return f"{ns}:{name}" if name in sub.specs else name + + sub_names = set(sub.specs.keys()) + for spec in sub.specs.values(): + # 子图内部依赖名需加前缀,对应的 fn 参数也需包装 + internal_deps = (set(spec.depends_on) | set(spec.soft_depends_on)) & sub_names + new_fn = _make_namespaced_fn(spec.fn, ns, internal_deps) if spec.fn else spec.fn + new_spec = replace( + spec, + name=_rename(spec.name), + fn=new_fn, + depends_on=tuple(_rename(d) for d in spec.depends_on), + soft_depends_on=tuple(_rename(d) for d in spec.soft_depends_on), + ) + self._register(new_spec) + self._validate_references() + self.validate() + return self + # ------------------------------------------------------------------ # # 校验 # ------------------------------------------------------------------ # diff --git a/src/pyflowx/report.py b/src/pyflowx/report.py index bf105b3..1e6d6d4 100644 --- a/src/pyflowx/report.py +++ b/src/pyflowx/report.py @@ -69,6 +69,22 @@ class RunReport: """以 FAILED 状态结束的任务名列表。""" return [name for name, r in self.results.items() if r.status == TaskStatus.FAILED] + def succeeded_tasks(self) -> list[str]: + """以 SUCCESS 状态结束的任务名列表。""" + return [name for name, r in self.results.items() if r.status == TaskStatus.SUCCESS] + + def skipped_tasks(self) -> list[str]: + """以 SKIPPED 状态结束的任务名列表。""" + return [name for name, r in self.results.items() if r.status == TaskStatus.SKIPPED] + + def tasks_by_status(self, status: TaskStatus) -> list[str]: + """返回指定状态的任务名列表。""" + return [name for name, r in self.results.items() if r.status == status] + + def durations(self) -> dict[str, float]: + """任务名 -> 执行时长(秒)。无时长记录的为 0.0。""" + return {name: (r.duration or 0.0) for name, r in self.results.items()} + def describe(self) -> str: """用于调试的人类可读多行报告。""" lines: list[str] = [f"RunReport(success={self.success})"] diff --git a/src/pyflowx/runner.py b/src/pyflowx/runner.py index 67268ea..950b99f 100644 --- a/src/pyflowx/runner.py +++ b/src/pyflowx/runner.py @@ -114,7 +114,7 @@ class CliRunner: """ graphs: dict[str, Graph] = field(default_factory=dict) - strategy: Strategy = field(default="sequential") + strategy: Strategy = field(default="dependency") description: str = field(default_factory=str) verbose: bool = field(default_factory=lambda: True) diff --git a/src/pyflowx/task.py b/src/pyflowx/task.py index 6705c9b..2785334 100644 --- a/src/pyflowx/task.py +++ b/src/pyflowx/task.py @@ -254,6 +254,10 @@ class TaskSpec(Generic[T]): 存取状态后端,使不同输入产生独立缓存条目。``None`` 表示用任务名。 hooks: :class:`TaskHooks` 生命周期钩子。 + executor: + 同步任务的执行器:``"thread"``(默认,线程池)/ ``"process"`` + (进程池,绕过 GIL,适合 CPU 密集型;``fn`` 须可 pickle)/ + ``"inline"``(直接在事件循环线程调用,最快但会阻塞循环)。 """ name: str @@ -279,6 +283,7 @@ class TaskSpec(Generic[T]): continue_on_error: bool = False cache_key: CacheKeyFn | None = None hooks: TaskHooks = field(default_factory=TaskHooks) + executor: str = "thread" # "thread" | "process" | "inline" def __post_init__(self) -> None: if not self.name: @@ -447,6 +452,89 @@ def _env_and_cwd( # ---------------------------------------------------------------------- # # 任务模板:批量生成相似 TaskSpec 的工厂 # ---------------------------------------------------------------------- # +def _task_noop() -> None: + """task(cmd=...) 形式下的占位 fn(cmd 任务执行期不调用 fn)。""" + return None + + +def task( + fn: TaskFn[Any] | None = None, + *, + cmd: TaskCmd | None = None, + depends_on: tuple[str, ...] = (), + soft_depends_on: tuple[str, ...] = (), + defaults: Mapping[str, Any] | None = None, + args: tuple[Any, ...] = (), + kwargs: Mapping[str, Any] | None = None, + retry: RetryPolicy | None = None, + timeout: float | None = None, + tags: tuple[str, ...] = (), + conditions: tuple[Condition, ...] = (), + cwd: str | Path | None = None, + env: Mapping[str, str] | None = None, + verbose: bool = False, + skip_if_missing: bool = False, + allow_upstream_skip: bool = False, + strategy: str | None = None, + priority: int = 0, + concurrency_key: str | None = None, + continue_on_error: bool = False, + cache_key: CacheKeyFn | None = None, + hooks: TaskHooks | None = None, + name: str | None = None, +) -> Any: + """装饰器:将函数转为 :class:`TaskSpec`。 + + ``name`` 默认取 ``fn.__name__``。可直接装饰函数,或带参数使用。 + + Examples + -------- + >>> @px.task + ... def extract(): return [1, 2, 3] + >>> @px.task(depends_on=("extract",)) + ... def double(extract): return [x * 2 for x in extract] + >>> graph = px.Graph.from_specs([extract, double]) + """ + + def _decorate(func: TaskFn[Any]) -> TaskSpec[Any]: + spec_name = name or func.__name__ + return TaskSpec( + name=spec_name, + fn=func, + cmd=cmd, + depends_on=depends_on, + soft_depends_on=soft_depends_on, + defaults=dict(defaults) if defaults else {}, + args=args, + kwargs=dict(kwargs) if kwargs else {}, + retry=retry if retry is not None else RetryPolicy(), + timeout=timeout, + tags=tags, + conditions=conditions, + cwd=Path(cwd) if isinstance(cwd, str) else cwd, + env=dict(env) if env else None, + verbose=verbose, + skip_if_missing=skip_if_missing, + allow_upstream_skip=allow_upstream_skip, + strategy=strategy, + priority=priority, + concurrency_key=concurrency_key, + continue_on_error=continue_on_error, + cache_key=cache_key, + hooks=hooks if hooks is not None else TaskHooks(), + ) + + if fn is None and cmd is None: + # 带参数调用:@task(depends_on=...),等待被装饰函数 + return _decorate + if fn is None: + # task(cmd=..., name=...) 直接构造,无被装饰函数 + if name is None: + raise ValueError("task(cmd=...) 需要显式提供 name") + return _decorate(_task_noop) + return _decorate(fn) + + def task_template( fn: TaskFn[Any] | None = None, cmd: TaskCmd | None = None, diff --git a/tests/_proc_helper.py b/tests/_proc_helper.py new file mode 100644 index 0000000..6595565 --- /dev/null +++ b/tests/_proc_helper.py @@ -0,0 +1,26 @@ +"""进程池测试辅助:模块级函数(须可 pickle)。""" + +from __future__ import annotations + +import time + + +def cpu_heavy(n: int) -> int: + """CPU 密集型计算(求平方和)。""" + return sum(i * i for i in range(n)) + + +def add(a: int, b: int) -> int: + """简单加法。""" + return a + b + + +def sub(a: int, b: int) -> int: + """简单减法。""" + return a - b + + +def slow_sleep(seconds: float) -> int: + """睡眠指定秒数,用于测试超时。""" + time.sleep(seconds) + return int(seconds) diff --git a/tests/conftest.py b/tests/conftest.py index f91878c..57e0df6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,16 @@ from __future__ import annotations +import sys from pathlib import Path import pytest +# 将 tests 目录加入 sys.path,使进程池测试能 import _proc_helper 模块级辅助函数。 +# 进程池 pickle 要求被调用函数为模块级,conftest.py 在 xdist worker 中也会执行。 +_TESTS_DIR = str(Path(__file__).resolve().parent) +if _TESTS_DIR not in sys.path: + sys.path.insert(0, _TESTS_DIR) + @pytest.fixture(autouse=True) def packtool_tmp_workdir(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/tests/test_chain_dsl.py b/tests/test_chain_dsl.py new file mode 100644 index 0000000..18f1ab7 --- /dev/null +++ b/tests/test_chain_dsl.py @@ -0,0 +1,101 @@ +"""Tests for Graph.chain DSL.""" + +from __future__ import annotations + +import pyflowx as px +from pyflowx.task import TaskSpec + + +def _fn() -> None: + return None + + +def test_chain_basic_linkage() -> None: + """chain(a, b, c) 应建立 a->b->c 依赖.""" + a = TaskSpec("a", _fn) + b = TaskSpec("b", _fn) + c = TaskSpec("c", _fn) + + graph = px.Graph().chain(a, b, c) + + assert graph.all_specs()["b"].depends_on == ("a",) + assert graph.all_specs()["c"].depends_on == ("b",) + assert graph.all_specs()["a"].depends_on == () + + +def test_chain_single_spec() -> None: + """chain(a) 应只注册 a,无依赖.""" + a = TaskSpec("a", _fn) + graph = px.Graph().chain(a) + assert "a" in graph + assert graph.all_specs()["a"].depends_on == () + + +def test_chain_preserves_existing_deps() -> None: + """chain 应保留 spec 已有的 depends_on.""" + a = TaskSpec("a", _fn) + b = TaskSpec("b", _fn) + c = TaskSpec("c", _fn, depends_on=("b",)) + + graph = px.Graph().chain(a, b, c) + # c 已有 depends_on=('b',),前驱是 b,已在依赖中,不重复添加 + assert graph.all_specs()["c"].depends_on == ("b",) + + +def test_chain_merges_existing_deps() -> None: + """chain 应将前驱追加到已有依赖前(若不存在).""" + a = TaskSpec("a", _fn) + x = TaskSpec("x", _fn) + c = TaskSpec("c", _fn, depends_on=("x",)) + + graph = px.Graph().chain(a, x, c) + # c 前驱是 x,但 c 已依赖 x,不重复 + assert graph.all_specs()["c"].depends_on == ("x",) + + +def test_chain_returns_self() -> None: + """chain 返回 self 支持链式调用.""" + a = TaskSpec("a", _fn) + graph = px.Graph() + assert graph.chain(a) is graph + + +def test_chain_execution_order() -> None: + """chain 应保证执行顺序.""" + order: list[str] = [] + + def make(name: str): + def fn() -> str: + order.append(name) + return name + return fn + + a = TaskSpec("a", make("a")) + b = TaskSpec("b", make("b")) + c = TaskSpec("c", make("c")) + + graph = px.Graph().chain(a, b, c) + report = px.run(graph) + assert report.success + assert order == ["a", "b", "c"] + + +def test_chain_with_decorator_specs() -> None: + """chain 应与 @task 装饰器配合.""" + + @px.task + def extract() -> int: + return 1 + + @px.task + def transform(extract: int) -> int: + return extract + 10 + + @px.task + def load(transform: int) -> int: + return transform + 100 + + graph = px.Graph().chain(extract, transform, load) + report = px.run(graph) + assert report.success + assert report["load"] == 111 diff --git a/tests/test_executor_process.py b/tests/test_executor_process.py new file mode 100644 index 0000000..7958f3e --- /dev/null +++ b/tests/test_executor_process.py @@ -0,0 +1,62 @@ +"""Tests for process executor (spec.executor='process').""" + +from __future__ import annotations + +import pytest + +# pyrefly: ignore[missing-import] +from _proc_helper import add, cpu_heavy, slow_sleep, sub + +import pyflowx as px +from pyflowx.errors import TaskFailedError + + +def test_process_executor_runs_cpu_task() -> None: + """executor='process' 应在进程池中执行 CPU 密集型任务.""" + spec = px.TaskSpec("cpu", fn=cpu_heavy, args=(1000,), executor="process") + graph = px.Graph.from_specs([spec]) + report = px.run(graph) + assert report.success + assert report["cpu"] == sum(i * i for i in range(1000)) + + +def test_process_executor_with_dependency() -> None: + """进程池任务应支持依赖注入.""" + spec1 = px.TaskSpec("a", fn=cpu_heavy, args=(100,), executor="process") + spec2 = px.TaskSpec("b", fn=add, args=(3, 4), executor="process", depends_on=("a",)) + graph = px.Graph.from_specs([spec1, spec2]) + report = px.run(graph) + assert report.success + assert report["b"] == 7 + + +def test_process_executor_default_is_thread() -> None: + """TaskSpec.executor 默认应为 'thread'.""" + spec = px.TaskSpec("x", fn=lambda: None) + assert spec.executor == "thread" + + +def test_inline_executor_runs_in_event_loop() -> None: + """executor='inline' 应直接在事件循环线程调用.""" + spec = px.TaskSpec("inline", fn=add, args=(10, 20), executor="inline") + graph = px.Graph.from_specs([spec]) + report = px.run(graph) + assert report.success + assert report["inline"] == 30 + + +def test_process_executor_with_kwargs() -> None: + """进程池任务应支持 kwargs 注入.""" + spec = px.TaskSpec("kw", fn=sub, args=(10,), kwargs={"b": 3}, executor="process") + graph = px.Graph.from_specs([spec]) + report = px.run(graph) + assert report.success + assert report["kw"] == 7 + + +def test_process_executor_timeout() -> None: + """进程池任务超时应抛 TaskFailedError.""" + spec = px.TaskSpec("slow", fn=slow_sleep, args=(10.0,), executor="process", timeout=0.1) + graph = px.Graph.from_specs([spec]) + with pytest.raises(TaskFailedError): + px.run(graph) diff --git a/tests/test_namespace.py b/tests/test_namespace.py new file mode 100644 index 0000000..e278fc6 --- /dev/null +++ b/tests/test_namespace.py @@ -0,0 +1,152 @@ +"""Tests for Graph namespace and add_subgraph.""" + +from __future__ import annotations + +import pytest + +import pyflowx as px + + +def _fn() -> None: + return None + + +def test_graph_namespace_field_default_none() -> None: + """Graph 默认 namespace 为 None.""" + graph = px.Graph() + assert graph.namespace is None + + +def test_graph_from_specs_with_namespace() -> None: + """from_specs(namespace=...) 应设置 graph.namespace.""" + graph = px.Graph.from_specs([px.TaskSpec("a", _fn)], namespace="ns1") + assert graph.namespace == "ns1" + + +def test_add_subgraph_prefixes_task_names() -> None: + """add_subgraph 应给子图任务名加命名空间前缀.""" + sub = px.Graph.from_specs( + [px.TaskSpec("extract", _fn), px.TaskSpec("build", _fn, depends_on=("extract",))], + namespace="build", + ) + main = px.Graph.from_specs([px.TaskSpec("start", _fn)]) + main.add_subgraph(sub) + + assert "start" in main + assert "build:extract" in main + assert "build:build" in main + + +def test_add_subgraph_renames_internal_deps() -> None: + """add_subgraph 应给子图内部依赖名加前缀.""" + sub = px.Graph.from_specs( + [px.TaskSpec("a", _fn), px.TaskSpec("b", _fn, depends_on=("a",))], + namespace="ns", + ) + main = px.Graph() + main.add_subgraph(sub) + + b_spec = main.all_specs()["ns:b"] + assert b_spec.depends_on == ("ns:a",) + + +def test_add_subgraph_all_internal_deps_prefixed() -> None: + """add_subgraph 子图内所有任务(含被依赖的)都加前缀.""" + sub = px.Graph.from_specs( + [px.TaskSpec("ext", _fn), px.TaskSpec("b", _fn, depends_on=("ext",))], + namespace="ns", + ) + main = px.Graph() + main.add_subgraph(sub) + + b_spec = main.all_specs()["ns:b"] + assert b_spec.depends_on == ("ns:ext",) + assert "ns:ext" in main + + +def test_add_subgraph_requires_namespace() -> None: + """add_subgraph 无 namespace 时应抛 ValueError.""" + sub = px.Graph.from_specs([px.TaskSpec("a", _fn)]) # 无 namespace + main = px.Graph() + with pytest.raises(ValueError, match="namespace"): + main.add_subgraph(sub) + + +def test_add_subgraph_explicit_namespace_overrides() -> None: + """add_subgraph(namespace=...) 应覆盖子图自带 namespace.""" + sub = px.Graph.from_specs([px.TaskSpec("a", _fn)], namespace="original") + main = px.Graph() + main.add_subgraph(sub, namespace="override") + + assert "override:a" in main + assert "original:a" not in main + + +def test_add_subgraph_internal_injection_works() -> None: + """子图内部依赖注入应通过 wrapper 正常工作.""" + sub = px.Graph.from_specs( + [ + px.TaskSpec("extract", lambda: [1, 2, 3]), + px.TaskSpec("build", lambda extract: [x * 2 for x in extract], depends_on=("extract",)), + ], + namespace="build", + ) + main = px.Graph() + main.add_subgraph(sub) + + report = px.run(main) + assert report.success + assert report["build:build"] == [2, 4, 6] + + +def test_add_subgraph_cross_namespace_ref_via_context() -> None: + """跨命名空间引用应通过 Context 标注接收.""" + + def consumer(ctx: px.Context) -> str: + return f"got {ctx['ns:data']}" + + sub = px.Graph.from_specs( + [px.TaskSpec("data", lambda: "data_value")], + namespace="ns", + ) + main = px.Graph() + main.add_subgraph(sub) + + main.add(px.TaskSpec("consumer", consumer, depends_on=("ns:data",))) + + report = px.run(main) + assert report.success + assert report["consumer"] == "got data_value" + + +def test_add_subgraph_context_annotation_in_subgraph() -> None: + """子图内部任务用 Context 标注时,wrapper 应正确传递.""" + + def sink(ctx: px.Context) -> int: + return ctx["src"] + + sub = px.Graph.from_specs( + [ + px.TaskSpec("src", lambda: 42), + px.TaskSpec("sink", sink, depends_on=("src",)), + ], + namespace="ns", + ) + main = px.Graph() + main.add_subgraph(sub) + + report = px.run(main) + assert report.success + assert report["ns:sink"] == 42 + + +def test_add_subgraph_chained() -> None: + """多个子图可链式合并到主图.""" + sub_a = px.Graph.from_specs([px.TaskSpec("a", _fn)], namespace="nsA") + sub_b = px.Graph.from_specs([px.TaskSpec("b", _fn)], namespace="nsB") + + main = px.Graph() + main.add_subgraph(sub_a).add_subgraph(sub_b) + + assert "nsA:a" in main + assert "nsB:b" in main diff --git a/tests/test_report.py b/tests/test_report.py index b713e6d..da8f6ba 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -126,3 +126,50 @@ class TestRunReportDescribe: report.results["a"] = TaskResult[Any](spec=spec, status=TaskStatus.PENDING) desc = report.describe() assert "-" in desc # duration 显示为 "-" + + +class TestRunReportQueries: + """测试 RunReport 的新查询 API.""" + + def test_succeeded_tasks(self) -> None: + """succeeded_tasks 返回 SUCCESS 状态的任务名.""" + report = px.RunReport() + report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS) + report.results["b"] = _make_result("b", status=TaskStatus.FAILED) + report.results["c"] = _make_result("c", status=TaskStatus.SUCCESS) + assert report.succeeded_tasks() == ["a", "c"] + + def test_skipped_tasks(self) -> None: + """skipped_tasks 返回 SKIPPED 状态的任务名.""" + report = px.RunReport() + report.results["a"] = _make_result("a", status=TaskStatus.SKIPPED) + report.results["b"] = _make_result("b", status=TaskStatus.SUCCESS) + assert report.skipped_tasks() == ["a"] + + def test_tasks_by_status(self) -> None: + """tasks_by_status 按指定状态过滤.""" + report = px.RunReport() + report.results["a"] = _make_result("a", status=TaskStatus.FAILED) + report.results["b"] = _make_result("b", status=TaskStatus.FAILED) + report.results["c"] = _make_result("c", status=TaskStatus.SUCCESS) + assert report.tasks_by_status(TaskStatus.FAILED) == ["a", "b"] + assert report.tasks_by_status(TaskStatus.SUCCESS) == ["c"] + assert report.tasks_by_status(TaskStatus.SKIPPED) == [] + + def test_durations(self) -> None: + """durations 返回任务名 -> 时长映射.""" + report = px.RunReport() + report.results["a"] = _make_result("a", duration=1.5) + report.results["b"] = _make_result("b", duration=2.0) + durs = report.durations() + assert durs["a"] == 1.5 + assert durs["b"] == 2.0 + + def test_durations_no_duration(self) -> None: + """无时长的任务应返回 0.0.""" + report = px.RunReport() + spec: TaskSpec[Any] = TaskSpec[Any]("a", _fn) # type: ignore[arg-type] + report.results["a"] = TaskResult[Any](spec=spec, status=TaskStatus.PENDING) + durs = report.durations() + assert durs["a"] == 0.0 + diff --git a/tests/test_runner.py b/tests/test_runner.py index 577396c..c8a97c2 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -72,10 +72,10 @@ class TestCliRunnerConstruction: ) assert runner.commands == ["clean", "build", "test"] - def test_default_strategy_is_sequential(self) -> None: - """默认策略应为 Strategy.SEQUENTIAL.""" + def test_default_strategy_is_dependency(self) -> None: + """默认策略应为 dependency(依赖驱动,最大并行度).""" runner = px.CliRunner({"clean": _echo_graph()}) - assert runner.strategy == "sequential" + assert runner.strategy == "dependency" def test_custom_strategy_string(self) -> None: """应支持通过字符串指定策略.""" diff --git a/tests/test_streaming.py b/tests/test_streaming.py new file mode 100644 index 0000000..7572587 --- /dev/null +++ b/tests/test_streaming.py @@ -0,0 +1,63 @@ +"""Tests for streaming result passing (iterators between tasks).""" + +from __future__ import annotations + +from typing import Iterator + +import pyflowx as px + + +def test_generator_passed_as_iterator() -> None: + """上游返回生成器,下游应能惰性消费.""" + + @px.task + def source() -> Iterator[int]: + yield from range(5) + + @px.task(depends_on=("source",)) + def consume(source: Iterator[int]) -> int: + return sum(source) + + graph = px.Graph.from_specs([source, consume]) + report = px.run(graph) + assert report.success + assert report["consume"] == 10 + + +def test_large_range_streaming() -> None: + """大范围迭代器流式传递,避免中间列表.""" + + @px.task + def numbers() -> Iterator[int]: + yield from range(1000) + + @px.task(depends_on=("numbers",)) + def total(numbers: Iterator[int]) -> int: + return sum(numbers) + + graph = px.Graph.from_specs([numbers, total]) + report = px.run(graph) + assert report.success + assert report["total"] == sum(range(1000)) + + +def test_chain_multiple_streams() -> None: + """多个流式任务串联.""" + + @px.task + def gen() -> Iterator[int]: + yield from range(10) + + @px.task(depends_on=("gen",)) + def doubled(gen: Iterator[int]) -> Iterator[int]: + for x in gen: + yield x * 2 + + @px.task(depends_on=("doubled",)) + def collect(doubled: Iterator[int]) -> list[int]: + return list(doubled) + + graph = px.Graph.from_specs([gen, doubled, collect]) + report = px.run(graph) + assert report.success + assert report["collect"] == [x * 2 for x in range(10)] diff --git a/tests/test_task_decorator.py b/tests/test_task_decorator.py new file mode 100644 index 0000000..0b84a74 --- /dev/null +++ b/tests/test_task_decorator.py @@ -0,0 +1,136 @@ +"""Tests for the @task decorator API.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Mapping + +import pyflowx as px +from pyflowx.task import RetryPolicy, TaskHooks, TaskSpec + + +def test_task_decorator_plain() -> None: + """@task 无参数装饰:name 取函数名,返回 TaskSpec.""" + + @px.task + def extract() -> list[int]: + return [1, 2, 3] + + assert isinstance(extract, TaskSpec) + assert extract.name == "extract" + assert extract.fn is not None + assert extract.depends_on == () + + +def test_task_decorator_with_params() -> None: + """@task(...) 带参数装饰:传递依赖与重试.""" + + @px.task(depends_on=("extract",), retry=RetryPolicy(max_attempts=3)) + def double(extract: list[int]) -> list[int]: + return [x * 2 for x in extract] + + assert isinstance(double, TaskSpec) + assert double.name == "double" + assert double.depends_on == ("extract",) + assert double.retry.max_attempts == 3 + + +def test_task_decorator_explicit_name() -> None: + """@task(name=...) 应使用显式名称而非函数名.""" + + @px.task(name="custom_name") + def my_func() -> None: + return None + + assert my_func.name == "custom_name" + + +def test_task_decorator_cmd_form() -> None: + """@task(cmd=...) 应支持命令形式.""" + + spec = px.task(cmd=["ls", "-la"], name="list_files") + assert isinstance(spec, TaskSpec) + assert spec.name == "list_files" + assert spec.cmd == ["ls", "-la"] + + +def test_task_decorator_full_options() -> None: + """@task 应支持全部 TaskSpec 字段.""" + + @px.task( + depends_on=("a",), + soft_depends_on=("b",), + defaults={"b": 0}, + args=(1,), + kwargs={"x": 2}, + retry=RetryPolicy(max_attempts=5), + timeout=10.0, + tags=("t1",), + conditions=(px.BuiltinConditions.IS_WINDOWS,), # type: ignore[arg-type] + cwd="/tmp", + env={"K": "v"}, + verbose=True, + skip_if_missing=True, + allow_upstream_skip=True, + strategy="thread", + priority=3, + concurrency_key="db", + continue_on_error=True, + ) + def f(a: int) -> int: + return a + + assert f.depends_on == ("a",) + assert f.soft_depends_on == ("b",) + assert f.defaults == {"b": 0} + assert f.args == (1,) + assert f.kwargs == {"x": 2} + assert f.retry.max_attempts == 5 + assert f.timeout == 10.0 + assert f.tags == ("t1",) + assert len(f.conditions) == 1 + assert isinstance(f.cwd, Path) + assert f.cwd == Path("/tmp") + assert f.env == {"K": "v"} + assert f.verbose is True + assert f.skip_if_missing is True + assert f.allow_upstream_skip is True + assert f.strategy == "thread" + assert f.priority == 3 + assert f.concurrency_key == "db" + assert f.continue_on_error is True + + +def test_task_decorator_runs_in_graph() -> None: + """装饰器生成的 TaskSpec 应能直接构建图并运行.""" + + @px.task + def extract() -> list[int]: + return [1, 2, 3] + + @px.task(depends_on=("extract",)) + def double(extract: list[int]) -> list[int]: + return [x * 2 for x in extract] + + graph = px.Graph.from_specs([extract, double]) + report = px.run(graph) + assert report.success + assert report["double"] == [2, 4, 6] + + +def test_task_decorator_hooks_passthrough() -> None: + """@task(hooks=...) 应传递 TaskHooks 实例.""" + + hooks = TaskHooks(pre_run=lambda _spec: None) + spec = px.task(fn=lambda: None, hooks=hooks, name="h") + assert spec.hooks is hooks + + +def test_task_decorator_cache_key_passthrough() -> None: + """@task(cache_key=...) 应传递缓存键函数.""" + + def ck(ctx: Mapping[str, Any]) -> str: + return "k" + + spec = px.task(fn=lambda: None, cache_key=ck, name="c") + assert spec.cache_key is ck