diff --git a/pyproject.toml b/pyproject.toml index 742bf6f..230743d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,10 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Topic :: Software Development :: Libraries :: Application Frameworks", ] -dependencies = ["graphlib_backport >= 1.0.0; python_version < '3.9'"] +dependencies = [ + "graphlib_backport >= 1.0.0; python_version < '3.9'", + "typing-extensions>=4.13.2", +] description = "Lightweight, type-safe DAG task scheduler with multi-strategy execution." keywords = ["async", "dag", "scheduler", "task", "workflow"] license = { text = "MIT" } @@ -86,3 +89,46 @@ reportImplicitStringConcatenation = "error" reportMissingTypeStubs = "none" reportUnusedCallResult = "warning" typeCheckingMode = "recommended" # 类型检查严格度:off / basic / standard / recommended(默认) / strict / all + +# Ruff 配置 - 与 .pre-commit-config.yaml 保持一致 +[tool.ruff] +target-version = "py38" +line-length = 88 + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # Pyflakes + "I", # isort + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "UP", # pyupgrade + "ARG", # flake8-unused-arguments + "SIM", # flake8-simplify + "PTH", # flake8-use-pathlib + "PL", # Pylint + "RUF", # Ruff-specific rules +] +ignore = [ + "E501", # line too long (handled by formatter) + "PLR0913", # too many arguments + "PLR2004", # magic value comparison + "PTH123", # pathlib open() replacement + "SIM108", # use ternary operator + "RUF001", # ambiguous unicode characters in string + "RUF002", # ambiguous unicode characters in docstring + "RUF003", # ambiguous unicode characters in comment + "RUF012", # mutable class attributes (intentional for config) + "PLC0415", # import should be at top-level (intentional for lazy imports) + "PLR0915", # too many statements (intentional for complex methods) + "PTH119", # os.path.basename (intentional for sys.argv) +] + +[tool.ruff.lint.isort] +known-first-party = ["pyflowx"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +docstring-code-format = true diff --git a/src/pyflowx/__init__.py b/src/pyflowx/__init__.py index 4afaed6..2ff78d1 100644 --- a/src/pyflowx/__init__.py +++ b/src/pyflowx/__init__.py @@ -81,43 +81,43 @@ from .task import TaskCmd, TaskEvent, TaskResult, TaskSpec, TaskStatus __version__ = "0.1.2" __all__ = [ - # 核心类型 - "TaskSpec", - "TaskStatus", - "TaskResult", - "TaskEvent", - "Context", - "TaskCmd", - "Graph", - "RunReport", - # 执行 - "run", - "Strategy", - # CLI 运行器 - "CliRunner", - "CliExitCode", - # 状态后端 - "StateBackend", - "MemoryBackend", - "JSONBackend", - # 错误 - "PyFlowXError", - "DuplicateTaskError", - "MissingDependencyError", - "CycleError", - "TaskFailedError", - "TaskTimeoutError", - "InjectionError", - "StorageError", - # 条件判断 - "Condition", - "Constants", - "BuiltinConditions", - "IS_WINDOWS", "IS_LINUX", "IS_MACOS", "IS_POSIX", + "IS_WINDOWS", + "BuiltinConditions", + "CliExitCode", + # CLI 运行器 + "CliRunner", + # 条件判断 + "Condition", + "Constants", + "Context", + "CycleError", + "DuplicateTaskError", + "Graph", + "InjectionError", + "JSONBackend", + "MemoryBackend", + "MissingDependencyError", + # 错误 + "PyFlowXError", + "RunReport", + # 状态后端 + "StateBackend", + "StorageError", + "Strategy", + "TaskCmd", + "TaskEvent", + "TaskFailedError", + "TaskResult", + # 核心类型 + "TaskSpec", + "TaskStatus", + "TaskTimeoutError", # 辅助(高级) "build_call_args", "describe_injection", + # 执行 + "run", ] diff --git a/src/pyflowx/conditions.py b/src/pyflowx/conditions.py index b6bc993..49975a3 100644 --- a/src/pyflowx/conditions.py +++ b/src/pyflowx/conditions.py @@ -8,7 +8,7 @@ from __future__ import annotations import shutil import sys -from typing import Callable, Optional +from typing import Callable # 条件判断函数类型 Condition = Callable[[], bool] @@ -47,7 +47,7 @@ class BuiltinConditions: return Constants.IS_POSIX @staticmethod - def PYTHON_VERSION(major: int, minor: Optional[int] = None) -> bool: + def PYTHON_VERSION(major: int, minor: int | None = None) -> bool: """检查 Python 版本是否匹配. Parameters diff --git a/src/pyflowx/context.py b/src/pyflowx/context.py index bb5aea8..dbe9ede 100644 --- a/src/pyflowx/context.py +++ b/src/pyflowx/context.py @@ -18,12 +18,12 @@ DAG 库中泛滥的样板包装器(如 ``def wrapper(): return fn(workflow.get from __future__ import annotations import inspect -from typing import Any, Dict, List, Mapping, Set, Tuple +from typing import Any, Mapping from .errors import InjectionError from .task import Context, TaskSpec -__all__ = ["Context", "build_call_args", "describe_injection", "_is_context_annotation"] +__all__ = ["Context", "_is_context_annotation", "build_call_args", "describe_injection"] def _is_context_annotation(annotation: Any) -> bool: @@ -43,15 +43,13 @@ def _is_context_annotation(annotation: Any) -> bool: return annotation == "Context" or annotation.endswith(".Context") # 按限定名匹配,支持 ``from pyflowx import Context`` 再导出。 name = getattr(annotation, "__name__", None) or getattr(annotation, "_name", None) - if name in ("Context", "Mapping"): - return True - return False + return name in ("Context", "Mapping") def build_call_args( spec: TaskSpec[object], context: Mapping[str, Any], -) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: +) -> tuple[tuple[Any, ...], dict[str, Any]]: """解析用于调用 ``spec.fn`` 的 ``(args, kwargs)``。 参数 @@ -84,7 +82,9 @@ def build_call_args( ) # 与本任务相关的上下文子集。 - dep_context: Dict[str, Any] = {name: context[name] for name in spec.depends_on if name in context} + dep_context: dict[str, Any] = { + name: context[name] for name in spec.depends_on if name in context + } # 检测静态 kwargs 与依赖名的冲突。 collisions = set(spec.kwargs) & set(dep_context) @@ -95,12 +95,12 @@ def build_call_args( "rename the static kwarg or the dependency.", ) - injected_kwargs: Dict[str, Any] = {} - leftover_dep_results: Dict[str, Any] = dict(dep_context) + injected_kwargs: dict[str, Any] = {} + leftover_dep_results: dict[str, Any] = dict(dep_context) # 被 spec.args 消费的位置参数。记录哪些参数名已被位置填充, # 以便在基于名称的注入(依赖 / Context / 静态 kwargs)时跳过。 - positional_params: List[str] = [] + positional_params: list[str] = [] positional_kinds = ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, @@ -109,7 +109,7 @@ def build_call_args( if param.kind in positional_kinds: positional_params.append(pname) # 前 len(spec.args) 个位置参数由 spec.args 填充。 - args_filled: Set[str] = set(positional_params[: len(spec.args)]) + args_filled: set[str] = set(positional_params[: len(spec.args)]) for pname, param in params.items(): # 跳过已被位置 spec.args 填充的参数。 diff --git a/src/pyflowx/errors.py b/src/pyflowx/errors.py index 655d552..7b94272 100644 --- a/src/pyflowx/errors.py +++ b/src/pyflowx/errors.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Any, Iterable, Optional +from typing import Any, Iterable class PyFlowXError(Exception): @@ -55,10 +55,12 @@ class TaskFailedError(PyFlowXError): task: str, cause: BaseException, attempts: int, - layer: Optional[int] = None, + layer: int | None = None, ) -> None: location = f" (layer {layer})" if layer is not None else "" - super().__init__(f"Task '{task}' failed after {attempts} attempt(s){location}: {cause}") + super().__init__( + f"Task '{task}' failed after {attempts} attempt(s){location}: {cause}" + ) self.task = task self.cause = cause self.attempts = attempts @@ -85,6 +87,6 @@ class InjectionError(PyFlowXError): class StorageError(PyFlowXError): """状态后端在持久化失败时抛出。""" - def __init__(self, detail: str, cause: Optional[BaseException] = None) -> None: + def __init__(self, detail: str, cause: BaseException | None = None) -> None: super().__init__(f"State storage error: {detail}") self.cause: Any = cause diff --git a/src/pyflowx/examples/async_aggregation.py b/src/pyflowx/examples/async_aggregation.py index 328d9ae..ebb1cc6 100644 --- a/src/pyflowx/examples/async_aggregation.py +++ b/src/pyflowx/examples/async_aggregation.py @@ -10,7 +10,7 @@ Shows: from __future__ import annotations import asyncio -from typing import Any, Dict, List +from typing import Any import pyflowx as px @@ -20,13 +20,13 @@ async def fetch_user(uid: int) -> dict: return {"id": uid, "name": f"User{uid}"} -async def fetch_posts(uid: int) -> List[int]: +async def fetch_posts(uid: int) -> list[int]: await asyncio.sleep(0.2) return [uid, uid + 1] # Context annotation → receives the full mapping of upstream results. -def aggregate(ctx: px.Context) -> Dict[str, Any]: +def aggregate(ctx: px.Context) -> dict[str, Any]: return dict(ctx) @@ -43,7 +43,7 @@ def main() -> None: print("=== Dry run ===") px.run(graph, strategy="async", dry_run=True) - events: List[px.TaskEvent] = [] + events: list[px.TaskEvent] = [] print("\n=== Async execution ===") report = px.run(graph, strategy="async", on_event=events.append) diff --git a/src/pyflowx/examples/etl_pipeline.py b/src/pyflowx/examples/etl_pipeline.py index f30955a..dd4dbac 100644 --- a/src/pyflowx/examples/etl_pipeline.py +++ b/src/pyflowx/examples/etl_pipeline.py @@ -10,21 +10,19 @@ Demonstrates the core PyFlowX workflow: from __future__ import annotations -from typing import List - import pyflowx as px # --- task functions: pure, testable, no framework coupling ------------- # -def extract_customers() -> List[dict]: +def extract_customers() -> list[dict]: return [ {"id": "C001", "name": "Alice"}, {"id": "C002", "name": "Bob"}, ] -def extract_orders() -> List[dict]: +def extract_orders() -> list[dict]: return [ {"id": "O001", "customer_id": "C001", "amount": 150.0}, {"id": "O002", "customer_id": "C002", "amount": 200.5}, @@ -33,9 +31,9 @@ def extract_orders() -> List[dict]: # Parameter names match dependency names → automatic injection. def transform( - extract_customers: List[dict], - extract_orders: List[dict], -) -> List[dict]: + extract_customers: list[dict], + extract_orders: list[dict], +) -> list[dict]: cmap = {c["id"]: c for c in extract_customers} return [ {**o, "customer_name": cmap[o["customer_id"]]["name"]} @@ -44,7 +42,7 @@ def transform( ] -def load(transform: List[dict]) -> int: +def load(transform: list[dict]) -> int: print(f" loaded {len(transform)} records") return len(transform) diff --git a/src/pyflowx/executors.py b/src/pyflowx/executors.py index c70bd9e..7a6bf9f 100644 --- a/src/pyflowx/executors.py +++ b/src/pyflowx/executors.py @@ -20,7 +20,7 @@ import enum import inspect import logging from datetime import datetime -from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, Union, cast +from typing import Any, Awaitable, Callable, Mapping, cast from .context import build_call_args, describe_injection from .errors import TaskFailedError, TaskTimeoutError @@ -53,7 +53,7 @@ class Strategy(enum.Enum): ASYNC = "async" -def _normalize_strategy(strategy: Union[str, Strategy]) -> Strategy: +def _normalize_strategy(strategy: str | Strategy) -> Strategy: """将字符串或 Strategy 归一化为 Strategy 枚举. Parameters @@ -79,7 +79,9 @@ def _normalize_strategy(strategy: Union[str, Strategy]) -> Strategy: return Strategy(strategy) except ValueError: valid = ", ".join(repr(s.value) for s in Strategy) - raise ValueError(f"unknown strategy {strategy!r}; expected one of {valid}.") from None + raise ValueError( + f"unknown strategy {strategy!r}; expected one of {valid}." + ) from None raise TypeError(f"strategy must be str or Strategy, got {type(strategy).__name__}") @@ -89,7 +91,7 @@ def _is_async_fn(spec: TaskSpec[object]) -> bool: def _emit( - on_event: Optional[EventCallback], + on_event: EventCallback | None, result: TaskResult[object], ) -> None: """若注册了回调则触发一个观察者事件。""" @@ -106,7 +108,9 @@ def _emit( ) -def _log_retry(spec: TaskSpec[object], attempts: int, max_attempts: int, exc: BaseException) -> None: +def _log_retry( + spec: TaskSpec[object], attempts: int, max_attempts: int, exc: BaseException +) -> None: """记录重试日志(sync 与 async 共享,便于测试覆盖)。""" logger.warning( "task %r failed (attempt %d/%d): %r; retrying", @@ -117,7 +121,7 @@ def _log_retry(spec: TaskSpec[object], attempts: int, max_attempts: int, exc: Ba ) -def _finalize_failure(result: TaskResult[object], layer_idx: Optional[int]) -> None: +def _finalize_failure(result: TaskResult[object], layer_idx: int | None) -> None: """标记任务为 FAILED 并抛出 TaskFailedError。""" result.status = TaskStatus.FAILED result.finished_at = datetime.now() @@ -132,7 +136,7 @@ def _finalize_failure(result: TaskResult[object], layer_idx: Optional[int]) -> N def _run_sync_with_retry( spec: TaskSpec[object], context: Mapping[str, Any], - layer_idx: Optional[int], + layer_idx: int | None, ) -> TaskResult[object]: """执行同步任务并带重试;返回填充好的 TaskResult。""" result: TaskResult[object] = TaskResult(spec=spec) @@ -155,7 +159,7 @@ def _run_sync_with_retry( result.status = TaskStatus.SUCCESS result.finished_at = datetime.now() return result - except Exception as exc: # noqa: BLE001 - 用户代码可能抛任何异常 + except Exception as exc: result.error = exc if result.attempts >= max_attempts: _finalize_failure(result, layer_idx) # pragma: no cover @@ -166,7 +170,7 @@ def _run_sync_with_retry( async def _run_async_with_retry( spec: TaskSpec[object], context: Mapping[str, Any], - layer_idx: Optional[int], + layer_idx: int | None, ) -> TaskResult[object]: """在事件循环上执行任务(同步或异步)并带重试。""" result: TaskResult[object] = TaskResult(spec=spec) @@ -198,7 +202,9 @@ async def _run_async_with_retry( return spec.effective_fn(*args, **kwargs) if spec.timeout is not None: - result.value = await asyncio.wait_for(loop.run_in_executor(None, fn_call), timeout=spec.timeout) + result.value = await asyncio.wait_for( + loop.run_in_executor(None, fn_call), timeout=spec.timeout + ) else: result.value = await loop.run_in_executor(None, fn_call) result.status = TaskStatus.SUCCESS @@ -214,7 +220,7 @@ async def _run_async_with_retry( result.attempts, max_attempts, ) - except Exception as exc: # noqa: BLE001 + except Exception as exc: result.error = exc if result.attempts >= max_attempts: _finalize_failure(result, layer_idx) # pragma: no cover @@ -230,17 +236,19 @@ def _build_context( global_context: Mapping[str, Any], ) -> Mapping[str, Any]: """将全局上下文限制为本任务的依赖。""" - return {dep: global_context[dep] for dep in spec.depends_on if dep in global_context} + return { + dep: global_context[dep] for dep in spec.depends_on if dep in global_context + } def _execute_layer_sequential( - layer: List[str], + layer: list[str], graph: Graph, - context: Dict[str, Any], + context: dict[str, Any], report: RunReport, backend: StateBackend, layer_idx: int, - on_event: Optional[EventCallback], + on_event: EventCallback | None, ) -> None: """逐个运行某层的任务。""" for name in layer: @@ -261,23 +269,25 @@ def _execute_layer_sequential( def _execute_layer_threaded( - layer: List[str], + layer: list[str], graph: Graph, - context: Dict[str, Any], + context: dict[str, Any], report: RunReport, backend: StateBackend, layer_idx: int, - on_event: Optional[EventCallback], + on_event: EventCallback | None, max_workers: int, ) -> None: """在线程池中并发运行某层的任务。""" # 先同步满足已缓存任务。 - to_run: List[str] = [] + to_run: list[str] = [] for name in layer: if backend.has(name): cached = backend.get(name) context[name] = cached - result = TaskResult(spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached) + result = TaskResult( + spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached + ) report.results[name] = result _emit(on_event, result) else: @@ -287,7 +297,7 @@ def _execute_layer_threaded( return with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool: - future_to_name: Dict[concurrent.futures.Future[TaskResult[object]], str] = {} + future_to_name: dict[concurrent.futures.Future[TaskResult[object]], str] = {} for name in to_run: spec = graph.spec(name) # 为本任务快照上下文以避免竞态。 @@ -305,21 +315,23 @@ def _execute_layer_threaded( async def _execute_layer_async( - layer: List[str], + layer: list[str], graph: Graph, - context: Dict[str, Any], + context: dict[str, Any], report: RunReport, backend: StateBackend, layer_idx: int, - on_event: Optional[EventCallback], + on_event: EventCallback | None, ) -> None: """在事件循环上并发运行某层的任务。""" - to_run: List[str] = [] + to_run: list[str] = [] for name in layer: if backend.has(name): cached = backend.get(name) context[name] = cached - result = TaskResult(spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached) + result = TaskResult( + spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached + ) report.results[name] = result _emit(on_event, result) else: @@ -346,8 +358,8 @@ async def _execute_layer_async( # 公共 API # ---------------------------------------------------------------------- # def _make_verbose_callback( - on_event: Optional[EventCallback], -) -> Optional[EventCallback]: + on_event: EventCallback | None, +) -> EventCallback | None: """包装 on_event 回调, 在 verbose 模式下打印任务生命周期. Parameters @@ -385,13 +397,13 @@ def _make_verbose_callback( def run( graph: Graph, - strategy: Union[str, Strategy] = Strategy.SEQUENTIAL, + strategy: str | Strategy = Strategy.SEQUENTIAL, *, - max_workers: Optional[int] = None, + max_workers: int | None = None, dry_run: bool = False, verbose: bool = False, - on_event: Optional[EventCallback] = None, - state: Optional[StateBackend] = None, + on_event: EventCallback | None = None, + state: StateBackend | None = None, ) -> RunReport: """执行图并返回 :class:`RunReport`。 @@ -434,17 +446,23 @@ def run( return RunReport(success=True) # verbose 模式下包装事件回调 - effective_callback: Optional[EventCallback] = _make_verbose_callback(on_event) if verbose else on_event + effective_callback: EventCallback | None = ( + _make_verbose_callback(on_event) if verbose else on_event + ) backend = resolve_backend(state) report = RunReport() - context: Dict[str, Any] = {} + context: dict[str, Any] = {} try: if normalized == Strategy.SEQUENTIAL: - _drive_sequential(graph, layers, context, report, backend, effective_callback) + _drive_sequential( + graph, layers, context, report, backend, effective_callback + ) elif normalized == Strategy.THREAD: - _drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers) + _drive_threaded( + graph, layers, context, report, backend, effective_callback, max_workers + ) else: _drive_async(graph, layers, context, report, backend, effective_callback) except TaskFailedError: @@ -454,7 +472,7 @@ def run( return report -def _print_dry_run(graph: Graph, layers: List[List[str]]) -> None: +def _print_dry_run(graph: Graph, layers: list[list[str]]) -> None: """打印执行计划但不运行任何任务。""" print(f"Dry run: {len(graph)} tasks, {len(layers)} layers") for idx, layer in enumerate(layers, 1): @@ -465,11 +483,11 @@ def _print_dry_run(graph: Graph, layers: List[List[str]]) -> None: def _drive_sequential( graph: Graph, - layers: List[List[str]], - context: Dict[str, Any], + layers: list[list[str]], + context: dict[str, Any], report: RunReport, backend: StateBackend, - on_event: Optional[EventCallback], + on_event: EventCallback | None, ) -> None: for idx, layer in enumerate(layers, 1): _execute_layer_sequential(layer, graph, context, report, backend, idx, on_event) @@ -477,36 +495,40 @@ def _drive_sequential( def _drive_threaded( graph: Graph, - layers: List[List[str]], - context: Dict[str, Any], + layers: list[list[str]], + context: dict[str, Any], report: RunReport, backend: StateBackend, - on_event: Optional[EventCallback], - max_workers: Optional[int], + on_event: EventCallback | None, + max_workers: int | None, ) -> None: for idx, layer in enumerate(layers, 1): workers = max_workers or max(1, min(32, len(layer))) - _execute_layer_threaded(layer, graph, context, report, backend, idx, on_event, workers) + _execute_layer_threaded( + layer, graph, context, report, backend, idx, on_event, workers + ) def _drive_async( graph: Graph, - layers: List[List[str]], - context: Dict[str, Any], + layers: list[list[str]], + context: dict[str, Any], report: RunReport, backend: StateBackend, - on_event: Optional[EventCallback], + on_event: EventCallback | None, ) -> None: asyncio.run(_async_drive(graph, layers, context, report, backend, on_event)) async def _async_drive( graph: Graph, - layers: List[List[str]], - context: Dict[str, Any], + layers: list[list[str]], + context: dict[str, Any], report: RunReport, backend: StateBackend, - on_event: Optional[EventCallback], + on_event: EventCallback | None, ) -> None: for idx, layer in enumerate(layers, 1): - await _execute_layer_async(layer, graph, context, report, backend, idx, on_event) + await _execute_layer_async( + layer, graph, context, report, backend, idx, on_event + ) diff --git a/src/pyflowx/graph.py b/src/pyflowx/graph.py index fd25c84..ea841c8 100644 --- a/src/pyflowx/graph.py +++ b/src/pyflowx/graph.py @@ -8,7 +8,7 @@ from __future__ import annotations import sys -from typing import Dict, Iterable, List, Mapping, Sequence, Set, Tuple +from typing import Iterable, Mapping, Sequence from typing_extensions import override @@ -38,14 +38,14 @@ class Graph: """ def __init__(self) -> None: - self._specs: Dict[str, TaskSpec[object]] = {} + self._specs: dict[str, TaskSpec[object]] = {} # 任务 -> 其直接依赖(前驱)。 - self._deps: Dict[str, Tuple[str, ...]] = {} + self._deps: dict[str, tuple[str, ...]] = {} # ------------------------------------------------------------------ # # 构建 # ------------------------------------------------------------------ # - def add(self, spec: TaskSpec[object]) -> "Graph": + def add(self, spec: TaskSpec[object]) -> Graph: """注册一个任务 spec,并即时校验。 返回 ``self`` 以支持链式调用,但推荐入口是 :meth:`from_specs`, @@ -60,7 +60,7 @@ class Graph: return self @classmethod - def from_specs(cls, specs: Iterable[TaskSpec[object]]) -> "Graph": + def from_specs(cls, specs: Iterable[TaskSpec[object]]) -> Graph: """从可迭代的 task spec 构建图。 先收集所有 spec,再统一校验。这意味着任务可以引用*后出现*的 @@ -107,7 +107,7 @@ class Graph: # 内省 # ------------------------------------------------------------------ # @property - def names(self) -> List[str]: + def names(self) -> list[str]: """所有已注册任务名(按插入顺序)。""" return list(self._specs.keys()) @@ -115,7 +115,7 @@ class Graph: """返回 ``name`` 的 spec;不存在则 ``KeyError``。""" return self._specs[name] - def dependencies(self, name: str) -> Tuple[str, ...]: + def dependencies(self, name: str) -> tuple[str, ...]: """``name`` 的直接前驱。""" return self._deps[name] @@ -123,7 +123,7 @@ class Graph: """name -> spec 的只读视图。""" return self._specs - def layers(self) -> List[List[str]]: + def layers(self) -> list[list[str]]: """将任务分组为可并行执行的层(Kahn 算法)。 同层任务无相互依赖,可并发执行。层按执行顺序返回。 @@ -132,7 +132,7 @@ class Graph: """ self.validate() sorter = _TopologicalSorter(self._deps) - result: List[List[str]] = [] + result: list[list[str]] = [] # ``get_ready`` + ``done`` 每次给出一层,正好是并行执行所需的分组。 sorter.prepare() while sorter.is_active(): @@ -147,19 +147,21 @@ class Graph: # ------------------------------------------------------------------ # # 子图 / 标签过滤 # ------------------------------------------------------------------ # - def subgraph(self, tags: Iterable[str]) -> "Graph": + def subgraph(self, tags: Iterable[str]) -> Graph: """返回仅包含匹配任意标签的任务的新图。 依赖会被修剪,仅保留被保留任务之间的边;指向被丢弃任务的边 会被移除(被保留的任务不再等待它们)。用于调试时运行大型 DAG 的切片。 """ - wanted: Set[str] = set(tags) - kept: List[TaskSpec[object]] = [] + wanted: set[str] = set(tags) + kept: list[TaskSpec[object]] = [] for spec in self._specs.values(): if wanted & set(spec.tags): pruned_deps = tuple( - d for d in spec.depends_on if d in self._specs and (wanted & set(self._specs[d].tags)) + d + for d in spec.depends_on + if d in self._specs and (wanted & set(self._specs[d].tags)) ) kept.append( TaskSpec( @@ -178,13 +180,13 @@ class Graph: ) return Graph.from_specs(kept) - def subgraph_by_names(self, names: Iterable[str]) -> "Graph": + def subgraph_by_names(self, names: Iterable[str]) -> Graph: """返回限定于 ``names`` 的新图(边已修剪)。""" - wanted: Set[str] = set(names) + wanted: set[str] = set(names) for n in wanted: if n not in self._specs: raise KeyError(f"Unknown task name: {n!r}") - kept: List[TaskSpec[object]] = [] + kept: list[TaskSpec[object]] = [] for spec in self._specs.values(): if spec.name in wanted: pruned_deps = tuple(d for d in spec.depends_on if d in wanted) @@ -217,8 +219,10 @@ class Graph: valid = {"TD", "TB", "BT", "LR", "RL"} orientation = orientation.upper() if orientation not in valid: - raise ValueError(f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}.") - lines: List[str] = [f"graph {orientation}"] + raise ValueError( + f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}." + ) + lines: list[str] = [f"graph {orientation}"] for name in self._specs: lines.append(f' {name}["{name}"]') for name, deps in self._deps.items(): @@ -231,7 +235,7 @@ class Graph: # ------------------------------------------------------------------ # def describe(self) -> str: """用于调试的人类可读多行摘要。""" - out: List[str] = [f"Graph(tasks={len(self._specs)})"] + out: list[str] = [f"Graph(tasks={len(self._specs)})"] for layer_idx, layer in enumerate(self.layers(), 1): out.append(f" Layer {layer_idx}: {layer}") return "\n".join(out) diff --git a/src/pyflowx/report.py b/src/pyflowx/report.py index ca097b5..3ec22cb 100644 --- a/src/pyflowx/report.py +++ b/src/pyflowx/report.py @@ -7,7 +7,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Dict, Iterator, List +from typing import Any, Iterator from .task import TaskResult, TaskStatus @@ -24,7 +24,7 @@ class RunReport: 当且仅当所有非跳过任务都以 ``SUCCESS`` 结束时为 ``True``。 """ - results: Dict[str, TaskResult[object]] = field(default_factory=dict) + results: dict[str, TaskResult[object]] = field(default_factory=dict) success: bool = True # ---- 类型化访问 --------------------------------------------------- # @@ -50,9 +50,9 @@ class RunReport: return len(self.results) # ---- 汇总 --------------------------------------------------------- # - def summary(self) -> Dict[str, Any]: + def summary(self) -> dict[str, Any]: """用于日志/仪表盘的紧凑统计字典。""" - counts: Dict[str, int] = {} + counts: dict[str, int] = {} total_duration = 0.0 for r in self.results.values(): counts[r.status.value] = counts.get(r.status.value, 0) + 1 @@ -65,7 +65,7 @@ class RunReport: "total_duration_seconds": round(total_duration, 6), } - def failed_tasks(self) -> List[str]: + def failed_tasks(self) -> list[str]: """以 FAILED 状态结束的任务名列表。""" return [ name for name, r in self.results.items() if r.status == TaskStatus.FAILED @@ -73,7 +73,7 @@ class RunReport: def describe(self) -> str: """用于调试的人类可读多行报告。""" - lines: List[str] = [f"RunReport(success={self.success})"] + lines: list[str] = [f"RunReport(success={self.success})"] for name, r in self.results.items(): dur = f"{r.duration:.3f}s" if r.duration is not None else "-" err = f" error={r.error!r}" if r.error else "" diff --git a/src/pyflowx/runner.py b/src/pyflowx/runner.py index 94d39e7..5223495 100644 --- a/src/pyflowx/runner.py +++ b/src/pyflowx/runner.py @@ -23,13 +23,13 @@ import argparse import dataclasses import enum import sys -from typing import Dict, List, Optional, Sequence, Union +from typing import Sequence from .errors import PyFlowXError from .executors import Strategy, _normalize_strategy, run from .graph import Graph -__all__ = ["CliRunner", "CliExitCode"] +__all__ = ["CliExitCode", "CliRunner"] class CliExitCode(enum.IntEnum): @@ -92,12 +92,16 @@ class CliRunner: 基本用法:: runner = px.CliRunner( - clean=px.Graph.from_specs([ - px.TaskSpec("cargo_clean", cmd=["cargo", "clean"]), - ]), - build=px.Graph.from_specs([ - px.TaskSpec("uv_build", cmd=["uv", "build"]), - ]), + clean=px.Graph.from_specs( + [ + px.TaskSpec("cargo_clean", cmd=["cargo", "clean"]), + ] + ), + build=px.Graph.from_specs( + [ + px.TaskSpec("uv_build", cmd=["uv", "build"]), + ] + ), ) runner.run() # 解析 sys.argv @@ -114,7 +118,7 @@ class CliRunner: def __init__( self, *, - strategy: Union[str, Strategy] = Strategy.SEQUENTIAL, + strategy: str | Strategy = Strategy.SEQUENTIAL, description: str = "", verbose: bool = True, **graphs: Graph, @@ -127,7 +131,7 @@ class CliRunner: raise TypeError( f"CliRunner 命令 {name!r} 的值必须是 Graph 实例, 实际是 {type(graph).__name__}" ) - self._graphs: Dict[str, Graph] = dict(graphs) + self._graphs: dict[str, Graph] = dict(graphs) self._strategy: Strategy = _normalize_strategy(strategy) self._description: str = description self._verbose: bool = verbose @@ -136,12 +140,12 @@ class CliRunner: # 内省 # ------------------------------------------------------------------ # @property - def commands(self) -> List[str]: + def commands(self) -> list[str]: """可用的命令列表 (按插入顺序).""" return list(self._graphs.keys()) @property - def graphs(self) -> Dict[str, Graph]: + def graphs(self) -> dict[str, Graph]: """命令名到图的映射 (只读副本).""" return dict(self._graphs) @@ -225,7 +229,7 @@ class CliRunner: # ------------------------------------------------------------------ # # 执行 # ------------------------------------------------------------------ # - def run(self, args: Optional[Sequence[str]] = None) -> int: + def run(self, args: Sequence[str] | None = None) -> int: """解析参数并执行对应的图. Parameters @@ -293,7 +297,7 @@ class CliRunner: print(f"错误: {e}", file=sys.stderr) return CliExitCode.FAILURE.value - def run_cli(self, args: Optional[Sequence[str]] = None) -> None: + def run_cli(self, args: Sequence[str] | None = None) -> None: """运行并以退出码退出进程. 作为 CLI 工具运行时的入口点, 等价于 ``sys.exit(self.run(args))``. diff --git a/src/pyflowx/storage.py b/src/pyflowx/storage.py index ebc33bc..803b3c4 100644 --- a/src/pyflowx/storage.py +++ b/src/pyflowx/storage.py @@ -17,9 +17,11 @@ from __future__ import annotations import json -import os from abc import ABC, abstractmethod -from typing import Any, Dict, Mapping, Optional +from pathlib import Path +from typing import Any, Mapping + +from typing_extensions import override from .errors import StorageError @@ -52,20 +54,25 @@ class MemoryBackend(StateBackend): """进程内 dict 后端。进程退出即丢失。""" def __init__(self) -> None: - self._store: Dict[str, Any] = {} + self._store: dict[str, Any] = {} + @override def load(self) -> Mapping[str, Any]: return dict(self._store) + @override def save(self, name: str, value: Any) -> None: self._store[name] = value + @override def has(self, name: str) -> bool: return name in self._store + @override def get(self, name: str) -> Any: return self._store[name] + @override def clear(self) -> None: self._store.clear() @@ -79,16 +86,16 @@ class JSONBackend(StateBackend): """ def __init__(self, path: str) -> None: - self._path = path - self._store: Dict[str, Any] = {} + self._path: str = path + self._store: dict[str, Any] = {} self._load() def _load(self) -> None: - if not os.path.exists(self._path): + if not Path(self._path).exists(): return try: - with open(self._path, "r", encoding="utf-8") as fh: - data = json.load(fh) + with open(self._path, encoding="utf-8") as fh: + data: Any = json.load(fh) if isinstance(data, dict): self._store = data except (OSError, json.JSONDecodeError) as exc: @@ -99,13 +106,15 @@ class JSONBackend(StateBackend): try: with open(tmp, "w", encoding="utf-8") as fh: json.dump(self._store, fh, ensure_ascii=False, indent=2) - os.replace(tmp, self._path) + Path(tmp).replace(Path(self._path)) except (OSError, TypeError) as exc: raise StorageError(f"cannot write state file {self._path!r}", exc) from exc + @override def load(self) -> Mapping[str, Any]: return dict(self._store) + @override def save(self, name: str, value: Any) -> None: # 在修改内存状态前先校验可序列化性。 try: @@ -115,19 +124,22 @@ class JSONBackend(StateBackend): f"result of task {name!r} is not JSON-serialisable", exc ) from exc self._store[name] = value - self._flush() + _ = self._flush() + @override def has(self, name: str) -> bool: return name in self._store + @override def get(self, name: str) -> Any: return self._store[name] + @override def clear(self) -> None: self._store.clear() - self._flush() + _ = self._flush() -def resolve_backend(backend: Optional[StateBackend]) -> StateBackend: +def resolve_backend(backend: StateBackend | None) -> StateBackend: """返回 ``backend``;为 ``None`` 时返回新的 :class:`MemoryBackend`。""" return backend if backend is not None else MemoryBackend() diff --git a/src/pyflowx/task.py b/src/pyflowx/task.py index 6e2fed9..087060c 100644 --- a/src/pyflowx/task.py +++ b/src/pyflowx/task.py @@ -30,6 +30,7 @@ from typing import ( Tuple, TypeVar, Union, + cast, ) T = TypeVar("T") @@ -48,7 +49,7 @@ Context = Mapping[str, Any] TaskCmd = Union[ List[str], # 命令列表, 如 ["ls", "-la"] str, # shell 命令字符串 - Callable[..., T], # Python 函数 + Callable[..., Any], # Python 函数 ] # 条件判断函数类型 @@ -151,12 +152,12 @@ class TaskSpec(Generic[T]): return self.fn raise ValueError(f"TaskSpec '{self.name}': 没有可执行的函数或命令。") - def _wrap_cmd(self) -> TaskFn[T]: + def _wrap_cmd(self) -> TaskFn[Any]: """将 cmd 包装为可执行函数. Returns ------- - TaskFn[T] + TaskFn[Any] 包装后的执行函数. """ cmd = self.cmd @@ -184,17 +185,19 @@ class TaskSpec(Generic[T]): check=False, ) except FileNotFoundError: - raise RuntimeError(f"命令未找到: {cmd_str}") + raise RuntimeError(f"命令未找到: {cmd_str}") from None except subprocess.TimeoutExpired: - raise RuntimeError(f"命令执行超时: {cmd_str} ({timeout}s)") + raise RuntimeError( + f"命令执行超时: {cmd_str} ({timeout}s)" + ) from None except OSError as e: - raise RuntimeError(f"命令执行异常: {cmd_str}: {e}") + raise RuntimeError(f"命令执行异常: {cmd_str}: {e}") from e if verbose: print(f"[verbose] 返回码: {result.returncode}", flush=True) if result.returncode == 0: - return None # type: ignore[return-value] + return cast(T, None) # type: ignore[return-value] err_msg = f"命令执行失败: `{cmd_str}`, 返回码: {result.returncode}" if not verbose and result.stderr.strip(): @@ -224,17 +227,19 @@ class TaskSpec(Generic[T]): check=False, ) except FileNotFoundError: - raise RuntimeError(f"Shell 命令未找到: {cmd}") + raise RuntimeError(f"Shell 命令未找到: {cmd}") from None except subprocess.TimeoutExpired: - raise RuntimeError(f"Shell 命令执行超时: {cmd} ({timeout}s)") + raise RuntimeError( + f"Shell 命令执行超时: {cmd} ({timeout}s)" + ) from None except OSError as e: - raise RuntimeError(f"Shell 命令执行异常: {cmd}: {e}") + raise RuntimeError(f"Shell 命令执行异常: {cmd}: {e}") from e if verbose: print(f"[verbose] 返回码: {result.returncode}", flush=True) if result.returncode == 0: - return None # type: ignore[return-value] + return cast(T, None) # type: ignore[return-value] err_msg = f"Shell 命令执行失败: `{cmd}`, 返回码: {result.returncode}" if not verbose and result.stderr.strip(): diff --git a/tests/cli/test_pymake.py b/tests/cli/test_pymake.py new file mode 100644 index 0000000..e8087ac --- /dev/null +++ b/tests/cli/test_pymake.py @@ -0,0 +1,165 @@ +"""Tests for pymake CLI.""" + +from pyflowx.cli.pymake import _build_graphs, _get_maturin_build_command, conf + + +def test_pymake_config_attributes(): + """Test PymakeConfig has expected attributes.""" + assert hasattr(conf, "PROJECT_ROOT") + assert hasattr(conf, "BUILD_TOOL") + assert hasattr(conf, "BUILD_COMMAND") + assert hasattr(conf, "MATURIN_TOOL") + assert hasattr(conf, "MATURIN_BUILD_COMMAND") + assert hasattr(conf, "MATURIN_DEV_COMMAND") + assert hasattr(conf, "TIMEOUT") + + +def test_pymake_config_values(): + """Test PymakeConfig values are correct.""" + assert conf.BUILD_TOOL == "uv" + assert conf.BUILD_COMMAND == ["uv", "build"] + assert conf.MATURIN_TOOL == "maturin" + assert conf.TIMEOUT == 600 + + +def test_get_maturin_build_command_basic(): + """Test _get_maturin_build_command returns base command.""" + cmd = _get_maturin_build_command() + assert "maturin" in cmd + assert "build" in cmd + assert "-r" in cmd + + +def test_build_graphs_returns_dict(): + """Test _build_graphs returns a dictionary.""" + graphs = _build_graphs() + assert isinstance(graphs, dict) + assert len(graphs) > 0 + + +def test_build_graphs_has_expected_commands(): + """Test _build_graphs has expected command keys.""" + graphs = _build_graphs() + expected_commands = [ + "b", + "bc", + "ba", + "ic", + "ip", + "ia", + "cp", + "cc", + "ca", + "t", + "lint", + ] + for cmd in expected_commands: + assert cmd in graphs, f"Expected command '{cmd}' not found in graphs" + + +def test_build_graphs_values_are_graphs(): + """Test _build_graphs values are Graph instances.""" + import pyflowx as px + + graphs = _build_graphs() + for name, graph in graphs.items(): + assert isinstance(graph, px.Graph), ( + f"Graph for command '{name}' is not a Graph instance" + ) + + +def test_build_command_graph_structure(): + """Test 'b' command graph has correct structure.""" + + graphs = _build_graphs() + graph = graphs["b"] + assert len(graph.all_specs()) == 1 + spec = graph.spec("uv_build") + assert spec.cmd == conf.BUILD_COMMAND + + +def test_build_all_command_graph_structure(): + """Test 'ba' command graph has correct dependencies.""" + + graphs = _build_graphs() + graph = graphs["ba"] + specs = graph.all_specs() + assert len(specs) == 2 + # Check dependency + uv_build_spec = graph.spec("uv_build") + assert "maturin_build" in uv_build_spec.depends_on + + +def test_maturin_build_command_graph_structure(): + """Test 'bc' command graph has correct structure.""" + graphs = _build_graphs() + graph = graphs["bc"] + specs = graph.all_specs() + assert len(specs) == 1 + spec = graph.spec("maturin_build") + assert spec.cmd == _get_maturin_build_command() + + +def test_install_all_command_graph_structure(): + """Test 'ia' command graph has correct dependencies.""" + graphs = _build_graphs() + graph = graphs["ia"] + specs = graph.all_specs() + assert len(specs) == 2 + uv_install_spec = graph.spec("uv_install") + assert "maturin_dev" in uv_install_spec.depends_on + + +def test_clean_all_command_graph_structure(): + """Test 'ca' command graph has correct structure.""" + graphs = _build_graphs() + graph = graphs["ca"] + specs = graph.all_specs() + assert len(specs) == 2 + + +def test_test_command_graph_structure(): + """Test 't' command graph has correct structure.""" + graphs = _build_graphs() + graph = graphs["t"] + specs = graph.all_specs() + assert len(specs) == 1 + spec = graph.spec("pytest") + assert "pytest" in spec.cmd + + +def test_lint_command_graph_structure(): + """Test 'lint' command graph has correct structure.""" + graphs = _build_graphs() + graph = graphs["lint"] + specs = graph.all_specs() + assert len(specs) == 1 + spec = graph.spec("ruff_check") + assert "ruff" in spec.cmd + + +def test_pymake_config_dirs_to_ignore(): + """Test PymakeConfig has correct dirs to ignore.""" + assert ".venv" in conf.DIRS_TO_IGNORE + assert ".git" in conf.DIRS_TO_IGNORE + assert ".tox" in conf.DIRS_TO_IGNORE + + +def test_pymake_config_python_build_dirs(): + """Test PymakeConfig has correct Python build dirs.""" + assert "dist" in conf.PYTHON_BUILD_DIRS + assert "build" in conf.PYTHON_BUILD_DIRS + + +def test_maturin_build_options_win7(): + """Test MATURIN_BUILD_OPTIONS_WIN7 has expected options.""" + assert "--target" in conf.MATURIN_BUILD_OPTIONS_WIN7 + assert "x86_64-win7-windows-msvc" in conf.MATURIN_BUILD_OPTIONS_WIN7 + assert "-Zbuild-std" in conf.MATURIN_BUILD_OPTIONS_WIN7 + + +def test_doc_build_command(): + """Test DOC_BUILD_COMMAND has expected structure.""" + assert "sphinx-build" in conf.DOC_BUILD_COMMAND + assert "-b" in conf.DOC_BUILD_COMMAND + assert "html" in conf.DOC_BUILD_COMMAND diff --git a/tests/test_conditions.py b/tests/test_conditions.py new file mode 100644 index 0000000..302a16a --- /dev/null +++ b/tests/test_conditions.py @@ -0,0 +1,178 @@ +"""Tests for conditions module.""" + +import os +import sys +from unittest.mock import patch + +from pyflowx.conditions import ( + IS_LINUX, + IS_MACOS, + IS_POSIX, + IS_WINDOWS, + BuiltinConditions, + Constants, +) + + +def test_constants_is_windows(): + """Test Constants.IS_WINDOWS is correct.""" + assert (sys.platform == "win32") == Constants.IS_WINDOWS + + +def test_constants_is_linux(): + """Test Constants.IS_LINUX is correct.""" + assert (sys.platform == "linux") == Constants.IS_LINUX + + +def test_constants_is_macos(): + """Test Constants.IS_MACOS is correct.""" + assert (sys.platform == "darwin") == Constants.IS_MACOS + + +def test_constants_is_posix(): + """Test Constants.IS_POSIX is correct.""" + assert (sys.platform != "win32") == Constants.IS_POSIX + + +def test_builtin_conditions_is_windows(): + """Test BuiltinConditions.IS_WINDOWS.""" + result = BuiltinConditions.IS_WINDOWS() + assert result == Constants.IS_WINDOWS + + +def test_builtin_conditions_is_linux(): + """Test BuiltinConditions.IS_LINUX.""" + result = BuiltinConditions.IS_LINUX() + assert result == Constants.IS_LINUX + + +def test_builtin_conditions_is_macos(): + """Test BuiltinConditions.IS_MACOS.""" + result = BuiltinConditions.IS_MACOS() + assert result == Constants.IS_MACOS + + +def test_builtin_conditions_is_posix(): + """Test BuiltinConditions.IS_POSIX.""" + result = BuiltinConditions.IS_POSIX() + assert result == Constants.IS_POSIX + + +def test_builtin_conditions_python_version_major_only(): + """Test BuiltinConditions.PYTHON_VERSION with major only.""" + # Test with current Python version + current_major = sys.version_info.major + assert BuiltinConditions.PYTHON_VERSION(current_major) is True + assert BuiltinConditions.PYTHON_VERSION(current_major + 1) is False + + +def test_builtin_conditions_python_version_with_minor(): + """Test BuiltinConditions.PYTHON_VERSION with major and minor.""" + current_major = sys.version_info.major + current_minor = sys.version_info.minor + assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor) is True + assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor + 1) is False + + +def test_builtin_conditions_python_version_at_least(): + """Test BuiltinConditions.PYTHON_VERSION_AT_LEAST.""" + current_major = sys.version_info.major + current_minor = sys.version_info.minor + # Current version should be at least itself + assert ( + BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major, current_minor) is True + ) + # Current version should be at least an older version + assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major - 1, 0) is True + # Current version should NOT be at least a newer version + assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major + 1, 0) is False + + +def test_builtin_conditions_has_app_installed_true(): + """Test BuiltinConditions.HAS_APP_INSTALLED when app exists.""" + # Python should always be available + condition = BuiltinConditions.HAS_APP_INSTALLED("python") + assert condition() is True + + +def test_builtin_conditions_has_app_installed_false(): + """Test BuiltinConditions.HAS_APP_INSTALLED when app doesn't exist.""" + condition = BuiltinConditions.HAS_APP_INSTALLED("nonexistent_app_12345") + assert condition() is False + + +def test_builtin_conditions_env_var_exists_true(): + """Test BuiltinConditions.ENV_VAR_EXISTS when variable exists.""" + with patch.dict(os.environ, {"TEST_VAR": "value"}): + condition = BuiltinConditions.ENV_VAR_EXISTS("TEST_VAR") + assert condition() is True + + +def test_builtin_conditions_env_var_exists_false(): + """Test BuiltinConditions.ENV_VAR_EXISTS when variable doesn't exist.""" + condition = BuiltinConditions.ENV_VAR_EXISTS("NONEXISTENT_VAR_12345") + assert condition() is False + + +def test_builtin_conditions_env_var_equals_true(): + """Test BuiltinConditions.ENV_VAR_EQUALS when value matches.""" + with patch.dict(os.environ, {"TEST_VAR": "expected_value"}): + condition = BuiltinConditions.ENV_VAR_EQUALS("TEST_VAR", "expected_value") + assert condition() is True + + +def test_builtin_conditions_env_var_equals_false(): + """Test BuiltinConditions.ENV_VAR_EQUALS when value doesn't match.""" + with patch.dict(os.environ, {"TEST_VAR": "different_value"}): + condition = BuiltinConditions.ENV_VAR_EQUALS("TEST_VAR", "expected_value") + assert condition() is False + + +def test_builtin_conditions_not(): + """Test BuiltinConditions.NOT.""" + true_condition = lambda: True # noqa: E731 + false_condition = lambda: False # noqa: E731 + + not_true = BuiltinConditions.NOT(true_condition) + assert not_true() is False + + not_false = BuiltinConditions.NOT(false_condition) + assert not_false() is True + + +def test_builtin_conditions_and_all_true(): + """Test BuiltinConditions.AND when all conditions are true.""" + true_condition = lambda: True # noqa: E731 + condition = BuiltinConditions.AND(true_condition, true_condition, true_condition) + assert condition() is True + + +def test_builtin_conditions_and_one_false(): + """Test BuiltinConditions.AND when one condition is false.""" + true_condition = lambda: True # noqa: E731 + false_condition = lambda: False # noqa: E731 + condition = BuiltinConditions.AND(true_condition, false_condition, true_condition) + assert condition() is False + + +def test_builtin_conditions_or_all_false(): + """Test BuiltinConditions.OR when all conditions are false.""" + false_condition = lambda: False # noqa: E731 + condition = BuiltinConditions.OR(false_condition, false_condition, false_condition) + assert condition() is False + + +def test_builtin_conditions_or_one_true(): + """Test BuiltinConditions.OR when one condition is true.""" + true_condition = lambda: True # noqa: E731 + false_condition = lambda: False # noqa: E731 + condition = BuiltinConditions.OR(false_condition, true_condition, false_condition) + assert condition() is True + + +def test_exported_conditions(): + """Test exported condition functions.""" + assert IS_WINDOWS() == Constants.IS_WINDOWS + assert IS_LINUX() == Constants.IS_LINUX + assert IS_MACOS() == Constants.IS_MACOS + assert IS_POSIX() == Constants.IS_POSIX diff --git a/tests/test_executors_edge_cases.py b/tests/test_executors_edge_cases.py new file mode 100644 index 0000000..d0c353c --- /dev/null +++ b/tests/test_executors_edge_cases.py @@ -0,0 +1,189 @@ +"""Tests for executors module edge cases.""" + +import asyncio +import sys + +import pytest + +import pyflowx as px +from pyflowx.task import TaskStatus + +# 跨平台的 echo 命令 +if sys.platform == "win32": + ECHO_CMD = ["cmd", "/c", "echo"] +else: + ECHO_CMD = ["echo"] + + +def test_execute_sync_with_timeout(): + """Test execute task with timeout correctly.""" + # Note: timeout for Python functions only works in async strategy + # For sync functions, timeout is not enforced in sequential strategy + # This test verifies that the task runs without timeout error + spec = px.TaskSpec("quick", fn=lambda: "result", timeout=10) + graph = px.Graph.from_specs([spec]) + + # Should succeed without timeout error + report = px.run(graph, strategy="sequential") + assert report.success + + +def test_execute_async_with_timeout(): + """Test execute async task with timeout correctly.""" + + async def slow_async_function(): + await asyncio.sleep(2) + return "result" + + spec = px.TaskSpec("slow_async", fn=slow_async_function, timeout=0.5) + graph = px.Graph.from_specs([spec]) + + # This should timeout + with pytest.raises(px.TaskFailedError): + px.run(graph, strategy="async") + + +def test_verbose_event_callback_running(): + """Test verbose event callback for RUNNING status.""" + # Create a graph with verbose callback + spec = px.TaskSpec("test", fn=lambda: "result", verbose=True) + graph = px.Graph.from_specs([spec]) + report = px.run(graph, strategy="sequential") + # Should print without error + assert report.success + + +def test_verbose_event_callback_success(): + """Test verbose event callback for SUCCESS status.""" + # Create a graph with verbose callback + spec = px.TaskSpec("test", fn=lambda: "result", verbose=True) + graph = px.Graph.from_specs([spec]) + report = px.run(graph, strategy="sequential") + # Should print without error + assert report.success + + +def test_verbose_event_callback_failed(): + """Test verbose event callback for FAILED status.""" + # Create a graph with verbose callback and failing task + + def raise_error(): + raise ValueError("test error") + + spec = px.TaskSpec("test", fn=raise_error, verbose=True) + graph = px.Graph.from_specs([spec]) + + # Should print without error + with pytest.raises(px.TaskFailedError): + px.run(graph, strategy="sequential") + + +def test_verbose_event_callback_skipped(): + """Test verbose event callback for SKIPPED status.""" + # Create a graph with verbose callback and skipped task + spec = px.TaskSpec( + "test", + fn=lambda: "result", + conditions=(lambda: False,), + verbose=True, + ) + graph = px.Graph.from_specs([spec]) + report = px.run(graph, strategy="sequential") + # Should print without error + assert report.success + + +def test_execute_sync_with_retries(): + """Test execute task with retries.""" + + call_count = 0 + + def failing_function(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise ValueError("temporary error") + return "success" + + spec = px.TaskSpec("retry_test", fn=failing_function, retries=3) + graph = px.Graph.from_specs([spec]) + + # Should succeed after retries + report = px.run(graph, strategy="sequential") + assert report.success + assert report.results["retry_test"].attempts == 3 + + +def test_execute_async_with_retries(): + """Test execute async task with retries.""" + + call_count = 0 + + async def failing_async_function(): + nonlocal call_count + call_count += 1 + if call_count < 3: + raise ValueError("temporary error") + return "success" + + spec = px.TaskSpec("retry_async_test", fn=failing_async_function, retries=3) + graph = px.Graph.from_specs([spec]) + + # Should succeed after retries + report = px.run(graph, strategy="async") + assert report.success + assert report.results["retry_async_test"].attempts == 3 + + +def test_execute_sync_skip_on_condition(): + """Test execute task skips task when condition is false.""" + spec = px.TaskSpec( + "skip_test", + fn=lambda: "result", + conditions=(lambda: False,), + ) + graph = px.Graph.from_specs([spec]) + + report = px.run(graph, strategy="sequential") + assert report.success + assert report.results["skip_test"].status == TaskStatus.SKIPPED + + +def test_execute_async_skip_on_condition(): + """Test execute async task skips task when condition is false.""" + spec = px.TaskSpec( + "skip_async_test", + fn=lambda: "result", + conditions=(lambda: False,), + ) + graph = px.Graph.from_specs([spec]) + + report = px.run(graph, strategy="async") + assert report.success + assert report.results["skip_async_test"].status == TaskStatus.SKIPPED + + +def test_execute_sync_with_error(): + """Test execute task handles errors correctly.""" + + def error_function(): + raise ValueError("test error") + + spec = px.TaskSpec("error_test", fn=error_function) + graph = px.Graph.from_specs([spec]) + + with pytest.raises(px.TaskFailedError): + px.run(graph, strategy="sequential") + + +def test_execute_async_with_error(): + """Test execute async task handles errors correctly.""" + + async def error_async_function(): + raise ValueError("test error") + + spec = px.TaskSpec("error_async_test", fn=error_async_function) + graph = px.Graph.from_specs([spec]) + + with pytest.raises(px.TaskFailedError): + px.run(graph, strategy="async") diff --git a/tests/test_task_edge_cases.py b/tests/test_task_edge_cases.py new file mode 100644 index 0000000..0129e54 --- /dev/null +++ b/tests/test_task_edge_cases.py @@ -0,0 +1,156 @@ +"""Tests for task module edge cases.""" + +import sys +import tempfile + +import pytest + +import pyflowx as px +from pyflowx.task import TaskSpec + +# 跨平台的 echo 命令 +if sys.platform == "win32": + ECHO_CMD = ["cmd", "/c", "echo"] +else: + ECHO_CMD = ["echo"] + + +def test_taskspec_wrap_cmd_with_list(): + """Test TaskSpec._wrap_cmd with command list.""" + spec = TaskSpec("test", cmd=[*ECHO_CMD, "hello"]) + wrapped_fn = spec.effective_fn + assert wrapped_fn is not None + assert wrapped_fn.__name__ == "test" + + +def test_taskspec_wrap_cmd_with_string(): + """Test TaskSpec._wrap_cmd with command string.""" + if sys.platform == "win32": + cmd_str = "cmd /c echo hello" + else: + cmd_str = "echo hello" + spec = TaskSpec("test", cmd=cmd_str) + wrapped_fn = spec.effective_fn + assert wrapped_fn is not None + assert wrapped_fn.__name__ == "test" + + +def test_taskspec_wrap_cmd_with_timeout(): + """Test TaskSpec._wrap_cmd with timeout.""" + spec = TaskSpec("test", cmd=[*ECHO_CMD, "hello"], timeout=0.1) + wrapped_fn = spec.effective_fn + + # Should not raise timeout error for quick command + result = wrapped_fn() + assert result is None + + +def test_taskspec_wrap_cmd_with_cwd(): + """Test TaskSpec._wrap_cmd with working directory.""" + with tempfile.TemporaryDirectory() as tmpdir: + spec = TaskSpec("test", cmd=[*ECHO_CMD, "hello"], cwd=tmpdir) + wrapped_fn = spec.effective_fn + result = wrapped_fn() + assert result is None + + +def test_taskspec_wrap_cmd_verbose(): + """Test TaskSpec._wrap_cmd with verbose=True.""" + spec = TaskSpec("test", cmd=[*ECHO_CMD, "hello"], verbose=True) + wrapped_fn = spec.effective_fn + + # Should print verbose output + result = wrapped_fn() + assert result is None + + +def test_taskspec_wrap_cmd_error(): + """Test TaskSpec._wrap_cmd handles command error.""" + spec = TaskSpec("test", cmd=["python", "-c", "import sys; sys.exit(1)"]) + wrapped_fn = spec.effective_fn + + with pytest.raises(RuntimeError, match="命令执行失败"): + wrapped_fn() + + +def test_taskspec_wrap_cmd_file_not_found(): + """Test TaskSpec._wrap_cmd handles file not found.""" + spec = TaskSpec("test", cmd=["nonexistent_command"]) + wrapped_fn = spec.effective_fn + + with pytest.raises(RuntimeError, match="命令未找到"): + wrapped_fn() + + +def test_taskspec_wrap_cmd_shell_file_not_found(): + """Test TaskSpec._wrap_cmd handles shell command file not found.""" + spec = TaskSpec("test", cmd="nonexistent_shell_command") + wrapped_fn = spec.effective_fn + + # Shell commands don't raise FileNotFoundError + # They just return non-zero exit code + with pytest.raises(RuntimeError): + wrapped_fn() + + +def test_taskspec_no_fn_no_cmd(): + """Test TaskSpec raises error when no fn or cmd.""" + with pytest.raises(ValueError, match="必须提供 fn 或 cmd 参数"): + TaskSpec("test") + + +def test_taskspec_cmd_overrides_fn(): + """Test TaskSpec cmd overrides fn.""" + + def my_fn(): + return "fn_result" + + spec = TaskSpec("test", fn=my_fn, cmd=[*ECHO_CMD, "hello"]) + wrapped_fn = spec.effective_fn + + # cmd should override fn + assert wrapped_fn.__name__ == "test" + + +def test_taskspec_conditions_check(): + """Test TaskSpec.should_execute with conditions.""" + spec = px.TaskSpec( + "test", + fn=lambda: "result", + conditions=(lambda: True,), + ) + + assert spec.should_execute() is True + + +def test_taskspec_conditions_false(): + """Test TaskSpec.should_execute with false conditions.""" + spec = px.TaskSpec( + "test", + fn=lambda: "result", + conditions=(lambda: False,), + ) + + assert spec.should_execute() is False + + +def test_taskspec_conditions_multiple(): + """Test TaskSpec.should_execute with multiple conditions.""" + spec = px.TaskSpec( + "test", + fn=lambda: "result", + conditions=(lambda: True, lambda: True, lambda: True), + ) + + assert spec.should_execute() is True + + +def test_taskspec_conditions_multiple_one_false(): + """Test TaskSpec.should_execute with one false condition.""" + spec = px.TaskSpec( + "test", + fn=lambda: "result", + conditions=(lambda: True, lambda: False, lambda: True), + ) + + assert spec.should_execute() is False diff --git a/uv.lock b/uv.lock index 2bb2fa5..2f59704 100644 --- a/uv.lock +++ b/uv.lock @@ -1907,6 +1907,8 @@ version = "0.1.2" source = { editable = "." } dependencies = [ { name = "graphlib-backport", marker = "python_full_version < '3.9'" }, + { name = "typing-extensions", version = "4.13.2", source = { registry = "https://mirrors.aliyun.com/pypi/simple/" }, marker = "python_full_version < '3.9'" }, + { name = "typing-extensions", version = "4.15.0", source = { registry = "https://mirrors.aliyun.com/pypi/simple/" }, marker = "python_full_version >= '3.9'" }, ] [package.optional-dependencies] @@ -1961,6 +1963,7 @@ requires-dist = [ { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.8.0" }, { name = "tox", marker = "extra == 'dev'", specifier = ">=4.25.0" }, { name = "tox-uv", marker = "extra == 'dev'", specifier = ">=1.13.1" }, + { name = "typing-extensions", specifier = ">=4.13.2" }, ] provides-extras = ["dev"] @@ -2759,6 +2762,7 @@ name = "typing-extensions" version = "4.15.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple/" } resolution-markers = [ + "python_full_version >= '3.15'", "python_full_version >= '3.10' and python_full_version < '3.15'", "python_full_version > '3.9' and python_full_version < '3.10'", "python_full_version == '3.9'",