refactor(executors): 重构执行器逻辑,移除重复mixin并优化分层排序

主要变更:
1.  将任务跳过/重试逻辑从类mixin改为模块级函数,减少代码重复
2.  优化_graph.layers()的前置校验逻辑,统一在run入口执行
3.  重构存储过期检查API,移除废弃的_expired方法
4.  优化TaskSpec.cache_key异常处理,增加指定异常捕获并记录警告
5.  修复verbose模式下的事件回调逻辑,正确触发RUNNING事件
6.  调整测试用例以适配新的API和行为变更
This commit is contained in:
2026-06-28 08:25:15 +08:00
parent bdd70e9c43
commit 9999071119
7 changed files with 213 additions and 181 deletions
+69 -44
View File
@@ -12,11 +12,12 @@
架构
----
本模块通过 **Mixin** 组合消除同步/异步任务执行器之间的重复代码:
本模块通过 **模块级函数** 消除同步/异步任务执行器之间的重复代码:
* :class:`_TaskSkipMixin` —— 上游跳过 / 条件跳过的预检逻辑。
* :class:`_TaskRetryMixin` —— 重试决策、成功/失败后处理、finalize。
* :class:`SyncTaskRunner` / :class:`AsyncTaskRunner` —— 任务级执行器,组合上述 Mixin
* 模块级跳过/重试函数(:func:`_prepare_for_execution` / :func:`_should_retry`
/ :func:`_mark_success` / :func:`_handle_failure` / :func:`_finalize_failure`
—— 上游跳过 / 条件跳过的预检、重试决策、成功/失败后处理
* :class:`SyncTaskRunner` / :class:`AsyncTaskRunner` —— 任务级执行器,调用上述函数。
* 模块级共享辅助(:func:`_filter_and_sort` / :func:`_store_result` /
:func:`_build_semaphores` / :func:`_get_sem`)—— 缓存过滤、优先级排序、
信号量构建、结果存储。
@@ -86,6 +87,22 @@ def _emit(on_event: EventCallback | None, result: TaskResult[Any]) -> None:
)
def _emit_running(on_event: EventCallback | None, spec: TaskSpec[Any]) -> None:
"""触发 RUNNING 事件(任务开始执行时)。"""
if on_event is None:
return
on_event(
TaskEvent(
task=spec.name,
status=TaskStatus.RUNNING,
attempts=0,
error=None,
duration=None,
reason=None,
)
)
def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None:
"""安全调用钩子(异常仅记录,不影响任务状态)。"""
hook: Callable[..., None] | None = getattr(hooks, fn_name, None)
@@ -129,11 +146,16 @@ def _apply_cached(
backend: StateBackend,
on_event: EventCallback | None,
) -> bool:
"""若 ``name`` 命中缓存,写入 context/report 并返回 True。"""
"""若 ``name`` 命中缓存,写入 context/report 并返回 True。
单次 ``backend.get`` + ``KeyError`` 回退,避免 ``has`` + ``get`` 双重
哈希查找与双重 TTL 判断。
"""
storage_key = spec.storage_key(context)
if not backend.has(storage_key):
return False
try:
cached = backend.get(storage_key)
except KeyError:
return False
context[name] = cached
result = TaskResult(spec=spec, status=TaskStatus.SKIPPED, value=cached, reason="缓存命中")
report.results[name] = result
@@ -142,22 +164,18 @@ def _apply_cached(
return True
def _sort_by_priority(layer: list[str], graph: Graph) -> list[str]:
"""按优先级降序排序(稳定排序)。"""
return sorted(layer, key=lambda n: -graph.resolved_spec(n).priority)
def _sort_by_priority(layer: list[str], specs: Mapping[str, TaskSpec[Any]]) -> list[str]:
"""按优先级降序排序(稳定排序)。
# ---------------------------------------------------------------------- #
# Mixin:任务级跳过 / 重试 / 成功处理
# ---------------------------------------------------------------------- #
class _TaskSkipMixin:
"""任务级跳过预检共享逻辑。
"上游被跳过/失败""条件不满足"两类跳过判断统一为单一入口,
被 :class:`SyncTaskRunner` 与 :class:`AsyncTaskRunner` 复用。
接受预构建的 ``{name: spec}`` 映射,避免在排序键函数中重复调用
``graph.resolved_spec``(即便有缓存也省去 N 次字典查询)。
"""
return sorted(layer, key=lambda n: -specs[n].priority)
@staticmethod
# ---------------------------------------------------------------------- #
# 任务级跳过 / 重试 / 成功处理:模块级函数
# ---------------------------------------------------------------------- #
def _upstream_skip_reason(spec: TaskSpec[Any], report: RunReport | None) -> str | None:
"""硬依赖被 SKIPPED/FAILED 时返回原因字符串,否则 ``None``。
@@ -173,7 +191,7 @@ class _TaskSkipMixin:
return f"上游任务 '{dep}' 状态为 {dep_status.value}"
return None
@staticmethod
def _prepare_for_execution(
spec: TaskSpec[Any],
context: Mapping[str, Any],
@@ -186,7 +204,7 @@ class _TaskSkipMixin:
条件判断委托给 :meth:`TaskSpec.should_execute`,避免重复实现。
"""
# 1. 上游被跳过/失败
skip_reason = _TaskSkipMixin._upstream_skip_reason(spec, report)
skip_reason = _upstream_skip_reason(spec, report)
# 2. 条件 / skip_if_missing(单一来源:TaskSpec.should_execute
if skip_reason is None:
should_run, cond_reason = spec.should_execute(context)
@@ -206,15 +224,11 @@ class _TaskSkipMixin:
return result
class _TaskRetryMixin:
"""任务级重试决策与失败/成功后处理共享逻辑。"""
@staticmethod
def _should_retry(spec: TaskSpec[Any], attempts: int, exc: BaseException) -> bool:
"""是否应继续重试。"""
return attempts < spec.retry.max_attempts and spec.retry.should_retry(exc)
@staticmethod
def _mark_success(spec: TaskSpec[Any], result: TaskResult[Any], value: Any) -> None:
"""标记任务成功并触发 post_run 钩子。"""
result.value = value
@@ -222,7 +236,7 @@ class _TaskRetryMixin:
result.finished_at = datetime.now()
_run_hooks(spec.hooks, "post_run", spec, value)
@staticmethod
def _finalize_failure(
result: TaskResult[Any],
layer_idx: int | None,
@@ -246,7 +260,7 @@ class _TaskRetryMixin:
layer=layer_idx,
)
@staticmethod
def _handle_failure(
spec: TaskSpec[Any],
result: TaskResult[Any],
@@ -279,17 +293,17 @@ class _TaskRetryMixin:
exc,
)
result.error = exc
if _TaskRetryMixin._should_retry(spec, result.attempts, exc):
if _should_retry(spec, result.attempts, exc):
return False
_run_hooks(spec.hooks, "on_failure", spec, exc)
_TaskRetryMixin._finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return True
# ---------------------------------------------------------------------- #
# 任务执行器:同步 / 异步(复用 _TaskSkipMixin + _TaskRetryMixin
# 任务执行器:同步 / 异步(调用模块级跳过/重试函数
# ---------------------------------------------------------------------- #
class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
class SyncTaskRunner:
"""同步任务执行器:带重试与跳过预检。"""
@staticmethod
@@ -300,7 +314,7 @@ class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
on_event: EventCallback | None = None,
report: RunReport | None = None,
) -> TaskResult[Any]:
skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event)
skipped = _prepare_for_execution(spec, context, report, on_event)
if skipped is not None:
return skipped
@@ -309,23 +323,24 @@ class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
args, kwargs = build_call_args(spec, context)
_run_hooks(spec.hooks, "pre_run", spec)
_emit_running(on_event, spec)
while True:
result.attempts += 1
try:
with spec.env_context():
value = spec.effective_fn(*args, **kwargs)
_TaskRetryMixin._mark_success(spec, result, value)
_mark_success(spec, result, value)
return result
except Exception as exc:
if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event):
if _handle_failure(spec, result, exc, layer_idx, on_event):
return result
wait = spec.retry.wait_seconds(result.attempts)
if wait > 0:
time.sleep(wait)
class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
class AsyncTaskRunner:
"""异步任务执行器:在事件循环上运行同步或异步任务,带重试与跳过预检。"""
@staticmethod
@@ -337,7 +352,7 @@ class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
report: RunReport | None = None,
semaphore: asyncio.Semaphore | None = None,
) -> TaskResult[Any]:
skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event)
skipped = _prepare_for_execution(spec, context, report, on_event)
if skipped is not None:
return skipped
@@ -348,15 +363,16 @@ class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
loop = asyncio.get_event_loop()
_run_hooks(spec.hooks, "pre_run", spec)
_emit_running(on_event, spec)
while True:
result.attempts += 1
try:
value = await _execute_async_task(spec, args, kwargs, loop)
_TaskRetryMixin._mark_success(spec, result, value)
_mark_success(spec, result, value)
return result
except Exception as exc:
if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event):
if _handle_failure(spec, result, exc, layer_idx, on_event):
return result
wait = spec.retry.wait_seconds(result.attempts)
if wait > 0:
@@ -401,13 +417,19 @@ def _filter_and_sort(
backend: StateBackend,
on_event: EventCallback | None,
) -> list[str]:
"""过滤掉已命中缓存的任务,按优先级排序返回待运行列表。"""
"""过滤掉已命中缓存的任务,按优先级排序返回待运行列表。
预构建 ``{name: spec}`` 映射,过滤与排序共享同一份 resolved spec
避免 ``_sort_by_priority`` 内重复调用 ``graph.resolved_spec``。
"""
specs: dict[str, TaskSpec[Any]] = {}
to_run: list[str] = []
for name in layer:
spec = graph.resolved_spec(name)
specs[name] = spec
if not _apply_cached(name, spec, context, report, backend, on_event):
to_run.append(name)
return _sort_by_priority(to_run, graph)
return _sort_by_priority(to_run, specs)
def _store_result(
@@ -619,7 +641,7 @@ def _make_verbose_callback(on_event: EventCallback | None) -> EventCallback:
def _verbose_callback(event: TaskEvent) -> None:
dur = f" ({event.duration:.3f}s)" if event.duration is not None else ""
if event.status == TaskStatus.RUNNING: # pragma: no cover
if event.status == TaskStatus.RUNNING:
print(f"[verbose] 任务 {event.task!r} 开始执行...", flush=True)
elif event.status == TaskStatus.SUCCESS:
print(f"[verbose] 任务 {event.task!r} 成功{dur}", flush=True)
@@ -684,6 +706,10 @@ def run(
_print_dry_run(graph, layers)
return RunReport(success=True)
# 入口统一校验一次:所有策略共用,避免 layers() / dependency 路径
# 各自重复调用 validate()。
graph.validate()
effective_callback: EventCallback | None = _make_verbose_callback(on_event) if verbose else on_event
backend = resolve_backend(state)
report = RunReport()
@@ -705,7 +731,6 @@ def run(
layers = graph.layers()
asyncio.run(_async_drive(graph, layers, context, report, backend, effective_callback, limits))
elif strategy == "dependency":
graph.validate()
asyncio.run(DependencyRunner.execute(graph, context, report, backend, effective_callback, limits))
else:
raise ValueError(f"Unknown strategy: {strategy!r}")
+4 -1
View File
@@ -231,8 +231,11 @@ class Graph:
同层任务无相互硬依赖,可并发执行。软依赖不参与分层。
层按执行顺序返回。图有环时抛出 :class:`CycleError`。
.. note::
本方法假定图已通过 :meth:`validate` 校验(由 :func:`pyflowx.run`
在入口统一执行一次)。若直接调用本方法,需自行先校验。
"""
self.validate()
sorter = _TopologicalSorter(self.deps)
result: list[list[str]] = []
sorter.prepare()
-11
View File
@@ -175,13 +175,6 @@ class MemoryBackend(_TTLStateBackendMixin):
def _clear_raw(self) -> None:
self._store.clear()
def _expired(self, key: str) -> bool:
"""键是否已过期(兼容旧测试 API)。"""
entry = self._get_raw(key)
if entry is None:
return False
return self._is_expired(entry[1])
class JSONBackend(_TTLStateBackendMixin):
"""基于文件的 JSON 存储,用于跨进程续跑。
@@ -283,10 +276,6 @@ class JSONBackend(_TTLStateBackendMixin):
self._defer_flush = False
self._flush()
def _expired(self, entry: Mapping[str, Any]) -> bool:
"""带元数据的条目是否已过期(兼容旧测试 API)。"""
return self._is_expired(float(entry.get("ts", 0)))
def resolve_backend(backend: StateBackend | None) -> StateBackend:
"""返回 ``backend``;为 ``None`` 时返回新的 :class:`MemoryBackend`。"""
+14 -3
View File
@@ -17,6 +17,7 @@
from __future__ import annotations
import logging
import os
import shutil
import sys
@@ -68,6 +69,8 @@ TaskCmd = Union[
Strategy = Union[str, "StrategyKind"]
StrategyKind = Any # 占位,避免循环;executors 模块用 Literal 约束
logger = logging.getLogger("pyflowx")
# 条件判断函数类型:接收依赖上下文(可能为空映射),返回是否应执行。
Condition = Callable[[Context], bool]
@@ -378,11 +381,19 @@ class TaskSpec(Generic[T]):
def storage_key(self, context: Context) -> str:
"""计算状态后端存储键。"""
if self.cache_key is not None:
if self.cache_key is None:
return self.name
try:
return f"{self.name}:{self.cache_key(context)}"
except Exception:
return self.name
except (TypeError, ValueError, KeyError, AttributeError) as exc:
# cache_key 抛出预期内的数据/类型异常时回退到 name,但仍记录警告
# 以便用户发现 cache_key 实现中的 bug。
logger.warning(
"task %r: cache_key 回退到 name%s: %s",
self.name,
type(exc).__name__,
exc,
)
return self.name
+7 -3
View File
@@ -99,7 +99,10 @@ def test_verbose_run_with_skipped_lifecycle(capsys: pytest.CaptureFixture[str]):
def test_verbose_run_with_user_callback():
"""Test px.run with verbose=True and user callback both called."""
"""Test px.run with verbose=True and user callback both called.
预期事件序列:RUNNING(开始)→ SUCCESS(完成)。
"""
events = []
def on_event(event: px.TaskEvent):
@@ -109,8 +112,9 @@ def test_verbose_run_with_user_callback():
graph = px.Graph.from_specs([spec])
report = px.run(graph, strategy="sequential", verbose=True, on_event=on_event)
assert report.success
assert len(events) == 1
assert events[0].status == px.TaskStatus.SUCCESS
assert len(events) == 2
assert events[0].status == px.TaskStatus.RUNNING
assert events[1].status == px.TaskStatus.SUCCESS
def test_verbose_event_callback_success():
+8 -8
View File
@@ -70,9 +70,9 @@ def test_memory_backend_ttl_load_filters_expired() -> None:
def test_memory_backend_expired_key_not_in_store() -> None:
"""_expired 对不存在键返回 False."""
"""不存在的键 has 返回 False."""
b = MemoryBackend(ttl=1.0)
assert b._expired("nonexistent") is False
assert b.has("nonexistent") is False
def test_memory_backend_no_ttl_never_expired() -> None:
@@ -244,35 +244,35 @@ def test_json_backend_ttl_load_filters_expired() -> None:
def test_json_backend_expired_no_ttl() -> None:
"""无 TTL 时 _expired 返回 False."""
"""无 TTL 时永不过期."""
with tempfile.TemporaryDirectory() as tmp:
path = str(Path(tmp) / "state.json")
b = JSONBackend(path)
b.save("a", 1)
# 手动修改 ts 为很久以前
b._store["a"]["ts"] = time.time() - 1000
assert b._expired(b._store["a"]) is False # 无 TTL,永不过期
assert b.has("a") is True # 无 TTL,永不过期
def test_json_backend_expired_with_ttl() -> None:
"""有 TTL 时 _expired 检查是否过期."""
"""有 TTL 时过期键 has 返回 False."""
with tempfile.TemporaryDirectory() as tmp:
path = str(Path(tmp) / "state.json")
b = JSONBackend(path, ttl=1.0)
b.save("a", 1)
# 手动修改 ts 为很久以前
b._store["a"]["ts"] = time.time() - 10 # 10 秒前,超过 TTL
assert b._expired(b._store["a"]) is True
assert b.has("a") is False
def test_json_backend_expired_missing_ts() -> None:
"""entry 缺少 ts 时使用默认值 0."""
"""entry 缺少 ts 时视为过期."""
with tempfile.TemporaryDirectory() as tmp:
path = str(Path(tmp) / "state.json")
b = JSONBackend(path, ttl=1.0)
b._store["a"] = {"value": 1} # 缺少 ts
# ts 默认为 0,已经过了很久
assert b._expired(b._store["a"]) is True
assert b.has("a") is False
def test_json_backend_save_value_error(monkeypatch: pytest.MonkeyPatch) -> None:
+2 -2
View File
@@ -203,10 +203,10 @@ def test_is_cmd_available_callable_returns_true() -> None:
# storage_key 异常处理
# ---------------------------------------------------------------------- #
def test_storage_key_cache_key_exception_returns_name() -> None:
"""cache_key 抛异常时应返回任务名."""
"""cache_key 抛预期异常(TypeError/ValueError/KeyError/AttributeError时应返回任务名."""
def bad_cache_key(_ctx):
raise RuntimeError("cache key error")
raise ValueError("cache key error")
spec = TaskSpec("a", _fn, cache_key=bad_cache_key)
key = spec.storage_key({})