Files
pyflowx/src/pyflowx/graph.py
T
zhou cd38e1246a chore: 版本升级到0.1.3并批量优化代码
变更包括:
1. 更新pyproject.toml行长度限制为120
2. 简化多处异常提示字符串的换行写法
3. 批量使用Any类型泛型优化类型标注
4. 重构cli/pymake.py的配置与任务定义
5. 删除冗余的测试代码与废弃的pymake测试文件
6. 修复示例代码的类型注解
2026-06-21 14:58:19 +08:00

244 lines
9.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""DAG 构建、校验、分层与可视化。
使用标准库的 :mod:`graphlib`3.9+)或 :mod:`graphlib_backport`3.8
进行拓扑排序。图以增量方式构建并即时校验,使配置错误在构建时(而非执行时)快速失败。
"""
from __future__ import annotations
import sys
from dataclasses import dataclass, field
from typing import Any, Iterable, Mapping, Sequence
from .errors import CycleError, DuplicateTaskError, MissingDependencyError
from .task import TaskSpec
# graphlib 自 3.9 起进入标准库;3.8 回退到 backport。
if sys.version_info >= (3, 9): # pragma: no cover
import graphlib # pyright: ignore[reportUnreachable]
_TopologicalSorter = graphlib.TopologicalSorter
else: # pragma: no cover
import graphlib # type: ignore[import-untyped] # pragma: no cover
_TopologicalSorter = graphlib.TopologicalSorter # pragma: no cover
@dataclass(frozen=True)
class Graph:
"""校验后不可变的有向无环任务图。
通过添加 :class:`~pyflowx.task.TaskSpec` 实例构建。每次 ``add`` 都
执行即时校验(重名、缺失依赖),:meth:`validate` / :meth:`layers`
执行完整 DAG 校验(环检测)与拓扑分层。
图仅持有*配置*;运行时状态存于 :class:`~pyflowx.report.RunReport`。
这使图可安全重复运行并在线程间共享。
"""
specs: dict[str, TaskSpec[Any]] = field(default_factory=dict)
deps: dict[str, tuple[str, ...]] = field(default_factory=dict)
# ------------------------------------------------------------------ #
# 构建
# ------------------------------------------------------------------ #
def add(self, spec: TaskSpec[Any]) -> Graph:
"""注册一个任务 spec,并即时校验。
返回 ``self`` 以支持链式调用,但推荐入口是 :meth:`from_specs`
它会整批校验(允许单次调用中的前向引用)。
"""
if spec.name in self.specs:
raise DuplicateTaskError(spec.name)
self.specs[spec.name] = spec
self.deps[spec.name] = spec.depends_on
# 为增量 API 即时检查重名与缺失依赖。
self._validate_references()
return self
@classmethod
def from_specs(cls, specs: Iterable[TaskSpec[Any]]) -> Graph:
"""从可迭代的 task spec 构建图。
先收集所有 spec,再统一校验。这意味着任务可以引用*后出现*的
依赖——顺序无关,就像声明式配置文件的读取方式。
"""
graph = cls()
for spec in specs:
if spec.name in graph.specs:
raise DuplicateTaskError(spec.name)
graph.specs[spec.name] = spec
graph.deps[spec.name] = spec.depends_on
graph._validate_references()
graph.validate()
return graph
# ------------------------------------------------------------------ #
# 校验
# ------------------------------------------------------------------ #
def _validate_references(self) -> None:
"""确保每个依赖名都存在于图中。"""
for name, deps in self.deps.items():
for dep in deps:
if dep not in self.specs:
raise MissingDependencyError(name, dep)
def validate(self) -> None:
"""执行完整 DAG 校验。
存在环时抛出 :class:`~pyflowx.errors.CycleError`。
依赖存在性由 :meth:`_validate_references` 检查。
"""
self._validate_references()
sorter = _TopologicalSorter(self.deps)
try:
# prepare() 在有环时抛出 CycleError;此处不需要
# static_order() 的结果,仅利用其校验副作用。
sorter.prepare()
except graphlib.CycleError as exc:
# exc.args[1] 是构成环的节点列表。
cycle: Sequence[str] = exc.args[1] if len(exc.args) > 1 else []
raise CycleError(list(cycle)) from exc
# ------------------------------------------------------------------ #
# 内省
# ------------------------------------------------------------------ #
@property
def names(self) -> list[str]:
"""所有已注册任务名(按插入顺序)。"""
return list(self.specs.keys())
def spec(self, name: str) -> TaskSpec[Any]:
"""返回 ``name`` 的 spec;不存在则 ``KeyError``。"""
return self.specs[name]
def dependencies(self, name: str) -> tuple[str, ...]:
"""``name`` 的直接前驱。"""
return self.deps[name]
def all_specs(self) -> Mapping[str, TaskSpec[Any]]:
"""name -> spec 的只读视图。"""
return self.specs
def layers(self) -> list[list[str]]:
"""将任务分组为可并行执行的层(Kahn 算法)。
同层任务无相互依赖,可并发执行。层按执行顺序返回。
图有环时抛出 :class:`~pyflowx.errors.CycleError`。
"""
self.validate()
sorter = _TopologicalSorter(self.deps)
result: list[list[str]] = []
# ``get_ready`` + ``done`` 每次给出一层,正好是并行执行所需的分组。
sorter.prepare()
while sorter.is_active():
ready = list(sorter.get_ready())
# 排序以保证确定性、可复现的执行计划。
ready.sort()
result.append(ready)
for node in ready:
sorter.done(node)
return result
# ------------------------------------------------------------------ #
# 子图 / 标签过滤
# ------------------------------------------------------------------ #
def subgraph(self, tags: Iterable[str]) -> Graph:
"""返回仅包含匹配任意标签的任务的新图。
依赖会被修剪,仅保留被保留任务之间的边;指向被丢弃任务的边
会被移除(被保留的任务不再等待它们)。用于调试时运行大型
DAG 的切片。
"""
wanted: set[str] = set(tags)
kept: list[TaskSpec[Any]] = []
for spec in self.specs.values():
if wanted & set(spec.tags):
pruned_deps = tuple(
d for d in spec.depends_on if d in self.specs and (wanted & set(self.specs[d].tags))
)
kept.append(
TaskSpec[Any](
name=spec.name,
fn=spec.fn,
cmd=spec.cmd,
depends_on=pruned_deps,
args=spec.args,
kwargs=spec.kwargs,
retries=spec.retries,
timeout=spec.timeout,
tags=spec.tags,
conditions=spec.conditions,
cwd=spec.cwd,
)
)
return Graph.from_specs(kept)
def subgraph_by_names(self, names: Iterable[str]) -> Graph:
"""返回限定于 ``names`` 的新图(边已修剪)。"""
wanted: set[str] = set(names)
for n in wanted:
if n not in self.specs:
raise KeyError(f"Unknown task name: {n!r}")
kept: list[TaskSpec[Any]] = []
for spec in self.specs.values():
if spec.name in wanted:
pruned_deps = tuple(d for d in spec.depends_on if d in wanted)
kept.append(
TaskSpec[Any](
name=spec.name,
fn=spec.fn,
cmd=spec.cmd,
depends_on=pruned_deps,
args=spec.args,
kwargs=spec.kwargs,
retries=spec.retries,
timeout=spec.timeout,
tags=spec.tags,
conditions=spec.conditions,
cwd=spec.cwd,
)
)
return Graph.from_specs(kept)
# ------------------------------------------------------------------ #
# 可视化
# ------------------------------------------------------------------ #
def to_mermaid(self, orientation: str = "TD") -> str:
"""将 DAG 渲染为 Mermaid ``graph`` 定义字符串。
无外部依赖;输出可粘贴到 Markdown、由 VS Code 的 Mermaid 预览
渲染,或保存为文件。
"""
valid = {"TD", "TB", "BT", "LR", "RL"}
orientation = orientation.upper()
if orientation not in valid:
raise ValueError(f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}.")
lines: list[str] = [f"graph {orientation}"]
for name in self.specs:
lines.append(f' {name}["{name}"]')
for name, deps in self.deps.items():
for dep in deps:
lines.append(f" {dep} --> {name}")
return "\n".join(lines) + "\n"
# ------------------------------------------------------------------ #
# 调试
# ------------------------------------------------------------------ #
def describe(self) -> str:
"""用于调试的人类可读多行摘要。"""
out: list[str] = [f"Graph(tasks={len(self.specs)})"]
for layer_idx, layer in enumerate(self.layers(), 1):
out.append(f" Layer {layer_idx}: {layer}")
return "\n".join(out)
def __repr__(self) -> str:
return f"Graph(tasks={len(self.specs)})"
def __len__(self) -> int:
return len(self.specs)
def __contains__(self, name: Any) -> bool:
return name in self.specs