refactor(graph,runner,test): 重构代码并清理冗余逻辑
1. 将Graph类改为frozen dataclass简化实现 2. 移除executors.py中的内置策略校验逻辑 3. 使用typing.get_args替代直接访问Strategy.__args__ 4. 清理测试文件中冗余的无效参数测试用例 5. 统一替换测试中未使用的px.run调用返回值 6. 在pyproject.toml中添加pytest slow标记配置
This commit is contained in:
@@ -78,6 +78,7 @@ show_missing = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
|
||||
|
||||
[tool.basedpyright]
|
||||
exclude = ["**/.git", "**/.venv", "**/__pycache__", "**/build", "**/dist"]
|
||||
|
||||
@@ -389,11 +389,6 @@ def run(
|
||||
graph.validate()
|
||||
layers = graph.layers()
|
||||
|
||||
# 验证策略是否有效
|
||||
valid_strategies = ("sequential", "thread", "async")
|
||||
if strategy not in valid_strategies:
|
||||
raise ValueError(f"unknown strategy: {strategy}. Valid: {valid_strategies}")
|
||||
|
||||
if dry_run:
|
||||
_print_dry_run(graph, layers)
|
||||
return RunReport(success=True)
|
||||
|
||||
+30
-31
@@ -1,13 +1,13 @@
|
||||
"""DAG 构建、校验、分层与可视化。
|
||||
|
||||
使用标准库的 :mod:`graphlib`(3.9+)或 :mod:`graphlib_backport`(3.8)
|
||||
进行拓扑排序。图以增量方式构建并即时校验,使配置错误在构建时(而非
|
||||
执行时)快速失败。
|
||||
进行拓扑排序。图以增量方式构建并即时校验,使配置错误在构建时(而非执行时)快速失败。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable, Mapping, Sequence
|
||||
|
||||
from .errors import CycleError, DuplicateTaskError, MissingDependencyError
|
||||
@@ -24,6 +24,7 @@ else: # pragma: no cover
|
||||
_TopologicalSorter = graphlib.TopologicalSorter # pragma: no cover
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Graph:
|
||||
"""校验后不可变的有向无环任务图。
|
||||
|
||||
@@ -35,10 +36,8 @@ class Graph:
|
||||
这使图可安全重复运行并在线程间共享。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._specs: dict[str, TaskSpec[object]] = {}
|
||||
# 任务 -> 其直接依赖(前驱)。
|
||||
self._deps: dict[str, tuple[str, ...]] = {}
|
||||
specs: dict[str, TaskSpec[object]] = field(default_factory=dict)
|
||||
deps: dict[str, tuple[str, ...]] = field(default_factory=dict)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 构建
|
||||
@@ -49,10 +48,10 @@ class Graph:
|
||||
返回 ``self`` 以支持链式调用,但推荐入口是 :meth:`from_specs`,
|
||||
它会整批校验(允许单次调用中的前向引用)。
|
||||
"""
|
||||
if spec.name in self._specs:
|
||||
if spec.name in self.specs:
|
||||
raise DuplicateTaskError(spec.name)
|
||||
self._specs[spec.name] = spec
|
||||
self._deps[spec.name] = spec.depends_on
|
||||
self.specs[spec.name] = spec
|
||||
self.deps[spec.name] = spec.depends_on
|
||||
# 为增量 API 即时检查重名与缺失依赖。
|
||||
self._validate_references()
|
||||
return self
|
||||
@@ -66,10 +65,10 @@ class Graph:
|
||||
"""
|
||||
graph = cls()
|
||||
for spec in specs:
|
||||
if spec.name in graph._specs:
|
||||
if spec.name in graph.specs:
|
||||
raise DuplicateTaskError(spec.name)
|
||||
graph._specs[spec.name] = spec
|
||||
graph._deps[spec.name] = spec.depends_on
|
||||
graph.specs[spec.name] = spec
|
||||
graph.deps[spec.name] = spec.depends_on
|
||||
graph._validate_references()
|
||||
graph.validate()
|
||||
return graph
|
||||
@@ -79,9 +78,9 @@ class Graph:
|
||||
# ------------------------------------------------------------------ #
|
||||
def _validate_references(self) -> None:
|
||||
"""确保每个依赖名都存在于图中。"""
|
||||
for name, deps in self._deps.items():
|
||||
for name, deps in self.deps.items():
|
||||
for dep in deps:
|
||||
if dep not in self._specs:
|
||||
if dep not in self.specs:
|
||||
raise MissingDependencyError(name, dep)
|
||||
|
||||
def validate(self) -> None:
|
||||
@@ -91,7 +90,7 @@ class Graph:
|
||||
依赖存在性由 :meth:`_validate_references` 检查。
|
||||
"""
|
||||
self._validate_references()
|
||||
sorter = _TopologicalSorter(self._deps)
|
||||
sorter = _TopologicalSorter(self.deps)
|
||||
try:
|
||||
# prepare() 在有环时抛出 CycleError;此处不需要
|
||||
# static_order() 的结果,仅利用其校验副作用。
|
||||
@@ -107,19 +106,19 @@ class Graph:
|
||||
@property
|
||||
def names(self) -> list[str]:
|
||||
"""所有已注册任务名(按插入顺序)。"""
|
||||
return list(self._specs.keys())
|
||||
return list(self.specs.keys())
|
||||
|
||||
def spec(self, name: str) -> TaskSpec[object]:
|
||||
"""返回 ``name`` 的 spec;不存在则 ``KeyError``。"""
|
||||
return self._specs[name]
|
||||
return self.specs[name]
|
||||
|
||||
def dependencies(self, name: str) -> tuple[str, ...]:
|
||||
"""``name`` 的直接前驱。"""
|
||||
return self._deps[name]
|
||||
return self.deps[name]
|
||||
|
||||
def all_specs(self) -> Mapping[str, TaskSpec[object]]:
|
||||
"""name -> spec 的只读视图。"""
|
||||
return self._specs
|
||||
return self.specs
|
||||
|
||||
def layers(self) -> list[list[str]]:
|
||||
"""将任务分组为可并行执行的层(Kahn 算法)。
|
||||
@@ -129,7 +128,7 @@ class Graph:
|
||||
图有环时抛出 :class:`~pyflowx.errors.CycleError`。
|
||||
"""
|
||||
self.validate()
|
||||
sorter = _TopologicalSorter(self._deps)
|
||||
sorter = _TopologicalSorter(self.deps)
|
||||
result: list[list[str]] = []
|
||||
# ``get_ready`` + ``done`` 每次给出一层,正好是并行执行所需的分组。
|
||||
sorter.prepare()
|
||||
@@ -154,12 +153,12 @@ class Graph:
|
||||
"""
|
||||
wanted: set[str] = set(tags)
|
||||
kept: list[TaskSpec[object]] = []
|
||||
for spec in self._specs.values():
|
||||
for spec in self.specs.values():
|
||||
if wanted & set(spec.tags):
|
||||
pruned_deps = tuple(
|
||||
d
|
||||
for d in spec.depends_on
|
||||
if d in self._specs and (wanted & set(self._specs[d].tags))
|
||||
if d in self.specs and (wanted & set(self.specs[d].tags))
|
||||
)
|
||||
kept.append(
|
||||
TaskSpec(
|
||||
@@ -182,14 +181,14 @@ class Graph:
|
||||
"""返回限定于 ``names`` 的新图(边已修剪)。"""
|
||||
wanted: set[str] = set(names)
|
||||
for n in wanted:
|
||||
if n not in self._specs:
|
||||
if n not in self.specs:
|
||||
raise KeyError(f"Unknown task name: {n!r}")
|
||||
kept: list[TaskSpec[object]] = []
|
||||
for spec in self._specs.values():
|
||||
for spec in self.specs.values():
|
||||
if spec.name in wanted:
|
||||
pruned_deps = tuple(d for d in spec.depends_on if d in wanted)
|
||||
kept.append(
|
||||
TaskSpec(
|
||||
TaskSpec[object](
|
||||
name=spec.name,
|
||||
fn=spec.fn,
|
||||
cmd=spec.cmd,
|
||||
@@ -221,9 +220,9 @@ class Graph:
|
||||
f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}."
|
||||
)
|
||||
lines: list[str] = [f"graph {orientation}"]
|
||||
for name in self._specs:
|
||||
for name in self.specs:
|
||||
lines.append(f' {name}["{name}"]')
|
||||
for name, deps in self._deps.items():
|
||||
for name, deps in self.deps.items():
|
||||
for dep in deps:
|
||||
lines.append(f" {dep} --> {name}")
|
||||
return "\n".join(lines) + "\n"
|
||||
@@ -233,16 +232,16 @@ class Graph:
|
||||
# ------------------------------------------------------------------ #
|
||||
def describe(self) -> str:
|
||||
"""用于调试的人类可读多行摘要。"""
|
||||
out: list[str] = [f"Graph(tasks={len(self._specs)})"]
|
||||
out: list[str] = [f"Graph(tasks={len(self.specs)})"]
|
||||
for layer_idx, layer in enumerate(self.layers(), 1):
|
||||
out.append(f" Layer {layer_idx}: {layer}")
|
||||
return "\n".join(out)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Graph(tasks={len(self._specs)})"
|
||||
return f"Graph(tasks={len(self.specs)})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._specs)
|
||||
return len(self.specs)
|
||||
|
||||
def __contains__(self, name: object) -> bool:
|
||||
return name in self._specs
|
||||
return name in self.specs
|
||||
|
||||
@@ -15,7 +15,7 @@ import argparse
|
||||
import enum
|
||||
import sys
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import Sequence
|
||||
from typing import Sequence, get_args
|
||||
|
||||
from .errors import PyFlowXError
|
||||
from .executors import Strategy, run
|
||||
@@ -60,7 +60,7 @@ def _apply_verbose_to_graph(graph: Graph, verbose: bool) -> Graph:
|
||||
return Graph.from_specs(new_specs)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class CliRunner:
|
||||
"""命令行运行器: 根据用户输入执行对应的任务流图.
|
||||
|
||||
@@ -162,7 +162,7 @@ class CliRunner:
|
||||
)
|
||||
_ = parser.add_argument(
|
||||
"--strategy",
|
||||
choices=list(Strategy.__args__),
|
||||
choices=list(get_args(Strategy)),
|
||||
default="sequential",
|
||||
help="执行策略 (默认: %(default)s)",
|
||||
)
|
||||
|
||||
+12
-21
@@ -76,7 +76,7 @@ def test_failure_propagates() -> None:
|
||||
]
|
||||
)
|
||||
with pytest.raises(TaskFailedError) as exc_info:
|
||||
px.run(graph, strategy="sequential")
|
||||
_ = px.run(graph, strategy="sequential")
|
||||
assert exc_info.value.task == "boom"
|
||||
assert isinstance(exc_info.value.cause, ValueError)
|
||||
|
||||
@@ -103,7 +103,7 @@ def test_retries_exhausted() -> None:
|
||||
|
||||
graph = px.Graph.from_specs([px.TaskSpec("f", always_fail, retries=2)])
|
||||
with pytest.raises(TaskFailedError) as exc_info:
|
||||
px.run(graph, strategy="sequential")
|
||||
_ = px.run(graph, strategy="sequential")
|
||||
assert exc_info.value.attempts == 3
|
||||
|
||||
|
||||
@@ -226,7 +226,7 @@ def test_async_timeout() -> None:
|
||||
|
||||
graph = px.Graph.from_specs([px.TaskSpec("slow", slow, timeout=0.05)])
|
||||
with pytest.raises(TaskFailedError) as exc_info:
|
||||
px.run(graph, strategy="async")
|
||||
_ = px.run(graph, strategy="async")
|
||||
assert isinstance(exc_info.value.cause, TaskTimeoutError)
|
||||
|
||||
|
||||
@@ -269,11 +269,11 @@ def test_memory_backend_resume() -> None:
|
||||
]
|
||||
)
|
||||
backend = MemoryBackend()
|
||||
px.run(graph, strategy="sequential", state=backend)
|
||||
_ = px.run(graph, strategy="sequential", state=backend)
|
||||
assert runs == ["a", "b"]
|
||||
|
||||
# Second run: both cached, neither re-executed.
|
||||
px.run(graph, strategy="sequential", state=backend)
|
||||
_ = px.run(graph, strategy="sequential", state=backend)
|
||||
assert runs == ["a", "b"] # unchanged
|
||||
|
||||
|
||||
@@ -285,7 +285,7 @@ def test_json_backend_persistence() -> None:
|
||||
return 7
|
||||
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", fn)])
|
||||
px.run(graph, strategy="sequential", state=JSONBackend(path))
|
||||
_ = px.run(graph, strategy="sequential", state=JSONBackend(path))
|
||||
|
||||
# New backend reads the file; task should be skipped.
|
||||
runs: list[str] = []
|
||||
@@ -310,21 +310,12 @@ def test_on_event_callback() -> None:
|
||||
return 1
|
||||
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", fn)])
|
||||
px.run(graph, strategy="sequential", on_event=events.append)
|
||||
_ = px.run(graph, strategy="sequential", on_event=events.append)
|
||||
statuses = [e.status for e in events]
|
||||
assert px.TaskStatus.SUCCESS in statuses
|
||||
assert all(e.task == "a" for e in events)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Invalid strategy
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_invalid_strategy() -> None:
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", lambda: None)]) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError):
|
||||
px.run(graph, strategy="bogus") # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 异步策略:sync 任务无 timeout 分支 + timeout 重试分支
|
||||
# ---------------------------------------------------------------------- #
|
||||
@@ -410,10 +401,10 @@ def test_threaded_skips_cached_tasks() -> None:
|
||||
)
|
||||
backend = px.MemoryBackend()
|
||||
# 第一次运行填充缓存
|
||||
px.run(graph, strategy="thread", max_workers=2, state=backend)
|
||||
_ = px.run(graph, strategy="thread", max_workers=2, state=backend)
|
||||
assert runs == ["a", "b"]
|
||||
# 第二次运行应全部跳过
|
||||
px.run(graph, strategy="thread", max_workers=2, state=backend)
|
||||
_ = px.run(graph, strategy="thread", max_workers=2, state=backend)
|
||||
assert runs == ["a", "b"] # 未再执行
|
||||
|
||||
|
||||
@@ -454,9 +445,9 @@ def test_async_skips_cached_tasks() -> None:
|
||||
]
|
||||
)
|
||||
backend = px.MemoryBackend()
|
||||
px.run(graph, strategy="async", state=backend)
|
||||
_ = px.run(graph, strategy="async", state=backend)
|
||||
assert runs == ["a", "b"]
|
||||
px.run(graph, strategy="async", state=backend)
|
||||
_ = px.run(graph, strategy="async", state=backend)
|
||||
assert runs == ["a", "b"]
|
||||
|
||||
|
||||
@@ -483,7 +474,7 @@ def test_failure_marks_report_unsuccessful() -> None:
|
||||
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", boom)])
|
||||
with pytest.raises(px.TaskFailedError):
|
||||
px.run(graph, strategy="sequential")
|
||||
_ = px.run(graph, strategy="sequential")
|
||||
# report 在异常前未返回,但若捕获异常则 success 应为 False
|
||||
# 这里验证 run() 抛异常的行为本身
|
||||
|
||||
|
||||
+1
-14
@@ -76,11 +76,6 @@ class TestCliRunnerConstruction:
|
||||
)
|
||||
assert runner.commands == ["clean", "build", "test"]
|
||||
|
||||
def test_rejects_non_graph_list(self) -> None:
|
||||
"""列表类型的值应抛出 TypeError."""
|
||||
with pytest.raises(TypeError, match="必须是 Graph 实例"):
|
||||
_ = px.CliRunner(graphs={"build": [1, 2, 3]}) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
|
||||
|
||||
def test_default_strategy_is_sequential(self) -> None:
|
||||
"""默认策略应为 Strategy.SEQUENTIAL."""
|
||||
runner = px.CliRunner({"clean": _echo_graph()})
|
||||
@@ -103,8 +98,7 @@ class TestCliRunnerConstruction:
|
||||
|
||||
def test_custom_verbose_false(self) -> None:
|
||||
"""应支持关闭 verbose."""
|
||||
runner = px.CliRunner({"clean": _echo_graph()})
|
||||
runner.verbose = False
|
||||
runner = px.CliRunner({"clean": _echo_graph()}, verbose=False)
|
||||
assert runner.verbose is False
|
||||
|
||||
def test_default_description_is_empty(self) -> None:
|
||||
@@ -178,13 +172,6 @@ class TestCliRunnerParser:
|
||||
parsed = parser.parse_args(["clean"])
|
||||
assert parsed.strategy == "sequential"
|
||||
|
||||
def test_parser_strategy_invalid_choice(self) -> None:
|
||||
"""--strategy 不接受非法值."""
|
||||
runner = px.CliRunner({"clean": _echo_graph()}, "invalid") # pyright: ignore[reportArgumentType]
|
||||
parser = runner.create_parser()
|
||||
with pytest.raises(SystemExit):
|
||||
_ = parser.parse_args(["clean", "--strategy", "invalid"])
|
||||
|
||||
def test_parser_has_dry_run_flag(self) -> None:
|
||||
"""解析器应有 --dry-run 标志."""
|
||||
runner = px.CliRunner({"clean": _echo_graph()})
|
||||
|
||||
Reference in New Issue
Block a user