refactor(executors): 重构执行器逻辑,移除重复mixin并优化分层排序

主要变更:
1.  将任务跳过/重试逻辑从类mixin改为模块级函数,减少代码重复
2.  优化_graph.layers()的前置校验逻辑,统一在run入口执行
3.  重构存储过期检查API,移除废弃的_expired方法
4.  优化TaskSpec.cache_key异常处理,增加指定异常捕获并记录警告
5.  修复verbose模式下的事件回调逻辑,正确触发RUNNING事件
6.  调整测试用例以适配新的API和行为变更
This commit is contained in:
2026-06-28 08:25:15 +08:00
parent bdd70e9c43
commit 9999071119
7 changed files with 213 additions and 181 deletions
+175 -150
View File
@@ -12,11 +12,12 @@
架构 架构
---- ----
本模块通过 **Mixin** 组合消除同步/异步任务执行器之间的重复代码: 本模块通过 **模块级函数** 消除同步/异步任务执行器之间的重复代码:
* :class:`_TaskSkipMixin` —— 上游跳过 / 条件跳过的预检逻辑。 * 模块级跳过/重试函数(:func:`_prepare_for_execution` / :func:`_should_retry`
* :class:`_TaskRetryMixin` —— 重试决策、成功/失败后处理、finalize。 / :func:`_mark_success` / :func:`_handle_failure` / :func:`_finalize_failure`
* :class:`SyncTaskRunner` / :class:`AsyncTaskRunner` —— 任务级执行器,组合上述 Mixin —— 上游跳过 / 条件跳过的预检、重试决策、成功/失败后处理
* :class:`SyncTaskRunner` / :class:`AsyncTaskRunner` —— 任务级执行器,调用上述函数。
* 模块级共享辅助(:func:`_filter_and_sort` / :func:`_store_result` / * 模块级共享辅助(:func:`_filter_and_sort` / :func:`_store_result` /
:func:`_build_semaphores` / :func:`_get_sem`)—— 缓存过滤、优先级排序、 :func:`_build_semaphores` / :func:`_get_sem`)—— 缓存过滤、优先级排序、
信号量构建、结果存储。 信号量构建、结果存储。
@@ -86,6 +87,22 @@ def _emit(on_event: EventCallback | None, result: TaskResult[Any]) -> None:
) )
def _emit_running(on_event: EventCallback | None, spec: TaskSpec[Any]) -> None:
"""触发 RUNNING 事件(任务开始执行时)。"""
if on_event is None:
return
on_event(
TaskEvent(
task=spec.name,
status=TaskStatus.RUNNING,
attempts=0,
error=None,
duration=None,
reason=None,
)
)
def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None: def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None:
"""安全调用钩子(异常仅记录,不影响任务状态)。""" """安全调用钩子(异常仅记录,不影响任务状态)。"""
hook: Callable[..., None] | None = getattr(hooks, fn_name, None) hook: Callable[..., None] | None = getattr(hooks, fn_name, None)
@@ -129,11 +146,16 @@ def _apply_cached(
backend: StateBackend, backend: StateBackend,
on_event: EventCallback | None, on_event: EventCallback | None,
) -> bool: ) -> bool:
"""若 ``name`` 命中缓存,写入 context/report 并返回 True。""" """若 ``name`` 命中缓存,写入 context/report 并返回 True。
单次 ``backend.get`` + ``KeyError`` 回退,避免 ``has`` + ``get`` 双重
哈希查找与双重 TTL 判断。
"""
storage_key = spec.storage_key(context) storage_key = spec.storage_key(context)
if not backend.has(storage_key): try:
cached = backend.get(storage_key)
except KeyError:
return False return False
cached = backend.get(storage_key)
context[name] = cached context[name] = cached
result = TaskResult(spec=spec, status=TaskStatus.SKIPPED, value=cached, reason="缓存命中") result = TaskResult(spec=spec, status=TaskStatus.SKIPPED, value=cached, reason="缓存命中")
report.results[name] = result report.results[name] = result
@@ -142,154 +164,146 @@ def _apply_cached(
return True return True
def _sort_by_priority(layer: list[str], graph: Graph) -> list[str]: def _sort_by_priority(layer: list[str], specs: Mapping[str, TaskSpec[Any]]) -> list[str]:
"""按优先级降序排序(稳定排序)。""" """按优先级降序排序(稳定排序)。
return sorted(layer, key=lambda n: -graph.resolved_spec(n).priority)
接受预构建的 ``{name: spec}`` 映射,避免在排序键函数中重复调用
# ---------------------------------------------------------------------- # ``graph.resolved_spec``(即便有缓存也省去 N 次字典查询)。
# Mixin:任务级跳过 / 重试 / 成功处理
# ---------------------------------------------------------------------- #
class _TaskSkipMixin:
"""任务级跳过预检共享逻辑。
"上游被跳过/失败""条件不满足"两类跳过判断统一为单一入口,
被 :class:`SyncTaskRunner` 与 :class:`AsyncTaskRunner` 复用。
""" """
return sorted(layer, key=lambda n: -specs[n].priority)
@staticmethod
def _upstream_skip_reason(spec: TaskSpec[Any], report: RunReport | None) -> str | None:
"""硬依赖被 SKIPPED/FAILED 时返回原因字符串,否则 ``None``。
软依赖不影响本检查——软依赖被跳过时注入默认值。 # ---------------------------------------------------------------------- #
""" # 任务级跳过 / 重试 / 成功处理:模块级函数
if report is None or spec.allow_upstream_skip: # ---------------------------------------------------------------------- #
return None def _upstream_skip_reason(spec: TaskSpec[Any], report: RunReport | None) -> str | None:
for dep in spec.depends_on: """硬依赖被 SKIPPED/FAILED 时返回原因字符串,否则 ``None``。
if dep not in report.results:
continue 软依赖不影响本检查——软依赖被跳过时注入默认值。
dep_status = report.results[dep].status """
if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED): if report is None or spec.allow_upstream_skip:
return f"上游任务 '{dep}' 状态为 {dep_status.value}"
return None return None
for dep in spec.depends_on:
if dep not in report.results:
continue
dep_status = report.results[dep].status
if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED):
return f"上游任务 '{dep}' 状态为 {dep_status.value}"
return None
@staticmethod
def _prepare_for_execution(
spec: TaskSpec[Any],
context: Mapping[str, Any],
report: RunReport | None,
on_event: EventCallback | None,
) -> TaskResult[Any] | None:
"""执行前预检:上游跳过 / 条件跳过。
返回 SKIPPED TaskResult 或 ``None``(继续执行)。 def _prepare_for_execution(
条件判断委托给 :meth:`TaskSpec.should_execute`,避免重复实现。 spec: TaskSpec[Any],
""" context: Mapping[str, Any],
# 1. 上游被跳过/失败 report: RunReport | None,
skip_reason = _TaskSkipMixin._upstream_skip_reason(spec, report) on_event: EventCallback | None,
# 2. 条件 / skip_if_missing(单一来源:TaskSpec.should_execute ) -> TaskResult[Any] | None:
if skip_reason is None: """执行前预检:上游跳过 / 条件跳过。
should_run, cond_reason = spec.should_execute(context)
if not should_run: 返回 SKIPPED TaskResult 或 ``None``(继续执行)。
skip_reason = cond_reason or "条件不满足" 条件判断委托给 :meth:`TaskSpec.should_execute`,避免重复实现。
if skip_reason is None: """
return None # 1. 上游被跳过/失败
# 构造 SKIPPED 结果 skip_reason = _upstream_skip_reason(spec, report)
result: TaskResult[Any] = TaskResult( # 2. 条件 / skip_if_missing(单一来源:TaskSpec.should_execute
spec=spec, if skip_reason is None:
status=TaskStatus.SKIPPED, should_run, cond_reason = spec.should_execute(context)
finished_at=datetime.now(), if not should_run:
reason=skip_reason, skip_reason = cond_reason or "条件不满足"
if skip_reason is None:
return None
# 构造 SKIPPED 结果
result: TaskResult[Any] = TaskResult(
spec=spec,
status=TaskStatus.SKIPPED,
finished_at=datetime.now(),
reason=skip_reason,
)
_emit(on_event, result)
logger.info("task %r skipped (%s)", spec.name, skip_reason)
return result
def _should_retry(spec: TaskSpec[Any], attempts: int, exc: BaseException) -> bool:
"""是否应继续重试。"""
return attempts < spec.retry.max_attempts and spec.retry.should_retry(exc)
def _mark_success(spec: TaskSpec[Any], result: TaskResult[Any], value: Any) -> None:
"""标记任务成功并触发 post_run 钩子。"""
result.value = value
result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now()
_run_hooks(spec.hooks, "post_run", spec, value)
def _finalize_failure(
result: TaskResult[Any],
layer_idx: int | None,
on_event: EventCallback | None,
continue_on_error: bool,
) -> None:
"""标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。"""
result.status = TaskStatus.FAILED
result.finished_at = datetime.now()
_emit(on_event, result)
if continue_on_error:
logger.warning(
"task %r failed but continue_on_error=True; continuing.",
result.spec.name,
) )
_emit(on_event, result) return
logger.info("task %r skipped (%s)", spec.name, skip_reason) raise TaskFailedError(
return result task=result.spec.name,
cause=result.error if result.error is not None else RuntimeError("unknown"),
attempts=result.attempts,
layer=layer_idx,
)
class _TaskRetryMixin: def _handle_failure(
"""任务级重试决策与失败/成功后处理共享逻辑。""" spec: TaskSpec[Any],
result: TaskResult[Any],
exc: BaseException,
layer_idx: int | None,
on_event: EventCallback | None,
) -> bool:
"""统一处理失败:超时转换、重试决策、finalize。
@staticmethod Returns
def _should_retry(spec: TaskSpec[Any], attempts: int, exc: BaseException) -> bool: -------
"""是否应继续重试。""" bool
return attempts < spec.retry.max_attempts and spec.retry.should_retry(exc) ``True`` 表示已 finalize(不再重试);``False`` 表示应继续重试。
"""
@staticmethod # asyncio.TimeoutError → TaskTimeoutError(统一异常类型)
def _mark_success(spec: TaskSpec[Any], result: TaskResult[Any], value: Any) -> None: if isinstance(exc, asyncio.TimeoutError):
"""标记任务成功并触发 post_run 钩子。""" exc = TaskTimeoutError(spec.name, spec.timeout or 0.0)
result.value = value logger.warning(
result.status = TaskStatus.SUCCESS "task %r timed out (attempt %d/%d); retrying",
result.finished_at = datetime.now() spec.name,
_run_hooks(spec.hooks, "post_run", spec, value) result.attempts,
spec.retry.max_attempts,
@staticmethod
def _finalize_failure(
result: TaskResult[Any],
layer_idx: int | None,
on_event: EventCallback | None,
continue_on_error: bool,
) -> None:
"""标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。"""
result.status = TaskStatus.FAILED
result.finished_at = datetime.now()
_emit(on_event, result)
if continue_on_error:
logger.warning(
"task %r failed but continue_on_error=True; continuing.",
result.spec.name,
)
return
raise TaskFailedError(
task=result.spec.name,
cause=result.error if result.error is not None else RuntimeError("unknown"),
attempts=result.attempts,
layer=layer_idx,
) )
else:
@staticmethod logger.warning(
def _handle_failure( "task %r failed (attempt %d/%d): %r; retrying",
spec: TaskSpec[Any], spec.name,
result: TaskResult[Any], result.attempts,
exc: BaseException, spec.retry.max_attempts,
layer_idx: int | None, exc,
on_event: EventCallback | None, )
) -> bool: result.error = exc
"""统一处理失败:超时转换、重试决策、finalize。 if _should_retry(spec, result.attempts, exc):
return False
Returns _run_hooks(spec.hooks, "on_failure", spec, exc)
------- _finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
bool return True
``True`` 表示已 finalize(不再重试);``False`` 表示应继续重试。
"""
# asyncio.TimeoutError → TaskTimeoutError(统一异常类型)
if isinstance(exc, asyncio.TimeoutError):
exc = TaskTimeoutError(spec.name, spec.timeout or 0.0)
logger.warning(
"task %r timed out (attempt %d/%d); retrying",
spec.name,
result.attempts,
spec.retry.max_attempts,
)
else:
logger.warning(
"task %r failed (attempt %d/%d): %r; retrying",
spec.name,
result.attempts,
spec.retry.max_attempts,
exc,
)
result.error = exc
if _TaskRetryMixin._should_retry(spec, result.attempts, exc):
return False
_run_hooks(spec.hooks, "on_failure", spec, exc)
_TaskRetryMixin._finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return True
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# 任务执行器:同步 / 异步(复用 _TaskSkipMixin + _TaskRetryMixin # 任务执行器:同步 / 异步(调用模块级跳过/重试函数
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin): class SyncTaskRunner:
"""同步任务执行器:带重试与跳过预检。""" """同步任务执行器:带重试与跳过预检。"""
@staticmethod @staticmethod
@@ -300,7 +314,7 @@ class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
on_event: EventCallback | None = None, on_event: EventCallback | None = None,
report: RunReport | None = None, report: RunReport | None = None,
) -> TaskResult[Any]: ) -> TaskResult[Any]:
skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event) skipped = _prepare_for_execution(spec, context, report, on_event)
if skipped is not None: if skipped is not None:
return skipped return skipped
@@ -309,23 +323,24 @@ class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
args, kwargs = build_call_args(spec, context) args, kwargs = build_call_args(spec, context)
_run_hooks(spec.hooks, "pre_run", spec) _run_hooks(spec.hooks, "pre_run", spec)
_emit_running(on_event, spec)
while True: while True:
result.attempts += 1 result.attempts += 1
try: try:
with spec.env_context(): with spec.env_context():
value = spec.effective_fn(*args, **kwargs) value = spec.effective_fn(*args, **kwargs)
_TaskRetryMixin._mark_success(spec, result, value) _mark_success(spec, result, value)
return result return result
except Exception as exc: except Exception as exc:
if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event): if _handle_failure(spec, result, exc, layer_idx, on_event):
return result return result
wait = spec.retry.wait_seconds(result.attempts) wait = spec.retry.wait_seconds(result.attempts)
if wait > 0: if wait > 0:
time.sleep(wait) time.sleep(wait)
class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin): class AsyncTaskRunner:
"""异步任务执行器:在事件循环上运行同步或异步任务,带重试与跳过预检。""" """异步任务执行器:在事件循环上运行同步或异步任务,带重试与跳过预检。"""
@staticmethod @staticmethod
@@ -337,7 +352,7 @@ class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
report: RunReport | None = None, report: RunReport | None = None,
semaphore: asyncio.Semaphore | None = None, semaphore: asyncio.Semaphore | None = None,
) -> TaskResult[Any]: ) -> TaskResult[Any]:
skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event) skipped = _prepare_for_execution(spec, context, report, on_event)
if skipped is not None: if skipped is not None:
return skipped return skipped
@@ -348,15 +363,16 @@ class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
_run_hooks(spec.hooks, "pre_run", spec) _run_hooks(spec.hooks, "pre_run", spec)
_emit_running(on_event, spec)
while True: while True:
result.attempts += 1 result.attempts += 1
try: try:
value = await _execute_async_task(spec, args, kwargs, loop) value = await _execute_async_task(spec, args, kwargs, loop)
_TaskRetryMixin._mark_success(spec, result, value) _mark_success(spec, result, value)
return result return result
except Exception as exc: except Exception as exc:
if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event): if _handle_failure(spec, result, exc, layer_idx, on_event):
return result return result
wait = spec.retry.wait_seconds(result.attempts) wait = spec.retry.wait_seconds(result.attempts)
if wait > 0: if wait > 0:
@@ -401,13 +417,19 @@ def _filter_and_sort(
backend: StateBackend, backend: StateBackend,
on_event: EventCallback | None, on_event: EventCallback | None,
) -> list[str]: ) -> list[str]:
"""过滤掉已命中缓存的任务,按优先级排序返回待运行列表。""" """过滤掉已命中缓存的任务,按优先级排序返回待运行列表。
预构建 ``{name: spec}`` 映射,过滤与排序共享同一份 resolved spec
避免 ``_sort_by_priority`` 内重复调用 ``graph.resolved_spec``。
"""
specs: dict[str, TaskSpec[Any]] = {}
to_run: list[str] = [] to_run: list[str] = []
for name in layer: for name in layer:
spec = graph.resolved_spec(name) spec = graph.resolved_spec(name)
specs[name] = spec
if not _apply_cached(name, spec, context, report, backend, on_event): if not _apply_cached(name, spec, context, report, backend, on_event):
to_run.append(name) to_run.append(name)
return _sort_by_priority(to_run, graph) return _sort_by_priority(to_run, specs)
def _store_result( def _store_result(
@@ -619,7 +641,7 @@ def _make_verbose_callback(on_event: EventCallback | None) -> EventCallback:
def _verbose_callback(event: TaskEvent) -> None: def _verbose_callback(event: TaskEvent) -> None:
dur = f" ({event.duration:.3f}s)" if event.duration is not None else "" dur = f" ({event.duration:.3f}s)" if event.duration is not None else ""
if event.status == TaskStatus.RUNNING: # pragma: no cover if event.status == TaskStatus.RUNNING:
print(f"[verbose] 任务 {event.task!r} 开始执行...", flush=True) print(f"[verbose] 任务 {event.task!r} 开始执行...", flush=True)
elif event.status == TaskStatus.SUCCESS: elif event.status == TaskStatus.SUCCESS:
print(f"[verbose] 任务 {event.task!r} 成功{dur}", flush=True) print(f"[verbose] 任务 {event.task!r} 成功{dur}", flush=True)
@@ -684,6 +706,10 @@ def run(
_print_dry_run(graph, layers) _print_dry_run(graph, layers)
return RunReport(success=True) return RunReport(success=True)
# 入口统一校验一次:所有策略共用,避免 layers() / dependency 路径
# 各自重复调用 validate()。
graph.validate()
effective_callback: EventCallback | None = _make_verbose_callback(on_event) if verbose else on_event effective_callback: EventCallback | None = _make_verbose_callback(on_event) if verbose else on_event
backend = resolve_backend(state) backend = resolve_backend(state)
report = RunReport() report = RunReport()
@@ -705,7 +731,6 @@ def run(
layers = graph.layers() layers = graph.layers()
asyncio.run(_async_drive(graph, layers, context, report, backend, effective_callback, limits)) asyncio.run(_async_drive(graph, layers, context, report, backend, effective_callback, limits))
elif strategy == "dependency": elif strategy == "dependency":
graph.validate()
asyncio.run(DependencyRunner.execute(graph, context, report, backend, effective_callback, limits)) asyncio.run(DependencyRunner.execute(graph, context, report, backend, effective_callback, limits))
else: else:
raise ValueError(f"Unknown strategy: {strategy!r}") raise ValueError(f"Unknown strategy: {strategy!r}")
+4 -1
View File
@@ -231,8 +231,11 @@ class Graph:
同层任务无相互硬依赖,可并发执行。软依赖不参与分层。 同层任务无相互硬依赖,可并发执行。软依赖不参与分层。
层按执行顺序返回。图有环时抛出 :class:`CycleError`。 层按执行顺序返回。图有环时抛出 :class:`CycleError`。
.. note::
本方法假定图已通过 :meth:`validate` 校验(由 :func:`pyflowx.run`
在入口统一执行一次)。若直接调用本方法,需自行先校验。
""" """
self.validate()
sorter = _TopologicalSorter(self.deps) sorter = _TopologicalSorter(self.deps)
result: list[list[str]] = [] result: list[list[str]] = []
sorter.prepare() sorter.prepare()
-11
View File
@@ -175,13 +175,6 @@ class MemoryBackend(_TTLStateBackendMixin):
def _clear_raw(self) -> None: def _clear_raw(self) -> None:
self._store.clear() self._store.clear()
def _expired(self, key: str) -> bool:
"""键是否已过期(兼容旧测试 API)。"""
entry = self._get_raw(key)
if entry is None:
return False
return self._is_expired(entry[1])
class JSONBackend(_TTLStateBackendMixin): class JSONBackend(_TTLStateBackendMixin):
"""基于文件的 JSON 存储,用于跨进程续跑。 """基于文件的 JSON 存储,用于跨进程续跑。
@@ -283,10 +276,6 @@ class JSONBackend(_TTLStateBackendMixin):
self._defer_flush = False self._defer_flush = False
self._flush() self._flush()
def _expired(self, entry: Mapping[str, Any]) -> bool:
"""带元数据的条目是否已过期(兼容旧测试 API)。"""
return self._is_expired(float(entry.get("ts", 0)))
def resolve_backend(backend: StateBackend | None) -> StateBackend: def resolve_backend(backend: StateBackend | None) -> StateBackend:
"""返回 ``backend``;为 ``None`` 时返回新的 :class:`MemoryBackend`。""" """返回 ``backend``;为 ``None`` 时返回新的 :class:`MemoryBackend`。"""
+17 -6
View File
@@ -17,6 +17,7 @@
from __future__ import annotations from __future__ import annotations
import logging
import os import os
import shutil import shutil
import sys import sys
@@ -68,6 +69,8 @@ TaskCmd = Union[
Strategy = Union[str, "StrategyKind"] Strategy = Union[str, "StrategyKind"]
StrategyKind = Any # 占位,避免循环;executors 模块用 Literal 约束 StrategyKind = Any # 占位,避免循环;executors 模块用 Literal 约束
logger = logging.getLogger("pyflowx")
# 条件判断函数类型:接收依赖上下文(可能为空映射),返回是否应执行。 # 条件判断函数类型:接收依赖上下文(可能为空映射),返回是否应执行。
Condition = Callable[[Context], bool] Condition = Callable[[Context], bool]
@@ -378,12 +381,20 @@ class TaskSpec(Generic[T]):
def storage_key(self, context: Context) -> str: def storage_key(self, context: Context) -> str:
"""计算状态后端存储键。""" """计算状态后端存储键。"""
if self.cache_key is not None: if self.cache_key is None:
try: return self.name
return f"{self.name}:{self.cache_key(context)}" try:
except Exception: return f"{self.name}:{self.cache_key(context)}"
return self.name except (TypeError, ValueError, KeyError, AttributeError) as exc:
return self.name # cache_key 抛出预期内的数据/类型异常时回退到 name,但仍记录警告
# 以便用户发现 cache_key 实现中的 bug。
logger.warning(
"task %r: cache_key 回退到 name%s: %s",
self.name,
type(exc).__name__,
exc,
)
return self.name
# 全局锁:序列化对进程级状态(os.environ / os.chdir)的临时修改。 # 全局锁:序列化对进程级状态(os.environ / os.chdir)的临时修改。
+7 -3
View File
@@ -99,7 +99,10 @@ def test_verbose_run_with_skipped_lifecycle(capsys: pytest.CaptureFixture[str]):
def test_verbose_run_with_user_callback(): def test_verbose_run_with_user_callback():
"""Test px.run with verbose=True and user callback both called.""" """Test px.run with verbose=True and user callback both called.
预期事件序列:RUNNING(开始)→ SUCCESS(完成)。
"""
events = [] events = []
def on_event(event: px.TaskEvent): def on_event(event: px.TaskEvent):
@@ -109,8 +112,9 @@ def test_verbose_run_with_user_callback():
graph = px.Graph.from_specs([spec]) graph = px.Graph.from_specs([spec])
report = px.run(graph, strategy="sequential", verbose=True, on_event=on_event) report = px.run(graph, strategy="sequential", verbose=True, on_event=on_event)
assert report.success assert report.success
assert len(events) == 1 assert len(events) == 2
assert events[0].status == px.TaskStatus.SUCCESS assert events[0].status == px.TaskStatus.RUNNING
assert events[1].status == px.TaskStatus.SUCCESS
def test_verbose_event_callback_success(): def test_verbose_event_callback_success():
+8 -8
View File
@@ -70,9 +70,9 @@ def test_memory_backend_ttl_load_filters_expired() -> None:
def test_memory_backend_expired_key_not_in_store() -> None: def test_memory_backend_expired_key_not_in_store() -> None:
"""_expired 对不存在键返回 False.""" """不存在的键 has 返回 False."""
b = MemoryBackend(ttl=1.0) b = MemoryBackend(ttl=1.0)
assert b._expired("nonexistent") is False assert b.has("nonexistent") is False
def test_memory_backend_no_ttl_never_expired() -> None: def test_memory_backend_no_ttl_never_expired() -> None:
@@ -244,35 +244,35 @@ def test_json_backend_ttl_load_filters_expired() -> None:
def test_json_backend_expired_no_ttl() -> None: def test_json_backend_expired_no_ttl() -> None:
"""无 TTL 时 _expired 返回 False.""" """无 TTL 时永不过期."""
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
path = str(Path(tmp) / "state.json") path = str(Path(tmp) / "state.json")
b = JSONBackend(path) b = JSONBackend(path)
b.save("a", 1) b.save("a", 1)
# 手动修改 ts 为很久以前 # 手动修改 ts 为很久以前
b._store["a"]["ts"] = time.time() - 1000 b._store["a"]["ts"] = time.time() - 1000
assert b._expired(b._store["a"]) is False # 无 TTL,永不过期 assert b.has("a") is True # 无 TTL,永不过期
def test_json_backend_expired_with_ttl() -> None: def test_json_backend_expired_with_ttl() -> None:
"""有 TTL 时 _expired 检查是否过期.""" """有 TTL 时过期键 has 返回 False."""
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
path = str(Path(tmp) / "state.json") path = str(Path(tmp) / "state.json")
b = JSONBackend(path, ttl=1.0) b = JSONBackend(path, ttl=1.0)
b.save("a", 1) b.save("a", 1)
# 手动修改 ts 为很久以前 # 手动修改 ts 为很久以前
b._store["a"]["ts"] = time.time() - 10 # 10 秒前,超过 TTL b._store["a"]["ts"] = time.time() - 10 # 10 秒前,超过 TTL
assert b._expired(b._store["a"]) is True assert b.has("a") is False
def test_json_backend_expired_missing_ts() -> None: def test_json_backend_expired_missing_ts() -> None:
"""entry 缺少 ts 时使用默认值 0.""" """entry 缺少 ts 时视为过期."""
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
path = str(Path(tmp) / "state.json") path = str(Path(tmp) / "state.json")
b = JSONBackend(path, ttl=1.0) b = JSONBackend(path, ttl=1.0)
b._store["a"] = {"value": 1} # 缺少 ts b._store["a"] = {"value": 1} # 缺少 ts
# ts 默认为 0,已经过了很久 # ts 默认为 0,已经过了很久
assert b._expired(b._store["a"]) is True assert b.has("a") is False
def test_json_backend_save_value_error(monkeypatch: pytest.MonkeyPatch) -> None: def test_json_backend_save_value_error(monkeypatch: pytest.MonkeyPatch) -> None:
+2 -2
View File
@@ -203,10 +203,10 @@ def test_is_cmd_available_callable_returns_true() -> None:
# storage_key 异常处理 # storage_key 异常处理
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
def test_storage_key_cache_key_exception_returns_name() -> None: def test_storage_key_cache_key_exception_returns_name() -> None:
"""cache_key 抛异常时应返回任务名.""" """cache_key 抛预期异常(TypeError/ValueError/KeyError/AttributeError时应返回任务名."""
def bad_cache_key(_ctx): def bad_cache_key(_ctx):
raise RuntimeError("cache key error") raise ValueError("cache key error")
spec = TaskSpec("a", _fn, cache_key=bad_cache_key) spec = TaskSpec("a", _fn, cache_key=bad_cache_key)
key = spec.storage_key({}) key = spec.storage_key({})