refactor: 重构项目代码结构,拆分职责模块
1. 抽离图组合逻辑到pyflowx.compose,原graph.py仅保留单图DAG逻辑 2. 抽离命令执行逻辑到pyflowx.command,移除task.py内的_run_command 3. 重构上下文签名缓存,优化性能 4. 移除废弃的utils.perf_timer相关代码 5. 为JSONBackend添加batch批量落盘优化 6. 调整导入路径与公开API,更新测试用例 7. 简化条件判断逻辑,移除冗余代码
This commit is contained in:
@@ -58,6 +58,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .command import run_command
|
||||
from .compose import GraphComposer, compose
|
||||
from .conditions import (
|
||||
IS_LINUX,
|
||||
IS_MACOS,
|
||||
@@ -79,7 +81,7 @@ from .errors import (
|
||||
TaskTimeoutError,
|
||||
)
|
||||
from .executors import Strategy, run
|
||||
from .graph import Graph, GraphComposer, GraphDefaults, compose
|
||||
from .graph import Graph, GraphDefaults
|
||||
from .report import RunReport
|
||||
from .runner import CliExitCode, CliRunner
|
||||
from .storage import JSONBackend, MemoryBackend, StateBackend
|
||||
@@ -136,5 +138,6 @@ __all__ = [
|
||||
"compose",
|
||||
"describe_injection",
|
||||
"run",
|
||||
"run_command",
|
||||
"task_template",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
"""命令执行器:把 :class:`~pyflowx.task.TaskSpec` 的 ``cmd`` 字段(list /
|
||||
shell 字符串 / 可调用对象)转换为统一执行入口。
|
||||
|
||||
历史背景:原 ``task.py`` 的模块文档声明其为"纯数据结构",但 ``_run_command``
|
||||
属于命令执行逻辑,违反单一职责。此处将其抽离,``TaskSpec`` 仅持有配置,
|
||||
执行逻辑集中于本模块,便于独立测试与维护。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Any, List, Union, cast
|
||||
|
||||
from .task import TaskSpec
|
||||
|
||||
__all__ = ["run_command"]
|
||||
|
||||
|
||||
def run_command(spec: TaskSpec[Any]) -> Any: # noqa: PLR0912
|
||||
"""执行 ``spec.cmd`` 指定的命令(list / shell 字符串 / 可调用对象)。
|
||||
|
||||
与原 ``TaskSpec._run_command`` 行为一致:
|
||||
|
||||
- 可调用对象:直接调用,异常包装为 :class:`RuntimeError`。
|
||||
- list / str:通过 :func:`subprocess.run` 执行,非零返回码抛
|
||||
:class:`RuntimeError`(``verbose=False`` 时附 stderr)。
|
||||
- ``verbose=True`` 时打印执行信息与返回码到 stdout。
|
||||
- ``cwd`` / ``env`` 通过 subprocess 参数隔离(进程级状态仅在 fn 任务路径
|
||||
使用,cmd 路径不依赖 ``os.chdir`` / ``os.environ``)。
|
||||
"""
|
||||
cmd = spec.cmd
|
||||
verbose = spec.verbose
|
||||
cwd = spec.cwd
|
||||
timeout = spec.timeout
|
||||
env_override = spec.env
|
||||
|
||||
# 可调用对象:直接调用,返回其结果。
|
||||
if callable(cmd) and not isinstance(cmd, (list, str)):
|
||||
name = getattr(cmd, "__name__", "callable")
|
||||
if verbose:
|
||||
print(f"[verbose] 执行可调用命令: {name}", flush=True)
|
||||
if cwd is not None:
|
||||
print(f"[verbose] 工作目录: {cwd}", flush=True)
|
||||
try:
|
||||
return cmd()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"可调用命令执行异常: {name}: {e}") from e
|
||||
|
||||
is_list = isinstance(cmd, list)
|
||||
if is_list:
|
||||
cmd_str = " ".join(arg for arg in cmd) # type: ignore[union-attr]
|
||||
verb = "执行命令"
|
||||
label = "命令"
|
||||
else:
|
||||
cmd_str = cast(str, cmd)
|
||||
verb = "执行 Shell"
|
||||
label = "Shell 命令"
|
||||
|
||||
if verbose:
|
||||
print(f"[verbose] {verb}: {cmd_str}", flush=True)
|
||||
if cwd is not None:
|
||||
print(f"[verbose] 工作目录: {cwd}", flush=True)
|
||||
|
||||
# 合并环境变量
|
||||
run_env: dict[str, str] | None = None
|
||||
if env_override:
|
||||
run_env = dict(os.environ)
|
||||
run_env.update(env_override)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cast(Union[str, List[str]], cmd),
|
||||
shell=not is_list,
|
||||
cwd=cwd,
|
||||
env=run_env,
|
||||
timeout=timeout,
|
||||
capture_output=not verbose,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(f"{label}未找到: {cmd_str}") from None
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(f"{label}执行超时: {cmd_str} ({timeout}s)") from None
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"{label}执行异常: {cmd_str}: {e}") from e
|
||||
|
||||
if verbose:
|
||||
print(f"[verbose] 返回码: {result.returncode}", flush=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
return None
|
||||
|
||||
err_msg = f"{label}执行失败: `{cmd_str}`, 返回码: {result.returncode}"
|
||||
if not verbose and result.stderr.strip():
|
||||
err_msg += f"\n{result.stderr.strip()}"
|
||||
raise RuntimeError(err_msg)
|
||||
@@ -0,0 +1,115 @@
|
||||
"""图组合:将带字符串引用的多个图展开为纯 :class:`~pyflowx.graph.Graph`。
|
||||
|
||||
历史背景:原 ``graph.py`` 同时承载 DAG 构建/校验/分层与多图组合逻辑,
|
||||
职责过载。组合逻辑(:class:`GraphComposer` / :func:`compose`)与单图 DAG
|
||||
模型正交,此处抽离为独立模块,便于按需导入与独立演进。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import replace
|
||||
from typing import Any
|
||||
|
||||
from .graph import Graph
|
||||
from .task import TaskSpec
|
||||
|
||||
__all__ = ["GraphComposer", "compose"]
|
||||
|
||||
|
||||
class GraphComposer:
|
||||
"""将带字符串引用的图展开为纯 :class:`TaskSpec` 图。
|
||||
|
||||
引用格式:
|
||||
* ``"command_name"`` —— 引用整个命令图。
|
||||
* ``"command_name.task_name"`` —— 引用特定任务。
|
||||
|
||||
引用按顺序展开,后续引用的任务依赖前面引用的最后一个任务;
|
||||
原始 ``TaskSpec`` 之间也按出现顺序串行依赖。
|
||||
"""
|
||||
|
||||
def __init__(self, graphs: dict[str, Graph]) -> None:
|
||||
self.graphs = graphs
|
||||
|
||||
def resolve_all(self) -> dict[str, Graph]:
|
||||
"""解析所有图的字符串引用,返回展开后的新图映射。"""
|
||||
resolved: dict[str, Graph] = {}
|
||||
for cmd_name, graph in self.graphs.items():
|
||||
resolved[cmd_name] = self.expand_refs(graph, cmd_name)
|
||||
return resolved
|
||||
|
||||
def expand_refs(self, graph: Graph, current_cmd: str) -> Graph:
|
||||
"""展开图中的字符串引用。若无 ``_pending_refs``,原样返回。"""
|
||||
pending_refs = graph._pending_refs
|
||||
if not pending_refs:
|
||||
return graph
|
||||
|
||||
all_specs: list[TaskSpec[Any]] = []
|
||||
previous_ref_last_task: str | None = None
|
||||
|
||||
for ref in pending_refs:
|
||||
expanded_specs = self.parse_ref(ref, current_cmd)
|
||||
if previous_ref_last_task and expanded_specs:
|
||||
for i, task in enumerate(expanded_specs):
|
||||
if i == 0 or not task.depends_on:
|
||||
expanded_specs[i] = replace(task, depends_on=tuple({*task.depends_on, previous_ref_last_task}))
|
||||
if expanded_specs:
|
||||
previous_ref_last_task = expanded_specs[-1].name
|
||||
all_specs.extend(expanded_specs)
|
||||
|
||||
original_specs = list(graph.all_specs().values())
|
||||
if original_specs:
|
||||
if previous_ref_last_task:
|
||||
first = original_specs[0]
|
||||
all_specs.append(replace(first, depends_on=tuple({*first.depends_on, previous_ref_last_task})))
|
||||
else:
|
||||
all_specs.append(original_specs[0])
|
||||
for i in range(1, len(original_specs)):
|
||||
current_task = original_specs[i]
|
||||
previous_task_name = original_specs[i - 1].name
|
||||
all_specs.append(
|
||||
replace(current_task, depends_on=tuple({*current_task.depends_on, previous_task_name}))
|
||||
)
|
||||
|
||||
return Graph.from_specs(all_specs, defaults=graph.defaults)
|
||||
|
||||
def parse_ref(self, ref: str, current_cmd: str) -> list[TaskSpec[Any]]:
|
||||
"""解析单个字符串引用,返回对应的 TaskSpec 列表。"""
|
||||
if ref == current_cmd:
|
||||
raise ValueError(f"循环引用: 命令 '{current_cmd}' 引用了自己")
|
||||
|
||||
if "." in ref:
|
||||
cmd_name, task_name = ref.split(".", 1)
|
||||
if cmd_name not in self.graphs:
|
||||
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
|
||||
ref_graph = self.graphs[cmd_name]
|
||||
if task_name not in ref_graph.all_specs():
|
||||
raise ValueError(f"任务 '{task_name}' 不存在于命令 '{cmd_name}' 中")
|
||||
return [ref_graph.all_specs()[task_name]]
|
||||
else:
|
||||
cmd_name = ref
|
||||
if cmd_name not in self.graphs:
|
||||
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
|
||||
ref_graph = self.graphs[cmd_name]
|
||||
ref_graph = self.expand_refs(ref_graph, cmd_name)
|
||||
return list(ref_graph.all_specs().values())
|
||||
|
||||
|
||||
def compose(
|
||||
graphs: dict[str, Graph],
|
||||
) -> dict[str, Graph]:
|
||||
"""编程式解析多图的字符串引用,返回展开后的新图映射。
|
||||
|
||||
与 :class:`GraphComposer` 等价,但作为独立函数暴露,供不使用
|
||||
:class:`~pyflowx.runner.CliRunner` 的编程式用户调用。
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> graphs = {
|
||||
... "build": px.Graph.from_specs([px.TaskSpec("b", cmd=["echo", "b"])]),
|
||||
... "all": px.Graph.from_specs(["build", px.TaskSpec("t", cmd=["echo", "t"])]),
|
||||
... }
|
||||
>>> resolved = px.compose(graphs)
|
||||
>>> "b" in resolved["all"].all_specs()
|
||||
True
|
||||
"""
|
||||
return GraphComposer(graphs).resolve_all()
|
||||
@@ -42,14 +42,6 @@ def _static(predicate: Callable[[], bool], name: str) -> Condition:
|
||||
return _cond
|
||||
|
||||
|
||||
def _cond_reason(cond: Condition) -> str | list[str] | None:
|
||||
"""获取条件的失败原因:优先返回 ``_reason``,否则返回 ``__name__``。"""
|
||||
reason = getattr(cond, "_reason", None)
|
||||
if reason is not None:
|
||||
return reason
|
||||
return getattr(cond, "__name__", repr(cond))
|
||||
|
||||
|
||||
def _cond_name(cond: Condition) -> str:
|
||||
"""获取条件的可读名称。"""
|
||||
return getattr(cond, "__name__", repr(cond))
|
||||
@@ -228,13 +220,7 @@ class BuiltinConditions:
|
||||
"""对条件取反."""
|
||||
|
||||
def _cond(ctx: Context) -> bool:
|
||||
result = condition(ctx)
|
||||
if result:
|
||||
# inner 为 True 时 NOT 会失败,记录 inner 的具体原因
|
||||
inner_reason = _cond_reason(condition)
|
||||
if inner_reason is not None:
|
||||
_cond._reason = inner_reason # type: ignore[attr-defined]
|
||||
return not result
|
||||
return not condition(ctx)
|
||||
|
||||
_cond.__name__ = f"NOT({_cond_name(condition)})"
|
||||
return _cond
|
||||
@@ -254,15 +240,7 @@ class BuiltinConditions:
|
||||
"""多个条件的逻辑或."""
|
||||
|
||||
def _cond(ctx: Context) -> bool:
|
||||
matched: list[str] = []
|
||||
for c in conditions:
|
||||
if c(ctx):
|
||||
reason = _cond_reason(c)
|
||||
matched.append(reason if isinstance(reason, str) else str(reason))
|
||||
if matched:
|
||||
_cond._reason = matched # type: ignore[attr-defined]
|
||||
return True
|
||||
return False
|
||||
return any(c(ctx) for c in conditions)
|
||||
|
||||
_cond.__name__ = f"OR({', '.join(_cond_name(c) for c in conditions)})"
|
||||
return _cond
|
||||
|
||||
+21
-2
@@ -16,6 +16,7 @@ DAG 库中泛滥的样板包装器。
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from functools import lru_cache
|
||||
from typing import Any, Mapping
|
||||
|
||||
from .errors import InjectionError
|
||||
@@ -24,6 +25,24 @@ from .task import Context, TaskSpec
|
||||
__all__ = ["Context", "_is_context_annotation", "build_call_args", "describe_injection"]
|
||||
|
||||
|
||||
@lru_cache(maxsize=1024)
|
||||
def _cached_signature(fn: Any) -> inspect.Signature:
|
||||
"""缓存 ``inspect.signature`` 结果(按 fn 对象键控)。
|
||||
|
||||
``fn`` 对象在 :meth:`TaskSpec.effective_fn` 缓存后稳定,签名重复内省
|
||||
属纯开销。对不可哈希的可调用对象,调用方回退到直接内省。
|
||||
"""
|
||||
return inspect.signature(fn)
|
||||
|
||||
|
||||
def _signature(fn: Any) -> inspect.Signature:
|
||||
"""获取签名,优先走缓存;``fn`` 不可哈希时回退到直接内省。"""
|
||||
try:
|
||||
return _cached_signature(fn)
|
||||
except TypeError:
|
||||
return inspect.signature(fn)
|
||||
|
||||
|
||||
def _is_context_annotation(annotation: Any) -> bool:
|
||||
"""判断参数标注是否为(或指向)``Context``。"""
|
||||
if annotation is Context:
|
||||
@@ -44,7 +63,7 @@ def build_call_args(
|
||||
执行器填入 :attr:`TaskSpec.defaults` 中的默认值)。
|
||||
"""
|
||||
fn = spec.effective_fn
|
||||
sig = inspect.signature(fn)
|
||||
sig = _signature(fn)
|
||||
params = sig.parameters
|
||||
|
||||
var_keyword = next(
|
||||
@@ -115,7 +134,7 @@ def build_call_args(
|
||||
def describe_injection(spec: TaskSpec[Any]) -> str:
|
||||
"""生成任务参数注入方式的人类可读描述。供 ``dry_run`` 使用。"""
|
||||
fn = spec.effective_fn
|
||||
sig = inspect.signature(fn)
|
||||
sig = _signature(fn)
|
||||
positional_params = [
|
||||
p
|
||||
for p, param in sig.parameters.items()
|
||||
|
||||
+57
-49
@@ -12,14 +12,17 @@
|
||||
|
||||
架构
|
||||
----
|
||||
本模块通过 **Mixin** 组合消除同步/异步与各层执行器之间的重复代码:
|
||||
本模块通过 **Mixin** 组合消除同步/异步任务执行器之间的重复代码:
|
||||
|
||||
* :class:`_TaskSkipMixin` —— 上游跳过 / 条件跳过的预检逻辑。
|
||||
* :class:`_TaskRetryMixin` —— 重试决策、成功/失败后处理、finalize。
|
||||
* :class:`_LayerMixin` —— 缓存过滤、优先级排序、信号量构建、结果存储。
|
||||
* :class:`SyncTaskRunner` / :class:`AsyncTaskRunner` —— 任务级执行器,组合上述 Mixin。
|
||||
* 模块级共享辅助(:func:`_filter_and_sort` / :func:`_store_result` /
|
||||
:func:`_build_semaphores` / :func:`_get_sem`)—— 缓存过滤、优先级排序、
|
||||
信号量构建、结果存储。
|
||||
* :class:`SequentialLayerRunner` / :class:`ThreadedLayerRunner` /
|
||||
:class:`AsyncLayerRunner` / :class:`DependencyRunner` —— 层级执行器,组合 :class:`_LayerMixin`。
|
||||
:class:`AsyncLayerRunner` —— 层级执行器,调用上述模块级辅助。
|
||||
* :class:`DependencyRunner` —— 依赖驱动调度(非层模型),同样调用模块级辅助。
|
||||
|
||||
所有策略共享统一异步内核,支持:
|
||||
* :class:`RetryPolicy`(max_attempts/delay/backoff/jitter/retry_on)
|
||||
@@ -388,16 +391,8 @@ async def _execute_async_task(
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Mixin:层执行共享逻辑
|
||||
# 共享辅助:缓存过滤、优先级排序、信号量构建、结果存储
|
||||
# ---------------------------------------------------------------------- #
|
||||
class _LayerMixin:
|
||||
"""层执行共享逻辑:缓存过滤、优先级排序、信号量构建、结果存储。
|
||||
|
||||
四个层执行器(sequential/threaded/async/dependency)通过组合此 Mixin
|
||||
消除"过滤缓存→排序→运行→存结果"的样板代码。
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _filter_and_sort(
|
||||
layer: list[str],
|
||||
graph: Graph,
|
||||
@@ -414,27 +409,29 @@ class _LayerMixin:
|
||||
to_run.append(name)
|
||||
return _sort_by_priority(to_run, graph)
|
||||
|
||||
@staticmethod
|
||||
|
||||
def _store_result(
|
||||
name: str,
|
||||
result: TaskResult[Any],
|
||||
graph: Graph,
|
||||
spec: TaskSpec[Any],
|
||||
task_ctx: dict[str, Any],
|
||||
context: dict[str, Any],
|
||||
report: RunReport,
|
||||
backend: StateBackend,
|
||||
on_event: EventCallback | None,
|
||||
context_snapshot: Mapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""存储任务结果到 context/report/backend 并触发事件。"""
|
||||
"""存储任务结果到 context/report/backend 并触发事件。
|
||||
|
||||
``spec`` 与 ``task_ctx`` 由调用方在执行前已构建,直接复用避免重复
|
||||
``resolved_spec`` / ``_build_context`` 调用。
|
||||
"""
|
||||
context[name] = result.value
|
||||
if result.status == TaskStatus.SUCCESS:
|
||||
spec = graph.resolved_spec(name)
|
||||
task_ctx = _build_context(spec, context_snapshot if context_snapshot is not None else context, report)
|
||||
backend.save(spec.storage_key(task_ctx), result.value)
|
||||
report.results[name] = result
|
||||
_emit(on_event, result)
|
||||
|
||||
@staticmethod
|
||||
|
||||
def _build_semaphores(
|
||||
to_run: list[str],
|
||||
graph: Graph,
|
||||
@@ -451,7 +448,7 @@ class _LayerMixin:
|
||||
semaphores[key] = sem_factory(limit)
|
||||
return semaphores
|
||||
|
||||
@staticmethod
|
||||
|
||||
def _get_sem(semaphores: Mapping[str, Any], spec: TaskSpec[Any]) -> Any | None:
|
||||
"""获取任务对应的信号量(无 concurrency_key 则返回 None)。"""
|
||||
if spec.concurrency_key is None:
|
||||
@@ -462,7 +459,7 @@ class _LayerMixin:
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 层执行器
|
||||
# ---------------------------------------------------------------------- #
|
||||
class SequentialLayerRunner(_LayerMixin):
|
||||
class SequentialLayerRunner:
|
||||
"""逐个运行某层的任务(按优先级排序)。"""
|
||||
|
||||
@staticmethod
|
||||
@@ -475,14 +472,14 @@ class SequentialLayerRunner(_LayerMixin):
|
||||
layer_idx: int,
|
||||
on_event: EventCallback | None,
|
||||
) -> None:
|
||||
for name in SequentialLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event):
|
||||
for name in _filter_and_sort(layer, graph, context, report, backend, on_event):
|
||||
spec = graph.resolved_spec(name)
|
||||
task_ctx = _build_context(spec, context, report)
|
||||
result = SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
|
||||
SequentialLayerRunner._store_result(name, result, graph, context, report, backend, on_event)
|
||||
_store_result(name, result, spec, task_ctx, context, report, backend, on_event)
|
||||
|
||||
|
||||
class ThreadedLayerRunner(_LayerMixin):
|
||||
class ThreadedLayerRunner:
|
||||
"""在线程池中并发运行某层的任务。"""
|
||||
|
||||
@staticmethod
|
||||
@@ -497,43 +494,43 @@ class ThreadedLayerRunner(_LayerMixin):
|
||||
max_workers: int,
|
||||
concurrency_limits: Mapping[str, int],
|
||||
) -> None:
|
||||
to_run = ThreadedLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event)
|
||||
to_run = _filter_and_sort(layer, graph, context, report, backend, on_event)
|
||||
if not to_run:
|
||||
return
|
||||
semaphores = ThreadedLayerRunner._build_semaphores(to_run, graph, threading.Semaphore, concurrency_limits)
|
||||
semaphores = _build_semaphores(to_run, graph, threading.Semaphore, concurrency_limits)
|
||||
context_snapshot = dict(context)
|
||||
lock = threading.Lock()
|
||||
|
||||
def _run_threaded_task(name: str) -> TaskResult[Any]:
|
||||
def _run_threaded_task(name: str) -> tuple[dict[str, Any], TaskResult[Any]]:
|
||||
spec = graph.resolved_spec(name)
|
||||
task_ctx = _build_context(spec, context_snapshot, report)
|
||||
sem = ThreadedLayerRunner._get_sem(semaphores, spec)
|
||||
sem = _get_sem(semaphores, spec)
|
||||
if sem is not None:
|
||||
sem.acquire()
|
||||
try:
|
||||
return SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
|
||||
return task_ctx, SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
|
||||
finally:
|
||||
if sem is not None:
|
||||
sem.release()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||
future_to_name: dict[concurrent.futures.Future[TaskResult[Any]], str] = {
|
||||
future_to_name: dict[concurrent.futures.Future[tuple[dict[str, Any], TaskResult[Any]]], str] = {
|
||||
pool.submit(_run_threaded_task, name): name for name in to_run
|
||||
}
|
||||
completed: dict[str, TaskResult[Any]] = {}
|
||||
completed: dict[str, tuple[dict[str, Any], TaskResult[Any]]] = {}
|
||||
try:
|
||||
for fut in concurrent.futures.as_completed(future_to_name):
|
||||
name = future_to_name[fut]
|
||||
completed[name] = fut.result()
|
||||
finally:
|
||||
with lock:
|
||||
for name, result in completed.items():
|
||||
ThreadedLayerRunner._store_result(
|
||||
name, result, graph, context, report, backend, on_event, context_snapshot
|
||||
for name, (task_ctx, result) in completed.items():
|
||||
_store_result(
|
||||
name, result, graph.resolved_spec(name), task_ctx, context, report, backend, on_event
|
||||
)
|
||||
|
||||
|
||||
class AsyncLayerRunner(_LayerMixin):
|
||||
class AsyncLayerRunner:
|
||||
"""在事件循环上并发运行某层的任务。"""
|
||||
|
||||
@staticmethod
|
||||
@@ -547,27 +544,32 @@ class AsyncLayerRunner(_LayerMixin):
|
||||
on_event: EventCallback | None,
|
||||
concurrency_limits: Mapping[str, int],
|
||||
) -> None:
|
||||
to_run = AsyncLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event)
|
||||
to_run = _filter_and_sort(layer, graph, context, report, backend, on_event)
|
||||
if not to_run:
|
||||
return
|
||||
semaphores = AsyncLayerRunner._build_semaphores(to_run, graph, asyncio.Semaphore, concurrency_limits)
|
||||
semaphores = _build_semaphores(to_run, graph, asyncio.Semaphore, concurrency_limits)
|
||||
context_snapshot = dict(context)
|
||||
|
||||
async def _run_async_task(name: str) -> TaskResult[Any]:
|
||||
async def _run_async_task(name: str) -> tuple[dict[str, Any], TaskResult[Any]]:
|
||||
spec = graph.resolved_spec(name)
|
||||
task_ctx = _build_context(spec, context_snapshot, report)
|
||||
sem = AsyncLayerRunner._get_sem(semaphores, spec)
|
||||
return await AsyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report, sem)
|
||||
sem = _get_sem(semaphores, spec)
|
||||
result = await AsyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report, sem)
|
||||
return task_ctx, result
|
||||
|
||||
results = await asyncio.gather(*[_run_async_task(name) for name in to_run])
|
||||
for name, result in zip(to_run, results):
|
||||
AsyncLayerRunner._store_result(name, result, graph, context, report, backend, on_event, context_snapshot)
|
||||
for name, (task_ctx, result) in zip(to_run, results):
|
||||
_store_result(name, result, graph.resolved_spec(name), task_ctx, context, report, backend, on_event)
|
||||
|
||||
|
||||
class DependencyRunner(_LayerMixin):
|
||||
class DependencyRunner:
|
||||
"""依赖驱动调度:任务在硬/软依赖完成后立即启动,无层屏障。
|
||||
|
||||
所有任务通过 asyncio 并发调度。同步任务卸载到线程池。
|
||||
|
||||
本类不继承层 Mixin:依赖驱动调度不是层模型,直接调用模块级共享辅助
|
||||
函数(:func:`_build_semaphores` / :func:`_get_sem` / :func:`_store_result`),
|
||||
职责更清晰。
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@@ -580,7 +582,7 @@ class DependencyRunner(_LayerMixin):
|
||||
concurrency_limits: Mapping[str, int],
|
||||
) -> None:
|
||||
all_names = list(graph.all_specs().keys())
|
||||
semaphores = DependencyRunner._build_semaphores(all_names, graph, asyncio.Semaphore, concurrency_limits)
|
||||
semaphores = _build_semaphores(all_names, graph, asyncio.Semaphore, concurrency_limits)
|
||||
futures: dict[str, asyncio.Future[TaskResult[Any]]] = {}
|
||||
|
||||
async def _run_task(name: str) -> TaskResult[Any]:
|
||||
@@ -598,9 +600,9 @@ class DependencyRunner(_LayerMixin):
|
||||
if _apply_cached(name, spec, context, report, backend, on_event):
|
||||
return report.results[name]
|
||||
|
||||
sem = DependencyRunner._get_sem(semaphores, spec)
|
||||
sem = _get_sem(semaphores, spec)
|
||||
result = await AsyncTaskRunner.run(spec, task_ctx, None, on_event, report, sem)
|
||||
DependencyRunner._store_result(name, result, graph, context, report, backend, on_event)
|
||||
_store_result(name, result, spec, task_ctx, context, report, backend, on_event)
|
||||
return result
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -677,10 +679,8 @@ def run(
|
||||
TaskFailedError
|
||||
任何任务耗尽重试后仍失败时(除非 ``continue_on_error=True``)。
|
||||
"""
|
||||
graph.validate()
|
||||
layers = graph.layers()
|
||||
|
||||
if dry_run:
|
||||
layers = graph.layers()
|
||||
_print_dry_run(graph, layers)
|
||||
return RunReport(success=True)
|
||||
|
||||
@@ -690,14 +690,22 @@ def run(
|
||||
context: dict[str, Any] = {}
|
||||
limits = concurrency_limits or {}
|
||||
|
||||
# backend.batch():将每任务一次落盘降为整次运行一次(JSONBackend);
|
||||
# MemoryBackend 为 no-op。即使中途抛出 TaskFailedError,batch 退出时
|
||||
# 仍会 flush 一次,保留已成功任务的结果以便断点续跑。
|
||||
with backend.batch():
|
||||
try:
|
||||
if strategy == "sequential":
|
||||
layers = graph.layers()
|
||||
_drive_sequential(graph, layers, context, report, backend, effective_callback)
|
||||
elif strategy == "thread":
|
||||
layers = graph.layers()
|
||||
_drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers, limits)
|
||||
elif strategy == "async":
|
||||
layers = graph.layers()
|
||||
asyncio.run(_async_drive(graph, layers, context, report, backend, effective_callback, limits))
|
||||
elif strategy == "dependency":
|
||||
graph.validate()
|
||||
asyncio.run(DependencyRunner.execute(graph, context, report, backend, effective_callback, limits))
|
||||
else:
|
||||
raise ValueError(f"Unknown strategy: {strategy!r}")
|
||||
|
||||
+13
-102
@@ -82,6 +82,10 @@ class Graph:
|
||||
# 待解析的字符串引用列表(由 GraphComposer 消费);为空表示无引用。
|
||||
_pending_refs: list[str] = field(default_factory=list)
|
||||
|
||||
# resolved_spec 缓存:避免执行期每个任务多次重复 dataclasses.replace 判断。
|
||||
# 在 specs / defaults 变更时失效。
|
||||
_resolved_cache: dict[str, TaskSpec[Any]] = field(default_factory=dict)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 构建
|
||||
# ------------------------------------------------------------------ #
|
||||
@@ -97,6 +101,7 @@ class Graph:
|
||||
self.specs[spec.name] = spec
|
||||
# 拓扑依赖仅含硬依赖;软依赖仅用于注入,不影响分层。
|
||||
self.deps[spec.name] = spec.depends_on
|
||||
self._resolved_cache.clear()
|
||||
|
||||
@classmethod
|
||||
def from_specs(
|
||||
@@ -175,7 +180,12 @@ class Graph:
|
||||
对于 ``retry``/``timeout``/``strategy``/``env``/``cwd`` 等可空
|
||||
字段,若 spec 字段为默认空值且图级默认值非空,则用
|
||||
:func:`dataclasses.replace` 生成带默认值的副本。
|
||||
|
||||
结果按 ``name`` 缓存;specs / defaults 变更时缓存失效。
|
||||
"""
|
||||
cached = self._resolved_cache.get(name)
|
||||
if cached is not None:
|
||||
return cached
|
||||
spec = self.specs[name]
|
||||
d = self.defaults
|
||||
overrides: dict[str, Any] = {}
|
||||
@@ -199,9 +209,9 @@ class Graph:
|
||||
overrides["verbose"] = True
|
||||
if not spec.tags and d.tags:
|
||||
overrides["tags"] = d.tags
|
||||
if not overrides:
|
||||
return spec
|
||||
return replace(spec, **overrides)
|
||||
resolved = spec if not overrides else replace(spec, **overrides)
|
||||
self._resolved_cache[name] = resolved
|
||||
return resolved
|
||||
|
||||
def dependencies(self, name: str) -> tuple[str, ...]:
|
||||
"""``name`` 的直接硬依赖前驱。"""
|
||||
@@ -355,102 +365,3 @@ class Graph:
|
||||
|
||||
def __contains__(self, name: Any) -> bool:
|
||||
return name in self.specs
|
||||
|
||||
|
||||
class GraphComposer:
|
||||
"""将带字符串引用的图展开为纯 :class:`TaskSpec` 图。
|
||||
|
||||
引用格式:
|
||||
* ``"command_name"`` —— 引用整个命令图。
|
||||
* ``"command_name.task_name"`` —— 引用特定任务。
|
||||
|
||||
引用按顺序展开,后续引用的任务依赖前面引用的最后一个任务;
|
||||
原始 ``TaskSpec`` 之间也按出现顺序串行依赖。
|
||||
"""
|
||||
|
||||
def __init__(self, graphs: dict[str, Graph]) -> None:
|
||||
self.graphs = graphs
|
||||
|
||||
def resolve_all(self) -> dict[str, Graph]:
|
||||
"""解析所有图的字符串引用,返回展开后的新图映射。"""
|
||||
resolved: dict[str, Graph] = {}
|
||||
for cmd_name, graph in self.graphs.items():
|
||||
resolved[cmd_name] = self.expand_refs(graph, cmd_name)
|
||||
return resolved
|
||||
|
||||
def expand_refs(self, graph: Graph, current_cmd: str) -> Graph:
|
||||
"""展开图中的字符串引用。若无 ``_pending_refs``,原样返回。"""
|
||||
pending_refs = graph._pending_refs
|
||||
if not pending_refs:
|
||||
return graph
|
||||
|
||||
all_specs: list[TaskSpec[Any]] = []
|
||||
previous_ref_last_task: str | None = None
|
||||
|
||||
for ref in pending_refs:
|
||||
expanded_specs = self.parse_ref(ref, current_cmd)
|
||||
if previous_ref_last_task and expanded_specs:
|
||||
for i, task in enumerate(expanded_specs):
|
||||
if i == 0 or not task.depends_on:
|
||||
expanded_specs[i] = replace(task, depends_on=tuple({*task.depends_on, previous_ref_last_task}))
|
||||
if expanded_specs:
|
||||
previous_ref_last_task = expanded_specs[-1].name
|
||||
all_specs.extend(expanded_specs)
|
||||
|
||||
original_specs = list(graph.all_specs().values())
|
||||
if original_specs:
|
||||
if previous_ref_last_task:
|
||||
first = original_specs[0]
|
||||
all_specs.append(replace(first, depends_on=tuple({*first.depends_on, previous_ref_last_task})))
|
||||
else:
|
||||
all_specs.append(original_specs[0])
|
||||
for i in range(1, len(original_specs)):
|
||||
current_task = original_specs[i]
|
||||
previous_task_name = original_specs[i - 1].name
|
||||
all_specs.append(
|
||||
replace(current_task, depends_on=tuple({*current_task.depends_on, previous_task_name}))
|
||||
)
|
||||
|
||||
return Graph.from_specs(all_specs, defaults=graph.defaults)
|
||||
|
||||
def parse_ref(self, ref: str, current_cmd: str) -> list[TaskSpec[Any]]:
|
||||
"""解析单个字符串引用,返回对应的 TaskSpec 列表。"""
|
||||
if ref == current_cmd:
|
||||
raise ValueError(f"循环引用: 命令 '{current_cmd}' 引用了自己")
|
||||
|
||||
if "." in ref:
|
||||
cmd_name, task_name = ref.split(".", 1)
|
||||
if cmd_name not in self.graphs:
|
||||
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
|
||||
ref_graph = self.graphs[cmd_name]
|
||||
if task_name not in ref_graph.all_specs():
|
||||
raise ValueError(f"任务 '{task_name}' 不存在于命令 '{cmd_name}' 中")
|
||||
return [ref_graph.all_specs()[task_name]]
|
||||
else:
|
||||
cmd_name = ref
|
||||
if cmd_name not in self.graphs:
|
||||
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
|
||||
ref_graph = self.graphs[cmd_name]
|
||||
ref_graph = self.expand_refs(ref_graph, cmd_name)
|
||||
return list(ref_graph.all_specs().values())
|
||||
|
||||
|
||||
def compose(
|
||||
graphs: dict[str, Graph],
|
||||
) -> dict[str, Graph]:
|
||||
"""编程式解析多图的字符串引用,返回展开后的新图映射。
|
||||
|
||||
与 :class:`GraphComposer` 等价,但作为独立函数暴露,供不使用
|
||||
:class:`~pyflowx.runner.CliRunner` 的编程式用户调用。
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> graphs = {
|
||||
... "build": px.Graph.from_specs([px.TaskSpec("b", cmd=["echo", "b"])]),
|
||||
... "all": px.Graph.from_specs(["build", px.TaskSpec("t", cmd=["echo", "t"])]),
|
||||
... }
|
||||
>>> resolved = px.compose(graphs)
|
||||
>>> "b" in resolved["all"].all_specs()
|
||||
True
|
||||
"""
|
||||
return GraphComposer(graphs).resolve_all()
|
||||
|
||||
@@ -17,9 +17,10 @@ import sys
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import Any, Sequence, get_args
|
||||
|
||||
from .compose import GraphComposer
|
||||
from .errors import PyFlowXError
|
||||
from .executors import Strategy, run
|
||||
from .graph import Graph, GraphComposer
|
||||
from .graph import Graph
|
||||
from .task import TaskSpec
|
||||
|
||||
__all__ = ["CliExitCode", "CliRunner"]
|
||||
|
||||
+38
-1
@@ -18,8 +18,9 @@ import sys
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping
|
||||
from typing import Any, ContextManager, Mapping
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
@@ -55,6 +56,22 @@ class StateBackend(ABC):
|
||||
def clear(self) -> None:
|
||||
"""清除所有存储状态。"""
|
||||
|
||||
def flush(self) -> None: # noqa: B027
|
||||
"""将内存中暂存的状态持久化到外部介质。
|
||||
|
||||
默认无操作(如 :class:`MemoryBackend` 无需落盘)。
|
||||
:class:`JSONBackend` 在 :meth:`batch` 期间会延迟落盘,需在退出时调用。
|
||||
"""
|
||||
|
||||
def batch(self) -> ContextManager[None]:
|
||||
"""返回一个上下文管理器,期间 :meth:`save` 可延迟 :meth:`flush`。
|
||||
|
||||
默认实现为 no-op(如 :class:`MemoryBackend`)。:class:`JSONBackend`
|
||||
覆盖为:进入时标记延迟,退出时统一 flush 一次,将每任务一次落盘
|
||||
(N 次写入)降为整次运行一次(O(N) 而非 O(N²))。
|
||||
"""
|
||||
return nullcontext()
|
||||
|
||||
|
||||
class _TTLStateBackendMixin(StateBackend):
|
||||
"""TTL 状态后端共享逻辑。
|
||||
@@ -184,6 +201,7 @@ class JSONBackend(_TTLStateBackendMixin):
|
||||
self._path: str = path
|
||||
self._ttl = ttl
|
||||
self._store: dict[str, dict[str, Any]] = {}
|
||||
self._defer_flush: bool = False
|
||||
self._load()
|
||||
|
||||
def _load(self) -> None:
|
||||
@@ -244,6 +262,25 @@ class JSONBackend(_TTLStateBackendMixin):
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise StorageError(f"result of key {key!r} is not JSON-serialisable", exc) from exc
|
||||
super().save(key, value)
|
||||
if not self._defer_flush:
|
||||
self._flush()
|
||||
|
||||
@override
|
||||
def flush(self) -> None:
|
||||
self._flush()
|
||||
|
||||
@override
|
||||
@contextmanager
|
||||
def batch(self) -> Iterator[None]:
|
||||
"""进入批量模式:``save`` 暂不落盘,退出时统一 flush 一次。
|
||||
|
||||
将整次运行 N 个任务的 N 次全量落盘降为 1 次。
|
||||
"""
|
||||
self._defer_flush = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._defer_flush = False
|
||||
self._flush()
|
||||
|
||||
def _expired(self, entry: Mapping[str, Any]) -> bool:
|
||||
|
||||
+32
-76
@@ -19,12 +19,13 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -291,13 +292,16 @@ class TaskSpec(Generic[T]):
|
||||
if self.fn is None and self.cmd is None:
|
||||
raise ValueError(f"TaskSpec '{self.name}': 必须提供 fn 或 cmd 参数。")
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def effective_fn(self) -> TaskFn[T]:
|
||||
"""获取有效的执行函数。
|
||||
|
||||
若提供 ``cmd``,返回包装后的命令执行函数;否则返回 ``fn``。
|
||||
包装函数在每次调用时从 ``self`` 读取 ``verbose``/``cwd``/``env``/
|
||||
``timeout``,避免闭包捕获运行期参数,使翻转字段无需重建 spec。
|
||||
|
||||
结果按实例缓存(:func:`functools.cached_property`):frozen dataclass
|
||||
字段不可变,``_wrap_cmd`` 生成的闭包稳定,无需每次访问重建。
|
||||
"""
|
||||
if self.cmd is not None:
|
||||
return self._wrap_cmd()
|
||||
@@ -306,11 +310,17 @@ class TaskSpec(Generic[T]):
|
||||
raise ValueError(f"TaskSpec '{self.name}': 没有可执行的函数或命令。") # pragma: no cover
|
||||
|
||||
def _wrap_cmd(self) -> TaskFn[Any]:
|
||||
"""将 cmd 包装为可执行函数。"""
|
||||
"""将 cmd 包装为可执行函数。
|
||||
|
||||
实际执行逻辑位于 :mod:`pyflowx.command`,避免 :class:`TaskSpec`
|
||||
作为纯数据结构混入命令执行逻辑。
|
||||
"""
|
||||
from .command import run_command
|
||||
|
||||
spec = self
|
||||
|
||||
def _run() -> T:
|
||||
return cast(T, _run_command(spec))
|
||||
return cast(T, run_command(spec))
|
||||
|
||||
_run.__name__ = spec.name
|
||||
return _run # type: ignore[return-value]
|
||||
@@ -376,12 +386,29 @@ class TaskSpec(Generic[T]):
|
||||
return self.name
|
||||
|
||||
|
||||
# 全局锁:序列化对进程级状态(os.environ / os.chdir)的临时修改。
|
||||
# ``fn`` 任务在 thread/async 策略下并发执行时,若各自配置了不同的
|
||||
# ``cwd``/``env``,会相互覆盖(os.chdir 与 os.environ 均为进程全局)。
|
||||
# 该锁仅包裹"切换→执行→恢复"区间,保证正确性;不使用 cwd/env 的任务不受影响。
|
||||
_env_cwd_lock = threading.RLock()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _env_and_cwd(
|
||||
env: Mapping[str, str] | None,
|
||||
cwd: Path | None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""临时设置环境变量与工作目录。"""
|
||||
"""临时设置环境变量与工作目录。
|
||||
|
||||
``os.environ`` 与 ``os.chdir`` 是进程级全局状态,在 thread/async 策略下
|
||||
并发执行多个带 ``env``/``cwd`` 的 ``fn`` 任务时会相互覆盖。本函数通过
|
||||
模块级 :data:`_env_cwd_lock` 串行化"切换→执行→恢复"区间,确保正确性。
|
||||
无 ``env`` 且无 ``cwd`` 时直接 yield,不获取锁。
|
||||
"""
|
||||
if not env and cwd is None:
|
||||
yield
|
||||
return
|
||||
with _env_cwd_lock:
|
||||
saved_env: dict[str, str] = {}
|
||||
saved_cwd: str | None = None
|
||||
if env:
|
||||
@@ -406,77 +433,6 @@ def _env_and_cwd(
|
||||
os.environ.pop(k, None)
|
||||
|
||||
|
||||
def _run_command(spec: TaskSpec[Any]) -> Any: # noqa: PLR0912
|
||||
"""执行 ``spec.cmd`` 指定的命令(list / shell 字符串 / 可调用对象)。"""
|
||||
cmd = spec.cmd
|
||||
verbose = spec.verbose
|
||||
cwd = spec.cwd
|
||||
timeout = spec.timeout
|
||||
env_override = spec.env
|
||||
|
||||
# 可调用对象:直接调用,返回其结果。
|
||||
if callable(cmd) and not isinstance(cmd, (list, str)):
|
||||
name = getattr(cmd, "__name__", "callable")
|
||||
if verbose:
|
||||
print(f"[verbose] 执行可调用命令: {name}", flush=True)
|
||||
if cwd is not None:
|
||||
print(f"[verbose] 工作目录: {cwd}", flush=True)
|
||||
try:
|
||||
return cmd()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"可调用命令执行异常: {name}: {e}") from e
|
||||
|
||||
is_list = isinstance(cmd, list)
|
||||
if is_list:
|
||||
cmd_str = " ".join(arg for arg in cmd) # type: ignore[union-attr]
|
||||
verb = "执行命令"
|
||||
label = "命令"
|
||||
else:
|
||||
cmd_str = cast(str, cmd)
|
||||
verb = "执行 Shell"
|
||||
label = "Shell 命令"
|
||||
|
||||
if verbose:
|
||||
print(f"[verbose] {verb}: {cmd_str}", flush=True)
|
||||
if cwd is not None:
|
||||
print(f"[verbose] 工作目录: {cwd}", flush=True)
|
||||
|
||||
# 合并环境变量
|
||||
run_env: dict[str, str] | None = None
|
||||
if env_override:
|
||||
run_env = dict(os.environ)
|
||||
run_env.update(env_override)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cast(Union[str, List[str]], cmd),
|
||||
shell=not is_list,
|
||||
cwd=cwd,
|
||||
env=run_env,
|
||||
timeout=timeout,
|
||||
capture_output=not verbose,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(f"{label}未找到: {cmd_str}") from None
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(f"{label}执行超时: {cmd_str} ({timeout}s)") from None
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"{label}执行异常: {cmd_str}: {e}") from e
|
||||
|
||||
if verbose:
|
||||
print(f"[verbose] 返回码: {result.returncode}", flush=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
return None
|
||||
|
||||
err_msg = f"{label}执行失败: `{cmd_str}`, 返回码: {result.returncode}"
|
||||
if not verbose and result.stderr.strip():
|
||||
err_msg += f"\n{result.stderr.strip()}"
|
||||
raise RuntimeError(err_msg)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 任务模板:批量生成相似 TaskSpec 的工厂
|
||||
# ---------------------------------------------------------------------- #
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
"""常用工具函数."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = ["perf_timer"]
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Callable, TypedDict
|
||||
|
||||
try:
|
||||
from typing_extensions import ParamSpec, TypeVar
|
||||
except ImportError:
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class _PerformanceMetrics(TypedDict):
|
||||
"""性能指标."""
|
||||
|
||||
count: int
|
||||
total_time: float
|
||||
|
||||
|
||||
_perf_metrics: defaultdict[str, _PerformanceMetrics] = defaultdict(
|
||||
lambda: _PerformanceMetrics(
|
||||
count=0,
|
||||
total_time=0.0,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _generate_report(unit: str, precision: int) -> str:
|
||||
"""生成性能指标报告,返回报告字符串."""
|
||||
if not _perf_metrics:
|
||||
return ""
|
||||
|
||||
lines: list[str] = []
|
||||
lines.append("=" * 50)
|
||||
lines.append("性能指标报告 (Performance Metrics Report)")
|
||||
lines.append("-" * 50)
|
||||
|
||||
# 按总耗时排序,最耗时的函数排在前面
|
||||
sorted_metrics = sorted(_perf_metrics.items(), key=lambda x: x[1]["total_time"], reverse=True)
|
||||
|
||||
for name, metrics in sorted_metrics:
|
||||
avg_time = metrics["total_time"] / metrics["count"] if metrics["count"] > 0 else 0
|
||||
lines.append(
|
||||
f"{name}: "
|
||||
f"调用次数={metrics['count']}, "
|
||||
f"总耗时={metrics['total_time']:.{precision}f}{unit}, "
|
||||
f"平均耗时={avg_time:.{precision}f}{unit}"
|
||||
)
|
||||
|
||||
lines.append("=" * 50)
|
||||
report_str = "\n".join(lines)
|
||||
|
||||
# 同时输出到日志
|
||||
logging.info("\n".join(lines))
|
||||
|
||||
return report_str
|
||||
|
||||
|
||||
def perf_timer(unit: str = "ms", precision: int = 4, report: bool = False):
|
||||
"""性能计时器装饰器."""
|
||||
scale: dict[str, float] = {
|
||||
"s": 1.0,
|
||||
"ms": 1000.0,
|
||||
"us": 1000000.0,
|
||||
}
|
||||
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
|
||||
_perf_metrics[func.__name__]["count"] += 1
|
||||
_perf_metrics[func.__name__]["total_time"] += (end_time - start_time) * scale[unit]
|
||||
|
||||
if not report:
|
||||
logging.info(
|
||||
f"{func.__name__} {unit}: {_perf_metrics[func.__name__]['total_time']:.{precision}f}{unit}"
|
||||
)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
if report:
|
||||
import atexit
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.info(f"Performance metrics report enabled with unit {unit} and precision {precision}")
|
||||
|
||||
@atexit.register
|
||||
def _report_at_exit() -> None:
|
||||
"""在程序退出时报告性能指标."""
|
||||
_generate_report(unit, precision)
|
||||
|
||||
# 将报告生成逻辑提取为独立函数,便于测试
|
||||
|
||||
return decorator
|
||||
+1
-1
@@ -5,8 +5,8 @@ from __future__ import annotations
|
||||
import pytest
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.compose import GraphComposer, compose
|
||||
from pyflowx.errors import CycleError, DuplicateTaskError, MissingDependencyError
|
||||
from pyflowx.graph import GraphComposer, compose
|
||||
|
||||
|
||||
def _fn() -> None:
|
||||
|
||||
+7
-7
@@ -345,14 +345,14 @@ def test_task_result_default_status() -> None:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# _run_command callable 命令测试
|
||||
# run_command callable 命令测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_run_command_callable_verbose_with_cwd(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None:
|
||||
"""callable 命令 verbose 模式应打印信息."""
|
||||
spec = TaskSpec("a", cmd=lambda: "result", verbose=True, cwd=tmp_path)
|
||||
import pyflowx.task as task_module
|
||||
from pyflowx.command import run_command
|
||||
|
||||
result = task_module._run_command(spec)
|
||||
spec = TaskSpec("a", cmd=lambda: "result", verbose=True, cwd=tmp_path)
|
||||
result = run_command(spec)
|
||||
assert result == "result"
|
||||
captured = capsys.readouterr()
|
||||
assert "执行可调用命令" in captured.out
|
||||
@@ -361,8 +361,8 @@ def test_run_command_callable_verbose_with_cwd(capsys: pytest.CaptureFixture[str
|
||||
|
||||
def test_run_command_callable_exception() -> None:
|
||||
"""callable 命令抛异常应转为 RuntimeError."""
|
||||
spec = TaskSpec("a", cmd=lambda: (_ for _ in ()).throw(RuntimeError("callable error")))
|
||||
import pyflowx.task as task_module
|
||||
from pyflowx.command import run_command
|
||||
|
||||
spec = TaskSpec("a", cmd=lambda: (_ for _ in ()).throw(RuntimeError("callable error")))
|
||||
with pytest.raises(RuntimeError, match="可调用命令执行异常"):
|
||||
task_module._run_command(spec)
|
||||
run_command(spec)
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from pyflowx.utils import _perf_metrics, perf_timer
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_perf_metrics():
|
||||
"""重置性能指标."""
|
||||
_perf_metrics.clear()
|
||||
|
||||
|
||||
class TestPerformanceTimer:
|
||||
def test_perf_timer(self):
|
||||
|
||||
@perf_timer()
|
||||
def test_func():
|
||||
time.sleep(0.1)
|
||||
|
||||
test_func()
|
||||
|
||||
assert _perf_metrics["test_func"] is not None
|
||||
assert _perf_metrics["test_func"]["count"] == 1
|
||||
assert _perf_metrics["test_func"]["total_time"] >= 0.1
|
||||
|
||||
def test_perf_timer_report(self, mocker: MockerFixture):
|
||||
mock_log = mocker.patch("logging.info")
|
||||
|
||||
@perf_timer(report=True, unit="ms", precision=3)
|
||||
def test_func():
|
||||
time.sleep(0.1)
|
||||
|
||||
test_func()
|
||||
|
||||
assert _perf_metrics["test_func"] is not None
|
||||
assert _perf_metrics["test_func"]["count"] == 1
|
||||
assert _perf_metrics["test_func"]["total_time"] >= 0.1
|
||||
|
||||
assert mock_log.call_count == 1
|
||||
|
||||
def test_generate_report(self, mocker: MockerFixture, caplog: pytest.LogCaptureFixture):
|
||||
mock_log = mocker.patch("logging.info")
|
||||
|
||||
from pyflowx.utils import _generate_report
|
||||
|
||||
@perf_timer(report=True, unit="ms", precision=3)
|
||||
def test_func():
|
||||
time.sleep(0.1)
|
||||
|
||||
@perf_timer(report=True, unit="ms", precision=3)
|
||||
def test_func2():
|
||||
time.sleep(0.2)
|
||||
|
||||
test_func()
|
||||
test_func2()
|
||||
|
||||
_generate_report("ms", 3)
|
||||
|
||||
assert mock_log.call_count == 3
|
||||
assert _perf_metrics["test_func"]["count"] == 1
|
||||
assert _perf_metrics["test_func"]["total_time"] >= 0.1
|
||||
assert _perf_metrics["test_func2"]["count"] == 1
|
||||
assert _perf_metrics["test_func2"]["total_time"] >= 0.2
|
||||
Reference in New Issue
Block a user