Compare commits
7 Commits
v0.2.11
...
232e7293d9
| Author | SHA1 | Date | |
|---|---|---|---|
| 232e7293d9 | |||
| a1bae58e56 | |||
| cbc7cc0a75 | |||
| d0ff7d7b4d | |||
| d154f67ce0 | |||
| 9999071119 | |||
| bdd70e9c43 |
@@ -0,0 +1,135 @@
|
||||
---
|
||||
name: "pyflowx-testing"
|
||||
description: "PyFlowX 项目的测试编写规范与 mock 使用指南。在编写或审查测试、选择 mock 工具、设计 fixture、处理 asyncio 测试时调用。"
|
||||
---
|
||||
|
||||
# PyFlowX 测试规范
|
||||
|
||||
本技能是 `.trae/rules/python-standards.md` 测试章节的详细展开。
|
||||
规则文件仅保留硬约束指针,本文件提供完整操作指南。
|
||||
|
||||
## 总则
|
||||
|
||||
- **覆盖率 ≥ 95%**(branch coverage),不得下降。
|
||||
- **公共 API 优先测试**:测试用公共接口(`has`/`get`),不访问私有方法
|
||||
(如 `_expired`)。兼容旧测试的私有方法应删除并迁移测试。
|
||||
例外:`_store`/`_flush` 等内部状态在无法用公共 API 触发时(如模拟过期、
|
||||
故障注入),可临时访问私有属性,并在 docstring 注明原因。
|
||||
- **命名**:`test_<被测对象>_<场景>`,如 `test_storage_key_cache_key_exception_returns_name`。
|
||||
- **每个测试一个断言重点**;多个断言要语义相关。
|
||||
- **slow 标记**:耗时测试加 `@pytest.mark.slow`,CI 可 `-m "not slow"` 跳过。
|
||||
- **测试代码也跑 ruff**:`tests/**` 忽略 `ARG001`/`ARG002`(未用 fixture 参数)。
|
||||
- **断言风格**:用原生 `assert` + 比较运算符(`assert x == 1`),
|
||||
不用 `self.assertEqual`;pytest 会生成更清晰的 diff。
|
||||
|
||||
## Mock 工具选择(强制)
|
||||
|
||||
**优先级**:`monkeypatch` > 内联 stub > `unittest.mock` > `pytest-mock`。
|
||||
|
||||
| 场景 | 工具 | 示例 |
|
||||
|------|------|------|
|
||||
| 替换模块属性 / 环境变量 / 工作目录 | `monkeypatch` | `monkeypatch.setattr(subprocess, "run", fake_run)` |
|
||||
| `os.environ["KEY"]` 临时设置 | `monkeypatch.setenv` | `monkeypatch.setenv("LOCALAPPDATA", "C:\\...")` |
|
||||
| 切换 cwd | `monkeypatch.chdir` | `monkeypatch.chdir(tmp_path)` |
|
||||
| 一次性 stub 函数 | 内联 lambda / 闭包 | `ran = []; monkeypatch.setattr(subprocess, "run", lambda *c, **__: ran.append(c))` |
|
||||
| 复杂 spy(记录调用次数/参数/返回序列) | `unittest.mock.MagicMock` | 仅当 lambda 不足以表达时 |
|
||||
| `with patch(...)` 上下文 | **禁用**(用 monkeypatch) | monkeypatch 自动 teardown 更安全 |
|
||||
|
||||
**禁止**:
|
||||
- 不用 `pytest-mock` 的 `mocker` fixture(项目虽在 dev 依赖声明,但实际
|
||||
测试代码未使用;为保持风格统一,新代码继续用 `monkeypatch`)。
|
||||
- 不用 `unittest.mock.patch` 装饰器(`@patch("x.y")`),它隐藏依赖且
|
||||
与 pytest fixture 模式不兼容;用 `monkeypatch.setattr` 替代。
|
||||
- 不用 `mock.patch.object` 作为上下文管理器,除非被测代码本身就是
|
||||
contextmanager(此时用 `monkeypatch.setattr` 仍更简单)。
|
||||
|
||||
## monkeypatch 使用规范
|
||||
|
||||
- **类型注解**:fixture 参数标注 `monkeypatch: pytest.MonkeyPatch`。
|
||||
- **作用域**:monkeypatch 自动在测试结束时撤销,**禁止**手动
|
||||
`monkeypatch.setattr(x, "y", original)` 恢复(多余且容易遗漏)。
|
||||
例外:在单个测试内需要中途恢复时,用 `monkeypatch.undo()` 全量撤销。
|
||||
- **替换目标**:替换"被测代码看到的对象",而非全局对象本身。
|
||||
- 错误:`monkeypatch.setattr("os.path.exists", fake)` —— 替换全局,影响其他模块。
|
||||
- 正确:`monkeypatch.setattr(pyflowx.command.shutil, "which", fake)` ——
|
||||
替换被测模块引用的 `shutil.which`。
|
||||
- **属性 vs 字符串路径**:优先属性访问形式 `monkeypatch.setattr(obj, "attr", val)`
|
||||
而非字符串路径 `monkeypatch.setattr("pkg.mod.obj.attr", val)`,
|
||||
前者有 IDE 跳转与重构支持。
|
||||
- **记录调用**:用闭包 `ran: list[tuple] = []` + `lambda *a, **k: ran.append((a, k))`
|
||||
替代 `MagicMock`,可读性更好且无需导入。
|
||||
|
||||
## Stub 与 Spy 模式
|
||||
|
||||
- **轻量 stub**:内联定义 `class MockResult: returncode = 0; stdout = ""`,
|
||||
替代 `MagicMock(return_value=...)`,类型明确且不引入 mock 依赖。
|
||||
- **状态收集**:闭包 + list 比 `mock.call_args_list` 更易断言:
|
||||
```python
|
||||
calls: list[list[str]] = []
|
||||
|
||||
|
||||
def fake_run(cmd: list[str], **_: Any) -> MockResult:
|
||||
calls.append(cmd)
|
||||
return MockResult()
|
||||
|
||||
|
||||
monkeypatch.setattr(subprocess, "run", fake_run)
|
||||
assert calls == [["clear"]]
|
||||
```
|
||||
- **副作用序列**:需要按调用次数返回不同值时,用 `itertools.cycle` 或
|
||||
手动计数器,而非 `side_effect=[...]`(mock 专有 API)。
|
||||
- **异常注入**:`def raise_oserror(*a, **k): raise OSError("...")`,
|
||||
用 `pytest.raises(OSError)` 验证,而非 `side_effect=OSError`。
|
||||
|
||||
## 异常断言
|
||||
|
||||
- **`pytest.raises`**:必填 `match=` 正则(除非异常消息完全不可预测),
|
||||
避免误捕获同类异常:
|
||||
```python
|
||||
with pytest.raises(StorageError, match="cannot write"):
|
||||
b.save("a", 1)
|
||||
```
|
||||
- **异常链**:验证 `__cause__` 时用 `exc_info.value.__cause__`,
|
||||
确认 `raise X from Y` 因果链完整。
|
||||
- **禁止** `try/except + assert False`:用 `pytest.raises` 替代。
|
||||
|
||||
## Fixture 规范
|
||||
|
||||
- **`tmp_path`**:处理临时文件,自动清理,禁止 `tempfile.mkdtemp()` 手动管理。
|
||||
- **`monkeypatch`**:环境变量、cwd、模块属性 mock(见上)。
|
||||
- **`capsys`/`capfd`**:捕获 stdout/stderr,验证日志或命令输出。
|
||||
- **autouse fixture**:仅在全局必需时用(如 `conftest.py` 的
|
||||
`packtool_tmp_workdir` 自动切到 tmp_path);否则显式声明参数。
|
||||
- **fixture 命名**:`snake_case`,描述"提供什么"而非"测试什么"
|
||||
(`sample_graph` 优于 `test_data`)。
|
||||
- **fixture 作用域**:默认 `function`;`module`/`session` 仅当构造昂贵且
|
||||
只读时,并加注释说明无副作用。
|
||||
|
||||
## asyncio 测试
|
||||
|
||||
- **fixture `loop_scope="function"`**(pyproject 已配置默认值)。
|
||||
- **async 测试**:`async def test_x():`,pytest-asyncio 自动驱动。
|
||||
- **await 检查**:测试异步函数必须 `await` 结果,禁止仅验证返回 coroutine 对象。
|
||||
- **异步 mock**:用 `AsyncMock`(3.8+ 在 `unittest.mock`)或
|
||||
`async def fake(): return value`,禁用 `MagicMock(return_value=coro)`。
|
||||
|
||||
## 参数化
|
||||
|
||||
- **`@pytest.mark.parametrize`**:用 `ids` 参数提供可读标识:
|
||||
```python
|
||||
@pytest.mark.parametrize(
|
||||
("strategy", "expected_workers"),
|
||||
[("sequential", 1), ("thread", 8), ("async", 1)],
|
||||
ids=["seq", "thread-8", "async"],
|
||||
)
|
||||
```
|
||||
- **参数命名**:参数元组用有意义名称,而非 `("a", "b")`。
|
||||
- **组合爆炸**:参数组合 > 20 时拆分测试,避免单个测试函数臃肿。
|
||||
|
||||
## 测试组织
|
||||
|
||||
- **文件命名**:`test_<被测模块>.py`(`test_storage.py` 对应 `storage.py`)。
|
||||
- **类分组**:仅在测试逻辑强相关时用 `class TestXxx:` 分组;默认用模块级函数。
|
||||
- **docstring**:每个测试函数一句话说明"测试什么场景",复杂场景补充"为什么"。
|
||||
- **setup/teardown**:优先 fixture;`setup_method`/`teardown_method` 仅在
|
||||
无法用 fixture 表达时(罕见)。
|
||||
@@ -0,0 +1,15 @@
|
||||
# PYTHON
|
||||
.coverage
|
||||
.pytest_cache/
|
||||
.ruff_cache/
|
||||
.tox/
|
||||
.venv/
|
||||
__pycache__/
|
||||
|
||||
# NODEJS
|
||||
node_modules/
|
||||
|
||||
# IDE
|
||||
.idea
|
||||
.trae
|
||||
.vscode
|
||||
@@ -0,0 +1,11 @@
|
||||
---
|
||||
alwaysApply: true
|
||||
scene: git_message
|
||||
---
|
||||
|
||||
在此处编写规则,自定义 AI 生成提交信息的风格。
|
||||
|
||||
## 提交信息格式
|
||||
- 提交信息必须使用中文。
|
||||
- 提交信息必须包含变更的类型(例如 "fix"、"feat"、"refactor" 等)。
|
||||
- 提交信息必须尽简洁明了,不要超过一段落。
|
||||
@@ -0,0 +1,157 @@
|
||||
# Python 开发规范
|
||||
|
||||
本规范结合 Python 最佳实践,作为编写与审查 Python 代码的统一标准。
|
||||
详细操作指南见 `.agents/skills/` 下相应技能。
|
||||
|
||||
## 工具链(以 pyproject.toml 为准)
|
||||
|
||||
| 工具 | 用途 | 配置要点 |
|
||||
|------|------|---------|
|
||||
| **ruff** | lint + format | `line-length=120`,`target-version="py38"` |
|
||||
| **pyrefly** | 类型检查 | `preset="strict"`,`python-version="3.8"` |
|
||||
| **pytest** | 测试 | `asyncio_default_fixture_loop_scope="function"`,marker `slow` |
|
||||
| **coverage** | 覆盖率 | `branch=true`,`fail_under=95`,`concurrency=["thread"]` |
|
||||
| **pre-commit** | 提交前检查 | ruff `--fix` + trailing-whitespace + end-of-file-fixer |
|
||||
|
||||
验证(每次修改后必做):
|
||||
|
||||
```bash
|
||||
uvx --from pyflowx pymake tc
|
||||
uvx --from pyflowx pymake cov
|
||||
```
|
||||
|
||||
## 兼容性
|
||||
|
||||
- **最低 Python 3.8**:用 `from __future__ import annotations` 延迟注解求值;
|
||||
按版本用 `typing.List`(3.8) → 内置泛型(3.9) → `X | Y`(3.10) → `typing.override`(3.12)。
|
||||
- **版本守卫**:`if sys.version_info >= (3, X):` 引入高版本 API;低版本回退分支加 `# pragma: no cover`。
|
||||
- **零运行时依赖**:仅依赖标准库(3.8 需 `graphlib_backport`、`typing-extensions`)。
|
||||
新增依赖须审慎,优先用标准库。
|
||||
|
||||
## 类型注解
|
||||
|
||||
- **公共 API 必须有完整类型注解**,包括返回类型;私有函数也应有注解。
|
||||
- 泛型用 `TypeVar`;PEP 696 `default=` 仅 3.13+ 标准库支持,3.8–3.12 用 `typing_extensions.TypeVar`。
|
||||
- `Mapping`/`Sequence` 用于只读参数,`dict`/`list` 用于可变返回。
|
||||
- `Any` 仅用于真正动态场景(如 `Context` 跨任务异构映射);任务内部类型必须完全静态。
|
||||
- 禁用裸 `# type: ignore`;确需时加具体规则码(如 `# type: ignore[union-attr]`)。
|
||||
- **`TYPE_CHECKING` 守卫**:仅类型检查需要的导入放 `if TYPE_CHECKING:` 块内,避免循环依赖。
|
||||
- **类型收窄**:用 `assert isinstance(x, Y)` 辅助 pyrefly 推断;`cast()` 仅用于类型系统无法表达的场景。
|
||||
|
||||
## 数据结构
|
||||
|
||||
- **不可变优先**:配置/描述类用 `@dataclass(frozen=True)`;可变类属性标注 `RUF012` 豁免。
|
||||
- **缓存**:实例级用 `functools.cached_property`,按参数键控用 `functools.lru_cache`;
|
||||
不可哈希参数需 try/except 回退。修改被缓存数据源后必须手动清空缓存。
|
||||
- **抽象基类**:接口用 `abc.ABC` + `@abstractmethod`(如 `StateBackend`)。
|
||||
- **枚举**:状态/标志值用 `enum.Enum`(如 `TaskStatus`),禁止裸字符串/魔术数字;枚举值用 `UPPER_SNAKE`。
|
||||
- **`__repr__`**:可变类实现 `__repr__`(含关键字段);`frozen=True` dataclass 自动生成。
|
||||
|
||||
## 模块与导入
|
||||
|
||||
- **单一职责**:每模块只做一件事(`task.py` 数据结构、`executors.py` 执行、`command.py` 命令、`compose.py` 组合)。禁止跨职责边界。
|
||||
- **导入顺序**(ruff isort):`__future__` → 标准库 → 第三方 → 本地,各组间空行。
|
||||
- **惰性导入**:仅为打破循环依赖时使用,函数体内导入并注释说明;顶层导入是默认。
|
||||
- **`__all__`**:定义 `__all__` 显式声明导出符号,位置仅次于 `__future__` 之后。
|
||||
- **禁用 star imports**:`from x import *` 污染命名空间、破坏类型检查(`__init__.py` 聚合经 `__all__` 控制为例外)。
|
||||
- **避免 `utils.py`/`helpers.py`**:按职责归入对应模块。
|
||||
|
||||
## 函数设计
|
||||
|
||||
- **模块级函数优于 Mixin**:共享逻辑用模块级函数,类只持有状态与薄方法。
|
||||
- **静态方法慎用**:纯函数直接放模块级。
|
||||
- **参数 ≤ 5 个**为宜;超出用 dataclass 封装参数对象。
|
||||
- **单一职责**:一个函数做一件事;过长函数考虑拆分。
|
||||
- **异常范围要窄**:只捕获预期异常(如 `(TypeError, ValueError, KeyError, AttributeError)`),
|
||||
**禁止** `except Exception` 掩盖 bug;捕获后至少 `logger.warning` 记录。
|
||||
- **可变默认参数**:`def f(x=[])` 是经典坑;用 `None` 哨兵或 `field(default_factory=list)`。
|
||||
|
||||
## 异常处理
|
||||
|
||||
- **自定义异常家族**:继承公共基类(如 `PyFlowXError`),按错误场景分类。
|
||||
- **异常包装**:`raise NewError(...) from exc` 保留因果链。
|
||||
- **不要吞异常**:捕获后必须处理(记录/包装/重抛),禁止空 `except: pass`。
|
||||
- **钩子/回调异常**:第三方回调异常仅记录,不影响主流程。
|
||||
|
||||
## 并发与线程安全
|
||||
|
||||
- **进程全局状态**(`os.environ`/`os.chdir`)在并发场景下必须用全局锁(`threading.RLock`)序列化。
|
||||
- **条件评估不可有可变状态**:组合条件(NOT/AND/OR)不得修改共享 `_reason`,避免竞态。
|
||||
- **批量 I/O**:循环内多次写盘改为批量一次(`contextmanager` 包裹延迟落盘)。
|
||||
- **信号量限流**:`concurrency_key` + `Semaphore` 按组限流。
|
||||
|
||||
## 测试
|
||||
|
||||
详细操作指南见 `.agents/skills/pyflowx-testing` 技能。硬约束:
|
||||
|
||||
- **覆盖率 ≥ 95%**(branch coverage),不得下降。
|
||||
- **公共 API 优先测试**:用公共接口(`has`/`get`),不访问私有方法;
|
||||
故障注入等场景可临时访问私有属性,docstring 注明原因。
|
||||
- **命名**:`test_<被测对象>_<场景>`。
|
||||
- **断言**:原生 `assert x == 1`,禁用 `self.assertEqual`;`pytest.raises` 必填 `match=`。
|
||||
- **Mock 优先级**:`monkeypatch` > 内联 stub > `unittest.mock` > `pytest-mock`。
|
||||
禁用 `@patch` 装饰器、`mock.patch.object` 上下文、`pytest-mock` 的 `mocker` fixture。
|
||||
- **fixture**:`tmp_path`/`monkeypatch`/`capsys` 优先;autouse 仅全局必需时用。
|
||||
- **slow 标记**:耗时测试加 `@pytest.mark.slow`,CI 可 `-m "not slow"` 跳过。
|
||||
- **测试代码也跑 ruff**:`tests/**` 忽略 `ARG001`/`ARG002`。
|
||||
|
||||
## 代码风格
|
||||
|
||||
- **行宽 120**(ruff formatter 处理)。
|
||||
- **docstring**:公共 API 必须有;中文叙述 + 中文注释是本项目既有风格。
|
||||
- **打印和日志**:使用中文打印和日志,避免使用英文。
|
||||
- **命名**:`snake_case` 函数/变量,`PascalCase` 类,`UPPER_SNAKE` 常量,`_leading_underscore` 私有。
|
||||
- **字符串引号**:ruff 默认双引号。
|
||||
- **末尾单 `\n`**、**无尾随空格**(pre-commit 强制)。
|
||||
- **不用 emoji**:除非用户明确要求。
|
||||
|
||||
## Pythonic 风格
|
||||
|
||||
- **`is` 比较 `None`/`True`/`False`**:单例用 `is`,值用 `==`(PEP 8 E711/E712)。
|
||||
- **EAFP 优于 LBYL**:先尝试再处理异常,而非先检查再执行(避免竞态窗口)。
|
||||
- **truthiness**:`if items:` 优于 `if len(items) > 0:`。
|
||||
- **字符串格式化**:首选 f-string;`%` 仅用于 `logging` 延迟格式化。
|
||||
- **推导式**优于 `map`+`filter`;> 2 层拆为显式循环。
|
||||
- **`enumerate`** 替代 `range(len())`;**`zip`** 并行迭代(3.10+ 用 `strict=True`)。
|
||||
- **解包** `a, b = pair` 优于索引访问;忽略值用 `_`。
|
||||
- **海象运算符 `:=`**(3.8+):赋值+判断合一,但不滥用。
|
||||
|
||||
## 日志
|
||||
|
||||
- **`logging.getLogger(__name__)`**:每模块独立 logger,禁用 `print` 调试残留。
|
||||
- **结构化上下文**:`extra={...}` 传字段;`logger.warning("task %r failed: %s", name, exc)` 优于 f-string(延迟格式化)。
|
||||
- **日志级别**:`DEBUG` 诊断 / `INFO` 关键流程 / `WARNING` 可恢复异常 / `ERROR` 需人工介入。
|
||||
- **禁止日志密码/密钥**:脱敏后再记录。
|
||||
|
||||
## 路径与资源
|
||||
|
||||
- **优先 `pathlib.Path`**:`Path("a") / "b"` 而非 `os.path.join`(ruff `PTH` 强制);
|
||||
禁止字符串拼接路径。类型注解用 `Path`,边界 `str` 立即包装。
|
||||
- **`with` 语句**:文件、锁、连接、临时目录一律用 `with` 或 `contextlib.contextmanager`;
|
||||
多资源用 `contextlib.ExitStack`。
|
||||
- **显式关闭**:长生命周期对象(连接池、线程池)实现 `close()`,但优先 `with`。
|
||||
- **批量操作**:循环内多次 acquire/release 改为批量一次。
|
||||
|
||||
## 安全
|
||||
|
||||
- **禁用 `eval`/`exec`**:处理不可信输入时绝不使用;用 `ast.literal_eval` 或专用解析器。
|
||||
- **`subprocess`**:禁用 `shell=True` 除非命令完全可信;优先 `list[str]` 形式。
|
||||
- **凭证不入仓**:密钥/token/密码放 `.env` 或环境变量,`.gitignore` 必须包含 `.env`。
|
||||
- **日志脱敏**:记录请求/响应时移除 `Authorization`、`password` 等字段。
|
||||
- **依赖审计**:`uv lock` 后审阅新增依赖,避免引入已知 CVE 的包。
|
||||
|
||||
## 性能要点
|
||||
|
||||
- **避免重复计算**:循环内查询应缓存或预构建映射(如 `{name: spec}`)。
|
||||
- **避免双重查找**:`has(k)` + `get(k)` 改为单次 `get(k)` + `KeyError` 回退。
|
||||
- **统一校验**:入口校验一次,下游路径不重复(如 `run()` 统一 `validate()`,`layers()` 不再重复)。
|
||||
- **事件 emit**:任务生命周期必须 emit `RUNNING` → `SUCCESS`/`FAILED`/`SKIPPED`,
|
||||
不要留死分支(`# pragma: no cover` 是清理信号,应激活或删除)。
|
||||
|
||||
## Git 与提交
|
||||
|
||||
- **不自动提交/push**:除非用户明确要求。
|
||||
- **不修改 git config**。
|
||||
- **不运行破坏性命令**(`push --force`/`reset --hard`/`clean -f`)除非用户明确要求。
|
||||
- **staging**:按文件名添加,不用 `git add -A`/`git add .`,避免误加敏感文件。
|
||||
- **commit message**:简洁,聚焦"为什么"而非"是什么";遵循仓库既有风格。
|
||||
@@ -14,18 +14,25 @@ PyFlowX 把"任务依赖"这件事做到极致简单:**参数名就是依赖
|
||||
## 特性
|
||||
|
||||
- **零样板** —— 参数名即依赖,框架自动注入上游结果
|
||||
- **三种执行策略** —— `sequential`(调试)/ `thread`(I/O 密集同步)/ `async`(I/O 密集异步)
|
||||
- **四种执行策略** —— `sequential`(调试)/ `thread`(I/O 密集同步)/ `async`(I/O 密集异步)/ `dependency`(依赖驱动,最大化并行)
|
||||
- **类型安全** —— `TaskSpec[T]` 把返回类型一路传到 `RunReport`,mypy strict 通过
|
||||
- **DAG 校验** —— 构建时即时校验重名、缺失依赖、环
|
||||
- **自动分层** —— Kahn 算法分组,同层任务可并行
|
||||
- **重试与超时** —— 每个任务独立配置 `retries` 与 `timeout`
|
||||
- **断点续跑** —— `MemoryBackend` / `JSONBackend`,成功结果可缓存复用
|
||||
- **重试与超时** —— 每个任务独立配置 `RetryPolicy`(max_attempts/delay/backoff/jitter/retry_on)与 `timeout`
|
||||
- **软依赖** —— `soft_depends_on` 仅用于上下文注入,不参与拓扑分层
|
||||
- **并发限制** —— `concurrency_key` + `concurrency_limits` 按组限流
|
||||
- **任务钩子** —— `TaskHooks`(pre_run/post_run/on_failure)生命周期回调
|
||||
- **断点续跑** —— `MemoryBackend` / `JSONBackend`,成功结果可缓存复用;`batch()` 批量落盘
|
||||
- **缓存键** —— `cache_key` 函数基于输入计算稳定键,使不同输入产生独立缓存
|
||||
- **命令任务** —— `cmd` 参数直接执行外部命令,支持列表/shell/可调用对象
|
||||
- **条件执行** —— `conditions` 参数按平台、环境变量、应用安装等条件跳过任务
|
||||
- **图组合** —— `compose` / `GraphComposer` 编程式展开多图字符串引用
|
||||
- **任务模板** —— `task_template` 工厂批量生成相似 TaskSpec
|
||||
- **图级默认值** —— `GraphDefaults` 统一配置 retry/timeout/concurrency 等
|
||||
- **CLI 运行器** —— `CliRunner` 把多个图映射为命令行子命令,替代 Makefile
|
||||
- **可观测** —— `on_event` 回调、`dry_run` 预览、`verbose` 生命周期日志、Mermaid 可视化
|
||||
- **可观测** —— `on_event` 回调(RUNNING/SUCCESS/FAILED/SKIPPED)、`dry_run` 预览、`verbose` 生命周期日志、Mermaid 可视化
|
||||
- **零运行时依赖** —— 仅依赖标准库(3.8 需 `graphlib_backport`)
|
||||
- **95% 测试覆盖** —— 分支覆盖率>= 95%
|
||||
- **97% 测试覆盖** —— 分支覆盖率 >= 95%
|
||||
|
||||
## 安装
|
||||
|
||||
@@ -67,23 +74,31 @@ print(report["double"]) # [2, 4, 6]
|
||||
|
||||
### TaskSpec —— 任务描述
|
||||
|
||||
`TaskSpec` 是不可变的任务描述符,是唯一需要配置的东西:
|
||||
`TaskSpec` 是不可变的任务描述符(`Generic[T]`,返回类型一路传到 `RunReport`),是唯一需要配置的东西:
|
||||
|
||||
```python
|
||||
px.TaskSpec(
|
||||
name="fetch_user", # 唯一标识
|
||||
fn=fetch_user, # 同步或异步函数
|
||||
cmd=["curl", "..."], # 或: 执行命令(覆盖 fn)
|
||||
depends_on=("auth",), # 依赖的任务名
|
||||
depends_on=("auth",), # 硬依赖(参与拓扑分层)
|
||||
soft_depends_on=("cache",), # 软依赖(仅注入,不参与分层)
|
||||
args=(uid,), # 静态位置参数(追加在注入参数后)
|
||||
kwargs={"timeout": 30}, # 静态关键字参数
|
||||
retries=3, # 失败重试次数(0 = 仅一次)
|
||||
retry=px.RetryPolicy(max_attempts=3, delay=1.0, backoff=2.0), # 重试策略
|
||||
timeout=30.0, # 超时秒数(None = 不限制)
|
||||
tags=("api", "user"), # 自由标签,用于子图过滤
|
||||
conditions=(is_prod,), # 条件函数列表(全部为 True 才执行)
|
||||
priority=10, # 同层内优先级(高优先执行,默认 0)
|
||||
concurrency_key="db", # 并发分组键(配合 concurrency_limits 限流)
|
||||
cache_key=lambda ctx: str(ctx.get("uid")), # 缓存键函数(不同输入独立缓存)
|
||||
hooks=px.TaskHooks(pre_run=..., post_run=..., on_failure=...), # 生命周期钩子
|
||||
cwd=Path("/tmp"), # 命令工作目录(仅 cmd 模式)
|
||||
env={"DEBUG": "1"}, # 环境变量覆盖(fn 与 cmd 模式均生效)
|
||||
verbose=True, # 打印命令输出(仅 cmd 模式)
|
||||
skip_if_missing=True, # 命令不存在时自动跳过(仅 list[str] cmd)
|
||||
allow_upstream_skip=False, # 上游 SKIPPED/FAILED 时是否仍执行
|
||||
continue_on_error=False, # 本任务失败是否不中断整体
|
||||
)
|
||||
```
|
||||
|
||||
@@ -97,18 +112,54 @@ px.TaskSpec(
|
||||
### Graph —— DAG 构建
|
||||
|
||||
```python
|
||||
graph = px.Graph.from_specs([...]) # 整批校验(推荐)
|
||||
# 图级默认值:TaskSpec 字段为 None 时回退
|
||||
defaults = px.GraphDefaults(retry=px.RetryPolicy(max_attempts=2), timeout=60.0)
|
||||
|
||||
graph = px.Graph.from_specs([...], defaults=defaults) # 整批校验(推荐)
|
||||
# 或增量构建
|
||||
graph = px.Graph()
|
||||
graph = px.Graph(defaults=defaults)
|
||||
graph.add(px.TaskSpec("a", fn_a))
|
||||
graph.add(px.TaskSpec("b", fn_b, ("a",)))
|
||||
|
||||
graph.validate() # 显式校验(环检测)
|
||||
graph.layers() # 拓扑分层
|
||||
graph.layers() # 拓扑分层(run() 入口已统一校验,直接调用需自行先 validate)
|
||||
graph.to_mermaid() # Mermaid 可视化
|
||||
graph.describe() # 人类可读摘要
|
||||
graph.subgraph(("api",)) # 按标签切片
|
||||
graph.subgraph_by_names(("a", "b")) # 按名称切片
|
||||
graph.map("fetch", [1, 2, 3], lambda i: TaskSpec(f"fetch_{i}", ...)) # 批量 fan-out
|
||||
```
|
||||
|
||||
### 图组合 —— compose
|
||||
|
||||
`compose` / `GraphComposer` 把带字符串引用的多个图展开为纯 `Graph`:
|
||||
|
||||
```python
|
||||
graphs = {
|
||||
"build": px.Graph.from_specs([px.TaskSpec("b", cmd=["echo", "b"])]),
|
||||
"all": px.Graph.from_specs(["build", px.TaskSpec("t", cmd=["echo", "t"])]),
|
||||
}
|
||||
resolved = px.compose(graphs) # "all" 图中的 "build" 引用被展开
|
||||
```
|
||||
|
||||
引用格式:`"command_name"`(整个图)或 `"command_name.task_name"`(特定任务)。
|
||||
`CliRunner` 内部自动调用 `compose`。
|
||||
|
||||
### 任务模板 —— task_template
|
||||
|
||||
`task_template` 工厂批量生成相似 TaskSpec:
|
||||
|
||||
```python
|
||||
fetch = px.task_template(
|
||||
fn=fetch_url,
|
||||
retry=px.RetryPolicy(max_attempts=5),
|
||||
timeout=30.0,
|
||||
tags=("api",),
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
fetch("users", url="https://api.example.com/users"),
|
||||
fetch("posts", url="https://api.example.com/posts"),
|
||||
])
|
||||
```
|
||||
|
||||
### run —— 执行
|
||||
@@ -116,12 +167,14 @@ graph.subgraph_by_names(("a", "b")) # 按名称切片
|
||||
```python
|
||||
report = px.run(
|
||||
graph,
|
||||
strategy="async", # sequential | thread | async
|
||||
strategy="async", # sequential | thread | async | dependency
|
||||
max_workers=8, # thread 策略的线程池大小
|
||||
concurrency_limits={"db": 2}, # 按 concurrency_key 限流
|
||||
dry_run=False, # True = 仅打印计划
|
||||
verbose=False, # True = 打印任务生命周期日志
|
||||
on_event=callback, # 状态转换回调
|
||||
on_event=callback, # 状态转换回调(RUNNING/SUCCESS/FAILED/SKIPPED)
|
||||
state=px.JSONBackend("state.json"), # 断点续跑后端
|
||||
continue_on_error=False, # True = 单任务失败不中断整体
|
||||
)
|
||||
```
|
||||
|
||||
@@ -141,7 +194,7 @@ report.describe() # 人类可读报告
|
||||
按顺序求值:
|
||||
|
||||
1. **标注为 `Context`** 的参数 → 接收完整上游结果映射
|
||||
2. **名称匹配依赖** 的参数 → 接收该依赖的结果
|
||||
2. **名称匹配依赖** 的参数 → 接收该依赖的结果(含软依赖,缺失时注入默认值)
|
||||
3. **`**kwargs`** 参数 → 接收所有依赖结果(dict)
|
||||
4. **`TaskSpec.args` / `kwargs`** → 为非依赖参数提供静态值
|
||||
|
||||
@@ -170,8 +223,11 @@ def fetch_user(uid: int) -> dict: # uid 来自 TaskSpec.args
|
||||
| `sequential` | 串行 | 调试、CPU 密集 | 直接调用 | 事件循环 |
|
||||
| `thread` | 线程池 | I/O 密集同步 | 线程池 | 不支持 |
|
||||
| `async` | 事件循环 | I/O 密集异步 | 卸载到线程池 | 事件循环 |
|
||||
| `dependency` | 依赖驱动 | 最大化并行度 | 卸载到线程池 | 事件循环 |
|
||||
|
||||
所有策略都遵循 `retries`、`timeout`、上下文注入、状态后端,并发出 `TaskEvent`。
|
||||
所有策略都遵循 `RetryPolicy`、`timeout`、上下文注入、状态后端、`concurrency_limits`,
|
||||
并发出 `TaskEvent`(RUNNING/SUCCESS/FAILED/SKIPPED)。`dependency` 策略无层屏障:
|
||||
任务在其所有硬依赖完成后立即启动。
|
||||
|
||||
## 命令任务
|
||||
|
||||
@@ -275,12 +331,25 @@ python examples/async_aggregation.py
|
||||
from pyflowx import JSONBackend
|
||||
|
||||
# 第一次运行:成功结果写入 state.json
|
||||
backend = JSONBackend("state.json")
|
||||
backend = JSONBackend("state.json", ttl=3600) # ttl 秒数,过期条目自动忽略
|
||||
report = px.run(graph, strategy="sequential", state=backend)
|
||||
|
||||
# 第二次运行:已缓存任务自动跳过
|
||||
# 第二次运行:已缓存任务自动跳过(状态为 SKIPPED)
|
||||
report = px.run(graph, strategy="sequential", state=backend)
|
||||
# report.results 中缓存任务状态为 SKIPPED
|
||||
```
|
||||
|
||||
`run()` 内部以 `backend.batch()` 包裹整个执行:所有 `save` 延迟到运行结束时统一落盘一次
|
||||
(`JSONBackend` 从 O(N²) 降为 O(N) 磁盘写入;`MemoryBackend` 为 no-op)。
|
||||
|
||||
**缓存键**:默认存储键为任务名。配置 `cache_key` 函数后,键为 `"name:cache_key_value"`,
|
||||
使不同输入产生独立缓存条目:
|
||||
|
||||
```python
|
||||
px.TaskSpec(
|
||||
"fetch_user",
|
||||
fn=fetch_user,
|
||||
cache_key=lambda ctx: str(ctx.get("uid")), # 不同 uid 独立缓存
|
||||
)
|
||||
```
|
||||
|
||||
## 错误处理
|
||||
@@ -321,14 +390,52 @@ except px.PyFlowXError:
|
||||
|
||||
PyFlowX 专注于**单机 DAG 调度**的极致简洁,适合 ETL、数据处理、CI 流水线等场景。
|
||||
|
||||
## 高级特性
|
||||
|
||||
### 并发限制
|
||||
|
||||
按 `concurrency_key` 分组限流,避免压垮下游资源:
|
||||
|
||||
```python
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("q1", fn=query_db, concurrency_key="db"),
|
||||
px.TaskSpec("q2", fn=query_db, concurrency_key="db"),
|
||||
px.TaskSpec("q3", fn=query_db, concurrency_key="db"),
|
||||
])
|
||||
# 同一时刻最多 2 个 "db" 组任务运行
|
||||
px.run(graph, strategy="async", concurrency_limits={"db": 2})
|
||||
```
|
||||
|
||||
### 任务钩子
|
||||
|
||||
`TaskHooks` 在任务生命周期触发(异常仅记录,不影响任务状态):
|
||||
|
||||
```python
|
||||
hooks = px.TaskHooks(
|
||||
pre_run=lambda spec: print(f"start {spec.name}"),
|
||||
post_run=lambda spec, value: print(f"done {spec.name}"),
|
||||
on_failure=lambda spec, exc: alert(spec.name, exc),
|
||||
)
|
||||
px.TaskSpec("task", fn=work, hooks=hooks)
|
||||
```
|
||||
|
||||
### 优先级
|
||||
|
||||
同层内按 `priority` 降序执行(稳定排序):
|
||||
|
||||
```python
|
||||
px.TaskSpec("low", fn=work, priority=0)
|
||||
px.TaskSpec("high", fn=work, priority=10) # 同层内先执行
|
||||
```
|
||||
|
||||
## 开发
|
||||
|
||||
```bash
|
||||
# 安装开发依赖
|
||||
uv sync --extra dev
|
||||
|
||||
# 运行测试(含覆盖率)
|
||||
uv run pytest --cov=pyflowx --cov-fail-under=100
|
||||
# 运行测试(含覆盖率,阈值 95%)
|
||||
uv run pytest --cov=pyflowx --cov-fail-under=95
|
||||
|
||||
# 类型检查
|
||||
uv run mypy
|
||||
@@ -338,6 +445,22 @@ uv run ruff check src tests examples
|
||||
uv run ruff format --check src tests examples
|
||||
```
|
||||
|
||||
## 模块结构
|
||||
|
||||
| 模块 | 职责 |
|
||||
|------|------|
|
||||
| `task.py` | 纯数据结构:`TaskSpec`、`RetryPolicy`、`TaskHooks`、`TaskStatus` |
|
||||
| `graph.py` | DAG 构建、校验、分层、可视化 |
|
||||
| `compose.py` | 多图组合:`GraphComposer` / `compose` |
|
||||
| `context.py` | 上下文注入:参数名→依赖解析 |
|
||||
| `command.py` | 命令执行:`run_command`(list/shell/Callable) |
|
||||
| `conditions.py` | 条件执行:内置条件与组合器 |
|
||||
| `executors.py` | 执行器与 `run` 入口:四种策略共享模块级辅助 |
|
||||
| `storage.py` | 状态后端:`MemoryBackend` / `JSONBackend`(batch flush) |
|
||||
| `runner.py` | CLI 运行器:`CliRunner` |
|
||||
| `report.py` | 运行结果:`RunReport` / `TaskResult` |
|
||||
| `errors.py` | 错误家族:`PyFlowXError` 子类 |
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT
|
||||
|
||||
@@ -58,6 +58,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .command import run_command
|
||||
from .compose import GraphComposer, compose
|
||||
from .conditions import (
|
||||
IS_LINUX,
|
||||
IS_MACOS,
|
||||
@@ -79,7 +81,7 @@ from .errors import (
|
||||
TaskTimeoutError,
|
||||
)
|
||||
from .executors import Strategy, run
|
||||
from .graph import Graph, GraphComposer, GraphDefaults, compose
|
||||
from .graph import Graph, GraphDefaults
|
||||
from .report import RunReport
|
||||
from .runner import CliExitCode, CliRunner
|
||||
from .storage import JSONBackend, MemoryBackend, StateBackend
|
||||
@@ -136,5 +138,6 @@ __all__ = [
|
||||
"compose",
|
||||
"describe_injection",
|
||||
"run",
|
||||
"run_command",
|
||||
"task_template",
|
||||
]
|
||||
|
||||
@@ -240,7 +240,7 @@ def _parse_email_date(date_str: str) -> str:
|
||||
try:
|
||||
dt = parsedate_to_datetime(date_str)
|
||||
return dt.isoformat()
|
||||
except Exception:
|
||||
except (ValueError, TypeError, OverflowError):
|
||||
return date_str
|
||||
|
||||
|
||||
@@ -277,11 +277,11 @@ def _extract_email_body_part(part: Any) -> str:
|
||||
decoded_text = payload.decode(charset, errors="replace")
|
||||
except (UnicodeDecodeError, LookupError) as decode_error:
|
||||
# 如果指定编码失败,尝试常见编码
|
||||
logger.warning(f"字符编码 {charset} 解码失败: {decode_error}")
|
||||
logger.warning("字符编码 %s 解码失败: %s", charset, decode_error)
|
||||
for fallback_charset in ["utf-8", "gbk", "gb2312", "latin-1"]:
|
||||
try:
|
||||
decoded_text = payload.decode(fallback_charset, errors="replace")
|
||||
logger.info(f"成功使用备用编码 {fallback_charset} 解码")
|
||||
logger.info("成功使用备用编码 %s 解码", fallback_charset)
|
||||
break
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
continue
|
||||
@@ -293,15 +293,15 @@ def _extract_email_body_part(part: Any) -> str:
|
||||
# 限制长度并返回
|
||||
result = decoded_text[:MAX_BODY_LENGTH]
|
||||
if len(decoded_text) > MAX_BODY_LENGTH:
|
||||
logger.debug(f"正文内容过长,截取前{MAX_BODY_LENGTH}字符")
|
||||
logger.debug("正文内容过长,截取前%d字符", MAX_BODY_LENGTH)
|
||||
|
||||
return result
|
||||
|
||||
except AttributeError as attr_error:
|
||||
logger.error(f"邮件部分对象属性错误: {attr_error}")
|
||||
logger.error("邮件部分对象属性错误: %s", attr_error)
|
||||
return ""
|
||||
except Exception as unexpected_error:
|
||||
logger.error(f"提取邮件正文时发生未知错误: {unexpected_error}")
|
||||
logger.error("提取邮件正文时发生未知错误: %s", unexpected_error)
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
"""命令执行器:把 :class:`~pyflowx.task.TaskSpec` 的 ``cmd`` 字段(list /
|
||||
shell 字符串 / 可调用对象)转换为统一执行入口。
|
||||
|
||||
历史背景:原 ``task.py`` 的模块文档声明其为"纯数据结构",但 ``_run_command``
|
||||
属于命令执行逻辑,违反单一职责。此处将其抽离,``TaskSpec`` 仅持有配置,
|
||||
执行逻辑集中于本模块,便于独立测试与维护。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Any, List, Union, cast
|
||||
|
||||
from .task import TaskSpec
|
||||
|
||||
__all__ = ["run_command"]
|
||||
|
||||
|
||||
def run_command(spec: TaskSpec[Any]) -> Any: # noqa: PLR0912
|
||||
"""执行 ``spec.cmd`` 指定的命令(list / shell 字符串 / 可调用对象)。
|
||||
|
||||
与原 ``TaskSpec._run_command`` 行为一致:
|
||||
|
||||
- 可调用对象:直接调用,异常包装为 :class:`RuntimeError`。
|
||||
- list / str:通过 :func:`subprocess.run` 执行,非零返回码抛
|
||||
:class:`RuntimeError`(``verbose=False`` 时附 stderr)。
|
||||
- ``verbose=True`` 时打印执行信息与返回码到 stdout。
|
||||
- ``cwd`` / ``env`` 通过 subprocess 参数隔离(进程级状态仅在 fn 任务路径
|
||||
使用,cmd 路径不依赖 ``os.chdir`` / ``os.environ``)。
|
||||
"""
|
||||
cmd = spec.cmd
|
||||
verbose = spec.verbose
|
||||
cwd = spec.cwd
|
||||
timeout = spec.timeout
|
||||
env_override = spec.env
|
||||
|
||||
# 可调用对象:直接调用,返回其结果。
|
||||
if callable(cmd) and not isinstance(cmd, (list, str)):
|
||||
name = getattr(cmd, "__name__", "callable")
|
||||
if verbose:
|
||||
print(f"[verbose] 执行可调用命令: {name}", flush=True)
|
||||
if cwd is not None:
|
||||
print(f"[verbose] 工作目录: {cwd}", flush=True)
|
||||
try:
|
||||
return cmd()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"可调用命令执行异常: {name}: {e}") from e
|
||||
|
||||
is_list = isinstance(cmd, list)
|
||||
if is_list:
|
||||
cmd_str = " ".join(arg for arg in cmd) # type: ignore[union-attr]
|
||||
verb = "执行命令"
|
||||
label = "命令"
|
||||
else:
|
||||
cmd_str = cast(str, cmd)
|
||||
verb = "执行 Shell"
|
||||
label = "Shell 命令"
|
||||
|
||||
if verbose:
|
||||
print(f"[verbose] {verb}: {cmd_str}", flush=True)
|
||||
if cwd is not None:
|
||||
print(f"[verbose] 工作目录: {cwd}", flush=True)
|
||||
|
||||
# 合并环境变量
|
||||
run_env: dict[str, str] | None = None
|
||||
if env_override:
|
||||
run_env = dict(os.environ)
|
||||
run_env.update(env_override)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cast(Union[str, List[str]], cmd),
|
||||
shell=not is_list,
|
||||
cwd=cwd,
|
||||
env=run_env,
|
||||
timeout=timeout,
|
||||
capture_output=not verbose,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(f"{label}未找到: {cmd_str}") from None
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(f"{label}执行超时: {cmd_str} ({timeout}s)") from None
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"{label}执行异常: {cmd_str}: {e}") from e
|
||||
|
||||
if verbose:
|
||||
print(f"[verbose] 返回码: {result.returncode}", flush=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
return None
|
||||
|
||||
err_msg = f"{label}执行失败: `{cmd_str}`, 返回码: {result.returncode}"
|
||||
if not verbose and result.stderr.strip():
|
||||
err_msg += f"\n{result.stderr.strip()}"
|
||||
raise RuntimeError(err_msg)
|
||||
@@ -0,0 +1,115 @@
|
||||
"""图组合:将带字符串引用的多个图展开为纯 :class:`~pyflowx.graph.Graph`。
|
||||
|
||||
历史背景:原 ``graph.py`` 同时承载 DAG 构建/校验/分层与多图组合逻辑,
|
||||
职责过载。组合逻辑(:class:`GraphComposer` / :func:`compose`)与单图 DAG
|
||||
模型正交,此处抽离为独立模块,便于按需导入与独立演进。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import replace
|
||||
from typing import Any
|
||||
|
||||
from .graph import Graph
|
||||
from .task import TaskSpec
|
||||
|
||||
__all__ = ["GraphComposer", "compose"]
|
||||
|
||||
|
||||
class GraphComposer:
|
||||
"""将带字符串引用的图展开为纯 :class:`TaskSpec` 图。
|
||||
|
||||
引用格式:
|
||||
* ``"command_name"`` —— 引用整个命令图。
|
||||
* ``"command_name.task_name"`` —— 引用特定任务。
|
||||
|
||||
引用按顺序展开,后续引用的任务依赖前面引用的最后一个任务;
|
||||
原始 ``TaskSpec`` 之间也按出现顺序串行依赖。
|
||||
"""
|
||||
|
||||
def __init__(self, graphs: dict[str, Graph]) -> None:
|
||||
self.graphs = graphs
|
||||
|
||||
def resolve_all(self) -> dict[str, Graph]:
|
||||
"""解析所有图的字符串引用,返回展开后的新图映射。"""
|
||||
resolved: dict[str, Graph] = {}
|
||||
for cmd_name, graph in self.graphs.items():
|
||||
resolved[cmd_name] = self.expand_refs(graph, cmd_name)
|
||||
return resolved
|
||||
|
||||
def expand_refs(self, graph: Graph, current_cmd: str) -> Graph:
|
||||
"""展开图中的字符串引用。若无 ``_pending_refs``,原样返回。"""
|
||||
pending_refs = graph._pending_refs
|
||||
if not pending_refs:
|
||||
return graph
|
||||
|
||||
all_specs: list[TaskSpec[Any]] = []
|
||||
previous_ref_last_task: str | None = None
|
||||
|
||||
for ref in pending_refs:
|
||||
expanded_specs = self.parse_ref(ref, current_cmd)
|
||||
if previous_ref_last_task and expanded_specs:
|
||||
for i, task in enumerate(expanded_specs):
|
||||
if i == 0 or not task.depends_on:
|
||||
expanded_specs[i] = replace(task, depends_on=tuple({*task.depends_on, previous_ref_last_task}))
|
||||
if expanded_specs:
|
||||
previous_ref_last_task = expanded_specs[-1].name
|
||||
all_specs.extend(expanded_specs)
|
||||
|
||||
original_specs = list(graph.all_specs().values())
|
||||
if original_specs:
|
||||
if previous_ref_last_task:
|
||||
first = original_specs[0]
|
||||
all_specs.append(replace(first, depends_on=tuple({*first.depends_on, previous_ref_last_task})))
|
||||
else:
|
||||
all_specs.append(original_specs[0])
|
||||
for i in range(1, len(original_specs)):
|
||||
current_task = original_specs[i]
|
||||
previous_task_name = original_specs[i - 1].name
|
||||
all_specs.append(
|
||||
replace(current_task, depends_on=tuple({*current_task.depends_on, previous_task_name}))
|
||||
)
|
||||
|
||||
return Graph.from_specs(all_specs, defaults=graph.defaults)
|
||||
|
||||
def parse_ref(self, ref: str, current_cmd: str) -> list[TaskSpec[Any]]:
|
||||
"""解析单个字符串引用,返回对应的 TaskSpec 列表。"""
|
||||
if ref == current_cmd:
|
||||
raise ValueError(f"循环引用: 命令 '{current_cmd}' 引用了自己")
|
||||
|
||||
if "." in ref:
|
||||
cmd_name, task_name = ref.split(".", 1)
|
||||
if cmd_name not in self.graphs:
|
||||
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
|
||||
ref_graph = self.graphs[cmd_name]
|
||||
if task_name not in ref_graph.all_specs():
|
||||
raise ValueError(f"任务 '{task_name}' 不存在于命令 '{cmd_name}' 中")
|
||||
return [ref_graph.all_specs()[task_name]]
|
||||
else:
|
||||
cmd_name = ref
|
||||
if cmd_name not in self.graphs:
|
||||
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
|
||||
ref_graph = self.graphs[cmd_name]
|
||||
ref_graph = self.expand_refs(ref_graph, cmd_name)
|
||||
return list(ref_graph.all_specs().values())
|
||||
|
||||
|
||||
def compose(
|
||||
graphs: dict[str, Graph],
|
||||
) -> dict[str, Graph]:
|
||||
"""编程式解析多图的字符串引用,返回展开后的新图映射。
|
||||
|
||||
与 :class:`GraphComposer` 等价,但作为独立函数暴露,供不使用
|
||||
:class:`~pyflowx.runner.CliRunner` 的编程式用户调用。
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> graphs = {
|
||||
... "build": px.Graph.from_specs([px.TaskSpec("b", cmd=["echo", "b"])]),
|
||||
... "all": px.Graph.from_specs(["build", px.TaskSpec("t", cmd=["echo", "t"])]),
|
||||
... }
|
||||
>>> resolved = px.compose(graphs)
|
||||
>>> "b" in resolved["all"].all_specs()
|
||||
True
|
||||
"""
|
||||
return GraphComposer(graphs).resolve_all()
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
@@ -20,6 +21,8 @@ from typing import Any, Callable
|
||||
|
||||
from .task import Condition, Context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["BuiltinConditions", "Condition", "Constants"]
|
||||
|
||||
|
||||
@@ -42,14 +45,6 @@ def _static(predicate: Callable[[], bool], name: str) -> Condition:
|
||||
return _cond
|
||||
|
||||
|
||||
def _cond_reason(cond: Condition) -> str | list[str] | None:
|
||||
"""获取条件的失败原因:优先返回 ``_reason``,否则返回 ``__name__``。"""
|
||||
reason = getattr(cond, "_reason", None)
|
||||
if reason is not None:
|
||||
return reason
|
||||
return getattr(cond, "__name__", repr(cond))
|
||||
|
||||
|
||||
def _cond_name(cond: Condition) -> str:
|
||||
"""获取条件的可读名称。"""
|
||||
return getattr(cond, "__name__", repr(cond))
|
||||
@@ -161,7 +156,7 @@ class BuiltinConditions:
|
||||
return False
|
||||
try:
|
||||
return content in p.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
except (OSError, UnicodeDecodeError):
|
||||
return False
|
||||
|
||||
return _static(_check, f"FILE_CONTENT_EXISTS({path!r},{content!r})")
|
||||
@@ -194,7 +189,8 @@ class BuiltinConditions:
|
||||
return False
|
||||
try:
|
||||
return predicate(ctx[dep_name])
|
||||
except Exception:
|
||||
except Exception as exc:
|
||||
logger.warning("DEP_MATCHES predicate %r raised: %r", dep_name, exc)
|
||||
return False
|
||||
|
||||
_cond.__name__ = f"DEP_MATCHES({dep_name!r},{getattr(predicate, '__name__', 'pred')})"
|
||||
@@ -228,13 +224,7 @@ class BuiltinConditions:
|
||||
"""对条件取反."""
|
||||
|
||||
def _cond(ctx: Context) -> bool:
|
||||
result = condition(ctx)
|
||||
if result:
|
||||
# inner 为 True 时 NOT 会失败,记录 inner 的具体原因
|
||||
inner_reason = _cond_reason(condition)
|
||||
if inner_reason is not None:
|
||||
_cond._reason = inner_reason # type: ignore[attr-defined]
|
||||
return not result
|
||||
return not condition(ctx)
|
||||
|
||||
_cond.__name__ = f"NOT({_cond_name(condition)})"
|
||||
return _cond
|
||||
@@ -254,15 +244,7 @@ class BuiltinConditions:
|
||||
"""多个条件的逻辑或."""
|
||||
|
||||
def _cond(ctx: Context) -> bool:
|
||||
matched: list[str] = []
|
||||
for c in conditions:
|
||||
if c(ctx):
|
||||
reason = _cond_reason(c)
|
||||
matched.append(reason if isinstance(reason, str) else str(reason))
|
||||
if matched:
|
||||
_cond._reason = matched # type: ignore[attr-defined]
|
||||
return True
|
||||
return False
|
||||
return any(c(ctx) for c in conditions)
|
||||
|
||||
_cond.__name__ = f"OR({', '.join(_cond_name(c) for c in conditions)})"
|
||||
return _cond
|
||||
|
||||
+21
-2
@@ -16,6 +16,7 @@ DAG 库中泛滥的样板包装器。
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from functools import lru_cache
|
||||
from typing import Any, Mapping
|
||||
|
||||
from .errors import InjectionError
|
||||
@@ -24,6 +25,24 @@ from .task import Context, TaskSpec
|
||||
__all__ = ["Context", "_is_context_annotation", "build_call_args", "describe_injection"]
|
||||
|
||||
|
||||
@lru_cache(maxsize=1024)
|
||||
def _cached_signature(fn: Any) -> inspect.Signature:
|
||||
"""缓存 ``inspect.signature`` 结果(按 fn 对象键控)。
|
||||
|
||||
``fn`` 对象在 :meth:`TaskSpec.effective_fn` 缓存后稳定,签名重复内省
|
||||
属纯开销。对不可哈希的可调用对象,调用方回退到直接内省。
|
||||
"""
|
||||
return inspect.signature(fn)
|
||||
|
||||
|
||||
def _signature(fn: Any) -> inspect.Signature:
|
||||
"""获取签名,优先走缓存;``fn`` 不可哈希时回退到直接内省。"""
|
||||
try:
|
||||
return _cached_signature(fn)
|
||||
except TypeError:
|
||||
return inspect.signature(fn)
|
||||
|
||||
|
||||
def _is_context_annotation(annotation: Any) -> bool:
|
||||
"""判断参数标注是否为(或指向)``Context``。"""
|
||||
if annotation is Context:
|
||||
@@ -44,7 +63,7 @@ def build_call_args(
|
||||
执行器填入 :attr:`TaskSpec.defaults` 中的默认值)。
|
||||
"""
|
||||
fn = spec.effective_fn
|
||||
sig = inspect.signature(fn)
|
||||
sig = _signature(fn)
|
||||
params = sig.parameters
|
||||
|
||||
var_keyword = next(
|
||||
@@ -115,7 +134,7 @@ def build_call_args(
|
||||
def describe_injection(spec: TaskSpec[Any]) -> str:
|
||||
"""生成任务参数注入方式的人类可读描述。供 ``dry_run`` 使用。"""
|
||||
fn = spec.effective_fn
|
||||
sig = inspect.signature(fn)
|
||||
sig = _signature(fn)
|
||||
positional_params = [
|
||||
p
|
||||
for p, param in sig.parameters.items()
|
||||
|
||||
+288
-255
@@ -12,14 +12,18 @@
|
||||
|
||||
架构
|
||||
----
|
||||
本模块通过 **Mixin** 组合消除同步/异步与各层执行器之间的重复代码:
|
||||
本模块通过 **模块级函数** 消除同步/异步任务执行器之间的重复代码:
|
||||
|
||||
* :class:`_TaskSkipMixin` —— 上游跳过 / 条件跳过的预检逻辑。
|
||||
* :class:`_TaskRetryMixin` —— 重试决策、成功/失败后处理、finalize。
|
||||
* :class:`_LayerMixin` —— 缓存过滤、优先级排序、信号量构建、结果存储。
|
||||
* :class:`SyncTaskRunner` / :class:`AsyncTaskRunner` —— 任务级执行器,组合上述 Mixin。
|
||||
* 模块级跳过/重试函数(:func:`_prepare_for_execution` / :func:`_should_retry`
|
||||
/ :func:`_mark_success` / :func:`_handle_failure` / :func:`_finalize_failure`)
|
||||
—— 上游跳过 / 条件跳过的预检、重试决策、成功/失败后处理。
|
||||
* :class:`SyncTaskRunner` / :class:`AsyncTaskRunner` —— 任务级执行器,调用上述函数。
|
||||
* 模块级共享辅助(:func:`_filter_and_sort` / :func:`_store_result` /
|
||||
:func:`_build_semaphores` / :func:`_get_sem`)—— 缓存过滤、优先级排序、
|
||||
信号量构建、结果存储。
|
||||
* :class:`SequentialLayerRunner` / :class:`ThreadedLayerRunner` /
|
||||
:class:`AsyncLayerRunner` / :class:`DependencyRunner` —— 层级执行器,组合 :class:`_LayerMixin`。
|
||||
:class:`AsyncLayerRunner` —— 层级执行器,调用上述模块级辅助。
|
||||
* :class:`DependencyRunner` —— 依赖驱动调度(非层模型),同样调用模块级辅助。
|
||||
|
||||
所有策略共享统一异步内核,支持:
|
||||
* :class:`RetryPolicy`(max_attempts/delay/backoff/jitter/retry_on)
|
||||
@@ -52,7 +56,7 @@ from .report import RunReport
|
||||
from .storage import StateBackend, resolve_backend
|
||||
from .task import TaskEvent, TaskHooks, TaskResult, TaskSpec, TaskStatus
|
||||
|
||||
logger = logging.getLogger("pyflowx")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 观察者回调类型。
|
||||
EventCallback = Callable[[TaskEvent], None]
|
||||
@@ -83,6 +87,22 @@ def _emit(on_event: EventCallback | None, result: TaskResult[Any]) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _emit_running(on_event: EventCallback | None, spec: TaskSpec[Any]) -> None:
|
||||
"""触发 RUNNING 事件(任务开始执行时)。"""
|
||||
if on_event is None:
|
||||
return
|
||||
on_event(
|
||||
TaskEvent(
|
||||
task=spec.name,
|
||||
status=TaskStatus.RUNNING,
|
||||
attempts=0,
|
||||
error=None,
|
||||
duration=None,
|
||||
reason=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None:
|
||||
"""安全调用钩子(异常仅记录,不影响任务状态)。"""
|
||||
hook: Callable[..., None] | None = getattr(hooks, fn_name, None)
|
||||
@@ -126,11 +146,16 @@ def _apply_cached(
|
||||
backend: StateBackend,
|
||||
on_event: EventCallback | None,
|
||||
) -> bool:
|
||||
"""若 ``name`` 命中缓存,写入 context/report 并返回 True。"""
|
||||
"""若 ``name`` 命中缓存,写入 context/report 并返回 True。
|
||||
|
||||
单次 ``backend.get`` + ``KeyError`` 回退,避免 ``has`` + ``get`` 双重
|
||||
哈希查找与双重 TTL 判断。
|
||||
"""
|
||||
storage_key = spec.storage_key(context)
|
||||
if not backend.has(storage_key):
|
||||
try:
|
||||
cached = backend.get(storage_key)
|
||||
except KeyError:
|
||||
return False
|
||||
cached = backend.get(storage_key)
|
||||
context[name] = cached
|
||||
result = TaskResult(spec=spec, status=TaskStatus.SKIPPED, value=cached, reason="缓存命中")
|
||||
report.results[name] = result
|
||||
@@ -139,154 +164,146 @@ def _apply_cached(
|
||||
return True
|
||||
|
||||
|
||||
def _sort_by_priority(layer: list[str], graph: Graph) -> list[str]:
|
||||
"""按优先级降序排序(稳定排序)。"""
|
||||
return sorted(layer, key=lambda n: -graph.resolved_spec(n).priority)
|
||||
def _sort_by_priority(layer: list[str], specs: Mapping[str, TaskSpec[Any]]) -> list[str]:
|
||||
"""按优先级降序排序(稳定排序)。
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Mixin:任务级跳过 / 重试 / 成功处理
|
||||
# ---------------------------------------------------------------------- #
|
||||
class _TaskSkipMixin:
|
||||
"""任务级跳过预检共享逻辑。
|
||||
|
||||
将"上游被跳过/失败"与"条件不满足"两类跳过判断统一为单一入口,
|
||||
被 :class:`SyncTaskRunner` 与 :class:`AsyncTaskRunner` 复用。
|
||||
接受预构建的 ``{name: spec}`` 映射,避免在排序键函数中重复调用
|
||||
``graph.resolved_spec``(即便有缓存也省去 N 次字典查询)。
|
||||
"""
|
||||
return sorted(layer, key=lambda n: -specs[n].priority)
|
||||
|
||||
@staticmethod
|
||||
def _upstream_skip_reason(spec: TaskSpec[Any], report: RunReport | None) -> str | None:
|
||||
"""硬依赖被 SKIPPED/FAILED 时返回原因字符串,否则 ``None``。
|
||||
|
||||
软依赖不影响本检查——软依赖被跳过时注入默认值。
|
||||
"""
|
||||
if report is None or spec.allow_upstream_skip:
|
||||
return None
|
||||
for dep in spec.depends_on:
|
||||
if dep not in report.results:
|
||||
continue
|
||||
dep_status = report.results[dep].status
|
||||
if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED):
|
||||
return f"上游任务 '{dep}' 状态为 {dep_status.value}"
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 任务级跳过 / 重试 / 成功处理:模块级函数
|
||||
# ---------------------------------------------------------------------- #
|
||||
def _upstream_skip_reason(spec: TaskSpec[Any], report: RunReport | None) -> str | None:
|
||||
"""硬依赖被 SKIPPED/FAILED 时返回原因字符串,否则 ``None``。
|
||||
|
||||
软依赖不影响本检查——软依赖被跳过时注入默认值。
|
||||
"""
|
||||
if report is None or spec.allow_upstream_skip:
|
||||
return None
|
||||
for dep in spec.depends_on:
|
||||
if dep not in report.results:
|
||||
continue
|
||||
dep_status = report.results[dep].status
|
||||
if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED):
|
||||
return f"上游任务 '{dep}' 状态为 {dep_status.value}"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _prepare_for_execution(
|
||||
spec: TaskSpec[Any],
|
||||
context: Mapping[str, Any],
|
||||
report: RunReport | None,
|
||||
on_event: EventCallback | None,
|
||||
) -> TaskResult[Any] | None:
|
||||
"""执行前预检:上游跳过 / 条件跳过。
|
||||
|
||||
返回 SKIPPED TaskResult 或 ``None``(继续执行)。
|
||||
条件判断委托给 :meth:`TaskSpec.should_execute`,避免重复实现。
|
||||
"""
|
||||
# 1. 上游被跳过/失败
|
||||
skip_reason = _TaskSkipMixin._upstream_skip_reason(spec, report)
|
||||
# 2. 条件 / skip_if_missing(单一来源:TaskSpec.should_execute)
|
||||
if skip_reason is None:
|
||||
should_run, cond_reason = spec.should_execute(context)
|
||||
if not should_run:
|
||||
skip_reason = cond_reason or "条件不满足"
|
||||
if skip_reason is None:
|
||||
return None
|
||||
# 构造 SKIPPED 结果
|
||||
result: TaskResult[Any] = TaskResult(
|
||||
spec=spec,
|
||||
status=TaskStatus.SKIPPED,
|
||||
finished_at=datetime.now(),
|
||||
reason=skip_reason,
|
||||
def _prepare_for_execution(
|
||||
spec: TaskSpec[Any],
|
||||
context: Mapping[str, Any],
|
||||
report: RunReport | None,
|
||||
on_event: EventCallback | None,
|
||||
) -> TaskResult[Any] | None:
|
||||
"""执行前预检:上游跳过 / 条件跳过。
|
||||
|
||||
返回 SKIPPED TaskResult 或 ``None``(继续执行)。
|
||||
条件判断委托给 :meth:`TaskSpec.should_execute`,避免重复实现。
|
||||
"""
|
||||
# 1. 上游被跳过/失败
|
||||
skip_reason = _upstream_skip_reason(spec, report)
|
||||
# 2. 条件 / skip_if_missing(单一来源:TaskSpec.should_execute)
|
||||
if skip_reason is None:
|
||||
should_run, cond_reason = spec.should_execute(context)
|
||||
if not should_run:
|
||||
skip_reason = cond_reason or "条件不满足"
|
||||
if skip_reason is None:
|
||||
return None
|
||||
# 构造 SKIPPED 结果
|
||||
result: TaskResult[Any] = TaskResult(
|
||||
spec=spec,
|
||||
status=TaskStatus.SKIPPED,
|
||||
finished_at=datetime.now(),
|
||||
reason=skip_reason,
|
||||
)
|
||||
_emit(on_event, result)
|
||||
logger.info("task %r skipped (%s)", spec.name, skip_reason)
|
||||
return result
|
||||
|
||||
|
||||
def _should_retry(spec: TaskSpec[Any], attempts: int, exc: BaseException) -> bool:
|
||||
"""是否应继续重试。"""
|
||||
return attempts < spec.retry.max_attempts and spec.retry.should_retry(exc)
|
||||
|
||||
|
||||
def _mark_success(spec: TaskSpec[Any], result: TaskResult[Any], value: Any) -> None:
|
||||
"""标记任务成功并触发 post_run 钩子。"""
|
||||
result.value = value
|
||||
result.status = TaskStatus.SUCCESS
|
||||
result.finished_at = datetime.now()
|
||||
_run_hooks(spec.hooks, "post_run", spec, value)
|
||||
|
||||
|
||||
def _finalize_failure(
|
||||
result: TaskResult[Any],
|
||||
layer_idx: int | None,
|
||||
on_event: EventCallback | None,
|
||||
continue_on_error: bool,
|
||||
) -> None:
|
||||
"""标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。"""
|
||||
result.status = TaskStatus.FAILED
|
||||
result.finished_at = datetime.now()
|
||||
_emit(on_event, result)
|
||||
if continue_on_error:
|
||||
logger.warning(
|
||||
"task %r failed but continue_on_error=True; continuing.",
|
||||
result.spec.name,
|
||||
)
|
||||
_emit(on_event, result)
|
||||
logger.info("task %r skipped (%s)", spec.name, skip_reason)
|
||||
return result
|
||||
return
|
||||
raise TaskFailedError(
|
||||
task=result.spec.name,
|
||||
cause=result.error if result.error is not None else RuntimeError("unknown"),
|
||||
attempts=result.attempts,
|
||||
layer=layer_idx,
|
||||
)
|
||||
|
||||
|
||||
class _TaskRetryMixin:
|
||||
"""任务级重试决策与失败/成功后处理共享逻辑。"""
|
||||
def _handle_failure(
|
||||
spec: TaskSpec[Any],
|
||||
result: TaskResult[Any],
|
||||
exc: BaseException,
|
||||
layer_idx: int | None,
|
||||
on_event: EventCallback | None,
|
||||
) -> bool:
|
||||
"""统一处理失败:超时转换、重试决策、finalize。
|
||||
|
||||
@staticmethod
|
||||
def _should_retry(spec: TaskSpec[Any], attempts: int, exc: BaseException) -> bool:
|
||||
"""是否应继续重试。"""
|
||||
return attempts < spec.retry.max_attempts and spec.retry.should_retry(exc)
|
||||
|
||||
@staticmethod
|
||||
def _mark_success(spec: TaskSpec[Any], result: TaskResult[Any], value: Any) -> None:
|
||||
"""标记任务成功并触发 post_run 钩子。"""
|
||||
result.value = value
|
||||
result.status = TaskStatus.SUCCESS
|
||||
result.finished_at = datetime.now()
|
||||
_run_hooks(spec.hooks, "post_run", spec, value)
|
||||
|
||||
@staticmethod
|
||||
def _finalize_failure(
|
||||
result: TaskResult[Any],
|
||||
layer_idx: int | None,
|
||||
on_event: EventCallback | None,
|
||||
continue_on_error: bool,
|
||||
) -> None:
|
||||
"""标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。"""
|
||||
result.status = TaskStatus.FAILED
|
||||
result.finished_at = datetime.now()
|
||||
_emit(on_event, result)
|
||||
if continue_on_error:
|
||||
logger.warning(
|
||||
"task %r failed but continue_on_error=True; continuing.",
|
||||
result.spec.name,
|
||||
)
|
||||
return
|
||||
raise TaskFailedError(
|
||||
task=result.spec.name,
|
||||
cause=result.error if result.error is not None else RuntimeError("unknown"),
|
||||
attempts=result.attempts,
|
||||
layer=layer_idx,
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
``True`` 表示已 finalize(不再重试);``False`` 表示应继续重试。
|
||||
"""
|
||||
# asyncio.TimeoutError → TaskTimeoutError(统一异常类型)
|
||||
if isinstance(exc, asyncio.TimeoutError):
|
||||
exc = TaskTimeoutError(spec.name, spec.timeout or 0.0)
|
||||
logger.warning(
|
||||
"task %r timed out (attempt %d/%d); retrying",
|
||||
spec.name,
|
||||
result.attempts,
|
||||
spec.retry.max_attempts,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _handle_failure(
|
||||
spec: TaskSpec[Any],
|
||||
result: TaskResult[Any],
|
||||
exc: BaseException,
|
||||
layer_idx: int | None,
|
||||
on_event: EventCallback | None,
|
||||
) -> bool:
|
||||
"""统一处理失败:超时转换、重试决策、finalize。
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
``True`` 表示已 finalize(不再重试);``False`` 表示应继续重试。
|
||||
"""
|
||||
# asyncio.TimeoutError → TaskTimeoutError(统一异常类型)
|
||||
if isinstance(exc, asyncio.TimeoutError):
|
||||
exc = TaskTimeoutError(spec.name, spec.timeout or 0.0)
|
||||
logger.warning(
|
||||
"task %r timed out (attempt %d/%d); retrying",
|
||||
spec.name,
|
||||
result.attempts,
|
||||
spec.retry.max_attempts,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"task %r failed (attempt %d/%d): %r; retrying",
|
||||
spec.name,
|
||||
result.attempts,
|
||||
spec.retry.max_attempts,
|
||||
exc,
|
||||
)
|
||||
result.error = exc
|
||||
if _TaskRetryMixin._should_retry(spec, result.attempts, exc):
|
||||
return False
|
||||
_run_hooks(spec.hooks, "on_failure", spec, exc)
|
||||
_TaskRetryMixin._finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
"task %r failed (attempt %d/%d): %r; retrying",
|
||||
spec.name,
|
||||
result.attempts,
|
||||
spec.retry.max_attempts,
|
||||
exc,
|
||||
)
|
||||
result.error = exc
|
||||
if _should_retry(spec, result.attempts, exc):
|
||||
return False
|
||||
_run_hooks(spec.hooks, "on_failure", spec, exc)
|
||||
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 任务执行器:同步 / 异步(复用 _TaskSkipMixin + _TaskRetryMixin)
|
||||
# 任务执行器:同步 / 异步(调用模块级跳过/重试函数)
|
||||
# ---------------------------------------------------------------------- #
|
||||
class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
|
||||
class SyncTaskRunner:
|
||||
"""同步任务执行器:带重试与跳过预检。"""
|
||||
|
||||
@staticmethod
|
||||
@@ -297,7 +314,7 @@ class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
|
||||
on_event: EventCallback | None = None,
|
||||
report: RunReport | None = None,
|
||||
) -> TaskResult[Any]:
|
||||
skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event)
|
||||
skipped = _prepare_for_execution(spec, context, report, on_event)
|
||||
if skipped is not None:
|
||||
return skipped
|
||||
|
||||
@@ -306,23 +323,24 @@ class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
|
||||
args, kwargs = build_call_args(spec, context)
|
||||
|
||||
_run_hooks(spec.hooks, "pre_run", spec)
|
||||
_emit_running(on_event, spec)
|
||||
|
||||
while True:
|
||||
result.attempts += 1
|
||||
try:
|
||||
with spec.env_context():
|
||||
value = spec.effective_fn(*args, **kwargs)
|
||||
_TaskRetryMixin._mark_success(spec, result, value)
|
||||
_mark_success(spec, result, value)
|
||||
return result
|
||||
except Exception as exc:
|
||||
if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event):
|
||||
if _handle_failure(spec, result, exc, layer_idx, on_event):
|
||||
return result
|
||||
wait = spec.retry.wait_seconds(result.attempts)
|
||||
if wait > 0:
|
||||
time.sleep(wait)
|
||||
|
||||
|
||||
class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
|
||||
class AsyncTaskRunner:
|
||||
"""异步任务执行器:在事件循环上运行同步或异步任务,带重试与跳过预检。"""
|
||||
|
||||
@staticmethod
|
||||
@@ -334,7 +352,7 @@ class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
|
||||
report: RunReport | None = None,
|
||||
semaphore: asyncio.Semaphore | None = None,
|
||||
) -> TaskResult[Any]:
|
||||
skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event)
|
||||
skipped = _prepare_for_execution(spec, context, report, on_event)
|
||||
if skipped is not None:
|
||||
return skipped
|
||||
|
||||
@@ -345,15 +363,16 @@ class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
_run_hooks(spec.hooks, "pre_run", spec)
|
||||
_emit_running(on_event, spec)
|
||||
|
||||
while True:
|
||||
result.attempts += 1
|
||||
try:
|
||||
value = await _execute_async_task(spec, args, kwargs, loop)
|
||||
_TaskRetryMixin._mark_success(spec, result, value)
|
||||
_mark_success(spec, result, value)
|
||||
return result
|
||||
except Exception as exc:
|
||||
if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event):
|
||||
if _handle_failure(spec, result, exc, layer_idx, on_event):
|
||||
return result
|
||||
wait = spec.retry.wait_seconds(result.attempts)
|
||||
if wait > 0:
|
||||
@@ -388,81 +407,81 @@ async def _execute_async_task(
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Mixin:层执行共享逻辑
|
||||
# 共享辅助:缓存过滤、优先级排序、信号量构建、结果存储
|
||||
# ---------------------------------------------------------------------- #
|
||||
class _LayerMixin:
|
||||
"""层执行共享逻辑:缓存过滤、优先级排序、信号量构建、结果存储。
|
||||
def _filter_and_sort(
|
||||
layer: list[str],
|
||||
graph: Graph,
|
||||
context: dict[str, Any],
|
||||
report: RunReport,
|
||||
backend: StateBackend,
|
||||
on_event: EventCallback | None,
|
||||
) -> list[str]:
|
||||
"""过滤掉已命中缓存的任务,按优先级排序返回待运行列表。
|
||||
|
||||
四个层执行器(sequential/threaded/async/dependency)通过组合此 Mixin
|
||||
消除"过滤缓存→排序→运行→存结果"的样板代码。
|
||||
预构建 ``{name: spec}`` 映射,过滤与排序共享同一份 resolved spec,
|
||||
避免 ``_sort_by_priority`` 内重复调用 ``graph.resolved_spec``。
|
||||
"""
|
||||
specs: dict[str, TaskSpec[Any]] = {}
|
||||
to_run: list[str] = []
|
||||
for name in layer:
|
||||
spec = graph.resolved_spec(name)
|
||||
specs[name] = spec
|
||||
if not _apply_cached(name, spec, context, report, backend, on_event):
|
||||
to_run.append(name)
|
||||
return _sort_by_priority(to_run, specs)
|
||||
|
||||
@staticmethod
|
||||
def _filter_and_sort(
|
||||
layer: list[str],
|
||||
graph: Graph,
|
||||
context: dict[str, Any],
|
||||
report: RunReport,
|
||||
backend: StateBackend,
|
||||
on_event: EventCallback | None,
|
||||
) -> list[str]:
|
||||
"""过滤掉已命中缓存的任务,按优先级排序返回待运行列表。"""
|
||||
to_run: list[str] = []
|
||||
for name in layer:
|
||||
spec = graph.resolved_spec(name)
|
||||
if not _apply_cached(name, spec, context, report, backend, on_event):
|
||||
to_run.append(name)
|
||||
return _sort_by_priority(to_run, graph)
|
||||
|
||||
@staticmethod
|
||||
def _store_result(
|
||||
name: str,
|
||||
result: TaskResult[Any],
|
||||
graph: Graph,
|
||||
context: dict[str, Any],
|
||||
report: RunReport,
|
||||
backend: StateBackend,
|
||||
on_event: EventCallback | None,
|
||||
context_snapshot: Mapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""存储任务结果到 context/report/backend 并触发事件。"""
|
||||
context[name] = result.value
|
||||
if result.status == TaskStatus.SUCCESS:
|
||||
spec = graph.resolved_spec(name)
|
||||
task_ctx = _build_context(spec, context_snapshot if context_snapshot is not None else context, report)
|
||||
backend.save(spec.storage_key(task_ctx), result.value)
|
||||
report.results[name] = result
|
||||
_emit(on_event, result)
|
||||
def _store_result(
|
||||
name: str,
|
||||
result: TaskResult[Any],
|
||||
spec: TaskSpec[Any],
|
||||
task_ctx: dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
report: RunReport,
|
||||
backend: StateBackend,
|
||||
on_event: EventCallback | None,
|
||||
) -> None:
|
||||
"""存储任务结果到 context/report/backend 并触发事件。
|
||||
|
||||
@staticmethod
|
||||
def _build_semaphores(
|
||||
to_run: list[str],
|
||||
graph: Graph,
|
||||
sem_factory: Callable[[int], Any],
|
||||
concurrency_limits: Mapping[str, int],
|
||||
) -> dict[str, Any]:
|
||||
"""为每个 ``concurrency_key`` 创建一个信号量。"""
|
||||
semaphores: dict[str, Any] = {}
|
||||
for name in to_run:
|
||||
spec = graph.resolved_spec(name)
|
||||
key = spec.concurrency_key
|
||||
if key is not None and key not in semaphores:
|
||||
limit = concurrency_limits.get(key, 1)
|
||||
semaphores[key] = sem_factory(limit)
|
||||
return semaphores
|
||||
``spec`` 与 ``task_ctx`` 由调用方在执行前已构建,直接复用避免重复
|
||||
``resolved_spec`` / ``_build_context`` 调用。
|
||||
"""
|
||||
context[name] = result.value
|
||||
if result.status == TaskStatus.SUCCESS:
|
||||
backend.save(spec.storage_key(task_ctx), result.value)
|
||||
report.results[name] = result
|
||||
_emit(on_event, result)
|
||||
|
||||
@staticmethod
|
||||
def _get_sem(semaphores: Mapping[str, Any], spec: TaskSpec[Any]) -> Any | None:
|
||||
"""获取任务对应的信号量(无 concurrency_key 则返回 None)。"""
|
||||
if spec.concurrency_key is None:
|
||||
return None
|
||||
return semaphores.get(spec.concurrency_key)
|
||||
|
||||
def _build_semaphores(
|
||||
to_run: list[str],
|
||||
graph: Graph,
|
||||
sem_factory: Callable[[int], Any],
|
||||
concurrency_limits: Mapping[str, int],
|
||||
) -> dict[str, Any]:
|
||||
"""为每个 ``concurrency_key`` 创建一个信号量。"""
|
||||
semaphores: dict[str, Any] = {}
|
||||
for name in to_run:
|
||||
spec = graph.resolved_spec(name)
|
||||
key = spec.concurrency_key
|
||||
if key is not None and key not in semaphores:
|
||||
limit = concurrency_limits.get(key, 1)
|
||||
semaphores[key] = sem_factory(limit)
|
||||
return semaphores
|
||||
|
||||
|
||||
def _get_sem(semaphores: Mapping[str, Any], spec: TaskSpec[Any]) -> Any | None:
|
||||
"""获取任务对应的信号量(无 concurrency_key 则返回 None)。"""
|
||||
if spec.concurrency_key is None:
|
||||
return None
|
||||
return semaphores.get(spec.concurrency_key)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 层执行器
|
||||
# ---------------------------------------------------------------------- #
|
||||
class SequentialLayerRunner(_LayerMixin):
|
||||
class SequentialLayerRunner:
|
||||
"""逐个运行某层的任务(按优先级排序)。"""
|
||||
|
||||
@staticmethod
|
||||
@@ -475,14 +494,14 @@ class SequentialLayerRunner(_LayerMixin):
|
||||
layer_idx: int,
|
||||
on_event: EventCallback | None,
|
||||
) -> None:
|
||||
for name in SequentialLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event):
|
||||
for name in _filter_and_sort(layer, graph, context, report, backend, on_event):
|
||||
spec = graph.resolved_spec(name)
|
||||
task_ctx = _build_context(spec, context, report)
|
||||
result = SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
|
||||
SequentialLayerRunner._store_result(name, result, graph, context, report, backend, on_event)
|
||||
_store_result(name, result, spec, task_ctx, context, report, backend, on_event)
|
||||
|
||||
|
||||
class ThreadedLayerRunner(_LayerMixin):
|
||||
class ThreadedLayerRunner:
|
||||
"""在线程池中并发运行某层的任务。"""
|
||||
|
||||
@staticmethod
|
||||
@@ -497,43 +516,43 @@ class ThreadedLayerRunner(_LayerMixin):
|
||||
max_workers: int,
|
||||
concurrency_limits: Mapping[str, int],
|
||||
) -> None:
|
||||
to_run = ThreadedLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event)
|
||||
to_run = _filter_and_sort(layer, graph, context, report, backend, on_event)
|
||||
if not to_run:
|
||||
return
|
||||
semaphores = ThreadedLayerRunner._build_semaphores(to_run, graph, threading.Semaphore, concurrency_limits)
|
||||
semaphores = _build_semaphores(to_run, graph, threading.Semaphore, concurrency_limits)
|
||||
context_snapshot = dict(context)
|
||||
lock = threading.Lock()
|
||||
|
||||
def _run_threaded_task(name: str) -> TaskResult[Any]:
|
||||
def _run_threaded_task(name: str) -> tuple[dict[str, Any], TaskResult[Any]]:
|
||||
spec = graph.resolved_spec(name)
|
||||
task_ctx = _build_context(spec, context_snapshot, report)
|
||||
sem = ThreadedLayerRunner._get_sem(semaphores, spec)
|
||||
sem = _get_sem(semaphores, spec)
|
||||
if sem is not None:
|
||||
sem.acquire()
|
||||
try:
|
||||
return SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
|
||||
return task_ctx, SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
|
||||
finally:
|
||||
if sem is not None:
|
||||
sem.release()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
future_to_name: dict[concurrent.futures.Future[TaskResult[Any]], str] = {
|
||||
future_to_name: dict[concurrent.futures.Future[tuple[dict[str, Any], TaskResult[Any]]], str] = {
|
||||
pool.submit(_run_threaded_task, name): name for name in to_run
|
||||
}
|
||||
completed: dict[str, TaskResult[Any]] = {}
|
||||
completed: dict[str, tuple[dict[str, Any], TaskResult[Any]]] = {}
|
||||
try:
|
||||
for fut in concurrent.futures.as_completed(future_to_name):
|
||||
name = future_to_name[fut]
|
||||
completed[name] = fut.result()
|
||||
finally:
|
||||
with lock:
|
||||
for name, result in completed.items():
|
||||
ThreadedLayerRunner._store_result(
|
||||
name, result, graph, context, report, backend, on_event, context_snapshot
|
||||
for name, (task_ctx, result) in completed.items():
|
||||
_store_result(
|
||||
name, result, graph.resolved_spec(name), task_ctx, context, report, backend, on_event
|
||||
)
|
||||
|
||||
|
||||
class AsyncLayerRunner(_LayerMixin):
|
||||
class AsyncLayerRunner:
|
||||
"""在事件循环上并发运行某层的任务。"""
|
||||
|
||||
@staticmethod
|
||||
@@ -547,27 +566,32 @@ class AsyncLayerRunner(_LayerMixin):
|
||||
on_event: EventCallback | None,
|
||||
concurrency_limits: Mapping[str, int],
|
||||
) -> None:
|
||||
to_run = AsyncLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event)
|
||||
to_run = _filter_and_sort(layer, graph, context, report, backend, on_event)
|
||||
if not to_run:
|
||||
return
|
||||
semaphores = AsyncLayerRunner._build_semaphores(to_run, graph, asyncio.Semaphore, concurrency_limits)
|
||||
semaphores = _build_semaphores(to_run, graph, asyncio.Semaphore, concurrency_limits)
|
||||
context_snapshot = dict(context)
|
||||
|
||||
async def _run_async_task(name: str) -> TaskResult[Any]:
|
||||
async def _run_async_task(name: str) -> tuple[dict[str, Any], TaskResult[Any]]:
|
||||
spec = graph.resolved_spec(name)
|
||||
task_ctx = _build_context(spec, context_snapshot, report)
|
||||
sem = AsyncLayerRunner._get_sem(semaphores, spec)
|
||||
return await AsyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report, sem)
|
||||
sem = _get_sem(semaphores, spec)
|
||||
result = await AsyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report, sem)
|
||||
return task_ctx, result
|
||||
|
||||
results = await asyncio.gather(*[_run_async_task(name) for name in to_run])
|
||||
for name, result in zip(to_run, results):
|
||||
AsyncLayerRunner._store_result(name, result, graph, context, report, backend, on_event, context_snapshot)
|
||||
for name, (task_ctx, result) in zip(to_run, results):
|
||||
_store_result(name, result, graph.resolved_spec(name), task_ctx, context, report, backend, on_event)
|
||||
|
||||
|
||||
class DependencyRunner(_LayerMixin):
|
||||
class DependencyRunner:
|
||||
"""依赖驱动调度:任务在硬/软依赖完成后立即启动,无层屏障。
|
||||
|
||||
所有任务通过 asyncio 并发调度。同步任务卸载到线程池。
|
||||
|
||||
本类不继承层 Mixin:依赖驱动调度不是层模型,直接调用模块级共享辅助
|
||||
函数(:func:`_build_semaphores` / :func:`_get_sem` / :func:`_store_result`),
|
||||
职责更清晰。
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@@ -580,7 +604,7 @@ class DependencyRunner(_LayerMixin):
|
||||
concurrency_limits: Mapping[str, int],
|
||||
) -> None:
|
||||
all_names = list(graph.all_specs().keys())
|
||||
semaphores = DependencyRunner._build_semaphores(all_names, graph, asyncio.Semaphore, concurrency_limits)
|
||||
semaphores = _build_semaphores(all_names, graph, asyncio.Semaphore, concurrency_limits)
|
||||
futures: dict[str, asyncio.Future[TaskResult[Any]]] = {}
|
||||
|
||||
async def _run_task(name: str) -> TaskResult[Any]:
|
||||
@@ -598,9 +622,9 @@ class DependencyRunner(_LayerMixin):
|
||||
if _apply_cached(name, spec, context, report, backend, on_event):
|
||||
return report.results[name]
|
||||
|
||||
sem = DependencyRunner._get_sem(semaphores, spec)
|
||||
sem = _get_sem(semaphores, spec)
|
||||
result = await AsyncTaskRunner.run(spec, task_ctx, None, on_event, report, sem)
|
||||
DependencyRunner._store_result(name, result, graph, context, report, backend, on_event)
|
||||
_store_result(name, result, spec, task_ctx, context, report, backend, on_event)
|
||||
return result
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -617,7 +641,7 @@ def _make_verbose_callback(on_event: EventCallback | None) -> EventCallback:
|
||||
|
||||
def _verbose_callback(event: TaskEvent) -> None:
|
||||
dur = f" ({event.duration:.3f}s)" if event.duration is not None else ""
|
||||
if event.status == TaskStatus.RUNNING: # pragma: no cover
|
||||
if event.status == TaskStatus.RUNNING:
|
||||
print(f"[verbose] 任务 {event.task!r} 开始执行...", flush=True)
|
||||
elif event.status == TaskStatus.SUCCESS:
|
||||
print(f"[verbose] 任务 {event.task!r} 成功{dur}", flush=True)
|
||||
@@ -677,33 +701,42 @@ def run(
|
||||
TaskFailedError
|
||||
任何任务耗尽重试后仍失败时(除非 ``continue_on_error=True``)。
|
||||
"""
|
||||
graph.validate()
|
||||
layers = graph.layers()
|
||||
|
||||
if dry_run:
|
||||
layers = graph.layers()
|
||||
_print_dry_run(graph, layers)
|
||||
return RunReport(success=True)
|
||||
|
||||
# 入口统一校验一次:所有策略共用,避免 layers() / dependency 路径
|
||||
# 各自重复调用 validate()。
|
||||
graph.validate()
|
||||
|
||||
effective_callback: EventCallback | None = _make_verbose_callback(on_event) if verbose else on_event
|
||||
backend = resolve_backend(state)
|
||||
report = RunReport()
|
||||
context: dict[str, Any] = {}
|
||||
limits = concurrency_limits or {}
|
||||
|
||||
try:
|
||||
if strategy == "sequential":
|
||||
_drive_sequential(graph, layers, context, report, backend, effective_callback)
|
||||
elif strategy == "thread":
|
||||
_drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers, limits)
|
||||
elif strategy == "async":
|
||||
asyncio.run(_async_drive(graph, layers, context, report, backend, effective_callback, limits))
|
||||
elif strategy == "dependency":
|
||||
asyncio.run(DependencyRunner.execute(graph, context, report, backend, effective_callback, limits))
|
||||
else:
|
||||
raise ValueError(f"Unknown strategy: {strategy!r}")
|
||||
except TaskFailedError:
|
||||
report.success = False
|
||||
raise
|
||||
# backend.batch():将每任务一次落盘降为整次运行一次(JSONBackend);
|
||||
# MemoryBackend 为 no-op。即使中途抛出 TaskFailedError,batch 退出时
|
||||
# 仍会 flush 一次,保留已成功任务的结果以便断点续跑。
|
||||
with backend.batch():
|
||||
try:
|
||||
if strategy == "sequential":
|
||||
layers = graph.layers()
|
||||
_drive_sequential(graph, layers, context, report, backend, effective_callback)
|
||||
elif strategy == "thread":
|
||||
layers = graph.layers()
|
||||
_drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers, limits)
|
||||
elif strategy == "async":
|
||||
layers = graph.layers()
|
||||
asyncio.run(_async_drive(graph, layers, context, report, backend, effective_callback, limits))
|
||||
elif strategy == "dependency":
|
||||
asyncio.run(DependencyRunner.execute(graph, context, report, backend, effective_callback, limits))
|
||||
else:
|
||||
raise ValueError(f"Unknown strategy: {strategy!r}")
|
||||
except TaskFailedError:
|
||||
report.success = False
|
||||
raise
|
||||
|
||||
return report
|
||||
|
||||
|
||||
+17
-103
@@ -82,6 +82,10 @@ class Graph:
|
||||
# 待解析的字符串引用列表(由 GraphComposer 消费);为空表示无引用。
|
||||
_pending_refs: list[str] = field(default_factory=list)
|
||||
|
||||
# resolved_spec 缓存:避免执行期每个任务多次重复 dataclasses.replace 判断。
|
||||
# 在 specs / defaults 变更时失效。
|
||||
_resolved_cache: dict[str, TaskSpec[Any]] = field(default_factory=dict)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 构建
|
||||
# ------------------------------------------------------------------ #
|
||||
@@ -97,6 +101,7 @@ class Graph:
|
||||
self.specs[spec.name] = spec
|
||||
# 拓扑依赖仅含硬依赖;软依赖仅用于注入,不影响分层。
|
||||
self.deps[spec.name] = spec.depends_on
|
||||
self._resolved_cache.clear()
|
||||
|
||||
@classmethod
|
||||
def from_specs(
|
||||
@@ -175,7 +180,12 @@ class Graph:
|
||||
对于 ``retry``/``timeout``/``strategy``/``env``/``cwd`` 等可空
|
||||
字段,若 spec 字段为默认空值且图级默认值非空,则用
|
||||
:func:`dataclasses.replace` 生成带默认值的副本。
|
||||
|
||||
结果按 ``name`` 缓存;specs / defaults 变更时缓存失效。
|
||||
"""
|
||||
cached = self._resolved_cache.get(name)
|
||||
if cached is not None:
|
||||
return cached
|
||||
spec = self.specs[name]
|
||||
d = self.defaults
|
||||
overrides: dict[str, Any] = {}
|
||||
@@ -199,9 +209,9 @@ class Graph:
|
||||
overrides["verbose"] = True
|
||||
if not spec.tags and d.tags:
|
||||
overrides["tags"] = d.tags
|
||||
if not overrides:
|
||||
return spec
|
||||
return replace(spec, **overrides)
|
||||
resolved = spec if not overrides else replace(spec, **overrides)
|
||||
self._resolved_cache[name] = resolved
|
||||
return resolved
|
||||
|
||||
def dependencies(self, name: str) -> tuple[str, ...]:
|
||||
"""``name`` 的直接硬依赖前驱。"""
|
||||
@@ -221,8 +231,11 @@ class Graph:
|
||||
|
||||
同层任务无相互硬依赖,可并发执行。软依赖不参与分层。
|
||||
层按执行顺序返回。图有环时抛出 :class:`CycleError`。
|
||||
|
||||
.. note::
|
||||
本方法假定图已通过 :meth:`validate` 校验(由 :func:`pyflowx.run`
|
||||
在入口统一执行一次)。若直接调用本方法,需自行先校验。
|
||||
"""
|
||||
self.validate()
|
||||
sorter = _TopologicalSorter(self.deps)
|
||||
result: list[list[str]] = []
|
||||
sorter.prepare()
|
||||
@@ -355,102 +368,3 @@ class Graph:
|
||||
|
||||
def __contains__(self, name: Any) -> bool:
|
||||
return name in self.specs
|
||||
|
||||
|
||||
class GraphComposer:
|
||||
"""将带字符串引用的图展开为纯 :class:`TaskSpec` 图。
|
||||
|
||||
引用格式:
|
||||
* ``"command_name"`` —— 引用整个命令图。
|
||||
* ``"command_name.task_name"`` —— 引用特定任务。
|
||||
|
||||
引用按顺序展开,后续引用的任务依赖前面引用的最后一个任务;
|
||||
原始 ``TaskSpec`` 之间也按出现顺序串行依赖。
|
||||
"""
|
||||
|
||||
def __init__(self, graphs: dict[str, Graph]) -> None:
|
||||
self.graphs = graphs
|
||||
|
||||
def resolve_all(self) -> dict[str, Graph]:
|
||||
"""解析所有图的字符串引用,返回展开后的新图映射。"""
|
||||
resolved: dict[str, Graph] = {}
|
||||
for cmd_name, graph in self.graphs.items():
|
||||
resolved[cmd_name] = self.expand_refs(graph, cmd_name)
|
||||
return resolved
|
||||
|
||||
def expand_refs(self, graph: Graph, current_cmd: str) -> Graph:
|
||||
"""展开图中的字符串引用。若无 ``_pending_refs``,原样返回。"""
|
||||
pending_refs = graph._pending_refs
|
||||
if not pending_refs:
|
||||
return graph
|
||||
|
||||
all_specs: list[TaskSpec[Any]] = []
|
||||
previous_ref_last_task: str | None = None
|
||||
|
||||
for ref in pending_refs:
|
||||
expanded_specs = self.parse_ref(ref, current_cmd)
|
||||
if previous_ref_last_task and expanded_specs:
|
||||
for i, task in enumerate(expanded_specs):
|
||||
if i == 0 or not task.depends_on:
|
||||
expanded_specs[i] = replace(task, depends_on=tuple({*task.depends_on, previous_ref_last_task}))
|
||||
if expanded_specs:
|
||||
previous_ref_last_task = expanded_specs[-1].name
|
||||
all_specs.extend(expanded_specs)
|
||||
|
||||
original_specs = list(graph.all_specs().values())
|
||||
if original_specs:
|
||||
if previous_ref_last_task:
|
||||
first = original_specs[0]
|
||||
all_specs.append(replace(first, depends_on=tuple({*first.depends_on, previous_ref_last_task})))
|
||||
else:
|
||||
all_specs.append(original_specs[0])
|
||||
for i in range(1, len(original_specs)):
|
||||
current_task = original_specs[i]
|
||||
previous_task_name = original_specs[i - 1].name
|
||||
all_specs.append(
|
||||
replace(current_task, depends_on=tuple({*current_task.depends_on, previous_task_name}))
|
||||
)
|
||||
|
||||
return Graph.from_specs(all_specs, defaults=graph.defaults)
|
||||
|
||||
def parse_ref(self, ref: str, current_cmd: str) -> list[TaskSpec[Any]]:
|
||||
"""解析单个字符串引用,返回对应的 TaskSpec 列表。"""
|
||||
if ref == current_cmd:
|
||||
raise ValueError(f"循环引用: 命令 '{current_cmd}' 引用了自己")
|
||||
|
||||
if "." in ref:
|
||||
cmd_name, task_name = ref.split(".", 1)
|
||||
if cmd_name not in self.graphs:
|
||||
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
|
||||
ref_graph = self.graphs[cmd_name]
|
||||
if task_name not in ref_graph.all_specs():
|
||||
raise ValueError(f"任务 '{task_name}' 不存在于命令 '{cmd_name}' 中")
|
||||
return [ref_graph.all_specs()[task_name]]
|
||||
else:
|
||||
cmd_name = ref
|
||||
if cmd_name not in self.graphs:
|
||||
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
|
||||
ref_graph = self.graphs[cmd_name]
|
||||
ref_graph = self.expand_refs(ref_graph, cmd_name)
|
||||
return list(ref_graph.all_specs().values())
|
||||
|
||||
|
||||
def compose(
|
||||
graphs: dict[str, Graph],
|
||||
) -> dict[str, Graph]:
|
||||
"""编程式解析多图的字符串引用,返回展开后的新图映射。
|
||||
|
||||
与 :class:`GraphComposer` 等价,但作为独立函数暴露,供不使用
|
||||
:class:`~pyflowx.runner.CliRunner` 的编程式用户调用。
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> graphs = {
|
||||
... "build": px.Graph.from_specs([px.TaskSpec("b", cmd=["echo", "b"])]),
|
||||
... "all": px.Graph.from_specs(["build", px.TaskSpec("t", cmd=["echo", "t"])]),
|
||||
... }
|
||||
>>> resolved = px.compose(graphs)
|
||||
>>> "b" in resolved["all"].all_specs()
|
||||
True
|
||||
"""
|
||||
return GraphComposer(graphs).resolve_all()
|
||||
|
||||
@@ -15,11 +15,13 @@ import argparse
|
||||
import enum
|
||||
import sys
|
||||
from dataclasses import dataclass, field, replace
|
||||
from pathlib import Path
|
||||
from typing import Any, Sequence, get_args
|
||||
|
||||
from .compose import GraphComposer
|
||||
from .errors import PyFlowXError
|
||||
from .executors import Strategy, run
|
||||
from .graph import Graph, GraphComposer
|
||||
from .graph import Graph
|
||||
from .task import TaskSpec
|
||||
|
||||
__all__ = ["CliExitCode", "CliRunner"]
|
||||
@@ -137,9 +139,7 @@ class CliRunner:
|
||||
# ------------------------------------------------------------------ #
|
||||
def _prog_name(self) -> str:
|
||||
"""从 sys.argv[0] 推导程序名."""
|
||||
import os
|
||||
|
||||
return os.path.basename(sys.argv[0]) if sys.argv else "pyflowx"
|
||||
return Path(sys.argv[0]).name if sys.argv else "pyflowx"
|
||||
|
||||
def create_parser(self) -> argparse.ArgumentParser:
|
||||
"""创建参数解析器.
|
||||
|
||||
+37
-11
@@ -18,8 +18,9 @@ import sys
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping
|
||||
from typing import Any, ContextManager, Mapping
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
@@ -55,6 +56,22 @@ class StateBackend(ABC):
|
||||
def clear(self) -> None:
|
||||
"""清除所有存储状态。"""
|
||||
|
||||
def flush(self) -> None: # noqa: B027
|
||||
"""将内存中暂存的状态持久化到外部介质。
|
||||
|
||||
默认无操作(如 :class:`MemoryBackend` 无需落盘)。
|
||||
:class:`JSONBackend` 在 :meth:`batch` 期间会延迟落盘,需在退出时调用。
|
||||
"""
|
||||
|
||||
def batch(self) -> ContextManager[None]:
|
||||
"""返回一个上下文管理器,期间 :meth:`save` 可延迟 :meth:`flush`。
|
||||
|
||||
默认实现为 no-op(如 :class:`MemoryBackend`)。:class:`JSONBackend`
|
||||
覆盖为:进入时标记延迟,退出时统一 flush 一次,将每任务一次落盘
|
||||
(N 次写入)降为整次运行一次(O(N) 而非 O(N²))。
|
||||
"""
|
||||
return nullcontext()
|
||||
|
||||
|
||||
class _TTLStateBackendMixin(StateBackend):
|
||||
"""TTL 状态后端共享逻辑。
|
||||
@@ -158,13 +175,6 @@ class MemoryBackend(_TTLStateBackendMixin):
|
||||
def _clear_raw(self) -> None:
|
||||
self._store.clear()
|
||||
|
||||
def _expired(self, key: str) -> bool:
|
||||
"""键是否已过期(兼容旧测试 API)。"""
|
||||
entry = self._get_raw(key)
|
||||
if entry is None:
|
||||
return False
|
||||
return self._is_expired(entry[1])
|
||||
|
||||
|
||||
class JSONBackend(_TTLStateBackendMixin):
|
||||
"""基于文件的 JSON 存储,用于跨进程续跑。
|
||||
@@ -184,6 +194,7 @@ class JSONBackend(_TTLStateBackendMixin):
|
||||
self._path: str = path
|
||||
self._ttl = ttl
|
||||
self._store: dict[str, dict[str, Any]] = {}
|
||||
self._defer_flush: bool = False
|
||||
self._load()
|
||||
|
||||
def _load(self) -> None:
|
||||
@@ -244,11 +255,26 @@ class JSONBackend(_TTLStateBackendMixin):
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise StorageError(f"result of key {key!r} is not JSON-serialisable", exc) from exc
|
||||
super().save(key, value)
|
||||
if not self._defer_flush:
|
||||
self._flush()
|
||||
|
||||
@override
|
||||
def flush(self) -> None:
|
||||
self._flush()
|
||||
|
||||
def _expired(self, entry: Mapping[str, Any]) -> bool:
|
||||
"""带元数据的条目是否已过期(兼容旧测试 API)。"""
|
||||
return self._is_expired(float(entry.get("ts", 0)))
|
||||
@override
|
||||
@contextmanager
|
||||
def batch(self) -> Iterator[None]:
|
||||
"""进入批量模式:``save`` 暂不落盘,退出时统一 flush 一次。
|
||||
|
||||
将整次运行 N 个任务的 N 次全量落盘降为 1 次。
|
||||
"""
|
||||
self._defer_flush = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._defer_flush = False
|
||||
self._flush()
|
||||
|
||||
|
||||
def resolve_backend(backend: StateBackend | None) -> StateBackend:
|
||||
|
||||
+68
-101
@@ -17,14 +17,16 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -67,6 +69,8 @@ TaskCmd = Union[
|
||||
Strategy = Union[str, "StrategyKind"]
|
||||
StrategyKind = Any # 占位,避免循环;executors 模块用 Literal 约束
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 条件判断函数类型:接收依赖上下文(可能为空映射),返回是否应执行。
|
||||
Condition = Callable[[Context], bool]
|
||||
|
||||
@@ -291,13 +295,16 @@ class TaskSpec(Generic[T]):
|
||||
if self.fn is None and self.cmd is None:
|
||||
raise ValueError(f"TaskSpec '{self.name}': 必须提供 fn 或 cmd 参数。")
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def effective_fn(self) -> TaskFn[T]:
|
||||
"""获取有效的执行函数。
|
||||
|
||||
若提供 ``cmd``,返回包装后的命令执行函数;否则返回 ``fn``。
|
||||
包装函数在每次调用时从 ``self`` 读取 ``verbose``/``cwd``/``env``/
|
||||
``timeout``,避免闭包捕获运行期参数,使翻转字段无需重建 spec。
|
||||
|
||||
结果按实例缓存(:func:`functools.cached_property`):frozen dataclass
|
||||
字段不可变,``_wrap_cmd`` 生成的闭包稳定,无需每次访问重建。
|
||||
"""
|
||||
if self.cmd is not None:
|
||||
return self._wrap_cmd()
|
||||
@@ -306,11 +313,17 @@ class TaskSpec(Generic[T]):
|
||||
raise ValueError(f"TaskSpec '{self.name}': 没有可执行的函数或命令。") # pragma: no cover
|
||||
|
||||
def _wrap_cmd(self) -> TaskFn[Any]:
|
||||
"""将 cmd 包装为可执行函数。"""
|
||||
"""将 cmd 包装为可执行函数。
|
||||
|
||||
实际执行逻辑位于 :mod:`pyflowx.command`,避免 :class:`TaskSpec`
|
||||
作为纯数据结构混入命令执行逻辑。
|
||||
"""
|
||||
from .command import run_command
|
||||
|
||||
spec = self
|
||||
|
||||
def _run() -> T:
|
||||
return cast(T, _run_command(spec))
|
||||
return cast(T, run_command(spec))
|
||||
|
||||
_run.__name__ = spec.name
|
||||
return _run # type: ignore[return-value]
|
||||
@@ -368,12 +381,27 @@ class TaskSpec(Generic[T]):
|
||||
|
||||
def storage_key(self, context: Context) -> str:
|
||||
"""计算状态后端存储键。"""
|
||||
if self.cache_key is not None:
|
||||
try:
|
||||
return f"{self.name}:{self.cache_key(context)}"
|
||||
except Exception:
|
||||
return self.name
|
||||
return self.name
|
||||
if self.cache_key is None:
|
||||
return self.name
|
||||
try:
|
||||
return f"{self.name}:{self.cache_key(context)}"
|
||||
except (TypeError, ValueError, KeyError, AttributeError) as exc:
|
||||
# cache_key 抛出预期内的数据/类型异常时回退到 name,但仍记录警告
|
||||
# 以便用户发现 cache_key 实现中的 bug。
|
||||
logger.warning(
|
||||
"task %r: cache_key 回退到 name(%s: %s)",
|
||||
self.name,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
return self.name
|
||||
|
||||
|
||||
# 全局锁:序列化对进程级状态(os.environ / os.chdir)的临时修改。
|
||||
# ``fn`` 任务在 thread/async 策略下并发执行时,若各自配置了不同的
|
||||
# ``cwd``/``env``,会相互覆盖(os.chdir 与 os.environ 均为进程全局)。
|
||||
# 该锁仅包裹"切换→执行→恢复"区间,保证正确性;不使用 cwd/env 的任务不受影响。
|
||||
_env_cwd_lock = threading.RLock()
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -381,100 +409,39 @@ def _env_and_cwd(
|
||||
env: Mapping[str, str] | None,
|
||||
cwd: Path | None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""临时设置环境变量与工作目录。"""
|
||||
saved_env: dict[str, str] = {}
|
||||
saved_cwd: str | None = None
|
||||
if env:
|
||||
for k, v in env.items():
|
||||
if k in os.environ:
|
||||
saved_env[k] = os.environ[k]
|
||||
os.environ[k] = v
|
||||
if cwd is not None:
|
||||
saved_cwd = str(Path.cwd())
|
||||
os.chdir(cwd)
|
||||
try:
|
||||
"""临时设置环境变量与工作目录。
|
||||
|
||||
``os.environ`` 与 ``os.chdir`` 是进程级全局状态,在 thread/async 策略下
|
||||
并发执行多个带 ``env``/``cwd`` 的 ``fn`` 任务时会相互覆盖。本函数通过
|
||||
模块级 :data:`_env_cwd_lock` 串行化"切换→执行→恢复"区间,确保正确性。
|
||||
无 ``env`` 且无 ``cwd`` 时直接 yield,不获取锁。
|
||||
"""
|
||||
if not env and cwd is None:
|
||||
yield
|
||||
finally:
|
||||
if saved_cwd is not None:
|
||||
os.chdir(saved_cwd)
|
||||
# 恢复环境变量
|
||||
return
|
||||
with _env_cwd_lock:
|
||||
saved_env: dict[str, str] = {}
|
||||
saved_cwd: str | None = None
|
||||
if env:
|
||||
for k in env:
|
||||
if k in saved_env:
|
||||
os.environ[k] = saved_env[k]
|
||||
else:
|
||||
os.environ.pop(k, None)
|
||||
|
||||
|
||||
def _run_command(spec: TaskSpec[Any]) -> Any: # noqa: PLR0912
|
||||
"""执行 ``spec.cmd`` 指定的命令(list / shell 字符串 / 可调用对象)。"""
|
||||
cmd = spec.cmd
|
||||
verbose = spec.verbose
|
||||
cwd = spec.cwd
|
||||
timeout = spec.timeout
|
||||
env_override = spec.env
|
||||
|
||||
# 可调用对象:直接调用,返回其结果。
|
||||
if callable(cmd) and not isinstance(cmd, (list, str)):
|
||||
name = getattr(cmd, "__name__", "callable")
|
||||
if verbose:
|
||||
print(f"[verbose] 执行可调用命令: {name}", flush=True)
|
||||
if cwd is not None:
|
||||
print(f"[verbose] 工作目录: {cwd}", flush=True)
|
||||
try:
|
||||
return cmd()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"可调用命令执行异常: {name}: {e}") from e
|
||||
|
||||
is_list = isinstance(cmd, list)
|
||||
if is_list:
|
||||
cmd_str = " ".join(arg for arg in cmd) # type: ignore[union-attr]
|
||||
verb = "执行命令"
|
||||
label = "命令"
|
||||
else:
|
||||
cmd_str = cast(str, cmd)
|
||||
verb = "执行 Shell"
|
||||
label = "Shell 命令"
|
||||
|
||||
if verbose:
|
||||
print(f"[verbose] {verb}: {cmd_str}", flush=True)
|
||||
for k, v in env.items():
|
||||
if k in os.environ:
|
||||
saved_env[k] = os.environ[k]
|
||||
os.environ[k] = v
|
||||
if cwd is not None:
|
||||
print(f"[verbose] 工作目录: {cwd}", flush=True)
|
||||
|
||||
# 合并环境变量
|
||||
run_env: dict[str, str] | None = None
|
||||
if env_override:
|
||||
run_env = dict(os.environ)
|
||||
run_env.update(env_override)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cast(Union[str, List[str]], cmd),
|
||||
shell=not is_list,
|
||||
cwd=cwd,
|
||||
env=run_env,
|
||||
timeout=timeout,
|
||||
capture_output=not verbose,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(f"{label}未找到: {cmd_str}") from None
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(f"{label}执行超时: {cmd_str} ({timeout}s)") from None
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"{label}执行异常: {cmd_str}: {e}") from e
|
||||
|
||||
if verbose:
|
||||
print(f"[verbose] 返回码: {result.returncode}", flush=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
return None
|
||||
|
||||
err_msg = f"{label}执行失败: `{cmd_str}`, 返回码: {result.returncode}"
|
||||
if not verbose and result.stderr.strip():
|
||||
err_msg += f"\n{result.stderr.strip()}"
|
||||
raise RuntimeError(err_msg)
|
||||
saved_cwd = str(Path.cwd())
|
||||
os.chdir(cwd)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if saved_cwd is not None:
|
||||
os.chdir(saved_cwd)
|
||||
# 恢复环境变量
|
||||
if env:
|
||||
for k in env:
|
||||
if k in saved_env:
|
||||
os.environ[k] = saved_env[k]
|
||||
else:
|
||||
os.environ.pop(k, None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
|
||||
@@ -113,10 +113,7 @@ def write_file(path: str, content: str, encoding: str = "utf-8") -> px.TaskSpec:
|
||||
"""写入文件任务."""
|
||||
|
||||
def write():
|
||||
try:
|
||||
with open(path, "w", encoding=encoding) as f:
|
||||
f.write(content)
|
||||
except Exception as e:
|
||||
print(f"写入文件 {path} 失败: {e}")
|
||||
p = Path(path)
|
||||
p.write_text(content, encoding=encoding)
|
||||
|
||||
return px.TaskSpec(f"write_file_{path}", fn=write, verbose=True)
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
"""常用工具函数."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ["perf_timer"]
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Callable, TypedDict
|
||||
|
||||
try:
|
||||
from typing_extensions import ParamSpec, TypeVar
|
||||
except ImportError:
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class _PerformanceMetrics(TypedDict):
|
||||
"""性能指标."""
|
||||
|
||||
count: int
|
||||
total_time: float
|
||||
|
||||
|
||||
_perf_metrics: defaultdict[str, _PerformanceMetrics] = defaultdict(
|
||||
lambda: _PerformanceMetrics(
|
||||
count=0,
|
||||
total_time=0.0,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _generate_report(unit: str, precision: int) -> str:
|
||||
"""生成性能指标报告,返回报告字符串."""
|
||||
if not _perf_metrics:
|
||||
return ""
|
||||
|
||||
lines: list[str] = []
|
||||
lines.append("=" * 50)
|
||||
lines.append("性能指标报告 (Performance Metrics Report)")
|
||||
lines.append("-" * 50)
|
||||
|
||||
# 按总耗时排序,最耗时的函数排在前面
|
||||
sorted_metrics = sorted(_perf_metrics.items(), key=lambda x: x[1]["total_time"], reverse=True)
|
||||
|
||||
for name, metrics in sorted_metrics:
|
||||
avg_time = metrics["total_time"] / metrics["count"] if metrics["count"] > 0 else 0
|
||||
lines.append(
|
||||
f"{name}: "
|
||||
f"调用次数={metrics['count']}, "
|
||||
f"总耗时={metrics['total_time']:.{precision}f}{unit}, "
|
||||
f"平均耗时={avg_time:.{precision}f}{unit}"
|
||||
)
|
||||
|
||||
lines.append("=" * 50)
|
||||
report_str = "\n".join(lines)
|
||||
|
||||
# 同时输出到日志
|
||||
logging.info("\n".join(lines))
|
||||
|
||||
return report_str
|
||||
|
||||
|
||||
def perf_timer(unit: str = "ms", precision: int = 4, report: bool = False):
|
||||
"""性能计时器装饰器."""
|
||||
scale: dict[str, float] = {
|
||||
"s": 1.0,
|
||||
"ms": 1000.0,
|
||||
"us": 1000000.0,
|
||||
}
|
||||
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
|
||||
_perf_metrics[func.__name__]["count"] += 1
|
||||
_perf_metrics[func.__name__]["total_time"] += (end_time - start_time) * scale[unit]
|
||||
|
||||
if not report:
|
||||
logging.info(
|
||||
f"{func.__name__} {unit}: {_perf_metrics[func.__name__]['total_time']:.{precision}f}{unit}"
|
||||
)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
if report:
|
||||
import atexit
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.info(f"Performance metrics report enabled with unit {unit} and precision {precision}")
|
||||
|
||||
@atexit.register
|
||||
def _report_at_exit() -> None:
|
||||
"""在程序退出时报告性能指标."""
|
||||
_generate_report(unit, precision)
|
||||
|
||||
# 将报告生成逻辑提取为独立函数,便于测试
|
||||
|
||||
return decorator
|
||||
@@ -301,3 +301,59 @@ def test_dir_exists_false(tmp_path: Path):
|
||||
missing = tmp_path / "nonexistent"
|
||||
cond = BuiltinConditions.DIR_EXISTS(missing)
|
||||
assert cond({}) is False
|
||||
|
||||
|
||||
def test_builtin_is_windows_returns_module_condition():
|
||||
"""BuiltinConditions.IS_WINDOWS() 应返回模块级 IS_WINDOWS."""
|
||||
assert BuiltinConditions.IS_WINDOWS() is IS_WINDOWS
|
||||
|
||||
|
||||
def test_builtin_is_linux_returns_module_condition():
|
||||
"""BuiltinConditions.IS_LINUX() 应返回模块级 IS_LINUX."""
|
||||
assert BuiltinConditions.IS_LINUX() is IS_LINUX
|
||||
|
||||
|
||||
def test_builtin_is_macos_returns_module_condition():
|
||||
"""BuiltinConditions.IS_MACOS() 应返回模块级 IS_MACOS."""
|
||||
assert BuiltinConditions.IS_MACOS() is IS_MACOS
|
||||
|
||||
|
||||
def test_builtin_is_posix_returns_module_condition():
|
||||
"""BuiltinConditions.IS_POSIX() 应返回模块级 IS_POSIX."""
|
||||
assert BuiltinConditions.IS_POSIX() is IS_POSIX
|
||||
|
||||
|
||||
def test_file_content_exists_missing_file(tmp_path: Path):
|
||||
"""FILE_CONTENT_EXISTS 文件不存在时返回 False."""
|
||||
cond = BuiltinConditions.FILE_CONTENT_EXISTS(tmp_path / "missing.txt", "x")
|
||||
assert cond({}) is False
|
||||
|
||||
|
||||
def test_file_content_exists_contains_content(tmp_path: Path):
|
||||
"""FILE_CONTENT_EXISTS 文件包含内容时返回 True."""
|
||||
f = tmp_path / "f.txt"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
cond = BuiltinConditions.FILE_CONTENT_EXISTS(f, "world")
|
||||
assert cond({}) is True
|
||||
|
||||
|
||||
def test_file_content_exists_not_contains_content(tmp_path: Path):
|
||||
"""FILE_CONTENT_EXISTS 文件不包含内容时返回 False."""
|
||||
f = tmp_path / "f.txt"
|
||||
f.write_text("hello", encoding="utf-8")
|
||||
cond = BuiltinConditions.FILE_CONTENT_EXISTS(f, "missing")
|
||||
assert cond({}) is False
|
||||
|
||||
|
||||
def test_file_content_exists_decode_error_returns_false(tmp_path: Path):
|
||||
"""FILE_CONTENT_EXISTS 读取非 UTF-8 文件应返回 False(解码异常被吞)."""
|
||||
f = tmp_path / "bin.dat"
|
||||
f.write_bytes(b"\xff\xfe\x00bad")
|
||||
cond = BuiltinConditions.FILE_CONTENT_EXISTS(f, "x")
|
||||
assert cond({}) is False
|
||||
|
||||
|
||||
def test_dep_matches_missing_dep_returns_false():
|
||||
"""DEP_MATCHES 依赖不存在时应返回 False(覆盖 if not in ctx 分支)."""
|
||||
cond = BuiltinConditions.DEP_MATCHES("missing", lambda _v: True)
|
||||
assert cond({}) is False
|
||||
|
||||
@@ -99,7 +99,10 @@ def test_verbose_run_with_skipped_lifecycle(capsys: pytest.CaptureFixture[str]):
|
||||
|
||||
|
||||
def test_verbose_run_with_user_callback():
|
||||
"""Test px.run with verbose=True and user callback both called."""
|
||||
"""Test px.run with verbose=True and user callback both called.
|
||||
|
||||
预期事件序列:RUNNING(开始)→ SUCCESS(完成)。
|
||||
"""
|
||||
events = []
|
||||
|
||||
def on_event(event: px.TaskEvent):
|
||||
@@ -109,8 +112,9 @@ def test_verbose_run_with_user_callback():
|
||||
graph = px.Graph.from_specs([spec])
|
||||
report = px.run(graph, strategy="sequential", verbose=True, on_event=on_event)
|
||||
assert report.success
|
||||
assert len(events) == 1
|
||||
assert events[0].status == px.TaskStatus.SUCCESS
|
||||
assert len(events) == 2
|
||||
assert events[0].status == px.TaskStatus.RUNNING
|
||||
assert events[1].status == px.TaskStatus.SUCCESS
|
||||
|
||||
|
||||
def test_verbose_event_callback_success():
|
||||
|
||||
+74
-1
@@ -5,8 +5,8 @@ from __future__ import annotations
|
||||
import pytest
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.compose import GraphComposer, compose
|
||||
from pyflowx.errors import CycleError, DuplicateTaskError, MissingDependencyError
|
||||
from pyflowx.graph import GraphComposer, compose
|
||||
|
||||
|
||||
def _fn() -> None:
|
||||
@@ -319,6 +319,79 @@ def test_compose_function() -> None:
|
||||
assert "a1" in resolved["cmd_b"]
|
||||
|
||||
|
||||
def test_graph_composer_expand_refs_multiple_refs_chain() -> None:
|
||||
"""expand_refs 多个 ref 应串联依赖:后一个 ref 首任务依赖前一个 ref 末任务."""
|
||||
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn)])
|
||||
graph_c = px.Graph.from_specs([px.TaskSpec("c1", _fn)])
|
||||
graph_b = px.Graph.from_specs([px.TaskSpec("b1", _fn)])
|
||||
graph_b._pending_refs = ["cmd_a", "cmd_c"]
|
||||
|
||||
composer = GraphComposer({"cmd_a": graph_a, "cmd_c": graph_c, "cmd_b": graph_b})
|
||||
resolved = composer.resolve_all()
|
||||
|
||||
# c1 应依赖 a1(后 ref 首任务依赖前 ref 末任务)
|
||||
assert "a1" in resolved["cmd_b"]
|
||||
assert "c1" in resolved["cmd_b"]
|
||||
assert "b1" in resolved["cmd_b"]
|
||||
c1_spec = resolved["cmd_b"].all_specs()["c1"]
|
||||
assert "a1" in c1_spec.depends_on
|
||||
|
||||
|
||||
def test_graph_composer_expand_refs_ref_returns_empty() -> None:
|
||||
"""expand_refs 引用空图时,previous_ref_last_task 保持 None,original_specs 走 else 分支."""
|
||||
graph_empty = px.Graph.from_specs([])
|
||||
graph_b = px.Graph.from_specs([px.TaskSpec("b1", _fn)])
|
||||
graph_b._pending_refs = ["empty_cmd"]
|
||||
|
||||
composer = GraphComposer({"empty_cmd": graph_empty, "cmd_b": graph_b})
|
||||
resolved = composer.resolve_all()
|
||||
|
||||
# b1 保留,无额外依赖
|
||||
assert "b1" in resolved["cmd_b"]
|
||||
b1_spec = resolved["cmd_b"].all_specs()["b1"]
|
||||
assert b1_spec.depends_on == ()
|
||||
|
||||
|
||||
def test_graph_composer_expand_refs_multiple_original_specs_serialized() -> None:
|
||||
"""expand_refs 多个 original_specs 应串行依赖,且首个依赖 ref 末任务."""
|
||||
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn)])
|
||||
graph_b = px.Graph.from_specs([
|
||||
px.TaskSpec("b1", _fn),
|
||||
px.TaskSpec("b2", _fn),
|
||||
px.TaskSpec("b3", _fn),
|
||||
])
|
||||
graph_b._pending_refs = ["cmd_a"]
|
||||
|
||||
composer = GraphComposer({"cmd_a": graph_a, "cmd_b": graph_b})
|
||||
resolved = composer.resolve_all()
|
||||
|
||||
specs = resolved["cmd_b"].all_specs()
|
||||
# b1 依赖 a1(ref 末任务)
|
||||
assert "a1" in specs["b1"].depends_on
|
||||
# b2 依赖 b1,b3 依赖 b2(串行)
|
||||
assert "b1" in specs["b2"].depends_on
|
||||
assert "b2" in specs["b3"].depends_on
|
||||
|
||||
|
||||
def test_graph_composer_parse_ref_dot_notation_success() -> None:
|
||||
"""parse_ref 'cmd.task' 形式应返回对应单个 TaskSpec."""
|
||||
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn), px.TaskSpec("a2", _fn)])
|
||||
composer = GraphComposer({"cmd_a": graph_a})
|
||||
|
||||
result = composer.parse_ref("cmd_a.a2", "cmd_b")
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "a2"
|
||||
|
||||
|
||||
def test_graph_composer_parse_ref_dot_notation_cmd_not_found() -> None:
|
||||
"""parse_ref 'missing.task' 形式应检测命令不存在."""
|
||||
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn)])
|
||||
composer = GraphComposer({"cmd_a": graph_a})
|
||||
|
||||
with pytest.raises(ValueError, match="引用的命令 'missing' 不存在"):
|
||||
_ = composer.parse_ref("missing.task", "cmd_b")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# resolved_spec defaults 测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
|
||||
@@ -70,9 +70,9 @@ def test_memory_backend_ttl_load_filters_expired() -> None:
|
||||
|
||||
|
||||
def test_memory_backend_expired_key_not_in_store() -> None:
|
||||
"""_expired 对不存在键返回 False."""
|
||||
"""不存在的键 has 返回 False."""
|
||||
b = MemoryBackend(ttl=1.0)
|
||||
assert b._expired("nonexistent") is False
|
||||
assert b.has("nonexistent") is False
|
||||
|
||||
|
||||
def test_memory_backend_no_ttl_never_expired() -> None:
|
||||
@@ -244,35 +244,35 @@ def test_json_backend_ttl_load_filters_expired() -> None:
|
||||
|
||||
|
||||
def test_json_backend_expired_no_ttl() -> None:
|
||||
"""无 TTL 时 _expired 返回 False."""
|
||||
"""无 TTL 时永不过期."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path)
|
||||
b.save("a", 1)
|
||||
# 手动修改 ts 为很久以前
|
||||
b._store["a"]["ts"] = time.time() - 1000
|
||||
assert b._expired(b._store["a"]) is False # 无 TTL,永不过期
|
||||
assert b.has("a") is True # 无 TTL,永不过期
|
||||
|
||||
|
||||
def test_json_backend_expired_with_ttl() -> None:
|
||||
"""有 TTL 时 _expired 检查是否过期."""
|
||||
"""有 TTL 时过期键 has 返回 False."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path, ttl=1.0)
|
||||
b.save("a", 1)
|
||||
# 手动修改 ts 为很久以前
|
||||
b._store["a"]["ts"] = time.time() - 10 # 10 秒前,超过 TTL
|
||||
assert b._expired(b._store["a"]) is True
|
||||
assert b.has("a") is False
|
||||
|
||||
|
||||
def test_json_backend_expired_missing_ts() -> None:
|
||||
"""entry 缺少 ts 时使用默认值 0."""
|
||||
"""entry 缺少 ts 时视为过期."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path, ttl=1.0)
|
||||
b._store["a"] = {"value": 1} # 缺少 ts
|
||||
# ts 默认为 0,已经过了很久
|
||||
assert b._expired(b._store["a"]) is True
|
||||
assert b.has("a") is False
|
||||
|
||||
|
||||
def test_json_backend_save_value_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
+56
-1
@@ -2,11 +2,12 @@
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from pyflowx.conditions import Constants
|
||||
from pyflowx.tasks.system import clr, reset_icon_cache, setenv, which
|
||||
from pyflowx.tasks.system import clr, reset_icon_cache, setenv, setenv_group, which, write_file
|
||||
|
||||
|
||||
def test_clr_creates_task_spec() -> None:
|
||||
@@ -189,3 +190,57 @@ def test_which_not_found(monkeypatch: pytest.MonkeyPatch, capsys: pytest.Capture
|
||||
spec.fn()
|
||||
captured = capsys.readouterr()
|
||||
assert "nonexistent_cmd -> 未找到" in captured.out
|
||||
|
||||
|
||||
def test_write_file_creates_task_spec() -> None:
|
||||
"""write_file() 应创建带 verbose 的 TaskSpec。"""
|
||||
spec = write_file("/tmp/unused", "x")
|
||||
assert spec.name == "write_file_/tmp/unused"
|
||||
assert spec.verbose is True
|
||||
|
||||
|
||||
def test_write_file_writes_content(tmp_path: Path) -> None:
|
||||
"""write_file() 应将内容写入指定文件."""
|
||||
f = tmp_path / "out.txt"
|
||||
spec = write_file(str(f), "hello world")
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
assert f.read_text(encoding="utf-8") == "hello world"
|
||||
|
||||
|
||||
def test_write_file_with_encoding(tmp_path: Path) -> None:
|
||||
"""write_file() 应支持指定编码."""
|
||||
f = tmp_path / "out.txt"
|
||||
spec = write_file(str(f), "中文", encoding="utf-8")
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
assert f.read_text(encoding="utf-8") == "中文"
|
||||
|
||||
|
||||
def test_write_file_failure_propagates(tmp_path: Path) -> None:
|
||||
"""write_file() 写入失败应抛出异常(不吞异常)."""
|
||||
# 父目录不存在时写入应抛 FileNotFoundError
|
||||
missing = tmp_path / "no_such_dir" / "out.txt"
|
||||
spec = write_file(str(missing), "x")
|
||||
assert spec.fn is not None
|
||||
with pytest.raises(FileNotFoundError):
|
||||
spec.fn()
|
||||
|
||||
|
||||
def test_setenv_group_creates_specs() -> None:
|
||||
"""setenv_group() 应为每个环境变量创建 TaskSpec."""
|
||||
envs = {"VAR_A": "1", "VAR_B": "2"}
|
||||
specs = setenv_group(envs)
|
||||
assert len(specs) == 2
|
||||
assert specs[0].name == "setenv_var_a"
|
||||
assert specs[1].name == "setenv_var_b"
|
||||
|
||||
|
||||
def test_setenv_group_default_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""setenv_group(default=True) 不应覆盖已存在的环境变量."""
|
||||
monkeypatch.setenv("PYFLOWX_GROUP_EXISTS", "original")
|
||||
specs = setenv_group({"PYFLOWX_GROUP_EXISTS": "new"}, default=True)
|
||||
for spec in specs:
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
assert os.environ["PYFLOWX_GROUP_EXISTS"] == "original"
|
||||
|
||||
+9
-9
@@ -203,10 +203,10 @@ def test_is_cmd_available_callable_returns_true() -> None:
|
||||
# storage_key 异常处理
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_storage_key_cache_key_exception_returns_name() -> None:
|
||||
"""cache_key 抛异常时应返回任务名."""
|
||||
"""cache_key 抛预期异常(TypeError/ValueError/KeyError/AttributeError)时应返回任务名."""
|
||||
|
||||
def bad_cache_key(_ctx):
|
||||
raise RuntimeError("cache key error")
|
||||
raise ValueError("cache key error")
|
||||
|
||||
spec = TaskSpec("a", _fn, cache_key=bad_cache_key)
|
||||
key = spec.storage_key({})
|
||||
@@ -345,14 +345,14 @@ def test_task_result_default_status() -> None:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# _run_command callable 命令测试
|
||||
# run_command callable 命令测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_run_command_callable_verbose_with_cwd(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None:
|
||||
"""callable 命令 verbose 模式应打印信息."""
|
||||
spec = TaskSpec("a", cmd=lambda: "result", verbose=True, cwd=tmp_path)
|
||||
import pyflowx.task as task_module
|
||||
from pyflowx.command import run_command
|
||||
|
||||
result = task_module._run_command(spec)
|
||||
spec = TaskSpec("a", cmd=lambda: "result", verbose=True, cwd=tmp_path)
|
||||
result = run_command(spec)
|
||||
assert result == "result"
|
||||
captured = capsys.readouterr()
|
||||
assert "执行可调用命令" in captured.out
|
||||
@@ -361,8 +361,8 @@ def test_run_command_callable_verbose_with_cwd(capsys: pytest.CaptureFixture[str
|
||||
|
||||
def test_run_command_callable_exception() -> None:
|
||||
"""callable 命令抛异常应转为 RuntimeError."""
|
||||
spec = TaskSpec("a", cmd=lambda: (_ for _ in ()).throw(RuntimeError("callable error")))
|
||||
import pyflowx.task as task_module
|
||||
from pyflowx.command import run_command
|
||||
|
||||
spec = TaskSpec("a", cmd=lambda: (_ for _ in ()).throw(RuntimeError("callable error")))
|
||||
with pytest.raises(RuntimeError, match="可调用命令执行异常"):
|
||||
task_module._run_command(spec)
|
||||
run_command(spec)
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from pyflowx.utils import _perf_metrics, perf_timer
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_perf_metrics():
|
||||
"""重置性能指标."""
|
||||
_perf_metrics.clear()
|
||||
|
||||
|
||||
class TestPerformanceTimer:
|
||||
def test_perf_timer(self):
|
||||
|
||||
@perf_timer()
|
||||
def test_func():
|
||||
time.sleep(0.1)
|
||||
|
||||
test_func()
|
||||
|
||||
assert _perf_metrics["test_func"] is not None
|
||||
assert _perf_metrics["test_func"]["count"] == 1
|
||||
assert _perf_metrics["test_func"]["total_time"] >= 0.1
|
||||
|
||||
def test_perf_timer_report(self, mocker: MockerFixture):
|
||||
mock_log = mocker.patch("logging.info")
|
||||
|
||||
@perf_timer(report=True, unit="ms", precision=3)
|
||||
def test_func():
|
||||
time.sleep(0.1)
|
||||
|
||||
test_func()
|
||||
|
||||
assert _perf_metrics["test_func"] is not None
|
||||
assert _perf_metrics["test_func"]["count"] == 1
|
||||
assert _perf_metrics["test_func"]["total_time"] >= 0.1
|
||||
|
||||
assert mock_log.call_count == 1
|
||||
|
||||
def test_generate_report(self, mocker: MockerFixture, caplog: pytest.LogCaptureFixture):
|
||||
mock_log = mocker.patch("logging.info")
|
||||
|
||||
from pyflowx.utils import _generate_report
|
||||
|
||||
@perf_timer(report=True, unit="ms", precision=3)
|
||||
def test_func():
|
||||
time.sleep(0.1)
|
||||
|
||||
@perf_timer(report=True, unit="ms", precision=3)
|
||||
def test_func2():
|
||||
time.sleep(0.2)
|
||||
|
||||
test_func()
|
||||
test_func2()
|
||||
|
||||
_generate_report("ms", 3)
|
||||
|
||||
assert mock_log.call_count == 3
|
||||
assert _perf_metrics["test_func"]["count"] == 1
|
||||
assert _perf_metrics["test_func"]["total_time"] >= 0.1
|
||||
assert _perf_metrics["test_func2"]["count"] == 1
|
||||
assert _perf_metrics["test_func2"]["total_time"] >= 0.2
|
||||
Reference in New Issue
Block a user