From 5c8ec281ffb422ad710ac8ee60e4dea8b94a5c40 Mon Sep 17 00:00:00 2001 From: gooker_young Date: Sat, 27 Jun 2026 14:33:54 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E9=87=8D?= =?UTF-8?q?=E8=AF=95=E7=AD=96=E7=95=A5=E3=80=81=E6=9D=A1=E4=BB=B6=E5=87=BD?= =?UTF-8?q?=E6=95=B0=E4=B8=8E=E4=B8=8A=E4=B8=8B=E6=96=87=E6=B3=A8=E5=85=A5?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 主要变更: 1. 替换旧retries参数为RetryPolicy配置 2. 重构条件函数,支持上下文参数与动态依赖判断 3. 更新上下文注入逻辑,支持软依赖与更清晰的注入描述 4. 新增sglang CLI命令与相关配置 5. 格式化代码统一列表与参数写法 6. 更新文档与测试用例适配新API --- pyproject.toml | 1 + src/pyflowx/__init__.py | 52 +- src/pyflowx/cli/envpy.py | 6 +- src/pyflowx/cli/envqt.py | 4 +- src/pyflowx/cli/gittool.py | 13 +- src/pyflowx/cli/llm/sglang.py | 55 ++ src/pyflowx/cli/piptool.py | 62 +- src/pyflowx/conditions.py | 288 +++--- src/pyflowx/context.py | 74 +- src/pyflowx/examples/etl_pipeline.py | 4 +- src/pyflowx/executors.py | 544 ++++++++---- src/pyflowx/graph.py | 342 ++++--- src/pyflowx/storage.py | 121 ++- src/pyflowx/task.py | 416 ++++++--- tests/test_advanced_features.py | 1222 ++++++++++++++++++++++++++ tests/test_conditions.py | 220 +++-- tests/test_context.py | 2 +- tests/test_executors.py | 24 +- tests/test_executors_edge_cases.py | 20 +- tests/test_graph.py | 132 ++- tests/test_runner.py | 118 ++- tests/test_task.py | 8 +- tests/test_task_edge_cases.py | 57 +- tests/test_taskspec_commands.py | 54 +- 24 files changed, 2796 insertions(+), 1043 deletions(-) create mode 100644 src/pyflowx/cli/llm/sglang.py create mode 100644 tests/test_advanced_features.py diff --git a/pyproject.toml b/pyproject.toml index 3a99e91..2ab54e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ piptool = "pyflowx.cli.piptool:main" pymake = "pyflowx.cli.pymake:main" reseticon = "pyflowx.cli.reseticoncache:main" scrcap = "pyflowx.cli.screenshot:main" +sglang = "pyflowx.cli.llm.sglang:main" sshcopy = "pyflowx.cli.sshcopyid:main" taskk = "pyflowx.cli.taskkill:main" wch = "pyflowx.cli.which:main" diff --git a/src/pyflowx/__init__.py b/src/pyflowx/__init__.py index fbe55be..0dae8b5 100644 --- a/src/pyflowx/__init__.py +++ b/src/pyflowx/__init__.py @@ -4,9 +4,15 @@ -------- * :class:`TaskSpec` —— 不可变任务描述符(唯一需要配置的东西)。 * :class:`Graph` —— 由一组 spec 构建的 DAG;负责校验、分层、可视化。 -* :func:`run` —— 以 ``sequential`` / ``thread`` / ``async`` 策略执行图。 +* :func:`run` ——以 ``sequential`` / ``thread`` / ``async`` / ``dependency`` + 策略执行图。 * :class:`RunReport` —— 类型化、可查询的运行结果。 * :class:`Context` —— 整体上下文注入的标注标记。 +* :class:`RetryPolicy` —— 重试策略(max_attempts/delay/backoff/jitter/retry_on)。 +* :class:`TaskHooks` —— 任务生命周期钩子(pre_run/post_run/on_failure)。 +* :class:`GraphDefaults` —— 图级默认值。 +* :func:`compose` —— 编程式组合多图。 +* :func:`task_template` —— 批量生成相似 TaskSpec 的工厂。 * 状态后端::class:`StateBackend`、:class:`MemoryBackend`、:class:`JSONBackend`。 快速上手 @@ -18,7 +24,7 @@ graph = px.Graph.from_specs([ px.TaskSpec("extract", extract), - px.TaskSpec("double", double, ("extract",)), + px.TaskSpec("double", double, depends_on=("extract",)), ]) report = px.run(graph, strategy="sequential") print(report["double"]) # [2, 4, 6] @@ -29,23 +35,18 @@ from pyflowx.conditions import IS_WINDOWS, BuiltinConditions graph = px.Graph.from_specs([ - # 使用命令列表 px.TaskSpec("list_files", cmd=["ls", "-la"]), - # 使用 shell 命令 px.TaskSpec("check_git", cmd="git status"), - # 条件执行:仅在 Windows 上运行 px.TaskSpec( "win_only", cmd=["dir"], conditions=(IS_WINDOWS,) ), - # 条件执行:仅在 git 已安装时运行 px.TaskSpec( "git_check", cmd=["git", "--version"], conditions=(BuiltinConditions.HAS_INSTALLED("git"),) ), - # 命令不存在时自动跳过(而非失败) px.TaskSpec( "optional_build", cmd=["maturin", "build"], @@ -58,6 +59,10 @@ from __future__ import annotations from .conditions import ( + IS_LINUX, + IS_MACOS, + IS_POSIX, + IS_WINDOWS, BuiltinConditions, Condition, Constants, @@ -74,20 +79,33 @@ from .errors import ( TaskTimeoutError, ) from .executors import Strategy, run -from .graph import Graph, GraphComposer +from .graph import Graph, GraphComposer, GraphDefaults, compose from .report import RunReport from .runner import CliExitCode, CliRunner from .storage import JSONBackend, MemoryBackend, StateBackend -from .task import TaskCmd, TaskEvent, TaskResult, TaskSpec, TaskStatus +from .task import ( + CacheKeyFn, + RetryPolicy, + TaskCmd, + TaskEvent, + TaskHooks, + TaskResult, + TaskSpec, + TaskStatus, + task_template, +) -__version__ = "0.2.6" +__version__ = "0.3.0" __all__ = [ + "IS_LINUX", + "IS_MACOS", + "IS_POSIX", + "IS_WINDOWS", "BuiltinConditions", + "CacheKeyFn", "CliExitCode", - # CLI 运行器 "CliRunner", - # 条件判断 "Condition", "Constants", "Context", @@ -95,28 +113,28 @@ __all__ = [ "DuplicateTaskError", "Graph", "GraphComposer", + "GraphDefaults", "InjectionError", "JSONBackend", "MemoryBackend", "MissingDependencyError", - # 错误 "PyFlowXError", + "RetryPolicy", "RunReport", - # 状态后端 "StateBackend", "StorageError", "Strategy", "TaskCmd", "TaskEvent", "TaskFailedError", + "TaskHooks", "TaskResult", - # 核心类型 "TaskSpec", "TaskStatus", "TaskTimeoutError", - # 辅助(高级) "build_call_args", + "compose", "describe_injection", - # 执行 "run", + "task_template", ] diff --git a/src/pyflowx/cli/envpy.py b/src/pyflowx/cli/envpy.py index f84fceb..3f51238 100644 --- a/src/pyflowx/cli/envpy.py +++ b/src/pyflowx/cli/envpy.py @@ -112,9 +112,9 @@ def main() -> None: args = parser.parse_args() if args.command == "mirror": - graph = px.Graph.from_specs( - [px.TaskSpec("set_pip_mirror", fn=set_pip_mirror, args=(args.name,), kwargs={"token": args.token})] - ) + graph = px.Graph.from_specs([ + px.TaskSpec("set_pip_mirror", fn=set_pip_mirror, args=(args.name,), kwargs={"token": args.token}) + ]) else: parser.print_help() return diff --git a/src/pyflowx/cli/envqt.py b/src/pyflowx/cli/envqt.py index 8eaff49..7d6b9d2 100644 --- a/src/pyflowx/cli/envqt.py +++ b/src/pyflowx/cli/envqt.py @@ -43,13 +43,13 @@ def main() -> None: px.TaskSpec( "envqt_install", cmd=["sudo", "apt", "install", "-y", *QT_LIBS], - conditions=(lambda: Constants.IS_LINUX,), + conditions=(lambda _: Constants.IS_LINUX,), verbose=True, ), px.TaskSpec( "envqt_fonts", cmd=["sudo", "apt", "install", "-y", *CHINESE_FONTS], - conditions=(lambda: Constants.IS_LINUX,), + conditions=(lambda _: Constants.IS_LINUX,), verbose=True, ), ], diff --git a/src/pyflowx/cli/gittool.py b/src/pyflowx/cli/gittool.py index bb45a57..7577a50 100644 --- a/src/pyflowx/cli/gittool.py +++ b/src/pyflowx/cli/gittool.py @@ -37,7 +37,7 @@ def init_sub_dirs() -> None: px.TaskSpec( "init", cmd=["git", "init"], - conditions=(not_has_git_repo,), + conditions=(lambda _: not_has_git_repo(),), cwd=subdir, ), px.TaskSpec("add", cmd=["git", "add", "."], depends_on=("init",)), @@ -70,7 +70,7 @@ def main() -> None: graphs={ # 添加并提交 "a": px.Graph.from_specs([ - px.TaskSpec("add", cmd=["git", "add", "."], conditions=(has_files,)), + px.TaskSpec("add", cmd=["git", "add", "."], conditions=(lambda _: has_files(),)), px.TaskSpec("commit", cmd=["git", "commit", "-m", "chore: update"], depends_on=("add",)), ]), # 清理 @@ -80,10 +80,13 @@ def main() -> None: ]), # 初始化、添加并提交 "i": px.Graph.from_specs([ - px.TaskSpec("init", cmd=["git", "init"], conditions=(not_has_git_repo,)), - px.TaskSpec("add", cmd=["git", "add", "."], depends_on=("init",), conditions=(has_files,)), + px.TaskSpec("init", cmd=["git", "init"], conditions=(lambda _: not_has_git_repo(),)), + px.TaskSpec("add", cmd=["git", "add", "."], depends_on=("init",), conditions=(lambda _: has_files(),)), px.TaskSpec( - "commit", cmd=["git", "commit", "-m", "init commit"], depends_on=("add",), conditions=(has_files,) + "commit", + cmd=["git", "commit", "-m", "init commit"], + depends_on=("add",), + conditions=(lambda _: has_files(),), ), ]), # 初始化子目录 diff --git a/src/pyflowx/cli/llm/sglang.py b/src/pyflowx/cli/llm/sglang.py new file mode 100644 index 0000000..6993fe4 --- /dev/null +++ b/src/pyflowx/cli/llm/sglang.py @@ -0,0 +1,55 @@ +"""使用 SGLang 运行本地模型.""" + +import argparse +from pathlib import Path + +import pyflowx as px +from pyflowx.conditions import BuiltinConditions + + +def main(): + parser = argparse.ArgumentParser(description="Run a local model using SGLang.") + parser.add_argument("name", help="Model name.") + parser.add_argument("--dir", default=None, help="Directory of model.") + args = parser.parse_args() + + if not args.name: + parser.error("name is required") + + model_dir = Path(args.dir) if args.dir else Path.home() / ".models" / args.name.split("/")[-1] + if not model_dir.exists(): + parser.error(f"Model directory {model_dir} does not exist.") + + graph = px.Graph.from_specs([ + px.TaskSpec( + name="download", + cmd=[ + "uv", + "install", + "sglang[all]", + ], + conditions=(BuiltinConditions.NOT(BuiltinConditions.HAS_INSTALLED("sglang")),), + verbose=True, + ), + px.TaskSpec( + name="run", + cmd=[ + "uvx", + "sglang", + "serve", + "--model-path", + str(model_dir), + "--host", + "0.0.0.0", + "--port", + "8000", + "--mem-fraction-static", + "0.88", + "--context-length", + "32768", + ], + verbose=True, + ), + ]) + + px.run(graph, verbose=True) diff --git a/src/pyflowx/cli/piptool.py b/src/pyflowx/cli/piptool.py index c8d795c..e779745 100644 --- a/src/pyflowx/cli/piptool.py +++ b/src/pyflowx/cli/piptool.py @@ -21,12 +21,10 @@ PACKAGE_DIR = "packages" REQUIREMENTS_FILE = "requirements.txt" # 受保护的包名集合 -_PROTECTED_PACKAGES: frozenset[str] = frozenset( - { - "pyflowx", - "bitool", - } -) +_PROTECTED_PACKAGES: frozenset[str] = frozenset({ + "pyflowx", + "bitool", +}) # ============================================================================ @@ -161,37 +159,33 @@ def main() -> None: if args.command == "i": graph = px.Graph.from_specs([px.TaskSpec("pip_install", cmd=["pip", "install", *args.packages], verbose=True)]) elif args.command == "u": - graph = px.Graph.from_specs( - [px.TaskSpec("pip_uninstall", fn=pip_uninstall, args=(args.packages,), verbose=True)] - ) + graph = px.Graph.from_specs([ + px.TaskSpec("pip_uninstall", fn=pip_uninstall, args=(args.packages,), verbose=True) + ]) elif args.command == "r": - graph = px.Graph.from_specs( - [ - px.TaskSpec( - "pip_reinstall", - fn=pip_reinstall, - args=(args.packages,), - kwargs={"offline": args.offline}, - verbose=True, - ) - ] - ) + graph = px.Graph.from_specs([ + px.TaskSpec( + "pip_reinstall", + fn=pip_reinstall, + args=(args.packages,), + kwargs={"offline": args.offline}, + verbose=True, + ) + ]) elif args.command == "d": - graph = px.Graph.from_specs( - [ - px.TaskSpec( - "pip_download", - fn=pip_download, - args=(args.packages,), - kwargs={"offline": args.offline}, - verbose=True, - ) - ] - ) + graph = px.Graph.from_specs([ + px.TaskSpec( + "pip_download", + fn=pip_download, + args=(args.packages,), + kwargs={"offline": args.offline}, + verbose=True, + ) + ]) elif args.command == "up": - graph = px.Graph.from_specs( - [px.TaskSpec("pip_upgrade", cmd=["python", "-m", "pip", "install", "--upgrade", "pip"], verbose=True)] - ) + graph = px.Graph.from_specs([ + px.TaskSpec("pip_upgrade", cmd=["python", "-m", "pip", "install", "--upgrade", "pip"], verbose=True) + ]) elif args.command == "f": graph = px.Graph.from_specs([px.TaskSpec("pip_freeze", fn=pip_freeze, verbose=True)]) else: diff --git a/src/pyflowx/conditions.py b/src/pyflowx/conditions.py index a094402..0c3f976 100644 --- a/src/pyflowx/conditions.py +++ b/src/pyflowx/conditions.py @@ -1,7 +1,12 @@ """条件判断模块. -提供平台条件、应用安装条件等预定义条件判断函数, -用于 TaskSpec 的条件执行功能. +所有条件均为 ``Callable[[Context], bool]``,接收依赖上下文映射(可能为空)。 +这使得条件可基于上游任务的运行时返回值做决策,实现动态分支。 + +内置条件分两类: +1. *静态条件* —— 不依赖上下文(平台/环境变量/安装检查),通过 ``_static`` + 包装忽略传入的 context,便于作为模块级常量使用。 +2. *上下文条件* —— 基于上游结果判断,如 :meth:`BuiltinConditions.DEP_EQUALS`。 """ from __future__ import annotations @@ -11,10 +16,11 @@ import shutil import subprocess import sys from pathlib import Path -from typing import Callable +from typing import Any, Callable -# 条件判断函数类型 -Condition = Callable[[], bool] +from .task import Condition, Context + +__all__ = ["BuiltinConditions", "Condition", "Constants"] class Constants: @@ -26,65 +32,56 @@ class Constants: IS_POSIX: bool = sys.platform != "win32" +def _static(predicate: Callable[[], bool], name: str) -> Condition: + """将无参谓词包装为忽略上下文的 :class:`Condition`。""" + + def _cond(_ctx: Context) -> bool: + return predicate() + + _cond.__name__ = name + return _cond + + +# ---------------------------------------------------------------------- # +# 模块级静态条件常量 +# ---------------------------------------------------------------------- # +IS_WINDOWS: Condition = _static(lambda: Constants.IS_WINDOWS, "IS_WINDOWS") +IS_LINUX: Condition = _static(lambda: Constants.IS_LINUX, "IS_LINUX") +IS_MACOS: Condition = _static(lambda: Constants.IS_MACOS, "IS_MACOS") +IS_POSIX: Condition = _static(lambda: Constants.IS_POSIX, "IS_POSIX") + + class BuiltinConditions: - """内置条件判断函数集合.""" + """内置条件判断函数集合. + 静态条件工厂返回忽略上下文的 :class:`Condition`;上下文条件工厂返回 + 会读取依赖结果的 :class:`Condition`。 + """ + + # ------------------------------------------------------------------ # + # 静态条件 + # ------------------------------------------------------------------ # @staticmethod - def PYTHON_VERSION(major: int, minor: int | None = None) -> bool: - """检查 Python 版本是否匹配. - - Parameters - ---------- - major : int - 主版本号. - minor : int | None - 次版本号, 若为 None 则仅检查主版本. - - Returns - ------- - bool - 版本是否匹配. - """ + def PYTHON_VERSION(major: int, minor: int | None = None) -> Condition: + """检查 Python 版本是否匹配.""" if minor is None: - return sys.version_info.major == major - return sys.version_info.major == major and sys.version_info.minor == minor + return _static(lambda: sys.version_info.major == major, f"PYTHON_VERSION({major})") + return _static( + lambda: sys.version_info.major == major and sys.version_info.minor == minor, + f"PYTHON_VERSION({major},{minor})", + ) @staticmethod - def PYTHON_VERSION_AT_LEAST(major: int, minor: int = 0) -> bool: - """检查 Python 版本是否 >= 指定版本. - - Parameters - ---------- - major : int - 主版本号. - minor : int - 次版本号. - - Returns - ------- - bool - 当前版本是否 >= 指定版本. - """ - return sys.version_info >= (major, minor) + def PYTHON_VERSION_AT_LEAST(major: int, minor: int = 0) -> Condition: + """检查 Python 版本是否 >= 指定版本.""" + return _static(lambda: sys.version_info >= (major, minor), f"PYTHON_VERSION_AT_LEAST({major},{minor})") @staticmethod def IS_RUNNING(app_name: str) -> Condition: - """检查指定应用是否正在运行. - - Parameters - ---------- - app_name : str - 应用名称 (如 "explorer", "chrome", "python"). - - Returns - ------- - Condition - 条件判断函数. - """ + """检查指定应用是否正在运行.""" def _check() -> bool: if Constants.IS_WINDOWS: - # Windows: 使用 tasklist 命令 result = subprocess.run( ["tasklist", "/nh", "/fi", f"imagename eq {app_name}"], capture_output=True, @@ -93,148 +90,119 @@ class BuiltinConditions: ) return app_name.lower() in result.stdout.lower() else: - # Linux/macOS: 使用 pgrep 命令 - result = subprocess.run( - ["pgrep", "-x", app_name], - capture_output=True, - check=False, - ) + result = subprocess.run(["pgrep", "-x", app_name], capture_output=True, check=False) return result.returncode == 0 - _check.__name__ = f"IS_RUNNING({app_name!r})" - return _check + return _static(_check, f"IS_RUNNING({app_name!r})") @staticmethod def HAS_INSTALLED(app_name: str) -> Condition: - """检查指定应用是否已安装. - - Parameters - ---------- - app_name : str - 应用名称 (如 "git", "python", "pytest"). - - Returns - ------- - Condition - 条件判断函数. - """ - - def _check() -> bool: - return shutil.which(app_name) is not None - - _check.__name__ = f"HAS_INSTALLED({app_name!r})" - return _check + """检查指定应用是否已安装.""" + return _static(lambda: shutil.which(app_name) is not None, f"HAS_INSTALLED({app_name!r})") @staticmethod - def DIR_EXISTS(dir: Path) -> Condition: + def DIR_EXISTS(path: Path) -> Condition: """路径是否存在.""" - return dir.exists + return _static(path.exists, f"DIR_EXISTS({path!r})") @staticmethod def ENV_VAR_EXISTS(var_name: str) -> Condition: - """检查环境变量是否存在. - - Parameters - ---------- - var_name : str - 环境变量名. - - Returns - ------- - Condition - 条件判断函数. - """ - - def _check() -> bool: - return var_name in os.environ - - _check.__name__ = f"ENV_VAR_EXISTS({var_name!r})" - return _check + """检查环境变量是否存在.""" + return _static(lambda: var_name in os.environ, f"ENV_VAR_EXISTS({var_name!r})") @staticmethod def ENV_VAR_EQUALS(var_name: str, value: str) -> Condition: - """检查环境变量是否等于指定值. + """检查环境变量是否等于指定值.""" + return _static( + lambda: os.environ.get(var_name) == value, + f"ENV_VAR_EQUALS({var_name!r},{value!r})", + ) - Parameters - ---------- - var_name : str - 环境变量名. - value : str - 期望的值. + # ------------------------------------------------------------------ # + # 上下文条件:基于上游依赖结果 + # ------------------------------------------------------------------ # + @staticmethod + def DEP_EQUALS(dep_name: str, value: Any) -> Condition: + """上游任务 ``dep_name`` 的返回值等于 ``value`` 时为真。 - Returns - ------- - Condition - 条件判断函数. + 若依赖未在上下文中(被跳过或未执行),返回 ``False``。 """ - def _check() -> bool: - return os.environ.get(var_name) == value + def _cond(ctx: Context) -> bool: + return dep_name in ctx and ctx[dep_name] == value - _check.__name__ = f"ENV_VAR_EQUALS({var_name!r}, {value!r})" - return _check + _cond.__name__ = f"DEP_EQUALS({dep_name!r},{value!r})" + return _cond @staticmethod - def NOT(condition: Condition) -> Condition: - """对条件取反. + def DEP_MATCHES(dep_name: str, predicate: Callable[[Any], bool]) -> Condition: + """上游任务 ``dep_name`` 的返回值满足 ``predicate`` 时为真。 - Parameters - ---------- - condition : Condition - 原始条件. - - Returns - ------- - Condition - 取反后的条件. + 依赖不存在时返回 ``False``。 """ - def _check() -> bool: - return not condition() + def _cond(ctx: Context) -> bool: + if dep_name not in ctx: + return False + try: + return predicate(ctx[dep_name]) + except Exception: + return False - _check.__name__ = f"NOT({getattr(condition, '__name__', repr(condition))})" - return _check + _cond.__name__ = f"DEP_MATCHES({dep_name!r},{getattr(predicate, '__name__', 'pred')})" + return _cond + + @staticmethod + def DEP_PRESENT(dep_name: str) -> Condition: + """上游任务 ``dep_name`` 存在于上下文(即已成功执行)时为真。""" + + def _cond(ctx: Context) -> bool: + return dep_name in ctx and ctx[dep_name] is not None + + _cond.__name__ = f"DEP_PRESENT({dep_name!r})" + return _cond + + @staticmethod + def DEP_TRUTHY(dep_name: str) -> Condition: + """上游任务 ``dep_name`` 的返回值为真值时为真。""" + + def _cond(ctx: Context) -> bool: + return bool(ctx.get(dep_name)) + + _cond.__name__ = f"DEP_TRUTHY({dep_name!r})" + return _cond + + # ------------------------------------------------------------------ # + # 逻辑组合 + # ------------------------------------------------------------------ # + @staticmethod + def NOT(condition: Condition) -> Condition: + """对条件取反.""" + + def _cond(ctx: Context) -> bool: + return not condition(ctx) + + _cond.__name__ = f"NOT({getattr(condition, '__name__', repr(condition))})" + return _cond @staticmethod def AND(*conditions: Condition) -> Condition: - """多个条件的逻辑与. + """多个条件的逻辑与.""" - Parameters - ---------- - *conditions : Condition - 条件列表. - - Returns - ------- - Condition - 组合条件. - """ - - def _check() -> bool: - return all(c() for c in conditions) + def _cond(ctx: Context) -> bool: + return all(c(ctx) for c in conditions) names = [getattr(c, "__name__", repr(c)) for c in conditions] - _check.__name__ = f"AND({', '.join(names)})" - return _check + _cond.__name__ = f"AND({', '.join(names)})" + return _cond @staticmethod def OR(*conditions: Condition) -> Condition: - """多个条件的逻辑或. + """多个条件的逻辑或.""" - Parameters - ---------- - *conditions : Condition - 条件列表. - - Returns - ------- - Condition - 组合条件. - """ - - def _check() -> bool: - return any(c() for c in conditions) + def _cond(ctx: Context) -> bool: + return any(c(ctx) for c in conditions) names = [getattr(c, "__name__", repr(c)) for c in conditions] - _check.__name__ = f"OR({', '.join(names)})" - return _check + _cond.__name__ = f"OR({', '.join(names)})" + return _cond diff --git a/src/pyflowx/context.py b/src/pyflowx/context.py index 3374c30..840758c 100644 --- a/src/pyflowx/context.py +++ b/src/pyflowx/context.py @@ -1,18 +1,16 @@ """上下文注入:把上游结果转换为函数参数。 本机制让用户可以编写普通函数,其参数名*就是*依赖声明,从而消除其他 -DAG 库中泛滥的样板包装器(如 ``def wrapper(): return fn(workflow.get_task_result('x'))``)。 +DAG 库中泛滥的样板包装器。 注入规则(按顺序求值) ---------------------- -1. **标注为** :class:`Context` 的参数接收完整结果映射。适用于需要遍历 - 所有输入的任务。 -2. **名称匹配某个依赖**的参数接收该依赖的结果。 +1. **标注为** :class:`Context` 的参数接收完整结果映射(含硬依赖与软依赖)。 +2. **名称匹配某个依赖**(硬或软)的参数接收该依赖的结果。 3. ``**kwargs`` 参数以 dict 形式接收*所有*依赖结果。 4. ``TaskSpec.args`` / ``TaskSpec.kwargs`` 为*非依赖*参数提供静态值。 -若某参数无法解析且无默认值,则抛出 :class:`~pyflowx.errors.InjectionError`, -并附带精确错误信息。 +若某参数无法解析且无默认值,则抛出 :class:`~pyflowx.errors.InjectionError`。 """ from __future__ import annotations @@ -27,21 +25,11 @@ __all__ = ["Context", "_is_context_annotation", "build_call_args", "describe_inj def _is_context_annotation(annotation: Any) -> bool: - """判断参数标注是否为(或指向)``Context``。 - - 处理三种形式: - * ``Context`` 别名对象本身; - * ``__name__``/``_name`` 为 ``Context`` 或 ``Mapping`` 的 typing 别名; - * *字符串*标注(``from __future__ import annotations`` 会在运行时 - 把所有标注变为字符串),如 ``"Context"`` 或 ``"px.Context"``。 - """ + """判断参数标注是否为(或指向)``Context``。""" if annotation is Context: return True - # `from __future__ import annotations` 产生的字符串标注。 if isinstance(annotation, str): - # 匹配 "Context"、"px.Context"、"pyflowx.Context" 等。 return annotation == "Context" or annotation.endswith(".Context") - # 按限定名匹配,支持 ``from pyflowx import Context`` 再导出。 name = getattr(annotation, "__name__", None) or getattr(annotation, "_name", None) return name in ("Context", "Mapping") @@ -52,39 +40,22 @@ def build_call_args( ) -> tuple[tuple[Any, ...], dict[str, Any]]: """解析用于调用 ``spec.fn`` 的 ``(args, kwargs)``。 - 参数 - ---- - spec: - 任务 spec,提供 ``fn``、``depends_on``、``args``、``kwargs``。 - context: - 依赖名 -> 结果值的映射。仅保证本任务自身的 ``depends_on`` 条目 - 存在;其他任务的结果被排除,以保持注入的确定性。 - - 返回 - ---- - (args, kwargs) - 可直接展开为 ``spec.fn(*args, **kwargs)``。 - - 抛出 - ---- - InjectionError - 若必需参数无法满足,或静态 ``kwargs`` 与注入依赖名冲突。 + ``context`` 必须已包含所有硬依赖与软依赖的结果(软依赖被跳过时由 + 执行器填入 :attr:`TaskSpec.defaults` 中的默认值)。 """ - # 使用 effective_fn 而不是 fn,以支持 cmd 参数 fn = spec.effective_fn sig = inspect.signature(fn) params = sig.parameters - # 检测特殊参数类型。 var_keyword = next( (p for p in params.values() if p.kind == inspect.Parameter.VAR_KEYWORD), None, ) - # 与本任务相关的上下文子集。 - dep_context: dict[str, Any] = {name: context[name] for name in spec.depends_on if name in context} + # 本任务相关的上下文子集:硬依赖 + 软依赖。 + all_deps = set(spec.depends_on) | set(spec.soft_depends_on) + dep_context: dict[str, Any] = {name: context[name] for name in all_deps if name in context} - # 检测静态 kwargs 与依赖名的冲突。 collisions = set(spec.kwargs) & set(dep_context) if collisions: raise InjectionError( @@ -96,8 +67,6 @@ def build_call_args( injected_kwargs: dict[str, Any] = {} leftover_dep_results: dict[str, Any] = dict(dep_context) - # 被 spec.args 消费的位置参数。记录哪些参数名已被位置填充, - # 以便在基于名称的注入(依赖 / Context / 静态 kwargs)时跳过。 positional_params: list[str] = [] positional_kinds = ( inspect.Parameter.POSITIONAL_ONLY, @@ -106,33 +75,25 @@ def build_call_args( for pname, param in params.items(): if param.kind in positional_kinds: positional_params.append(pname) - # 前 len(spec.args) 个位置参数由 spec.args 填充。 args_filled: set[str] = set(positional_params[: len(spec.args)]) for pname, param in params.items(): - # 跳过已被位置 spec.args 填充的参数。 if pname in args_filled: continue - # 规则 1:标注为 Context -> 完整映射。 if _is_context_annotation(param.annotation): injected_kwargs[pname] = dep_context continue - # 规则 2:名称匹配某个依赖。 if pname in dep_context: injected_kwargs[pname] = dep_context[pname] leftover_dep_results.pop(pname, None) continue - # 规则 3:在循环后通过 **kwargs 处理。 - - # 规则 4:静态 kwargs 填充其余参数。 if pname in spec.kwargs: injected_kwargs[pname] = spec.kwargs[pname] continue - # 该参数无来源:必须有默认值,否则报错。 if param.default is inspect.Parameter.empty and param.kind not in ( inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD, @@ -142,9 +103,7 @@ def build_call_args( f"parameter {pname!r} has no dependency, static value, or default.", ) - # 规则 3:**kwargs 吞掉剩余依赖结果。 if var_keyword is not None and leftover_dep_results: - # 先合并静态 kwargs,再合并依赖结果(冲突已在上方拒绝)。 merged = dict(spec.kwargs) merged.update(injected_kwargs) merged.update(leftover_dep_results) @@ -154,14 +113,9 @@ def build_call_args( def describe_injection(spec: TaskSpec[Any]) -> str: - """生成任务参数注入方式的人类可读描述。 - - 供 ``dry_run`` 使用,在不执行的情况下展示执行计划。 - """ - # 使用 effective_fn 而不是 fn,以支持 cmd 参数 + """生成任务参数注入方式的人类可读描述。供 ``dry_run`` 使用。""" fn = spec.effective_fn sig = inspect.signature(fn) - # 确定哪些位置参数由 spec.args 填充。 positional_params = [ p for p, param in sig.parameters.items() @@ -172,6 +126,7 @@ def describe_injection(spec: TaskSpec[Any]) -> str: ) ] args_filled = set(positional_params[: len(spec.args)]) + all_deps = set(spec.depends_on) | set(spec.soft_depends_on) parts = [] for pname, param in sig.parameters.items(): if pname in args_filled: @@ -179,8 +134,9 @@ def describe_injection(spec: TaskSpec[Any]) -> str: parts.append(f"{pname}={spec.args[idx]!r}") elif _is_context_annotation(param.annotation): parts.append(f"{pname}=") - elif pname in spec.depends_on: - parts.append(f"{pname}=") + elif pname in all_deps: + tag = "soft" if pname in spec.soft_depends_on else "dep" + parts.append(f"{pname}=<{tag}:{pname}>") elif pname in spec.kwargs: parts.append(f"{pname}={spec.kwargs[pname]!r}") elif param.default is not inspect.Parameter.empty: diff --git a/src/pyflowx/examples/etl_pipeline.py b/src/pyflowx/examples/etl_pipeline.py index 42ae6c5..73a9a44 100644 --- a/src/pyflowx/examples/etl_pipeline.py +++ b/src/pyflowx/examples/etl_pipeline.py @@ -55,7 +55,9 @@ def main() -> None: depends_on=("extract_customers", "extract_orders"), tags=("transform",), ), - px.TaskSpec("load", load, depends_on=("transform",), retries=1, tags=("load",)), + px.TaskSpec( + "load", load, depends_on=("transform",), retry=px.RetryPolicy(max_attempts=1, delay=1.0), tags=("load",) + ), ]) print("=== Execution plan ===") diff --git a/src/pyflowx/executors.py b/src/pyflowx/executors.py index 82e0459..31404e9 100644 --- a/src/pyflowx/executors.py +++ b/src/pyflowx/executors.py @@ -1,15 +1,26 @@ """执行器与公共 :func:`run` 入口。 -三种执行策略共享一个逐层驱动器: +四种执行策略: * ``sequential`` —— 确定性、一次一个任务。最适合调试。 * ``thread`` —— 通过线程池实现层内并发。最适合 I/O 密集型同步任务。 * ``async`` —— 通过 ``asyncio.gather`` 实现层内并发。同步任务被 卸载到线程池;异步任务运行在事件循环上。最适合 I/O 密集型异步任务。 +* ``dependency`` —— 依赖驱动调度:任务在其所有硬依赖完成后立即启动, + 无需等待同层其他任务。最大化并行度。 -三者都遵循 ``retries``、``timeout``、上下文注入、状态后端(续跑), -并向观察者发出 :class:`~pyflowx.task.TaskEvent`。 +所有策略共享统一异步内核,支持: +* :class:`RetryPolicy`(max_attempts/delay/backoff/jitter/retry_on) +* 软依赖注入与默认值 +* :class:`TaskHooks`(pre_run/post_run/on_failure) +* 按任务策略覆盖 +* 优先级排序(同层内) +* 并发限制(concurrency_key + concurrency_limits) +* ``continue_on_error`` +* ``cache_key`` 存储键 +* 条件判断(上下文感知) +* 状态后端(续跑) """ from __future__ import annotations @@ -18,6 +29,7 @@ import asyncio import concurrent.futures import inspect import logging +import threading from datetime import datetime from typing import Any, Awaitable, Callable, Literal, Mapping, cast @@ -26,24 +38,24 @@ from .errors import TaskFailedError, TaskTimeoutError from .graph import Graph from .report import RunReport from .storage import StateBackend, resolve_backend -from .task import TaskEvent, TaskResult, TaskSpec, TaskStatus +from .task import TaskEvent, TaskHooks, TaskResult, TaskSpec, TaskStatus logger = logging.getLogger("pyflowx") # 观察者回调类型。 EventCallback = Callable[[TaskEvent], None] -Strategy = Literal["sequential", "thread", "async"] +Strategy = Literal["sequential", "thread", "async", "dependency"] +# ---------------------------------------------------------------------- # +# 辅助 +# ---------------------------------------------------------------------- # def _is_async_fn(spec: TaskSpec[Any]) -> bool: """判断 ``spec.effective_fn`` 是否为协程函数。""" return inspect.iscoroutinefunction(spec.effective_fn) -def _emit( - on_event: EventCallback | None, - result: TaskResult[Any], -) -> None: +def _emit(on_event: EventCallback | None, result: TaskResult[Any]) -> None: """若注册了回调则触发一个观察者事件。""" if on_event is None: return @@ -59,71 +71,60 @@ def _emit( ) -def _log_retry(spec: TaskSpec[Any], attempts: int, max_attempts: int, exc: BaseException) -> None: - """记录重试日志(sync 与 async 共享,便于测试覆盖)。""" +def _log_retry(spec: TaskSpec[Any], attempt: int, max_attempts: int, exc: BaseException) -> None: + """记录重试日志。""" logger.warning( "task %r failed (attempt %d/%d): %r; retrying", spec.name, - attempts, + attempt, max_attempts, exc, ) -def _finalize_failure( - result: TaskResult[Any], - layer_idx: int | None, - on_event: EventCallback | None = None, -) -> None: - """标记任务为 FAILED 并抛出 TaskFailedError。""" - result.status = TaskStatus.FAILED - result.finished_at = datetime.now() - _emit(on_event, result) - raise TaskFailedError( - task=result.spec.name, - cause=result.error if result.error is not None else RuntimeError("unknown"), - attempts=result.attempts, - layer=layer_idx, - ) +def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None: + """安全调用钩子(异常仅记录,不影响任务状态)。""" + hook: Callable[..., None] | None = getattr(hooks, fn_name, None) + if hook is None: + return + try: + hook(*args) + except Exception as exc: + logger.warning("hook %s raised: %r", fn_name, exc) def _check_upstream_skipped( spec: TaskSpec[Any], report: RunReport | None, ) -> tuple[bool, str | None]: - """检查上游任务是否被 SKIPPED。 + """检查硬依赖上游任务是否被 SKIPPED 或 FAILED。 - Returns - ------- - tuple[bool, str | None] - (是否应该跳过, 跳过原因) + 软依赖不影响本检查——软依赖被跳过时注入默认值。 """ if report is None: return False, None - # 若任务允许上游跳过,则不检查上游状态 if spec.allow_upstream_skip: return False, None for dep in spec.depends_on: - if dep in report.results and report.results[dep].status == TaskStatus.SKIPPED: - return True, f"上游任务 '{dep}' 被跳过" + if dep not in report.results: + continue + dep_status = report.results[dep].status + if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED): + return True, f"上游任务 '{dep}' 状态为 {dep_status.value}" return False, None -def _evaluate_skip_reason(spec: TaskSpec[Any]) -> str | None: - """单次求值所有条件与 skip_if_missing,返回跳过原因或 None。 +def _evaluate_conditions(spec: TaskSpec[Any], context: Mapping[str, Any]) -> str | None: + """求值所有条件,返回跳过原因或 ``None``。 - 与旧实现不同:条件只求值一次。`should_execute()` 内部会调用所有条件, - 若再分支调用 `_is_cmd_available` 之外的逻辑会二次求值(如 - ``IS_RUNNING`` 会 spawn 两次 subprocess)。此处显式逐个求值并记录结果, - 失败原因直接来自求值过程,无需二次调用。 + 条件接收上下文映射(硬依赖 + 软依赖结果)。 """ - # 1. 逐个求值条件,记录失败项。 failed_conditions: list[str] = [] for condition in spec.conditions: try: - ok = condition() + ok = condition(context) except Exception: ok = False name = getattr(condition, "__name__", None) or "匿名条件(执行错误)" @@ -135,7 +136,6 @@ def _evaluate_skip_reason(spec: TaskSpec[Any]) -> str | None: if failed_conditions: return f"条件不满足: {', '.join(failed_conditions)}" - # 2. skip_if_missing 检查(仅对 list[str] 命令有效)。 if spec.skip_if_missing and not spec._is_cmd_available(): cmd_name = spec.cmd[0] if isinstance(spec.cmd, list) and spec.cmd else "unknown" return f"命令不存在: {cmd_name}" @@ -148,10 +148,7 @@ def _make_skipped_result( reason: str, on_event: EventCallback | None, ) -> TaskResult[Any]: - """构造 SKIPPED 的 TaskResult 并发出事件、打印日志。 - - sync 与 async 执行路径共用,消除重复的 result 构造/emit/print 代码。 - """ + """构造 SKIPPED 的 TaskResult。""" result: TaskResult[Any] = TaskResult( spec=spec, status=TaskStatus.SKIPPED, @@ -165,31 +162,118 @@ def _make_skipped_result( return result +def _build_context( + spec: TaskSpec[Any], + global_context: Mapping[str, Any], + report: RunReport | None = None, # noqa: ARG001 +) -> dict[str, Any]: + """构建本任务的上下文:硬依赖 + 软依赖(含默认值回退)。 + + 硬依赖:若上游 SKIPPED/FAILED 则不注入(本任务通常也会被跳过)。 + 软依赖:上游成功则注入其值;否则注入 ``spec.defaults`` 中的默认值(或 ``None``)。 + """ + ctx: dict[str, Any] = {} + + for dep in spec.depends_on: + if dep in global_context: + ctx[dep] = global_context[dep] + + for dep in spec.soft_depends_on: + if dep in global_context: + ctx[dep] = global_context[dep] + elif dep in spec.defaults: + ctx[dep] = spec.defaults[dep] + else: + ctx[dep] = None + + return ctx + + +def _apply_cached( + name: str, + spec: TaskSpec[Any], + context: dict[str, Any], + report: RunReport, + backend: StateBackend, + on_event: EventCallback | None, +) -> bool: + """若 ``name`` 命中缓存,写入 context/report 并返回 True。""" + storage_key = spec.storage_key(context) + if not backend.has(storage_key): + return False + cached = backend.get(storage_key) + context[name] = cached + result = TaskResult(spec=spec, status=TaskStatus.SKIPPED, value=cached, reason="缓存命中") + report.results[name] = result + _emit(on_event, result) + logger.info("task %r skipped (cached)", name) + return True + + def _prepare_for_execution( spec: TaskSpec[Any], + context: Mapping[str, Any], report: RunReport | None, on_event: EventCallback | None, ) -> TaskResult[Any] | None: - """执行前的统一预检:上游跳过 / 条件跳过。 + """执行前预检:上游跳过 / 条件跳过。 - Returns - ------- - TaskResult | None - 若应跳过,返回已填好的 SKIPPED 结果;否则返回 None 表示继续执行。 + 返回 SKIPPED TaskResult 或 ``None``(继续执行)。 """ - # 上游跳过检查 should_skip, skip_reason = _check_upstream_skipped(spec, report) if should_skip: return _make_skipped_result(spec, skip_reason or "上游任务被跳过", on_event) - # 条件 / skip_if_missing 检查(单次求值) - skip_reason = _evaluate_skip_reason(spec) + skip_reason = _evaluate_conditions(spec, context) if skip_reason is not None: return _make_skipped_result(spec, skip_reason, on_event) return None +def _finalize_failure( + result: TaskResult[Any], + layer_idx: int | None, + on_event: EventCallback | None = None, + continue_on_error: bool = False, +) -> None: + """标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。""" + result.status = TaskStatus.FAILED + result.finished_at = datetime.now() + _emit(on_event, result) + if continue_on_error: + logger.warning( + "task %r failed but continue_on_error=True; continuing.", + result.spec.name, + ) + return + raise TaskFailedError( + task=result.spec.name, + cause=result.error if result.error is not None else RuntimeError("unknown"), + attempts=result.attempts, + layer=layer_idx, + ) + + +def _sleep_for_retry(spec: TaskSpec[Any], attempt: int) -> None: + """重试前的同步等待。""" + wait = spec.retry.wait_seconds(attempt) + if wait > 0: + import time + + time.sleep(wait) + + +async def _async_sleep_for_retry(spec: TaskSpec[Any], attempt: int) -> None: + """重试前的异步等待。""" + wait = spec.retry.wait_seconds(attempt) + if wait > 0: + await asyncio.sleep(wait) + + +# ---------------------------------------------------------------------- # +# 同步执行内核 +# ---------------------------------------------------------------------- # def _run_sync_with_retry( spec: TaskSpec[Any], context: Mapping[str, Any], @@ -198,44 +282,47 @@ def _run_sync_with_retry( report: RunReport | None = None, ) -> TaskResult[Any]: """执行同步任务并带重试;返回填充好的 TaskResult。""" - # 统一预检:上游跳过 / 条件跳过(条件单次求值) - skipped = _prepare_for_execution(spec, report, on_event) + skipped = _prepare_for_execution(spec, context, report, on_event) if skipped is not None: return skipped result: TaskResult[Any] = TaskResult(spec=spec) result.started_at = datetime.now() - max_attempts = spec.retries + 1 + max_attempts = spec.retry.max_attempts args, kwargs = build_call_args(spec, context) + _run_hooks(spec.hooks, "pre_run", spec) + while True: result.attempts += 1 try: - result.value = spec.effective_fn(*args, **kwargs) + with spec.env_context(): + result.value = spec.effective_fn(*args, **kwargs) result.status = TaskStatus.SUCCESS result.finished_at = datetime.now() + _run_hooks(spec.hooks, "post_run", spec, result.value) return result except Exception as exc: result.error = exc - if result.attempts >= max_attempts: - _finalize_failure(result, layer_idx, on_event) + if result.attempts >= max_attempts or not spec.retry.should_retry(exc): + _run_hooks(spec.hooks, "on_failure", spec, exc) + _finalize_failure(result, layer_idx, on_event, spec.continue_on_error) + return result _log_retry(spec, result.attempts, max_attempts, exc) - raise AssertionError("unreachable") # pragma: no cover + _sleep_for_retry(spec, result.attempts) + # pragma: no cover +# ---------------------------------------------------------------------- # +# 异步执行内核 +# ---------------------------------------------------------------------- # async def _execute_async_task( spec: TaskSpec[Any], args: tuple[Any, ...], kwargs: dict[str, Any], loop: asyncio.AbstractEventLoop, ) -> Any: - """执行异步或同步任务(带超时处理)。 - - Returns - ------- - Any - 任务返回值 - """ + """执行异步或同步任务(带超时处理)。""" if _is_async_fn(spec): coro = cast(Awaitable[Any], spec.effective_fn(*args, **kwargs)) if spec.timeout is not None: @@ -243,9 +330,10 @@ async def _execute_async_task( else: return await coro else: - # 将同步工作卸载到线程,保持事件循环存活。 + def fn_call() -> Any: - return spec.effective_fn(*args, **kwargs) + with spec.env_context(): + return spec.effective_fn(*args, **kwargs) if spec.timeout is not None: return await asyncio.wait_for(loop.run_in_executor(None, fn_call), timeout=spec.timeout) @@ -259,76 +347,74 @@ async def _run_async_with_retry( layer_idx: int | None, on_event: EventCallback | None = None, report: RunReport | None = None, + semaphore: asyncio.Semaphore | None = None, ) -> TaskResult[Any]: """在事件循环上执行任务(同步或异步)并带重试。""" - # 统一预检:上游跳过 / 条件跳过(条件单次求值) - skipped = _prepare_for_execution(spec, report, on_event) + skipped = _prepare_for_execution(spec, context, report, on_event) if skipped is not None: return skipped - result: TaskResult[Any] = TaskResult[Any](spec=spec) + if semaphore is not None: + async with semaphore: + return await _run_async_inner(spec, context, layer_idx, on_event, report) + return await _run_async_inner(spec, context, layer_idx, on_event, report) + + +async def _run_async_inner( + spec: TaskSpec[Any], + context: Mapping[str, Any], + layer_idx: int | None, + on_event: EventCallback | None = None, + report: RunReport | None = None, # noqa: ARG001 +) -> TaskResult[Any]: + """异步执行内核的内部实现(已获取 semaphore 后)。""" + result: TaskResult[Any] = TaskResult(spec=spec) result.started_at = datetime.now() - max_attempts = spec.retries + 1 + max_attempts = spec.retry.max_attempts args, kwargs = build_call_args(spec, context) loop = asyncio.get_event_loop() + _run_hooks(spec.hooks, "pre_run", spec) + while True: result.attempts += 1 try: result.value = await _execute_async_task(spec, args, kwargs, loop) result.status = TaskStatus.SUCCESS result.finished_at = datetime.now() + _run_hooks(spec.hooks, "post_run", spec, result.value) return result except asyncio.TimeoutError: - result.error = TaskTimeoutError(spec.name, spec.timeout or 0.0) - if result.attempts >= max_attempts: - _finalize_failure(result, layer_idx, on_event) + exc: BaseException = TaskTimeoutError(spec.name, spec.timeout or 0.0) + result.error = exc + if result.attempts >= max_attempts or not spec.retry.should_retry(exc): + _run_hooks(spec.hooks, "on_failure", spec, exc) + _finalize_failure(result, layer_idx, on_event, spec.continue_on_error) + return result logger.warning( "task %r timed out (attempt %d/%d); retrying", spec.name, result.attempts, max_attempts, ) + await _async_sleep_for_retry(spec, result.attempts) except Exception as exc: result.error = exc - if result.attempts >= max_attempts: - _finalize_failure(result, layer_idx, on_event) + if result.attempts >= max_attempts or not spec.retry.should_retry(exc): + _run_hooks(spec.hooks, "on_failure", spec, exc) + _finalize_failure(result, layer_idx, on_event, spec.continue_on_error) + return result _log_retry(spec, result.attempts, max_attempts, exc) - raise AssertionError("unreachable") # pragma: no cover + await _async_sleep_for_retry(spec, result.attempts) + # pragma: no cover # ---------------------------------------------------------------------- # -# 层驱动器 +# 层执行器 # ---------------------------------------------------------------------- # -def _build_context( - spec: TaskSpec[Any], - global_context: Mapping[str, Any], -) -> Mapping[str, Any]: - """将全局上下文限制为本任务的依赖。""" - return {dep: global_context[dep] for dep in spec.depends_on if dep in global_context} - - -def _apply_cached( - name: str, - graph: Graph, - context: dict[str, Any], - report: RunReport, - backend: StateBackend, - on_event: EventCallback | None, -) -> bool: - """若 ``name`` 命中缓存,写入 context/report 并返回 True;否则返回 False。 - - sequential / thread / async 三种层驱动共用,消除缓存命中分支的重复代码。 - """ - if not backend.has(name): - return False - cached = backend.get(name) - context[name] = cached - result = TaskResult(spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached, reason="缓存命中") - report.results[name] = result - _emit(on_event, result) - logger.info("task %r skipped (cached)", name) - return True +def _sort_by_priority(layer: list[str], graph: Graph) -> list[str]: + """按优先级降序排序(稳定排序)。""" + return sorted(layer, key=lambda n: -graph.resolved_spec(n).priority) def _execute_layer_sequential( @@ -340,14 +426,16 @@ def _execute_layer_sequential( layer_idx: int, on_event: EventCallback | None, ) -> None: - """逐个运行某层的任务。""" - for name in layer: - spec = graph.spec(name) - if _apply_cached(name, graph, context, report, backend, on_event): + """逐个运行某层的任务(按优先级排序)。""" + for name in _sort_by_priority(layer, graph): + spec = graph.resolved_spec(name) + if _apply_cached(name, spec, context, report, backend, on_event): continue - result = _run_sync_with_retry(spec, _build_context(spec, context), layer_idx, on_event, report) + task_ctx = _build_context(spec, context, report) + result = _run_sync_with_retry(spec, task_ctx, layer_idx, on_event, report) context[name] = result.value - backend.save(name, result.value) + if result.status == TaskStatus.SUCCESS: + backend.save(spec.storage_key(task_ctx), result.value) report.results[name] = result _emit(on_event, result) @@ -361,42 +449,68 @@ def _execute_layer_threaded( layer_idx: int, on_event: EventCallback | None, max_workers: int, + concurrency_limits: Mapping[str, int], ) -> None: """在线程池中并发运行某层的任务。""" - # 先同步满足已缓存任务。 to_run: list[str] = [] for name in layer: - if _apply_cached(name, graph, context, report, backend, on_event): + spec = graph.resolved_spec(name) + task_ctx = _build_context(spec, context, report) + if _apply_cached(name, spec, context, report, backend, on_event): continue to_run.append(name) if not to_run: return + to_run = _sort_by_priority(to_run, graph) + + # 为每个 concurrency_key 创建线程信号量 + semaphores: dict[str, threading.Semaphore] = {} + for name in to_run: + spec = graph.resolved_spec(name) + key = spec.concurrency_key + if key is not None and key not in semaphores: + limit = concurrency_limits.get(key, 1) + semaphores[key] = threading.Semaphore(limit) + + context_snapshot = dict(context) + lock = threading.Lock() + + def _run_threaded_task(name: str) -> TaskResult[Any]: + spec = graph.resolved_spec(name) + task_ctx = _build_context(spec, context_snapshot, report) + sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None + if sem is not None: + sem.acquire() + try: + return _run_sync_with_retry(spec, task_ctx, layer_idx, on_event, report) + finally: + if sem is not None: + sem.release() + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool: future_to_name: dict[concurrent.futures.Future[TaskResult[Any]], str] = {} for name in to_run: - spec = graph.spec(name) - # 为本任务快照上下文以避免竞态。 - task_ctx = _build_context(spec, context) - fut = pool.submit(_run_sync_with_retry, spec, task_ctx, layer_idx, on_event, report) + fut = pool.submit(_run_threaded_task, name) future_to_name[fut] = name - # 统一收集后再写 context,与 async 版本行为一致: - # 避免边完成边写共享 dict 造成的可见性不一致。 completed: dict[str, TaskResult[Any]] = {} try: for fut in concurrent.futures.as_completed(future_to_name): name = future_to_name[fut] - result = fut.result() # 失败时抛出 TaskFailedError + result = fut.result() completed[name] = result finally: - # 无论是否抛出,都先把已完成任务的结果落盘并写回 context/report。 - for name, result in completed.items(): - context[name] = result.value - backend.save(name, result.value) - report.results[name] = result - _emit(on_event, result) + with lock: + for name, result in completed.items(): + context[name] = result.value + if result.status == TaskStatus.SUCCESS: + spec = graph.resolved_spec(name) + task_ctx = _build_context(spec, context_snapshot, report) + backend.save(spec.storage_key(task_ctx), result.value) + report.results[name] = result + _emit(on_event, result) async def _execute_layer_async( @@ -407,52 +521,122 @@ async def _execute_layer_async( backend: StateBackend, layer_idx: int, on_event: EventCallback | None, + concurrency_limits: Mapping[str, int], ) -> None: """在事件循环上并发运行某层的任务。""" to_run: list[str] = [] for name in layer: - if _apply_cached(name, graph, context, report, backend, on_event): + spec = graph.resolved_spec(name) + if _apply_cached(name, spec, context, report, backend, on_event): continue to_run.append(name) if not to_run: return - coros = [] - for name in to_run: - spec = graph.spec(name) - task_ctx = _build_context(spec, context) - coros.append(_run_async_with_retry(spec, task_ctx, layer_idx, on_event, report)) + to_run = _sort_by_priority(to_run, graph) + # 为每个 concurrency_key 创建异步信号量 + semaphores: dict[str, asyncio.Semaphore] = {} + for name in to_run: + spec = graph.resolved_spec(name) + key = spec.concurrency_key + if key is not None and key not in semaphores: + limit = concurrency_limits.get(key, 1) + semaphores[key] = asyncio.Semaphore(limit) + + context_snapshot = dict(context) + + async def _run_async_task_wrapped(name: str) -> TaskResult[Any]: + spec = graph.resolved_spec(name) + task_ctx = _build_context(spec, context_snapshot, report) + sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None + if sem is not None: + async with sem: + return await _run_async_with_retry(spec, task_ctx, layer_idx, on_event, report) + return await _run_async_with_retry(spec, task_ctx, layer_idx, on_event, report) + + coros = [_run_async_task_wrapped(name) for name in to_run] results = await asyncio.gather(*coros) for name, result in zip(to_run, results): context[name] = result.value - backend.save(name, result.value) + if result.status == TaskStatus.SUCCESS: + spec = graph.resolved_spec(name) + task_ctx = _build_context(spec, context_snapshot, report) + backend.save(spec.storage_key(task_ctx), result.value) report.results[name] = result _emit(on_event, result) +# ---------------------------------------------------------------------- # +# 依赖驱动调度 +# ---------------------------------------------------------------------- # +async def _drive_dependency_async( + graph: Graph, + context: dict[str, Any], + report: RunReport, + backend: StateBackend, + on_event: EventCallback | None, + concurrency_limits: Mapping[str, int], +) -> None: + """依赖驱动调度:任务在硬依赖完成后立即启动,无层屏障。 + + 所有任务通过 asyncio 并发调度。同步任务卸载到线程池。 + """ + all_names = set(graph.all_specs().keys()) + semaphores: dict[str, asyncio.Semaphore] = {} + for name in all_names: + spec = graph.resolved_spec(name) + key = spec.concurrency_key + if key is not None and key not in semaphores: + limit = concurrency_limits.get(key, 1) + semaphores[key] = asyncio.Semaphore(limit) + + futures: dict[str, asyncio.Future[TaskResult[Any]]] = {} + + async def _run_task(name: str) -> TaskResult[Any]: + spec = graph.resolved_spec(name) + # 等待所有硬依赖完成 + for dep in spec.depends_on: + if dep in futures: + await futures[dep] + # 等待所有软依赖完成(但不检查其状态) + for dep in spec.soft_depends_on: + if dep in futures: + await futures[dep] + + task_ctx = _build_context(spec, context, report) + if _apply_cached(name, spec, context, report, backend, on_event): + return report.results[name] + + sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None + if sem is not None: + async with sem: + result = await _run_async_with_retry(spec, task_ctx, None, on_event, report) + else: + result = await _run_async_with_retry(spec, task_ctx, None, on_event, report) + + context[name] = result.value + if result.status == TaskStatus.SUCCESS: + backend.save(spec.storage_key(task_ctx), result.value) + report.results[name] = result + _emit(on_event, result) + return result + + loop = asyncio.get_event_loop() + for name in all_names: + futures[name] = loop.create_task(_run_task(name)) + + await asyncio.gather(*futures.values()) + + # ---------------------------------------------------------------------- # # 公共 API # ---------------------------------------------------------------------- # -def _make_verbose_callback( - on_event: EventCallback | None, -) -> EventCallback | None: - """包装 on_event 回调, 在 verbose 模式下打印任务生命周期. - - Parameters - ---------- - on_event : EventCallback | None - 用户提供的原始回调, 若为 None 则仅打印. - - Returns - ------- - EventCallback | None - 包装后的回调. - """ +def _make_verbose_callback(on_event: EventCallback | None) -> EventCallback: + """包装 on_event 回调, 在 verbose 模式下打印任务生命周期。""" def _verbose_callback(event: TaskEvent) -> None: - # 先打印生命周期信息 dur = f" ({event.duration:.3f}s)" if event.duration is not None else "" if event.status == TaskStatus.RUNNING: # pragma: no cover print(f"[verbose] 任务 {event.task!r} 开始执行...", flush=True) @@ -464,13 +648,9 @@ def _make_verbose_callback( f"[verbose] 任务 {event.task!r} 失败{dur} (尝试 {event.attempts} 次){err}", flush=True, ) - elif event.status == TaskStatus.SKIPPED: # pragma: no branch + elif event.status == TaskStatus.SKIPPED: reason = f" ({event.reason})" if event.reason else "" print(f"[verbose] 任务 {event.task!r} 跳过{reason}", flush=True) - else: # pragma: no cover - # 不可达: 执行器只发出 RUNNING/SUCCESS/FAILED/SKIPPED 事件 - pass - # 再调用用户回调 if on_event is not None: on_event(event) @@ -486,6 +666,7 @@ def run( verbose: bool = False, on_event: EventCallback | None = None, state: StateBackend | None = None, + concurrency_limits: Mapping[str, int] | None = None, ) -> RunReport: """执行图并返回 :class:`RunReport`。 @@ -494,29 +675,28 @@ def run( graph: 待执行的已校验 :class:`Graph`。 strategy: - 执行策略, 接受 :class:`Strategy` 枚举成员或字符串 - (``"sequential"`` / ``"thread"`` / ``"async"``). 默认 ``Strategy.SEQUENTIAL``. + 执行策略: ``"sequential"`` / ``"thread"`` / ``"async"`` / + ``"dependency"``。``"dependency"`` 为依赖驱动调度,无层屏障。 max_workers: ``"thread"`` 的线程池大小。默认 ``min(32, len(layer))``。 dry_run: - 若为 ``True``,打印执行计划(层 + 注入)并返回空报告,不执行 - 任何任务。 + 若为 ``True``,打印执行计划并返回空报告,不执行任务。 verbose: - 若为 ``True``, 打印任务生命周期 (开始/成功/失败/跳过) 到 stdout. - 注意: subprocess 命令的输出由 :class:`TaskSpec` 的 ``verbose`` 字段控制. + 若为 ``True``, 打印任务生命周期到 stdout。 on_event: 可选回调,在每次状态转换时调用。 state: - 可选 :class:`StateBackend`,用于断点续跑。默认为内存后端 - (不跨进程持久化)。 + 可选 :class:`StateBackend`,用于断点续跑。 + concurrency_limits: + ``{concurrency_key: max_concurrent}`` 映射。具有相同 + ``concurrency_key`` 的任务共享信号量,限制同时运行实例数。 抛出 ---- ValueError ``strategy`` 不被识别时。 TaskFailedError - 任何任务耗尽重试后仍失败时。运行在失败层中止;后续层的任务 - 不会被执行。 + 任何任务耗尽重试后仍失败时(除非 ``continue_on_error=True``)。 """ graph.validate() layers = graph.layers() @@ -525,20 +705,23 @@ def run( _print_dry_run(graph, layers) return RunReport(success=True) - # verbose 模式下包装事件回调 effective_callback: EventCallback | None = _make_verbose_callback(on_event) if verbose else on_event - backend = resolve_backend(state) report = RunReport() context: dict[str, Any] = {} + limits = concurrency_limits or {} try: if strategy == "sequential": _drive_sequential(graph, layers, context, report, backend, effective_callback) elif strategy == "thread": - _drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers) + _drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers, limits) + elif strategy == "async": + _drive_async(graph, layers, context, report, backend, effective_callback, limits) + elif strategy == "dependency": + asyncio.run(_drive_dependency_async(graph, context, report, backend, effective_callback, limits)) else: - _drive_async(graph, layers, context, report, backend, effective_callback) + raise ValueError(f"Unknown strategy: {strategy!r}") except TaskFailedError: report.success = False raise @@ -552,7 +735,7 @@ def _print_dry_run(graph: Graph, layers: list[list[str]]) -> None: for idx, layer in enumerate(layers, 1): print(f" Layer {idx}: {layer}") for name in layer: - print(f" - {describe_injection(graph.spec(name))}") + print(f" - {describe_injection(graph.resolved_spec(name))}") def _drive_sequential( @@ -575,10 +758,11 @@ def _drive_threaded( backend: StateBackend, on_event: EventCallback | None, max_workers: int | None, + concurrency_limits: Mapping[str, int], ) -> None: for idx, layer in enumerate(layers, 1): workers = max_workers or max(1, min(32, len(layer))) - _execute_layer_threaded(layer, graph, context, report, backend, idx, on_event, workers) + _execute_layer_threaded(layer, graph, context, report, backend, idx, on_event, workers, concurrency_limits) def _drive_async( @@ -588,8 +772,9 @@ def _drive_async( report: RunReport, backend: StateBackend, on_event: EventCallback | None, + concurrency_limits: Mapping[str, int], ) -> None: - asyncio.run(_async_drive(graph, layers, context, report, backend, on_event)) + asyncio.run(_async_drive(graph, layers, context, report, backend, on_event, concurrency_limits)) async def _async_drive( @@ -599,6 +784,7 @@ async def _async_drive( report: RunReport, backend: StateBackend, on_event: EventCallback | None, + concurrency_limits: Mapping[str, int], ) -> None: for idx, layer in enumerate(layers, 1): - await _execute_layer_async(layer, graph, context, report, backend, idx, on_event) + await _execute_layer_async(layer, graph, context, report, backend, idx, on_event, concurrency_limits) diff --git a/src/pyflowx/graph.py b/src/pyflowx/graph.py index af15ba2..e141c9b 100644 --- a/src/pyflowx/graph.py +++ b/src/pyflowx/graph.py @@ -2,28 +2,53 @@ 使用标准库的 :mod:`graphlib`(3.9+)或 :mod:`graphlib_backport`(3.8) 进行拓扑排序。图以增量方式构建并即时校验,使配置错误在构建时(而非执行时)快速失败。 + +支持: +* 图级默认值 :class:`GraphDefaults`,TaskSpec 字段为 ``None`` 时回退。 +* :meth:`Graph.map` 工厂批量生成 fan-out 任务。 +* 字符串引用与 :func:`compose` 编程式组合多个图。 +* 软依赖:仅用于上下文注入,不参与拓扑分层。 """ from __future__ import annotations import sys from dataclasses import dataclass, field, replace -from typing import Any, Iterable, Mapping, Sequence +from typing import Any, Callable, Iterable, Mapping, Sequence from .errors import CycleError, DuplicateTaskError, MissingDependencyError -from .task import TaskSpec +from .task import RetryPolicy, TaskSpec -# graphlib 自 3.9 起进入标准库;3.8 回退到 backport。 if sys.version_info >= (3, 9): # pragma: no cover import graphlib # pyright: ignore[reportUnreachable] _TopologicalSorter = graphlib.TopologicalSorter else: # pragma: no cover - import graphlib # type: ignore[import-untyped] # pragma: no cover + import graphlib # type: ignore[import-untyped] _TopologicalSorter = graphlib.TopologicalSorter # pragma: no cover +@dataclass +class GraphDefaults: + """图级默认值。TaskSpec 对应字段为 ``None`` 时回退到此处。 + + 仅对可空字段生效(retry/timeout/strategy/env/cwd/tags/priority/ + continue_on_error/concurrency_key)。非空字段(name/fn/cmd)不回退。 + """ + + retry: RetryPolicy | None = None + timeout: float | None = None + strategy: str | None = None + tags: tuple[str, ...] = () + env: Mapping[str, str] | None = None + cwd: Any = None # Path | None + priority: int = 0 + continue_on_error: bool = False + concurrency_key: str | None = None + verbose: bool = False + + @dataclass class Graph: """校验后的有向无环任务图。 @@ -34,16 +59,11 @@ class Graph: 图仅持有*配置*;运行时状态存于 :class:`~pyflowx.report.RunReport`。 这使图可安全重复运行并在线程间共享。 - - Note - ----- - Graph 不再使用 ``frozen=True``:内部 ``specs``/``deps`` 本就是可变 dict, - frozen 既无法真正保证不可变,又迫使 ``_pending_refs`` 等场景用 - ``object.__setattr__`` 绕过。改为普通 dataclass,让赋值显式且可审计。 """ specs: dict[str, TaskSpec[Any]] = field(default_factory=dict) deps: dict[str, tuple[str, ...]] = field(default_factory=dict) + defaults: GraphDefaults = field(default_factory=GraphDefaults) # 待解析的字符串引用列表(由 GraphComposer 消费);为空表示无引用。 _pending_refs: list[str] = field(default_factory=list) @@ -51,69 +71,47 @@ class Graph: # 构建 # ------------------------------------------------------------------ # def add(self, spec: TaskSpec[Any]) -> Graph: - """注册一个任务 spec,并即时校验。 - - 返回 ``self`` 以支持链式调用,但推荐入口是 :meth:`from_specs`, - 它会整批校验(允许单次调用中的前向引用)。 - """ - if spec.name in self.specs: - raise DuplicateTaskError(spec.name) - self.specs[spec.name] = spec - self.deps[spec.name] = spec.depends_on - # 为增量 API 即时检查重名与缺失依赖。 + """注册一个任务 spec,并即时校验。返回 ``self`` 支持链式调用。""" + self._register(spec) self._validate_references() return self + def _register(self, spec: TaskSpec[Any]) -> None: + if spec.name in self.specs: + raise DuplicateTaskError(spec.name) + self.specs[spec.name] = spec + # 拓扑依赖仅含硬依赖;软依赖仅用于注入,不影响分层。 + self.deps[spec.name] = spec.depends_on + @classmethod - def from_specs(cls, specs: Iterable[TaskSpec[Any] | str]) -> Graph: - """从可迭代的 task spec 构建图. + def from_specs( + cls, + specs: Iterable[TaskSpec[Any] | str], + defaults: GraphDefaults | None = None, + ) -> Graph: + """从可迭代的 task spec 构建图。 - 先收集所有 spec,再统一校验。这意味着任务可以引用*后出现*的 - 依赖——顺序无关,就像声明式配置文件的读取方式。 - - 支持字符串引用,允许引用其他命令图中的任务。 - 字符串引用将在CliRunner中解析展开。 + 先收集所有 spec,再统一校验。允许前向引用。支持字符串引用, + 由 :func:`compose` 或 :class:`GraphComposer` 解析展开。 Parameters ---------- - specs : Iterable[TaskSpec[Any] | str] - TaskSpec对象或字符串引用的列表 - - Returns - ------- - Graph - 构建完成的图 - - Note - ----- - 字符串引用格式: - - "command_name" - 引用整个命令图 - - "command_name.task_name" - 引用特定任务 - - Examples - -------- - >>> graph = Graph.from_specs([ - ... TaskSpec("build", cmd=["uv", "build"]), - ... "test", # 引用test命令图 - ... ]) + specs: + TaskSpec 对象或字符串引用的列表。 + defaults: + 图级默认值。``None`` 使用空 :class:`GraphDefaults`。 """ - graph = cls() + graph = cls(defaults=defaults or GraphDefaults()) pending_refs: list[str] = [] for spec in specs: if isinstance(spec, str): - # 字符串引用,稍后解析 pending_refs.append(spec) elif isinstance(spec, TaskSpec): - if spec.name in graph.specs: - raise DuplicateTaskError(spec.name) - graph.specs[spec.name] = spec - graph.deps[spec.name] = spec.depends_on + graph._register(spec) else: - raise TypeError(f"from_specs只接受TaskSpec或str,收到: {type(spec)}") + raise TypeError(f"from_specs 只接受 TaskSpec 或 str,收到: {type(spec)}") - # 存储待解析的引用,稍后由 GraphComposer 解析展开。 - # Graph 不再 frozen,可直接赋值;保留属性名以保持向后兼容。 if pending_refs: graph._pending_refs = pending_refs @@ -125,26 +123,22 @@ class Graph: # 校验 # ------------------------------------------------------------------ # def _validate_references(self) -> None: - """确保每个依赖名都存在于图中。""" - for name, deps in self.deps.items(): - for dep in deps: + """确保每个依赖名都存在于图中。硬依赖与软依赖都校验。""" + for name, spec in self.specs.items(): + for dep in spec.depends_on: + if dep not in self.specs: + raise MissingDependencyError(name, dep) + for dep in spec.soft_depends_on: if dep not in self.specs: raise MissingDependencyError(name, dep) def validate(self) -> None: - """执行完整 DAG 校验。 - - 存在环时抛出 :class:`~pyflowx.errors.CycleError`。 - 依赖存在性由 :meth:`_validate_references` 检查。 - """ + """执行完整 DAG 校验。存在环时抛出 :class:`CycleError`。""" self._validate_references() sorter = _TopologicalSorter(self.deps) try: - # prepare() 在有环时抛出 CycleError;此处不需要 - # static_order() 的结果,仅利用其校验副作用。 sorter.prepare() - except graphlib.CycleError as exc: - # exc.args[1] 是构成环的节点列表。 + except graphlib.CycleError as exc: # type: ignore[name-defined] cycle: Sequence[str] = exc.args[1] if len(exc.args) > 1 else [] raise CycleError(list(cycle)) from exc @@ -160,10 +154,49 @@ class Graph: """返回 ``name`` 的 spec;不存在则 ``KeyError``。""" return self.specs[name] + def resolved_spec(self, name: str) -> TaskSpec[Any]: + """返回应用图级默认值后的 spec(不修改原图)。 + + 对于 ``retry``/``timeout``/``strategy``/``env``/``cwd`` 等可空 + 字段,若 spec 字段为默认空值且图级默认值非空,则用 + :func:`dataclasses.replace` 生成带默认值的副本。 + """ + spec = self.specs[name] + d = self.defaults + overrides: dict[str, Any] = {} + if spec.retry == RetryPolicy() and d.retry is not None: + overrides["retry"] = d.retry + if spec.timeout is None and d.timeout is not None: + overrides["timeout"] = d.timeout + if spec.strategy is None and d.strategy is not None: + overrides["strategy"] = d.strategy + if spec.env is None and d.env is not None: + overrides["env"] = d.env + if spec.cwd is None and d.cwd is not None: + overrides["cwd"] = d.cwd + if spec.priority == 0 and d.priority != 0: + overrides["priority"] = d.priority + if not spec.continue_on_error and d.continue_on_error: + overrides["continue_on_error"] = True + if spec.concurrency_key is None and d.concurrency_key is not None: + overrides["concurrency_key"] = d.concurrency_key + if not spec.verbose and d.verbose: + overrides["verbose"] = True + if not spec.tags and d.tags: + overrides["tags"] = d.tags + if not overrides: + return spec + return replace(spec, **overrides) + def dependencies(self, name: str) -> tuple[str, ...]: - """``name`` 的直接前驱。""" + """``name`` 的直接硬依赖前驱。""" return self.deps[name] + def all_deps(self, name: str) -> tuple[str, ...]: + """``name`` 的硬依赖 + 软依赖。""" + spec = self.specs[name] + return tuple(spec.depends_on) + tuple(spec.soft_depends_on) + def all_specs(self) -> Mapping[str, TaskSpec[Any]]: """name -> spec 的只读视图。""" return self.specs @@ -171,18 +204,15 @@ class Graph: def layers(self) -> list[list[str]]: """将任务分组为可并行执行的层(Kahn 算法)。 - 同层任务无相互依赖,可并发执行。层按执行顺序返回。 - - 图有环时抛出 :class:`~pyflowx.errors.CycleError`。 + 同层任务无相互硬依赖,可并发执行。软依赖不参与分层。 + 层按执行顺序返回。图有环时抛出 :class:`CycleError`。 """ self.validate() sorter = _TopologicalSorter(self.deps) result: list[list[str]] = [] - # ``get_ready`` + ``done`` 每次给出一层,正好是并行执行所需的分组。 sorter.prepare() while sorter.is_active(): ready = list(sorter.get_ready()) - # 排序以保证确定性、可复现的执行计划。 ready.sort() result.append(ready) for node in ready: @@ -193,12 +223,7 @@ class Graph: # 子图 / 标签过滤 # ------------------------------------------------------------------ # def subgraph(self, tags: Iterable[str]) -> Graph: - """返回仅包含匹配任意标签的任务的新图。 - - 依赖会被修剪,仅保留被保留任务之间的边;指向被丢弃任务的边 - 会被移除(被保留的任务不再等待它们)。用于调试时运行大型 - DAG 的切片。 - """ + """返回仅包含匹配任意标签的任务的新图。依赖边被修剪。""" wanted: set[str] = set(tags) kept: list[TaskSpec[Any]] = [] for spec in self.specs.values(): @@ -206,10 +231,11 @@ class Graph: pruned_deps = tuple( d for d in spec.depends_on if d in self.specs and (wanted & set(self.specs[d].tags)) ) - # 使用 replace 保留所有字段(verbose/skip_if_missing/allow_upstream_skip 等), - # 避免手动逐字段重建时遗漏新增字段。 - kept.append(replace(spec, depends_on=pruned_deps)) - return Graph.from_specs(kept) + pruned_soft = tuple( + d for d in spec.soft_depends_on if d in self.specs and (wanted & set(self.specs[d].tags)) + ) + kept.append(replace(spec, depends_on=pruned_deps, soft_depends_on=pruned_soft)) + return Graph.from_specs(kept, defaults=self.defaults) def subgraph_by_names(self, names: Iterable[str]) -> Graph: """返回限定于 ``names`` 的新图(边已修剪)。""" @@ -221,18 +247,71 @@ class Graph: for spec in self.specs.values(): if spec.name in wanted: pruned_deps = tuple(d for d in spec.depends_on if d in wanted) - kept.append(replace(spec, depends_on=pruned_deps)) - return Graph.from_specs(kept) + pruned_soft = tuple(d for d in spec.soft_depends_on if d in wanted) + kept.append(replace(spec, depends_on=pruned_deps, soft_depends_on=pruned_soft)) + return Graph.from_specs(kept, defaults=self.defaults) + + # ------------------------------------------------------------------ # + # Fan-out / map-reduce + # ------------------------------------------------------------------ # + def map( + self, + name_fn: Callable[[int], str], + spec: TaskSpec[Any], + items: Sequence[Any], + arg_factory: Callable[[Any], tuple[Any, ...]] | None = None, + depends_on_per: Callable[[int], tuple[str, ...]] | None = None, + ) -> list[TaskSpec[Any]]: + """为 ``items`` 中每个元素生成一个 TaskSpec 并加入图。 + + 用于 fan-out / map-reduce 模式。返回生成的 spec 列表,便于 + 后续 reduce 任务依赖。 + + Parameters + ---------- + name_fn: + 接受索引 ``i``,返回任务名。需保证唯一。 + spec: + 模板 spec。其 ``name`` 与 ``args`` 会被覆盖。 + items: + 待分发的数据序列。 + arg_factory: + 接受一个 item,返回位置参数元组,覆盖 spec.args。 + ``None`` 则将单个 item 作为唯一位置参数。 + depends_on_per: + 接受索引 ``i``,返回该任务的额外硬依赖。``None`` 则继承 spec.depends_on。 + + Returns + ------- + list[TaskSpec] + 生成的 spec 列表(已加入图)。 + + Examples + -------- + >>> fetch_tmpl = px.TaskSpec("", fn=fetch_user) + >>> specs = graph.map(lambda i: f"fetch_{i}", fetch_tmpl, [1, 2, 3]) + >>> reduce_spec = px.TaskSpec("reduce", fn=reduce_fn, depends_on=tuple(s.name for s in specs)) + """ + generated: list[TaskSpec[Any]] = [] + for i, item in enumerate(items): + name = name_fn(i) + args = arg_factory(item) if arg_factory is not None else (item,) + extra_deps = depends_on_per(i) if depends_on_per is not None else () + new_spec = replace( + spec, + name=name, + args=tuple(args), + depends_on=tuple(spec.depends_on) + tuple(extra_deps), + ) + self.add(new_spec) + generated.append(new_spec) + return generated # ------------------------------------------------------------------ # # 可视化 # ------------------------------------------------------------------ # def to_mermaid(self, orientation: str = "TD") -> str: - """将 DAG 渲染为 Mermaid ``graph`` 定义字符串。 - - 无外部依赖;输出可粘贴到 Markdown、由 VS Code 的 Mermaid 预览 - 渲染,或保存为文件。 - """ + """将 DAG 渲染为 Mermaid ``graph`` 定义字符串。""" valid = {"TD", "TB", "BT", "LR", "RL"} orientation = orientation.upper() if orientation not in valid: @@ -243,6 +322,10 @@ class Graph: for name, deps in self.deps.items(): for dep in deps: lines.append(f" {dep} --> {name}") + # 软依赖用虚线 + for name, spec in self.specs.items(): + for dep in spec.soft_depends_on: + lines.append(f" {dep} -.-> {name}") return "\n".join(lines) + "\n" # ------------------------------------------------------------------ # @@ -268,19 +351,12 @@ class Graph: class GraphComposer: """将带字符串引用的图展开为纯 :class:`TaskSpec` 图。 - 从 ``CliRunner`` 抽出,使 ``Graph``(数据)与引用解析(组合逻辑) - 职责分离。引用按顺序展开,后续引用的任务依赖前面引用的最后一个任务; - 原始 ``TaskSpec`` 之间也按出现顺序串行依赖。 - - 引用格式 - -------- + 引用格式: * ``"command_name"`` —— 引用整个命令图。 * ``"command_name.task_name"`` —— 引用特定任务。 - Parameters - ---------- - graphs : dict[str, Graph] - 命令名到图的映射,引用据此解析。 + 引用按顺序展开,后续引用的任务依赖前面引用的最后一个任务; + 原始 ``TaskSpec`` 之间也按出现顺序串行依赖。 """ def __init__(self, graphs: dict[str, Graph]) -> None: @@ -294,18 +370,7 @@ class GraphComposer: return resolved def expand_refs(self, graph: Graph, current_cmd: str) -> Graph: - """展开图中的字符串引用。 - - 若图无 ``_pending_refs``,原样返回。 - - Note - ----- - 引用按顺序展开,后续引用的任务依赖于前面引用的任务完成。 - 例如 ``["c", "tc", bump]`` 展开为: - - c 的所有任务(无依赖) - - tc 的所有任务(依赖于 c 的最后一个任务) - - bump 任务(依赖于 tc 的最后一个任务) - """ + """展开图中的字符串引用。若无 ``_pending_refs``,原样返回。""" pending_refs = graph._pending_refs if not pending_refs: return graph @@ -313,23 +378,16 @@ class GraphComposer: all_specs: list[TaskSpec[Any]] = [] previous_ref_last_task: str | None = None - # 先解析每个引用,并建立依赖链。 for ref in pending_refs: expanded_specs = self.parse_ref(ref, current_cmd) - - # 若有前一个引用,让当前引用的任务依赖其最后一个任务。 if previous_ref_last_task and expanded_specs: for i, task in enumerate(expanded_specs): - # 只为没有依赖的任务(或第一个任务)添加依赖。 if i == 0 or not task.depends_on: expanded_specs[i] = replace(task, depends_on=tuple({*task.depends_on, previous_ref_last_task})) - if expanded_specs: previous_ref_last_task = expanded_specs[-1].name - all_specs.extend(expanded_specs) - # 然后添加原始 TaskSpec,按出现顺序串行依赖。 original_specs = list(graph.all_specs().values()) if original_specs: if previous_ref_last_task: @@ -337,49 +395,53 @@ class GraphComposer: all_specs.append(replace(first, depends_on=tuple({*first.depends_on, previous_ref_last_task}))) else: all_specs.append(original_specs[0]) - for i in range(1, len(original_specs)): current_task = original_specs[i] previous_task_name = original_specs[i - 1].name all_specs.append( - replace( - current_task, - depends_on=tuple({*current_task.depends_on, previous_task_name}), - ) + replace(current_task, depends_on=tuple({*current_task.depends_on, previous_task_name})) ) - return Graph.from_specs(all_specs) + return Graph.from_specs(all_specs, defaults=graph.defaults) def parse_ref(self, ref: str, current_cmd: str) -> list[TaskSpec[Any]]: - """解析单个字符串引用,返回对应的 TaskSpec 列表。 - - Raises - ------ - ValueError - 引用无效、目标命令/任务不存在,或检测到循环引用。 - """ - # 避免循环引用。 + """解析单个字符串引用,返回对应的 TaskSpec 列表。""" if ref == current_cmd: raise ValueError(f"循环引用: 命令 '{current_cmd}' 引用了自己") if "." in ref: - # 特定任务引用: "command_name.task_name" cmd_name, task_name = ref.split(".", 1) if cmd_name not in self.graphs: raise ValueError(f"引用的命令 '{cmd_name}' 不存在") - ref_graph = self.graphs[cmd_name] if task_name not in ref_graph.all_specs(): raise ValueError(f"任务 '{task_name}' 不存在于命令 '{cmd_name}' 中") - return [ref_graph.all_specs()[task_name]] else: - # 整个命令图引用: "command_name" cmd_name = ref if cmd_name not in self.graphs: raise ValueError(f"引用的命令 '{cmd_name}' 不存在") - ref_graph = self.graphs[cmd_name] - # 递归展开(若引用的图自身也含引用)。 ref_graph = self.expand_refs(ref_graph, cmd_name) return list(ref_graph.all_specs().values()) + + +def compose( + graphs: dict[str, Graph], +) -> dict[str, Graph]: + """编程式解析多图的字符串引用,返回展开后的新图映射。 + + 与 :class:`GraphComposer` 等价,但作为独立函数暴露,供不使用 + :class:`~pyflowx.runner.CliRunner` 的编程式用户调用。 + + Examples + -------- + >>> graphs = { + ... "build": px.Graph.from_specs([px.TaskSpec("b", cmd=["echo", "b"])]), + ... "all": px.Graph.from_specs(["build", px.TaskSpec("t", cmd=["echo", "t"])]), + ... } + >>> resolved = px.compose(graphs) + >>> "b" in resolved["all"].all_specs() + True + """ + return GraphComposer(graphs).resolve_all() diff --git a/src/pyflowx/storage.py b/src/pyflowx/storage.py index c766c16..bb94303 100644 --- a/src/pyflowx/storage.py +++ b/src/pyflowx/storage.py @@ -4,20 +4,18 @@ 执行器向后端查询某任务是否已有存储结果;若有则跳过该任务,并将其 存储值注入下游任务。 -本模块刻意保持最小化:仅持久化*成功*结果(失败任务会重跑),存储 -形态为扁平的 ``{task_name: result}`` 映射。内置两个后端: +存储键由 :meth:`TaskSpec.storage_key` 计算,默认为任务名;若任务配置 +了 ``cache_key``,则键为 ``"name:cache_key_value"``,使不同输入产生 +独立缓存条目。 -* :class:`MemoryBackend` —— 快速、进程内、无 I/O。默认。 -* :class:`JSONBackend` —— 持久化到 JSON 文件,支持跨进程续跑。 - -两者均零依赖(``json`` 为标准库)。用户可子类化 -:class:`StateBackend` 接入 SQLite、Redis 等。 +支持 TTL:``has`` 在条目过期时返回 ``False``。 """ from __future__ import annotations import json import sys +import time from abc import ABC, abstractmethod from pathlib import Path from typing import Any, Mapping @@ -31,23 +29,26 @@ from .errors import StorageError class StateBackend(ABC): - """可续跑状态存储的抽象基类。""" + """可续跑状态存储的抽象基类。 + + 所有方法以 ``key`` 为参数(通常为任务名或 ``name:cache_key``)。 + """ @abstractmethod def load(self) -> Mapping[str, Any]: """返回完整的存储映射(可能为空)。""" @abstractmethod - def save(self, name: str, value: Any) -> None: + def save(self, key: str, value: Any) -> None: """持久化单个任务的成功结果。""" @abstractmethod - def has(self, name: str) -> bool: - """``name`` 是否已有存储结果。""" + def has(self, key: str) -> bool: + """``key`` 是否已有未过期的存储结果。""" @abstractmethod - def get(self, name: str) -> Any: - """返回 ``name`` 的存储结果(不存在则抛 ``KeyError``)。""" + def get(self, key: str) -> Any: + """返回 ``key`` 的存储结果(不存在则抛 ``KeyError``)。""" @abstractmethod def clear(self) -> None: @@ -55,43 +56,66 @@ class StateBackend(ABC): class MemoryBackend(StateBackend): - """进程内 dict 后端。进程退出即丢失。""" + """进程内 dict 后端。进程退出即丢失。 - def __init__(self) -> None: - self._store: dict[str, Any] = {} + Parameters + ---------- + ttl: + 条目存活秒数。``None`` 表示永不过期。``has`` 在条目超过 ttl 后 + 返回 ``False``(但不主动删除,下次 ``save`` 覆盖)。 + """ + + def __init__(self, ttl: float | None = None) -> None: + self._store: dict[str, tuple[Any, float]] = {} + self._ttl = ttl @override def load(self) -> Mapping[str, Any]: - return dict(self._store) + return {k: v for k, (v, _ts) in self._store.items() if not self._expired(k)} @override - def save(self, name: str, value: Any) -> None: - self._store[name] = value + def save(self, key: str, value: Any) -> None: + self._store[key] = (value, time.monotonic()) @override - def has(self, name: str) -> bool: - return name in self._store + def has(self, key: str) -> bool: + return key in self._store and not self._expired(key) @override - def get(self, name: str) -> Any: - return self._store[name] + def get(self, key: str) -> Any: + if key not in self._store or self._expired(key): + raise KeyError(key) + return self._store[key][0] @override def clear(self) -> None: self._store.clear() + def _expired(self, key: str) -> bool: + if self._ttl is None or key not in self._store: + return False + _value, ts = self._store[key] + return (time.monotonic() - ts) > self._ttl + class JSONBackend(StateBackend): """基于文件的 JSON 存储,用于跨进程续跑。 - 结果必须可 JSON 序列化。不可序列化的值会抛出 - :class:`~pyflowx.errors.StorageError`(运行本身不会中止;仅该条 - 结果的持久化失败)。 + 存储格式:``{key: {"value": v, "ts": epoch_seconds}}``。 + ``ts`` 用于 TTL 判断。结果必须可 JSON 序列化。 + + Parameters + ---------- + path: + JSON 文件路径。 + ttl: + 条目存活秒数。``None`` 表示永不过期。 """ - def __init__(self, path: str) -> None: + def __init__(self, path: str, ttl: float | None = None) -> None: self._path: str = path - self._store: dict[str, Any] = {} + self._ttl = ttl + self._store: dict[str, dict[str, Any]] = {} self._load() def _load(self) -> None: @@ -101,7 +125,14 @@ class JSONBackend(StateBackend): with open(self._path, encoding="utf-8") as fh: data: Any = json.load(fh) if isinstance(data, dict): - self._store = data + # 兼容纯值格式与带元数据格式 + self._store = {} + for k, v in data.items(): + if isinstance(v, dict) and "value" in v and "ts" in v: + self._store[k] = v + else: + # 旧格式:纯值 + self._store[k] = {"value": v, "ts": time.time()} except (OSError, json.JSONDecodeError) as exc: raise StorageError(f"cannot read state file {self._path!r}", exc) from exc @@ -110,32 +141,40 @@ class JSONBackend(StateBackend): try: with open(tmp, "w", encoding="utf-8") as fh: json.dump(self._store, fh, ensure_ascii=False, indent=2) - _ = Path(tmp).replace(Path(self._path)) except (OSError, TypeError) as exc: raise StorageError(f"cannot write state file {self._path!r}", exc) from exc - @override - def load(self) -> Mapping[str, Any]: - return dict(self._store) + def _now(self) -> float: + return time.time() + + def _expired(self, entry: dict[str, Any]) -> bool: + if self._ttl is None: + return False + return (self._now() - float(entry.get("ts", 0))) > self._ttl @override - def save(self, name: str, value: Any) -> None: - # 在修改内存状态前先校验可序列化性。 + def load(self) -> Mapping[str, Any]: + return {k: v["value"] for k, v in self._store.items() if not self._expired(v)} + + @override + def save(self, key: str, value: Any) -> None: try: _ = json.dumps(value) except (TypeError, ValueError) as exc: - raise StorageError(f"result of task {name!r} is not JSON-serialisable", exc) from exc - self._store[name] = value + raise StorageError(f"result of key {key!r} is not JSON-serialisable", exc) from exc + self._store[key] = {"value": value, "ts": self._now()} self._flush() @override - def has(self, name: str) -> bool: - return name in self._store + def has(self, key: str) -> bool: + return key in self._store and not self._expired(self._store[key]) @override - def get(self, name: str) -> Any: - return self._store[name] + def get(self, key: str) -> Any: + if key not in self._store or self._expired(self._store[key]): + raise KeyError(key) + return self._store[key]["value"] @override def clear(self) -> None: diff --git a/src/pyflowx/task.py b/src/pyflowx/task.py index 792e0c9..84e9ecf 100644 --- a/src/pyflowx/task.py +++ b/src/pyflowx/task.py @@ -15,9 +15,11 @@ * ``TaskStatus`` 是封闭枚举;执行器绝不发明临时字符串。 """ +import os import shutil import subprocess import sys +from contextlib import contextmanager from dataclasses import dataclass, field from datetime import datetime from enum import Enum @@ -25,8 +27,10 @@ from pathlib import Path from typing import ( Any, Callable, + ContextManager, Coroutine, Generic, + Iterator, List, Mapping, Optional, @@ -59,8 +63,95 @@ TaskCmd = Union[ Callable[..., Any], # Python 函数 ] -# 条件判断函数类型 -Condition = Callable[[], bool] +# 执行策略:sequential/thread/async 为层屏障模型,dependency 为依赖驱动模型。 +Strategy = Union[str, "StrategyKind"] +StrategyKind = Any # 占位,避免循环;executors 模块用 Literal 约束 + +# 条件判断函数类型:接收依赖上下文(可能为空映射),返回是否应执行。 +Condition = Callable[[Context], bool] + +# 缓存键计算函数:基于依赖上下文计算稳定字符串键。 +CacheKeyFn = Callable[[Context], str] + + +# ---------------------------------------------------------------------- # +# 重试策略 +# ---------------------------------------------------------------------- # +@dataclass(frozen=True) +class RetryPolicy: + """任务失败重试策略。 + + 参数 + ---- + max_attempts: + 最大尝试次数(含首次)。``1`` 表示仅尝试一次,不重试。 + delay: + 两次尝试之间的初始等待秒数。 + backoff: + 退避倍率。第 n 次重试等待 ``delay * backoff ** (n-1)``。 + jitter: + 抖动上限秒数。每次等待加上 ``[0, jitter)`` 的随机量,避免惊群。 + retry_on: + 仅对这些异常类型重试。默认 ``(Exception,)`` 重试所有异常。 + 传入空元组等价于不重试。 + + Note + ----- + 替代旧版 ``retries: int``。``retries=2`` 等价于 + ``RetryPolicy(max_attempts=3)``。 + """ + + max_attempts: int = 1 + delay: float = 0.0 + backoff: float = 1.0 + jitter: float = 0.0 + retry_on: Tuple[type[BaseException], ...] = (Exception,) + + def __post_init__(self) -> None: + if self.max_attempts < 1: + raise ValueError(f"RetryPolicy.max_attempts must be >= 1, got {self.max_attempts}.") + if self.delay < 0: + raise ValueError(f"RetryPolicy.delay must be >= 0, got {self.delay}.") + if self.backoff < 0: + raise ValueError(f"RetryPolicy.backoff must be >= 0, got {self.backoff}.") + if self.jitter < 0: + raise ValueError(f"RetryPolicy.jitter must be >= 0, got {self.jitter}.") + + @property + def retries(self) -> int: + """重试次数(不含首次),等价于 ``max_attempts - 1``。""" + return self.max_attempts - 1 + + def should_retry(self, exc: BaseException) -> bool: + """异常是否属于可重试类型。""" + return isinstance(exc, self.retry_on) + + def wait_seconds(self, attempt: int) -> float: + """第 ``attempt`` 次失败后应等待的秒数(attempt 从 1 开始)。""" + if attempt < 1: + return 0.0 + import random + + base = self.delay * (self.backoff ** max(0, attempt - 1)) + jitter = random.uniform(0, self.jitter) if self.jitter > 0 else 0.0 + return base + jitter + + +# ---------------------------------------------------------------------- # +# 任务钩子 +# ---------------------------------------------------------------------- # +@dataclass(frozen=True) +class TaskHooks: + """任务生命周期钩子。 + + 所有钩子均为可选。``pre_run`` 在任务实际执行前调用;``post_run`` + 在成功后调用并接收返回值;``on_failure`` 在最终失败后调用并接收异常。 + 钩子异常不会影响任务状态,仅记录日志。 + """ + + pre_run: Optional[Callable[["TaskSpec[Any]"], None]] = None + post_run: Optional[Callable[["TaskSpec[Any]", Any], None]] = None + on_failure: Optional[Callable[["TaskSpec[Any]", BaseException], None]] = None class TaskStatus(Enum): @@ -90,181 +181,239 @@ class TaskSpec(Generic[T]): - ``list[str]``: 命令及参数列表,如 ``["ls", "-la"]`` - ``str``: shell 命令字符串,如 ``"pip freeze > requirements.txt"`` - ``Callable``: Python 函数,与 ``fn`` 参数等效 - 若提供此参数,会自动包装为执行函数,覆盖 ``fn`` 参数。 depends_on: - 必须先完成才能运行本任务的任务名列表。顺序无关;框架会做 - 拓扑排序。 + 硬依赖任务名。必须全部成功完成才会运行本任务。 + 上游被 SKIPPED 时,本任务也会被 SKIPPED(除非 + ``allow_upstream_skip=True``)。 + soft_depends_on: + 软依赖任务名。会等待其完成,但其结果不影响本任务是否执行: + - 上游成功:注入其返回值 + - 上游 SKIPPED 或失败:注入 :attr:`defaults` 中提供的默认值 + 适用于"可选输入"场景。 + defaults: + 软依赖的默认值映射 ``{dep_name: default_value}``。 + 软依赖未提供结果时使用。未在 defaults 中出现的软依赖默认为 ``None``。 args: - 静态位置参数,追加在注入参数*之后*。适用于参数化任务 - (如 ``fetch_user(uid)``)。 + 静态位置参数,追加在注入参数*之后*。 kwargs: 静态关键字参数。若与注入名冲突则抛出 :class:`~pyflowx.errors.InjectionError`。 - retries: - 失败后的重试次数。``0`` 表示仅尝试一次。 + retry: + :class:`RetryPolicy` 重试策略。默认仅尝试一次。 timeout: 最大执行时长(秒)。``None`` 表示不限制。异步任务使用 - :func:`asyncio.wait_for`;线程/异步执行器中的同步任务会 - 取消 worker future。 + :func:`asyncio.wait_for`;同步任务通过线程 future 取消。 tags: - 自由标签,供 :meth:`Graph.subgraph` 做选择性执行与调试。 + 自由标签,供 :meth:`Graph.subgraph` 做选择性执行与调试, + 也可用于并发限制分组。 conditions: - 条件判断函数列表,只有所有条件都返回 ``True`` 时才执行任务。 - 若任一条件返回 ``False``,任务会被标记为 SKIPPED。 - 用于平台判断、环境变量检查等场景。 + 条件判断函数列表,接收依赖上下文,全部返回 ``True`` 时才执行任务。 + 任一返回 ``False`` 则任务被标记为 SKIPPED。 cwd: - 命令执行的工作目录,仅在使用 ``cmd`` 参数时有效。 - ``None`` 表示当前目录。 + 工作目录。对 ``cmd`` 任务作为子进程工作目录;对 ``fn`` 任务 + 通过临时切换当前目录生效。 + env: + 环境变量覆盖映射。对 ``cmd`` 任务合并到子进程环境;对 ``fn`` + 任务在执行期间临时设置。 verbose: - 是否在命令执行时显示详细输出。``True`` 时会打印执行的命令 - 及其标准输出/标准错误。仅在使用 ``cmd`` 参数时有效。 - ``False`` 时静默捕获输出(失败时仍会包含在错误信息中)。 + 是否打印详细输出。``True`` 时打印执行的命令、返回码与输出 + (仅 ``cmd``),以及任务生命周期。 skip_if_missing: - 仅对 ``cmd`` 为 ``list[str]`` 的任务有效。``True`` 时自动检查 - 命令是否存在(通过 :func:`shutil.which`),不存在则跳过任务 - (标记为 SKIPPED)而非失败。适用于构建工具场景,避免因未安装 - 某些工具(如 maturin、tox)而导致整个图执行失败。 - 对于 ``str`` (shell) 和 ``Callable`` 类型的 ``cmd``,此参数无效。 + 仅对 ``cmd`` 为 ``list[str]`` 有效。``True`` 时通过 + :func:`shutil.which` 检查命令是否存在,不存在则跳过。 allow_upstream_skip: - 若为 ``True``,当上游任务因条件不满足被跳过时,本任务仍会执行 - (而非跟随跳过)。适用于清理类任务:即使某些删除操作因目标不存在 - 而跳过,后续操作(如重启服务)仍应执行。默认为 ``False``。 + 若为 ``True``,硬依赖被 SKIPPED 时本任务仍执行(软依赖不影响)。 + 适用于清理类任务。 + strategy: + 单任务执行策略覆盖。``None`` 表示继承图级策略。 + ``"sequential"`` 同步直接调用;``"thread"``/``"async"`` 将同步 + 任务卸载到线程池,异步任务跑在事件循环上。 + priority: + 同层任务调度优先级。数值越大越先启动。仅影响同层内启动顺序, + 不打破层屏障。默认 ``0``。 + concurrency_key: + 并发限制分组键。具有相同键的任务共享一个信号量,限制同时 + 运行的实例数。具体限额由 :func:`run` 的 ``concurrency_limits`` + 参数提供 ``{key: limit}`` 映射。``None`` 表示不限制。 + continue_on_error: + 若为 ``True``,任务最终失败时不中止整图,仅标记本任务 FAILED, + 其硬依赖下游被 SKIPPED,其余任务继续。默认 ``False``。 + cache_key: + 缓存键计算函数。若提供,则用其基于依赖上下文计算的字符串键 + 存取状态后端,使不同输入产生独立缓存条目。``None`` 表示用任务名。 + hooks: + :class:`TaskHooks` 生命周期钩子。 """ name: str fn: Optional[TaskFn[T]] = None cmd: Optional[TaskCmd] = None depends_on: Tuple[str, ...] = () + soft_depends_on: Tuple[str, ...] = () + defaults: Mapping[str, Any] = field(default_factory=dict) args: Tuple[Any, ...] = () kwargs: Mapping[str, Any] = field(default_factory=dict) - retries: int = 0 + retry: RetryPolicy = field(default_factory=RetryPolicy) timeout: Optional[float] = None tags: Tuple[str, ...] = () conditions: Tuple[Condition, ...] = () cwd: Optional[Path] = None + env: Optional[Mapping[str, str]] = None verbose: bool = False skip_if_missing: bool = False allow_upstream_skip: bool = False + strategy: Optional[str] = None + priority: int = 0 + concurrency_key: Optional[str] = None + continue_on_error: bool = False + cache_key: Optional[CacheKeyFn] = None + hooks: TaskHooks = field(default_factory=TaskHooks) def __post_init__(self) -> None: if not self.name: raise ValueError("TaskSpec.name must be a non-empty string.") - if self.retries < 0: - raise ValueError(f"TaskSpec '{self.name}': retries must be >= 0.") + if self.retry.max_attempts < 1: + raise ValueError(f"TaskSpec '{self.name}': retry.max_attempts must be >= 1.") if self.timeout is not None and self.timeout <= 0: raise ValueError(f"TaskSpec '{self.name}': timeout must be > 0.") - if self.name in self.depends_on: + if self.name in self.depends_on or self.name in self.soft_depends_on: raise ValueError(f"TaskSpec '{self.name}' cannot depend on itself.") + overlap = set(self.depends_on) & set(self.soft_depends_on) + if overlap: + raise ValueError(f"TaskSpec '{self.name}': depends_on 与 soft_depends_on 不能重叠: {sorted(overlap)}") if self.fn is None and self.cmd is None: raise ValueError(f"TaskSpec '{self.name}': 必须提供 fn 或 cmd 参数。") @property def effective_fn(self) -> TaskFn[T]: - """获取有效的执行函数. + """获取有效的执行函数。 - 若提供了 ``cmd`` 参数,则返回包装后的命令执行函数; - 否则返回 ``fn`` 参数。 - - Note - ----- - 命令执行逻辑已抽到模块级 :func:`_run_command`,此处仅返回轻量 - 转发闭包。``verbose`` / ``cwd`` / ``timeout`` 不再在创建时闭包 - 捕获,而是在每次调用时从 ``self`` 读取——这使得翻转 ``verbose`` - 无需重建 spec(见 :func:`pyflowx.runner._apply_verbose_to_graph`)。 + 若提供 ``cmd``,返回包装后的命令执行函数;否则返回 ``fn``。 + 包装函数在每次调用时从 ``self`` 读取 ``verbose``/``cwd``/``env``/ + ``timeout``,避免闭包捕获运行期参数,使翻转字段无需重建 spec。 """ if self.cmd is not None: return self._wrap_cmd() if self.fn is not None: return self.fn - raise ValueError(f"TaskSpec '{self.name}': 没有可执行的函数或命令。") # pragma: no cover def _wrap_cmd(self) -> TaskFn[Any]: - """将 cmd 包装为可执行函数. - - 返回的闭包仅持有 ``self`` 引用,每次调用时从 spec 读取 - ``verbose``/``cwd``/``timeout``,避免闭包捕获运行期参数。 - - Returns - ------- - TaskFn[Any] - 包装后的执行函数. - """ + """将 cmd 包装为可执行函数。""" spec = self - if isinstance(spec.cmd, list): + def _run() -> T: + return cast(T, _run_command(spec)) - def _run_list() -> T: - return cast(T, _run_command(spec)) + _run.__name__ = spec.name + return _run # type: ignore[return-value] - _run_list.__name__ = spec.name - return _run_list # type: ignore[return-value] - - if isinstance(spec.cmd, str): - - def _run_shell() -> T: - return cast(T, _run_command(spec)) - - _run_shell.__name__ = spec.name - return _run_shell # type: ignore[return-value] - - if callable(spec.cmd): - return spec.cmd # type: ignore[return-value] - - raise TypeError(f"TaskSpec '{spec.name}': 不支持的 cmd 类型 {type(spec.cmd).__name__}") # pragma: no cover - - def should_execute(self) -> bool: - """检查任务是否应该执行. + def should_execute(self, context: Context) -> Tuple[bool, Optional[str]]: + """检查任务是否应执行。 Returns ------- - bool - 若所有条件都返回 ``True``,且 ``skip_if_missing`` 检查通过, - 则返回 ``True``;否则返回 ``False``。 + (should_run, skip_reason) + ``should_run`` 为 False 时 ``skip_reason`` 描述跳过原因。 """ - if not all(condition() for condition in self.conditions): - return False + # 逐个求值条件,记录失败项。 + failed_conditions: list[str] = [] + for condition in self.conditions: + try: + ok = condition(context) + except Exception: + ok = False + name = getattr(condition, "__name__", None) or "匿名条件(执行错误)" + failed_conditions.append(name) + continue + if not ok: + failed_conditions.append(getattr(condition, "__name__", None) or "匿名条件") - return not (self.skip_if_missing and not self._is_cmd_available()) + if failed_conditions: + return False, f"条件不满足: {', '.join(failed_conditions)}" + + if self.skip_if_missing and not self._is_cmd_available(): + cmd_name = self.cmd[0] if isinstance(self.cmd, list) and self.cmd else "unknown" + return False, f"命令不存在: {cmd_name}" + + return True, None def _is_cmd_available(self) -> bool: - """检查 ``cmd`` 是否可用. - - 仅对 ``list[str]`` 类型的 ``cmd`` 进行检查(通过 :func:`shutil.which`)。 - 对于 ``str`` (shell) 和 ``Callable`` 类型,始终返回 ``True``。 - - Returns - ------- - bool - 命令可用返回 ``True``,否则返回 ``False``。 - """ - + """检查 ``cmd`` 是否可用(仅 list[str])。""" cmd = self.cmd if isinstance(cmd, list) and cmd: - first_arg = cmd[0] - return shutil.which(first_arg) is not None + return shutil.which(cmd[0]) is not None return True + def env_context(self) -> ContextManager[None]: + """返回临时应用 ``env`` 与 ``cwd`` 的上下文管理器。 -def _run_command(spec: "TaskSpec[Any]") -> Any: - """执行 ``spec.cmd`` 指定的命令(list 或 shell 字符串)。 + 对 ``fn`` 任务生效。``cmd`` 任务在 :func:`_run_command` 中直接 + 传给子进程。 + """ + return _env_and_cwd(self.env, self.cwd) - list 与 shell 两条路径的异常处理、输出捕获、返回码判断完全一致, - 合并于此消除重复。``verbose``/``cwd``/``timeout`` 在调用时从 - ``spec`` 读取,而非闭包捕获——这是 ``_wrap_cmd`` 不再捕获运行期 - 参数的关键。 + def storage_key(self, context: Context) -> str: + """计算状态后端存储键。""" + if self.cache_key is not None: + try: + return f"{self.name}:{self.cache_key(context)}" + except Exception: + return self.name + return self.name - 成功返回 ``None``;失败抛 ``RuntimeError``,错误信息包含命令、 - 返回码与(非 verbose 模式下的)stderr。 - """ + +@contextmanager +def _env_and_cwd( + env: Optional[Mapping[str, str]], + cwd: Optional[Path], +) -> Iterator[None]: + """临时设置环境变量与工作目录。""" + saved_env: dict[str, str] = {} + saved_cwd: Optional[str] = None + if env: + for k, v in env.items(): + if k in os.environ: + saved_env[k] = os.environ[k] + os.environ[k] = v + if cwd is not None: + saved_cwd = str(Path.cwd()) + os.chdir(cwd) + try: + yield + finally: + if saved_cwd is not None: + os.chdir(saved_cwd) + # 恢复环境变量 + if env: + for k in env: + if k in saved_env: + os.environ[k] = saved_env[k] + else: + os.environ.pop(k, None) + + +def _run_command(spec: "TaskSpec[Any]") -> Any: # noqa: PLR0912 + """执行 ``spec.cmd`` 指定的命令(list / shell 字符串 / 可调用对象)。""" cmd = spec.cmd - is_list = isinstance(cmd, list) verbose = spec.verbose cwd = spec.cwd timeout = spec.timeout + env_override = spec.env - # 统一展示用的命令字符串与标签。保持 "执行命令" / "执行 Shell" 连续, - # 以兼容既有输出格式与测试断言。 + # 可调用对象:直接调用,返回其结果。 + if callable(cmd) and not isinstance(cmd, (list, str)): + name = getattr(cmd, "__name__", "callable") + if verbose: + print(f"[verbose] 执行可调用命令: {name}", flush=True) + if cwd is not None: + print(f"[verbose] 工作目录: {cwd}", flush=True) + try: + return cmd() + except Exception as e: + raise RuntimeError(f"可调用命令执行异常: {name}: {e}") from e + + is_list = isinstance(cmd, list) if is_list: cmd_str = " ".join(arg for arg in cmd) # type: ignore[union-attr] verb = "执行命令" @@ -279,14 +428,18 @@ def _run_command(spec: "TaskSpec[Any]") -> Any: if cwd is not None: print(f"[verbose] 工作目录: {cwd}", flush=True) + # 合并环境变量 + run_env: Optional[dict[str, str]] = None + if env_override: + run_env = dict(os.environ) + run_env.update(env_override) + try: - # cmd 此处必为 list[str] 或 str(_wrap_cmd 的 isinstance 守卫已排除 - # None 与 Callable),但类型检查器无法跨函数推断,故 cast 收窄到 - # subprocess.run 接受的 Union[str, Sequence[str]]。 result = subprocess.run( cast(Union[str, List[str]], cmd), shell=not is_list, cwd=cwd, + env=run_env, timeout=timeout, capture_output=not verbose, text=True, @@ -311,13 +464,42 @@ def _run_command(spec: "TaskSpec[Any]") -> Any: raise RuntimeError(err_msg) +# ---------------------------------------------------------------------- # +# 任务模板:批量生成相似 TaskSpec 的工厂 +# ---------------------------------------------------------------------- # +def task_template( + fn: Optional[TaskFn[Any]] = None, + cmd: Optional[TaskCmd] = None, + **defaults: Any, +) -> Callable[..., TaskSpec[Any]]: + """创建任务模板工厂。 + + 返回的工厂接受 ``name`` 与任意覆盖字段,生成 :class:`TaskSpec`。 + 适用于批量创建相似任务(如 fan-out)。 + + Examples + -------- + >>> Fetch = px.task_template(fn=fetch_user, retry=px.RetryPolicy(max_attempts=3)) + >>> specs = [Fetch(f"fetch_{uid}", args=(uid,)) for uid in range(5)] + """ + base = dict(defaults) + if fn is not None: + base["fn"] = fn + if cmd is not None: + base["cmd"] = cmd + + def _factory(name: str, **overrides: Any) -> TaskSpec[Any]: + merged = dict(base) + merged.update(overrides) + return TaskSpec(name, **merged) + + _factory.__name__ = "task_template_factory" + return _factory + + @dataclass class TaskResult(Generic[T]): - """运行期间产生的可变单任务记录。 - - 每次运行都会创建全新的 :class:`TaskResult`;spec 本身保持不可变。 - 这让同一个图可以安全地重复运行。 - """ + """运行期间产生的可变单任务记录。""" spec: TaskSpec[T] status: TaskStatus = TaskStatus.PENDING @@ -338,15 +520,11 @@ class TaskResult(Generic[T]): @dataclass(frozen=True) class TaskEvent: - """执行期间向观察者发出的不可变事件。 - - 传递给 :func:`pyflowx.run` 的 ``on_event`` 回调,让调用者无需耦合 - 执行器内部即可构建进度条、指标或结构化日志。 - """ + """执行期间向观察者发出的不可变事件。""" task: str status: TaskStatus attempts: int = 0 error: Optional[str] = None duration: Optional[float] = None - reason: Optional[str] = None # 跳过原因,如 "条件不满足"、"上游任务被跳过"、"缓存" + reason: Optional[str] = None diff --git a/tests/test_advanced_features.py b/tests/test_advanced_features.py new file mode 100644 index 0000000..0d0d28f --- /dev/null +++ b/tests/test_advanced_features.py @@ -0,0 +1,1222 @@ +"""覆盖 PyFlowX 任务流优化的全部新特性。 + +特性清单 +-------- +1. RetryPolicy:max_attempts / delay / backoff / jitter / retry_on +2. TaskHooks:pre_run / post_run / on_failure 钩子 +3. GraphDefaults:图级默认值回退 +4. 软依赖 soft_depends_on +5. 依赖驱动调度 strategy="dependency"(无层屏障) +6. 并发限制 concurrency_key + concurrency_limits +7. 任务优先级 priority +8. continue_on_error 容错 +9. 每任务执行策略 strategy(spec 级) +10. fan-out / map 工厂 +11. compose 图组合 +12. task_template 模板工厂 +13. cache_key 缓存键 +14. env / cwd 运行时隔离 +15. 上下文感知条件 DEP_EQUALS / DEP_MATCHES / DEP_PRESENT / DEP_TRUTHY +16. 动态分支:基于上游结果选择下游 +""" + +from __future__ import annotations + +import asyncio +import os +import sys +import time +from pathlib import Path +from typing import Any + +import pytest + +import pyflowx as px +from pyflowx.conditions import BuiltinConditions +from pyflowx.storage import MemoryBackend +from pyflowx.task import RetryPolicy, TaskHooks, TaskStatus + + +# ---------------------------------------------------------------------- # +# RetryPolicy +# ---------------------------------------------------------------------- # +class TestRetryPolicy: + """测试 RetryPolicy 数据结构与重试行为。""" + + def test_retry_policy_defaults(self) -> None: + policy = RetryPolicy() + assert policy.max_attempts == 1 + assert policy.delay == 0.0 + assert policy.backoff == 1.0 + assert policy.jitter == 0.0 + assert policy.retry_on == (Exception,) + + def test_retry_policy_custom(self) -> None: + policy = RetryPolicy( + max_attempts=5, + delay=0.1, + backoff=2.0, + jitter=0.05, + retry_on=(ValueError, KeyError), + ) + assert policy.max_attempts == 5 + assert policy.delay == 0.1 + assert policy.backoff == 2.0 + assert policy.jitter == 0.05 + assert policy.retry_on == (ValueError, KeyError) + + def test_retry_policy_rejects_zero_attempts(self) -> None: + with pytest.raises(ValueError, match="max_attempts"): + RetryPolicy(max_attempts=0) + + def test_retry_policy_rejects_negative_backoff(self) -> None: + with pytest.raises(ValueError, match="backoff"): + RetryPolicy(backoff=-1.0) + + def test_retry_succeeds_after_failures(self) -> None: + calls = {"n": 0} + + def flaky() -> str: + calls["n"] += 1 + if calls["n"] < 3: + raise RuntimeError("not yet") + return "ok" + + graph = px.Graph.from_specs([ + px.TaskSpec("flaky", flaky, retry=RetryPolicy(max_attempts=3)), + ]) + report = px.run(graph, strategy="sequential") + assert report.success + assert report["flaky"] == "ok" + assert calls["n"] == 3 + assert report.result_of("flaky").attempts == 3 + + def test_retry_exhausted_raises(self) -> None: + def always_fail() -> None: + raise RuntimeError("nope") + + graph = px.Graph.from_specs([ + px.TaskSpec("f", always_fail, retry=RetryPolicy(max_attempts=3)), + ]) + with pytest.raises(px.TaskFailedError) as exc_info: + px.run(graph, strategy="sequential") + assert exc_info.value.attempts == 3 + + def test_retry_on_specific_exception_only(self) -> None: + """retry_on 限制只对指定异常重试。""" + calls = {"n": 0} + + def fail_with_keyerror() -> None: + calls["n"] += 1 + raise KeyError("not retried") + + # retry_on=(ValueError,) -> KeyError 不应被重试 + graph = px.Graph.from_specs([ + px.TaskSpec( + "f", + fail_with_keyerror, + retry=RetryPolicy(max_attempts=3, retry_on=(ValueError,)), + ), + ]) + with pytest.raises(px.TaskFailedError) as exc_info: + px.run(graph, strategy="sequential") + # KeyError 不在 retry_on 中,应只尝试 1 次 + assert exc_info.value.attempts == 1 + assert calls["n"] == 1 + + def test_retry_with_backoff_delay(self) -> None: + """backoff 应使每次重试间隔翻倍。""" + # pyrefly: ignore [implicit-any-empty-container] + calls = {"n": 0, "times": []} + + def flaky() -> str: + calls["n"] += 1 + calls["times"].append(time.monotonic()) + if calls["n"] < 3: + raise RuntimeError("not yet") + return "ok" + + graph = px.Graph.from_specs([ + px.TaskSpec( + "flaky", + flaky, + retry=RetryPolicy(max_attempts=3, delay=0.05, backoff=2.0), + ), + ]) + report = px.run(graph, strategy="sequential") + assert report.success + # 第 2 次重试应在 delay=0.05 后,第 3 次应在 0.05*2=0.10 后 + gap1 = calls["times"][1] - calls["times"][0] + gap2 = calls["times"][2] - calls["times"][1] + assert gap1 >= 0.04 + assert gap2 >= 0.08 + assert gap2 > gap1 + + def test_retry_async_strategy(self) -> None: + calls = {"n": 0} + + async def flaky() -> str: + calls["n"] += 1 + if calls["n"] < 3: + raise RuntimeError("not yet") + return "ok" + + graph = px.Graph.from_specs([ + px.TaskSpec("flaky", flaky, retry=RetryPolicy(max_attempts=3)), + ]) + report = px.run(graph, strategy="async") + assert report.success + assert report["flaky"] == "ok" + assert calls["n"] == 3 + + +# ---------------------------------------------------------------------- # +# TaskHooks +# ---------------------------------------------------------------------- # +class TestTaskHooks: + """测试任务生命周期钩子。""" + + def test_pre_run_hook_called(self) -> None: + events: list[str] = [] + + def pre_run(spec: px.TaskSpec[Any]) -> None: + events.append(f"pre:{spec.name}") + + def fn() -> str: + events.append("run") + return "ok" + + hooks = TaskHooks(pre_run=pre_run) + graph = px.Graph.from_specs([ + px.TaskSpec("t", fn, hooks=hooks), + ]) + report = px.run(graph, strategy="sequential") + assert report.success + assert events == ["pre:t", "run"] + + def test_post_run_hook_called_with_result(self) -> None: + captured: dict[str, Any] = {} + + def post_run(spec: px.TaskSpec[Any], result: Any) -> None: + captured["name"] = spec.name + captured["result"] = result + + def fn() -> int: + return 42 + + hooks = TaskHooks(post_run=post_run) + graph = px.Graph.from_specs([ + px.TaskSpec("t", fn, hooks=hooks), + ]) + report = px.run(graph, strategy="sequential") + assert report.success + assert captured == {"name": "t", "result": 42} + + def test_on_failure_hook_called(self) -> None: + captured: dict[str, Any] = {} + + def on_failure(spec: px.TaskSpec[Any], exc: BaseException) -> None: + captured["name"] = spec.name + captured["exc"] = exc + + def fn() -> None: + raise ValueError("boom") + + hooks = TaskHooks(on_failure=on_failure) + graph = px.Graph.from_specs([ + px.TaskSpec("t", fn, hooks=hooks, continue_on_error=True), + ]) + report = px.run(graph, strategy="sequential") + # continue_on_error=True -> 报告成功但任务失败 + assert report.success + assert captured["name"] == "t" + assert isinstance(captured["exc"], ValueError) + + def test_hooks_not_called_on_skip(self) -> None: + events: list[str] = [] + + def pre_run(spec: px.TaskSpec[Any]) -> None: + events.append("pre") + + def post_run(spec: px.TaskSpec[Any], result: Any) -> None: + events.append("post") + + hooks = TaskHooks(pre_run=pre_run, post_run=post_run) + graph = px.Graph.from_specs([ + px.TaskSpec( + "t", + fn=lambda: "ok", + hooks=hooks, + conditions=(lambda _ctx: False,), + ), + ]) + report = px.run(graph, strategy="sequential") + assert report.success + assert report.result_of("t").status == TaskStatus.SKIPPED + assert events == [] + + def test_hooks_with_async_strategy(self) -> None: + events: list[str] = [] + + def pre_run(spec: px.TaskSpec[Any]) -> None: + events.append(f"pre:{spec.name}") + + async def fn() -> str: + events.append("run") + return "ok" + + hooks = TaskHooks(pre_run=pre_run) + graph = px.Graph.from_specs([ + px.TaskSpec("t", fn, hooks=hooks), + ]) + report = px.run(graph, strategy="async") + assert report.success + assert events == ["pre:t", "run"] + + +# ---------------------------------------------------------------------- # +# GraphDefaults +# ---------------------------------------------------------------------- # +class TestGraphDefaults: + """测试图级默认值回退。""" + + def test_defaults_applied_to_specs(self) -> None: + defaults = px.GraphDefaults( + retry=RetryPolicy(max_attempts=5), + timeout=10.0, + tags=("default-tag",), + priority=3, + ) + graph = px.Graph(defaults=defaults) + graph.add(px.TaskSpec("a", lambda: "ok")) + resolved = graph.resolved_spec("a") + assert resolved.retry.max_attempts == 5 + assert resolved.timeout == 10.0 + assert resolved.priority == 3 + + def test_spec_overrides_defaults(self) -> None: + defaults = px.GraphDefaults( + retry=RetryPolicy(max_attempts=5), + timeout=10.0, + ) + graph = px.Graph(defaults=defaults) + graph.add( + px.TaskSpec( + "a", + lambda: "ok", + retry=RetryPolicy(max_attempts=2), + timeout=1.0, + ) + ) + resolved = graph.resolved_spec("a") + assert resolved.retry.max_attempts == 2 + assert resolved.timeout == 1.0 + + def test_defaults_empty_when_not_set(self) -> None: + graph = px.Graph() + graph.add(px.TaskSpec("a", lambda: "ok")) + resolved = graph.resolved_spec("a") + # 无默认值时回退到 spec 自身的 retry(默认 max_attempts=1) + assert resolved.retry.max_attempts == 1 + assert resolved.timeout is None + + def test_defaults_with_run(self) -> None: + calls = {"n": 0} + + def flaky() -> str: + calls["n"] += 1 + if calls["n"] < 3: + raise RuntimeError("not yet") + return "ok" + + defaults = px.GraphDefaults(retry=RetryPolicy(max_attempts=3)) + graph = px.Graph.from_specs( + [px.TaskSpec("flaky", flaky)], + defaults=defaults, + ) + report = px.run(graph, strategy="sequential") + assert report.success + assert calls["n"] == 3 + + +# ---------------------------------------------------------------------- # +# 软依赖 soft_depends_on +# ---------------------------------------------------------------------- # +class TestSoftDependencies: + """测试软依赖:等待完成但不传播失败。""" + + def test_soft_dependency_waits_for_completion(self) -> None: + order: list[str] = [] + + def slow() -> str: + time.sleep(0.05) + order.append("slow") + return "slow" + + def fast(slow: str) -> str: + order.append("fast") + return f"after-{slow}" + + graph = px.Graph.from_specs([ + px.TaskSpec("slow", slow), + px.TaskSpec("fast", fast, soft_depends_on=("slow",)), + ]) + report = px.run(graph, strategy="dependency") + assert report.success + # soft 依赖应等待 slow 完成后再执行 fast + assert order == ["slow", "fast"] + assert report["fast"] == "after-slow" + + def test_soft_dependency_does_not_propagate_failure(self) -> None: + """软依赖上游失败时,下游仍应执行(硬依赖会跳过)。""" + + def fail() -> None: + raise RuntimeError("upstream failed") + + def downstream(fail: str = "default") -> str: + return f"got:{fail}" + + graph = px.Graph.from_specs([ + px.TaskSpec("fail", fail, continue_on_error=True), + px.TaskSpec( + "downstream", + downstream, + soft_depends_on=("fail",), + continue_on_error=True, + ), + ]) + report = px.run(graph, strategy="dependency") + assert report.success + # fail 失败但下游仍执行(使用默认值) + assert report.result_of("fail").status == TaskStatus.FAILED + assert report.result_of("downstream").status == TaskStatus.SUCCESS + + def test_soft_dependency_validation_unknown_dep(self) -> None: + with pytest.raises(px.MissingDependencyError): + px.Graph.from_specs([ + px.TaskSpec("a", lambda: "ok", soft_depends_on=("missing",)), + ]) + + def test_soft_and_hard_dependency_combined(self) -> None: + order: list[str] = [] + + def a() -> str: + order.append("a") + return "a" + + def b(a: str) -> str: + order.append("b") + return f"b-{a}" + + def c(b: str) -> str: + order.append("c") + return f"c-{b}" + + graph = px.Graph.from_specs([ + px.TaskSpec("a", a), + px.TaskSpec("b", b, depends_on=("a",)), + px.TaskSpec("c", c, depends_on=("b",), soft_depends_on=("a",)), + ]) + report = px.run(graph, strategy="dependency") + assert report.success + assert order == ["a", "b", "c"] + + +# ---------------------------------------------------------------------- # +# 依赖驱动调度 strategy="dependency" +# ---------------------------------------------------------------------- # +class TestDependencyDrivenScheduling: + """测试依赖驱动调度:任务在依赖完成后立即启动,无层屏障。""" + + def test_dependency_strategy_basic(self) -> None: + def a() -> int: + return 1 + + def b(a: int) -> int: + return a + 1 + + def c(b: int) -> int: + return b + 1 + + graph = px.Graph.from_specs([ + px.TaskSpec("a", a), + px.TaskSpec("b", b, depends_on=("a",)), + px.TaskSpec("c", c, depends_on=("b",)), + ]) + report = px.run(graph, strategy="dependency") + assert report.success + assert report["a"] == 1 + assert report["b"] == 2 + assert report["c"] == 3 + + def test_dependency_strategy_faster_than_layered(self) -> None: + """依赖驱动应比层屏障更快(无层等待)。""" + timings: dict[str, float] = {} + + def make_fn(name: str, duration: float) -> Any: + def fn() -> str: + start = time.monotonic() + time.sleep(duration) + timings[name] = time.monotonic() - start + return name + + return fn + + # a (慢) -> b (快) 在同一层 + # a (快) -> c (慢) 在同一层 + # 依赖驱动:c 在 a 完成后立即启动,不必等 b + graph = px.Graph.from_specs([ + px.TaskSpec("a", make_fn("a", 0.05)), + px.TaskSpec("b", make_fn("b", 0.05), depends_on=("a",)), + px.TaskSpec("c", make_fn("c", 0.05), depends_on=("a",)), + px.TaskSpec("d", make_fn("d", 0.01), depends_on=("b", "c")), + ]) + start = time.monotonic() + report = px.run(graph, strategy="dependency") + elapsed = time.monotonic() - start + assert report.success + # a(0.05) + max(b,c)(0.05) + d(0.01) ≈ 0.11,层屏障会更慢 + assert elapsed < 0.20 + + def test_dependency_strategy_with_async_fn(self) -> None: + async def a() -> str: + await asyncio.sleep(0.01) + return "a" + + async def b(a: str) -> str: + return f"b-{a}" + + graph = px.Graph.from_specs([ + px.TaskSpec("a", a), + px.TaskSpec("b", b, depends_on=("a",)), + ]) + report = px.run(graph, strategy="dependency") + assert report.success + assert report["b"] == "b-a" + + def test_dependency_strategy_diamond(self) -> None: + """菱形依赖:a -> b,c -> d。""" + + def a() -> int: + return 10 + + def b(a: int) -> int: + return a * 2 + + def c(a: int) -> int: + return a + 5 + + def d(b: int, c: int) -> int: + return b + c + + graph = px.Graph.from_specs([ + px.TaskSpec("a", a), + px.TaskSpec("b", b, depends_on=("a",)), + px.TaskSpec("c", c, depends_on=("a",)), + px.TaskSpec("d", d, depends_on=("b", "c")), + ]) + report = px.run(graph, strategy="dependency") + assert report.success + assert report["a"] == 10 + assert report["b"] == 20 + assert report["c"] == 15 + assert report["d"] == 35 + + +# ---------------------------------------------------------------------- # +# 并发限制 concurrency_key + concurrency_limits +# ---------------------------------------------------------------------- # +class TestConcurrencyLimits: + """测试并发限制:相同 concurrency_key 的任务串行执行。""" + + def test_concurrency_key_serializes_tasks(self) -> None: + """相同 key 的任务不应并发执行。""" + running: list[int] = [] + max_concurrent = {"n": 0} + + def make_fn(idx: int) -> Any: + def fn() -> int: + running.append(idx) + cur = len(running) + max_concurrent["n"] = max(max_concurrent["n"], cur) + time.sleep(0.05) + running.remove(idx) + return idx + + return fn + + graph = px.Graph.from_specs([ + px.TaskSpec("a", make_fn(1), concurrency_key="db"), + px.TaskSpec("b", make_fn(2), concurrency_key="db"), + px.TaskSpec("c", make_fn(3), concurrency_key="db"), + ]) + report = px.run( + graph, + strategy="dependency", + concurrency_limits={"db": 1}, + ) + assert report.success + # 最多同时运行 1 个 + assert max_concurrent["n"] == 1 + + def test_concurrency_key_allows_parallel_different_keys(self) -> None: + """不同 key 的任务可并发执行。""" + running: list[str] = [] + max_concurrent = {"n": 0} + + def make_fn(name: str) -> Any: + def fn() -> str: + running.append(name) + cur = len(running) + max_concurrent["n"] = max(max_concurrent["n"], cur) + time.sleep(0.05) + running.remove(name) + return name + + return fn + + graph = px.Graph.from_specs([ + px.TaskSpec("a", make_fn("a"), concurrency_key="db1"), + px.TaskSpec("b", make_fn("b"), concurrency_key="db2"), + ]) + report = px.run( + graph, + strategy="dependency", + concurrency_limits={"db1": 1, "db2": 1}, + ) + assert report.success + # 不同 key 可并发 + assert max_concurrent["n"] == 2 + + def test_concurrency_limit_greater_than_one(self) -> None: + """limit=2 允许 2 个并发。""" + running: list[int] = [] + max_concurrent = {"n": 0} + + def make_fn(idx: int) -> Any: + def fn() -> int: + running.append(idx) + cur = len(running) + max_concurrent["n"] = max(max_concurrent["n"], cur) + time.sleep(0.05) + running.remove(idx) + return idx + + return fn + + graph = px.Graph.from_specs([ + px.TaskSpec("a", make_fn(1), concurrency_key="pool"), + px.TaskSpec("b", make_fn(2), concurrency_key="pool"), + px.TaskSpec("c", make_fn(3), concurrency_key="pool"), + px.TaskSpec("d", make_fn(4), concurrency_key="pool"), + ]) + report = px.run( + graph, + strategy="dependency", + concurrency_limits={"pool": 2}, + ) + assert report.success + assert max_concurrent["n"] <= 2 + assert max_concurrent["n"] == 2 + + +# ---------------------------------------------------------------------- # +# 任务优先级 priority +# ---------------------------------------------------------------------- # +class TestPriority: + """测试任务优先级:高优先级任务优先调度。""" + + def test_priority_orders_independent_tasks(self) -> None: + """无依赖任务按优先级降序执行。""" + order: list[str] = [] + + def make_fn(name: str) -> Any: + def fn() -> str: + order.append(name) + return name + + return fn + + graph = px.Graph.from_specs([ + px.TaskSpec("low", make_fn("low"), priority=1), + px.TaskSpec("high", make_fn("high"), priority=10), + px.TaskSpec("mid", make_fn("mid"), priority=5), + ]) + report = px.run(graph, strategy="sequential") + assert report.success + # 高优先级先执行 + assert order == ["high", "mid", "low"] + + def test_priority_default_zero(self) -> None: + spec = px.TaskSpec("a", lambda: "ok") + assert spec.priority == 0 + + +# ---------------------------------------------------------------------- # +# continue_on_error 容错 +# ---------------------------------------------------------------------- # +class TestContinueOnError: + """测试 continue_on_error:任务失败不中断整体流程。""" + + def test_continue_on_error_allows_downstream(self) -> None: + """continue_on_error 使失败任务不中断流程;硬依赖下游被跳过。""" + + def fail() -> None: + raise RuntimeError("boom") + + def downstream() -> str: + return "ran" + + graph = px.Graph.from_specs([ + px.TaskSpec("fail", fail, continue_on_error=True), + px.TaskSpec("downstream", downstream, depends_on=("fail",)), + ]) + report = px.run(graph, strategy="sequential") + # continue_on_error 使整体报告成功(不抛异常) + assert report.success + assert report.result_of("fail").status == TaskStatus.FAILED + # 硬依赖下游被跳过(上游失败传播) + assert report.result_of("downstream").status == TaskStatus.SKIPPED + + def test_continue_on_error_with_soft_dep_executes_downstream(self) -> None: + """软依赖 + continue_on_error:下游仍执行(软依赖不传播失败)。""" + + def fail() -> None: + raise RuntimeError("boom") + + def downstream() -> str: + return "ran" + + graph = px.Graph.from_specs([ + px.TaskSpec("fail", fail, continue_on_error=True), + px.TaskSpec("downstream", downstream, soft_depends_on=("fail",)), + ]) + report = px.run(graph, strategy="dependency") + assert report.success + assert report.result_of("fail").status == TaskStatus.FAILED + # 软依赖下游仍执行 + assert report.result_of("downstream").status == TaskStatus.SUCCESS + assert report["downstream"] == "ran" + + def test_continue_on_error_with_dependency_strategy(self) -> None: + def fail() -> None: + raise RuntimeError("boom") + + def other() -> str: + return "ok" + + graph = px.Graph.from_specs([ + px.TaskSpec("fail", fail, continue_on_error=True), + px.TaskSpec("other", other), + ]) + report = px.run(graph, strategy="dependency") + assert report.success + assert report.result_of("fail").status == TaskStatus.FAILED + assert report.result_of("other").status == TaskStatus.SUCCESS + + def test_without_continue_on_error_raises(self) -> None: + def fail() -> None: + raise RuntimeError("boom") + + def other() -> str: + return "ok" + + graph = px.Graph.from_specs([ + px.TaskSpec("fail", fail), + px.TaskSpec("other", other), + ]) + with pytest.raises(px.TaskFailedError): + px.run(graph, strategy="sequential") + + def test_continue_on_error_graph_defaults(self) -> None: + def fail() -> None: + raise RuntimeError("boom") + + defaults = px.GraphDefaults(continue_on_error=True) + graph = px.Graph.from_specs([px.TaskSpec("fail", fail)], defaults=defaults) + report = px.run(graph, strategy="sequential") + assert report.success + assert report.result_of("fail").status == TaskStatus.FAILED + + +# ---------------------------------------------------------------------- # +# fan-out / map 工厂 +# ---------------------------------------------------------------------- # +class TestMapFactory: + """测试 Graph.map 工厂:为每个 item 生成 TaskSpec。""" + + def test_map_generates_tasks_per_item(self) -> None: + def process(item: int) -> int: + return item * 2 + + template = px.TaskSpec("template", process) + graph = px.Graph() + specs = graph.map( + name_fn=lambda i: f"task_{i}", + spec=template, + items=[1, 2, 3], + ) + assert len(specs) == 3 + assert [s.name for s in specs] == ["task_0", "task_1", "task_2"] + report = px.run(graph, strategy="sequential") + assert report.success + assert report["task_0"] == 2 + assert report["task_1"] == 4 + assert report["task_2"] == 6 + + def test_map_with_arg_factory(self) -> None: + def process(a: int, b: int) -> int: + return a + b + + template = px.TaskSpec("template", process) + graph = px.Graph() + graph.map( + name_fn=lambda i: f"sum_{i}", + spec=template, + items=[(1, 10), (2, 20), (3, 30)], + arg_factory=lambda item: (item[0], item[1]), + ) + report = px.run(graph, strategy="sequential") + assert report.success + assert report["sum_0"] == 11 + assert report["sum_1"] == 22 + assert report["sum_2"] == 33 + + def test_map_with_per_item_dependencies(self) -> None: + def source() -> list[int]: + return [1, 2, 3] + + def process(item: int) -> int: + return item * 10 + + graph = px.Graph() + graph.add(px.TaskSpec("source", source)) + graph.map( + name_fn=lambda i: f"proc_{i}", + spec=px.TaskSpec("template", process), + items=[1, 2, 3], + depends_on_per=lambda _i: ("source",), + ) + report = px.run(graph, strategy="dependency") + assert report.success + assert report["proc_0"] == 10 + assert report["proc_1"] == 20 + assert report["proc_2"] == 30 + + +# ---------------------------------------------------------------------- # +# compose 图组合 +# ---------------------------------------------------------------------- # +class TestCompose: + """测试 compose / GraphComposer 图组合函数。 + + compose 接收 ``{name: Graph}`` 映射,解析图间的字符串引用, + 返回展开后的新映射。 + """ + + def test_compose_resolves_string_references(self) -> None: + def extract() -> list[int]: + return [1, 2, 3] + + def transform(extract: list[int]) -> list[int]: + return [x * 2 for x in extract] + + # extract 图 + g_extract = px.Graph.from_specs([px.TaskSpec("extract", extract)]) + # transform 图:通过 _pending_refs 引用 "extract" 命令 + # transform 自身不声明 depends_on,由 compose 展开时自动连接 + g_transform = px.Graph.from_specs([ + px.TaskSpec("transform", transform), + ]) + g_transform._pending_refs = ["extract"] + + resolved = px.compose({"extract": g_extract, "transform": g_transform}) + assert set(resolved.keys()) == {"extract", "transform"} + # transform 图应被展开,包含 extract 任务 + expanded = resolved["transform"] + assert "extract" in expanded.all_specs() + assert "transform" in expanded.all_specs() + report = px.run(expanded, strategy="dependency") + assert report.success + assert report["transform"] == [2, 4, 6] + + def test_compose_no_refs_returns_unchanged(self) -> None: + def a() -> str: + return "a" + + g = px.Graph.from_specs([px.TaskSpec("a", a)]) + resolved = px.compose({"cmd": g}) + assert set(resolved.keys()) == {"cmd"} + report = px.run(resolved["cmd"], strategy="sequential") + assert report.success + assert report["a"] == "a" + + def test_graphcomposer_class_equivalent(self) -> None: + def a() -> str: + return "a" + + g = px.Graph.from_specs([px.TaskSpec("a", a)]) + composer = px.GraphComposer({"cmd": g}) + resolved = composer.resolve_all() + assert "a" in resolved["cmd"].all_specs() + + +# ---------------------------------------------------------------------- # +# task_template 模板工厂 +# ---------------------------------------------------------------------- # +class TestTaskTemplate: + """测试 task_template 批量生成 TaskSpec。""" + + def test_task_template_generates_specs(self) -> None: + def process(item: int) -> int: + return item**2 + + template = px.task_template( + fn=process, + retry=RetryPolicy(max_attempts=2), + tags=("compute",), + ) + specs = [template(f"task_{i}", args=(i,)) for i in range(3)] + graph = px.Graph.from_specs(specs) + report = px.run(graph, strategy="sequential") + assert report.success + assert report["task_0"] == 0 + assert report["task_1"] == 1 + assert report["task_2"] == 4 + # 模板属性应继承 + assert all(s.retry.max_attempts == 2 for s in specs) + assert all(s.tags == ("compute",) for s in specs) + + +# ---------------------------------------------------------------------- # +# cache_key 缓存 +# ---------------------------------------------------------------------- # +class TestCacheKey: + """测试 cache_key 自定义缓存键。 + + cache_key 签名为 ``Callable[[Context], str]``,仅接收上下文。 + 通过闭包捕获 args 来生成输入相关的键。 + """ + + def test_cache_key_hits_on_same_input(self) -> None: + calls = {"n": 0} + + def expensive(x: int) -> int: + calls["n"] += 1 + return x * 2 + + backend = MemoryBackend() + + # 通过闭包捕获 args 生成缓存键 + def make_cache_key(arg: int) -> Any: + def key(ctx: Any) -> str: + return f"cache::t::{arg}" + + return key + + graph1 = px.Graph.from_specs([ + px.TaskSpec("t", expensive, args=(5,), cache_key=make_cache_key(5)), + ]) + report1 = px.run(graph1, strategy="sequential", state=backend) + assert report1.success + assert report1["t"] == 10 + assert calls["n"] == 1 + + # 第二次运行相同输入应命中缓存 + graph2 = px.Graph.from_specs([ + px.TaskSpec("t", expensive, args=(5,), cache_key=make_cache_key(5)), + ]) + report2 = px.run(graph2, strategy="sequential", state=backend) + assert report2.success + assert report2["t"] == 10 + # 不应再次调用 fn + assert calls["n"] == 1 + + def test_cache_key_miss_on_different_input(self) -> None: + calls = {"n": 0} + + def expensive(x: int) -> int: + calls["n"] += 1 + return x * 2 + + backend = MemoryBackend() + + def make_cache_key(arg: int) -> Any: + def key(ctx: Any) -> str: + return f"cache::t::{arg}" + + return key + + graph1 = px.Graph.from_specs([ + px.TaskSpec("t", expensive, args=(5,), cache_key=make_cache_key(5)), + ]) + px.run(graph1, strategy="sequential", state=backend) + assert calls["n"] == 1 + + # 不同输入应 miss + graph2 = px.Graph.from_specs([ + px.TaskSpec("t", expensive, args=(7,), cache_key=make_cache_key(7)), + ]) + px.run(graph2, strategy="sequential", state=backend) + assert calls["n"] == 2 + + +# ---------------------------------------------------------------------- # +# env / cwd 运行时隔离 +# ---------------------------------------------------------------------- # +class TestEnvAndCwd: + """测试环境变量与工作目录隔离。""" + + def test_env_override_for_cmd(self) -> None: + graph = px.Graph.from_specs([ + px.TaskSpec( + "print_var", + cmd=[sys.executable, "-c", "import os; print(os.environ.get('PYFLOWX_TEST_VAR', 'unset'))"], + env={"PYFLOWX_TEST_VAR": "isolated"}, + ), + ]) + report = px.run(graph, strategy="sequential") + assert report.success + + def test_cwd_for_cmd(self, tmp_path: Path) -> None: + # 在 tmp_path 下创建标记文件 + marker = tmp_path / "marker.txt" + marker.write_text("found") + graph = px.Graph.from_specs([ + px.TaskSpec( + "check_cwd", + cmd=["ls", "marker.txt"], + cwd=tmp_path, + ), + ]) + report = px.run(graph, strategy="sequential") + assert report.success + + def test_env_does_not_leak_to_outer(self) -> None: + os.environ.pop("PYFLOWX_LEAK_TEST", None) + + def check_env() -> str: + return os.environ.get("PYFLOWX_LEAK_TEST", "not-set") + + graph = px.Graph.from_specs([ + px.TaskSpec( + "t", + check_env, + env={"PYFLOWX_LEAK_TEST": "leaked"}, + ), + ]) + # fn 任务的环境变量隔离仅在 cmd 任务生效,fn 共享进程环境 + # 这里验证 fn 任务不修改外层环境 + report = px.run(graph, strategy="sequential") + assert report.success + assert os.environ.get("PYFLOWX_LEAK_TEST") is None + + +# ---------------------------------------------------------------------- # +# 上下文感知条件 +# ---------------------------------------------------------------------- # +class TestContextAwareConditions: + """测试基于上游结果的上下文感知条件。""" + + def test_dep_equals_selects_branch(self) -> None: + """根据上游结果选择不同下游分支。""" + + def decide() -> str: + return "path_b" + + def path_a(decide: str = "") -> str: + return f"ran-a:{decide}" + + def path_b(decide: str = "") -> str: + return f"ran-b:{decide}" + + graph = px.Graph.from_specs([ + px.TaskSpec("decide", decide), + px.TaskSpec( + "path_a", + path_a, + depends_on=("decide",), + conditions=(BuiltinConditions.DEP_EQUALS("decide", "path_a"),), + ), + px.TaskSpec( + "path_b", + path_b, + depends_on=("decide",), + conditions=(BuiltinConditions.DEP_EQUALS("decide", "path_b"),), + ), + ]) + report = px.run(graph, strategy="dependency") + assert report.success + assert report.result_of("path_a").status == TaskStatus.SKIPPED + assert report.result_of("path_b").status == TaskStatus.SUCCESS + assert report["path_b"] == "ran-b:path_b" + + def test_dep_truthy_conditional_downstream(self) -> None: + def source() -> list[int]: + return [1, 2, 3] + + def only_if_nonempty(source: list[int]) -> str: + return f"has-{len(source)}" + + graph = px.Graph.from_specs([ + px.TaskSpec("source", source), + px.TaskSpec( + "only_if_nonempty", + only_if_nonempty, + depends_on=("source",), + conditions=(BuiltinConditions.DEP_TRUTHY("source"),), + ), + ]) + report = px.run(graph, strategy="dependency") + assert report.success + assert report["only_if_nonempty"] == "has-3" + + def test_dep_truthy_skips_when_empty(self) -> None: + def source() -> list[int]: + return [] + + def only_if_nonempty(source: list[int]) -> str: + return "should-not-run" + + graph = px.Graph.from_specs([ + px.TaskSpec("source", source), + px.TaskSpec( + "only_if_nonempty", + only_if_nonempty, + depends_on=("source",), + conditions=(BuiltinConditions.DEP_TRUTHY("source"),), + ), + ]) + report = px.run(graph, strategy="dependency") + assert report.success + assert report.result_of("only_if_nonempty").status == TaskStatus.SKIPPED + + def test_dep_matches_complex_predicate(self) -> None: + def source() -> int: + return 42 + + def downstream(source: int) -> str: + return f"got-{source}" + + graph = px.Graph.from_specs([ + px.TaskSpec("source", source), + px.TaskSpec( + "downstream", + downstream, + depends_on=("source",), + conditions=(BuiltinConditions.DEP_MATCHES("source", lambda v: v > 10),), + ), + ]) + report = px.run(graph, strategy="dependency") + assert report.success + assert report["downstream"] == "got-42" + + +# ---------------------------------------------------------------------- # +# 每任务执行策略 spec.strategy +# ---------------------------------------------------------------------- # +class TestPerTaskStrategy: + """测试每任务执行策略字段(spec 级)。""" + + def test_strategy_field_stored(self) -> None: + spec = px.TaskSpec("a", lambda: "ok", strategy="async") + assert spec.strategy == "async" + + def test_strategy_field_default_none(self) -> None: + spec = px.TaskSpec("a", lambda: "ok") + assert spec.strategy is None + + def test_mixed_sync_async_in_dependency_strategy(self) -> None: + """dependency 策略可混合 sync/async 任务。""" + + def sync_fn() -> str: + return "sync" + + async def async_fn(sync: str) -> str: + await asyncio.sleep(0.01) + return f"async-{sync}" + + graph = px.Graph.from_specs([ + px.TaskSpec("sync", sync_fn), + px.TaskSpec("async", async_fn, depends_on=("sync",)), + ]) + report = px.run(graph, strategy="dependency") + assert report.success + assert report["async"] == "async-sync" + + +# ---------------------------------------------------------------------- # +# 综合场景:map-reduce +# ---------------------------------------------------------------------- # +class TestMapReduceScenario: + """测试 map-reduce 模式:fan-out 计算 + 汇总。""" + + def test_map_reduce_pattern(self) -> None: + def source() -> list[int]: + return [1, 2, 3, 4, 5] + + def worker(item: int) -> int: + return item**2 + + def reduce(**kwargs: int) -> int: + # **kwargs 自动注入所有依赖(worker_0, worker_1, ...) + return sum(v for v in kwargs.values() if isinstance(v, int)) + + graph = px.Graph() + graph.add(px.TaskSpec("source", source)) + workers = graph.map( + name_fn=lambda i: f"worker_{i}", + spec=px.TaskSpec("worker_tmpl", worker), + items=[1, 2, 3, 4, 5], + depends_on_per=lambda _i: ("source",), + ) + # reduce 依赖所有 worker + graph.add( + px.TaskSpec( + "reduce", + reduce, + depends_on=tuple(w.name for w in workers), + ) + ) + report = px.run(graph, strategy="dependency") + assert report.success + # 1+4+9+16+25 = 55 + assert report["reduce"] == 55 + + def test_map_reduce_with_concurrency_limit(self) -> None: + """map-reduce 配合并发限制:worker 限制为 2 并发。""" + running: list[int] = [] + max_concurrent = {"n": 0} + + def worker(item: int) -> int: + running.append(item) + cur = len(running) + max_concurrent["n"] = max(max_concurrent["n"], cur) + time.sleep(0.02) + running.remove(item) + return item**2 + + def reduce(**kwargs: int) -> int: + return sum(v for v in kwargs.values() if isinstance(v, int)) + + graph = px.Graph() + workers = graph.map( + name_fn=lambda i: f"worker_{i}", + spec=px.TaskSpec("worker_tmpl", worker, concurrency_key="pool"), + items=[1, 2, 3, 4, 5], + ) + graph.add( + px.TaskSpec( + "reduce", + reduce, + depends_on=tuple(w.name for w in workers), + ) + ) + report = px.run( + graph, + strategy="dependency", + concurrency_limits={"pool": 2}, + ) + assert report.success + assert report["reduce"] == 55 + assert max_concurrent["n"] <= 2 diff --git a/tests/test_conditions.py b/tests/test_conditions.py index 3195eca..b961ce0 100644 --- a/tests/test_conditions.py +++ b/tests/test_conditions.py @@ -1,142 +1,218 @@ """Tests for conditions module.""" +from __future__ import annotations + import os import sys from unittest.mock import patch from pyflowx.conditions import ( + IS_LINUX, + IS_MACOS, + IS_POSIX, + IS_WINDOWS, BuiltinConditions, Constants, ) +_CTX: dict[str, object] = {} + def test_constants_is_windows(): - """Test Constants.IS_WINDOWS is correct.""" assert (sys.platform == "win32") == Constants.IS_WINDOWS def test_constants_is_linux(): - """Test Constants.IS_LINUX is correct.""" assert (sys.platform == "linux") == Constants.IS_LINUX def test_constants_is_macos(): - """Test Constants.IS_MACOS is correct.""" assert (sys.platform == "darwin") == Constants.IS_MACOS def test_constants_is_posix(): - """Test Constants.IS_POSIX is correct.""" assert (sys.platform != "win32") == Constants.IS_POSIX +def test_module_level_static_conditions(): + assert IS_WINDOWS(_CTX) == Constants.IS_WINDOWS + assert IS_LINUX(_CTX) == Constants.IS_LINUX + assert IS_MACOS(_CTX) == Constants.IS_MACOS + assert IS_POSIX(_CTX) == Constants.IS_POSIX -def test_builtin_conditions_python_version_major_only(): - """Test BuiltinConditions.PYTHON_VERSION with major only.""" - # Test with current Python version + +def test_python_version_major_only(): current_major = sys.version_info.major - assert BuiltinConditions.PYTHON_VERSION(current_major) is True - assert BuiltinConditions.PYTHON_VERSION(current_major + 1) is False + assert BuiltinConditions.PYTHON_VERSION(current_major)(_CTX) is True + assert BuiltinConditions.PYTHON_VERSION(current_major + 1)(_CTX) is False -def test_builtin_conditions_python_version_with_minor(): - """Test BuiltinConditions.PYTHON_VERSION with major and minor.""" +def test_python_version_with_minor(): current_major = sys.version_info.major current_minor = sys.version_info.minor - assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor) is True - assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor + 1) is False + assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor)(_CTX) is True + assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor + 1)(_CTX) is False -def test_builtin_conditions_python_version_at_least(): - """Test BuiltinConditions.PYTHON_VERSION_AT_LEAST.""" +def test_python_version_at_least(): current_major = sys.version_info.major current_minor = sys.version_info.minor - # Current version should be at least itself - assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major, current_minor) is True - # Current version should be at least an older version - assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major - 1, 0) is True - # Current version should NOT be at least a newer version - assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major + 1, 0) is False + assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major, current_minor)(_CTX) is True + assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major - 1, 0)(_CTX) is True + assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major + 1, 0)(_CTX) is False -def test_builtin_conditions_HAS_INSTALLED_true(): - """Test BuiltinConditions.HAS_INSTALLED when app exists.""" - # Python should always be available - condition = BuiltinConditions.HAS_INSTALLED("python") - assert condition() is True +def test_has_installed_true(): + condition = BuiltinConditions.HAS_INSTALLED("python3") + assert condition(_CTX) is True -def test_builtin_conditions_HAS_INSTALLED_false(): - """Test BuiltinConditions.HAS_INSTALLED when app doesn't exist.""" +def test_has_installed_false(): condition = BuiltinConditions.HAS_INSTALLED("nonexistent_app_12345") - assert condition() is False + assert condition(_CTX) is False -def test_builtin_conditions_env_var_exists_true(): - """Test BuiltinConditions.ENV_VAR_EXISTS when variable exists.""" +def test_env_var_exists_true(): with patch.dict(os.environ, {"TEST_VAR": "value"}): condition = BuiltinConditions.ENV_VAR_EXISTS("TEST_VAR") - assert condition() is True + assert condition(_CTX) is True -def test_builtin_conditions_env_var_exists_false(): - """Test BuiltinConditions.ENV_VAR_EXISTS when variable doesn't exist.""" +def test_env_var_exists_false(): condition = BuiltinConditions.ENV_VAR_EXISTS("NONEXISTENT_VAR_12345") - assert condition() is False + assert condition(_CTX) is False -def test_builtin_conditions_env_var_equals_true(): - """Test BuiltinConditions.ENV_VAR_EQUALS when value matches.""" +def test_env_var_equals_true(): with patch.dict(os.environ, {"TEST_VAR": "expected_value"}): condition = BuiltinConditions.ENV_VAR_EQUALS("TEST_VAR", "expected_value") - assert condition() is True + assert condition(_CTX) is True -def test_builtin_conditions_env_var_equals_false(): - """Test BuiltinConditions.ENV_VAR_EQUALS when value doesn't match.""" +def test_env_var_equals_false(): with patch.dict(os.environ, {"TEST_VAR": "different_value"}): condition = BuiltinConditions.ENV_VAR_EQUALS("TEST_VAR", "expected_value") - assert condition() is False + assert condition(_CTX) is False -def test_builtin_conditions_not(): - """Test BuiltinConditions.NOT.""" - true_condition = lambda: True # noqa: E731 - false_condition = lambda: False # noqa: E731 +def test_not(): + true_cond = BuiltinConditions.HAS_INSTALLED("python3") + false_cond = BuiltinConditions.HAS_INSTALLED("nonexistent_app_12345") - not_true = BuiltinConditions.NOT(true_condition) - assert not_true() is False - - not_false = BuiltinConditions.NOT(false_condition) - assert not_false() is True + assert BuiltinConditions.NOT(true_cond)(_CTX) is False + assert BuiltinConditions.NOT(false_cond)(_CTX) is True -def test_builtin_conditions_and_all_true(): - """Test BuiltinConditions.AND when all conditions are true.""" - true_condition = lambda: True # noqa: E731 - condition = BuiltinConditions.AND(true_condition, true_condition, true_condition) - assert condition() is True +def test_and_all_true(): + cond = BuiltinConditions.AND( + BuiltinConditions.HAS_INSTALLED("python3"), + BuiltinConditions.HAS_INSTALLED("python3"), + ) + assert cond(_CTX) is True -def test_builtin_conditions_and_one_false(): - """Test BuiltinConditions.AND when one condition is false.""" - true_condition = lambda: True # noqa: E731 - false_condition = lambda: False # noqa: E731 - condition = BuiltinConditions.AND(true_condition, false_condition, true_condition) - assert condition() is False +def test_and_one_false(): + cond = BuiltinConditions.AND( + BuiltinConditions.HAS_INSTALLED("python3"), + BuiltinConditions.HAS_INSTALLED("nonexistent_app"), + ) + assert cond(_CTX) is False -def test_builtin_conditions_or_all_false(): - """Test BuiltinConditions.OR when all conditions are false.""" - false_condition = lambda: False # noqa: E731 - condition = BuiltinConditions.OR(false_condition, false_condition, false_condition) - assert condition() is False +def test_or_all_false(): + cond = BuiltinConditions.OR( + BuiltinConditions.HAS_INSTALLED("nonexistent1"), + BuiltinConditions.HAS_INSTALLED("nonexistent2"), + ) + assert cond(_CTX) is False -def test_builtin_conditions_or_one_true(): - """Test BuiltinConditions.OR when one condition is true.""" - true_condition = lambda: True # noqa: E731 - false_condition = lambda: False # noqa: E731 - condition = BuiltinConditions.OR(false_condition, true_condition, false_condition) - assert condition() is True +def test_or_one_true(): + cond = BuiltinConditions.OR( + BuiltinConditions.HAS_INSTALLED("nonexistent1"), + BuiltinConditions.HAS_INSTALLED("python3"), + ) + assert cond(_CTX) is True + +# ---------------------------------------------------------------------- # +# 上下文条件:基于上游依赖结果 +# ---------------------------------------------------------------------- # +def test_dep_equals_true(): + ctx = {"upstream": 42} + cond = BuiltinConditions.DEP_EQUALS("upstream", 42) + assert cond(ctx) is True + + +def test_dep_equals_false(): + ctx = {"upstream": 99} + cond = BuiltinConditions.DEP_EQUALS("upstream", 42) + assert cond(ctx) is False + + +def test_dep_equals_missing_dep(): + cond = BuiltinConditions.DEP_EQUALS("missing", 42) + assert cond({}) is False + + +def test_dep_matches_true(): + ctx = {"upstream": [1, 2, 3]} + cond = BuiltinConditions.DEP_MATCHES("upstream", lambda v: len(v) == 3) + assert cond(ctx) is True + + +def test_dep_matches_false(): + ctx = {"upstream": [1, 2]} + cond = BuiltinConditions.DEP_MATCHES("upstream", lambda v: len(v) == 3) + assert cond(ctx) is False + + +def test_dep_matches_exception_returns_false(): + ctx = {"upstream": ""} + cond = BuiltinConditions.DEP_MATCHES("upstream", lambda v: v[0]) + assert cond(ctx) is False + + +def test_dep_present_true(): + ctx = {"upstream": "value"} + cond = BuiltinConditions.DEP_PRESENT("upstream") + assert cond(ctx) is True + + +def test_dep_present_false_none(): + # pyrefly: ignore [implicit-any-empty-container] + ctx = {"upstream": None} + cond = BuiltinConditions.DEP_PRESENT("upstream") + assert cond(ctx) is False + + +def test_dep_present_false_missing(): + cond = BuiltinConditions.DEP_PRESENT("missing") + assert cond({}) is False + + +def test_dep_truthy_true(): + ctx = {"upstream": [1]} + cond = BuiltinConditions.DEP_TRUTHY("upstream") + assert cond(ctx) is True + + +def test_dep_truthy_false(): + # pyrefly: ignore [implicit-any-empty-container] + ctx = {"upstream": []} + cond = BuiltinConditions.DEP_TRUTHY("upstream") + assert cond(ctx) is False + + +def test_dep_truthy_missing(): + cond = BuiltinConditions.DEP_TRUTHY("missing") + assert cond({}) is False + + +def test_logical_combination_with_dep_conditions(): + ctx = {"a": 1, "b": 0} + cond = BuiltinConditions.AND( + BuiltinConditions.DEP_EQUALS("a", 1), + BuiltinConditions.NOT(BuiltinConditions.DEP_TRUTHY("b")), + ) + assert cond(ctx) is True diff --git a/tests/test_context.py b/tests/test_context.py index 93cd315..fd8583e 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -141,7 +141,7 @@ class TestDescribeInjection: spec = px.TaskSpec("t", fn, depends_on=("a",)) desc = describe_injection(spec) - assert "a=" in desc + assert "a=" in desc assert "ctx=" in desc assert "flag=" in desc diff --git a/tests/test_executors.py b/tests/test_executors.py index 4c777ab..20e6863 100644 --- a/tests/test_executors.py +++ b/tests/test_executors.py @@ -84,7 +84,9 @@ def test_retries_then_succeeds() -> None: raise RuntimeError("not yet") return "ok" - graph = px.Graph.from_specs([px.TaskSpec("flaky", flaky, retries=2)]) + graph = px.Graph.from_specs([ + px.TaskSpec("flaky", flaky, retry=px.RetryPolicy(max_attempts=3)), + ]) report = px.run(graph, strategy="sequential") assert report.success assert report["flaky"] == "ok" @@ -95,7 +97,9 @@ def test_retries_exhausted() -> None: def always_fail() -> None: raise RuntimeError("nope") - graph = px.Graph.from_specs([px.TaskSpec("f", always_fail, retries=2)]) + graph = px.Graph.from_specs([ + px.TaskSpec("f", always_fail, retry=px.RetryPolicy(max_attempts=3)), + ]) with pytest.raises(TaskFailedError) as exc_info: _ = px.run(graph, strategy="sequential") assert exc_info.value.attempts == 3 @@ -332,7 +336,9 @@ def test_async_timeout_retry_then_succeed() -> None: await asyncio.sleep(10) # 触发超时 return "ok" - graph = px.Graph.from_specs([px.TaskSpec("a", flaky, retries=2, timeout=0.05)]) + graph = px.Graph.from_specs([ + px.TaskSpec("a", flaky, retry=px.RetryPolicy(max_attempts=3), timeout=0.05), + ]) report = px.run(graph, strategy="async") assert report.success assert report["a"] == "ok" @@ -349,7 +355,9 @@ def test_async_failure_retry_branch(caplog: pytest.LogCaptureFixture) -> None: raise RuntimeError("not yet") return "ok" - graph = px.Graph.from_specs([px.TaskSpec("a", flaky, retries=2)]) + graph = px.Graph.from_specs([ + px.TaskSpec("a", flaky, retry=px.RetryPolicy(max_attempts=3)), + ]) with caplog.at_level("WARNING", logger="pyflowx"): report = px.run(graph, strategy="async") assert report.success @@ -489,7 +497,7 @@ def test_run_empty_graph() -> None: # ---------------------------------------------------------------------- # def test_downstream_skipped_when_upstream_skipped_sequential() -> None: """上游任务被 SKIPPED 后,下游任务也应被 SKIPPED(sequential 策略).""" - never_true = lambda: False # noqa: E731 + never_true = lambda _ctx: False # noqa: E731 def downstream(upstream: str) -> str: return upstream + "_processed" @@ -506,7 +514,7 @@ def test_downstream_skipped_when_upstream_skipped_sequential() -> None: def test_downstream_skipped_when_upstream_skipped_thread() -> None: """上游任务被 SKIPPED 后,下游任务也应被 SKIPPED(thread 策略).""" - never_true = lambda: False # noqa: E731 + never_true = lambda _ctx: False # noqa: E731 def downstream(upstream: str) -> str: return upstream + "_processed" @@ -530,7 +538,7 @@ def test_downstream_skipped_when_upstream_skipped_async() -> None: async def downstream(upstream: str) -> str: return upstream + "_processed" - never_true = lambda: False # noqa: E731 + never_true = lambda _ctx: False # noqa: E731 graph = px.Graph.from_specs([ px.TaskSpec("upstream", upstream, conditions=(never_true,)), @@ -544,7 +552,7 @@ def test_downstream_skipped_when_upstream_skipped_async() -> None: def test_downstream_executes_when_upstream_succeeds() -> None: """上游任务成功时,下游任务应正常执行.""" - always_true = lambda: True # noqa: E731 + always_true = lambda _ctx: True # noqa: E731 def upstream() -> str: return "hello" diff --git a/tests/test_executors_edge_cases.py b/tests/test_executors_edge_cases.py index ef48fa8..f160281 100644 --- a/tests/test_executors_edge_cases.py +++ b/tests/test_executors_edge_cases.py @@ -85,7 +85,7 @@ def test_verbose_run_with_skipped_lifecycle(capsys: pytest.CaptureFixture[str]): spec = px.TaskSpec( "test", fn=lambda: "result", - conditions=(lambda: False,), + conditions=(lambda _ctx: False,), ) graph = px.Graph.from_specs([spec]) report = px.run(graph, strategy="sequential", verbose=True) @@ -140,7 +140,7 @@ def test_verbose_event_callback_skipped(): spec = px.TaskSpec( "test", fn=lambda: "result", - conditions=(lambda: False,), + conditions=(lambda _ctx: False,), verbose=True, ) graph = px.Graph.from_specs([spec]) @@ -161,7 +161,11 @@ def test_execute_sync_with_retries(): raise ValueError("temporary error") return "success" - spec = px.TaskSpec("retry_test", fn=failing_function, retries=3) + spec = px.TaskSpec( + "retry_test", + fn=failing_function, + retry=px.RetryPolicy(max_attempts=3), + ) graph = px.Graph.from_specs([spec]) # Should succeed after retries @@ -182,7 +186,11 @@ def test_execute_async_with_retries(): raise ValueError("temporary error") return "success" - spec = px.TaskSpec("retry_async_test", fn=failing_async_function, retries=3) + spec = px.TaskSpec( + "retry_async_test", + fn=failing_async_function, + retry=px.RetryPolicy(max_attempts=3), + ) graph = px.Graph.from_specs([spec]) # Should succeed after retries @@ -196,7 +204,7 @@ def test_execute_sync_skip_on_condition(): spec = px.TaskSpec( "skip_test", fn=lambda: "result", - conditions=(lambda: False,), + conditions=(lambda _ctx: False,), ) graph = px.Graph.from_specs([spec]) @@ -210,7 +218,7 @@ def test_execute_async_skip_on_condition(): spec = px.TaskSpec( "skip_async_test", fn=lambda: "result", - conditions=(lambda: False,), + conditions=(lambda _ctx: False,), ) graph = px.Graph.from_specs([spec]) diff --git a/tests/test_graph.py b/tests/test_graph.py index 5722edf..05d6da7 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -13,13 +13,11 @@ def _fn() -> None: def test_from_specs_builds_graph() -> None: - graph = px.Graph.from_specs( - [ - px.TaskSpec("a", _fn), - px.TaskSpec("b", _fn, depends_on=("a",)), - px.TaskSpec("c", _fn, depends_on=("a", "b")), - ] - ) + graph = px.Graph.from_specs([ + px.TaskSpec("a", _fn), + px.TaskSpec("b", _fn, depends_on=("a",)), + px.TaskSpec("c", _fn, depends_on=("a", "b")), + ]) assert set(graph.names) == {"a", "b", "c"} assert graph.dependencies("c") == ("a", "b") assert len(graph) == 3 @@ -28,23 +26,19 @@ def test_from_specs_builds_graph() -> None: def test_from_specs_allows_forward_references() -> None: # b depends on a, but a is declared after b — order should not matter. - graph = px.Graph.from_specs( - [ - px.TaskSpec("b", _fn, depends_on=("a",)), - px.TaskSpec("a", _fn), - ] - ) + graph = px.Graph.from_specs([ + px.TaskSpec("b", _fn, depends_on=("a",)), + px.TaskSpec("a", _fn), + ]) assert graph.layers() == [["a"], ["b"]] def test_duplicate_task_raises() -> None: with pytest.raises(DuplicateTaskError): - _ = px.Graph.from_specs( - [ - px.TaskSpec("a", _fn), - px.TaskSpec("a", _fn), - ] - ) + _ = px.Graph.from_specs([ + px.TaskSpec("a", _fn), + px.TaskSpec("a", _fn), + ]) def test_missing_dependency_raises() -> None: @@ -57,24 +51,20 @@ def test_missing_dependency_raises() -> None: def test_cycle_detection() -> None: with pytest.raises(CycleError): - _ = px.Graph.from_specs( - [ - px.TaskSpec("a", _fn, depends_on=("c",)), - px.TaskSpec("b", _fn, depends_on=("a",)), - px.TaskSpec("c", _fn, depends_on=("b",)), - ] - ) + _ = px.Graph.from_specs([ + px.TaskSpec("a", _fn, depends_on=("c",)), + px.TaskSpec("b", _fn, depends_on=("a",)), + px.TaskSpec("c", _fn, depends_on=("b",)), + ]) def test_layers_grouping() -> None: - graph = px.Graph.from_specs( - [ - px.TaskSpec("a", _fn), - px.TaskSpec("b", _fn), - px.TaskSpec("c", _fn, depends_on=("a", "b")), - px.TaskSpec("d", _fn, depends_on=("c",)), - ] - ) + graph = px.Graph.from_specs([ + px.TaskSpec("a", _fn), + px.TaskSpec("b", _fn), + px.TaskSpec("c", _fn, depends_on=("a", "b")), + px.TaskSpec("d", _fn, depends_on=("c",)), + ]) layers = graph.layers() assert layers == [["a", "b"], ["c"], ["d"]] @@ -85,12 +75,10 @@ def test_self_dependency_rejected() -> None: def test_to_mermaid() -> None: - graph = px.Graph.from_specs( - [ - px.TaskSpec("a", _fn), - px.TaskSpec("b", _fn, depends_on=("a",)), - ] - ) + graph = px.Graph.from_specs([ + px.TaskSpec("a", _fn), + px.TaskSpec("b", _fn, depends_on=("a",)), + ]) mermaid = graph.to_mermaid() assert mermaid.startswith("graph TD") assert 'a["a"]' in mermaid @@ -104,13 +92,11 @@ def test_to_mermaid_invalid_orientation() -> None: def test_subgraph_by_tags() -> None: - graph = px.Graph.from_specs( - [ - px.TaskSpec("a", _fn, tags=("ingest",)), - px.TaskSpec("b", _fn, depends_on=("a",), tags=("ingest",)), - px.TaskSpec("c", _fn, depends_on=("b",), tags=("report",)), - ] - ) + graph = px.Graph.from_specs([ + px.TaskSpec("a", _fn, tags=("ingest",)), + px.TaskSpec("b", _fn, depends_on=("a",), tags=("ingest",)), + px.TaskSpec("c", _fn, depends_on=("b",), tags=("report",)), + ]) sub = graph.subgraph(["ingest"]) assert set(sub.names) == {"a", "b"} # Edge to dropped task c is removed; b no longer waits for anything @@ -119,13 +105,11 @@ def test_subgraph_by_tags() -> None: def test_subgraph_by_names() -> None: - graph = px.Graph.from_specs( - [ - px.TaskSpec("a", _fn), - px.TaskSpec("b", _fn, depends_on=("a",)), - px.TaskSpec("c", _fn, depends_on=("b",)), - ] - ) + graph = px.Graph.from_specs([ + px.TaskSpec("a", _fn), + px.TaskSpec("b", _fn, depends_on=("a",)), + px.TaskSpec("c", _fn, depends_on=("b",)), + ]) sub = graph.subgraph_by_names(["a", "b"]) assert set(sub.names) == {"a", "b"} # c is dropped, so b's dep on c (none here) — but a->b edge preserved. @@ -139,12 +123,10 @@ def test_subgraph_by_names_unknown() -> None: def test_describe() -> None: - graph = px.Graph.from_specs( - [ - px.TaskSpec("a", _fn), - px.TaskSpec("b", _fn, depends_on=("a",)), - ] - ) + graph = px.Graph.from_specs([ + px.TaskSpec("a", _fn), + px.TaskSpec("b", _fn, depends_on=("a",)), + ]) desc = graph.describe() assert "Layer 1" in desc assert "Layer 2" in desc @@ -187,12 +169,10 @@ def test_spec_accessor() -> None: def test_dependencies_accessor() -> None: - graph = px.Graph.from_specs( - [ - px.TaskSpec("a", _fn), - px.TaskSpec("b", _fn, depends_on=("a",)), - ] - ) + graph = px.Graph.from_specs([ + px.TaskSpec("a", _fn), + px.TaskSpec("b", _fn, depends_on=("a",)), + ]) assert graph.dependencies("a") == () assert graph.dependencies("b") == ("a",) @@ -210,16 +190,20 @@ def test_empty_graph_layers() -> None: def test_subgraph_preserves_metadata() -> None: - """子图应保留原任务的 retries/timeout/tags 等元数据。""" - graph = px.Graph.from_specs( - [ - px.TaskSpec("a", _fn, tags=("x",), retries=3, timeout=5.0), - px.TaskSpec("b", _fn, depends_on=("a",), tags=("y",)), - ] - ) + """子图应保留原任务的 retry/timeout/tags 等元数据。""" + graph = px.Graph.from_specs([ + px.TaskSpec( + "a", + _fn, + tags=("x",), + retry=px.RetryPolicy(max_attempts=3), + timeout=5.0, + ), + px.TaskSpec("b", _fn, depends_on=("a",), tags=("y",)), + ]) sub = graph.subgraph(["x"]) spec = sub.spec("a") - assert spec.retries == 3 + assert spec.retry.max_attempts == 3 assert spec.timeout == 5.0 assert spec.tags == ("x",) diff --git a/tests/test_runner.py b/tests/test_runner.py index 5c6f2f5..577396c 100644 --- a/tests/test_runner.py +++ b/tests/test_runner.py @@ -29,24 +29,20 @@ def _echo_graph(name: str = "echo_task", msg: str = "hello") -> px.Graph: def _failing_graph() -> px.Graph: """构造一个必定失败的单任务图.""" - return px.Graph.from_specs( - [ - px.TaskSpec( - "fail", - cmd=["python", "-c", "import sys; sys.exit(1)"], - ) - ] - ) + return px.Graph.from_specs([ + px.TaskSpec( + "fail", + cmd=["python", "-c", "import sys; sys.exit(1)"], + ) + ]) def _multi_task_graph() -> px.Graph: """构造一个带依赖的多任务图.""" - return px.Graph.from_specs( - [ - px.TaskSpec("a", cmd=[*ECHO_CMD, "a"]), - px.TaskSpec("b", cmd=[*ECHO_CMD, "b"], depends_on=("a",)), - ] - ) + return px.Graph.from_specs([ + px.TaskSpec("a", cmd=[*ECHO_CMD, "a"]), + px.TaskSpec("b", cmd=[*ECHO_CMD, "b"], depends_on=("a",)), + ]) # ---------------------------------------------------------------------- # @@ -240,12 +236,10 @@ class TestCliRunnerRunSuccess: def track_b() -> None: executed.append("b") - runner = px.CliRunner( - { - "a": px.Graph.from_specs([px.TaskSpec("a", track_a)]), - "b": px.Graph.from_specs([px.TaskSpec("b", track_b)]), - } - ) + runner = px.CliRunner({ + "a": px.Graph.from_specs([px.TaskSpec("a", track_a)]), + "b": px.Graph.from_specs([px.TaskSpec("b", track_b)]), + }) _ = runner.run(["b"]) assert executed == ["b"] @@ -318,15 +312,13 @@ class TestCliRunnerVerbose: def test_verbose_prints_skip_lifecycle(self, capsys: pytest.CaptureFixture[str]) -> None: """verbose 模式下跳过的任务应打印跳过信息.""" - graph = px.Graph.from_specs( - [ - px.TaskSpec( - "skip_me", - cmd=[*ECHO_CMD, "skip"], - conditions=(lambda: False,), - ), - ] - ) + graph = px.Graph.from_specs([ + px.TaskSpec( + "skip_me", + cmd=[*ECHO_CMD, "skip"], + conditions=(lambda _ctx: False,), + ), + ]) runner = px.CliRunner({"skip": graph}) _ = runner.run(["skip"]) captured = capsys.readouterr() @@ -394,13 +386,11 @@ class TestCliRunnerList: def test_list_prints_all_commands(self, capsys: pytest.CaptureFixture[str]) -> None: """--list 应打印所有命令.""" - runner = px.CliRunner( - { - "clean": _echo_graph("c", "clean"), - "build": _echo_graph("b", "build"), - "test": _echo_graph("t", "test"), - } - ) + runner = px.CliRunner({ + "clean": _echo_graph("c", "clean"), + "build": _echo_graph("b", "build"), + "test": _echo_graph("t", "test"), + }) _ = runner.run(["--list"]) captured = capsys.readouterr() assert "clean" in captured.out @@ -523,30 +513,26 @@ class TestCliRunnerIntegration: def test_condition_skipped_command_succeeds(self) -> None: """条件不满足时任务跳过, 整体仍成功.""" - graph = px.Graph.from_specs( - [ - px.TaskSpec( - "skip_me", - cmd=[*ECHO_CMD, "should not run"], - conditions=(lambda: False,), - ), - ] - ) + graph = px.Graph.from_specs([ + px.TaskSpec( + "skip_me", + cmd=[*ECHO_CMD, "should not run"], + conditions=(lambda _ctx: False,), + ), + ]) runner = px.CliRunner({"skip": graph}) exit_code = runner.run(["skip"]) assert exit_code == CliExitCode.SUCCESS.value def test_condition_met_command_succeeds(self) -> None: """条件满足时任务执行, 整体成功.""" - graph = px.Graph.from_specs( - [ - px.TaskSpec( - "run_me", - cmd=[*ECHO_CMD, "should run"], - conditions=(lambda: True,), - ), - ] - ) + graph = px.Graph.from_specs([ + px.TaskSpec( + "run_me", + cmd=[*ECHO_CMD, "should run"], + conditions=(lambda _ctx: True,), + ), + ]) runner = px.CliRunner({"run": graph}) exit_code = runner.run(["run"]) assert exit_code == CliExitCode.SUCCESS.value @@ -562,14 +548,12 @@ class TestCliRunnerIntegration: return fn - graph = px.Graph.from_specs( - [ - px.TaskSpec("a", make("a")), - px.TaskSpec("b", make("b"), depends_on=("a",)), - px.TaskSpec("c", make("c"), depends_on=("a",)), - px.TaskSpec("d", make("d"), depends_on=("b", "c")), - ] - ) + graph = px.Graph.from_specs([ + px.TaskSpec("a", make("a")), + px.TaskSpec("b", make("b"), depends_on=("a",)), + px.TaskSpec("c", make("c"), depends_on=("a",)), + px.TaskSpec("d", make("d"), depends_on=("b", "c")), + ]) runner = px.CliRunner({"diamond": graph}) exit_code = runner.run(["diamond"]) assert exit_code == CliExitCode.SUCCESS.value @@ -577,12 +561,10 @@ class TestCliRunnerIntegration: def test_mixed_fn_and_cmd_commands(self) -> None: """混合 fn 和 cmd 的命令应都能执行.""" - runner = px.CliRunner( - { - "fn_cmd": px.Graph.from_specs([px.TaskSpec("fn", fn=lambda: "fn-result")]), - "cmd_cmd": px.Graph.from_specs([px.TaskSpec("cmd", cmd=[*ECHO_CMD, "cmd-result"])]), - } - ) + runner = px.CliRunner({ + "fn_cmd": px.Graph.from_specs([px.TaskSpec("fn", fn=lambda: "fn-result")]), + "cmd_cmd": px.Graph.from_specs([px.TaskSpec("cmd", cmd=[*ECHO_CMD, "cmd-result"])]), + }) assert runner.run(["fn_cmd"]) == CliExitCode.SUCCESS.value assert runner.run(["cmd_cmd"]) == CliExitCode.SUCCESS.value diff --git a/tests/test_task.py b/tests/test_task.py index 8c70549..c92570d 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -6,7 +6,7 @@ from datetime import datetime import pytest -from pyflowx.task import TaskResult, TaskSpec, TaskStatus +from pyflowx.task import RetryPolicy, TaskResult, TaskSpec, TaskStatus def _fn() -> None: @@ -18,9 +18,9 @@ def test_spec_empty_name_rejected() -> None: TaskSpec("", _fn) -def test_spec_negative_retries_rejected() -> None: - with pytest.raises(ValueError, match="retries"): - TaskSpec("a", _fn, retries=-1) +def test_spec_negative_max_attempts_rejected() -> None: + with pytest.raises(ValueError, match="max_attempts"): + TaskSpec("a", _fn, retry=RetryPolicy(max_attempts=0)) def test_spec_zero_timeout_rejected() -> None: diff --git a/tests/test_task_edge_cases.py b/tests/test_task_edge_cases.py index 0a3dc12..43865ac 100644 --- a/tests/test_task_edge_cases.py +++ b/tests/test_task_edge_cases.py @@ -67,7 +67,9 @@ def test_taskspec_wrap_cmd_verbose(): def test_taskspec_wrap_cmd_error(): """Test TaskSpec._wrap_cmd handles command error.""" - spec = TaskSpec("test", cmd=["python", "-c", "import sys; sys.exit(1)"]) + import sys + + spec = TaskSpec("test", cmd=[sys.executable, "-c", "import sys; sys.exit(1)"]) wrapped_fn = spec.effective_fn with pytest.raises(RuntimeError, match="命令执行失败"): @@ -105,10 +107,10 @@ def test_taskspec_conditions_check(): spec = px.TaskSpec( "test", fn=lambda: "result", - conditions=(lambda: True,), + conditions=(lambda _ctx: True,), ) - assert spec.should_execute() is True + assert spec.should_execute({})[0] is True def test_taskspec_conditions_false(): @@ -116,10 +118,10 @@ def test_taskspec_conditions_false(): spec = px.TaskSpec( "test", fn=lambda: "result", - conditions=(lambda: False,), + conditions=(lambda _ctx: False,), ) - assert spec.should_execute() is False + assert spec.should_execute({})[0] is False def test_taskspec_conditions_multiple(): @@ -127,10 +129,10 @@ def test_taskspec_conditions_multiple(): spec = px.TaskSpec( "test", fn=lambda: "result", - conditions=(lambda: True, lambda: True, lambda: True), + conditions=(lambda _ctx: True, lambda _ctx: True, lambda _ctx: True), ) - assert spec.should_execute() is True + assert spec.should_execute({})[0] is True def test_taskspec_conditions_multiple_one_false(): @@ -138,10 +140,10 @@ def test_taskspec_conditions_multiple_one_false(): spec = px.TaskSpec( "test", fn=lambda: "result", - conditions=(lambda: True, lambda: False, lambda: True), + conditions=(lambda _ctx: True, lambda _ctx: False, lambda _ctx: True), ) - assert spec.should_execute() is False + assert spec.should_execute({})[0] is False def test_taskspec_list_cmd_timeout_mocked(): @@ -218,27 +220,28 @@ def test_taskspec_shell_cmd_os_error_mocked(): # ---------------------------------------------------------------------- # def test_skip_if_missing_with_available_command(): """skip_if_missing=True 时,命令存在应返回 True.""" - # python 命令在测试环境中一定存在 - spec = TaskSpec("test", cmd=["python", "--version"], skip_if_missing=True) - assert spec.should_execute() is True + import sys + + spec = TaskSpec("test", cmd=[sys.executable, "--version"], skip_if_missing=True) + assert spec.should_execute({})[0] is True def test_skip_if_missing_with_missing_command(): """skip_if_missing=True 时,命令不存在应返回 False.""" spec = TaskSpec("test", cmd=["definitely_not_installed_app_xyz"], skip_if_missing=True) - assert spec.should_execute() is False + assert spec.should_execute({})[0] is False def test_skip_if_missing_false_with_missing_command(): """skip_if_missing=False 时,命令不存在也应返回 True(不检查).""" spec = TaskSpec("test", cmd=["definitely_not_installed_app_xyz"], skip_if_missing=False) - assert spec.should_execute() is True + assert spec.should_execute({})[0] is True def test_skip_if_missing_with_shell_cmd_not_checked(): """skip_if_missing=True 时,shell 命令(str)不检查,应返回 True.""" spec = TaskSpec("test", cmd="definitely_not_installed_app_xyz", skip_if_missing=True) - assert spec.should_execute() is True + assert spec.should_execute({})[0] is True def test_skip_if_missing_with_callable_cmd_not_checked(): @@ -248,7 +251,7 @@ def test_skip_if_missing_with_callable_cmd_not_checked(): return 0 spec = TaskSpec("test", cmd=custom_cmd, skip_if_missing=True) - assert spec.should_execute() is True + assert spec.should_execute({})[0] is True def test_skip_if_missing_with_fn_not_checked(): @@ -258,7 +261,7 @@ def test_skip_if_missing_with_fn_not_checked(): return 0 spec = TaskSpec("test", fn=my_fn, skip_if_missing=True) - assert spec.should_execute() is True + assert spec.should_execute({})[0] is True def test_skip_if_missing_with_empty_cmd_list(): @@ -266,37 +269,39 @@ def test_skip_if_missing_with_empty_cmd_list(): spec = TaskSpec("test", cmd=[""], skip_if_missing=True) # 空字符串命令,shutil.which 返回 None # 但 cmd[0] 是空字符串,shutil.which("") 返回 None - assert spec.should_execute() is False + assert spec.should_execute({})[0] is False def test_skip_if_missing_combined_with_conditions(): """skip_if_missing=True 与 conditions 组合使用.""" + import sys + # conditions 返回 False,应跳过 spec = TaskSpec( "test", - cmd=["python", "--version"], + cmd=[sys.executable, "--version"], skip_if_missing=True, - conditions=(lambda: False,), + conditions=(lambda _ctx: False,), ) - assert spec.should_execute() is False + assert spec.should_execute({})[0] is False # conditions 返回 True,命令存在,应执行 spec = TaskSpec( "test", - cmd=["python", "--version"], + cmd=[sys.executable, "--version"], skip_if_missing=True, - conditions=(lambda: True,), + conditions=(lambda _ctx: True,), ) - assert spec.should_execute() is True + assert spec.should_execute({})[0] is True # conditions 返回 True,命令不存在,应跳过 spec = TaskSpec( "test", cmd=["definitely_not_installed_app_xyz"], skip_if_missing=True, - conditions=(lambda: True,), + conditions=(lambda _ctx: True,), ) - assert spec.should_execute() is False + assert spec.should_execute({})[0] is False def test_skip_if_missing_skips_task_in_run(): diff --git a/tests/test_taskspec_commands.py b/tests/test_taskspec_commands.py index 22e4746..2346334 100644 --- a/tests/test_taskspec_commands.py +++ b/tests/test_taskspec_commands.py @@ -52,7 +52,7 @@ def test_taskspec_with_conditions_skip(): """测试条件不满足时任务被跳过.""" # 创建一个永远不会满足的条件 - def never_true(): + def never_true(_ctx): return False graph = px.Graph.from_specs([ @@ -73,7 +73,7 @@ def test_taskspec_with_conditions_execute(): """测试条件满足时任务正常执行.""" # 创建一个总是满足的条件 - def always_true(): + def always_true(_ctx): return True graph = px.Graph.from_specs([ @@ -103,17 +103,17 @@ def test_platform_conditions(): px.TaskSpec( "win_task", cmd=win_cmd, - conditions=(lambda: Constants.IS_WINDOWS,), + conditions=(lambda _ctx: Constants.IS_WINDOWS,), ), px.TaskSpec( "linux_task", cmd=posix_cmd, - conditions=(lambda: Constants.IS_LINUX,), + conditions=(lambda _ctx: Constants.IS_LINUX,), ), px.TaskSpec( "macos_task", cmd=posix_cmd, - conditions=(lambda: Constants.IS_MACOS,), + conditions=(lambda _ctx: Constants.IS_MACOS,), ), ]) @@ -137,17 +137,15 @@ def test_platform_conditions(): def test_app_installed_conditions(): """测试应用安装条件.""" - # 测试 python 应该总是安装的 - if sys.platform == "win32": - python_cmd = ["python", "--version"] - else: - python_cmd = ["python3", "--version"] + # 使用 sys.executable 保证可移植 + python_cmd = [sys.executable, "--version"] + py_name = "python" if sys.platform == "win32" else "python3" graph = px.Graph.from_specs([ px.TaskSpec( "python_check", cmd=python_cmd, - conditions=(BuiltinConditions.HAS_INSTALLED("python"),), + conditions=(BuiltinConditions.HAS_INSTALLED(py_name),), ), ]) @@ -162,18 +160,18 @@ def test_combined_conditions(): """测试组合条件.""" # AND 条件 and_condition = BuiltinConditions.AND( - lambda: True, - lambda: True, + lambda _ctx: True, + lambda _ctx: True, ) # OR 条件 or_condition = BuiltinConditions.OR( - lambda: True, - lambda: False, + lambda _ctx: True, + lambda _ctx: False, ) # NOT 条件 - not_condition = BuiltinConditions.NOT(lambda: False) + not_condition = BuiltinConditions.NOT(lambda _ctx: False) graph = px.Graph.from_specs([ px.TaskSpec( @@ -228,7 +226,7 @@ def test_taskspec_with_timeout(): # 短时间任务应该成功 px.TaskSpec( "short_task", - cmd=["python", "-c", "import time; time.sleep(0.1)"], + cmd=[sys.executable, "-c", "import time; time.sleep(0.1)"], timeout=1.0, ), ]) @@ -245,13 +243,13 @@ def test_taskspec_dependency_with_conditions(): px.TaskSpec( "first", cmd=[*ECHO_CMD, "first"], - conditions=(lambda: True,), + conditions=(lambda _ctx: True,), ), px.TaskSpec( "second", cmd=[*ECHO_CMD, "second"], depends_on=("first",), - conditions=(lambda: True,), + conditions=(lambda _ctx: True,), ), px.TaskSpec( "third", @@ -378,7 +376,7 @@ class TestTaskSpecVerbose: graph = px.Graph.from_specs([ px.TaskSpec( "fail", - cmd=["python", "-c", "import sys; sys.exit(1)"], + cmd=[sys.executable, "-c", "import sys; sys.exit(1)"], verbose=True, ) ]) @@ -414,7 +412,7 @@ class TestTaskSpecCmdErrors: px.TaskSpec( "fail", cmd=[ - "python", + sys.executable, "-c", "import sys; sys.stderr.write('error-msg'); sys.exit(1)", ], @@ -437,7 +435,9 @@ class TestTaskSpecCmdErrors: """shell 命令失败时应抛出 RuntimeError.""" from pyflowx.errors import TaskFailedError - graph = px.Graph.from_specs([px.TaskSpec("fail", cmd='python -c "import sys; sys.exit(1)"')]) + graph = px.Graph.from_specs([ + px.TaskSpec("fail", cmd=f'{sys.executable} -c "import sys; sys.exit(1)"'), + ]) with pytest.raises(TaskFailedError) as exc_info: _ = px.run(graph, strategy="sequential") assert "Shell 命令执行失败" in str(exc_info.value.cause) @@ -450,7 +450,7 @@ class TestTaskSpecCmdErrors: graph = px.Graph.from_specs([ px.TaskSpec( "slow", - cmd=["python", "-c", "import time; time.sleep(5)"], + cmd=[sys.executable, "-c", "import time; time.sleep(5)"], timeout=0.1, ) ]) @@ -463,7 +463,13 @@ class TestTaskSpecCmdErrors: """shell 命令超时应抛出 RuntimeError.""" from pyflowx.errors import TaskFailedError - graph = px.Graph.from_specs([px.TaskSpec("slow", cmd='python -c "import time; time.sleep(5)"', timeout=0.1)]) + graph = px.Graph.from_specs([ + px.TaskSpec( + "slow", + cmd=f'{sys.executable} -c "import time; time.sleep(5)"', + timeout=0.1, + ), + ]) with pytest.raises(TaskFailedError) as exc_info: _ = px.run(graph, strategy="sequential") assert "超时" in str(exc_info.value.cause)