refactor: 重构重试策略、条件函数与上下文注入逻辑

主要变更:
1. 替换旧retries参数为RetryPolicy配置
2. 重构条件函数,支持上下文参数与动态依赖判断
3. 更新上下文注入逻辑,支持软依赖与更清晰的注入描述
4. 新增sglang CLI命令与相关配置
5. 格式化代码统一列表与参数写法
6. 更新文档与测试用例适配新API
This commit is contained in:
2026-06-27 14:33:54 +08:00
parent 6f01cde8ac
commit 5c8ec281ff
24 changed files with 2796 additions and 1043 deletions
+1
View File
@@ -44,6 +44,7 @@ piptool = "pyflowx.cli.piptool:main"
pymake = "pyflowx.cli.pymake:main" pymake = "pyflowx.cli.pymake:main"
reseticon = "pyflowx.cli.reseticoncache:main" reseticon = "pyflowx.cli.reseticoncache:main"
scrcap = "pyflowx.cli.screenshot:main" scrcap = "pyflowx.cli.screenshot:main"
sglang = "pyflowx.cli.llm.sglang:main"
sshcopy = "pyflowx.cli.sshcopyid:main" sshcopy = "pyflowx.cli.sshcopyid:main"
taskk = "pyflowx.cli.taskkill:main" taskk = "pyflowx.cli.taskkill:main"
wch = "pyflowx.cli.which:main" wch = "pyflowx.cli.which:main"
+35 -17
View File
@@ -4,9 +4,15 @@
-------- --------
* :class:`TaskSpec` —— 不可变任务描述符(唯一需要配置的东西)。 * :class:`TaskSpec` —— 不可变任务描述符(唯一需要配置的东西)。
* :class:`Graph` —— 由一组 spec 构建的 DAG;负责校验、分层、可视化。 * :class:`Graph` —— 由一组 spec 构建的 DAG;负责校验、分层、可视化。
* :func:`run` —— 以 ``sequential`` / ``thread`` / ``async`` 策略执行图。 * :func:`run` ——以 ``sequential`` / ``thread`` / ``async`` / ``dependency``
策略执行图。
* :class:`RunReport` —— 类型化、可查询的运行结果。 * :class:`RunReport` —— 类型化、可查询的运行结果。
* :class:`Context` —— 整体上下文注入的标注标记。 * :class:`Context` —— 整体上下文注入的标注标记。
* :class:`RetryPolicy` —— 重试策略(max_attempts/delay/backoff/jitter/retry_on)。
* :class:`TaskHooks` —— 任务生命周期钩子(pre_run/post_run/on_failure)。
* :class:`GraphDefaults` —— 图级默认值。
* :func:`compose` —— 编程式组合多图。
* :func:`task_template` —— 批量生成相似 TaskSpec 的工厂。
* 状态后端::class:`StateBackend`、:class:`MemoryBackend`、:class:`JSONBackend`。 * 状态后端::class:`StateBackend`、:class:`MemoryBackend`、:class:`JSONBackend`。
快速上手 快速上手
@@ -18,7 +24,7 @@
graph = px.Graph.from_specs([ graph = px.Graph.from_specs([
px.TaskSpec("extract", extract), px.TaskSpec("extract", extract),
px.TaskSpec("double", double, ("extract",)), px.TaskSpec("double", double, depends_on=("extract",)),
]) ])
report = px.run(graph, strategy="sequential") report = px.run(graph, strategy="sequential")
print(report["double"]) # [2, 4, 6] print(report["double"]) # [2, 4, 6]
@@ -29,23 +35,18 @@
from pyflowx.conditions import IS_WINDOWS, BuiltinConditions from pyflowx.conditions import IS_WINDOWS, BuiltinConditions
graph = px.Graph.from_specs([ graph = px.Graph.from_specs([
# 使用命令列表
px.TaskSpec("list_files", cmd=["ls", "-la"]), px.TaskSpec("list_files", cmd=["ls", "-la"]),
# 使用 shell 命令
px.TaskSpec("check_git", cmd="git status"), px.TaskSpec("check_git", cmd="git status"),
# 条件执行:仅在 Windows 上运行
px.TaskSpec( px.TaskSpec(
"win_only", "win_only",
cmd=["dir"], cmd=["dir"],
conditions=(IS_WINDOWS,) conditions=(IS_WINDOWS,)
), ),
# 条件执行:仅在 git 已安装时运行
px.TaskSpec( px.TaskSpec(
"git_check", "git_check",
cmd=["git", "--version"], cmd=["git", "--version"],
conditions=(BuiltinConditions.HAS_INSTALLED("git"),) conditions=(BuiltinConditions.HAS_INSTALLED("git"),)
), ),
# 命令不存在时自动跳过(而非失败)
px.TaskSpec( px.TaskSpec(
"optional_build", "optional_build",
cmd=["maturin", "build"], cmd=["maturin", "build"],
@@ -58,6 +59,10 @@
from __future__ import annotations from __future__ import annotations
from .conditions import ( from .conditions import (
IS_LINUX,
IS_MACOS,
IS_POSIX,
IS_WINDOWS,
BuiltinConditions, BuiltinConditions,
Condition, Condition,
Constants, Constants,
@@ -74,20 +79,33 @@ from .errors import (
TaskTimeoutError, TaskTimeoutError,
) )
from .executors import Strategy, run from .executors import Strategy, run
from .graph import Graph, GraphComposer from .graph import Graph, GraphComposer, GraphDefaults, compose
from .report import RunReport from .report import RunReport
from .runner import CliExitCode, CliRunner from .runner import CliExitCode, CliRunner
from .storage import JSONBackend, MemoryBackend, StateBackend from .storage import JSONBackend, MemoryBackend, StateBackend
from .task import TaskCmd, TaskEvent, TaskResult, TaskSpec, TaskStatus from .task import (
CacheKeyFn,
RetryPolicy,
TaskCmd,
TaskEvent,
TaskHooks,
TaskResult,
TaskSpec,
TaskStatus,
task_template,
)
__version__ = "0.2.6" __version__ = "0.3.0"
__all__ = [ __all__ = [
"IS_LINUX",
"IS_MACOS",
"IS_POSIX",
"IS_WINDOWS",
"BuiltinConditions", "BuiltinConditions",
"CacheKeyFn",
"CliExitCode", "CliExitCode",
# CLI 运行器
"CliRunner", "CliRunner",
# 条件判断
"Condition", "Condition",
"Constants", "Constants",
"Context", "Context",
@@ -95,28 +113,28 @@ __all__ = [
"DuplicateTaskError", "DuplicateTaskError",
"Graph", "Graph",
"GraphComposer", "GraphComposer",
"GraphDefaults",
"InjectionError", "InjectionError",
"JSONBackend", "JSONBackend",
"MemoryBackend", "MemoryBackend",
"MissingDependencyError", "MissingDependencyError",
# 错误
"PyFlowXError", "PyFlowXError",
"RetryPolicy",
"RunReport", "RunReport",
# 状态后端
"StateBackend", "StateBackend",
"StorageError", "StorageError",
"Strategy", "Strategy",
"TaskCmd", "TaskCmd",
"TaskEvent", "TaskEvent",
"TaskFailedError", "TaskFailedError",
"TaskHooks",
"TaskResult", "TaskResult",
# 核心类型
"TaskSpec", "TaskSpec",
"TaskStatus", "TaskStatus",
"TaskTimeoutError", "TaskTimeoutError",
# 辅助(高级)
"build_call_args", "build_call_args",
"compose",
"describe_injection", "describe_injection",
# 执行
"run", "run",
"task_template",
] ]
+3 -3
View File
@@ -112,9 +112,9 @@ def main() -> None:
args = parser.parse_args() args = parser.parse_args()
if args.command == "mirror": if args.command == "mirror":
graph = px.Graph.from_specs( graph = px.Graph.from_specs([
[px.TaskSpec("set_pip_mirror", fn=set_pip_mirror, args=(args.name,), kwargs={"token": args.token})] px.TaskSpec("set_pip_mirror", fn=set_pip_mirror, args=(args.name,), kwargs={"token": args.token})
) ])
else: else:
parser.print_help() parser.print_help()
return return
+2 -2
View File
@@ -43,13 +43,13 @@ def main() -> None:
px.TaskSpec( px.TaskSpec(
"envqt_install", "envqt_install",
cmd=["sudo", "apt", "install", "-y", *QT_LIBS], cmd=["sudo", "apt", "install", "-y", *QT_LIBS],
conditions=(lambda: Constants.IS_LINUX,), conditions=(lambda _: Constants.IS_LINUX,),
verbose=True, verbose=True,
), ),
px.TaskSpec( px.TaskSpec(
"envqt_fonts", "envqt_fonts",
cmd=["sudo", "apt", "install", "-y", *CHINESE_FONTS], cmd=["sudo", "apt", "install", "-y", *CHINESE_FONTS],
conditions=(lambda: Constants.IS_LINUX,), conditions=(lambda _: Constants.IS_LINUX,),
verbose=True, verbose=True,
), ),
], ],
+8 -5
View File
@@ -37,7 +37,7 @@ def init_sub_dirs() -> None:
px.TaskSpec( px.TaskSpec(
"init", "init",
cmd=["git", "init"], cmd=["git", "init"],
conditions=(not_has_git_repo,), conditions=(lambda _: not_has_git_repo(),),
cwd=subdir, cwd=subdir,
), ),
px.TaskSpec("add", cmd=["git", "add", "."], depends_on=("init",)), px.TaskSpec("add", cmd=["git", "add", "."], depends_on=("init",)),
@@ -70,7 +70,7 @@ def main() -> None:
graphs={ graphs={
# 添加并提交 # 添加并提交
"a": px.Graph.from_specs([ "a": px.Graph.from_specs([
px.TaskSpec("add", cmd=["git", "add", "."], conditions=(has_files,)), px.TaskSpec("add", cmd=["git", "add", "."], conditions=(lambda _: has_files(),)),
px.TaskSpec("commit", cmd=["git", "commit", "-m", "chore: update"], depends_on=("add",)), px.TaskSpec("commit", cmd=["git", "commit", "-m", "chore: update"], depends_on=("add",)),
]), ]),
# 清理 # 清理
@@ -80,10 +80,13 @@ def main() -> None:
]), ]),
# 初始化、添加并提交 # 初始化、添加并提交
"i": px.Graph.from_specs([ "i": px.Graph.from_specs([
px.TaskSpec("init", cmd=["git", "init"], conditions=(not_has_git_repo,)), px.TaskSpec("init", cmd=["git", "init"], conditions=(lambda _: not_has_git_repo(),)),
px.TaskSpec("add", cmd=["git", "add", "."], depends_on=("init",), conditions=(has_files,)), px.TaskSpec("add", cmd=["git", "add", "."], depends_on=("init",), conditions=(lambda _: has_files(),)),
px.TaskSpec( px.TaskSpec(
"commit", cmd=["git", "commit", "-m", "init commit"], depends_on=("add",), conditions=(has_files,) "commit",
cmd=["git", "commit", "-m", "init commit"],
depends_on=("add",),
conditions=(lambda _: has_files(),),
), ),
]), ]),
# 初始化子目录 # 初始化子目录
+55
View File
@@ -0,0 +1,55 @@
"""使用 SGLang 运行本地模型."""
import argparse
from pathlib import Path
import pyflowx as px
from pyflowx.conditions import BuiltinConditions
def main():
parser = argparse.ArgumentParser(description="Run a local model using SGLang.")
parser.add_argument("name", help="Model name.")
parser.add_argument("--dir", default=None, help="Directory of model.")
args = parser.parse_args()
if not args.name:
parser.error("name is required")
model_dir = Path(args.dir) if args.dir else Path.home() / ".models" / args.name.split("/")[-1]
if not model_dir.exists():
parser.error(f"Model directory {model_dir} does not exist.")
graph = px.Graph.from_specs([
px.TaskSpec(
name="download",
cmd=[
"uv",
"install",
"sglang[all]",
],
conditions=(BuiltinConditions.NOT(BuiltinConditions.HAS_INSTALLED("sglang")),),
verbose=True,
),
px.TaskSpec(
name="run",
cmd=[
"uvx",
"sglang",
"serve",
"--model-path",
str(model_dir),
"--host",
"0.0.0.0",
"--port",
"8000",
"--mem-fraction-static",
"0.88",
"--context-length",
"32768",
],
verbose=True,
),
])
px.run(graph, verbose=True)
+28 -34
View File
@@ -21,12 +21,10 @@ PACKAGE_DIR = "packages"
REQUIREMENTS_FILE = "requirements.txt" REQUIREMENTS_FILE = "requirements.txt"
# 受保护的包名集合 # 受保护的包名集合
_PROTECTED_PACKAGES: frozenset[str] = frozenset( _PROTECTED_PACKAGES: frozenset[str] = frozenset({
{ "pyflowx",
"pyflowx", "bitool",
"bitool", })
}
)
# ============================================================================ # ============================================================================
@@ -161,37 +159,33 @@ def main() -> None:
if args.command == "i": if args.command == "i":
graph = px.Graph.from_specs([px.TaskSpec("pip_install", cmd=["pip", "install", *args.packages], verbose=True)]) graph = px.Graph.from_specs([px.TaskSpec("pip_install", cmd=["pip", "install", *args.packages], verbose=True)])
elif args.command == "u": elif args.command == "u":
graph = px.Graph.from_specs( graph = px.Graph.from_specs([
[px.TaskSpec("pip_uninstall", fn=pip_uninstall, args=(args.packages,), verbose=True)] px.TaskSpec("pip_uninstall", fn=pip_uninstall, args=(args.packages,), verbose=True)
) ])
elif args.command == "r": elif args.command == "r":
graph = px.Graph.from_specs( graph = px.Graph.from_specs([
[ px.TaskSpec(
px.TaskSpec( "pip_reinstall",
"pip_reinstall", fn=pip_reinstall,
fn=pip_reinstall, args=(args.packages,),
args=(args.packages,), kwargs={"offline": args.offline},
kwargs={"offline": args.offline}, verbose=True,
verbose=True, )
) ])
]
)
elif args.command == "d": elif args.command == "d":
graph = px.Graph.from_specs( graph = px.Graph.from_specs([
[ px.TaskSpec(
px.TaskSpec( "pip_download",
"pip_download", fn=pip_download,
fn=pip_download, args=(args.packages,),
args=(args.packages,), kwargs={"offline": args.offline},
kwargs={"offline": args.offline}, verbose=True,
verbose=True, )
) ])
]
)
elif args.command == "up": elif args.command == "up":
graph = px.Graph.from_specs( graph = px.Graph.from_specs([
[px.TaskSpec("pip_upgrade", cmd=["python", "-m", "pip", "install", "--upgrade", "pip"], verbose=True)] px.TaskSpec("pip_upgrade", cmd=["python", "-m", "pip", "install", "--upgrade", "pip"], verbose=True)
) ])
elif args.command == "f": elif args.command == "f":
graph = px.Graph.from_specs([px.TaskSpec("pip_freeze", fn=pip_freeze, verbose=True)]) graph = px.Graph.from_specs([px.TaskSpec("pip_freeze", fn=pip_freeze, verbose=True)])
else: else:
+128 -160
View File
@@ -1,7 +1,12 @@
"""条件判断模块. """条件判断模块.
提供平台条件、应用安装条件等预定义条件判断函数, 所有条件均为 ``Callable[[Context], bool]``,接收依赖上下文映射(可能为空)。
用于 TaskSpec 的条件执行功能. 这使得条件可基于上游任务的运行时返回值做决策,实现动态分支。
内置条件分两类:
1. *静态条件* —— 不依赖上下文(平台/环境变量/安装检查),通过 ``_static``
包装忽略传入的 context,便于作为模块级常量使用。
2. *上下文条件* —— 基于上游结果判断,如 :meth:`BuiltinConditions.DEP_EQUALS`。
""" """
from __future__ import annotations from __future__ import annotations
@@ -11,10 +16,11 @@ import shutil
import subprocess import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Callable from typing import Any, Callable
# 条件判断函数类型 from .task import Condition, Context
Condition = Callable[[], bool]
__all__ = ["BuiltinConditions", "Condition", "Constants"]
class Constants: class Constants:
@@ -26,65 +32,56 @@ class Constants:
IS_POSIX: bool = sys.platform != "win32" IS_POSIX: bool = sys.platform != "win32"
def _static(predicate: Callable[[], bool], name: str) -> Condition:
"""将无参谓词包装为忽略上下文的 :class:`Condition`。"""
def _cond(_ctx: Context) -> bool:
return predicate()
_cond.__name__ = name
return _cond
# ---------------------------------------------------------------------- #
# 模块级静态条件常量
# ---------------------------------------------------------------------- #
IS_WINDOWS: Condition = _static(lambda: Constants.IS_WINDOWS, "IS_WINDOWS")
IS_LINUX: Condition = _static(lambda: Constants.IS_LINUX, "IS_LINUX")
IS_MACOS: Condition = _static(lambda: Constants.IS_MACOS, "IS_MACOS")
IS_POSIX: Condition = _static(lambda: Constants.IS_POSIX, "IS_POSIX")
class BuiltinConditions: class BuiltinConditions:
"""内置条件判断函数集合.""" """内置条件判断函数集合.
静态条件工厂返回忽略上下文的 :class:`Condition`;上下文条件工厂返回
会读取依赖结果的 :class:`Condition`。
"""
# ------------------------------------------------------------------ #
# 静态条件
# ------------------------------------------------------------------ #
@staticmethod @staticmethod
def PYTHON_VERSION(major: int, minor: int | None = None) -> bool: def PYTHON_VERSION(major: int, minor: int | None = None) -> Condition:
"""检查 Python 版本是否匹配. """检查 Python 版本是否匹配."""
Parameters
----------
major : int
主版本号.
minor : int | None
次版本号, 若为 None 则仅检查主版本.
Returns
-------
bool
版本是否匹配.
"""
if minor is None: if minor is None:
return sys.version_info.major == major return _static(lambda: sys.version_info.major == major, f"PYTHON_VERSION({major})")
return sys.version_info.major == major and sys.version_info.minor == minor return _static(
lambda: sys.version_info.major == major and sys.version_info.minor == minor,
f"PYTHON_VERSION({major},{minor})",
)
@staticmethod @staticmethod
def PYTHON_VERSION_AT_LEAST(major: int, minor: int = 0) -> bool: def PYTHON_VERSION_AT_LEAST(major: int, minor: int = 0) -> Condition:
"""检查 Python 版本是否 >= 指定版本. """检查 Python 版本是否 >= 指定版本."""
return _static(lambda: sys.version_info >= (major, minor), f"PYTHON_VERSION_AT_LEAST({major},{minor})")
Parameters
----------
major : int
主版本号.
minor : int
次版本号.
Returns
-------
bool
当前版本是否 >= 指定版本.
"""
return sys.version_info >= (major, minor)
@staticmethod @staticmethod
def IS_RUNNING(app_name: str) -> Condition: def IS_RUNNING(app_name: str) -> Condition:
"""检查指定应用是否正在运行. """检查指定应用是否正在运行."""
Parameters
----------
app_name : str
应用名称 (如 "explorer", "chrome", "python").
Returns
-------
Condition
条件判断函数.
"""
def _check() -> bool: def _check() -> bool:
if Constants.IS_WINDOWS: if Constants.IS_WINDOWS:
# Windows: 使用 tasklist 命令
result = subprocess.run( result = subprocess.run(
["tasklist", "/nh", "/fi", f"imagename eq {app_name}"], ["tasklist", "/nh", "/fi", f"imagename eq {app_name}"],
capture_output=True, capture_output=True,
@@ -93,148 +90,119 @@ class BuiltinConditions:
) )
return app_name.lower() in result.stdout.lower() return app_name.lower() in result.stdout.lower()
else: else:
# Linux/macOS: 使用 pgrep 命令 result = subprocess.run(["pgrep", "-x", app_name], capture_output=True, check=False)
result = subprocess.run(
["pgrep", "-x", app_name],
capture_output=True,
check=False,
)
return result.returncode == 0 return result.returncode == 0
_check.__name__ = f"IS_RUNNING({app_name!r})" return _static(_check, f"IS_RUNNING({app_name!r})")
return _check
@staticmethod @staticmethod
def HAS_INSTALLED(app_name: str) -> Condition: def HAS_INSTALLED(app_name: str) -> Condition:
"""检查指定应用是否已安装. """检查指定应用是否已安装."""
return _static(lambda: shutil.which(app_name) is not None, f"HAS_INSTALLED({app_name!r})")
Parameters
----------
app_name : str
应用名称 (如 "git", "python", "pytest").
Returns
-------
Condition
条件判断函数.
"""
def _check() -> bool:
return shutil.which(app_name) is not None
_check.__name__ = f"HAS_INSTALLED({app_name!r})"
return _check
@staticmethod @staticmethod
def DIR_EXISTS(dir: Path) -> Condition: def DIR_EXISTS(path: Path) -> Condition:
"""路径是否存在.""" """路径是否存在."""
return dir.exists return _static(path.exists, f"DIR_EXISTS({path!r})")
@staticmethod @staticmethod
def ENV_VAR_EXISTS(var_name: str) -> Condition: def ENV_VAR_EXISTS(var_name: str) -> Condition:
"""检查环境变量是否存在. """检查环境变量是否存在."""
return _static(lambda: var_name in os.environ, f"ENV_VAR_EXISTS({var_name!r})")
Parameters
----------
var_name : str
环境变量名.
Returns
-------
Condition
条件判断函数.
"""
def _check() -> bool:
return var_name in os.environ
_check.__name__ = f"ENV_VAR_EXISTS({var_name!r})"
return _check
@staticmethod @staticmethod
def ENV_VAR_EQUALS(var_name: str, value: str) -> Condition: def ENV_VAR_EQUALS(var_name: str, value: str) -> Condition:
"""检查环境变量是否等于指定值. """检查环境变量是否等于指定值."""
return _static(
lambda: os.environ.get(var_name) == value,
f"ENV_VAR_EQUALS({var_name!r},{value!r})",
)
Parameters # ------------------------------------------------------------------ #
---------- # 上下文条件:基于上游依赖结果
var_name : str # ------------------------------------------------------------------ #
环境变量名. @staticmethod
value : str def DEP_EQUALS(dep_name: str, value: Any) -> Condition:
期望的值. """上游任务 ``dep_name`` 的返回值等于 ``value`` 时为真。
Returns 若依赖未在上下文中(被跳过或未执行),返回 ``False``。
-------
Condition
条件判断函数.
""" """
def _check() -> bool: def _cond(ctx: Context) -> bool:
return os.environ.get(var_name) == value return dep_name in ctx and ctx[dep_name] == value
_check.__name__ = f"ENV_VAR_EQUALS({var_name!r}, {value!r})" _cond.__name__ = f"DEP_EQUALS({dep_name!r},{value!r})"
return _check return _cond
@staticmethod @staticmethod
def NOT(condition: Condition) -> Condition: def DEP_MATCHES(dep_name: str, predicate: Callable[[Any], bool]) -> Condition:
"""对条件取反. """上游任务 ``dep_name`` 的返回值满足 ``predicate`` 时为真。
Parameters 依赖不存在时返回 ``False``。
----------
condition : Condition
原始条件.
Returns
-------
Condition
取反后的条件.
""" """
def _check() -> bool: def _cond(ctx: Context) -> bool:
return not condition() if dep_name not in ctx:
return False
try:
return predicate(ctx[dep_name])
except Exception:
return False
_check.__name__ = f"NOT({getattr(condition, '__name__', repr(condition))})" _cond.__name__ = f"DEP_MATCHES({dep_name!r},{getattr(predicate, '__name__', 'pred')})"
return _check return _cond
@staticmethod
def DEP_PRESENT(dep_name: str) -> Condition:
"""上游任务 ``dep_name`` 存在于上下文(即已成功执行)时为真。"""
def _cond(ctx: Context) -> bool:
return dep_name in ctx and ctx[dep_name] is not None
_cond.__name__ = f"DEP_PRESENT({dep_name!r})"
return _cond
@staticmethod
def DEP_TRUTHY(dep_name: str) -> Condition:
"""上游任务 ``dep_name`` 的返回值为真值时为真。"""
def _cond(ctx: Context) -> bool:
return bool(ctx.get(dep_name))
_cond.__name__ = f"DEP_TRUTHY({dep_name!r})"
return _cond
# ------------------------------------------------------------------ #
# 逻辑组合
# ------------------------------------------------------------------ #
@staticmethod
def NOT(condition: Condition) -> Condition:
"""对条件取反."""
def _cond(ctx: Context) -> bool:
return not condition(ctx)
_cond.__name__ = f"NOT({getattr(condition, '__name__', repr(condition))})"
return _cond
@staticmethod @staticmethod
def AND(*conditions: Condition) -> Condition: def AND(*conditions: Condition) -> Condition:
"""多个条件的逻辑与. """多个条件的逻辑与."""
Parameters def _cond(ctx: Context) -> bool:
---------- return all(c(ctx) for c in conditions)
*conditions : Condition
条件列表.
Returns
-------
Condition
组合条件.
"""
def _check() -> bool:
return all(c() for c in conditions)
names = [getattr(c, "__name__", repr(c)) for c in conditions] names = [getattr(c, "__name__", repr(c)) for c in conditions]
_check.__name__ = f"AND({', '.join(names)})" _cond.__name__ = f"AND({', '.join(names)})"
return _check return _cond
@staticmethod @staticmethod
def OR(*conditions: Condition) -> Condition: def OR(*conditions: Condition) -> Condition:
"""多个条件的逻辑或. """多个条件的逻辑或."""
Parameters def _cond(ctx: Context) -> bool:
---------- return any(c(ctx) for c in conditions)
*conditions : Condition
条件列表.
Returns
-------
Condition
组合条件.
"""
def _check() -> bool:
return any(c() for c in conditions)
names = [getattr(c, "__name__", repr(c)) for c in conditions] names = [getattr(c, "__name__", repr(c)) for c in conditions]
_check.__name__ = f"OR({', '.join(names)})" _cond.__name__ = f"OR({', '.join(names)})"
return _check return _cond
+15 -59
View File
@@ -1,18 +1,16 @@
"""上下文注入:把上游结果转换为函数参数。 """上下文注入:把上游结果转换为函数参数。
本机制让用户可以编写普通函数,其参数名*就是*依赖声明,从而消除其他 本机制让用户可以编写普通函数,其参数名*就是*依赖声明,从而消除其他
DAG 库中泛滥的样板包装器(如 ``def wrapper(): return fn(workflow.get_task_result('x'))`` DAG 库中泛滥的样板包装器。
注入规则(按顺序求值) 注入规则(按顺序求值)
---------------------- ----------------------
1. **标注为** :class:`Context` 的参数接收完整结果映射。适用于需要遍历 1. **标注为** :class:`Context` 的参数接收完整结果映射(含硬依赖与软依赖)。
所有输入的任务 2. **名称匹配某个依赖**(硬或软)的参数接收该依赖的结果
2. **名称匹配某个依赖**的参数接收该依赖的结果。
3. ``**kwargs`` 参数以 dict 形式接收*所有*依赖结果。 3. ``**kwargs`` 参数以 dict 形式接收*所有*依赖结果。
4. ``TaskSpec.args`` / ``TaskSpec.kwargs`` 为*非依赖*参数提供静态值。 4. ``TaskSpec.args`` / ``TaskSpec.kwargs`` 为*非依赖*参数提供静态值。
若某参数无法解析且无默认值,则抛出 :class:`~pyflowx.errors.InjectionError` 若某参数无法解析且无默认值,则抛出 :class:`~pyflowx.errors.InjectionError`
并附带精确错误信息。
""" """
from __future__ import annotations from __future__ import annotations
@@ -27,21 +25,11 @@ __all__ = ["Context", "_is_context_annotation", "build_call_args", "describe_inj
def _is_context_annotation(annotation: Any) -> bool: def _is_context_annotation(annotation: Any) -> bool:
"""判断参数标注是否为(或指向)``Context``。 """判断参数标注是否为(或指向)``Context``。"""
处理三种形式:
* ``Context`` 别名对象本身;
* ``__name__``/``_name`` 为 ``Context`` 或 ``Mapping`` 的 typing 别名;
* *字符串*标注(``from __future__ import annotations`` 会在运行时
把所有标注变为字符串),如 ``"Context"`` 或 ``"px.Context"``。
"""
if annotation is Context: if annotation is Context:
return True return True
# `from __future__ import annotations` 产生的字符串标注。
if isinstance(annotation, str): if isinstance(annotation, str):
# 匹配 "Context"、"px.Context"、"pyflowx.Context" 等。
return annotation == "Context" or annotation.endswith(".Context") return annotation == "Context" or annotation.endswith(".Context")
# 按限定名匹配,支持 ``from pyflowx import Context`` 再导出。
name = getattr(annotation, "__name__", None) or getattr(annotation, "_name", None) name = getattr(annotation, "__name__", None) or getattr(annotation, "_name", None)
return name in ("Context", "Mapping") return name in ("Context", "Mapping")
@@ -52,39 +40,22 @@ def build_call_args(
) -> tuple[tuple[Any, ...], dict[str, Any]]: ) -> tuple[tuple[Any, ...], dict[str, Any]]:
"""解析用于调用 ``spec.fn`` 的 ``(args, kwargs)``。 """解析用于调用 ``spec.fn`` 的 ``(args, kwargs)``。
参数 ``context`` 必须已包含所有硬依赖与软依赖的结果(软依赖被跳过时由
---- 执行器填入 :attr:`TaskSpec.defaults` 中的默认值)。
spec:
任务 spec,提供 ``fn``、``depends_on``、``args``、``kwargs``。
context:
依赖名 -> 结果值的映射。仅保证本任务自身的 ``depends_on`` 条目
存在;其他任务的结果被排除,以保持注入的确定性。
返回
----
(args, kwargs)
可直接展开为 ``spec.fn(*args, **kwargs)``。
抛出
----
InjectionError
若必需参数无法满足,或静态 ``kwargs`` 与注入依赖名冲突。
""" """
# 使用 effective_fn 而不是 fn,以支持 cmd 参数
fn = spec.effective_fn fn = spec.effective_fn
sig = inspect.signature(fn) sig = inspect.signature(fn)
params = sig.parameters params = sig.parameters
# 检测特殊参数类型。
var_keyword = next( var_keyword = next(
(p for p in params.values() if p.kind == inspect.Parameter.VAR_KEYWORD), (p for p in params.values() if p.kind == inspect.Parameter.VAR_KEYWORD),
None, None,
) )
# 本任务相关的上下文子集。 # 本任务相关的上下文子集:硬依赖 + 软依赖
dep_context: dict[str, Any] = {name: context[name] for name in spec.depends_on if name in context} all_deps = set(spec.depends_on) | set(spec.soft_depends_on)
dep_context: dict[str, Any] = {name: context[name] for name in all_deps if name in context}
# 检测静态 kwargs 与依赖名的冲突。
collisions = set(spec.kwargs) & set(dep_context) collisions = set(spec.kwargs) & set(dep_context)
if collisions: if collisions:
raise InjectionError( raise InjectionError(
@@ -96,8 +67,6 @@ def build_call_args(
injected_kwargs: dict[str, Any] = {} injected_kwargs: dict[str, Any] = {}
leftover_dep_results: dict[str, Any] = dict(dep_context) leftover_dep_results: dict[str, Any] = dict(dep_context)
# 被 spec.args 消费的位置参数。记录哪些参数名已被位置填充,
# 以便在基于名称的注入(依赖 / Context / 静态 kwargs)时跳过。
positional_params: list[str] = [] positional_params: list[str] = []
positional_kinds = ( positional_kinds = (
inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_ONLY,
@@ -106,33 +75,25 @@ def build_call_args(
for pname, param in params.items(): for pname, param in params.items():
if param.kind in positional_kinds: if param.kind in positional_kinds:
positional_params.append(pname) positional_params.append(pname)
# 前 len(spec.args) 个位置参数由 spec.args 填充。
args_filled: set[str] = set(positional_params[: len(spec.args)]) args_filled: set[str] = set(positional_params[: len(spec.args)])
for pname, param in params.items(): for pname, param in params.items():
# 跳过已被位置 spec.args 填充的参数。
if pname in args_filled: if pname in args_filled:
continue continue
# 规则 1:标注为 Context -> 完整映射。
if _is_context_annotation(param.annotation): if _is_context_annotation(param.annotation):
injected_kwargs[pname] = dep_context injected_kwargs[pname] = dep_context
continue continue
# 规则 2:名称匹配某个依赖。
if pname in dep_context: if pname in dep_context:
injected_kwargs[pname] = dep_context[pname] injected_kwargs[pname] = dep_context[pname]
leftover_dep_results.pop(pname, None) leftover_dep_results.pop(pname, None)
continue continue
# 规则 3:在循环后通过 **kwargs 处理。
# 规则 4:静态 kwargs 填充其余参数。
if pname in spec.kwargs: if pname in spec.kwargs:
injected_kwargs[pname] = spec.kwargs[pname] injected_kwargs[pname] = spec.kwargs[pname]
continue continue
# 该参数无来源:必须有默认值,否则报错。
if param.default is inspect.Parameter.empty and param.kind not in ( if param.default is inspect.Parameter.empty and param.kind not in (
inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_KEYWORD,
@@ -142,9 +103,7 @@ def build_call_args(
f"parameter {pname!r} has no dependency, static value, or default.", f"parameter {pname!r} has no dependency, static value, or default.",
) )
# 规则 3:**kwargs 吞掉剩余依赖结果。
if var_keyword is not None and leftover_dep_results: if var_keyword is not None and leftover_dep_results:
# 先合并静态 kwargs,再合并依赖结果(冲突已在上方拒绝)。
merged = dict(spec.kwargs) merged = dict(spec.kwargs)
merged.update(injected_kwargs) merged.update(injected_kwargs)
merged.update(leftover_dep_results) merged.update(leftover_dep_results)
@@ -154,14 +113,9 @@ def build_call_args(
def describe_injection(spec: TaskSpec[Any]) -> str: def describe_injection(spec: TaskSpec[Any]) -> str:
"""生成任务参数注入方式的人类可读描述。 """生成任务参数注入方式的人类可读描述。供 ``dry_run`` 使用。"""
供 ``dry_run`` 使用,在不执行的情况下展示执行计划。
"""
# 使用 effective_fn 而不是 fn,以支持 cmd 参数
fn = spec.effective_fn fn = spec.effective_fn
sig = inspect.signature(fn) sig = inspect.signature(fn)
# 确定哪些位置参数由 spec.args 填充。
positional_params = [ positional_params = [
p p
for p, param in sig.parameters.items() for p, param in sig.parameters.items()
@@ -172,6 +126,7 @@ def describe_injection(spec: TaskSpec[Any]) -> str:
) )
] ]
args_filled = set(positional_params[: len(spec.args)]) args_filled = set(positional_params[: len(spec.args)])
all_deps = set(spec.depends_on) | set(spec.soft_depends_on)
parts = [] parts = []
for pname, param in sig.parameters.items(): for pname, param in sig.parameters.items():
if pname in args_filled: if pname in args_filled:
@@ -179,8 +134,9 @@ def describe_injection(spec: TaskSpec[Any]) -> str:
parts.append(f"{pname}={spec.args[idx]!r}") parts.append(f"{pname}={spec.args[idx]!r}")
elif _is_context_annotation(param.annotation): elif _is_context_annotation(param.annotation):
parts.append(f"{pname}=<Context>") parts.append(f"{pname}=<Context>")
elif pname in spec.depends_on: elif pname in all_deps:
parts.append(f"{pname}=<result:{pname}>") tag = "soft" if pname in spec.soft_depends_on else "dep"
parts.append(f"{pname}=<{tag}:{pname}>")
elif pname in spec.kwargs: elif pname in spec.kwargs:
parts.append(f"{pname}={spec.kwargs[pname]!r}") parts.append(f"{pname}={spec.kwargs[pname]!r}")
elif param.default is not inspect.Parameter.empty: elif param.default is not inspect.Parameter.empty:
+3 -1
View File
@@ -55,7 +55,9 @@ def main() -> None:
depends_on=("extract_customers", "extract_orders"), depends_on=("extract_customers", "extract_orders"),
tags=("transform",), tags=("transform",),
), ),
px.TaskSpec("load", load, depends_on=("transform",), retries=1, tags=("load",)), px.TaskSpec(
"load", load, depends_on=("transform",), retry=px.RetryPolicy(max_attempts=1, delay=1.0), tags=("load",)
),
]) ])
print("=== Execution plan ===") print("=== Execution plan ===")
+365 -179
View File
@@ -1,15 +1,26 @@
"""执行器与公共 :func:`run` 入口。 """执行器与公共 :func:`run` 入口。
种执行策略共享一个逐层驱动器 种执行策略:
* ``sequential`` —— 确定性、一次一个任务。最适合调试。 * ``sequential`` —— 确定性、一次一个任务。最适合调试。
* ``thread`` —— 通过线程池实现层内并发。最适合 I/O 密集型同步任务。 * ``thread`` —— 通过线程池实现层内并发。最适合 I/O 密集型同步任务。
* ``async`` —— 通过 ``asyncio.gather`` 实现层内并发。同步任务被 * ``async`` —— 通过 ``asyncio.gather`` 实现层内并发。同步任务被
卸载到线程池;异步任务运行在事件循环上。最适合 卸载到线程池;异步任务运行在事件循环上。最适合
I/O 密集型异步任务。 I/O 密集型异步任务。
* ``dependency`` —— 依赖驱动调度:任务在其所有硬依赖完成后立即启动,
无需等待同层其他任务。最大化并行度。
三者都遵循 ``retries``、``timeout``、上下文注入、状态后端(续跑), 所有策略共享统一异步内核,支持:
并向观察者发出 :class:`~pyflowx.task.TaskEvent`。 * :class:`RetryPolicy`max_attempts/delay/backoff/jitter/retry_on
* 软依赖注入与默认值
* :class:`TaskHooks`pre_run/post_run/on_failure
* 按任务策略覆盖
* 优先级排序(同层内)
* 并发限制(concurrency_key + concurrency_limits
* ``continue_on_error``
* ``cache_key`` 存储键
* 条件判断(上下文感知)
* 状态后端(续跑)
""" """
from __future__ import annotations from __future__ import annotations
@@ -18,6 +29,7 @@ import asyncio
import concurrent.futures import concurrent.futures
import inspect import inspect
import logging import logging
import threading
from datetime import datetime from datetime import datetime
from typing import Any, Awaitable, Callable, Literal, Mapping, cast from typing import Any, Awaitable, Callable, Literal, Mapping, cast
@@ -26,24 +38,24 @@ from .errors import TaskFailedError, TaskTimeoutError
from .graph import Graph from .graph import Graph
from .report import RunReport from .report import RunReport
from .storage import StateBackend, resolve_backend from .storage import StateBackend, resolve_backend
from .task import TaskEvent, TaskResult, TaskSpec, TaskStatus from .task import TaskEvent, TaskHooks, TaskResult, TaskSpec, TaskStatus
logger = logging.getLogger("pyflowx") logger = logging.getLogger("pyflowx")
# 观察者回调类型。 # 观察者回调类型。
EventCallback = Callable[[TaskEvent], None] EventCallback = Callable[[TaskEvent], None]
Strategy = Literal["sequential", "thread", "async"] Strategy = Literal["sequential", "thread", "async", "dependency"]
# ---------------------------------------------------------------------- #
# 辅助
# ---------------------------------------------------------------------- #
def _is_async_fn(spec: TaskSpec[Any]) -> bool: def _is_async_fn(spec: TaskSpec[Any]) -> bool:
"""判断 ``spec.effective_fn`` 是否为协程函数。""" """判断 ``spec.effective_fn`` 是否为协程函数。"""
return inspect.iscoroutinefunction(spec.effective_fn) return inspect.iscoroutinefunction(spec.effective_fn)
def _emit( def _emit(on_event: EventCallback | None, result: TaskResult[Any]) -> None:
on_event: EventCallback | None,
result: TaskResult[Any],
) -> None:
"""若注册了回调则触发一个观察者事件。""" """若注册了回调则触发一个观察者事件。"""
if on_event is None: if on_event is None:
return return
@@ -59,71 +71,60 @@ def _emit(
) )
def _log_retry(spec: TaskSpec[Any], attempts: int, max_attempts: int, exc: BaseException) -> None: def _log_retry(spec: TaskSpec[Any], attempt: int, max_attempts: int, exc: BaseException) -> None:
"""记录重试日志sync 与 async 共享,便于测试覆盖)""" """记录重试日志。"""
logger.warning( logger.warning(
"task %r failed (attempt %d/%d): %r; retrying", "task %r failed (attempt %d/%d): %r; retrying",
spec.name, spec.name,
attempts, attempt,
max_attempts, max_attempts,
exc, exc,
) )
def _finalize_failure( def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None:
result: TaskResult[Any], """安全调用钩子(异常仅记录,不影响任务状态)。"""
layer_idx: int | None, hook: Callable[..., None] | None = getattr(hooks, fn_name, None)
on_event: EventCallback | None = None, if hook is None:
) -> None: return
"""标记任务为 FAILED 并抛出 TaskFailedError。""" try:
result.status = TaskStatus.FAILED hook(*args)
result.finished_at = datetime.now() except Exception as exc:
_emit(on_event, result) logger.warning("hook %s raised: %r", fn_name, exc)
raise TaskFailedError(
task=result.spec.name,
cause=result.error if result.error is not None else RuntimeError("unknown"),
attempts=result.attempts,
layer=layer_idx,
)
def _check_upstream_skipped( def _check_upstream_skipped(
spec: TaskSpec[Any], spec: TaskSpec[Any],
report: RunReport | None, report: RunReport | None,
) -> tuple[bool, str | None]: ) -> tuple[bool, str | None]:
"""检查上游任务是否被 SKIPPED。 """检查硬依赖上游任务是否被 SKIPPED 或 FAILED
Returns 软依赖不影响本检查——软依赖被跳过时注入默认值。
-------
tuple[bool, str | None]
(是否应该跳过, 跳过原因)
""" """
if report is None: if report is None:
return False, None return False, None
# 若任务允许上游跳过,则不检查上游状态
if spec.allow_upstream_skip: if spec.allow_upstream_skip:
return False, None return False, None
for dep in spec.depends_on: for dep in spec.depends_on:
if dep in report.results and report.results[dep].status == TaskStatus.SKIPPED: if dep not in report.results:
return True, f"上游任务 '{dep}' 被跳过" continue
dep_status = report.results[dep].status
if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED):
return True, f"上游任务 '{dep}' 状态为 {dep_status.value}"
return False, None return False, None
def _evaluate_skip_reason(spec: TaskSpec[Any]) -> str | None: def _evaluate_conditions(spec: TaskSpec[Any], context: Mapping[str, Any]) -> str | None:
"""单次求值所有条件与 skip_if_missing,返回跳过原因或 None。 """求值所有条件,返回跳过原因或 ``None``
与旧实现不同:条件只求值一次。`should_execute()` 内部会调用所有条件, 条件接收上下文映射(硬依赖 + 软依赖结果)。
若再分支调用 `_is_cmd_available` 之外的逻辑会二次求值(如
``IS_RUNNING`` 会 spawn 两次 subprocess)。此处显式逐个求值并记录结果,
失败原因直接来自求值过程,无需二次调用。
""" """
# 1. 逐个求值条件,记录失败项。
failed_conditions: list[str] = [] failed_conditions: list[str] = []
for condition in spec.conditions: for condition in spec.conditions:
try: try:
ok = condition() ok = condition(context)
except Exception: except Exception:
ok = False ok = False
name = getattr(condition, "__name__", None) or "匿名条件(执行错误)" name = getattr(condition, "__name__", None) or "匿名条件(执行错误)"
@@ -135,7 +136,6 @@ def _evaluate_skip_reason(spec: TaskSpec[Any]) -> str | None:
if failed_conditions: if failed_conditions:
return f"条件不满足: {', '.join(failed_conditions)}" return f"条件不满足: {', '.join(failed_conditions)}"
# 2. skip_if_missing 检查(仅对 list[str] 命令有效)。
if spec.skip_if_missing and not spec._is_cmd_available(): if spec.skip_if_missing and not spec._is_cmd_available():
cmd_name = spec.cmd[0] if isinstance(spec.cmd, list) and spec.cmd else "unknown" cmd_name = spec.cmd[0] if isinstance(spec.cmd, list) and spec.cmd else "unknown"
return f"命令不存在: {cmd_name}" return f"命令不存在: {cmd_name}"
@@ -148,10 +148,7 @@ def _make_skipped_result(
reason: str, reason: str,
on_event: EventCallback | None, on_event: EventCallback | None,
) -> TaskResult[Any]: ) -> TaskResult[Any]:
"""构造 SKIPPED 的 TaskResult 并发出事件、打印日志。 """构造 SKIPPED 的 TaskResult"""
sync 与 async 执行路径共用,消除重复的 result 构造/emit/print 代码。
"""
result: TaskResult[Any] = TaskResult( result: TaskResult[Any] = TaskResult(
spec=spec, spec=spec,
status=TaskStatus.SKIPPED, status=TaskStatus.SKIPPED,
@@ -165,31 +162,118 @@ def _make_skipped_result(
return result return result
def _build_context(
spec: TaskSpec[Any],
global_context: Mapping[str, Any],
report: RunReport | None = None, # noqa: ARG001
) -> dict[str, Any]:
"""构建本任务的上下文:硬依赖 + 软依赖(含默认值回退)。
硬依赖:若上游 SKIPPED/FAILED 则不注入(本任务通常也会被跳过)。
软依赖:上游成功则注入其值;否则注入 ``spec.defaults`` 中的默认值(或 ``None``)。
"""
ctx: dict[str, Any] = {}
for dep in spec.depends_on:
if dep in global_context:
ctx[dep] = global_context[dep]
for dep in spec.soft_depends_on:
if dep in global_context:
ctx[dep] = global_context[dep]
elif dep in spec.defaults:
ctx[dep] = spec.defaults[dep]
else:
ctx[dep] = None
return ctx
def _apply_cached(
name: str,
spec: TaskSpec[Any],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
) -> bool:
"""若 ``name`` 命中缓存,写入 context/report 并返回 True。"""
storage_key = spec.storage_key(context)
if not backend.has(storage_key):
return False
cached = backend.get(storage_key)
context[name] = cached
result = TaskResult(spec=spec, status=TaskStatus.SKIPPED, value=cached, reason="缓存命中")
report.results[name] = result
_emit(on_event, result)
logger.info("task %r skipped (cached)", name)
return True
def _prepare_for_execution( def _prepare_for_execution(
spec: TaskSpec[Any], spec: TaskSpec[Any],
context: Mapping[str, Any],
report: RunReport | None, report: RunReport | None,
on_event: EventCallback | None, on_event: EventCallback | None,
) -> TaskResult[Any] | None: ) -> TaskResult[Any] | None:
"""执行前的统一预检:上游跳过 / 条件跳过。 """执行前预检:上游跳过 / 条件跳过。
Returns 返回 SKIPPED TaskResult 或 ``None``(继续执行)。
-------
TaskResult | None
若应跳过,返回已填好的 SKIPPED 结果;否则返回 None 表示继续执行。
""" """
# 上游跳过检查
should_skip, skip_reason = _check_upstream_skipped(spec, report) should_skip, skip_reason = _check_upstream_skipped(spec, report)
if should_skip: if should_skip:
return _make_skipped_result(spec, skip_reason or "上游任务被跳过", on_event) return _make_skipped_result(spec, skip_reason or "上游任务被跳过", on_event)
# 条件 / skip_if_missing 检查(单次求值) skip_reason = _evaluate_conditions(spec, context)
skip_reason = _evaluate_skip_reason(spec)
if skip_reason is not None: if skip_reason is not None:
return _make_skipped_result(spec, skip_reason, on_event) return _make_skipped_result(spec, skip_reason, on_event)
return None return None
def _finalize_failure(
result: TaskResult[Any],
layer_idx: int | None,
on_event: EventCallback | None = None,
continue_on_error: bool = False,
) -> None:
"""标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。"""
result.status = TaskStatus.FAILED
result.finished_at = datetime.now()
_emit(on_event, result)
if continue_on_error:
logger.warning(
"task %r failed but continue_on_error=True; continuing.",
result.spec.name,
)
return
raise TaskFailedError(
task=result.spec.name,
cause=result.error if result.error is not None else RuntimeError("unknown"),
attempts=result.attempts,
layer=layer_idx,
)
def _sleep_for_retry(spec: TaskSpec[Any], attempt: int) -> None:
"""重试前的同步等待。"""
wait = spec.retry.wait_seconds(attempt)
if wait > 0:
import time
time.sleep(wait)
async def _async_sleep_for_retry(spec: TaskSpec[Any], attempt: int) -> None:
"""重试前的异步等待。"""
wait = spec.retry.wait_seconds(attempt)
if wait > 0:
await asyncio.sleep(wait)
# ---------------------------------------------------------------------- #
# 同步执行内核
# ---------------------------------------------------------------------- #
def _run_sync_with_retry( def _run_sync_with_retry(
spec: TaskSpec[Any], spec: TaskSpec[Any],
context: Mapping[str, Any], context: Mapping[str, Any],
@@ -198,44 +282,47 @@ def _run_sync_with_retry(
report: RunReport | None = None, report: RunReport | None = None,
) -> TaskResult[Any]: ) -> TaskResult[Any]:
"""执行同步任务并带重试;返回填充好的 TaskResult。""" """执行同步任务并带重试;返回填充好的 TaskResult。"""
# 统一预检:上游跳过 / 条件跳过(条件单次求值) skipped = _prepare_for_execution(spec, context, report, on_event)
skipped = _prepare_for_execution(spec, report, on_event)
if skipped is not None: if skipped is not None:
return skipped return skipped
result: TaskResult[Any] = TaskResult(spec=spec) result: TaskResult[Any] = TaskResult(spec=spec)
result.started_at = datetime.now() result.started_at = datetime.now()
max_attempts = spec.retries + 1 max_attempts = spec.retry.max_attempts
args, kwargs = build_call_args(spec, context) args, kwargs = build_call_args(spec, context)
_run_hooks(spec.hooks, "pre_run", spec)
while True: while True:
result.attempts += 1 result.attempts += 1
try: try:
result.value = spec.effective_fn(*args, **kwargs) with spec.env_context():
result.value = spec.effective_fn(*args, **kwargs)
result.status = TaskStatus.SUCCESS result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now() result.finished_at = datetime.now()
_run_hooks(spec.hooks, "post_run", spec, result.value)
return result return result
except Exception as exc: except Exception as exc:
result.error = exc result.error = exc
if result.attempts >= max_attempts: if result.attempts >= max_attempts or not spec.retry.should_retry(exc):
_finalize_failure(result, layer_idx, on_event) _run_hooks(spec.hooks, "on_failure", spec, exc)
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return result
_log_retry(spec, result.attempts, max_attempts, exc) _log_retry(spec, result.attempts, max_attempts, exc)
raise AssertionError("unreachable") # pragma: no cover _sleep_for_retry(spec, result.attempts)
# pragma: no cover
# ---------------------------------------------------------------------- #
# 异步执行内核
# ---------------------------------------------------------------------- #
async def _execute_async_task( async def _execute_async_task(
spec: TaskSpec[Any], spec: TaskSpec[Any],
args: tuple[Any, ...], args: tuple[Any, ...],
kwargs: dict[str, Any], kwargs: dict[str, Any],
loop: asyncio.AbstractEventLoop, loop: asyncio.AbstractEventLoop,
) -> Any: ) -> Any:
"""执行异步或同步任务(带超时处理)。 """执行异步或同步任务(带超时处理)。"""
Returns
-------
Any
任务返回值
"""
if _is_async_fn(spec): if _is_async_fn(spec):
coro = cast(Awaitable[Any], spec.effective_fn(*args, **kwargs)) coro = cast(Awaitable[Any], spec.effective_fn(*args, **kwargs))
if spec.timeout is not None: if spec.timeout is not None:
@@ -243,9 +330,10 @@ async def _execute_async_task(
else: else:
return await coro return await coro
else: else:
# 将同步工作卸载到线程,保持事件循环存活。
def fn_call() -> Any: def fn_call() -> Any:
return spec.effective_fn(*args, **kwargs) with spec.env_context():
return spec.effective_fn(*args, **kwargs)
if spec.timeout is not None: if spec.timeout is not None:
return await asyncio.wait_for(loop.run_in_executor(None, fn_call), timeout=spec.timeout) return await asyncio.wait_for(loop.run_in_executor(None, fn_call), timeout=spec.timeout)
@@ -259,76 +347,74 @@ async def _run_async_with_retry(
layer_idx: int | None, layer_idx: int | None,
on_event: EventCallback | None = None, on_event: EventCallback | None = None,
report: RunReport | None = None, report: RunReport | None = None,
semaphore: asyncio.Semaphore | None = None,
) -> TaskResult[Any]: ) -> TaskResult[Any]:
"""在事件循环上执行任务(同步或异步)并带重试。""" """在事件循环上执行任务(同步或异步)并带重试。"""
# 统一预检:上游跳过 / 条件跳过(条件单次求值) skipped = _prepare_for_execution(spec, context, report, on_event)
skipped = _prepare_for_execution(spec, report, on_event)
if skipped is not None: if skipped is not None:
return skipped return skipped
result: TaskResult[Any] = TaskResult[Any](spec=spec) if semaphore is not None:
async with semaphore:
return await _run_async_inner(spec, context, layer_idx, on_event, report)
return await _run_async_inner(spec, context, layer_idx, on_event, report)
async def _run_async_inner(
spec: TaskSpec[Any],
context: Mapping[str, Any],
layer_idx: int | None,
on_event: EventCallback | None = None,
report: RunReport | None = None, # noqa: ARG001
) -> TaskResult[Any]:
"""异步执行内核的内部实现(已获取 semaphore 后)。"""
result: TaskResult[Any] = TaskResult(spec=spec)
result.started_at = datetime.now() result.started_at = datetime.now()
max_attempts = spec.retries + 1 max_attempts = spec.retry.max_attempts
args, kwargs = build_call_args(spec, context) args, kwargs = build_call_args(spec, context)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
_run_hooks(spec.hooks, "pre_run", spec)
while True: while True:
result.attempts += 1 result.attempts += 1
try: try:
result.value = await _execute_async_task(spec, args, kwargs, loop) result.value = await _execute_async_task(spec, args, kwargs, loop)
result.status = TaskStatus.SUCCESS result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now() result.finished_at = datetime.now()
_run_hooks(spec.hooks, "post_run", spec, result.value)
return result return result
except asyncio.TimeoutError: except asyncio.TimeoutError:
result.error = TaskTimeoutError(spec.name, spec.timeout or 0.0) exc: BaseException = TaskTimeoutError(spec.name, spec.timeout or 0.0)
if result.attempts >= max_attempts: result.error = exc
_finalize_failure(result, layer_idx, on_event) if result.attempts >= max_attempts or not spec.retry.should_retry(exc):
_run_hooks(spec.hooks, "on_failure", spec, exc)
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return result
logger.warning( logger.warning(
"task %r timed out (attempt %d/%d); retrying", "task %r timed out (attempt %d/%d); retrying",
spec.name, spec.name,
result.attempts, result.attempts,
max_attempts, max_attempts,
) )
await _async_sleep_for_retry(spec, result.attempts)
except Exception as exc: except Exception as exc:
result.error = exc result.error = exc
if result.attempts >= max_attempts: if result.attempts >= max_attempts or not spec.retry.should_retry(exc):
_finalize_failure(result, layer_idx, on_event) _run_hooks(spec.hooks, "on_failure", spec, exc)
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return result
_log_retry(spec, result.attempts, max_attempts, exc) _log_retry(spec, result.attempts, max_attempts, exc)
raise AssertionError("unreachable") # pragma: no cover await _async_sleep_for_retry(spec, result.attempts)
# pragma: no cover
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# 层驱动 # 层执行
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
def _build_context( def _sort_by_priority(layer: list[str], graph: Graph) -> list[str]:
spec: TaskSpec[Any], """按优先级降序排序(稳定排序)。"""
global_context: Mapping[str, Any], return sorted(layer, key=lambda n: -graph.resolved_spec(n).priority)
) -> Mapping[str, Any]:
"""将全局上下文限制为本任务的依赖。"""
return {dep: global_context[dep] for dep in spec.depends_on if dep in global_context}
def _apply_cached(
name: str,
graph: Graph,
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
) -> bool:
"""若 ``name`` 命中缓存,写入 context/report 并返回 True;否则返回 False。
sequential / thread / async 三种层驱动共用,消除缓存命中分支的重复代码。
"""
if not backend.has(name):
return False
cached = backend.get(name)
context[name] = cached
result = TaskResult(spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached, reason="缓存命中")
report.results[name] = result
_emit(on_event, result)
logger.info("task %r skipped (cached)", name)
return True
def _execute_layer_sequential( def _execute_layer_sequential(
@@ -340,14 +426,16 @@ def _execute_layer_sequential(
layer_idx: int, layer_idx: int,
on_event: EventCallback | None, on_event: EventCallback | None,
) -> None: ) -> None:
"""逐个运行某层的任务。""" """逐个运行某层的任务(按优先级排序)"""
for name in layer: for name in _sort_by_priority(layer, graph):
spec = graph.spec(name) spec = graph.resolved_spec(name)
if _apply_cached(name, graph, context, report, backend, on_event): if _apply_cached(name, spec, context, report, backend, on_event):
continue continue
result = _run_sync_with_retry(spec, _build_context(spec, context), layer_idx, on_event, report) task_ctx = _build_context(spec, context, report)
result = _run_sync_with_retry(spec, task_ctx, layer_idx, on_event, report)
context[name] = result.value context[name] = result.value
backend.save(name, result.value) if result.status == TaskStatus.SUCCESS:
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result report.results[name] = result
_emit(on_event, result) _emit(on_event, result)
@@ -361,42 +449,68 @@ def _execute_layer_threaded(
layer_idx: int, layer_idx: int,
on_event: EventCallback | None, on_event: EventCallback | None,
max_workers: int, max_workers: int,
concurrency_limits: Mapping[str, int],
) -> None: ) -> None:
"""在线程池中并发运行某层的任务。""" """在线程池中并发运行某层的任务。"""
# 先同步满足已缓存任务。
to_run: list[str] = [] to_run: list[str] = []
for name in layer: for name in layer:
if _apply_cached(name, graph, context, report, backend, on_event): spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context, report)
if _apply_cached(name, spec, context, report, backend, on_event):
continue continue
to_run.append(name) to_run.append(name)
if not to_run: if not to_run:
return return
to_run = _sort_by_priority(to_run, graph)
# 为每个 concurrency_key 创建线程信号量
semaphores: dict[str, threading.Semaphore] = {}
for name in to_run:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = threading.Semaphore(limit)
context_snapshot = dict(context)
lock = threading.Lock()
def _run_threaded_task(name: str) -> TaskResult[Any]:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report)
sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None
if sem is not None:
sem.acquire()
try:
return _run_sync_with_retry(spec, task_ctx, layer_idx, on_event, report)
finally:
if sem is not None:
sem.release()
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
future_to_name: dict[concurrent.futures.Future[TaskResult[Any]], str] = {} future_to_name: dict[concurrent.futures.Future[TaskResult[Any]], str] = {}
for name in to_run: for name in to_run:
spec = graph.spec(name) fut = pool.submit(_run_threaded_task, name)
# 为本任务快照上下文以避免竞态。
task_ctx = _build_context(spec, context)
fut = pool.submit(_run_sync_with_retry, spec, task_ctx, layer_idx, on_event, report)
future_to_name[fut] = name future_to_name[fut] = name
# 统一收集后再写 context,与 async 版本行为一致:
# 避免边完成边写共享 dict 造成的可见性不一致。
completed: dict[str, TaskResult[Any]] = {} completed: dict[str, TaskResult[Any]] = {}
try: try:
for fut in concurrent.futures.as_completed(future_to_name): for fut in concurrent.futures.as_completed(future_to_name):
name = future_to_name[fut] name = future_to_name[fut]
result = fut.result() # 失败时抛出 TaskFailedError result = fut.result()
completed[name] = result completed[name] = result
finally: finally:
# 无论是否抛出,都先把已完成任务的结果落盘并写回 context/report。 with lock:
for name, result in completed.items(): for name, result in completed.items():
context[name] = result.value context[name] = result.value
backend.save(name, result.value) if result.status == TaskStatus.SUCCESS:
report.results[name] = result spec = graph.resolved_spec(name)
_emit(on_event, result) task_ctx = _build_context(spec, context_snapshot, report)
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
async def _execute_layer_async( async def _execute_layer_async(
@@ -407,52 +521,122 @@ async def _execute_layer_async(
backend: StateBackend, backend: StateBackend,
layer_idx: int, layer_idx: int,
on_event: EventCallback | None, on_event: EventCallback | None,
concurrency_limits: Mapping[str, int],
) -> None: ) -> None:
"""在事件循环上并发运行某层的任务。""" """在事件循环上并发运行某层的任务。"""
to_run: list[str] = [] to_run: list[str] = []
for name in layer: for name in layer:
if _apply_cached(name, graph, context, report, backend, on_event): spec = graph.resolved_spec(name)
if _apply_cached(name, spec, context, report, backend, on_event):
continue continue
to_run.append(name) to_run.append(name)
if not to_run: if not to_run:
return return
coros = [] to_run = _sort_by_priority(to_run, graph)
for name in to_run:
spec = graph.spec(name)
task_ctx = _build_context(spec, context)
coros.append(_run_async_with_retry(spec, task_ctx, layer_idx, on_event, report))
# 为每个 concurrency_key 创建异步信号量
semaphores: dict[str, asyncio.Semaphore] = {}
for name in to_run:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = asyncio.Semaphore(limit)
context_snapshot = dict(context)
async def _run_async_task_wrapped(name: str) -> TaskResult[Any]:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report)
sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None
if sem is not None:
async with sem:
return await _run_async_with_retry(spec, task_ctx, layer_idx, on_event, report)
return await _run_async_with_retry(spec, task_ctx, layer_idx, on_event, report)
coros = [_run_async_task_wrapped(name) for name in to_run]
results = await asyncio.gather(*coros) results = await asyncio.gather(*coros)
for name, result in zip(to_run, results): for name, result in zip(to_run, results):
context[name] = result.value context[name] = result.value
backend.save(name, result.value) if result.status == TaskStatus.SUCCESS:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report)
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result report.results[name] = result
_emit(on_event, result) _emit(on_event, result)
# ---------------------------------------------------------------------- #
# 依赖驱动调度
# ---------------------------------------------------------------------- #
async def _drive_dependency_async(
graph: Graph,
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
concurrency_limits: Mapping[str, int],
) -> None:
"""依赖驱动调度:任务在硬依赖完成后立即启动,无层屏障。
所有任务通过 asyncio 并发调度。同步任务卸载到线程池。
"""
all_names = set(graph.all_specs().keys())
semaphores: dict[str, asyncio.Semaphore] = {}
for name in all_names:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = asyncio.Semaphore(limit)
futures: dict[str, asyncio.Future[TaskResult[Any]]] = {}
async def _run_task(name: str) -> TaskResult[Any]:
spec = graph.resolved_spec(name)
# 等待所有硬依赖完成
for dep in spec.depends_on:
if dep in futures:
await futures[dep]
# 等待所有软依赖完成(但不检查其状态)
for dep in spec.soft_depends_on:
if dep in futures:
await futures[dep]
task_ctx = _build_context(spec, context, report)
if _apply_cached(name, spec, context, report, backend, on_event):
return report.results[name]
sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None
if sem is not None:
async with sem:
result = await _run_async_with_retry(spec, task_ctx, None, on_event, report)
else:
result = await _run_async_with_retry(spec, task_ctx, None, on_event, report)
context[name] = result.value
if result.status == TaskStatus.SUCCESS:
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
return result
loop = asyncio.get_event_loop()
for name in all_names:
futures[name] = loop.create_task(_run_task(name))
await asyncio.gather(*futures.values())
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# 公共 API # 公共 API
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
def _make_verbose_callback( def _make_verbose_callback(on_event: EventCallback | None) -> EventCallback:
on_event: EventCallback | None, """包装 on_event 回调, 在 verbose 模式下打印任务生命周期。"""
) -> EventCallback | None:
"""包装 on_event 回调, 在 verbose 模式下打印任务生命周期.
Parameters
----------
on_event : EventCallback | None
用户提供的原始回调, 若为 None 则仅打印.
Returns
-------
EventCallback | None
包装后的回调.
"""
def _verbose_callback(event: TaskEvent) -> None: def _verbose_callback(event: TaskEvent) -> None:
# 先打印生命周期信息
dur = f" ({event.duration:.3f}s)" if event.duration is not None else "" dur = f" ({event.duration:.3f}s)" if event.duration is not None else ""
if event.status == TaskStatus.RUNNING: # pragma: no cover if event.status == TaskStatus.RUNNING: # pragma: no cover
print(f"[verbose] 任务 {event.task!r} 开始执行...", flush=True) print(f"[verbose] 任务 {event.task!r} 开始执行...", flush=True)
@@ -464,13 +648,9 @@ def _make_verbose_callback(
f"[verbose] 任务 {event.task!r} 失败{dur} (尝试 {event.attempts} 次){err}", f"[verbose] 任务 {event.task!r} 失败{dur} (尝试 {event.attempts} 次){err}",
flush=True, flush=True,
) )
elif event.status == TaskStatus.SKIPPED: # pragma: no branch elif event.status == TaskStatus.SKIPPED:
reason = f" ({event.reason})" if event.reason else "" reason = f" ({event.reason})" if event.reason else ""
print(f"[verbose] 任务 {event.task!r} 跳过{reason}", flush=True) print(f"[verbose] 任务 {event.task!r} 跳过{reason}", flush=True)
else: # pragma: no cover
# 不可达: 执行器只发出 RUNNING/SUCCESS/FAILED/SKIPPED 事件
pass
# 再调用用户回调
if on_event is not None: if on_event is not None:
on_event(event) on_event(event)
@@ -486,6 +666,7 @@ def run(
verbose: bool = False, verbose: bool = False,
on_event: EventCallback | None = None, on_event: EventCallback | None = None,
state: StateBackend | None = None, state: StateBackend | None = None,
concurrency_limits: Mapping[str, int] | None = None,
) -> RunReport: ) -> RunReport:
"""执行图并返回 :class:`RunReport`。 """执行图并返回 :class:`RunReport`。
@@ -494,29 +675,28 @@ def run(
graph: graph:
待执行的已校验 :class:`Graph`。 待执行的已校验 :class:`Graph`。
strategy: strategy:
执行策略, 接受 :class:`Strategy` 枚举成员或字符串 执行策略: ``"sequential"`` / ``"thread"`` / ``"async"`` /
(``"sequential"`` / ``"thread"`` / ``"async"``). 默认 ``Strategy.SEQUENTIAL``. ``"dependency"````"dependency"`` 为依赖驱动调度,无层屏障。
max_workers: max_workers:
``"thread"`` 的线程池大小。默认 ``min(32, len(layer))``。 ``"thread"`` 的线程池大小。默认 ``min(32, len(layer))``。
dry_run: dry_run:
若为 ``True``,打印执行计划(层 + 注入)并返回空报告,不执行 若为 ``True``,打印执行计划并返回空报告,不执行任务。
任何任务。
verbose: verbose:
若为 ``True``, 打印任务生命周期 (开始/成功/失败/跳过) 到 stdout. 若为 ``True``, 打印任务生命周期到 stdout
注意: subprocess 命令的输出由 :class:`TaskSpec` 的 ``verbose`` 字段控制.
on_event: on_event:
可选回调,在每次状态转换时调用。 可选回调,在每次状态转换时调用。
state: state:
可选 :class:`StateBackend`,用于断点续跑。默认为内存后端 可选 :class:`StateBackend`,用于断点续跑。
(不跨进程持久化)。 concurrency_limits:
``{concurrency_key: max_concurrent}`` 映射。具有相同
``concurrency_key`` 的任务共享信号量,限制同时运行实例数。
抛出 抛出
---- ----
ValueError ValueError
``strategy`` 不被识别时。 ``strategy`` 不被识别时。
TaskFailedError TaskFailedError
任何任务耗尽重试后仍失败时。运行在失败层中止;后续层的任务 任何任务耗尽重试后仍失败时(除非 ``continue_on_error=True``)。
不会被执行。
""" """
graph.validate() graph.validate()
layers = graph.layers() layers = graph.layers()
@@ -525,20 +705,23 @@ def run(
_print_dry_run(graph, layers) _print_dry_run(graph, layers)
return RunReport(success=True) return RunReport(success=True)
# verbose 模式下包装事件回调
effective_callback: EventCallback | None = _make_verbose_callback(on_event) if verbose else on_event effective_callback: EventCallback | None = _make_verbose_callback(on_event) if verbose else on_event
backend = resolve_backend(state) backend = resolve_backend(state)
report = RunReport() report = RunReport()
context: dict[str, Any] = {} context: dict[str, Any] = {}
limits = concurrency_limits or {}
try: try:
if strategy == "sequential": if strategy == "sequential":
_drive_sequential(graph, layers, context, report, backend, effective_callback) _drive_sequential(graph, layers, context, report, backend, effective_callback)
elif strategy == "thread": elif strategy == "thread":
_drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers) _drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers, limits)
elif strategy == "async":
_drive_async(graph, layers, context, report, backend, effective_callback, limits)
elif strategy == "dependency":
asyncio.run(_drive_dependency_async(graph, context, report, backend, effective_callback, limits))
else: else:
_drive_async(graph, layers, context, report, backend, effective_callback) raise ValueError(f"Unknown strategy: {strategy!r}")
except TaskFailedError: except TaskFailedError:
report.success = False report.success = False
raise raise
@@ -552,7 +735,7 @@ def _print_dry_run(graph: Graph, layers: list[list[str]]) -> None:
for idx, layer in enumerate(layers, 1): for idx, layer in enumerate(layers, 1):
print(f" Layer {idx}: {layer}") print(f" Layer {idx}: {layer}")
for name in layer: for name in layer:
print(f" - {describe_injection(graph.spec(name))}") print(f" - {describe_injection(graph.resolved_spec(name))}")
def _drive_sequential( def _drive_sequential(
@@ -575,10 +758,11 @@ def _drive_threaded(
backend: StateBackend, backend: StateBackend,
on_event: EventCallback | None, on_event: EventCallback | None,
max_workers: int | None, max_workers: int | None,
concurrency_limits: Mapping[str, int],
) -> None: ) -> None:
for idx, layer in enumerate(layers, 1): for idx, layer in enumerate(layers, 1):
workers = max_workers or max(1, min(32, len(layer))) workers = max_workers or max(1, min(32, len(layer)))
_execute_layer_threaded(layer, graph, context, report, backend, idx, on_event, workers) _execute_layer_threaded(layer, graph, context, report, backend, idx, on_event, workers, concurrency_limits)
def _drive_async( def _drive_async(
@@ -588,8 +772,9 @@ def _drive_async(
report: RunReport, report: RunReport,
backend: StateBackend, backend: StateBackend,
on_event: EventCallback | None, on_event: EventCallback | None,
concurrency_limits: Mapping[str, int],
) -> None: ) -> None:
asyncio.run(_async_drive(graph, layers, context, report, backend, on_event)) asyncio.run(_async_drive(graph, layers, context, report, backend, on_event, concurrency_limits))
async def _async_drive( async def _async_drive(
@@ -599,6 +784,7 @@ async def _async_drive(
report: RunReport, report: RunReport,
backend: StateBackend, backend: StateBackend,
on_event: EventCallback | None, on_event: EventCallback | None,
concurrency_limits: Mapping[str, int],
) -> None: ) -> None:
for idx, layer in enumerate(layers, 1): for idx, layer in enumerate(layers, 1):
await _execute_layer_async(layer, graph, context, report, backend, idx, on_event) await _execute_layer_async(layer, graph, context, report, backend, idx, on_event, concurrency_limits)
+202 -140
View File
@@ -2,28 +2,53 @@
使用标准库的 :mod:`graphlib`3.9+)或 :mod:`graphlib_backport`3.8 使用标准库的 :mod:`graphlib`3.9+)或 :mod:`graphlib_backport`3.8
进行拓扑排序。图以增量方式构建并即时校验,使配置错误在构建时(而非执行时)快速失败。 进行拓扑排序。图以增量方式构建并即时校验,使配置错误在构建时(而非执行时)快速失败。
支持:
* 图级默认值 :class:`GraphDefaults`TaskSpec 字段为 ``None`` 时回退。
* :meth:`Graph.map` 工厂批量生成 fan-out 任务。
* 字符串引用与 :func:`compose` 编程式组合多个图。
* 软依赖:仅用于上下文注入,不参与拓扑分层。
""" """
from __future__ import annotations from __future__ import annotations
import sys import sys
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from typing import Any, Iterable, Mapping, Sequence from typing import Any, Callable, Iterable, Mapping, Sequence
from .errors import CycleError, DuplicateTaskError, MissingDependencyError from .errors import CycleError, DuplicateTaskError, MissingDependencyError
from .task import TaskSpec from .task import RetryPolicy, TaskSpec
# graphlib 自 3.9 起进入标准库;3.8 回退到 backport。
if sys.version_info >= (3, 9): # pragma: no cover if sys.version_info >= (3, 9): # pragma: no cover
import graphlib # pyright: ignore[reportUnreachable] import graphlib # pyright: ignore[reportUnreachable]
_TopologicalSorter = graphlib.TopologicalSorter _TopologicalSorter = graphlib.TopologicalSorter
else: # pragma: no cover else: # pragma: no cover
import graphlib # type: ignore[import-untyped] # pragma: no cover import graphlib # type: ignore[import-untyped]
_TopologicalSorter = graphlib.TopologicalSorter # pragma: no cover _TopologicalSorter = graphlib.TopologicalSorter # pragma: no cover
@dataclass
class GraphDefaults:
"""图级默认值。TaskSpec 对应字段为 ``None`` 时回退到此处。
仅对可空字段生效(retry/timeout/strategy/env/cwd/tags/priority/
continue_on_error/concurrency_key)。非空字段(name/fn/cmd)不回退。
"""
retry: RetryPolicy | None = None
timeout: float | None = None
strategy: str | None = None
tags: tuple[str, ...] = ()
env: Mapping[str, str] | None = None
cwd: Any = None # Path | None
priority: int = 0
continue_on_error: bool = False
concurrency_key: str | None = None
verbose: bool = False
@dataclass @dataclass
class Graph: class Graph:
"""校验后的有向无环任务图。 """校验后的有向无环任务图。
@@ -34,16 +59,11 @@ class Graph:
图仅持有*配置*;运行时状态存于 :class:`~pyflowx.report.RunReport`。 图仅持有*配置*;运行时状态存于 :class:`~pyflowx.report.RunReport`。
这使图可安全重复运行并在线程间共享。 这使图可安全重复运行并在线程间共享。
Note
-----
Graph 不再使用 ``frozen=True``:内部 ``specs``/``deps`` 本就是可变 dict
frozen 既无法真正保证不可变,又迫使 ``_pending_refs`` 等场景用
``object.__setattr__`` 绕过。改为普通 dataclass,让赋值显式且可审计。
""" """
specs: dict[str, TaskSpec[Any]] = field(default_factory=dict) specs: dict[str, TaskSpec[Any]] = field(default_factory=dict)
deps: dict[str, tuple[str, ...]] = field(default_factory=dict) deps: dict[str, tuple[str, ...]] = field(default_factory=dict)
defaults: GraphDefaults = field(default_factory=GraphDefaults)
# 待解析的字符串引用列表(由 GraphComposer 消费);为空表示无引用。 # 待解析的字符串引用列表(由 GraphComposer 消费);为空表示无引用。
_pending_refs: list[str] = field(default_factory=list) _pending_refs: list[str] = field(default_factory=list)
@@ -51,69 +71,47 @@ class Graph:
# 构建 # 构建
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def add(self, spec: TaskSpec[Any]) -> Graph: def add(self, spec: TaskSpec[Any]) -> Graph:
"""注册一个任务 spec,并即时校验。 """注册一个任务 spec,并即时校验。返回 ``self`` 支持链式调用。"""
self._register(spec)
返回 ``self`` 以支持链式调用,但推荐入口是 :meth:`from_specs`
它会整批校验(允许单次调用中的前向引用)。
"""
if spec.name in self.specs:
raise DuplicateTaskError(spec.name)
self.specs[spec.name] = spec
self.deps[spec.name] = spec.depends_on
# 为增量 API 即时检查重名与缺失依赖。
self._validate_references() self._validate_references()
return self return self
def _register(self, spec: TaskSpec[Any]) -> None:
if spec.name in self.specs:
raise DuplicateTaskError(spec.name)
self.specs[spec.name] = spec
# 拓扑依赖仅含硬依赖;软依赖仅用于注入,不影响分层。
self.deps[spec.name] = spec.depends_on
@classmethod @classmethod
def from_specs(cls, specs: Iterable[TaskSpec[Any] | str]) -> Graph: def from_specs(
"""从可迭代的 task spec 构建图. cls,
specs: Iterable[TaskSpec[Any] | str],
defaults: GraphDefaults | None = None,
) -> Graph:
"""从可迭代的 task spec 构建图。
先收集所有 spec,再统一校验。这意味着任务可以引用*后出现*的 先收集所有 spec,再统一校验。允许前向引用。支持字符串引用,
依赖——顺序无关,就像声明式配置文件的读取方式 由 :func:`compose` 或 :class:`GraphComposer` 解析展开
支持字符串引用,允许引用其他命令图中的任务。
字符串引用将在CliRunner中解析展开。
Parameters Parameters
---------- ----------
specs : Iterable[TaskSpec[Any] | str] specs:
TaskSpec对象或字符串引用的列表 TaskSpec 对象或字符串引用的列表
defaults:
Returns 图级默认值。``None`` 使用空 :class:`GraphDefaults`。
-------
Graph
构建完成的图
Note
-----
字符串引用格式:
- "command_name" - 引用整个命令图
- "command_name.task_name" - 引用特定任务
Examples
--------
>>> graph = Graph.from_specs([
... TaskSpec("build", cmd=["uv", "build"]),
... "test", # 引用test命令图
... ])
""" """
graph = cls() graph = cls(defaults=defaults or GraphDefaults())
pending_refs: list[str] = [] pending_refs: list[str] = []
for spec in specs: for spec in specs:
if isinstance(spec, str): if isinstance(spec, str):
# 字符串引用,稍后解析
pending_refs.append(spec) pending_refs.append(spec)
elif isinstance(spec, TaskSpec): elif isinstance(spec, TaskSpec):
if spec.name in graph.specs: graph._register(spec)
raise DuplicateTaskError(spec.name)
graph.specs[spec.name] = spec
graph.deps[spec.name] = spec.depends_on
else: else:
raise TypeError(f"from_specs只接受TaskSpecstr,收到: {type(spec)}") raise TypeError(f"from_specs 只接受 TaskSpecstr,收到: {type(spec)}")
# 存储待解析的引用,稍后由 GraphComposer 解析展开。
# Graph 不再 frozen,可直接赋值;保留属性名以保持向后兼容。
if pending_refs: if pending_refs:
graph._pending_refs = pending_refs graph._pending_refs = pending_refs
@@ -125,26 +123,22 @@ class Graph:
# 校验 # 校验
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def _validate_references(self) -> None: def _validate_references(self) -> None:
"""确保每个依赖名都存在于图中。""" """确保每个依赖名都存在于图中。硬依赖与软依赖都校验。"""
for name, deps in self.deps.items(): for name, spec in self.specs.items():
for dep in deps: for dep in spec.depends_on:
if dep not in self.specs:
raise MissingDependencyError(name, dep)
for dep in spec.soft_depends_on:
if dep not in self.specs: if dep not in self.specs:
raise MissingDependencyError(name, dep) raise MissingDependencyError(name, dep)
def validate(self) -> None: def validate(self) -> None:
"""执行完整 DAG 校验。 """执行完整 DAG 校验。存在环时抛出 :class:`CycleError`。"""
存在环时抛出 :class:`~pyflowx.errors.CycleError`。
依赖存在性由 :meth:`_validate_references` 检查。
"""
self._validate_references() self._validate_references()
sorter = _TopologicalSorter(self.deps) sorter = _TopologicalSorter(self.deps)
try: try:
# prepare() 在有环时抛出 CycleError;此处不需要
# static_order() 的结果,仅利用其校验副作用。
sorter.prepare() sorter.prepare()
except graphlib.CycleError as exc: except graphlib.CycleError as exc: # type: ignore[name-defined]
# exc.args[1] 是构成环的节点列表。
cycle: Sequence[str] = exc.args[1] if len(exc.args) > 1 else [] cycle: Sequence[str] = exc.args[1] if len(exc.args) > 1 else []
raise CycleError(list(cycle)) from exc raise CycleError(list(cycle)) from exc
@@ -160,10 +154,49 @@ class Graph:
"""返回 ``name`` 的 spec;不存在则 ``KeyError``。""" """返回 ``name`` 的 spec;不存在则 ``KeyError``。"""
return self.specs[name] return self.specs[name]
def resolved_spec(self, name: str) -> TaskSpec[Any]:
"""返回应用图级默认值后的 spec(不修改原图)。
对于 ``retry``/``timeout``/``strategy``/``env``/``cwd`` 等可空
字段,若 spec 字段为默认空值且图级默认值非空,则用
:func:`dataclasses.replace` 生成带默认值的副本。
"""
spec = self.specs[name]
d = self.defaults
overrides: dict[str, Any] = {}
if spec.retry == RetryPolicy() and d.retry is not None:
overrides["retry"] = d.retry
if spec.timeout is None and d.timeout is not None:
overrides["timeout"] = d.timeout
if spec.strategy is None and d.strategy is not None:
overrides["strategy"] = d.strategy
if spec.env is None and d.env is not None:
overrides["env"] = d.env
if spec.cwd is None and d.cwd is not None:
overrides["cwd"] = d.cwd
if spec.priority == 0 and d.priority != 0:
overrides["priority"] = d.priority
if not spec.continue_on_error and d.continue_on_error:
overrides["continue_on_error"] = True
if spec.concurrency_key is None and d.concurrency_key is not None:
overrides["concurrency_key"] = d.concurrency_key
if not spec.verbose and d.verbose:
overrides["verbose"] = True
if not spec.tags and d.tags:
overrides["tags"] = d.tags
if not overrides:
return spec
return replace(spec, **overrides)
def dependencies(self, name: str) -> tuple[str, ...]: def dependencies(self, name: str) -> tuple[str, ...]:
"""``name`` 的直接前驱。""" """``name`` 的直接硬依赖前驱。"""
return self.deps[name] return self.deps[name]
def all_deps(self, name: str) -> tuple[str, ...]:
"""``name`` 的硬依赖 + 软依赖。"""
spec = self.specs[name]
return tuple(spec.depends_on) + tuple(spec.soft_depends_on)
def all_specs(self) -> Mapping[str, TaskSpec[Any]]: def all_specs(self) -> Mapping[str, TaskSpec[Any]]:
"""name -> spec 的只读视图。""" """name -> spec 的只读视图。"""
return self.specs return self.specs
@@ -171,18 +204,15 @@ class Graph:
def layers(self) -> list[list[str]]: def layers(self) -> list[list[str]]:
"""将任务分组为可并行执行的层(Kahn 算法)。 """将任务分组为可并行执行的层(Kahn 算法)。
同层任务无相互依赖,可并发执行。层按执行顺序返回 同层任务无相互依赖,可并发执行。软依赖不参与分层
层按执行顺序返回。图有环时抛出 :class:`CycleError`。
图有环时抛出 :class:`~pyflowx.errors.CycleError`。
""" """
self.validate() self.validate()
sorter = _TopologicalSorter(self.deps) sorter = _TopologicalSorter(self.deps)
result: list[list[str]] = [] result: list[list[str]] = []
# ``get_ready`` + ``done`` 每次给出一层,正好是并行执行所需的分组。
sorter.prepare() sorter.prepare()
while sorter.is_active(): while sorter.is_active():
ready = list(sorter.get_ready()) ready = list(sorter.get_ready())
# 排序以保证确定性、可复现的执行计划。
ready.sort() ready.sort()
result.append(ready) result.append(ready)
for node in ready: for node in ready:
@@ -193,12 +223,7 @@ class Graph:
# 子图 / 标签过滤 # 子图 / 标签过滤
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def subgraph(self, tags: Iterable[str]) -> Graph: def subgraph(self, tags: Iterable[str]) -> Graph:
"""返回仅包含匹配任意标签的任务的新图。 """返回仅包含匹配任意标签的任务的新图。依赖边被修剪。"""
依赖会被修剪,仅保留被保留任务之间的边;指向被丢弃任务的边
会被移除(被保留的任务不再等待它们)。用于调试时运行大型
DAG 的切片。
"""
wanted: set[str] = set(tags) wanted: set[str] = set(tags)
kept: list[TaskSpec[Any]] = [] kept: list[TaskSpec[Any]] = []
for spec in self.specs.values(): for spec in self.specs.values():
@@ -206,10 +231,11 @@ class Graph:
pruned_deps = tuple( pruned_deps = tuple(
d for d in spec.depends_on if d in self.specs and (wanted & set(self.specs[d].tags)) d for d in spec.depends_on if d in self.specs and (wanted & set(self.specs[d].tags))
) )
# 使用 replace 保留所有字段(verbose/skip_if_missing/allow_upstream_skip 等), pruned_soft = tuple(
# 避免手动逐字段重建时遗漏新增字段。 d for d in spec.soft_depends_on if d in self.specs and (wanted & set(self.specs[d].tags))
kept.append(replace(spec, depends_on=pruned_deps)) )
return Graph.from_specs(kept) kept.append(replace(spec, depends_on=pruned_deps, soft_depends_on=pruned_soft))
return Graph.from_specs(kept, defaults=self.defaults)
def subgraph_by_names(self, names: Iterable[str]) -> Graph: def subgraph_by_names(self, names: Iterable[str]) -> Graph:
"""返回限定于 ``names`` 的新图(边已修剪)。""" """返回限定于 ``names`` 的新图(边已修剪)。"""
@@ -221,18 +247,71 @@ class Graph:
for spec in self.specs.values(): for spec in self.specs.values():
if spec.name in wanted: if spec.name in wanted:
pruned_deps = tuple(d for d in spec.depends_on if d in wanted) pruned_deps = tuple(d for d in spec.depends_on if d in wanted)
kept.append(replace(spec, depends_on=pruned_deps)) pruned_soft = tuple(d for d in spec.soft_depends_on if d in wanted)
return Graph.from_specs(kept) kept.append(replace(spec, depends_on=pruned_deps, soft_depends_on=pruned_soft))
return Graph.from_specs(kept, defaults=self.defaults)
# ------------------------------------------------------------------ #
# Fan-out / map-reduce
# ------------------------------------------------------------------ #
def map(
self,
name_fn: Callable[[int], str],
spec: TaskSpec[Any],
items: Sequence[Any],
arg_factory: Callable[[Any], tuple[Any, ...]] | None = None,
depends_on_per: Callable[[int], tuple[str, ...]] | None = None,
) -> list[TaskSpec[Any]]:
"""为 ``items`` 中每个元素生成一个 TaskSpec 并加入图。
用于 fan-out / map-reduce 模式。返回生成的 spec 列表,便于
后续 reduce 任务依赖。
Parameters
----------
name_fn:
接受索引 ``i``,返回任务名。需保证唯一。
spec:
模板 spec。其 ``name`` 与 ``args`` 会被覆盖。
items:
待分发的数据序列。
arg_factory:
接受一个 item,返回位置参数元组,覆盖 spec.args。
``None`` 则将单个 item 作为唯一位置参数。
depends_on_per:
接受索引 ``i``,返回该任务的额外硬依赖。``None`` 则继承 spec.depends_on。
Returns
-------
list[TaskSpec]
生成的 spec 列表(已加入图)。
Examples
--------
>>> fetch_tmpl = px.TaskSpec("", fn=fetch_user)
>>> specs = graph.map(lambda i: f"fetch_{i}", fetch_tmpl, [1, 2, 3])
>>> reduce_spec = px.TaskSpec("reduce", fn=reduce_fn, depends_on=tuple(s.name for s in specs))
"""
generated: list[TaskSpec[Any]] = []
for i, item in enumerate(items):
name = name_fn(i)
args = arg_factory(item) if arg_factory is not None else (item,)
extra_deps = depends_on_per(i) if depends_on_per is not None else ()
new_spec = replace(
spec,
name=name,
args=tuple(args),
depends_on=tuple(spec.depends_on) + tuple(extra_deps),
)
self.add(new_spec)
generated.append(new_spec)
return generated
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
# 可视化 # 可视化
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def to_mermaid(self, orientation: str = "TD") -> str: def to_mermaid(self, orientation: str = "TD") -> str:
"""将 DAG 渲染为 Mermaid ``graph`` 定义字符串。 """将 DAG 渲染为 Mermaid ``graph`` 定义字符串。"""
无外部依赖;输出可粘贴到 Markdown、由 VS Code 的 Mermaid 预览
渲染,或保存为文件。
"""
valid = {"TD", "TB", "BT", "LR", "RL"} valid = {"TD", "TB", "BT", "LR", "RL"}
orientation = orientation.upper() orientation = orientation.upper()
if orientation not in valid: if orientation not in valid:
@@ -243,6 +322,10 @@ class Graph:
for name, deps in self.deps.items(): for name, deps in self.deps.items():
for dep in deps: for dep in deps:
lines.append(f" {dep} --> {name}") lines.append(f" {dep} --> {name}")
# 软依赖用虚线
for name, spec in self.specs.items():
for dep in spec.soft_depends_on:
lines.append(f" {dep} -.-> {name}")
return "\n".join(lines) + "\n" return "\n".join(lines) + "\n"
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
@@ -268,19 +351,12 @@ class Graph:
class GraphComposer: class GraphComposer:
"""将带字符串引用的图展开为纯 :class:`TaskSpec` 图。 """将带字符串引用的图展开为纯 :class:`TaskSpec` 图。
从 ``CliRunner`` 抽出,使 ``Graph``(数据)与引用解析(组合逻辑) 引用格式:
职责分离。引用按顺序展开,后续引用的任务依赖前面引用的最后一个任务;
原始 ``TaskSpec`` 之间也按出现顺序串行依赖。
引用格式
--------
* ``"command_name"`` —— 引用整个命令图。 * ``"command_name"`` —— 引用整个命令图。
* ``"command_name.task_name"`` —— 引用特定任务。 * ``"command_name.task_name"`` —— 引用特定任务。
Parameters 引用按顺序展开,后续引用的任务依赖前面引用的最后一个任务;
---------- 原始 ``TaskSpec`` 之间也按出现顺序串行依赖。
graphs : dict[str, Graph]
命令名到图的映射,引用据此解析。
""" """
def __init__(self, graphs: dict[str, Graph]) -> None: def __init__(self, graphs: dict[str, Graph]) -> None:
@@ -294,18 +370,7 @@ class GraphComposer:
return resolved return resolved
def expand_refs(self, graph: Graph, current_cmd: str) -> Graph: def expand_refs(self, graph: Graph, current_cmd: str) -> Graph:
"""展开图中的字符串引用。 """展开图中的字符串引用。若无 ``_pending_refs``,原样返回。"""
若图无 ``_pending_refs``,原样返回。
Note
-----
引用按顺序展开,后续引用的任务依赖于前面引用的任务完成。
例如 ``["c", "tc", bump]`` 展开为:
- c 的所有任务(无依赖)
- tc 的所有任务(依赖于 c 的最后一个任务)
- bump 任务(依赖于 tc 的最后一个任务)
"""
pending_refs = graph._pending_refs pending_refs = graph._pending_refs
if not pending_refs: if not pending_refs:
return graph return graph
@@ -313,23 +378,16 @@ class GraphComposer:
all_specs: list[TaskSpec[Any]] = [] all_specs: list[TaskSpec[Any]] = []
previous_ref_last_task: str | None = None previous_ref_last_task: str | None = None
# 先解析每个引用,并建立依赖链。
for ref in pending_refs: for ref in pending_refs:
expanded_specs = self.parse_ref(ref, current_cmd) expanded_specs = self.parse_ref(ref, current_cmd)
# 若有前一个引用,让当前引用的任务依赖其最后一个任务。
if previous_ref_last_task and expanded_specs: if previous_ref_last_task and expanded_specs:
for i, task in enumerate(expanded_specs): for i, task in enumerate(expanded_specs):
# 只为没有依赖的任务(或第一个任务)添加依赖。
if i == 0 or not task.depends_on: if i == 0 or not task.depends_on:
expanded_specs[i] = replace(task, depends_on=tuple({*task.depends_on, previous_ref_last_task})) expanded_specs[i] = replace(task, depends_on=tuple({*task.depends_on, previous_ref_last_task}))
if expanded_specs: if expanded_specs:
previous_ref_last_task = expanded_specs[-1].name previous_ref_last_task = expanded_specs[-1].name
all_specs.extend(expanded_specs) all_specs.extend(expanded_specs)
# 然后添加原始 TaskSpec,按出现顺序串行依赖。
original_specs = list(graph.all_specs().values()) original_specs = list(graph.all_specs().values())
if original_specs: if original_specs:
if previous_ref_last_task: if previous_ref_last_task:
@@ -337,49 +395,53 @@ class GraphComposer:
all_specs.append(replace(first, depends_on=tuple({*first.depends_on, previous_ref_last_task}))) all_specs.append(replace(first, depends_on=tuple({*first.depends_on, previous_ref_last_task})))
else: else:
all_specs.append(original_specs[0]) all_specs.append(original_specs[0])
for i in range(1, len(original_specs)): for i in range(1, len(original_specs)):
current_task = original_specs[i] current_task = original_specs[i]
previous_task_name = original_specs[i - 1].name previous_task_name = original_specs[i - 1].name
all_specs.append( all_specs.append(
replace( replace(current_task, depends_on=tuple({*current_task.depends_on, previous_task_name}))
current_task,
depends_on=tuple({*current_task.depends_on, previous_task_name}),
)
) )
return Graph.from_specs(all_specs) return Graph.from_specs(all_specs, defaults=graph.defaults)
def parse_ref(self, ref: str, current_cmd: str) -> list[TaskSpec[Any]]: def parse_ref(self, ref: str, current_cmd: str) -> list[TaskSpec[Any]]:
"""解析单个字符串引用,返回对应的 TaskSpec 列表。 """解析单个字符串引用,返回对应的 TaskSpec 列表。"""
Raises
------
ValueError
引用无效、目标命令/任务不存在,或检测到循环引用。
"""
# 避免循环引用。
if ref == current_cmd: if ref == current_cmd:
raise ValueError(f"循环引用: 命令 '{current_cmd}' 引用了自己") raise ValueError(f"循环引用: 命令 '{current_cmd}' 引用了自己")
if "." in ref: if "." in ref:
# 特定任务引用: "command_name.task_name"
cmd_name, task_name = ref.split(".", 1) cmd_name, task_name = ref.split(".", 1)
if cmd_name not in self.graphs: if cmd_name not in self.graphs:
raise ValueError(f"引用的命令 '{cmd_name}' 不存在") raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
ref_graph = self.graphs[cmd_name] ref_graph = self.graphs[cmd_name]
if task_name not in ref_graph.all_specs(): if task_name not in ref_graph.all_specs():
raise ValueError(f"任务 '{task_name}' 不存在于命令 '{cmd_name}'") raise ValueError(f"任务 '{task_name}' 不存在于命令 '{cmd_name}'")
return [ref_graph.all_specs()[task_name]] return [ref_graph.all_specs()[task_name]]
else: else:
# 整个命令图引用: "command_name"
cmd_name = ref cmd_name = ref
if cmd_name not in self.graphs: if cmd_name not in self.graphs:
raise ValueError(f"引用的命令 '{cmd_name}' 不存在") raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
ref_graph = self.graphs[cmd_name] ref_graph = self.graphs[cmd_name]
# 递归展开(若引用的图自身也含引用)。
ref_graph = self.expand_refs(ref_graph, cmd_name) ref_graph = self.expand_refs(ref_graph, cmd_name)
return list(ref_graph.all_specs().values()) return list(ref_graph.all_specs().values())
def compose(
graphs: dict[str, Graph],
) -> dict[str, Graph]:
"""编程式解析多图的字符串引用,返回展开后的新图映射。
与 :class:`GraphComposer` 等价,但作为独立函数暴露,供不使用
:class:`~pyflowx.runner.CliRunner` 的编程式用户调用。
Examples
--------
>>> graphs = {
... "build": px.Graph.from_specs([px.TaskSpec("b", cmd=["echo", "b"])]),
... "all": px.Graph.from_specs(["build", px.TaskSpec("t", cmd=["echo", "t"])]),
... }
>>> resolved = px.compose(graphs)
>>> "b" in resolved["all"].all_specs()
True
"""
return GraphComposer(graphs).resolve_all()
+80 -41
View File
@@ -4,20 +4,18 @@
执行器向后端查询某任务是否已有存储结果;若有则跳过该任务,并将其 执行器向后端查询某任务是否已有存储结果;若有则跳过该任务,并将其
存储值注入下游任务。 存储值注入下游任务。
本模块刻意保持最小化:仅持久化*成功*结果(失败任务会重跑),存储 存储键由 :meth:`TaskSpec.storage_key` 计算,默认为任务名;若任务配置
形态为扁平的 ``{task_name: result}`` 映射。内置两个后端: 了 ``cache_key``,则键为 ``"name:cache_key_value"``,使不同输入产生
独立缓存条目。
* :class:`MemoryBackend` —— 快速、进程内、无 I/O。默认 支持 TTL``has`` 在条目过期时返回 ``False``
* :class:`JSONBackend` —— 持久化到 JSON 文件,支持跨进程续跑。
两者均零依赖(``json`` 为标准库)。用户可子类化
:class:`StateBackend` 接入 SQLite、Redis 等。
""" """
from __future__ import annotations from __future__ import annotations
import json import json
import sys import sys
import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import Any, Mapping from typing import Any, Mapping
@@ -31,23 +29,26 @@ from .errors import StorageError
class StateBackend(ABC): class StateBackend(ABC):
"""可续跑状态存储的抽象基类。""" """可续跑状态存储的抽象基类。
所有方法以 ``key`` 为参数(通常为任务名或 ``name:cache_key``)。
"""
@abstractmethod @abstractmethod
def load(self) -> Mapping[str, Any]: def load(self) -> Mapping[str, Any]:
"""返回完整的存储映射(可能为空)。""" """返回完整的存储映射(可能为空)。"""
@abstractmethod @abstractmethod
def save(self, name: str, value: Any) -> None: def save(self, key: str, value: Any) -> None:
"""持久化单个任务的成功结果。""" """持久化单个任务的成功结果。"""
@abstractmethod @abstractmethod
def has(self, name: str) -> bool: def has(self, key: str) -> bool:
"""``name`` 是否已有存储结果。""" """``key`` 是否已有未过期的存储结果。"""
@abstractmethod @abstractmethod
def get(self, name: str) -> Any: def get(self, key: str) -> Any:
"""返回 ``name`` 的存储结果(不存在则抛 ``KeyError``)。""" """返回 ``key`` 的存储结果(不存在则抛 ``KeyError``)。"""
@abstractmethod @abstractmethod
def clear(self) -> None: def clear(self) -> None:
@@ -55,43 +56,66 @@ class StateBackend(ABC):
class MemoryBackend(StateBackend): class MemoryBackend(StateBackend):
"""进程内 dict 后端。进程退出即丢失。""" """进程内 dict 后端。进程退出即丢失。
def __init__(self) -> None: Parameters
self._store: dict[str, Any] = {} ----------
ttl:
条目存活秒数。``None`` 表示永不过期。``has`` 在条目超过 ttl 后
返回 ``False``(但不主动删除,下次 ``save`` 覆盖)。
"""
def __init__(self, ttl: float | None = None) -> None:
self._store: dict[str, tuple[Any, float]] = {}
self._ttl = ttl
@override @override
def load(self) -> Mapping[str, Any]: def load(self) -> Mapping[str, Any]:
return dict(self._store) return {k: v for k, (v, _ts) in self._store.items() if not self._expired(k)}
@override @override
def save(self, name: str, value: Any) -> None: def save(self, key: str, value: Any) -> None:
self._store[name] = value self._store[key] = (value, time.monotonic())
@override @override
def has(self, name: str) -> bool: def has(self, key: str) -> bool:
return name in self._store return key in self._store and not self._expired(key)
@override @override
def get(self, name: str) -> Any: def get(self, key: str) -> Any:
return self._store[name] if key not in self._store or self._expired(key):
raise KeyError(key)
return self._store[key][0]
@override @override
def clear(self) -> None: def clear(self) -> None:
self._store.clear() self._store.clear()
def _expired(self, key: str) -> bool:
if self._ttl is None or key not in self._store:
return False
_value, ts = self._store[key]
return (time.monotonic() - ts) > self._ttl
class JSONBackend(StateBackend): class JSONBackend(StateBackend):
"""基于文件的 JSON 存储,用于跨进程续跑。 """基于文件的 JSON 存储,用于跨进程续跑。
结果必须可 JSON 序列化。不可序列化的值会抛出 存储格式:``{key: {"value": v, "ts": epoch_seconds}}``。
:class:`~pyflowx.errors.StorageError`(运行本身不会中止;仅该条 ``ts`` 用于 TTL 判断。结果必须可 JSON 序列化。
结果的持久化失败)。
Parameters
----------
path:
JSON 文件路径。
ttl:
条目存活秒数。``None`` 表示永不过期。
""" """
def __init__(self, path: str) -> None: def __init__(self, path: str, ttl: float | None = None) -> None:
self._path: str = path self._path: str = path
self._store: dict[str, Any] = {} self._ttl = ttl
self._store: dict[str, dict[str, Any]] = {}
self._load() self._load()
def _load(self) -> None: def _load(self) -> None:
@@ -101,7 +125,14 @@ class JSONBackend(StateBackend):
with open(self._path, encoding="utf-8") as fh: with open(self._path, encoding="utf-8") as fh:
data: Any = json.load(fh) data: Any = json.load(fh)
if isinstance(data, dict): if isinstance(data, dict):
self._store = data # 兼容纯值格式与带元数据格式
self._store = {}
for k, v in data.items():
if isinstance(v, dict) and "value" in v and "ts" in v:
self._store[k] = v
else:
# 旧格式:纯值
self._store[k] = {"value": v, "ts": time.time()}
except (OSError, json.JSONDecodeError) as exc: except (OSError, json.JSONDecodeError) as exc:
raise StorageError(f"cannot read state file {self._path!r}", exc) from exc raise StorageError(f"cannot read state file {self._path!r}", exc) from exc
@@ -110,32 +141,40 @@ class JSONBackend(StateBackend):
try: try:
with open(tmp, "w", encoding="utf-8") as fh: with open(tmp, "w", encoding="utf-8") as fh:
json.dump(self._store, fh, ensure_ascii=False, indent=2) json.dump(self._store, fh, ensure_ascii=False, indent=2)
_ = Path(tmp).replace(Path(self._path)) _ = Path(tmp).replace(Path(self._path))
except (OSError, TypeError) as exc: except (OSError, TypeError) as exc:
raise StorageError(f"cannot write state file {self._path!r}", exc) from exc raise StorageError(f"cannot write state file {self._path!r}", exc) from exc
@override def _now(self) -> float:
def load(self) -> Mapping[str, Any]: return time.time()
return dict(self._store)
def _expired(self, entry: dict[str, Any]) -> bool:
if self._ttl is None:
return False
return (self._now() - float(entry.get("ts", 0))) > self._ttl
@override @override
def save(self, name: str, value: Any) -> None: def load(self) -> Mapping[str, Any]:
# 在修改内存状态前先校验可序列化性。 return {k: v["value"] for k, v in self._store.items() if not self._expired(v)}
@override
def save(self, key: str, value: Any) -> None:
try: try:
_ = json.dumps(value) _ = json.dumps(value)
except (TypeError, ValueError) as exc: except (TypeError, ValueError) as exc:
raise StorageError(f"result of task {name!r} is not JSON-serialisable", exc) from exc raise StorageError(f"result of key {key!r} is not JSON-serialisable", exc) from exc
self._store[name] = value self._store[key] = {"value": value, "ts": self._now()}
self._flush() self._flush()
@override @override
def has(self, name: str) -> bool: def has(self, key: str) -> bool:
return name in self._store return key in self._store and not self._expired(self._store[key])
@override @override
def get(self, name: str) -> Any: def get(self, key: str) -> Any:
return self._store[name] if key not in self._store or self._expired(self._store[key]):
raise KeyError(key)
return self._store[key]["value"]
@override @override
def clear(self) -> None: def clear(self) -> None:
+297 -119
View File
@@ -15,9 +15,11 @@
* ``TaskStatus`` 是封闭枚举;执行器绝不发明临时字符串。 * ``TaskStatus`` 是封闭枚举;执行器绝不发明临时字符串。
""" """
import os
import shutil import shutil
import subprocess import subprocess
import sys import sys
from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
@@ -25,8 +27,10 @@ from pathlib import Path
from typing import ( from typing import (
Any, Any,
Callable, Callable,
ContextManager,
Coroutine, Coroutine,
Generic, Generic,
Iterator,
List, List,
Mapping, Mapping,
Optional, Optional,
@@ -59,8 +63,95 @@ TaskCmd = Union[
Callable[..., Any], # Python 函数 Callable[..., Any], # Python 函数
] ]
# 条件判断函数类型 # 执行策略:sequential/thread/async 为层屏障模型,dependency 为依赖驱动模型。
Condition = Callable[[], bool] Strategy = Union[str, "StrategyKind"]
StrategyKind = Any # 占位,避免循环;executors 模块用 Literal 约束
# 条件判断函数类型:接收依赖上下文(可能为空映射),返回是否应执行。
Condition = Callable[[Context], bool]
# 缓存键计算函数:基于依赖上下文计算稳定字符串键。
CacheKeyFn = Callable[[Context], str]
# ---------------------------------------------------------------------- #
# 重试策略
# ---------------------------------------------------------------------- #
@dataclass(frozen=True)
class RetryPolicy:
"""任务失败重试策略。
参数
----
max_attempts:
最大尝试次数(含首次)。``1`` 表示仅尝试一次,不重试。
delay:
两次尝试之间的初始等待秒数。
backoff:
退避倍率。第 n 次重试等待 ``delay * backoff ** (n-1)``。
jitter:
抖动上限秒数。每次等待加上 ``[0, jitter)`` 的随机量,避免惊群。
retry_on:
仅对这些异常类型重试。默认 ``(Exception,)`` 重试所有异常。
传入空元组等价于不重试。
Note
-----
替代旧版 ``retries: int``。``retries=2`` 等价于
``RetryPolicy(max_attempts=3)``。
"""
max_attempts: int = 1
delay: float = 0.0
backoff: float = 1.0
jitter: float = 0.0
retry_on: Tuple[type[BaseException], ...] = (Exception,)
def __post_init__(self) -> None:
if self.max_attempts < 1:
raise ValueError(f"RetryPolicy.max_attempts must be >= 1, got {self.max_attempts}.")
if self.delay < 0:
raise ValueError(f"RetryPolicy.delay must be >= 0, got {self.delay}.")
if self.backoff < 0:
raise ValueError(f"RetryPolicy.backoff must be >= 0, got {self.backoff}.")
if self.jitter < 0:
raise ValueError(f"RetryPolicy.jitter must be >= 0, got {self.jitter}.")
@property
def retries(self) -> int:
"""重试次数(不含首次),等价于 ``max_attempts - 1``。"""
return self.max_attempts - 1
def should_retry(self, exc: BaseException) -> bool:
"""异常是否属于可重试类型。"""
return isinstance(exc, self.retry_on)
def wait_seconds(self, attempt: int) -> float:
"""第 ``attempt`` 次失败后应等待的秒数(attempt 从 1 开始)。"""
if attempt < 1:
return 0.0
import random
base = self.delay * (self.backoff ** max(0, attempt - 1))
jitter = random.uniform(0, self.jitter) if self.jitter > 0 else 0.0
return base + jitter
# ---------------------------------------------------------------------- #
# 任务钩子
# ---------------------------------------------------------------------- #
@dataclass(frozen=True)
class TaskHooks:
"""任务生命周期钩子。
所有钩子均为可选。``pre_run`` 在任务实际执行前调用;``post_run``
在成功后调用并接收返回值;``on_failure`` 在最终失败后调用并接收异常。
钩子异常不会影响任务状态,仅记录日志。
"""
pre_run: Optional[Callable[["TaskSpec[Any]"], None]] = None
post_run: Optional[Callable[["TaskSpec[Any]", Any], None]] = None
on_failure: Optional[Callable[["TaskSpec[Any]", BaseException], None]] = None
class TaskStatus(Enum): class TaskStatus(Enum):
@@ -90,181 +181,239 @@ class TaskSpec(Generic[T]):
- ``list[str]``: 命令及参数列表,如 ``["ls", "-la"]`` - ``list[str]``: 命令及参数列表,如 ``["ls", "-la"]``
- ``str``: shell 命令字符串,如 ``"pip freeze > requirements.txt"`` - ``str``: shell 命令字符串,如 ``"pip freeze > requirements.txt"``
- ``Callable``: Python 函数,与 ``fn`` 参数等效 - ``Callable``: Python 函数,与 ``fn`` 参数等效
若提供此参数,会自动包装为执行函数,覆盖 ``fn`` 参数。
depends_on: depends_on:
必须先完成才运行本任务的任务名列表。顺序无关;框架会做 硬依赖任务名。必须全部成功完成才运行本任务
拓扑排序。 上游被 SKIPPED 时,本任务也会被 SKIPPED(除非
``allow_upstream_skip=True``)。
soft_depends_on:
软依赖任务名。会等待其完成,但其结果不影响本任务是否执行:
- 上游成功:注入其返回值
- 上游 SKIPPED 或失败:注入 :attr:`defaults` 中提供的默认值
适用于"可选输入"场景。
defaults:
软依赖的默认值映射 ``{dep_name: default_value}``。
软依赖未提供结果时使用。未在 defaults 中出现的软依赖默认为 ``None``。
args: args:
静态位置参数,追加在注入参数*之后*。适用于参数化任务 静态位置参数,追加在注入参数*之后*。
(如 ``fetch_user(uid)``)。
kwargs: kwargs:
静态关键字参数。若与注入名冲突则抛出 静态关键字参数。若与注入名冲突则抛出
:class:`~pyflowx.errors.InjectionError`。 :class:`~pyflowx.errors.InjectionError`。
retries: retry:
失败后的重试次数。``0`` 表示仅尝试一次。 :class:`RetryPolicy` 重试策略。默认仅尝试一次。
timeout: timeout:
最大执行时长(秒)。``None`` 表示不限制。异步任务使用 最大执行时长(秒)。``None`` 表示不限制。异步任务使用
:func:`asyncio.wait_for`线程/异步执行器中的同步任务会 :func:`asyncio.wait_for`同步任务通过线程 future 取消。
取消 worker future。
tags: tags:
自由标签,供 :meth:`Graph.subgraph` 做选择性执行与调试 自由标签,供 :meth:`Graph.subgraph` 做选择性执行与调试
也可用于并发限制分组。
conditions: conditions:
条件判断函数列表,只有所有条件都返回 ``True`` 时才执行任务。 条件判断函数列表,接收依赖上下文,全部返回 ``True`` 时才执行任务。
任一条件返回 ``False``任务被标记为 SKIPPED。 任一返回 ``False``任务被标记为 SKIPPED。
用于平台判断、环境变量检查等场景。
cwd: cwd:
命令执行的工作目录,仅在使用 ``cmd`` 参数时有效。 工作目录。对 ``cmd`` 任务作为子进程工作目录;对 ``fn`` 任务
``None`` 表示当前目录。 通过临时切换当前目录生效
env:
环境变量覆盖映射。对 ``cmd`` 任务合并到子进程环境;对 ``fn``
任务在执行期间临时设置。
verbose: verbose:
是否在命令执行时显示详细输出。``True`` 时打印执行的命令 是否打印详细输出。``True`` 时打印执行的命令、返回码与输出
及其标准输出/标准错误。仅在使用 ``cmd`` 参数时有效 (仅 ``cmd``),以及任务生命周期
``False`` 时静默捕获输出(失败时仍会包含在错误信息中)。
skip_if_missing: skip_if_missing:
仅对 ``cmd`` 为 ``list[str]`` 的任务有效。``True`` 时自动检查 仅对 ``cmd`` 为 ``list[str]`` 有效。``True`` 时通过
命令是否存在(通过 :func:`shutil.which`,不存在则跳过任务 :func:`shutil.which` 检查命令是否存在,不存在则跳过
(标记为 SKIPPED)而非失败。适用于构建工具场景,避免因未安装
某些工具(如 maturin、tox)而导致整个图执行失败。
对于 ``str`` (shell) 和 ``Callable`` 类型的 ``cmd``,此参数无效。
allow_upstream_skip: allow_upstream_skip:
若为 ``True``当上游任务因条件不满足被跳过时,本任务仍执行 若为 ``True``硬依赖被 SKIPPED 时本任务仍执行(软依赖不影响)。
(而非跟随跳过)。适用于清理类任务:即使某些删除操作因目标不存在 适用于清理类任务。
而跳过,后续操作(如重启服务)仍应执行。默认为 ``False``。 strategy:
单任务执行策略覆盖。``None`` 表示继承图级策略。
``"sequential"`` 同步直接调用;``"thread"``/``"async"`` 将同步
任务卸载到线程池,异步任务跑在事件循环上。
priority:
同层任务调度优先级。数值越大越先启动。仅影响同层内启动顺序,
不打破层屏障。默认 ``0``。
concurrency_key:
并发限制分组键。具有相同键的任务共享一个信号量,限制同时
运行的实例数。具体限额由 :func:`run` 的 ``concurrency_limits``
参数提供 ``{key: limit}`` 映射。``None`` 表示不限制。
continue_on_error:
若为 ``True``,任务最终失败时不中止整图,仅标记本任务 FAILED,
其硬依赖下游被 SKIPPED,其余任务继续。默认 ``False``。
cache_key:
缓存键计算函数。若提供,则用其基于依赖上下文计算的字符串键
存取状态后端,使不同输入产生独立缓存条目。``None`` 表示用任务名。
hooks:
:class:`TaskHooks` 生命周期钩子。
""" """
name: str name: str
fn: Optional[TaskFn[T]] = None fn: Optional[TaskFn[T]] = None
cmd: Optional[TaskCmd] = None cmd: Optional[TaskCmd] = None
depends_on: Tuple[str, ...] = () depends_on: Tuple[str, ...] = ()
soft_depends_on: Tuple[str, ...] = ()
defaults: Mapping[str, Any] = field(default_factory=dict)
args: Tuple[Any, ...] = () args: Tuple[Any, ...] = ()
kwargs: Mapping[str, Any] = field(default_factory=dict) kwargs: Mapping[str, Any] = field(default_factory=dict)
retries: int = 0 retry: RetryPolicy = field(default_factory=RetryPolicy)
timeout: Optional[float] = None timeout: Optional[float] = None
tags: Tuple[str, ...] = () tags: Tuple[str, ...] = ()
conditions: Tuple[Condition, ...] = () conditions: Tuple[Condition, ...] = ()
cwd: Optional[Path] = None cwd: Optional[Path] = None
env: Optional[Mapping[str, str]] = None
verbose: bool = False verbose: bool = False
skip_if_missing: bool = False skip_if_missing: bool = False
allow_upstream_skip: bool = False allow_upstream_skip: bool = False
strategy: Optional[str] = None
priority: int = 0
concurrency_key: Optional[str] = None
continue_on_error: bool = False
cache_key: Optional[CacheKeyFn] = None
hooks: TaskHooks = field(default_factory=TaskHooks)
def __post_init__(self) -> None: def __post_init__(self) -> None:
if not self.name: if not self.name:
raise ValueError("TaskSpec.name must be a non-empty string.") raise ValueError("TaskSpec.name must be a non-empty string.")
if self.retries < 0: if self.retry.max_attempts < 1:
raise ValueError(f"TaskSpec '{self.name}': retries must be >= 0.") raise ValueError(f"TaskSpec '{self.name}': retry.max_attempts must be >= 1.")
if self.timeout is not None and self.timeout <= 0: if self.timeout is not None and self.timeout <= 0:
raise ValueError(f"TaskSpec '{self.name}': timeout must be > 0.") raise ValueError(f"TaskSpec '{self.name}': timeout must be > 0.")
if self.name in self.depends_on: if self.name in self.depends_on or self.name in self.soft_depends_on:
raise ValueError(f"TaskSpec '{self.name}' cannot depend on itself.") raise ValueError(f"TaskSpec '{self.name}' cannot depend on itself.")
overlap = set(self.depends_on) & set(self.soft_depends_on)
if overlap:
raise ValueError(f"TaskSpec '{self.name}': depends_on 与 soft_depends_on 不能重叠: {sorted(overlap)}")
if self.fn is None and self.cmd is None: if self.fn is None and self.cmd is None:
raise ValueError(f"TaskSpec '{self.name}': 必须提供 fn 或 cmd 参数。") raise ValueError(f"TaskSpec '{self.name}': 必须提供 fn 或 cmd 参数。")
@property @property
def effective_fn(self) -> TaskFn[T]: def effective_fn(self) -> TaskFn[T]:
"""获取有效的执行函数. """获取有效的执行函数
若提供 ``cmd`` 参数,则返回包装后的命令执行函数; 若提供 ``cmd``返回包装后的命令执行函数;否则返回 ``fn``。
否则返回 ``fn`` 参数。 包装函数在每次调用时从 ``self`` 读取 ``verbose``/``cwd``/``env``/
``timeout``,避免闭包捕获运行期参数,使翻转字段无需重建 spec。
Note
-----
命令执行逻辑已抽到模块级 :func:`_run_command`,此处仅返回轻量
转发闭包。``verbose`` / ``cwd`` / ``timeout`` 不再在创建时闭包
捕获,而是在每次调用时从 ``self`` 读取——这使得翻转 ``verbose``
无需重建 spec(见 :func:`pyflowx.runner._apply_verbose_to_graph`)。
""" """
if self.cmd is not None: if self.cmd is not None:
return self._wrap_cmd() return self._wrap_cmd()
if self.fn is not None: if self.fn is not None:
return self.fn return self.fn
raise ValueError(f"TaskSpec '{self.name}': 没有可执行的函数或命令。") # pragma: no cover raise ValueError(f"TaskSpec '{self.name}': 没有可执行的函数或命令。") # pragma: no cover
def _wrap_cmd(self) -> TaskFn[Any]: def _wrap_cmd(self) -> TaskFn[Any]:
"""将 cmd 包装为可执行函数. """将 cmd 包装为可执行函数"""
返回的闭包仅持有 ``self`` 引用,每次调用时从 spec 读取
``verbose``/``cwd``/``timeout``,避免闭包捕获运行期参数。
Returns
-------
TaskFn[Any]
包装后的执行函数.
"""
spec = self spec = self
if isinstance(spec.cmd, list): def _run() -> T:
return cast(T, _run_command(spec))
def _run_list() -> T: _run.__name__ = spec.name
return cast(T, _run_command(spec)) return _run # type: ignore[return-value]
_run_list.__name__ = spec.name def should_execute(self, context: Context) -> Tuple[bool, Optional[str]]:
return _run_list # type: ignore[return-value] """检查任务是否应执行。
if isinstance(spec.cmd, str):
def _run_shell() -> T:
return cast(T, _run_command(spec))
_run_shell.__name__ = spec.name
return _run_shell # type: ignore[return-value]
if callable(spec.cmd):
return spec.cmd # type: ignore[return-value]
raise TypeError(f"TaskSpec '{spec.name}': 不支持的 cmd 类型 {type(spec.cmd).__name__}") # pragma: no cover
def should_execute(self) -> bool:
"""检查任务是否应该执行.
Returns Returns
------- -------
bool (should_run, skip_reason)
若所有条件都返回 ``True``,且 ``skip_if_missing`` 检查通过, ``should_run`` 为 False 时 ``skip_reason`` 描述跳过原因。
则返回 ``True``;否则返回 ``False``。
""" """
if not all(condition() for condition in self.conditions): # 逐个求值条件,记录失败项。
return False failed_conditions: list[str] = []
for condition in self.conditions:
try:
ok = condition(context)
except Exception:
ok = False
name = getattr(condition, "__name__", None) or "匿名条件(执行错误)"
failed_conditions.append(name)
continue
if not ok:
failed_conditions.append(getattr(condition, "__name__", None) or "匿名条件")
return not (self.skip_if_missing and not self._is_cmd_available()) if failed_conditions:
return False, f"条件不满足: {', '.join(failed_conditions)}"
if self.skip_if_missing and not self._is_cmd_available():
cmd_name = self.cmd[0] if isinstance(self.cmd, list) and self.cmd else "unknown"
return False, f"命令不存在: {cmd_name}"
return True, None
def _is_cmd_available(self) -> bool: def _is_cmd_available(self) -> bool:
"""检查 ``cmd`` 是否可用. """检查 ``cmd`` 是否可用(仅 list[str])。"""
仅对 ``list[str]`` 类型的 ``cmd`` 进行检查(通过 :func:`shutil.which`)。
对于 ``str`` (shell) 和 ``Callable`` 类型,始终返回 ``True``。
Returns
-------
bool
命令可用返回 ``True``,否则返回 ``False``。
"""
cmd = self.cmd cmd = self.cmd
if isinstance(cmd, list) and cmd: if isinstance(cmd, list) and cmd:
first_arg = cmd[0] return shutil.which(cmd[0]) is not None
return shutil.which(first_arg) is not None
return True return True
def env_context(self) -> ContextManager[None]:
"""返回临时应用 ``env`` 与 ``cwd`` 的上下文管理器。
def _run_command(spec: "TaskSpec[Any]") -> Any: 对 ``fn`` 任务生效。``cmd`` 任务在 :func:`_run_command` 中直接
"""执行 ``spec.cmd`` 指定的命令(list 或 shell 字符串) 传给子进程
"""
return _env_and_cwd(self.env, self.cwd)
list 与 shell 两条路径的异常处理、输出捕获、返回码判断完全一致, def storage_key(self, context: Context) -> str:
合并于此消除重复。``verbose``/``cwd``/``timeout`` 在调用时从 """计算状态后端存储键。"""
``spec`` 读取,而非闭包捕获——这是 ``_wrap_cmd`` 不再捕获运行期 if self.cache_key is not None:
参数的关键。 try:
return f"{self.name}:{self.cache_key(context)}"
except Exception:
return self.name
return self.name
成功返回 ``None``;失败抛 ``RuntimeError``,错误信息包含命令、
返回码与(非 verbose 模式下的)stderr。 @contextmanager
""" def _env_and_cwd(
env: Optional[Mapping[str, str]],
cwd: Optional[Path],
) -> Iterator[None]:
"""临时设置环境变量与工作目录。"""
saved_env: dict[str, str] = {}
saved_cwd: Optional[str] = None
if env:
for k, v in env.items():
if k in os.environ:
saved_env[k] = os.environ[k]
os.environ[k] = v
if cwd is not None:
saved_cwd = str(Path.cwd())
os.chdir(cwd)
try:
yield
finally:
if saved_cwd is not None:
os.chdir(saved_cwd)
# 恢复环境变量
if env:
for k in env:
if k in saved_env:
os.environ[k] = saved_env[k]
else:
os.environ.pop(k, None)
def _run_command(spec: "TaskSpec[Any]") -> Any: # noqa: PLR0912
"""执行 ``spec.cmd`` 指定的命令(list / shell 字符串 / 可调用对象)。"""
cmd = spec.cmd cmd = spec.cmd
is_list = isinstance(cmd, list)
verbose = spec.verbose verbose = spec.verbose
cwd = spec.cwd cwd = spec.cwd
timeout = spec.timeout timeout = spec.timeout
env_override = spec.env
# 统一展示用的命令字符串与标签。保持 "执行命令" / "执行 Shell" 连续, # 可调用对象:直接调用,返回其结果。
# 以兼容既有输出格式与测试断言。 if callable(cmd) and not isinstance(cmd, (list, str)):
name = getattr(cmd, "__name__", "callable")
if verbose:
print(f"[verbose] 执行可调用命令: {name}", flush=True)
if cwd is not None:
print(f"[verbose] 工作目录: {cwd}", flush=True)
try:
return cmd()
except Exception as e:
raise RuntimeError(f"可调用命令执行异常: {name}: {e}") from e
is_list = isinstance(cmd, list)
if is_list: if is_list:
cmd_str = " ".join(arg for arg in cmd) # type: ignore[union-attr] cmd_str = " ".join(arg for arg in cmd) # type: ignore[union-attr]
verb = "执行命令" verb = "执行命令"
@@ -279,14 +428,18 @@ def _run_command(spec: "TaskSpec[Any]") -> Any:
if cwd is not None: if cwd is not None:
print(f"[verbose] 工作目录: {cwd}", flush=True) print(f"[verbose] 工作目录: {cwd}", flush=True)
# 合并环境变量
run_env: Optional[dict[str, str]] = None
if env_override:
run_env = dict(os.environ)
run_env.update(env_override)
try: try:
# cmd 此处必为 list[str] 或 str_wrap_cmd 的 isinstance 守卫已排除
# None 与 Callable),但类型检查器无法跨函数推断,故 cast 收窄到
# subprocess.run 接受的 Union[str, Sequence[str]]。
result = subprocess.run( result = subprocess.run(
cast(Union[str, List[str]], cmd), cast(Union[str, List[str]], cmd),
shell=not is_list, shell=not is_list,
cwd=cwd, cwd=cwd,
env=run_env,
timeout=timeout, timeout=timeout,
capture_output=not verbose, capture_output=not verbose,
text=True, text=True,
@@ -311,13 +464,42 @@ def _run_command(spec: "TaskSpec[Any]") -> Any:
raise RuntimeError(err_msg) raise RuntimeError(err_msg)
# ---------------------------------------------------------------------- #
# 任务模板:批量生成相似 TaskSpec 的工厂
# ---------------------------------------------------------------------- #
def task_template(
fn: Optional[TaskFn[Any]] = None,
cmd: Optional[TaskCmd] = None,
**defaults: Any,
) -> Callable[..., TaskSpec[Any]]:
"""创建任务模板工厂。
返回的工厂接受 ``name`` 与任意覆盖字段,生成 :class:`TaskSpec`。
适用于批量创建相似任务(如 fan-out)。
Examples
--------
>>> Fetch = px.task_template(fn=fetch_user, retry=px.RetryPolicy(max_attempts=3))
>>> specs = [Fetch(f"fetch_{uid}", args=(uid,)) for uid in range(5)]
"""
base = dict(defaults)
if fn is not None:
base["fn"] = fn
if cmd is not None:
base["cmd"] = cmd
def _factory(name: str, **overrides: Any) -> TaskSpec[Any]:
merged = dict(base)
merged.update(overrides)
return TaskSpec(name, **merged)
_factory.__name__ = "task_template_factory"
return _factory
@dataclass @dataclass
class TaskResult(Generic[T]): class TaskResult(Generic[T]):
"""运行期间产生的可变单任务记录。 """运行期间产生的可变单任务记录。"""
每次运行都会创建全新的 :class:`TaskResult`spec 本身保持不可变。
这让同一个图可以安全地重复运行。
"""
spec: TaskSpec[T] spec: TaskSpec[T]
status: TaskStatus = TaskStatus.PENDING status: TaskStatus = TaskStatus.PENDING
@@ -338,15 +520,11 @@ class TaskResult(Generic[T]):
@dataclass(frozen=True) @dataclass(frozen=True)
class TaskEvent: class TaskEvent:
"""执行期间向观察者发出的不可变事件。 """执行期间向观察者发出的不可变事件。"""
传递给 :func:`pyflowx.run` 的 ``on_event`` 回调,让调用者无需耦合
执行器内部即可构建进度条、指标或结构化日志。
"""
task: str task: str
status: TaskStatus status: TaskStatus
attempts: int = 0 attempts: int = 0
error: Optional[str] = None error: Optional[str] = None
duration: Optional[float] = None duration: Optional[float] = None
reason: Optional[str] = None # 跳过原因,如 "条件不满足"、"上游任务被跳过"、"缓存" reason: Optional[str] = None
File diff suppressed because it is too large Load Diff
+148 -72
View File
@@ -1,142 +1,218 @@
"""Tests for conditions module.""" """Tests for conditions module."""
from __future__ import annotations
import os import os
import sys import sys
from unittest.mock import patch from unittest.mock import patch
from pyflowx.conditions import ( from pyflowx.conditions import (
IS_LINUX,
IS_MACOS,
IS_POSIX,
IS_WINDOWS,
BuiltinConditions, BuiltinConditions,
Constants, Constants,
) )
_CTX: dict[str, object] = {}
def test_constants_is_windows(): def test_constants_is_windows():
"""Test Constants.IS_WINDOWS is correct."""
assert (sys.platform == "win32") == Constants.IS_WINDOWS assert (sys.platform == "win32") == Constants.IS_WINDOWS
def test_constants_is_linux(): def test_constants_is_linux():
"""Test Constants.IS_LINUX is correct."""
assert (sys.platform == "linux") == Constants.IS_LINUX assert (sys.platform == "linux") == Constants.IS_LINUX
def test_constants_is_macos(): def test_constants_is_macos():
"""Test Constants.IS_MACOS is correct."""
assert (sys.platform == "darwin") == Constants.IS_MACOS assert (sys.platform == "darwin") == Constants.IS_MACOS
def test_constants_is_posix(): def test_constants_is_posix():
"""Test Constants.IS_POSIX is correct."""
assert (sys.platform != "win32") == Constants.IS_POSIX assert (sys.platform != "win32") == Constants.IS_POSIX
def test_module_level_static_conditions():
assert IS_WINDOWS(_CTX) == Constants.IS_WINDOWS
assert IS_LINUX(_CTX) == Constants.IS_LINUX
assert IS_MACOS(_CTX) == Constants.IS_MACOS
assert IS_POSIX(_CTX) == Constants.IS_POSIX
def test_builtin_conditions_python_version_major_only():
"""Test BuiltinConditions.PYTHON_VERSION with major only.""" def test_python_version_major_only():
# Test with current Python version
current_major = sys.version_info.major current_major = sys.version_info.major
assert BuiltinConditions.PYTHON_VERSION(current_major) is True assert BuiltinConditions.PYTHON_VERSION(current_major)(_CTX) is True
assert BuiltinConditions.PYTHON_VERSION(current_major + 1) is False assert BuiltinConditions.PYTHON_VERSION(current_major + 1)(_CTX) is False
def test_builtin_conditions_python_version_with_minor(): def test_python_version_with_minor():
"""Test BuiltinConditions.PYTHON_VERSION with major and minor."""
current_major = sys.version_info.major current_major = sys.version_info.major
current_minor = sys.version_info.minor current_minor = sys.version_info.minor
assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor) is True assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor)(_CTX) is True
assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor + 1) is False assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor + 1)(_CTX) is False
def test_builtin_conditions_python_version_at_least(): def test_python_version_at_least():
"""Test BuiltinConditions.PYTHON_VERSION_AT_LEAST."""
current_major = sys.version_info.major current_major = sys.version_info.major
current_minor = sys.version_info.minor current_minor = sys.version_info.minor
# Current version should be at least itself assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major, current_minor)(_CTX) is True
assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major, current_minor) is True assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major - 1, 0)(_CTX) is True
# Current version should be at least an older version assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major + 1, 0)(_CTX) is False
assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major - 1, 0) is True
# Current version should NOT be at least a newer version
assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major + 1, 0) is False
def test_builtin_conditions_HAS_INSTALLED_true(): def test_has_installed_true():
"""Test BuiltinConditions.HAS_INSTALLED when app exists.""" condition = BuiltinConditions.HAS_INSTALLED("python3")
# Python should always be available assert condition(_CTX) is True
condition = BuiltinConditions.HAS_INSTALLED("python")
assert condition() is True
def test_builtin_conditions_HAS_INSTALLED_false(): def test_has_installed_false():
"""Test BuiltinConditions.HAS_INSTALLED when app doesn't exist."""
condition = BuiltinConditions.HAS_INSTALLED("nonexistent_app_12345") condition = BuiltinConditions.HAS_INSTALLED("nonexistent_app_12345")
assert condition() is False assert condition(_CTX) is False
def test_builtin_conditions_env_var_exists_true(): def test_env_var_exists_true():
"""Test BuiltinConditions.ENV_VAR_EXISTS when variable exists."""
with patch.dict(os.environ, {"TEST_VAR": "value"}): with patch.dict(os.environ, {"TEST_VAR": "value"}):
condition = BuiltinConditions.ENV_VAR_EXISTS("TEST_VAR") condition = BuiltinConditions.ENV_VAR_EXISTS("TEST_VAR")
assert condition() is True assert condition(_CTX) is True
def test_builtin_conditions_env_var_exists_false(): def test_env_var_exists_false():
"""Test BuiltinConditions.ENV_VAR_EXISTS when variable doesn't exist."""
condition = BuiltinConditions.ENV_VAR_EXISTS("NONEXISTENT_VAR_12345") condition = BuiltinConditions.ENV_VAR_EXISTS("NONEXISTENT_VAR_12345")
assert condition() is False assert condition(_CTX) is False
def test_builtin_conditions_env_var_equals_true(): def test_env_var_equals_true():
"""Test BuiltinConditions.ENV_VAR_EQUALS when value matches."""
with patch.dict(os.environ, {"TEST_VAR": "expected_value"}): with patch.dict(os.environ, {"TEST_VAR": "expected_value"}):
condition = BuiltinConditions.ENV_VAR_EQUALS("TEST_VAR", "expected_value") condition = BuiltinConditions.ENV_VAR_EQUALS("TEST_VAR", "expected_value")
assert condition() is True assert condition(_CTX) is True
def test_builtin_conditions_env_var_equals_false(): def test_env_var_equals_false():
"""Test BuiltinConditions.ENV_VAR_EQUALS when value doesn't match."""
with patch.dict(os.environ, {"TEST_VAR": "different_value"}): with patch.dict(os.environ, {"TEST_VAR": "different_value"}):
condition = BuiltinConditions.ENV_VAR_EQUALS("TEST_VAR", "expected_value") condition = BuiltinConditions.ENV_VAR_EQUALS("TEST_VAR", "expected_value")
assert condition() is False assert condition(_CTX) is False
def test_builtin_conditions_not(): def test_not():
"""Test BuiltinConditions.NOT.""" true_cond = BuiltinConditions.HAS_INSTALLED("python3")
true_condition = lambda: True # noqa: E731 false_cond = BuiltinConditions.HAS_INSTALLED("nonexistent_app_12345")
false_condition = lambda: False # noqa: E731
not_true = BuiltinConditions.NOT(true_condition) assert BuiltinConditions.NOT(true_cond)(_CTX) is False
assert not_true() is False assert BuiltinConditions.NOT(false_cond)(_CTX) is True
not_false = BuiltinConditions.NOT(false_condition)
assert not_false() is True
def test_builtin_conditions_and_all_true(): def test_and_all_true():
"""Test BuiltinConditions.AND when all conditions are true.""" cond = BuiltinConditions.AND(
true_condition = lambda: True # noqa: E731 BuiltinConditions.HAS_INSTALLED("python3"),
condition = BuiltinConditions.AND(true_condition, true_condition, true_condition) BuiltinConditions.HAS_INSTALLED("python3"),
assert condition() is True )
assert cond(_CTX) is True
def test_builtin_conditions_and_one_false(): def test_and_one_false():
"""Test BuiltinConditions.AND when one condition is false.""" cond = BuiltinConditions.AND(
true_condition = lambda: True # noqa: E731 BuiltinConditions.HAS_INSTALLED("python3"),
false_condition = lambda: False # noqa: E731 BuiltinConditions.HAS_INSTALLED("nonexistent_app"),
condition = BuiltinConditions.AND(true_condition, false_condition, true_condition) )
assert condition() is False assert cond(_CTX) is False
def test_builtin_conditions_or_all_false(): def test_or_all_false():
"""Test BuiltinConditions.OR when all conditions are false.""" cond = BuiltinConditions.OR(
false_condition = lambda: False # noqa: E731 BuiltinConditions.HAS_INSTALLED("nonexistent1"),
condition = BuiltinConditions.OR(false_condition, false_condition, false_condition) BuiltinConditions.HAS_INSTALLED("nonexistent2"),
assert condition() is False )
assert cond(_CTX) is False
def test_builtin_conditions_or_one_true(): def test_or_one_true():
"""Test BuiltinConditions.OR when one condition is true.""" cond = BuiltinConditions.OR(
true_condition = lambda: True # noqa: E731 BuiltinConditions.HAS_INSTALLED("nonexistent1"),
false_condition = lambda: False # noqa: E731 BuiltinConditions.HAS_INSTALLED("python3"),
condition = BuiltinConditions.OR(false_condition, true_condition, false_condition) )
assert condition() is True assert cond(_CTX) is True
# ---------------------------------------------------------------------- #
# 上下文条件:基于上游依赖结果
# ---------------------------------------------------------------------- #
def test_dep_equals_true():
ctx = {"upstream": 42}
cond = BuiltinConditions.DEP_EQUALS("upstream", 42)
assert cond(ctx) is True
def test_dep_equals_false():
ctx = {"upstream": 99}
cond = BuiltinConditions.DEP_EQUALS("upstream", 42)
assert cond(ctx) is False
def test_dep_equals_missing_dep():
cond = BuiltinConditions.DEP_EQUALS("missing", 42)
assert cond({}) is False
def test_dep_matches_true():
ctx = {"upstream": [1, 2, 3]}
cond = BuiltinConditions.DEP_MATCHES("upstream", lambda v: len(v) == 3)
assert cond(ctx) is True
def test_dep_matches_false():
ctx = {"upstream": [1, 2]}
cond = BuiltinConditions.DEP_MATCHES("upstream", lambda v: len(v) == 3)
assert cond(ctx) is False
def test_dep_matches_exception_returns_false():
ctx = {"upstream": ""}
cond = BuiltinConditions.DEP_MATCHES("upstream", lambda v: v[0])
assert cond(ctx) is False
def test_dep_present_true():
ctx = {"upstream": "value"}
cond = BuiltinConditions.DEP_PRESENT("upstream")
assert cond(ctx) is True
def test_dep_present_false_none():
# pyrefly: ignore [implicit-any-empty-container]
ctx = {"upstream": None}
cond = BuiltinConditions.DEP_PRESENT("upstream")
assert cond(ctx) is False
def test_dep_present_false_missing():
cond = BuiltinConditions.DEP_PRESENT("missing")
assert cond({}) is False
def test_dep_truthy_true():
ctx = {"upstream": [1]}
cond = BuiltinConditions.DEP_TRUTHY("upstream")
assert cond(ctx) is True
def test_dep_truthy_false():
# pyrefly: ignore [implicit-any-empty-container]
ctx = {"upstream": []}
cond = BuiltinConditions.DEP_TRUTHY("upstream")
assert cond(ctx) is False
def test_dep_truthy_missing():
cond = BuiltinConditions.DEP_TRUTHY("missing")
assert cond({}) is False
def test_logical_combination_with_dep_conditions():
ctx = {"a": 1, "b": 0}
cond = BuiltinConditions.AND(
BuiltinConditions.DEP_EQUALS("a", 1),
BuiltinConditions.NOT(BuiltinConditions.DEP_TRUTHY("b")),
)
assert cond(ctx) is True
+1 -1
View File
@@ -141,7 +141,7 @@ class TestDescribeInjection:
spec = px.TaskSpec("t", fn, depends_on=("a",)) spec = px.TaskSpec("t", fn, depends_on=("a",))
desc = describe_injection(spec) desc = describe_injection(spec)
assert "a=<result:a>" in desc assert "a=<dep:a>" in desc
assert "ctx=<Context>" in desc assert "ctx=<Context>" in desc
assert "flag=<default>" in desc assert "flag=<default>" in desc
+16 -8
View File
@@ -84,7 +84,9 @@ def test_retries_then_succeeds() -> None:
raise RuntimeError("not yet") raise RuntimeError("not yet")
return "ok" return "ok"
graph = px.Graph.from_specs([px.TaskSpec("flaky", flaky, retries=2)]) graph = px.Graph.from_specs([
px.TaskSpec("flaky", flaky, retry=px.RetryPolicy(max_attempts=3)),
])
report = px.run(graph, strategy="sequential") report = px.run(graph, strategy="sequential")
assert report.success assert report.success
assert report["flaky"] == "ok" assert report["flaky"] == "ok"
@@ -95,7 +97,9 @@ def test_retries_exhausted() -> None:
def always_fail() -> None: def always_fail() -> None:
raise RuntimeError("nope") raise RuntimeError("nope")
graph = px.Graph.from_specs([px.TaskSpec("f", always_fail, retries=2)]) graph = px.Graph.from_specs([
px.TaskSpec("f", always_fail, retry=px.RetryPolicy(max_attempts=3)),
])
with pytest.raises(TaskFailedError) as exc_info: with pytest.raises(TaskFailedError) as exc_info:
_ = px.run(graph, strategy="sequential") _ = px.run(graph, strategy="sequential")
assert exc_info.value.attempts == 3 assert exc_info.value.attempts == 3
@@ -332,7 +336,9 @@ def test_async_timeout_retry_then_succeed() -> None:
await asyncio.sleep(10) # 触发超时 await asyncio.sleep(10) # 触发超时
return "ok" return "ok"
graph = px.Graph.from_specs([px.TaskSpec("a", flaky, retries=2, timeout=0.05)]) graph = px.Graph.from_specs([
px.TaskSpec("a", flaky, retry=px.RetryPolicy(max_attempts=3), timeout=0.05),
])
report = px.run(graph, strategy="async") report = px.run(graph, strategy="async")
assert report.success assert report.success
assert report["a"] == "ok" assert report["a"] == "ok"
@@ -349,7 +355,9 @@ def test_async_failure_retry_branch(caplog: pytest.LogCaptureFixture) -> None:
raise RuntimeError("not yet") raise RuntimeError("not yet")
return "ok" return "ok"
graph = px.Graph.from_specs([px.TaskSpec("a", flaky, retries=2)]) graph = px.Graph.from_specs([
px.TaskSpec("a", flaky, retry=px.RetryPolicy(max_attempts=3)),
])
with caplog.at_level("WARNING", logger="pyflowx"): with caplog.at_level("WARNING", logger="pyflowx"):
report = px.run(graph, strategy="async") report = px.run(graph, strategy="async")
assert report.success assert report.success
@@ -489,7 +497,7 @@ def test_run_empty_graph() -> None:
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
def test_downstream_skipped_when_upstream_skipped_sequential() -> None: def test_downstream_skipped_when_upstream_skipped_sequential() -> None:
"""上游任务被 SKIPPED 后,下游任务也应被 SKIPPEDsequential 策略).""" """上游任务被 SKIPPED 后,下游任务也应被 SKIPPEDsequential 策略)."""
never_true = lambda: False # noqa: E731 never_true = lambda _ctx: False # noqa: E731
def downstream(upstream: str) -> str: def downstream(upstream: str) -> str:
return upstream + "_processed" return upstream + "_processed"
@@ -506,7 +514,7 @@ def test_downstream_skipped_when_upstream_skipped_sequential() -> None:
def test_downstream_skipped_when_upstream_skipped_thread() -> None: def test_downstream_skipped_when_upstream_skipped_thread() -> None:
"""上游任务被 SKIPPED 后,下游任务也应被 SKIPPEDthread 策略).""" """上游任务被 SKIPPED 后,下游任务也应被 SKIPPEDthread 策略)."""
never_true = lambda: False # noqa: E731 never_true = lambda _ctx: False # noqa: E731
def downstream(upstream: str) -> str: def downstream(upstream: str) -> str:
return upstream + "_processed" return upstream + "_processed"
@@ -530,7 +538,7 @@ def test_downstream_skipped_when_upstream_skipped_async() -> None:
async def downstream(upstream: str) -> str: async def downstream(upstream: str) -> str:
return upstream + "_processed" return upstream + "_processed"
never_true = lambda: False # noqa: E731 never_true = lambda _ctx: False # noqa: E731
graph = px.Graph.from_specs([ graph = px.Graph.from_specs([
px.TaskSpec("upstream", upstream, conditions=(never_true,)), px.TaskSpec("upstream", upstream, conditions=(never_true,)),
@@ -544,7 +552,7 @@ def test_downstream_skipped_when_upstream_skipped_async() -> None:
def test_downstream_executes_when_upstream_succeeds() -> None: def test_downstream_executes_when_upstream_succeeds() -> None:
"""上游任务成功时,下游任务应正常执行.""" """上游任务成功时,下游任务应正常执行."""
always_true = lambda: True # noqa: E731 always_true = lambda _ctx: True # noqa: E731
def upstream() -> str: def upstream() -> str:
return "hello" return "hello"
+14 -6
View File
@@ -85,7 +85,7 @@ def test_verbose_run_with_skipped_lifecycle(capsys: pytest.CaptureFixture[str]):
spec = px.TaskSpec( spec = px.TaskSpec(
"test", "test",
fn=lambda: "result", fn=lambda: "result",
conditions=(lambda: False,), conditions=(lambda _ctx: False,),
) )
graph = px.Graph.from_specs([spec]) graph = px.Graph.from_specs([spec])
report = px.run(graph, strategy="sequential", verbose=True) report = px.run(graph, strategy="sequential", verbose=True)
@@ -140,7 +140,7 @@ def test_verbose_event_callback_skipped():
spec = px.TaskSpec( spec = px.TaskSpec(
"test", "test",
fn=lambda: "result", fn=lambda: "result",
conditions=(lambda: False,), conditions=(lambda _ctx: False,),
verbose=True, verbose=True,
) )
graph = px.Graph.from_specs([spec]) graph = px.Graph.from_specs([spec])
@@ -161,7 +161,11 @@ def test_execute_sync_with_retries():
raise ValueError("temporary error") raise ValueError("temporary error")
return "success" return "success"
spec = px.TaskSpec("retry_test", fn=failing_function, retries=3) spec = px.TaskSpec(
"retry_test",
fn=failing_function,
retry=px.RetryPolicy(max_attempts=3),
)
graph = px.Graph.from_specs([spec]) graph = px.Graph.from_specs([spec])
# Should succeed after retries # Should succeed after retries
@@ -182,7 +186,11 @@ def test_execute_async_with_retries():
raise ValueError("temporary error") raise ValueError("temporary error")
return "success" return "success"
spec = px.TaskSpec("retry_async_test", fn=failing_async_function, retries=3) spec = px.TaskSpec(
"retry_async_test",
fn=failing_async_function,
retry=px.RetryPolicy(max_attempts=3),
)
graph = px.Graph.from_specs([spec]) graph = px.Graph.from_specs([spec])
# Should succeed after retries # Should succeed after retries
@@ -196,7 +204,7 @@ def test_execute_sync_skip_on_condition():
spec = px.TaskSpec( spec = px.TaskSpec(
"skip_test", "skip_test",
fn=lambda: "result", fn=lambda: "result",
conditions=(lambda: False,), conditions=(lambda _ctx: False,),
) )
graph = px.Graph.from_specs([spec]) graph = px.Graph.from_specs([spec])
@@ -210,7 +218,7 @@ def test_execute_async_skip_on_condition():
spec = px.TaskSpec( spec = px.TaskSpec(
"skip_async_test", "skip_async_test",
fn=lambda: "result", fn=lambda: "result",
conditions=(lambda: False,), conditions=(lambda _ctx: False,),
) )
graph = px.Graph.from_specs([spec]) graph = px.Graph.from_specs([spec])
+58 -74
View File
@@ -13,13 +13,11 @@ def _fn() -> None:
def test_from_specs_builds_graph() -> None: def test_from_specs_builds_graph() -> None:
graph = px.Graph.from_specs( graph = px.Graph.from_specs([
[ px.TaskSpec("a", _fn),
px.TaskSpec("a", _fn), px.TaskSpec("b", _fn, depends_on=("a",)),
px.TaskSpec("b", _fn, depends_on=("a",)), px.TaskSpec("c", _fn, depends_on=("a", "b")),
px.TaskSpec("c", _fn, depends_on=("a", "b")), ])
]
)
assert set(graph.names) == {"a", "b", "c"} assert set(graph.names) == {"a", "b", "c"}
assert graph.dependencies("c") == ("a", "b") assert graph.dependencies("c") == ("a", "b")
assert len(graph) == 3 assert len(graph) == 3
@@ -28,23 +26,19 @@ def test_from_specs_builds_graph() -> None:
def test_from_specs_allows_forward_references() -> None: def test_from_specs_allows_forward_references() -> None:
# b depends on a, but a is declared after b — order should not matter. # b depends on a, but a is declared after b — order should not matter.
graph = px.Graph.from_specs( graph = px.Graph.from_specs([
[ px.TaskSpec("b", _fn, depends_on=("a",)),
px.TaskSpec("b", _fn, depends_on=("a",)), px.TaskSpec("a", _fn),
px.TaskSpec("a", _fn), ])
]
)
assert graph.layers() == [["a"], ["b"]] assert graph.layers() == [["a"], ["b"]]
def test_duplicate_task_raises() -> None: def test_duplicate_task_raises() -> None:
with pytest.raises(DuplicateTaskError): with pytest.raises(DuplicateTaskError):
_ = px.Graph.from_specs( _ = px.Graph.from_specs([
[ px.TaskSpec("a", _fn),
px.TaskSpec("a", _fn), px.TaskSpec("a", _fn),
px.TaskSpec("a", _fn), ])
]
)
def test_missing_dependency_raises() -> None: def test_missing_dependency_raises() -> None:
@@ -57,24 +51,20 @@ def test_missing_dependency_raises() -> None:
def test_cycle_detection() -> None: def test_cycle_detection() -> None:
with pytest.raises(CycleError): with pytest.raises(CycleError):
_ = px.Graph.from_specs( _ = px.Graph.from_specs([
[ px.TaskSpec("a", _fn, depends_on=("c",)),
px.TaskSpec("a", _fn, depends_on=("c",)), px.TaskSpec("b", _fn, depends_on=("a",)),
px.TaskSpec("b", _fn, depends_on=("a",)), px.TaskSpec("c", _fn, depends_on=("b",)),
px.TaskSpec("c", _fn, depends_on=("b",)), ])
]
)
def test_layers_grouping() -> None: def test_layers_grouping() -> None:
graph = px.Graph.from_specs( graph = px.Graph.from_specs([
[ px.TaskSpec("a", _fn),
px.TaskSpec("a", _fn), px.TaskSpec("b", _fn),
px.TaskSpec("b", _fn), px.TaskSpec("c", _fn, depends_on=("a", "b")),
px.TaskSpec("c", _fn, depends_on=("a", "b")), px.TaskSpec("d", _fn, depends_on=("c",)),
px.TaskSpec("d", _fn, depends_on=("c",)), ])
]
)
layers = graph.layers() layers = graph.layers()
assert layers == [["a", "b"], ["c"], ["d"]] assert layers == [["a", "b"], ["c"], ["d"]]
@@ -85,12 +75,10 @@ def test_self_dependency_rejected() -> None:
def test_to_mermaid() -> None: def test_to_mermaid() -> None:
graph = px.Graph.from_specs( graph = px.Graph.from_specs([
[ px.TaskSpec("a", _fn),
px.TaskSpec("a", _fn), px.TaskSpec("b", _fn, depends_on=("a",)),
px.TaskSpec("b", _fn, depends_on=("a",)), ])
]
)
mermaid = graph.to_mermaid() mermaid = graph.to_mermaid()
assert mermaid.startswith("graph TD") assert mermaid.startswith("graph TD")
assert 'a["a"]' in mermaid assert 'a["a"]' in mermaid
@@ -104,13 +92,11 @@ def test_to_mermaid_invalid_orientation() -> None:
def test_subgraph_by_tags() -> None: def test_subgraph_by_tags() -> None:
graph = px.Graph.from_specs( graph = px.Graph.from_specs([
[ px.TaskSpec("a", _fn, tags=("ingest",)),
px.TaskSpec("a", _fn, tags=("ingest",)), px.TaskSpec("b", _fn, depends_on=("a",), tags=("ingest",)),
px.TaskSpec("b", _fn, depends_on=("a",), tags=("ingest",)), px.TaskSpec("c", _fn, depends_on=("b",), tags=("report",)),
px.TaskSpec("c", _fn, depends_on=("b",), tags=("report",)), ])
]
)
sub = graph.subgraph(["ingest"]) sub = graph.subgraph(["ingest"])
assert set(sub.names) == {"a", "b"} assert set(sub.names) == {"a", "b"}
# Edge to dropped task c is removed; b no longer waits for anything # Edge to dropped task c is removed; b no longer waits for anything
@@ -119,13 +105,11 @@ def test_subgraph_by_tags() -> None:
def test_subgraph_by_names() -> None: def test_subgraph_by_names() -> None:
graph = px.Graph.from_specs( graph = px.Graph.from_specs([
[ px.TaskSpec("a", _fn),
px.TaskSpec("a", _fn), px.TaskSpec("b", _fn, depends_on=("a",)),
px.TaskSpec("b", _fn, depends_on=("a",)), px.TaskSpec("c", _fn, depends_on=("b",)),
px.TaskSpec("c", _fn, depends_on=("b",)), ])
]
)
sub = graph.subgraph_by_names(["a", "b"]) sub = graph.subgraph_by_names(["a", "b"])
assert set(sub.names) == {"a", "b"} assert set(sub.names) == {"a", "b"}
# c is dropped, so b's dep on c (none here) — but a->b edge preserved. # c is dropped, so b's dep on c (none here) — but a->b edge preserved.
@@ -139,12 +123,10 @@ def test_subgraph_by_names_unknown() -> None:
def test_describe() -> None: def test_describe() -> None:
graph = px.Graph.from_specs( graph = px.Graph.from_specs([
[ px.TaskSpec("a", _fn),
px.TaskSpec("a", _fn), px.TaskSpec("b", _fn, depends_on=("a",)),
px.TaskSpec("b", _fn, depends_on=("a",)), ])
]
)
desc = graph.describe() desc = graph.describe()
assert "Layer 1" in desc assert "Layer 1" in desc
assert "Layer 2" in desc assert "Layer 2" in desc
@@ -187,12 +169,10 @@ def test_spec_accessor() -> None:
def test_dependencies_accessor() -> None: def test_dependencies_accessor() -> None:
graph = px.Graph.from_specs( graph = px.Graph.from_specs([
[ px.TaskSpec("a", _fn),
px.TaskSpec("a", _fn), px.TaskSpec("b", _fn, depends_on=("a",)),
px.TaskSpec("b", _fn, depends_on=("a",)), ])
]
)
assert graph.dependencies("a") == () assert graph.dependencies("a") == ()
assert graph.dependencies("b") == ("a",) assert graph.dependencies("b") == ("a",)
@@ -210,16 +190,20 @@ def test_empty_graph_layers() -> None:
def test_subgraph_preserves_metadata() -> None: def test_subgraph_preserves_metadata() -> None:
"""子图应保留原任务的 retries/timeout/tags 等元数据。""" """子图应保留原任务的 retry/timeout/tags 等元数据。"""
graph = px.Graph.from_specs( graph = px.Graph.from_specs([
[ px.TaskSpec(
px.TaskSpec("a", _fn, tags=("x",), retries=3, timeout=5.0), "a",
px.TaskSpec("b", _fn, depends_on=("a",), tags=("y",)), _fn,
] tags=("x",),
) retry=px.RetryPolicy(max_attempts=3),
timeout=5.0,
),
px.TaskSpec("b", _fn, depends_on=("a",), tags=("y",)),
])
sub = graph.subgraph(["x"]) sub = graph.subgraph(["x"])
spec = sub.spec("a") spec = sub.spec("a")
assert spec.retries == 3 assert spec.retry.max_attempts == 3
assert spec.timeout == 5.0 assert spec.timeout == 5.0
assert spec.tags == ("x",) assert spec.tags == ("x",)
+50 -68
View File
@@ -29,24 +29,20 @@ def _echo_graph(name: str = "echo_task", msg: str = "hello") -> px.Graph:
def _failing_graph() -> px.Graph: def _failing_graph() -> px.Graph:
"""构造一个必定失败的单任务图.""" """构造一个必定失败的单任务图."""
return px.Graph.from_specs( return px.Graph.from_specs([
[ px.TaskSpec(
px.TaskSpec( "fail",
"fail", cmd=["python", "-c", "import sys; sys.exit(1)"],
cmd=["python", "-c", "import sys; sys.exit(1)"], )
) ])
]
)
def _multi_task_graph() -> px.Graph: def _multi_task_graph() -> px.Graph:
"""构造一个带依赖的多任务图.""" """构造一个带依赖的多任务图."""
return px.Graph.from_specs( return px.Graph.from_specs([
[ px.TaskSpec("a", cmd=[*ECHO_CMD, "a"]),
px.TaskSpec("a", cmd=[*ECHO_CMD, "a"]), px.TaskSpec("b", cmd=[*ECHO_CMD, "b"], depends_on=("a",)),
px.TaskSpec("b", cmd=[*ECHO_CMD, "b"], depends_on=("a",)), ])
]
)
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
@@ -240,12 +236,10 @@ class TestCliRunnerRunSuccess:
def track_b() -> None: def track_b() -> None:
executed.append("b") executed.append("b")
runner = px.CliRunner( runner = px.CliRunner({
{ "a": px.Graph.from_specs([px.TaskSpec("a", track_a)]),
"a": px.Graph.from_specs([px.TaskSpec("a", track_a)]), "b": px.Graph.from_specs([px.TaskSpec("b", track_b)]),
"b": px.Graph.from_specs([px.TaskSpec("b", track_b)]), })
}
)
_ = runner.run(["b"]) _ = runner.run(["b"])
assert executed == ["b"] assert executed == ["b"]
@@ -318,15 +312,13 @@ class TestCliRunnerVerbose:
def test_verbose_prints_skip_lifecycle(self, capsys: pytest.CaptureFixture[str]) -> None: def test_verbose_prints_skip_lifecycle(self, capsys: pytest.CaptureFixture[str]) -> None:
"""verbose 模式下跳过的任务应打印跳过信息.""" """verbose 模式下跳过的任务应打印跳过信息."""
graph = px.Graph.from_specs( graph = px.Graph.from_specs([
[ px.TaskSpec(
px.TaskSpec( "skip_me",
"skip_me", cmd=[*ECHO_CMD, "skip"],
cmd=[*ECHO_CMD, "skip"], conditions=(lambda _ctx: False,),
conditions=(lambda: False,), ),
), ])
]
)
runner = px.CliRunner({"skip": graph}) runner = px.CliRunner({"skip": graph})
_ = runner.run(["skip"]) _ = runner.run(["skip"])
captured = capsys.readouterr() captured = capsys.readouterr()
@@ -394,13 +386,11 @@ class TestCliRunnerList:
def test_list_prints_all_commands(self, capsys: pytest.CaptureFixture[str]) -> None: def test_list_prints_all_commands(self, capsys: pytest.CaptureFixture[str]) -> None:
"""--list 应打印所有命令.""" """--list 应打印所有命令."""
runner = px.CliRunner( runner = px.CliRunner({
{ "clean": _echo_graph("c", "clean"),
"clean": _echo_graph("c", "clean"), "build": _echo_graph("b", "build"),
"build": _echo_graph("b", "build"), "test": _echo_graph("t", "test"),
"test": _echo_graph("t", "test"), })
}
)
_ = runner.run(["--list"]) _ = runner.run(["--list"])
captured = capsys.readouterr() captured = capsys.readouterr()
assert "clean" in captured.out assert "clean" in captured.out
@@ -523,30 +513,26 @@ class TestCliRunnerIntegration:
def test_condition_skipped_command_succeeds(self) -> None: def test_condition_skipped_command_succeeds(self) -> None:
"""条件不满足时任务跳过, 整体仍成功.""" """条件不满足时任务跳过, 整体仍成功."""
graph = px.Graph.from_specs( graph = px.Graph.from_specs([
[ px.TaskSpec(
px.TaskSpec( "skip_me",
"skip_me", cmd=[*ECHO_CMD, "should not run"],
cmd=[*ECHO_CMD, "should not run"], conditions=(lambda _ctx: False,),
conditions=(lambda: False,), ),
), ])
]
)
runner = px.CliRunner({"skip": graph}) runner = px.CliRunner({"skip": graph})
exit_code = runner.run(["skip"]) exit_code = runner.run(["skip"])
assert exit_code == CliExitCode.SUCCESS.value assert exit_code == CliExitCode.SUCCESS.value
def test_condition_met_command_succeeds(self) -> None: def test_condition_met_command_succeeds(self) -> None:
"""条件满足时任务执行, 整体成功.""" """条件满足时任务执行, 整体成功."""
graph = px.Graph.from_specs( graph = px.Graph.from_specs([
[ px.TaskSpec(
px.TaskSpec( "run_me",
"run_me", cmd=[*ECHO_CMD, "should run"],
cmd=[*ECHO_CMD, "should run"], conditions=(lambda _ctx: True,),
conditions=(lambda: True,), ),
), ])
]
)
runner = px.CliRunner({"run": graph}) runner = px.CliRunner({"run": graph})
exit_code = runner.run(["run"]) exit_code = runner.run(["run"])
assert exit_code == CliExitCode.SUCCESS.value assert exit_code == CliExitCode.SUCCESS.value
@@ -562,14 +548,12 @@ class TestCliRunnerIntegration:
return fn return fn
graph = px.Graph.from_specs( graph = px.Graph.from_specs([
[ px.TaskSpec("a", make("a")),
px.TaskSpec("a", make("a")), px.TaskSpec("b", make("b"), depends_on=("a",)),
px.TaskSpec("b", make("b"), depends_on=("a",)), px.TaskSpec("c", make("c"), depends_on=("a",)),
px.TaskSpec("c", make("c"), depends_on=("a",)), px.TaskSpec("d", make("d"), depends_on=("b", "c")),
px.TaskSpec("d", make("d"), depends_on=("b", "c")), ])
]
)
runner = px.CliRunner({"diamond": graph}) runner = px.CliRunner({"diamond": graph})
exit_code = runner.run(["diamond"]) exit_code = runner.run(["diamond"])
assert exit_code == CliExitCode.SUCCESS.value assert exit_code == CliExitCode.SUCCESS.value
@@ -577,12 +561,10 @@ class TestCliRunnerIntegration:
def test_mixed_fn_and_cmd_commands(self) -> None: def test_mixed_fn_and_cmd_commands(self) -> None:
"""混合 fn 和 cmd 的命令应都能执行.""" """混合 fn 和 cmd 的命令应都能执行."""
runner = px.CliRunner( runner = px.CliRunner({
{ "fn_cmd": px.Graph.from_specs([px.TaskSpec("fn", fn=lambda: "fn-result")]),
"fn_cmd": px.Graph.from_specs([px.TaskSpec("fn", fn=lambda: "fn-result")]), "cmd_cmd": px.Graph.from_specs([px.TaskSpec("cmd", cmd=[*ECHO_CMD, "cmd-result"])]),
"cmd_cmd": px.Graph.from_specs([px.TaskSpec("cmd", cmd=[*ECHO_CMD, "cmd-result"])]), })
}
)
assert runner.run(["fn_cmd"]) == CliExitCode.SUCCESS.value assert runner.run(["fn_cmd"]) == CliExitCode.SUCCESS.value
assert runner.run(["cmd_cmd"]) == CliExitCode.SUCCESS.value assert runner.run(["cmd_cmd"]) == CliExitCode.SUCCESS.value
+4 -4
View File
@@ -6,7 +6,7 @@ from datetime import datetime
import pytest import pytest
from pyflowx.task import TaskResult, TaskSpec, TaskStatus from pyflowx.task import RetryPolicy, TaskResult, TaskSpec, TaskStatus
def _fn() -> None: def _fn() -> None:
@@ -18,9 +18,9 @@ def test_spec_empty_name_rejected() -> None:
TaskSpec("", _fn) TaskSpec("", _fn)
def test_spec_negative_retries_rejected() -> None: def test_spec_negative_max_attempts_rejected() -> None:
with pytest.raises(ValueError, match="retries"): with pytest.raises(ValueError, match="max_attempts"):
TaskSpec("a", _fn, retries=-1) TaskSpec("a", _fn, retry=RetryPolicy(max_attempts=0))
def test_spec_zero_timeout_rejected() -> None: def test_spec_zero_timeout_rejected() -> None:
+31 -26
View File
@@ -67,7 +67,9 @@ def test_taskspec_wrap_cmd_verbose():
def test_taskspec_wrap_cmd_error(): def test_taskspec_wrap_cmd_error():
"""Test TaskSpec._wrap_cmd handles command error.""" """Test TaskSpec._wrap_cmd handles command error."""
spec = TaskSpec("test", cmd=["python", "-c", "import sys; sys.exit(1)"]) import sys
spec = TaskSpec("test", cmd=[sys.executable, "-c", "import sys; sys.exit(1)"])
wrapped_fn = spec.effective_fn wrapped_fn = spec.effective_fn
with pytest.raises(RuntimeError, match="命令执行失败"): with pytest.raises(RuntimeError, match="命令执行失败"):
@@ -105,10 +107,10 @@ def test_taskspec_conditions_check():
spec = px.TaskSpec( spec = px.TaskSpec(
"test", "test",
fn=lambda: "result", fn=lambda: "result",
conditions=(lambda: True,), conditions=(lambda _ctx: True,),
) )
assert spec.should_execute() is True assert spec.should_execute({})[0] is True
def test_taskspec_conditions_false(): def test_taskspec_conditions_false():
@@ -116,10 +118,10 @@ def test_taskspec_conditions_false():
spec = px.TaskSpec( spec = px.TaskSpec(
"test", "test",
fn=lambda: "result", fn=lambda: "result",
conditions=(lambda: False,), conditions=(lambda _ctx: False,),
) )
assert spec.should_execute() is False assert spec.should_execute({})[0] is False
def test_taskspec_conditions_multiple(): def test_taskspec_conditions_multiple():
@@ -127,10 +129,10 @@ def test_taskspec_conditions_multiple():
spec = px.TaskSpec( spec = px.TaskSpec(
"test", "test",
fn=lambda: "result", fn=lambda: "result",
conditions=(lambda: True, lambda: True, lambda: True), conditions=(lambda _ctx: True, lambda _ctx: True, lambda _ctx: True),
) )
assert spec.should_execute() is True assert spec.should_execute({})[0] is True
def test_taskspec_conditions_multiple_one_false(): def test_taskspec_conditions_multiple_one_false():
@@ -138,10 +140,10 @@ def test_taskspec_conditions_multiple_one_false():
spec = px.TaskSpec( spec = px.TaskSpec(
"test", "test",
fn=lambda: "result", fn=lambda: "result",
conditions=(lambda: True, lambda: False, lambda: True), conditions=(lambda _ctx: True, lambda _ctx: False, lambda _ctx: True),
) )
assert spec.should_execute() is False assert spec.should_execute({})[0] is False
def test_taskspec_list_cmd_timeout_mocked(): def test_taskspec_list_cmd_timeout_mocked():
@@ -218,27 +220,28 @@ def test_taskspec_shell_cmd_os_error_mocked():
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
def test_skip_if_missing_with_available_command(): def test_skip_if_missing_with_available_command():
"""skip_if_missing=True 时,命令存在应返回 True.""" """skip_if_missing=True 时,命令存在应返回 True."""
# python 命令在测试环境中一定存在 import sys
spec = TaskSpec("test", cmd=["python", "--version"], skip_if_missing=True)
assert spec.should_execute() is True spec = TaskSpec("test", cmd=[sys.executable, "--version"], skip_if_missing=True)
assert spec.should_execute({})[0] is True
def test_skip_if_missing_with_missing_command(): def test_skip_if_missing_with_missing_command():
"""skip_if_missing=True 时,命令不存在应返回 False.""" """skip_if_missing=True 时,命令不存在应返回 False."""
spec = TaskSpec("test", cmd=["definitely_not_installed_app_xyz"], skip_if_missing=True) spec = TaskSpec("test", cmd=["definitely_not_installed_app_xyz"], skip_if_missing=True)
assert spec.should_execute() is False assert spec.should_execute({})[0] is False
def test_skip_if_missing_false_with_missing_command(): def test_skip_if_missing_false_with_missing_command():
"""skip_if_missing=False 时,命令不存在也应返回 True(不检查).""" """skip_if_missing=False 时,命令不存在也应返回 True(不检查)."""
spec = TaskSpec("test", cmd=["definitely_not_installed_app_xyz"], skip_if_missing=False) spec = TaskSpec("test", cmd=["definitely_not_installed_app_xyz"], skip_if_missing=False)
assert spec.should_execute() is True assert spec.should_execute({})[0] is True
def test_skip_if_missing_with_shell_cmd_not_checked(): def test_skip_if_missing_with_shell_cmd_not_checked():
"""skip_if_missing=True 时,shell 命令(str)不检查,应返回 True.""" """skip_if_missing=True 时,shell 命令(str)不检查,应返回 True."""
spec = TaskSpec("test", cmd="definitely_not_installed_app_xyz", skip_if_missing=True) spec = TaskSpec("test", cmd="definitely_not_installed_app_xyz", skip_if_missing=True)
assert spec.should_execute() is True assert spec.should_execute({})[0] is True
def test_skip_if_missing_with_callable_cmd_not_checked(): def test_skip_if_missing_with_callable_cmd_not_checked():
@@ -248,7 +251,7 @@ def test_skip_if_missing_with_callable_cmd_not_checked():
return 0 return 0
spec = TaskSpec("test", cmd=custom_cmd, skip_if_missing=True) spec = TaskSpec("test", cmd=custom_cmd, skip_if_missing=True)
assert spec.should_execute() is True assert spec.should_execute({})[0] is True
def test_skip_if_missing_with_fn_not_checked(): def test_skip_if_missing_with_fn_not_checked():
@@ -258,7 +261,7 @@ def test_skip_if_missing_with_fn_not_checked():
return 0 return 0
spec = TaskSpec("test", fn=my_fn, skip_if_missing=True) spec = TaskSpec("test", fn=my_fn, skip_if_missing=True)
assert spec.should_execute() is True assert spec.should_execute({})[0] is True
def test_skip_if_missing_with_empty_cmd_list(): def test_skip_if_missing_with_empty_cmd_list():
@@ -266,37 +269,39 @@ def test_skip_if_missing_with_empty_cmd_list():
spec = TaskSpec("test", cmd=[""], skip_if_missing=True) spec = TaskSpec("test", cmd=[""], skip_if_missing=True)
# 空字符串命令,shutil.which 返回 None # 空字符串命令,shutil.which 返回 None
# 但 cmd[0] 是空字符串,shutil.which("") 返回 None # 但 cmd[0] 是空字符串,shutil.which("") 返回 None
assert spec.should_execute() is False assert spec.should_execute({})[0] is False
def test_skip_if_missing_combined_with_conditions(): def test_skip_if_missing_combined_with_conditions():
"""skip_if_missing=True 与 conditions 组合使用.""" """skip_if_missing=True 与 conditions 组合使用."""
import sys
# conditions 返回 False,应跳过 # conditions 返回 False,应跳过
spec = TaskSpec( spec = TaskSpec(
"test", "test",
cmd=["python", "--version"], cmd=[sys.executable, "--version"],
skip_if_missing=True, skip_if_missing=True,
conditions=(lambda: False,), conditions=(lambda _ctx: False,),
) )
assert spec.should_execute() is False assert spec.should_execute({})[0] is False
# conditions 返回 True,命令存在,应执行 # conditions 返回 True,命令存在,应执行
spec = TaskSpec( spec = TaskSpec(
"test", "test",
cmd=["python", "--version"], cmd=[sys.executable, "--version"],
skip_if_missing=True, skip_if_missing=True,
conditions=(lambda: True,), conditions=(lambda _ctx: True,),
) )
assert spec.should_execute() is True assert spec.should_execute({})[0] is True
# conditions 返回 True,命令不存在,应跳过 # conditions 返回 True,命令不存在,应跳过
spec = TaskSpec( spec = TaskSpec(
"test", "test",
cmd=["definitely_not_installed_app_xyz"], cmd=["definitely_not_installed_app_xyz"],
skip_if_missing=True, skip_if_missing=True,
conditions=(lambda: True,), conditions=(lambda _ctx: True,),
) )
assert spec.should_execute() is False assert spec.should_execute({})[0] is False
def test_skip_if_missing_skips_task_in_run(): def test_skip_if_missing_skips_task_in_run():
+30 -24
View File
@@ -52,7 +52,7 @@ def test_taskspec_with_conditions_skip():
"""测试条件不满足时任务被跳过.""" """测试条件不满足时任务被跳过."""
# 创建一个永远不会满足的条件 # 创建一个永远不会满足的条件
def never_true(): def never_true(_ctx):
return False return False
graph = px.Graph.from_specs([ graph = px.Graph.from_specs([
@@ -73,7 +73,7 @@ def test_taskspec_with_conditions_execute():
"""测试条件满足时任务正常执行.""" """测试条件满足时任务正常执行."""
# 创建一个总是满足的条件 # 创建一个总是满足的条件
def always_true(): def always_true(_ctx):
return True return True
graph = px.Graph.from_specs([ graph = px.Graph.from_specs([
@@ -103,17 +103,17 @@ def test_platform_conditions():
px.TaskSpec( px.TaskSpec(
"win_task", "win_task",
cmd=win_cmd, cmd=win_cmd,
conditions=(lambda: Constants.IS_WINDOWS,), conditions=(lambda _ctx: Constants.IS_WINDOWS,),
), ),
px.TaskSpec( px.TaskSpec(
"linux_task", "linux_task",
cmd=posix_cmd, cmd=posix_cmd,
conditions=(lambda: Constants.IS_LINUX,), conditions=(lambda _ctx: Constants.IS_LINUX,),
), ),
px.TaskSpec( px.TaskSpec(
"macos_task", "macos_task",
cmd=posix_cmd, cmd=posix_cmd,
conditions=(lambda: Constants.IS_MACOS,), conditions=(lambda _ctx: Constants.IS_MACOS,),
), ),
]) ])
@@ -137,17 +137,15 @@ def test_platform_conditions():
def test_app_installed_conditions(): def test_app_installed_conditions():
"""测试应用安装条件.""" """测试应用安装条件."""
# 测试 python 应该总是安装的 # 使用 sys.executable 保证可移植
if sys.platform == "win32": python_cmd = [sys.executable, "--version"]
python_cmd = ["python", "--version"] py_name = "python" if sys.platform == "win32" else "python3"
else:
python_cmd = ["python3", "--version"]
graph = px.Graph.from_specs([ graph = px.Graph.from_specs([
px.TaskSpec( px.TaskSpec(
"python_check", "python_check",
cmd=python_cmd, cmd=python_cmd,
conditions=(BuiltinConditions.HAS_INSTALLED("python"),), conditions=(BuiltinConditions.HAS_INSTALLED(py_name),),
), ),
]) ])
@@ -162,18 +160,18 @@ def test_combined_conditions():
"""测试组合条件.""" """测试组合条件."""
# AND 条件 # AND 条件
and_condition = BuiltinConditions.AND( and_condition = BuiltinConditions.AND(
lambda: True, lambda _ctx: True,
lambda: True, lambda _ctx: True,
) )
# OR 条件 # OR 条件
or_condition = BuiltinConditions.OR( or_condition = BuiltinConditions.OR(
lambda: True, lambda _ctx: True,
lambda: False, lambda _ctx: False,
) )
# NOT 条件 # NOT 条件
not_condition = BuiltinConditions.NOT(lambda: False) not_condition = BuiltinConditions.NOT(lambda _ctx: False)
graph = px.Graph.from_specs([ graph = px.Graph.from_specs([
px.TaskSpec( px.TaskSpec(
@@ -228,7 +226,7 @@ def test_taskspec_with_timeout():
# 短时间任务应该成功 # 短时间任务应该成功
px.TaskSpec( px.TaskSpec(
"short_task", "short_task",
cmd=["python", "-c", "import time; time.sleep(0.1)"], cmd=[sys.executable, "-c", "import time; time.sleep(0.1)"],
timeout=1.0, timeout=1.0,
), ),
]) ])
@@ -245,13 +243,13 @@ def test_taskspec_dependency_with_conditions():
px.TaskSpec( px.TaskSpec(
"first", "first",
cmd=[*ECHO_CMD, "first"], cmd=[*ECHO_CMD, "first"],
conditions=(lambda: True,), conditions=(lambda _ctx: True,),
), ),
px.TaskSpec( px.TaskSpec(
"second", "second",
cmd=[*ECHO_CMD, "second"], cmd=[*ECHO_CMD, "second"],
depends_on=("first",), depends_on=("first",),
conditions=(lambda: True,), conditions=(lambda _ctx: True,),
), ),
px.TaskSpec( px.TaskSpec(
"third", "third",
@@ -378,7 +376,7 @@ class TestTaskSpecVerbose:
graph = px.Graph.from_specs([ graph = px.Graph.from_specs([
px.TaskSpec( px.TaskSpec(
"fail", "fail",
cmd=["python", "-c", "import sys; sys.exit(1)"], cmd=[sys.executable, "-c", "import sys; sys.exit(1)"],
verbose=True, verbose=True,
) )
]) ])
@@ -414,7 +412,7 @@ class TestTaskSpecCmdErrors:
px.TaskSpec( px.TaskSpec(
"fail", "fail",
cmd=[ cmd=[
"python", sys.executable,
"-c", "-c",
"import sys; sys.stderr.write('error-msg'); sys.exit(1)", "import sys; sys.stderr.write('error-msg'); sys.exit(1)",
], ],
@@ -437,7 +435,9 @@ class TestTaskSpecCmdErrors:
"""shell 命令失败时应抛出 RuntimeError.""" """shell 命令失败时应抛出 RuntimeError."""
from pyflowx.errors import TaskFailedError from pyflowx.errors import TaskFailedError
graph = px.Graph.from_specs([px.TaskSpec("fail", cmd='python -c "import sys; sys.exit(1)"')]) graph = px.Graph.from_specs([
px.TaskSpec("fail", cmd=f'{sys.executable} -c "import sys; sys.exit(1)"'),
])
with pytest.raises(TaskFailedError) as exc_info: with pytest.raises(TaskFailedError) as exc_info:
_ = px.run(graph, strategy="sequential") _ = px.run(graph, strategy="sequential")
assert "Shell 命令执行失败" in str(exc_info.value.cause) assert "Shell 命令执行失败" in str(exc_info.value.cause)
@@ -450,7 +450,7 @@ class TestTaskSpecCmdErrors:
graph = px.Graph.from_specs([ graph = px.Graph.from_specs([
px.TaskSpec( px.TaskSpec(
"slow", "slow",
cmd=["python", "-c", "import time; time.sleep(5)"], cmd=[sys.executable, "-c", "import time; time.sleep(5)"],
timeout=0.1, timeout=0.1,
) )
]) ])
@@ -463,7 +463,13 @@ class TestTaskSpecCmdErrors:
"""shell 命令超时应抛出 RuntimeError.""" """shell 命令超时应抛出 RuntimeError."""
from pyflowx.errors import TaskFailedError from pyflowx.errors import TaskFailedError
graph = px.Graph.from_specs([px.TaskSpec("slow", cmd='python -c "import time; time.sleep(5)"', timeout=0.1)]) graph = px.Graph.from_specs([
px.TaskSpec(
"slow",
cmd=f'{sys.executable} -c "import time; time.sleep(5)"',
timeout=0.1,
),
])
with pytest.raises(TaskFailedError) as exc_info: with pytest.raises(TaskFailedError) as exc_info:
_ = px.run(graph, strategy="sequential") _ = px.run(graph, strategy="sequential")
assert "超时" in str(exc_info.value.cause) assert "超时" in str(exc_info.value.cause)