diff --git a/src/pyflowx/executors.py b/src/pyflowx/executors.py index 133cbfc..29781c8 100644 --- a/src/pyflowx/executors.py +++ b/src/pyflowx/executors.py @@ -12,11 +12,12 @@ 架构 ---- -本模块通过 **Mixin** 组合消除同步/异步任务执行器之间的重复代码: +本模块通过 **模块级函数** 消除同步/异步任务执行器之间的重复代码: -* :class:`_TaskSkipMixin` —— 上游跳过 / 条件跳过的预检逻辑。 -* :class:`_TaskRetryMixin` —— 重试决策、成功/失败后处理、finalize。 -* :class:`SyncTaskRunner` / :class:`AsyncTaskRunner` —— 任务级执行器,组合上述 Mixin。 +* 模块级跳过/重试函数(:func:`_prepare_for_execution` / :func:`_should_retry` + / :func:`_mark_success` / :func:`_handle_failure` / :func:`_finalize_failure`) + —— 上游跳过 / 条件跳过的预检、重试决策、成功/失败后处理。 +* :class:`SyncTaskRunner` / :class:`AsyncTaskRunner` —— 任务级执行器,调用上述函数。 * 模块级共享辅助(:func:`_filter_and_sort` / :func:`_store_result` / :func:`_build_semaphores` / :func:`_get_sem`)—— 缓存过滤、优先级排序、 信号量构建、结果存储。 @@ -86,6 +87,22 @@ def _emit(on_event: EventCallback | None, result: TaskResult[Any]) -> None: ) +def _emit_running(on_event: EventCallback | None, spec: TaskSpec[Any]) -> None: + """触发 RUNNING 事件(任务开始执行时)。""" + if on_event is None: + return + on_event( + TaskEvent( + task=spec.name, + status=TaskStatus.RUNNING, + attempts=0, + error=None, + duration=None, + reason=None, + ) + ) + + def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None: """安全调用钩子(异常仅记录,不影响任务状态)。""" hook: Callable[..., None] | None = getattr(hooks, fn_name, None) @@ -129,11 +146,16 @@ def _apply_cached( backend: StateBackend, on_event: EventCallback | None, ) -> bool: - """若 ``name`` 命中缓存,写入 context/report 并返回 True。""" + """若 ``name`` 命中缓存,写入 context/report 并返回 True。 + + 单次 ``backend.get`` + ``KeyError`` 回退,避免 ``has`` + ``get`` 双重 + 哈希查找与双重 TTL 判断。 + """ storage_key = spec.storage_key(context) - if not backend.has(storage_key): + try: + cached = backend.get(storage_key) + except KeyError: return False - cached = backend.get(storage_key) context[name] = cached result = TaskResult(spec=spec, status=TaskStatus.SKIPPED, value=cached, reason="缓存命中") report.results[name] = result @@ -142,154 +164,146 @@ def _apply_cached( return True -def _sort_by_priority(layer: list[str], graph: Graph) -> list[str]: - """按优先级降序排序(稳定排序)。""" - return sorted(layer, key=lambda n: -graph.resolved_spec(n).priority) +def _sort_by_priority(layer: list[str], specs: Mapping[str, TaskSpec[Any]]) -> list[str]: + """按优先级降序排序(稳定排序)。 - -# ---------------------------------------------------------------------- # -# Mixin:任务级跳过 / 重试 / 成功处理 -# ---------------------------------------------------------------------- # -class _TaskSkipMixin: - """任务级跳过预检共享逻辑。 - - 将"上游被跳过/失败"与"条件不满足"两类跳过判断统一为单一入口, - 被 :class:`SyncTaskRunner` 与 :class:`AsyncTaskRunner` 复用。 + 接受预构建的 ``{name: spec}`` 映射,避免在排序键函数中重复调用 + ``graph.resolved_spec``(即便有缓存也省去 N 次字典查询)。 """ + return sorted(layer, key=lambda n: -specs[n].priority) - @staticmethod - def _upstream_skip_reason(spec: TaskSpec[Any], report: RunReport | None) -> str | None: - """硬依赖被 SKIPPED/FAILED 时返回原因字符串,否则 ``None``。 - 软依赖不影响本检查——软依赖被跳过时注入默认值。 - """ - if report is None or spec.allow_upstream_skip: - return None - for dep in spec.depends_on: - if dep not in report.results: - continue - dep_status = report.results[dep].status - if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED): - return f"上游任务 '{dep}' 状态为 {dep_status.value}" +# ---------------------------------------------------------------------- # +# 任务级跳过 / 重试 / 成功处理:模块级函数 +# ---------------------------------------------------------------------- # +def _upstream_skip_reason(spec: TaskSpec[Any], report: RunReport | None) -> str | None: + """硬依赖被 SKIPPED/FAILED 时返回原因字符串,否则 ``None``。 + + 软依赖不影响本检查——软依赖被跳过时注入默认值。 + """ + if report is None or spec.allow_upstream_skip: return None + for dep in spec.depends_on: + if dep not in report.results: + continue + dep_status = report.results[dep].status + if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED): + return f"上游任务 '{dep}' 状态为 {dep_status.value}" + return None - @staticmethod - def _prepare_for_execution( - spec: TaskSpec[Any], - context: Mapping[str, Any], - report: RunReport | None, - on_event: EventCallback | None, - ) -> TaskResult[Any] | None: - """执行前预检:上游跳过 / 条件跳过。 - 返回 SKIPPED TaskResult 或 ``None``(继续执行)。 - 条件判断委托给 :meth:`TaskSpec.should_execute`,避免重复实现。 - """ - # 1. 上游被跳过/失败 - skip_reason = _TaskSkipMixin._upstream_skip_reason(spec, report) - # 2. 条件 / skip_if_missing(单一来源:TaskSpec.should_execute) - if skip_reason is None: - should_run, cond_reason = spec.should_execute(context) - if not should_run: - skip_reason = cond_reason or "条件不满足" - if skip_reason is None: - return None - # 构造 SKIPPED 结果 - result: TaskResult[Any] = TaskResult( - spec=spec, - status=TaskStatus.SKIPPED, - finished_at=datetime.now(), - reason=skip_reason, +def _prepare_for_execution( + spec: TaskSpec[Any], + context: Mapping[str, Any], + report: RunReport | None, + on_event: EventCallback | None, +) -> TaskResult[Any] | None: + """执行前预检:上游跳过 / 条件跳过。 + + 返回 SKIPPED TaskResult 或 ``None``(继续执行)。 + 条件判断委托给 :meth:`TaskSpec.should_execute`,避免重复实现。 + """ + # 1. 上游被跳过/失败 + skip_reason = _upstream_skip_reason(spec, report) + # 2. 条件 / skip_if_missing(单一来源:TaskSpec.should_execute) + if skip_reason is None: + should_run, cond_reason = spec.should_execute(context) + if not should_run: + skip_reason = cond_reason or "条件不满足" + if skip_reason is None: + return None + # 构造 SKIPPED 结果 + result: TaskResult[Any] = TaskResult( + spec=spec, + status=TaskStatus.SKIPPED, + finished_at=datetime.now(), + reason=skip_reason, + ) + _emit(on_event, result) + logger.info("task %r skipped (%s)", spec.name, skip_reason) + return result + + +def _should_retry(spec: TaskSpec[Any], attempts: int, exc: BaseException) -> bool: + """是否应继续重试。""" + return attempts < spec.retry.max_attempts and spec.retry.should_retry(exc) + + +def _mark_success(spec: TaskSpec[Any], result: TaskResult[Any], value: Any) -> None: + """标记任务成功并触发 post_run 钩子。""" + result.value = value + result.status = TaskStatus.SUCCESS + result.finished_at = datetime.now() + _run_hooks(spec.hooks, "post_run", spec, value) + + +def _finalize_failure( + result: TaskResult[Any], + layer_idx: int | None, + on_event: EventCallback | None, + continue_on_error: bool, +) -> None: + """标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。""" + result.status = TaskStatus.FAILED + result.finished_at = datetime.now() + _emit(on_event, result) + if continue_on_error: + logger.warning( + "task %r failed but continue_on_error=True; continuing.", + result.spec.name, ) - _emit(on_event, result) - logger.info("task %r skipped (%s)", spec.name, skip_reason) - return result + return + raise TaskFailedError( + task=result.spec.name, + cause=result.error if result.error is not None else RuntimeError("unknown"), + attempts=result.attempts, + layer=layer_idx, + ) -class _TaskRetryMixin: - """任务级重试决策与失败/成功后处理共享逻辑。""" +def _handle_failure( + spec: TaskSpec[Any], + result: TaskResult[Any], + exc: BaseException, + layer_idx: int | None, + on_event: EventCallback | None, +) -> bool: + """统一处理失败:超时转换、重试决策、finalize。 - @staticmethod - def _should_retry(spec: TaskSpec[Any], attempts: int, exc: BaseException) -> bool: - """是否应继续重试。""" - return attempts < spec.retry.max_attempts and spec.retry.should_retry(exc) - - @staticmethod - def _mark_success(spec: TaskSpec[Any], result: TaskResult[Any], value: Any) -> None: - """标记任务成功并触发 post_run 钩子。""" - result.value = value - result.status = TaskStatus.SUCCESS - result.finished_at = datetime.now() - _run_hooks(spec.hooks, "post_run", spec, value) - - @staticmethod - def _finalize_failure( - result: TaskResult[Any], - layer_idx: int | None, - on_event: EventCallback | None, - continue_on_error: bool, - ) -> None: - """标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。""" - result.status = TaskStatus.FAILED - result.finished_at = datetime.now() - _emit(on_event, result) - if continue_on_error: - logger.warning( - "task %r failed but continue_on_error=True; continuing.", - result.spec.name, - ) - return - raise TaskFailedError( - task=result.spec.name, - cause=result.error if result.error is not None else RuntimeError("unknown"), - attempts=result.attempts, - layer=layer_idx, + Returns + ------- + bool + ``True`` 表示已 finalize(不再重试);``False`` 表示应继续重试。 + """ + # asyncio.TimeoutError → TaskTimeoutError(统一异常类型) + if isinstance(exc, asyncio.TimeoutError): + exc = TaskTimeoutError(spec.name, spec.timeout or 0.0) + logger.warning( + "task %r timed out (attempt %d/%d); retrying", + spec.name, + result.attempts, + spec.retry.max_attempts, ) - - @staticmethod - def _handle_failure( - spec: TaskSpec[Any], - result: TaskResult[Any], - exc: BaseException, - layer_idx: int | None, - on_event: EventCallback | None, - ) -> bool: - """统一处理失败:超时转换、重试决策、finalize。 - - Returns - ------- - bool - ``True`` 表示已 finalize(不再重试);``False`` 表示应继续重试。 - """ - # asyncio.TimeoutError → TaskTimeoutError(统一异常类型) - if isinstance(exc, asyncio.TimeoutError): - exc = TaskTimeoutError(spec.name, spec.timeout or 0.0) - logger.warning( - "task %r timed out (attempt %d/%d); retrying", - spec.name, - result.attempts, - spec.retry.max_attempts, - ) - else: - logger.warning( - "task %r failed (attempt %d/%d): %r; retrying", - spec.name, - result.attempts, - spec.retry.max_attempts, - exc, - ) - result.error = exc - if _TaskRetryMixin._should_retry(spec, result.attempts, exc): - return False - _run_hooks(spec.hooks, "on_failure", spec, exc) - _TaskRetryMixin._finalize_failure(result, layer_idx, on_event, spec.continue_on_error) - return True + else: + logger.warning( + "task %r failed (attempt %d/%d): %r; retrying", + spec.name, + result.attempts, + spec.retry.max_attempts, + exc, + ) + result.error = exc + if _should_retry(spec, result.attempts, exc): + return False + _run_hooks(spec.hooks, "on_failure", spec, exc) + _finalize_failure(result, layer_idx, on_event, spec.continue_on_error) + return True # ---------------------------------------------------------------------- # -# 任务执行器:同步 / 异步(复用 _TaskSkipMixin + _TaskRetryMixin) +# 任务执行器:同步 / 异步(调用模块级跳过/重试函数) # ---------------------------------------------------------------------- # -class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin): +class SyncTaskRunner: """同步任务执行器:带重试与跳过预检。""" @staticmethod @@ -300,7 +314,7 @@ class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin): on_event: EventCallback | None = None, report: RunReport | None = None, ) -> TaskResult[Any]: - skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event) + skipped = _prepare_for_execution(spec, context, report, on_event) if skipped is not None: return skipped @@ -309,23 +323,24 @@ class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin): args, kwargs = build_call_args(spec, context) _run_hooks(spec.hooks, "pre_run", spec) + _emit_running(on_event, spec) while True: result.attempts += 1 try: with spec.env_context(): value = spec.effective_fn(*args, **kwargs) - _TaskRetryMixin._mark_success(spec, result, value) + _mark_success(spec, result, value) return result except Exception as exc: - if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event): + if _handle_failure(spec, result, exc, layer_idx, on_event): return result wait = spec.retry.wait_seconds(result.attempts) if wait > 0: time.sleep(wait) -class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin): +class AsyncTaskRunner: """异步任务执行器:在事件循环上运行同步或异步任务,带重试与跳过预检。""" @staticmethod @@ -337,7 +352,7 @@ class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin): report: RunReport | None = None, semaphore: asyncio.Semaphore | None = None, ) -> TaskResult[Any]: - skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event) + skipped = _prepare_for_execution(spec, context, report, on_event) if skipped is not None: return skipped @@ -348,15 +363,16 @@ class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin): loop = asyncio.get_event_loop() _run_hooks(spec.hooks, "pre_run", spec) + _emit_running(on_event, spec) while True: result.attempts += 1 try: value = await _execute_async_task(spec, args, kwargs, loop) - _TaskRetryMixin._mark_success(spec, result, value) + _mark_success(spec, result, value) return result except Exception as exc: - if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event): + if _handle_failure(spec, result, exc, layer_idx, on_event): return result wait = spec.retry.wait_seconds(result.attempts) if wait > 0: @@ -401,13 +417,19 @@ def _filter_and_sort( backend: StateBackend, on_event: EventCallback | None, ) -> list[str]: - """过滤掉已命中缓存的任务,按优先级排序返回待运行列表。""" + """过滤掉已命中缓存的任务,按优先级排序返回待运行列表。 + + 预构建 ``{name: spec}`` 映射,过滤与排序共享同一份 resolved spec, + 避免 ``_sort_by_priority`` 内重复调用 ``graph.resolved_spec``。 + """ + specs: dict[str, TaskSpec[Any]] = {} to_run: list[str] = [] for name in layer: spec = graph.resolved_spec(name) + specs[name] = spec if not _apply_cached(name, spec, context, report, backend, on_event): to_run.append(name) - return _sort_by_priority(to_run, graph) + return _sort_by_priority(to_run, specs) def _store_result( @@ -619,7 +641,7 @@ def _make_verbose_callback(on_event: EventCallback | None) -> EventCallback: def _verbose_callback(event: TaskEvent) -> None: dur = f" ({event.duration:.3f}s)" if event.duration is not None else "" - if event.status == TaskStatus.RUNNING: # pragma: no cover + if event.status == TaskStatus.RUNNING: print(f"[verbose] 任务 {event.task!r} 开始执行...", flush=True) elif event.status == TaskStatus.SUCCESS: print(f"[verbose] 任务 {event.task!r} 成功{dur}", flush=True) @@ -684,6 +706,10 @@ def run( _print_dry_run(graph, layers) return RunReport(success=True) + # 入口统一校验一次:所有策略共用,避免 layers() / dependency 路径 + # 各自重复调用 validate()。 + graph.validate() + effective_callback: EventCallback | None = _make_verbose_callback(on_event) if verbose else on_event backend = resolve_backend(state) report = RunReport() @@ -705,7 +731,6 @@ def run( layers = graph.layers() asyncio.run(_async_drive(graph, layers, context, report, backend, effective_callback, limits)) elif strategy == "dependency": - graph.validate() asyncio.run(DependencyRunner.execute(graph, context, report, backend, effective_callback, limits)) else: raise ValueError(f"Unknown strategy: {strategy!r}") diff --git a/src/pyflowx/graph.py b/src/pyflowx/graph.py index 82b59d7..42058a7 100644 --- a/src/pyflowx/graph.py +++ b/src/pyflowx/graph.py @@ -231,8 +231,11 @@ class Graph: 同层任务无相互硬依赖,可并发执行。软依赖不参与分层。 层按执行顺序返回。图有环时抛出 :class:`CycleError`。 + + .. note:: + 本方法假定图已通过 :meth:`validate` 校验(由 :func:`pyflowx.run` + 在入口统一执行一次)。若直接调用本方法,需自行先校验。 """ - self.validate() sorter = _TopologicalSorter(self.deps) result: list[list[str]] = [] sorter.prepare() diff --git a/src/pyflowx/storage.py b/src/pyflowx/storage.py index b3354a8..16f9dbe 100644 --- a/src/pyflowx/storage.py +++ b/src/pyflowx/storage.py @@ -175,13 +175,6 @@ class MemoryBackend(_TTLStateBackendMixin): def _clear_raw(self) -> None: self._store.clear() - def _expired(self, key: str) -> bool: - """键是否已过期(兼容旧测试 API)。""" - entry = self._get_raw(key) - if entry is None: - return False - return self._is_expired(entry[1]) - class JSONBackend(_TTLStateBackendMixin): """基于文件的 JSON 存储,用于跨进程续跑。 @@ -283,10 +276,6 @@ class JSONBackend(_TTLStateBackendMixin): self._defer_flush = False self._flush() - def _expired(self, entry: Mapping[str, Any]) -> bool: - """带元数据的条目是否已过期(兼容旧测试 API)。""" - return self._is_expired(float(entry.get("ts", 0))) - def resolve_backend(backend: StateBackend | None) -> StateBackend: """返回 ``backend``;为 ``None`` 时返回新的 :class:`MemoryBackend`。""" diff --git a/src/pyflowx/task.py b/src/pyflowx/task.py index fd5b0c8..f4de1c5 100644 --- a/src/pyflowx/task.py +++ b/src/pyflowx/task.py @@ -17,6 +17,7 @@ from __future__ import annotations +import logging import os import shutil import sys @@ -68,6 +69,8 @@ TaskCmd = Union[ Strategy = Union[str, "StrategyKind"] StrategyKind = Any # 占位,避免循环;executors 模块用 Literal 约束 +logger = logging.getLogger("pyflowx") + # 条件判断函数类型:接收依赖上下文(可能为空映射),返回是否应执行。 Condition = Callable[[Context], bool] @@ -378,12 +381,20 @@ class TaskSpec(Generic[T]): def storage_key(self, context: Context) -> str: """计算状态后端存储键。""" - if self.cache_key is not None: - try: - return f"{self.name}:{self.cache_key(context)}" - except Exception: - return self.name - return self.name + if self.cache_key is None: + return self.name + try: + return f"{self.name}:{self.cache_key(context)}" + except (TypeError, ValueError, KeyError, AttributeError) as exc: + # cache_key 抛出预期内的数据/类型异常时回退到 name,但仍记录警告 + # 以便用户发现 cache_key 实现中的 bug。 + logger.warning( + "task %r: cache_key 回退到 name(%s: %s)", + self.name, + type(exc).__name__, + exc, + ) + return self.name # 全局锁:序列化对进程级状态(os.environ / os.chdir)的临时修改。 diff --git a/tests/test_executors_edge_cases.py b/tests/test_executors_edge_cases.py index 598fa1c..0893409 100644 --- a/tests/test_executors_edge_cases.py +++ b/tests/test_executors_edge_cases.py @@ -99,7 +99,10 @@ def test_verbose_run_with_skipped_lifecycle(capsys: pytest.CaptureFixture[str]): def test_verbose_run_with_user_callback(): - """Test px.run with verbose=True and user callback both called.""" + """Test px.run with verbose=True and user callback both called. + + 预期事件序列:RUNNING(开始)→ SUCCESS(完成)。 + """ events = [] def on_event(event: px.TaskEvent): @@ -109,8 +112,9 @@ def test_verbose_run_with_user_callback(): graph = px.Graph.from_specs([spec]) report = px.run(graph, strategy="sequential", verbose=True, on_event=on_event) assert report.success - assert len(events) == 1 - assert events[0].status == px.TaskStatus.SUCCESS + assert len(events) == 2 + assert events[0].status == px.TaskStatus.RUNNING + assert events[1].status == px.TaskStatus.SUCCESS def test_verbose_event_callback_success(): diff --git a/tests/test_storage.py b/tests/test_storage.py index cc9c2b4..3a681ea 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -70,9 +70,9 @@ def test_memory_backend_ttl_load_filters_expired() -> None: def test_memory_backend_expired_key_not_in_store() -> None: - """_expired 对不存在键返回 False.""" + """不存在的键 has 返回 False.""" b = MemoryBackend(ttl=1.0) - assert b._expired("nonexistent") is False + assert b.has("nonexistent") is False def test_memory_backend_no_ttl_never_expired() -> None: @@ -244,35 +244,35 @@ def test_json_backend_ttl_load_filters_expired() -> None: def test_json_backend_expired_no_ttl() -> None: - """无 TTL 时 _expired 返回 False.""" + """无 TTL 时永不过期.""" with tempfile.TemporaryDirectory() as tmp: path = str(Path(tmp) / "state.json") b = JSONBackend(path) b.save("a", 1) # 手动修改 ts 为很久以前 b._store["a"]["ts"] = time.time() - 1000 - assert b._expired(b._store["a"]) is False # 无 TTL,永不过期 + assert b.has("a") is True # 无 TTL,永不过期 def test_json_backend_expired_with_ttl() -> None: - """有 TTL 时 _expired 检查是否过期.""" + """有 TTL 时过期键 has 返回 False.""" with tempfile.TemporaryDirectory() as tmp: path = str(Path(tmp) / "state.json") b = JSONBackend(path, ttl=1.0) b.save("a", 1) # 手动修改 ts 为很久以前 b._store["a"]["ts"] = time.time() - 10 # 10 秒前,超过 TTL - assert b._expired(b._store["a"]) is True + assert b.has("a") is False def test_json_backend_expired_missing_ts() -> None: - """entry 缺少 ts 时使用默认值 0.""" + """entry 缺少 ts 时视为过期.""" with tempfile.TemporaryDirectory() as tmp: path = str(Path(tmp) / "state.json") b = JSONBackend(path, ttl=1.0) b._store["a"] = {"value": 1} # 缺少 ts # ts 默认为 0,已经过了很久 - assert b._expired(b._store["a"]) is True + assert b.has("a") is False def test_json_backend_save_value_error(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/tests/test_task.py b/tests/test_task.py index 8ce6d51..ab88bfc 100644 --- a/tests/test_task.py +++ b/tests/test_task.py @@ -203,10 +203,10 @@ def test_is_cmd_available_callable_returns_true() -> None: # storage_key 异常处理 # ---------------------------------------------------------------------- # def test_storage_key_cache_key_exception_returns_name() -> None: - """cache_key 抛异常时应返回任务名.""" + """cache_key 抛预期异常(TypeError/ValueError/KeyError/AttributeError)时应返回任务名.""" def bad_cache_key(_ctx): - raise RuntimeError("cache key error") + raise ValueError("cache key error") spec = TaskSpec("a", _fn, cache_key=bad_cache_key) key = spec.storage_key({})