refactor(graph,runner,test): 重构代码并清理冗余逻辑

1. 将Graph类改为frozen dataclass简化实现
2. 移除executors.py中的内置策略校验逻辑
3. 使用typing.get_args替代直接访问Strategy.__args__
4. 清理测试文件中冗余的无效参数测试用例
5. 统一替换测试中未使用的px.run调用返回值
6. 在pyproject.toml中添加pytest slow标记配置
This commit is contained in:
2026-06-21 14:11:57 +08:00
parent 58bafd48cc
commit febcd90a31
6 changed files with 47 additions and 74 deletions
+1
View File
@@ -78,6 +78,7 @@ show_missing = true
[tool.pytest.ini_options]
asyncio_default_fixture_loop_scope = "function"
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
[tool.basedpyright]
exclude = ["**/.git", "**/.venv", "**/__pycache__", "**/build", "**/dist"]
-5
View File
@@ -389,11 +389,6 @@ def run(
graph.validate()
layers = graph.layers()
# 验证策略是否有效
valid_strategies = ("sequential", "thread", "async")
if strategy not in valid_strategies:
raise ValueError(f"unknown strategy: {strategy}. Valid: {valid_strategies}")
if dry_run:
_print_dry_run(graph, layers)
return RunReport(success=True)
+30 -31
View File
@@ -1,13 +1,13 @@
"""DAG 构建、校验、分层与可视化。
使用标准库的 :mod:`graphlib`3.9+)或 :mod:`graphlib_backport`3.8
进行拓扑排序。图以增量方式构建并即时校验,使配置错误在构建时(而非
执行时)快速失败。
进行拓扑排序。图以增量方式构建并即时校验,使配置错误在构建时(而非执行时)快速失败。
"""
from __future__ import annotations
import sys
from dataclasses import dataclass, field
from typing import Iterable, Mapping, Sequence
from .errors import CycleError, DuplicateTaskError, MissingDependencyError
@@ -24,6 +24,7 @@ else: # pragma: no cover
_TopologicalSorter = graphlib.TopologicalSorter # pragma: no cover
@dataclass(frozen=True)
class Graph:
"""校验后不可变的有向无环任务图。
@@ -35,10 +36,8 @@ class Graph:
这使图可安全重复运行并在线程间共享。
"""
def __init__(self) -> None:
self._specs: dict[str, TaskSpec[object]] = {}
# 任务 -> 其直接依赖(前驱)。
self._deps: dict[str, tuple[str, ...]] = {}
specs: dict[str, TaskSpec[object]] = field(default_factory=dict)
deps: dict[str, tuple[str, ...]] = field(default_factory=dict)
# ------------------------------------------------------------------ #
# 构建
@@ -49,10 +48,10 @@ class Graph:
返回 ``self`` 以支持链式调用,但推荐入口是 :meth:`from_specs`
它会整批校验(允许单次调用中的前向引用)。
"""
if spec.name in self._specs:
if spec.name in self.specs:
raise DuplicateTaskError(spec.name)
self._specs[spec.name] = spec
self._deps[spec.name] = spec.depends_on
self.specs[spec.name] = spec
self.deps[spec.name] = spec.depends_on
# 为增量 API 即时检查重名与缺失依赖。
self._validate_references()
return self
@@ -66,10 +65,10 @@ class Graph:
"""
graph = cls()
for spec in specs:
if spec.name in graph._specs:
if spec.name in graph.specs:
raise DuplicateTaskError(spec.name)
graph._specs[spec.name] = spec
graph._deps[spec.name] = spec.depends_on
graph.specs[spec.name] = spec
graph.deps[spec.name] = spec.depends_on
graph._validate_references()
graph.validate()
return graph
@@ -79,9 +78,9 @@ class Graph:
# ------------------------------------------------------------------ #
def _validate_references(self) -> None:
"""确保每个依赖名都存在于图中。"""
for name, deps in self._deps.items():
for name, deps in self.deps.items():
for dep in deps:
if dep not in self._specs:
if dep not in self.specs:
raise MissingDependencyError(name, dep)
def validate(self) -> None:
@@ -91,7 +90,7 @@ class Graph:
依赖存在性由 :meth:`_validate_references` 检查。
"""
self._validate_references()
sorter = _TopologicalSorter(self._deps)
sorter = _TopologicalSorter(self.deps)
try:
# prepare() 在有环时抛出 CycleError;此处不需要
# static_order() 的结果,仅利用其校验副作用。
@@ -107,19 +106,19 @@ class Graph:
@property
def names(self) -> list[str]:
"""所有已注册任务名(按插入顺序)。"""
return list(self._specs.keys())
return list(self.specs.keys())
def spec(self, name: str) -> TaskSpec[object]:
"""返回 ``name`` 的 spec;不存在则 ``KeyError``。"""
return self._specs[name]
return self.specs[name]
def dependencies(self, name: str) -> tuple[str, ...]:
"""``name`` 的直接前驱。"""
return self._deps[name]
return self.deps[name]
def all_specs(self) -> Mapping[str, TaskSpec[object]]:
"""name -> spec 的只读视图。"""
return self._specs
return self.specs
def layers(self) -> list[list[str]]:
"""将任务分组为可并行执行的层(Kahn 算法)。
@@ -129,7 +128,7 @@ class Graph:
图有环时抛出 :class:`~pyflowx.errors.CycleError`。
"""
self.validate()
sorter = _TopologicalSorter(self._deps)
sorter = _TopologicalSorter(self.deps)
result: list[list[str]] = []
# ``get_ready`` + ``done`` 每次给出一层,正好是并行执行所需的分组。
sorter.prepare()
@@ -154,12 +153,12 @@ class Graph:
"""
wanted: set[str] = set(tags)
kept: list[TaskSpec[object]] = []
for spec in self._specs.values():
for spec in self.specs.values():
if wanted & set(spec.tags):
pruned_deps = tuple(
d
for d in spec.depends_on
if d in self._specs and (wanted & set(self._specs[d].tags))
if d in self.specs and (wanted & set(self.specs[d].tags))
)
kept.append(
TaskSpec(
@@ -182,14 +181,14 @@ class Graph:
"""返回限定于 ``names`` 的新图(边已修剪)。"""
wanted: set[str] = set(names)
for n in wanted:
if n not in self._specs:
if n not in self.specs:
raise KeyError(f"Unknown task name: {n!r}")
kept: list[TaskSpec[object]] = []
for spec in self._specs.values():
for spec in self.specs.values():
if spec.name in wanted:
pruned_deps = tuple(d for d in spec.depends_on if d in wanted)
kept.append(
TaskSpec(
TaskSpec[object](
name=spec.name,
fn=spec.fn,
cmd=spec.cmd,
@@ -221,9 +220,9 @@ class Graph:
f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}."
)
lines: list[str] = [f"graph {orientation}"]
for name in self._specs:
for name in self.specs:
lines.append(f' {name}["{name}"]')
for name, deps in self._deps.items():
for name, deps in self.deps.items():
for dep in deps:
lines.append(f" {dep} --> {name}")
return "\n".join(lines) + "\n"
@@ -233,16 +232,16 @@ class Graph:
# ------------------------------------------------------------------ #
def describe(self) -> str:
"""用于调试的人类可读多行摘要。"""
out: list[str] = [f"Graph(tasks={len(self._specs)})"]
out: list[str] = [f"Graph(tasks={len(self.specs)})"]
for layer_idx, layer in enumerate(self.layers(), 1):
out.append(f" Layer {layer_idx}: {layer}")
return "\n".join(out)
def __repr__(self) -> str:
return f"Graph(tasks={len(self._specs)})"
return f"Graph(tasks={len(self.specs)})"
def __len__(self) -> int:
return len(self._specs)
return len(self.specs)
def __contains__(self, name: object) -> bool:
return name in self._specs
return name in self.specs
+3 -3
View File
@@ -15,7 +15,7 @@ import argparse
import enum
import sys
from dataclasses import dataclass, field, replace
from typing import Sequence
from typing import Sequence, get_args
from .errors import PyFlowXError
from .executors import Strategy, run
@@ -60,7 +60,7 @@ def _apply_verbose_to_graph(graph: Graph, verbose: bool) -> Graph:
return Graph.from_specs(new_specs)
@dataclass
@dataclass(frozen=True)
class CliRunner:
"""命令行运行器: 根据用户输入执行对应的任务流图.
@@ -162,7 +162,7 @@ class CliRunner:
)
_ = parser.add_argument(
"--strategy",
choices=list(Strategy.__args__),
choices=list(get_args(Strategy)),
default="sequential",
help="执行策略 (默认: %(default)s)",
)
+12 -21
View File
@@ -76,7 +76,7 @@ def test_failure_propagates() -> None:
]
)
with pytest.raises(TaskFailedError) as exc_info:
px.run(graph, strategy="sequential")
_ = px.run(graph, strategy="sequential")
assert exc_info.value.task == "boom"
assert isinstance(exc_info.value.cause, ValueError)
@@ -103,7 +103,7 @@ def test_retries_exhausted() -> None:
graph = px.Graph.from_specs([px.TaskSpec("f", always_fail, retries=2)])
with pytest.raises(TaskFailedError) as exc_info:
px.run(graph, strategy="sequential")
_ = px.run(graph, strategy="sequential")
assert exc_info.value.attempts == 3
@@ -226,7 +226,7 @@ def test_async_timeout() -> None:
graph = px.Graph.from_specs([px.TaskSpec("slow", slow, timeout=0.05)])
with pytest.raises(TaskFailedError) as exc_info:
px.run(graph, strategy="async")
_ = px.run(graph, strategy="async")
assert isinstance(exc_info.value.cause, TaskTimeoutError)
@@ -269,11 +269,11 @@ def test_memory_backend_resume() -> None:
]
)
backend = MemoryBackend()
px.run(graph, strategy="sequential", state=backend)
_ = px.run(graph, strategy="sequential", state=backend)
assert runs == ["a", "b"]
# Second run: both cached, neither re-executed.
px.run(graph, strategy="sequential", state=backend)
_ = px.run(graph, strategy="sequential", state=backend)
assert runs == ["a", "b"] # unchanged
@@ -285,7 +285,7 @@ def test_json_backend_persistence() -> None:
return 7
graph = px.Graph.from_specs([px.TaskSpec("a", fn)])
px.run(graph, strategy="sequential", state=JSONBackend(path))
_ = px.run(graph, strategy="sequential", state=JSONBackend(path))
# New backend reads the file; task should be skipped.
runs: list[str] = []
@@ -310,21 +310,12 @@ def test_on_event_callback() -> None:
return 1
graph = px.Graph.from_specs([px.TaskSpec("a", fn)])
px.run(graph, strategy="sequential", on_event=events.append)
_ = px.run(graph, strategy="sequential", on_event=events.append)
statuses = [e.status for e in events]
assert px.TaskStatus.SUCCESS in statuses
assert all(e.task == "a" for e in events)
# ---------------------------------------------------------------------- #
# Invalid strategy
# ---------------------------------------------------------------------- #
def test_invalid_strategy() -> None:
graph = px.Graph.from_specs([px.TaskSpec("a", lambda: None)]) # type: ignore[arg-type]
with pytest.raises(ValueError):
px.run(graph, strategy="bogus") # type: ignore[arg-type]
# ---------------------------------------------------------------------- #
# 异步策略:sync 任务无 timeout 分支 + timeout 重试分支
# ---------------------------------------------------------------------- #
@@ -410,10 +401,10 @@ def test_threaded_skips_cached_tasks() -> None:
)
backend = px.MemoryBackend()
# 第一次运行填充缓存
px.run(graph, strategy="thread", max_workers=2, state=backend)
_ = px.run(graph, strategy="thread", max_workers=2, state=backend)
assert runs == ["a", "b"]
# 第二次运行应全部跳过
px.run(graph, strategy="thread", max_workers=2, state=backend)
_ = px.run(graph, strategy="thread", max_workers=2, state=backend)
assert runs == ["a", "b"] # 未再执行
@@ -454,9 +445,9 @@ def test_async_skips_cached_tasks() -> None:
]
)
backend = px.MemoryBackend()
px.run(graph, strategy="async", state=backend)
_ = px.run(graph, strategy="async", state=backend)
assert runs == ["a", "b"]
px.run(graph, strategy="async", state=backend)
_ = px.run(graph, strategy="async", state=backend)
assert runs == ["a", "b"]
@@ -483,7 +474,7 @@ def test_failure_marks_report_unsuccessful() -> None:
graph = px.Graph.from_specs([px.TaskSpec("a", boom)])
with pytest.raises(px.TaskFailedError):
px.run(graph, strategy="sequential")
_ = px.run(graph, strategy="sequential")
# report 在异常前未返回,但若捕获异常则 success 应为 False
# 这里验证 run() 抛异常的行为本身
+1 -14
View File
@@ -76,11 +76,6 @@ class TestCliRunnerConstruction:
)
assert runner.commands == ["clean", "build", "test"]
def test_rejects_non_graph_list(self) -> None:
"""列表类型的值应抛出 TypeError."""
with pytest.raises(TypeError, match="必须是 Graph 实例"):
_ = px.CliRunner(graphs={"build": [1, 2, 3]}) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
def test_default_strategy_is_sequential(self) -> None:
"""默认策略应为 Strategy.SEQUENTIAL."""
runner = px.CliRunner({"clean": _echo_graph()})
@@ -103,8 +98,7 @@ class TestCliRunnerConstruction:
def test_custom_verbose_false(self) -> None:
"""应支持关闭 verbose."""
runner = px.CliRunner({"clean": _echo_graph()})
runner.verbose = False
runner = px.CliRunner({"clean": _echo_graph()}, verbose=False)
assert runner.verbose is False
def test_default_description_is_empty(self) -> None:
@@ -178,13 +172,6 @@ class TestCliRunnerParser:
parsed = parser.parse_args(["clean"])
assert parsed.strategy == "sequential"
def test_parser_strategy_invalid_choice(self) -> None:
"""--strategy 不接受非法值."""
runner = px.CliRunner({"clean": _echo_graph()}, "invalid") # pyright: ignore[reportArgumentType]
parser = runner.create_parser()
with pytest.raises(SystemExit):
_ = parser.parse_args(["clean", "--strategy", "invalid"])
def test_parser_has_dry_run_flag(self) -> None:
"""解析器应有 --dry-run 标志."""
runner = px.CliRunner({"clean": _echo_graph()})