refactor(graph,runner,test): 重构代码并清理冗余逻辑
1. 将Graph类改为frozen dataclass简化实现 2. 移除executors.py中的内置策略校验逻辑 3. 使用typing.get_args替代直接访问Strategy.__args__ 4. 清理测试文件中冗余的无效参数测试用例 5. 统一替换测试中未使用的px.run调用返回值 6. 在pyproject.toml中添加pytest slow标记配置
This commit is contained in:
@@ -389,11 +389,6 @@ def run(
|
||||
graph.validate()
|
||||
layers = graph.layers()
|
||||
|
||||
# 验证策略是否有效
|
||||
valid_strategies = ("sequential", "thread", "async")
|
||||
if strategy not in valid_strategies:
|
||||
raise ValueError(f"unknown strategy: {strategy}. Valid: {valid_strategies}")
|
||||
|
||||
if dry_run:
|
||||
_print_dry_run(graph, layers)
|
||||
return RunReport(success=True)
|
||||
|
||||
+30
-31
@@ -1,13 +1,13 @@
|
||||
"""DAG 构建、校验、分层与可视化。
|
||||
|
||||
使用标准库的 :mod:`graphlib`(3.9+)或 :mod:`graphlib_backport`(3.8)
|
||||
进行拓扑排序。图以增量方式构建并即时校验,使配置错误在构建时(而非
|
||||
执行时)快速失败。
|
||||
进行拓扑排序。图以增量方式构建并即时校验,使配置错误在构建时(而非执行时)快速失败。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable, Mapping, Sequence
|
||||
|
||||
from .errors import CycleError, DuplicateTaskError, MissingDependencyError
|
||||
@@ -24,6 +24,7 @@ else: # pragma: no cover
|
||||
_TopologicalSorter = graphlib.TopologicalSorter # pragma: no cover
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Graph:
|
||||
"""校验后不可变的有向无环任务图。
|
||||
|
||||
@@ -35,10 +36,8 @@ class Graph:
|
||||
这使图可安全重复运行并在线程间共享。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._specs: dict[str, TaskSpec[object]] = {}
|
||||
# 任务 -> 其直接依赖(前驱)。
|
||||
self._deps: dict[str, tuple[str, ...]] = {}
|
||||
specs: dict[str, TaskSpec[object]] = field(default_factory=dict)
|
||||
deps: dict[str, tuple[str, ...]] = field(default_factory=dict)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 构建
|
||||
@@ -49,10 +48,10 @@ class Graph:
|
||||
返回 ``self`` 以支持链式调用,但推荐入口是 :meth:`from_specs`,
|
||||
它会整批校验(允许单次调用中的前向引用)。
|
||||
"""
|
||||
if spec.name in self._specs:
|
||||
if spec.name in self.specs:
|
||||
raise DuplicateTaskError(spec.name)
|
||||
self._specs[spec.name] = spec
|
||||
self._deps[spec.name] = spec.depends_on
|
||||
self.specs[spec.name] = spec
|
||||
self.deps[spec.name] = spec.depends_on
|
||||
# 为增量 API 即时检查重名与缺失依赖。
|
||||
self._validate_references()
|
||||
return self
|
||||
@@ -66,10 +65,10 @@ class Graph:
|
||||
"""
|
||||
graph = cls()
|
||||
for spec in specs:
|
||||
if spec.name in graph._specs:
|
||||
if spec.name in graph.specs:
|
||||
raise DuplicateTaskError(spec.name)
|
||||
graph._specs[spec.name] = spec
|
||||
graph._deps[spec.name] = spec.depends_on
|
||||
graph.specs[spec.name] = spec
|
||||
graph.deps[spec.name] = spec.depends_on
|
||||
graph._validate_references()
|
||||
graph.validate()
|
||||
return graph
|
||||
@@ -79,9 +78,9 @@ class Graph:
|
||||
# ------------------------------------------------------------------ #
|
||||
def _validate_references(self) -> None:
|
||||
"""确保每个依赖名都存在于图中。"""
|
||||
for name, deps in self._deps.items():
|
||||
for name, deps in self.deps.items():
|
||||
for dep in deps:
|
||||
if dep not in self._specs:
|
||||
if dep not in self.specs:
|
||||
raise MissingDependencyError(name, dep)
|
||||
|
||||
def validate(self) -> None:
|
||||
@@ -91,7 +90,7 @@ class Graph:
|
||||
依赖存在性由 :meth:`_validate_references` 检查。
|
||||
"""
|
||||
self._validate_references()
|
||||
sorter = _TopologicalSorter(self._deps)
|
||||
sorter = _TopologicalSorter(self.deps)
|
||||
try:
|
||||
# prepare() 在有环时抛出 CycleError;此处不需要
|
||||
# static_order() 的结果,仅利用其校验副作用。
|
||||
@@ -107,19 +106,19 @@ class Graph:
|
||||
@property
|
||||
def names(self) -> list[str]:
|
||||
"""所有已注册任务名(按插入顺序)。"""
|
||||
return list(self._specs.keys())
|
||||
return list(self.specs.keys())
|
||||
|
||||
def spec(self, name: str) -> TaskSpec[object]:
|
||||
"""返回 ``name`` 的 spec;不存在则 ``KeyError``。"""
|
||||
return self._specs[name]
|
||||
return self.specs[name]
|
||||
|
||||
def dependencies(self, name: str) -> tuple[str, ...]:
|
||||
"""``name`` 的直接前驱。"""
|
||||
return self._deps[name]
|
||||
return self.deps[name]
|
||||
|
||||
def all_specs(self) -> Mapping[str, TaskSpec[object]]:
|
||||
"""name -> spec 的只读视图。"""
|
||||
return self._specs
|
||||
return self.specs
|
||||
|
||||
def layers(self) -> list[list[str]]:
|
||||
"""将任务分组为可并行执行的层(Kahn 算法)。
|
||||
@@ -129,7 +128,7 @@ class Graph:
|
||||
图有环时抛出 :class:`~pyflowx.errors.CycleError`。
|
||||
"""
|
||||
self.validate()
|
||||
sorter = _TopologicalSorter(self._deps)
|
||||
sorter = _TopologicalSorter(self.deps)
|
||||
result: list[list[str]] = []
|
||||
# ``get_ready`` + ``done`` 每次给出一层,正好是并行执行所需的分组。
|
||||
sorter.prepare()
|
||||
@@ -154,12 +153,12 @@ class Graph:
|
||||
"""
|
||||
wanted: set[str] = set(tags)
|
||||
kept: list[TaskSpec[object]] = []
|
||||
for spec in self._specs.values():
|
||||
for spec in self.specs.values():
|
||||
if wanted & set(spec.tags):
|
||||
pruned_deps = tuple(
|
||||
d
|
||||
for d in spec.depends_on
|
||||
if d in self._specs and (wanted & set(self._specs[d].tags))
|
||||
if d in self.specs and (wanted & set(self.specs[d].tags))
|
||||
)
|
||||
kept.append(
|
||||
TaskSpec(
|
||||
@@ -182,14 +181,14 @@ class Graph:
|
||||
"""返回限定于 ``names`` 的新图(边已修剪)。"""
|
||||
wanted: set[str] = set(names)
|
||||
for n in wanted:
|
||||
if n not in self._specs:
|
||||
if n not in self.specs:
|
||||
raise KeyError(f"Unknown task name: {n!r}")
|
||||
kept: list[TaskSpec[object]] = []
|
||||
for spec in self._specs.values():
|
||||
for spec in self.specs.values():
|
||||
if spec.name in wanted:
|
||||
pruned_deps = tuple(d for d in spec.depends_on if d in wanted)
|
||||
kept.append(
|
||||
TaskSpec(
|
||||
TaskSpec[object](
|
||||
name=spec.name,
|
||||
fn=spec.fn,
|
||||
cmd=spec.cmd,
|
||||
@@ -221,9 +220,9 @@ class Graph:
|
||||
f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}."
|
||||
)
|
||||
lines: list[str] = [f"graph {orientation}"]
|
||||
for name in self._specs:
|
||||
for name in self.specs:
|
||||
lines.append(f' {name}["{name}"]')
|
||||
for name, deps in self._deps.items():
|
||||
for name, deps in self.deps.items():
|
||||
for dep in deps:
|
||||
lines.append(f" {dep} --> {name}")
|
||||
return "\n".join(lines) + "\n"
|
||||
@@ -233,16 +232,16 @@ class Graph:
|
||||
# ------------------------------------------------------------------ #
|
||||
def describe(self) -> str:
|
||||
"""用于调试的人类可读多行摘要。"""
|
||||
out: list[str] = [f"Graph(tasks={len(self._specs)})"]
|
||||
out: list[str] = [f"Graph(tasks={len(self.specs)})"]
|
||||
for layer_idx, layer in enumerate(self.layers(), 1):
|
||||
out.append(f" Layer {layer_idx}: {layer}")
|
||||
return "\n".join(out)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Graph(tasks={len(self._specs)})"
|
||||
return f"Graph(tasks={len(self.specs)})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._specs)
|
||||
return len(self.specs)
|
||||
|
||||
def __contains__(self, name: object) -> bool:
|
||||
return name in self._specs
|
||||
return name in self.specs
|
||||
|
||||
@@ -15,7 +15,7 @@ import argparse
|
||||
import enum
|
||||
import sys
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import Sequence
|
||||
from typing import Sequence, get_args
|
||||
|
||||
from .errors import PyFlowXError
|
||||
from .executors import Strategy, run
|
||||
@@ -60,7 +60,7 @@ def _apply_verbose_to_graph(graph: Graph, verbose: bool) -> Graph:
|
||||
return Graph.from_specs(new_specs)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class CliRunner:
|
||||
"""命令行运行器: 根据用户输入执行对应的任务流图.
|
||||
|
||||
@@ -162,7 +162,7 @@ class CliRunner:
|
||||
)
|
||||
_ = parser.add_argument(
|
||||
"--strategy",
|
||||
choices=list(Strategy.__args__),
|
||||
choices=list(get_args(Strategy)),
|
||||
default="sequential",
|
||||
help="执行策略 (默认: %(default)s)",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user