refactor: 重构重试策略、条件函数与上下文注入逻辑

主要变更:
1. 替换旧retries参数为RetryPolicy配置
2. 重构条件函数,支持上下文参数与动态依赖判断
3. 更新上下文注入逻辑,支持软依赖与更清晰的注入描述
4. 新增sglang CLI命令与相关配置
5. 格式化代码统一列表与参数写法
6. 更新文档与测试用例适配新API
This commit is contained in:
2026-06-27 14:33:54 +08:00
parent 6f01cde8ac
commit 5c8ec281ff
24 changed files with 2796 additions and 1043 deletions
+1
View File
@@ -44,6 +44,7 @@ piptool = "pyflowx.cli.piptool:main"
pymake = "pyflowx.cli.pymake:main"
reseticon = "pyflowx.cli.reseticoncache:main"
scrcap = "pyflowx.cli.screenshot:main"
sglang = "pyflowx.cli.llm.sglang:main"
sshcopy = "pyflowx.cli.sshcopyid:main"
taskk = "pyflowx.cli.taskkill:main"
wch = "pyflowx.cli.which:main"
+35 -17
View File
@@ -4,9 +4,15 @@
--------
* :class:`TaskSpec` —— 不可变任务描述符(唯一需要配置的东西)。
* :class:`Graph` —— 由一组 spec 构建的 DAG;负责校验、分层、可视化。
* :func:`run` —— 以 ``sequential`` / ``thread`` / ``async`` 策略执行图。
* :func:`run` ——以 ``sequential`` / ``thread`` / ``async`` / ``dependency``
策略执行图。
* :class:`RunReport` —— 类型化、可查询的运行结果。
* :class:`Context` —— 整体上下文注入的标注标记。
* :class:`RetryPolicy` —— 重试策略(max_attempts/delay/backoff/jitter/retry_on)。
* :class:`TaskHooks` —— 任务生命周期钩子(pre_run/post_run/on_failure)。
* :class:`GraphDefaults` —— 图级默认值。
* :func:`compose` —— 编程式组合多图。
* :func:`task_template` —— 批量生成相似 TaskSpec 的工厂。
* 状态后端::class:`StateBackend`、:class:`MemoryBackend`、:class:`JSONBackend`。
快速上手
@@ -18,7 +24,7 @@
graph = px.Graph.from_specs([
px.TaskSpec("extract", extract),
px.TaskSpec("double", double, ("extract",)),
px.TaskSpec("double", double, depends_on=("extract",)),
])
report = px.run(graph, strategy="sequential")
print(report["double"]) # [2, 4, 6]
@@ -29,23 +35,18 @@
from pyflowx.conditions import IS_WINDOWS, BuiltinConditions
graph = px.Graph.from_specs([
# 使用命令列表
px.TaskSpec("list_files", cmd=["ls", "-la"]),
# 使用 shell 命令
px.TaskSpec("check_git", cmd="git status"),
# 条件执行:仅在 Windows 上运行
px.TaskSpec(
"win_only",
cmd=["dir"],
conditions=(IS_WINDOWS,)
),
# 条件执行:仅在 git 已安装时运行
px.TaskSpec(
"git_check",
cmd=["git", "--version"],
conditions=(BuiltinConditions.HAS_INSTALLED("git"),)
),
# 命令不存在时自动跳过(而非失败)
px.TaskSpec(
"optional_build",
cmd=["maturin", "build"],
@@ -58,6 +59,10 @@
from __future__ import annotations
from .conditions import (
IS_LINUX,
IS_MACOS,
IS_POSIX,
IS_WINDOWS,
BuiltinConditions,
Condition,
Constants,
@@ -74,20 +79,33 @@ from .errors import (
TaskTimeoutError,
)
from .executors import Strategy, run
from .graph import Graph, GraphComposer
from .graph import Graph, GraphComposer, GraphDefaults, compose
from .report import RunReport
from .runner import CliExitCode, CliRunner
from .storage import JSONBackend, MemoryBackend, StateBackend
from .task import TaskCmd, TaskEvent, TaskResult, TaskSpec, TaskStatus
from .task import (
CacheKeyFn,
RetryPolicy,
TaskCmd,
TaskEvent,
TaskHooks,
TaskResult,
TaskSpec,
TaskStatus,
task_template,
)
__version__ = "0.2.6"
__version__ = "0.3.0"
__all__ = [
"IS_LINUX",
"IS_MACOS",
"IS_POSIX",
"IS_WINDOWS",
"BuiltinConditions",
"CacheKeyFn",
"CliExitCode",
# CLI 运行器
"CliRunner",
# 条件判断
"Condition",
"Constants",
"Context",
@@ -95,28 +113,28 @@ __all__ = [
"DuplicateTaskError",
"Graph",
"GraphComposer",
"GraphDefaults",
"InjectionError",
"JSONBackend",
"MemoryBackend",
"MissingDependencyError",
# 错误
"PyFlowXError",
"RetryPolicy",
"RunReport",
# 状态后端
"StateBackend",
"StorageError",
"Strategy",
"TaskCmd",
"TaskEvent",
"TaskFailedError",
"TaskHooks",
"TaskResult",
# 核心类型
"TaskSpec",
"TaskStatus",
"TaskTimeoutError",
# 辅助(高级)
"build_call_args",
"compose",
"describe_injection",
# 执行
"run",
"task_template",
]
+3 -3
View File
@@ -112,9 +112,9 @@ def main() -> None:
args = parser.parse_args()
if args.command == "mirror":
graph = px.Graph.from_specs(
[px.TaskSpec("set_pip_mirror", fn=set_pip_mirror, args=(args.name,), kwargs={"token": args.token})]
)
graph = px.Graph.from_specs([
px.TaskSpec("set_pip_mirror", fn=set_pip_mirror, args=(args.name,), kwargs={"token": args.token})
])
else:
parser.print_help()
return
+2 -2
View File
@@ -43,13 +43,13 @@ def main() -> None:
px.TaskSpec(
"envqt_install",
cmd=["sudo", "apt", "install", "-y", *QT_LIBS],
conditions=(lambda: Constants.IS_LINUX,),
conditions=(lambda _: Constants.IS_LINUX,),
verbose=True,
),
px.TaskSpec(
"envqt_fonts",
cmd=["sudo", "apt", "install", "-y", *CHINESE_FONTS],
conditions=(lambda: Constants.IS_LINUX,),
conditions=(lambda _: Constants.IS_LINUX,),
verbose=True,
),
],
+8 -5
View File
@@ -37,7 +37,7 @@ def init_sub_dirs() -> None:
px.TaskSpec(
"init",
cmd=["git", "init"],
conditions=(not_has_git_repo,),
conditions=(lambda _: not_has_git_repo(),),
cwd=subdir,
),
px.TaskSpec("add", cmd=["git", "add", "."], depends_on=("init",)),
@@ -70,7 +70,7 @@ def main() -> None:
graphs={
# 添加并提交
"a": px.Graph.from_specs([
px.TaskSpec("add", cmd=["git", "add", "."], conditions=(has_files,)),
px.TaskSpec("add", cmd=["git", "add", "."], conditions=(lambda _: has_files(),)),
px.TaskSpec("commit", cmd=["git", "commit", "-m", "chore: update"], depends_on=("add",)),
]),
# 清理
@@ -80,10 +80,13 @@ def main() -> None:
]),
# 初始化、添加并提交
"i": px.Graph.from_specs([
px.TaskSpec("init", cmd=["git", "init"], conditions=(not_has_git_repo,)),
px.TaskSpec("add", cmd=["git", "add", "."], depends_on=("init",), conditions=(has_files,)),
px.TaskSpec("init", cmd=["git", "init"], conditions=(lambda _: not_has_git_repo(),)),
px.TaskSpec("add", cmd=["git", "add", "."], depends_on=("init",), conditions=(lambda _: has_files(),)),
px.TaskSpec(
"commit", cmd=["git", "commit", "-m", "init commit"], depends_on=("add",), conditions=(has_files,)
"commit",
cmd=["git", "commit", "-m", "init commit"],
depends_on=("add",),
conditions=(lambda _: has_files(),),
),
]),
# 初始化子目录
+55
View File
@@ -0,0 +1,55 @@
"""使用 SGLang 运行本地模型."""
import argparse
from pathlib import Path
import pyflowx as px
from pyflowx.conditions import BuiltinConditions
def main():
parser = argparse.ArgumentParser(description="Run a local model using SGLang.")
parser.add_argument("name", help="Model name.")
parser.add_argument("--dir", default=None, help="Directory of model.")
args = parser.parse_args()
if not args.name:
parser.error("name is required")
model_dir = Path(args.dir) if args.dir else Path.home() / ".models" / args.name.split("/")[-1]
if not model_dir.exists():
parser.error(f"Model directory {model_dir} does not exist.")
graph = px.Graph.from_specs([
px.TaskSpec(
name="download",
cmd=[
"uv",
"install",
"sglang[all]",
],
conditions=(BuiltinConditions.NOT(BuiltinConditions.HAS_INSTALLED("sglang")),),
verbose=True,
),
px.TaskSpec(
name="run",
cmd=[
"uvx",
"sglang",
"serve",
"--model-path",
str(model_dir),
"--host",
"0.0.0.0",
"--port",
"8000",
"--mem-fraction-static",
"0.88",
"--context-length",
"32768",
],
verbose=True,
),
])
px.run(graph, verbose=True)
+28 -34
View File
@@ -21,12 +21,10 @@ PACKAGE_DIR = "packages"
REQUIREMENTS_FILE = "requirements.txt"
# 受保护的包名集合
_PROTECTED_PACKAGES: frozenset[str] = frozenset(
{
"pyflowx",
"bitool",
}
)
_PROTECTED_PACKAGES: frozenset[str] = frozenset({
"pyflowx",
"bitool",
})
# ============================================================================
@@ -161,37 +159,33 @@ def main() -> None:
if args.command == "i":
graph = px.Graph.from_specs([px.TaskSpec("pip_install", cmd=["pip", "install", *args.packages], verbose=True)])
elif args.command == "u":
graph = px.Graph.from_specs(
[px.TaskSpec("pip_uninstall", fn=pip_uninstall, args=(args.packages,), verbose=True)]
)
graph = px.Graph.from_specs([
px.TaskSpec("pip_uninstall", fn=pip_uninstall, args=(args.packages,), verbose=True)
])
elif args.command == "r":
graph = px.Graph.from_specs(
[
px.TaskSpec(
"pip_reinstall",
fn=pip_reinstall,
args=(args.packages,),
kwargs={"offline": args.offline},
verbose=True,
)
]
)
graph = px.Graph.from_specs([
px.TaskSpec(
"pip_reinstall",
fn=pip_reinstall,
args=(args.packages,),
kwargs={"offline": args.offline},
verbose=True,
)
])
elif args.command == "d":
graph = px.Graph.from_specs(
[
px.TaskSpec(
"pip_download",
fn=pip_download,
args=(args.packages,),
kwargs={"offline": args.offline},
verbose=True,
)
]
)
graph = px.Graph.from_specs([
px.TaskSpec(
"pip_download",
fn=pip_download,
args=(args.packages,),
kwargs={"offline": args.offline},
verbose=True,
)
])
elif args.command == "up":
graph = px.Graph.from_specs(
[px.TaskSpec("pip_upgrade", cmd=["python", "-m", "pip", "install", "--upgrade", "pip"], verbose=True)]
)
graph = px.Graph.from_specs([
px.TaskSpec("pip_upgrade", cmd=["python", "-m", "pip", "install", "--upgrade", "pip"], verbose=True)
])
elif args.command == "f":
graph = px.Graph.from_specs([px.TaskSpec("pip_freeze", fn=pip_freeze, verbose=True)])
else:
+128 -160
View File
@@ -1,7 +1,12 @@
"""条件判断模块.
提供平台条件、应用安装条件等预定义条件判断函数,
用于 TaskSpec 的条件执行功能.
所有条件均为 ``Callable[[Context], bool]``,接收依赖上下文映射(可能为空)。
这使得条件可基于上游任务的运行时返回值做决策,实现动态分支。
内置条件分两类:
1. *静态条件* —— 不依赖上下文(平台/环境变量/安装检查),通过 ``_static``
包装忽略传入的 context,便于作为模块级常量使用。
2. *上下文条件* —— 基于上游结果判断,如 :meth:`BuiltinConditions.DEP_EQUALS`。
"""
from __future__ import annotations
@@ -11,10 +16,11 @@ import shutil
import subprocess
import sys
from pathlib import Path
from typing import Callable
from typing import Any, Callable
# 条件判断函数类型
Condition = Callable[[], bool]
from .task import Condition, Context
__all__ = ["BuiltinConditions", "Condition", "Constants"]
class Constants:
@@ -26,65 +32,56 @@ class Constants:
IS_POSIX: bool = sys.platform != "win32"
def _static(predicate: Callable[[], bool], name: str) -> Condition:
"""将无参谓词包装为忽略上下文的 :class:`Condition`。"""
def _cond(_ctx: Context) -> bool:
return predicate()
_cond.__name__ = name
return _cond
# ---------------------------------------------------------------------- #
# 模块级静态条件常量
# ---------------------------------------------------------------------- #
IS_WINDOWS: Condition = _static(lambda: Constants.IS_WINDOWS, "IS_WINDOWS")
IS_LINUX: Condition = _static(lambda: Constants.IS_LINUX, "IS_LINUX")
IS_MACOS: Condition = _static(lambda: Constants.IS_MACOS, "IS_MACOS")
IS_POSIX: Condition = _static(lambda: Constants.IS_POSIX, "IS_POSIX")
class BuiltinConditions:
"""内置条件判断函数集合."""
"""内置条件判断函数集合.
静态条件工厂返回忽略上下文的 :class:`Condition`;上下文条件工厂返回
会读取依赖结果的 :class:`Condition`。
"""
# ------------------------------------------------------------------ #
# 静态条件
# ------------------------------------------------------------------ #
@staticmethod
def PYTHON_VERSION(major: int, minor: int | None = None) -> bool:
"""检查 Python 版本是否匹配.
Parameters
----------
major : int
主版本号.
minor : int | None
次版本号, 若为 None 则仅检查主版本.
Returns
-------
bool
版本是否匹配.
"""
def PYTHON_VERSION(major: int, minor: int | None = None) -> Condition:
"""检查 Python 版本是否匹配."""
if minor is None:
return sys.version_info.major == major
return sys.version_info.major == major and sys.version_info.minor == minor
return _static(lambda: sys.version_info.major == major, f"PYTHON_VERSION({major})")
return _static(
lambda: sys.version_info.major == major and sys.version_info.minor == minor,
f"PYTHON_VERSION({major},{minor})",
)
@staticmethod
def PYTHON_VERSION_AT_LEAST(major: int, minor: int = 0) -> bool:
"""检查 Python 版本是否 >= 指定版本.
Parameters
----------
major : int
主版本号.
minor : int
次版本号.
Returns
-------
bool
当前版本是否 >= 指定版本.
"""
return sys.version_info >= (major, minor)
def PYTHON_VERSION_AT_LEAST(major: int, minor: int = 0) -> Condition:
"""检查 Python 版本是否 >= 指定版本."""
return _static(lambda: sys.version_info >= (major, minor), f"PYTHON_VERSION_AT_LEAST({major},{minor})")
@staticmethod
def IS_RUNNING(app_name: str) -> Condition:
"""检查指定应用是否正在运行.
Parameters
----------
app_name : str
应用名称 (如 "explorer", "chrome", "python").
Returns
-------
Condition
条件判断函数.
"""
"""检查指定应用是否正在运行."""
def _check() -> bool:
if Constants.IS_WINDOWS:
# Windows: 使用 tasklist 命令
result = subprocess.run(
["tasklist", "/nh", "/fi", f"imagename eq {app_name}"],
capture_output=True,
@@ -93,148 +90,119 @@ class BuiltinConditions:
)
return app_name.lower() in result.stdout.lower()
else:
# Linux/macOS: 使用 pgrep 命令
result = subprocess.run(
["pgrep", "-x", app_name],
capture_output=True,
check=False,
)
result = subprocess.run(["pgrep", "-x", app_name], capture_output=True, check=False)
return result.returncode == 0
_check.__name__ = f"IS_RUNNING({app_name!r})"
return _check
return _static(_check, f"IS_RUNNING({app_name!r})")
@staticmethod
def HAS_INSTALLED(app_name: str) -> Condition:
"""检查指定应用是否已安装.
Parameters
----------
app_name : str
应用名称 (如 "git", "python", "pytest").
Returns
-------
Condition
条件判断函数.
"""
def _check() -> bool:
return shutil.which(app_name) is not None
_check.__name__ = f"HAS_INSTALLED({app_name!r})"
return _check
"""检查指定应用是否已安装."""
return _static(lambda: shutil.which(app_name) is not None, f"HAS_INSTALLED({app_name!r})")
@staticmethod
def DIR_EXISTS(dir: Path) -> Condition:
def DIR_EXISTS(path: Path) -> Condition:
"""路径是否存在."""
return dir.exists
return _static(path.exists, f"DIR_EXISTS({path!r})")
@staticmethod
def ENV_VAR_EXISTS(var_name: str) -> Condition:
"""检查环境变量是否存在.
Parameters
----------
var_name : str
环境变量名.
Returns
-------
Condition
条件判断函数.
"""
def _check() -> bool:
return var_name in os.environ
_check.__name__ = f"ENV_VAR_EXISTS({var_name!r})"
return _check
"""检查环境变量是否存在."""
return _static(lambda: var_name in os.environ, f"ENV_VAR_EXISTS({var_name!r})")
@staticmethod
def ENV_VAR_EQUALS(var_name: str, value: str) -> Condition:
"""检查环境变量是否等于指定值.
"""检查环境变量是否等于指定值."""
return _static(
lambda: os.environ.get(var_name) == value,
f"ENV_VAR_EQUALS({var_name!r},{value!r})",
)
Parameters
----------
var_name : str
环境变量名.
value : str
期望的值.
# ------------------------------------------------------------------ #
# 上下文条件:基于上游依赖结果
# ------------------------------------------------------------------ #
@staticmethod
def DEP_EQUALS(dep_name: str, value: Any) -> Condition:
"""上游任务 ``dep_name`` 的返回值等于 ``value`` 时为真。
Returns
-------
Condition
条件判断函数.
若依赖未在上下文中(被跳过或未执行),返回 ``False``。
"""
def _check() -> bool:
return os.environ.get(var_name) == value
def _cond(ctx: Context) -> bool:
return dep_name in ctx and ctx[dep_name] == value
_check.__name__ = f"ENV_VAR_EQUALS({var_name!r}, {value!r})"
return _check
_cond.__name__ = f"DEP_EQUALS({dep_name!r},{value!r})"
return _cond
@staticmethod
def NOT(condition: Condition) -> Condition:
"""对条件取反.
def DEP_MATCHES(dep_name: str, predicate: Callable[[Any], bool]) -> Condition:
"""上游任务 ``dep_name`` 的返回值满足 ``predicate`` 时为真。
Parameters
----------
condition : Condition
原始条件.
Returns
-------
Condition
取反后的条件.
依赖不存在时返回 ``False``。
"""
def _check() -> bool:
return not condition()
def _cond(ctx: Context) -> bool:
if dep_name not in ctx:
return False
try:
return predicate(ctx[dep_name])
except Exception:
return False
_check.__name__ = f"NOT({getattr(condition, '__name__', repr(condition))})"
return _check
_cond.__name__ = f"DEP_MATCHES({dep_name!r},{getattr(predicate, '__name__', 'pred')})"
return _cond
@staticmethod
def DEP_PRESENT(dep_name: str) -> Condition:
"""上游任务 ``dep_name`` 存在于上下文(即已成功执行)时为真。"""
def _cond(ctx: Context) -> bool:
return dep_name in ctx and ctx[dep_name] is not None
_cond.__name__ = f"DEP_PRESENT({dep_name!r})"
return _cond
@staticmethod
def DEP_TRUTHY(dep_name: str) -> Condition:
"""上游任务 ``dep_name`` 的返回值为真值时为真。"""
def _cond(ctx: Context) -> bool:
return bool(ctx.get(dep_name))
_cond.__name__ = f"DEP_TRUTHY({dep_name!r})"
return _cond
# ------------------------------------------------------------------ #
# 逻辑组合
# ------------------------------------------------------------------ #
@staticmethod
def NOT(condition: Condition) -> Condition:
"""对条件取反."""
def _cond(ctx: Context) -> bool:
return not condition(ctx)
_cond.__name__ = f"NOT({getattr(condition, '__name__', repr(condition))})"
return _cond
@staticmethod
def AND(*conditions: Condition) -> Condition:
"""多个条件的逻辑与.
"""多个条件的逻辑与."""
Parameters
----------
*conditions : Condition
条件列表.
Returns
-------
Condition
组合条件.
"""
def _check() -> bool:
return all(c() for c in conditions)
def _cond(ctx: Context) -> bool:
return all(c(ctx) for c in conditions)
names = [getattr(c, "__name__", repr(c)) for c in conditions]
_check.__name__ = f"AND({', '.join(names)})"
return _check
_cond.__name__ = f"AND({', '.join(names)})"
return _cond
@staticmethod
def OR(*conditions: Condition) -> Condition:
"""多个条件的逻辑或.
"""多个条件的逻辑或."""
Parameters
----------
*conditions : Condition
条件列表.
Returns
-------
Condition
组合条件.
"""
def _check() -> bool:
return any(c() for c in conditions)
def _cond(ctx: Context) -> bool:
return any(c(ctx) for c in conditions)
names = [getattr(c, "__name__", repr(c)) for c in conditions]
_check.__name__ = f"OR({', '.join(names)})"
return _check
_cond.__name__ = f"OR({', '.join(names)})"
return _cond
+15 -59
View File
@@ -1,18 +1,16 @@
"""上下文注入:把上游结果转换为函数参数。
本机制让用户可以编写普通函数,其参数名*就是*依赖声明,从而消除其他
DAG 库中泛滥的样板包装器(如 ``def wrapper(): return fn(workflow.get_task_result('x'))``
DAG 库中泛滥的样板包装器。
注入规则(按顺序求值)
----------------------
1. **标注为** :class:`Context` 的参数接收完整结果映射。适用于需要遍历
所有输入的任务
2. **名称匹配某个依赖**的参数接收该依赖的结果。
1. **标注为** :class:`Context` 的参数接收完整结果映射(含硬依赖与软依赖)。
2. **名称匹配某个依赖**(硬或软)的参数接收该依赖的结果
3. ``**kwargs`` 参数以 dict 形式接收*所有*依赖结果。
4. ``TaskSpec.args`` / ``TaskSpec.kwargs`` 为*非依赖*参数提供静态值。
若某参数无法解析且无默认值,则抛出 :class:`~pyflowx.errors.InjectionError`
并附带精确错误信息。
若某参数无法解析且无默认值,则抛出 :class:`~pyflowx.errors.InjectionError`
"""
from __future__ import annotations
@@ -27,21 +25,11 @@ __all__ = ["Context", "_is_context_annotation", "build_call_args", "describe_inj
def _is_context_annotation(annotation: Any) -> bool:
"""判断参数标注是否为(或指向)``Context``。
处理三种形式:
* ``Context`` 别名对象本身;
* ``__name__``/``_name`` 为 ``Context`` 或 ``Mapping`` 的 typing 别名;
* *字符串*标注(``from __future__ import annotations`` 会在运行时
把所有标注变为字符串),如 ``"Context"`` 或 ``"px.Context"``。
"""
"""判断参数标注是否为(或指向)``Context``。"""
if annotation is Context:
return True
# `from __future__ import annotations` 产生的字符串标注。
if isinstance(annotation, str):
# 匹配 "Context"、"px.Context"、"pyflowx.Context" 等。
return annotation == "Context" or annotation.endswith(".Context")
# 按限定名匹配,支持 ``from pyflowx import Context`` 再导出。
name = getattr(annotation, "__name__", None) or getattr(annotation, "_name", None)
return name in ("Context", "Mapping")
@@ -52,39 +40,22 @@ def build_call_args(
) -> tuple[tuple[Any, ...], dict[str, Any]]:
"""解析用于调用 ``spec.fn`` 的 ``(args, kwargs)``。
参数
----
spec:
任务 spec,提供 ``fn``、``depends_on``、``args``、``kwargs``。
context:
依赖名 -> 结果值的映射。仅保证本任务自身的 ``depends_on`` 条目
存在;其他任务的结果被排除,以保持注入的确定性。
返回
----
(args, kwargs)
可直接展开为 ``spec.fn(*args, **kwargs)``。
抛出
----
InjectionError
若必需参数无法满足,或静态 ``kwargs`` 与注入依赖名冲突。
``context`` 必须已包含所有硬依赖与软依赖的结果(软依赖被跳过时由
执行器填入 :attr:`TaskSpec.defaults` 中的默认值)。
"""
# 使用 effective_fn 而不是 fn,以支持 cmd 参数
fn = spec.effective_fn
sig = inspect.signature(fn)
params = sig.parameters
# 检测特殊参数类型。
var_keyword = next(
(p for p in params.values() if p.kind == inspect.Parameter.VAR_KEYWORD),
None,
)
# 本任务相关的上下文子集。
dep_context: dict[str, Any] = {name: context[name] for name in spec.depends_on if name in context}
# 本任务相关的上下文子集:硬依赖 + 软依赖
all_deps = set(spec.depends_on) | set(spec.soft_depends_on)
dep_context: dict[str, Any] = {name: context[name] for name in all_deps if name in context}
# 检测静态 kwargs 与依赖名的冲突。
collisions = set(spec.kwargs) & set(dep_context)
if collisions:
raise InjectionError(
@@ -96,8 +67,6 @@ def build_call_args(
injected_kwargs: dict[str, Any] = {}
leftover_dep_results: dict[str, Any] = dict(dep_context)
# 被 spec.args 消费的位置参数。记录哪些参数名已被位置填充,
# 以便在基于名称的注入(依赖 / Context / 静态 kwargs)时跳过。
positional_params: list[str] = []
positional_kinds = (
inspect.Parameter.POSITIONAL_ONLY,
@@ -106,33 +75,25 @@ def build_call_args(
for pname, param in params.items():
if param.kind in positional_kinds:
positional_params.append(pname)
# 前 len(spec.args) 个位置参数由 spec.args 填充。
args_filled: set[str] = set(positional_params[: len(spec.args)])
for pname, param in params.items():
# 跳过已被位置 spec.args 填充的参数。
if pname in args_filled:
continue
# 规则 1:标注为 Context -> 完整映射。
if _is_context_annotation(param.annotation):
injected_kwargs[pname] = dep_context
continue
# 规则 2:名称匹配某个依赖。
if pname in dep_context:
injected_kwargs[pname] = dep_context[pname]
leftover_dep_results.pop(pname, None)
continue
# 规则 3:在循环后通过 **kwargs 处理。
# 规则 4:静态 kwargs 填充其余参数。
if pname in spec.kwargs:
injected_kwargs[pname] = spec.kwargs[pname]
continue
# 该参数无来源:必须有默认值,否则报错。
if param.default is inspect.Parameter.empty and param.kind not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
@@ -142,9 +103,7 @@ def build_call_args(
f"parameter {pname!r} has no dependency, static value, or default.",
)
# 规则 3:**kwargs 吞掉剩余依赖结果。
if var_keyword is not None and leftover_dep_results:
# 先合并静态 kwargs,再合并依赖结果(冲突已在上方拒绝)。
merged = dict(spec.kwargs)
merged.update(injected_kwargs)
merged.update(leftover_dep_results)
@@ -154,14 +113,9 @@ def build_call_args(
def describe_injection(spec: TaskSpec[Any]) -> str:
"""生成任务参数注入方式的人类可读描述。
供 ``dry_run`` 使用,在不执行的情况下展示执行计划。
"""
# 使用 effective_fn 而不是 fn,以支持 cmd 参数
"""生成任务参数注入方式的人类可读描述。供 ``dry_run`` 使用。"""
fn = spec.effective_fn
sig = inspect.signature(fn)
# 确定哪些位置参数由 spec.args 填充。
positional_params = [
p
for p, param in sig.parameters.items()
@@ -172,6 +126,7 @@ def describe_injection(spec: TaskSpec[Any]) -> str:
)
]
args_filled = set(positional_params[: len(spec.args)])
all_deps = set(spec.depends_on) | set(spec.soft_depends_on)
parts = []
for pname, param in sig.parameters.items():
if pname in args_filled:
@@ -179,8 +134,9 @@ def describe_injection(spec: TaskSpec[Any]) -> str:
parts.append(f"{pname}={spec.args[idx]!r}")
elif _is_context_annotation(param.annotation):
parts.append(f"{pname}=<Context>")
elif pname in spec.depends_on:
parts.append(f"{pname}=<result:{pname}>")
elif pname in all_deps:
tag = "soft" if pname in spec.soft_depends_on else "dep"
parts.append(f"{pname}=<{tag}:{pname}>")
elif pname in spec.kwargs:
parts.append(f"{pname}={spec.kwargs[pname]!r}")
elif param.default is not inspect.Parameter.empty:
+3 -1
View File
@@ -55,7 +55,9 @@ def main() -> None:
depends_on=("extract_customers", "extract_orders"),
tags=("transform",),
),
px.TaskSpec("load", load, depends_on=("transform",), retries=1, tags=("load",)),
px.TaskSpec(
"load", load, depends_on=("transform",), retry=px.RetryPolicy(max_attempts=1, delay=1.0), tags=("load",)
),
])
print("=== Execution plan ===")
+365 -179
View File
@@ -1,15 +1,26 @@
"""执行器与公共 :func:`run` 入口。
种执行策略共享一个逐层驱动器
种执行策略:
* ``sequential`` —— 确定性、一次一个任务。最适合调试。
* ``thread`` —— 通过线程池实现层内并发。最适合 I/O 密集型同步任务。
* ``async`` —— 通过 ``asyncio.gather`` 实现层内并发。同步任务被
卸载到线程池;异步任务运行在事件循环上。最适合
I/O 密集型异步任务。
* ``dependency`` —— 依赖驱动调度:任务在其所有硬依赖完成后立即启动,
无需等待同层其他任务。最大化并行度。
三者都遵循 ``retries``、``timeout``、上下文注入、状态后端(续跑),
并向观察者发出 :class:`~pyflowx.task.TaskEvent`。
所有策略共享统一异步内核,支持:
* :class:`RetryPolicy`max_attempts/delay/backoff/jitter/retry_on
* 软依赖注入与默认值
* :class:`TaskHooks`pre_run/post_run/on_failure
* 按任务策略覆盖
* 优先级排序(同层内)
* 并发限制(concurrency_key + concurrency_limits
* ``continue_on_error``
* ``cache_key`` 存储键
* 条件判断(上下文感知)
* 状态后端(续跑)
"""
from __future__ import annotations
@@ -18,6 +29,7 @@ import asyncio
import concurrent.futures
import inspect
import logging
import threading
from datetime import datetime
from typing import Any, Awaitable, Callable, Literal, Mapping, cast
@@ -26,24 +38,24 @@ from .errors import TaskFailedError, TaskTimeoutError
from .graph import Graph
from .report import RunReport
from .storage import StateBackend, resolve_backend
from .task import TaskEvent, TaskResult, TaskSpec, TaskStatus
from .task import TaskEvent, TaskHooks, TaskResult, TaskSpec, TaskStatus
logger = logging.getLogger("pyflowx")
# 观察者回调类型。
EventCallback = Callable[[TaskEvent], None]
Strategy = Literal["sequential", "thread", "async"]
Strategy = Literal["sequential", "thread", "async", "dependency"]
# ---------------------------------------------------------------------- #
# 辅助
# ---------------------------------------------------------------------- #
def _is_async_fn(spec: TaskSpec[Any]) -> bool:
"""判断 ``spec.effective_fn`` 是否为协程函数。"""
return inspect.iscoroutinefunction(spec.effective_fn)
def _emit(
on_event: EventCallback | None,
result: TaskResult[Any],
) -> None:
def _emit(on_event: EventCallback | None, result: TaskResult[Any]) -> None:
"""若注册了回调则触发一个观察者事件。"""
if on_event is None:
return
@@ -59,71 +71,60 @@ def _emit(
)
def _log_retry(spec: TaskSpec[Any], attempts: int, max_attempts: int, exc: BaseException) -> None:
"""记录重试日志sync 与 async 共享,便于测试覆盖)"""
def _log_retry(spec: TaskSpec[Any], attempt: int, max_attempts: int, exc: BaseException) -> None:
"""记录重试日志。"""
logger.warning(
"task %r failed (attempt %d/%d): %r; retrying",
spec.name,
attempts,
attempt,
max_attempts,
exc,
)
def _finalize_failure(
result: TaskResult[Any],
layer_idx: int | None,
on_event: EventCallback | None = None,
) -> None:
"""标记任务为 FAILED 并抛出 TaskFailedError。"""
result.status = TaskStatus.FAILED
result.finished_at = datetime.now()
_emit(on_event, result)
raise TaskFailedError(
task=result.spec.name,
cause=result.error if result.error is not None else RuntimeError("unknown"),
attempts=result.attempts,
layer=layer_idx,
)
def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None:
"""安全调用钩子(异常仅记录,不影响任务状态)。"""
hook: Callable[..., None] | None = getattr(hooks, fn_name, None)
if hook is None:
return
try:
hook(*args)
except Exception as exc:
logger.warning("hook %s raised: %r", fn_name, exc)
def _check_upstream_skipped(
spec: TaskSpec[Any],
report: RunReport | None,
) -> tuple[bool, str | None]:
"""检查上游任务是否被 SKIPPED。
"""检查硬依赖上游任务是否被 SKIPPED 或 FAILED
Returns
-------
tuple[bool, str | None]
(是否应该跳过, 跳过原因)
软依赖不影响本检查——软依赖被跳过时注入默认值。
"""
if report is None:
return False, None
# 若任务允许上游跳过,则不检查上游状态
if spec.allow_upstream_skip:
return False, None
for dep in spec.depends_on:
if dep in report.results and report.results[dep].status == TaskStatus.SKIPPED:
return True, f"上游任务 '{dep}' 被跳过"
if dep not in report.results:
continue
dep_status = report.results[dep].status
if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED):
return True, f"上游任务 '{dep}' 状态为 {dep_status.value}"
return False, None
def _evaluate_skip_reason(spec: TaskSpec[Any]) -> str | None:
"""单次求值所有条件与 skip_if_missing,返回跳过原因或 None。
def _evaluate_conditions(spec: TaskSpec[Any], context: Mapping[str, Any]) -> str | None:
"""求值所有条件,返回跳过原因或 ``None``
与旧实现不同:条件只求值一次。`should_execute()` 内部会调用所有条件,
若再分支调用 `_is_cmd_available` 之外的逻辑会二次求值(如
``IS_RUNNING`` 会 spawn 两次 subprocess)。此处显式逐个求值并记录结果,
失败原因直接来自求值过程,无需二次调用。
条件接收上下文映射(硬依赖 + 软依赖结果)。
"""
# 1. 逐个求值条件,记录失败项。
failed_conditions: list[str] = []
for condition in spec.conditions:
try:
ok = condition()
ok = condition(context)
except Exception:
ok = False
name = getattr(condition, "__name__", None) or "匿名条件(执行错误)"
@@ -135,7 +136,6 @@ def _evaluate_skip_reason(spec: TaskSpec[Any]) -> str | None:
if failed_conditions:
return f"条件不满足: {', '.join(failed_conditions)}"
# 2. skip_if_missing 检查(仅对 list[str] 命令有效)。
if spec.skip_if_missing and not spec._is_cmd_available():
cmd_name = spec.cmd[0] if isinstance(spec.cmd, list) and spec.cmd else "unknown"
return f"命令不存在: {cmd_name}"
@@ -148,10 +148,7 @@ def _make_skipped_result(
reason: str,
on_event: EventCallback | None,
) -> TaskResult[Any]:
"""构造 SKIPPED 的 TaskResult 并发出事件、打印日志。
sync 与 async 执行路径共用,消除重复的 result 构造/emit/print 代码。
"""
"""构造 SKIPPED 的 TaskResult"""
result: TaskResult[Any] = TaskResult(
spec=spec,
status=TaskStatus.SKIPPED,
@@ -165,31 +162,118 @@ def _make_skipped_result(
return result
def _build_context(
spec: TaskSpec[Any],
global_context: Mapping[str, Any],
report: RunReport | None = None, # noqa: ARG001
) -> dict[str, Any]:
"""构建本任务的上下文:硬依赖 + 软依赖(含默认值回退)。
硬依赖:若上游 SKIPPED/FAILED 则不注入(本任务通常也会被跳过)。
软依赖:上游成功则注入其值;否则注入 ``spec.defaults`` 中的默认值(或 ``None``)。
"""
ctx: dict[str, Any] = {}
for dep in spec.depends_on:
if dep in global_context:
ctx[dep] = global_context[dep]
for dep in spec.soft_depends_on:
if dep in global_context:
ctx[dep] = global_context[dep]
elif dep in spec.defaults:
ctx[dep] = spec.defaults[dep]
else:
ctx[dep] = None
return ctx
def _apply_cached(
name: str,
spec: TaskSpec[Any],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
) -> bool:
"""若 ``name`` 命中缓存,写入 context/report 并返回 True。"""
storage_key = spec.storage_key(context)
if not backend.has(storage_key):
return False
cached = backend.get(storage_key)
context[name] = cached
result = TaskResult(spec=spec, status=TaskStatus.SKIPPED, value=cached, reason="缓存命中")
report.results[name] = result
_emit(on_event, result)
logger.info("task %r skipped (cached)", name)
return True
def _prepare_for_execution(
spec: TaskSpec[Any],
context: Mapping[str, Any],
report: RunReport | None,
on_event: EventCallback | None,
) -> TaskResult[Any] | None:
"""执行前的统一预检:上游跳过 / 条件跳过。
"""执行前预检:上游跳过 / 条件跳过。
Returns
-------
TaskResult | None
若应跳过,返回已填好的 SKIPPED 结果;否则返回 None 表示继续执行。
返回 SKIPPED TaskResult 或 ``None``(继续执行)。
"""
# 上游跳过检查
should_skip, skip_reason = _check_upstream_skipped(spec, report)
if should_skip:
return _make_skipped_result(spec, skip_reason or "上游任务被跳过", on_event)
# 条件 / skip_if_missing 检查(单次求值)
skip_reason = _evaluate_skip_reason(spec)
skip_reason = _evaluate_conditions(spec, context)
if skip_reason is not None:
return _make_skipped_result(spec, skip_reason, on_event)
return None
def _finalize_failure(
result: TaskResult[Any],
layer_idx: int | None,
on_event: EventCallback | None = None,
continue_on_error: bool = False,
) -> None:
"""标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。"""
result.status = TaskStatus.FAILED
result.finished_at = datetime.now()
_emit(on_event, result)
if continue_on_error:
logger.warning(
"task %r failed but continue_on_error=True; continuing.",
result.spec.name,
)
return
raise TaskFailedError(
task=result.spec.name,
cause=result.error if result.error is not None else RuntimeError("unknown"),
attempts=result.attempts,
layer=layer_idx,
)
def _sleep_for_retry(spec: TaskSpec[Any], attempt: int) -> None:
"""重试前的同步等待。"""
wait = spec.retry.wait_seconds(attempt)
if wait > 0:
import time
time.sleep(wait)
async def _async_sleep_for_retry(spec: TaskSpec[Any], attempt: int) -> None:
"""重试前的异步等待。"""
wait = spec.retry.wait_seconds(attempt)
if wait > 0:
await asyncio.sleep(wait)
# ---------------------------------------------------------------------- #
# 同步执行内核
# ---------------------------------------------------------------------- #
def _run_sync_with_retry(
spec: TaskSpec[Any],
context: Mapping[str, Any],
@@ -198,44 +282,47 @@ def _run_sync_with_retry(
report: RunReport | None = None,
) -> TaskResult[Any]:
"""执行同步任务并带重试;返回填充好的 TaskResult。"""
# 统一预检:上游跳过 / 条件跳过(条件单次求值)
skipped = _prepare_for_execution(spec, report, on_event)
skipped = _prepare_for_execution(spec, context, report, on_event)
if skipped is not None:
return skipped
result: TaskResult[Any] = TaskResult(spec=spec)
result.started_at = datetime.now()
max_attempts = spec.retries + 1
max_attempts = spec.retry.max_attempts
args, kwargs = build_call_args(spec, context)
_run_hooks(spec.hooks, "pre_run", spec)
while True:
result.attempts += 1
try:
result.value = spec.effective_fn(*args, **kwargs)
with spec.env_context():
result.value = spec.effective_fn(*args, **kwargs)
result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now()
_run_hooks(spec.hooks, "post_run", spec, result.value)
return result
except Exception as exc:
result.error = exc
if result.attempts >= max_attempts:
_finalize_failure(result, layer_idx, on_event)
if result.attempts >= max_attempts or not spec.retry.should_retry(exc):
_run_hooks(spec.hooks, "on_failure", spec, exc)
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return result
_log_retry(spec, result.attempts, max_attempts, exc)
raise AssertionError("unreachable") # pragma: no cover
_sleep_for_retry(spec, result.attempts)
# pragma: no cover
# ---------------------------------------------------------------------- #
# 异步执行内核
# ---------------------------------------------------------------------- #
async def _execute_async_task(
spec: TaskSpec[Any],
args: tuple[Any, ...],
kwargs: dict[str, Any],
loop: asyncio.AbstractEventLoop,
) -> Any:
"""执行异步或同步任务(带超时处理)。
Returns
-------
Any
任务返回值
"""
"""执行异步或同步任务(带超时处理)。"""
if _is_async_fn(spec):
coro = cast(Awaitable[Any], spec.effective_fn(*args, **kwargs))
if spec.timeout is not None:
@@ -243,9 +330,10 @@ async def _execute_async_task(
else:
return await coro
else:
# 将同步工作卸载到线程,保持事件循环存活。
def fn_call() -> Any:
return spec.effective_fn(*args, **kwargs)
with spec.env_context():
return spec.effective_fn(*args, **kwargs)
if spec.timeout is not None:
return await asyncio.wait_for(loop.run_in_executor(None, fn_call), timeout=spec.timeout)
@@ -259,76 +347,74 @@ async def _run_async_with_retry(
layer_idx: int | None,
on_event: EventCallback | None = None,
report: RunReport | None = None,
semaphore: asyncio.Semaphore | None = None,
) -> TaskResult[Any]:
"""在事件循环上执行任务(同步或异步)并带重试。"""
# 统一预检:上游跳过 / 条件跳过(条件单次求值)
skipped = _prepare_for_execution(spec, report, on_event)
skipped = _prepare_for_execution(spec, context, report, on_event)
if skipped is not None:
return skipped
result: TaskResult[Any] = TaskResult[Any](spec=spec)
if semaphore is not None:
async with semaphore:
return await _run_async_inner(spec, context, layer_idx, on_event, report)
return await _run_async_inner(spec, context, layer_idx, on_event, report)
async def _run_async_inner(
spec: TaskSpec[Any],
context: Mapping[str, Any],
layer_idx: int | None,
on_event: EventCallback | None = None,
report: RunReport | None = None, # noqa: ARG001
) -> TaskResult[Any]:
"""异步执行内核的内部实现(已获取 semaphore 后)。"""
result: TaskResult[Any] = TaskResult(spec=spec)
result.started_at = datetime.now()
max_attempts = spec.retries + 1
max_attempts = spec.retry.max_attempts
args, kwargs = build_call_args(spec, context)
loop = asyncio.get_event_loop()
_run_hooks(spec.hooks, "pre_run", spec)
while True:
result.attempts += 1
try:
result.value = await _execute_async_task(spec, args, kwargs, loop)
result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now()
_run_hooks(spec.hooks, "post_run", spec, result.value)
return result
except asyncio.TimeoutError:
result.error = TaskTimeoutError(spec.name, spec.timeout or 0.0)
if result.attempts >= max_attempts:
_finalize_failure(result, layer_idx, on_event)
exc: BaseException = TaskTimeoutError(spec.name, spec.timeout or 0.0)
result.error = exc
if result.attempts >= max_attempts or not spec.retry.should_retry(exc):
_run_hooks(spec.hooks, "on_failure", spec, exc)
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return result
logger.warning(
"task %r timed out (attempt %d/%d); retrying",
spec.name,
result.attempts,
max_attempts,
)
await _async_sleep_for_retry(spec, result.attempts)
except Exception as exc:
result.error = exc
if result.attempts >= max_attempts:
_finalize_failure(result, layer_idx, on_event)
if result.attempts >= max_attempts or not spec.retry.should_retry(exc):
_run_hooks(spec.hooks, "on_failure", spec, exc)
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return result
_log_retry(spec, result.attempts, max_attempts, exc)
raise AssertionError("unreachable") # pragma: no cover
await _async_sleep_for_retry(spec, result.attempts)
# pragma: no cover
# ---------------------------------------------------------------------- #
# 层驱动
# 层执行
# ---------------------------------------------------------------------- #
def _build_context(
spec: TaskSpec[Any],
global_context: Mapping[str, Any],
) -> Mapping[str, Any]:
"""将全局上下文限制为本任务的依赖。"""
return {dep: global_context[dep] for dep in spec.depends_on if dep in global_context}
def _apply_cached(
name: str,
graph: Graph,
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
) -> bool:
"""若 ``name`` 命中缓存,写入 context/report 并返回 True;否则返回 False。
sequential / thread / async 三种层驱动共用,消除缓存命中分支的重复代码。
"""
if not backend.has(name):
return False
cached = backend.get(name)
context[name] = cached
result = TaskResult(spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached, reason="缓存命中")
report.results[name] = result
_emit(on_event, result)
logger.info("task %r skipped (cached)", name)
return True
def _sort_by_priority(layer: list[str], graph: Graph) -> list[str]:
"""按优先级降序排序(稳定排序)。"""
return sorted(layer, key=lambda n: -graph.resolved_spec(n).priority)
def _execute_layer_sequential(
@@ -340,14 +426,16 @@ def _execute_layer_sequential(
layer_idx: int,
on_event: EventCallback | None,
) -> None:
"""逐个运行某层的任务。"""
for name in layer:
spec = graph.spec(name)
if _apply_cached(name, graph, context, report, backend, on_event):
"""逐个运行某层的任务(按优先级排序)"""
for name in _sort_by_priority(layer, graph):
spec = graph.resolved_spec(name)
if _apply_cached(name, spec, context, report, backend, on_event):
continue
result = _run_sync_with_retry(spec, _build_context(spec, context), layer_idx, on_event, report)
task_ctx = _build_context(spec, context, report)
result = _run_sync_with_retry(spec, task_ctx, layer_idx, on_event, report)
context[name] = result.value
backend.save(name, result.value)
if result.status == TaskStatus.SUCCESS:
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
@@ -361,42 +449,68 @@ def _execute_layer_threaded(
layer_idx: int,
on_event: EventCallback | None,
max_workers: int,
concurrency_limits: Mapping[str, int],
) -> None:
"""在线程池中并发运行某层的任务。"""
# 先同步满足已缓存任务。
to_run: list[str] = []
for name in layer:
if _apply_cached(name, graph, context, report, backend, on_event):
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context, report)
if _apply_cached(name, spec, context, report, backend, on_event):
continue
to_run.append(name)
if not to_run:
return
to_run = _sort_by_priority(to_run, graph)
# 为每个 concurrency_key 创建线程信号量
semaphores: dict[str, threading.Semaphore] = {}
for name in to_run:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = threading.Semaphore(limit)
context_snapshot = dict(context)
lock = threading.Lock()
def _run_threaded_task(name: str) -> TaskResult[Any]:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report)
sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None
if sem is not None:
sem.acquire()
try:
return _run_sync_with_retry(spec, task_ctx, layer_idx, on_event, report)
finally:
if sem is not None:
sem.release()
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
future_to_name: dict[concurrent.futures.Future[TaskResult[Any]], str] = {}
for name in to_run:
spec = graph.spec(name)
# 为本任务快照上下文以避免竞态。
task_ctx = _build_context(spec, context)
fut = pool.submit(_run_sync_with_retry, spec, task_ctx, layer_idx, on_event, report)
fut = pool.submit(_run_threaded_task, name)
future_to_name[fut] = name
# 统一收集后再写 context,与 async 版本行为一致:
# 避免边完成边写共享 dict 造成的可见性不一致。
completed: dict[str, TaskResult[Any]] = {}
try:
for fut in concurrent.futures.as_completed(future_to_name):
name = future_to_name[fut]
result = fut.result() # 失败时抛出 TaskFailedError
result = fut.result()
completed[name] = result
finally:
# 无论是否抛出,都先把已完成任务的结果落盘并写回 context/report。
for name, result in completed.items():
context[name] = result.value
backend.save(name, result.value)
report.results[name] = result
_emit(on_event, result)
with lock:
for name, result in completed.items():
context[name] = result.value
if result.status == TaskStatus.SUCCESS:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report)
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
async def _execute_layer_async(
@@ -407,52 +521,122 @@ async def _execute_layer_async(
backend: StateBackend,
layer_idx: int,
on_event: EventCallback | None,
concurrency_limits: Mapping[str, int],
) -> None:
"""在事件循环上并发运行某层的任务。"""
to_run: list[str] = []
for name in layer:
if _apply_cached(name, graph, context, report, backend, on_event):
spec = graph.resolved_spec(name)
if _apply_cached(name, spec, context, report, backend, on_event):
continue
to_run.append(name)
if not to_run:
return
coros = []
for name in to_run:
spec = graph.spec(name)
task_ctx = _build_context(spec, context)
coros.append(_run_async_with_retry(spec, task_ctx, layer_idx, on_event, report))
to_run = _sort_by_priority(to_run, graph)
# 为每个 concurrency_key 创建异步信号量
semaphores: dict[str, asyncio.Semaphore] = {}
for name in to_run:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = asyncio.Semaphore(limit)
context_snapshot = dict(context)
async def _run_async_task_wrapped(name: str) -> TaskResult[Any]:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report)
sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None
if sem is not None:
async with sem:
return await _run_async_with_retry(spec, task_ctx, layer_idx, on_event, report)
return await _run_async_with_retry(spec, task_ctx, layer_idx, on_event, report)
coros = [_run_async_task_wrapped(name) for name in to_run]
results = await asyncio.gather(*coros)
for name, result in zip(to_run, results):
context[name] = result.value
backend.save(name, result.value)
if result.status == TaskStatus.SUCCESS:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report)
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
# ---------------------------------------------------------------------- #
# 依赖驱动调度
# ---------------------------------------------------------------------- #
async def _drive_dependency_async(
graph: Graph,
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
concurrency_limits: Mapping[str, int],
) -> None:
"""依赖驱动调度:任务在硬依赖完成后立即启动,无层屏障。
所有任务通过 asyncio 并发调度。同步任务卸载到线程池。
"""
all_names = set(graph.all_specs().keys())
semaphores: dict[str, asyncio.Semaphore] = {}
for name in all_names:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = asyncio.Semaphore(limit)
futures: dict[str, asyncio.Future[TaskResult[Any]]] = {}
async def _run_task(name: str) -> TaskResult[Any]:
spec = graph.resolved_spec(name)
# 等待所有硬依赖完成
for dep in spec.depends_on:
if dep in futures:
await futures[dep]
# 等待所有软依赖完成(但不检查其状态)
for dep in spec.soft_depends_on:
if dep in futures:
await futures[dep]
task_ctx = _build_context(spec, context, report)
if _apply_cached(name, spec, context, report, backend, on_event):
return report.results[name]
sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None
if sem is not None:
async with sem:
result = await _run_async_with_retry(spec, task_ctx, None, on_event, report)
else:
result = await _run_async_with_retry(spec, task_ctx, None, on_event, report)
context[name] = result.value
if result.status == TaskStatus.SUCCESS:
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
return result
loop = asyncio.get_event_loop()
for name in all_names:
futures[name] = loop.create_task(_run_task(name))
await asyncio.gather(*futures.values())
# ---------------------------------------------------------------------- #
# 公共 API
# ---------------------------------------------------------------------- #
def _make_verbose_callback(
on_event: EventCallback | None,
) -> EventCallback | None:
"""包装 on_event 回调, 在 verbose 模式下打印任务生命周期.
Parameters
----------
on_event : EventCallback | None
用户提供的原始回调, 若为 None 则仅打印.
Returns
-------
EventCallback | None
包装后的回调.
"""
def _make_verbose_callback(on_event: EventCallback | None) -> EventCallback:
"""包装 on_event 回调, 在 verbose 模式下打印任务生命周期。"""
def _verbose_callback(event: TaskEvent) -> None:
# 先打印生命周期信息
dur = f" ({event.duration:.3f}s)" if event.duration is not None else ""
if event.status == TaskStatus.RUNNING: # pragma: no cover
print(f"[verbose] 任务 {event.task!r} 开始执行...", flush=True)
@@ -464,13 +648,9 @@ def _make_verbose_callback(
f"[verbose] 任务 {event.task!r} 失败{dur} (尝试 {event.attempts} 次){err}",
flush=True,
)
elif event.status == TaskStatus.SKIPPED: # pragma: no branch
elif event.status == TaskStatus.SKIPPED:
reason = f" ({event.reason})" if event.reason else ""
print(f"[verbose] 任务 {event.task!r} 跳过{reason}", flush=True)
else: # pragma: no cover
# 不可达: 执行器只发出 RUNNING/SUCCESS/FAILED/SKIPPED 事件
pass
# 再调用用户回调
if on_event is not None:
on_event(event)
@@ -486,6 +666,7 @@ def run(
verbose: bool = False,
on_event: EventCallback | None = None,
state: StateBackend | None = None,
concurrency_limits: Mapping[str, int] | None = None,
) -> RunReport:
"""执行图并返回 :class:`RunReport`。
@@ -494,29 +675,28 @@ def run(
graph:
待执行的已校验 :class:`Graph`。
strategy:
执行策略, 接受 :class:`Strategy` 枚举成员或字符串
(``"sequential"`` / ``"thread"`` / ``"async"``). 默认 ``Strategy.SEQUENTIAL``.
执行策略: ``"sequential"`` / ``"thread"`` / ``"async"`` /
``"dependency"````"dependency"`` 为依赖驱动调度,无层屏障。
max_workers:
``"thread"`` 的线程池大小。默认 ``min(32, len(layer))``。
dry_run:
若为 ``True``,打印执行计划(层 + 注入)并返回空报告,不执行
任何任务。
若为 ``True``,打印执行计划并返回空报告,不执行任务。
verbose:
若为 ``True``, 打印任务生命周期 (开始/成功/失败/跳过) 到 stdout.
注意: subprocess 命令的输出由 :class:`TaskSpec` 的 ``verbose`` 字段控制.
若为 ``True``, 打印任务生命周期到 stdout
on_event:
可选回调,在每次状态转换时调用。
state:
可选 :class:`StateBackend`,用于断点续跑。默认为内存后端
(不跨进程持久化)。
可选 :class:`StateBackend`,用于断点续跑。
concurrency_limits:
``{concurrency_key: max_concurrent}`` 映射。具有相同
``concurrency_key`` 的任务共享信号量,限制同时运行实例数。
抛出
----
ValueError
``strategy`` 不被识别时。
TaskFailedError
任何任务耗尽重试后仍失败时。运行在失败层中止;后续层的任务
不会被执行。
任何任务耗尽重试后仍失败时(除非 ``continue_on_error=True``)。
"""
graph.validate()
layers = graph.layers()
@@ -525,20 +705,23 @@ def run(
_print_dry_run(graph, layers)
return RunReport(success=True)
# verbose 模式下包装事件回调
effective_callback: EventCallback | None = _make_verbose_callback(on_event) if verbose else on_event
backend = resolve_backend(state)
report = RunReport()
context: dict[str, Any] = {}
limits = concurrency_limits or {}
try:
if strategy == "sequential":
_drive_sequential(graph, layers, context, report, backend, effective_callback)
elif strategy == "thread":
_drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers)
_drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers, limits)
elif strategy == "async":
_drive_async(graph, layers, context, report, backend, effective_callback, limits)
elif strategy == "dependency":
asyncio.run(_drive_dependency_async(graph, context, report, backend, effective_callback, limits))
else:
_drive_async(graph, layers, context, report, backend, effective_callback)
raise ValueError(f"Unknown strategy: {strategy!r}")
except TaskFailedError:
report.success = False
raise
@@ -552,7 +735,7 @@ def _print_dry_run(graph: Graph, layers: list[list[str]]) -> None:
for idx, layer in enumerate(layers, 1):
print(f" Layer {idx}: {layer}")
for name in layer:
print(f" - {describe_injection(graph.spec(name))}")
print(f" - {describe_injection(graph.resolved_spec(name))}")
def _drive_sequential(
@@ -575,10 +758,11 @@ def _drive_threaded(
backend: StateBackend,
on_event: EventCallback | None,
max_workers: int | None,
concurrency_limits: Mapping[str, int],
) -> None:
for idx, layer in enumerate(layers, 1):
workers = max_workers or max(1, min(32, len(layer)))
_execute_layer_threaded(layer, graph, context, report, backend, idx, on_event, workers)
_execute_layer_threaded(layer, graph, context, report, backend, idx, on_event, workers, concurrency_limits)
def _drive_async(
@@ -588,8 +772,9 @@ def _drive_async(
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
concurrency_limits: Mapping[str, int],
) -> None:
asyncio.run(_async_drive(graph, layers, context, report, backend, on_event))
asyncio.run(_async_drive(graph, layers, context, report, backend, on_event, concurrency_limits))
async def _async_drive(
@@ -599,6 +784,7 @@ async def _async_drive(
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
concurrency_limits: Mapping[str, int],
) -> None:
for idx, layer in enumerate(layers, 1):
await _execute_layer_async(layer, graph, context, report, backend, idx, on_event)
await _execute_layer_async(layer, graph, context, report, backend, idx, on_event, concurrency_limits)
+202 -140
View File
@@ -2,28 +2,53 @@
使用标准库的 :mod:`graphlib`3.9+)或 :mod:`graphlib_backport`3.8
进行拓扑排序。图以增量方式构建并即时校验,使配置错误在构建时(而非执行时)快速失败。
支持:
* 图级默认值 :class:`GraphDefaults`TaskSpec 字段为 ``None`` 时回退。
* :meth:`Graph.map` 工厂批量生成 fan-out 任务。
* 字符串引用与 :func:`compose` 编程式组合多个图。
* 软依赖:仅用于上下文注入,不参与拓扑分层。
"""
from __future__ import annotations
import sys
from dataclasses import dataclass, field, replace
from typing import Any, Iterable, Mapping, Sequence
from typing import Any, Callable, Iterable, Mapping, Sequence
from .errors import CycleError, DuplicateTaskError, MissingDependencyError
from .task import TaskSpec
from .task import RetryPolicy, TaskSpec
# graphlib 自 3.9 起进入标准库;3.8 回退到 backport。
if sys.version_info >= (3, 9): # pragma: no cover
import graphlib # pyright: ignore[reportUnreachable]
_TopologicalSorter = graphlib.TopologicalSorter
else: # pragma: no cover
import graphlib # type: ignore[import-untyped] # pragma: no cover
import graphlib # type: ignore[import-untyped]
_TopologicalSorter = graphlib.TopologicalSorter # pragma: no cover
@dataclass
class GraphDefaults:
"""图级默认值。TaskSpec 对应字段为 ``None`` 时回退到此处。
仅对可空字段生效(retry/timeout/strategy/env/cwd/tags/priority/
continue_on_error/concurrency_key)。非空字段(name/fn/cmd)不回退。
"""
retry: RetryPolicy | None = None
timeout: float | None = None
strategy: str | None = None
tags: tuple[str, ...] = ()
env: Mapping[str, str] | None = None
cwd: Any = None # Path | None
priority: int = 0
continue_on_error: bool = False
concurrency_key: str | None = None
verbose: bool = False
@dataclass
class Graph:
"""校验后的有向无环任务图。
@@ -34,16 +59,11 @@ class Graph:
图仅持有*配置*;运行时状态存于 :class:`~pyflowx.report.RunReport`。
这使图可安全重复运行并在线程间共享。
Note
-----
Graph 不再使用 ``frozen=True``:内部 ``specs``/``deps`` 本就是可变 dict
frozen 既无法真正保证不可变,又迫使 ``_pending_refs`` 等场景用
``object.__setattr__`` 绕过。改为普通 dataclass,让赋值显式且可审计。
"""
specs: dict[str, TaskSpec[Any]] = field(default_factory=dict)
deps: dict[str, tuple[str, ...]] = field(default_factory=dict)
defaults: GraphDefaults = field(default_factory=GraphDefaults)
# 待解析的字符串引用列表(由 GraphComposer 消费);为空表示无引用。
_pending_refs: list[str] = field(default_factory=list)
@@ -51,69 +71,47 @@ class Graph:
# 构建
# ------------------------------------------------------------------ #
def add(self, spec: TaskSpec[Any]) -> Graph:
"""注册一个任务 spec,并即时校验。
返回 ``self`` 以支持链式调用,但推荐入口是 :meth:`from_specs`
它会整批校验(允许单次调用中的前向引用)。
"""
if spec.name in self.specs:
raise DuplicateTaskError(spec.name)
self.specs[spec.name] = spec
self.deps[spec.name] = spec.depends_on
# 为增量 API 即时检查重名与缺失依赖。
"""注册一个任务 spec,并即时校验。返回 ``self`` 支持链式调用。"""
self._register(spec)
self._validate_references()
return self
def _register(self, spec: TaskSpec[Any]) -> None:
if spec.name in self.specs:
raise DuplicateTaskError(spec.name)
self.specs[spec.name] = spec
# 拓扑依赖仅含硬依赖;软依赖仅用于注入,不影响分层。
self.deps[spec.name] = spec.depends_on
@classmethod
def from_specs(cls, specs: Iterable[TaskSpec[Any] | str]) -> Graph:
"""从可迭代的 task spec 构建图.
def from_specs(
cls,
specs: Iterable[TaskSpec[Any] | str],
defaults: GraphDefaults | None = None,
) -> Graph:
"""从可迭代的 task spec 构建图。
先收集所有 spec,再统一校验。这意味着任务可以引用*后出现*的
依赖——顺序无关,就像声明式配置文件的读取方式
支持字符串引用,允许引用其他命令图中的任务。
字符串引用将在CliRunner中解析展开。
先收集所有 spec,再统一校验。允许前向引用。支持字符串引用,
由 :func:`compose` 或 :class:`GraphComposer` 解析展开
Parameters
----------
specs : Iterable[TaskSpec[Any] | str]
TaskSpec对象或字符串引用的列表
Returns
-------
Graph
构建完成的图
Note
-----
字符串引用格式:
- "command_name" - 引用整个命令图
- "command_name.task_name" - 引用特定任务
Examples
--------
>>> graph = Graph.from_specs([
... TaskSpec("build", cmd=["uv", "build"]),
... "test", # 引用test命令图
... ])
specs:
TaskSpec 对象或字符串引用的列表
defaults:
图级默认值。``None`` 使用空 :class:`GraphDefaults`。
"""
graph = cls()
graph = cls(defaults=defaults or GraphDefaults())
pending_refs: list[str] = []
for spec in specs:
if isinstance(spec, str):
# 字符串引用,稍后解析
pending_refs.append(spec)
elif isinstance(spec, TaskSpec):
if spec.name in graph.specs:
raise DuplicateTaskError(spec.name)
graph.specs[spec.name] = spec
graph.deps[spec.name] = spec.depends_on
graph._register(spec)
else:
raise TypeError(f"from_specs只接受TaskSpecstr,收到: {type(spec)}")
raise TypeError(f"from_specs 只接受 TaskSpecstr,收到: {type(spec)}")
# 存储待解析的引用,稍后由 GraphComposer 解析展开。
# Graph 不再 frozen,可直接赋值;保留属性名以保持向后兼容。
if pending_refs:
graph._pending_refs = pending_refs
@@ -125,26 +123,22 @@ class Graph:
# 校验
# ------------------------------------------------------------------ #
def _validate_references(self) -> None:
"""确保每个依赖名都存在于图中。"""
for name, deps in self.deps.items():
for dep in deps:
"""确保每个依赖名都存在于图中。硬依赖与软依赖都校验。"""
for name, spec in self.specs.items():
for dep in spec.depends_on:
if dep not in self.specs:
raise MissingDependencyError(name, dep)
for dep in spec.soft_depends_on:
if dep not in self.specs:
raise MissingDependencyError(name, dep)
def validate(self) -> None:
"""执行完整 DAG 校验。
存在环时抛出 :class:`~pyflowx.errors.CycleError`。
依赖存在性由 :meth:`_validate_references` 检查。
"""
"""执行完整 DAG 校验。存在环时抛出 :class:`CycleError`。"""
self._validate_references()
sorter = _TopologicalSorter(self.deps)
try:
# prepare() 在有环时抛出 CycleError;此处不需要
# static_order() 的结果,仅利用其校验副作用。
sorter.prepare()
except graphlib.CycleError as exc:
# exc.args[1] 是构成环的节点列表。
except graphlib.CycleError as exc: # type: ignore[name-defined]
cycle: Sequence[str] = exc.args[1] if len(exc.args) > 1 else []
raise CycleError(list(cycle)) from exc
@@ -160,10 +154,49 @@ class Graph:
"""返回 ``name`` 的 spec;不存在则 ``KeyError``。"""
return self.specs[name]
def resolved_spec(self, name: str) -> TaskSpec[Any]:
"""返回应用图级默认值后的 spec(不修改原图)。
对于 ``retry``/``timeout``/``strategy``/``env``/``cwd`` 等可空
字段,若 spec 字段为默认空值且图级默认值非空,则用
:func:`dataclasses.replace` 生成带默认值的副本。
"""
spec = self.specs[name]
d = self.defaults
overrides: dict[str, Any] = {}
if spec.retry == RetryPolicy() and d.retry is not None:
overrides["retry"] = d.retry
if spec.timeout is None and d.timeout is not None:
overrides["timeout"] = d.timeout
if spec.strategy is None and d.strategy is not None:
overrides["strategy"] = d.strategy
if spec.env is None and d.env is not None:
overrides["env"] = d.env
if spec.cwd is None and d.cwd is not None:
overrides["cwd"] = d.cwd
if spec.priority == 0 and d.priority != 0:
overrides["priority"] = d.priority
if not spec.continue_on_error and d.continue_on_error:
overrides["continue_on_error"] = True
if spec.concurrency_key is None and d.concurrency_key is not None:
overrides["concurrency_key"] = d.concurrency_key
if not spec.verbose and d.verbose:
overrides["verbose"] = True
if not spec.tags and d.tags:
overrides["tags"] = d.tags
if not overrides:
return spec
return replace(spec, **overrides)
def dependencies(self, name: str) -> tuple[str, ...]:
"""``name`` 的直接前驱。"""
"""``name`` 的直接硬依赖前驱。"""
return self.deps[name]
def all_deps(self, name: str) -> tuple[str, ...]:
"""``name`` 的硬依赖 + 软依赖。"""
spec = self.specs[name]
return tuple(spec.depends_on) + tuple(spec.soft_depends_on)
def all_specs(self) -> Mapping[str, TaskSpec[Any]]:
"""name -> spec 的只读视图。"""
return self.specs
@@ -171,18 +204,15 @@ class Graph:
def layers(self) -> list[list[str]]:
"""将任务分组为可并行执行的层(Kahn 算法)。
同层任务无相互依赖,可并发执行。层按执行顺序返回
图有环时抛出 :class:`~pyflowx.errors.CycleError`。
同层任务无相互依赖,可并发执行。软依赖不参与分层
层按执行顺序返回。图有环时抛出 :class:`CycleError`。
"""
self.validate()
sorter = _TopologicalSorter(self.deps)
result: list[list[str]] = []
# ``get_ready`` + ``done`` 每次给出一层,正好是并行执行所需的分组。
sorter.prepare()
while sorter.is_active():
ready = list(sorter.get_ready())
# 排序以保证确定性、可复现的执行计划。
ready.sort()
result.append(ready)
for node in ready:
@@ -193,12 +223,7 @@ class Graph:
# 子图 / 标签过滤
# ------------------------------------------------------------------ #
def subgraph(self, tags: Iterable[str]) -> Graph:
"""返回仅包含匹配任意标签的任务的新图。
依赖会被修剪,仅保留被保留任务之间的边;指向被丢弃任务的边
会被移除(被保留的任务不再等待它们)。用于调试时运行大型
DAG 的切片。
"""
"""返回仅包含匹配任意标签的任务的新图。依赖边被修剪。"""
wanted: set[str] = set(tags)
kept: list[TaskSpec[Any]] = []
for spec in self.specs.values():
@@ -206,10 +231,11 @@ class Graph:
pruned_deps = tuple(
d for d in spec.depends_on if d in self.specs and (wanted & set(self.specs[d].tags))
)
# 使用 replace 保留所有字段(verbose/skip_if_missing/allow_upstream_skip 等),
# 避免手动逐字段重建时遗漏新增字段。
kept.append(replace(spec, depends_on=pruned_deps))
return Graph.from_specs(kept)
pruned_soft = tuple(
d for d in spec.soft_depends_on if d in self.specs and (wanted & set(self.specs[d].tags))
)
kept.append(replace(spec, depends_on=pruned_deps, soft_depends_on=pruned_soft))
return Graph.from_specs(kept, defaults=self.defaults)
def subgraph_by_names(self, names: Iterable[str]) -> Graph:
"""返回限定于 ``names`` 的新图(边已修剪)。"""
@@ -221,18 +247,71 @@ class Graph:
for spec in self.specs.values():
if spec.name in wanted:
pruned_deps = tuple(d for d in spec.depends_on if d in wanted)
kept.append(replace(spec, depends_on=pruned_deps))
return Graph.from_specs(kept)
pruned_soft = tuple(d for d in spec.soft_depends_on if d in wanted)
kept.append(replace(spec, depends_on=pruned_deps, soft_depends_on=pruned_soft))
return Graph.from_specs(kept, defaults=self.defaults)
# ------------------------------------------------------------------ #
# Fan-out / map-reduce
# ------------------------------------------------------------------ #
def map(
self,
name_fn: Callable[[int], str],
spec: TaskSpec[Any],
items: Sequence[Any],
arg_factory: Callable[[Any], tuple[Any, ...]] | None = None,
depends_on_per: Callable[[int], tuple[str, ...]] | None = None,
) -> list[TaskSpec[Any]]:
"""为 ``items`` 中每个元素生成一个 TaskSpec 并加入图。
用于 fan-out / map-reduce 模式。返回生成的 spec 列表,便于
后续 reduce 任务依赖。
Parameters
----------
name_fn:
接受索引 ``i``,返回任务名。需保证唯一。
spec:
模板 spec。其 ``name`` 与 ``args`` 会被覆盖。
items:
待分发的数据序列。
arg_factory:
接受一个 item,返回位置参数元组,覆盖 spec.args。
``None`` 则将单个 item 作为唯一位置参数。
depends_on_per:
接受索引 ``i``,返回该任务的额外硬依赖。``None`` 则继承 spec.depends_on。
Returns
-------
list[TaskSpec]
生成的 spec 列表(已加入图)。
Examples
--------
>>> fetch_tmpl = px.TaskSpec("", fn=fetch_user)
>>> specs = graph.map(lambda i: f"fetch_{i}", fetch_tmpl, [1, 2, 3])
>>> reduce_spec = px.TaskSpec("reduce", fn=reduce_fn, depends_on=tuple(s.name for s in specs))
"""
generated: list[TaskSpec[Any]] = []
for i, item in enumerate(items):
name = name_fn(i)
args = arg_factory(item) if arg_factory is not None else (item,)
extra_deps = depends_on_per(i) if depends_on_per is not None else ()
new_spec = replace(
spec,
name=name,
args=tuple(args),
depends_on=tuple(spec.depends_on) + tuple(extra_deps),
)
self.add(new_spec)
generated.append(new_spec)
return generated
# ------------------------------------------------------------------ #
# 可视化
# ------------------------------------------------------------------ #
def to_mermaid(self, orientation: str = "TD") -> str:
"""将 DAG 渲染为 Mermaid ``graph`` 定义字符串。
无外部依赖;输出可粘贴到 Markdown、由 VS Code 的 Mermaid 预览
渲染,或保存为文件。
"""
"""将 DAG 渲染为 Mermaid ``graph`` 定义字符串。"""
valid = {"TD", "TB", "BT", "LR", "RL"}
orientation = orientation.upper()
if orientation not in valid:
@@ -243,6 +322,10 @@ class Graph:
for name, deps in self.deps.items():
for dep in deps:
lines.append(f" {dep} --> {name}")
# 软依赖用虚线
for name, spec in self.specs.items():
for dep in spec.soft_depends_on:
lines.append(f" {dep} -.-> {name}")
return "\n".join(lines) + "\n"
# ------------------------------------------------------------------ #
@@ -268,19 +351,12 @@ class Graph:
class GraphComposer:
"""将带字符串引用的图展开为纯 :class:`TaskSpec` 图。
从 ``CliRunner`` 抽出,使 ``Graph``(数据)与引用解析(组合逻辑)
职责分离。引用按顺序展开,后续引用的任务依赖前面引用的最后一个任务;
原始 ``TaskSpec`` 之间也按出现顺序串行依赖。
引用格式
--------
引用格式:
* ``"command_name"`` —— 引用整个命令图。
* ``"command_name.task_name"`` —— 引用特定任务。
Parameters
----------
graphs : dict[str, Graph]
命令名到图的映射,引用据此解析。
引用按顺序展开,后续引用的任务依赖前面引用的最后一个任务;
原始 ``TaskSpec`` 之间也按出现顺序串行依赖。
"""
def __init__(self, graphs: dict[str, Graph]) -> None:
@@ -294,18 +370,7 @@ class GraphComposer:
return resolved
def expand_refs(self, graph: Graph, current_cmd: str) -> Graph:
"""展开图中的字符串引用。
若图无 ``_pending_refs``,原样返回。
Note
-----
引用按顺序展开,后续引用的任务依赖于前面引用的任务完成。
例如 ``["c", "tc", bump]`` 展开为:
- c 的所有任务(无依赖)
- tc 的所有任务(依赖于 c 的最后一个任务)
- bump 任务(依赖于 tc 的最后一个任务)
"""
"""展开图中的字符串引用。若无 ``_pending_refs``,原样返回。"""
pending_refs = graph._pending_refs
if not pending_refs:
return graph
@@ -313,23 +378,16 @@ class GraphComposer:
all_specs: list[TaskSpec[Any]] = []
previous_ref_last_task: str | None = None
# 先解析每个引用,并建立依赖链。
for ref in pending_refs:
expanded_specs = self.parse_ref(ref, current_cmd)
# 若有前一个引用,让当前引用的任务依赖其最后一个任务。
if previous_ref_last_task and expanded_specs:
for i, task in enumerate(expanded_specs):
# 只为没有依赖的任务(或第一个任务)添加依赖。
if i == 0 or not task.depends_on:
expanded_specs[i] = replace(task, depends_on=tuple({*task.depends_on, previous_ref_last_task}))
if expanded_specs:
previous_ref_last_task = expanded_specs[-1].name
all_specs.extend(expanded_specs)
# 然后添加原始 TaskSpec,按出现顺序串行依赖。
original_specs = list(graph.all_specs().values())
if original_specs:
if previous_ref_last_task:
@@ -337,49 +395,53 @@ class GraphComposer:
all_specs.append(replace(first, depends_on=tuple({*first.depends_on, previous_ref_last_task})))
else:
all_specs.append(original_specs[0])
for i in range(1, len(original_specs)):
current_task = original_specs[i]
previous_task_name = original_specs[i - 1].name
all_specs.append(
replace(
current_task,
depends_on=tuple({*current_task.depends_on, previous_task_name}),
)
replace(current_task, depends_on=tuple({*current_task.depends_on, previous_task_name}))
)
return Graph.from_specs(all_specs)
return Graph.from_specs(all_specs, defaults=graph.defaults)
def parse_ref(self, ref: str, current_cmd: str) -> list[TaskSpec[Any]]:
"""解析单个字符串引用,返回对应的 TaskSpec 列表。
Raises
------
ValueError
引用无效、目标命令/任务不存在,或检测到循环引用。
"""
# 避免循环引用。
"""解析单个字符串引用,返回对应的 TaskSpec 列表。"""
if ref == current_cmd:
raise ValueError(f"循环引用: 命令 '{current_cmd}' 引用了自己")
if "." in ref:
# 特定任务引用: "command_name.task_name"
cmd_name, task_name = ref.split(".", 1)
if cmd_name not in self.graphs:
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
ref_graph = self.graphs[cmd_name]
if task_name not in ref_graph.all_specs():
raise ValueError(f"任务 '{task_name}' 不存在于命令 '{cmd_name}'")
return [ref_graph.all_specs()[task_name]]
else:
# 整个命令图引用: "command_name"
cmd_name = ref
if cmd_name not in self.graphs:
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
ref_graph = self.graphs[cmd_name]
# 递归展开(若引用的图自身也含引用)。
ref_graph = self.expand_refs(ref_graph, cmd_name)
return list(ref_graph.all_specs().values())
def compose(
graphs: dict[str, Graph],
) -> dict[str, Graph]:
"""编程式解析多图的字符串引用,返回展开后的新图映射。
与 :class:`GraphComposer` 等价,但作为独立函数暴露,供不使用
:class:`~pyflowx.runner.CliRunner` 的编程式用户调用。
Examples
--------
>>> graphs = {
... "build": px.Graph.from_specs([px.TaskSpec("b", cmd=["echo", "b"])]),
... "all": px.Graph.from_specs(["build", px.TaskSpec("t", cmd=["echo", "t"])]),
... }
>>> resolved = px.compose(graphs)
>>> "b" in resolved["all"].all_specs()
True
"""
return GraphComposer(graphs).resolve_all()
+80 -41
View File
@@ -4,20 +4,18 @@
执行器向后端查询某任务是否已有存储结果;若有则跳过该任务,并将其
存储值注入下游任务。
本模块刻意保持最小化:仅持久化*成功*结果(失败任务会重跑),存储
形态为扁平的 ``{task_name: result}`` 映射。内置两个后端:
存储键由 :meth:`TaskSpec.storage_key` 计算,默认为任务名;若任务配置
了 ``cache_key``,则键为 ``"name:cache_key_value"``,使不同输入产生
独立缓存条目。
* :class:`MemoryBackend` —— 快速、进程内、无 I/O。默认
* :class:`JSONBackend` —— 持久化到 JSON 文件,支持跨进程续跑。
两者均零依赖(``json`` 为标准库)。用户可子类化
:class:`StateBackend` 接入 SQLite、Redis 等。
支持 TTL``has`` 在条目过期时返回 ``False``
"""
from __future__ import annotations
import json
import sys
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Mapping
@@ -31,23 +29,26 @@ from .errors import StorageError
class StateBackend(ABC):
"""可续跑状态存储的抽象基类。"""
"""可续跑状态存储的抽象基类。
所有方法以 ``key`` 为参数(通常为任务名或 ``name:cache_key``)。
"""
@abstractmethod
def load(self) -> Mapping[str, Any]:
"""返回完整的存储映射(可能为空)。"""
@abstractmethod
def save(self, name: str, value: Any) -> None:
def save(self, key: str, value: Any) -> None:
"""持久化单个任务的成功结果。"""
@abstractmethod
def has(self, name: str) -> bool:
"""``name`` 是否已有存储结果。"""
def has(self, key: str) -> bool:
"""``key`` 是否已有未过期的存储结果。"""
@abstractmethod
def get(self, name: str) -> Any:
"""返回 ``name`` 的存储结果(不存在则抛 ``KeyError``)。"""
def get(self, key: str) -> Any:
"""返回 ``key`` 的存储结果(不存在则抛 ``KeyError``)。"""
@abstractmethod
def clear(self) -> None:
@@ -55,43 +56,66 @@ class StateBackend(ABC):
class MemoryBackend(StateBackend):
"""进程内 dict 后端。进程退出即丢失。"""
"""进程内 dict 后端。进程退出即丢失。
def __init__(self) -> None:
self._store: dict[str, Any] = {}
Parameters
----------
ttl:
条目存活秒数。``None`` 表示永不过期。``has`` 在条目超过 ttl 后
返回 ``False``(但不主动删除,下次 ``save`` 覆盖)。
"""
def __init__(self, ttl: float | None = None) -> None:
self._store: dict[str, tuple[Any, float]] = {}
self._ttl = ttl
@override
def load(self) -> Mapping[str, Any]:
return dict(self._store)
return {k: v for k, (v, _ts) in self._store.items() if not self._expired(k)}
@override
def save(self, name: str, value: Any) -> None:
self._store[name] = value
def save(self, key: str, value: Any) -> None:
self._store[key] = (value, time.monotonic())
@override
def has(self, name: str) -> bool:
return name in self._store
def has(self, key: str) -> bool:
return key in self._store and not self._expired(key)
@override
def get(self, name: str) -> Any:
return self._store[name]
def get(self, key: str) -> Any:
if key not in self._store or self._expired(key):
raise KeyError(key)
return self._store[key][0]
@override
def clear(self) -> None:
self._store.clear()
def _expired(self, key: str) -> bool:
if self._ttl is None or key not in self._store:
return False
_value, ts = self._store[key]
return (time.monotonic() - ts) > self._ttl
class JSONBackend(StateBackend):
"""基于文件的 JSON 存储,用于跨进程续跑。
结果必须可 JSON 序列化。不可序列化的值会抛出
:class:`~pyflowx.errors.StorageError`(运行本身不会中止;仅该条
结果的持久化失败)。
存储格式:``{key: {"value": v, "ts": epoch_seconds}}``。
``ts`` 用于 TTL 判断。结果必须可 JSON 序列化。
Parameters
----------
path:
JSON 文件路径。
ttl:
条目存活秒数。``None`` 表示永不过期。
"""
def __init__(self, path: str) -> None:
def __init__(self, path: str, ttl: float | None = None) -> None:
self._path: str = path
self._store: dict[str, Any] = {}
self._ttl = ttl
self._store: dict[str, dict[str, Any]] = {}
self._load()
def _load(self) -> None:
@@ -101,7 +125,14 @@ class JSONBackend(StateBackend):
with open(self._path, encoding="utf-8") as fh:
data: Any = json.load(fh)
if isinstance(data, dict):
self._store = data
# 兼容纯值格式与带元数据格式
self._store = {}
for k, v in data.items():
if isinstance(v, dict) and "value" in v and "ts" in v:
self._store[k] = v
else:
# 旧格式:纯值
self._store[k] = {"value": v, "ts": time.time()}
except (OSError, json.JSONDecodeError) as exc:
raise StorageError(f"cannot read state file {self._path!r}", exc) from exc
@@ -110,32 +141,40 @@ class JSONBackend(StateBackend):
try:
with open(tmp, "w", encoding="utf-8") as fh:
json.dump(self._store, fh, ensure_ascii=False, indent=2)
_ = Path(tmp).replace(Path(self._path))
except (OSError, TypeError) as exc:
raise StorageError(f"cannot write state file {self._path!r}", exc) from exc
@override
def load(self) -> Mapping[str, Any]:
return dict(self._store)
def _now(self) -> float:
return time.time()
def _expired(self, entry: dict[str, Any]) -> bool:
if self._ttl is None:
return False
return (self._now() - float(entry.get("ts", 0))) > self._ttl
@override
def save(self, name: str, value: Any) -> None:
# 在修改内存状态前先校验可序列化性。
def load(self) -> Mapping[str, Any]:
return {k: v["value"] for k, v in self._store.items() if not self._expired(v)}
@override
def save(self, key: str, value: Any) -> None:
try:
_ = json.dumps(value)
except (TypeError, ValueError) as exc:
raise StorageError(f"result of task {name!r} is not JSON-serialisable", exc) from exc
self._store[name] = value
raise StorageError(f"result of key {key!r} is not JSON-serialisable", exc) from exc
self._store[key] = {"value": value, "ts": self._now()}
self._flush()
@override
def has(self, name: str) -> bool:
return name in self._store
def has(self, key: str) -> bool:
return key in self._store and not self._expired(self._store[key])
@override
def get(self, name: str) -> Any:
return self._store[name]
def get(self, key: str) -> Any:
if key not in self._store or self._expired(self._store[key]):
raise KeyError(key)
return self._store[key]["value"]
@override
def clear(self) -> None:
+297 -119
View File
@@ -15,9 +15,11 @@
* ``TaskStatus`` 是封闭枚举;执行器绝不发明临时字符串。
"""
import os
import shutil
import subprocess
import sys
from contextlib import contextmanager
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
@@ -25,8 +27,10 @@ from pathlib import Path
from typing import (
Any,
Callable,
ContextManager,
Coroutine,
Generic,
Iterator,
List,
Mapping,
Optional,
@@ -59,8 +63,95 @@ TaskCmd = Union[
Callable[..., Any], # Python 函数
]
# 条件判断函数类型
Condition = Callable[[], bool]
# 执行策略:sequential/thread/async 为层屏障模型,dependency 为依赖驱动模型。
Strategy = Union[str, "StrategyKind"]
StrategyKind = Any # 占位,避免循环;executors 模块用 Literal 约束
# 条件判断函数类型:接收依赖上下文(可能为空映射),返回是否应执行。
Condition = Callable[[Context], bool]
# 缓存键计算函数:基于依赖上下文计算稳定字符串键。
CacheKeyFn = Callable[[Context], str]
# ---------------------------------------------------------------------- #
# 重试策略
# ---------------------------------------------------------------------- #
@dataclass(frozen=True)
class RetryPolicy:
"""任务失败重试策略。
参数
----
max_attempts:
最大尝试次数(含首次)。``1`` 表示仅尝试一次,不重试。
delay:
两次尝试之间的初始等待秒数。
backoff:
退避倍率。第 n 次重试等待 ``delay * backoff ** (n-1)``。
jitter:
抖动上限秒数。每次等待加上 ``[0, jitter)`` 的随机量,避免惊群。
retry_on:
仅对这些异常类型重试。默认 ``(Exception,)`` 重试所有异常。
传入空元组等价于不重试。
Note
-----
替代旧版 ``retries: int``。``retries=2`` 等价于
``RetryPolicy(max_attempts=3)``。
"""
max_attempts: int = 1
delay: float = 0.0
backoff: float = 1.0
jitter: float = 0.0
retry_on: Tuple[type[BaseException], ...] = (Exception,)
def __post_init__(self) -> None:
if self.max_attempts < 1:
raise ValueError(f"RetryPolicy.max_attempts must be >= 1, got {self.max_attempts}.")
if self.delay < 0:
raise ValueError(f"RetryPolicy.delay must be >= 0, got {self.delay}.")
if self.backoff < 0:
raise ValueError(f"RetryPolicy.backoff must be >= 0, got {self.backoff}.")
if self.jitter < 0:
raise ValueError(f"RetryPolicy.jitter must be >= 0, got {self.jitter}.")
@property
def retries(self) -> int:
"""重试次数(不含首次),等价于 ``max_attempts - 1``。"""
return self.max_attempts - 1
def should_retry(self, exc: BaseException) -> bool:
"""异常是否属于可重试类型。"""
return isinstance(exc, self.retry_on)
def wait_seconds(self, attempt: int) -> float:
"""第 ``attempt`` 次失败后应等待的秒数(attempt 从 1 开始)。"""
if attempt < 1:
return 0.0
import random
base = self.delay * (self.backoff ** max(0, attempt - 1))
jitter = random.uniform(0, self.jitter) if self.jitter > 0 else 0.0
return base + jitter
# ---------------------------------------------------------------------- #
# 任务钩子
# ---------------------------------------------------------------------- #
@dataclass(frozen=True)
class TaskHooks:
"""任务生命周期钩子。
所有钩子均为可选。``pre_run`` 在任务实际执行前调用;``post_run``
在成功后调用并接收返回值;``on_failure`` 在最终失败后调用并接收异常。
钩子异常不会影响任务状态,仅记录日志。
"""
pre_run: Optional[Callable[["TaskSpec[Any]"], None]] = None
post_run: Optional[Callable[["TaskSpec[Any]", Any], None]] = None
on_failure: Optional[Callable[["TaskSpec[Any]", BaseException], None]] = None
class TaskStatus(Enum):
@@ -90,181 +181,239 @@ class TaskSpec(Generic[T]):
- ``list[str]``: 命令及参数列表,如 ``["ls", "-la"]``
- ``str``: shell 命令字符串,如 ``"pip freeze > requirements.txt"``
- ``Callable``: Python 函数,与 ``fn`` 参数等效
若提供此参数,会自动包装为执行函数,覆盖 ``fn`` 参数。
depends_on:
必须先完成才运行本任务的任务名列表。顺序无关;框架会做
拓扑排序。
硬依赖任务名。必须全部成功完成才运行本任务
上游被 SKIPPED 时,本任务也会被 SKIPPED(除非
``allow_upstream_skip=True``)。
soft_depends_on:
软依赖任务名。会等待其完成,但其结果不影响本任务是否执行:
- 上游成功:注入其返回值
- 上游 SKIPPED 或失败:注入 :attr:`defaults` 中提供的默认值
适用于"可选输入"场景。
defaults:
软依赖的默认值映射 ``{dep_name: default_value}``。
软依赖未提供结果时使用。未在 defaults 中出现的软依赖默认为 ``None``。
args:
静态位置参数,追加在注入参数*之后*。适用于参数化任务
(如 ``fetch_user(uid)``)。
静态位置参数,追加在注入参数*之后*。
kwargs:
静态关键字参数。若与注入名冲突则抛出
:class:`~pyflowx.errors.InjectionError`。
retries:
失败后的重试次数。``0`` 表示仅尝试一次。
retry:
:class:`RetryPolicy` 重试策略。默认仅尝试一次。
timeout:
最大执行时长(秒)。``None`` 表示不限制。异步任务使用
:func:`asyncio.wait_for`线程/异步执行器中的同步任务会
取消 worker future。
:func:`asyncio.wait_for`同步任务通过线程 future 取消。
tags:
自由标签,供 :meth:`Graph.subgraph` 做选择性执行与调试
自由标签,供 :meth:`Graph.subgraph` 做选择性执行与调试
也可用于并发限制分组。
conditions:
条件判断函数列表,只有所有条件都返回 ``True`` 时才执行任务。
任一条件返回 ``False``任务被标记为 SKIPPED。
用于平台判断、环境变量检查等场景。
条件判断函数列表,接收依赖上下文,全部返回 ``True`` 时才执行任务。
任一返回 ``False``任务被标记为 SKIPPED。
cwd:
命令执行的工作目录,仅在使用 ``cmd`` 参数时有效。
``None`` 表示当前目录。
工作目录。对 ``cmd`` 任务作为子进程工作目录;对 ``fn`` 任务
通过临时切换当前目录生效
env:
环境变量覆盖映射。对 ``cmd`` 任务合并到子进程环境;对 ``fn``
任务在执行期间临时设置。
verbose:
是否在命令执行时显示详细输出。``True`` 时打印执行的命令
及其标准输出/标准错误。仅在使用 ``cmd`` 参数时有效
``False`` 时静默捕获输出(失败时仍会包含在错误信息中)。
是否打印详细输出。``True`` 时打印执行的命令、返回码与输出
(仅 ``cmd``),以及任务生命周期
skip_if_missing:
仅对 ``cmd`` 为 ``list[str]`` 的任务有效。``True`` 时自动检查
命令是否存在(通过 :func:`shutil.which`,不存在则跳过任务
(标记为 SKIPPED)而非失败。适用于构建工具场景,避免因未安装
某些工具(如 maturin、tox)而导致整个图执行失败。
对于 ``str`` (shell) 和 ``Callable`` 类型的 ``cmd``,此参数无效。
仅对 ``cmd`` 为 ``list[str]`` 有效。``True`` 时通过
:func:`shutil.which` 检查命令是否存在,不存在则跳过
allow_upstream_skip:
若为 ``True``当上游任务因条件不满足被跳过时,本任务仍执行
(而非跟随跳过)。适用于清理类任务:即使某些删除操作因目标不存在
而跳过,后续操作(如重启服务)仍应执行。默认为 ``False``。
若为 ``True``硬依赖被 SKIPPED 时本任务仍执行(软依赖不影响)。
适用于清理类任务。
strategy:
单任务执行策略覆盖。``None`` 表示继承图级策略。
``"sequential"`` 同步直接调用;``"thread"``/``"async"`` 将同步
任务卸载到线程池,异步任务跑在事件循环上。
priority:
同层任务调度优先级。数值越大越先启动。仅影响同层内启动顺序,
不打破层屏障。默认 ``0``。
concurrency_key:
并发限制分组键。具有相同键的任务共享一个信号量,限制同时
运行的实例数。具体限额由 :func:`run` 的 ``concurrency_limits``
参数提供 ``{key: limit}`` 映射。``None`` 表示不限制。
continue_on_error:
若为 ``True``,任务最终失败时不中止整图,仅标记本任务 FAILED,
其硬依赖下游被 SKIPPED,其余任务继续。默认 ``False``。
cache_key:
缓存键计算函数。若提供,则用其基于依赖上下文计算的字符串键
存取状态后端,使不同输入产生独立缓存条目。``None`` 表示用任务名。
hooks:
:class:`TaskHooks` 生命周期钩子。
"""
name: str
fn: Optional[TaskFn[T]] = None
cmd: Optional[TaskCmd] = None
depends_on: Tuple[str, ...] = ()
soft_depends_on: Tuple[str, ...] = ()
defaults: Mapping[str, Any] = field(default_factory=dict)
args: Tuple[Any, ...] = ()
kwargs: Mapping[str, Any] = field(default_factory=dict)
retries: int = 0
retry: RetryPolicy = field(default_factory=RetryPolicy)
timeout: Optional[float] = None
tags: Tuple[str, ...] = ()
conditions: Tuple[Condition, ...] = ()
cwd: Optional[Path] = None
env: Optional[Mapping[str, str]] = None
verbose: bool = False
skip_if_missing: bool = False
allow_upstream_skip: bool = False
strategy: Optional[str] = None
priority: int = 0
concurrency_key: Optional[str] = None
continue_on_error: bool = False
cache_key: Optional[CacheKeyFn] = None
hooks: TaskHooks = field(default_factory=TaskHooks)
def __post_init__(self) -> None:
if not self.name:
raise ValueError("TaskSpec.name must be a non-empty string.")
if self.retries < 0:
raise ValueError(f"TaskSpec '{self.name}': retries must be >= 0.")
if self.retry.max_attempts < 1:
raise ValueError(f"TaskSpec '{self.name}': retry.max_attempts must be >= 1.")
if self.timeout is not None and self.timeout <= 0:
raise ValueError(f"TaskSpec '{self.name}': timeout must be > 0.")
if self.name in self.depends_on:
if self.name in self.depends_on or self.name in self.soft_depends_on:
raise ValueError(f"TaskSpec '{self.name}' cannot depend on itself.")
overlap = set(self.depends_on) & set(self.soft_depends_on)
if overlap:
raise ValueError(f"TaskSpec '{self.name}': depends_on 与 soft_depends_on 不能重叠: {sorted(overlap)}")
if self.fn is None and self.cmd is None:
raise ValueError(f"TaskSpec '{self.name}': 必须提供 fn 或 cmd 参数。")
@property
def effective_fn(self) -> TaskFn[T]:
"""获取有效的执行函数.
"""获取有效的执行函数
若提供 ``cmd`` 参数,则返回包装后的命令执行函数;
否则返回 ``fn`` 参数。
Note
-----
命令执行逻辑已抽到模块级 :func:`_run_command`,此处仅返回轻量
转发闭包。``verbose`` / ``cwd`` / ``timeout`` 不再在创建时闭包
捕获,而是在每次调用时从 ``self`` 读取——这使得翻转 ``verbose``
无需重建 spec(见 :func:`pyflowx.runner._apply_verbose_to_graph`)。
若提供 ``cmd``返回包装后的命令执行函数;否则返回 ``fn``。
包装函数在每次调用时从 ``self`` 读取 ``verbose``/``cwd``/``env``/
``timeout``,避免闭包捕获运行期参数,使翻转字段无需重建 spec。
"""
if self.cmd is not None:
return self._wrap_cmd()
if self.fn is not None:
return self.fn
raise ValueError(f"TaskSpec '{self.name}': 没有可执行的函数或命令。") # pragma: no cover
def _wrap_cmd(self) -> TaskFn[Any]:
"""将 cmd 包装为可执行函数.
返回的闭包仅持有 ``self`` 引用,每次调用时从 spec 读取
``verbose``/``cwd``/``timeout``,避免闭包捕获运行期参数。
Returns
-------
TaskFn[Any]
包装后的执行函数.
"""
"""将 cmd 包装为可执行函数"""
spec = self
if isinstance(spec.cmd, list):
def _run() -> T:
return cast(T, _run_command(spec))
def _run_list() -> T:
return cast(T, _run_command(spec))
_run.__name__ = spec.name
return _run # type: ignore[return-value]
_run_list.__name__ = spec.name
return _run_list # type: ignore[return-value]
if isinstance(spec.cmd, str):
def _run_shell() -> T:
return cast(T, _run_command(spec))
_run_shell.__name__ = spec.name
return _run_shell # type: ignore[return-value]
if callable(spec.cmd):
return spec.cmd # type: ignore[return-value]
raise TypeError(f"TaskSpec '{spec.name}': 不支持的 cmd 类型 {type(spec.cmd).__name__}") # pragma: no cover
def should_execute(self) -> bool:
"""检查任务是否应该执行.
def should_execute(self, context: Context) -> Tuple[bool, Optional[str]]:
"""检查任务是否应执行。
Returns
-------
bool
若所有条件都返回 ``True``,且 ``skip_if_missing`` 检查通过,
则返回 ``True``;否则返回 ``False``。
(should_run, skip_reason)
``should_run`` 为 False 时 ``skip_reason`` 描述跳过原因。
"""
if not all(condition() for condition in self.conditions):
return False
# 逐个求值条件,记录失败项。
failed_conditions: list[str] = []
for condition in self.conditions:
try:
ok = condition(context)
except Exception:
ok = False
name = getattr(condition, "__name__", None) or "匿名条件(执行错误)"
failed_conditions.append(name)
continue
if not ok:
failed_conditions.append(getattr(condition, "__name__", None) or "匿名条件")
return not (self.skip_if_missing and not self._is_cmd_available())
if failed_conditions:
return False, f"条件不满足: {', '.join(failed_conditions)}"
if self.skip_if_missing and not self._is_cmd_available():
cmd_name = self.cmd[0] if isinstance(self.cmd, list) and self.cmd else "unknown"
return False, f"命令不存在: {cmd_name}"
return True, None
def _is_cmd_available(self) -> bool:
"""检查 ``cmd`` 是否可用.
仅对 ``list[str]`` 类型的 ``cmd`` 进行检查(通过 :func:`shutil.which`)。
对于 ``str`` (shell) 和 ``Callable`` 类型,始终返回 ``True``。
Returns
-------
bool
命令可用返回 ``True``,否则返回 ``False``。
"""
"""检查 ``cmd`` 是否可用(仅 list[str])。"""
cmd = self.cmd
if isinstance(cmd, list) and cmd:
first_arg = cmd[0]
return shutil.which(first_arg) is not None
return shutil.which(cmd[0]) is not None
return True
def env_context(self) -> ContextManager[None]:
"""返回临时应用 ``env`` 与 ``cwd`` 的上下文管理器。
def _run_command(spec: "TaskSpec[Any]") -> Any:
"""执行 ``spec.cmd`` 指定的命令(list 或 shell 字符串)
对 ``fn`` 任务生效。``cmd`` 任务在 :func:`_run_command` 中直接
传给子进程
"""
return _env_and_cwd(self.env, self.cwd)
list 与 shell 两条路径的异常处理、输出捕获、返回码判断完全一致,
合并于此消除重复。``verbose``/``cwd``/``timeout`` 在调用时从
``spec`` 读取,而非闭包捕获——这是 ``_wrap_cmd`` 不再捕获运行期
参数的关键。
def storage_key(self, context: Context) -> str:
"""计算状态后端存储键。"""
if self.cache_key is not None:
try:
return f"{self.name}:{self.cache_key(context)}"
except Exception:
return self.name
return self.name
成功返回 ``None``;失败抛 ``RuntimeError``,错误信息包含命令、
返回码与(非 verbose 模式下的)stderr。
"""
@contextmanager
def _env_and_cwd(
env: Optional[Mapping[str, str]],
cwd: Optional[Path],
) -> Iterator[None]:
"""临时设置环境变量与工作目录。"""
saved_env: dict[str, str] = {}
saved_cwd: Optional[str] = None
if env:
for k, v in env.items():
if k in os.environ:
saved_env[k] = os.environ[k]
os.environ[k] = v
if cwd is not None:
saved_cwd = str(Path.cwd())
os.chdir(cwd)
try:
yield
finally:
if saved_cwd is not None:
os.chdir(saved_cwd)
# 恢复环境变量
if env:
for k in env:
if k in saved_env:
os.environ[k] = saved_env[k]
else:
os.environ.pop(k, None)
def _run_command(spec: "TaskSpec[Any]") -> Any: # noqa: PLR0912
"""执行 ``spec.cmd`` 指定的命令(list / shell 字符串 / 可调用对象)。"""
cmd = spec.cmd
is_list = isinstance(cmd, list)
verbose = spec.verbose
cwd = spec.cwd
timeout = spec.timeout
env_override = spec.env
# 统一展示用的命令字符串与标签。保持 "执行命令" / "执行 Shell" 连续,
# 以兼容既有输出格式与测试断言。
# 可调用对象:直接调用,返回其结果。
if callable(cmd) and not isinstance(cmd, (list, str)):
name = getattr(cmd, "__name__", "callable")
if verbose:
print(f"[verbose] 执行可调用命令: {name}", flush=True)
if cwd is not None:
print(f"[verbose] 工作目录: {cwd}", flush=True)
try:
return cmd()
except Exception as e:
raise RuntimeError(f"可调用命令执行异常: {name}: {e}") from e
is_list = isinstance(cmd, list)
if is_list:
cmd_str = " ".join(arg for arg in cmd) # type: ignore[union-attr]
verb = "执行命令"
@@ -279,14 +428,18 @@ def _run_command(spec: "TaskSpec[Any]") -> Any:
if cwd is not None:
print(f"[verbose] 工作目录: {cwd}", flush=True)
# 合并环境变量
run_env: Optional[dict[str, str]] = None
if env_override:
run_env = dict(os.environ)
run_env.update(env_override)
try:
# cmd 此处必为 list[str] 或 str_wrap_cmd 的 isinstance 守卫已排除
# None 与 Callable),但类型检查器无法跨函数推断,故 cast 收窄到
# subprocess.run 接受的 Union[str, Sequence[str]]。
result = subprocess.run(
cast(Union[str, List[str]], cmd),
shell=not is_list,
cwd=cwd,
env=run_env,
timeout=timeout,
capture_output=not verbose,
text=True,
@@ -311,13 +464,42 @@ def _run_command(spec: "TaskSpec[Any]") -> Any:
raise RuntimeError(err_msg)
# ---------------------------------------------------------------------- #
# 任务模板:批量生成相似 TaskSpec 的工厂
# ---------------------------------------------------------------------- #
def task_template(
fn: Optional[TaskFn[Any]] = None,
cmd: Optional[TaskCmd] = None,
**defaults: Any,
) -> Callable[..., TaskSpec[Any]]:
"""创建任务模板工厂。
返回的工厂接受 ``name`` 与任意覆盖字段,生成 :class:`TaskSpec`。
适用于批量创建相似任务(如 fan-out)。
Examples
--------
>>> Fetch = px.task_template(fn=fetch_user, retry=px.RetryPolicy(max_attempts=3))
>>> specs = [Fetch(f"fetch_{uid}", args=(uid,)) for uid in range(5)]
"""
base = dict(defaults)
if fn is not None:
base["fn"] = fn
if cmd is not None:
base["cmd"] = cmd
def _factory(name: str, **overrides: Any) -> TaskSpec[Any]:
merged = dict(base)
merged.update(overrides)
return TaskSpec(name, **merged)
_factory.__name__ = "task_template_factory"
return _factory
@dataclass
class TaskResult(Generic[T]):
"""运行期间产生的可变单任务记录。
每次运行都会创建全新的 :class:`TaskResult`spec 本身保持不可变。
这让同一个图可以安全地重复运行。
"""
"""运行期间产生的可变单任务记录。"""
spec: TaskSpec[T]
status: TaskStatus = TaskStatus.PENDING
@@ -338,15 +520,11 @@ class TaskResult(Generic[T]):
@dataclass(frozen=True)
class TaskEvent:
"""执行期间向观察者发出的不可变事件。
传递给 :func:`pyflowx.run` 的 ``on_event`` 回调,让调用者无需耦合
执行器内部即可构建进度条、指标或结构化日志。
"""
"""执行期间向观察者发出的不可变事件。"""
task: str
status: TaskStatus
attempts: int = 0
error: Optional[str] = None
duration: Optional[float] = None
reason: Optional[str] = None # 跳过原因,如 "条件不满足"、"上游任务被跳过"、"缓存"
reason: Optional[str] = None
File diff suppressed because it is too large Load Diff
+148 -72
View File
@@ -1,142 +1,218 @@
"""Tests for conditions module."""
from __future__ import annotations
import os
import sys
from unittest.mock import patch
from pyflowx.conditions import (
IS_LINUX,
IS_MACOS,
IS_POSIX,
IS_WINDOWS,
BuiltinConditions,
Constants,
)
_CTX: dict[str, object] = {}
def test_constants_is_windows():
"""Test Constants.IS_WINDOWS is correct."""
assert (sys.platform == "win32") == Constants.IS_WINDOWS
def test_constants_is_linux():
"""Test Constants.IS_LINUX is correct."""
assert (sys.platform == "linux") == Constants.IS_LINUX
def test_constants_is_macos():
"""Test Constants.IS_MACOS is correct."""
assert (sys.platform == "darwin") == Constants.IS_MACOS
def test_constants_is_posix():
"""Test Constants.IS_POSIX is correct."""
assert (sys.platform != "win32") == Constants.IS_POSIX
def test_module_level_static_conditions():
assert IS_WINDOWS(_CTX) == Constants.IS_WINDOWS
assert IS_LINUX(_CTX) == Constants.IS_LINUX
assert IS_MACOS(_CTX) == Constants.IS_MACOS
assert IS_POSIX(_CTX) == Constants.IS_POSIX
def test_builtin_conditions_python_version_major_only():
"""Test BuiltinConditions.PYTHON_VERSION with major only."""
# Test with current Python version
def test_python_version_major_only():
current_major = sys.version_info.major
assert BuiltinConditions.PYTHON_VERSION(current_major) is True
assert BuiltinConditions.PYTHON_VERSION(current_major + 1) is False
assert BuiltinConditions.PYTHON_VERSION(current_major)(_CTX) is True
assert BuiltinConditions.PYTHON_VERSION(current_major + 1)(_CTX) is False
def test_builtin_conditions_python_version_with_minor():
"""Test BuiltinConditions.PYTHON_VERSION with major and minor."""
def test_python_version_with_minor():
current_major = sys.version_info.major
current_minor = sys.version_info.minor
assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor) is True
assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor + 1) is False
assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor)(_CTX) is True
assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor + 1)(_CTX) is False
def test_builtin_conditions_python_version_at_least():
"""Test BuiltinConditions.PYTHON_VERSION_AT_LEAST."""
def test_python_version_at_least():
current_major = sys.version_info.major
current_minor = sys.version_info.minor
# Current version should be at least itself
assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major, current_minor) is True
# Current version should be at least an older version
assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major - 1, 0) is True
# Current version should NOT be at least a newer version
assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major + 1, 0) is False
assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major, current_minor)(_CTX) is True
assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major - 1, 0)(_CTX) is True
assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major + 1, 0)(_CTX) is False
def test_builtin_conditions_HAS_INSTALLED_true():
"""Test BuiltinConditions.HAS_INSTALLED when app exists."""
# Python should always be available
condition = BuiltinConditions.HAS_INSTALLED("python")
assert condition() is True
def test_has_installed_true():
condition = BuiltinConditions.HAS_INSTALLED("python3")
assert condition(_CTX) is True
def test_builtin_conditions_HAS_INSTALLED_false():
"""Test BuiltinConditions.HAS_INSTALLED when app doesn't exist."""
def test_has_installed_false():
condition = BuiltinConditions.HAS_INSTALLED("nonexistent_app_12345")
assert condition() is False
assert condition(_CTX) is False
def test_builtin_conditions_env_var_exists_true():
"""Test BuiltinConditions.ENV_VAR_EXISTS when variable exists."""
def test_env_var_exists_true():
with patch.dict(os.environ, {"TEST_VAR": "value"}):
condition = BuiltinConditions.ENV_VAR_EXISTS("TEST_VAR")
assert condition() is True
assert condition(_CTX) is True
def test_builtin_conditions_env_var_exists_false():
"""Test BuiltinConditions.ENV_VAR_EXISTS when variable doesn't exist."""
def test_env_var_exists_false():
condition = BuiltinConditions.ENV_VAR_EXISTS("NONEXISTENT_VAR_12345")
assert condition() is False
assert condition(_CTX) is False
def test_builtin_conditions_env_var_equals_true():
"""Test BuiltinConditions.ENV_VAR_EQUALS when value matches."""
def test_env_var_equals_true():
with patch.dict(os.environ, {"TEST_VAR": "expected_value"}):
condition = BuiltinConditions.ENV_VAR_EQUALS("TEST_VAR", "expected_value")
assert condition() is True
assert condition(_CTX) is True
def test_builtin_conditions_env_var_equals_false():
"""Test BuiltinConditions.ENV_VAR_EQUALS when value doesn't match."""
def test_env_var_equals_false():
with patch.dict(os.environ, {"TEST_VAR": "different_value"}):
condition = BuiltinConditions.ENV_VAR_EQUALS("TEST_VAR", "expected_value")
assert condition() is False
assert condition(_CTX) is False
def test_builtin_conditions_not():
"""Test BuiltinConditions.NOT."""
true_condition = lambda: True # noqa: E731
false_condition = lambda: False # noqa: E731
def test_not():
true_cond = BuiltinConditions.HAS_INSTALLED("python3")
false_cond = BuiltinConditions.HAS_INSTALLED("nonexistent_app_12345")
not_true = BuiltinConditions.NOT(true_condition)
assert not_true() is False
not_false = BuiltinConditions.NOT(false_condition)
assert not_false() is True
assert BuiltinConditions.NOT(true_cond)(_CTX) is False
assert BuiltinConditions.NOT(false_cond)(_CTX) is True
def test_builtin_conditions_and_all_true():
"""Test BuiltinConditions.AND when all conditions are true."""
true_condition = lambda: True # noqa: E731
condition = BuiltinConditions.AND(true_condition, true_condition, true_condition)
assert condition() is True
def test_and_all_true():
cond = BuiltinConditions.AND(
BuiltinConditions.HAS_INSTALLED("python3"),
BuiltinConditions.HAS_INSTALLED("python3"),
)
assert cond(_CTX) is True
def test_builtin_conditions_and_one_false():
"""Test BuiltinConditions.AND when one condition is false."""
true_condition = lambda: True # noqa: E731
false_condition = lambda: False # noqa: E731
condition = BuiltinConditions.AND(true_condition, false_condition, true_condition)
assert condition() is False
def test_and_one_false():
cond = BuiltinConditions.AND(
BuiltinConditions.HAS_INSTALLED("python3"),
BuiltinConditions.HAS_INSTALLED("nonexistent_app"),
)
assert cond(_CTX) is False
def test_builtin_conditions_or_all_false():
"""Test BuiltinConditions.OR when all conditions are false."""
false_condition = lambda: False # noqa: E731
condition = BuiltinConditions.OR(false_condition, false_condition, false_condition)
assert condition() is False
def test_or_all_false():
cond = BuiltinConditions.OR(
BuiltinConditions.HAS_INSTALLED("nonexistent1"),
BuiltinConditions.HAS_INSTALLED("nonexistent2"),
)
assert cond(_CTX) is False
def test_builtin_conditions_or_one_true():
"""Test BuiltinConditions.OR when one condition is true."""
true_condition = lambda: True # noqa: E731
false_condition = lambda: False # noqa: E731
condition = BuiltinConditions.OR(false_condition, true_condition, false_condition)
assert condition() is True
def test_or_one_true():
cond = BuiltinConditions.OR(
BuiltinConditions.HAS_INSTALLED("nonexistent1"),
BuiltinConditions.HAS_INSTALLED("python3"),
)
assert cond(_CTX) is True
# ---------------------------------------------------------------------- #
# 上下文条件:基于上游依赖结果
# ---------------------------------------------------------------------- #
def test_dep_equals_true():
ctx = {"upstream": 42}
cond = BuiltinConditions.DEP_EQUALS("upstream", 42)
assert cond(ctx) is True
def test_dep_equals_false():
ctx = {"upstream": 99}
cond = BuiltinConditions.DEP_EQUALS("upstream", 42)
assert cond(ctx) is False
def test_dep_equals_missing_dep():
cond = BuiltinConditions.DEP_EQUALS("missing", 42)
assert cond({}) is False
def test_dep_matches_true():
ctx = {"upstream": [1, 2, 3]}
cond = BuiltinConditions.DEP_MATCHES("upstream", lambda v: len(v) == 3)
assert cond(ctx) is True
def test_dep_matches_false():
ctx = {"upstream": [1, 2]}
cond = BuiltinConditions.DEP_MATCHES("upstream", lambda v: len(v) == 3)
assert cond(ctx) is False
def test_dep_matches_exception_returns_false():
ctx = {"upstream": ""}
cond = BuiltinConditions.DEP_MATCHES("upstream", lambda v: v[0])
assert cond(ctx) is False
def test_dep_present_true():
ctx = {"upstream": "value"}
cond = BuiltinConditions.DEP_PRESENT("upstream")
assert cond(ctx) is True
def test_dep_present_false_none():
# pyrefly: ignore [implicit-any-empty-container]
ctx = {"upstream": None}
cond = BuiltinConditions.DEP_PRESENT("upstream")
assert cond(ctx) is False
def test_dep_present_false_missing():
cond = BuiltinConditions.DEP_PRESENT("missing")
assert cond({}) is False
def test_dep_truthy_true():
ctx = {"upstream": [1]}
cond = BuiltinConditions.DEP_TRUTHY("upstream")
assert cond(ctx) is True
def test_dep_truthy_false():
# pyrefly: ignore [implicit-any-empty-container]
ctx = {"upstream": []}
cond = BuiltinConditions.DEP_TRUTHY("upstream")
assert cond(ctx) is False
def test_dep_truthy_missing():
cond = BuiltinConditions.DEP_TRUTHY("missing")
assert cond({}) is False
def test_logical_combination_with_dep_conditions():
ctx = {"a": 1, "b": 0}
cond = BuiltinConditions.AND(
BuiltinConditions.DEP_EQUALS("a", 1),
BuiltinConditions.NOT(BuiltinConditions.DEP_TRUTHY("b")),
)
assert cond(ctx) is True
+1 -1
View File
@@ -141,7 +141,7 @@ class TestDescribeInjection:
spec = px.TaskSpec("t", fn, depends_on=("a",))
desc = describe_injection(spec)
assert "a=<result:a>" in desc
assert "a=<dep:a>" in desc
assert "ctx=<Context>" in desc
assert "flag=<default>" in desc
+16 -8
View File
@@ -84,7 +84,9 @@ def test_retries_then_succeeds() -> None:
raise RuntimeError("not yet")
return "ok"
graph = px.Graph.from_specs([px.TaskSpec("flaky", flaky, retries=2)])
graph = px.Graph.from_specs([
px.TaskSpec("flaky", flaky, retry=px.RetryPolicy(max_attempts=3)),
])
report = px.run(graph, strategy="sequential")
assert report.success
assert report["flaky"] == "ok"
@@ -95,7 +97,9 @@ def test_retries_exhausted() -> None:
def always_fail() -> None:
raise RuntimeError("nope")
graph = px.Graph.from_specs([px.TaskSpec("f", always_fail, retries=2)])
graph = px.Graph.from_specs([
px.TaskSpec("f", always_fail, retry=px.RetryPolicy(max_attempts=3)),
])
with pytest.raises(TaskFailedError) as exc_info:
_ = px.run(graph, strategy="sequential")
assert exc_info.value.attempts == 3
@@ -332,7 +336,9 @@ def test_async_timeout_retry_then_succeed() -> None:
await asyncio.sleep(10) # 触发超时
return "ok"
graph = px.Graph.from_specs([px.TaskSpec("a", flaky, retries=2, timeout=0.05)])
graph = px.Graph.from_specs([
px.TaskSpec("a", flaky, retry=px.RetryPolicy(max_attempts=3), timeout=0.05),
])
report = px.run(graph, strategy="async")
assert report.success
assert report["a"] == "ok"
@@ -349,7 +355,9 @@ def test_async_failure_retry_branch(caplog: pytest.LogCaptureFixture) -> None:
raise RuntimeError("not yet")
return "ok"
graph = px.Graph.from_specs([px.TaskSpec("a", flaky, retries=2)])
graph = px.Graph.from_specs([
px.TaskSpec("a", flaky, retry=px.RetryPolicy(max_attempts=3)),
])
with caplog.at_level("WARNING", logger="pyflowx"):
report = px.run(graph, strategy="async")
assert report.success
@@ -489,7 +497,7 @@ def test_run_empty_graph() -> None:
# ---------------------------------------------------------------------- #
def test_downstream_skipped_when_upstream_skipped_sequential() -> None:
"""上游任务被 SKIPPED 后,下游任务也应被 SKIPPEDsequential 策略)."""
never_true = lambda: False # noqa: E731
never_true = lambda _ctx: False # noqa: E731
def downstream(upstream: str) -> str:
return upstream + "_processed"
@@ -506,7 +514,7 @@ def test_downstream_skipped_when_upstream_skipped_sequential() -> None:
def test_downstream_skipped_when_upstream_skipped_thread() -> None:
"""上游任务被 SKIPPED 后,下游任务也应被 SKIPPEDthread 策略)."""
never_true = lambda: False # noqa: E731
never_true = lambda _ctx: False # noqa: E731
def downstream(upstream: str) -> str:
return upstream + "_processed"
@@ -530,7 +538,7 @@ def test_downstream_skipped_when_upstream_skipped_async() -> None:
async def downstream(upstream: str) -> str:
return upstream + "_processed"
never_true = lambda: False # noqa: E731
never_true = lambda _ctx: False # noqa: E731
graph = px.Graph.from_specs([
px.TaskSpec("upstream", upstream, conditions=(never_true,)),
@@ -544,7 +552,7 @@ def test_downstream_skipped_when_upstream_skipped_async() -> None:
def test_downstream_executes_when_upstream_succeeds() -> None:
"""上游任务成功时,下游任务应正常执行."""
always_true = lambda: True # noqa: E731
always_true = lambda _ctx: True # noqa: E731
def upstream() -> str:
return "hello"
+14 -6
View File
@@ -85,7 +85,7 @@ def test_verbose_run_with_skipped_lifecycle(capsys: pytest.CaptureFixture[str]):
spec = px.TaskSpec(
"test",
fn=lambda: "result",
conditions=(lambda: False,),
conditions=(lambda _ctx: False,),
)
graph = px.Graph.from_specs([spec])
report = px.run(graph, strategy="sequential", verbose=True)
@@ -140,7 +140,7 @@ def test_verbose_event_callback_skipped():
spec = px.TaskSpec(
"test",
fn=lambda: "result",
conditions=(lambda: False,),
conditions=(lambda _ctx: False,),
verbose=True,
)
graph = px.Graph.from_specs([spec])
@@ -161,7 +161,11 @@ def test_execute_sync_with_retries():
raise ValueError("temporary error")
return "success"
spec = px.TaskSpec("retry_test", fn=failing_function, retries=3)
spec = px.TaskSpec(
"retry_test",
fn=failing_function,
retry=px.RetryPolicy(max_attempts=3),
)
graph = px.Graph.from_specs([spec])
# Should succeed after retries
@@ -182,7 +186,11 @@ def test_execute_async_with_retries():
raise ValueError("temporary error")
return "success"
spec = px.TaskSpec("retry_async_test", fn=failing_async_function, retries=3)
spec = px.TaskSpec(
"retry_async_test",
fn=failing_async_function,
retry=px.RetryPolicy(max_attempts=3),
)
graph = px.Graph.from_specs([spec])
# Should succeed after retries
@@ -196,7 +204,7 @@ def test_execute_sync_skip_on_condition():
spec = px.TaskSpec(
"skip_test",
fn=lambda: "result",
conditions=(lambda: False,),
conditions=(lambda _ctx: False,),
)
graph = px.Graph.from_specs([spec])
@@ -210,7 +218,7 @@ def test_execute_async_skip_on_condition():
spec = px.TaskSpec(
"skip_async_test",
fn=lambda: "result",
conditions=(lambda: False,),
conditions=(lambda _ctx: False,),
)
graph = px.Graph.from_specs([spec])
+58 -74
View File
@@ -13,13 +13,11 @@ def _fn() -> None:
def test_from_specs_builds_graph() -> None:
graph = px.Graph.from_specs(
[
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn, depends_on=("a",)),
px.TaskSpec("c", _fn, depends_on=("a", "b")),
]
)
graph = px.Graph.from_specs([
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn, depends_on=("a",)),
px.TaskSpec("c", _fn, depends_on=("a", "b")),
])
assert set(graph.names) == {"a", "b", "c"}
assert graph.dependencies("c") == ("a", "b")
assert len(graph) == 3
@@ -28,23 +26,19 @@ def test_from_specs_builds_graph() -> None:
def test_from_specs_allows_forward_references() -> None:
# b depends on a, but a is declared after b — order should not matter.
graph = px.Graph.from_specs(
[
px.TaskSpec("b", _fn, depends_on=("a",)),
px.TaskSpec("a", _fn),
]
)
graph = px.Graph.from_specs([
px.TaskSpec("b", _fn, depends_on=("a",)),
px.TaskSpec("a", _fn),
])
assert graph.layers() == [["a"], ["b"]]
def test_duplicate_task_raises() -> None:
with pytest.raises(DuplicateTaskError):
_ = px.Graph.from_specs(
[
px.TaskSpec("a", _fn),
px.TaskSpec("a", _fn),
]
)
_ = px.Graph.from_specs([
px.TaskSpec("a", _fn),
px.TaskSpec("a", _fn),
])
def test_missing_dependency_raises() -> None:
@@ -57,24 +51,20 @@ def test_missing_dependency_raises() -> None:
def test_cycle_detection() -> None:
with pytest.raises(CycleError):
_ = px.Graph.from_specs(
[
px.TaskSpec("a", _fn, depends_on=("c",)),
px.TaskSpec("b", _fn, depends_on=("a",)),
px.TaskSpec("c", _fn, depends_on=("b",)),
]
)
_ = px.Graph.from_specs([
px.TaskSpec("a", _fn, depends_on=("c",)),
px.TaskSpec("b", _fn, depends_on=("a",)),
px.TaskSpec("c", _fn, depends_on=("b",)),
])
def test_layers_grouping() -> None:
graph = px.Graph.from_specs(
[
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn),
px.TaskSpec("c", _fn, depends_on=("a", "b")),
px.TaskSpec("d", _fn, depends_on=("c",)),
]
)
graph = px.Graph.from_specs([
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn),
px.TaskSpec("c", _fn, depends_on=("a", "b")),
px.TaskSpec("d", _fn, depends_on=("c",)),
])
layers = graph.layers()
assert layers == [["a", "b"], ["c"], ["d"]]
@@ -85,12 +75,10 @@ def test_self_dependency_rejected() -> None:
def test_to_mermaid() -> None:
graph = px.Graph.from_specs(
[
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn, depends_on=("a",)),
]
)
graph = px.Graph.from_specs([
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn, depends_on=("a",)),
])
mermaid = graph.to_mermaid()
assert mermaid.startswith("graph TD")
assert 'a["a"]' in mermaid
@@ -104,13 +92,11 @@ def test_to_mermaid_invalid_orientation() -> None:
def test_subgraph_by_tags() -> None:
graph = px.Graph.from_specs(
[
px.TaskSpec("a", _fn, tags=("ingest",)),
px.TaskSpec("b", _fn, depends_on=("a",), tags=("ingest",)),
px.TaskSpec("c", _fn, depends_on=("b",), tags=("report",)),
]
)
graph = px.Graph.from_specs([
px.TaskSpec("a", _fn, tags=("ingest",)),
px.TaskSpec("b", _fn, depends_on=("a",), tags=("ingest",)),
px.TaskSpec("c", _fn, depends_on=("b",), tags=("report",)),
])
sub = graph.subgraph(["ingest"])
assert set(sub.names) == {"a", "b"}
# Edge to dropped task c is removed; b no longer waits for anything
@@ -119,13 +105,11 @@ def test_subgraph_by_tags() -> None:
def test_subgraph_by_names() -> None:
graph = px.Graph.from_specs(
[
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn, depends_on=("a",)),
px.TaskSpec("c", _fn, depends_on=("b",)),
]
)
graph = px.Graph.from_specs([
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn, depends_on=("a",)),
px.TaskSpec("c", _fn, depends_on=("b",)),
])
sub = graph.subgraph_by_names(["a", "b"])
assert set(sub.names) == {"a", "b"}
# c is dropped, so b's dep on c (none here) — but a->b edge preserved.
@@ -139,12 +123,10 @@ def test_subgraph_by_names_unknown() -> None:
def test_describe() -> None:
graph = px.Graph.from_specs(
[
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn, depends_on=("a",)),
]
)
graph = px.Graph.from_specs([
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn, depends_on=("a",)),
])
desc = graph.describe()
assert "Layer 1" in desc
assert "Layer 2" in desc
@@ -187,12 +169,10 @@ def test_spec_accessor() -> None:
def test_dependencies_accessor() -> None:
graph = px.Graph.from_specs(
[
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn, depends_on=("a",)),
]
)
graph = px.Graph.from_specs([
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn, depends_on=("a",)),
])
assert graph.dependencies("a") == ()
assert graph.dependencies("b") == ("a",)
@@ -210,16 +190,20 @@ def test_empty_graph_layers() -> None:
def test_subgraph_preserves_metadata() -> None:
"""子图应保留原任务的 retries/timeout/tags 等元数据。"""
graph = px.Graph.from_specs(
[
px.TaskSpec("a", _fn, tags=("x",), retries=3, timeout=5.0),
px.TaskSpec("b", _fn, depends_on=("a",), tags=("y",)),
]
)
"""子图应保留原任务的 retry/timeout/tags 等元数据。"""
graph = px.Graph.from_specs([
px.TaskSpec(
"a",
_fn,
tags=("x",),
retry=px.RetryPolicy(max_attempts=3),
timeout=5.0,
),
px.TaskSpec("b", _fn, depends_on=("a",), tags=("y",)),
])
sub = graph.subgraph(["x"])
spec = sub.spec("a")
assert spec.retries == 3
assert spec.retry.max_attempts == 3
assert spec.timeout == 5.0
assert spec.tags == ("x",)
+50 -68
View File
@@ -29,24 +29,20 @@ def _echo_graph(name: str = "echo_task", msg: str = "hello") -> px.Graph:
def _failing_graph() -> px.Graph:
"""构造一个必定失败的单任务图."""
return px.Graph.from_specs(
[
px.TaskSpec(
"fail",
cmd=["python", "-c", "import sys; sys.exit(1)"],
)
]
)
return px.Graph.from_specs([
px.TaskSpec(
"fail",
cmd=["python", "-c", "import sys; sys.exit(1)"],
)
])
def _multi_task_graph() -> px.Graph:
"""构造一个带依赖的多任务图."""
return px.Graph.from_specs(
[
px.TaskSpec("a", cmd=[*ECHO_CMD, "a"]),
px.TaskSpec("b", cmd=[*ECHO_CMD, "b"], depends_on=("a",)),
]
)
return px.Graph.from_specs([
px.TaskSpec("a", cmd=[*ECHO_CMD, "a"]),
px.TaskSpec("b", cmd=[*ECHO_CMD, "b"], depends_on=("a",)),
])
# ---------------------------------------------------------------------- #
@@ -240,12 +236,10 @@ class TestCliRunnerRunSuccess:
def track_b() -> None:
executed.append("b")
runner = px.CliRunner(
{
"a": px.Graph.from_specs([px.TaskSpec("a", track_a)]),
"b": px.Graph.from_specs([px.TaskSpec("b", track_b)]),
}
)
runner = px.CliRunner({
"a": px.Graph.from_specs([px.TaskSpec("a", track_a)]),
"b": px.Graph.from_specs([px.TaskSpec("b", track_b)]),
})
_ = runner.run(["b"])
assert executed == ["b"]
@@ -318,15 +312,13 @@ class TestCliRunnerVerbose:
def test_verbose_prints_skip_lifecycle(self, capsys: pytest.CaptureFixture[str]) -> None:
"""verbose 模式下跳过的任务应打印跳过信息."""
graph = px.Graph.from_specs(
[
px.TaskSpec(
"skip_me",
cmd=[*ECHO_CMD, "skip"],
conditions=(lambda: False,),
),
]
)
graph = px.Graph.from_specs([
px.TaskSpec(
"skip_me",
cmd=[*ECHO_CMD, "skip"],
conditions=(lambda _ctx: False,),
),
])
runner = px.CliRunner({"skip": graph})
_ = runner.run(["skip"])
captured = capsys.readouterr()
@@ -394,13 +386,11 @@ class TestCliRunnerList:
def test_list_prints_all_commands(self, capsys: pytest.CaptureFixture[str]) -> None:
"""--list 应打印所有命令."""
runner = px.CliRunner(
{
"clean": _echo_graph("c", "clean"),
"build": _echo_graph("b", "build"),
"test": _echo_graph("t", "test"),
}
)
runner = px.CliRunner({
"clean": _echo_graph("c", "clean"),
"build": _echo_graph("b", "build"),
"test": _echo_graph("t", "test"),
})
_ = runner.run(["--list"])
captured = capsys.readouterr()
assert "clean" in captured.out
@@ -523,30 +513,26 @@ class TestCliRunnerIntegration:
def test_condition_skipped_command_succeeds(self) -> None:
"""条件不满足时任务跳过, 整体仍成功."""
graph = px.Graph.from_specs(
[
px.TaskSpec(
"skip_me",
cmd=[*ECHO_CMD, "should not run"],
conditions=(lambda: False,),
),
]
)
graph = px.Graph.from_specs([
px.TaskSpec(
"skip_me",
cmd=[*ECHO_CMD, "should not run"],
conditions=(lambda _ctx: False,),
),
])
runner = px.CliRunner({"skip": graph})
exit_code = runner.run(["skip"])
assert exit_code == CliExitCode.SUCCESS.value
def test_condition_met_command_succeeds(self) -> None:
"""条件满足时任务执行, 整体成功."""
graph = px.Graph.from_specs(
[
px.TaskSpec(
"run_me",
cmd=[*ECHO_CMD, "should run"],
conditions=(lambda: True,),
),
]
)
graph = px.Graph.from_specs([
px.TaskSpec(
"run_me",
cmd=[*ECHO_CMD, "should run"],
conditions=(lambda _ctx: True,),
),
])
runner = px.CliRunner({"run": graph})
exit_code = runner.run(["run"])
assert exit_code == CliExitCode.SUCCESS.value
@@ -562,14 +548,12 @@ class TestCliRunnerIntegration:
return fn
graph = px.Graph.from_specs(
[
px.TaskSpec("a", make("a")),
px.TaskSpec("b", make("b"), depends_on=("a",)),
px.TaskSpec("c", make("c"), depends_on=("a",)),
px.TaskSpec("d", make("d"), depends_on=("b", "c")),
]
)
graph = px.Graph.from_specs([
px.TaskSpec("a", make("a")),
px.TaskSpec("b", make("b"), depends_on=("a",)),
px.TaskSpec("c", make("c"), depends_on=("a",)),
px.TaskSpec("d", make("d"), depends_on=("b", "c")),
])
runner = px.CliRunner({"diamond": graph})
exit_code = runner.run(["diamond"])
assert exit_code == CliExitCode.SUCCESS.value
@@ -577,12 +561,10 @@ class TestCliRunnerIntegration:
def test_mixed_fn_and_cmd_commands(self) -> None:
"""混合 fn 和 cmd 的命令应都能执行."""
runner = px.CliRunner(
{
"fn_cmd": px.Graph.from_specs([px.TaskSpec("fn", fn=lambda: "fn-result")]),
"cmd_cmd": px.Graph.from_specs([px.TaskSpec("cmd", cmd=[*ECHO_CMD, "cmd-result"])]),
}
)
runner = px.CliRunner({
"fn_cmd": px.Graph.from_specs([px.TaskSpec("fn", fn=lambda: "fn-result")]),
"cmd_cmd": px.Graph.from_specs([px.TaskSpec("cmd", cmd=[*ECHO_CMD, "cmd-result"])]),
})
assert runner.run(["fn_cmd"]) == CliExitCode.SUCCESS.value
assert runner.run(["cmd_cmd"]) == CliExitCode.SUCCESS.value
+4 -4
View File
@@ -6,7 +6,7 @@ from datetime import datetime
import pytest
from pyflowx.task import TaskResult, TaskSpec, TaskStatus
from pyflowx.task import RetryPolicy, TaskResult, TaskSpec, TaskStatus
def _fn() -> None:
@@ -18,9 +18,9 @@ def test_spec_empty_name_rejected() -> None:
TaskSpec("", _fn)
def test_spec_negative_retries_rejected() -> None:
with pytest.raises(ValueError, match="retries"):
TaskSpec("a", _fn, retries=-1)
def test_spec_negative_max_attempts_rejected() -> None:
with pytest.raises(ValueError, match="max_attempts"):
TaskSpec("a", _fn, retry=RetryPolicy(max_attempts=0))
def test_spec_zero_timeout_rejected() -> None:
+31 -26
View File
@@ -67,7 +67,9 @@ def test_taskspec_wrap_cmd_verbose():
def test_taskspec_wrap_cmd_error():
"""Test TaskSpec._wrap_cmd handles command error."""
spec = TaskSpec("test", cmd=["python", "-c", "import sys; sys.exit(1)"])
import sys
spec = TaskSpec("test", cmd=[sys.executable, "-c", "import sys; sys.exit(1)"])
wrapped_fn = spec.effective_fn
with pytest.raises(RuntimeError, match="命令执行失败"):
@@ -105,10 +107,10 @@ def test_taskspec_conditions_check():
spec = px.TaskSpec(
"test",
fn=lambda: "result",
conditions=(lambda: True,),
conditions=(lambda _ctx: True,),
)
assert spec.should_execute() is True
assert spec.should_execute({})[0] is True
def test_taskspec_conditions_false():
@@ -116,10 +118,10 @@ def test_taskspec_conditions_false():
spec = px.TaskSpec(
"test",
fn=lambda: "result",
conditions=(lambda: False,),
conditions=(lambda _ctx: False,),
)
assert spec.should_execute() is False
assert spec.should_execute({})[0] is False
def test_taskspec_conditions_multiple():
@@ -127,10 +129,10 @@ def test_taskspec_conditions_multiple():
spec = px.TaskSpec(
"test",
fn=lambda: "result",
conditions=(lambda: True, lambda: True, lambda: True),
conditions=(lambda _ctx: True, lambda _ctx: True, lambda _ctx: True),
)
assert spec.should_execute() is True
assert spec.should_execute({})[0] is True
def test_taskspec_conditions_multiple_one_false():
@@ -138,10 +140,10 @@ def test_taskspec_conditions_multiple_one_false():
spec = px.TaskSpec(
"test",
fn=lambda: "result",
conditions=(lambda: True, lambda: False, lambda: True),
conditions=(lambda _ctx: True, lambda _ctx: False, lambda _ctx: True),
)
assert spec.should_execute() is False
assert spec.should_execute({})[0] is False
def test_taskspec_list_cmd_timeout_mocked():
@@ -218,27 +220,28 @@ def test_taskspec_shell_cmd_os_error_mocked():
# ---------------------------------------------------------------------- #
def test_skip_if_missing_with_available_command():
"""skip_if_missing=True 时,命令存在应返回 True."""
# python 命令在测试环境中一定存在
spec = TaskSpec("test", cmd=["python", "--version"], skip_if_missing=True)
assert spec.should_execute() is True
import sys
spec = TaskSpec("test", cmd=[sys.executable, "--version"], skip_if_missing=True)
assert spec.should_execute({})[0] is True
def test_skip_if_missing_with_missing_command():
"""skip_if_missing=True 时,命令不存在应返回 False."""
spec = TaskSpec("test", cmd=["definitely_not_installed_app_xyz"], skip_if_missing=True)
assert spec.should_execute() is False
assert spec.should_execute({})[0] is False
def test_skip_if_missing_false_with_missing_command():
"""skip_if_missing=False 时,命令不存在也应返回 True(不检查)."""
spec = TaskSpec("test", cmd=["definitely_not_installed_app_xyz"], skip_if_missing=False)
assert spec.should_execute() is True
assert spec.should_execute({})[0] is True
def test_skip_if_missing_with_shell_cmd_not_checked():
"""skip_if_missing=True 时,shell 命令(str)不检查,应返回 True."""
spec = TaskSpec("test", cmd="definitely_not_installed_app_xyz", skip_if_missing=True)
assert spec.should_execute() is True
assert spec.should_execute({})[0] is True
def test_skip_if_missing_with_callable_cmd_not_checked():
@@ -248,7 +251,7 @@ def test_skip_if_missing_with_callable_cmd_not_checked():
return 0
spec = TaskSpec("test", cmd=custom_cmd, skip_if_missing=True)
assert spec.should_execute() is True
assert spec.should_execute({})[0] is True
def test_skip_if_missing_with_fn_not_checked():
@@ -258,7 +261,7 @@ def test_skip_if_missing_with_fn_not_checked():
return 0
spec = TaskSpec("test", fn=my_fn, skip_if_missing=True)
assert spec.should_execute() is True
assert spec.should_execute({})[0] is True
def test_skip_if_missing_with_empty_cmd_list():
@@ -266,37 +269,39 @@ def test_skip_if_missing_with_empty_cmd_list():
spec = TaskSpec("test", cmd=[""], skip_if_missing=True)
# 空字符串命令,shutil.which 返回 None
# 但 cmd[0] 是空字符串,shutil.which("") 返回 None
assert spec.should_execute() is False
assert spec.should_execute({})[0] is False
def test_skip_if_missing_combined_with_conditions():
"""skip_if_missing=True 与 conditions 组合使用."""
import sys
# conditions 返回 False,应跳过
spec = TaskSpec(
"test",
cmd=["python", "--version"],
cmd=[sys.executable, "--version"],
skip_if_missing=True,
conditions=(lambda: False,),
conditions=(lambda _ctx: False,),
)
assert spec.should_execute() is False
assert spec.should_execute({})[0] is False
# conditions 返回 True,命令存在,应执行
spec = TaskSpec(
"test",
cmd=["python", "--version"],
cmd=[sys.executable, "--version"],
skip_if_missing=True,
conditions=(lambda: True,),
conditions=(lambda _ctx: True,),
)
assert spec.should_execute() is True
assert spec.should_execute({})[0] is True
# conditions 返回 True,命令不存在,应跳过
spec = TaskSpec(
"test",
cmd=["definitely_not_installed_app_xyz"],
skip_if_missing=True,
conditions=(lambda: True,),
conditions=(lambda _ctx: True,),
)
assert spec.should_execute() is False
assert spec.should_execute({})[0] is False
def test_skip_if_missing_skips_task_in_run():
+30 -24
View File
@@ -52,7 +52,7 @@ def test_taskspec_with_conditions_skip():
"""测试条件不满足时任务被跳过."""
# 创建一个永远不会满足的条件
def never_true():
def never_true(_ctx):
return False
graph = px.Graph.from_specs([
@@ -73,7 +73,7 @@ def test_taskspec_with_conditions_execute():
"""测试条件满足时任务正常执行."""
# 创建一个总是满足的条件
def always_true():
def always_true(_ctx):
return True
graph = px.Graph.from_specs([
@@ -103,17 +103,17 @@ def test_platform_conditions():
px.TaskSpec(
"win_task",
cmd=win_cmd,
conditions=(lambda: Constants.IS_WINDOWS,),
conditions=(lambda _ctx: Constants.IS_WINDOWS,),
),
px.TaskSpec(
"linux_task",
cmd=posix_cmd,
conditions=(lambda: Constants.IS_LINUX,),
conditions=(lambda _ctx: Constants.IS_LINUX,),
),
px.TaskSpec(
"macos_task",
cmd=posix_cmd,
conditions=(lambda: Constants.IS_MACOS,),
conditions=(lambda _ctx: Constants.IS_MACOS,),
),
])
@@ -137,17 +137,15 @@ def test_platform_conditions():
def test_app_installed_conditions():
"""测试应用安装条件."""
# 测试 python 应该总是安装的
if sys.platform == "win32":
python_cmd = ["python", "--version"]
else:
python_cmd = ["python3", "--version"]
# 使用 sys.executable 保证可移植
python_cmd = [sys.executable, "--version"]
py_name = "python" if sys.platform == "win32" else "python3"
graph = px.Graph.from_specs([
px.TaskSpec(
"python_check",
cmd=python_cmd,
conditions=(BuiltinConditions.HAS_INSTALLED("python"),),
conditions=(BuiltinConditions.HAS_INSTALLED(py_name),),
),
])
@@ -162,18 +160,18 @@ def test_combined_conditions():
"""测试组合条件."""
# AND 条件
and_condition = BuiltinConditions.AND(
lambda: True,
lambda: True,
lambda _ctx: True,
lambda _ctx: True,
)
# OR 条件
or_condition = BuiltinConditions.OR(
lambda: True,
lambda: False,
lambda _ctx: True,
lambda _ctx: False,
)
# NOT 条件
not_condition = BuiltinConditions.NOT(lambda: False)
not_condition = BuiltinConditions.NOT(lambda _ctx: False)
graph = px.Graph.from_specs([
px.TaskSpec(
@@ -228,7 +226,7 @@ def test_taskspec_with_timeout():
# 短时间任务应该成功
px.TaskSpec(
"short_task",
cmd=["python", "-c", "import time; time.sleep(0.1)"],
cmd=[sys.executable, "-c", "import time; time.sleep(0.1)"],
timeout=1.0,
),
])
@@ -245,13 +243,13 @@ def test_taskspec_dependency_with_conditions():
px.TaskSpec(
"first",
cmd=[*ECHO_CMD, "first"],
conditions=(lambda: True,),
conditions=(lambda _ctx: True,),
),
px.TaskSpec(
"second",
cmd=[*ECHO_CMD, "second"],
depends_on=("first",),
conditions=(lambda: True,),
conditions=(lambda _ctx: True,),
),
px.TaskSpec(
"third",
@@ -378,7 +376,7 @@ class TestTaskSpecVerbose:
graph = px.Graph.from_specs([
px.TaskSpec(
"fail",
cmd=["python", "-c", "import sys; sys.exit(1)"],
cmd=[sys.executable, "-c", "import sys; sys.exit(1)"],
verbose=True,
)
])
@@ -414,7 +412,7 @@ class TestTaskSpecCmdErrors:
px.TaskSpec(
"fail",
cmd=[
"python",
sys.executable,
"-c",
"import sys; sys.stderr.write('error-msg'); sys.exit(1)",
],
@@ -437,7 +435,9 @@ class TestTaskSpecCmdErrors:
"""shell 命令失败时应抛出 RuntimeError."""
from pyflowx.errors import TaskFailedError
graph = px.Graph.from_specs([px.TaskSpec("fail", cmd='python -c "import sys; sys.exit(1)"')])
graph = px.Graph.from_specs([
px.TaskSpec("fail", cmd=f'{sys.executable} -c "import sys; sys.exit(1)"'),
])
with pytest.raises(TaskFailedError) as exc_info:
_ = px.run(graph, strategy="sequential")
assert "Shell 命令执行失败" in str(exc_info.value.cause)
@@ -450,7 +450,7 @@ class TestTaskSpecCmdErrors:
graph = px.Graph.from_specs([
px.TaskSpec(
"slow",
cmd=["python", "-c", "import time; time.sleep(5)"],
cmd=[sys.executable, "-c", "import time; time.sleep(5)"],
timeout=0.1,
)
])
@@ -463,7 +463,13 @@ class TestTaskSpecCmdErrors:
"""shell 命令超时应抛出 RuntimeError."""
from pyflowx.errors import TaskFailedError
graph = px.Graph.from_specs([px.TaskSpec("slow", cmd='python -c "import time; time.sleep(5)"', timeout=0.1)])
graph = px.Graph.from_specs([
px.TaskSpec(
"slow",
cmd=f'{sys.executable} -c "import time; time.sleep(5)"',
timeout=0.1,
),
])
with pytest.raises(TaskFailedError) as exc_info:
_ = px.run(graph, strategy="sequential")
assert "超时" in str(exc_info.value.cause)