7 Commits

Author SHA1 Message Date
zhou 232e7293d9 refactor(system): 简化write_file实现,使用pathlib替代手动文件操作。 2026-06-28 11:20:58 +08:00
zhou a1bae58e56 refactor: 优化日志配置与代码细节
1. 统一使用__name__替代硬编码的logger名称
2. 使用pathlib替代os.path处理程序名
3. 细化异常捕获并优化日志打印格式
4. 收紧文件内容检查的异常捕获范围
2026-06-28 10:57:51 +08:00
zhou cbc7cc0a75 docs: 拆分测试规范到独立技能文档并更新主规范
将原python-standards.md中的测试章节迁移到新建的pyflowx-testing/SKILL.md,更新主规范指向新文档,同时整理优化了整体文档结构与内容。
2026-06-28 10:19:26 +08:00
zhou d0ff7d7b4d docs: 更新 README 与新增 Python 开发规范文档
本次提交大幅完善了 PyFlowX 的 README 文档,新增了四种执行策略、软依赖、并发限制、任务钩子等多项特性说明,补充了任务模板、图组合、缓存键等新功能的使用示例,同时更新了执行参数、执行策略对照表与模块结构文档。另外新增了 .trae/rules/python-standards.md 规范文档,统一了项目的代码风格、类型检查、测试编写等开发标准。
2026-06-28 09:34:45 +08:00
zhou d154f67ce0 +trae ignore 2026-06-28 08:44:23 +08:00
zhou 9999071119 refactor(executors): 重构执行器逻辑,移除重复mixin并优化分层排序
主要变更:
1.  将任务跳过/重试逻辑从类mixin改为模块级函数,减少代码重复
2.  优化_graph.layers()的前置校验逻辑,统一在run入口执行
3.  重构存储过期检查API,移除废弃的_expired方法
4.  优化TaskSpec.cache_key异常处理,增加指定异常捕获并记录警告
5.  修复verbose模式下的事件回调逻辑,正确触发RUNNING事件
6.  调整测试用例以适配新的API和行为变更
2026-06-28 08:25:15 +08:00
zhou bdd70e9c43 refactor: 重构项目代码结构,拆分职责模块
1. 抽离图组合逻辑到pyflowx.compose,原graph.py仅保留单图DAG逻辑
2. 抽离命令执行逻辑到pyflowx.command,移除task.py内的_run_command
3. 重构上下文签名缓存,优化性能
4. 移除废弃的utils.perf_timer相关代码
5. 为JSONBackend添加batch批量落盘优化
6. 调整导入路径与公开API,更新测试用例
7. 简化条件判断逻辑,移除冗余代码
2026-06-28 02:28:38 +08:00
26 changed files with 1340 additions and 729 deletions
+135
View File
@@ -0,0 +1,135 @@
---
name: "pyflowx-testing"
description: "PyFlowX 项目的测试编写规范与 mock 使用指南。在编写或审查测试、选择 mock 工具、设计 fixture、处理 asyncio 测试时调用。"
---
# PyFlowX 测试规范
本技能是 `.trae/rules/python-standards.md` 测试章节的详细展开。
规则文件仅保留硬约束指针,本文件提供完整操作指南。
## 总则
- **覆盖率 ≥ 95%**branch coverage),不得下降。
- **公共 API 优先测试**:测试用公共接口(`has`/`get`),不访问私有方法
(如 `_expired`)。兼容旧测试的私有方法应删除并迁移测试。
例外:`_store`/`_flush` 等内部状态在无法用公共 API 触发时(如模拟过期、
故障注入),可临时访问私有属性,并在 docstring 注明原因。
- **命名**`test_<被测对象>_<场景>`,如 `test_storage_key_cache_key_exception_returns_name`
- **每个测试一个断言重点**;多个断言要语义相关。
- **slow 标记**:耗时测试加 `@pytest.mark.slow`CI 可 `-m "not slow"` 跳过。
- **测试代码也跑 ruff**`tests/**` 忽略 `ARG001`/`ARG002`(未用 fixture 参数)。
- **断言风格**:用原生 `assert` + 比较运算符(`assert x == 1`),
不用 `self.assertEqual`pytest 会生成更清晰的 diff。
## Mock 工具选择(强制)
**优先级**`monkeypatch` > 内联 stub > `unittest.mock` > `pytest-mock`
| 场景 | 工具 | 示例 |
|------|------|------|
| 替换模块属性 / 环境变量 / 工作目录 | `monkeypatch` | `monkeypatch.setattr(subprocess, "run", fake_run)` |
| `os.environ["KEY"]` 临时设置 | `monkeypatch.setenv` | `monkeypatch.setenv("LOCALAPPDATA", "C:\\...")` |
| 切换 cwd | `monkeypatch.chdir` | `monkeypatch.chdir(tmp_path)` |
| 一次性 stub 函数 | 内联 lambda / 闭包 | `ran = []; monkeypatch.setattr(subprocess, "run", lambda *c, **__: ran.append(c))` |
| 复杂 spy(记录调用次数/参数/返回序列) | `unittest.mock.MagicMock` | 仅当 lambda 不足以表达时 |
| `with patch(...)` 上下文 | **禁用**(用 monkeypatch | monkeypatch 自动 teardown 更安全 |
**禁止**
- 不用 `pytest-mock``mocker` fixture(项目虽在 dev 依赖声明,但实际
测试代码未使用;为保持风格统一,新代码继续用 `monkeypatch`)。
- 不用 `unittest.mock.patch` 装饰器(`@patch("x.y")`),它隐藏依赖且
与 pytest fixture 模式不兼容;用 `monkeypatch.setattr` 替代。
- 不用 `mock.patch.object` 作为上下文管理器,除非被测代码本身就是
contextmanager(此时用 `monkeypatch.setattr` 仍更简单)。
## monkeypatch 使用规范
- **类型注解**fixture 参数标注 `monkeypatch: pytest.MonkeyPatch`
- **作用域**monkeypatch 自动在测试结束时撤销,**禁止**手动
`monkeypatch.setattr(x, "y", original)` 恢复(多余且容易遗漏)。
例外:在单个测试内需要中途恢复时,用 `monkeypatch.undo()` 全量撤销。
- **替换目标**:替换"被测代码看到的对象",而非全局对象本身。
- 错误:`monkeypatch.setattr("os.path.exists", fake)` —— 替换全局,影响其他模块。
- 正确:`monkeypatch.setattr(pyflowx.command.shutil, "which", fake)` ——
替换被测模块引用的 `shutil.which`
- **属性 vs 字符串路径**:优先属性访问形式 `monkeypatch.setattr(obj, "attr", val)`
而非字符串路径 `monkeypatch.setattr("pkg.mod.obj.attr", val)`
前者有 IDE 跳转与重构支持。
- **记录调用**:用闭包 `ran: list[tuple] = []` + `lambda *a, **k: ran.append((a, k))`
替代 `MagicMock`,可读性更好且无需导入。
## Stub 与 Spy 模式
- **轻量 stub**:内联定义 `class MockResult: returncode = 0; stdout = ""`
替代 `MagicMock(return_value=...)`,类型明确且不引入 mock 依赖。
- **状态收集**:闭包 + list 比 `mock.call_args_list` 更易断言:
```python
calls: list[list[str]] = []
def fake_run(cmd: list[str], **_: Any) -> MockResult:
calls.append(cmd)
return MockResult()
monkeypatch.setattr(subprocess, "run", fake_run)
assert calls == [["clear"]]
```
- **副作用序列**:需要按调用次数返回不同值时,用 `itertools.cycle` 或
手动计数器,而非 `side_effect=[...]`mock 专有 API)。
- **异常注入**`def raise_oserror(*a, **k): raise OSError("...")`
用 `pytest.raises(OSError)` 验证,而非 `side_effect=OSError`。
## 异常断言
- **`pytest.raises`**:必填 `match=` 正则(除非异常消息完全不可预测),
避免误捕获同类异常:
```python
with pytest.raises(StorageError, match="cannot write"):
b.save("a", 1)
```
- **异常链**:验证 `__cause__` 时用 `exc_info.value.__cause__`
确认 `raise X from Y` 因果链完整。
- **禁止** `try/except + assert False`:用 `pytest.raises` 替代。
## Fixture 规范
- **`tmp_path`**:处理临时文件,自动清理,禁止 `tempfile.mkdtemp()` 手动管理。
- **`monkeypatch`**:环境变量、cwd、模块属性 mock(见上)。
- **`capsys`/`capfd`**:捕获 stdout/stderr,验证日志或命令输出。
- **autouse fixture**:仅在全局必需时用(如 `conftest.py` 的
`packtool_tmp_workdir` 自动切到 tmp_path);否则显式声明参数。
- **fixture 命名**`snake_case`,描述"提供什么"而非"测试什么"
`sample_graph` 优于 `test_data`)。
- **fixture 作用域**:默认 `function``module`/`session` 仅当构造昂贵且
只读时,并加注释说明无副作用。
## asyncio 测试
- **fixture `loop_scope="function"`**pyproject 已配置默认值)。
- **async 测试**`async def test_x():`pytest-asyncio 自动驱动。
- **await 检查**:测试异步函数必须 `await` 结果,禁止仅验证返回 coroutine 对象。
- **异步 mock**:用 `AsyncMock`3.8+ 在 `unittest.mock`)或
`async def fake(): return value`,禁用 `MagicMock(return_value=coro)`。
## 参数化
- **`@pytest.mark.parametrize`**:用 `ids` 参数提供可读标识:
```python
@pytest.mark.parametrize(
("strategy", "expected_workers"),
[("sequential", 1), ("thread", 8), ("async", 1)],
ids=["seq", "thread-8", "async"],
)
```
- **参数命名**:参数元组用有意义名称,而非 `("a", "b")`。
- **组合爆炸**:参数组合 > 20 时拆分测试,避免单个测试函数臃肿。
## 测试组织
- **文件命名**`test_<被测模块>.py``test_storage.py` 对应 `storage.py`)。
- **类分组**:仅在测试逻辑强相关时用 `class TestXxx:` 分组;默认用模块级函数。
- **docstring**:每个测试函数一句话说明"测试什么场景",复杂场景补充"为什么"。
- **setup/teardown**:优先 fixture`setup_method`/`teardown_method` 仅在
无法用 fixture 表达时(罕见)。
+15
View File
@@ -0,0 +1,15 @@
# PYTHON
.coverage
.pytest_cache/
.ruff_cache/
.tox/
.venv/
__pycache__/
# NODEJS
node_modules/
# IDE
.idea
.trae
.vscode
+11
View File
@@ -0,0 +1,11 @@
---
alwaysApply: true
scene: git_message
---
在此处编写规则,自定义 AI 生成提交信息的风格。
## 提交信息格式
- 提交信息必须使用中文。
- 提交信息必须包含变更的类型(例如 "fix"、"feat"、"refactor" 等)。
- 提交信息必须尽简洁明了,不要超过一段落。
+157
View File
@@ -0,0 +1,157 @@
# Python 开发规范
本规范结合 Python 最佳实践,作为编写与审查 Python 代码的统一标准。
详细操作指南见 `.agents/skills/` 下相应技能。
## 工具链(以 pyproject.toml 为准)
| 工具 | 用途 | 配置要点 |
|------|------|---------|
| **ruff** | lint + format | `line-length=120``target-version="py38"` |
| **pyrefly** | 类型检查 | `preset="strict"``python-version="3.8"` |
| **pytest** | 测试 | `asyncio_default_fixture_loop_scope="function"`marker `slow` |
| **coverage** | 覆盖率 | `branch=true``fail_under=95``concurrency=["thread"]` |
| **pre-commit** | 提交前检查 | ruff `--fix` + trailing-whitespace + end-of-file-fixer |
验证(每次修改后必做):
```bash
uvx --from pyflowx pymake tc
uvx --from pyflowx pymake cov
```
## 兼容性
- **最低 Python 3.8**:用 `from __future__ import annotations` 延迟注解求值;
按版本用 `typing.List`(3.8) → 内置泛型(3.9) → `X | Y`(3.10) → `typing.override`(3.12)。
- **版本守卫**`if sys.version_info >= (3, X):` 引入高版本 API;低版本回退分支加 `# pragma: no cover`
- **零运行时依赖**:仅依赖标准库(3.8 需 `graphlib_backport``typing-extensions`)。
新增依赖须审慎,优先用标准库。
## 类型注解
- **公共 API 必须有完整类型注解**,包括返回类型;私有函数也应有注解。
- 泛型用 `TypeVar`PEP 696 `default=` 仅 3.13+ 标准库支持,3.83.12 用 `typing_extensions.TypeVar`
- `Mapping`/`Sequence` 用于只读参数,`dict`/`list` 用于可变返回。
- `Any` 仅用于真正动态场景(如 `Context` 跨任务异构映射);任务内部类型必须完全静态。
- 禁用裸 `# type: ignore`;确需时加具体规则码(如 `# type: ignore[union-attr]`)。
- **`TYPE_CHECKING` 守卫**:仅类型检查需要的导入放 `if TYPE_CHECKING:` 块内,避免循环依赖。
- **类型收窄**:用 `assert isinstance(x, Y)` 辅助 pyrefly 推断;`cast()` 仅用于类型系统无法表达的场景。
## 数据结构
- **不可变优先**:配置/描述类用 `@dataclass(frozen=True)`;可变类属性标注 `RUF012` 豁免。
- **缓存**:实例级用 `functools.cached_property`,按参数键控用 `functools.lru_cache`
不可哈希参数需 try/except 回退。修改被缓存数据源后必须手动清空缓存。
- **抽象基类**:接口用 `abc.ABC` + `@abstractmethod`(如 `StateBackend`)。
- **枚举**:状态/标志值用 `enum.Enum`(如 `TaskStatus`),禁止裸字符串/魔术数字;枚举值用 `UPPER_SNAKE`
- **`__repr__`**:可变类实现 `__repr__`(含关键字段);`frozen=True` dataclass 自动生成。
## 模块与导入
- **单一职责**:每模块只做一件事(`task.py` 数据结构、`executors.py` 执行、`command.py` 命令、`compose.py` 组合)。禁止跨职责边界。
- **导入顺序**ruff isort):`__future__` → 标准库 → 第三方 → 本地,各组间空行。
- **惰性导入**:仅为打破循环依赖时使用,函数体内导入并注释说明;顶层导入是默认。
- **`__all__`**:定义 `__all__` 显式声明导出符号,位置仅次于 `__future__` 之后。
- **禁用 star imports**`from x import *` 污染命名空间、破坏类型检查(`__init__.py` 聚合经 `__all__` 控制为例外)。
- **避免 `utils.py`/`helpers.py`**:按职责归入对应模块。
## 函数设计
- **模块级函数优于 Mixin**:共享逻辑用模块级函数,类只持有状态与薄方法。
- **静态方法慎用**:纯函数直接放模块级。
- **参数 ≤ 5 个**为宜;超出用 dataclass 封装参数对象。
- **单一职责**:一个函数做一件事;过长函数考虑拆分。
- **异常范围要窄**:只捕获预期异常(如 `(TypeError, ValueError, KeyError, AttributeError)`),
**禁止** `except Exception` 掩盖 bug;捕获后至少 `logger.warning` 记录。
- **可变默认参数**`def f(x=[])` 是经典坑;用 `None` 哨兵或 `field(default_factory=list)`
## 异常处理
- **自定义异常家族**:继承公共基类(如 `PyFlowXError`),按错误场景分类。
- **异常包装**`raise NewError(...) from exc` 保留因果链。
- **不要吞异常**:捕获后必须处理(记录/包装/重抛),禁止空 `except: pass`
- **钩子/回调异常**:第三方回调异常仅记录,不影响主流程。
## 并发与线程安全
- **进程全局状态**`os.environ`/`os.chdir`)在并发场景下必须用全局锁(`threading.RLock`)序列化。
- **条件评估不可有可变状态**:组合条件(NOT/AND/OR)不得修改共享 `_reason`,避免竞态。
- **批量 I/O**:循环内多次写盘改为批量一次(`contextmanager` 包裹延迟落盘)。
- **信号量限流**`concurrency_key` + `Semaphore` 按组限流。
## 测试
详细操作指南见 `.agents/skills/pyflowx-testing` 技能。硬约束:
- **覆盖率 ≥ 95%**branch coverage),不得下降。
- **公共 API 优先测试**:用公共接口(`has`/`get`),不访问私有方法;
故障注入等场景可临时访问私有属性,docstring 注明原因。
- **命名**`test_<被测对象>_<场景>`
- **断言**:原生 `assert x == 1`,禁用 `self.assertEqual``pytest.raises` 必填 `match=`
- **Mock 优先级**`monkeypatch` > 内联 stub > `unittest.mock` > `pytest-mock`
禁用 `@patch` 装饰器、`mock.patch.object` 上下文、`pytest-mock``mocker` fixture。
- **fixture**`tmp_path`/`monkeypatch`/`capsys` 优先;autouse 仅全局必需时用。
- **slow 标记**:耗时测试加 `@pytest.mark.slow`CI 可 `-m "not slow"` 跳过。
- **测试代码也跑 ruff**`tests/**` 忽略 `ARG001`/`ARG002`
## 代码风格
- **行宽 120**ruff formatter 处理)。
- **docstring**:公共 API 必须有;中文叙述 + 中文注释是本项目既有风格。
- **打印和日志**:使用中文打印和日志,避免使用英文。
- **命名**`snake_case` 函数/变量,`PascalCase` 类,`UPPER_SNAKE` 常量,`_leading_underscore` 私有。
- **字符串引号**:ruff 默认双引号。
- **末尾单 `\n`**、**无尾随空格**pre-commit 强制)。
- **不用 emoji**:除非用户明确要求。
## Pythonic 风格
- **`is` 比较 `None`/`True`/`False`**:单例用 `is`,值用 `==`PEP 8 E711/E712)。
- **EAFP 优于 LBYL**:先尝试再处理异常,而非先检查再执行(避免竞态窗口)。
- **truthiness**`if items:` 优于 `if len(items) > 0:`
- **字符串格式化**:首选 f-string;`%` 仅用于 `logging` 延迟格式化。
- **推导式**优于 `map`+`filter`> 2 层拆为显式循环。
- **`enumerate`** 替代 `range(len())`**`zip`** 并行迭代(3.10+ 用 `strict=True`)。
- **解包** `a, b = pair` 优于索引访问;忽略值用 `_`
- **海象运算符 `:=`**(3.8+):赋值+判断合一,但不滥用。
## 日志
- **`logging.getLogger(__name__)`**:每模块独立 logger,禁用 `print` 调试残留。
- **结构化上下文**`extra={...}` 传字段;`logger.warning("task %r failed: %s", name, exc)` 优于 f-string(延迟格式化)。
- **日志级别**`DEBUG` 诊断 / `INFO` 关键流程 / `WARNING` 可恢复异常 / `ERROR` 需人工介入。
- **禁止日志密码/密钥**:脱敏后再记录。
## 路径与资源
- **优先 `pathlib.Path`**`Path("a") / "b"` 而非 `os.path.join`ruff `PTH` 强制);
禁止字符串拼接路径。类型注解用 `Path`,边界 `str` 立即包装。
- **`with` 语句**:文件、锁、连接、临时目录一律用 `with``contextlib.contextmanager`
多资源用 `contextlib.ExitStack`
- **显式关闭**:长生命周期对象(连接池、线程池)实现 `close()`,但优先 `with`
- **批量操作**:循环内多次 acquire/release 改为批量一次。
## 安全
- **禁用 `eval`/`exec`**:处理不可信输入时绝不使用;用 `ast.literal_eval` 或专用解析器。
- **`subprocess`**:禁用 `shell=True` 除非命令完全可信;优先 `list[str]` 形式。
- **凭证不入仓**:密钥/token/密码放 `.env` 或环境变量,`.gitignore` 必须包含 `.env`
- **日志脱敏**:记录请求/响应时移除 `Authorization``password` 等字段。
- **依赖审计**`uv lock` 后审阅新增依赖,避免引入已知 CVE 的包。
## 性能要点
- **避免重复计算**:循环内查询应缓存或预构建映射(如 `{name: spec}`)。
- **避免双重查找**`has(k)` + `get(k)` 改为单次 `get(k)` + `KeyError` 回退。
- **统一校验**:入口校验一次,下游路径不重复(如 `run()` 统一 `validate()``layers()` 不再重复)。
- **事件 emit**:任务生命周期必须 emit `RUNNING``SUCCESS`/`FAILED`/`SKIPPED`
不要留死分支(`# pragma: no cover` 是清理信号,应激活或删除)。
## Git 与提交
- **不自动提交/push**:除非用户明确要求。
- **不修改 git config**。
- **不运行破坏性命令**`push --force`/`reset --hard`/`clean -f`)除非用户明确要求。
- **staging**:按文件名添加,不用 `git add -A`/`git add .`,避免误加敏感文件。
- **commit message**:简洁,聚焦"为什么"而非"是什么";遵循仓库既有风格。
+143 -20
View File
@@ -14,18 +14,25 @@ PyFlowX 把"任务依赖"这件事做到极致简单:**参数名就是依赖
## 特性
- **零样板** —— 参数名即依赖,框架自动注入上游结果
- **种执行策略** —— `sequential`(调试)/ `thread`I/O 密集同步)/ `async`I/O 密集异步)
- **种执行策略** —— `sequential`(调试)/ `thread`I/O 密集同步)/ `async`I/O 密集异步)/ `dependency`(依赖驱动,最大化并行)
- **类型安全** —— `TaskSpec[T]` 把返回类型一路传到 `RunReport`mypy strict 通过
- **DAG 校验** —— 构建时即时校验重名、缺失依赖、环
- **自动分层** —— Kahn 算法分组,同层任务可并行
- **重试与超时** —— 每个任务独立配置 `retries` `timeout`
- **断点续跑** —— `MemoryBackend` / `JSONBackend`,成功结果可缓存复用
- **重试与超时** —— 每个任务独立配置 `RetryPolicy`max_attempts/delay/backoff/jitter/retry_on`timeout`
- **软依赖** —— `soft_depends_on` 仅用于上下文注入,不参与拓扑分层
- **并发限制** —— `concurrency_key` + `concurrency_limits` 按组限流
- **任务钩子** —— `TaskHooks`pre_run/post_run/on_failure)生命周期回调
- **断点续跑** —— `MemoryBackend` / `JSONBackend`,成功结果可缓存复用;`batch()` 批量落盘
- **缓存键** —— `cache_key` 函数基于输入计算稳定键,使不同输入产生独立缓存
- **命令任务** —— `cmd` 参数直接执行外部命令,支持列表/shell/可调用对象
- **条件执行** —— `conditions` 参数按平台、环境变量、应用安装等条件跳过任务
- **图组合** —— `compose` / `GraphComposer` 编程式展开多图字符串引用
- **任务模板** —— `task_template` 工厂批量生成相似 TaskSpec
- **图级默认值** —— `GraphDefaults` 统一配置 retry/timeout/concurrency 等
- **CLI 运行器** —— `CliRunner` 把多个图映射为命令行子命令,替代 Makefile
- **可观测** —— `on_event` 回调、`dry_run` 预览、`verbose` 生命周期日志、Mermaid 可视化
- **可观测** —— `on_event` 回调RUNNING/SUCCESS/FAILED/SKIPPED`dry_run` 预览、`verbose` 生命周期日志、Mermaid 可视化
- **零运行时依赖** —— 仅依赖标准库(3.8 需 `graphlib_backport`
- **95% 测试覆盖** —— 分支覆盖率>= 95%
- **97% 测试覆盖** —— 分支覆盖率 >= 95%
## 安装
@@ -67,23 +74,31 @@ print(report["double"]) # [2, 4, 6]
### TaskSpec —— 任务描述
`TaskSpec` 是不可变的任务描述符,是唯一需要配置的东西:
`TaskSpec` 是不可变的任务描述符`Generic[T]`,返回类型一路传到 `RunReport`,是唯一需要配置的东西:
```python
px.TaskSpec(
name="fetch_user", # 唯一标识
fn=fetch_user, # 同步或异步函数
cmd=["curl", "..."], # 或: 执行命令(覆盖 fn
depends_on=("auth",), # 依赖的任务名
depends_on=("auth",), # 依赖(参与拓扑分层)
soft_depends_on=("cache",), # 软依赖(仅注入,不参与分层)
args=(uid,), # 静态位置参数(追加在注入参数后)
kwargs={"timeout": 30}, # 静态关键字参数
retries=3, # 失败重试次数(0 = 仅一次)
retry=px.RetryPolicy(max_attempts=3, delay=1.0, backoff=2.0), # 重试策略
timeout=30.0, # 超时秒数(None = 不限制)
tags=("api", "user"), # 自由标签,用于子图过滤
conditions=(is_prod,), # 条件函数列表(全部为 True 才执行)
priority=10, # 同层内优先级(高优先执行,默认 0)
concurrency_key="db", # 并发分组键(配合 concurrency_limits 限流)
cache_key=lambda ctx: str(ctx.get("uid")), # 缓存键函数(不同输入独立缓存)
hooks=px.TaskHooks(pre_run=..., post_run=..., on_failure=...), # 生命周期钩子
cwd=Path("/tmp"), # 命令工作目录(仅 cmd 模式)
env={"DEBUG": "1"}, # 环境变量覆盖(fn 与 cmd 模式均生效)
verbose=True, # 打印命令输出(仅 cmd 模式)
skip_if_missing=True, # 命令不存在时自动跳过(仅 list[str] cmd
allow_upstream_skip=False, # 上游 SKIPPED/FAILED 时是否仍执行
continue_on_error=False, # 本任务失败是否不中断整体
)
```
@@ -97,18 +112,54 @@ px.TaskSpec(
### Graph —— DAG 构建
```python
graph = px.Graph.from_specs([...]) # 整批校验(推荐)
# 图级默认值:TaskSpec 字段为 None 时回退
defaults = px.GraphDefaults(retry=px.RetryPolicy(max_attempts=2), timeout=60.0)
graph = px.Graph.from_specs([...], defaults=defaults) # 整批校验(推荐)
# 或增量构建
graph = px.Graph()
graph = px.Graph(defaults=defaults)
graph.add(px.TaskSpec("a", fn_a))
graph.add(px.TaskSpec("b", fn_b, ("a",)))
graph.validate() # 显式校验(环检测)
graph.layers() # 拓扑分层
graph.layers() # 拓扑分层(run() 入口已统一校验,直接调用需自行先 validate)
graph.to_mermaid() # Mermaid 可视化
graph.describe() # 人类可读摘要
graph.subgraph(("api",)) # 按标签切片
graph.subgraph_by_names(("a", "b")) # 按名称切片
graph.map("fetch", [1, 2, 3], lambda i: TaskSpec(f"fetch_{i}", ...)) # 批量 fan-out
```
### 图组合 —— compose
`compose` / `GraphComposer` 把带字符串引用的多个图展开为纯 `Graph`
```python
graphs = {
"build": px.Graph.from_specs([px.TaskSpec("b", cmd=["echo", "b"])]),
"all": px.Graph.from_specs(["build", px.TaskSpec("t", cmd=["echo", "t"])]),
}
resolved = px.compose(graphs) # "all" 图中的 "build" 引用被展开
```
引用格式:`"command_name"`(整个图)或 `"command_name.task_name"`(特定任务)。
`CliRunner` 内部自动调用 `compose`
### 任务模板 —— task_template
`task_template` 工厂批量生成相似 TaskSpec:
```python
fetch = px.task_template(
fn=fetch_url,
retry=px.RetryPolicy(max_attempts=5),
timeout=30.0,
tags=("api",),
)
graph = px.Graph.from_specs([
fetch("users", url="https://api.example.com/users"),
fetch("posts", url="https://api.example.com/posts"),
])
```
### run —— 执行
@@ -116,12 +167,14 @@ graph.subgraph_by_names(("a", "b")) # 按名称切片
```python
report = px.run(
graph,
strategy="async", # sequential | thread | async
strategy="async", # sequential | thread | async | dependency
max_workers=8, # thread 策略的线程池大小
concurrency_limits={"db": 2}, # 按 concurrency_key 限流
dry_run=False, # True = 仅打印计划
verbose=False, # True = 打印任务生命周期日志
on_event=callback, # 状态转换回调
on_event=callback, # 状态转换回调RUNNING/SUCCESS/FAILED/SKIPPED
state=px.JSONBackend("state.json"), # 断点续跑后端
continue_on_error=False, # True = 单任务失败不中断整体
)
```
@@ -141,7 +194,7 @@ report.describe() # 人类可读报告
按顺序求值:
1. **标注为 `Context`** 的参数 → 接收完整上游结果映射
2. **名称匹配依赖** 的参数 → 接收该依赖的结果
2. **名称匹配依赖** 的参数 → 接收该依赖的结果(含软依赖,缺失时注入默认值)
3. **`**kwargs`** 参数 → 接收所有依赖结果(dict)
4. **`TaskSpec.args` / `kwargs`** → 为非依赖参数提供静态值
@@ -170,8 +223,11 @@ def fetch_user(uid: int) -> dict: # uid 来自 TaskSpec.args
| `sequential` | 串行 | 调试、CPU 密集 | 直接调用 | 事件循环 |
| `thread` | 线程池 | I/O 密集同步 | 线程池 | 不支持 |
| `async` | 事件循环 | I/O 密集异步 | 卸载到线程池 | 事件循环 |
| `dependency` | 依赖驱动 | 最大化并行度 | 卸载到线程池 | 事件循环 |
所有策略都遵循 `retries``timeout`、上下文注入、状态后端,并发出 `TaskEvent`
所有策略都遵循 `RetryPolicy``timeout`、上下文注入、状态后端`concurrency_limits`
并发出 `TaskEvent`RUNNING/SUCCESS/FAILED/SKIPPED)。`dependency` 策略无层屏障:
任务在其所有硬依赖完成后立即启动。
## 命令任务
@@ -275,12 +331,25 @@ python examples/async_aggregation.py
from pyflowx import JSONBackend
# 第一次运行:成功结果写入 state.json
backend = JSONBackend("state.json")
backend = JSONBackend("state.json", ttl=3600) # ttl 秒数,过期条目自动忽略
report = px.run(graph, strategy="sequential", state=backend)
# 第二次运行:已缓存任务自动跳过
# 第二次运行:已缓存任务自动跳过(状态为 SKIPPED
report = px.run(graph, strategy="sequential", state=backend)
# report.results 中缓存任务状态为 SKIPPED
```
`run()` 内部以 `backend.batch()` 包裹整个执行:所有 `save` 延迟到运行结束时统一落盘一次
`JSONBackend` 从 O(N²) 降为 O(N) 磁盘写入;`MemoryBackend` 为 no-op)。
**缓存键**:默认存储键为任务名。配置 `cache_key` 函数后,键为 `"name:cache_key_value"`
使不同输入产生独立缓存条目:
```python
px.TaskSpec(
"fetch_user",
fn=fetch_user,
cache_key=lambda ctx: str(ctx.get("uid")), # 不同 uid 独立缓存
)
```
## 错误处理
@@ -321,14 +390,52 @@ except px.PyFlowXError:
PyFlowX 专注于**单机 DAG 调度**的极致简洁,适合 ETL、数据处理、CI 流水线等场景。
## 高级特性
### 并发限制
`concurrency_key` 分组限流,避免压垮下游资源:
```python
graph = px.Graph.from_specs([
px.TaskSpec("q1", fn=query_db, concurrency_key="db"),
px.TaskSpec("q2", fn=query_db, concurrency_key="db"),
px.TaskSpec("q3", fn=query_db, concurrency_key="db"),
])
# 同一时刻最多 2 个 "db" 组任务运行
px.run(graph, strategy="async", concurrency_limits={"db": 2})
```
### 任务钩子
`TaskHooks` 在任务生命周期触发(异常仅记录,不影响任务状态):
```python
hooks = px.TaskHooks(
pre_run=lambda spec: print(f"start {spec.name}"),
post_run=lambda spec, value: print(f"done {spec.name}"),
on_failure=lambda spec, exc: alert(spec.name, exc),
)
px.TaskSpec("task", fn=work, hooks=hooks)
```
### 优先级
同层内按 `priority` 降序执行(稳定排序):
```python
px.TaskSpec("low", fn=work, priority=0)
px.TaskSpec("high", fn=work, priority=10) # 同层内先执行
```
## 开发
```bash
# 安装开发依赖
uv sync --extra dev
# 运行测试(含覆盖率)
uv run pytest --cov=pyflowx --cov-fail-under=100
# 运行测试(含覆盖率,阈值 95%
uv run pytest --cov=pyflowx --cov-fail-under=95
# 类型检查
uv run mypy
@@ -338,6 +445,22 @@ uv run ruff check src tests examples
uv run ruff format --check src tests examples
```
## 模块结构
| 模块 | 职责 |
|------|------|
| `task.py` | 纯数据结构:`TaskSpec``RetryPolicy``TaskHooks``TaskStatus` |
| `graph.py` | DAG 构建、校验、分层、可视化 |
| `compose.py` | 多图组合:`GraphComposer` / `compose` |
| `context.py` | 上下文注入:参数名→依赖解析 |
| `command.py` | 命令执行:`run_command`list/shell/Callable |
| `conditions.py` | 条件执行:内置条件与组合器 |
| `executors.py` | 执行器与 `run` 入口:四种策略共享模块级辅助 |
| `storage.py` | 状态后端:`MemoryBackend` / `JSONBackend`batch flush |
| `runner.py` | CLI 运行器:`CliRunner` |
| `report.py` | 运行结果:`RunReport` / `TaskResult` |
| `errors.py` | 错误家族:`PyFlowXError` 子类 |
## 许可证
MIT
+4 -1
View File
@@ -58,6 +58,8 @@
from __future__ import annotations
from .command import run_command
from .compose import GraphComposer, compose
from .conditions import (
IS_LINUX,
IS_MACOS,
@@ -79,7 +81,7 @@ from .errors import (
TaskTimeoutError,
)
from .executors import Strategy, run
from .graph import Graph, GraphComposer, GraphDefaults, compose
from .graph import Graph, GraphDefaults
from .report import RunReport
from .runner import CliExitCode, CliRunner
from .storage import JSONBackend, MemoryBackend, StateBackend
@@ -136,5 +138,6 @@ __all__ = [
"compose",
"describe_injection",
"run",
"run_command",
"task_template",
]
+6 -6
View File
@@ -240,7 +240,7 @@ def _parse_email_date(date_str: str) -> str:
try:
dt = parsedate_to_datetime(date_str)
return dt.isoformat()
except Exception:
except (ValueError, TypeError, OverflowError):
return date_str
@@ -277,11 +277,11 @@ def _extract_email_body_part(part: Any) -> str:
decoded_text = payload.decode(charset, errors="replace")
except (UnicodeDecodeError, LookupError) as decode_error:
# 如果指定编码失败,尝试常见编码
logger.warning(f"字符编码 {charset} 解码失败: {decode_error}")
logger.warning("字符编码 %s 解码失败: %s", charset, decode_error)
for fallback_charset in ["utf-8", "gbk", "gb2312", "latin-1"]:
try:
decoded_text = payload.decode(fallback_charset, errors="replace")
logger.info(f"成功使用备用编码 {fallback_charset} 解码")
logger.info("成功使用备用编码 %s 解码", fallback_charset)
break
except (UnicodeDecodeError, LookupError):
continue
@@ -293,15 +293,15 @@ def _extract_email_body_part(part: Any) -> str:
# 限制长度并返回
result = decoded_text[:MAX_BODY_LENGTH]
if len(decoded_text) > MAX_BODY_LENGTH:
logger.debug(f"正文内容过长,截取前{MAX_BODY_LENGTH}字符")
logger.debug("正文内容过长,截取前%d字符", MAX_BODY_LENGTH)
return result
except AttributeError as attr_error:
logger.error(f"邮件部分对象属性错误: {attr_error}")
logger.error("邮件部分对象属性错误: %s", attr_error)
return ""
except Exception as unexpected_error:
logger.error(f"提取邮件正文时发生未知错误: {unexpected_error}")
logger.error("提取邮件正文时发生未知错误: %s", unexpected_error)
return ""
+98
View File
@@ -0,0 +1,98 @@
"""命令执行器:把 :class:`~pyflowx.task.TaskSpec` 的 ``cmd`` 字段(list /
shell 字符串 / 可调用对象)转换为统一执行入口。
历史背景:原 ``task.py`` 的模块文档声明其为"纯数据结构",但 ``_run_command``
属于命令执行逻辑,违反单一职责。此处将其抽离,``TaskSpec`` 仅持有配置,
执行逻辑集中于本模块,便于独立测试与维护。
"""
from __future__ import annotations
import os
import subprocess
from typing import Any, List, Union, cast
from .task import TaskSpec
__all__ = ["run_command"]
def run_command(spec: TaskSpec[Any]) -> Any: # noqa: PLR0912
"""执行 ``spec.cmd`` 指定的命令(list / shell 字符串 / 可调用对象)。
与原 ``TaskSpec._run_command`` 行为一致:
- 可调用对象:直接调用,异常包装为 :class:`RuntimeError`。
- list / str:通过 :func:`subprocess.run` 执行,非零返回码抛
:class:`RuntimeError```verbose=False`` 时附 stderr)。
- ``verbose=True`` 时打印执行信息与返回码到 stdout。
- ``cwd`` / ``env`` 通过 subprocess 参数隔离(进程级状态仅在 fn 任务路径
使用,cmd 路径不依赖 ``os.chdir`` / ``os.environ``)。
"""
cmd = spec.cmd
verbose = spec.verbose
cwd = spec.cwd
timeout = spec.timeout
env_override = spec.env
# 可调用对象:直接调用,返回其结果。
if callable(cmd) and not isinstance(cmd, (list, str)):
name = getattr(cmd, "__name__", "callable")
if verbose:
print(f"[verbose] 执行可调用命令: {name}", flush=True)
if cwd is not None:
print(f"[verbose] 工作目录: {cwd}", flush=True)
try:
return cmd()
except Exception as e:
raise RuntimeError(f"可调用命令执行异常: {name}: {e}") from e
is_list = isinstance(cmd, list)
if is_list:
cmd_str = " ".join(arg for arg in cmd) # type: ignore[union-attr]
verb = "执行命令"
label = "命令"
else:
cmd_str = cast(str, cmd)
verb = "执行 Shell"
label = "Shell 命令"
if verbose:
print(f"[verbose] {verb}: {cmd_str}", flush=True)
if cwd is not None:
print(f"[verbose] 工作目录: {cwd}", flush=True)
# 合并环境变量
run_env: dict[str, str] | None = None
if env_override:
run_env = dict(os.environ)
run_env.update(env_override)
try:
result = subprocess.run(
cast(Union[str, List[str]], cmd),
shell=not is_list,
cwd=cwd,
env=run_env,
timeout=timeout,
capture_output=not verbose,
text=True,
check=False,
)
except FileNotFoundError:
raise RuntimeError(f"{label}未找到: {cmd_str}") from None
except subprocess.TimeoutExpired:
raise RuntimeError(f"{label}执行超时: {cmd_str} ({timeout}s)") from None
except OSError as e:
raise RuntimeError(f"{label}执行异常: {cmd_str}: {e}") from e
if verbose:
print(f"[verbose] 返回码: {result.returncode}", flush=True)
if result.returncode == 0:
return None
err_msg = f"{label}执行失败: `{cmd_str}`, 返回码: {result.returncode}"
if not verbose and result.stderr.strip():
err_msg += f"\n{result.stderr.strip()}"
raise RuntimeError(err_msg)
+115
View File
@@ -0,0 +1,115 @@
"""图组合:将带字符串引用的多个图展开为纯 :class:`~pyflowx.graph.Graph`。
历史背景:原 ``graph.py`` 同时承载 DAG 构建/校验/分层与多图组合逻辑,
职责过载。组合逻辑(:class:`GraphComposer` / :func:`compose`)与单图 DAG
模型正交,此处抽离为独立模块,便于按需导入与独立演进。
"""
from __future__ import annotations
from dataclasses import replace
from typing import Any
from .graph import Graph
from .task import TaskSpec
__all__ = ["GraphComposer", "compose"]
class GraphComposer:
"""将带字符串引用的图展开为纯 :class:`TaskSpec` 图。
引用格式:
* ``"command_name"`` —— 引用整个命令图。
* ``"command_name.task_name"`` —— 引用特定任务。
引用按顺序展开,后续引用的任务依赖前面引用的最后一个任务;
原始 ``TaskSpec`` 之间也按出现顺序串行依赖。
"""
def __init__(self, graphs: dict[str, Graph]) -> None:
self.graphs = graphs
def resolve_all(self) -> dict[str, Graph]:
"""解析所有图的字符串引用,返回展开后的新图映射。"""
resolved: dict[str, Graph] = {}
for cmd_name, graph in self.graphs.items():
resolved[cmd_name] = self.expand_refs(graph, cmd_name)
return resolved
def expand_refs(self, graph: Graph, current_cmd: str) -> Graph:
"""展开图中的字符串引用。若无 ``_pending_refs``,原样返回。"""
pending_refs = graph._pending_refs
if not pending_refs:
return graph
all_specs: list[TaskSpec[Any]] = []
previous_ref_last_task: str | None = None
for ref in pending_refs:
expanded_specs = self.parse_ref(ref, current_cmd)
if previous_ref_last_task and expanded_specs:
for i, task in enumerate(expanded_specs):
if i == 0 or not task.depends_on:
expanded_specs[i] = replace(task, depends_on=tuple({*task.depends_on, previous_ref_last_task}))
if expanded_specs:
previous_ref_last_task = expanded_specs[-1].name
all_specs.extend(expanded_specs)
original_specs = list(graph.all_specs().values())
if original_specs:
if previous_ref_last_task:
first = original_specs[0]
all_specs.append(replace(first, depends_on=tuple({*first.depends_on, previous_ref_last_task})))
else:
all_specs.append(original_specs[0])
for i in range(1, len(original_specs)):
current_task = original_specs[i]
previous_task_name = original_specs[i - 1].name
all_specs.append(
replace(current_task, depends_on=tuple({*current_task.depends_on, previous_task_name}))
)
return Graph.from_specs(all_specs, defaults=graph.defaults)
def parse_ref(self, ref: str, current_cmd: str) -> list[TaskSpec[Any]]:
"""解析单个字符串引用,返回对应的 TaskSpec 列表。"""
if ref == current_cmd:
raise ValueError(f"循环引用: 命令 '{current_cmd}' 引用了自己")
if "." in ref:
cmd_name, task_name = ref.split(".", 1)
if cmd_name not in self.graphs:
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
ref_graph = self.graphs[cmd_name]
if task_name not in ref_graph.all_specs():
raise ValueError(f"任务 '{task_name}' 不存在于命令 '{cmd_name}'")
return [ref_graph.all_specs()[task_name]]
else:
cmd_name = ref
if cmd_name not in self.graphs:
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
ref_graph = self.graphs[cmd_name]
ref_graph = self.expand_refs(ref_graph, cmd_name)
return list(ref_graph.all_specs().values())
def compose(
graphs: dict[str, Graph],
) -> dict[str, Graph]:
"""编程式解析多图的字符串引用,返回展开后的新图映射。
与 :class:`GraphComposer` 等价,但作为独立函数暴露,供不使用
:class:`~pyflowx.runner.CliRunner` 的编程式用户调用。
Examples
--------
>>> graphs = {
... "build": px.Graph.from_specs([px.TaskSpec("b", cmd=["echo", "b"])]),
... "all": px.Graph.from_specs(["build", px.TaskSpec("t", cmd=["echo", "t"])]),
... }
>>> resolved = px.compose(graphs)
>>> "b" in resolved["all"].all_specs()
True
"""
return GraphComposer(graphs).resolve_all()
+8 -26
View File
@@ -11,6 +11,7 @@
from __future__ import annotations
import logging
import os
import shutil
import subprocess
@@ -20,6 +21,8 @@ from typing import Any, Callable
from .task import Condition, Context
logger = logging.getLogger(__name__)
__all__ = ["BuiltinConditions", "Condition", "Constants"]
@@ -42,14 +45,6 @@ def _static(predicate: Callable[[], bool], name: str) -> Condition:
return _cond
def _cond_reason(cond: Condition) -> str | list[str] | None:
"""获取条件的失败原因:优先返回 ``_reason``,否则返回 ``__name__``。"""
reason = getattr(cond, "_reason", None)
if reason is not None:
return reason
return getattr(cond, "__name__", repr(cond))
def _cond_name(cond: Condition) -> str:
"""获取条件的可读名称。"""
return getattr(cond, "__name__", repr(cond))
@@ -161,7 +156,7 @@ class BuiltinConditions:
return False
try:
return content in p.read_text(encoding="utf-8")
except Exception:
except (OSError, UnicodeDecodeError):
return False
return _static(_check, f"FILE_CONTENT_EXISTS({path!r},{content!r})")
@@ -194,7 +189,8 @@ class BuiltinConditions:
return False
try:
return predicate(ctx[dep_name])
except Exception:
except Exception as exc:
logger.warning("DEP_MATCHES predicate %r raised: %r", dep_name, exc)
return False
_cond.__name__ = f"DEP_MATCHES({dep_name!r},{getattr(predicate, '__name__', 'pred')})"
@@ -228,13 +224,7 @@ class BuiltinConditions:
"""对条件取反."""
def _cond(ctx: Context) -> bool:
result = condition(ctx)
if result:
# inner 为 True 时 NOT 会失败,记录 inner 的具体原因
inner_reason = _cond_reason(condition)
if inner_reason is not None:
_cond._reason = inner_reason # type: ignore[attr-defined]
return not result
return not condition(ctx)
_cond.__name__ = f"NOT({_cond_name(condition)})"
return _cond
@@ -254,15 +244,7 @@ class BuiltinConditions:
"""多个条件的逻辑或."""
def _cond(ctx: Context) -> bool:
matched: list[str] = []
for c in conditions:
if c(ctx):
reason = _cond_reason(c)
matched.append(reason if isinstance(reason, str) else str(reason))
if matched:
_cond._reason = matched # type: ignore[attr-defined]
return True
return False
return any(c(ctx) for c in conditions)
_cond.__name__ = f"OR({', '.join(_cond_name(c) for c in conditions)})"
return _cond
+21 -2
View File
@@ -16,6 +16,7 @@ DAG 库中泛滥的样板包装器。
from __future__ import annotations
import inspect
from functools import lru_cache
from typing import Any, Mapping
from .errors import InjectionError
@@ -24,6 +25,24 @@ from .task import Context, TaskSpec
__all__ = ["Context", "_is_context_annotation", "build_call_args", "describe_injection"]
@lru_cache(maxsize=1024)
def _cached_signature(fn: Any) -> inspect.Signature:
"""缓存 ``inspect.signature`` 结果(按 fn 对象键控)。
``fn`` 对象在 :meth:`TaskSpec.effective_fn` 缓存后稳定,签名重复内省
属纯开销。对不可哈希的可调用对象,调用方回退到直接内省。
"""
return inspect.signature(fn)
def _signature(fn: Any) -> inspect.Signature:
"""获取签名,优先走缓存;``fn`` 不可哈希时回退到直接内省。"""
try:
return _cached_signature(fn)
except TypeError:
return inspect.signature(fn)
def _is_context_annotation(annotation: Any) -> bool:
"""判断参数标注是否为(或指向)``Context``。"""
if annotation is Context:
@@ -44,7 +63,7 @@ def build_call_args(
执行器填入 :attr:`TaskSpec.defaults` 中的默认值)。
"""
fn = spec.effective_fn
sig = inspect.signature(fn)
sig = _signature(fn)
params = sig.parameters
var_keyword = next(
@@ -115,7 +134,7 @@ def build_call_args(
def describe_injection(spec: TaskSpec[Any]) -> str:
"""生成任务参数注入方式的人类可读描述。供 ``dry_run`` 使用。"""
fn = spec.effective_fn
sig = inspect.signature(fn)
sig = _signature(fn)
positional_params = [
p
for p, param in sig.parameters.items()
+288 -255
View File
@@ -12,14 +12,18 @@
架构
----
本模块通过 **Mixin** 组合消除同步/异步与各层执行器之间的重复代码:
本模块通过 **模块级函数** 消除同步/异步任务执行器之间的重复代码:
* :class:`_TaskSkipMixin` —— 上游跳过 / 条件跳过的预检逻辑。
* :class:`_TaskRetryMixin` —— 重试决策、成功/失败后处理、finalize。
* :class:`_LayerMixin` —— 缓存过滤、优先级排序、信号量构建、结果存储
* :class:`SyncTaskRunner` / :class:`AsyncTaskRunner` —— 任务级执行器,组合上述 Mixin
* 模块级跳过/重试函数(:func:`_prepare_for_execution` / :func:`_should_retry`
/ :func:`_mark_success` / :func:`_handle_failure` / :func:`_finalize_failure`
—— 上游跳过 / 条件跳过的预检、重试决策、成功/失败后处理
* :class:`SyncTaskRunner` / :class:`AsyncTaskRunner` —— 任务级执行器,调用上述函数
* 模块级共享辅助(:func:`_filter_and_sort` / :func:`_store_result` /
:func:`_build_semaphores` / :func:`_get_sem`)—— 缓存过滤、优先级排序、
信号量构建、结果存储。
* :class:`SequentialLayerRunner` / :class:`ThreadedLayerRunner` /
:class:`AsyncLayerRunner` / :class:`DependencyRunner` —— 层级执行器,组合 :class:`_LayerMixin`
:class:`AsyncLayerRunner` —— 层级执行器,调用上述模块级辅助
* :class:`DependencyRunner` —— 依赖驱动调度(非层模型),同样调用模块级辅助。
所有策略共享统一异步内核,支持:
* :class:`RetryPolicy`max_attempts/delay/backoff/jitter/retry_on
@@ -52,7 +56,7 @@ from .report import RunReport
from .storage import StateBackend, resolve_backend
from .task import TaskEvent, TaskHooks, TaskResult, TaskSpec, TaskStatus
logger = logging.getLogger("pyflowx")
logger = logging.getLogger(__name__)
# 观察者回调类型。
EventCallback = Callable[[TaskEvent], None]
@@ -83,6 +87,22 @@ def _emit(on_event: EventCallback | None, result: TaskResult[Any]) -> None:
)
def _emit_running(on_event: EventCallback | None, spec: TaskSpec[Any]) -> None:
"""触发 RUNNING 事件(任务开始执行时)。"""
if on_event is None:
return
on_event(
TaskEvent(
task=spec.name,
status=TaskStatus.RUNNING,
attempts=0,
error=None,
duration=None,
reason=None,
)
)
def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None:
"""安全调用钩子(异常仅记录,不影响任务状态)。"""
hook: Callable[..., None] | None = getattr(hooks, fn_name, None)
@@ -126,11 +146,16 @@ def _apply_cached(
backend: StateBackend,
on_event: EventCallback | None,
) -> bool:
"""若 ``name`` 命中缓存,写入 context/report 并返回 True。"""
"""若 ``name`` 命中缓存,写入 context/report 并返回 True。
单次 ``backend.get`` + ``KeyError`` 回退,避免 ``has`` + ``get`` 双重
哈希查找与双重 TTL 判断。
"""
storage_key = spec.storage_key(context)
if not backend.has(storage_key):
try:
cached = backend.get(storage_key)
except KeyError:
return False
cached = backend.get(storage_key)
context[name] = cached
result = TaskResult(spec=spec, status=TaskStatus.SKIPPED, value=cached, reason="缓存命中")
report.results[name] = result
@@ -139,154 +164,146 @@ def _apply_cached(
return True
def _sort_by_priority(layer: list[str], graph: Graph) -> list[str]:
"""按优先级降序排序(稳定排序)。"""
return sorted(layer, key=lambda n: -graph.resolved_spec(n).priority)
def _sort_by_priority(layer: list[str], specs: Mapping[str, TaskSpec[Any]]) -> list[str]:
"""按优先级降序排序(稳定排序)。
# ---------------------------------------------------------------------- #
# Mixin:任务级跳过 / 重试 / 成功处理
# ---------------------------------------------------------------------- #
class _TaskSkipMixin:
"""任务级跳过预检共享逻辑。
"上游被跳过/失败""条件不满足"两类跳过判断统一为单一入口,
被 :class:`SyncTaskRunner` 与 :class:`AsyncTaskRunner` 复用。
接受预构建的 ``{name: spec}`` 映射,避免在排序键函数中重复调用
``graph.resolved_spec``(即便有缓存也省去 N 次字典查询)。
"""
return sorted(layer, key=lambda n: -specs[n].priority)
@staticmethod
def _upstream_skip_reason(spec: TaskSpec[Any], report: RunReport | None) -> str | None:
"""硬依赖被 SKIPPED/FAILED 时返回原因字符串,否则 ``None``。
软依赖不影响本检查——软依赖被跳过时注入默认值。
"""
if report is None or spec.allow_upstream_skip:
return None
for dep in spec.depends_on:
if dep not in report.results:
continue
dep_status = report.results[dep].status
if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED):
return f"上游任务 '{dep}' 状态为 {dep_status.value}"
# ---------------------------------------------------------------------- #
# 任务级跳过 / 重试 / 成功处理:模块级函数
# ---------------------------------------------------------------------- #
def _upstream_skip_reason(spec: TaskSpec[Any], report: RunReport | None) -> str | None:
"""硬依赖被 SKIPPED/FAILED 时返回原因字符串,否则 ``None``。
软依赖不影响本检查——软依赖被跳过时注入默认值。
"""
if report is None or spec.allow_upstream_skip:
return None
for dep in spec.depends_on:
if dep not in report.results:
continue
dep_status = report.results[dep].status
if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED):
return f"上游任务 '{dep}' 状态为 {dep_status.value}"
return None
@staticmethod
def _prepare_for_execution(
spec: TaskSpec[Any],
context: Mapping[str, Any],
report: RunReport | None,
on_event: EventCallback | None,
) -> TaskResult[Any] | None:
"""执行前预检:上游跳过 / 条件跳过。
返回 SKIPPED TaskResult 或 ``None``(继续执行)。
条件判断委托给 :meth:`TaskSpec.should_execute`,避免重复实现。
"""
# 1. 上游被跳过/失败
skip_reason = _TaskSkipMixin._upstream_skip_reason(spec, report)
# 2. 条件 / skip_if_missing(单一来源:TaskSpec.should_execute
if skip_reason is None:
should_run, cond_reason = spec.should_execute(context)
if not should_run:
skip_reason = cond_reason or "条件不满足"
if skip_reason is None:
return None
# 构造 SKIPPED 结果
result: TaskResult[Any] = TaskResult(
spec=spec,
status=TaskStatus.SKIPPED,
finished_at=datetime.now(),
reason=skip_reason,
def _prepare_for_execution(
spec: TaskSpec[Any],
context: Mapping[str, Any],
report: RunReport | None,
on_event: EventCallback | None,
) -> TaskResult[Any] | None:
"""执行前预检:上游跳过 / 条件跳过。
返回 SKIPPED TaskResult 或 ``None``(继续执行)。
条件判断委托给 :meth:`TaskSpec.should_execute`,避免重复实现。
"""
# 1. 上游被跳过/失败
skip_reason = _upstream_skip_reason(spec, report)
# 2. 条件 / skip_if_missing(单一来源:TaskSpec.should_execute
if skip_reason is None:
should_run, cond_reason = spec.should_execute(context)
if not should_run:
skip_reason = cond_reason or "条件不满足"
if skip_reason is None:
return None
# 构造 SKIPPED 结果
result: TaskResult[Any] = TaskResult(
spec=spec,
status=TaskStatus.SKIPPED,
finished_at=datetime.now(),
reason=skip_reason,
)
_emit(on_event, result)
logger.info("task %r skipped (%s)", spec.name, skip_reason)
return result
def _should_retry(spec: TaskSpec[Any], attempts: int, exc: BaseException) -> bool:
"""是否应继续重试。"""
return attempts < spec.retry.max_attempts and spec.retry.should_retry(exc)
def _mark_success(spec: TaskSpec[Any], result: TaskResult[Any], value: Any) -> None:
"""标记任务成功并触发 post_run 钩子。"""
result.value = value
result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now()
_run_hooks(spec.hooks, "post_run", spec, value)
def _finalize_failure(
result: TaskResult[Any],
layer_idx: int | None,
on_event: EventCallback | None,
continue_on_error: bool,
) -> None:
"""标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。"""
result.status = TaskStatus.FAILED
result.finished_at = datetime.now()
_emit(on_event, result)
if continue_on_error:
logger.warning(
"task %r failed but continue_on_error=True; continuing.",
result.spec.name,
)
_emit(on_event, result)
logger.info("task %r skipped (%s)", spec.name, skip_reason)
return result
return
raise TaskFailedError(
task=result.spec.name,
cause=result.error if result.error is not None else RuntimeError("unknown"),
attempts=result.attempts,
layer=layer_idx,
)
class _TaskRetryMixin:
"""任务级重试决策与失败/成功后处理共享逻辑。"""
def _handle_failure(
spec: TaskSpec[Any],
result: TaskResult[Any],
exc: BaseException,
layer_idx: int | None,
on_event: EventCallback | None,
) -> bool:
"""统一处理失败:超时转换、重试决策、finalize。
@staticmethod
def _should_retry(spec: TaskSpec[Any], attempts: int, exc: BaseException) -> bool:
"""是否应继续重试。"""
return attempts < spec.retry.max_attempts and spec.retry.should_retry(exc)
@staticmethod
def _mark_success(spec: TaskSpec[Any], result: TaskResult[Any], value: Any) -> None:
"""标记任务成功并触发 post_run 钩子。"""
result.value = value
result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now()
_run_hooks(spec.hooks, "post_run", spec, value)
@staticmethod
def _finalize_failure(
result: TaskResult[Any],
layer_idx: int | None,
on_event: EventCallback | None,
continue_on_error: bool,
) -> None:
"""标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。"""
result.status = TaskStatus.FAILED
result.finished_at = datetime.now()
_emit(on_event, result)
if continue_on_error:
logger.warning(
"task %r failed but continue_on_error=True; continuing.",
result.spec.name,
)
return
raise TaskFailedError(
task=result.spec.name,
cause=result.error if result.error is not None else RuntimeError("unknown"),
attempts=result.attempts,
layer=layer_idx,
Returns
-------
bool
``True`` 表示已 finalize(不再重试);``False`` 表示应继续重试。
"""
# asyncio.TimeoutError → TaskTimeoutError(统一异常类型)
if isinstance(exc, asyncio.TimeoutError):
exc = TaskTimeoutError(spec.name, spec.timeout or 0.0)
logger.warning(
"task %r timed out (attempt %d/%d); retrying",
spec.name,
result.attempts,
spec.retry.max_attempts,
)
@staticmethod
def _handle_failure(
spec: TaskSpec[Any],
result: TaskResult[Any],
exc: BaseException,
layer_idx: int | None,
on_event: EventCallback | None,
) -> bool:
"""统一处理失败:超时转换、重试决策、finalize。
Returns
-------
bool
``True`` 表示已 finalize(不再重试);``False`` 表示应继续重试。
"""
# asyncio.TimeoutError → TaskTimeoutError(统一异常类型)
if isinstance(exc, asyncio.TimeoutError):
exc = TaskTimeoutError(spec.name, spec.timeout or 0.0)
logger.warning(
"task %r timed out (attempt %d/%d); retrying",
spec.name,
result.attempts,
spec.retry.max_attempts,
)
else:
logger.warning(
"task %r failed (attempt %d/%d): %r; retrying",
spec.name,
result.attempts,
spec.retry.max_attempts,
exc,
)
result.error = exc
if _TaskRetryMixin._should_retry(spec, result.attempts, exc):
return False
_run_hooks(spec.hooks, "on_failure", spec, exc)
_TaskRetryMixin._finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return True
else:
logger.warning(
"task %r failed (attempt %d/%d): %r; retrying",
spec.name,
result.attempts,
spec.retry.max_attempts,
exc,
)
result.error = exc
if _should_retry(spec, result.attempts, exc):
return False
_run_hooks(spec.hooks, "on_failure", spec, exc)
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return True
# ---------------------------------------------------------------------- #
# 任务执行器:同步 / 异步(复用 _TaskSkipMixin + _TaskRetryMixin
# 任务执行器:同步 / 异步(调用模块级跳过/重试函数
# ---------------------------------------------------------------------- #
class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
class SyncTaskRunner:
"""同步任务执行器:带重试与跳过预检。"""
@staticmethod
@@ -297,7 +314,7 @@ class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
on_event: EventCallback | None = None,
report: RunReport | None = None,
) -> TaskResult[Any]:
skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event)
skipped = _prepare_for_execution(spec, context, report, on_event)
if skipped is not None:
return skipped
@@ -306,23 +323,24 @@ class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
args, kwargs = build_call_args(spec, context)
_run_hooks(spec.hooks, "pre_run", spec)
_emit_running(on_event, spec)
while True:
result.attempts += 1
try:
with spec.env_context():
value = spec.effective_fn(*args, **kwargs)
_TaskRetryMixin._mark_success(spec, result, value)
_mark_success(spec, result, value)
return result
except Exception as exc:
if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event):
if _handle_failure(spec, result, exc, layer_idx, on_event):
return result
wait = spec.retry.wait_seconds(result.attempts)
if wait > 0:
time.sleep(wait)
class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
class AsyncTaskRunner:
"""异步任务执行器:在事件循环上运行同步或异步任务,带重试与跳过预检。"""
@staticmethod
@@ -334,7 +352,7 @@ class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
report: RunReport | None = None,
semaphore: asyncio.Semaphore | None = None,
) -> TaskResult[Any]:
skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event)
skipped = _prepare_for_execution(spec, context, report, on_event)
if skipped is not None:
return skipped
@@ -345,15 +363,16 @@ class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
loop = asyncio.get_event_loop()
_run_hooks(spec.hooks, "pre_run", spec)
_emit_running(on_event, spec)
while True:
result.attempts += 1
try:
value = await _execute_async_task(spec, args, kwargs, loop)
_TaskRetryMixin._mark_success(spec, result, value)
_mark_success(spec, result, value)
return result
except Exception as exc:
if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event):
if _handle_failure(spec, result, exc, layer_idx, on_event):
return result
wait = spec.retry.wait_seconds(result.attempts)
if wait > 0:
@@ -388,81 +407,81 @@ async def _execute_async_task(
# ---------------------------------------------------------------------- #
# Mixin:层执行共享逻辑
# 共享辅助:缓存过滤、优先级排序、信号量构建、结果存储
# ---------------------------------------------------------------------- #
class _LayerMixin:
"""层执行共享逻辑:缓存过滤、优先级排序、信号量构建、结果存储。
def _filter_and_sort(
layer: list[str],
graph: Graph,
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
) -> list[str]:
"""过滤掉已命中缓存的任务,按优先级排序返回待运行列表。
四个层执行器(sequential/threaded/async/dependency)通过组合此 Mixin
消除"过滤缓存→排序→运行→存结果"的样板代码
预构建 ``{name: spec}`` 映射,过滤与排序共享同一份 resolved spec
避免 ``_sort_by_priority`` 内重复调用 ``graph.resolved_spec``
"""
specs: dict[str, TaskSpec[Any]] = {}
to_run: list[str] = []
for name in layer:
spec = graph.resolved_spec(name)
specs[name] = spec
if not _apply_cached(name, spec, context, report, backend, on_event):
to_run.append(name)
return _sort_by_priority(to_run, specs)
@staticmethod
def _filter_and_sort(
layer: list[str],
graph: Graph,
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
) -> list[str]:
"""过滤掉已命中缓存的任务,按优先级排序返回待运行列表。"""
to_run: list[str] = []
for name in layer:
spec = graph.resolved_spec(name)
if not _apply_cached(name, spec, context, report, backend, on_event):
to_run.append(name)
return _sort_by_priority(to_run, graph)
@staticmethod
def _store_result(
name: str,
result: TaskResult[Any],
graph: Graph,
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
context_snapshot: Mapping[str, Any] | None = None,
) -> None:
"""存储任务结果到 context/report/backend 并触发事件。"""
context[name] = result.value
if result.status == TaskStatus.SUCCESS:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot if context_snapshot is not None else context, report)
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
def _store_result(
name: str,
result: TaskResult[Any],
spec: TaskSpec[Any],
task_ctx: dict[str, Any],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
) -> None:
"""存储任务结果到 context/report/backend 并触发事件。
@staticmethod
def _build_semaphores(
to_run: list[str],
graph: Graph,
sem_factory: Callable[[int], Any],
concurrency_limits: Mapping[str, int],
) -> dict[str, Any]:
"""为每个 ``concurrency_key`` 创建一个信号量。"""
semaphores: dict[str, Any] = {}
for name in to_run:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = sem_factory(limit)
return semaphores
``spec`` 与 ``task_ctx`` 由调用方在执行前已构建,直接复用避免重复
``resolved_spec`` / ``_build_context`` 调用。
"""
context[name] = result.value
if result.status == TaskStatus.SUCCESS:
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
@staticmethod
def _get_sem(semaphores: Mapping[str, Any], spec: TaskSpec[Any]) -> Any | None:
"""获取任务对应的信号量(无 concurrency_key 则返回 None)。"""
if spec.concurrency_key is None:
return None
return semaphores.get(spec.concurrency_key)
def _build_semaphores(
to_run: list[str],
graph: Graph,
sem_factory: Callable[[int], Any],
concurrency_limits: Mapping[str, int],
) -> dict[str, Any]:
"""为每个 ``concurrency_key`` 创建一个信号量。"""
semaphores: dict[str, Any] = {}
for name in to_run:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = sem_factory(limit)
return semaphores
def _get_sem(semaphores: Mapping[str, Any], spec: TaskSpec[Any]) -> Any | None:
"""获取任务对应的信号量(无 concurrency_key 则返回 None)。"""
if spec.concurrency_key is None:
return None
return semaphores.get(spec.concurrency_key)
# ---------------------------------------------------------------------- #
# 层执行器
# ---------------------------------------------------------------------- #
class SequentialLayerRunner(_LayerMixin):
class SequentialLayerRunner:
"""逐个运行某层的任务(按优先级排序)。"""
@staticmethod
@@ -475,14 +494,14 @@ class SequentialLayerRunner(_LayerMixin):
layer_idx: int,
on_event: EventCallback | None,
) -> None:
for name in SequentialLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event):
for name in _filter_and_sort(layer, graph, context, report, backend, on_event):
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context, report)
result = SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
SequentialLayerRunner._store_result(name, result, graph, context, report, backend, on_event)
_store_result(name, result, spec, task_ctx, context, report, backend, on_event)
class ThreadedLayerRunner(_LayerMixin):
class ThreadedLayerRunner:
"""在线程池中并发运行某层的任务。"""
@staticmethod
@@ -497,43 +516,43 @@ class ThreadedLayerRunner(_LayerMixin):
max_workers: int,
concurrency_limits: Mapping[str, int],
) -> None:
to_run = ThreadedLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event)
to_run = _filter_and_sort(layer, graph, context, report, backend, on_event)
if not to_run:
return
semaphores = ThreadedLayerRunner._build_semaphores(to_run, graph, threading.Semaphore, concurrency_limits)
semaphores = _build_semaphores(to_run, graph, threading.Semaphore, concurrency_limits)
context_snapshot = dict(context)
lock = threading.Lock()
def _run_threaded_task(name: str) -> TaskResult[Any]:
def _run_threaded_task(name: str) -> tuple[dict[str, Any], TaskResult[Any]]:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report)
sem = ThreadedLayerRunner._get_sem(semaphores, spec)
sem = _get_sem(semaphores, spec)
if sem is not None:
sem.acquire()
try:
return SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
return task_ctx, SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
finally:
if sem is not None:
sem.release()
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
future_to_name: dict[concurrent.futures.Future[TaskResult[Any]], str] = {
future_to_name: dict[concurrent.futures.Future[tuple[dict[str, Any], TaskResult[Any]]], str] = {
pool.submit(_run_threaded_task, name): name for name in to_run
}
completed: dict[str, TaskResult[Any]] = {}
completed: dict[str, tuple[dict[str, Any], TaskResult[Any]]] = {}
try:
for fut in concurrent.futures.as_completed(future_to_name):
name = future_to_name[fut]
completed[name] = fut.result()
finally:
with lock:
for name, result in completed.items():
ThreadedLayerRunner._store_result(
name, result, graph, context, report, backend, on_event, context_snapshot
for name, (task_ctx, result) in completed.items():
_store_result(
name, result, graph.resolved_spec(name), task_ctx, context, report, backend, on_event
)
class AsyncLayerRunner(_LayerMixin):
class AsyncLayerRunner:
"""在事件循环上并发运行某层的任务。"""
@staticmethod
@@ -547,27 +566,32 @@ class AsyncLayerRunner(_LayerMixin):
on_event: EventCallback | None,
concurrency_limits: Mapping[str, int],
) -> None:
to_run = AsyncLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event)
to_run = _filter_and_sort(layer, graph, context, report, backend, on_event)
if not to_run:
return
semaphores = AsyncLayerRunner._build_semaphores(to_run, graph, asyncio.Semaphore, concurrency_limits)
semaphores = _build_semaphores(to_run, graph, asyncio.Semaphore, concurrency_limits)
context_snapshot = dict(context)
async def _run_async_task(name: str) -> TaskResult[Any]:
async def _run_async_task(name: str) -> tuple[dict[str, Any], TaskResult[Any]]:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report)
sem = AsyncLayerRunner._get_sem(semaphores, spec)
return await AsyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report, sem)
sem = _get_sem(semaphores, spec)
result = await AsyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report, sem)
return task_ctx, result
results = await asyncio.gather(*[_run_async_task(name) for name in to_run])
for name, result in zip(to_run, results):
AsyncLayerRunner._store_result(name, result, graph, context, report, backend, on_event, context_snapshot)
for name, (task_ctx, result) in zip(to_run, results):
_store_result(name, result, graph.resolved_spec(name), task_ctx, context, report, backend, on_event)
class DependencyRunner(_LayerMixin):
class DependencyRunner:
"""依赖驱动调度:任务在硬/软依赖完成后立即启动,无层屏障。
所有任务通过 asyncio 并发调度。同步任务卸载到线程池。
本类不继承层 Mixin:依赖驱动调度不是层模型,直接调用模块级共享辅助
函数(:func:`_build_semaphores` / :func:`_get_sem` / :func:`_store_result`),
职责更清晰。
"""
@staticmethod
@@ -580,7 +604,7 @@ class DependencyRunner(_LayerMixin):
concurrency_limits: Mapping[str, int],
) -> None:
all_names = list(graph.all_specs().keys())
semaphores = DependencyRunner._build_semaphores(all_names, graph, asyncio.Semaphore, concurrency_limits)
semaphores = _build_semaphores(all_names, graph, asyncio.Semaphore, concurrency_limits)
futures: dict[str, asyncio.Future[TaskResult[Any]]] = {}
async def _run_task(name: str) -> TaskResult[Any]:
@@ -598,9 +622,9 @@ class DependencyRunner(_LayerMixin):
if _apply_cached(name, spec, context, report, backend, on_event):
return report.results[name]
sem = DependencyRunner._get_sem(semaphores, spec)
sem = _get_sem(semaphores, spec)
result = await AsyncTaskRunner.run(spec, task_ctx, None, on_event, report, sem)
DependencyRunner._store_result(name, result, graph, context, report, backend, on_event)
_store_result(name, result, spec, task_ctx, context, report, backend, on_event)
return result
loop = asyncio.get_event_loop()
@@ -617,7 +641,7 @@ def _make_verbose_callback(on_event: EventCallback | None) -> EventCallback:
def _verbose_callback(event: TaskEvent) -> None:
dur = f" ({event.duration:.3f}s)" if event.duration is not None else ""
if event.status == TaskStatus.RUNNING: # pragma: no cover
if event.status == TaskStatus.RUNNING:
print(f"[verbose] 任务 {event.task!r} 开始执行...", flush=True)
elif event.status == TaskStatus.SUCCESS:
print(f"[verbose] 任务 {event.task!r} 成功{dur}", flush=True)
@@ -677,33 +701,42 @@ def run(
TaskFailedError
任何任务耗尽重试后仍失败时(除非 ``continue_on_error=True``)。
"""
graph.validate()
layers = graph.layers()
if dry_run:
layers = graph.layers()
_print_dry_run(graph, layers)
return RunReport(success=True)
# 入口统一校验一次:所有策略共用,避免 layers() / dependency 路径
# 各自重复调用 validate()。
graph.validate()
effective_callback: EventCallback | None = _make_verbose_callback(on_event) if verbose else on_event
backend = resolve_backend(state)
report = RunReport()
context: dict[str, Any] = {}
limits = concurrency_limits or {}
try:
if strategy == "sequential":
_drive_sequential(graph, layers, context, report, backend, effective_callback)
elif strategy == "thread":
_drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers, limits)
elif strategy == "async":
asyncio.run(_async_drive(graph, layers, context, report, backend, effective_callback, limits))
elif strategy == "dependency":
asyncio.run(DependencyRunner.execute(graph, context, report, backend, effective_callback, limits))
else:
raise ValueError(f"Unknown strategy: {strategy!r}")
except TaskFailedError:
report.success = False
raise
# backend.batch():将每任务一次落盘降为整次运行一次(JSONBackend);
# MemoryBackend 为 no-op。即使中途抛出 TaskFailedErrorbatch 退出时
# 仍会 flush 一次,保留已成功任务的结果以便断点续跑。
with backend.batch():
try:
if strategy == "sequential":
layers = graph.layers()
_drive_sequential(graph, layers, context, report, backend, effective_callback)
elif strategy == "thread":
layers = graph.layers()
_drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers, limits)
elif strategy == "async":
layers = graph.layers()
asyncio.run(_async_drive(graph, layers, context, report, backend, effective_callback, limits))
elif strategy == "dependency":
asyncio.run(DependencyRunner.execute(graph, context, report, backend, effective_callback, limits))
else:
raise ValueError(f"Unknown strategy: {strategy!r}")
except TaskFailedError:
report.success = False
raise
return report
+17 -103
View File
@@ -82,6 +82,10 @@ class Graph:
# 待解析的字符串引用列表(由 GraphComposer 消费);为空表示无引用。
_pending_refs: list[str] = field(default_factory=list)
# resolved_spec 缓存:避免执行期每个任务多次重复 dataclasses.replace 判断。
# 在 specs / defaults 变更时失效。
_resolved_cache: dict[str, TaskSpec[Any]] = field(default_factory=dict)
# ------------------------------------------------------------------ #
# 构建
# ------------------------------------------------------------------ #
@@ -97,6 +101,7 @@ class Graph:
self.specs[spec.name] = spec
# 拓扑依赖仅含硬依赖;软依赖仅用于注入,不影响分层。
self.deps[spec.name] = spec.depends_on
self._resolved_cache.clear()
@classmethod
def from_specs(
@@ -175,7 +180,12 @@ class Graph:
对于 ``retry``/``timeout``/``strategy``/``env``/``cwd`` 等可空
字段,若 spec 字段为默认空值且图级默认值非空,则用
:func:`dataclasses.replace` 生成带默认值的副本。
结果按 ``name`` 缓存;specs / defaults 变更时缓存失效。
"""
cached = self._resolved_cache.get(name)
if cached is not None:
return cached
spec = self.specs[name]
d = self.defaults
overrides: dict[str, Any] = {}
@@ -199,9 +209,9 @@ class Graph:
overrides["verbose"] = True
if not spec.tags and d.tags:
overrides["tags"] = d.tags
if not overrides:
return spec
return replace(spec, **overrides)
resolved = spec if not overrides else replace(spec, **overrides)
self._resolved_cache[name] = resolved
return resolved
def dependencies(self, name: str) -> tuple[str, ...]:
"""``name`` 的直接硬依赖前驱。"""
@@ -221,8 +231,11 @@ class Graph:
同层任务无相互硬依赖,可并发执行。软依赖不参与分层。
层按执行顺序返回。图有环时抛出 :class:`CycleError`。
.. note::
本方法假定图已通过 :meth:`validate` 校验(由 :func:`pyflowx.run`
在入口统一执行一次)。若直接调用本方法,需自行先校验。
"""
self.validate()
sorter = _TopologicalSorter(self.deps)
result: list[list[str]] = []
sorter.prepare()
@@ -355,102 +368,3 @@ class Graph:
def __contains__(self, name: Any) -> bool:
return name in self.specs
class GraphComposer:
"""将带字符串引用的图展开为纯 :class:`TaskSpec` 图。
引用格式:
* ``"command_name"`` —— 引用整个命令图。
* ``"command_name.task_name"`` —— 引用特定任务。
引用按顺序展开,后续引用的任务依赖前面引用的最后一个任务;
原始 ``TaskSpec`` 之间也按出现顺序串行依赖。
"""
def __init__(self, graphs: dict[str, Graph]) -> None:
self.graphs = graphs
def resolve_all(self) -> dict[str, Graph]:
"""解析所有图的字符串引用,返回展开后的新图映射。"""
resolved: dict[str, Graph] = {}
for cmd_name, graph in self.graphs.items():
resolved[cmd_name] = self.expand_refs(graph, cmd_name)
return resolved
def expand_refs(self, graph: Graph, current_cmd: str) -> Graph:
"""展开图中的字符串引用。若无 ``_pending_refs``,原样返回。"""
pending_refs = graph._pending_refs
if not pending_refs:
return graph
all_specs: list[TaskSpec[Any]] = []
previous_ref_last_task: str | None = None
for ref in pending_refs:
expanded_specs = self.parse_ref(ref, current_cmd)
if previous_ref_last_task and expanded_specs:
for i, task in enumerate(expanded_specs):
if i == 0 or not task.depends_on:
expanded_specs[i] = replace(task, depends_on=tuple({*task.depends_on, previous_ref_last_task}))
if expanded_specs:
previous_ref_last_task = expanded_specs[-1].name
all_specs.extend(expanded_specs)
original_specs = list(graph.all_specs().values())
if original_specs:
if previous_ref_last_task:
first = original_specs[0]
all_specs.append(replace(first, depends_on=tuple({*first.depends_on, previous_ref_last_task})))
else:
all_specs.append(original_specs[0])
for i in range(1, len(original_specs)):
current_task = original_specs[i]
previous_task_name = original_specs[i - 1].name
all_specs.append(
replace(current_task, depends_on=tuple({*current_task.depends_on, previous_task_name}))
)
return Graph.from_specs(all_specs, defaults=graph.defaults)
def parse_ref(self, ref: str, current_cmd: str) -> list[TaskSpec[Any]]:
"""解析单个字符串引用,返回对应的 TaskSpec 列表。"""
if ref == current_cmd:
raise ValueError(f"循环引用: 命令 '{current_cmd}' 引用了自己")
if "." in ref:
cmd_name, task_name = ref.split(".", 1)
if cmd_name not in self.graphs:
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
ref_graph = self.graphs[cmd_name]
if task_name not in ref_graph.all_specs():
raise ValueError(f"任务 '{task_name}' 不存在于命令 '{cmd_name}'")
return [ref_graph.all_specs()[task_name]]
else:
cmd_name = ref
if cmd_name not in self.graphs:
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
ref_graph = self.graphs[cmd_name]
ref_graph = self.expand_refs(ref_graph, cmd_name)
return list(ref_graph.all_specs().values())
def compose(
graphs: dict[str, Graph],
) -> dict[str, Graph]:
"""编程式解析多图的字符串引用,返回展开后的新图映射。
与 :class:`GraphComposer` 等价,但作为独立函数暴露,供不使用
:class:`~pyflowx.runner.CliRunner` 的编程式用户调用。
Examples
--------
>>> graphs = {
... "build": px.Graph.from_specs([px.TaskSpec("b", cmd=["echo", "b"])]),
... "all": px.Graph.from_specs(["build", px.TaskSpec("t", cmd=["echo", "t"])]),
... }
>>> resolved = px.compose(graphs)
>>> "b" in resolved["all"].all_specs()
True
"""
return GraphComposer(graphs).resolve_all()
+4 -4
View File
@@ -15,11 +15,13 @@ import argparse
import enum
import sys
from dataclasses import dataclass, field, replace
from pathlib import Path
from typing import Any, Sequence, get_args
from .compose import GraphComposer
from .errors import PyFlowXError
from .executors import Strategy, run
from .graph import Graph, GraphComposer
from .graph import Graph
from .task import TaskSpec
__all__ = ["CliExitCode", "CliRunner"]
@@ -137,9 +139,7 @@ class CliRunner:
# ------------------------------------------------------------------ #
def _prog_name(self) -> str:
"""从 sys.argv[0] 推导程序名."""
import os
return os.path.basename(sys.argv[0]) if sys.argv else "pyflowx"
return Path(sys.argv[0]).name if sys.argv else "pyflowx"
def create_parser(self) -> argparse.ArgumentParser:
"""创建参数解析器.
+37 -11
View File
@@ -18,8 +18,9 @@ import sys
import time
from abc import ABC, abstractmethod
from collections.abc import Iterator
from contextlib import contextmanager, nullcontext
from pathlib import Path
from typing import Any, Mapping
from typing import Any, ContextManager, Mapping
if sys.version_info >= (3, 12):
from typing import override
@@ -55,6 +56,22 @@ class StateBackend(ABC):
def clear(self) -> None:
"""清除所有存储状态。"""
def flush(self) -> None: # noqa: B027
"""将内存中暂存的状态持久化到外部介质。
默认无操作(如 :class:`MemoryBackend` 无需落盘)。
:class:`JSONBackend` 在 :meth:`batch` 期间会延迟落盘,需在退出时调用。
"""
def batch(self) -> ContextManager[None]:
"""返回一个上下文管理器,期间 :meth:`save` 可延迟 :meth:`flush`。
默认实现为 no-op(如 :class:`MemoryBackend`)。:class:`JSONBackend`
覆盖为:进入时标记延迟,退出时统一 flush 一次,将每任务一次落盘
(N 次写入)降为整次运行一次(O(N) 而非 O(N²))。
"""
return nullcontext()
class _TTLStateBackendMixin(StateBackend):
"""TTL 状态后端共享逻辑。
@@ -158,13 +175,6 @@ class MemoryBackend(_TTLStateBackendMixin):
def _clear_raw(self) -> None:
self._store.clear()
def _expired(self, key: str) -> bool:
"""键是否已过期(兼容旧测试 API)。"""
entry = self._get_raw(key)
if entry is None:
return False
return self._is_expired(entry[1])
class JSONBackend(_TTLStateBackendMixin):
"""基于文件的 JSON 存储,用于跨进程续跑。
@@ -184,6 +194,7 @@ class JSONBackend(_TTLStateBackendMixin):
self._path: str = path
self._ttl = ttl
self._store: dict[str, dict[str, Any]] = {}
self._defer_flush: bool = False
self._load()
def _load(self) -> None:
@@ -244,11 +255,26 @@ class JSONBackend(_TTLStateBackendMixin):
except (TypeError, ValueError) as exc:
raise StorageError(f"result of key {key!r} is not JSON-serialisable", exc) from exc
super().save(key, value)
if not self._defer_flush:
self._flush()
@override
def flush(self) -> None:
self._flush()
def _expired(self, entry: Mapping[str, Any]) -> bool:
"""带元数据的条目是否已过期(兼容旧测试 API)。"""
return self._is_expired(float(entry.get("ts", 0)))
@override
@contextmanager
def batch(self) -> Iterator[None]:
"""进入批量模式:``save`` 暂不落盘,退出时统一 flush 一次。
将整次运行 N 个任务的 N 次全量落盘降为 1 次。
"""
self._defer_flush = True
try:
yield
finally:
self._defer_flush = False
self._flush()
def resolve_backend(backend: StateBackend | None) -> StateBackend:
+68 -101
View File
@@ -17,14 +17,16 @@
from __future__ import annotations
import logging
import os
import shutil
import subprocess
import sys
import threading
from contextlib import contextmanager
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from functools import cached_property
from pathlib import Path
from typing import (
Any,
@@ -67,6 +69,8 @@ TaskCmd = Union[
Strategy = Union[str, "StrategyKind"]
StrategyKind = Any # 占位,避免循环;executors 模块用 Literal 约束
logger = logging.getLogger(__name__)
# 条件判断函数类型:接收依赖上下文(可能为空映射),返回是否应执行。
Condition = Callable[[Context], bool]
@@ -291,13 +295,16 @@ class TaskSpec(Generic[T]):
if self.fn is None and self.cmd is None:
raise ValueError(f"TaskSpec '{self.name}': 必须提供 fn 或 cmd 参数。")
@property
@cached_property
def effective_fn(self) -> TaskFn[T]:
"""获取有效的执行函数。
若提供 ``cmd``,返回包装后的命令执行函数;否则返回 ``fn``。
包装函数在每次调用时从 ``self`` 读取 ``verbose``/``cwd``/``env``/
``timeout``,避免闭包捕获运行期参数,使翻转字段无需重建 spec。
结果按实例缓存(:func:`functools.cached_property`):frozen dataclass
字段不可变,``_wrap_cmd`` 生成的闭包稳定,无需每次访问重建。
"""
if self.cmd is not None:
return self._wrap_cmd()
@@ -306,11 +313,17 @@ class TaskSpec(Generic[T]):
raise ValueError(f"TaskSpec '{self.name}': 没有可执行的函数或命令。") # pragma: no cover
def _wrap_cmd(self) -> TaskFn[Any]:
"""将 cmd 包装为可执行函数。"""
"""将 cmd 包装为可执行函数。
实际执行逻辑位于 :mod:`pyflowx.command`,避免 :class:`TaskSpec`
作为纯数据结构混入命令执行逻辑。
"""
from .command import run_command
spec = self
def _run() -> T:
return cast(T, _run_command(spec))
return cast(T, run_command(spec))
_run.__name__ = spec.name
return _run # type: ignore[return-value]
@@ -368,12 +381,27 @@ class TaskSpec(Generic[T]):
def storage_key(self, context: Context) -> str:
"""计算状态后端存储键。"""
if self.cache_key is not None:
try:
return f"{self.name}:{self.cache_key(context)}"
except Exception:
return self.name
return self.name
if self.cache_key is None:
return self.name
try:
return f"{self.name}:{self.cache_key(context)}"
except (TypeError, ValueError, KeyError, AttributeError) as exc:
# cache_key 抛出预期内的数据/类型异常时回退到 name,但仍记录警告
# 以便用户发现 cache_key 实现中的 bug。
logger.warning(
"task %r: cache_key 回退到 name%s: %s",
self.name,
type(exc).__name__,
exc,
)
return self.name
# 全局锁:序列化对进程级状态(os.environ / os.chdir)的临时修改。
# ``fn`` 任务在 thread/async 策略下并发执行时,若各自配置了不同的
# ``cwd``/``env``,会相互覆盖(os.chdir 与 os.environ 均为进程全局)。
# 该锁仅包裹"切换→执行→恢复"区间,保证正确性;不使用 cwd/env 的任务不受影响。
_env_cwd_lock = threading.RLock()
@contextmanager
@@ -381,100 +409,39 @@ def _env_and_cwd(
env: Mapping[str, str] | None,
cwd: Path | None,
) -> Generator[None, None, None]:
"""临时设置环境变量与工作目录。"""
saved_env: dict[str, str] = {}
saved_cwd: str | None = None
if env:
for k, v in env.items():
if k in os.environ:
saved_env[k] = os.environ[k]
os.environ[k] = v
if cwd is not None:
saved_cwd = str(Path.cwd())
os.chdir(cwd)
try:
"""临时设置环境变量与工作目录。
``os.environ`` 与 ``os.chdir`` 是进程级全局状态,在 thread/async 策略下
并发执行多个带 ``env``/``cwd`` 的 ``fn`` 任务时会相互覆盖。本函数通过
模块级 :data:`_env_cwd_lock` 串行化"切换→执行→恢复"区间,确保正确性。
无 ``env`` 且无 ``cwd`` 时直接 yield,不获取锁。
"""
if not env and cwd is None:
yield
finally:
if saved_cwd is not None:
os.chdir(saved_cwd)
# 恢复环境变量
return
with _env_cwd_lock:
saved_env: dict[str, str] = {}
saved_cwd: str | None = None
if env:
for k in env:
if k in saved_env:
os.environ[k] = saved_env[k]
else:
os.environ.pop(k, None)
def _run_command(spec: TaskSpec[Any]) -> Any: # noqa: PLR0912
"""执行 ``spec.cmd`` 指定的命令(list / shell 字符串 / 可调用对象)。"""
cmd = spec.cmd
verbose = spec.verbose
cwd = spec.cwd
timeout = spec.timeout
env_override = spec.env
# 可调用对象:直接调用,返回其结果。
if callable(cmd) and not isinstance(cmd, (list, str)):
name = getattr(cmd, "__name__", "callable")
if verbose:
print(f"[verbose] 执行可调用命令: {name}", flush=True)
if cwd is not None:
print(f"[verbose] 工作目录: {cwd}", flush=True)
try:
return cmd()
except Exception as e:
raise RuntimeError(f"可调用命令执行异常: {name}: {e}") from e
is_list = isinstance(cmd, list)
if is_list:
cmd_str = " ".join(arg for arg in cmd) # type: ignore[union-attr]
verb = "执行命令"
label = "命令"
else:
cmd_str = cast(str, cmd)
verb = "执行 Shell"
label = "Shell 命令"
if verbose:
print(f"[verbose] {verb}: {cmd_str}", flush=True)
for k, v in env.items():
if k in os.environ:
saved_env[k] = os.environ[k]
os.environ[k] = v
if cwd is not None:
print(f"[verbose] 工作目录: {cwd}", flush=True)
# 合并环境变量
run_env: dict[str, str] | None = None
if env_override:
run_env = dict(os.environ)
run_env.update(env_override)
try:
result = subprocess.run(
cast(Union[str, List[str]], cmd),
shell=not is_list,
cwd=cwd,
env=run_env,
timeout=timeout,
capture_output=not verbose,
text=True,
check=False,
)
except FileNotFoundError:
raise RuntimeError(f"{label}未找到: {cmd_str}") from None
except subprocess.TimeoutExpired:
raise RuntimeError(f"{label}执行超时: {cmd_str} ({timeout}s)") from None
except OSError as e:
raise RuntimeError(f"{label}执行异常: {cmd_str}: {e}") from e
if verbose:
print(f"[verbose] 返回码: {result.returncode}", flush=True)
if result.returncode == 0:
return None
err_msg = f"{label}执行失败: `{cmd_str}`, 返回码: {result.returncode}"
if not verbose and result.stderr.strip():
err_msg += f"\n{result.stderr.strip()}"
raise RuntimeError(err_msg)
saved_cwd = str(Path.cwd())
os.chdir(cwd)
try:
yield
finally:
if saved_cwd is not None:
os.chdir(saved_cwd)
# 恢复环境变量
if env:
for k in env:
if k in saved_env:
os.environ[k] = saved_env[k]
else:
os.environ.pop(k, None)
# ---------------------------------------------------------------------- #
+2 -5
View File
@@ -113,10 +113,7 @@ def write_file(path: str, content: str, encoding: str = "utf-8") -> px.TaskSpec:
"""写入文件任务."""
def write():
try:
with open(path, "w", encoding=encoding) as f:
f.write(content)
except Exception as e:
print(f"写入文件 {path} 失败: {e}")
p = Path(path)
p.write_text(content, encoding=encoding)
return px.TaskSpec(f"write_file_{path}", fn=write, verbose=True)
-107
View File
@@ -1,107 +0,0 @@
"""常用工具函数."""
from __future__ import annotations
__all__ = ["perf_timer"]
import functools
import logging
import time
from collections import defaultdict
from typing import Callable, TypedDict
try:
from typing_extensions import ParamSpec, TypeVar
except ImportError:
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
class _PerformanceMetrics(TypedDict):
"""性能指标."""
count: int
total_time: float
_perf_metrics: defaultdict[str, _PerformanceMetrics] = defaultdict(
lambda: _PerformanceMetrics(
count=0,
total_time=0.0,
)
)
def _generate_report(unit: str, precision: int) -> str:
"""生成性能指标报告,返回报告字符串."""
if not _perf_metrics:
return ""
lines: list[str] = []
lines.append("=" * 50)
lines.append("性能指标报告 (Performance Metrics Report)")
lines.append("-" * 50)
# 按总耗时排序,最耗时的函数排在前面
sorted_metrics = sorted(_perf_metrics.items(), key=lambda x: x[1]["total_time"], reverse=True)
for name, metrics in sorted_metrics:
avg_time = metrics["total_time"] / metrics["count"] if metrics["count"] > 0 else 0
lines.append(
f"{name}: "
f"调用次数={metrics['count']}, "
f"总耗时={metrics['total_time']:.{precision}f}{unit}, "
f"平均耗时={avg_time:.{precision}f}{unit}"
)
lines.append("=" * 50)
report_str = "\n".join(lines)
# 同时输出到日志
logging.info("\n".join(lines))
return report_str
def perf_timer(unit: str = "ms", precision: int = 4, report: bool = False):
"""性能计时器装饰器."""
scale: dict[str, float] = {
"s": 1.0,
"ms": 1000.0,
"us": 1000000.0,
}
def decorator(func: Callable[P, R]) -> Callable[P, R]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
_perf_metrics[func.__name__]["count"] += 1
_perf_metrics[func.__name__]["total_time"] += (end_time - start_time) * scale[unit]
if not report:
logging.info(
f"{func.__name__} {unit}: {_perf_metrics[func.__name__]['total_time']:.{precision}f}{unit}"
)
return result
return wrapper
if report:
import atexit
logging.basicConfig(level=logging.INFO)
logging.info(f"Performance metrics report enabled with unit {unit} and precision {precision}")
@atexit.register
def _report_at_exit() -> None:
"""在程序退出时报告性能指标."""
_generate_report(unit, precision)
# 将报告生成逻辑提取为独立函数,便于测试
return decorator
+56
View File
@@ -301,3 +301,59 @@ def test_dir_exists_false(tmp_path: Path):
missing = tmp_path / "nonexistent"
cond = BuiltinConditions.DIR_EXISTS(missing)
assert cond({}) is False
def test_builtin_is_windows_returns_module_condition():
"""BuiltinConditions.IS_WINDOWS() 应返回模块级 IS_WINDOWS."""
assert BuiltinConditions.IS_WINDOWS() is IS_WINDOWS
def test_builtin_is_linux_returns_module_condition():
"""BuiltinConditions.IS_LINUX() 应返回模块级 IS_LINUX."""
assert BuiltinConditions.IS_LINUX() is IS_LINUX
def test_builtin_is_macos_returns_module_condition():
"""BuiltinConditions.IS_MACOS() 应返回模块级 IS_MACOS."""
assert BuiltinConditions.IS_MACOS() is IS_MACOS
def test_builtin_is_posix_returns_module_condition():
"""BuiltinConditions.IS_POSIX() 应返回模块级 IS_POSIX."""
assert BuiltinConditions.IS_POSIX() is IS_POSIX
def test_file_content_exists_missing_file(tmp_path: Path):
"""FILE_CONTENT_EXISTS 文件不存在时返回 False."""
cond = BuiltinConditions.FILE_CONTENT_EXISTS(tmp_path / "missing.txt", "x")
assert cond({}) is False
def test_file_content_exists_contains_content(tmp_path: Path):
"""FILE_CONTENT_EXISTS 文件包含内容时返回 True."""
f = tmp_path / "f.txt"
f.write_text("hello world", encoding="utf-8")
cond = BuiltinConditions.FILE_CONTENT_EXISTS(f, "world")
assert cond({}) is True
def test_file_content_exists_not_contains_content(tmp_path: Path):
"""FILE_CONTENT_EXISTS 文件不包含内容时返回 False."""
f = tmp_path / "f.txt"
f.write_text("hello", encoding="utf-8")
cond = BuiltinConditions.FILE_CONTENT_EXISTS(f, "missing")
assert cond({}) is False
def test_file_content_exists_decode_error_returns_false(tmp_path: Path):
"""FILE_CONTENT_EXISTS 读取非 UTF-8 文件应返回 False(解码异常被吞)."""
f = tmp_path / "bin.dat"
f.write_bytes(b"\xff\xfe\x00bad")
cond = BuiltinConditions.FILE_CONTENT_EXISTS(f, "x")
assert cond({}) is False
def test_dep_matches_missing_dep_returns_false():
"""DEP_MATCHES 依赖不存在时应返回 False(覆盖 if not in ctx 分支)."""
cond = BuiltinConditions.DEP_MATCHES("missing", lambda _v: True)
assert cond({}) is False
+7 -3
View File
@@ -99,7 +99,10 @@ def test_verbose_run_with_skipped_lifecycle(capsys: pytest.CaptureFixture[str]):
def test_verbose_run_with_user_callback():
"""Test px.run with verbose=True and user callback both called."""
"""Test px.run with verbose=True and user callback both called.
预期事件序列:RUNNING(开始)→ SUCCESS(完成)。
"""
events = []
def on_event(event: px.TaskEvent):
@@ -109,8 +112,9 @@ def test_verbose_run_with_user_callback():
graph = px.Graph.from_specs([spec])
report = px.run(graph, strategy="sequential", verbose=True, on_event=on_event)
assert report.success
assert len(events) == 1
assert events[0].status == px.TaskStatus.SUCCESS
assert len(events) == 2
assert events[0].status == px.TaskStatus.RUNNING
assert events[1].status == px.TaskStatus.SUCCESS
def test_verbose_event_callback_success():
+74 -1
View File
@@ -5,8 +5,8 @@ from __future__ import annotations
import pytest
import pyflowx as px
from pyflowx.compose import GraphComposer, compose
from pyflowx.errors import CycleError, DuplicateTaskError, MissingDependencyError
from pyflowx.graph import GraphComposer, compose
def _fn() -> None:
@@ -319,6 +319,79 @@ def test_compose_function() -> None:
assert "a1" in resolved["cmd_b"]
def test_graph_composer_expand_refs_multiple_refs_chain() -> None:
"""expand_refs 多个 ref 应串联依赖:后一个 ref 首任务依赖前一个 ref 末任务."""
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn)])
graph_c = px.Graph.from_specs([px.TaskSpec("c1", _fn)])
graph_b = px.Graph.from_specs([px.TaskSpec("b1", _fn)])
graph_b._pending_refs = ["cmd_a", "cmd_c"]
composer = GraphComposer({"cmd_a": graph_a, "cmd_c": graph_c, "cmd_b": graph_b})
resolved = composer.resolve_all()
# c1 应依赖 a1(后 ref 首任务依赖前 ref 末任务)
assert "a1" in resolved["cmd_b"]
assert "c1" in resolved["cmd_b"]
assert "b1" in resolved["cmd_b"]
c1_spec = resolved["cmd_b"].all_specs()["c1"]
assert "a1" in c1_spec.depends_on
def test_graph_composer_expand_refs_ref_returns_empty() -> None:
"""expand_refs 引用空图时,previous_ref_last_task 保持 Noneoriginal_specs 走 else 分支."""
graph_empty = px.Graph.from_specs([])
graph_b = px.Graph.from_specs([px.TaskSpec("b1", _fn)])
graph_b._pending_refs = ["empty_cmd"]
composer = GraphComposer({"empty_cmd": graph_empty, "cmd_b": graph_b})
resolved = composer.resolve_all()
# b1 保留,无额外依赖
assert "b1" in resolved["cmd_b"]
b1_spec = resolved["cmd_b"].all_specs()["b1"]
assert b1_spec.depends_on == ()
def test_graph_composer_expand_refs_multiple_original_specs_serialized() -> None:
"""expand_refs 多个 original_specs 应串行依赖,且首个依赖 ref 末任务."""
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn)])
graph_b = px.Graph.from_specs([
px.TaskSpec("b1", _fn),
px.TaskSpec("b2", _fn),
px.TaskSpec("b3", _fn),
])
graph_b._pending_refs = ["cmd_a"]
composer = GraphComposer({"cmd_a": graph_a, "cmd_b": graph_b})
resolved = composer.resolve_all()
specs = resolved["cmd_b"].all_specs()
# b1 依赖 a1ref 末任务)
assert "a1" in specs["b1"].depends_on
# b2 依赖 b1,b3 依赖 b2(串行)
assert "b1" in specs["b2"].depends_on
assert "b2" in specs["b3"].depends_on
def test_graph_composer_parse_ref_dot_notation_success() -> None:
"""parse_ref 'cmd.task' 形式应返回对应单个 TaskSpec."""
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn), px.TaskSpec("a2", _fn)])
composer = GraphComposer({"cmd_a": graph_a})
result = composer.parse_ref("cmd_a.a2", "cmd_b")
assert len(result) == 1
assert result[0].name == "a2"
def test_graph_composer_parse_ref_dot_notation_cmd_not_found() -> None:
"""parse_ref 'missing.task' 形式应检测命令不存在."""
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn)])
composer = GraphComposer({"cmd_a": graph_a})
with pytest.raises(ValueError, match="引用的命令 'missing' 不存在"):
_ = composer.parse_ref("missing.task", "cmd_b")
# ---------------------------------------------------------------------- #
# resolved_spec defaults 测试
# ---------------------------------------------------------------------- #
+8 -8
View File
@@ -70,9 +70,9 @@ def test_memory_backend_ttl_load_filters_expired() -> None:
def test_memory_backend_expired_key_not_in_store() -> None:
"""_expired 对不存在键返回 False."""
"""不存在的键 has 返回 False."""
b = MemoryBackend(ttl=1.0)
assert b._expired("nonexistent") is False
assert b.has("nonexistent") is False
def test_memory_backend_no_ttl_never_expired() -> None:
@@ -244,35 +244,35 @@ def test_json_backend_ttl_load_filters_expired() -> None:
def test_json_backend_expired_no_ttl() -> None:
"""无 TTL 时 _expired 返回 False."""
"""无 TTL 时永不过期."""
with tempfile.TemporaryDirectory() as tmp:
path = str(Path(tmp) / "state.json")
b = JSONBackend(path)
b.save("a", 1)
# 手动修改 ts 为很久以前
b._store["a"]["ts"] = time.time() - 1000
assert b._expired(b._store["a"]) is False # 无 TTL,永不过期
assert b.has("a") is True # 无 TTL,永不过期
def test_json_backend_expired_with_ttl() -> None:
"""有 TTL 时 _expired 检查是否过期."""
"""有 TTL 时过期键 has 返回 False."""
with tempfile.TemporaryDirectory() as tmp:
path = str(Path(tmp) / "state.json")
b = JSONBackend(path, ttl=1.0)
b.save("a", 1)
# 手动修改 ts 为很久以前
b._store["a"]["ts"] = time.time() - 10 # 10 秒前,超过 TTL
assert b._expired(b._store["a"]) is True
assert b.has("a") is False
def test_json_backend_expired_missing_ts() -> None:
"""entry 缺少 ts 时使用默认值 0."""
"""entry 缺少 ts 时视为过期."""
with tempfile.TemporaryDirectory() as tmp:
path = str(Path(tmp) / "state.json")
b = JSONBackend(path, ttl=1.0)
b._store["a"] = {"value": 1} # 缺少 ts
# ts 默认为 0,已经过了很久
assert b._expired(b._store["a"]) is True
assert b.has("a") is False
def test_json_backend_save_value_error(monkeypatch: pytest.MonkeyPatch) -> None:
+56 -1
View File
@@ -2,11 +2,12 @@
import os
import subprocess
from pathlib import Path
import pytest
from pyflowx.conditions import Constants
from pyflowx.tasks.system import clr, reset_icon_cache, setenv, which
from pyflowx.tasks.system import clr, reset_icon_cache, setenv, setenv_group, which, write_file
def test_clr_creates_task_spec() -> None:
@@ -189,3 +190,57 @@ def test_which_not_found(monkeypatch: pytest.MonkeyPatch, capsys: pytest.Capture
spec.fn()
captured = capsys.readouterr()
assert "nonexistent_cmd -> 未找到" in captured.out
def test_write_file_creates_task_spec() -> None:
"""write_file() 应创建带 verbose 的 TaskSpec。"""
spec = write_file("/tmp/unused", "x")
assert spec.name == "write_file_/tmp/unused"
assert spec.verbose is True
def test_write_file_writes_content(tmp_path: Path) -> None:
"""write_file() 应将内容写入指定文件."""
f = tmp_path / "out.txt"
spec = write_file(str(f), "hello world")
assert spec.fn is not None
spec.fn()
assert f.read_text(encoding="utf-8") == "hello world"
def test_write_file_with_encoding(tmp_path: Path) -> None:
"""write_file() 应支持指定编码."""
f = tmp_path / "out.txt"
spec = write_file(str(f), "中文", encoding="utf-8")
assert spec.fn is not None
spec.fn()
assert f.read_text(encoding="utf-8") == "中文"
def test_write_file_failure_propagates(tmp_path: Path) -> None:
"""write_file() 写入失败应抛出异常(不吞异常)."""
# 父目录不存在时写入应抛 FileNotFoundError
missing = tmp_path / "no_such_dir" / "out.txt"
spec = write_file(str(missing), "x")
assert spec.fn is not None
with pytest.raises(FileNotFoundError):
spec.fn()
def test_setenv_group_creates_specs() -> None:
"""setenv_group() 应为每个环境变量创建 TaskSpec."""
envs = {"VAR_A": "1", "VAR_B": "2"}
specs = setenv_group(envs)
assert len(specs) == 2
assert specs[0].name == "setenv_var_a"
assert specs[1].name == "setenv_var_b"
def test_setenv_group_default_mode(monkeypatch: pytest.MonkeyPatch) -> None:
"""setenv_group(default=True) 不应覆盖已存在的环境变量."""
monkeypatch.setenv("PYFLOWX_GROUP_EXISTS", "original")
specs = setenv_group({"PYFLOWX_GROUP_EXISTS": "new"}, default=True)
for spec in specs:
assert spec.fn is not None
spec.fn()
assert os.environ["PYFLOWX_GROUP_EXISTS"] == "original"
+9 -9
View File
@@ -203,10 +203,10 @@ def test_is_cmd_available_callable_returns_true() -> None:
# storage_key 异常处理
# ---------------------------------------------------------------------- #
def test_storage_key_cache_key_exception_returns_name() -> None:
"""cache_key 抛异常时应返回任务名."""
"""cache_key 抛预期异常(TypeError/ValueError/KeyError/AttributeError时应返回任务名."""
def bad_cache_key(_ctx):
raise RuntimeError("cache key error")
raise ValueError("cache key error")
spec = TaskSpec("a", _fn, cache_key=bad_cache_key)
key = spec.storage_key({})
@@ -345,14 +345,14 @@ def test_task_result_default_status() -> None:
# ---------------------------------------------------------------------- #
# _run_command callable 命令测试
# run_command callable 命令测试
# ---------------------------------------------------------------------- #
def test_run_command_callable_verbose_with_cwd(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None:
"""callable 命令 verbose 模式应打印信息."""
spec = TaskSpec("a", cmd=lambda: "result", verbose=True, cwd=tmp_path)
import pyflowx.task as task_module
from pyflowx.command import run_command
result = task_module._run_command(spec)
spec = TaskSpec("a", cmd=lambda: "result", verbose=True, cwd=tmp_path)
result = run_command(spec)
assert result == "result"
captured = capsys.readouterr()
assert "执行可调用命令" in captured.out
@@ -361,8 +361,8 @@ def test_run_command_callable_verbose_with_cwd(capsys: pytest.CaptureFixture[str
def test_run_command_callable_exception() -> None:
"""callable 命令抛异常应转为 RuntimeError."""
spec = TaskSpec("a", cmd=lambda: (_ for _ in ()).throw(RuntimeError("callable error")))
import pyflowx.task as task_module
from pyflowx.command import run_command
spec = TaskSpec("a", cmd=lambda: (_ for _ in ()).throw(RuntimeError("callable error")))
with pytest.raises(RuntimeError, match="可调用命令执行异常"):
task_module._run_command(spec)
run_command(spec)
-65
View File
@@ -1,65 +0,0 @@
import time
import pytest
from pytest_mock import MockerFixture
from pyflowx.utils import _perf_metrics, perf_timer
@pytest.fixture(autouse=True)
def reset_perf_metrics():
"""重置性能指标."""
_perf_metrics.clear()
class TestPerformanceTimer:
def test_perf_timer(self):
@perf_timer()
def test_func():
time.sleep(0.1)
test_func()
assert _perf_metrics["test_func"] is not None
assert _perf_metrics["test_func"]["count"] == 1
assert _perf_metrics["test_func"]["total_time"] >= 0.1
def test_perf_timer_report(self, mocker: MockerFixture):
mock_log = mocker.patch("logging.info")
@perf_timer(report=True, unit="ms", precision=3)
def test_func():
time.sleep(0.1)
test_func()
assert _perf_metrics["test_func"] is not None
assert _perf_metrics["test_func"]["count"] == 1
assert _perf_metrics["test_func"]["total_time"] >= 0.1
assert mock_log.call_count == 1
def test_generate_report(self, mocker: MockerFixture, caplog: pytest.LogCaptureFixture):
mock_log = mocker.patch("logging.info")
from pyflowx.utils import _generate_report
@perf_timer(report=True, unit="ms", precision=3)
def test_func():
time.sleep(0.1)
@perf_timer(report=True, unit="ms", precision=3)
def test_func2():
time.sleep(0.2)
test_func()
test_func2()
_generate_report("ms", 3)
assert mock_log.call_count == 3
assert _perf_metrics["test_func"]["count"] == 1
assert _perf_metrics["test_func"]["total_time"] >= 0.1
assert _perf_metrics["test_func2"]["count"] == 1
assert _perf_metrics["test_func2"]["total_time"] >= 0.2
Generated
+1 -1
View File
@@ -5603,7 +5603,7 @@ pycountry = [
[[package]]
name = "pyflowx"
version = "0.2.10"
version = "0.2.11"
source = { editable = "." }
dependencies = [
{ name = "graphlib-backport", marker = "python_full_version < '3.9'" },