refactor: 全面迁移至 Python 3.9+ 原生泛型类型语法

- 将所有 `Optional[T]` 替换为 `T | None`
- 将所有 `List[T]`/`Dict[K, V]`/`Tuple[Ts, ...]` 替换为对应原生泛型
- 调整类型导入,移除冗余的 typing 导入项
- 更新项目依赖,添加 typing-extensions 兼容旧版本 Python
- 重构部分函数签名与内部实现以匹配新类型语法
This commit is contained in:
2026-06-20 17:52:42 +08:00
parent c06d0284c4
commit 08eb743ea9
18 changed files with 962 additions and 177 deletions
+33 -33
View File
@@ -81,43 +81,43 @@ from .task import TaskCmd, TaskEvent, TaskResult, TaskSpec, TaskStatus
__version__ = "0.1.2"
__all__ = [
# 核心类型
"TaskSpec",
"TaskStatus",
"TaskResult",
"TaskEvent",
"Context",
"TaskCmd",
"Graph",
"RunReport",
# 执行
"run",
"Strategy",
# CLI 运行器
"CliRunner",
"CliExitCode",
# 状态后端
"StateBackend",
"MemoryBackend",
"JSONBackend",
# 错误
"PyFlowXError",
"DuplicateTaskError",
"MissingDependencyError",
"CycleError",
"TaskFailedError",
"TaskTimeoutError",
"InjectionError",
"StorageError",
# 条件判断
"Condition",
"Constants",
"BuiltinConditions",
"IS_WINDOWS",
"IS_LINUX",
"IS_MACOS",
"IS_POSIX",
"IS_WINDOWS",
"BuiltinConditions",
"CliExitCode",
# CLI 运行器
"CliRunner",
# 条件判断
"Condition",
"Constants",
"Context",
"CycleError",
"DuplicateTaskError",
"Graph",
"InjectionError",
"JSONBackend",
"MemoryBackend",
"MissingDependencyError",
# 错误
"PyFlowXError",
"RunReport",
# 状态后端
"StateBackend",
"StorageError",
"Strategy",
"TaskCmd",
"TaskEvent",
"TaskFailedError",
"TaskResult",
# 核心类型
"TaskSpec",
"TaskStatus",
"TaskTimeoutError",
# 辅助(高级)
"build_call_args",
"describe_injection",
# 执行
"run",
]
+2 -2
View File
@@ -8,7 +8,7 @@ from __future__ import annotations
import shutil
import sys
from typing import Callable, Optional
from typing import Callable
# 条件判断函数类型
Condition = Callable[[], bool]
@@ -47,7 +47,7 @@ class BuiltinConditions:
return Constants.IS_POSIX
@staticmethod
def PYTHON_VERSION(major: int, minor: Optional[int] = None) -> bool:
def PYTHON_VERSION(major: int, minor: int | None = None) -> bool:
"""检查 Python 版本是否匹配.
Parameters
+11 -11
View File
@@ -18,12 +18,12 @@ DAG 库中泛滥的样板包装器(如 ``def wrapper(): return fn(workflow.get
from __future__ import annotations
import inspect
from typing import Any, Dict, List, Mapping, Set, Tuple
from typing import Any, Mapping
from .errors import InjectionError
from .task import Context, TaskSpec
__all__ = ["Context", "build_call_args", "describe_injection", "_is_context_annotation"]
__all__ = ["Context", "_is_context_annotation", "build_call_args", "describe_injection"]
def _is_context_annotation(annotation: Any) -> bool:
@@ -43,15 +43,13 @@ def _is_context_annotation(annotation: Any) -> bool:
return annotation == "Context" or annotation.endswith(".Context")
# 按限定名匹配,支持 ``from pyflowx import Context`` 再导出。
name = getattr(annotation, "__name__", None) or getattr(annotation, "_name", None)
if name in ("Context", "Mapping"):
return True
return False
return name in ("Context", "Mapping")
def build_call_args(
spec: TaskSpec[object],
context: Mapping[str, Any],
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
) -> tuple[tuple[Any, ...], dict[str, Any]]:
"""解析用于调用 ``spec.fn`` 的 ``(args, kwargs)``。
参数
@@ -84,7 +82,9 @@ def build_call_args(
)
# 与本任务相关的上下文子集。
dep_context: Dict[str, Any] = {name: context[name] for name in spec.depends_on if name in context}
dep_context: dict[str, Any] = {
name: context[name] for name in spec.depends_on if name in context
}
# 检测静态 kwargs 与依赖名的冲突。
collisions = set(spec.kwargs) & set(dep_context)
@@ -95,12 +95,12 @@ def build_call_args(
"rename the static kwarg or the dependency.",
)
injected_kwargs: Dict[str, Any] = {}
leftover_dep_results: Dict[str, Any] = dict(dep_context)
injected_kwargs: dict[str, Any] = {}
leftover_dep_results: dict[str, Any] = dict(dep_context)
# 被 spec.args 消费的位置参数。记录哪些参数名已被位置填充,
# 以便在基于名称的注入(依赖 / Context / 静态 kwargs)时跳过。
positional_params: List[str] = []
positional_params: list[str] = []
positional_kinds = (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
@@ -109,7 +109,7 @@ def build_call_args(
if param.kind in positional_kinds:
positional_params.append(pname)
# 前 len(spec.args) 个位置参数由 spec.args 填充。
args_filled: Set[str] = set(positional_params[: len(spec.args)])
args_filled: set[str] = set(positional_params[: len(spec.args)])
for pname, param in params.items():
# 跳过已被位置 spec.args 填充的参数。
+6 -4
View File
@@ -6,7 +6,7 @@
from __future__ import annotations
from typing import Any, Iterable, Optional
from typing import Any, Iterable
class PyFlowXError(Exception):
@@ -55,10 +55,12 @@ class TaskFailedError(PyFlowXError):
task: str,
cause: BaseException,
attempts: int,
layer: Optional[int] = None,
layer: int | None = None,
) -> None:
location = f" (layer {layer})" if layer is not None else ""
super().__init__(f"Task '{task}' failed after {attempts} attempt(s){location}: {cause}")
super().__init__(
f"Task '{task}' failed after {attempts} attempt(s){location}: {cause}"
)
self.task = task
self.cause = cause
self.attempts = attempts
@@ -85,6 +87,6 @@ class InjectionError(PyFlowXError):
class StorageError(PyFlowXError):
"""状态后端在持久化失败时抛出。"""
def __init__(self, detail: str, cause: Optional[BaseException] = None) -> None:
def __init__(self, detail: str, cause: BaseException | None = None) -> None:
super().__init__(f"State storage error: {detail}")
self.cause: Any = cause
+4 -4
View File
@@ -10,7 +10,7 @@ Shows:
from __future__ import annotations
import asyncio
from typing import Any, Dict, List
from typing import Any
import pyflowx as px
@@ -20,13 +20,13 @@ async def fetch_user(uid: int) -> dict:
return {"id": uid, "name": f"User{uid}"}
async def fetch_posts(uid: int) -> List[int]:
async def fetch_posts(uid: int) -> list[int]:
await asyncio.sleep(0.2)
return [uid, uid + 1]
# Context annotation → receives the full mapping of upstream results.
def aggregate(ctx: px.Context) -> Dict[str, Any]:
def aggregate(ctx: px.Context) -> dict[str, Any]:
return dict(ctx)
@@ -43,7 +43,7 @@ def main() -> None:
print("=== Dry run ===")
px.run(graph, strategy="async", dry_run=True)
events: List[px.TaskEvent] = []
events: list[px.TaskEvent] = []
print("\n=== Async execution ===")
report = px.run(graph, strategy="async", on_event=events.append)
+6 -8
View File
@@ -10,21 +10,19 @@ Demonstrates the core PyFlowX workflow:
from __future__ import annotations
from typing import List
import pyflowx as px
# --- task functions: pure, testable, no framework coupling ------------- #
def extract_customers() -> List[dict]:
def extract_customers() -> list[dict]:
return [
{"id": "C001", "name": "Alice"},
{"id": "C002", "name": "Bob"},
]
def extract_orders() -> List[dict]:
def extract_orders() -> list[dict]:
return [
{"id": "O001", "customer_id": "C001", "amount": 150.0},
{"id": "O002", "customer_id": "C002", "amount": 200.5},
@@ -33,9 +31,9 @@ def extract_orders() -> List[dict]:
# Parameter names match dependency names → automatic injection.
def transform(
extract_customers: List[dict],
extract_orders: List[dict],
) -> List[dict]:
extract_customers: list[dict],
extract_orders: list[dict],
) -> list[dict]:
cmap = {c["id"]: c for c in extract_customers}
return [
{**o, "customer_name": cmap[o["customer_id"]]["name"]}
@@ -44,7 +42,7 @@ def transform(
]
def load(transform: List[dict]) -> int:
def load(transform: list[dict]) -> int:
print(f" loaded {len(transform)} records")
return len(transform)
+74 -52
View File
@@ -20,7 +20,7 @@ import enum
import inspect
import logging
from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Union, cast
from typing import Any, Awaitable, Callable, Mapping, cast
from .context import build_call_args, describe_injection
from .errors import TaskFailedError, TaskTimeoutError
@@ -53,7 +53,7 @@ class Strategy(enum.Enum):
ASYNC = "async"
def _normalize_strategy(strategy: Union[str, Strategy]) -> Strategy:
def _normalize_strategy(strategy: str | Strategy) -> Strategy:
"""将字符串或 Strategy 归一化为 Strategy 枚举.
Parameters
@@ -79,7 +79,9 @@ def _normalize_strategy(strategy: Union[str, Strategy]) -> Strategy:
return Strategy(strategy)
except ValueError:
valid = ", ".join(repr(s.value) for s in Strategy)
raise ValueError(f"unknown strategy {strategy!r}; expected one of {valid}.") from None
raise ValueError(
f"unknown strategy {strategy!r}; expected one of {valid}."
) from None
raise TypeError(f"strategy must be str or Strategy, got {type(strategy).__name__}")
@@ -89,7 +91,7 @@ def _is_async_fn(spec: TaskSpec[object]) -> bool:
def _emit(
on_event: Optional[EventCallback],
on_event: EventCallback | None,
result: TaskResult[object],
) -> None:
"""若注册了回调则触发一个观察者事件。"""
@@ -106,7 +108,9 @@ def _emit(
)
def _log_retry(spec: TaskSpec[object], attempts: int, max_attempts: int, exc: BaseException) -> None:
def _log_retry(
spec: TaskSpec[object], attempts: int, max_attempts: int, exc: BaseException
) -> None:
"""记录重试日志(sync 与 async 共享,便于测试覆盖)。"""
logger.warning(
"task %r failed (attempt %d/%d): %r; retrying",
@@ -117,7 +121,7 @@ def _log_retry(spec: TaskSpec[object], attempts: int, max_attempts: int, exc: Ba
)
def _finalize_failure(result: TaskResult[object], layer_idx: Optional[int]) -> None:
def _finalize_failure(result: TaskResult[object], layer_idx: int | None) -> None:
"""标记任务为 FAILED 并抛出 TaskFailedError。"""
result.status = TaskStatus.FAILED
result.finished_at = datetime.now()
@@ -132,7 +136,7 @@ def _finalize_failure(result: TaskResult[object], layer_idx: Optional[int]) -> N
def _run_sync_with_retry(
spec: TaskSpec[object],
context: Mapping[str, Any],
layer_idx: Optional[int],
layer_idx: int | None,
) -> TaskResult[object]:
"""执行同步任务并带重试;返回填充好的 TaskResult。"""
result: TaskResult[object] = TaskResult(spec=spec)
@@ -155,7 +159,7 @@ def _run_sync_with_retry(
result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now()
return result
except Exception as exc: # noqa: BLE001 - 用户代码可能抛任何异常
except Exception as exc:
result.error = exc
if result.attempts >= max_attempts:
_finalize_failure(result, layer_idx) # pragma: no cover
@@ -166,7 +170,7 @@ def _run_sync_with_retry(
async def _run_async_with_retry(
spec: TaskSpec[object],
context: Mapping[str, Any],
layer_idx: Optional[int],
layer_idx: int | None,
) -> TaskResult[object]:
"""在事件循环上执行任务(同步或异步)并带重试。"""
result: TaskResult[object] = TaskResult(spec=spec)
@@ -198,7 +202,9 @@ async def _run_async_with_retry(
return spec.effective_fn(*args, **kwargs)
if spec.timeout is not None:
result.value = await asyncio.wait_for(loop.run_in_executor(None, fn_call), timeout=spec.timeout)
result.value = await asyncio.wait_for(
loop.run_in_executor(None, fn_call), timeout=spec.timeout
)
else:
result.value = await loop.run_in_executor(None, fn_call)
result.status = TaskStatus.SUCCESS
@@ -214,7 +220,7 @@ async def _run_async_with_retry(
result.attempts,
max_attempts,
)
except Exception as exc: # noqa: BLE001
except Exception as exc:
result.error = exc
if result.attempts >= max_attempts:
_finalize_failure(result, layer_idx) # pragma: no cover
@@ -230,17 +236,19 @@ def _build_context(
global_context: Mapping[str, Any],
) -> Mapping[str, Any]:
"""将全局上下文限制为本任务的依赖。"""
return {dep: global_context[dep] for dep in spec.depends_on if dep in global_context}
return {
dep: global_context[dep] for dep in spec.depends_on if dep in global_context
}
def _execute_layer_sequential(
layer: List[str],
layer: list[str],
graph: Graph,
context: Dict[str, Any],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
layer_idx: int,
on_event: Optional[EventCallback],
on_event: EventCallback | None,
) -> None:
"""逐个运行某层的任务。"""
for name in layer:
@@ -261,23 +269,25 @@ def _execute_layer_sequential(
def _execute_layer_threaded(
layer: List[str],
layer: list[str],
graph: Graph,
context: Dict[str, Any],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
layer_idx: int,
on_event: Optional[EventCallback],
on_event: EventCallback | None,
max_workers: int,
) -> None:
"""在线程池中并发运行某层的任务。"""
# 先同步满足已缓存任务。
to_run: List[str] = []
to_run: list[str] = []
for name in layer:
if backend.has(name):
cached = backend.get(name)
context[name] = cached
result = TaskResult(spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached)
result = TaskResult(
spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached
)
report.results[name] = result
_emit(on_event, result)
else:
@@ -287,7 +297,7 @@ def _execute_layer_threaded(
return
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
future_to_name: Dict[concurrent.futures.Future[TaskResult[object]], str] = {}
future_to_name: dict[concurrent.futures.Future[TaskResult[object]], str] = {}
for name in to_run:
spec = graph.spec(name)
# 为本任务快照上下文以避免竞态。
@@ -305,21 +315,23 @@ def _execute_layer_threaded(
async def _execute_layer_async(
layer: List[str],
layer: list[str],
graph: Graph,
context: Dict[str, Any],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
layer_idx: int,
on_event: Optional[EventCallback],
on_event: EventCallback | None,
) -> None:
"""在事件循环上并发运行某层的任务。"""
to_run: List[str] = []
to_run: list[str] = []
for name in layer:
if backend.has(name):
cached = backend.get(name)
context[name] = cached
result = TaskResult(spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached)
result = TaskResult(
spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached
)
report.results[name] = result
_emit(on_event, result)
else:
@@ -346,8 +358,8 @@ async def _execute_layer_async(
# 公共 API
# ---------------------------------------------------------------------- #
def _make_verbose_callback(
on_event: Optional[EventCallback],
) -> Optional[EventCallback]:
on_event: EventCallback | None,
) -> EventCallback | None:
"""包装 on_event 回调, 在 verbose 模式下打印任务生命周期.
Parameters
@@ -385,13 +397,13 @@ def _make_verbose_callback(
def run(
graph: Graph,
strategy: Union[str, Strategy] = Strategy.SEQUENTIAL,
strategy: str | Strategy = Strategy.SEQUENTIAL,
*,
max_workers: Optional[int] = None,
max_workers: int | None = None,
dry_run: bool = False,
verbose: bool = False,
on_event: Optional[EventCallback] = None,
state: Optional[StateBackend] = None,
on_event: EventCallback | None = None,
state: StateBackend | None = None,
) -> RunReport:
"""执行图并返回 :class:`RunReport`。
@@ -434,17 +446,23 @@ def run(
return RunReport(success=True)
# verbose 模式下包装事件回调
effective_callback: Optional[EventCallback] = _make_verbose_callback(on_event) if verbose else on_event
effective_callback: EventCallback | None = (
_make_verbose_callback(on_event) if verbose else on_event
)
backend = resolve_backend(state)
report = RunReport()
context: Dict[str, Any] = {}
context: dict[str, Any] = {}
try:
if normalized == Strategy.SEQUENTIAL:
_drive_sequential(graph, layers, context, report, backend, effective_callback)
_drive_sequential(
graph, layers, context, report, backend, effective_callback
)
elif normalized == Strategy.THREAD:
_drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers)
_drive_threaded(
graph, layers, context, report, backend, effective_callback, max_workers
)
else:
_drive_async(graph, layers, context, report, backend, effective_callback)
except TaskFailedError:
@@ -454,7 +472,7 @@ def run(
return report
def _print_dry_run(graph: Graph, layers: List[List[str]]) -> None:
def _print_dry_run(graph: Graph, layers: list[list[str]]) -> None:
"""打印执行计划但不运行任何任务。"""
print(f"Dry run: {len(graph)} tasks, {len(layers)} layers")
for idx, layer in enumerate(layers, 1):
@@ -465,11 +483,11 @@ def _print_dry_run(graph: Graph, layers: List[List[str]]) -> None:
def _drive_sequential(
graph: Graph,
layers: List[List[str]],
context: Dict[str, Any],
layers: list[list[str]],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: Optional[EventCallback],
on_event: EventCallback | None,
) -> None:
for idx, layer in enumerate(layers, 1):
_execute_layer_sequential(layer, graph, context, report, backend, idx, on_event)
@@ -477,36 +495,40 @@ def _drive_sequential(
def _drive_threaded(
graph: Graph,
layers: List[List[str]],
context: Dict[str, Any],
layers: list[list[str]],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: Optional[EventCallback],
max_workers: Optional[int],
on_event: EventCallback | None,
max_workers: int | None,
) -> None:
for idx, layer in enumerate(layers, 1):
workers = max_workers or max(1, min(32, len(layer)))
_execute_layer_threaded(layer, graph, context, report, backend, idx, on_event, workers)
_execute_layer_threaded(
layer, graph, context, report, backend, idx, on_event, workers
)
def _drive_async(
graph: Graph,
layers: List[List[str]],
context: Dict[str, Any],
layers: list[list[str]],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: Optional[EventCallback],
on_event: EventCallback | None,
) -> None:
asyncio.run(_async_drive(graph, layers, context, report, backend, on_event))
async def _async_drive(
graph: Graph,
layers: List[List[str]],
context: Dict[str, Any],
layers: list[list[str]],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: Optional[EventCallback],
on_event: EventCallback | None,
) -> None:
for idx, layer in enumerate(layers, 1):
await _execute_layer_async(layer, graph, context, report, backend, idx, on_event)
await _execute_layer_async(
layer, graph, context, report, backend, idx, on_event
)
+23 -19
View File
@@ -8,7 +8,7 @@
from __future__ import annotations
import sys
from typing import Dict, Iterable, List, Mapping, Sequence, Set, Tuple
from typing import Iterable, Mapping, Sequence
from typing_extensions import override
@@ -38,14 +38,14 @@ class Graph:
"""
def __init__(self) -> None:
self._specs: Dict[str, TaskSpec[object]] = {}
self._specs: dict[str, TaskSpec[object]] = {}
# 任务 -> 其直接依赖(前驱)。
self._deps: Dict[str, Tuple[str, ...]] = {}
self._deps: dict[str, tuple[str, ...]] = {}
# ------------------------------------------------------------------ #
# 构建
# ------------------------------------------------------------------ #
def add(self, spec: TaskSpec[object]) -> "Graph":
def add(self, spec: TaskSpec[object]) -> Graph:
"""注册一个任务 spec,并即时校验。
返回 ``self`` 以支持链式调用但推荐入口是 :meth:`from_specs`
@@ -60,7 +60,7 @@ class Graph:
return self
@classmethod
def from_specs(cls, specs: Iterable[TaskSpec[object]]) -> "Graph":
def from_specs(cls, specs: Iterable[TaskSpec[object]]) -> Graph:
"""从可迭代的 task spec 构建图。
先收集所有 spec再统一校验这意味着任务可以引用*后出现*
@@ -107,7 +107,7 @@ class Graph:
# 内省
# ------------------------------------------------------------------ #
@property
def names(self) -> List[str]:
def names(self) -> list[str]:
"""所有已注册任务名(按插入顺序)。"""
return list(self._specs.keys())
@@ -115,7 +115,7 @@ class Graph:
"""返回 ``name`` 的 spec;不存在则 ``KeyError``。"""
return self._specs[name]
def dependencies(self, name: str) -> Tuple[str, ...]:
def dependencies(self, name: str) -> tuple[str, ...]:
"""``name`` 的直接前驱。"""
return self._deps[name]
@@ -123,7 +123,7 @@ class Graph:
"""name -> spec 的只读视图。"""
return self._specs
def layers(self) -> List[List[str]]:
def layers(self) -> list[list[str]]:
"""将任务分组为可并行执行的层(Kahn 算法)。
同层任务无相互依赖可并发执行层按执行顺序返回
@@ -132,7 +132,7 @@ class Graph:
"""
self.validate()
sorter = _TopologicalSorter(self._deps)
result: List[List[str]] = []
result: list[list[str]] = []
# ``get_ready`` + ``done`` 每次给出一层,正好是并行执行所需的分组。
sorter.prepare()
while sorter.is_active():
@@ -147,19 +147,21 @@ class Graph:
# ------------------------------------------------------------------ #
# 子图 / 标签过滤
# ------------------------------------------------------------------ #
def subgraph(self, tags: Iterable[str]) -> "Graph":
def subgraph(self, tags: Iterable[str]) -> Graph:
"""返回仅包含匹配任意标签的任务的新图。
依赖会被修剪仅保留被保留任务之间的边指向被丢弃任务的边
会被移除被保留的任务不再等待它们用于调试时运行大型
DAG 的切片
"""
wanted: Set[str] = set(tags)
kept: List[TaskSpec[object]] = []
wanted: set[str] = set(tags)
kept: list[TaskSpec[object]] = []
for spec in self._specs.values():
if wanted & set(spec.tags):
pruned_deps = tuple(
d for d in spec.depends_on if d in self._specs and (wanted & set(self._specs[d].tags))
d
for d in spec.depends_on
if d in self._specs and (wanted & set(self._specs[d].tags))
)
kept.append(
TaskSpec(
@@ -178,13 +180,13 @@ class Graph:
)
return Graph.from_specs(kept)
def subgraph_by_names(self, names: Iterable[str]) -> "Graph":
def subgraph_by_names(self, names: Iterable[str]) -> Graph:
"""返回限定于 ``names`` 的新图(边已修剪)。"""
wanted: Set[str] = set(names)
wanted: set[str] = set(names)
for n in wanted:
if n not in self._specs:
raise KeyError(f"Unknown task name: {n!r}")
kept: List[TaskSpec[object]] = []
kept: list[TaskSpec[object]] = []
for spec in self._specs.values():
if spec.name in wanted:
pruned_deps = tuple(d for d in spec.depends_on if d in wanted)
@@ -217,8 +219,10 @@ class Graph:
valid = {"TD", "TB", "BT", "LR", "RL"}
orientation = orientation.upper()
if orientation not in valid:
raise ValueError(f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}.")
lines: List[str] = [f"graph {orientation}"]
raise ValueError(
f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}."
)
lines: list[str] = [f"graph {orientation}"]
for name in self._specs:
lines.append(f' {name}["{name}"]')
for name, deps in self._deps.items():
@@ -231,7 +235,7 @@ class Graph:
# ------------------------------------------------------------------ #
def describe(self) -> str:
"""用于调试的人类可读多行摘要。"""
out: List[str] = [f"Graph(tasks={len(self._specs)})"]
out: list[str] = [f"Graph(tasks={len(self._specs)})"]
for layer_idx, layer in enumerate(self.layers(), 1):
out.append(f" Layer {layer_idx}: {layer}")
return "\n".join(out)
+6 -6
View File
@@ -7,7 +7,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List
from typing import Any, Iterator
from .task import TaskResult, TaskStatus
@@ -24,7 +24,7 @@ class RunReport:
当且仅当所有非跳过任务都以 ``SUCCESS`` 结束时为 ``True``
"""
results: Dict[str, TaskResult[object]] = field(default_factory=dict)
results: dict[str, TaskResult[object]] = field(default_factory=dict)
success: bool = True
# ---- 类型化访问 --------------------------------------------------- #
@@ -50,9 +50,9 @@ class RunReport:
return len(self.results)
# ---- 汇总 --------------------------------------------------------- #
def summary(self) -> Dict[str, Any]:
def summary(self) -> dict[str, Any]:
"""用于日志/仪表盘的紧凑统计字典。"""
counts: Dict[str, int] = {}
counts: dict[str, int] = {}
total_duration = 0.0
for r in self.results.values():
counts[r.status.value] = counts.get(r.status.value, 0) + 1
@@ -65,7 +65,7 @@ class RunReport:
"total_duration_seconds": round(total_duration, 6),
}
def failed_tasks(self) -> List[str]:
def failed_tasks(self) -> list[str]:
"""以 FAILED 状态结束的任务名列表。"""
return [
name for name, r in self.results.items() if r.status == TaskStatus.FAILED
@@ -73,7 +73,7 @@ class RunReport:
def describe(self) -> str:
"""用于调试的人类可读多行报告。"""
lines: List[str] = [f"RunReport(success={self.success})"]
lines: list[str] = [f"RunReport(success={self.success})"]
for name, r in self.results.items():
dur = f"{r.duration:.3f}s" if r.duration is not None else "-"
err = f" error={r.error!r}" if r.error else ""
+18 -14
View File
@@ -23,13 +23,13 @@ import argparse
import dataclasses
import enum
import sys
from typing import Dict, List, Optional, Sequence, Union
from typing import Sequence
from .errors import PyFlowXError
from .executors import Strategy, _normalize_strategy, run
from .graph import Graph
__all__ = ["CliRunner", "CliExitCode"]
__all__ = ["CliExitCode", "CliRunner"]
class CliExitCode(enum.IntEnum):
@@ -92,12 +92,16 @@ class CliRunner:
基本用法::
runner = px.CliRunner(
clean=px.Graph.from_specs([
px.TaskSpec("cargo_clean", cmd=["cargo", "clean"]),
]),
build=px.Graph.from_specs([
px.TaskSpec("uv_build", cmd=["uv", "build"]),
]),
clean=px.Graph.from_specs(
[
px.TaskSpec("cargo_clean", cmd=["cargo", "clean"]),
]
),
build=px.Graph.from_specs(
[
px.TaskSpec("uv_build", cmd=["uv", "build"]),
]
),
)
runner.run() # 解析 sys.argv
@@ -114,7 +118,7 @@ class CliRunner:
def __init__(
self,
*,
strategy: Union[str, Strategy] = Strategy.SEQUENTIAL,
strategy: str | Strategy = Strategy.SEQUENTIAL,
description: str = "",
verbose: bool = True,
**graphs: Graph,
@@ -127,7 +131,7 @@ class CliRunner:
raise TypeError(
f"CliRunner 命令 {name!r} 的值必须是 Graph 实例, 实际是 {type(graph).__name__}"
)
self._graphs: Dict[str, Graph] = dict(graphs)
self._graphs: dict[str, Graph] = dict(graphs)
self._strategy: Strategy = _normalize_strategy(strategy)
self._description: str = description
self._verbose: bool = verbose
@@ -136,12 +140,12 @@ class CliRunner:
# 内省
# ------------------------------------------------------------------ #
@property
def commands(self) -> List[str]:
def commands(self) -> list[str]:
"""可用的命令列表 (按插入顺序)."""
return list(self._graphs.keys())
@property
def graphs(self) -> Dict[str, Graph]:
def graphs(self) -> dict[str, Graph]:
"""命令名到图的映射 (只读副本)."""
return dict(self._graphs)
@@ -225,7 +229,7 @@ class CliRunner:
# ------------------------------------------------------------------ #
# 执行
# ------------------------------------------------------------------ #
def run(self, args: Optional[Sequence[str]] = None) -> int:
def run(self, args: Sequence[str] | None = None) -> int:
"""解析参数并执行对应的图.
Parameters
@@ -293,7 +297,7 @@ class CliRunner:
print(f"错误: {e}", file=sys.stderr)
return CliExitCode.FAILURE.value
def run_cli(self, args: Optional[Sequence[str]] = None) -> None:
def run_cli(self, args: Sequence[str] | None = None) -> None:
"""运行并以退出码退出进程.
作为 CLI 工具运行时的入口点, 等价于 ``sys.exit(self.run(args))``.
+24 -12
View File
@@ -17,9 +17,11 @@
from __future__ import annotations
import json
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, Mapping, Optional
from pathlib import Path
from typing import Any, Mapping
from typing_extensions import override
from .errors import StorageError
@@ -52,20 +54,25 @@ class MemoryBackend(StateBackend):
"""进程内 dict 后端。进程退出即丢失。"""
def __init__(self) -> None:
self._store: Dict[str, Any] = {}
self._store: dict[str, Any] = {}
@override
def load(self) -> Mapping[str, Any]:
return dict(self._store)
@override
def save(self, name: str, value: Any) -> None:
self._store[name] = value
@override
def has(self, name: str) -> bool:
return name in self._store
@override
def get(self, name: str) -> Any:
return self._store[name]
@override
def clear(self) -> None:
self._store.clear()
@@ -79,16 +86,16 @@ class JSONBackend(StateBackend):
"""
def __init__(self, path: str) -> None:
self._path = path
self._store: Dict[str, Any] = {}
self._path: str = path
self._store: dict[str, Any] = {}
self._load()
def _load(self) -> None:
if not os.path.exists(self._path):
if not Path(self._path).exists():
return
try:
with open(self._path, "r", encoding="utf-8") as fh:
data = json.load(fh)
with open(self._path, encoding="utf-8") as fh:
data: Any = json.load(fh)
if isinstance(data, dict):
self._store = data
except (OSError, json.JSONDecodeError) as exc:
@@ -99,13 +106,15 @@ class JSONBackend(StateBackend):
try:
with open(tmp, "w", encoding="utf-8") as fh:
json.dump(self._store, fh, ensure_ascii=False, indent=2)
os.replace(tmp, self._path)
Path(tmp).replace(Path(self._path))
except (OSError, TypeError) as exc:
raise StorageError(f"cannot write state file {self._path!r}", exc) from exc
@override
def load(self) -> Mapping[str, Any]:
return dict(self._store)
@override
def save(self, name: str, value: Any) -> None:
# 在修改内存状态前先校验可序列化性。
try:
@@ -115,19 +124,22 @@ class JSONBackend(StateBackend):
f"result of task {name!r} is not JSON-serialisable", exc
) from exc
self._store[name] = value
self._flush()
_ = self._flush()
@override
def has(self, name: str) -> bool:
return name in self._store
@override
def get(self, name: str) -> Any:
return self._store[name]
@override
def clear(self) -> None:
self._store.clear()
self._flush()
_ = self._flush()
def resolve_backend(backend: Optional[StateBackend]) -> StateBackend:
def resolve_backend(backend: StateBackend | None) -> StateBackend:
"""返回 ``backend``;为 ``None`` 时返回新的 :class:`MemoryBackend`。"""
return backend if backend is not None else MemoryBackend()
+16 -11
View File
@@ -30,6 +30,7 @@ from typing import (
Tuple,
TypeVar,
Union,
cast,
)
T = TypeVar("T")
@@ -48,7 +49,7 @@ Context = Mapping[str, Any]
TaskCmd = Union[
List[str], # 命令列表, 如 ["ls", "-la"]
str, # shell 命令字符串
Callable[..., T], # Python 函数
Callable[..., Any], # Python 函数
]
# 条件判断函数类型
@@ -151,12 +152,12 @@ class TaskSpec(Generic[T]):
return self.fn
raise ValueError(f"TaskSpec '{self.name}': 没有可执行的函数或命令。")
def _wrap_cmd(self) -> TaskFn[T]:
def _wrap_cmd(self) -> TaskFn[Any]:
"""将 cmd 包装为可执行函数.
Returns
-------
TaskFn[T]
TaskFn[Any]
包装后的执行函数.
"""
cmd = self.cmd
@@ -184,17 +185,19 @@ class TaskSpec(Generic[T]):
check=False,
)
except FileNotFoundError:
raise RuntimeError(f"命令未找到: {cmd_str}")
raise RuntimeError(f"命令未找到: {cmd_str}") from None
except subprocess.TimeoutExpired:
raise RuntimeError(f"命令执行超时: {cmd_str} ({timeout}s)")
raise RuntimeError(
f"命令执行超时: {cmd_str} ({timeout}s)"
) from None
except OSError as e:
raise RuntimeError(f"命令执行异常: {cmd_str}: {e}")
raise RuntimeError(f"命令执行异常: {cmd_str}: {e}") from e
if verbose:
print(f"[verbose] 返回码: {result.returncode}", flush=True)
if result.returncode == 0:
return None # type: ignore[return-value]
return cast(T, None) # type: ignore[return-value]
err_msg = f"命令执行失败: `{cmd_str}`, 返回码: {result.returncode}"
if not verbose and result.stderr.strip():
@@ -224,17 +227,19 @@ class TaskSpec(Generic[T]):
check=False,
)
except FileNotFoundError:
raise RuntimeError(f"Shell 命令未找到: {cmd}")
raise RuntimeError(f"Shell 命令未找到: {cmd}") from None
except subprocess.TimeoutExpired:
raise RuntimeError(f"Shell 命令执行超时: {cmd} ({timeout}s)")
raise RuntimeError(
f"Shell 命令执行超时: {cmd} ({timeout}s)"
) from None
except OSError as e:
raise RuntimeError(f"Shell 命令执行异常: {cmd}: {e}")
raise RuntimeError(f"Shell 命令执行异常: {cmd}: {e}") from e
if verbose:
print(f"[verbose] 返回码: {result.returncode}", flush=True)
if result.returncode == 0:
return None # type: ignore[return-value]
return cast(T, None) # type: ignore[return-value]
err_msg = f"Shell 命令执行失败: `{cmd}`, 返回码: {result.returncode}"
if not verbose and result.stderr.strip():