chore: 发布 pyflowx 0.2.10,新增性能计时器与多项重构
1. 新增 perf_timer 工具与配套测试用例 2. 重构任务条件跳过逻辑,优化失败条件展示 3. 重构 Graph 子图生成逻辑,提取公共依赖修剪函数 4. 重构条件模块,统一条件名称与失败原因获取逻辑 5. 重构存储后端,提取 TTL 共享逻辑并优化实现 6. 重构执行器模块,使用 Mixin 复用代码,拆分任务与层执行逻辑 7. 删除冗余的 which 命令测试文件 8. 更新依赖锁文件
This commit is contained in:
+27
-13
@@ -42,6 +42,19 @@ def _static(predicate: Callable[[], bool], name: str) -> Condition:
|
|||||||
return _cond
|
return _cond
|
||||||
|
|
||||||
|
|
||||||
|
def _cond_reason(cond: Condition) -> str | list[str] | None:
|
||||||
|
"""获取条件的失败原因:优先返回 ``_reason``,否则返回 ``__name__``。"""
|
||||||
|
reason = getattr(cond, "_reason", None)
|
||||||
|
if reason is not None:
|
||||||
|
return reason
|
||||||
|
return getattr(cond, "__name__", repr(cond))
|
||||||
|
|
||||||
|
|
||||||
|
def _cond_name(cond: Condition) -> str:
|
||||||
|
"""获取条件的可读名称。"""
|
||||||
|
return getattr(cond, "__name__", repr(cond))
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
# 模块级静态条件常量
|
# 模块级静态条件常量
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
@@ -61,21 +74,25 @@ class BuiltinConditions:
|
|||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
# 静态条件
|
# 静态条件
|
||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
|
@staticmethod
|
||||||
def IS_WINDOWS() -> Condition:
|
def IS_WINDOWS() -> Condition:
|
||||||
"""检查是否为 Windows 平台."""
|
"""检查是否为 Windows 平台."""
|
||||||
return _static(lambda: Constants.IS_WINDOWS, "IS_WINDOWS")
|
return IS_WINDOWS
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def IS_LINUX() -> Condition:
|
def IS_LINUX() -> Condition:
|
||||||
"""检查是否为 Linux 平台."""
|
"""检查是否为 Linux 平台."""
|
||||||
return _static(lambda: Constants.IS_LINUX, "IS_LINUX")
|
return IS_LINUX
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def IS_MACOS() -> Condition:
|
def IS_MACOS() -> Condition:
|
||||||
"""检查是否为 macOS 平台."""
|
"""检查是否为 macOS 平台."""
|
||||||
return _static(lambda: Constants.IS_MACOS, "IS_MACOS")
|
return IS_MACOS
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def IS_POSIX() -> Condition:
|
def IS_POSIX() -> Condition:
|
||||||
"""检查是否为 POSIX 平台."""
|
"""检查是否为 POSIX 平台."""
|
||||||
return _static(lambda: Constants.IS_POSIX, "IS_POSIX")
|
return IS_POSIX
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def PYTHON_VERSION(major: int, minor: int | None = None) -> Condition:
|
def PYTHON_VERSION(major: int, minor: int | None = None) -> Condition:
|
||||||
@@ -214,12 +231,12 @@ class BuiltinConditions:
|
|||||||
result = condition(ctx)
|
result = condition(ctx)
|
||||||
if result:
|
if result:
|
||||||
# inner 为 True 时 NOT 会失败,记录 inner 的具体原因
|
# inner 为 True 时 NOT 会失败,记录 inner 的具体原因
|
||||||
inner_reason = getattr(condition, "_reason", None)
|
inner_reason = _cond_reason(condition)
|
||||||
if inner_reason is not None:
|
if inner_reason is not None:
|
||||||
_cond._reason = inner_reason # type: ignore[attr-defined]
|
_cond._reason = inner_reason # type: ignore[attr-defined]
|
||||||
return not result
|
return not result
|
||||||
|
|
||||||
_cond.__name__ = f"NOT({getattr(condition, '__name__', repr(condition))})"
|
_cond.__name__ = f"NOT({_cond_name(condition)})"
|
||||||
return _cond
|
return _cond
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -229,8 +246,7 @@ class BuiltinConditions:
|
|||||||
def _cond(ctx: Context) -> bool:
|
def _cond(ctx: Context) -> bool:
|
||||||
return all(c(ctx) for c in conditions)
|
return all(c(ctx) for c in conditions)
|
||||||
|
|
||||||
names = [getattr(c, "__name__", repr(c)) for c in conditions]
|
_cond.__name__ = f"AND({', '.join(_cond_name(c) for c in conditions)})"
|
||||||
_cond.__name__ = f"AND({', '.join(names)})"
|
|
||||||
return _cond
|
return _cond
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -241,14 +257,12 @@ class BuiltinConditions:
|
|||||||
matched: list[str] = []
|
matched: list[str] = []
|
||||||
for c in conditions:
|
for c in conditions:
|
||||||
if c(ctx):
|
if c(ctx):
|
||||||
matched.append(
|
reason = _cond_reason(c)
|
||||||
getattr(c, "_reason", None) or getattr(c, "__name__", repr(c)),
|
matched.append(reason if isinstance(reason, str) else str(reason))
|
||||||
)
|
|
||||||
if matched:
|
if matched:
|
||||||
_cond._reason = matched # type: ignore[attr-defined]
|
_cond._reason = matched # type: ignore[attr-defined]
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
names = [getattr(c, "__name__", repr(c)) for c in conditions]
|
_cond.__name__ = f"OR({', '.join(_cond_name(c) for c in conditions)})"
|
||||||
_cond.__name__ = f"OR({', '.join(names)})"
|
|
||||||
return _cond
|
return _cond
|
||||||
|
|||||||
+296
-341
@@ -10,6 +10,17 @@
|
|||||||
* ``dependency`` —— 依赖驱动调度:任务在其所有硬依赖完成后立即启动,
|
* ``dependency`` —— 依赖驱动调度:任务在其所有硬依赖完成后立即启动,
|
||||||
无需等待同层其他任务。最大化并行度。
|
无需等待同层其他任务。最大化并行度。
|
||||||
|
|
||||||
|
架构
|
||||||
|
----
|
||||||
|
本模块通过 **Mixin** 组合消除同步/异步与各层执行器之间的重复代码:
|
||||||
|
|
||||||
|
* :class:`_TaskSkipMixin` —— 上游跳过 / 条件跳过的预检逻辑。
|
||||||
|
* :class:`_TaskRetryMixin` —— 重试决策、成功/失败后处理、finalize。
|
||||||
|
* :class:`_LayerMixin` —— 缓存过滤、优先级排序、信号量构建、结果存储。
|
||||||
|
* :class:`SyncTaskRunner` / :class:`AsyncTaskRunner` —— 任务级执行器,组合上述 Mixin。
|
||||||
|
* :class:`SequentialLayerRunner` / :class:`ThreadedLayerRunner` /
|
||||||
|
:class:`AsyncLayerRunner` / :class:`DependencyRunner` —— 层级执行器,组合 :class:`_LayerMixin`。
|
||||||
|
|
||||||
所有策略共享统一异步内核,支持:
|
所有策略共享统一异步内核,支持:
|
||||||
* :class:`RetryPolicy`(max_attempts/delay/backoff/jitter/retry_on)
|
* :class:`RetryPolicy`(max_attempts/delay/backoff/jitter/retry_on)
|
||||||
* 软依赖注入与默认值
|
* 软依赖注入与默认值
|
||||||
@@ -30,6 +41,7 @@ import concurrent.futures
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Awaitable, Callable, Literal, Mapping, cast
|
from typing import Any, Awaitable, Callable, Literal, Mapping, cast
|
||||||
|
|
||||||
@@ -48,7 +60,7 @@ Strategy = Literal["sequential", "thread", "async", "dependency"]
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
# 辅助
|
# 无状态公共辅助
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
def _is_async_fn(spec: TaskSpec[Any]) -> bool:
|
def _is_async_fn(spec: TaskSpec[Any]) -> bool:
|
||||||
"""判断 ``spec.effective_fn`` 是否为协程函数。"""
|
"""判断 ``spec.effective_fn`` 是否为协程函数。"""
|
||||||
@@ -71,17 +83,6 @@ def _emit(on_event: EventCallback | None, result: TaskResult[Any]) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _log_retry(spec: TaskSpec[Any], attempt: int, max_attempts: int, exc: BaseException) -> None:
|
|
||||||
"""记录重试日志。"""
|
|
||||||
logger.warning(
|
|
||||||
"task %r failed (attempt %d/%d): %r; retrying",
|
|
||||||
spec.name,
|
|
||||||
attempt,
|
|
||||||
max_attempts,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None:
|
def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None:
|
||||||
"""安全调用钩子(异常仅记录,不影响任务状态)。"""
|
"""安全调用钩子(异常仅记录,不影响任务状态)。"""
|
||||||
hook: Callable[..., None] | None = getattr(hooks, fn_name, None)
|
hook: Callable[..., None] | None = getattr(hooks, fn_name, None)
|
||||||
@@ -93,87 +94,6 @@ def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None:
|
|||||||
logger.warning("hook %s raised: %r", fn_name, exc)
|
logger.warning("hook %s raised: %r", fn_name, exc)
|
||||||
|
|
||||||
|
|
||||||
def _check_upstream_skipped(
|
|
||||||
spec: TaskSpec[Any],
|
|
||||||
report: RunReport | None,
|
|
||||||
) -> tuple[bool, str | None]:
|
|
||||||
"""检查硬依赖上游任务是否被 SKIPPED 或 FAILED。
|
|
||||||
|
|
||||||
软依赖不影响本检查——软依赖被跳过时注入默认值。
|
|
||||||
"""
|
|
||||||
if report is None: # pragma: no cover
|
|
||||||
return False, None # pragma: no cover
|
|
||||||
|
|
||||||
if spec.allow_upstream_skip: # pragma: no cover
|
|
||||||
return False, None # pragma: no cover
|
|
||||||
|
|
||||||
for dep in spec.depends_on:
|
|
||||||
if dep not in report.results: # pragma: no cover
|
|
||||||
continue # pragma: no cover
|
|
||||||
dep_status = report.results[dep].status
|
|
||||||
if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED):
|
|
||||||
return True, f"上游任务 '{dep}' 状态为 {dep_status.value}"
|
|
||||||
return False, None # pragma: no cover
|
|
||||||
|
|
||||||
|
|
||||||
def _format_reason(reason: Any) -> str:
|
|
||||||
"""将 _reason 格式化为可读字符串."""
|
|
||||||
if isinstance(reason, list):
|
|
||||||
return ", ".join(str(r) for r in reason)
|
|
||||||
return str(reason)
|
|
||||||
|
|
||||||
|
|
||||||
def _evaluate_conditions(spec: TaskSpec[Any], context: Mapping[str, Any]) -> str | None:
|
|
||||||
"""求值所有条件,返回跳过原因或 ``None``。
|
|
||||||
|
|
||||||
条件接收上下文映射(硬依赖 + 软依赖结果)。
|
|
||||||
"""
|
|
||||||
failed_conditions: list[str] = []
|
|
||||||
for condition in spec.conditions:
|
|
||||||
try:
|
|
||||||
ok = condition(context)
|
|
||||||
except Exception:
|
|
||||||
ok = False
|
|
||||||
name = getattr(condition, "__name__", None) or "匿名条件(执行错误)"
|
|
||||||
failed_conditions.append(name)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not ok:
|
|
||||||
reason = getattr(condition, "_reason", None)
|
|
||||||
if reason is not None:
|
|
||||||
failed_conditions.append(_format_reason(reason))
|
|
||||||
else:
|
|
||||||
failed_conditions.append(getattr(condition, "__name__", None) or "匿名条件")
|
|
||||||
|
|
||||||
if failed_conditions:
|
|
||||||
if len(failed_conditions) <= 2:
|
|
||||||
return f"条件不满足: {', '.join(failed_conditions)}"
|
|
||||||
return f"条件不满足: {', '.join(failed_conditions[:2])} 等{len(failed_conditions)}个条件"
|
|
||||||
|
|
||||||
if spec.skip_if_missing and not spec._is_cmd_available():
|
|
||||||
cmd_name = spec.cmd[0] if isinstance(spec.cmd, list) and spec.cmd else "unknown"
|
|
||||||
return f"命令不存在: {cmd_name}"
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _make_skipped_result(
|
|
||||||
spec: TaskSpec[Any],
|
|
||||||
reason: str,
|
|
||||||
on_event: EventCallback | None,
|
|
||||||
) -> TaskResult[Any]:
|
|
||||||
"""构造 SKIPPED 的 TaskResult。"""
|
|
||||||
result: TaskResult[Any] = TaskResult(
|
|
||||||
spec=spec,
|
|
||||||
status=TaskStatus.SKIPPED,
|
|
||||||
finished_at=datetime.now(),
|
|
||||||
reason=reason,
|
|
||||||
)
|
|
||||||
_emit(on_event, result)
|
|
||||||
logger.info("task %r skipped (%s)", spec.name, reason)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _build_context(
|
def _build_context(
|
||||||
spec: TaskSpec[Any],
|
spec: TaskSpec[Any],
|
||||||
global_context: Mapping[str, Any],
|
global_context: Mapping[str, Any],
|
||||||
@@ -185,19 +105,16 @@ def _build_context(
|
|||||||
软依赖:上游成功则注入其值;否则注入 ``spec.defaults`` 中的默认值(或 ``None``)。
|
软依赖:上游成功则注入其值;否则注入 ``spec.defaults`` 中的默认值(或 ``None``)。
|
||||||
"""
|
"""
|
||||||
ctx: dict[str, Any] = {}
|
ctx: dict[str, Any] = {}
|
||||||
|
|
||||||
for dep in spec.depends_on:
|
for dep in spec.depends_on:
|
||||||
if dep in global_context:
|
if dep in global_context:
|
||||||
ctx[dep] = global_context[dep]
|
ctx[dep] = global_context[dep]
|
||||||
|
|
||||||
for dep in spec.soft_depends_on:
|
for dep in spec.soft_depends_on:
|
||||||
if dep in global_context:
|
if dep in global_context:
|
||||||
ctx[dep] = global_context[dep]
|
ctx[dep] = global_context[dep]
|
||||||
elif dep in spec.defaults: # pragma: no cover
|
elif dep in spec.defaults:
|
||||||
ctx[dep] = spec.defaults[dep] # pragma: no cover
|
ctx[dep] = spec.defaults[dep]
|
||||||
else:
|
else:
|
||||||
ctx[dep] = None
|
ctx[dep] = None
|
||||||
|
|
||||||
return ctx
|
return ctx
|
||||||
|
|
||||||
|
|
||||||
@@ -222,6 +139,38 @@ def _apply_cached(
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _sort_by_priority(layer: list[str], graph: Graph) -> list[str]:
|
||||||
|
"""按优先级降序排序(稳定排序)。"""
|
||||||
|
return sorted(layer, key=lambda n: -graph.resolved_spec(n).priority)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------- #
|
||||||
|
# Mixin:任务级跳过 / 重试 / 成功处理
|
||||||
|
# ---------------------------------------------------------------------- #
|
||||||
|
class _TaskSkipMixin:
|
||||||
|
"""任务级跳过预检共享逻辑。
|
||||||
|
|
||||||
|
将"上游被跳过/失败"与"条件不满足"两类跳过判断统一为单一入口,
|
||||||
|
被 :class:`SyncTaskRunner` 与 :class:`AsyncTaskRunner` 复用。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _upstream_skip_reason(spec: TaskSpec[Any], report: RunReport | None) -> str | None:
|
||||||
|
"""硬依赖被 SKIPPED/FAILED 时返回原因字符串,否则 ``None``。
|
||||||
|
|
||||||
|
软依赖不影响本检查——软依赖被跳过时注入默认值。
|
||||||
|
"""
|
||||||
|
if report is None or spec.allow_upstream_skip:
|
||||||
|
return None
|
||||||
|
for dep in spec.depends_on:
|
||||||
|
if dep not in report.results:
|
||||||
|
continue
|
||||||
|
dep_status = report.results[dep].status
|
||||||
|
if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED):
|
||||||
|
return f"上游任务 '{dep}' 状态为 {dep_status.value}"
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _prepare_for_execution(
|
def _prepare_for_execution(
|
||||||
spec: TaskSpec[Any],
|
spec: TaskSpec[Any],
|
||||||
context: Mapping[str, Any],
|
context: Mapping[str, Any],
|
||||||
@@ -231,23 +180,51 @@ def _prepare_for_execution(
|
|||||||
"""执行前预检:上游跳过 / 条件跳过。
|
"""执行前预检:上游跳过 / 条件跳过。
|
||||||
|
|
||||||
返回 SKIPPED TaskResult 或 ``None``(继续执行)。
|
返回 SKIPPED TaskResult 或 ``None``(继续执行)。
|
||||||
|
条件判断委托给 :meth:`TaskSpec.should_execute`,避免重复实现。
|
||||||
"""
|
"""
|
||||||
should_skip, skip_reason = _check_upstream_skipped(spec, report)
|
# 1. 上游被跳过/失败
|
||||||
if should_skip:
|
skip_reason = _TaskSkipMixin._upstream_skip_reason(spec, report)
|
||||||
return _make_skipped_result(spec, skip_reason or "上游任务被跳过", on_event)
|
# 2. 条件 / skip_if_missing(单一来源:TaskSpec.should_execute)
|
||||||
|
if skip_reason is None:
|
||||||
skip_reason = _evaluate_conditions(spec, context)
|
should_run, cond_reason = spec.should_execute(context)
|
||||||
if skip_reason is not None:
|
if not should_run:
|
||||||
return _make_skipped_result(spec, skip_reason, on_event)
|
skip_reason = cond_reason or "条件不满足"
|
||||||
|
if skip_reason is None:
|
||||||
return None
|
return None
|
||||||
|
# 构造 SKIPPED 结果
|
||||||
|
result: TaskResult[Any] = TaskResult(
|
||||||
|
spec=spec,
|
||||||
|
status=TaskStatus.SKIPPED,
|
||||||
|
finished_at=datetime.now(),
|
||||||
|
reason=skip_reason,
|
||||||
|
)
|
||||||
|
_emit(on_event, result)
|
||||||
|
logger.info("task %r skipped (%s)", spec.name, skip_reason)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class _TaskRetryMixin:
|
||||||
|
"""任务级重试决策与失败/成功后处理共享逻辑。"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _should_retry(spec: TaskSpec[Any], attempts: int, exc: BaseException) -> bool:
|
||||||
|
"""是否应继续重试。"""
|
||||||
|
return attempts < spec.retry.max_attempts and spec.retry.should_retry(exc)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _mark_success(spec: TaskSpec[Any], result: TaskResult[Any], value: Any) -> None:
|
||||||
|
"""标记任务成功并触发 post_run 钩子。"""
|
||||||
|
result.value = value
|
||||||
|
result.status = TaskStatus.SUCCESS
|
||||||
|
result.finished_at = datetime.now()
|
||||||
|
_run_hooks(spec.hooks, "post_run", spec, value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _finalize_failure(
|
def _finalize_failure(
|
||||||
result: TaskResult[Any],
|
result: TaskResult[Any],
|
||||||
layer_idx: int | None,
|
layer_idx: int | None,
|
||||||
on_event: EventCallback | None = None,
|
on_event: EventCallback | None,
|
||||||
continue_on_error: bool = False,
|
continue_on_error: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。"""
|
"""标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。"""
|
||||||
result.status = TaskStatus.FAILED
|
result.status = TaskStatus.FAILED
|
||||||
@@ -266,41 +243,66 @@ def _finalize_failure(
|
|||||||
layer=layer_idx,
|
layer=layer_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _handle_failure(
|
||||||
|
spec: TaskSpec[Any],
|
||||||
|
result: TaskResult[Any],
|
||||||
|
exc: BaseException,
|
||||||
|
layer_idx: int | None,
|
||||||
|
on_event: EventCallback | None,
|
||||||
|
) -> bool:
|
||||||
|
"""统一处理失败:超时转换、重试决策、finalize。
|
||||||
|
|
||||||
def _sleep_for_retry(spec: TaskSpec[Any], attempt: int) -> None:
|
Returns
|
||||||
"""重试前的同步等待。"""
|
-------
|
||||||
wait = spec.retry.wait_seconds(attempt)
|
bool
|
||||||
if wait > 0:
|
``True`` 表示已 finalize(不再重试);``False`` 表示应继续重试。
|
||||||
import time
|
"""
|
||||||
|
# asyncio.TimeoutError → TaskTimeoutError(统一异常类型)
|
||||||
time.sleep(wait)
|
if isinstance(exc, asyncio.TimeoutError):
|
||||||
|
exc = TaskTimeoutError(spec.name, spec.timeout or 0.0)
|
||||||
|
logger.warning(
|
||||||
async def _async_sleep_for_retry(spec: TaskSpec[Any], attempt: int) -> None:
|
"task %r timed out (attempt %d/%d); retrying",
|
||||||
"""重试前的异步等待。"""
|
spec.name,
|
||||||
wait = spec.retry.wait_seconds(attempt)
|
result.attempts,
|
||||||
if wait > 0:
|
spec.retry.max_attempts,
|
||||||
await asyncio.sleep(wait)
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"task %r failed (attempt %d/%d): %r; retrying",
|
||||||
|
spec.name,
|
||||||
|
result.attempts,
|
||||||
|
spec.retry.max_attempts,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
result.error = exc
|
||||||
|
if _TaskRetryMixin._should_retry(spec, result.attempts, exc):
|
||||||
|
return False
|
||||||
|
_run_hooks(spec.hooks, "on_failure", spec, exc)
|
||||||
|
_TaskRetryMixin._finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
# 同步执行内核
|
# 任务执行器:同步 / 异步(复用 _TaskSkipMixin + _TaskRetryMixin)
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
def _run_sync_with_retry(
|
class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
|
||||||
|
"""同步任务执行器:带重试与跳过预检。"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def run(
|
||||||
spec: TaskSpec[Any],
|
spec: TaskSpec[Any],
|
||||||
context: Mapping[str, Any],
|
context: Mapping[str, Any],
|
||||||
layer_idx: int | None,
|
layer_idx: int | None,
|
||||||
on_event: EventCallback | None = None,
|
on_event: EventCallback | None = None,
|
||||||
report: RunReport | None = None,
|
report: RunReport | None = None,
|
||||||
) -> TaskResult[Any]:
|
) -> TaskResult[Any]:
|
||||||
"""执行同步任务并带重试;返回填充好的 TaskResult。"""
|
skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event)
|
||||||
skipped = _prepare_for_execution(spec, context, report, on_event)
|
|
||||||
if skipped is not None:
|
if skipped is not None:
|
||||||
return skipped
|
return skipped
|
||||||
|
|
||||||
result: TaskResult[Any] = TaskResult(spec=spec)
|
result: TaskResult[Any] = TaskResult(spec=spec)
|
||||||
result.started_at = datetime.now()
|
result.started_at = datetime.now()
|
||||||
max_attempts = spec.retry.max_attempts
|
|
||||||
args, kwargs = build_call_args(spec, context)
|
args, kwargs = build_call_args(spec, context)
|
||||||
|
|
||||||
_run_hooks(spec.hooks, "pre_run", spec)
|
_run_hooks(spec.hooks, "pre_run", spec)
|
||||||
@@ -309,25 +311,60 @@ def _run_sync_with_retry(
|
|||||||
result.attempts += 1
|
result.attempts += 1
|
||||||
try:
|
try:
|
||||||
with spec.env_context():
|
with spec.env_context():
|
||||||
result.value = spec.effective_fn(*args, **kwargs)
|
value = spec.effective_fn(*args, **kwargs)
|
||||||
result.status = TaskStatus.SUCCESS
|
_TaskRetryMixin._mark_success(spec, result, value)
|
||||||
result.finished_at = datetime.now()
|
|
||||||
_run_hooks(spec.hooks, "post_run", spec, result.value)
|
|
||||||
return result
|
return result
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
result.error = exc
|
if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event):
|
||||||
if result.attempts >= max_attempts or not spec.retry.should_retry(exc):
|
|
||||||
_run_hooks(spec.hooks, "on_failure", spec, exc)
|
|
||||||
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
|
|
||||||
return result
|
return result
|
||||||
_log_retry(spec, result.attempts, max_attempts, exc)
|
wait = spec.retry.wait_seconds(result.attempts)
|
||||||
_sleep_for_retry(spec, result.attempts)
|
if wait > 0:
|
||||||
# pragma: no cover
|
time.sleep(wait)
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
|
||||||
|
"""异步任务执行器:在事件循环上运行同步或异步任务,带重试与跳过预检。"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def run(
|
||||||
|
spec: TaskSpec[Any],
|
||||||
|
context: Mapping[str, Any],
|
||||||
|
layer_idx: int | None,
|
||||||
|
on_event: EventCallback | None = None,
|
||||||
|
report: RunReport | None = None,
|
||||||
|
semaphore: asyncio.Semaphore | None = None,
|
||||||
|
) -> TaskResult[Any]:
|
||||||
|
skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event)
|
||||||
|
if skipped is not None:
|
||||||
|
return skipped
|
||||||
|
|
||||||
|
async def _inner() -> TaskResult[Any]:
|
||||||
|
result: TaskResult[Any] = TaskResult(spec=spec)
|
||||||
|
result.started_at = datetime.now()
|
||||||
|
args, kwargs = build_call_args(spec, context)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
_run_hooks(spec.hooks, "pre_run", spec)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
result.attempts += 1
|
||||||
|
try:
|
||||||
|
value = await _execute_async_task(spec, args, kwargs, loop)
|
||||||
|
_TaskRetryMixin._mark_success(spec, result, value)
|
||||||
|
return result
|
||||||
|
except Exception as exc:
|
||||||
|
if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event):
|
||||||
|
return result
|
||||||
|
wait = spec.retry.wait_seconds(result.attempts)
|
||||||
|
if wait > 0:
|
||||||
|
await asyncio.sleep(wait)
|
||||||
|
|
||||||
|
if semaphore is not None:
|
||||||
|
async with semaphore:
|
||||||
|
return await _inner()
|
||||||
|
return await _inner()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
|
||||||
# 异步执行内核
|
|
||||||
# ---------------------------------------------------------------------- #
|
|
||||||
async def _execute_async_task(
|
async def _execute_async_task(
|
||||||
spec: TaskSpec[Any],
|
spec: TaskSpec[Any],
|
||||||
args: tuple[Any, ...],
|
args: tuple[Any, ...],
|
||||||
@@ -339,9 +376,7 @@ async def _execute_async_task(
|
|||||||
coro = cast(Awaitable[Any], spec.effective_fn(*args, **kwargs))
|
coro = cast(Awaitable[Any], spec.effective_fn(*args, **kwargs))
|
||||||
if spec.timeout is not None:
|
if spec.timeout is not None:
|
||||||
return await asyncio.wait_for(coro, timeout=spec.timeout)
|
return await asyncio.wait_for(coro, timeout=spec.timeout)
|
||||||
else:
|
|
||||||
return await coro
|
return await coro
|
||||||
else:
|
|
||||||
|
|
||||||
def fn_call() -> Any:
|
def fn_call() -> Any:
|
||||||
with spec.env_context():
|
with spec.env_context():
|
||||||
@@ -349,87 +384,89 @@ async def _execute_async_task(
|
|||||||
|
|
||||||
if spec.timeout is not None:
|
if spec.timeout is not None:
|
||||||
return await asyncio.wait_for(loop.run_in_executor(None, fn_call), timeout=spec.timeout)
|
return await asyncio.wait_for(loop.run_in_executor(None, fn_call), timeout=spec.timeout)
|
||||||
else:
|
|
||||||
return await loop.run_in_executor(None, fn_call)
|
return await loop.run_in_executor(None, fn_call)
|
||||||
|
|
||||||
|
|
||||||
async def _run_async_with_retry(
|
# ---------------------------------------------------------------------- #
|
||||||
spec: TaskSpec[Any],
|
# Mixin:层执行共享逻辑
|
||||||
context: Mapping[str, Any],
|
# ---------------------------------------------------------------------- #
|
||||||
layer_idx: int | None,
|
class _LayerMixin:
|
||||||
on_event: EventCallback | None = None,
|
"""层执行共享逻辑:缓存过滤、优先级排序、信号量构建、结果存储。
|
||||||
report: RunReport | None = None,
|
|
||||||
semaphore: asyncio.Semaphore | None = None,
|
|
||||||
) -> TaskResult[Any]:
|
|
||||||
"""在事件循环上执行任务(同步或异步)并带重试。"""
|
|
||||||
skipped = _prepare_for_execution(spec, context, report, on_event)
|
|
||||||
if skipped is not None:
|
|
||||||
return skipped
|
|
||||||
|
|
||||||
if semaphore is not None:
|
四个层执行器(sequential/threaded/async/dependency)通过组合此 Mixin
|
||||||
async with semaphore:
|
消除"过滤缓存→排序→运行→存结果"的样板代码。
|
||||||
return await _run_async_inner(spec, context, layer_idx, on_event, report)
|
"""
|
||||||
return await _run_async_inner(spec, context, layer_idx, on_event, report)
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _filter_and_sort(
|
||||||
|
layer: list[str],
|
||||||
|
graph: Graph,
|
||||||
|
context: dict[str, Any],
|
||||||
|
report: RunReport,
|
||||||
|
backend: StateBackend,
|
||||||
|
on_event: EventCallback | None,
|
||||||
|
) -> list[str]:
|
||||||
|
"""过滤掉已命中缓存的任务,按优先级排序返回待运行列表。"""
|
||||||
|
to_run: list[str] = []
|
||||||
|
for name in layer:
|
||||||
|
spec = graph.resolved_spec(name)
|
||||||
|
if not _apply_cached(name, spec, context, report, backend, on_event):
|
||||||
|
to_run.append(name)
|
||||||
|
return _sort_by_priority(to_run, graph)
|
||||||
|
|
||||||
async def _run_async_inner(
|
@staticmethod
|
||||||
spec: TaskSpec[Any],
|
def _store_result(
|
||||||
context: Mapping[str, Any],
|
name: str,
|
||||||
layer_idx: int | None,
|
result: TaskResult[Any],
|
||||||
on_event: EventCallback | None = None,
|
graph: Graph,
|
||||||
report: RunReport | None = None, # noqa: ARG001
|
context: dict[str, Any],
|
||||||
) -> TaskResult[Any]:
|
report: RunReport,
|
||||||
"""异步执行内核的内部实现(已获取 semaphore 后)。"""
|
backend: StateBackend,
|
||||||
result: TaskResult[Any] = TaskResult(spec=spec)
|
on_event: EventCallback | None,
|
||||||
result.started_at = datetime.now()
|
context_snapshot: Mapping[str, Any] | None = None,
|
||||||
max_attempts = spec.retry.max_attempts
|
) -> None:
|
||||||
args, kwargs = build_call_args(spec, context)
|
"""存储任务结果到 context/report/backend 并触发事件。"""
|
||||||
loop = asyncio.get_event_loop()
|
context[name] = result.value
|
||||||
|
if result.status == TaskStatus.SUCCESS:
|
||||||
|
spec = graph.resolved_spec(name)
|
||||||
|
task_ctx = _build_context(spec, context_snapshot if context_snapshot is not None else context, report)
|
||||||
|
backend.save(spec.storage_key(task_ctx), result.value)
|
||||||
|
report.results[name] = result
|
||||||
|
_emit(on_event, result)
|
||||||
|
|
||||||
_run_hooks(spec.hooks, "pre_run", spec)
|
@staticmethod
|
||||||
|
def _build_semaphores(
|
||||||
|
to_run: list[str],
|
||||||
|
graph: Graph,
|
||||||
|
sem_factory: Callable[[int], Any],
|
||||||
|
concurrency_limits: Mapping[str, int],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""为每个 ``concurrency_key`` 创建一个信号量。"""
|
||||||
|
semaphores: dict[str, Any] = {}
|
||||||
|
for name in to_run:
|
||||||
|
spec = graph.resolved_spec(name)
|
||||||
|
key = spec.concurrency_key
|
||||||
|
if key is not None and key not in semaphores:
|
||||||
|
limit = concurrency_limits.get(key, 1)
|
||||||
|
semaphores[key] = sem_factory(limit)
|
||||||
|
return semaphores
|
||||||
|
|
||||||
while True:
|
@staticmethod
|
||||||
result.attempts += 1
|
def _get_sem(semaphores: Mapping[str, Any], spec: TaskSpec[Any]) -> Any | None:
|
||||||
try:
|
"""获取任务对应的信号量(无 concurrency_key 则返回 None)。"""
|
||||||
result.value = await _execute_async_task(spec, args, kwargs, loop)
|
if spec.concurrency_key is None:
|
||||||
result.status = TaskStatus.SUCCESS
|
return None
|
||||||
result.finished_at = datetime.now()
|
return semaphores.get(spec.concurrency_key)
|
||||||
_run_hooks(spec.hooks, "post_run", spec, result.value)
|
|
||||||
return result
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
exc: BaseException = TaskTimeoutError(spec.name, spec.timeout or 0.0)
|
|
||||||
result.error = exc
|
|
||||||
if result.attempts >= max_attempts or not spec.retry.should_retry(exc):
|
|
||||||
_run_hooks(spec.hooks, "on_failure", spec, exc)
|
|
||||||
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
|
|
||||||
return result
|
|
||||||
logger.warning(
|
|
||||||
"task %r timed out (attempt %d/%d); retrying",
|
|
||||||
spec.name,
|
|
||||||
result.attempts,
|
|
||||||
max_attempts,
|
|
||||||
)
|
|
||||||
await _async_sleep_for_retry(spec, result.attempts)
|
|
||||||
except Exception as exc:
|
|
||||||
result.error = exc
|
|
||||||
if result.attempts >= max_attempts or not spec.retry.should_retry(exc):
|
|
||||||
_run_hooks(spec.hooks, "on_failure", spec, exc)
|
|
||||||
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
|
|
||||||
return result
|
|
||||||
_log_retry(spec, result.attempts, max_attempts, exc)
|
|
||||||
await _async_sleep_for_retry(spec, result.attempts)
|
|
||||||
# pragma: no cover
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
# 层执行器
|
# 层执行器
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
def _sort_by_priority(layer: list[str], graph: Graph) -> list[str]:
|
class SequentialLayerRunner(_LayerMixin):
|
||||||
"""按优先级降序排序(稳定排序)。"""
|
"""逐个运行某层的任务(按优先级排序)。"""
|
||||||
return sorted(layer, key=lambda n: -graph.resolved_spec(n).priority)
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _execute_layer_sequential(
|
def execute(
|
||||||
layer: list[str],
|
layer: list[str],
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
context: dict[str, Any],
|
context: dict[str, Any],
|
||||||
@@ -438,21 +475,18 @@ def _execute_layer_sequential(
|
|||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
on_event: EventCallback | None,
|
on_event: EventCallback | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""逐个运行某层的任务(按优先级排序)。"""
|
for name in SequentialLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event):
|
||||||
for name in _sort_by_priority(layer, graph):
|
|
||||||
spec = graph.resolved_spec(name)
|
spec = graph.resolved_spec(name)
|
||||||
if _apply_cached(name, spec, context, report, backend, on_event):
|
|
||||||
continue
|
|
||||||
task_ctx = _build_context(spec, context, report)
|
task_ctx = _build_context(spec, context, report)
|
||||||
result = _run_sync_with_retry(spec, task_ctx, layer_idx, on_event, report)
|
result = SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
|
||||||
context[name] = result.value
|
SequentialLayerRunner._store_result(name, result, graph, context, report, backend, on_event)
|
||||||
if result.status == TaskStatus.SUCCESS:
|
|
||||||
backend.save(spec.storage_key(task_ctx), result.value)
|
|
||||||
report.results[name] = result
|
|
||||||
_emit(on_event, result)
|
|
||||||
|
|
||||||
|
|
||||||
def _execute_layer_threaded(
|
class ThreadedLayerRunner(_LayerMixin):
|
||||||
|
"""在线程池中并发运行某层的任务。"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def execute(
|
||||||
layer: list[str],
|
layer: list[str],
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
context: dict[str, Any],
|
context: dict[str, Any],
|
||||||
@@ -463,69 +497,47 @@ def _execute_layer_threaded(
|
|||||||
max_workers: int,
|
max_workers: int,
|
||||||
concurrency_limits: Mapping[str, int],
|
concurrency_limits: Mapping[str, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""在线程池中并发运行某层的任务。"""
|
to_run = ThreadedLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event)
|
||||||
to_run: list[str] = []
|
|
||||||
for name in layer:
|
|
||||||
spec = graph.resolved_spec(name)
|
|
||||||
task_ctx = _build_context(spec, context, report)
|
|
||||||
if _apply_cached(name, spec, context, report, backend, on_event):
|
|
||||||
continue
|
|
||||||
to_run.append(name)
|
|
||||||
|
|
||||||
if not to_run:
|
if not to_run:
|
||||||
return
|
return
|
||||||
|
semaphores = ThreadedLayerRunner._build_semaphores(to_run, graph, threading.Semaphore, concurrency_limits)
|
||||||
to_run = _sort_by_priority(to_run, graph)
|
|
||||||
|
|
||||||
# 为每个 concurrency_key 创建线程信号量
|
|
||||||
semaphores: dict[str, threading.Semaphore] = {}
|
|
||||||
for name in to_run:
|
|
||||||
spec = graph.resolved_spec(name)
|
|
||||||
key = spec.concurrency_key
|
|
||||||
if key is not None and key not in semaphores:
|
|
||||||
limit = concurrency_limits.get(key, 1)
|
|
||||||
semaphores[key] = threading.Semaphore(limit)
|
|
||||||
|
|
||||||
context_snapshot = dict(context)
|
context_snapshot = dict(context)
|
||||||
lock = threading.Lock()
|
lock = threading.Lock()
|
||||||
|
|
||||||
def _run_threaded_task(name: str) -> TaskResult[Any]:
|
def _run_threaded_task(name: str) -> TaskResult[Any]:
|
||||||
spec = graph.resolved_spec(name)
|
spec = graph.resolved_spec(name)
|
||||||
task_ctx = _build_context(spec, context_snapshot, report)
|
task_ctx = _build_context(spec, context_snapshot, report)
|
||||||
sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None
|
sem = ThreadedLayerRunner._get_sem(semaphores, spec)
|
||||||
if sem is not None:
|
if sem is not None:
|
||||||
sem.acquire()
|
sem.acquire()
|
||||||
try:
|
try:
|
||||||
return _run_sync_with_retry(spec, task_ctx, layer_idx, on_event, report)
|
return SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
|
||||||
finally:
|
finally:
|
||||||
if sem is not None:
|
if sem is not None:
|
||||||
sem.release()
|
sem.release()
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||||
future_to_name: dict[concurrent.futures.Future[TaskResult[Any]], str] = {}
|
future_to_name: dict[concurrent.futures.Future[TaskResult[Any]], str] = {
|
||||||
for name in to_run:
|
pool.submit(_run_threaded_task, name): name for name in to_run
|
||||||
fut = pool.submit(_run_threaded_task, name)
|
}
|
||||||
future_to_name[fut] = name
|
|
||||||
|
|
||||||
completed: dict[str, TaskResult[Any]] = {}
|
completed: dict[str, TaskResult[Any]] = {}
|
||||||
try:
|
try:
|
||||||
for fut in concurrent.futures.as_completed(future_to_name):
|
for fut in concurrent.futures.as_completed(future_to_name):
|
||||||
name = future_to_name[fut]
|
name = future_to_name[fut]
|
||||||
result = fut.result()
|
completed[name] = fut.result()
|
||||||
completed[name] = result
|
|
||||||
finally:
|
finally:
|
||||||
with lock:
|
with lock:
|
||||||
for name, result in completed.items():
|
for name, result in completed.items():
|
||||||
context[name] = result.value
|
ThreadedLayerRunner._store_result(
|
||||||
if result.status == TaskStatus.SUCCESS:
|
name, result, graph, context, report, backend, on_event, context_snapshot
|
||||||
spec = graph.resolved_spec(name)
|
)
|
||||||
task_ctx = _build_context(spec, context_snapshot, report)
|
|
||||||
backend.save(spec.storage_key(task_ctx), result.value)
|
|
||||||
report.results[name] = result
|
|
||||||
_emit(on_event, result)
|
|
||||||
|
|
||||||
|
|
||||||
async def _execute_layer_async(
|
class AsyncLayerRunner(_LayerMixin):
|
||||||
|
"""在事件循环上并发运行某层的任务。"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def execute(
|
||||||
layer: list[str],
|
layer: list[str],
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
context: dict[str, Any],
|
context: dict[str, Any],
|
||||||
@@ -535,55 +547,31 @@ async def _execute_layer_async(
|
|||||||
on_event: EventCallback | None,
|
on_event: EventCallback | None,
|
||||||
concurrency_limits: Mapping[str, int],
|
concurrency_limits: Mapping[str, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""在事件循环上并发运行某层的任务。"""
|
to_run = AsyncLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event)
|
||||||
to_run: list[str] = []
|
|
||||||
for name in layer:
|
|
||||||
spec = graph.resolved_spec(name)
|
|
||||||
if _apply_cached(name, spec, context, report, backend, on_event):
|
|
||||||
continue
|
|
||||||
to_run.append(name)
|
|
||||||
|
|
||||||
if not to_run:
|
if not to_run:
|
||||||
return
|
return
|
||||||
|
semaphores = AsyncLayerRunner._build_semaphores(to_run, graph, asyncio.Semaphore, concurrency_limits)
|
||||||
to_run = _sort_by_priority(to_run, graph)
|
|
||||||
|
|
||||||
# 为每个 concurrency_key 创建异步信号量
|
|
||||||
semaphores: dict[str, asyncio.Semaphore] = {}
|
|
||||||
for name in to_run:
|
|
||||||
spec = graph.resolved_spec(name)
|
|
||||||
key = spec.concurrency_key
|
|
||||||
if key is not None and key not in semaphores:
|
|
||||||
limit = concurrency_limits.get(key, 1)
|
|
||||||
semaphores[key] = asyncio.Semaphore(limit)
|
|
||||||
|
|
||||||
context_snapshot = dict(context)
|
context_snapshot = dict(context)
|
||||||
|
|
||||||
async def _run_async_task_wrapped(name: str) -> TaskResult[Any]:
|
async def _run_async_task(name: str) -> TaskResult[Any]:
|
||||||
spec = graph.resolved_spec(name)
|
spec = graph.resolved_spec(name)
|
||||||
task_ctx = _build_context(spec, context_snapshot, report)
|
task_ctx = _build_context(spec, context_snapshot, report)
|
||||||
sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None
|
sem = AsyncLayerRunner._get_sem(semaphores, spec)
|
||||||
if sem is not None:
|
return await AsyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report, sem)
|
||||||
async with sem:
|
|
||||||
return await _run_async_with_retry(spec, task_ctx, layer_idx, on_event, report)
|
|
||||||
return await _run_async_with_retry(spec, task_ctx, layer_idx, on_event, report)
|
|
||||||
|
|
||||||
coros = [_run_async_task_wrapped(name) for name in to_run]
|
results = await asyncio.gather(*[_run_async_task(name) for name in to_run])
|
||||||
results = await asyncio.gather(*coros)
|
|
||||||
for name, result in zip(to_run, results):
|
for name, result in zip(to_run, results):
|
||||||
context[name] = result.value
|
AsyncLayerRunner._store_result(name, result, graph, context, report, backend, on_event, context_snapshot)
|
||||||
if result.status == TaskStatus.SUCCESS:
|
|
||||||
spec = graph.resolved_spec(name)
|
|
||||||
task_ctx = _build_context(spec, context_snapshot, report)
|
|
||||||
backend.save(spec.storage_key(task_ctx), result.value)
|
|
||||||
report.results[name] = result
|
|
||||||
_emit(on_event, result)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
class DependencyRunner(_LayerMixin):
|
||||||
# 依赖驱动调度
|
"""依赖驱动调度:任务在硬/软依赖完成后立即启动,无层屏障。
|
||||||
# ---------------------------------------------------------------------- #
|
|
||||||
async def _drive_dependency_async(
|
所有任务通过 asyncio 并发调度。同步任务卸载到线程池。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def execute(
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
context: dict[str, Any],
|
context: dict[str, Any],
|
||||||
report: RunReport,
|
report: RunReport,
|
||||||
@@ -591,19 +579,8 @@ async def _drive_dependency_async(
|
|||||||
on_event: EventCallback | None,
|
on_event: EventCallback | None,
|
||||||
concurrency_limits: Mapping[str, int],
|
concurrency_limits: Mapping[str, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""依赖驱动调度:任务在硬依赖完成后立即启动,无层屏障。
|
all_names = list(graph.all_specs().keys())
|
||||||
|
semaphores = DependencyRunner._build_semaphores(all_names, graph, asyncio.Semaphore, concurrency_limits)
|
||||||
所有任务通过 asyncio 并发调度。同步任务卸载到线程池。
|
|
||||||
"""
|
|
||||||
all_names = set(graph.all_specs().keys())
|
|
||||||
semaphores: dict[str, asyncio.Semaphore] = {}
|
|
||||||
for name in all_names:
|
|
||||||
spec = graph.resolved_spec(name)
|
|
||||||
key = spec.concurrency_key
|
|
||||||
if key is not None and key not in semaphores:
|
|
||||||
limit = concurrency_limits.get(key, 1)
|
|
||||||
semaphores[key] = asyncio.Semaphore(limit)
|
|
||||||
|
|
||||||
futures: dict[str, asyncio.Future[TaskResult[Any]]] = {}
|
futures: dict[str, asyncio.Future[TaskResult[Any]]] = {}
|
||||||
|
|
||||||
async def _run_task(name: str) -> TaskResult[Any]:
|
async def _run_task(name: str) -> TaskResult[Any]:
|
||||||
@@ -621,24 +598,14 @@ async def _drive_dependency_async(
|
|||||||
if _apply_cached(name, spec, context, report, backend, on_event):
|
if _apply_cached(name, spec, context, report, backend, on_event):
|
||||||
return report.results[name]
|
return report.results[name]
|
||||||
|
|
||||||
sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None
|
sem = DependencyRunner._get_sem(semaphores, spec)
|
||||||
if sem is not None:
|
result = await AsyncTaskRunner.run(spec, task_ctx, None, on_event, report, sem)
|
||||||
async with sem:
|
DependencyRunner._store_result(name, result, graph, context, report, backend, on_event)
|
||||||
result = await _run_async_with_retry(spec, task_ctx, None, on_event, report)
|
|
||||||
else:
|
|
||||||
result = await _run_async_with_retry(spec, task_ctx, None, on_event, report)
|
|
||||||
|
|
||||||
context[name] = result.value
|
|
||||||
if result.status == TaskStatus.SUCCESS:
|
|
||||||
backend.save(spec.storage_key(task_ctx), result.value)
|
|
||||||
report.results[name] = result
|
|
||||||
_emit(on_event, result)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
for name in all_names:
|
for name in all_names:
|
||||||
futures[name] = loop.create_task(_run_task(name))
|
futures[name] = loop.create_task(_run_task(name))
|
||||||
|
|
||||||
await asyncio.gather(*futures.values())
|
await asyncio.gather(*futures.values())
|
||||||
|
|
||||||
|
|
||||||
@@ -729,9 +696,9 @@ def run(
|
|||||||
elif strategy == "thread":
|
elif strategy == "thread":
|
||||||
_drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers, limits)
|
_drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers, limits)
|
||||||
elif strategy == "async":
|
elif strategy == "async":
|
||||||
_drive_async(graph, layers, context, report, backend, effective_callback, limits)
|
asyncio.run(_async_drive(graph, layers, context, report, backend, effective_callback, limits))
|
||||||
elif strategy == "dependency":
|
elif strategy == "dependency":
|
||||||
asyncio.run(_drive_dependency_async(graph, context, report, backend, effective_callback, limits))
|
asyncio.run(DependencyRunner.execute(graph, context, report, backend, effective_callback, limits))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown strategy: {strategy!r}")
|
raise ValueError(f"Unknown strategy: {strategy!r}")
|
||||||
except TaskFailedError:
|
except TaskFailedError:
|
||||||
@@ -759,7 +726,7 @@ def _drive_sequential(
|
|||||||
on_event: EventCallback | None,
|
on_event: EventCallback | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
for idx, layer in enumerate(layers, 1):
|
for idx, layer in enumerate(layers, 1):
|
||||||
_execute_layer_sequential(layer, graph, context, report, backend, idx, on_event)
|
SequentialLayerRunner.execute(layer, graph, context, report, backend, idx, on_event)
|
||||||
|
|
||||||
|
|
||||||
def _drive_threaded(
|
def _drive_threaded(
|
||||||
@@ -774,19 +741,7 @@ def _drive_threaded(
|
|||||||
) -> None:
|
) -> None:
|
||||||
for idx, layer in enumerate(layers, 1):
|
for idx, layer in enumerate(layers, 1):
|
||||||
workers = max_workers or max(1, min(32, len(layer)))
|
workers = max_workers or max(1, min(32, len(layer)))
|
||||||
_execute_layer_threaded(layer, graph, context, report, backend, idx, on_event, workers, concurrency_limits)
|
ThreadedLayerRunner.execute(layer, graph, context, report, backend, idx, on_event, workers, concurrency_limits)
|
||||||
|
|
||||||
|
|
||||||
def _drive_async(
|
|
||||||
graph: Graph,
|
|
||||||
layers: list[list[str]],
|
|
||||||
context: dict[str, Any],
|
|
||||||
report: RunReport,
|
|
||||||
backend: StateBackend,
|
|
||||||
on_event: EventCallback | None,
|
|
||||||
concurrency_limits: Mapping[str, int],
|
|
||||||
) -> None:
|
|
||||||
asyncio.run(_async_drive(graph, layers, context, report, backend, on_event, concurrency_limits))
|
|
||||||
|
|
||||||
|
|
||||||
async def _async_drive(
|
async def _async_drive(
|
||||||
@@ -799,4 +754,4 @@ async def _async_drive(
|
|||||||
concurrency_limits: Mapping[str, int],
|
concurrency_limits: Mapping[str, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
for idx, layer in enumerate(layers, 1):
|
for idx, layer in enumerate(layers, 1):
|
||||||
await _execute_layer_async(layer, graph, context, report, backend, idx, on_event, concurrency_limits)
|
await AsyncLayerRunner.execute(layer, graph, context, report, backend, idx, on_event, concurrency_limits)
|
||||||
|
|||||||
+19
-16
@@ -49,6 +49,15 @@ class GraphDefaults:
|
|||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def _prune_deps(spec: TaskSpec[Any], keep: Callable[[str], bool]) -> TaskSpec[Any]:
|
||||||
|
"""返回新 spec,其 ``depends_on`` / ``soft_depends_on`` 仅保留 ``keep(dep)`` 为真的依赖。"""
|
||||||
|
return replace(
|
||||||
|
spec,
|
||||||
|
depends_on=tuple(d for d in spec.depends_on if keep(d)),
|
||||||
|
soft_depends_on=tuple(d for d in spec.soft_depends_on if keep(d)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Graph:
|
class Graph:
|
||||||
"""校验后的有向无环任务图。
|
"""校验后的有向无环任务图。
|
||||||
@@ -225,16 +234,13 @@ class Graph:
|
|||||||
def subgraph(self, tags: Iterable[str]) -> Graph:
|
def subgraph(self, tags: Iterable[str]) -> Graph:
|
||||||
"""返回仅包含匹配任意标签的任务的新图。依赖边被修剪。"""
|
"""返回仅包含匹配任意标签的任务的新图。依赖边被修剪。"""
|
||||||
wanted: set[str] = set(tags)
|
wanted: set[str] = set(tags)
|
||||||
kept: list[TaskSpec[Any]] = []
|
|
||||||
for spec in self.specs.values():
|
def _dep_kept(dep: str) -> bool:
|
||||||
if wanted & set(spec.tags):
|
return dep in self.specs and bool(wanted & set(self.specs[dep].tags))
|
||||||
pruned_deps = tuple(
|
|
||||||
d for d in spec.depends_on if d in self.specs and (wanted & set(self.specs[d].tags))
|
kept: list[TaskSpec[Any]] = [
|
||||||
)
|
_prune_deps(spec, _dep_kept) for spec in self.specs.values() if wanted & set(spec.tags)
|
||||||
pruned_soft = tuple(
|
]
|
||||||
d for d in spec.soft_depends_on if d in self.specs and (wanted & set(self.specs[d].tags))
|
|
||||||
)
|
|
||||||
kept.append(replace(spec, depends_on=pruned_deps, soft_depends_on=pruned_soft))
|
|
||||||
return Graph.from_specs(kept, defaults=self.defaults)
|
return Graph.from_specs(kept, defaults=self.defaults)
|
||||||
|
|
||||||
def subgraph_by_names(self, names: Iterable[str]) -> Graph:
|
def subgraph_by_names(self, names: Iterable[str]) -> Graph:
|
||||||
@@ -243,12 +249,9 @@ class Graph:
|
|||||||
for n in wanted:
|
for n in wanted:
|
||||||
if n not in self.specs:
|
if n not in self.specs:
|
||||||
raise KeyError(f"Unknown task name: {n!r}")
|
raise KeyError(f"Unknown task name: {n!r}")
|
||||||
kept: list[TaskSpec[Any]] = []
|
kept: list[TaskSpec[Any]] = [
|
||||||
for spec in self.specs.values():
|
_prune_deps(spec, lambda d: d in wanted) for spec in self.specs.values() if spec.name in wanted
|
||||||
if spec.name in wanted:
|
]
|
||||||
pruned_deps = tuple(d for d in spec.depends_on if d in wanted)
|
|
||||||
pruned_soft = tuple(d for d in spec.soft_depends_on if d in wanted)
|
|
||||||
kept.append(replace(spec, depends_on=pruned_deps, soft_depends_on=pruned_soft))
|
|
||||||
return Graph.from_specs(kept, defaults=self.defaults)
|
return Graph.from_specs(kept, defaults=self.defaults)
|
||||||
|
|
||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
|
|||||||
+110
-40
@@ -17,6 +17,7 @@ import json
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Iterator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Mapping
|
from typing import Any, Mapping
|
||||||
|
|
||||||
@@ -55,7 +56,74 @@ class StateBackend(ABC):
|
|||||||
"""清除所有存储状态。"""
|
"""清除所有存储状态。"""
|
||||||
|
|
||||||
|
|
||||||
class MemoryBackend(StateBackend):
|
class _TTLStateBackendMixin(StateBackend):
|
||||||
|
"""TTL 状态后端共享逻辑。
|
||||||
|
|
||||||
|
将 ``has`` / ``get`` / ``load`` / ``save`` / ``clear`` 的统一实现
|
||||||
|
委托给四个原始存取原语::meth:`_get_raw`、:meth:`_put_raw`、
|
||||||
|
:meth:`_iter_raw`、:meth:`_clear_raw`,并基于 :meth:`_now` 与
|
||||||
|
``self._ttl`` 提供统一的过期判断 :meth:`_is_expired`。
|
||||||
|
|
||||||
|
子类需设置 ``self._ttl`` 并实现上述四个原语;如需自定义时间源
|
||||||
|
(如 ``time.monotonic``)可覆盖 :meth:`_now`。
|
||||||
|
"""
|
||||||
|
|
||||||
|
_ttl: float | None
|
||||||
|
|
||||||
|
# ---- 原语:由子类实现 ---- #
|
||||||
|
@abstractmethod
|
||||||
|
def _get_raw(self, key: str) -> tuple[Any, float] | None:
|
||||||
|
"""返回 ``(value, ts)``;键不存在时返回 ``None``。"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _put_raw(self, key: str, value: Any, ts: float) -> None:
|
||||||
|
"""写入一条记录。"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _iter_raw(self) -> Iterator[tuple[str, Any, float]]:
|
||||||
|
"""迭代所有记录(不做过期过滤),yield ``(key, value, ts)``。"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _clear_raw(self) -> None:
|
||||||
|
"""清空所有记录。"""
|
||||||
|
|
||||||
|
# ---- 共享实现 ---- #
|
||||||
|
def _now(self) -> float:
|
||||||
|
"""当前时间戳,默认为 wall-clock 秒。"""
|
||||||
|
return time.time()
|
||||||
|
|
||||||
|
def _is_expired(self, ts: float) -> bool:
|
||||||
|
"""时间戳 ``ts`` 是否已过期。"""
|
||||||
|
if self._ttl is None:
|
||||||
|
return False
|
||||||
|
return (self._now() - ts) > self._ttl
|
||||||
|
|
||||||
|
@override
|
||||||
|
def load(self) -> Mapping[str, Any]:
|
||||||
|
return {k: v for k, v, ts in self._iter_raw() if not self._is_expired(ts)}
|
||||||
|
|
||||||
|
@override
|
||||||
|
def save(self, key: str, value: Any) -> None:
|
||||||
|
self._put_raw(key, value, self._now())
|
||||||
|
|
||||||
|
@override
|
||||||
|
def has(self, key: str) -> bool:
|
||||||
|
entry = self._get_raw(key)
|
||||||
|
return entry is not None and not self._is_expired(entry[1])
|
||||||
|
|
||||||
|
@override
|
||||||
|
def get(self, key: str) -> Any:
|
||||||
|
entry = self._get_raw(key)
|
||||||
|
if entry is None or self._is_expired(entry[1]):
|
||||||
|
raise KeyError(key)
|
||||||
|
return entry[0]
|
||||||
|
|
||||||
|
@override
|
||||||
|
def clear(self) -> None:
|
||||||
|
self._clear_raw()
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryBackend(_TTLStateBackendMixin):
|
||||||
"""进程内 dict 后端。进程退出即丢失。
|
"""进程内 dict 后端。进程退出即丢失。
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -70,35 +138,35 @@ class MemoryBackend(StateBackend):
|
|||||||
self._ttl = ttl
|
self._ttl = ttl
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def load(self) -> Mapping[str, Any]:
|
def _now(self) -> float:
|
||||||
return {k: v for k, (v, _ts) in self._store.items() if not self._expired(k)}
|
return time.monotonic()
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def save(self, key: str, value: Any) -> None:
|
def _get_raw(self, key: str) -> tuple[Any, float] | None:
|
||||||
self._store[key] = (value, time.monotonic())
|
return self._store.get(key)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def has(self, key: str) -> bool:
|
def _put_raw(self, key: str, value: Any, ts: float) -> None:
|
||||||
return key in self._store and not self._expired(key)
|
self._store[key] = (value, ts)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def get(self, key: str) -> Any:
|
def _iter_raw(self) -> Iterator[tuple[str, Any, float]]:
|
||||||
if key not in self._store or self._expired(key):
|
for k, (v, ts) in self._store.items():
|
||||||
raise KeyError(key)
|
yield k, v, ts
|
||||||
return self._store[key][0]
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def clear(self) -> None:
|
def _clear_raw(self) -> None:
|
||||||
self._store.clear()
|
self._store.clear()
|
||||||
|
|
||||||
def _expired(self, key: str) -> bool:
|
def _expired(self, key: str) -> bool:
|
||||||
if self._ttl is None or key not in self._store:
|
"""键是否已过期(兼容旧测试 API)。"""
|
||||||
|
entry = self._get_raw(key)
|
||||||
|
if entry is None:
|
||||||
return False
|
return False
|
||||||
_value, ts = self._store[key]
|
return self._is_expired(entry[1])
|
||||||
return (time.monotonic() - ts) > self._ttl
|
|
||||||
|
|
||||||
|
|
||||||
class JSONBackend(StateBackend):
|
class JSONBackend(_TTLStateBackendMixin):
|
||||||
"""基于文件的 JSON 存储,用于跨进程续跑。
|
"""基于文件的 JSON 存储,用于跨进程续跑。
|
||||||
|
|
||||||
存储格式:``{key: {"value": v, "ts": epoch_seconds}}``。
|
存储格式:``{key: {"value": v, "ts": epoch_seconds}}``。
|
||||||
@@ -144,17 +212,30 @@ class JSONBackend(StateBackend):
|
|||||||
except (OSError, TypeError) as exc:
|
except (OSError, TypeError) as exc:
|
||||||
raise StorageError(f"cannot write state file {self._path!r}", exc) from exc
|
raise StorageError(f"cannot write state file {self._path!r}", exc) from exc
|
||||||
|
|
||||||
def _now(self) -> float:
|
@override
|
||||||
return time.time()
|
def _get_raw(self, key: str) -> tuple[Any, float] | None:
|
||||||
|
entry = self._store.get(key)
|
||||||
def _expired(self, entry: dict[str, Any]) -> bool:
|
if entry is None:
|
||||||
if self._ttl is None:
|
return None
|
||||||
return False
|
return entry["value"], float(entry.get("ts", 0))
|
||||||
return (self._now() - float(entry.get("ts", 0))) > self._ttl
|
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def load(self) -> Mapping[str, Any]:
|
def _put_raw(self, key: str, value: Any, ts: float) -> None:
|
||||||
return {k: v["value"] for k, v in self._store.items() if not self._expired(v)}
|
self._store[key] = {"value": value, "ts": ts}
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _iter_raw(self) -> Iterator[tuple[str, Any, float]]:
|
||||||
|
for k, entry in self._store.items():
|
||||||
|
yield k, entry["value"], float(entry.get("ts", 0))
|
||||||
|
|
||||||
|
@override
|
||||||
|
def _clear_raw(self) -> None:
|
||||||
|
self._store.clear()
|
||||||
|
|
||||||
|
@override
|
||||||
|
def clear(self) -> None:
|
||||||
|
super().clear()
|
||||||
|
self._flush()
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def save(self, key: str, value: Any) -> None:
|
def save(self, key: str, value: Any) -> None:
|
||||||
@@ -162,23 +243,12 @@ class JSONBackend(StateBackend):
|
|||||||
_ = json.dumps(value)
|
_ = json.dumps(value)
|
||||||
except (TypeError, ValueError) as exc:
|
except (TypeError, ValueError) as exc:
|
||||||
raise StorageError(f"result of key {key!r} is not JSON-serialisable", exc) from exc
|
raise StorageError(f"result of key {key!r} is not JSON-serialisable", exc) from exc
|
||||||
self._store[key] = {"value": value, "ts": self._now()}
|
super().save(key, value)
|
||||||
self._flush()
|
self._flush()
|
||||||
|
|
||||||
@override
|
def _expired(self, entry: Mapping[str, Any]) -> bool:
|
||||||
def has(self, key: str) -> bool:
|
"""带元数据的条目是否已过期(兼容旧测试 API)。"""
|
||||||
return key in self._store and not self._expired(self._store[key])
|
return self._is_expired(float(entry.get("ts", 0)))
|
||||||
|
|
||||||
@override
|
|
||||||
def get(self, key: str) -> Any:
|
|
||||||
if key not in self._store or self._expired(self._store[key]):
|
|
||||||
raise KeyError(key)
|
|
||||||
return self._store[key]["value"]
|
|
||||||
|
|
||||||
@override
|
|
||||||
def clear(self) -> None:
|
|
||||||
self._store.clear()
|
|
||||||
self._flush()
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_backend(backend: StateBackend | None) -> StateBackend:
|
def resolve_backend(backend: StateBackend | None) -> StateBackend:
|
||||||
|
|||||||
+10
-3
@@ -74,6 +74,13 @@ Condition = Callable[[Context], bool]
|
|||||||
CacheKeyFn = Callable[[Context], str]
|
CacheKeyFn = Callable[[Context], str]
|
||||||
|
|
||||||
|
|
||||||
|
def _format_skip_reason(failed_conditions: list[str]) -> str:
|
||||||
|
"""格式化跳过原因:≤2 个全展示,>2 个仅展示前 2 个并附总数。"""
|
||||||
|
if len(failed_conditions) <= 2:
|
||||||
|
return f"条件不满足: {', '.join(failed_conditions)}"
|
||||||
|
return f"条件不满足: {', '.join(failed_conditions[:2])} 等{len(failed_conditions)}个条件"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
# 重试策略
|
# 重试策略
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
@@ -315,6 +322,7 @@ class TaskSpec(Generic[T]):
|
|||||||
-------
|
-------
|
||||||
(should_run, skip_reason)
|
(should_run, skip_reason)
|
||||||
``should_run`` 为 False 时 ``skip_reason`` 描述跳过原因。
|
``should_run`` 为 False 时 ``skip_reason`` 描述跳过原因。
|
||||||
|
失败条件超过 2 个时仅展示前 2 个并附总数。
|
||||||
"""
|
"""
|
||||||
# 逐个求值条件,记录失败项。
|
# 逐个求值条件,记录失败项。
|
||||||
failed_conditions: list[str] = []
|
failed_conditions: list[str] = []
|
||||||
@@ -323,8 +331,7 @@ class TaskSpec(Generic[T]):
|
|||||||
ok = condition(context)
|
ok = condition(context)
|
||||||
except Exception:
|
except Exception:
|
||||||
ok = False
|
ok = False
|
||||||
name = getattr(condition, "__name__", None) or "匿名条件(执行错误)"
|
failed_conditions.append("匿名条件(执行错误)")
|
||||||
failed_conditions.append(name)
|
|
||||||
continue
|
continue
|
||||||
if not ok:
|
if not ok:
|
||||||
reason = getattr(condition, "_reason", None)
|
reason = getattr(condition, "_reason", None)
|
||||||
@@ -336,7 +343,7 @@ class TaskSpec(Generic[T]):
|
|||||||
failed_conditions.append(getattr(condition, "__name__", None) or "匿名条件")
|
failed_conditions.append(getattr(condition, "__name__", None) or "匿名条件")
|
||||||
|
|
||||||
if failed_conditions:
|
if failed_conditions:
|
||||||
return False, f"条件不满足: {', '.join(failed_conditions)}"
|
return False, _format_skip_reason(failed_conditions)
|
||||||
|
|
||||||
if self.skip_if_missing and not self._is_cmd_available():
|
if self.skip_if_missing and not self._is_cmd_available():
|
||||||
cmd_name = self.cmd[0] if isinstance(self.cmd, list) and self.cmd else "unknown"
|
cmd_name = self.cmd[0] if isinstance(self.cmd, list) and self.cmd else "unknown"
|
||||||
|
|||||||
@@ -0,0 +1,70 @@
|
|||||||
|
"""常用工具函数."""
|
||||||
|
|
||||||
|
__all__ = ["perf_timer"]
|
||||||
|
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Callable, ParamSpec, TypedDict
|
||||||
|
|
||||||
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
|
class _PerformanceMetrics(TypedDict):
|
||||||
|
"""性能指标."""
|
||||||
|
|
||||||
|
count: int
|
||||||
|
total_time: float
|
||||||
|
|
||||||
|
|
||||||
|
_perf_metrics: defaultdict[str, _PerformanceMetrics] = defaultdict(
|
||||||
|
lambda: _PerformanceMetrics(
|
||||||
|
count=0,
|
||||||
|
total_time=0.0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def perf_timer(unit: str = "ms", precision: int = 4, report: bool = False):
|
||||||
|
"""性能计时器装饰器."""
|
||||||
|
scale: dict[str, float] = {
|
||||||
|
"s": 1.0,
|
||||||
|
"ms": 1000.0,
|
||||||
|
"us": 1000000.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||||
|
@functools.wraps(func)
|
||||||
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
|
start_time = time.time()
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
|
_perf_metrics[func.__name__]["count"] += 1
|
||||||
|
_perf_metrics[func.__name__]["total_time"] += (end_time - start_time) * scale[unit]
|
||||||
|
|
||||||
|
if not report:
|
||||||
|
logging.info(
|
||||||
|
f"{func.__name__} {unit}: {_perf_metrics[func.__name__]['total_time']:.{precision}f}{unit}"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
if report:
|
||||||
|
import atexit
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logging.info(f"Performance metrics report enabled with unit {unit} and precision {precision}")
|
||||||
|
|
||||||
|
@atexit.register
|
||||||
|
def _() -> None:
|
||||||
|
for name, metrics in _perf_metrics.items():
|
||||||
|
logging.info(f"{name}: {metrics['count']} times, {metrics['total_time']:.{precision}f}{unit}")
|
||||||
|
|
||||||
|
return decorator
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
"""Tests for cli.which module."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import shutil
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import pyflowx as px
|
|
||||||
from pyflowx.cli import which
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
|
||||||
# main function
|
|
||||||
# ---------------------------------------------------------------------- #
|
|
||||||
class TestMain:
|
|
||||||
"""Test main function."""
|
|
||||||
|
|
||||||
def test_main_with_single_command(self) -> None:
|
|
||||||
"""main() should handle single command argument."""
|
|
||||||
with patch("sys.argv", ["which", "python"]), patch.object(
|
|
||||||
shutil, "which", return_value="/usr/bin/python"
|
|
||||||
), patch.object(px, "run") as mock_run:
|
|
||||||
which.main()
|
|
||||||
# Should create a graph with one task
|
|
||||||
assert mock_run.called
|
|
||||||
graph = mock_run.call_args[0][0]
|
|
||||||
assert isinstance(graph, px.Graph)
|
|
||||||
|
|
||||||
def test_main_with_multiple_commands(self) -> None:
|
|
||||||
"""main() should handle multiple command arguments."""
|
|
||||||
with patch("sys.argv", ["which", "python", "pip", "node"]), patch.object(
|
|
||||||
shutil, "which", return_value="/usr/bin/cmd"
|
|
||||||
), patch.object(px, "run") as mock_run:
|
|
||||||
which.main()
|
|
||||||
# Should create a graph with three tasks
|
|
||||||
assert mock_run.called
|
|
||||||
graph = mock_run.call_args[0][0]
|
|
||||||
assert isinstance(graph, px.Graph)
|
|
||||||
|
|
||||||
def test_main_with_no_args_shows_help(self) -> None:
|
|
||||||
"""main() with no args should show help and exit."""
|
|
||||||
with patch("sys.argv", ["which"]), pytest.raises(SystemExit) as exc_info:
|
|
||||||
which.main()
|
|
||||||
assert exc_info.value.code == 2
|
|
||||||
|
|
||||||
def test_main_creates_task_specs_with_correct_names(self) -> None:
|
|
||||||
"""main() should create TaskSpecs with correct names."""
|
|
||||||
with patch("sys.argv", ["which", "git", "npm"]), patch.object(
|
|
||||||
shutil, "which", return_value="/usr/bin/cmd"
|
|
||||||
), patch.object(px, "run") as mock_run:
|
|
||||||
which.main()
|
|
||||||
graph = mock_run.call_args[0][0]
|
|
||||||
# Check that task names are correct
|
|
||||||
task_names = list(graph.all_specs().keys())
|
|
||||||
assert "which_git" in task_names
|
|
||||||
assert "which_npm" in task_names
|
|
||||||
|
|
||||||
def test_main_uses_thread_strategy(self) -> None:
|
|
||||||
"""main() should use thread strategy."""
|
|
||||||
with patch("sys.argv", ["which", "python"]), patch.object(
|
|
||||||
shutil, "which", return_value="/usr/bin/python"
|
|
||||||
), patch.object(px, "run") as mock_run:
|
|
||||||
which.main()
|
|
||||||
assert mock_run.call_args[1]["strategy"] == "thread"
|
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
|
from pyflowx.utils import _perf_metrics, perf_timer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_perf_metrics():
|
||||||
|
"""重置性能指标."""
|
||||||
|
_perf_metrics.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerformanceTimer:
|
||||||
|
def test_perf_timer(self):
|
||||||
|
|
||||||
|
@perf_timer()
|
||||||
|
def test_func():
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
test_func()
|
||||||
|
|
||||||
|
assert _perf_metrics["test_func"] is not None
|
||||||
|
assert _perf_metrics["test_func"]["count"] == 1
|
||||||
|
assert _perf_metrics["test_func"]["total_time"] >= 0.1
|
||||||
|
|
||||||
|
def test_perf_timer_report(self, mocker: MockerFixture):
|
||||||
|
mock_log = mocker.patch("logging.info")
|
||||||
|
|
||||||
|
@perf_timer(report=True, unit="ms", precision=3)
|
||||||
|
def test_func():
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
test_func()
|
||||||
|
|
||||||
|
assert _perf_metrics["test_func"] is not None
|
||||||
|
assert _perf_metrics["test_func"]["count"] == 1
|
||||||
|
assert _perf_metrics["test_func"]["total_time"] >= 0.1
|
||||||
|
|
||||||
|
assert mock_log.call_count == 1
|
||||||
@@ -5603,7 +5603,7 @@ pycountry = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyflowx"
|
name = "pyflowx"
|
||||||
version = "0.2.9"
|
version = "0.2.10"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "graphlib-backport", marker = "python_full_version < '3.9'" },
|
{ name = "graphlib-backport", marker = "python_full_version < '3.9'" },
|
||||||
|
|||||||
Reference in New Issue
Block a user