refactor: 全面迁移至 Python 3.9+ 原生泛型类型语法
- 将所有 `Optional[T]` 替换为 `T | None` - 将所有 `List[T]`/`Dict[K, V]`/`Tuple[Ts, ...]` 替换为对应原生泛型 - 调整类型导入,移除冗余的 typing 导入项 - 更新项目依赖,添加 typing-extensions 兼容旧版本 Python - 重构部分函数签名与内部实现以匹配新类型语法
This commit is contained in:
+47
-1
@@ -10,7 +10,10 @@ classifiers = [
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Topic :: Software Development :: Libraries :: Application Frameworks",
|
||||
]
|
||||
dependencies = ["graphlib_backport >= 1.0.0; python_version < '3.9'"]
|
||||
dependencies = [
|
||||
"graphlib_backport >= 1.0.0; python_version < '3.9'",
|
||||
"typing-extensions>=4.13.2",
|
||||
]
|
||||
description = "Lightweight, type-safe DAG task scheduler with multi-strategy execution."
|
||||
keywords = ["async", "dag", "scheduler", "task", "workflow"]
|
||||
license = { text = "MIT" }
|
||||
@@ -86,3 +89,46 @@ reportImplicitStringConcatenation = "error"
|
||||
reportMissingTypeStubs = "none"
|
||||
reportUnusedCallResult = "warning"
|
||||
typeCheckingMode = "recommended" # 类型检查严格度:off / basic / standard / recommended(默认) / strict / all
|
||||
|
||||
# Ruff 配置 - 与 .pre-commit-config.yaml 保持一致
|
||||
[tool.ruff]
|
||||
target-version = "py38"
|
||||
line-length = 88
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"F", # Pyflakes
|
||||
"I", # isort
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"UP", # pyupgrade
|
||||
"ARG", # flake8-unused-arguments
|
||||
"SIM", # flake8-simplify
|
||||
"PTH", # flake8-use-pathlib
|
||||
"PL", # Pylint
|
||||
"RUF", # Ruff-specific rules
|
||||
]
|
||||
ignore = [
|
||||
"E501", # line too long (handled by formatter)
|
||||
"PLR0913", # too many arguments
|
||||
"PLR2004", # magic value comparison
|
||||
"PTH123", # pathlib open() replacement
|
||||
"SIM108", # use ternary operator
|
||||
"RUF001", # ambiguous unicode characters in string
|
||||
"RUF002", # ambiguous unicode characters in docstring
|
||||
"RUF003", # ambiguous unicode characters in comment
|
||||
"RUF012", # mutable class attributes (intentional for config)
|
||||
"PLC0415", # import should be at top-level (intentional for lazy imports)
|
||||
"PLR0915", # too many statements (intentional for complex methods)
|
||||
"PTH119", # os.path.basename (intentional for sys.argv)
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["pyflowx"]
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
indent-style = "space"
|
||||
docstring-code-format = true
|
||||
|
||||
+33
-33
@@ -81,43 +81,43 @@ from .task import TaskCmd, TaskEvent, TaskResult, TaskSpec, TaskStatus
|
||||
__version__ = "0.1.2"
|
||||
|
||||
__all__ = [
|
||||
# 核心类型
|
||||
"TaskSpec",
|
||||
"TaskStatus",
|
||||
"TaskResult",
|
||||
"TaskEvent",
|
||||
"Context",
|
||||
"TaskCmd",
|
||||
"Graph",
|
||||
"RunReport",
|
||||
# 执行
|
||||
"run",
|
||||
"Strategy",
|
||||
# CLI 运行器
|
||||
"CliRunner",
|
||||
"CliExitCode",
|
||||
# 状态后端
|
||||
"StateBackend",
|
||||
"MemoryBackend",
|
||||
"JSONBackend",
|
||||
# 错误
|
||||
"PyFlowXError",
|
||||
"DuplicateTaskError",
|
||||
"MissingDependencyError",
|
||||
"CycleError",
|
||||
"TaskFailedError",
|
||||
"TaskTimeoutError",
|
||||
"InjectionError",
|
||||
"StorageError",
|
||||
# 条件判断
|
||||
"Condition",
|
||||
"Constants",
|
||||
"BuiltinConditions",
|
||||
"IS_WINDOWS",
|
||||
"IS_LINUX",
|
||||
"IS_MACOS",
|
||||
"IS_POSIX",
|
||||
"IS_WINDOWS",
|
||||
"BuiltinConditions",
|
||||
"CliExitCode",
|
||||
# CLI 运行器
|
||||
"CliRunner",
|
||||
# 条件判断
|
||||
"Condition",
|
||||
"Constants",
|
||||
"Context",
|
||||
"CycleError",
|
||||
"DuplicateTaskError",
|
||||
"Graph",
|
||||
"InjectionError",
|
||||
"JSONBackend",
|
||||
"MemoryBackend",
|
||||
"MissingDependencyError",
|
||||
# 错误
|
||||
"PyFlowXError",
|
||||
"RunReport",
|
||||
# 状态后端
|
||||
"StateBackend",
|
||||
"StorageError",
|
||||
"Strategy",
|
||||
"TaskCmd",
|
||||
"TaskEvent",
|
||||
"TaskFailedError",
|
||||
"TaskResult",
|
||||
# 核心类型
|
||||
"TaskSpec",
|
||||
"TaskStatus",
|
||||
"TaskTimeoutError",
|
||||
# 辅助(高级)
|
||||
"build_call_args",
|
||||
"describe_injection",
|
||||
# 执行
|
||||
"run",
|
||||
]
|
||||
|
||||
@@ -8,7 +8,7 @@ from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import sys
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable
|
||||
|
||||
# 条件判断函数类型
|
||||
Condition = Callable[[], bool]
|
||||
@@ -47,7 +47,7 @@ class BuiltinConditions:
|
||||
return Constants.IS_POSIX
|
||||
|
||||
@staticmethod
|
||||
def PYTHON_VERSION(major: int, minor: Optional[int] = None) -> bool:
|
||||
def PYTHON_VERSION(major: int, minor: int | None = None) -> bool:
|
||||
"""检查 Python 版本是否匹配.
|
||||
|
||||
Parameters
|
||||
|
||||
+11
-11
@@ -18,12 +18,12 @@ DAG 库中泛滥的样板包装器(如 ``def wrapper(): return fn(workflow.get
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Mapping, Set, Tuple
|
||||
from typing import Any, Mapping
|
||||
|
||||
from .errors import InjectionError
|
||||
from .task import Context, TaskSpec
|
||||
|
||||
__all__ = ["Context", "build_call_args", "describe_injection", "_is_context_annotation"]
|
||||
__all__ = ["Context", "_is_context_annotation", "build_call_args", "describe_injection"]
|
||||
|
||||
|
||||
def _is_context_annotation(annotation: Any) -> bool:
|
||||
@@ -43,15 +43,13 @@ def _is_context_annotation(annotation: Any) -> bool:
|
||||
return annotation == "Context" or annotation.endswith(".Context")
|
||||
# 按限定名匹配,支持 ``from pyflowx import Context`` 再导出。
|
||||
name = getattr(annotation, "__name__", None) or getattr(annotation, "_name", None)
|
||||
if name in ("Context", "Mapping"):
|
||||
return True
|
||||
return False
|
||||
return name in ("Context", "Mapping")
|
||||
|
||||
|
||||
def build_call_args(
|
||||
spec: TaskSpec[object],
|
||||
context: Mapping[str, Any],
|
||||
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
|
||||
) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
||||
"""解析用于调用 ``spec.fn`` 的 ``(args, kwargs)``。
|
||||
|
||||
参数
|
||||
@@ -84,7 +82,9 @@ def build_call_args(
|
||||
)
|
||||
|
||||
# 与本任务相关的上下文子集。
|
||||
dep_context: Dict[str, Any] = {name: context[name] for name in spec.depends_on if name in context}
|
||||
dep_context: dict[str, Any] = {
|
||||
name: context[name] for name in spec.depends_on if name in context
|
||||
}
|
||||
|
||||
# 检测静态 kwargs 与依赖名的冲突。
|
||||
collisions = set(spec.kwargs) & set(dep_context)
|
||||
@@ -95,12 +95,12 @@ def build_call_args(
|
||||
"rename the static kwarg or the dependency.",
|
||||
)
|
||||
|
||||
injected_kwargs: Dict[str, Any] = {}
|
||||
leftover_dep_results: Dict[str, Any] = dict(dep_context)
|
||||
injected_kwargs: dict[str, Any] = {}
|
||||
leftover_dep_results: dict[str, Any] = dict(dep_context)
|
||||
|
||||
# 被 spec.args 消费的位置参数。记录哪些参数名已被位置填充,
|
||||
# 以便在基于名称的注入(依赖 / Context / 静态 kwargs)时跳过。
|
||||
positional_params: List[str] = []
|
||||
positional_params: list[str] = []
|
||||
positional_kinds = (
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
@@ -109,7 +109,7 @@ def build_call_args(
|
||||
if param.kind in positional_kinds:
|
||||
positional_params.append(pname)
|
||||
# 前 len(spec.args) 个位置参数由 spec.args 填充。
|
||||
args_filled: Set[str] = set(positional_params[: len(spec.args)])
|
||||
args_filled: set[str] = set(positional_params[: len(spec.args)])
|
||||
|
||||
for pname, param in params.items():
|
||||
# 跳过已被位置 spec.args 填充的参数。
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterable, Optional
|
||||
from typing import Any, Iterable
|
||||
|
||||
|
||||
class PyFlowXError(Exception):
|
||||
@@ -55,10 +55,12 @@ class TaskFailedError(PyFlowXError):
|
||||
task: str,
|
||||
cause: BaseException,
|
||||
attempts: int,
|
||||
layer: Optional[int] = None,
|
||||
layer: int | None = None,
|
||||
) -> None:
|
||||
location = f" (layer {layer})" if layer is not None else ""
|
||||
super().__init__(f"Task '{task}' failed after {attempts} attempt(s){location}: {cause}")
|
||||
super().__init__(
|
||||
f"Task '{task}' failed after {attempts} attempt(s){location}: {cause}"
|
||||
)
|
||||
self.task = task
|
||||
self.cause = cause
|
||||
self.attempts = attempts
|
||||
@@ -85,6 +87,6 @@ class InjectionError(PyFlowXError):
|
||||
class StorageError(PyFlowXError):
|
||||
"""状态后端在持久化失败时抛出。"""
|
||||
|
||||
def __init__(self, detail: str, cause: Optional[BaseException] = None) -> None:
|
||||
def __init__(self, detail: str, cause: BaseException | None = None) -> None:
|
||||
super().__init__(f"State storage error: {detail}")
|
||||
self.cause: Any = cause
|
||||
|
||||
@@ -10,7 +10,7 @@ Shows:
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
@@ -20,13 +20,13 @@ async def fetch_user(uid: int) -> dict:
|
||||
return {"id": uid, "name": f"User{uid}"}
|
||||
|
||||
|
||||
async def fetch_posts(uid: int) -> List[int]:
|
||||
async def fetch_posts(uid: int) -> list[int]:
|
||||
await asyncio.sleep(0.2)
|
||||
return [uid, uid + 1]
|
||||
|
||||
|
||||
# Context annotation → receives the full mapping of upstream results.
|
||||
def aggregate(ctx: px.Context) -> Dict[str, Any]:
|
||||
def aggregate(ctx: px.Context) -> dict[str, Any]:
|
||||
return dict(ctx)
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ def main() -> None:
|
||||
print("=== Dry run ===")
|
||||
px.run(graph, strategy="async", dry_run=True)
|
||||
|
||||
events: List[px.TaskEvent] = []
|
||||
events: list[px.TaskEvent] = []
|
||||
print("\n=== Async execution ===")
|
||||
report = px.run(graph, strategy="async", on_event=events.append)
|
||||
|
||||
|
||||
@@ -10,21 +10,19 @@ Demonstrates the core PyFlowX workflow:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
# --- task functions: pure, testable, no framework coupling ------------- #
|
||||
|
||||
|
||||
def extract_customers() -> List[dict]:
|
||||
def extract_customers() -> list[dict]:
|
||||
return [
|
||||
{"id": "C001", "name": "Alice"},
|
||||
{"id": "C002", "name": "Bob"},
|
||||
]
|
||||
|
||||
|
||||
def extract_orders() -> List[dict]:
|
||||
def extract_orders() -> list[dict]:
|
||||
return [
|
||||
{"id": "O001", "customer_id": "C001", "amount": 150.0},
|
||||
{"id": "O002", "customer_id": "C002", "amount": 200.5},
|
||||
@@ -33,9 +31,9 @@ def extract_orders() -> List[dict]:
|
||||
|
||||
# Parameter names match dependency names → automatic injection.
|
||||
def transform(
|
||||
extract_customers: List[dict],
|
||||
extract_orders: List[dict],
|
||||
) -> List[dict]:
|
||||
extract_customers: list[dict],
|
||||
extract_orders: list[dict],
|
||||
) -> list[dict]:
|
||||
cmap = {c["id"]: c for c in extract_customers}
|
||||
return [
|
||||
{**o, "customer_name": cmap[o["customer_id"]]["name"]}
|
||||
@@ -44,7 +42,7 @@ def transform(
|
||||
]
|
||||
|
||||
|
||||
def load(transform: List[dict]) -> int:
|
||||
def load(transform: list[dict]) -> int:
|
||||
print(f" loaded {len(transform)} records")
|
||||
return len(transform)
|
||||
|
||||
|
||||
+74
-52
@@ -20,7 +20,7 @@ import enum
|
||||
import inspect
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Union, cast
|
||||
from typing import Any, Awaitable, Callable, Mapping, cast
|
||||
|
||||
from .context import build_call_args, describe_injection
|
||||
from .errors import TaskFailedError, TaskTimeoutError
|
||||
@@ -53,7 +53,7 @@ class Strategy(enum.Enum):
|
||||
ASYNC = "async"
|
||||
|
||||
|
||||
def _normalize_strategy(strategy: Union[str, Strategy]) -> Strategy:
|
||||
def _normalize_strategy(strategy: str | Strategy) -> Strategy:
|
||||
"""将字符串或 Strategy 归一化为 Strategy 枚举.
|
||||
|
||||
Parameters
|
||||
@@ -79,7 +79,9 @@ def _normalize_strategy(strategy: Union[str, Strategy]) -> Strategy:
|
||||
return Strategy(strategy)
|
||||
except ValueError:
|
||||
valid = ", ".join(repr(s.value) for s in Strategy)
|
||||
raise ValueError(f"unknown strategy {strategy!r}; expected one of {valid}.") from None
|
||||
raise ValueError(
|
||||
f"unknown strategy {strategy!r}; expected one of {valid}."
|
||||
) from None
|
||||
raise TypeError(f"strategy must be str or Strategy, got {type(strategy).__name__}")
|
||||
|
||||
|
||||
@@ -89,7 +91,7 @@ def _is_async_fn(spec: TaskSpec[object]) -> bool:
|
||||
|
||||
|
||||
def _emit(
|
||||
on_event: Optional[EventCallback],
|
||||
on_event: EventCallback | None,
|
||||
result: TaskResult[object],
|
||||
) -> None:
|
||||
"""若注册了回调则触发一个观察者事件。"""
|
||||
@@ -106,7 +108,9 @@ def _emit(
|
||||
)
|
||||
|
||||
|
||||
def _log_retry(spec: TaskSpec[object], attempts: int, max_attempts: int, exc: BaseException) -> None:
|
||||
def _log_retry(
|
||||
spec: TaskSpec[object], attempts: int, max_attempts: int, exc: BaseException
|
||||
) -> None:
|
||||
"""记录重试日志(sync 与 async 共享,便于测试覆盖)。"""
|
||||
logger.warning(
|
||||
"task %r failed (attempt %d/%d): %r; retrying",
|
||||
@@ -117,7 +121,7 @@ def _log_retry(spec: TaskSpec[object], attempts: int, max_attempts: int, exc: Ba
|
||||
)
|
||||
|
||||
|
||||
def _finalize_failure(result: TaskResult[object], layer_idx: Optional[int]) -> None:
|
||||
def _finalize_failure(result: TaskResult[object], layer_idx: int | None) -> None:
|
||||
"""标记任务为 FAILED 并抛出 TaskFailedError。"""
|
||||
result.status = TaskStatus.FAILED
|
||||
result.finished_at = datetime.now()
|
||||
@@ -132,7 +136,7 @@ def _finalize_failure(result: TaskResult[object], layer_idx: Optional[int]) -> N
|
||||
def _run_sync_with_retry(
|
||||
spec: TaskSpec[object],
|
||||
context: Mapping[str, Any],
|
||||
layer_idx: Optional[int],
|
||||
layer_idx: int | None,
|
||||
) -> TaskResult[object]:
|
||||
"""执行同步任务并带重试;返回填充好的 TaskResult。"""
|
||||
result: TaskResult[object] = TaskResult(spec=spec)
|
||||
@@ -155,7 +159,7 @@ def _run_sync_with_retry(
|
||||
result.status = TaskStatus.SUCCESS
|
||||
result.finished_at = datetime.now()
|
||||
return result
|
||||
except Exception as exc: # noqa: BLE001 - 用户代码可能抛任何异常
|
||||
except Exception as exc:
|
||||
result.error = exc
|
||||
if result.attempts >= max_attempts:
|
||||
_finalize_failure(result, layer_idx) # pragma: no cover
|
||||
@@ -166,7 +170,7 @@ def _run_sync_with_retry(
|
||||
async def _run_async_with_retry(
|
||||
spec: TaskSpec[object],
|
||||
context: Mapping[str, Any],
|
||||
layer_idx: Optional[int],
|
||||
layer_idx: int | None,
|
||||
) -> TaskResult[object]:
|
||||
"""在事件循环上执行任务(同步或异步)并带重试。"""
|
||||
result: TaskResult[object] = TaskResult(spec=spec)
|
||||
@@ -198,7 +202,9 @@ async def _run_async_with_retry(
|
||||
return spec.effective_fn(*args, **kwargs)
|
||||
|
||||
if spec.timeout is not None:
|
||||
result.value = await asyncio.wait_for(loop.run_in_executor(None, fn_call), timeout=spec.timeout)
|
||||
result.value = await asyncio.wait_for(
|
||||
loop.run_in_executor(None, fn_call), timeout=spec.timeout
|
||||
)
|
||||
else:
|
||||
result.value = await loop.run_in_executor(None, fn_call)
|
||||
result.status = TaskStatus.SUCCESS
|
||||
@@ -214,7 +220,7 @@ async def _run_async_with_retry(
|
||||
result.attempts,
|
||||
max_attempts,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
except Exception as exc:
|
||||
result.error = exc
|
||||
if result.attempts >= max_attempts:
|
||||
_finalize_failure(result, layer_idx) # pragma: no cover
|
||||
@@ -230,17 +236,19 @@ def _build_context(
|
||||
global_context: Mapping[str, Any],
|
||||
) -> Mapping[str, Any]:
|
||||
"""将全局上下文限制为本任务的依赖。"""
|
||||
return {dep: global_context[dep] for dep in spec.depends_on if dep in global_context}
|
||||
return {
|
||||
dep: global_context[dep] for dep in spec.depends_on if dep in global_context
|
||||
}
|
||||
|
||||
|
||||
def _execute_layer_sequential(
|
||||
layer: List[str],
|
||||
layer: list[str],
|
||||
graph: Graph,
|
||||
context: Dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
report: RunReport,
|
||||
backend: StateBackend,
|
||||
layer_idx: int,
|
||||
on_event: Optional[EventCallback],
|
||||
on_event: EventCallback | None,
|
||||
) -> None:
|
||||
"""逐个运行某层的任务。"""
|
||||
for name in layer:
|
||||
@@ -261,23 +269,25 @@ def _execute_layer_sequential(
|
||||
|
||||
|
||||
def _execute_layer_threaded(
|
||||
layer: List[str],
|
||||
layer: list[str],
|
||||
graph: Graph,
|
||||
context: Dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
report: RunReport,
|
||||
backend: StateBackend,
|
||||
layer_idx: int,
|
||||
on_event: Optional[EventCallback],
|
||||
on_event: EventCallback | None,
|
||||
max_workers: int,
|
||||
) -> None:
|
||||
"""在线程池中并发运行某层的任务。"""
|
||||
# 先同步满足已缓存任务。
|
||||
to_run: List[str] = []
|
||||
to_run: list[str] = []
|
||||
for name in layer:
|
||||
if backend.has(name):
|
||||
cached = backend.get(name)
|
||||
context[name] = cached
|
||||
result = TaskResult(spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached)
|
||||
result = TaskResult(
|
||||
spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached
|
||||
)
|
||||
report.results[name] = result
|
||||
_emit(on_event, result)
|
||||
else:
|
||||
@@ -287,7 +297,7 @@ def _execute_layer_threaded(
|
||||
return
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
future_to_name: Dict[concurrent.futures.Future[TaskResult[object]], str] = {}
|
||||
future_to_name: dict[concurrent.futures.Future[TaskResult[object]], str] = {}
|
||||
for name in to_run:
|
||||
spec = graph.spec(name)
|
||||
# 为本任务快照上下文以避免竞态。
|
||||
@@ -305,21 +315,23 @@ def _execute_layer_threaded(
|
||||
|
||||
|
||||
async def _execute_layer_async(
|
||||
layer: List[str],
|
||||
layer: list[str],
|
||||
graph: Graph,
|
||||
context: Dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
report: RunReport,
|
||||
backend: StateBackend,
|
||||
layer_idx: int,
|
||||
on_event: Optional[EventCallback],
|
||||
on_event: EventCallback | None,
|
||||
) -> None:
|
||||
"""在事件循环上并发运行某层的任务。"""
|
||||
to_run: List[str] = []
|
||||
to_run: list[str] = []
|
||||
for name in layer:
|
||||
if backend.has(name):
|
||||
cached = backend.get(name)
|
||||
context[name] = cached
|
||||
result = TaskResult(spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached)
|
||||
result = TaskResult(
|
||||
spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached
|
||||
)
|
||||
report.results[name] = result
|
||||
_emit(on_event, result)
|
||||
else:
|
||||
@@ -346,8 +358,8 @@ async def _execute_layer_async(
|
||||
# 公共 API
|
||||
# ---------------------------------------------------------------------- #
|
||||
def _make_verbose_callback(
|
||||
on_event: Optional[EventCallback],
|
||||
) -> Optional[EventCallback]:
|
||||
on_event: EventCallback | None,
|
||||
) -> EventCallback | None:
|
||||
"""包装 on_event 回调, 在 verbose 模式下打印任务生命周期.
|
||||
|
||||
Parameters
|
||||
@@ -385,13 +397,13 @@ def _make_verbose_callback(
|
||||
|
||||
def run(
|
||||
graph: Graph,
|
||||
strategy: Union[str, Strategy] = Strategy.SEQUENTIAL,
|
||||
strategy: str | Strategy = Strategy.SEQUENTIAL,
|
||||
*,
|
||||
max_workers: Optional[int] = None,
|
||||
max_workers: int | None = None,
|
||||
dry_run: bool = False,
|
||||
verbose: bool = False,
|
||||
on_event: Optional[EventCallback] = None,
|
||||
state: Optional[StateBackend] = None,
|
||||
on_event: EventCallback | None = None,
|
||||
state: StateBackend | None = None,
|
||||
) -> RunReport:
|
||||
"""执行图并返回 :class:`RunReport`。
|
||||
|
||||
@@ -434,17 +446,23 @@ def run(
|
||||
return RunReport(success=True)
|
||||
|
||||
# verbose 模式下包装事件回调
|
||||
effective_callback: Optional[EventCallback] = _make_verbose_callback(on_event) if verbose else on_event
|
||||
effective_callback: EventCallback | None = (
|
||||
_make_verbose_callback(on_event) if verbose else on_event
|
||||
)
|
||||
|
||||
backend = resolve_backend(state)
|
||||
report = RunReport()
|
||||
context: Dict[str, Any] = {}
|
||||
context: dict[str, Any] = {}
|
||||
|
||||
try:
|
||||
if normalized == Strategy.SEQUENTIAL:
|
||||
_drive_sequential(graph, layers, context, report, backend, effective_callback)
|
||||
_drive_sequential(
|
||||
graph, layers, context, report, backend, effective_callback
|
||||
)
|
||||
elif normalized == Strategy.THREAD:
|
||||
_drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers)
|
||||
_drive_threaded(
|
||||
graph, layers, context, report, backend, effective_callback, max_workers
|
||||
)
|
||||
else:
|
||||
_drive_async(graph, layers, context, report, backend, effective_callback)
|
||||
except TaskFailedError:
|
||||
@@ -454,7 +472,7 @@ def run(
|
||||
return report
|
||||
|
||||
|
||||
def _print_dry_run(graph: Graph, layers: List[List[str]]) -> None:
|
||||
def _print_dry_run(graph: Graph, layers: list[list[str]]) -> None:
|
||||
"""打印执行计划但不运行任何任务。"""
|
||||
print(f"Dry run: {len(graph)} tasks, {len(layers)} layers")
|
||||
for idx, layer in enumerate(layers, 1):
|
||||
@@ -465,11 +483,11 @@ def _print_dry_run(graph: Graph, layers: List[List[str]]) -> None:
|
||||
|
||||
def _drive_sequential(
|
||||
graph: Graph,
|
||||
layers: List[List[str]],
|
||||
context: Dict[str, Any],
|
||||
layers: list[list[str]],
|
||||
context: dict[str, Any],
|
||||
report: RunReport,
|
||||
backend: StateBackend,
|
||||
on_event: Optional[EventCallback],
|
||||
on_event: EventCallback | None,
|
||||
) -> None:
|
||||
for idx, layer in enumerate(layers, 1):
|
||||
_execute_layer_sequential(layer, graph, context, report, backend, idx, on_event)
|
||||
@@ -477,36 +495,40 @@ def _drive_sequential(
|
||||
|
||||
def _drive_threaded(
|
||||
graph: Graph,
|
||||
layers: List[List[str]],
|
||||
context: Dict[str, Any],
|
||||
layers: list[list[str]],
|
||||
context: dict[str, Any],
|
||||
report: RunReport,
|
||||
backend: StateBackend,
|
||||
on_event: Optional[EventCallback],
|
||||
max_workers: Optional[int],
|
||||
on_event: EventCallback | None,
|
||||
max_workers: int | None,
|
||||
) -> None:
|
||||
for idx, layer in enumerate(layers, 1):
|
||||
workers = max_workers or max(1, min(32, len(layer)))
|
||||
_execute_layer_threaded(layer, graph, context, report, backend, idx, on_event, workers)
|
||||
_execute_layer_threaded(
|
||||
layer, graph, context, report, backend, idx, on_event, workers
|
||||
)
|
||||
|
||||
|
||||
def _drive_async(
|
||||
graph: Graph,
|
||||
layers: List[List[str]],
|
||||
context: Dict[str, Any],
|
||||
layers: list[list[str]],
|
||||
context: dict[str, Any],
|
||||
report: RunReport,
|
||||
backend: StateBackend,
|
||||
on_event: Optional[EventCallback],
|
||||
on_event: EventCallback | None,
|
||||
) -> None:
|
||||
asyncio.run(_async_drive(graph, layers, context, report, backend, on_event))
|
||||
|
||||
|
||||
async def _async_drive(
|
||||
graph: Graph,
|
||||
layers: List[List[str]],
|
||||
context: Dict[str, Any],
|
||||
layers: list[list[str]],
|
||||
context: dict[str, Any],
|
||||
report: RunReport,
|
||||
backend: StateBackend,
|
||||
on_event: Optional[EventCallback],
|
||||
on_event: EventCallback | None,
|
||||
) -> None:
|
||||
for idx, layer in enumerate(layers, 1):
|
||||
await _execute_layer_async(layer, graph, context, report, backend, idx, on_event)
|
||||
await _execute_layer_async(
|
||||
layer, graph, context, report, backend, idx, on_event
|
||||
)
|
||||
|
||||
+23
-19
@@ -8,7 +8,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import Dict, Iterable, List, Mapping, Sequence, Set, Tuple
|
||||
from typing import Iterable, Mapping, Sequence
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -38,14 +38,14 @@ class Graph:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._specs: Dict[str, TaskSpec[object]] = {}
|
||||
self._specs: dict[str, TaskSpec[object]] = {}
|
||||
# 任务 -> 其直接依赖(前驱)。
|
||||
self._deps: Dict[str, Tuple[str, ...]] = {}
|
||||
self._deps: dict[str, tuple[str, ...]] = {}
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 构建
|
||||
# ------------------------------------------------------------------ #
|
||||
def add(self, spec: TaskSpec[object]) -> "Graph":
|
||||
def add(self, spec: TaskSpec[object]) -> Graph:
|
||||
"""注册一个任务 spec,并即时校验。
|
||||
|
||||
返回 ``self`` 以支持链式调用,但推荐入口是 :meth:`from_specs`,
|
||||
@@ -60,7 +60,7 @@ class Graph:
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_specs(cls, specs: Iterable[TaskSpec[object]]) -> "Graph":
|
||||
def from_specs(cls, specs: Iterable[TaskSpec[object]]) -> Graph:
|
||||
"""从可迭代的 task spec 构建图。
|
||||
|
||||
先收集所有 spec,再统一校验。这意味着任务可以引用*后出现*的
|
||||
@@ -107,7 +107,7 @@ class Graph:
|
||||
# 内省
|
||||
# ------------------------------------------------------------------ #
|
||||
@property
|
||||
def names(self) -> List[str]:
|
||||
def names(self) -> list[str]:
|
||||
"""所有已注册任务名(按插入顺序)。"""
|
||||
return list(self._specs.keys())
|
||||
|
||||
@@ -115,7 +115,7 @@ class Graph:
|
||||
"""返回 ``name`` 的 spec;不存在则 ``KeyError``。"""
|
||||
return self._specs[name]
|
||||
|
||||
def dependencies(self, name: str) -> Tuple[str, ...]:
|
||||
def dependencies(self, name: str) -> tuple[str, ...]:
|
||||
"""``name`` 的直接前驱。"""
|
||||
return self._deps[name]
|
||||
|
||||
@@ -123,7 +123,7 @@ class Graph:
|
||||
"""name -> spec 的只读视图。"""
|
||||
return self._specs
|
||||
|
||||
def layers(self) -> List[List[str]]:
|
||||
def layers(self) -> list[list[str]]:
|
||||
"""将任务分组为可并行执行的层(Kahn 算法)。
|
||||
|
||||
同层任务无相互依赖,可并发执行。层按执行顺序返回。
|
||||
@@ -132,7 +132,7 @@ class Graph:
|
||||
"""
|
||||
self.validate()
|
||||
sorter = _TopologicalSorter(self._deps)
|
||||
result: List[List[str]] = []
|
||||
result: list[list[str]] = []
|
||||
# ``get_ready`` + ``done`` 每次给出一层,正好是并行执行所需的分组。
|
||||
sorter.prepare()
|
||||
while sorter.is_active():
|
||||
@@ -147,19 +147,21 @@ class Graph:
|
||||
# ------------------------------------------------------------------ #
|
||||
# 子图 / 标签过滤
|
||||
# ------------------------------------------------------------------ #
|
||||
def subgraph(self, tags: Iterable[str]) -> "Graph":
|
||||
def subgraph(self, tags: Iterable[str]) -> Graph:
|
||||
"""返回仅包含匹配任意标签的任务的新图。
|
||||
|
||||
依赖会被修剪,仅保留被保留任务之间的边;指向被丢弃任务的边
|
||||
会被移除(被保留的任务不再等待它们)。用于调试时运行大型
|
||||
DAG 的切片。
|
||||
"""
|
||||
wanted: Set[str] = set(tags)
|
||||
kept: List[TaskSpec[object]] = []
|
||||
wanted: set[str] = set(tags)
|
||||
kept: list[TaskSpec[object]] = []
|
||||
for spec in self._specs.values():
|
||||
if wanted & set(spec.tags):
|
||||
pruned_deps = tuple(
|
||||
d for d in spec.depends_on if d in self._specs and (wanted & set(self._specs[d].tags))
|
||||
d
|
||||
for d in spec.depends_on
|
||||
if d in self._specs and (wanted & set(self._specs[d].tags))
|
||||
)
|
||||
kept.append(
|
||||
TaskSpec(
|
||||
@@ -178,13 +180,13 @@ class Graph:
|
||||
)
|
||||
return Graph.from_specs(kept)
|
||||
|
||||
def subgraph_by_names(self, names: Iterable[str]) -> "Graph":
|
||||
def subgraph_by_names(self, names: Iterable[str]) -> Graph:
|
||||
"""返回限定于 ``names`` 的新图(边已修剪)。"""
|
||||
wanted: Set[str] = set(names)
|
||||
wanted: set[str] = set(names)
|
||||
for n in wanted:
|
||||
if n not in self._specs:
|
||||
raise KeyError(f"Unknown task name: {n!r}")
|
||||
kept: List[TaskSpec[object]] = []
|
||||
kept: list[TaskSpec[object]] = []
|
||||
for spec in self._specs.values():
|
||||
if spec.name in wanted:
|
||||
pruned_deps = tuple(d for d in spec.depends_on if d in wanted)
|
||||
@@ -217,8 +219,10 @@ class Graph:
|
||||
valid = {"TD", "TB", "BT", "LR", "RL"}
|
||||
orientation = orientation.upper()
|
||||
if orientation not in valid:
|
||||
raise ValueError(f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}.")
|
||||
lines: List[str] = [f"graph {orientation}"]
|
||||
raise ValueError(
|
||||
f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}."
|
||||
)
|
||||
lines: list[str] = [f"graph {orientation}"]
|
||||
for name in self._specs:
|
||||
lines.append(f' {name}["{name}"]')
|
||||
for name, deps in self._deps.items():
|
||||
@@ -231,7 +235,7 @@ class Graph:
|
||||
# ------------------------------------------------------------------ #
|
||||
def describe(self) -> str:
|
||||
"""用于调试的人类可读多行摘要。"""
|
||||
out: List[str] = [f"Graph(tasks={len(self._specs)})"]
|
||||
out: list[str] = [f"Graph(tasks={len(self._specs)})"]
|
||||
for layer_idx, layer in enumerate(self.layers(), 1):
|
||||
out.append(f" Layer {layer_idx}: {layer}")
|
||||
return "\n".join(out)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterator, List
|
||||
from typing import Any, Iterator
|
||||
|
||||
from .task import TaskResult, TaskStatus
|
||||
|
||||
@@ -24,7 +24,7 @@ class RunReport:
|
||||
当且仅当所有非跳过任务都以 ``SUCCESS`` 结束时为 ``True``。
|
||||
"""
|
||||
|
||||
results: Dict[str, TaskResult[object]] = field(default_factory=dict)
|
||||
results: dict[str, TaskResult[object]] = field(default_factory=dict)
|
||||
success: bool = True
|
||||
|
||||
# ---- 类型化访问 --------------------------------------------------- #
|
||||
@@ -50,9 +50,9 @@ class RunReport:
|
||||
return len(self.results)
|
||||
|
||||
# ---- 汇总 --------------------------------------------------------- #
|
||||
def summary(self) -> Dict[str, Any]:
|
||||
def summary(self) -> dict[str, Any]:
|
||||
"""用于日志/仪表盘的紧凑统计字典。"""
|
||||
counts: Dict[str, int] = {}
|
||||
counts: dict[str, int] = {}
|
||||
total_duration = 0.0
|
||||
for r in self.results.values():
|
||||
counts[r.status.value] = counts.get(r.status.value, 0) + 1
|
||||
@@ -65,7 +65,7 @@ class RunReport:
|
||||
"total_duration_seconds": round(total_duration, 6),
|
||||
}
|
||||
|
||||
def failed_tasks(self) -> List[str]:
|
||||
def failed_tasks(self) -> list[str]:
|
||||
"""以 FAILED 状态结束的任务名列表。"""
|
||||
return [
|
||||
name for name, r in self.results.items() if r.status == TaskStatus.FAILED
|
||||
@@ -73,7 +73,7 @@ class RunReport:
|
||||
|
||||
def describe(self) -> str:
|
||||
"""用于调试的人类可读多行报告。"""
|
||||
lines: List[str] = [f"RunReport(success={self.success})"]
|
||||
lines: list[str] = [f"RunReport(success={self.success})"]
|
||||
for name, r in self.results.items():
|
||||
dur = f"{r.duration:.3f}s" if r.duration is not None else "-"
|
||||
err = f" error={r.error!r}" if r.error else ""
|
||||
|
||||
+18
-14
@@ -23,13 +23,13 @@ import argparse
|
||||
import dataclasses
|
||||
import enum
|
||||
import sys
|
||||
from typing import Dict, List, Optional, Sequence, Union
|
||||
from typing import Sequence
|
||||
|
||||
from .errors import PyFlowXError
|
||||
from .executors import Strategy, _normalize_strategy, run
|
||||
from .graph import Graph
|
||||
|
||||
__all__ = ["CliRunner", "CliExitCode"]
|
||||
__all__ = ["CliExitCode", "CliRunner"]
|
||||
|
||||
|
||||
class CliExitCode(enum.IntEnum):
|
||||
@@ -92,12 +92,16 @@ class CliRunner:
|
||||
基本用法::
|
||||
|
||||
runner = px.CliRunner(
|
||||
clean=px.Graph.from_specs([
|
||||
px.TaskSpec("cargo_clean", cmd=["cargo", "clean"]),
|
||||
]),
|
||||
build=px.Graph.from_specs([
|
||||
px.TaskSpec("uv_build", cmd=["uv", "build"]),
|
||||
]),
|
||||
clean=px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("cargo_clean", cmd=["cargo", "clean"]),
|
||||
]
|
||||
),
|
||||
build=px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("uv_build", cmd=["uv", "build"]),
|
||||
]
|
||||
),
|
||||
)
|
||||
runner.run() # 解析 sys.argv
|
||||
|
||||
@@ -114,7 +118,7 @@ class CliRunner:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
strategy: Union[str, Strategy] = Strategy.SEQUENTIAL,
|
||||
strategy: str | Strategy = Strategy.SEQUENTIAL,
|
||||
description: str = "",
|
||||
verbose: bool = True,
|
||||
**graphs: Graph,
|
||||
@@ -127,7 +131,7 @@ class CliRunner:
|
||||
raise TypeError(
|
||||
f"CliRunner 命令 {name!r} 的值必须是 Graph 实例, 实际是 {type(graph).__name__}"
|
||||
)
|
||||
self._graphs: Dict[str, Graph] = dict(graphs)
|
||||
self._graphs: dict[str, Graph] = dict(graphs)
|
||||
self._strategy: Strategy = _normalize_strategy(strategy)
|
||||
self._description: str = description
|
||||
self._verbose: bool = verbose
|
||||
@@ -136,12 +140,12 @@ class CliRunner:
|
||||
# 内省
|
||||
# ------------------------------------------------------------------ #
|
||||
@property
|
||||
def commands(self) -> List[str]:
|
||||
def commands(self) -> list[str]:
|
||||
"""可用的命令列表 (按插入顺序)."""
|
||||
return list(self._graphs.keys())
|
||||
|
||||
@property
|
||||
def graphs(self) -> Dict[str, Graph]:
|
||||
def graphs(self) -> dict[str, Graph]:
|
||||
"""命令名到图的映射 (只读副本)."""
|
||||
return dict(self._graphs)
|
||||
|
||||
@@ -225,7 +229,7 @@ class CliRunner:
|
||||
# ------------------------------------------------------------------ #
|
||||
# 执行
|
||||
# ------------------------------------------------------------------ #
|
||||
def run(self, args: Optional[Sequence[str]] = None) -> int:
|
||||
def run(self, args: Sequence[str] | None = None) -> int:
|
||||
"""解析参数并执行对应的图.
|
||||
|
||||
Parameters
|
||||
@@ -293,7 +297,7 @@ class CliRunner:
|
||||
print(f"错误: {e}", file=sys.stderr)
|
||||
return CliExitCode.FAILURE.value
|
||||
|
||||
def run_cli(self, args: Optional[Sequence[str]] = None) -> None:
|
||||
def run_cli(self, args: Sequence[str] | None = None) -> None:
|
||||
"""运行并以退出码退出进程.
|
||||
|
||||
作为 CLI 工具运行时的入口点, 等价于 ``sys.exit(self.run(args))``.
|
||||
|
||||
+24
-12
@@ -17,9 +17,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from .errors import StorageError
|
||||
|
||||
@@ -52,20 +54,25 @@ class MemoryBackend(StateBackend):
|
||||
"""进程内 dict 后端。进程退出即丢失。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store: Dict[str, Any] = {}
|
||||
self._store: dict[str, Any] = {}
|
||||
|
||||
@override
|
||||
def load(self) -> Mapping[str, Any]:
|
||||
return dict(self._store)
|
||||
|
||||
@override
|
||||
def save(self, name: str, value: Any) -> None:
|
||||
self._store[name] = value
|
||||
|
||||
@override
|
||||
def has(self, name: str) -> bool:
|
||||
return name in self._store
|
||||
|
||||
@override
|
||||
def get(self, name: str) -> Any:
|
||||
return self._store[name]
|
||||
|
||||
@override
|
||||
def clear(self) -> None:
|
||||
self._store.clear()
|
||||
|
||||
@@ -79,16 +86,16 @@ class JSONBackend(StateBackend):
|
||||
"""
|
||||
|
||||
def __init__(self, path: str) -> None:
|
||||
self._path = path
|
||||
self._store: Dict[str, Any] = {}
|
||||
self._path: str = path
|
||||
self._store: dict[str, Any] = {}
|
||||
self._load()
|
||||
|
||||
def _load(self) -> None:
|
||||
if not os.path.exists(self._path):
|
||||
if not Path(self._path).exists():
|
||||
return
|
||||
try:
|
||||
with open(self._path, "r", encoding="utf-8") as fh:
|
||||
data = json.load(fh)
|
||||
with open(self._path, encoding="utf-8") as fh:
|
||||
data: Any = json.load(fh)
|
||||
if isinstance(data, dict):
|
||||
self._store = data
|
||||
except (OSError, json.JSONDecodeError) as exc:
|
||||
@@ -99,13 +106,15 @@ class JSONBackend(StateBackend):
|
||||
try:
|
||||
with open(tmp, "w", encoding="utf-8") as fh:
|
||||
json.dump(self._store, fh, ensure_ascii=False, indent=2)
|
||||
os.replace(tmp, self._path)
|
||||
Path(tmp).replace(Path(self._path))
|
||||
except (OSError, TypeError) as exc:
|
||||
raise StorageError(f"cannot write state file {self._path!r}", exc) from exc
|
||||
|
||||
@override
|
||||
def load(self) -> Mapping[str, Any]:
|
||||
return dict(self._store)
|
||||
|
||||
@override
|
||||
def save(self, name: str, value: Any) -> None:
|
||||
# 在修改内存状态前先校验可序列化性。
|
||||
try:
|
||||
@@ -115,19 +124,22 @@ class JSONBackend(StateBackend):
|
||||
f"result of task {name!r} is not JSON-serialisable", exc
|
||||
) from exc
|
||||
self._store[name] = value
|
||||
self._flush()
|
||||
_ = self._flush()
|
||||
|
||||
@override
|
||||
def has(self, name: str) -> bool:
|
||||
return name in self._store
|
||||
|
||||
@override
|
||||
def get(self, name: str) -> Any:
|
||||
return self._store[name]
|
||||
|
||||
@override
|
||||
def clear(self) -> None:
|
||||
self._store.clear()
|
||||
self._flush()
|
||||
_ = self._flush()
|
||||
|
||||
|
||||
def resolve_backend(backend: Optional[StateBackend]) -> StateBackend:
|
||||
def resolve_backend(backend: StateBackend | None) -> StateBackend:
|
||||
"""返回 ``backend``;为 ``None`` 时返回新的 :class:`MemoryBackend`。"""
|
||||
return backend if backend is not None else MemoryBackend()
|
||||
|
||||
+16
-11
@@ -30,6 +30,7 @@ from typing import (
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
@@ -48,7 +49,7 @@ Context = Mapping[str, Any]
|
||||
TaskCmd = Union[
|
||||
List[str], # 命令列表, 如 ["ls", "-la"]
|
||||
str, # shell 命令字符串
|
||||
Callable[..., T], # Python 函数
|
||||
Callable[..., Any], # Python 函数
|
||||
]
|
||||
|
||||
# 条件判断函数类型
|
||||
@@ -151,12 +152,12 @@ class TaskSpec(Generic[T]):
|
||||
return self.fn
|
||||
raise ValueError(f"TaskSpec '{self.name}': 没有可执行的函数或命令。")
|
||||
|
||||
def _wrap_cmd(self) -> TaskFn[T]:
|
||||
def _wrap_cmd(self) -> TaskFn[Any]:
|
||||
"""将 cmd 包装为可执行函数.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TaskFn[T]
|
||||
TaskFn[Any]
|
||||
包装后的执行函数.
|
||||
"""
|
||||
cmd = self.cmd
|
||||
@@ -184,17 +185,19 @@ class TaskSpec(Generic[T]):
|
||||
check=False,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(f"命令未找到: {cmd_str}")
|
||||
raise RuntimeError(f"命令未找到: {cmd_str}") from None
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(f"命令执行超时: {cmd_str} ({timeout}s)")
|
||||
raise RuntimeError(
|
||||
f"命令执行超时: {cmd_str} ({timeout}s)"
|
||||
) from None
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"命令执行异常: {cmd_str}: {e}")
|
||||
raise RuntimeError(f"命令执行异常: {cmd_str}: {e}") from e
|
||||
|
||||
if verbose:
|
||||
print(f"[verbose] 返回码: {result.returncode}", flush=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
return None # type: ignore[return-value]
|
||||
return cast(T, None) # type: ignore[return-value]
|
||||
|
||||
err_msg = f"命令执行失败: `{cmd_str}`, 返回码: {result.returncode}"
|
||||
if not verbose and result.stderr.strip():
|
||||
@@ -224,17 +227,19 @@ class TaskSpec(Generic[T]):
|
||||
check=False,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(f"Shell 命令未找到: {cmd}")
|
||||
raise RuntimeError(f"Shell 命令未找到: {cmd}") from None
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(f"Shell 命令执行超时: {cmd} ({timeout}s)")
|
||||
raise RuntimeError(
|
||||
f"Shell 命令执行超时: {cmd} ({timeout}s)"
|
||||
) from None
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Shell 命令执行异常: {cmd}: {e}")
|
||||
raise RuntimeError(f"Shell 命令执行异常: {cmd}: {e}") from e
|
||||
|
||||
if verbose:
|
||||
print(f"[verbose] 返回码: {result.returncode}", flush=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
return None # type: ignore[return-value]
|
||||
return cast(T, None) # type: ignore[return-value]
|
||||
|
||||
err_msg = f"Shell 命令执行失败: `{cmd}`, 返回码: {result.returncode}"
|
||||
if not verbose and result.stderr.strip():
|
||||
|
||||
@@ -0,0 +1,165 @@
|
||||
"""Tests for pymake CLI."""
|
||||
|
||||
from pyflowx.cli.pymake import _build_graphs, _get_maturin_build_command, conf
|
||||
|
||||
|
||||
def test_pymake_config_attributes():
|
||||
"""Test PymakeConfig has expected attributes."""
|
||||
assert hasattr(conf, "PROJECT_ROOT")
|
||||
assert hasattr(conf, "BUILD_TOOL")
|
||||
assert hasattr(conf, "BUILD_COMMAND")
|
||||
assert hasattr(conf, "MATURIN_TOOL")
|
||||
assert hasattr(conf, "MATURIN_BUILD_COMMAND")
|
||||
assert hasattr(conf, "MATURIN_DEV_COMMAND")
|
||||
assert hasattr(conf, "TIMEOUT")
|
||||
|
||||
|
||||
def test_pymake_config_values():
|
||||
"""Test PymakeConfig values are correct."""
|
||||
assert conf.BUILD_TOOL == "uv"
|
||||
assert conf.BUILD_COMMAND == ["uv", "build"]
|
||||
assert conf.MATURIN_TOOL == "maturin"
|
||||
assert conf.TIMEOUT == 600
|
||||
|
||||
|
||||
def test_get_maturin_build_command_basic():
|
||||
"""Test _get_maturin_build_command returns base command."""
|
||||
cmd = _get_maturin_build_command()
|
||||
assert "maturin" in cmd
|
||||
assert "build" in cmd
|
||||
assert "-r" in cmd
|
||||
|
||||
|
||||
def test_build_graphs_returns_dict():
|
||||
"""Test _build_graphs returns a dictionary."""
|
||||
graphs = _build_graphs()
|
||||
assert isinstance(graphs, dict)
|
||||
assert len(graphs) > 0
|
||||
|
||||
|
||||
def test_build_graphs_has_expected_commands():
|
||||
"""Test _build_graphs has expected command keys."""
|
||||
graphs = _build_graphs()
|
||||
expected_commands = [
|
||||
"b",
|
||||
"bc",
|
||||
"ba",
|
||||
"ic",
|
||||
"ip",
|
||||
"ia",
|
||||
"cp",
|
||||
"cc",
|
||||
"ca",
|
||||
"t",
|
||||
"lint",
|
||||
]
|
||||
for cmd in expected_commands:
|
||||
assert cmd in graphs, f"Expected command '{cmd}' not found in graphs"
|
||||
|
||||
|
||||
def test_build_graphs_values_are_graphs():
|
||||
"""Test _build_graphs values are Graph instances."""
|
||||
import pyflowx as px
|
||||
|
||||
graphs = _build_graphs()
|
||||
for name, graph in graphs.items():
|
||||
assert isinstance(graph, px.Graph), (
|
||||
f"Graph for command '{name}' is not a Graph instance"
|
||||
)
|
||||
|
||||
|
||||
def test_build_command_graph_structure():
|
||||
"""Test 'b' command graph has correct structure."""
|
||||
|
||||
graphs = _build_graphs()
|
||||
graph = graphs["b"]
|
||||
assert len(graph.all_specs()) == 1
|
||||
spec = graph.spec("uv_build")
|
||||
assert spec.cmd == conf.BUILD_COMMAND
|
||||
|
||||
|
||||
def test_build_all_command_graph_structure():
|
||||
"""Test 'ba' command graph has correct dependencies."""
|
||||
|
||||
graphs = _build_graphs()
|
||||
graph = graphs["ba"]
|
||||
specs = graph.all_specs()
|
||||
assert len(specs) == 2
|
||||
# Check dependency
|
||||
uv_build_spec = graph.spec("uv_build")
|
||||
assert "maturin_build" in uv_build_spec.depends_on
|
||||
|
||||
|
||||
def test_maturin_build_command_graph_structure():
|
||||
"""Test 'bc' command graph has correct structure."""
|
||||
graphs = _build_graphs()
|
||||
graph = graphs["bc"]
|
||||
specs = graph.all_specs()
|
||||
assert len(specs) == 1
|
||||
spec = graph.spec("maturin_build")
|
||||
assert spec.cmd == _get_maturin_build_command()
|
||||
|
||||
|
||||
def test_install_all_command_graph_structure():
|
||||
"""Test 'ia' command graph has correct dependencies."""
|
||||
graphs = _build_graphs()
|
||||
graph = graphs["ia"]
|
||||
specs = graph.all_specs()
|
||||
assert len(specs) == 2
|
||||
uv_install_spec = graph.spec("uv_install")
|
||||
assert "maturin_dev" in uv_install_spec.depends_on
|
||||
|
||||
|
||||
def test_clean_all_command_graph_structure():
|
||||
"""Test 'ca' command graph has correct structure."""
|
||||
graphs = _build_graphs()
|
||||
graph = graphs["ca"]
|
||||
specs = graph.all_specs()
|
||||
assert len(specs) == 2
|
||||
|
||||
|
||||
def test_test_command_graph_structure():
|
||||
"""Test 't' command graph has correct structure."""
|
||||
graphs = _build_graphs()
|
||||
graph = graphs["t"]
|
||||
specs = graph.all_specs()
|
||||
assert len(specs) == 1
|
||||
spec = graph.spec("pytest")
|
||||
assert "pytest" in spec.cmd
|
||||
|
||||
|
||||
def test_lint_command_graph_structure():
|
||||
"""Test 'lint' command graph has correct structure."""
|
||||
graphs = _build_graphs()
|
||||
graph = graphs["lint"]
|
||||
specs = graph.all_specs()
|
||||
assert len(specs) == 1
|
||||
spec = graph.spec("ruff_check")
|
||||
assert "ruff" in spec.cmd
|
||||
|
||||
|
||||
def test_pymake_config_dirs_to_ignore():
|
||||
"""Test PymakeConfig has correct dirs to ignore."""
|
||||
assert ".venv" in conf.DIRS_TO_IGNORE
|
||||
assert ".git" in conf.DIRS_TO_IGNORE
|
||||
assert ".tox" in conf.DIRS_TO_IGNORE
|
||||
|
||||
|
||||
def test_pymake_config_python_build_dirs():
|
||||
"""Test PymakeConfig has correct Python build dirs."""
|
||||
assert "dist" in conf.PYTHON_BUILD_DIRS
|
||||
assert "build" in conf.PYTHON_BUILD_DIRS
|
||||
|
||||
|
||||
def test_maturin_build_options_win7():
|
||||
"""Test MATURIN_BUILD_OPTIONS_WIN7 has expected options."""
|
||||
assert "--target" in conf.MATURIN_BUILD_OPTIONS_WIN7
|
||||
assert "x86_64-win7-windows-msvc" in conf.MATURIN_BUILD_OPTIONS_WIN7
|
||||
assert "-Zbuild-std" in conf.MATURIN_BUILD_OPTIONS_WIN7
|
||||
|
||||
|
||||
def test_doc_build_command():
|
||||
"""Test DOC_BUILD_COMMAND has expected structure."""
|
||||
assert "sphinx-build" in conf.DOC_BUILD_COMMAND
|
||||
assert "-b" in conf.DOC_BUILD_COMMAND
|
||||
assert "html" in conf.DOC_BUILD_COMMAND
|
||||
@@ -0,0 +1,178 @@
|
||||
"""Tests for conditions module."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import patch
|
||||
|
||||
from pyflowx.conditions import (
|
||||
IS_LINUX,
|
||||
IS_MACOS,
|
||||
IS_POSIX,
|
||||
IS_WINDOWS,
|
||||
BuiltinConditions,
|
||||
Constants,
|
||||
)
|
||||
|
||||
|
||||
def test_constants_is_windows():
|
||||
"""Test Constants.IS_WINDOWS is correct."""
|
||||
assert (sys.platform == "win32") == Constants.IS_WINDOWS
|
||||
|
||||
|
||||
def test_constants_is_linux():
|
||||
"""Test Constants.IS_LINUX is correct."""
|
||||
assert (sys.platform == "linux") == Constants.IS_LINUX
|
||||
|
||||
|
||||
def test_constants_is_macos():
|
||||
"""Test Constants.IS_MACOS is correct."""
|
||||
assert (sys.platform == "darwin") == Constants.IS_MACOS
|
||||
|
||||
|
||||
def test_constants_is_posix():
|
||||
"""Test Constants.IS_POSIX is correct."""
|
||||
assert (sys.platform != "win32") == Constants.IS_POSIX
|
||||
|
||||
|
||||
def test_builtin_conditions_is_windows():
|
||||
"""Test BuiltinConditions.IS_WINDOWS."""
|
||||
result = BuiltinConditions.IS_WINDOWS()
|
||||
assert result == Constants.IS_WINDOWS
|
||||
|
||||
|
||||
def test_builtin_conditions_is_linux():
|
||||
"""Test BuiltinConditions.IS_LINUX."""
|
||||
result = BuiltinConditions.IS_LINUX()
|
||||
assert result == Constants.IS_LINUX
|
||||
|
||||
|
||||
def test_builtin_conditions_is_macos():
|
||||
"""Test BuiltinConditions.IS_MACOS."""
|
||||
result = BuiltinConditions.IS_MACOS()
|
||||
assert result == Constants.IS_MACOS
|
||||
|
||||
|
||||
def test_builtin_conditions_is_posix():
|
||||
"""Test BuiltinConditions.IS_POSIX."""
|
||||
result = BuiltinConditions.IS_POSIX()
|
||||
assert result == Constants.IS_POSIX
|
||||
|
||||
|
||||
def test_builtin_conditions_python_version_major_only():
|
||||
"""Test BuiltinConditions.PYTHON_VERSION with major only."""
|
||||
# Test with current Python version
|
||||
current_major = sys.version_info.major
|
||||
assert BuiltinConditions.PYTHON_VERSION(current_major) is True
|
||||
assert BuiltinConditions.PYTHON_VERSION(current_major + 1) is False
|
||||
|
||||
|
||||
def test_builtin_conditions_python_version_with_minor():
|
||||
"""Test BuiltinConditions.PYTHON_VERSION with major and minor."""
|
||||
current_major = sys.version_info.major
|
||||
current_minor = sys.version_info.minor
|
||||
assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor) is True
|
||||
assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor + 1) is False
|
||||
|
||||
|
||||
def test_builtin_conditions_python_version_at_least():
|
||||
"""Test BuiltinConditions.PYTHON_VERSION_AT_LEAST."""
|
||||
current_major = sys.version_info.major
|
||||
current_minor = sys.version_info.minor
|
||||
# Current version should be at least itself
|
||||
assert (
|
||||
BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major, current_minor) is True
|
||||
)
|
||||
# Current version should be at least an older version
|
||||
assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major - 1, 0) is True
|
||||
# Current version should NOT be at least a newer version
|
||||
assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major + 1, 0) is False
|
||||
|
||||
|
||||
def test_builtin_conditions_has_app_installed_true():
|
||||
"""Test BuiltinConditions.HAS_APP_INSTALLED when app exists."""
|
||||
# Python should always be available
|
||||
condition = BuiltinConditions.HAS_APP_INSTALLED("python")
|
||||
assert condition() is True
|
||||
|
||||
|
||||
def test_builtin_conditions_has_app_installed_false():
|
||||
"""Test BuiltinConditions.HAS_APP_INSTALLED when app doesn't exist."""
|
||||
condition = BuiltinConditions.HAS_APP_INSTALLED("nonexistent_app_12345")
|
||||
assert condition() is False
|
||||
|
||||
|
||||
def test_builtin_conditions_env_var_exists_true():
|
||||
"""Test BuiltinConditions.ENV_VAR_EXISTS when variable exists."""
|
||||
with patch.dict(os.environ, {"TEST_VAR": "value"}):
|
||||
condition = BuiltinConditions.ENV_VAR_EXISTS("TEST_VAR")
|
||||
assert condition() is True
|
||||
|
||||
|
||||
def test_builtin_conditions_env_var_exists_false():
|
||||
"""Test BuiltinConditions.ENV_VAR_EXISTS when variable doesn't exist."""
|
||||
condition = BuiltinConditions.ENV_VAR_EXISTS("NONEXISTENT_VAR_12345")
|
||||
assert condition() is False
|
||||
|
||||
|
||||
def test_builtin_conditions_env_var_equals_true():
|
||||
"""Test BuiltinConditions.ENV_VAR_EQUALS when value matches."""
|
||||
with patch.dict(os.environ, {"TEST_VAR": "expected_value"}):
|
||||
condition = BuiltinConditions.ENV_VAR_EQUALS("TEST_VAR", "expected_value")
|
||||
assert condition() is True
|
||||
|
||||
|
||||
def test_builtin_conditions_env_var_equals_false():
|
||||
"""Test BuiltinConditions.ENV_VAR_EQUALS when value doesn't match."""
|
||||
with patch.dict(os.environ, {"TEST_VAR": "different_value"}):
|
||||
condition = BuiltinConditions.ENV_VAR_EQUALS("TEST_VAR", "expected_value")
|
||||
assert condition() is False
|
||||
|
||||
|
||||
def test_builtin_conditions_not():
|
||||
"""Test BuiltinConditions.NOT."""
|
||||
true_condition = lambda: True # noqa: E731
|
||||
false_condition = lambda: False # noqa: E731
|
||||
|
||||
not_true = BuiltinConditions.NOT(true_condition)
|
||||
assert not_true() is False
|
||||
|
||||
not_false = BuiltinConditions.NOT(false_condition)
|
||||
assert not_false() is True
|
||||
|
||||
|
||||
def test_builtin_conditions_and_all_true():
|
||||
"""Test BuiltinConditions.AND when all conditions are true."""
|
||||
true_condition = lambda: True # noqa: E731
|
||||
condition = BuiltinConditions.AND(true_condition, true_condition, true_condition)
|
||||
assert condition() is True
|
||||
|
||||
|
||||
def test_builtin_conditions_and_one_false():
|
||||
"""Test BuiltinConditions.AND when one condition is false."""
|
||||
true_condition = lambda: True # noqa: E731
|
||||
false_condition = lambda: False # noqa: E731
|
||||
condition = BuiltinConditions.AND(true_condition, false_condition, true_condition)
|
||||
assert condition() is False
|
||||
|
||||
|
||||
def test_builtin_conditions_or_all_false():
|
||||
"""Test BuiltinConditions.OR when all conditions are false."""
|
||||
false_condition = lambda: False # noqa: E731
|
||||
condition = BuiltinConditions.OR(false_condition, false_condition, false_condition)
|
||||
assert condition() is False
|
||||
|
||||
|
||||
def test_builtin_conditions_or_one_true():
|
||||
"""Test BuiltinConditions.OR when one condition is true."""
|
||||
true_condition = lambda: True # noqa: E731
|
||||
false_condition = lambda: False # noqa: E731
|
||||
condition = BuiltinConditions.OR(false_condition, true_condition, false_condition)
|
||||
assert condition() is True
|
||||
|
||||
|
||||
def test_exported_conditions():
|
||||
"""Test exported condition functions."""
|
||||
assert IS_WINDOWS() == Constants.IS_WINDOWS
|
||||
assert IS_LINUX() == Constants.IS_LINUX
|
||||
assert IS_MACOS() == Constants.IS_MACOS
|
||||
assert IS_POSIX() == Constants.IS_POSIX
|
||||
@@ -0,0 +1,189 @@
|
||||
"""Tests for executors module edge cases."""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.task import TaskStatus
|
||||
|
||||
# 跨平台的 echo 命令
|
||||
if sys.platform == "win32":
|
||||
ECHO_CMD = ["cmd", "/c", "echo"]
|
||||
else:
|
||||
ECHO_CMD = ["echo"]
|
||||
|
||||
|
||||
def test_execute_sync_with_timeout():
|
||||
"""Test execute task with timeout correctly."""
|
||||
# Note: timeout for Python functions only works in async strategy
|
||||
# For sync functions, timeout is not enforced in sequential strategy
|
||||
# This test verifies that the task runs without timeout error
|
||||
spec = px.TaskSpec("quick", fn=lambda: "result", timeout=10)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
# Should succeed without timeout error
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
|
||||
|
||||
def test_execute_async_with_timeout():
|
||||
"""Test execute async task with timeout correctly."""
|
||||
|
||||
async def slow_async_function():
|
||||
await asyncio.sleep(2)
|
||||
return "result"
|
||||
|
||||
spec = px.TaskSpec("slow_async", fn=slow_async_function, timeout=0.5)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
# This should timeout
|
||||
with pytest.raises(px.TaskFailedError):
|
||||
px.run(graph, strategy="async")
|
||||
|
||||
|
||||
def test_verbose_event_callback_running():
|
||||
"""Test verbose event callback for RUNNING status."""
|
||||
# Create a graph with verbose callback
|
||||
spec = px.TaskSpec("test", fn=lambda: "result", verbose=True)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
# Should print without error
|
||||
assert report.success
|
||||
|
||||
|
||||
def test_verbose_event_callback_success():
|
||||
"""Test verbose event callback for SUCCESS status."""
|
||||
# Create a graph with verbose callback
|
||||
spec = px.TaskSpec("test", fn=lambda: "result", verbose=True)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
# Should print without error
|
||||
assert report.success
|
||||
|
||||
|
||||
def test_verbose_event_callback_failed():
|
||||
"""Test verbose event callback for FAILED status."""
|
||||
# Create a graph with verbose callback and failing task
|
||||
|
||||
def raise_error():
|
||||
raise ValueError("test error")
|
||||
|
||||
spec = px.TaskSpec("test", fn=raise_error, verbose=True)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
# Should print without error
|
||||
with pytest.raises(px.TaskFailedError):
|
||||
px.run(graph, strategy="sequential")
|
||||
|
||||
|
||||
def test_verbose_event_callback_skipped():
|
||||
"""Test verbose event callback for SKIPPED status."""
|
||||
# Create a graph with verbose callback and skipped task
|
||||
spec = px.TaskSpec(
|
||||
"test",
|
||||
fn=lambda: "result",
|
||||
conditions=(lambda: False,),
|
||||
verbose=True,
|
||||
)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
# Should print without error
|
||||
assert report.success
|
||||
|
||||
|
||||
def test_execute_sync_with_retries():
|
||||
"""Test execute task with retries."""
|
||||
|
||||
call_count = 0
|
||||
|
||||
def failing_function():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise ValueError("temporary error")
|
||||
return "success"
|
||||
|
||||
spec = px.TaskSpec("retry_test", fn=failing_function, retries=3)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
# Should succeed after retries
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert report.results["retry_test"].attempts == 3
|
||||
|
||||
|
||||
def test_execute_async_with_retries():
|
||||
"""Test execute async task with retries."""
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def failing_async_function():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise ValueError("temporary error")
|
||||
return "success"
|
||||
|
||||
spec = px.TaskSpec("retry_async_test", fn=failing_async_function, retries=3)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
# Should succeed after retries
|
||||
report = px.run(graph, strategy="async")
|
||||
assert report.success
|
||||
assert report.results["retry_async_test"].attempts == 3
|
||||
|
||||
|
||||
def test_execute_sync_skip_on_condition():
|
||||
"""Test execute task skips task when condition is false."""
|
||||
spec = px.TaskSpec(
|
||||
"skip_test",
|
||||
fn=lambda: "result",
|
||||
conditions=(lambda: False,),
|
||||
)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert report.results["skip_test"].status == TaskStatus.SKIPPED
|
||||
|
||||
|
||||
def test_execute_async_skip_on_condition():
|
||||
"""Test execute async task skips task when condition is false."""
|
||||
spec = px.TaskSpec(
|
||||
"skip_async_test",
|
||||
fn=lambda: "result",
|
||||
conditions=(lambda: False,),
|
||||
)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
report = px.run(graph, strategy="async")
|
||||
assert report.success
|
||||
assert report.results["skip_async_test"].status == TaskStatus.SKIPPED
|
||||
|
||||
|
||||
def test_execute_sync_with_error():
|
||||
"""Test execute task handles errors correctly."""
|
||||
|
||||
def error_function():
|
||||
raise ValueError("test error")
|
||||
|
||||
spec = px.TaskSpec("error_test", fn=error_function)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
with pytest.raises(px.TaskFailedError):
|
||||
px.run(graph, strategy="sequential")
|
||||
|
||||
|
||||
def test_execute_async_with_error():
|
||||
"""Test execute async task handles errors correctly."""
|
||||
|
||||
async def error_async_function():
|
||||
raise ValueError("test error")
|
||||
|
||||
spec = px.TaskSpec("error_async_test", fn=error_async_function)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
with pytest.raises(px.TaskFailedError):
|
||||
px.run(graph, strategy="async")
|
||||
@@ -0,0 +1,156 @@
|
||||
"""Tests for task module edge cases."""
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.task import TaskSpec
|
||||
|
||||
# 跨平台的 echo 命令
|
||||
if sys.platform == "win32":
|
||||
ECHO_CMD = ["cmd", "/c", "echo"]
|
||||
else:
|
||||
ECHO_CMD = ["echo"]
|
||||
|
||||
|
||||
def test_taskspec_wrap_cmd_with_list():
|
||||
"""Test TaskSpec._wrap_cmd with command list."""
|
||||
spec = TaskSpec("test", cmd=[*ECHO_CMD, "hello"])
|
||||
wrapped_fn = spec.effective_fn
|
||||
assert wrapped_fn is not None
|
||||
assert wrapped_fn.__name__ == "test"
|
||||
|
||||
|
||||
def test_taskspec_wrap_cmd_with_string():
|
||||
"""Test TaskSpec._wrap_cmd with command string."""
|
||||
if sys.platform == "win32":
|
||||
cmd_str = "cmd /c echo hello"
|
||||
else:
|
||||
cmd_str = "echo hello"
|
||||
spec = TaskSpec("test", cmd=cmd_str)
|
||||
wrapped_fn = spec.effective_fn
|
||||
assert wrapped_fn is not None
|
||||
assert wrapped_fn.__name__ == "test"
|
||||
|
||||
|
||||
def test_taskspec_wrap_cmd_with_timeout():
|
||||
"""Test TaskSpec._wrap_cmd with timeout."""
|
||||
spec = TaskSpec("test", cmd=[*ECHO_CMD, "hello"], timeout=0.1)
|
||||
wrapped_fn = spec.effective_fn
|
||||
|
||||
# Should not raise timeout error for quick command
|
||||
result = wrapped_fn()
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_taskspec_wrap_cmd_with_cwd():
|
||||
"""Test TaskSpec._wrap_cmd with working directory."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
spec = TaskSpec("test", cmd=[*ECHO_CMD, "hello"], cwd=tmpdir)
|
||||
wrapped_fn = spec.effective_fn
|
||||
result = wrapped_fn()
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_taskspec_wrap_cmd_verbose():
|
||||
"""Test TaskSpec._wrap_cmd with verbose=True."""
|
||||
spec = TaskSpec("test", cmd=[*ECHO_CMD, "hello"], verbose=True)
|
||||
wrapped_fn = spec.effective_fn
|
||||
|
||||
# Should print verbose output
|
||||
result = wrapped_fn()
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_taskspec_wrap_cmd_error():
|
||||
"""Test TaskSpec._wrap_cmd handles command error."""
|
||||
spec = TaskSpec("test", cmd=["python", "-c", "import sys; sys.exit(1)"])
|
||||
wrapped_fn = spec.effective_fn
|
||||
|
||||
with pytest.raises(RuntimeError, match="命令执行失败"):
|
||||
wrapped_fn()
|
||||
|
||||
|
||||
def test_taskspec_wrap_cmd_file_not_found():
|
||||
"""Test TaskSpec._wrap_cmd handles file not found."""
|
||||
spec = TaskSpec("test", cmd=["nonexistent_command"])
|
||||
wrapped_fn = spec.effective_fn
|
||||
|
||||
with pytest.raises(RuntimeError, match="命令未找到"):
|
||||
wrapped_fn()
|
||||
|
||||
|
||||
def test_taskspec_wrap_cmd_shell_file_not_found():
|
||||
"""Test TaskSpec._wrap_cmd handles shell command file not found."""
|
||||
spec = TaskSpec("test", cmd="nonexistent_shell_command")
|
||||
wrapped_fn = spec.effective_fn
|
||||
|
||||
# Shell commands don't raise FileNotFoundError
|
||||
# They just return non-zero exit code
|
||||
with pytest.raises(RuntimeError):
|
||||
wrapped_fn()
|
||||
|
||||
|
||||
def test_taskspec_no_fn_no_cmd():
|
||||
"""Test TaskSpec raises error when no fn or cmd."""
|
||||
with pytest.raises(ValueError, match="必须提供 fn 或 cmd 参数"):
|
||||
TaskSpec("test")
|
||||
|
||||
|
||||
def test_taskspec_cmd_overrides_fn():
|
||||
"""Test TaskSpec cmd overrides fn."""
|
||||
|
||||
def my_fn():
|
||||
return "fn_result"
|
||||
|
||||
spec = TaskSpec("test", fn=my_fn, cmd=[*ECHO_CMD, "hello"])
|
||||
wrapped_fn = spec.effective_fn
|
||||
|
||||
# cmd should override fn
|
||||
assert wrapped_fn.__name__ == "test"
|
||||
|
||||
|
||||
def test_taskspec_conditions_check():
|
||||
"""Test TaskSpec.should_execute with conditions."""
|
||||
spec = px.TaskSpec(
|
||||
"test",
|
||||
fn=lambda: "result",
|
||||
conditions=(lambda: True,),
|
||||
)
|
||||
|
||||
assert spec.should_execute() is True
|
||||
|
||||
|
||||
def test_taskspec_conditions_false():
|
||||
"""Test TaskSpec.should_execute with false conditions."""
|
||||
spec = px.TaskSpec(
|
||||
"test",
|
||||
fn=lambda: "result",
|
||||
conditions=(lambda: False,),
|
||||
)
|
||||
|
||||
assert spec.should_execute() is False
|
||||
|
||||
|
||||
def test_taskspec_conditions_multiple():
|
||||
"""Test TaskSpec.should_execute with multiple conditions."""
|
||||
spec = px.TaskSpec(
|
||||
"test",
|
||||
fn=lambda: "result",
|
||||
conditions=(lambda: True, lambda: True, lambda: True),
|
||||
)
|
||||
|
||||
assert spec.should_execute() is True
|
||||
|
||||
|
||||
def test_taskspec_conditions_multiple_one_false():
|
||||
"""Test TaskSpec.should_execute with one false condition."""
|
||||
spec = px.TaskSpec(
|
||||
"test",
|
||||
fn=lambda: "result",
|
||||
conditions=(lambda: True, lambda: False, lambda: True),
|
||||
)
|
||||
|
||||
assert spec.should_execute() is False
|
||||
@@ -1907,6 +1907,8 @@ version = "0.1.2"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "graphlib-backport", marker = "python_full_version < '3.9'" },
|
||||
{ name = "typing-extensions", version = "4.13.2", source = { registry = "https://mirrors.aliyun.com/pypi/simple/" }, marker = "python_full_version < '3.9'" },
|
||||
{ name = "typing-extensions", version = "4.15.0", source = { registry = "https://mirrors.aliyun.com/pypi/simple/" }, marker = "python_full_version >= '3.9'" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
@@ -1961,6 +1963,7 @@ requires-dist = [
|
||||
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.8.0" },
|
||||
{ name = "tox", marker = "extra == 'dev'", specifier = ">=4.25.0" },
|
||||
{ name = "tox-uv", marker = "extra == 'dev'", specifier = ">=1.13.1" },
|
||||
{ name = "typing-extensions", specifier = ">=4.13.2" },
|
||||
]
|
||||
provides-extras = ["dev"]
|
||||
|
||||
@@ -2759,6 +2762,7 @@ name = "typing-extensions"
|
||||
version = "4.15.0"
|
||||
source = { registry = "https://mirrors.aliyun.com/pypi/simple/" }
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.15'",
|
||||
"python_full_version >= '3.10' and python_full_version < '3.15'",
|
||||
"python_full_version > '3.9' and python_full_version < '3.10'",
|
||||
"python_full_version == '3.9'",
|
||||
|
||||
Reference in New Issue
Block a user