chore: 发布 pyflowx 0.2.10,新增性能计时器与多项重构

1. 新增 perf_timer 工具与配套测试用例
2. 重构任务条件跳过逻辑,优化失败条件展示
3. 重构 Graph 子图生成逻辑,提取公共依赖修剪函数
4. 重构条件模块,统一条件名称与失败原因获取逻辑
5. 重构存储后端,提取 TTL 共享逻辑并优化实现
6. 重构执行器模块,使用 Mixin 复用代码,拆分任务与层执行逻辑
7. 删除冗余的 which 命令测试文件
8. 更新依赖锁文件
This commit is contained in:
2026-06-27 20:15:35 +08:00
parent c3b86b603d
commit d58fc5536e
9 changed files with 701 additions and 607 deletions
+27 -13
View File
@@ -42,6 +42,19 @@ def _static(predicate: Callable[[], bool], name: str) -> Condition:
return _cond return _cond
def _cond_reason(cond: Condition) -> str | list[str] | None:
"""获取条件的失败原因:优先返回 ``_reason``,否则返回 ``__name__``。"""
reason = getattr(cond, "_reason", None)
if reason is not None:
return reason
return getattr(cond, "__name__", repr(cond))
def _cond_name(cond: Condition) -> str:
"""获取条件的可读名称。"""
return getattr(cond, "__name__", repr(cond))
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# 模块级静态条件常量 # 模块级静态条件常量
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
@@ -61,21 +74,25 @@ class BuiltinConditions:
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
# 静态条件 # 静态条件
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
@staticmethod
def IS_WINDOWS() -> Condition: def IS_WINDOWS() -> Condition:
"""检查是否为 Windows 平台.""" """检查是否为 Windows 平台."""
return _static(lambda: Constants.IS_WINDOWS, "IS_WINDOWS") return IS_WINDOWS
@staticmethod
def IS_LINUX() -> Condition: def IS_LINUX() -> Condition:
"""检查是否为 Linux 平台.""" """检查是否为 Linux 平台."""
return _static(lambda: Constants.IS_LINUX, "IS_LINUX") return IS_LINUX
@staticmethod
def IS_MACOS() -> Condition: def IS_MACOS() -> Condition:
"""检查是否为 macOS 平台.""" """检查是否为 macOS 平台."""
return _static(lambda: Constants.IS_MACOS, "IS_MACOS") return IS_MACOS
@staticmethod
def IS_POSIX() -> Condition: def IS_POSIX() -> Condition:
"""检查是否为 POSIX 平台.""" """检查是否为 POSIX 平台."""
return _static(lambda: Constants.IS_POSIX, "IS_POSIX") return IS_POSIX
@staticmethod @staticmethod
def PYTHON_VERSION(major: int, minor: int | None = None) -> Condition: def PYTHON_VERSION(major: int, minor: int | None = None) -> Condition:
@@ -214,12 +231,12 @@ class BuiltinConditions:
result = condition(ctx) result = condition(ctx)
if result: if result:
# inner 为 True 时 NOT 会失败,记录 inner 的具体原因 # inner 为 True 时 NOT 会失败,记录 inner 的具体原因
inner_reason = getattr(condition, "_reason", None) inner_reason = _cond_reason(condition)
if inner_reason is not None: if inner_reason is not None:
_cond._reason = inner_reason # type: ignore[attr-defined] _cond._reason = inner_reason # type: ignore[attr-defined]
return not result return not result
_cond.__name__ = f"NOT({getattr(condition, '__name__', repr(condition))})" _cond.__name__ = f"NOT({_cond_name(condition)})"
return _cond return _cond
@staticmethod @staticmethod
@@ -229,8 +246,7 @@ class BuiltinConditions:
def _cond(ctx: Context) -> bool: def _cond(ctx: Context) -> bool:
return all(c(ctx) for c in conditions) return all(c(ctx) for c in conditions)
names = [getattr(c, "__name__", repr(c)) for c in conditions] _cond.__name__ = f"AND({', '.join(_cond_name(c) for c in conditions)})"
_cond.__name__ = f"AND({', '.join(names)})"
return _cond return _cond
@staticmethod @staticmethod
@@ -241,14 +257,12 @@ class BuiltinConditions:
matched: list[str] = [] matched: list[str] = []
for c in conditions: for c in conditions:
if c(ctx): if c(ctx):
matched.append( reason = _cond_reason(c)
getattr(c, "_reason", None) or getattr(c, "__name__", repr(c)), matched.append(reason if isinstance(reason, str) else str(reason))
)
if matched: if matched:
_cond._reason = matched # type: ignore[attr-defined] _cond._reason = matched # type: ignore[attr-defined]
return True return True
return False return False
names = [getattr(c, "__name__", repr(c)) for c in conditions] _cond.__name__ = f"OR({', '.join(_cond_name(c) for c in conditions)})"
_cond.__name__ = f"OR({', '.join(names)})"
return _cond return _cond
+305 -350
View File
@@ -10,6 +10,17 @@
* ``dependency`` —— 依赖驱动调度:任务在其所有硬依赖完成后立即启动, * ``dependency`` —— 依赖驱动调度:任务在其所有硬依赖完成后立即启动,
无需等待同层其他任务。最大化并行度。 无需等待同层其他任务。最大化并行度。
架构
----
本模块通过 **Mixin** 组合消除同步/异步与各层执行器之间的重复代码:
* :class:`_TaskSkipMixin` —— 上游跳过 / 条件跳过的预检逻辑。
* :class:`_TaskRetryMixin` —— 重试决策、成功/失败后处理、finalize。
* :class:`_LayerMixin` —— 缓存过滤、优先级排序、信号量构建、结果存储。
* :class:`SyncTaskRunner` / :class:`AsyncTaskRunner` —— 任务级执行器,组合上述 Mixin。
* :class:`SequentialLayerRunner` / :class:`ThreadedLayerRunner` /
:class:`AsyncLayerRunner` / :class:`DependencyRunner` —— 层级执行器,组合 :class:`_LayerMixin`。
所有策略共享统一异步内核,支持: 所有策略共享统一异步内核,支持:
* :class:`RetryPolicy`max_attempts/delay/backoff/jitter/retry_on * :class:`RetryPolicy`max_attempts/delay/backoff/jitter/retry_on
* 软依赖注入与默认值 * 软依赖注入与默认值
@@ -30,6 +41,7 @@ import concurrent.futures
import inspect import inspect
import logging import logging
import threading import threading
import time
from datetime import datetime from datetime import datetime
from typing import Any, Awaitable, Callable, Literal, Mapping, cast from typing import Any, Awaitable, Callable, Literal, Mapping, cast
@@ -48,7 +60,7 @@ Strategy = Literal["sequential", "thread", "async", "dependency"]
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# 辅助 # 无状态公共辅助
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
def _is_async_fn(spec: TaskSpec[Any]) -> bool: def _is_async_fn(spec: TaskSpec[Any]) -> bool:
"""判断 ``spec.effective_fn`` 是否为协程函数。""" """判断 ``spec.effective_fn`` 是否为协程函数。"""
@@ -71,17 +83,6 @@ def _emit(on_event: EventCallback | None, result: TaskResult[Any]) -> None:
) )
def _log_retry(spec: TaskSpec[Any], attempt: int, max_attempts: int, exc: BaseException) -> None:
"""记录重试日志。"""
logger.warning(
"task %r failed (attempt %d/%d): %r; retrying",
spec.name,
attempt,
max_attempts,
exc,
)
def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None: def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None:
"""安全调用钩子(异常仅记录,不影响任务状态)。""" """安全调用钩子(异常仅记录,不影响任务状态)。"""
hook: Callable[..., None] | None = getattr(hooks, fn_name, None) hook: Callable[..., None] | None = getattr(hooks, fn_name, None)
@@ -93,87 +94,6 @@ def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None:
logger.warning("hook %s raised: %r", fn_name, exc) logger.warning("hook %s raised: %r", fn_name, exc)
def _check_upstream_skipped(
spec: TaskSpec[Any],
report: RunReport | None,
) -> tuple[bool, str | None]:
"""检查硬依赖上游任务是否被 SKIPPED 或 FAILED。
软依赖不影响本检查——软依赖被跳过时注入默认值。
"""
if report is None: # pragma: no cover
return False, None # pragma: no cover
if spec.allow_upstream_skip: # pragma: no cover
return False, None # pragma: no cover
for dep in spec.depends_on:
if dep not in report.results: # pragma: no cover
continue # pragma: no cover
dep_status = report.results[dep].status
if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED):
return True, f"上游任务 '{dep}' 状态为 {dep_status.value}"
return False, None # pragma: no cover
def _format_reason(reason: Any) -> str:
"""将 _reason 格式化为可读字符串."""
if isinstance(reason, list):
return ", ".join(str(r) for r in reason)
return str(reason)
def _evaluate_conditions(spec: TaskSpec[Any], context: Mapping[str, Any]) -> str | None:
"""求值所有条件,返回跳过原因或 ``None``。
条件接收上下文映射(硬依赖 + 软依赖结果)。
"""
failed_conditions: list[str] = []
for condition in spec.conditions:
try:
ok = condition(context)
except Exception:
ok = False
name = getattr(condition, "__name__", None) or "匿名条件(执行错误)"
failed_conditions.append(name)
continue
if not ok:
reason = getattr(condition, "_reason", None)
if reason is not None:
failed_conditions.append(_format_reason(reason))
else:
failed_conditions.append(getattr(condition, "__name__", None) or "匿名条件")
if failed_conditions:
if len(failed_conditions) <= 2:
return f"条件不满足: {', '.join(failed_conditions)}"
return f"条件不满足: {', '.join(failed_conditions[:2])}{len(failed_conditions)}个条件"
if spec.skip_if_missing and not spec._is_cmd_available():
cmd_name = spec.cmd[0] if isinstance(spec.cmd, list) and spec.cmd else "unknown"
return f"命令不存在: {cmd_name}"
return None
def _make_skipped_result(
spec: TaskSpec[Any],
reason: str,
on_event: EventCallback | None,
) -> TaskResult[Any]:
"""构造 SKIPPED 的 TaskResult。"""
result: TaskResult[Any] = TaskResult(
spec=spec,
status=TaskStatus.SKIPPED,
finished_at=datetime.now(),
reason=reason,
)
_emit(on_event, result)
logger.info("task %r skipped (%s)", spec.name, reason)
return result
def _build_context( def _build_context(
spec: TaskSpec[Any], spec: TaskSpec[Any],
global_context: Mapping[str, Any], global_context: Mapping[str, Any],
@@ -185,19 +105,16 @@ def _build_context(
软依赖:上游成功则注入其值;否则注入 ``spec.defaults`` 中的默认值(或 ``None``)。 软依赖:上游成功则注入其值;否则注入 ``spec.defaults`` 中的默认值(或 ``None``)。
""" """
ctx: dict[str, Any] = {} ctx: dict[str, Any] = {}
for dep in spec.depends_on: for dep in spec.depends_on:
if dep in global_context: if dep in global_context:
ctx[dep] = global_context[dep] ctx[dep] = global_context[dep]
for dep in spec.soft_depends_on: for dep in spec.soft_depends_on:
if dep in global_context: if dep in global_context:
ctx[dep] = global_context[dep] ctx[dep] = global_context[dep]
elif dep in spec.defaults: # pragma: no cover elif dep in spec.defaults:
ctx[dep] = spec.defaults[dep] # pragma: no cover ctx[dep] = spec.defaults[dep]
else: else:
ctx[dep] = None ctx[dep] = None
return ctx return ctx
@@ -222,33 +139,93 @@ def _apply_cached(
return True return True
def _prepare_for_execution( def _sort_by_priority(layer: list[str], graph: Graph) -> list[str]:
"""按优先级降序排序(稳定排序)。"""
return sorted(layer, key=lambda n: -graph.resolved_spec(n).priority)
# ---------------------------------------------------------------------- #
# Mixin:任务级跳过 / 重试 / 成功处理
# ---------------------------------------------------------------------- #
class _TaskSkipMixin:
"""任务级跳过预检共享逻辑。
"上游被跳过/失败""条件不满足"两类跳过判断统一为单一入口,
被 :class:`SyncTaskRunner` 与 :class:`AsyncTaskRunner` 复用。
"""
@staticmethod
def _upstream_skip_reason(spec: TaskSpec[Any], report: RunReport | None) -> str | None:
"""硬依赖被 SKIPPED/FAILED 时返回原因字符串,否则 ``None``。
软依赖不影响本检查——软依赖被跳过时注入默认值。
"""
if report is None or spec.allow_upstream_skip:
return None
for dep in spec.depends_on:
if dep not in report.results:
continue
dep_status = report.results[dep].status
if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED):
return f"上游任务 '{dep}' 状态为 {dep_status.value}"
return None
@staticmethod
def _prepare_for_execution(
spec: TaskSpec[Any], spec: TaskSpec[Any],
context: Mapping[str, Any], context: Mapping[str, Any],
report: RunReport | None, report: RunReport | None,
on_event: EventCallback | None, on_event: EventCallback | None,
) -> TaskResult[Any] | None: ) -> TaskResult[Any] | None:
"""执行前预检:上游跳过 / 条件跳过。 """执行前预检:上游跳过 / 条件跳过。
返回 SKIPPED TaskResult 或 ``None``(继续执行)。 返回 SKIPPED TaskResult 或 ``None``(继续执行)。
条件判断委托给 :meth:`TaskSpec.should_execute`,避免重复实现。
""" """
should_skip, skip_reason = _check_upstream_skipped(spec, report) # 1. 上游被跳过/失败
if should_skip: skip_reason = _TaskSkipMixin._upstream_skip_reason(spec, report)
return _make_skipped_result(spec, skip_reason or "上游任务被跳过", on_event) # 2. 条件 / skip_if_missing(单一来源:TaskSpec.should_execute
if skip_reason is None:
skip_reason = _evaluate_conditions(spec, context) should_run, cond_reason = spec.should_execute(context)
if skip_reason is not None: if not should_run:
return _make_skipped_result(spec, skip_reason, on_event) skip_reason = cond_reason or "条件不满足"
if skip_reason is None:
return None return None
# 构造 SKIPPED 结果
result: TaskResult[Any] = TaskResult(
spec=spec,
status=TaskStatus.SKIPPED,
finished_at=datetime.now(),
reason=skip_reason,
)
_emit(on_event, result)
logger.info("task %r skipped (%s)", spec.name, skip_reason)
return result
def _finalize_failure( class _TaskRetryMixin:
"""任务级重试决策与失败/成功后处理共享逻辑。"""
@staticmethod
def _should_retry(spec: TaskSpec[Any], attempts: int, exc: BaseException) -> bool:
"""是否应继续重试。"""
return attempts < spec.retry.max_attempts and spec.retry.should_retry(exc)
@staticmethod
def _mark_success(spec: TaskSpec[Any], result: TaskResult[Any], value: Any) -> None:
"""标记任务成功并触发 post_run 钩子。"""
result.value = value
result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now()
_run_hooks(spec.hooks, "post_run", spec, value)
@staticmethod
def _finalize_failure(
result: TaskResult[Any], result: TaskResult[Any],
layer_idx: int | None, layer_idx: int | None,
on_event: EventCallback | None = None, on_event: EventCallback | None,
continue_on_error: bool = False, continue_on_error: bool,
) -> None: ) -> None:
"""标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。""" """标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。"""
result.status = TaskStatus.FAILED result.status = TaskStatus.FAILED
result.finished_at = datetime.now() result.finished_at = datetime.now()
@@ -266,41 +243,66 @@ def _finalize_failure(
layer=layer_idx, layer=layer_idx,
) )
@staticmethod
def _handle_failure(
spec: TaskSpec[Any],
result: TaskResult[Any],
exc: BaseException,
layer_idx: int | None,
on_event: EventCallback | None,
) -> bool:
"""统一处理失败:超时转换、重试决策、finalize。
def _sleep_for_retry(spec: TaskSpec[Any], attempt: int) -> None: Returns
"""重试前的同步等待。""" -------
wait = spec.retry.wait_seconds(attempt) bool
if wait > 0: ``True`` 表示已 finalize(不再重试);``False`` 表示应继续重试。
import time """
# asyncio.TimeoutError → TaskTimeoutError(统一异常类型)
time.sleep(wait) if isinstance(exc, asyncio.TimeoutError):
exc = TaskTimeoutError(spec.name, spec.timeout or 0.0)
logger.warning(
async def _async_sleep_for_retry(spec: TaskSpec[Any], attempt: int) -> None: "task %r timed out (attempt %d/%d); retrying",
"""重试前的异步等待。""" spec.name,
wait = spec.retry.wait_seconds(attempt) result.attempts,
if wait > 0: spec.retry.max_attempts,
await asyncio.sleep(wait) )
else:
logger.warning(
"task %r failed (attempt %d/%d): %r; retrying",
spec.name,
result.attempts,
spec.retry.max_attempts,
exc,
)
result.error = exc
if _TaskRetryMixin._should_retry(spec, result.attempts, exc):
return False
_run_hooks(spec.hooks, "on_failure", spec, exc)
_TaskRetryMixin._finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return True
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# 同步执行内核 # 任务执行器:同步 / 异步(复用 _TaskSkipMixin + _TaskRetryMixin
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
def _run_sync_with_retry( class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
"""同步任务执行器:带重试与跳过预检。"""
@staticmethod
def run(
spec: TaskSpec[Any], spec: TaskSpec[Any],
context: Mapping[str, Any], context: Mapping[str, Any],
layer_idx: int | None, layer_idx: int | None,
on_event: EventCallback | None = None, on_event: EventCallback | None = None,
report: RunReport | None = None, report: RunReport | None = None,
) -> TaskResult[Any]: ) -> TaskResult[Any]:
"""执行同步任务并带重试;返回填充好的 TaskResult。""" skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event)
skipped = _prepare_for_execution(spec, context, report, on_event)
if skipped is not None: if skipped is not None:
return skipped return skipped
result: TaskResult[Any] = TaskResult(spec=spec) result: TaskResult[Any] = TaskResult(spec=spec)
result.started_at = datetime.now() result.started_at = datetime.now()
max_attempts = spec.retry.max_attempts
args, kwargs = build_call_args(spec, context) args, kwargs = build_call_args(spec, context)
_run_hooks(spec.hooks, "pre_run", spec) _run_hooks(spec.hooks, "pre_run", spec)
@@ -309,25 +311,60 @@ def _run_sync_with_retry(
result.attempts += 1 result.attempts += 1
try: try:
with spec.env_context(): with spec.env_context():
result.value = spec.effective_fn(*args, **kwargs) value = spec.effective_fn(*args, **kwargs)
result.status = TaskStatus.SUCCESS _TaskRetryMixin._mark_success(spec, result, value)
result.finished_at = datetime.now()
_run_hooks(spec.hooks, "post_run", spec, result.value)
return result return result
except Exception as exc: except Exception as exc:
result.error = exc if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event):
if result.attempts >= max_attempts or not spec.retry.should_retry(exc):
_run_hooks(spec.hooks, "on_failure", spec, exc)
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return result return result
_log_retry(spec, result.attempts, max_attempts, exc) wait = spec.retry.wait_seconds(result.attempts)
_sleep_for_retry(spec, result.attempts) if wait > 0:
# pragma: no cover time.sleep(wait)
class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
"""异步任务执行器:在事件循环上运行同步或异步任务,带重试与跳过预检。"""
@staticmethod
async def run(
spec: TaskSpec[Any],
context: Mapping[str, Any],
layer_idx: int | None,
on_event: EventCallback | None = None,
report: RunReport | None = None,
semaphore: asyncio.Semaphore | None = None,
) -> TaskResult[Any]:
skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event)
if skipped is not None:
return skipped
async def _inner() -> TaskResult[Any]:
result: TaskResult[Any] = TaskResult(spec=spec)
result.started_at = datetime.now()
args, kwargs = build_call_args(spec, context)
loop = asyncio.get_event_loop()
_run_hooks(spec.hooks, "pre_run", spec)
while True:
result.attempts += 1
try:
value = await _execute_async_task(spec, args, kwargs, loop)
_TaskRetryMixin._mark_success(spec, result, value)
return result
except Exception as exc:
if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event):
return result
wait = spec.retry.wait_seconds(result.attempts)
if wait > 0:
await asyncio.sleep(wait)
if semaphore is not None:
async with semaphore:
return await _inner()
return await _inner()
# ---------------------------------------------------------------------- #
# 异步执行内核
# ---------------------------------------------------------------------- #
async def _execute_async_task( async def _execute_async_task(
spec: TaskSpec[Any], spec: TaskSpec[Any],
args: tuple[Any, ...], args: tuple[Any, ...],
@@ -339,9 +376,7 @@ async def _execute_async_task(
coro = cast(Awaitable[Any], spec.effective_fn(*args, **kwargs)) coro = cast(Awaitable[Any], spec.effective_fn(*args, **kwargs))
if spec.timeout is not None: if spec.timeout is not None:
return await asyncio.wait_for(coro, timeout=spec.timeout) return await asyncio.wait_for(coro, timeout=spec.timeout)
else:
return await coro return await coro
else:
def fn_call() -> Any: def fn_call() -> Any:
with spec.env_context(): with spec.env_context():
@@ -349,87 +384,89 @@ async def _execute_async_task(
if spec.timeout is not None: if spec.timeout is not None:
return await asyncio.wait_for(loop.run_in_executor(None, fn_call), timeout=spec.timeout) return await asyncio.wait_for(loop.run_in_executor(None, fn_call), timeout=spec.timeout)
else:
return await loop.run_in_executor(None, fn_call) return await loop.run_in_executor(None, fn_call)
async def _run_async_with_retry( # ---------------------------------------------------------------------- #
spec: TaskSpec[Any], # Mixin:层执行共享逻辑
context: Mapping[str, Any], # ---------------------------------------------------------------------- #
layer_idx: int | None, class _LayerMixin:
on_event: EventCallback | None = None, """层执行共享逻辑:缓存过滤、优先级排序、信号量构建、结果存储。
report: RunReport | None = None,
semaphore: asyncio.Semaphore | None = None,
) -> TaskResult[Any]:
"""在事件循环上执行任务(同步或异步)并带重试。"""
skipped = _prepare_for_execution(spec, context, report, on_event)
if skipped is not None:
return skipped
if semaphore is not None: 四个层执行器(sequential/threaded/async/dependency)通过组合此 Mixin
async with semaphore: 消除"过滤缓存→排序→运行→存结果"的样板代码。
return await _run_async_inner(spec, context, layer_idx, on_event, report) """
return await _run_async_inner(spec, context, layer_idx, on_event, report)
@staticmethod
def _filter_and_sort(
layer: list[str],
graph: Graph,
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
) -> list[str]:
"""过滤掉已命中缓存的任务,按优先级排序返回待运行列表。"""
to_run: list[str] = []
for name in layer:
spec = graph.resolved_spec(name)
if not _apply_cached(name, spec, context, report, backend, on_event):
to_run.append(name)
return _sort_by_priority(to_run, graph)
async def _run_async_inner( @staticmethod
spec: TaskSpec[Any], def _store_result(
context: Mapping[str, Any], name: str,
layer_idx: int | None, result: TaskResult[Any],
on_event: EventCallback | None = None, graph: Graph,
report: RunReport | None = None, # noqa: ARG001 context: dict[str, Any],
) -> TaskResult[Any]: report: RunReport,
"""异步执行内核的内部实现(已获取 semaphore 后)。""" backend: StateBackend,
result: TaskResult[Any] = TaskResult(spec=spec) on_event: EventCallback | None,
result.started_at = datetime.now() context_snapshot: Mapping[str, Any] | None = None,
max_attempts = spec.retry.max_attempts ) -> None:
args, kwargs = build_call_args(spec, context) """存储任务结果到 context/report/backend 并触发事件。"""
loop = asyncio.get_event_loop() context[name] = result.value
if result.status == TaskStatus.SUCCESS:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot if context_snapshot is not None else context, report)
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
_run_hooks(spec.hooks, "pre_run", spec) @staticmethod
def _build_semaphores(
to_run: list[str],
graph: Graph,
sem_factory: Callable[[int], Any],
concurrency_limits: Mapping[str, int],
) -> dict[str, Any]:
"""为每个 ``concurrency_key`` 创建一个信号量。"""
semaphores: dict[str, Any] = {}
for name in to_run:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = sem_factory(limit)
return semaphores
while True: @staticmethod
result.attempts += 1 def _get_sem(semaphores: Mapping[str, Any], spec: TaskSpec[Any]) -> Any | None:
try: """获取任务对应的信号量(无 concurrency_key 则返回 None)。"""
result.value = await _execute_async_task(spec, args, kwargs, loop) if spec.concurrency_key is None:
result.status = TaskStatus.SUCCESS return None
result.finished_at = datetime.now() return semaphores.get(spec.concurrency_key)
_run_hooks(spec.hooks, "post_run", spec, result.value)
return result
except asyncio.TimeoutError:
exc: BaseException = TaskTimeoutError(spec.name, spec.timeout or 0.0)
result.error = exc
if result.attempts >= max_attempts or not spec.retry.should_retry(exc):
_run_hooks(spec.hooks, "on_failure", spec, exc)
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return result
logger.warning(
"task %r timed out (attempt %d/%d); retrying",
spec.name,
result.attempts,
max_attempts,
)
await _async_sleep_for_retry(spec, result.attempts)
except Exception as exc:
result.error = exc
if result.attempts >= max_attempts or not spec.retry.should_retry(exc):
_run_hooks(spec.hooks, "on_failure", spec, exc)
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return result
_log_retry(spec, result.attempts, max_attempts, exc)
await _async_sleep_for_retry(spec, result.attempts)
# pragma: no cover
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# 层执行器 # 层执行器
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
def _sort_by_priority(layer: list[str], graph: Graph) -> list[str]: class SequentialLayerRunner(_LayerMixin):
"""按优先级降序排序(稳定排序)。""" """逐个运行某层的任务(按优先级排序)。"""
return sorted(layer, key=lambda n: -graph.resolved_spec(n).priority)
@staticmethod
def _execute_layer_sequential( def execute(
layer: list[str], layer: list[str],
graph: Graph, graph: Graph,
context: dict[str, Any], context: dict[str, Any],
@@ -437,22 +474,19 @@ def _execute_layer_sequential(
backend: StateBackend, backend: StateBackend,
layer_idx: int, layer_idx: int,
on_event: EventCallback | None, on_event: EventCallback | None,
) -> None: ) -> None:
"""逐个运行某层的任务(按优先级排序)。""" for name in SequentialLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event):
for name in _sort_by_priority(layer, graph):
spec = graph.resolved_spec(name) spec = graph.resolved_spec(name)
if _apply_cached(name, spec, context, report, backend, on_event):
continue
task_ctx = _build_context(spec, context, report) task_ctx = _build_context(spec, context, report)
result = _run_sync_with_retry(spec, task_ctx, layer_idx, on_event, report) result = SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
context[name] = result.value SequentialLayerRunner._store_result(name, result, graph, context, report, backend, on_event)
if result.status == TaskStatus.SUCCESS:
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
def _execute_layer_threaded( class ThreadedLayerRunner(_LayerMixin):
"""在线程池中并发运行某层的任务。"""
@staticmethod
def execute(
layer: list[str], layer: list[str],
graph: Graph, graph: Graph,
context: dict[str, Any], context: dict[str, Any],
@@ -462,70 +496,48 @@ def _execute_layer_threaded(
on_event: EventCallback | None, on_event: EventCallback | None,
max_workers: int, max_workers: int,
concurrency_limits: Mapping[str, int], concurrency_limits: Mapping[str, int],
) -> None: ) -> None:
"""在线程池中并发运行某层的任务。""" to_run = ThreadedLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event)
to_run: list[str] = []
for name in layer:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context, report)
if _apply_cached(name, spec, context, report, backend, on_event):
continue
to_run.append(name)
if not to_run: if not to_run:
return return
semaphores = ThreadedLayerRunner._build_semaphores(to_run, graph, threading.Semaphore, concurrency_limits)
to_run = _sort_by_priority(to_run, graph)
# 为每个 concurrency_key 创建线程信号量
semaphores: dict[str, threading.Semaphore] = {}
for name in to_run:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = threading.Semaphore(limit)
context_snapshot = dict(context) context_snapshot = dict(context)
lock = threading.Lock() lock = threading.Lock()
def _run_threaded_task(name: str) -> TaskResult[Any]: def _run_threaded_task(name: str) -> TaskResult[Any]:
spec = graph.resolved_spec(name) spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report) task_ctx = _build_context(spec, context_snapshot, report)
sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None sem = ThreadedLayerRunner._get_sem(semaphores, spec)
if sem is not None: if sem is not None:
sem.acquire() sem.acquire()
try: try:
return _run_sync_with_retry(spec, task_ctx, layer_idx, on_event, report) return SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
finally: finally:
if sem is not None: if sem is not None:
sem.release() sem.release()
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool: with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
future_to_name: dict[concurrent.futures.Future[TaskResult[Any]], str] = {} future_to_name: dict[concurrent.futures.Future[TaskResult[Any]], str] = {
for name in to_run: pool.submit(_run_threaded_task, name): name for name in to_run
fut = pool.submit(_run_threaded_task, name) }
future_to_name[fut] = name
completed: dict[str, TaskResult[Any]] = {} completed: dict[str, TaskResult[Any]] = {}
try: try:
for fut in concurrent.futures.as_completed(future_to_name): for fut in concurrent.futures.as_completed(future_to_name):
name = future_to_name[fut] name = future_to_name[fut]
result = fut.result() completed[name] = fut.result()
completed[name] = result
finally: finally:
with lock: with lock:
for name, result in completed.items(): for name, result in completed.items():
context[name] = result.value ThreadedLayerRunner._store_result(
if result.status == TaskStatus.SUCCESS: name, result, graph, context, report, backend, on_event, context_snapshot
spec = graph.resolved_spec(name) )
task_ctx = _build_context(spec, context_snapshot, report)
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
async def _execute_layer_async( class AsyncLayerRunner(_LayerMixin):
"""在事件循环上并发运行某层的任务。"""
@staticmethod
async def execute(
layer: list[str], layer: list[str],
graph: Graph, graph: Graph,
context: dict[str, Any], context: dict[str, Any],
@@ -534,76 +546,41 @@ async def _execute_layer_async(
layer_idx: int, layer_idx: int,
on_event: EventCallback | None, on_event: EventCallback | None,
concurrency_limits: Mapping[str, int], concurrency_limits: Mapping[str, int],
) -> None: ) -> None:
"""在事件循环上并发运行某层的任务。""" to_run = AsyncLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event)
to_run: list[str] = []
for name in layer:
spec = graph.resolved_spec(name)
if _apply_cached(name, spec, context, report, backend, on_event):
continue
to_run.append(name)
if not to_run: if not to_run:
return return
semaphores = AsyncLayerRunner._build_semaphores(to_run, graph, asyncio.Semaphore, concurrency_limits)
to_run = _sort_by_priority(to_run, graph)
# 为每个 concurrency_key 创建异步信号量
semaphores: dict[str, asyncio.Semaphore] = {}
for name in to_run:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = asyncio.Semaphore(limit)
context_snapshot = dict(context) context_snapshot = dict(context)
async def _run_async_task_wrapped(name: str) -> TaskResult[Any]: async def _run_async_task(name: str) -> TaskResult[Any]:
spec = graph.resolved_spec(name) spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report) task_ctx = _build_context(spec, context_snapshot, report)
sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None sem = AsyncLayerRunner._get_sem(semaphores, spec)
if sem is not None: return await AsyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report, sem)
async with sem:
return await _run_async_with_retry(spec, task_ctx, layer_idx, on_event, report)
return await _run_async_with_retry(spec, task_ctx, layer_idx, on_event, report)
coros = [_run_async_task_wrapped(name) for name in to_run] results = await asyncio.gather(*[_run_async_task(name) for name in to_run])
results = await asyncio.gather(*coros)
for name, result in zip(to_run, results): for name, result in zip(to_run, results):
context[name] = result.value AsyncLayerRunner._store_result(name, result, graph, context, report, backend, on_event, context_snapshot)
if result.status == TaskStatus.SUCCESS:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report)
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
# ---------------------------------------------------------------------- # class DependencyRunner(_LayerMixin):
# 依赖驱动调度 """依赖驱动调度:任务在硬/软依赖完成后立即启动,无层屏障。
# ---------------------------------------------------------------------- #
async def _drive_dependency_async( 所有任务通过 asyncio 并发调度。同步任务卸载到线程池。
"""
@staticmethod
async def execute(
graph: Graph, graph: Graph,
context: dict[str, Any], context: dict[str, Any],
report: RunReport, report: RunReport,
backend: StateBackend, backend: StateBackend,
on_event: EventCallback | None, on_event: EventCallback | None,
concurrency_limits: Mapping[str, int], concurrency_limits: Mapping[str, int],
) -> None: ) -> None:
"""依赖驱动调度:任务在硬依赖完成后立即启动,无层屏障。 all_names = list(graph.all_specs().keys())
semaphores = DependencyRunner._build_semaphores(all_names, graph, asyncio.Semaphore, concurrency_limits)
所有任务通过 asyncio 并发调度。同步任务卸载到线程池。
"""
all_names = set(graph.all_specs().keys())
semaphores: dict[str, asyncio.Semaphore] = {}
for name in all_names:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = asyncio.Semaphore(limit)
futures: dict[str, asyncio.Future[TaskResult[Any]]] = {} futures: dict[str, asyncio.Future[TaskResult[Any]]] = {}
async def _run_task(name: str) -> TaskResult[Any]: async def _run_task(name: str) -> TaskResult[Any]:
@@ -621,24 +598,14 @@ async def _drive_dependency_async(
if _apply_cached(name, spec, context, report, backend, on_event): if _apply_cached(name, spec, context, report, backend, on_event):
return report.results[name] return report.results[name]
sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None sem = DependencyRunner._get_sem(semaphores, spec)
if sem is not None: result = await AsyncTaskRunner.run(spec, task_ctx, None, on_event, report, sem)
async with sem: DependencyRunner._store_result(name, result, graph, context, report, backend, on_event)
result = await _run_async_with_retry(spec, task_ctx, None, on_event, report)
else:
result = await _run_async_with_retry(spec, task_ctx, None, on_event, report)
context[name] = result.value
if result.status == TaskStatus.SUCCESS:
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
return result return result
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
for name in all_names: for name in all_names:
futures[name] = loop.create_task(_run_task(name)) futures[name] = loop.create_task(_run_task(name))
await asyncio.gather(*futures.values()) await asyncio.gather(*futures.values())
@@ -729,9 +696,9 @@ def run(
elif strategy == "thread": elif strategy == "thread":
_drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers, limits) _drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers, limits)
elif strategy == "async": elif strategy == "async":
_drive_async(graph, layers, context, report, backend, effective_callback, limits) asyncio.run(_async_drive(graph, layers, context, report, backend, effective_callback, limits))
elif strategy == "dependency": elif strategy == "dependency":
asyncio.run(_drive_dependency_async(graph, context, report, backend, effective_callback, limits)) asyncio.run(DependencyRunner.execute(graph, context, report, backend, effective_callback, limits))
else: else:
raise ValueError(f"Unknown strategy: {strategy!r}") raise ValueError(f"Unknown strategy: {strategy!r}")
except TaskFailedError: except TaskFailedError:
@@ -759,7 +726,7 @@ def _drive_sequential(
on_event: EventCallback | None, on_event: EventCallback | None,
) -> None: ) -> None:
for idx, layer in enumerate(layers, 1): for idx, layer in enumerate(layers, 1):
_execute_layer_sequential(layer, graph, context, report, backend, idx, on_event) SequentialLayerRunner.execute(layer, graph, context, report, backend, idx, on_event)
def _drive_threaded( def _drive_threaded(
@@ -774,19 +741,7 @@ def _drive_threaded(
) -> None: ) -> None:
for idx, layer in enumerate(layers, 1): for idx, layer in enumerate(layers, 1):
workers = max_workers or max(1, min(32, len(layer))) workers = max_workers or max(1, min(32, len(layer)))
_execute_layer_threaded(layer, graph, context, report, backend, idx, on_event, workers, concurrency_limits) ThreadedLayerRunner.execute(layer, graph, context, report, backend, idx, on_event, workers, concurrency_limits)
def _drive_async(
graph: Graph,
layers: list[list[str]],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
concurrency_limits: Mapping[str, int],
) -> None:
asyncio.run(_async_drive(graph, layers, context, report, backend, on_event, concurrency_limits))
async def _async_drive( async def _async_drive(
@@ -799,4 +754,4 @@ async def _async_drive(
concurrency_limits: Mapping[str, int], concurrency_limits: Mapping[str, int],
) -> None: ) -> None:
for idx, layer in enumerate(layers, 1): for idx, layer in enumerate(layers, 1):
await _execute_layer_async(layer, graph, context, report, backend, idx, on_event, concurrency_limits) await AsyncLayerRunner.execute(layer, graph, context, report, backend, idx, on_event, concurrency_limits)
+19 -16
View File
@@ -49,6 +49,15 @@ class GraphDefaults:
verbose: bool = False verbose: bool = False
def _prune_deps(spec: TaskSpec[Any], keep: Callable[[str], bool]) -> TaskSpec[Any]:
"""返回新 spec,其 ``depends_on`` / ``soft_depends_on`` 仅保留 ``keep(dep)`` 为真的依赖。"""
return replace(
spec,
depends_on=tuple(d for d in spec.depends_on if keep(d)),
soft_depends_on=tuple(d for d in spec.soft_depends_on if keep(d)),
)
@dataclass @dataclass
class Graph: class Graph:
"""校验后的有向无环任务图。 """校验后的有向无环任务图。
@@ -225,16 +234,13 @@ class Graph:
def subgraph(self, tags: Iterable[str]) -> Graph: def subgraph(self, tags: Iterable[str]) -> Graph:
"""返回仅包含匹配任意标签的任务的新图。依赖边被修剪。""" """返回仅包含匹配任意标签的任务的新图。依赖边被修剪。"""
wanted: set[str] = set(tags) wanted: set[str] = set(tags)
kept: list[TaskSpec[Any]] = []
for spec in self.specs.values(): def _dep_kept(dep: str) -> bool:
if wanted & set(spec.tags): return dep in self.specs and bool(wanted & set(self.specs[dep].tags))
pruned_deps = tuple(
d for d in spec.depends_on if d in self.specs and (wanted & set(self.specs[d].tags)) kept: list[TaskSpec[Any]] = [
) _prune_deps(spec, _dep_kept) for spec in self.specs.values() if wanted & set(spec.tags)
pruned_soft = tuple( ]
d for d in spec.soft_depends_on if d in self.specs and (wanted & set(self.specs[d].tags))
)
kept.append(replace(spec, depends_on=pruned_deps, soft_depends_on=pruned_soft))
return Graph.from_specs(kept, defaults=self.defaults) return Graph.from_specs(kept, defaults=self.defaults)
def subgraph_by_names(self, names: Iterable[str]) -> Graph: def subgraph_by_names(self, names: Iterable[str]) -> Graph:
@@ -243,12 +249,9 @@ class Graph:
for n in wanted: for n in wanted:
if n not in self.specs: if n not in self.specs:
raise KeyError(f"Unknown task name: {n!r}") raise KeyError(f"Unknown task name: {n!r}")
kept: list[TaskSpec[Any]] = [] kept: list[TaskSpec[Any]] = [
for spec in self.specs.values(): _prune_deps(spec, lambda d: d in wanted) for spec in self.specs.values() if spec.name in wanted
if spec.name in wanted: ]
pruned_deps = tuple(d for d in spec.depends_on if d in wanted)
pruned_soft = tuple(d for d in spec.soft_depends_on if d in wanted)
kept.append(replace(spec, depends_on=pruned_deps, soft_depends_on=pruned_soft))
return Graph.from_specs(kept, defaults=self.defaults) return Graph.from_specs(kept, defaults=self.defaults)
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
+110 -40
View File
@@ -17,6 +17,7 @@ import json
import sys import sys
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterator
from pathlib import Path from pathlib import Path
from typing import Any, Mapping from typing import Any, Mapping
@@ -55,7 +56,74 @@ class StateBackend(ABC):
"""清除所有存储状态。""" """清除所有存储状态。"""
class MemoryBackend(StateBackend): class _TTLStateBackendMixin(StateBackend):
"""TTL 状态后端共享逻辑。
将 ``has`` / ``get`` / ``load`` / ``save`` / ``clear`` 的统一实现
委托给四个原始存取原语::meth:`_get_raw`、:meth:`_put_raw`、
:meth:`_iter_raw`、:meth:`_clear_raw`,并基于 :meth:`_now` 与
``self._ttl`` 提供统一的过期判断 :meth:`_is_expired`。
子类需设置 ``self._ttl`` 并实现上述四个原语;如需自定义时间源
(如 ``time.monotonic``)可覆盖 :meth:`_now`。
"""
_ttl: float | None
# ---- 原语:由子类实现 ---- #
@abstractmethod
def _get_raw(self, key: str) -> tuple[Any, float] | None:
"""返回 ``(value, ts)``;键不存在时返回 ``None``。"""
@abstractmethod
def _put_raw(self, key: str, value: Any, ts: float) -> None:
"""写入一条记录。"""
@abstractmethod
def _iter_raw(self) -> Iterator[tuple[str, Any, float]]:
"""迭代所有记录(不做过期过滤),yield ``(key, value, ts)``。"""
@abstractmethod
def _clear_raw(self) -> None:
"""清空所有记录。"""
# ---- 共享实现 ---- #
def _now(self) -> float:
"""当前时间戳,默认为 wall-clock 秒。"""
return time.time()
def _is_expired(self, ts: float) -> bool:
"""时间戳 ``ts`` 是否已过期。"""
if self._ttl is None:
return False
return (self._now() - ts) > self._ttl
@override
def load(self) -> Mapping[str, Any]:
return {k: v for k, v, ts in self._iter_raw() if not self._is_expired(ts)}
@override
def save(self, key: str, value: Any) -> None:
self._put_raw(key, value, self._now())
@override
def has(self, key: str) -> bool:
entry = self._get_raw(key)
return entry is not None and not self._is_expired(entry[1])
@override
def get(self, key: str) -> Any:
entry = self._get_raw(key)
if entry is None or self._is_expired(entry[1]):
raise KeyError(key)
return entry[0]
@override
def clear(self) -> None:
self._clear_raw()
class MemoryBackend(_TTLStateBackendMixin):
"""进程内 dict 后端。进程退出即丢失。 """进程内 dict 后端。进程退出即丢失。
Parameters Parameters
@@ -70,35 +138,35 @@ class MemoryBackend(StateBackend):
self._ttl = ttl self._ttl = ttl
@override @override
def load(self) -> Mapping[str, Any]: def _now(self) -> float:
return {k: v for k, (v, _ts) in self._store.items() if not self._expired(k)} return time.monotonic()
@override @override
def save(self, key: str, value: Any) -> None: def _get_raw(self, key: str) -> tuple[Any, float] | None:
self._store[key] = (value, time.monotonic()) return self._store.get(key)
@override @override
def has(self, key: str) -> bool: def _put_raw(self, key: str, value: Any, ts: float) -> None:
return key in self._store and not self._expired(key) self._store[key] = (value, ts)
@override @override
def get(self, key: str) -> Any: def _iter_raw(self) -> Iterator[tuple[str, Any, float]]:
if key not in self._store or self._expired(key): for k, (v, ts) in self._store.items():
raise KeyError(key) yield k, v, ts
return self._store[key][0]
@override @override
def clear(self) -> None: def _clear_raw(self) -> None:
self._store.clear() self._store.clear()
def _expired(self, key: str) -> bool: def _expired(self, key: str) -> bool:
if self._ttl is None or key not in self._store: """键是否已过期(兼容旧测试 API)。"""
entry = self._get_raw(key)
if entry is None:
return False return False
_value, ts = self._store[key] return self._is_expired(entry[1])
return (time.monotonic() - ts) > self._ttl
class JSONBackend(StateBackend): class JSONBackend(_TTLStateBackendMixin):
"""基于文件的 JSON 存储,用于跨进程续跑。 """基于文件的 JSON 存储,用于跨进程续跑。
存储格式:``{key: {"value": v, "ts": epoch_seconds}}``。 存储格式:``{key: {"value": v, "ts": epoch_seconds}}``。
@@ -144,17 +212,30 @@ class JSONBackend(StateBackend):
except (OSError, TypeError) as exc: except (OSError, TypeError) as exc:
raise StorageError(f"cannot write state file {self._path!r}", exc) from exc raise StorageError(f"cannot write state file {self._path!r}", exc) from exc
def _now(self) -> float: @override
return time.time() def _get_raw(self, key: str) -> tuple[Any, float] | None:
entry = self._store.get(key)
def _expired(self, entry: dict[str, Any]) -> bool: if entry is None:
if self._ttl is None: return None
return False return entry["value"], float(entry.get("ts", 0))
return (self._now() - float(entry.get("ts", 0))) > self._ttl
@override @override
def load(self) -> Mapping[str, Any]: def _put_raw(self, key: str, value: Any, ts: float) -> None:
return {k: v["value"] for k, v in self._store.items() if not self._expired(v)} self._store[key] = {"value": value, "ts": ts}
@override
def _iter_raw(self) -> Iterator[tuple[str, Any, float]]:
for k, entry in self._store.items():
yield k, entry["value"], float(entry.get("ts", 0))
@override
def _clear_raw(self) -> None:
self._store.clear()
@override
def clear(self) -> None:
super().clear()
self._flush()
@override @override
def save(self, key: str, value: Any) -> None: def save(self, key: str, value: Any) -> None:
@@ -162,23 +243,12 @@ class JSONBackend(StateBackend):
_ = json.dumps(value) _ = json.dumps(value)
except (TypeError, ValueError) as exc: except (TypeError, ValueError) as exc:
raise StorageError(f"result of key {key!r} is not JSON-serialisable", exc) from exc raise StorageError(f"result of key {key!r} is not JSON-serialisable", exc) from exc
self._store[key] = {"value": value, "ts": self._now()} super().save(key, value)
self._flush() self._flush()
@override def _expired(self, entry: Mapping[str, Any]) -> bool:
def has(self, key: str) -> bool: """带元数据的条目是否已过期(兼容旧测试 API)。"""
return key in self._store and not self._expired(self._store[key]) return self._is_expired(float(entry.get("ts", 0)))
@override
def get(self, key: str) -> Any:
if key not in self._store or self._expired(self._store[key]):
raise KeyError(key)
return self._store[key]["value"]
@override
def clear(self) -> None:
self._store.clear()
self._flush()
def resolve_backend(backend: StateBackend | None) -> StateBackend: def resolve_backend(backend: StateBackend | None) -> StateBackend:
+10 -3
View File
@@ -74,6 +74,13 @@ Condition = Callable[[Context], bool]
CacheKeyFn = Callable[[Context], str] CacheKeyFn = Callable[[Context], str]
def _format_skip_reason(failed_conditions: list[str]) -> str:
"""格式化跳过原因:≤2 个全展示,>2 个仅展示前 2 个并附总数。"""
if len(failed_conditions) <= 2:
return f"条件不满足: {', '.join(failed_conditions)}"
return f"条件不满足: {', '.join(failed_conditions[:2])}{len(failed_conditions)}个条件"
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# 重试策略 # 重试策略
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
@@ -315,6 +322,7 @@ class TaskSpec(Generic[T]):
------- -------
(should_run, skip_reason) (should_run, skip_reason)
``should_run`` 为 False 时 ``skip_reason`` 描述跳过原因。 ``should_run`` 为 False 时 ``skip_reason`` 描述跳过原因。
失败条件超过 2 个时仅展示前 2 个并附总数。
""" """
# 逐个求值条件,记录失败项。 # 逐个求值条件,记录失败项。
failed_conditions: list[str] = [] failed_conditions: list[str] = []
@@ -323,8 +331,7 @@ class TaskSpec(Generic[T]):
ok = condition(context) ok = condition(context)
except Exception: except Exception:
ok = False ok = False
name = getattr(condition, "__name__", None) or "匿名条件(执行错误)" failed_conditions.append("匿名条件(执行错误)")
failed_conditions.append(name)
continue continue
if not ok: if not ok:
reason = getattr(condition, "_reason", None) reason = getattr(condition, "_reason", None)
@@ -336,7 +343,7 @@ class TaskSpec(Generic[T]):
failed_conditions.append(getattr(condition, "__name__", None) or "匿名条件") failed_conditions.append(getattr(condition, "__name__", None) or "匿名条件")
if failed_conditions: if failed_conditions:
return False, f"条件不满足: {', '.join(failed_conditions)}" return False, _format_skip_reason(failed_conditions)
if self.skip_if_missing and not self._is_cmd_available(): if self.skip_if_missing and not self._is_cmd_available():
cmd_name = self.cmd[0] if isinstance(self.cmd, list) and self.cmd else "unknown" cmd_name = self.cmd[0] if isinstance(self.cmd, list) and self.cmd else "unknown"
+70
View File
@@ -0,0 +1,70 @@
"""常用工具函数."""
__all__ = ["perf_timer"]
import functools
import logging
import time
from collections import defaultdict
from typing import Callable, ParamSpec, TypedDict
from typing_extensions import TypeVar
P = ParamSpec("P")
R = TypeVar("R")
class _PerformanceMetrics(TypedDict):
"""性能指标."""
count: int
total_time: float
_perf_metrics: defaultdict[str, _PerformanceMetrics] = defaultdict(
lambda: _PerformanceMetrics(
count=0,
total_time=0.0,
)
)
def perf_timer(unit: str = "ms", precision: int = 4, report: bool = False):
"""性能计时器装饰器."""
scale: dict[str, float] = {
"s": 1.0,
"ms": 1000.0,
"us": 1000000.0,
}
def decorator(func: Callable[P, R]) -> Callable[P, R]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
_perf_metrics[func.__name__]["count"] += 1
_perf_metrics[func.__name__]["total_time"] += (end_time - start_time) * scale[unit]
if not report:
logging.info(
f"{func.__name__} {unit}: {_perf_metrics[func.__name__]['total_time']:.{precision}f}{unit}"
)
return result
return wrapper
if report:
import atexit
logging.basicConfig(level=logging.INFO)
logging.info(f"Performance metrics report enabled with unit {unit} and precision {precision}")
@atexit.register
def _() -> None:
for name, metrics in _perf_metrics.items():
logging.info(f"{name}: {metrics['count']} times, {metrics['total_time']:.{precision}f}{unit}")
return decorator
-66
View File
@@ -1,66 +0,0 @@
"""Tests for cli.which module."""
from __future__ import annotations
import shutil
from unittest.mock import patch
import pytest
import pyflowx as px
from pyflowx.cli import which
# ---------------------------------------------------------------------- #
# main function
# ---------------------------------------------------------------------- #
class TestMain:
"""Test main function."""
def test_main_with_single_command(self) -> None:
"""main() should handle single command argument."""
with patch("sys.argv", ["which", "python"]), patch.object(
shutil, "which", return_value="/usr/bin/python"
), patch.object(px, "run") as mock_run:
which.main()
# Should create a graph with one task
assert mock_run.called
graph = mock_run.call_args[0][0]
assert isinstance(graph, px.Graph)
def test_main_with_multiple_commands(self) -> None:
"""main() should handle multiple command arguments."""
with patch("sys.argv", ["which", "python", "pip", "node"]), patch.object(
shutil, "which", return_value="/usr/bin/cmd"
), patch.object(px, "run") as mock_run:
which.main()
# Should create a graph with three tasks
assert mock_run.called
graph = mock_run.call_args[0][0]
assert isinstance(graph, px.Graph)
def test_main_with_no_args_shows_help(self) -> None:
"""main() with no args should show help and exit."""
with patch("sys.argv", ["which"]), pytest.raises(SystemExit) as exc_info:
which.main()
assert exc_info.value.code == 2
def test_main_creates_task_specs_with_correct_names(self) -> None:
"""main() should create TaskSpecs with correct names."""
with patch("sys.argv", ["which", "git", "npm"]), patch.object(
shutil, "which", return_value="/usr/bin/cmd"
), patch.object(px, "run") as mock_run:
which.main()
graph = mock_run.call_args[0][0]
# Check that task names are correct
task_names = list(graph.all_specs().keys())
assert "which_git" in task_names
assert "which_npm" in task_names
def test_main_uses_thread_strategy(self) -> None:
"""main() should use thread strategy."""
with patch("sys.argv", ["which", "python"]), patch.object(
shutil, "which", return_value="/usr/bin/python"
), patch.object(px, "run") as mock_run:
which.main()
assert mock_run.call_args[1]["strategy"] == "thread"
+41
View File
@@ -0,0 +1,41 @@
import time
import pytest
from pytest_mock import MockerFixture
from pyflowx.utils import _perf_metrics, perf_timer
@pytest.fixture(autouse=True)
def reset_perf_metrics():
"""重置性能指标."""
_perf_metrics.clear()
class TestPerformanceTimer:
def test_perf_timer(self):
@perf_timer()
def test_func():
time.sleep(0.1)
test_func()
assert _perf_metrics["test_func"] is not None
assert _perf_metrics["test_func"]["count"] == 1
assert _perf_metrics["test_func"]["total_time"] >= 0.1
def test_perf_timer_report(self, mocker: MockerFixture):
mock_log = mocker.patch("logging.info")
@perf_timer(report=True, unit="ms", precision=3)
def test_func():
time.sleep(0.1)
test_func()
assert _perf_metrics["test_func"] is not None
assert _perf_metrics["test_func"]["count"] == 1
assert _perf_metrics["test_func"]["total_time"] >= 0.1
assert mock_log.call_count == 1
Generated
+1 -1
View File
@@ -5603,7 +5603,7 @@ pycountry = [
[[package]] [[package]]
name = "pyflowx" name = "pyflowx"
version = "0.2.9" version = "0.2.10"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "graphlib-backport", marker = "python_full_version < '3.9'" }, { name = "graphlib-backport", marker = "python_full_version < '3.9'" },