refactor(graph,runner,test): 重构代码并清理冗余逻辑

1. 将Graph类改为frozen dataclass简化实现
2. 移除executors.py中的内置策略校验逻辑
3. 使用typing.get_args替代直接访问Strategy.__args__
4. 清理测试文件中冗余的无效参数测试用例
5. 统一替换测试中未使用的px.run调用返回值
6. 在pyproject.toml中添加pytest slow标记配置
This commit is contained in:
2026-06-21 14:11:57 +08:00
parent 58bafd48cc
commit febcd90a31
6 changed files with 47 additions and 74 deletions
+1
View File
@@ -78,6 +78,7 @@ show_missing = true
[tool.pytest.ini_options] [tool.pytest.ini_options]
asyncio_default_fixture_loop_scope = "function" asyncio_default_fixture_loop_scope = "function"
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
[tool.basedpyright] [tool.basedpyright]
exclude = ["**/.git", "**/.venv", "**/__pycache__", "**/build", "**/dist"] exclude = ["**/.git", "**/.venv", "**/__pycache__", "**/build", "**/dist"]
-5
View File
@@ -389,11 +389,6 @@ def run(
graph.validate() graph.validate()
layers = graph.layers() layers = graph.layers()
# 验证策略是否有效
valid_strategies = ("sequential", "thread", "async")
if strategy not in valid_strategies:
raise ValueError(f"unknown strategy: {strategy}. Valid: {valid_strategies}")
if dry_run: if dry_run:
_print_dry_run(graph, layers) _print_dry_run(graph, layers)
return RunReport(success=True) return RunReport(success=True)
+30 -31
View File
@@ -1,13 +1,13 @@
"""DAG 构建、校验、分层与可视化。 """DAG 构建、校验、分层与可视化。
使用标准库的 :mod:`graphlib`3.9+)或 :mod:`graphlib_backport`3.8 使用标准库的 :mod:`graphlib`3.9+)或 :mod:`graphlib_backport`3.8
进行拓扑排序。图以增量方式构建并即时校验,使配置错误在构建时(而非 进行拓扑排序。图以增量方式构建并即时校验,使配置错误在构建时(而非执行时)快速失败。
执行时)快速失败。
""" """
from __future__ import annotations from __future__ import annotations
import sys import sys
from dataclasses import dataclass, field
from typing import Iterable, Mapping, Sequence from typing import Iterable, Mapping, Sequence
from .errors import CycleError, DuplicateTaskError, MissingDependencyError from .errors import CycleError, DuplicateTaskError, MissingDependencyError
@@ -24,6 +24,7 @@ else: # pragma: no cover
_TopologicalSorter = graphlib.TopologicalSorter # pragma: no cover _TopologicalSorter = graphlib.TopologicalSorter # pragma: no cover
@dataclass(frozen=True)
class Graph: class Graph:
"""校验后不可变的有向无环任务图。 """校验后不可变的有向无环任务图。
@@ -35,10 +36,8 @@ class Graph:
这使图可安全重复运行并在线程间共享。 这使图可安全重复运行并在线程间共享。
""" """
def __init__(self) -> None: specs: dict[str, TaskSpec[object]] = field(default_factory=dict)
self._specs: dict[str, TaskSpec[object]] = {} deps: dict[str, tuple[str, ...]] = field(default_factory=dict)
# 任务 -> 其直接依赖(前驱)。
self._deps: dict[str, tuple[str, ...]] = {}
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
# 构建 # 构建
@@ -49,10 +48,10 @@ class Graph:
返回 ``self`` 以支持链式调用,但推荐入口是 :meth:`from_specs` 返回 ``self`` 以支持链式调用,但推荐入口是 :meth:`from_specs`
它会整批校验(允许单次调用中的前向引用)。 它会整批校验(允许单次调用中的前向引用)。
""" """
if spec.name in self._specs: if spec.name in self.specs:
raise DuplicateTaskError(spec.name) raise DuplicateTaskError(spec.name)
self._specs[spec.name] = spec self.specs[spec.name] = spec
self._deps[spec.name] = spec.depends_on self.deps[spec.name] = spec.depends_on
# 为增量 API 即时检查重名与缺失依赖。 # 为增量 API 即时检查重名与缺失依赖。
self._validate_references() self._validate_references()
return self return self
@@ -66,10 +65,10 @@ class Graph:
""" """
graph = cls() graph = cls()
for spec in specs: for spec in specs:
if spec.name in graph._specs: if spec.name in graph.specs:
raise DuplicateTaskError(spec.name) raise DuplicateTaskError(spec.name)
graph._specs[spec.name] = spec graph.specs[spec.name] = spec
graph._deps[spec.name] = spec.depends_on graph.deps[spec.name] = spec.depends_on
graph._validate_references() graph._validate_references()
graph.validate() graph.validate()
return graph return graph
@@ -79,9 +78,9 @@ class Graph:
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def _validate_references(self) -> None: def _validate_references(self) -> None:
"""确保每个依赖名都存在于图中。""" """确保每个依赖名都存在于图中。"""
for name, deps in self._deps.items(): for name, deps in self.deps.items():
for dep in deps: for dep in deps:
if dep not in self._specs: if dep not in self.specs:
raise MissingDependencyError(name, dep) raise MissingDependencyError(name, dep)
def validate(self) -> None: def validate(self) -> None:
@@ -91,7 +90,7 @@ class Graph:
依赖存在性由 :meth:`_validate_references` 检查。 依赖存在性由 :meth:`_validate_references` 检查。
""" """
self._validate_references() self._validate_references()
sorter = _TopologicalSorter(self._deps) sorter = _TopologicalSorter(self.deps)
try: try:
# prepare() 在有环时抛出 CycleError;此处不需要 # prepare() 在有环时抛出 CycleError;此处不需要
# static_order() 的结果,仅利用其校验副作用。 # static_order() 的结果,仅利用其校验副作用。
@@ -107,19 +106,19 @@ class Graph:
@property @property
def names(self) -> list[str]: def names(self) -> list[str]:
"""所有已注册任务名(按插入顺序)。""" """所有已注册任务名(按插入顺序)。"""
return list(self._specs.keys()) return list(self.specs.keys())
def spec(self, name: str) -> TaskSpec[object]: def spec(self, name: str) -> TaskSpec[object]:
"""返回 ``name`` 的 spec;不存在则 ``KeyError``。""" """返回 ``name`` 的 spec;不存在则 ``KeyError``。"""
return self._specs[name] return self.specs[name]
def dependencies(self, name: str) -> tuple[str, ...]: def dependencies(self, name: str) -> tuple[str, ...]:
"""``name`` 的直接前驱。""" """``name`` 的直接前驱。"""
return self._deps[name] return self.deps[name]
def all_specs(self) -> Mapping[str, TaskSpec[object]]: def all_specs(self) -> Mapping[str, TaskSpec[object]]:
"""name -> spec 的只读视图。""" """name -> spec 的只读视图。"""
return self._specs return self.specs
def layers(self) -> list[list[str]]: def layers(self) -> list[list[str]]:
"""将任务分组为可并行执行的层(Kahn 算法)。 """将任务分组为可并行执行的层(Kahn 算法)。
@@ -129,7 +128,7 @@ class Graph:
图有环时抛出 :class:`~pyflowx.errors.CycleError`。 图有环时抛出 :class:`~pyflowx.errors.CycleError`。
""" """
self.validate() self.validate()
sorter = _TopologicalSorter(self._deps) sorter = _TopologicalSorter(self.deps)
result: list[list[str]] = [] result: list[list[str]] = []
# ``get_ready`` + ``done`` 每次给出一层,正好是并行执行所需的分组。 # ``get_ready`` + ``done`` 每次给出一层,正好是并行执行所需的分组。
sorter.prepare() sorter.prepare()
@@ -154,12 +153,12 @@ class Graph:
""" """
wanted: set[str] = set(tags) wanted: set[str] = set(tags)
kept: list[TaskSpec[object]] = [] kept: list[TaskSpec[object]] = []
for spec in self._specs.values(): for spec in self.specs.values():
if wanted & set(spec.tags): if wanted & set(spec.tags):
pruned_deps = tuple( pruned_deps = tuple(
d d
for d in spec.depends_on for d in spec.depends_on
if d in self._specs and (wanted & set(self._specs[d].tags)) if d in self.specs and (wanted & set(self.specs[d].tags))
) )
kept.append( kept.append(
TaskSpec( TaskSpec(
@@ -182,14 +181,14 @@ class Graph:
"""返回限定于 ``names`` 的新图(边已修剪)。""" """返回限定于 ``names`` 的新图(边已修剪)。"""
wanted: set[str] = set(names) wanted: set[str] = set(names)
for n in wanted: for n in wanted:
if n not in self._specs: if n not in self.specs:
raise KeyError(f"Unknown task name: {n!r}") raise KeyError(f"Unknown task name: {n!r}")
kept: list[TaskSpec[object]] = [] kept: list[TaskSpec[object]] = []
for spec in self._specs.values(): for spec in self.specs.values():
if spec.name in wanted: if spec.name in wanted:
pruned_deps = tuple(d for d in spec.depends_on if d in wanted) pruned_deps = tuple(d for d in spec.depends_on if d in wanted)
kept.append( kept.append(
TaskSpec( TaskSpec[object](
name=spec.name, name=spec.name,
fn=spec.fn, fn=spec.fn,
cmd=spec.cmd, cmd=spec.cmd,
@@ -221,9 +220,9 @@ class Graph:
f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}." f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}."
) )
lines: list[str] = [f"graph {orientation}"] lines: list[str] = [f"graph {orientation}"]
for name in self._specs: for name in self.specs:
lines.append(f' {name}["{name}"]') lines.append(f' {name}["{name}"]')
for name, deps in self._deps.items(): for name, deps in self.deps.items():
for dep in deps: for dep in deps:
lines.append(f" {dep} --> {name}") lines.append(f" {dep} --> {name}")
return "\n".join(lines) + "\n" return "\n".join(lines) + "\n"
@@ -233,16 +232,16 @@ class Graph:
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
def describe(self) -> str: def describe(self) -> str:
"""用于调试的人类可读多行摘要。""" """用于调试的人类可读多行摘要。"""
out: list[str] = [f"Graph(tasks={len(self._specs)})"] out: list[str] = [f"Graph(tasks={len(self.specs)})"]
for layer_idx, layer in enumerate(self.layers(), 1): for layer_idx, layer in enumerate(self.layers(), 1):
out.append(f" Layer {layer_idx}: {layer}") out.append(f" Layer {layer_idx}: {layer}")
return "\n".join(out) return "\n".join(out)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"Graph(tasks={len(self._specs)})" return f"Graph(tasks={len(self.specs)})"
def __len__(self) -> int: def __len__(self) -> int:
return len(self._specs) return len(self.specs)
def __contains__(self, name: object) -> bool: def __contains__(self, name: object) -> bool:
return name in self._specs return name in self.specs
+3 -3
View File
@@ -15,7 +15,7 @@ import argparse
import enum import enum
import sys import sys
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from typing import Sequence from typing import Sequence, get_args
from .errors import PyFlowXError from .errors import PyFlowXError
from .executors import Strategy, run from .executors import Strategy, run
@@ -60,7 +60,7 @@ def _apply_verbose_to_graph(graph: Graph, verbose: bool) -> Graph:
return Graph.from_specs(new_specs) return Graph.from_specs(new_specs)
@dataclass @dataclass(frozen=True)
class CliRunner: class CliRunner:
"""命令行运行器: 根据用户输入执行对应的任务流图. """命令行运行器: 根据用户输入执行对应的任务流图.
@@ -162,7 +162,7 @@ class CliRunner:
) )
_ = parser.add_argument( _ = parser.add_argument(
"--strategy", "--strategy",
choices=list(Strategy.__args__), choices=list(get_args(Strategy)),
default="sequential", default="sequential",
help="执行策略 (默认: %(default)s)", help="执行策略 (默认: %(default)s)",
) )
+12 -21
View File
@@ -76,7 +76,7 @@ def test_failure_propagates() -> None:
] ]
) )
with pytest.raises(TaskFailedError) as exc_info: with pytest.raises(TaskFailedError) as exc_info:
px.run(graph, strategy="sequential") _ = px.run(graph, strategy="sequential")
assert exc_info.value.task == "boom" assert exc_info.value.task == "boom"
assert isinstance(exc_info.value.cause, ValueError) assert isinstance(exc_info.value.cause, ValueError)
@@ -103,7 +103,7 @@ def test_retries_exhausted() -> None:
graph = px.Graph.from_specs([px.TaskSpec("f", always_fail, retries=2)]) graph = px.Graph.from_specs([px.TaskSpec("f", always_fail, retries=2)])
with pytest.raises(TaskFailedError) as exc_info: with pytest.raises(TaskFailedError) as exc_info:
px.run(graph, strategy="sequential") _ = px.run(graph, strategy="sequential")
assert exc_info.value.attempts == 3 assert exc_info.value.attempts == 3
@@ -226,7 +226,7 @@ def test_async_timeout() -> None:
graph = px.Graph.from_specs([px.TaskSpec("slow", slow, timeout=0.05)]) graph = px.Graph.from_specs([px.TaskSpec("slow", slow, timeout=0.05)])
with pytest.raises(TaskFailedError) as exc_info: with pytest.raises(TaskFailedError) as exc_info:
px.run(graph, strategy="async") _ = px.run(graph, strategy="async")
assert isinstance(exc_info.value.cause, TaskTimeoutError) assert isinstance(exc_info.value.cause, TaskTimeoutError)
@@ -269,11 +269,11 @@ def test_memory_backend_resume() -> None:
] ]
) )
backend = MemoryBackend() backend = MemoryBackend()
px.run(graph, strategy="sequential", state=backend) _ = px.run(graph, strategy="sequential", state=backend)
assert runs == ["a", "b"] assert runs == ["a", "b"]
# Second run: both cached, neither re-executed. # Second run: both cached, neither re-executed.
px.run(graph, strategy="sequential", state=backend) _ = px.run(graph, strategy="sequential", state=backend)
assert runs == ["a", "b"] # unchanged assert runs == ["a", "b"] # unchanged
@@ -285,7 +285,7 @@ def test_json_backend_persistence() -> None:
return 7 return 7
graph = px.Graph.from_specs([px.TaskSpec("a", fn)]) graph = px.Graph.from_specs([px.TaskSpec("a", fn)])
px.run(graph, strategy="sequential", state=JSONBackend(path)) _ = px.run(graph, strategy="sequential", state=JSONBackend(path))
# New backend reads the file; task should be skipped. # New backend reads the file; task should be skipped.
runs: list[str] = [] runs: list[str] = []
@@ -310,21 +310,12 @@ def test_on_event_callback() -> None:
return 1 return 1
graph = px.Graph.from_specs([px.TaskSpec("a", fn)]) graph = px.Graph.from_specs([px.TaskSpec("a", fn)])
px.run(graph, strategy="sequential", on_event=events.append) _ = px.run(graph, strategy="sequential", on_event=events.append)
statuses = [e.status for e in events] statuses = [e.status for e in events]
assert px.TaskStatus.SUCCESS in statuses assert px.TaskStatus.SUCCESS in statuses
assert all(e.task == "a" for e in events) assert all(e.task == "a" for e in events)
# ---------------------------------------------------------------------- #
# Invalid strategy
# ---------------------------------------------------------------------- #
def test_invalid_strategy() -> None:
graph = px.Graph.from_specs([px.TaskSpec("a", lambda: None)]) # type: ignore[arg-type]
with pytest.raises(ValueError):
px.run(graph, strategy="bogus") # type: ignore[arg-type]
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# 异步策略:sync 任务无 timeout 分支 + timeout 重试分支 # 异步策略:sync 任务无 timeout 分支 + timeout 重试分支
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
@@ -410,10 +401,10 @@ def test_threaded_skips_cached_tasks() -> None:
) )
backend = px.MemoryBackend() backend = px.MemoryBackend()
# 第一次运行填充缓存 # 第一次运行填充缓存
px.run(graph, strategy="thread", max_workers=2, state=backend) _ = px.run(graph, strategy="thread", max_workers=2, state=backend)
assert runs == ["a", "b"] assert runs == ["a", "b"]
# 第二次运行应全部跳过 # 第二次运行应全部跳过
px.run(graph, strategy="thread", max_workers=2, state=backend) _ = px.run(graph, strategy="thread", max_workers=2, state=backend)
assert runs == ["a", "b"] # 未再执行 assert runs == ["a", "b"] # 未再执行
@@ -454,9 +445,9 @@ def test_async_skips_cached_tasks() -> None:
] ]
) )
backend = px.MemoryBackend() backend = px.MemoryBackend()
px.run(graph, strategy="async", state=backend) _ = px.run(graph, strategy="async", state=backend)
assert runs == ["a", "b"] assert runs == ["a", "b"]
px.run(graph, strategy="async", state=backend) _ = px.run(graph, strategy="async", state=backend)
assert runs == ["a", "b"] assert runs == ["a", "b"]
@@ -483,7 +474,7 @@ def test_failure_marks_report_unsuccessful() -> None:
graph = px.Graph.from_specs([px.TaskSpec("a", boom)]) graph = px.Graph.from_specs([px.TaskSpec("a", boom)])
with pytest.raises(px.TaskFailedError): with pytest.raises(px.TaskFailedError):
px.run(graph, strategy="sequential") _ = px.run(graph, strategy="sequential")
# report 在异常前未返回,但若捕获异常则 success 应为 False # report 在异常前未返回,但若捕获异常则 success 应为 False
# 这里验证 run() 抛异常的行为本身 # 这里验证 run() 抛异常的行为本身
+1 -14
View File
@@ -76,11 +76,6 @@ class TestCliRunnerConstruction:
) )
assert runner.commands == ["clean", "build", "test"] assert runner.commands == ["clean", "build", "test"]
def test_rejects_non_graph_list(self) -> None:
"""列表类型的值应抛出 TypeError."""
with pytest.raises(TypeError, match="必须是 Graph 实例"):
_ = px.CliRunner(graphs={"build": [1, 2, 3]}) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
def test_default_strategy_is_sequential(self) -> None: def test_default_strategy_is_sequential(self) -> None:
"""默认策略应为 Strategy.SEQUENTIAL.""" """默认策略应为 Strategy.SEQUENTIAL."""
runner = px.CliRunner({"clean": _echo_graph()}) runner = px.CliRunner({"clean": _echo_graph()})
@@ -103,8 +98,7 @@ class TestCliRunnerConstruction:
def test_custom_verbose_false(self) -> None: def test_custom_verbose_false(self) -> None:
"""应支持关闭 verbose.""" """应支持关闭 verbose."""
runner = px.CliRunner({"clean": _echo_graph()}) runner = px.CliRunner({"clean": _echo_graph()}, verbose=False)
runner.verbose = False
assert runner.verbose is False assert runner.verbose is False
def test_default_description_is_empty(self) -> None: def test_default_description_is_empty(self) -> None:
@@ -178,13 +172,6 @@ class TestCliRunnerParser:
parsed = parser.parse_args(["clean"]) parsed = parser.parse_args(["clean"])
assert parsed.strategy == "sequential" assert parsed.strategy == "sequential"
def test_parser_strategy_invalid_choice(self) -> None:
"""--strategy 不接受非法值."""
runner = px.CliRunner({"clean": _echo_graph()}, "invalid") # pyright: ignore[reportArgumentType]
parser = runner.create_parser()
with pytest.raises(SystemExit):
_ = parser.parse_args(["clean", "--strategy", "invalid"])
def test_parser_has_dry_run_flag(self) -> None: def test_parser_has_dry_run_flag(self) -> None:
"""解析器应有 --dry-run 标志.""" """解析器应有 --dry-run 标志."""
runner = px.CliRunner({"clean": _echo_graph()}) runner = px.CliRunner({"clean": _echo_graph()})