From d58fc5536e95696db7c588ff1c68371c9b9675e5 Mon Sep 17 00:00:00 2001 From: gooker_young Date: Sat, 27 Jun 2026 20:15:35 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E5=8F=91=E5=B8=83=20pyflowx=200.2.10?= =?UTF-8?q?=EF=BC=8C=E6=96=B0=E5=A2=9E=E6=80=A7=E8=83=BD=E8=AE=A1=E6=97=B6?= =?UTF-8?q?=E5=99=A8=E4=B8=8E=E5=A4=9A=E9=A1=B9=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 新增 perf_timer 工具与配套测试用例 2. 重构任务条件跳过逻辑,优化失败条件展示 3. 重构 Graph 子图生成逻辑,提取公共依赖修剪函数 4. 重构条件模块,统一条件名称与失败原因获取逻辑 5. 重构存储后端,提取 TTL 共享逻辑并优化实现 6. 重构执行器模块,使用 Mixin 复用代码,拆分任务与层执行逻辑 7. 删除冗余的 which 命令测试文件 8. 更新依赖锁文件 --- src/pyflowx/conditions.py | 40 +- src/pyflowx/executors.py | 891 ++++++++++++++++++-------------------- src/pyflowx/graph.py | 35 +- src/pyflowx/storage.py | 150 +++++-- src/pyflowx/task.py | 13 +- src/pyflowx/utils.py | 70 +++ tests/cli/test_which.py | 66 --- tests/test_utils.py | 41 ++ uv.lock | 2 +- 9 files changed, 701 insertions(+), 607 deletions(-) create mode 100644 src/pyflowx/utils.py delete mode 100644 tests/cli/test_which.py create mode 100644 tests/test_utils.py diff --git a/src/pyflowx/conditions.py b/src/pyflowx/conditions.py index 46d0020..d3d879a 100644 --- a/src/pyflowx/conditions.py +++ b/src/pyflowx/conditions.py @@ -42,6 +42,19 @@ def _static(predicate: Callable[[], bool], name: str) -> Condition: return _cond +def _cond_reason(cond: Condition) -> str | list[str] | None: + """获取条件的失败原因:优先返回 ``_reason``,否则返回 ``__name__``。""" + reason = getattr(cond, "_reason", None) + if reason is not None: + return reason + return getattr(cond, "__name__", repr(cond)) + + +def _cond_name(cond: Condition) -> str: + """获取条件的可读名称。""" + return getattr(cond, "__name__", repr(cond)) + + # ---------------------------------------------------------------------- # # 模块级静态条件常量 # ---------------------------------------------------------------------- # @@ -61,21 +74,25 @@ class BuiltinConditions: # ------------------------------------------------------------------ # # 静态条件 # ------------------------------------------------------------------ # + @staticmethod def IS_WINDOWS() -> Condition: """检查是否为 Windows 平台.""" - return _static(lambda: Constants.IS_WINDOWS, "IS_WINDOWS") + return IS_WINDOWS + @staticmethod def IS_LINUX() -> Condition: """检查是否为 Linux 平台.""" - return _static(lambda: Constants.IS_LINUX, "IS_LINUX") + return IS_LINUX + @staticmethod def IS_MACOS() -> Condition: """检查是否为 macOS 平台.""" - return _static(lambda: Constants.IS_MACOS, "IS_MACOS") + return IS_MACOS + @staticmethod def IS_POSIX() -> Condition: """检查是否为 POSIX 平台.""" - return _static(lambda: Constants.IS_POSIX, "IS_POSIX") + return IS_POSIX @staticmethod def PYTHON_VERSION(major: int, minor: int | None = None) -> Condition: @@ -214,12 +231,12 @@ class BuiltinConditions: result = condition(ctx) if result: # inner 为 True 时 NOT 会失败,记录 inner 的具体原因 - inner_reason = getattr(condition, "_reason", None) + inner_reason = _cond_reason(condition) if inner_reason is not None: _cond._reason = inner_reason # type: ignore[attr-defined] return not result - _cond.__name__ = f"NOT({getattr(condition, '__name__', repr(condition))})" + _cond.__name__ = f"NOT({_cond_name(condition)})" return _cond @staticmethod @@ -229,8 +246,7 @@ class BuiltinConditions: def _cond(ctx: Context) -> bool: return all(c(ctx) for c in conditions) - names = [getattr(c, "__name__", repr(c)) for c in conditions] - _cond.__name__ = f"AND({', '.join(names)})" + _cond.__name__ = f"AND({', '.join(_cond_name(c) for c in conditions)})" return _cond @staticmethod @@ -241,14 +257,12 @@ class BuiltinConditions: matched: list[str] = [] for c in conditions: if c(ctx): - matched.append( - getattr(c, "_reason", None) or getattr(c, "__name__", repr(c)), - ) + reason = _cond_reason(c) + matched.append(reason if isinstance(reason, str) else str(reason)) if matched: _cond._reason = matched # type: ignore[attr-defined] return True return False - names = [getattr(c, "__name__", repr(c)) for c in conditions] - _cond.__name__ = f"OR({', '.join(names)})" + _cond.__name__ = f"OR({', '.join(_cond_name(c) for c in conditions)})" return _cond diff --git a/src/pyflowx/executors.py b/src/pyflowx/executors.py index 1818563..1356856 100644 --- a/src/pyflowx/executors.py +++ b/src/pyflowx/executors.py @@ -10,6 +10,17 @@ * ``dependency`` —— 依赖驱动调度:任务在其所有硬依赖完成后立即启动, 无需等待同层其他任务。最大化并行度。 +架构 +---- +本模块通过 **Mixin** 组合消除同步/异步与各层执行器之间的重复代码: + +* :class:`_TaskSkipMixin` —— 上游跳过 / 条件跳过的预检逻辑。 +* :class:`_TaskRetryMixin` —— 重试决策、成功/失败后处理、finalize。 +* :class:`_LayerMixin` —— 缓存过滤、优先级排序、信号量构建、结果存储。 +* :class:`SyncTaskRunner` / :class:`AsyncTaskRunner` —— 任务级执行器,组合上述 Mixin。 +* :class:`SequentialLayerRunner` / :class:`ThreadedLayerRunner` / + :class:`AsyncLayerRunner` / :class:`DependencyRunner` —— 层级执行器,组合 :class:`_LayerMixin`。 + 所有策略共享统一异步内核,支持: * :class:`RetryPolicy`(max_attempts/delay/backoff/jitter/retry_on) * 软依赖注入与默认值 @@ -30,6 +41,7 @@ import concurrent.futures import inspect import logging import threading +import time from datetime import datetime from typing import Any, Awaitable, Callable, Literal, Mapping, cast @@ -48,7 +60,7 @@ Strategy = Literal["sequential", "thread", "async", "dependency"] # ---------------------------------------------------------------------- # -# 辅助 +# 无状态公共辅助 # ---------------------------------------------------------------------- # def _is_async_fn(spec: TaskSpec[Any]) -> bool: """判断 ``spec.effective_fn`` 是否为协程函数。""" @@ -71,17 +83,6 @@ def _emit(on_event: EventCallback | None, result: TaskResult[Any]) -> None: ) -def _log_retry(spec: TaskSpec[Any], attempt: int, max_attempts: int, exc: BaseException) -> None: - """记录重试日志。""" - logger.warning( - "task %r failed (attempt %d/%d): %r; retrying", - spec.name, - attempt, - max_attempts, - exc, - ) - - def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None: """安全调用钩子(异常仅记录,不影响任务状态)。""" hook: Callable[..., None] | None = getattr(hooks, fn_name, None) @@ -93,87 +94,6 @@ def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None: logger.warning("hook %s raised: %r", fn_name, exc) -def _check_upstream_skipped( - spec: TaskSpec[Any], - report: RunReport | None, -) -> tuple[bool, str | None]: - """检查硬依赖上游任务是否被 SKIPPED 或 FAILED。 - - 软依赖不影响本检查——软依赖被跳过时注入默认值。 - """ - if report is None: # pragma: no cover - return False, None # pragma: no cover - - if spec.allow_upstream_skip: # pragma: no cover - return False, None # pragma: no cover - - for dep in spec.depends_on: - if dep not in report.results: # pragma: no cover - continue # pragma: no cover - dep_status = report.results[dep].status - if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED): - return True, f"上游任务 '{dep}' 状态为 {dep_status.value}" - return False, None # pragma: no cover - - -def _format_reason(reason: Any) -> str: - """将 _reason 格式化为可读字符串.""" - if isinstance(reason, list): - return ", ".join(str(r) for r in reason) - return str(reason) - - -def _evaluate_conditions(spec: TaskSpec[Any], context: Mapping[str, Any]) -> str | None: - """求值所有条件,返回跳过原因或 ``None``。 - - 条件接收上下文映射(硬依赖 + 软依赖结果)。 - """ - failed_conditions: list[str] = [] - for condition in spec.conditions: - try: - ok = condition(context) - except Exception: - ok = False - name = getattr(condition, "__name__", None) or "匿名条件(执行错误)" - failed_conditions.append(name) - continue - - if not ok: - reason = getattr(condition, "_reason", None) - if reason is not None: - failed_conditions.append(_format_reason(reason)) - else: - failed_conditions.append(getattr(condition, "__name__", None) or "匿名条件") - - if failed_conditions: - if len(failed_conditions) <= 2: - return f"条件不满足: {', '.join(failed_conditions)}" - return f"条件不满足: {', '.join(failed_conditions[:2])} 等{len(failed_conditions)}个条件" - - if spec.skip_if_missing and not spec._is_cmd_available(): - cmd_name = spec.cmd[0] if isinstance(spec.cmd, list) and spec.cmd else "unknown" - return f"命令不存在: {cmd_name}" - - return None - - -def _make_skipped_result( - spec: TaskSpec[Any], - reason: str, - on_event: EventCallback | None, -) -> TaskResult[Any]: - """构造 SKIPPED 的 TaskResult。""" - result: TaskResult[Any] = TaskResult( - spec=spec, - status=TaskStatus.SKIPPED, - finished_at=datetime.now(), - reason=reason, - ) - _emit(on_event, result) - logger.info("task %r skipped (%s)", spec.name, reason) - return result - - def _build_context( spec: TaskSpec[Any], global_context: Mapping[str, Any], @@ -185,19 +105,16 @@ def _build_context( 软依赖:上游成功则注入其值;否则注入 ``spec.defaults`` 中的默认值(或 ``None``)。 """ ctx: dict[str, Any] = {} - for dep in spec.depends_on: if dep in global_context: ctx[dep] = global_context[dep] - for dep in spec.soft_depends_on: if dep in global_context: ctx[dep] = global_context[dep] - elif dep in spec.defaults: # pragma: no cover - ctx[dep] = spec.defaults[dep] # pragma: no cover + elif dep in spec.defaults: + ctx[dep] = spec.defaults[dep] else: ctx[dep] = None - return ctx @@ -222,112 +139,232 @@ def _apply_cached( return True -def _prepare_for_execution( - spec: TaskSpec[Any], - context: Mapping[str, Any], - report: RunReport | None, - on_event: EventCallback | None, -) -> TaskResult[Any] | None: - """执行前预检:上游跳过 / 条件跳过。 +def _sort_by_priority(layer: list[str], graph: Graph) -> list[str]: + """按优先级降序排序(稳定排序)。""" + return sorted(layer, key=lambda n: -graph.resolved_spec(n).priority) - 返回 SKIPPED TaskResult 或 ``None``(继续执行)。 + +# ---------------------------------------------------------------------- # +# Mixin:任务级跳过 / 重试 / 成功处理 +# ---------------------------------------------------------------------- # +class _TaskSkipMixin: + """任务级跳过预检共享逻辑。 + + 将"上游被跳过/失败"与"条件不满足"两类跳过判断统一为单一入口, + 被 :class:`SyncTaskRunner` 与 :class:`AsyncTaskRunner` 复用。 """ - should_skip, skip_reason = _check_upstream_skipped(spec, report) - if should_skip: - return _make_skipped_result(spec, skip_reason or "上游任务被跳过", on_event) - skip_reason = _evaluate_conditions(spec, context) - if skip_reason is not None: - return _make_skipped_result(spec, skip_reason, on_event) + @staticmethod + def _upstream_skip_reason(spec: TaskSpec[Any], report: RunReport | None) -> str | None: + """硬依赖被 SKIPPED/FAILED 时返回原因字符串,否则 ``None``。 - return None + 软依赖不影响本检查——软依赖被跳过时注入默认值。 + """ + if report is None or spec.allow_upstream_skip: + return None + for dep in spec.depends_on: + if dep not in report.results: + continue + dep_status = report.results[dep].status + if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED): + return f"上游任务 '{dep}' 状态为 {dep_status.value}" + return None + @staticmethod + def _prepare_for_execution( + spec: TaskSpec[Any], + context: Mapping[str, Any], + report: RunReport | None, + on_event: EventCallback | None, + ) -> TaskResult[Any] | None: + """执行前预检:上游跳过 / 条件跳过。 -def _finalize_failure( - result: TaskResult[Any], - layer_idx: int | None, - on_event: EventCallback | None = None, - continue_on_error: bool = False, -) -> None: - """标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。""" - result.status = TaskStatus.FAILED - result.finished_at = datetime.now() - _emit(on_event, result) - if continue_on_error: - logger.warning( - "task %r failed but continue_on_error=True; continuing.", - result.spec.name, + 返回 SKIPPED TaskResult 或 ``None``(继续执行)。 + 条件判断委托给 :meth:`TaskSpec.should_execute`,避免重复实现。 + """ + # 1. 上游被跳过/失败 + skip_reason = _TaskSkipMixin._upstream_skip_reason(spec, report) + # 2. 条件 / skip_if_missing(单一来源:TaskSpec.should_execute) + if skip_reason is None: + should_run, cond_reason = spec.should_execute(context) + if not should_run: + skip_reason = cond_reason or "条件不满足" + if skip_reason is None: + return None + # 构造 SKIPPED 结果 + result: TaskResult[Any] = TaskResult( + spec=spec, + status=TaskStatus.SKIPPED, + finished_at=datetime.now(), + reason=skip_reason, ) - return - raise TaskFailedError( - task=result.spec.name, - cause=result.error if result.error is not None else RuntimeError("unknown"), - attempts=result.attempts, - layer=layer_idx, - ) + _emit(on_event, result) + logger.info("task %r skipped (%s)", spec.name, skip_reason) + return result -def _sleep_for_retry(spec: TaskSpec[Any], attempt: int) -> None: - """重试前的同步等待。""" - wait = spec.retry.wait_seconds(attempt) - if wait > 0: - import time +class _TaskRetryMixin: + """任务级重试决策与失败/成功后处理共享逻辑。""" - time.sleep(wait) + @staticmethod + def _should_retry(spec: TaskSpec[Any], attempts: int, exc: BaseException) -> bool: + """是否应继续重试。""" + return attempts < spec.retry.max_attempts and spec.retry.should_retry(exc) + @staticmethod + def _mark_success(spec: TaskSpec[Any], result: TaskResult[Any], value: Any) -> None: + """标记任务成功并触发 post_run 钩子。""" + result.value = value + result.status = TaskStatus.SUCCESS + result.finished_at = datetime.now() + _run_hooks(spec.hooks, "post_run", spec, value) -async def _async_sleep_for_retry(spec: TaskSpec[Any], attempt: int) -> None: - """重试前的异步等待。""" - wait = spec.retry.wait_seconds(attempt) - if wait > 0: - await asyncio.sleep(wait) + @staticmethod + def _finalize_failure( + result: TaskResult[Any], + layer_idx: int | None, + on_event: EventCallback | None, + continue_on_error: bool, + ) -> None: + """标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。""" + result.status = TaskStatus.FAILED + result.finished_at = datetime.now() + _emit(on_event, result) + if continue_on_error: + logger.warning( + "task %r failed but continue_on_error=True; continuing.", + result.spec.name, + ) + return + raise TaskFailedError( + task=result.spec.name, + cause=result.error if result.error is not None else RuntimeError("unknown"), + attempts=result.attempts, + layer=layer_idx, + ) + + @staticmethod + def _handle_failure( + spec: TaskSpec[Any], + result: TaskResult[Any], + exc: BaseException, + layer_idx: int | None, + on_event: EventCallback | None, + ) -> bool: + """统一处理失败:超时转换、重试决策、finalize。 + + Returns + ------- + bool + ``True`` 表示已 finalize(不再重试);``False`` 表示应继续重试。 + """ + # asyncio.TimeoutError → TaskTimeoutError(统一异常类型) + if isinstance(exc, asyncio.TimeoutError): + exc = TaskTimeoutError(spec.name, spec.timeout or 0.0) + logger.warning( + "task %r timed out (attempt %d/%d); retrying", + spec.name, + result.attempts, + spec.retry.max_attempts, + ) + else: + logger.warning( + "task %r failed (attempt %d/%d): %r; retrying", + spec.name, + result.attempts, + spec.retry.max_attempts, + exc, + ) + result.error = exc + if _TaskRetryMixin._should_retry(spec, result.attempts, exc): + return False + _run_hooks(spec.hooks, "on_failure", spec, exc) + _TaskRetryMixin._finalize_failure(result, layer_idx, on_event, spec.continue_on_error) + return True # ---------------------------------------------------------------------- # -# 同步执行内核 +# 任务执行器:同步 / 异步(复用 _TaskSkipMixin + _TaskRetryMixin) # ---------------------------------------------------------------------- # -def _run_sync_with_retry( - spec: TaskSpec[Any], - context: Mapping[str, Any], - layer_idx: int | None, - on_event: EventCallback | None = None, - report: RunReport | None = None, -) -> TaskResult[Any]: - """执行同步任务并带重试;返回填充好的 TaskResult。""" - skipped = _prepare_for_execution(spec, context, report, on_event) - if skipped is not None: - return skipped +class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin): + """同步任务执行器:带重试与跳过预检。""" - result: TaskResult[Any] = TaskResult(spec=spec) - result.started_at = datetime.now() - max_attempts = spec.retry.max_attempts - args, kwargs = build_call_args(spec, context) + @staticmethod + def run( + spec: TaskSpec[Any], + context: Mapping[str, Any], + layer_idx: int | None, + on_event: EventCallback | None = None, + report: RunReport | None = None, + ) -> TaskResult[Any]: + skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event) + if skipped is not None: + return skipped - _run_hooks(spec.hooks, "pre_run", spec) + result: TaskResult[Any] = TaskResult(spec=spec) + result.started_at = datetime.now() + args, kwargs = build_call_args(spec, context) - while True: - result.attempts += 1 - try: - with spec.env_context(): - result.value = spec.effective_fn(*args, **kwargs) - result.status = TaskStatus.SUCCESS - result.finished_at = datetime.now() - _run_hooks(spec.hooks, "post_run", spec, result.value) - return result - except Exception as exc: - result.error = exc - if result.attempts >= max_attempts or not spec.retry.should_retry(exc): - _run_hooks(spec.hooks, "on_failure", spec, exc) - _finalize_failure(result, layer_idx, on_event, spec.continue_on_error) + _run_hooks(spec.hooks, "pre_run", spec) + + while True: + result.attempts += 1 + try: + with spec.env_context(): + value = spec.effective_fn(*args, **kwargs) + _TaskRetryMixin._mark_success(spec, result, value) return result - _log_retry(spec, result.attempts, max_attempts, exc) - _sleep_for_retry(spec, result.attempts) - # pragma: no cover + except Exception as exc: + if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event): + return result + wait = spec.retry.wait_seconds(result.attempts) + if wait > 0: + time.sleep(wait) + + +class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin): + """异步任务执行器:在事件循环上运行同步或异步任务,带重试与跳过预检。""" + + @staticmethod + async def run( + spec: TaskSpec[Any], + context: Mapping[str, Any], + layer_idx: int | None, + on_event: EventCallback | None = None, + report: RunReport | None = None, + semaphore: asyncio.Semaphore | None = None, + ) -> TaskResult[Any]: + skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event) + if skipped is not None: + return skipped + + async def _inner() -> TaskResult[Any]: + result: TaskResult[Any] = TaskResult(spec=spec) + result.started_at = datetime.now() + args, kwargs = build_call_args(spec, context) + loop = asyncio.get_event_loop() + + _run_hooks(spec.hooks, "pre_run", spec) + + while True: + result.attempts += 1 + try: + value = await _execute_async_task(spec, args, kwargs, loop) + _TaskRetryMixin._mark_success(spec, result, value) + return result + except Exception as exc: + if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event): + return result + wait = spec.retry.wait_seconds(result.attempts) + if wait > 0: + await asyncio.sleep(wait) + + if semaphore is not None: + async with semaphore: + return await _inner() + return await _inner() -# ---------------------------------------------------------------------- # -# 异步执行内核 -# ---------------------------------------------------------------------- # async def _execute_async_task( spec: TaskSpec[Any], args: tuple[Any, ...], @@ -339,307 +376,237 @@ async def _execute_async_task( coro = cast(Awaitable[Any], spec.effective_fn(*args, **kwargs)) if spec.timeout is not None: return await asyncio.wait_for(coro, timeout=spec.timeout) - else: - return await coro - else: + return await coro - def fn_call() -> Any: - with spec.env_context(): - return spec.effective_fn(*args, **kwargs) + def fn_call() -> Any: + with spec.env_context(): + return spec.effective_fn(*args, **kwargs) - if spec.timeout is not None: - return await asyncio.wait_for(loop.run_in_executor(None, fn_call), timeout=spec.timeout) - else: - return await loop.run_in_executor(None, fn_call) + if spec.timeout is not None: + return await asyncio.wait_for(loop.run_in_executor(None, fn_call), timeout=spec.timeout) + return await loop.run_in_executor(None, fn_call) -async def _run_async_with_retry( - spec: TaskSpec[Any], - context: Mapping[str, Any], - layer_idx: int | None, - on_event: EventCallback | None = None, - report: RunReport | None = None, - semaphore: asyncio.Semaphore | None = None, -) -> TaskResult[Any]: - """在事件循环上执行任务(同步或异步)并带重试。""" - skipped = _prepare_for_execution(spec, context, report, on_event) - if skipped is not None: - return skipped +# ---------------------------------------------------------------------- # +# Mixin:层执行共享逻辑 +# ---------------------------------------------------------------------- # +class _LayerMixin: + """层执行共享逻辑:缓存过滤、优先级排序、信号量构建、结果存储。 - if semaphore is not None: - async with semaphore: - return await _run_async_inner(spec, context, layer_idx, on_event, report) - return await _run_async_inner(spec, context, layer_idx, on_event, report) + 四个层执行器(sequential/threaded/async/dependency)通过组合此 Mixin + 消除"过滤缓存→排序→运行→存结果"的样板代码。 + """ + @staticmethod + def _filter_and_sort( + layer: list[str], + graph: Graph, + context: dict[str, Any], + report: RunReport, + backend: StateBackend, + on_event: EventCallback | None, + ) -> list[str]: + """过滤掉已命中缓存的任务,按优先级排序返回待运行列表。""" + to_run: list[str] = [] + for name in layer: + spec = graph.resolved_spec(name) + if not _apply_cached(name, spec, context, report, backend, on_event): + to_run.append(name) + return _sort_by_priority(to_run, graph) -async def _run_async_inner( - spec: TaskSpec[Any], - context: Mapping[str, Any], - layer_idx: int | None, - on_event: EventCallback | None = None, - report: RunReport | None = None, # noqa: ARG001 -) -> TaskResult[Any]: - """异步执行内核的内部实现(已获取 semaphore 后)。""" - result: TaskResult[Any] = TaskResult(spec=spec) - result.started_at = datetime.now() - max_attempts = spec.retry.max_attempts - args, kwargs = build_call_args(spec, context) - loop = asyncio.get_event_loop() + @staticmethod + def _store_result( + name: str, + result: TaskResult[Any], + graph: Graph, + context: dict[str, Any], + report: RunReport, + backend: StateBackend, + on_event: EventCallback | None, + context_snapshot: Mapping[str, Any] | None = None, + ) -> None: + """存储任务结果到 context/report/backend 并触发事件。""" + context[name] = result.value + if result.status == TaskStatus.SUCCESS: + spec = graph.resolved_spec(name) + task_ctx = _build_context(spec, context_snapshot if context_snapshot is not None else context, report) + backend.save(spec.storage_key(task_ctx), result.value) + report.results[name] = result + _emit(on_event, result) - _run_hooks(spec.hooks, "pre_run", spec) + @staticmethod + def _build_semaphores( + to_run: list[str], + graph: Graph, + sem_factory: Callable[[int], Any], + concurrency_limits: Mapping[str, int], + ) -> dict[str, Any]: + """为每个 ``concurrency_key`` 创建一个信号量。""" + semaphores: dict[str, Any] = {} + for name in to_run: + spec = graph.resolved_spec(name) + key = spec.concurrency_key + if key is not None and key not in semaphores: + limit = concurrency_limits.get(key, 1) + semaphores[key] = sem_factory(limit) + return semaphores - while True: - result.attempts += 1 - try: - result.value = await _execute_async_task(spec, args, kwargs, loop) - result.status = TaskStatus.SUCCESS - result.finished_at = datetime.now() - _run_hooks(spec.hooks, "post_run", spec, result.value) - return result - except asyncio.TimeoutError: - exc: BaseException = TaskTimeoutError(spec.name, spec.timeout or 0.0) - result.error = exc - if result.attempts >= max_attempts or not spec.retry.should_retry(exc): - _run_hooks(spec.hooks, "on_failure", spec, exc) - _finalize_failure(result, layer_idx, on_event, spec.continue_on_error) - return result - logger.warning( - "task %r timed out (attempt %d/%d); retrying", - spec.name, - result.attempts, - max_attempts, - ) - await _async_sleep_for_retry(spec, result.attempts) - except Exception as exc: - result.error = exc - if result.attempts >= max_attempts or not spec.retry.should_retry(exc): - _run_hooks(spec.hooks, "on_failure", spec, exc) - _finalize_failure(result, layer_idx, on_event, spec.continue_on_error) - return result - _log_retry(spec, result.attempts, max_attempts, exc) - await _async_sleep_for_retry(spec, result.attempts) - # pragma: no cover + @staticmethod + def _get_sem(semaphores: Mapping[str, Any], spec: TaskSpec[Any]) -> Any | None: + """获取任务对应的信号量(无 concurrency_key 则返回 None)。""" + if spec.concurrency_key is None: + return None + return semaphores.get(spec.concurrency_key) # ---------------------------------------------------------------------- # # 层执行器 # ---------------------------------------------------------------------- # -def _sort_by_priority(layer: list[str], graph: Graph) -> list[str]: - """按优先级降序排序(稳定排序)。""" - return sorted(layer, key=lambda n: -graph.resolved_spec(n).priority) - - -def _execute_layer_sequential( - layer: list[str], - graph: Graph, - context: dict[str, Any], - report: RunReport, - backend: StateBackend, - layer_idx: int, - on_event: EventCallback | None, -) -> None: +class SequentialLayerRunner(_LayerMixin): """逐个运行某层的任务(按优先级排序)。""" - for name in _sort_by_priority(layer, graph): - spec = graph.resolved_spec(name) - if _apply_cached(name, spec, context, report, backend, on_event): - continue - task_ctx = _build_context(spec, context, report) - result = _run_sync_with_retry(spec, task_ctx, layer_idx, on_event, report) - context[name] = result.value - if result.status == TaskStatus.SUCCESS: - backend.save(spec.storage_key(task_ctx), result.value) - report.results[name] = result - _emit(on_event, result) + + @staticmethod + def execute( + layer: list[str], + graph: Graph, + context: dict[str, Any], + report: RunReport, + backend: StateBackend, + layer_idx: int, + on_event: EventCallback | None, + ) -> None: + for name in SequentialLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event): + spec = graph.resolved_spec(name) + task_ctx = _build_context(spec, context, report) + result = SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report) + SequentialLayerRunner._store_result(name, result, graph, context, report, backend, on_event) -def _execute_layer_threaded( - layer: list[str], - graph: Graph, - context: dict[str, Any], - report: RunReport, - backend: StateBackend, - layer_idx: int, - on_event: EventCallback | None, - max_workers: int, - concurrency_limits: Mapping[str, int], -) -> None: +class ThreadedLayerRunner(_LayerMixin): """在线程池中并发运行某层的任务。""" - to_run: list[str] = [] - for name in layer: - spec = graph.resolved_spec(name) - task_ctx = _build_context(spec, context, report) - if _apply_cached(name, spec, context, report, backend, on_event): - continue - to_run.append(name) - if not to_run: - return + @staticmethod + def execute( + layer: list[str], + graph: Graph, + context: dict[str, Any], + report: RunReport, + backend: StateBackend, + layer_idx: int, + on_event: EventCallback | None, + max_workers: int, + concurrency_limits: Mapping[str, int], + ) -> None: + to_run = ThreadedLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event) + if not to_run: + return + semaphores = ThreadedLayerRunner._build_semaphores(to_run, graph, threading.Semaphore, concurrency_limits) + context_snapshot = dict(context) + lock = threading.Lock() - to_run = _sort_by_priority(to_run, graph) - - # 为每个 concurrency_key 创建线程信号量 - semaphores: dict[str, threading.Semaphore] = {} - for name in to_run: - spec = graph.resolved_spec(name) - key = spec.concurrency_key - if key is not None and key not in semaphores: - limit = concurrency_limits.get(key, 1) - semaphores[key] = threading.Semaphore(limit) - - context_snapshot = dict(context) - lock = threading.Lock() - - def _run_threaded_task(name: str) -> TaskResult[Any]: - spec = graph.resolved_spec(name) - task_ctx = _build_context(spec, context_snapshot, report) - sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None - if sem is not None: - sem.acquire() - try: - return _run_sync_with_retry(spec, task_ctx, layer_idx, on_event, report) - finally: - if sem is not None: - sem.release() - - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool: - future_to_name: dict[concurrent.futures.Future[TaskResult[Any]], str] = {} - for name in to_run: - fut = pool.submit(_run_threaded_task, name) - future_to_name[fut] = name - - completed: dict[str, TaskResult[Any]] = {} - try: - for fut in concurrent.futures.as_completed(future_to_name): - name = future_to_name[fut] - result = fut.result() - completed[name] = result - finally: - with lock: - for name, result in completed.items(): - context[name] = result.value - if result.status == TaskStatus.SUCCESS: - spec = graph.resolved_spec(name) - task_ctx = _build_context(spec, context_snapshot, report) - backend.save(spec.storage_key(task_ctx), result.value) - report.results[name] = result - _emit(on_event, result) - - -async def _execute_layer_async( - layer: list[str], - graph: Graph, - context: dict[str, Any], - report: RunReport, - backend: StateBackend, - layer_idx: int, - on_event: EventCallback | None, - concurrency_limits: Mapping[str, int], -) -> None: - """在事件循环上并发运行某层的任务。""" - to_run: list[str] = [] - for name in layer: - spec = graph.resolved_spec(name) - if _apply_cached(name, spec, context, report, backend, on_event): - continue - to_run.append(name) - - if not to_run: - return - - to_run = _sort_by_priority(to_run, graph) - - # 为每个 concurrency_key 创建异步信号量 - semaphores: dict[str, asyncio.Semaphore] = {} - for name in to_run: - spec = graph.resolved_spec(name) - key = spec.concurrency_key - if key is not None and key not in semaphores: - limit = concurrency_limits.get(key, 1) - semaphores[key] = asyncio.Semaphore(limit) - - context_snapshot = dict(context) - - async def _run_async_task_wrapped(name: str) -> TaskResult[Any]: - spec = graph.resolved_spec(name) - task_ctx = _build_context(spec, context_snapshot, report) - sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None - if sem is not None: - async with sem: - return await _run_async_with_retry(spec, task_ctx, layer_idx, on_event, report) - return await _run_async_with_retry(spec, task_ctx, layer_idx, on_event, report) - - coros = [_run_async_task_wrapped(name) for name in to_run] - results = await asyncio.gather(*coros) - for name, result in zip(to_run, results): - context[name] = result.value - if result.status == TaskStatus.SUCCESS: + def _run_threaded_task(name: str) -> TaskResult[Any]: spec = graph.resolved_spec(name) task_ctx = _build_context(spec, context_snapshot, report) - backend.save(spec.storage_key(task_ctx), result.value) - report.results[name] = result - _emit(on_event, result) + sem = ThreadedLayerRunner._get_sem(semaphores, spec) + if sem is not None: + sem.acquire() + try: + return SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report) + finally: + if sem is not None: + sem.release() + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool: + future_to_name: dict[concurrent.futures.Future[TaskResult[Any]], str] = { + pool.submit(_run_threaded_task, name): name for name in to_run + } + completed: dict[str, TaskResult[Any]] = {} + try: + for fut in concurrent.futures.as_completed(future_to_name): + name = future_to_name[fut] + completed[name] = fut.result() + finally: + with lock: + for name, result in completed.items(): + ThreadedLayerRunner._store_result( + name, result, graph, context, report, backend, on_event, context_snapshot + ) -# ---------------------------------------------------------------------- # -# 依赖驱动调度 -# ---------------------------------------------------------------------- # -async def _drive_dependency_async( - graph: Graph, - context: dict[str, Any], - report: RunReport, - backend: StateBackend, - on_event: EventCallback | None, - concurrency_limits: Mapping[str, int], -) -> None: - """依赖驱动调度:任务在硬依赖完成后立即启动,无层屏障。 +class AsyncLayerRunner(_LayerMixin): + """在事件循环上并发运行某层的任务。""" + + @staticmethod + async def execute( + layer: list[str], + graph: Graph, + context: dict[str, Any], + report: RunReport, + backend: StateBackend, + layer_idx: int, + on_event: EventCallback | None, + concurrency_limits: Mapping[str, int], + ) -> None: + to_run = AsyncLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event) + if not to_run: + return + semaphores = AsyncLayerRunner._build_semaphores(to_run, graph, asyncio.Semaphore, concurrency_limits) + context_snapshot = dict(context) + + async def _run_async_task(name: str) -> TaskResult[Any]: + spec = graph.resolved_spec(name) + task_ctx = _build_context(spec, context_snapshot, report) + sem = AsyncLayerRunner._get_sem(semaphores, spec) + return await AsyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report, sem) + + results = await asyncio.gather(*[_run_async_task(name) for name in to_run]) + for name, result in zip(to_run, results): + AsyncLayerRunner._store_result(name, result, graph, context, report, backend, on_event, context_snapshot) + + +class DependencyRunner(_LayerMixin): + """依赖驱动调度:任务在硬/软依赖完成后立即启动,无层屏障。 所有任务通过 asyncio 并发调度。同步任务卸载到线程池。 """ - all_names = set(graph.all_specs().keys()) - semaphores: dict[str, asyncio.Semaphore] = {} - for name in all_names: - spec = graph.resolved_spec(name) - key = spec.concurrency_key - if key is not None and key not in semaphores: - limit = concurrency_limits.get(key, 1) - semaphores[key] = asyncio.Semaphore(limit) - futures: dict[str, asyncio.Future[TaskResult[Any]]] = {} + @staticmethod + async def execute( + graph: Graph, + context: dict[str, Any], + report: RunReport, + backend: StateBackend, + on_event: EventCallback | None, + concurrency_limits: Mapping[str, int], + ) -> None: + all_names = list(graph.all_specs().keys()) + semaphores = DependencyRunner._build_semaphores(all_names, graph, asyncio.Semaphore, concurrency_limits) + futures: dict[str, asyncio.Future[TaskResult[Any]]] = {} - async def _run_task(name: str) -> TaskResult[Any]: - spec = graph.resolved_spec(name) - # 等待所有硬依赖完成 - for dep in spec.depends_on: - if dep in futures: - await futures[dep] - # 等待所有软依赖完成(但不检查其状态) - for dep in spec.soft_depends_on: - if dep in futures: - await futures[dep] + async def _run_task(name: str) -> TaskResult[Any]: + spec = graph.resolved_spec(name) + # 等待所有硬依赖完成 + for dep in spec.depends_on: + if dep in futures: + await futures[dep] + # 等待所有软依赖完成(但不检查其状态) + for dep in spec.soft_depends_on: + if dep in futures: + await futures[dep] - task_ctx = _build_context(spec, context, report) - if _apply_cached(name, spec, context, report, backend, on_event): - return report.results[name] + task_ctx = _build_context(spec, context, report) + if _apply_cached(name, spec, context, report, backend, on_event): + return report.results[name] - sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None - if sem is not None: - async with sem: - result = await _run_async_with_retry(spec, task_ctx, None, on_event, report) - else: - result = await _run_async_with_retry(spec, task_ctx, None, on_event, report) + sem = DependencyRunner._get_sem(semaphores, spec) + result = await AsyncTaskRunner.run(spec, task_ctx, None, on_event, report, sem) + DependencyRunner._store_result(name, result, graph, context, report, backend, on_event) + return result - context[name] = result.value - if result.status == TaskStatus.SUCCESS: - backend.save(spec.storage_key(task_ctx), result.value) - report.results[name] = result - _emit(on_event, result) - return result - - loop = asyncio.get_event_loop() - for name in all_names: - futures[name] = loop.create_task(_run_task(name)) - - await asyncio.gather(*futures.values()) + loop = asyncio.get_event_loop() + for name in all_names: + futures[name] = loop.create_task(_run_task(name)) + await asyncio.gather(*futures.values()) # ---------------------------------------------------------------------- # @@ -729,9 +696,9 @@ def run( elif strategy == "thread": _drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers, limits) elif strategy == "async": - _drive_async(graph, layers, context, report, backend, effective_callback, limits) + asyncio.run(_async_drive(graph, layers, context, report, backend, effective_callback, limits)) elif strategy == "dependency": - asyncio.run(_drive_dependency_async(graph, context, report, backend, effective_callback, limits)) + asyncio.run(DependencyRunner.execute(graph, context, report, backend, effective_callback, limits)) else: raise ValueError(f"Unknown strategy: {strategy!r}") except TaskFailedError: @@ -759,7 +726,7 @@ def _drive_sequential( on_event: EventCallback | None, ) -> None: for idx, layer in enumerate(layers, 1): - _execute_layer_sequential(layer, graph, context, report, backend, idx, on_event) + SequentialLayerRunner.execute(layer, graph, context, report, backend, idx, on_event) def _drive_threaded( @@ -774,19 +741,7 @@ def _drive_threaded( ) -> None: for idx, layer in enumerate(layers, 1): workers = max_workers or max(1, min(32, len(layer))) - _execute_layer_threaded(layer, graph, context, report, backend, idx, on_event, workers, concurrency_limits) - - -def _drive_async( - graph: Graph, - layers: list[list[str]], - context: dict[str, Any], - report: RunReport, - backend: StateBackend, - on_event: EventCallback | None, - concurrency_limits: Mapping[str, int], -) -> None: - asyncio.run(_async_drive(graph, layers, context, report, backend, on_event, concurrency_limits)) + ThreadedLayerRunner.execute(layer, graph, context, report, backend, idx, on_event, workers, concurrency_limits) async def _async_drive( @@ -799,4 +754,4 @@ async def _async_drive( concurrency_limits: Mapping[str, int], ) -> None: for idx, layer in enumerate(layers, 1): - await _execute_layer_async(layer, graph, context, report, backend, idx, on_event, concurrency_limits) + await AsyncLayerRunner.execute(layer, graph, context, report, backend, idx, on_event, concurrency_limits) diff --git a/src/pyflowx/graph.py b/src/pyflowx/graph.py index e141c9b..0e8a673 100644 --- a/src/pyflowx/graph.py +++ b/src/pyflowx/graph.py @@ -49,6 +49,15 @@ class GraphDefaults: verbose: bool = False +def _prune_deps(spec: TaskSpec[Any], keep: Callable[[str], bool]) -> TaskSpec[Any]: + """返回新 spec,其 ``depends_on`` / ``soft_depends_on`` 仅保留 ``keep(dep)`` 为真的依赖。""" + return replace( + spec, + depends_on=tuple(d for d in spec.depends_on if keep(d)), + soft_depends_on=tuple(d for d in spec.soft_depends_on if keep(d)), + ) + + @dataclass class Graph: """校验后的有向无环任务图。 @@ -225,16 +234,13 @@ class Graph: def subgraph(self, tags: Iterable[str]) -> Graph: """返回仅包含匹配任意标签的任务的新图。依赖边被修剪。""" wanted: set[str] = set(tags) - kept: list[TaskSpec[Any]] = [] - for spec in self.specs.values(): - if wanted & set(spec.tags): - pruned_deps = tuple( - d for d in spec.depends_on if d in self.specs and (wanted & set(self.specs[d].tags)) - ) - pruned_soft = tuple( - d for d in spec.soft_depends_on if d in self.specs and (wanted & set(self.specs[d].tags)) - ) - kept.append(replace(spec, depends_on=pruned_deps, soft_depends_on=pruned_soft)) + + def _dep_kept(dep: str) -> bool: + return dep in self.specs and bool(wanted & set(self.specs[dep].tags)) + + kept: list[TaskSpec[Any]] = [ + _prune_deps(spec, _dep_kept) for spec in self.specs.values() if wanted & set(spec.tags) + ] return Graph.from_specs(kept, defaults=self.defaults) def subgraph_by_names(self, names: Iterable[str]) -> Graph: @@ -243,12 +249,9 @@ class Graph: for n in wanted: if n not in self.specs: raise KeyError(f"Unknown task name: {n!r}") - kept: list[TaskSpec[Any]] = [] - for spec in self.specs.values(): - if spec.name in wanted: - pruned_deps = tuple(d for d in spec.depends_on if d in wanted) - pruned_soft = tuple(d for d in spec.soft_depends_on if d in wanted) - kept.append(replace(spec, depends_on=pruned_deps, soft_depends_on=pruned_soft)) + kept: list[TaskSpec[Any]] = [ + _prune_deps(spec, lambda d: d in wanted) for spec in self.specs.values() if spec.name in wanted + ] return Graph.from_specs(kept, defaults=self.defaults) # ------------------------------------------------------------------ # diff --git a/src/pyflowx/storage.py b/src/pyflowx/storage.py index def7b7b..4cbc6d7 100644 --- a/src/pyflowx/storage.py +++ b/src/pyflowx/storage.py @@ -17,6 +17,7 @@ import json import sys import time from abc import ABC, abstractmethod +from collections.abc import Iterator from pathlib import Path from typing import Any, Mapping @@ -55,7 +56,74 @@ class StateBackend(ABC): """清除所有存储状态。""" -class MemoryBackend(StateBackend): +class _TTLStateBackendMixin(StateBackend): + """TTL 状态后端共享逻辑。 + + 将 ``has`` / ``get`` / ``load`` / ``save`` / ``clear`` 的统一实现 + 委托给四个原始存取原语::meth:`_get_raw`、:meth:`_put_raw`、 + :meth:`_iter_raw`、:meth:`_clear_raw`,并基于 :meth:`_now` 与 + ``self._ttl`` 提供统一的过期判断 :meth:`_is_expired`。 + + 子类需设置 ``self._ttl`` 并实现上述四个原语;如需自定义时间源 + (如 ``time.monotonic``)可覆盖 :meth:`_now`。 + """ + + _ttl: float | None + + # ---- 原语:由子类实现 ---- # + @abstractmethod + def _get_raw(self, key: str) -> tuple[Any, float] | None: + """返回 ``(value, ts)``;键不存在时返回 ``None``。""" + + @abstractmethod + def _put_raw(self, key: str, value: Any, ts: float) -> None: + """写入一条记录。""" + + @abstractmethod + def _iter_raw(self) -> Iterator[tuple[str, Any, float]]: + """迭代所有记录(不做过期过滤),yield ``(key, value, ts)``。""" + + @abstractmethod + def _clear_raw(self) -> None: + """清空所有记录。""" + + # ---- 共享实现 ---- # + def _now(self) -> float: + """当前时间戳,默认为 wall-clock 秒。""" + return time.time() + + def _is_expired(self, ts: float) -> bool: + """时间戳 ``ts`` 是否已过期。""" + if self._ttl is None: + return False + return (self._now() - ts) > self._ttl + + @override + def load(self) -> Mapping[str, Any]: + return {k: v for k, v, ts in self._iter_raw() if not self._is_expired(ts)} + + @override + def save(self, key: str, value: Any) -> None: + self._put_raw(key, value, self._now()) + + @override + def has(self, key: str) -> bool: + entry = self._get_raw(key) + return entry is not None and not self._is_expired(entry[1]) + + @override + def get(self, key: str) -> Any: + entry = self._get_raw(key) + if entry is None or self._is_expired(entry[1]): + raise KeyError(key) + return entry[0] + + @override + def clear(self) -> None: + self._clear_raw() + + +class MemoryBackend(_TTLStateBackendMixin): """进程内 dict 后端。进程退出即丢失。 Parameters @@ -70,35 +138,35 @@ class MemoryBackend(StateBackend): self._ttl = ttl @override - def load(self) -> Mapping[str, Any]: - return {k: v for k, (v, _ts) in self._store.items() if not self._expired(k)} + def _now(self) -> float: + return time.monotonic() @override - def save(self, key: str, value: Any) -> None: - self._store[key] = (value, time.monotonic()) + def _get_raw(self, key: str) -> tuple[Any, float] | None: + return self._store.get(key) @override - def has(self, key: str) -> bool: - return key in self._store and not self._expired(key) + def _put_raw(self, key: str, value: Any, ts: float) -> None: + self._store[key] = (value, ts) @override - def get(self, key: str) -> Any: - if key not in self._store or self._expired(key): - raise KeyError(key) - return self._store[key][0] + def _iter_raw(self) -> Iterator[tuple[str, Any, float]]: + for k, (v, ts) in self._store.items(): + yield k, v, ts @override - def clear(self) -> None: + def _clear_raw(self) -> None: self._store.clear() def _expired(self, key: str) -> bool: - if self._ttl is None or key not in self._store: + """键是否已过期(兼容旧测试 API)。""" + entry = self._get_raw(key) + if entry is None: return False - _value, ts = self._store[key] - return (time.monotonic() - ts) > self._ttl + return self._is_expired(entry[1]) -class JSONBackend(StateBackend): +class JSONBackend(_TTLStateBackendMixin): """基于文件的 JSON 存储,用于跨进程续跑。 存储格式:``{key: {"value": v, "ts": epoch_seconds}}``。 @@ -144,17 +212,30 @@ class JSONBackend(StateBackend): except (OSError, TypeError) as exc: raise StorageError(f"cannot write state file {self._path!r}", exc) from exc - def _now(self) -> float: - return time.time() - - def _expired(self, entry: dict[str, Any]) -> bool: - if self._ttl is None: - return False - return (self._now() - float(entry.get("ts", 0))) > self._ttl + @override + def _get_raw(self, key: str) -> tuple[Any, float] | None: + entry = self._store.get(key) + if entry is None: + return None + return entry["value"], float(entry.get("ts", 0)) @override - def load(self) -> Mapping[str, Any]: - return {k: v["value"] for k, v in self._store.items() if not self._expired(v)} + def _put_raw(self, key: str, value: Any, ts: float) -> None: + self._store[key] = {"value": value, "ts": ts} + + @override + def _iter_raw(self) -> Iterator[tuple[str, Any, float]]: + for k, entry in self._store.items(): + yield k, entry["value"], float(entry.get("ts", 0)) + + @override + def _clear_raw(self) -> None: + self._store.clear() + + @override + def clear(self) -> None: + super().clear() + self._flush() @override def save(self, key: str, value: Any) -> None: @@ -162,23 +243,12 @@ class JSONBackend(StateBackend): _ = json.dumps(value) except (TypeError, ValueError) as exc: raise StorageError(f"result of key {key!r} is not JSON-serialisable", exc) from exc - self._store[key] = {"value": value, "ts": self._now()} + super().save(key, value) self._flush() - @override - def has(self, key: str) -> bool: - return key in self._store and not self._expired(self._store[key]) - - @override - def get(self, key: str) -> Any: - if key not in self._store or self._expired(self._store[key]): - raise KeyError(key) - return self._store[key]["value"] - - @override - def clear(self) -> None: - self._store.clear() - self._flush() + def _expired(self, entry: Mapping[str, Any]) -> bool: + """带元数据的条目是否已过期(兼容旧测试 API)。""" + return self._is_expired(float(entry.get("ts", 0))) def resolve_backend(backend: StateBackend | None) -> StateBackend: diff --git a/src/pyflowx/task.py b/src/pyflowx/task.py index ac94cce..c4afc76 100644 --- a/src/pyflowx/task.py +++ b/src/pyflowx/task.py @@ -74,6 +74,13 @@ Condition = Callable[[Context], bool] CacheKeyFn = Callable[[Context], str] +def _format_skip_reason(failed_conditions: list[str]) -> str: + """格式化跳过原因:≤2 个全展示,>2 个仅展示前 2 个并附总数。""" + if len(failed_conditions) <= 2: + return f"条件不满足: {', '.join(failed_conditions)}" + return f"条件不满足: {', '.join(failed_conditions[:2])} 等{len(failed_conditions)}个条件" + + # ---------------------------------------------------------------------- # # 重试策略 # ---------------------------------------------------------------------- # @@ -315,6 +322,7 @@ class TaskSpec(Generic[T]): ------- (should_run, skip_reason) ``should_run`` 为 False 时 ``skip_reason`` 描述跳过原因。 + 失败条件超过 2 个时仅展示前 2 个并附总数。 """ # 逐个求值条件,记录失败项。 failed_conditions: list[str] = [] @@ -323,8 +331,7 @@ class TaskSpec(Generic[T]): ok = condition(context) except Exception: ok = False - name = getattr(condition, "__name__", None) or "匿名条件(执行错误)" - failed_conditions.append(name) + failed_conditions.append("匿名条件(执行错误)") continue if not ok: reason = getattr(condition, "_reason", None) @@ -336,7 +343,7 @@ class TaskSpec(Generic[T]): failed_conditions.append(getattr(condition, "__name__", None) or "匿名条件") if failed_conditions: - return False, f"条件不满足: {', '.join(failed_conditions)}" + return False, _format_skip_reason(failed_conditions) if self.skip_if_missing and not self._is_cmd_available(): cmd_name = self.cmd[0] if isinstance(self.cmd, list) and self.cmd else "unknown" diff --git a/src/pyflowx/utils.py b/src/pyflowx/utils.py new file mode 100644 index 0000000..e1303b8 --- /dev/null +++ b/src/pyflowx/utils.py @@ -0,0 +1,70 @@ +"""常用工具函数.""" + +__all__ = ["perf_timer"] + + +import functools +import logging +import time +from collections import defaultdict +from typing import Callable, ParamSpec, TypedDict + +from typing_extensions import TypeVar + +P = ParamSpec("P") +R = TypeVar("R") + + +class _PerformanceMetrics(TypedDict): + """性能指标.""" + + count: int + total_time: float + + +_perf_metrics: defaultdict[str, _PerformanceMetrics] = defaultdict( + lambda: _PerformanceMetrics( + count=0, + total_time=0.0, + ) +) + + +def perf_timer(unit: str = "ms", precision: int = 4, report: bool = False): + """性能计时器装饰器.""" + scale: dict[str, float] = { + "s": 1.0, + "ms": 1000.0, + "us": 1000000.0, + } + + def decorator(func: Callable[P, R]) -> Callable[P, R]: + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + + _perf_metrics[func.__name__]["count"] += 1 + _perf_metrics[func.__name__]["total_time"] += (end_time - start_time) * scale[unit] + + if not report: + logging.info( + f"{func.__name__} {unit}: {_perf_metrics[func.__name__]['total_time']:.{precision}f}{unit}" + ) + return result + + return wrapper + + if report: + import atexit + + logging.basicConfig(level=logging.INFO) + logging.info(f"Performance metrics report enabled with unit {unit} and precision {precision}") + + @atexit.register + def _() -> None: + for name, metrics in _perf_metrics.items(): + logging.info(f"{name}: {metrics['count']} times, {metrics['total_time']:.{precision}f}{unit}") + + return decorator diff --git a/tests/cli/test_which.py b/tests/cli/test_which.py deleted file mode 100644 index 71954a1..0000000 --- a/tests/cli/test_which.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Tests for cli.which module.""" - -from __future__ import annotations - -import shutil -from unittest.mock import patch - -import pytest - -import pyflowx as px -from pyflowx.cli import which - - -# ---------------------------------------------------------------------- # -# main function -# ---------------------------------------------------------------------- # -class TestMain: - """Test main function.""" - - def test_main_with_single_command(self) -> None: - """main() should handle single command argument.""" - with patch("sys.argv", ["which", "python"]), patch.object( - shutil, "which", return_value="/usr/bin/python" - ), patch.object(px, "run") as mock_run: - which.main() - # Should create a graph with one task - assert mock_run.called - graph = mock_run.call_args[0][0] - assert isinstance(graph, px.Graph) - - def test_main_with_multiple_commands(self) -> None: - """main() should handle multiple command arguments.""" - with patch("sys.argv", ["which", "python", "pip", "node"]), patch.object( - shutil, "which", return_value="/usr/bin/cmd" - ), patch.object(px, "run") as mock_run: - which.main() - # Should create a graph with three tasks - assert mock_run.called - graph = mock_run.call_args[0][0] - assert isinstance(graph, px.Graph) - - def test_main_with_no_args_shows_help(self) -> None: - """main() with no args should show help and exit.""" - with patch("sys.argv", ["which"]), pytest.raises(SystemExit) as exc_info: - which.main() - assert exc_info.value.code == 2 - - def test_main_creates_task_specs_with_correct_names(self) -> None: - """main() should create TaskSpecs with correct names.""" - with patch("sys.argv", ["which", "git", "npm"]), patch.object( - shutil, "which", return_value="/usr/bin/cmd" - ), patch.object(px, "run") as mock_run: - which.main() - graph = mock_run.call_args[0][0] - # Check that task names are correct - task_names = list(graph.all_specs().keys()) - assert "which_git" in task_names - assert "which_npm" in task_names - - def test_main_uses_thread_strategy(self) -> None: - """main() should use thread strategy.""" - with patch("sys.argv", ["which", "python"]), patch.object( - shutil, "which", return_value="/usr/bin/python" - ), patch.object(px, "run") as mock_run: - which.main() - assert mock_run.call_args[1]["strategy"] == "thread" diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..d7b3f2e --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,41 @@ +import time + +import pytest +from pytest_mock import MockerFixture + +from pyflowx.utils import _perf_metrics, perf_timer + + +@pytest.fixture(autouse=True) +def reset_perf_metrics(): + """重置性能指标.""" + _perf_metrics.clear() + + +class TestPerformanceTimer: + def test_perf_timer(self): + + @perf_timer() + def test_func(): + time.sleep(0.1) + + test_func() + + assert _perf_metrics["test_func"] is not None + assert _perf_metrics["test_func"]["count"] == 1 + assert _perf_metrics["test_func"]["total_time"] >= 0.1 + + def test_perf_timer_report(self, mocker: MockerFixture): + mock_log = mocker.patch("logging.info") + + @perf_timer(report=True, unit="ms", precision=3) + def test_func(): + time.sleep(0.1) + + test_func() + + assert _perf_metrics["test_func"] is not None + assert _perf_metrics["test_func"]["count"] == 1 + assert _perf_metrics["test_func"]["total_time"] >= 0.1 + + assert mock_log.call_count == 1 diff --git a/uv.lock b/uv.lock index 0fbb470..07d0620 100644 --- a/uv.lock +++ b/uv.lock @@ -5603,7 +5603,7 @@ pycountry = [ [[package]] name = "pyflowx" -version = "0.2.9" +version = "0.2.10" source = { editable = "." } dependencies = [ { name = "graphlib-backport", marker = "python_full_version < '3.9'" },