refactor: 全面迁移至 Python 3.9+ 原生泛型类型语法

- 将所有 `Optional[T]` 替换为 `T | None`
- 将所有 `List[T]`/`Dict[K, V]`/`Tuple[Ts, ...]` 替换为对应原生泛型
- 调整类型导入,移除冗余的 typing 导入项
- 更新项目依赖,添加 typing-extensions 兼容旧版本 Python
- 重构部分函数签名与内部实现以匹配新类型语法
This commit is contained in:
2026-06-20 17:52:42 +08:00
parent c06d0284c4
commit 08eb743ea9
18 changed files with 962 additions and 177 deletions
+47 -1
View File
@@ -10,7 +10,10 @@ classifiers = [
"Programming Language :: Python :: 3.9",
"Topic :: Software Development :: Libraries :: Application Frameworks",
]
dependencies = ["graphlib_backport >= 1.0.0; python_version < '3.9'"]
dependencies = [
"graphlib_backport >= 1.0.0; python_version < '3.9'",
"typing-extensions>=4.13.2",
]
description = "Lightweight, type-safe DAG task scheduler with multi-strategy execution."
keywords = ["async", "dag", "scheduler", "task", "workflow"]
license = { text = "MIT" }
@@ -86,3 +89,46 @@ reportImplicitStringConcatenation = "error"
reportMissingTypeStubs = "none"
reportUnusedCallResult = "warning"
typeCheckingMode = "recommended" # 类型检查严格度:off / basic / standard / recommended(默认) / strict / all
# Ruff 配置 - 与 .pre-commit-config.yaml 保持一致
[tool.ruff]
target-version = "py38"
line-length = 88
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # Pyflakes
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"UP", # pyupgrade
"ARG", # flake8-unused-arguments
"SIM", # flake8-simplify
"PTH", # flake8-use-pathlib
"PL", # Pylint
"RUF", # Ruff-specific rules
]
ignore = [
"E501", # line too long (handled by formatter)
"PLR0913", # too many arguments
"PLR2004", # magic value comparison
"PTH123", # pathlib open() replacement
"SIM108", # use ternary operator
"RUF001", # ambiguous unicode characters in string
"RUF002", # ambiguous unicode characters in docstring
"RUF003", # ambiguous unicode characters in comment
"RUF012", # mutable class attributes (intentional for config)
"PLC0415", # import should be at top-level (intentional for lazy imports)
"PLR0915", # too many statements (intentional for complex methods)
"PTH119", # os.path.basename (intentional for sys.argv)
]
[tool.ruff.lint.isort]
known-first-party = ["pyflowx"]
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
docstring-code-format = true
+33 -33
View File
@@ -81,43 +81,43 @@ from .task import TaskCmd, TaskEvent, TaskResult, TaskSpec, TaskStatus
__version__ = "0.1.2"
__all__ = [
# 核心类型
"TaskSpec",
"TaskStatus",
"TaskResult",
"TaskEvent",
"Context",
"TaskCmd",
"Graph",
"RunReport",
# 执行
"run",
"Strategy",
# CLI 运行器
"CliRunner",
"CliExitCode",
# 状态后端
"StateBackend",
"MemoryBackend",
"JSONBackend",
# 错误
"PyFlowXError",
"DuplicateTaskError",
"MissingDependencyError",
"CycleError",
"TaskFailedError",
"TaskTimeoutError",
"InjectionError",
"StorageError",
# 条件判断
"Condition",
"Constants",
"BuiltinConditions",
"IS_WINDOWS",
"IS_LINUX",
"IS_MACOS",
"IS_POSIX",
"IS_WINDOWS",
"BuiltinConditions",
"CliExitCode",
# CLI 运行器
"CliRunner",
# 条件判断
"Condition",
"Constants",
"Context",
"CycleError",
"DuplicateTaskError",
"Graph",
"InjectionError",
"JSONBackend",
"MemoryBackend",
"MissingDependencyError",
# 错误
"PyFlowXError",
"RunReport",
# 状态后端
"StateBackend",
"StorageError",
"Strategy",
"TaskCmd",
"TaskEvent",
"TaskFailedError",
"TaskResult",
# 核心类型
"TaskSpec",
"TaskStatus",
"TaskTimeoutError",
# 辅助(高级)
"build_call_args",
"describe_injection",
# 执行
"run",
]
+2 -2
View File
@@ -8,7 +8,7 @@ from __future__ import annotations
import shutil
import sys
from typing import Callable, Optional
from typing import Callable
# 条件判断函数类型
Condition = Callable[[], bool]
@@ -47,7 +47,7 @@ class BuiltinConditions:
return Constants.IS_POSIX
@staticmethod
def PYTHON_VERSION(major: int, minor: Optional[int] = None) -> bool:
def PYTHON_VERSION(major: int, minor: int | None = None) -> bool:
"""检查 Python 版本是否匹配.
Parameters
+11 -11
View File
@@ -18,12 +18,12 @@ DAG 库中泛滥的样板包装器(如 ``def wrapper(): return fn(workflow.get
from __future__ import annotations
import inspect
from typing import Any, Dict, List, Mapping, Set, Tuple
from typing import Any, Mapping
from .errors import InjectionError
from .task import Context, TaskSpec
__all__ = ["Context", "build_call_args", "describe_injection", "_is_context_annotation"]
__all__ = ["Context", "_is_context_annotation", "build_call_args", "describe_injection"]
def _is_context_annotation(annotation: Any) -> bool:
@@ -43,15 +43,13 @@ def _is_context_annotation(annotation: Any) -> bool:
return annotation == "Context" or annotation.endswith(".Context")
# 按限定名匹配,支持 ``from pyflowx import Context`` 再导出。
name = getattr(annotation, "__name__", None) or getattr(annotation, "_name", None)
if name in ("Context", "Mapping"):
return True
return False
return name in ("Context", "Mapping")
def build_call_args(
spec: TaskSpec[object],
context: Mapping[str, Any],
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
) -> tuple[tuple[Any, ...], dict[str, Any]]:
"""解析用于调用 ``spec.fn`` 的 ``(args, kwargs)``。
参数
@@ -84,7 +82,9 @@ def build_call_args(
)
# 与本任务相关的上下文子集。
dep_context: Dict[str, Any] = {name: context[name] for name in spec.depends_on if name in context}
dep_context: dict[str, Any] = {
name: context[name] for name in spec.depends_on if name in context
}
# 检测静态 kwargs 与依赖名的冲突。
collisions = set(spec.kwargs) & set(dep_context)
@@ -95,12 +95,12 @@ def build_call_args(
"rename the static kwarg or the dependency.",
)
injected_kwargs: Dict[str, Any] = {}
leftover_dep_results: Dict[str, Any] = dict(dep_context)
injected_kwargs: dict[str, Any] = {}
leftover_dep_results: dict[str, Any] = dict(dep_context)
# 被 spec.args 消费的位置参数。记录哪些参数名已被位置填充,
# 以便在基于名称的注入(依赖 / Context / 静态 kwargs)时跳过。
positional_params: List[str] = []
positional_params: list[str] = []
positional_kinds = (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
@@ -109,7 +109,7 @@ def build_call_args(
if param.kind in positional_kinds:
positional_params.append(pname)
# 前 len(spec.args) 个位置参数由 spec.args 填充。
args_filled: Set[str] = set(positional_params[: len(spec.args)])
args_filled: set[str] = set(positional_params[: len(spec.args)])
for pname, param in params.items():
# 跳过已被位置 spec.args 填充的参数。
+6 -4
View File
@@ -6,7 +6,7 @@
from __future__ import annotations
from typing import Any, Iterable, Optional
from typing import Any, Iterable
class PyFlowXError(Exception):
@@ -55,10 +55,12 @@ class TaskFailedError(PyFlowXError):
task: str,
cause: BaseException,
attempts: int,
layer: Optional[int] = None,
layer: int | None = None,
) -> None:
location = f" (layer {layer})" if layer is not None else ""
super().__init__(f"Task '{task}' failed after {attempts} attempt(s){location}: {cause}")
super().__init__(
f"Task '{task}' failed after {attempts} attempt(s){location}: {cause}"
)
self.task = task
self.cause = cause
self.attempts = attempts
@@ -85,6 +87,6 @@ class InjectionError(PyFlowXError):
class StorageError(PyFlowXError):
"""状态后端在持久化失败时抛出。"""
def __init__(self, detail: str, cause: Optional[BaseException] = None) -> None:
def __init__(self, detail: str, cause: BaseException | None = None) -> None:
super().__init__(f"State storage error: {detail}")
self.cause: Any = cause
+4 -4
View File
@@ -10,7 +10,7 @@ Shows:
from __future__ import annotations
import asyncio
from typing import Any, Dict, List
from typing import Any
import pyflowx as px
@@ -20,13 +20,13 @@ async def fetch_user(uid: int) -> dict:
return {"id": uid, "name": f"User{uid}"}
async def fetch_posts(uid: int) -> List[int]:
async def fetch_posts(uid: int) -> list[int]:
await asyncio.sleep(0.2)
return [uid, uid + 1]
# Context annotation → receives the full mapping of upstream results.
def aggregate(ctx: px.Context) -> Dict[str, Any]:
def aggregate(ctx: px.Context) -> dict[str, Any]:
return dict(ctx)
@@ -43,7 +43,7 @@ def main() -> None:
print("=== Dry run ===")
px.run(graph, strategy="async", dry_run=True)
events: List[px.TaskEvent] = []
events: list[px.TaskEvent] = []
print("\n=== Async execution ===")
report = px.run(graph, strategy="async", on_event=events.append)
+6 -8
View File
@@ -10,21 +10,19 @@ Demonstrates the core PyFlowX workflow:
from __future__ import annotations
from typing import List
import pyflowx as px
# --- task functions: pure, testable, no framework coupling ------------- #
def extract_customers() -> List[dict]:
def extract_customers() -> list[dict]:
return [
{"id": "C001", "name": "Alice"},
{"id": "C002", "name": "Bob"},
]
def extract_orders() -> List[dict]:
def extract_orders() -> list[dict]:
return [
{"id": "O001", "customer_id": "C001", "amount": 150.0},
{"id": "O002", "customer_id": "C002", "amount": 200.5},
@@ -33,9 +31,9 @@ def extract_orders() -> List[dict]:
# Parameter names match dependency names → automatic injection.
def transform(
extract_customers: List[dict],
extract_orders: List[dict],
) -> List[dict]:
extract_customers: list[dict],
extract_orders: list[dict],
) -> list[dict]:
cmap = {c["id"]: c for c in extract_customers}
return [
{**o, "customer_name": cmap[o["customer_id"]]["name"]}
@@ -44,7 +42,7 @@ def transform(
]
def load(transform: List[dict]) -> int:
def load(transform: list[dict]) -> int:
print(f" loaded {len(transform)} records")
return len(transform)
+74 -52
View File
@@ -20,7 +20,7 @@ import enum
import inspect
import logging
from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Union, cast
from typing import Any, Awaitable, Callable, Mapping, cast
from .context import build_call_args, describe_injection
from .errors import TaskFailedError, TaskTimeoutError
@@ -53,7 +53,7 @@ class Strategy(enum.Enum):
ASYNC = "async"
def _normalize_strategy(strategy: Union[str, Strategy]) -> Strategy:
def _normalize_strategy(strategy: str | Strategy) -> Strategy:
"""将字符串或 Strategy 归一化为 Strategy 枚举.
Parameters
@@ -79,7 +79,9 @@ def _normalize_strategy(strategy: Union[str, Strategy]) -> Strategy:
return Strategy(strategy)
except ValueError:
valid = ", ".join(repr(s.value) for s in Strategy)
raise ValueError(f"unknown strategy {strategy!r}; expected one of {valid}.") from None
raise ValueError(
f"unknown strategy {strategy!r}; expected one of {valid}."
) from None
raise TypeError(f"strategy must be str or Strategy, got {type(strategy).__name__}")
@@ -89,7 +91,7 @@ def _is_async_fn(spec: TaskSpec[object]) -> bool:
def _emit(
on_event: Optional[EventCallback],
on_event: EventCallback | None,
result: TaskResult[object],
) -> None:
"""若注册了回调则触发一个观察者事件。"""
@@ -106,7 +108,9 @@ def _emit(
)
def _log_retry(spec: TaskSpec[object], attempts: int, max_attempts: int, exc: BaseException) -> None:
def _log_retry(
spec: TaskSpec[object], attempts: int, max_attempts: int, exc: BaseException
) -> None:
"""记录重试日志(sync 与 async 共享,便于测试覆盖)。"""
logger.warning(
"task %r failed (attempt %d/%d): %r; retrying",
@@ -117,7 +121,7 @@ def _log_retry(spec: TaskSpec[object], attempts: int, max_attempts: int, exc: Ba
)
def _finalize_failure(result: TaskResult[object], layer_idx: Optional[int]) -> None:
def _finalize_failure(result: TaskResult[object], layer_idx: int | None) -> None:
"""标记任务为 FAILED 并抛出 TaskFailedError。"""
result.status = TaskStatus.FAILED
result.finished_at = datetime.now()
@@ -132,7 +136,7 @@ def _finalize_failure(result: TaskResult[object], layer_idx: Optional[int]) -> N
def _run_sync_with_retry(
spec: TaskSpec[object],
context: Mapping[str, Any],
layer_idx: Optional[int],
layer_idx: int | None,
) -> TaskResult[object]:
"""执行同步任务并带重试;返回填充好的 TaskResult。"""
result: TaskResult[object] = TaskResult(spec=spec)
@@ -155,7 +159,7 @@ def _run_sync_with_retry(
result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now()
return result
except Exception as exc: # noqa: BLE001 - 用户代码可能抛任何异常
except Exception as exc:
result.error = exc
if result.attempts >= max_attempts:
_finalize_failure(result, layer_idx) # pragma: no cover
@@ -166,7 +170,7 @@ def _run_sync_with_retry(
async def _run_async_with_retry(
spec: TaskSpec[object],
context: Mapping[str, Any],
layer_idx: Optional[int],
layer_idx: int | None,
) -> TaskResult[object]:
"""在事件循环上执行任务(同步或异步)并带重试。"""
result: TaskResult[object] = TaskResult(spec=spec)
@@ -198,7 +202,9 @@ async def _run_async_with_retry(
return spec.effective_fn(*args, **kwargs)
if spec.timeout is not None:
result.value = await asyncio.wait_for(loop.run_in_executor(None, fn_call), timeout=spec.timeout)
result.value = await asyncio.wait_for(
loop.run_in_executor(None, fn_call), timeout=spec.timeout
)
else:
result.value = await loop.run_in_executor(None, fn_call)
result.status = TaskStatus.SUCCESS
@@ -214,7 +220,7 @@ async def _run_async_with_retry(
result.attempts,
max_attempts,
)
except Exception as exc: # noqa: BLE001
except Exception as exc:
result.error = exc
if result.attempts >= max_attempts:
_finalize_failure(result, layer_idx) # pragma: no cover
@@ -230,17 +236,19 @@ def _build_context(
global_context: Mapping[str, Any],
) -> Mapping[str, Any]:
"""将全局上下文限制为本任务的依赖。"""
return {dep: global_context[dep] for dep in spec.depends_on if dep in global_context}
return {
dep: global_context[dep] for dep in spec.depends_on if dep in global_context
}
def _execute_layer_sequential(
layer: List[str],
layer: list[str],
graph: Graph,
context: Dict[str, Any],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
layer_idx: int,
on_event: Optional[EventCallback],
on_event: EventCallback | None,
) -> None:
"""逐个运行某层的任务。"""
for name in layer:
@@ -261,23 +269,25 @@ def _execute_layer_sequential(
def _execute_layer_threaded(
layer: List[str],
layer: list[str],
graph: Graph,
context: Dict[str, Any],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
layer_idx: int,
on_event: Optional[EventCallback],
on_event: EventCallback | None,
max_workers: int,
) -> None:
"""在线程池中并发运行某层的任务。"""
# 先同步满足已缓存任务。
to_run: List[str] = []
to_run: list[str] = []
for name in layer:
if backend.has(name):
cached = backend.get(name)
context[name] = cached
result = TaskResult(spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached)
result = TaskResult(
spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached
)
report.results[name] = result
_emit(on_event, result)
else:
@@ -287,7 +297,7 @@ def _execute_layer_threaded(
return
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
future_to_name: Dict[concurrent.futures.Future[TaskResult[object]], str] = {}
future_to_name: dict[concurrent.futures.Future[TaskResult[object]], str] = {}
for name in to_run:
spec = graph.spec(name)
# 为本任务快照上下文以避免竞态。
@@ -305,21 +315,23 @@ def _execute_layer_threaded(
async def _execute_layer_async(
layer: List[str],
layer: list[str],
graph: Graph,
context: Dict[str, Any],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
layer_idx: int,
on_event: Optional[EventCallback],
on_event: EventCallback | None,
) -> None:
"""在事件循环上并发运行某层的任务。"""
to_run: List[str] = []
to_run: list[str] = []
for name in layer:
if backend.has(name):
cached = backend.get(name)
context[name] = cached
result = TaskResult(spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached)
result = TaskResult(
spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached
)
report.results[name] = result
_emit(on_event, result)
else:
@@ -346,8 +358,8 @@ async def _execute_layer_async(
# 公共 API
# ---------------------------------------------------------------------- #
def _make_verbose_callback(
on_event: Optional[EventCallback],
) -> Optional[EventCallback]:
on_event: EventCallback | None,
) -> EventCallback | None:
"""包装 on_event 回调, 在 verbose 模式下打印任务生命周期.
Parameters
@@ -385,13 +397,13 @@ def _make_verbose_callback(
def run(
graph: Graph,
strategy: Union[str, Strategy] = Strategy.SEQUENTIAL,
strategy: str | Strategy = Strategy.SEQUENTIAL,
*,
max_workers: Optional[int] = None,
max_workers: int | None = None,
dry_run: bool = False,
verbose: bool = False,
on_event: Optional[EventCallback] = None,
state: Optional[StateBackend] = None,
on_event: EventCallback | None = None,
state: StateBackend | None = None,
) -> RunReport:
"""执行图并返回 :class:`RunReport`。
@@ -434,17 +446,23 @@ def run(
return RunReport(success=True)
# verbose 模式下包装事件回调
effective_callback: Optional[EventCallback] = _make_verbose_callback(on_event) if verbose else on_event
effective_callback: EventCallback | None = (
_make_verbose_callback(on_event) if verbose else on_event
)
backend = resolve_backend(state)
report = RunReport()
context: Dict[str, Any] = {}
context: dict[str, Any] = {}
try:
if normalized == Strategy.SEQUENTIAL:
_drive_sequential(graph, layers, context, report, backend, effective_callback)
_drive_sequential(
graph, layers, context, report, backend, effective_callback
)
elif normalized == Strategy.THREAD:
_drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers)
_drive_threaded(
graph, layers, context, report, backend, effective_callback, max_workers
)
else:
_drive_async(graph, layers, context, report, backend, effective_callback)
except TaskFailedError:
@@ -454,7 +472,7 @@ def run(
return report
def _print_dry_run(graph: Graph, layers: List[List[str]]) -> None:
def _print_dry_run(graph: Graph, layers: list[list[str]]) -> None:
"""打印执行计划但不运行任何任务。"""
print(f"Dry run: {len(graph)} tasks, {len(layers)} layers")
for idx, layer in enumerate(layers, 1):
@@ -465,11 +483,11 @@ def _print_dry_run(graph: Graph, layers: List[List[str]]) -> None:
def _drive_sequential(
graph: Graph,
layers: List[List[str]],
context: Dict[str, Any],
layers: list[list[str]],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: Optional[EventCallback],
on_event: EventCallback | None,
) -> None:
for idx, layer in enumerate(layers, 1):
_execute_layer_sequential(layer, graph, context, report, backend, idx, on_event)
@@ -477,36 +495,40 @@ def _drive_sequential(
def _drive_threaded(
graph: Graph,
layers: List[List[str]],
context: Dict[str, Any],
layers: list[list[str]],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: Optional[EventCallback],
max_workers: Optional[int],
on_event: EventCallback | None,
max_workers: int | None,
) -> None:
for idx, layer in enumerate(layers, 1):
workers = max_workers or max(1, min(32, len(layer)))
_execute_layer_threaded(layer, graph, context, report, backend, idx, on_event, workers)
_execute_layer_threaded(
layer, graph, context, report, backend, idx, on_event, workers
)
def _drive_async(
graph: Graph,
layers: List[List[str]],
context: Dict[str, Any],
layers: list[list[str]],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: Optional[EventCallback],
on_event: EventCallback | None,
) -> None:
asyncio.run(_async_drive(graph, layers, context, report, backend, on_event))
async def _async_drive(
graph: Graph,
layers: List[List[str]],
context: Dict[str, Any],
layers: list[list[str]],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: Optional[EventCallback],
on_event: EventCallback | None,
) -> None:
for idx, layer in enumerate(layers, 1):
await _execute_layer_async(layer, graph, context, report, backend, idx, on_event)
await _execute_layer_async(
layer, graph, context, report, backend, idx, on_event
)
+23 -19
View File
@@ -8,7 +8,7 @@
from __future__ import annotations
import sys
from typing import Dict, Iterable, List, Mapping, Sequence, Set, Tuple
from typing import Iterable, Mapping, Sequence
from typing_extensions import override
@@ -38,14 +38,14 @@ class Graph:
"""
def __init__(self) -> None:
self._specs: Dict[str, TaskSpec[object]] = {}
self._specs: dict[str, TaskSpec[object]] = {}
# 任务 -> 其直接依赖(前驱)。
self._deps: Dict[str, Tuple[str, ...]] = {}
self._deps: dict[str, tuple[str, ...]] = {}
# ------------------------------------------------------------------ #
# 构建
# ------------------------------------------------------------------ #
def add(self, spec: TaskSpec[object]) -> "Graph":
def add(self, spec: TaskSpec[object]) -> Graph:
"""注册一个任务 spec,并即时校验。
返回 ``self`` 以支持链式调用,但推荐入口是 :meth:`from_specs`
@@ -60,7 +60,7 @@ class Graph:
return self
@classmethod
def from_specs(cls, specs: Iterable[TaskSpec[object]]) -> "Graph":
def from_specs(cls, specs: Iterable[TaskSpec[object]]) -> Graph:
"""从可迭代的 task spec 构建图。
先收集所有 spec,再统一校验。这意味着任务可以引用*后出现*的
@@ -107,7 +107,7 @@ class Graph:
# 内省
# ------------------------------------------------------------------ #
@property
def names(self) -> List[str]:
def names(self) -> list[str]:
"""所有已注册任务名(按插入顺序)。"""
return list(self._specs.keys())
@@ -115,7 +115,7 @@ class Graph:
"""返回 ``name`` 的 spec;不存在则 ``KeyError``。"""
return self._specs[name]
def dependencies(self, name: str) -> Tuple[str, ...]:
def dependencies(self, name: str) -> tuple[str, ...]:
"""``name`` 的直接前驱。"""
return self._deps[name]
@@ -123,7 +123,7 @@ class Graph:
"""name -> spec 的只读视图。"""
return self._specs
def layers(self) -> List[List[str]]:
def layers(self) -> list[list[str]]:
"""将任务分组为可并行执行的层(Kahn 算法)。
同层任务无相互依赖,可并发执行。层按执行顺序返回。
@@ -132,7 +132,7 @@ class Graph:
"""
self.validate()
sorter = _TopologicalSorter(self._deps)
result: List[List[str]] = []
result: list[list[str]] = []
# ``get_ready`` + ``done`` 每次给出一层,正好是并行执行所需的分组。
sorter.prepare()
while sorter.is_active():
@@ -147,19 +147,21 @@ class Graph:
# ------------------------------------------------------------------ #
# 子图 / 标签过滤
# ------------------------------------------------------------------ #
def subgraph(self, tags: Iterable[str]) -> "Graph":
def subgraph(self, tags: Iterable[str]) -> Graph:
"""返回仅包含匹配任意标签的任务的新图。
依赖会被修剪,仅保留被保留任务之间的边;指向被丢弃任务的边
会被移除(被保留的任务不再等待它们)。用于调试时运行大型
DAG 的切片。
"""
wanted: Set[str] = set(tags)
kept: List[TaskSpec[object]] = []
wanted: set[str] = set(tags)
kept: list[TaskSpec[object]] = []
for spec in self._specs.values():
if wanted & set(spec.tags):
pruned_deps = tuple(
d for d in spec.depends_on if d in self._specs and (wanted & set(self._specs[d].tags))
d
for d in spec.depends_on
if d in self._specs and (wanted & set(self._specs[d].tags))
)
kept.append(
TaskSpec(
@@ -178,13 +180,13 @@ class Graph:
)
return Graph.from_specs(kept)
def subgraph_by_names(self, names: Iterable[str]) -> "Graph":
def subgraph_by_names(self, names: Iterable[str]) -> Graph:
"""返回限定于 ``names`` 的新图(边已修剪)。"""
wanted: Set[str] = set(names)
wanted: set[str] = set(names)
for n in wanted:
if n not in self._specs:
raise KeyError(f"Unknown task name: {n!r}")
kept: List[TaskSpec[object]] = []
kept: list[TaskSpec[object]] = []
for spec in self._specs.values():
if spec.name in wanted:
pruned_deps = tuple(d for d in spec.depends_on if d in wanted)
@@ -217,8 +219,10 @@ class Graph:
valid = {"TD", "TB", "BT", "LR", "RL"}
orientation = orientation.upper()
if orientation not in valid:
raise ValueError(f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}.")
lines: List[str] = [f"graph {orientation}"]
raise ValueError(
f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}."
)
lines: list[str] = [f"graph {orientation}"]
for name in self._specs:
lines.append(f' {name}["{name}"]')
for name, deps in self._deps.items():
@@ -231,7 +235,7 @@ class Graph:
# ------------------------------------------------------------------ #
def describe(self) -> str:
"""用于调试的人类可读多行摘要。"""
out: List[str] = [f"Graph(tasks={len(self._specs)})"]
out: list[str] = [f"Graph(tasks={len(self._specs)})"]
for layer_idx, layer in enumerate(self.layers(), 1):
out.append(f" Layer {layer_idx}: {layer}")
return "\n".join(out)
+6 -6
View File
@@ -7,7 +7,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List
from typing import Any, Iterator
from .task import TaskResult, TaskStatus
@@ -24,7 +24,7 @@ class RunReport:
当且仅当所有非跳过任务都以 ``SUCCESS`` 结束时为 ``True``。
"""
results: Dict[str, TaskResult[object]] = field(default_factory=dict)
results: dict[str, TaskResult[object]] = field(default_factory=dict)
success: bool = True
# ---- 类型化访问 --------------------------------------------------- #
@@ -50,9 +50,9 @@ class RunReport:
return len(self.results)
# ---- 汇总 --------------------------------------------------------- #
def summary(self) -> Dict[str, Any]:
def summary(self) -> dict[str, Any]:
"""用于日志/仪表盘的紧凑统计字典。"""
counts: Dict[str, int] = {}
counts: dict[str, int] = {}
total_duration = 0.0
for r in self.results.values():
counts[r.status.value] = counts.get(r.status.value, 0) + 1
@@ -65,7 +65,7 @@ class RunReport:
"total_duration_seconds": round(total_duration, 6),
}
def failed_tasks(self) -> List[str]:
def failed_tasks(self) -> list[str]:
"""以 FAILED 状态结束的任务名列表。"""
return [
name for name, r in self.results.items() if r.status == TaskStatus.FAILED
@@ -73,7 +73,7 @@ class RunReport:
def describe(self) -> str:
"""用于调试的人类可读多行报告。"""
lines: List[str] = [f"RunReport(success={self.success})"]
lines: list[str] = [f"RunReport(success={self.success})"]
for name, r in self.results.items():
dur = f"{r.duration:.3f}s" if r.duration is not None else "-"
err = f" error={r.error!r}" if r.error else ""
+16 -12
View File
@@ -23,13 +23,13 @@ import argparse
import dataclasses
import enum
import sys
from typing import Dict, List, Optional, Sequence, Union
from typing import Sequence
from .errors import PyFlowXError
from .executors import Strategy, _normalize_strategy, run
from .graph import Graph
__all__ = ["CliRunner", "CliExitCode"]
__all__ = ["CliExitCode", "CliRunner"]
class CliExitCode(enum.IntEnum):
@@ -92,12 +92,16 @@ class CliRunner:
基本用法::
runner = px.CliRunner(
clean=px.Graph.from_specs([
clean=px.Graph.from_specs(
[
px.TaskSpec("cargo_clean", cmd=["cargo", "clean"]),
]),
build=px.Graph.from_specs([
]
),
build=px.Graph.from_specs(
[
px.TaskSpec("uv_build", cmd=["uv", "build"]),
]),
]
),
)
runner.run() # 解析 sys.argv
@@ -114,7 +118,7 @@ class CliRunner:
def __init__(
self,
*,
strategy: Union[str, Strategy] = Strategy.SEQUENTIAL,
strategy: str | Strategy = Strategy.SEQUENTIAL,
description: str = "",
verbose: bool = True,
**graphs: Graph,
@@ -127,7 +131,7 @@ class CliRunner:
raise TypeError(
f"CliRunner 命令 {name!r} 的值必须是 Graph 实例, 实际是 {type(graph).__name__}"
)
self._graphs: Dict[str, Graph] = dict(graphs)
self._graphs: dict[str, Graph] = dict(graphs)
self._strategy: Strategy = _normalize_strategy(strategy)
self._description: str = description
self._verbose: bool = verbose
@@ -136,12 +140,12 @@ class CliRunner:
# 内省
# ------------------------------------------------------------------ #
@property
def commands(self) -> List[str]:
def commands(self) -> list[str]:
"""可用的命令列表 (按插入顺序)."""
return list(self._graphs.keys())
@property
def graphs(self) -> Dict[str, Graph]:
def graphs(self) -> dict[str, Graph]:
"""命令名到图的映射 (只读副本)."""
return dict(self._graphs)
@@ -225,7 +229,7 @@ class CliRunner:
# ------------------------------------------------------------------ #
# 执行
# ------------------------------------------------------------------ #
def run(self, args: Optional[Sequence[str]] = None) -> int:
def run(self, args: Sequence[str] | None = None) -> int:
"""解析参数并执行对应的图.
Parameters
@@ -293,7 +297,7 @@ class CliRunner:
print(f"错误: {e}", file=sys.stderr)
return CliExitCode.FAILURE.value
def run_cli(self, args: Optional[Sequence[str]] = None) -> None:
def run_cli(self, args: Sequence[str] | None = None) -> None:
"""运行并以退出码退出进程.
作为 CLI 工具运行时的入口点, 等价于 ``sys.exit(self.run(args))``.
+24 -12
View File
@@ -17,9 +17,11 @@
from __future__ import annotations
import json
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, Mapping, Optional
from pathlib import Path
from typing import Any, Mapping
from typing_extensions import override
from .errors import StorageError
@@ -52,20 +54,25 @@ class MemoryBackend(StateBackend):
"""进程内 dict 后端。进程退出即丢失。"""
def __init__(self) -> None:
self._store: Dict[str, Any] = {}
self._store: dict[str, Any] = {}
@override
def load(self) -> Mapping[str, Any]:
return dict(self._store)
@override
def save(self, name: str, value: Any) -> None:
self._store[name] = value
@override
def has(self, name: str) -> bool:
return name in self._store
@override
def get(self, name: str) -> Any:
return self._store[name]
@override
def clear(self) -> None:
self._store.clear()
@@ -79,16 +86,16 @@ class JSONBackend(StateBackend):
"""
def __init__(self, path: str) -> None:
self._path = path
self._store: Dict[str, Any] = {}
self._path: str = path
self._store: dict[str, Any] = {}
self._load()
def _load(self) -> None:
if not os.path.exists(self._path):
if not Path(self._path).exists():
return
try:
with open(self._path, "r", encoding="utf-8") as fh:
data = json.load(fh)
with open(self._path, encoding="utf-8") as fh:
data: Any = json.load(fh)
if isinstance(data, dict):
self._store = data
except (OSError, json.JSONDecodeError) as exc:
@@ -99,13 +106,15 @@ class JSONBackend(StateBackend):
try:
with open(tmp, "w", encoding="utf-8") as fh:
json.dump(self._store, fh, ensure_ascii=False, indent=2)
os.replace(tmp, self._path)
Path(tmp).replace(Path(self._path))
except (OSError, TypeError) as exc:
raise StorageError(f"cannot write state file {self._path!r}", exc) from exc
@override
def load(self) -> Mapping[str, Any]:
return dict(self._store)
@override
def save(self, name: str, value: Any) -> None:
# 在修改内存状态前先校验可序列化性。
try:
@@ -115,19 +124,22 @@ class JSONBackend(StateBackend):
f"result of task {name!r} is not JSON-serialisable", exc
) from exc
self._store[name] = value
self._flush()
_ = self._flush()
@override
def has(self, name: str) -> bool:
return name in self._store
@override
def get(self, name: str) -> Any:
return self._store[name]
@override
def clear(self) -> None:
self._store.clear()
self._flush()
_ = self._flush()
def resolve_backend(backend: Optional[StateBackend]) -> StateBackend:
def resolve_backend(backend: StateBackend | None) -> StateBackend:
"""返回 ``backend``;为 ``None`` 时返回新的 :class:`MemoryBackend`。"""
return backend if backend is not None else MemoryBackend()
+16 -11
View File
@@ -30,6 +30,7 @@ from typing import (
Tuple,
TypeVar,
Union,
cast,
)
T = TypeVar("T")
@@ -48,7 +49,7 @@ Context = Mapping[str, Any]
TaskCmd = Union[
List[str], # 命令列表, 如 ["ls", "-la"]
str, # shell 命令字符串
Callable[..., T], # Python 函数
Callable[..., Any], # Python 函数
]
# 条件判断函数类型
@@ -151,12 +152,12 @@ class TaskSpec(Generic[T]):
return self.fn
raise ValueError(f"TaskSpec '{self.name}': 没有可执行的函数或命令。")
def _wrap_cmd(self) -> TaskFn[T]:
def _wrap_cmd(self) -> TaskFn[Any]:
"""将 cmd 包装为可执行函数.
Returns
-------
TaskFn[T]
TaskFn[Any]
包装后的执行函数.
"""
cmd = self.cmd
@@ -184,17 +185,19 @@ class TaskSpec(Generic[T]):
check=False,
)
except FileNotFoundError:
raise RuntimeError(f"命令未找到: {cmd_str}")
raise RuntimeError(f"命令未找到: {cmd_str}") from None
except subprocess.TimeoutExpired:
raise RuntimeError(f"命令执行超时: {cmd_str} ({timeout}s)")
raise RuntimeError(
f"命令执行超时: {cmd_str} ({timeout}s)"
) from None
except OSError as e:
raise RuntimeError(f"命令执行异常: {cmd_str}: {e}")
raise RuntimeError(f"命令执行异常: {cmd_str}: {e}") from e
if verbose:
print(f"[verbose] 返回码: {result.returncode}", flush=True)
if result.returncode == 0:
return None # type: ignore[return-value]
return cast(T, None) # type: ignore[return-value]
err_msg = f"命令执行失败: `{cmd_str}`, 返回码: {result.returncode}"
if not verbose and result.stderr.strip():
@@ -224,17 +227,19 @@ class TaskSpec(Generic[T]):
check=False,
)
except FileNotFoundError:
raise RuntimeError(f"Shell 命令未找到: {cmd}")
raise RuntimeError(f"Shell 命令未找到: {cmd}") from None
except subprocess.TimeoutExpired:
raise RuntimeError(f"Shell 命令执行超时: {cmd} ({timeout}s)")
raise RuntimeError(
f"Shell 命令执行超时: {cmd} ({timeout}s)"
) from None
except OSError as e:
raise RuntimeError(f"Shell 命令执行异常: {cmd}: {e}")
raise RuntimeError(f"Shell 命令执行异常: {cmd}: {e}") from e
if verbose:
print(f"[verbose] 返回码: {result.returncode}", flush=True)
if result.returncode == 0:
return None # type: ignore[return-value]
return cast(T, None) # type: ignore[return-value]
err_msg = f"Shell 命令执行失败: `{cmd}`, 返回码: {result.returncode}"
if not verbose and result.stderr.strip():
+165
View File
@@ -0,0 +1,165 @@
"""Tests for pymake CLI."""
from pyflowx.cli.pymake import _build_graphs, _get_maturin_build_command, conf
def test_pymake_config_attributes():
"""Test PymakeConfig has expected attributes."""
assert hasattr(conf, "PROJECT_ROOT")
assert hasattr(conf, "BUILD_TOOL")
assert hasattr(conf, "BUILD_COMMAND")
assert hasattr(conf, "MATURIN_TOOL")
assert hasattr(conf, "MATURIN_BUILD_COMMAND")
assert hasattr(conf, "MATURIN_DEV_COMMAND")
assert hasattr(conf, "TIMEOUT")
def test_pymake_config_values():
"""Test PymakeConfig values are correct."""
assert conf.BUILD_TOOL == "uv"
assert conf.BUILD_COMMAND == ["uv", "build"]
assert conf.MATURIN_TOOL == "maturin"
assert conf.TIMEOUT == 600
def test_get_maturin_build_command_basic():
"""Test _get_maturin_build_command returns base command."""
cmd = _get_maturin_build_command()
assert "maturin" in cmd
assert "build" in cmd
assert "-r" in cmd
def test_build_graphs_returns_dict():
"""Test _build_graphs returns a dictionary."""
graphs = _build_graphs()
assert isinstance(graphs, dict)
assert len(graphs) > 0
def test_build_graphs_has_expected_commands():
"""Test _build_graphs has expected command keys."""
graphs = _build_graphs()
expected_commands = [
"b",
"bc",
"ba",
"ic",
"ip",
"ia",
"cp",
"cc",
"ca",
"t",
"lint",
]
for cmd in expected_commands:
assert cmd in graphs, f"Expected command '{cmd}' not found in graphs"
def test_build_graphs_values_are_graphs():
"""Test _build_graphs values are Graph instances."""
import pyflowx as px
graphs = _build_graphs()
for name, graph in graphs.items():
assert isinstance(graph, px.Graph), (
f"Graph for command '{name}' is not a Graph instance"
)
def test_build_command_graph_structure():
"""Test 'b' command graph has correct structure."""
graphs = _build_graphs()
graph = graphs["b"]
assert len(graph.all_specs()) == 1
spec = graph.spec("uv_build")
assert spec.cmd == conf.BUILD_COMMAND
def test_build_all_command_graph_structure():
"""Test 'ba' command graph has correct dependencies."""
graphs = _build_graphs()
graph = graphs["ba"]
specs = graph.all_specs()
assert len(specs) == 2
# Check dependency
uv_build_spec = graph.spec("uv_build")
assert "maturin_build" in uv_build_spec.depends_on
def test_maturin_build_command_graph_structure():
"""Test 'bc' command graph has correct structure."""
graphs = _build_graphs()
graph = graphs["bc"]
specs = graph.all_specs()
assert len(specs) == 1
spec = graph.spec("maturin_build")
assert spec.cmd == _get_maturin_build_command()
def test_install_all_command_graph_structure():
"""Test 'ia' command graph has correct dependencies."""
graphs = _build_graphs()
graph = graphs["ia"]
specs = graph.all_specs()
assert len(specs) == 2
uv_install_spec = graph.spec("uv_install")
assert "maturin_dev" in uv_install_spec.depends_on
def test_clean_all_command_graph_structure():
"""Test 'ca' command graph has correct structure."""
graphs = _build_graphs()
graph = graphs["ca"]
specs = graph.all_specs()
assert len(specs) == 2
def test_test_command_graph_structure():
"""Test 't' command graph has correct structure."""
graphs = _build_graphs()
graph = graphs["t"]
specs = graph.all_specs()
assert len(specs) == 1
spec = graph.spec("pytest")
assert "pytest" in spec.cmd
def test_lint_command_graph_structure():
"""Test 'lint' command graph has correct structure."""
graphs = _build_graphs()
graph = graphs["lint"]
specs = graph.all_specs()
assert len(specs) == 1
spec = graph.spec("ruff_check")
assert "ruff" in spec.cmd
def test_pymake_config_dirs_to_ignore():
"""Test PymakeConfig has correct dirs to ignore."""
assert ".venv" in conf.DIRS_TO_IGNORE
assert ".git" in conf.DIRS_TO_IGNORE
assert ".tox" in conf.DIRS_TO_IGNORE
def test_pymake_config_python_build_dirs():
"""Test PymakeConfig has correct Python build dirs."""
assert "dist" in conf.PYTHON_BUILD_DIRS
assert "build" in conf.PYTHON_BUILD_DIRS
def test_maturin_build_options_win7():
"""Test MATURIN_BUILD_OPTIONS_WIN7 has expected options."""
assert "--target" in conf.MATURIN_BUILD_OPTIONS_WIN7
assert "x86_64-win7-windows-msvc" in conf.MATURIN_BUILD_OPTIONS_WIN7
assert "-Zbuild-std" in conf.MATURIN_BUILD_OPTIONS_WIN7
def test_doc_build_command():
"""Test DOC_BUILD_COMMAND has expected structure."""
assert "sphinx-build" in conf.DOC_BUILD_COMMAND
assert "-b" in conf.DOC_BUILD_COMMAND
assert "html" in conf.DOC_BUILD_COMMAND
+178
View File
@@ -0,0 +1,178 @@
"""Tests for conditions module."""
import os
import sys
from unittest.mock import patch
from pyflowx.conditions import (
IS_LINUX,
IS_MACOS,
IS_POSIX,
IS_WINDOWS,
BuiltinConditions,
Constants,
)
def test_constants_is_windows():
"""Test Constants.IS_WINDOWS is correct."""
assert (sys.platform == "win32") == Constants.IS_WINDOWS
def test_constants_is_linux():
"""Test Constants.IS_LINUX is correct."""
assert (sys.platform == "linux") == Constants.IS_LINUX
def test_constants_is_macos():
"""Test Constants.IS_MACOS is correct."""
assert (sys.platform == "darwin") == Constants.IS_MACOS
def test_constants_is_posix():
"""Test Constants.IS_POSIX is correct."""
assert (sys.platform != "win32") == Constants.IS_POSIX
def test_builtin_conditions_is_windows():
"""Test BuiltinConditions.IS_WINDOWS."""
result = BuiltinConditions.IS_WINDOWS()
assert result == Constants.IS_WINDOWS
def test_builtin_conditions_is_linux():
"""Test BuiltinConditions.IS_LINUX."""
result = BuiltinConditions.IS_LINUX()
assert result == Constants.IS_LINUX
def test_builtin_conditions_is_macos():
"""Test BuiltinConditions.IS_MACOS."""
result = BuiltinConditions.IS_MACOS()
assert result == Constants.IS_MACOS
def test_builtin_conditions_is_posix():
"""Test BuiltinConditions.IS_POSIX."""
result = BuiltinConditions.IS_POSIX()
assert result == Constants.IS_POSIX
def test_builtin_conditions_python_version_major_only():
"""Test BuiltinConditions.PYTHON_VERSION with major only."""
# Test with current Python version
current_major = sys.version_info.major
assert BuiltinConditions.PYTHON_VERSION(current_major) is True
assert BuiltinConditions.PYTHON_VERSION(current_major + 1) is False
def test_builtin_conditions_python_version_with_minor():
"""Test BuiltinConditions.PYTHON_VERSION with major and minor."""
current_major = sys.version_info.major
current_minor = sys.version_info.minor
assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor) is True
assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor + 1) is False
def test_builtin_conditions_python_version_at_least():
"""Test BuiltinConditions.PYTHON_VERSION_AT_LEAST."""
current_major = sys.version_info.major
current_minor = sys.version_info.minor
# Current version should be at least itself
assert (
BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major, current_minor) is True
)
# Current version should be at least an older version
assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major - 1, 0) is True
# Current version should NOT be at least a newer version
assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major + 1, 0) is False
def test_builtin_conditions_has_app_installed_true():
"""Test BuiltinConditions.HAS_APP_INSTALLED when app exists."""
# Python should always be available
condition = BuiltinConditions.HAS_APP_INSTALLED("python")
assert condition() is True
def test_builtin_conditions_has_app_installed_false():
"""Test BuiltinConditions.HAS_APP_INSTALLED when app doesn't exist."""
condition = BuiltinConditions.HAS_APP_INSTALLED("nonexistent_app_12345")
assert condition() is False
def test_builtin_conditions_env_var_exists_true():
"""Test BuiltinConditions.ENV_VAR_EXISTS when variable exists."""
with patch.dict(os.environ, {"TEST_VAR": "value"}):
condition = BuiltinConditions.ENV_VAR_EXISTS("TEST_VAR")
assert condition() is True
def test_builtin_conditions_env_var_exists_false():
"""Test BuiltinConditions.ENV_VAR_EXISTS when variable doesn't exist."""
condition = BuiltinConditions.ENV_VAR_EXISTS("NONEXISTENT_VAR_12345")
assert condition() is False
def test_builtin_conditions_env_var_equals_true():
"""Test BuiltinConditions.ENV_VAR_EQUALS when value matches."""
with patch.dict(os.environ, {"TEST_VAR": "expected_value"}):
condition = BuiltinConditions.ENV_VAR_EQUALS("TEST_VAR", "expected_value")
assert condition() is True
def test_builtin_conditions_env_var_equals_false():
"""Test BuiltinConditions.ENV_VAR_EQUALS when value doesn't match."""
with patch.dict(os.environ, {"TEST_VAR": "different_value"}):
condition = BuiltinConditions.ENV_VAR_EQUALS("TEST_VAR", "expected_value")
assert condition() is False
def test_builtin_conditions_not():
"""Test BuiltinConditions.NOT."""
true_condition = lambda: True # noqa: E731
false_condition = lambda: False # noqa: E731
not_true = BuiltinConditions.NOT(true_condition)
assert not_true() is False
not_false = BuiltinConditions.NOT(false_condition)
assert not_false() is True
def test_builtin_conditions_and_all_true():
"""Test BuiltinConditions.AND when all conditions are true."""
true_condition = lambda: True # noqa: E731
condition = BuiltinConditions.AND(true_condition, true_condition, true_condition)
assert condition() is True
def test_builtin_conditions_and_one_false():
"""Test BuiltinConditions.AND when one condition is false."""
true_condition = lambda: True # noqa: E731
false_condition = lambda: False # noqa: E731
condition = BuiltinConditions.AND(true_condition, false_condition, true_condition)
assert condition() is False
def test_builtin_conditions_or_all_false():
"""Test BuiltinConditions.OR when all conditions are false."""
false_condition = lambda: False # noqa: E731
condition = BuiltinConditions.OR(false_condition, false_condition, false_condition)
assert condition() is False
def test_builtin_conditions_or_one_true():
"""Test BuiltinConditions.OR when one condition is true."""
true_condition = lambda: True # noqa: E731
false_condition = lambda: False # noqa: E731
condition = BuiltinConditions.OR(false_condition, true_condition, false_condition)
assert condition() is True
def test_exported_conditions():
"""Test exported condition functions."""
assert IS_WINDOWS() == Constants.IS_WINDOWS
assert IS_LINUX() == Constants.IS_LINUX
assert IS_MACOS() == Constants.IS_MACOS
assert IS_POSIX() == Constants.IS_POSIX
+189
View File
@@ -0,0 +1,189 @@
"""Tests for executors module edge cases."""
import asyncio
import sys
import pytest
import pyflowx as px
from pyflowx.task import TaskStatus
# 跨平台的 echo 命令
if sys.platform == "win32":
ECHO_CMD = ["cmd", "/c", "echo"]
else:
ECHO_CMD = ["echo"]
def test_execute_sync_with_timeout():
"""Test execute task with timeout correctly."""
# Note: timeout for Python functions only works in async strategy
# For sync functions, timeout is not enforced in sequential strategy
# This test verifies that the task runs without timeout error
spec = px.TaskSpec("quick", fn=lambda: "result", timeout=10)
graph = px.Graph.from_specs([spec])
# Should succeed without timeout error
report = px.run(graph, strategy="sequential")
assert report.success
def test_execute_async_with_timeout():
"""Test execute async task with timeout correctly."""
async def slow_async_function():
await asyncio.sleep(2)
return "result"
spec = px.TaskSpec("slow_async", fn=slow_async_function, timeout=0.5)
graph = px.Graph.from_specs([spec])
# This should timeout
with pytest.raises(px.TaskFailedError):
px.run(graph, strategy="async")
def test_verbose_event_callback_running():
"""Test verbose event callback for RUNNING status."""
# Create a graph with verbose callback
spec = px.TaskSpec("test", fn=lambda: "result", verbose=True)
graph = px.Graph.from_specs([spec])
report = px.run(graph, strategy="sequential")
# Should print without error
assert report.success
def test_verbose_event_callback_success():
"""Test verbose event callback for SUCCESS status."""
# Create a graph with verbose callback
spec = px.TaskSpec("test", fn=lambda: "result", verbose=True)
graph = px.Graph.from_specs([spec])
report = px.run(graph, strategy="sequential")
# Should print without error
assert report.success
def test_verbose_event_callback_failed():
"""Test verbose event callback for FAILED status."""
# Create a graph with verbose callback and failing task
def raise_error():
raise ValueError("test error")
spec = px.TaskSpec("test", fn=raise_error, verbose=True)
graph = px.Graph.from_specs([spec])
# Should print without error
with pytest.raises(px.TaskFailedError):
px.run(graph, strategy="sequential")
def test_verbose_event_callback_skipped():
"""Test verbose event callback for SKIPPED status."""
# Create a graph with verbose callback and skipped task
spec = px.TaskSpec(
"test",
fn=lambda: "result",
conditions=(lambda: False,),
verbose=True,
)
graph = px.Graph.from_specs([spec])
report = px.run(graph, strategy="sequential")
# Should print without error
assert report.success
def test_execute_sync_with_retries():
"""Test execute task with retries."""
call_count = 0
def failing_function():
nonlocal call_count
call_count += 1
if call_count < 3:
raise ValueError("temporary error")
return "success"
spec = px.TaskSpec("retry_test", fn=failing_function, retries=3)
graph = px.Graph.from_specs([spec])
# Should succeed after retries
report = px.run(graph, strategy="sequential")
assert report.success
assert report.results["retry_test"].attempts == 3
def test_execute_async_with_retries():
"""Test execute async task with retries."""
call_count = 0
async def failing_async_function():
nonlocal call_count
call_count += 1
if call_count < 3:
raise ValueError("temporary error")
return "success"
spec = px.TaskSpec("retry_async_test", fn=failing_async_function, retries=3)
graph = px.Graph.from_specs([spec])
# Should succeed after retries
report = px.run(graph, strategy="async")
assert report.success
assert report.results["retry_async_test"].attempts == 3
def test_execute_sync_skip_on_condition():
"""Test execute task skips task when condition is false."""
spec = px.TaskSpec(
"skip_test",
fn=lambda: "result",
conditions=(lambda: False,),
)
graph = px.Graph.from_specs([spec])
report = px.run(graph, strategy="sequential")
assert report.success
assert report.results["skip_test"].status == TaskStatus.SKIPPED
def test_execute_async_skip_on_condition():
"""Test execute async task skips task when condition is false."""
spec = px.TaskSpec(
"skip_async_test",
fn=lambda: "result",
conditions=(lambda: False,),
)
graph = px.Graph.from_specs([spec])
report = px.run(graph, strategy="async")
assert report.success
assert report.results["skip_async_test"].status == TaskStatus.SKIPPED
def test_execute_sync_with_error():
"""Test execute task handles errors correctly."""
def error_function():
raise ValueError("test error")
spec = px.TaskSpec("error_test", fn=error_function)
graph = px.Graph.from_specs([spec])
with pytest.raises(px.TaskFailedError):
px.run(graph, strategy="sequential")
def test_execute_async_with_error():
"""Test execute async task handles errors correctly."""
async def error_async_function():
raise ValueError("test error")
spec = px.TaskSpec("error_async_test", fn=error_async_function)
graph = px.Graph.from_specs([spec])
with pytest.raises(px.TaskFailedError):
px.run(graph, strategy="async")
+156
View File
@@ -0,0 +1,156 @@
"""Tests for task module edge cases."""
import sys
import tempfile
import pytest
import pyflowx as px
from pyflowx.task import TaskSpec
# 跨平台的 echo 命令
if sys.platform == "win32":
ECHO_CMD = ["cmd", "/c", "echo"]
else:
ECHO_CMD = ["echo"]
def test_taskspec_wrap_cmd_with_list():
"""Test TaskSpec._wrap_cmd with command list."""
spec = TaskSpec("test", cmd=[*ECHO_CMD, "hello"])
wrapped_fn = spec.effective_fn
assert wrapped_fn is not None
assert wrapped_fn.__name__ == "test"
def test_taskspec_wrap_cmd_with_string():
"""Test TaskSpec._wrap_cmd with command string."""
if sys.platform == "win32":
cmd_str = "cmd /c echo hello"
else:
cmd_str = "echo hello"
spec = TaskSpec("test", cmd=cmd_str)
wrapped_fn = spec.effective_fn
assert wrapped_fn is not None
assert wrapped_fn.__name__ == "test"
def test_taskspec_wrap_cmd_with_timeout():
"""Test TaskSpec._wrap_cmd with timeout."""
spec = TaskSpec("test", cmd=[*ECHO_CMD, "hello"], timeout=0.1)
wrapped_fn = spec.effective_fn
# Should not raise timeout error for quick command
result = wrapped_fn()
assert result is None
def test_taskspec_wrap_cmd_with_cwd():
"""Test TaskSpec._wrap_cmd with working directory."""
with tempfile.TemporaryDirectory() as tmpdir:
spec = TaskSpec("test", cmd=[*ECHO_CMD, "hello"], cwd=tmpdir)
wrapped_fn = spec.effective_fn
result = wrapped_fn()
assert result is None
def test_taskspec_wrap_cmd_verbose():
"""Test TaskSpec._wrap_cmd with verbose=True."""
spec = TaskSpec("test", cmd=[*ECHO_CMD, "hello"], verbose=True)
wrapped_fn = spec.effective_fn
# Should print verbose output
result = wrapped_fn()
assert result is None
def test_taskspec_wrap_cmd_error():
"""Test TaskSpec._wrap_cmd handles command error."""
spec = TaskSpec("test", cmd=["python", "-c", "import sys; sys.exit(1)"])
wrapped_fn = spec.effective_fn
with pytest.raises(RuntimeError, match="命令执行失败"):
wrapped_fn()
def test_taskspec_wrap_cmd_file_not_found():
"""Test TaskSpec._wrap_cmd handles file not found."""
spec = TaskSpec("test", cmd=["nonexistent_command"])
wrapped_fn = spec.effective_fn
with pytest.raises(RuntimeError, match="命令未找到"):
wrapped_fn()
def test_taskspec_wrap_cmd_shell_file_not_found():
"""Test TaskSpec._wrap_cmd handles shell command file not found."""
spec = TaskSpec("test", cmd="nonexistent_shell_command")
wrapped_fn = spec.effective_fn
# Shell commands don't raise FileNotFoundError
# They just return non-zero exit code
with pytest.raises(RuntimeError):
wrapped_fn()
def test_taskspec_no_fn_no_cmd():
"""Test TaskSpec raises error when no fn or cmd."""
with pytest.raises(ValueError, match="必须提供 fn 或 cmd 参数"):
TaskSpec("test")
def test_taskspec_cmd_overrides_fn():
"""Test TaskSpec cmd overrides fn."""
def my_fn():
return "fn_result"
spec = TaskSpec("test", fn=my_fn, cmd=[*ECHO_CMD, "hello"])
wrapped_fn = spec.effective_fn
# cmd should override fn
assert wrapped_fn.__name__ == "test"
def test_taskspec_conditions_check():
"""Test TaskSpec.should_execute with conditions."""
spec = px.TaskSpec(
"test",
fn=lambda: "result",
conditions=(lambda: True,),
)
assert spec.should_execute() is True
def test_taskspec_conditions_false():
"""Test TaskSpec.should_execute with false conditions."""
spec = px.TaskSpec(
"test",
fn=lambda: "result",
conditions=(lambda: False,),
)
assert spec.should_execute() is False
def test_taskspec_conditions_multiple():
"""Test TaskSpec.should_execute with multiple conditions."""
spec = px.TaskSpec(
"test",
fn=lambda: "result",
conditions=(lambda: True, lambda: True, lambda: True),
)
assert spec.should_execute() is True
def test_taskspec_conditions_multiple_one_false():
"""Test TaskSpec.should_execute with one false condition."""
spec = px.TaskSpec(
"test",
fn=lambda: "result",
conditions=(lambda: True, lambda: False, lambda: True),
)
assert spec.should_execute() is False
Generated
+4
View File
@@ -1907,6 +1907,8 @@ version = "0.1.2"
source = { editable = "." }
dependencies = [
{ name = "graphlib-backport", marker = "python_full_version < '3.9'" },
{ name = "typing-extensions", version = "4.13.2", source = { registry = "https://mirrors.aliyun.com/pypi/simple/" }, marker = "python_full_version < '3.9'" },
{ name = "typing-extensions", version = "4.15.0", source = { registry = "https://mirrors.aliyun.com/pypi/simple/" }, marker = "python_full_version >= '3.9'" },
]
[package.optional-dependencies]
@@ -1961,6 +1963,7 @@ requires-dist = [
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.8.0" },
{ name = "tox", marker = "extra == 'dev'", specifier = ">=4.25.0" },
{ name = "tox-uv", marker = "extra == 'dev'", specifier = ">=1.13.1" },
{ name = "typing-extensions", specifier = ">=4.13.2" },
]
provides-extras = ["dev"]
@@ -2759,6 +2762,7 @@ name = "typing-extensions"
version = "4.15.0"
source = { registry = "https://mirrors.aliyun.com/pypi/simple/" }
resolution-markers = [
"python_full_version >= '3.15'",
"python_full_version >= '3.10' and python_full_version < '3.15'",
"python_full_version > '3.9' and python_full_version < '3.10'",
"python_full_version == '3.9'",