refactor: 重构项目代码结构,拆分职责模块
1. 抽离图组合逻辑到pyflowx.compose,原graph.py仅保留单图DAG逻辑 2. 抽离命令执行逻辑到pyflowx.command,移除task.py内的_run_command 3. 重构上下文签名缓存,优化性能 4. 移除废弃的utils.perf_timer相关代码 5. 为JSONBackend添加batch批量落盘优化 6. 调整导入路径与公开API,更新测试用例 7. 简化条件判断逻辑,移除冗余代码
This commit is contained in:
@@ -58,6 +58,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from .command import run_command
|
||||||
|
from .compose import GraphComposer, compose
|
||||||
from .conditions import (
|
from .conditions import (
|
||||||
IS_LINUX,
|
IS_LINUX,
|
||||||
IS_MACOS,
|
IS_MACOS,
|
||||||
@@ -79,7 +81,7 @@ from .errors import (
|
|||||||
TaskTimeoutError,
|
TaskTimeoutError,
|
||||||
)
|
)
|
||||||
from .executors import Strategy, run
|
from .executors import Strategy, run
|
||||||
from .graph import Graph, GraphComposer, GraphDefaults, compose
|
from .graph import Graph, GraphDefaults
|
||||||
from .report import RunReport
|
from .report import RunReport
|
||||||
from .runner import CliExitCode, CliRunner
|
from .runner import CliExitCode, CliRunner
|
||||||
from .storage import JSONBackend, MemoryBackend, StateBackend
|
from .storage import JSONBackend, MemoryBackend, StateBackend
|
||||||
@@ -136,5 +138,6 @@ __all__ = [
|
|||||||
"compose",
|
"compose",
|
||||||
"describe_injection",
|
"describe_injection",
|
||||||
"run",
|
"run",
|
||||||
|
"run_command",
|
||||||
"task_template",
|
"task_template",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1,98 @@
|
|||||||
|
"""命令执行器:把 :class:`~pyflowx.task.TaskSpec` 的 ``cmd`` 字段(list /
|
||||||
|
shell 字符串 / 可调用对象)转换为统一执行入口。
|
||||||
|
|
||||||
|
历史背景:原 ``task.py`` 的模块文档声明其为"纯数据结构",但 ``_run_command``
|
||||||
|
属于命令执行逻辑,违反单一职责。此处将其抽离,``TaskSpec`` 仅持有配置,
|
||||||
|
执行逻辑集中于本模块,便于独立测试与维护。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from typing import Any, List, Union, cast
|
||||||
|
|
||||||
|
from .task import TaskSpec
|
||||||
|
|
||||||
|
__all__ = ["run_command"]
|
||||||
|
|
||||||
|
|
||||||
|
def run_command(spec: TaskSpec[Any]) -> Any: # noqa: PLR0912
|
||||||
|
"""执行 ``spec.cmd`` 指定的命令(list / shell 字符串 / 可调用对象)。
|
||||||
|
|
||||||
|
与原 ``TaskSpec._run_command`` 行为一致:
|
||||||
|
|
||||||
|
- 可调用对象:直接调用,异常包装为 :class:`RuntimeError`。
|
||||||
|
- list / str:通过 :func:`subprocess.run` 执行,非零返回码抛
|
||||||
|
:class:`RuntimeError`(``verbose=False`` 时附 stderr)。
|
||||||
|
- ``verbose=True`` 时打印执行信息与返回码到 stdout。
|
||||||
|
- ``cwd`` / ``env`` 通过 subprocess 参数隔离(进程级状态仅在 fn 任务路径
|
||||||
|
使用,cmd 路径不依赖 ``os.chdir`` / ``os.environ``)。
|
||||||
|
"""
|
||||||
|
cmd = spec.cmd
|
||||||
|
verbose = spec.verbose
|
||||||
|
cwd = spec.cwd
|
||||||
|
timeout = spec.timeout
|
||||||
|
env_override = spec.env
|
||||||
|
|
||||||
|
# 可调用对象:直接调用,返回其结果。
|
||||||
|
if callable(cmd) and not isinstance(cmd, (list, str)):
|
||||||
|
name = getattr(cmd, "__name__", "callable")
|
||||||
|
if verbose:
|
||||||
|
print(f"[verbose] 执行可调用命令: {name}", flush=True)
|
||||||
|
if cwd is not None:
|
||||||
|
print(f"[verbose] 工作目录: {cwd}", flush=True)
|
||||||
|
try:
|
||||||
|
return cmd()
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"可调用命令执行异常: {name}: {e}") from e
|
||||||
|
|
||||||
|
is_list = isinstance(cmd, list)
|
||||||
|
if is_list:
|
||||||
|
cmd_str = " ".join(arg for arg in cmd) # type: ignore[union-attr]
|
||||||
|
verb = "执行命令"
|
||||||
|
label = "命令"
|
||||||
|
else:
|
||||||
|
cmd_str = cast(str, cmd)
|
||||||
|
verb = "执行 Shell"
|
||||||
|
label = "Shell 命令"
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"[verbose] {verb}: {cmd_str}", flush=True)
|
||||||
|
if cwd is not None:
|
||||||
|
print(f"[verbose] 工作目录: {cwd}", flush=True)
|
||||||
|
|
||||||
|
# 合并环境变量
|
||||||
|
run_env: dict[str, str] | None = None
|
||||||
|
if env_override:
|
||||||
|
run_env = dict(os.environ)
|
||||||
|
run_env.update(env_override)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = subprocess.run(
|
||||||
|
cast(Union[str, List[str]], cmd),
|
||||||
|
shell=not is_list,
|
||||||
|
cwd=cwd,
|
||||||
|
env=run_env,
|
||||||
|
timeout=timeout,
|
||||||
|
capture_output=not verbose,
|
||||||
|
text=True,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise RuntimeError(f"{label}未找到: {cmd_str}") from None
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
raise RuntimeError(f"{label}执行超时: {cmd_str} ({timeout}s)") from None
|
||||||
|
except OSError as e:
|
||||||
|
raise RuntimeError(f"{label}执行异常: {cmd_str}: {e}") from e
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"[verbose] 返回码: {result.returncode}", flush=True)
|
||||||
|
|
||||||
|
if result.returncode == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
err_msg = f"{label}执行失败: `{cmd_str}`, 返回码: {result.returncode}"
|
||||||
|
if not verbose and result.stderr.strip():
|
||||||
|
err_msg += f"\n{result.stderr.strip()}"
|
||||||
|
raise RuntimeError(err_msg)
|
||||||
@@ -0,0 +1,115 @@
|
|||||||
|
"""图组合:将带字符串引用的多个图展开为纯 :class:`~pyflowx.graph.Graph`。
|
||||||
|
|
||||||
|
历史背景:原 ``graph.py`` 同时承载 DAG 构建/校验/分层与多图组合逻辑,
|
||||||
|
职责过载。组合逻辑(:class:`GraphComposer` / :func:`compose`)与单图 DAG
|
||||||
|
模型正交,此处抽离为独立模块,便于按需导入与独立演进。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import replace
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .graph import Graph
|
||||||
|
from .task import TaskSpec
|
||||||
|
|
||||||
|
__all__ = ["GraphComposer", "compose"]
|
||||||
|
|
||||||
|
|
||||||
|
class GraphComposer:
|
||||||
|
"""将带字符串引用的图展开为纯 :class:`TaskSpec` 图。
|
||||||
|
|
||||||
|
引用格式:
|
||||||
|
* ``"command_name"`` —— 引用整个命令图。
|
||||||
|
* ``"command_name.task_name"`` —— 引用特定任务。
|
||||||
|
|
||||||
|
引用按顺序展开,后续引用的任务依赖前面引用的最后一个任务;
|
||||||
|
原始 ``TaskSpec`` 之间也按出现顺序串行依赖。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, graphs: dict[str, Graph]) -> None:
|
||||||
|
self.graphs = graphs
|
||||||
|
|
||||||
|
def resolve_all(self) -> dict[str, Graph]:
|
||||||
|
"""解析所有图的字符串引用,返回展开后的新图映射。"""
|
||||||
|
resolved: dict[str, Graph] = {}
|
||||||
|
for cmd_name, graph in self.graphs.items():
|
||||||
|
resolved[cmd_name] = self.expand_refs(graph, cmd_name)
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
def expand_refs(self, graph: Graph, current_cmd: str) -> Graph:
|
||||||
|
"""展开图中的字符串引用。若无 ``_pending_refs``,原样返回。"""
|
||||||
|
pending_refs = graph._pending_refs
|
||||||
|
if not pending_refs:
|
||||||
|
return graph
|
||||||
|
|
||||||
|
all_specs: list[TaskSpec[Any]] = []
|
||||||
|
previous_ref_last_task: str | None = None
|
||||||
|
|
||||||
|
for ref in pending_refs:
|
||||||
|
expanded_specs = self.parse_ref(ref, current_cmd)
|
||||||
|
if previous_ref_last_task and expanded_specs:
|
||||||
|
for i, task in enumerate(expanded_specs):
|
||||||
|
if i == 0 or not task.depends_on:
|
||||||
|
expanded_specs[i] = replace(task, depends_on=tuple({*task.depends_on, previous_ref_last_task}))
|
||||||
|
if expanded_specs:
|
||||||
|
previous_ref_last_task = expanded_specs[-1].name
|
||||||
|
all_specs.extend(expanded_specs)
|
||||||
|
|
||||||
|
original_specs = list(graph.all_specs().values())
|
||||||
|
if original_specs:
|
||||||
|
if previous_ref_last_task:
|
||||||
|
first = original_specs[0]
|
||||||
|
all_specs.append(replace(first, depends_on=tuple({*first.depends_on, previous_ref_last_task})))
|
||||||
|
else:
|
||||||
|
all_specs.append(original_specs[0])
|
||||||
|
for i in range(1, len(original_specs)):
|
||||||
|
current_task = original_specs[i]
|
||||||
|
previous_task_name = original_specs[i - 1].name
|
||||||
|
all_specs.append(
|
||||||
|
replace(current_task, depends_on=tuple({*current_task.depends_on, previous_task_name}))
|
||||||
|
)
|
||||||
|
|
||||||
|
return Graph.from_specs(all_specs, defaults=graph.defaults)
|
||||||
|
|
||||||
|
def parse_ref(self, ref: str, current_cmd: str) -> list[TaskSpec[Any]]:
|
||||||
|
"""解析单个字符串引用,返回对应的 TaskSpec 列表。"""
|
||||||
|
if ref == current_cmd:
|
||||||
|
raise ValueError(f"循环引用: 命令 '{current_cmd}' 引用了自己")
|
||||||
|
|
||||||
|
if "." in ref:
|
||||||
|
cmd_name, task_name = ref.split(".", 1)
|
||||||
|
if cmd_name not in self.graphs:
|
||||||
|
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
|
||||||
|
ref_graph = self.graphs[cmd_name]
|
||||||
|
if task_name not in ref_graph.all_specs():
|
||||||
|
raise ValueError(f"任务 '{task_name}' 不存在于命令 '{cmd_name}' 中")
|
||||||
|
return [ref_graph.all_specs()[task_name]]
|
||||||
|
else:
|
||||||
|
cmd_name = ref
|
||||||
|
if cmd_name not in self.graphs:
|
||||||
|
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
|
||||||
|
ref_graph = self.graphs[cmd_name]
|
||||||
|
ref_graph = self.expand_refs(ref_graph, cmd_name)
|
||||||
|
return list(ref_graph.all_specs().values())
|
||||||
|
|
||||||
|
|
||||||
|
def compose(
|
||||||
|
graphs: dict[str, Graph],
|
||||||
|
) -> dict[str, Graph]:
|
||||||
|
"""编程式解析多图的字符串引用,返回展开后的新图映射。
|
||||||
|
|
||||||
|
与 :class:`GraphComposer` 等价,但作为独立函数暴露,供不使用
|
||||||
|
:class:`~pyflowx.runner.CliRunner` 的编程式用户调用。
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> graphs = {
|
||||||
|
... "build": px.Graph.from_specs([px.TaskSpec("b", cmd=["echo", "b"])]),
|
||||||
|
... "all": px.Graph.from_specs(["build", px.TaskSpec("t", cmd=["echo", "t"])]),
|
||||||
|
... }
|
||||||
|
>>> resolved = px.compose(graphs)
|
||||||
|
>>> "b" in resolved["all"].all_specs()
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
return GraphComposer(graphs).resolve_all()
|
||||||
@@ -42,14 +42,6 @@ def _static(predicate: Callable[[], bool], name: str) -> Condition:
|
|||||||
return _cond
|
return _cond
|
||||||
|
|
||||||
|
|
||||||
def _cond_reason(cond: Condition) -> str | list[str] | None:
|
|
||||||
"""获取条件的失败原因:优先返回 ``_reason``,否则返回 ``__name__``。"""
|
|
||||||
reason = getattr(cond, "_reason", None)
|
|
||||||
if reason is not None:
|
|
||||||
return reason
|
|
||||||
return getattr(cond, "__name__", repr(cond))
|
|
||||||
|
|
||||||
|
|
||||||
def _cond_name(cond: Condition) -> str:
|
def _cond_name(cond: Condition) -> str:
|
||||||
"""获取条件的可读名称。"""
|
"""获取条件的可读名称。"""
|
||||||
return getattr(cond, "__name__", repr(cond))
|
return getattr(cond, "__name__", repr(cond))
|
||||||
@@ -228,13 +220,7 @@ class BuiltinConditions:
|
|||||||
"""对条件取反."""
|
"""对条件取反."""
|
||||||
|
|
||||||
def _cond(ctx: Context) -> bool:
|
def _cond(ctx: Context) -> bool:
|
||||||
result = condition(ctx)
|
return not condition(ctx)
|
||||||
if result:
|
|
||||||
# inner 为 True 时 NOT 会失败,记录 inner 的具体原因
|
|
||||||
inner_reason = _cond_reason(condition)
|
|
||||||
if inner_reason is not None:
|
|
||||||
_cond._reason = inner_reason # type: ignore[attr-defined]
|
|
||||||
return not result
|
|
||||||
|
|
||||||
_cond.__name__ = f"NOT({_cond_name(condition)})"
|
_cond.__name__ = f"NOT({_cond_name(condition)})"
|
||||||
return _cond
|
return _cond
|
||||||
@@ -254,15 +240,7 @@ class BuiltinConditions:
|
|||||||
"""多个条件的逻辑或."""
|
"""多个条件的逻辑或."""
|
||||||
|
|
||||||
def _cond(ctx: Context) -> bool:
|
def _cond(ctx: Context) -> bool:
|
||||||
matched: list[str] = []
|
return any(c(ctx) for c in conditions)
|
||||||
for c in conditions:
|
|
||||||
if c(ctx):
|
|
||||||
reason = _cond_reason(c)
|
|
||||||
matched.append(reason if isinstance(reason, str) else str(reason))
|
|
||||||
if matched:
|
|
||||||
_cond._reason = matched # type: ignore[attr-defined]
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
_cond.__name__ = f"OR({', '.join(_cond_name(c) for c in conditions)})"
|
_cond.__name__ = f"OR({', '.join(_cond_name(c) for c in conditions)})"
|
||||||
return _cond
|
return _cond
|
||||||
|
|||||||
+21
-2
@@ -16,6 +16,7 @@ DAG 库中泛滥的样板包装器。
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
from functools import lru_cache
|
||||||
from typing import Any, Mapping
|
from typing import Any, Mapping
|
||||||
|
|
||||||
from .errors import InjectionError
|
from .errors import InjectionError
|
||||||
@@ -24,6 +25,24 @@ from .task import Context, TaskSpec
|
|||||||
__all__ = ["Context", "_is_context_annotation", "build_call_args", "describe_injection"]
|
__all__ = ["Context", "_is_context_annotation", "build_call_args", "describe_injection"]
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1024)
|
||||||
|
def _cached_signature(fn: Any) -> inspect.Signature:
|
||||||
|
"""缓存 ``inspect.signature`` 结果(按 fn 对象键控)。
|
||||||
|
|
||||||
|
``fn`` 对象在 :meth:`TaskSpec.effective_fn` 缓存后稳定,签名重复内省
|
||||||
|
属纯开销。对不可哈希的可调用对象,调用方回退到直接内省。
|
||||||
|
"""
|
||||||
|
return inspect.signature(fn)
|
||||||
|
|
||||||
|
|
||||||
|
def _signature(fn: Any) -> inspect.Signature:
|
||||||
|
"""获取签名,优先走缓存;``fn`` 不可哈希时回退到直接内省。"""
|
||||||
|
try:
|
||||||
|
return _cached_signature(fn)
|
||||||
|
except TypeError:
|
||||||
|
return inspect.signature(fn)
|
||||||
|
|
||||||
|
|
||||||
def _is_context_annotation(annotation: Any) -> bool:
|
def _is_context_annotation(annotation: Any) -> bool:
|
||||||
"""判断参数标注是否为(或指向)``Context``。"""
|
"""判断参数标注是否为(或指向)``Context``。"""
|
||||||
if annotation is Context:
|
if annotation is Context:
|
||||||
@@ -44,7 +63,7 @@ def build_call_args(
|
|||||||
执行器填入 :attr:`TaskSpec.defaults` 中的默认值)。
|
执行器填入 :attr:`TaskSpec.defaults` 中的默认值)。
|
||||||
"""
|
"""
|
||||||
fn = spec.effective_fn
|
fn = spec.effective_fn
|
||||||
sig = inspect.signature(fn)
|
sig = _signature(fn)
|
||||||
params = sig.parameters
|
params = sig.parameters
|
||||||
|
|
||||||
var_keyword = next(
|
var_keyword = next(
|
||||||
@@ -115,7 +134,7 @@ def build_call_args(
|
|||||||
def describe_injection(spec: TaskSpec[Any]) -> str:
|
def describe_injection(spec: TaskSpec[Any]) -> str:
|
||||||
"""生成任务参数注入方式的人类可读描述。供 ``dry_run`` 使用。"""
|
"""生成任务参数注入方式的人类可读描述。供 ``dry_run`` 使用。"""
|
||||||
fn = spec.effective_fn
|
fn = spec.effective_fn
|
||||||
sig = inspect.signature(fn)
|
sig = _signature(fn)
|
||||||
positional_params = [
|
positional_params = [
|
||||||
p
|
p
|
||||||
for p, param in sig.parameters.items()
|
for p, param in sig.parameters.items()
|
||||||
|
|||||||
+116
-108
@@ -12,14 +12,17 @@
|
|||||||
|
|
||||||
架构
|
架构
|
||||||
----
|
----
|
||||||
本模块通过 **Mixin** 组合消除同步/异步与各层执行器之间的重复代码:
|
本模块通过 **Mixin** 组合消除同步/异步任务执行器之间的重复代码:
|
||||||
|
|
||||||
* :class:`_TaskSkipMixin` —— 上游跳过 / 条件跳过的预检逻辑。
|
* :class:`_TaskSkipMixin` —— 上游跳过 / 条件跳过的预检逻辑。
|
||||||
* :class:`_TaskRetryMixin` —— 重试决策、成功/失败后处理、finalize。
|
* :class:`_TaskRetryMixin` —— 重试决策、成功/失败后处理、finalize。
|
||||||
* :class:`_LayerMixin` —— 缓存过滤、优先级排序、信号量构建、结果存储。
|
|
||||||
* :class:`SyncTaskRunner` / :class:`AsyncTaskRunner` —— 任务级执行器,组合上述 Mixin。
|
* :class:`SyncTaskRunner` / :class:`AsyncTaskRunner` —— 任务级执行器,组合上述 Mixin。
|
||||||
|
* 模块级共享辅助(:func:`_filter_and_sort` / :func:`_store_result` /
|
||||||
|
:func:`_build_semaphores` / :func:`_get_sem`)—— 缓存过滤、优先级排序、
|
||||||
|
信号量构建、结果存储。
|
||||||
* :class:`SequentialLayerRunner` / :class:`ThreadedLayerRunner` /
|
* :class:`SequentialLayerRunner` / :class:`ThreadedLayerRunner` /
|
||||||
:class:`AsyncLayerRunner` / :class:`DependencyRunner` —— 层级执行器,组合 :class:`_LayerMixin`。
|
:class:`AsyncLayerRunner` —— 层级执行器,调用上述模块级辅助。
|
||||||
|
* :class:`DependencyRunner` —— 依赖驱动调度(非层模型),同样调用模块级辅助。
|
||||||
|
|
||||||
所有策略共享统一异步内核,支持:
|
所有策略共享统一异步内核,支持:
|
||||||
* :class:`RetryPolicy`(max_attempts/delay/backoff/jitter/retry_on)
|
* :class:`RetryPolicy`(max_attempts/delay/backoff/jitter/retry_on)
|
||||||
@@ -388,81 +391,75 @@ async def _execute_async_task(
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
# Mixin:层执行共享逻辑
|
# 共享辅助:缓存过滤、优先级排序、信号量构建、结果存储
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
class _LayerMixin:
|
def _filter_and_sort(
|
||||||
"""层执行共享逻辑:缓存过滤、优先级排序、信号量构建、结果存储。
|
layer: list[str],
|
||||||
|
graph: Graph,
|
||||||
|
context: dict[str, Any],
|
||||||
|
report: RunReport,
|
||||||
|
backend: StateBackend,
|
||||||
|
on_event: EventCallback | None,
|
||||||
|
) -> list[str]:
|
||||||
|
"""过滤掉已命中缓存的任务,按优先级排序返回待运行列表。"""
|
||||||
|
to_run: list[str] = []
|
||||||
|
for name in layer:
|
||||||
|
spec = graph.resolved_spec(name)
|
||||||
|
if not _apply_cached(name, spec, context, report, backend, on_event):
|
||||||
|
to_run.append(name)
|
||||||
|
return _sort_by_priority(to_run, graph)
|
||||||
|
|
||||||
四个层执行器(sequential/threaded/async/dependency)通过组合此 Mixin
|
|
||||||
消除"过滤缓存→排序→运行→存结果"的样板代码。
|
def _store_result(
|
||||||
|
name: str,
|
||||||
|
result: TaskResult[Any],
|
||||||
|
spec: TaskSpec[Any],
|
||||||
|
task_ctx: dict[str, Any],
|
||||||
|
context: dict[str, Any],
|
||||||
|
report: RunReport,
|
||||||
|
backend: StateBackend,
|
||||||
|
on_event: EventCallback | None,
|
||||||
|
) -> None:
|
||||||
|
"""存储任务结果到 context/report/backend 并触发事件。
|
||||||
|
|
||||||
|
``spec`` 与 ``task_ctx`` 由调用方在执行前已构建,直接复用避免重复
|
||||||
|
``resolved_spec`` / ``_build_context`` 调用。
|
||||||
"""
|
"""
|
||||||
|
context[name] = result.value
|
||||||
|
if result.status == TaskStatus.SUCCESS:
|
||||||
|
backend.save(spec.storage_key(task_ctx), result.value)
|
||||||
|
report.results[name] = result
|
||||||
|
_emit(on_event, result)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _filter_and_sort(
|
|
||||||
layer: list[str],
|
|
||||||
graph: Graph,
|
|
||||||
context: dict[str, Any],
|
|
||||||
report: RunReport,
|
|
||||||
backend: StateBackend,
|
|
||||||
on_event: EventCallback | None,
|
|
||||||
) -> list[str]:
|
|
||||||
"""过滤掉已命中缓存的任务,按优先级排序返回待运行列表。"""
|
|
||||||
to_run: list[str] = []
|
|
||||||
for name in layer:
|
|
||||||
spec = graph.resolved_spec(name)
|
|
||||||
if not _apply_cached(name, spec, context, report, backend, on_event):
|
|
||||||
to_run.append(name)
|
|
||||||
return _sort_by_priority(to_run, graph)
|
|
||||||
|
|
||||||
@staticmethod
|
def _build_semaphores(
|
||||||
def _store_result(
|
to_run: list[str],
|
||||||
name: str,
|
graph: Graph,
|
||||||
result: TaskResult[Any],
|
sem_factory: Callable[[int], Any],
|
||||||
graph: Graph,
|
concurrency_limits: Mapping[str, int],
|
||||||
context: dict[str, Any],
|
) -> dict[str, Any]:
|
||||||
report: RunReport,
|
"""为每个 ``concurrency_key`` 创建一个信号量。"""
|
||||||
backend: StateBackend,
|
semaphores: dict[str, Any] = {}
|
||||||
on_event: EventCallback | None,
|
for name in to_run:
|
||||||
context_snapshot: Mapping[str, Any] | None = None,
|
spec = graph.resolved_spec(name)
|
||||||
) -> None:
|
key = spec.concurrency_key
|
||||||
"""存储任务结果到 context/report/backend 并触发事件。"""
|
if key is not None and key not in semaphores:
|
||||||
context[name] = result.value
|
limit = concurrency_limits.get(key, 1)
|
||||||
if result.status == TaskStatus.SUCCESS:
|
semaphores[key] = sem_factory(limit)
|
||||||
spec = graph.resolved_spec(name)
|
return semaphores
|
||||||
task_ctx = _build_context(spec, context_snapshot if context_snapshot is not None else context, report)
|
|
||||||
backend.save(spec.storage_key(task_ctx), result.value)
|
|
||||||
report.results[name] = result
|
|
||||||
_emit(on_event, result)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _build_semaphores(
|
|
||||||
to_run: list[str],
|
|
||||||
graph: Graph,
|
|
||||||
sem_factory: Callable[[int], Any],
|
|
||||||
concurrency_limits: Mapping[str, int],
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""为每个 ``concurrency_key`` 创建一个信号量。"""
|
|
||||||
semaphores: dict[str, Any] = {}
|
|
||||||
for name in to_run:
|
|
||||||
spec = graph.resolved_spec(name)
|
|
||||||
key = spec.concurrency_key
|
|
||||||
if key is not None and key not in semaphores:
|
|
||||||
limit = concurrency_limits.get(key, 1)
|
|
||||||
semaphores[key] = sem_factory(limit)
|
|
||||||
return semaphores
|
|
||||||
|
|
||||||
@staticmethod
|
def _get_sem(semaphores: Mapping[str, Any], spec: TaskSpec[Any]) -> Any | None:
|
||||||
def _get_sem(semaphores: Mapping[str, Any], spec: TaskSpec[Any]) -> Any | None:
|
"""获取任务对应的信号量(无 concurrency_key 则返回 None)。"""
|
||||||
"""获取任务对应的信号量(无 concurrency_key 则返回 None)。"""
|
if spec.concurrency_key is None:
|
||||||
if spec.concurrency_key is None:
|
return None
|
||||||
return None
|
return semaphores.get(spec.concurrency_key)
|
||||||
return semaphores.get(spec.concurrency_key)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
# 层执行器
|
# 层执行器
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
class SequentialLayerRunner(_LayerMixin):
|
class SequentialLayerRunner:
|
||||||
"""逐个运行某层的任务(按优先级排序)。"""
|
"""逐个运行某层的任务(按优先级排序)。"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -475,14 +472,14 @@ class SequentialLayerRunner(_LayerMixin):
|
|||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
on_event: EventCallback | None,
|
on_event: EventCallback | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
for name in SequentialLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event):
|
for name in _filter_and_sort(layer, graph, context, report, backend, on_event):
|
||||||
spec = graph.resolved_spec(name)
|
spec = graph.resolved_spec(name)
|
||||||
task_ctx = _build_context(spec, context, report)
|
task_ctx = _build_context(spec, context, report)
|
||||||
result = SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
|
result = SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
|
||||||
SequentialLayerRunner._store_result(name, result, graph, context, report, backend, on_event)
|
_store_result(name, result, spec, task_ctx, context, report, backend, on_event)
|
||||||
|
|
||||||
|
|
||||||
class ThreadedLayerRunner(_LayerMixin):
|
class ThreadedLayerRunner:
|
||||||
"""在线程池中并发运行某层的任务。"""
|
"""在线程池中并发运行某层的任务。"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -497,43 +494,43 @@ class ThreadedLayerRunner(_LayerMixin):
|
|||||||
max_workers: int,
|
max_workers: int,
|
||||||
concurrency_limits: Mapping[str, int],
|
concurrency_limits: Mapping[str, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
to_run = ThreadedLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event)
|
to_run = _filter_and_sort(layer, graph, context, report, backend, on_event)
|
||||||
if not to_run:
|
if not to_run:
|
||||||
return
|
return
|
||||||
semaphores = ThreadedLayerRunner._build_semaphores(to_run, graph, threading.Semaphore, concurrency_limits)
|
semaphores = _build_semaphores(to_run, graph, threading.Semaphore, concurrency_limits)
|
||||||
context_snapshot = dict(context)
|
context_snapshot = dict(context)
|
||||||
lock = threading.Lock()
|
lock = threading.Lock()
|
||||||
|
|
||||||
def _run_threaded_task(name: str) -> TaskResult[Any]:
|
def _run_threaded_task(name: str) -> tuple[dict[str, Any], TaskResult[Any]]:
|
||||||
spec = graph.resolved_spec(name)
|
spec = graph.resolved_spec(name)
|
||||||
task_ctx = _build_context(spec, context_snapshot, report)
|
task_ctx = _build_context(spec, context_snapshot, report)
|
||||||
sem = ThreadedLayerRunner._get_sem(semaphores, spec)
|
sem = _get_sem(semaphores, spec)
|
||||||
if sem is not None:
|
if sem is not None:
|
||||||
sem.acquire()
|
sem.acquire()
|
||||||
try:
|
try:
|
||||||
return SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
|
return task_ctx, SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
|
||||||
finally:
|
finally:
|
||||||
if sem is not None:
|
if sem is not None:
|
||||||
sem.release()
|
sem.release()
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
|
||||||
future_to_name: dict[concurrent.futures.Future[TaskResult[Any]], str] = {
|
future_to_name: dict[concurrent.futures.Future[tuple[dict[str, Any], TaskResult[Any]]], str] = {
|
||||||
pool.submit(_run_threaded_task, name): name for name in to_run
|
pool.submit(_run_threaded_task, name): name for name in to_run
|
||||||
}
|
}
|
||||||
completed: dict[str, TaskResult[Any]] = {}
|
completed: dict[str, tuple[dict[str, Any], TaskResult[Any]]] = {}
|
||||||
try:
|
try:
|
||||||
for fut in concurrent.futures.as_completed(future_to_name):
|
for fut in concurrent.futures.as_completed(future_to_name):
|
||||||
name = future_to_name[fut]
|
name = future_to_name[fut]
|
||||||
completed[name] = fut.result()
|
completed[name] = fut.result()
|
||||||
finally:
|
finally:
|
||||||
with lock:
|
with lock:
|
||||||
for name, result in completed.items():
|
for name, (task_ctx, result) in completed.items():
|
||||||
ThreadedLayerRunner._store_result(
|
_store_result(
|
||||||
name, result, graph, context, report, backend, on_event, context_snapshot
|
name, result, graph.resolved_spec(name), task_ctx, context, report, backend, on_event
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AsyncLayerRunner(_LayerMixin):
|
class AsyncLayerRunner:
|
||||||
"""在事件循环上并发运行某层的任务。"""
|
"""在事件循环上并发运行某层的任务。"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -547,27 +544,32 @@ class AsyncLayerRunner(_LayerMixin):
|
|||||||
on_event: EventCallback | None,
|
on_event: EventCallback | None,
|
||||||
concurrency_limits: Mapping[str, int],
|
concurrency_limits: Mapping[str, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
to_run = AsyncLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event)
|
to_run = _filter_and_sort(layer, graph, context, report, backend, on_event)
|
||||||
if not to_run:
|
if not to_run:
|
||||||
return
|
return
|
||||||
semaphores = AsyncLayerRunner._build_semaphores(to_run, graph, asyncio.Semaphore, concurrency_limits)
|
semaphores = _build_semaphores(to_run, graph, asyncio.Semaphore, concurrency_limits)
|
||||||
context_snapshot = dict(context)
|
context_snapshot = dict(context)
|
||||||
|
|
||||||
async def _run_async_task(name: str) -> TaskResult[Any]:
|
async def _run_async_task(name: str) -> tuple[dict[str, Any], TaskResult[Any]]:
|
||||||
spec = graph.resolved_spec(name)
|
spec = graph.resolved_spec(name)
|
||||||
task_ctx = _build_context(spec, context_snapshot, report)
|
task_ctx = _build_context(spec, context_snapshot, report)
|
||||||
sem = AsyncLayerRunner._get_sem(semaphores, spec)
|
sem = _get_sem(semaphores, spec)
|
||||||
return await AsyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report, sem)
|
result = await AsyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report, sem)
|
||||||
|
return task_ctx, result
|
||||||
|
|
||||||
results = await asyncio.gather(*[_run_async_task(name) for name in to_run])
|
results = await asyncio.gather(*[_run_async_task(name) for name in to_run])
|
||||||
for name, result in zip(to_run, results):
|
for name, (task_ctx, result) in zip(to_run, results):
|
||||||
AsyncLayerRunner._store_result(name, result, graph, context, report, backend, on_event, context_snapshot)
|
_store_result(name, result, graph.resolved_spec(name), task_ctx, context, report, backend, on_event)
|
||||||
|
|
||||||
|
|
||||||
class DependencyRunner(_LayerMixin):
|
class DependencyRunner:
|
||||||
"""依赖驱动调度:任务在硬/软依赖完成后立即启动,无层屏障。
|
"""依赖驱动调度:任务在硬/软依赖完成后立即启动,无层屏障。
|
||||||
|
|
||||||
所有任务通过 asyncio 并发调度。同步任务卸载到线程池。
|
所有任务通过 asyncio 并发调度。同步任务卸载到线程池。
|
||||||
|
|
||||||
|
本类不继承层 Mixin:依赖驱动调度不是层模型,直接调用模块级共享辅助
|
||||||
|
函数(:func:`_build_semaphores` / :func:`_get_sem` / :func:`_store_result`),
|
||||||
|
职责更清晰。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -580,7 +582,7 @@ class DependencyRunner(_LayerMixin):
|
|||||||
concurrency_limits: Mapping[str, int],
|
concurrency_limits: Mapping[str, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
all_names = list(graph.all_specs().keys())
|
all_names = list(graph.all_specs().keys())
|
||||||
semaphores = DependencyRunner._build_semaphores(all_names, graph, asyncio.Semaphore, concurrency_limits)
|
semaphores = _build_semaphores(all_names, graph, asyncio.Semaphore, concurrency_limits)
|
||||||
futures: dict[str, asyncio.Future[TaskResult[Any]]] = {}
|
futures: dict[str, asyncio.Future[TaskResult[Any]]] = {}
|
||||||
|
|
||||||
async def _run_task(name: str) -> TaskResult[Any]:
|
async def _run_task(name: str) -> TaskResult[Any]:
|
||||||
@@ -598,9 +600,9 @@ class DependencyRunner(_LayerMixin):
|
|||||||
if _apply_cached(name, spec, context, report, backend, on_event):
|
if _apply_cached(name, spec, context, report, backend, on_event):
|
||||||
return report.results[name]
|
return report.results[name]
|
||||||
|
|
||||||
sem = DependencyRunner._get_sem(semaphores, spec)
|
sem = _get_sem(semaphores, spec)
|
||||||
result = await AsyncTaskRunner.run(spec, task_ctx, None, on_event, report, sem)
|
result = await AsyncTaskRunner.run(spec, task_ctx, None, on_event, report, sem)
|
||||||
DependencyRunner._store_result(name, result, graph, context, report, backend, on_event)
|
_store_result(name, result, spec, task_ctx, context, report, backend, on_event)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@@ -677,10 +679,8 @@ def run(
|
|||||||
TaskFailedError
|
TaskFailedError
|
||||||
任何任务耗尽重试后仍失败时(除非 ``continue_on_error=True``)。
|
任何任务耗尽重试后仍失败时(除非 ``continue_on_error=True``)。
|
||||||
"""
|
"""
|
||||||
graph.validate()
|
|
||||||
layers = graph.layers()
|
|
||||||
|
|
||||||
if dry_run:
|
if dry_run:
|
||||||
|
layers = graph.layers()
|
||||||
_print_dry_run(graph, layers)
|
_print_dry_run(graph, layers)
|
||||||
return RunReport(success=True)
|
return RunReport(success=True)
|
||||||
|
|
||||||
@@ -690,20 +690,28 @@ def run(
|
|||||||
context: dict[str, Any] = {}
|
context: dict[str, Any] = {}
|
||||||
limits = concurrency_limits or {}
|
limits = concurrency_limits or {}
|
||||||
|
|
||||||
try:
|
# backend.batch():将每任务一次落盘降为整次运行一次(JSONBackend);
|
||||||
if strategy == "sequential":
|
# MemoryBackend 为 no-op。即使中途抛出 TaskFailedError,batch 退出时
|
||||||
_drive_sequential(graph, layers, context, report, backend, effective_callback)
|
# 仍会 flush 一次,保留已成功任务的结果以便断点续跑。
|
||||||
elif strategy == "thread":
|
with backend.batch():
|
||||||
_drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers, limits)
|
try:
|
||||||
elif strategy == "async":
|
if strategy == "sequential":
|
||||||
asyncio.run(_async_drive(graph, layers, context, report, backend, effective_callback, limits))
|
layers = graph.layers()
|
||||||
elif strategy == "dependency":
|
_drive_sequential(graph, layers, context, report, backend, effective_callback)
|
||||||
asyncio.run(DependencyRunner.execute(graph, context, report, backend, effective_callback, limits))
|
elif strategy == "thread":
|
||||||
else:
|
layers = graph.layers()
|
||||||
raise ValueError(f"Unknown strategy: {strategy!r}")
|
_drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers, limits)
|
||||||
except TaskFailedError:
|
elif strategy == "async":
|
||||||
report.success = False
|
layers = graph.layers()
|
||||||
raise
|
asyncio.run(_async_drive(graph, layers, context, report, backend, effective_callback, limits))
|
||||||
|
elif strategy == "dependency":
|
||||||
|
graph.validate()
|
||||||
|
asyncio.run(DependencyRunner.execute(graph, context, report, backend, effective_callback, limits))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown strategy: {strategy!r}")
|
||||||
|
except TaskFailedError:
|
||||||
|
report.success = False
|
||||||
|
raise
|
||||||
|
|
||||||
return report
|
return report
|
||||||
|
|
||||||
|
|||||||
+13
-102
@@ -82,6 +82,10 @@ class Graph:
|
|||||||
# 待解析的字符串引用列表(由 GraphComposer 消费);为空表示无引用。
|
# 待解析的字符串引用列表(由 GraphComposer 消费);为空表示无引用。
|
||||||
_pending_refs: list[str] = field(default_factory=list)
|
_pending_refs: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
# resolved_spec 缓存:避免执行期每个任务多次重复 dataclasses.replace 判断。
|
||||||
|
# 在 specs / defaults 变更时失效。
|
||||||
|
_resolved_cache: dict[str, TaskSpec[Any]] = field(default_factory=dict)
|
||||||
|
|
||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
# 构建
|
# 构建
|
||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
@@ -97,6 +101,7 @@ class Graph:
|
|||||||
self.specs[spec.name] = spec
|
self.specs[spec.name] = spec
|
||||||
# 拓扑依赖仅含硬依赖;软依赖仅用于注入,不影响分层。
|
# 拓扑依赖仅含硬依赖;软依赖仅用于注入,不影响分层。
|
||||||
self.deps[spec.name] = spec.depends_on
|
self.deps[spec.name] = spec.depends_on
|
||||||
|
self._resolved_cache.clear()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_specs(
|
def from_specs(
|
||||||
@@ -175,7 +180,12 @@ class Graph:
|
|||||||
对于 ``retry``/``timeout``/``strategy``/``env``/``cwd`` 等可空
|
对于 ``retry``/``timeout``/``strategy``/``env``/``cwd`` 等可空
|
||||||
字段,若 spec 字段为默认空值且图级默认值非空,则用
|
字段,若 spec 字段为默认空值且图级默认值非空,则用
|
||||||
:func:`dataclasses.replace` 生成带默认值的副本。
|
:func:`dataclasses.replace` 生成带默认值的副本。
|
||||||
|
|
||||||
|
结果按 ``name`` 缓存;specs / defaults 变更时缓存失效。
|
||||||
"""
|
"""
|
||||||
|
cached = self._resolved_cache.get(name)
|
||||||
|
if cached is not None:
|
||||||
|
return cached
|
||||||
spec = self.specs[name]
|
spec = self.specs[name]
|
||||||
d = self.defaults
|
d = self.defaults
|
||||||
overrides: dict[str, Any] = {}
|
overrides: dict[str, Any] = {}
|
||||||
@@ -199,9 +209,9 @@ class Graph:
|
|||||||
overrides["verbose"] = True
|
overrides["verbose"] = True
|
||||||
if not spec.tags and d.tags:
|
if not spec.tags and d.tags:
|
||||||
overrides["tags"] = d.tags
|
overrides["tags"] = d.tags
|
||||||
if not overrides:
|
resolved = spec if not overrides else replace(spec, **overrides)
|
||||||
return spec
|
self._resolved_cache[name] = resolved
|
||||||
return replace(spec, **overrides)
|
return resolved
|
||||||
|
|
||||||
def dependencies(self, name: str) -> tuple[str, ...]:
|
def dependencies(self, name: str) -> tuple[str, ...]:
|
||||||
"""``name`` 的直接硬依赖前驱。"""
|
"""``name`` 的直接硬依赖前驱。"""
|
||||||
@@ -355,102 +365,3 @@ class Graph:
|
|||||||
|
|
||||||
def __contains__(self, name: Any) -> bool:
|
def __contains__(self, name: Any) -> bool:
|
||||||
return name in self.specs
|
return name in self.specs
|
||||||
|
|
||||||
|
|
||||||
class GraphComposer:
|
|
||||||
"""将带字符串引用的图展开为纯 :class:`TaskSpec` 图。
|
|
||||||
|
|
||||||
引用格式:
|
|
||||||
* ``"command_name"`` —— 引用整个命令图。
|
|
||||||
* ``"command_name.task_name"`` —— 引用特定任务。
|
|
||||||
|
|
||||||
引用按顺序展开,后续引用的任务依赖前面引用的最后一个任务;
|
|
||||||
原始 ``TaskSpec`` 之间也按出现顺序串行依赖。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, graphs: dict[str, Graph]) -> None:
|
|
||||||
self.graphs = graphs
|
|
||||||
|
|
||||||
def resolve_all(self) -> dict[str, Graph]:
|
|
||||||
"""解析所有图的字符串引用,返回展开后的新图映射。"""
|
|
||||||
resolved: dict[str, Graph] = {}
|
|
||||||
for cmd_name, graph in self.graphs.items():
|
|
||||||
resolved[cmd_name] = self.expand_refs(graph, cmd_name)
|
|
||||||
return resolved
|
|
||||||
|
|
||||||
def expand_refs(self, graph: Graph, current_cmd: str) -> Graph:
|
|
||||||
"""展开图中的字符串引用。若无 ``_pending_refs``,原样返回。"""
|
|
||||||
pending_refs = graph._pending_refs
|
|
||||||
if not pending_refs:
|
|
||||||
return graph
|
|
||||||
|
|
||||||
all_specs: list[TaskSpec[Any]] = []
|
|
||||||
previous_ref_last_task: str | None = None
|
|
||||||
|
|
||||||
for ref in pending_refs:
|
|
||||||
expanded_specs = self.parse_ref(ref, current_cmd)
|
|
||||||
if previous_ref_last_task and expanded_specs:
|
|
||||||
for i, task in enumerate(expanded_specs):
|
|
||||||
if i == 0 or not task.depends_on:
|
|
||||||
expanded_specs[i] = replace(task, depends_on=tuple({*task.depends_on, previous_ref_last_task}))
|
|
||||||
if expanded_specs:
|
|
||||||
previous_ref_last_task = expanded_specs[-1].name
|
|
||||||
all_specs.extend(expanded_specs)
|
|
||||||
|
|
||||||
original_specs = list(graph.all_specs().values())
|
|
||||||
if original_specs:
|
|
||||||
if previous_ref_last_task:
|
|
||||||
first = original_specs[0]
|
|
||||||
all_specs.append(replace(first, depends_on=tuple({*first.depends_on, previous_ref_last_task})))
|
|
||||||
else:
|
|
||||||
all_specs.append(original_specs[0])
|
|
||||||
for i in range(1, len(original_specs)):
|
|
||||||
current_task = original_specs[i]
|
|
||||||
previous_task_name = original_specs[i - 1].name
|
|
||||||
all_specs.append(
|
|
||||||
replace(current_task, depends_on=tuple({*current_task.depends_on, previous_task_name}))
|
|
||||||
)
|
|
||||||
|
|
||||||
return Graph.from_specs(all_specs, defaults=graph.defaults)
|
|
||||||
|
|
||||||
def parse_ref(self, ref: str, current_cmd: str) -> list[TaskSpec[Any]]:
|
|
||||||
"""解析单个字符串引用,返回对应的 TaskSpec 列表。"""
|
|
||||||
if ref == current_cmd:
|
|
||||||
raise ValueError(f"循环引用: 命令 '{current_cmd}' 引用了自己")
|
|
||||||
|
|
||||||
if "." in ref:
|
|
||||||
cmd_name, task_name = ref.split(".", 1)
|
|
||||||
if cmd_name not in self.graphs:
|
|
||||||
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
|
|
||||||
ref_graph = self.graphs[cmd_name]
|
|
||||||
if task_name not in ref_graph.all_specs():
|
|
||||||
raise ValueError(f"任务 '{task_name}' 不存在于命令 '{cmd_name}' 中")
|
|
||||||
return [ref_graph.all_specs()[task_name]]
|
|
||||||
else:
|
|
||||||
cmd_name = ref
|
|
||||||
if cmd_name not in self.graphs:
|
|
||||||
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
|
|
||||||
ref_graph = self.graphs[cmd_name]
|
|
||||||
ref_graph = self.expand_refs(ref_graph, cmd_name)
|
|
||||||
return list(ref_graph.all_specs().values())
|
|
||||||
|
|
||||||
|
|
||||||
def compose(
|
|
||||||
graphs: dict[str, Graph],
|
|
||||||
) -> dict[str, Graph]:
|
|
||||||
"""编程式解析多图的字符串引用,返回展开后的新图映射。
|
|
||||||
|
|
||||||
与 :class:`GraphComposer` 等价,但作为独立函数暴露,供不使用
|
|
||||||
:class:`~pyflowx.runner.CliRunner` 的编程式用户调用。
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
>>> graphs = {
|
|
||||||
... "build": px.Graph.from_specs([px.TaskSpec("b", cmd=["echo", "b"])]),
|
|
||||||
... "all": px.Graph.from_specs(["build", px.TaskSpec("t", cmd=["echo", "t"])]),
|
|
||||||
... }
|
|
||||||
>>> resolved = px.compose(graphs)
|
|
||||||
>>> "b" in resolved["all"].all_specs()
|
|
||||||
True
|
|
||||||
"""
|
|
||||||
return GraphComposer(graphs).resolve_all()
|
|
||||||
|
|||||||
@@ -17,9 +17,10 @@ import sys
|
|||||||
from dataclasses import dataclass, field, replace
|
from dataclasses import dataclass, field, replace
|
||||||
from typing import Any, Sequence, get_args
|
from typing import Any, Sequence, get_args
|
||||||
|
|
||||||
|
from .compose import GraphComposer
|
||||||
from .errors import PyFlowXError
|
from .errors import PyFlowXError
|
||||||
from .executors import Strategy, run
|
from .executors import Strategy, run
|
||||||
from .graph import Graph, GraphComposer
|
from .graph import Graph
|
||||||
from .task import TaskSpec
|
from .task import TaskSpec
|
||||||
|
|
||||||
__all__ = ["CliExitCode", "CliRunner"]
|
__all__ = ["CliExitCode", "CliRunner"]
|
||||||
|
|||||||
+38
-1
@@ -18,8 +18,9 @@ import sys
|
|||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
|
from contextlib import contextmanager, nullcontext
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Mapping
|
from typing import Any, ContextManager, Mapping
|
||||||
|
|
||||||
if sys.version_info >= (3, 12):
|
if sys.version_info >= (3, 12):
|
||||||
from typing import override
|
from typing import override
|
||||||
@@ -55,6 +56,22 @@ class StateBackend(ABC):
|
|||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""清除所有存储状态。"""
|
"""清除所有存储状态。"""
|
||||||
|
|
||||||
|
def flush(self) -> None: # noqa: B027
|
||||||
|
"""将内存中暂存的状态持久化到外部介质。
|
||||||
|
|
||||||
|
默认无操作(如 :class:`MemoryBackend` 无需落盘)。
|
||||||
|
:class:`JSONBackend` 在 :meth:`batch` 期间会延迟落盘,需在退出时调用。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def batch(self) -> ContextManager[None]:
|
||||||
|
"""返回一个上下文管理器,期间 :meth:`save` 可延迟 :meth:`flush`。
|
||||||
|
|
||||||
|
默认实现为 no-op(如 :class:`MemoryBackend`)。:class:`JSONBackend`
|
||||||
|
覆盖为:进入时标记延迟,退出时统一 flush 一次,将每任务一次落盘
|
||||||
|
(N 次写入)降为整次运行一次(O(N) 而非 O(N²))。
|
||||||
|
"""
|
||||||
|
return nullcontext()
|
||||||
|
|
||||||
|
|
||||||
class _TTLStateBackendMixin(StateBackend):
|
class _TTLStateBackendMixin(StateBackend):
|
||||||
"""TTL 状态后端共享逻辑。
|
"""TTL 状态后端共享逻辑。
|
||||||
@@ -184,6 +201,7 @@ class JSONBackend(_TTLStateBackendMixin):
|
|||||||
self._path: str = path
|
self._path: str = path
|
||||||
self._ttl = ttl
|
self._ttl = ttl
|
||||||
self._store: dict[str, dict[str, Any]] = {}
|
self._store: dict[str, dict[str, Any]] = {}
|
||||||
|
self._defer_flush: bool = False
|
||||||
self._load()
|
self._load()
|
||||||
|
|
||||||
def _load(self) -> None:
|
def _load(self) -> None:
|
||||||
@@ -244,8 +262,27 @@ class JSONBackend(_TTLStateBackendMixin):
|
|||||||
except (TypeError, ValueError) as exc:
|
except (TypeError, ValueError) as exc:
|
||||||
raise StorageError(f"result of key {key!r} is not JSON-serialisable", exc) from exc
|
raise StorageError(f"result of key {key!r} is not JSON-serialisable", exc) from exc
|
||||||
super().save(key, value)
|
super().save(key, value)
|
||||||
|
if not self._defer_flush:
|
||||||
|
self._flush()
|
||||||
|
|
||||||
|
@override
|
||||||
|
def flush(self) -> None:
|
||||||
self._flush()
|
self._flush()
|
||||||
|
|
||||||
|
@override
|
||||||
|
@contextmanager
|
||||||
|
def batch(self) -> Iterator[None]:
|
||||||
|
"""进入批量模式:``save`` 暂不落盘,退出时统一 flush 一次。
|
||||||
|
|
||||||
|
将整次运行 N 个任务的 N 次全量落盘降为 1 次。
|
||||||
|
"""
|
||||||
|
self._defer_flush = True
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
self._defer_flush = False
|
||||||
|
self._flush()
|
||||||
|
|
||||||
def _expired(self, entry: Mapping[str, Any]) -> bool:
|
def _expired(self, entry: Mapping[str, Any]) -> bool:
|
||||||
"""带元数据的条目是否已过期(兼容旧测试 API)。"""
|
"""带元数据的条目是否已过期(兼容旧测试 API)。"""
|
||||||
return self._is_expired(float(entry.get("ts", 0)))
|
return self._is_expired(float(entry.get("ts", 0)))
|
||||||
|
|||||||
+51
-95
@@ -19,12 +19,13 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
@@ -291,13 +292,16 @@ class TaskSpec(Generic[T]):
|
|||||||
if self.fn is None and self.cmd is None:
|
if self.fn is None and self.cmd is None:
|
||||||
raise ValueError(f"TaskSpec '{self.name}': 必须提供 fn 或 cmd 参数。")
|
raise ValueError(f"TaskSpec '{self.name}': 必须提供 fn 或 cmd 参数。")
|
||||||
|
|
||||||
@property
|
@cached_property
|
||||||
def effective_fn(self) -> TaskFn[T]:
|
def effective_fn(self) -> TaskFn[T]:
|
||||||
"""获取有效的执行函数。
|
"""获取有效的执行函数。
|
||||||
|
|
||||||
若提供 ``cmd``,返回包装后的命令执行函数;否则返回 ``fn``。
|
若提供 ``cmd``,返回包装后的命令执行函数;否则返回 ``fn``。
|
||||||
包装函数在每次调用时从 ``self`` 读取 ``verbose``/``cwd``/``env``/
|
包装函数在每次调用时从 ``self`` 读取 ``verbose``/``cwd``/``env``/
|
||||||
``timeout``,避免闭包捕获运行期参数,使翻转字段无需重建 spec。
|
``timeout``,避免闭包捕获运行期参数,使翻转字段无需重建 spec。
|
||||||
|
|
||||||
|
结果按实例缓存(:func:`functools.cached_property`):frozen dataclass
|
||||||
|
字段不可变,``_wrap_cmd`` 生成的闭包稳定,无需每次访问重建。
|
||||||
"""
|
"""
|
||||||
if self.cmd is not None:
|
if self.cmd is not None:
|
||||||
return self._wrap_cmd()
|
return self._wrap_cmd()
|
||||||
@@ -306,11 +310,17 @@ class TaskSpec(Generic[T]):
|
|||||||
raise ValueError(f"TaskSpec '{self.name}': 没有可执行的函数或命令。") # pragma: no cover
|
raise ValueError(f"TaskSpec '{self.name}': 没有可执行的函数或命令。") # pragma: no cover
|
||||||
|
|
||||||
def _wrap_cmd(self) -> TaskFn[Any]:
|
def _wrap_cmd(self) -> TaskFn[Any]:
|
||||||
"""将 cmd 包装为可执行函数。"""
|
"""将 cmd 包装为可执行函数。
|
||||||
|
|
||||||
|
实际执行逻辑位于 :mod:`pyflowx.command`,避免 :class:`TaskSpec`
|
||||||
|
作为纯数据结构混入命令执行逻辑。
|
||||||
|
"""
|
||||||
|
from .command import run_command
|
||||||
|
|
||||||
spec = self
|
spec = self
|
||||||
|
|
||||||
def _run() -> T:
|
def _run() -> T:
|
||||||
return cast(T, _run_command(spec))
|
return cast(T, run_command(spec))
|
||||||
|
|
||||||
_run.__name__ = spec.name
|
_run.__name__ = spec.name
|
||||||
return _run # type: ignore[return-value]
|
return _run # type: ignore[return-value]
|
||||||
@@ -376,105 +386,51 @@ class TaskSpec(Generic[T]):
|
|||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
# 全局锁:序列化对进程级状态(os.environ / os.chdir)的临时修改。
|
||||||
|
# ``fn`` 任务在 thread/async 策略下并发执行时,若各自配置了不同的
|
||||||
|
# ``cwd``/``env``,会相互覆盖(os.chdir 与 os.environ 均为进程全局)。
|
||||||
|
# 该锁仅包裹"切换→执行→恢复"区间,保证正确性;不使用 cwd/env 的任务不受影响。
|
||||||
|
_env_cwd_lock = threading.RLock()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _env_and_cwd(
|
def _env_and_cwd(
|
||||||
env: Mapping[str, str] | None,
|
env: Mapping[str, str] | None,
|
||||||
cwd: Path | None,
|
cwd: Path | None,
|
||||||
) -> Generator[None, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
"""临时设置环境变量与工作目录。"""
|
"""临时设置环境变量与工作目录。
|
||||||
saved_env: dict[str, str] = {}
|
|
||||||
saved_cwd: str | None = None
|
``os.environ`` 与 ``os.chdir`` 是进程级全局状态,在 thread/async 策略下
|
||||||
if env:
|
并发执行多个带 ``env``/``cwd`` 的 ``fn`` 任务时会相互覆盖。本函数通过
|
||||||
for k, v in env.items():
|
模块级 :data:`_env_cwd_lock` 串行化"切换→执行→恢复"区间,确保正确性。
|
||||||
if k in os.environ:
|
无 ``env`` 且无 ``cwd`` 时直接 yield,不获取锁。
|
||||||
saved_env[k] = os.environ[k]
|
"""
|
||||||
os.environ[k] = v
|
if not env and cwd is None:
|
||||||
if cwd is not None:
|
|
||||||
saved_cwd = str(Path.cwd())
|
|
||||||
os.chdir(cwd)
|
|
||||||
try:
|
|
||||||
yield
|
yield
|
||||||
finally:
|
return
|
||||||
if saved_cwd is not None:
|
with _env_cwd_lock:
|
||||||
os.chdir(saved_cwd)
|
saved_env: dict[str, str] = {}
|
||||||
# 恢复环境变量
|
saved_cwd: str | None = None
|
||||||
if env:
|
if env:
|
||||||
for k in env:
|
for k, v in env.items():
|
||||||
if k in saved_env:
|
if k in os.environ:
|
||||||
os.environ[k] = saved_env[k]
|
saved_env[k] = os.environ[k]
|
||||||
else:
|
os.environ[k] = v
|
||||||
os.environ.pop(k, None)
|
|
||||||
|
|
||||||
|
|
||||||
def _run_command(spec: TaskSpec[Any]) -> Any: # noqa: PLR0912
|
|
||||||
"""执行 ``spec.cmd`` 指定的命令(list / shell 字符串 / 可调用对象)。"""
|
|
||||||
cmd = spec.cmd
|
|
||||||
verbose = spec.verbose
|
|
||||||
cwd = spec.cwd
|
|
||||||
timeout = spec.timeout
|
|
||||||
env_override = spec.env
|
|
||||||
|
|
||||||
# 可调用对象:直接调用,返回其结果。
|
|
||||||
if callable(cmd) and not isinstance(cmd, (list, str)):
|
|
||||||
name = getattr(cmd, "__name__", "callable")
|
|
||||||
if verbose:
|
|
||||||
print(f"[verbose] 执行可调用命令: {name}", flush=True)
|
|
||||||
if cwd is not None:
|
|
||||||
print(f"[verbose] 工作目录: {cwd}", flush=True)
|
|
||||||
try:
|
|
||||||
return cmd()
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"可调用命令执行异常: {name}: {e}") from e
|
|
||||||
|
|
||||||
is_list = isinstance(cmd, list)
|
|
||||||
if is_list:
|
|
||||||
cmd_str = " ".join(arg for arg in cmd) # type: ignore[union-attr]
|
|
||||||
verb = "执行命令"
|
|
||||||
label = "命令"
|
|
||||||
else:
|
|
||||||
cmd_str = cast(str, cmd)
|
|
||||||
verb = "执行 Shell"
|
|
||||||
label = "Shell 命令"
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"[verbose] {verb}: {cmd_str}", flush=True)
|
|
||||||
if cwd is not None:
|
if cwd is not None:
|
||||||
print(f"[verbose] 工作目录: {cwd}", flush=True)
|
saved_cwd = str(Path.cwd())
|
||||||
|
os.chdir(cwd)
|
||||||
# 合并环境变量
|
try:
|
||||||
run_env: dict[str, str] | None = None
|
yield
|
||||||
if env_override:
|
finally:
|
||||||
run_env = dict(os.environ)
|
if saved_cwd is not None:
|
||||||
run_env.update(env_override)
|
os.chdir(saved_cwd)
|
||||||
|
# 恢复环境变量
|
||||||
try:
|
if env:
|
||||||
result = subprocess.run(
|
for k in env:
|
||||||
cast(Union[str, List[str]], cmd),
|
if k in saved_env:
|
||||||
shell=not is_list,
|
os.environ[k] = saved_env[k]
|
||||||
cwd=cwd,
|
else:
|
||||||
env=run_env,
|
os.environ.pop(k, None)
|
||||||
timeout=timeout,
|
|
||||||
capture_output=not verbose,
|
|
||||||
text=True,
|
|
||||||
check=False,
|
|
||||||
)
|
|
||||||
except FileNotFoundError:
|
|
||||||
raise RuntimeError(f"{label}未找到: {cmd_str}") from None
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
raise RuntimeError(f"{label}执行超时: {cmd_str} ({timeout}s)") from None
|
|
||||||
except OSError as e:
|
|
||||||
raise RuntimeError(f"{label}执行异常: {cmd_str}: {e}") from e
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"[verbose] 返回码: {result.returncode}", flush=True)
|
|
||||||
|
|
||||||
if result.returncode == 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
err_msg = f"{label}执行失败: `{cmd_str}`, 返回码: {result.returncode}"
|
|
||||||
if not verbose and result.stderr.strip():
|
|
||||||
err_msg += f"\n{result.stderr.strip()}"
|
|
||||||
raise RuntimeError(err_msg)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
|
|||||||
@@ -1,107 +0,0 @@
|
|||||||
"""常用工具函数."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
__all__ = ["perf_timer"]
|
|
||||||
|
|
||||||
import functools
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Callable, TypedDict
|
|
||||||
|
|
||||||
try:
|
|
||||||
from typing_extensions import ParamSpec, TypeVar
|
|
||||||
except ImportError:
|
|
||||||
from typing import ParamSpec, TypeVar
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
R = TypeVar("R")
|
|
||||||
|
|
||||||
|
|
||||||
class _PerformanceMetrics(TypedDict):
|
|
||||||
"""性能指标."""
|
|
||||||
|
|
||||||
count: int
|
|
||||||
total_time: float
|
|
||||||
|
|
||||||
|
|
||||||
_perf_metrics: defaultdict[str, _PerformanceMetrics] = defaultdict(
|
|
||||||
lambda: _PerformanceMetrics(
|
|
||||||
count=0,
|
|
||||||
total_time=0.0,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_report(unit: str, precision: int) -> str:
|
|
||||||
"""生成性能指标报告,返回报告字符串."""
|
|
||||||
if not _perf_metrics:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
lines: list[str] = []
|
|
||||||
lines.append("=" * 50)
|
|
||||||
lines.append("性能指标报告 (Performance Metrics Report)")
|
|
||||||
lines.append("-" * 50)
|
|
||||||
|
|
||||||
# 按总耗时排序,最耗时的函数排在前面
|
|
||||||
sorted_metrics = sorted(_perf_metrics.items(), key=lambda x: x[1]["total_time"], reverse=True)
|
|
||||||
|
|
||||||
for name, metrics in sorted_metrics:
|
|
||||||
avg_time = metrics["total_time"] / metrics["count"] if metrics["count"] > 0 else 0
|
|
||||||
lines.append(
|
|
||||||
f"{name}: "
|
|
||||||
f"调用次数={metrics['count']}, "
|
|
||||||
f"总耗时={metrics['total_time']:.{precision}f}{unit}, "
|
|
||||||
f"平均耗时={avg_time:.{precision}f}{unit}"
|
|
||||||
)
|
|
||||||
|
|
||||||
lines.append("=" * 50)
|
|
||||||
report_str = "\n".join(lines)
|
|
||||||
|
|
||||||
# 同时输出到日志
|
|
||||||
logging.info("\n".join(lines))
|
|
||||||
|
|
||||||
return report_str
|
|
||||||
|
|
||||||
|
|
||||||
def perf_timer(unit: str = "ms", precision: int = 4, report: bool = False):
|
|
||||||
"""性能计时器装饰器."""
|
|
||||||
scale: dict[str, float] = {
|
|
||||||
"s": 1.0,
|
|
||||||
"ms": 1000.0,
|
|
||||||
"us": 1000000.0,
|
|
||||||
}
|
|
||||||
|
|
||||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
|
||||||
@functools.wraps(func)
|
|
||||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
||||||
start_time = time.time()
|
|
||||||
result = func(*args, **kwargs)
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
_perf_metrics[func.__name__]["count"] += 1
|
|
||||||
_perf_metrics[func.__name__]["total_time"] += (end_time - start_time) * scale[unit]
|
|
||||||
|
|
||||||
if not report:
|
|
||||||
logging.info(
|
|
||||||
f"{func.__name__} {unit}: {_perf_metrics[func.__name__]['total_time']:.{precision}f}{unit}"
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
if report:
|
|
||||||
import atexit
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logging.info(f"Performance metrics report enabled with unit {unit} and precision {precision}")
|
|
||||||
|
|
||||||
@atexit.register
|
|
||||||
def _report_at_exit() -> None:
|
|
||||||
"""在程序退出时报告性能指标."""
|
|
||||||
_generate_report(unit, precision)
|
|
||||||
|
|
||||||
# 将报告生成逻辑提取为独立函数,便于测试
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
+1
-1
@@ -5,8 +5,8 @@ from __future__ import annotations
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import pyflowx as px
|
import pyflowx as px
|
||||||
|
from pyflowx.compose import GraphComposer, compose
|
||||||
from pyflowx.errors import CycleError, DuplicateTaskError, MissingDependencyError
|
from pyflowx.errors import CycleError, DuplicateTaskError, MissingDependencyError
|
||||||
from pyflowx.graph import GraphComposer, compose
|
|
||||||
|
|
||||||
|
|
||||||
def _fn() -> None:
|
def _fn() -> None:
|
||||||
|
|||||||
+7
-7
@@ -345,14 +345,14 @@ def test_task_result_default_status() -> None:
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
# _run_command callable 命令测试
|
# run_command callable 命令测试
|
||||||
# ---------------------------------------------------------------------- #
|
# ---------------------------------------------------------------------- #
|
||||||
def test_run_command_callable_verbose_with_cwd(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None:
|
def test_run_command_callable_verbose_with_cwd(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None:
|
||||||
"""callable 命令 verbose 模式应打印信息."""
|
"""callable 命令 verbose 模式应打印信息."""
|
||||||
spec = TaskSpec("a", cmd=lambda: "result", verbose=True, cwd=tmp_path)
|
from pyflowx.command import run_command
|
||||||
import pyflowx.task as task_module
|
|
||||||
|
|
||||||
result = task_module._run_command(spec)
|
spec = TaskSpec("a", cmd=lambda: "result", verbose=True, cwd=tmp_path)
|
||||||
|
result = run_command(spec)
|
||||||
assert result == "result"
|
assert result == "result"
|
||||||
captured = capsys.readouterr()
|
captured = capsys.readouterr()
|
||||||
assert "执行可调用命令" in captured.out
|
assert "执行可调用命令" in captured.out
|
||||||
@@ -361,8 +361,8 @@ def test_run_command_callable_verbose_with_cwd(capsys: pytest.CaptureFixture[str
|
|||||||
|
|
||||||
def test_run_command_callable_exception() -> None:
|
def test_run_command_callable_exception() -> None:
|
||||||
"""callable 命令抛异常应转为 RuntimeError."""
|
"""callable 命令抛异常应转为 RuntimeError."""
|
||||||
spec = TaskSpec("a", cmd=lambda: (_ for _ in ()).throw(RuntimeError("callable error")))
|
from pyflowx.command import run_command
|
||||||
import pyflowx.task as task_module
|
|
||||||
|
|
||||||
|
spec = TaskSpec("a", cmd=lambda: (_ for _ in ()).throw(RuntimeError("callable error")))
|
||||||
with pytest.raises(RuntimeError, match="可调用命令执行异常"):
|
with pytest.raises(RuntimeError, match="可调用命令执行异常"):
|
||||||
task_module._run_command(spec)
|
run_command(spec)
|
||||||
|
|||||||
@@ -1,65 +0,0 @@
|
|||||||
import time
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pytest_mock import MockerFixture
|
|
||||||
|
|
||||||
from pyflowx.utils import _perf_metrics, perf_timer
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def reset_perf_metrics():
|
|
||||||
"""重置性能指标."""
|
|
||||||
_perf_metrics.clear()
|
|
||||||
|
|
||||||
|
|
||||||
class TestPerformanceTimer:
|
|
||||||
def test_perf_timer(self):
|
|
||||||
|
|
||||||
@perf_timer()
|
|
||||||
def test_func():
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
test_func()
|
|
||||||
|
|
||||||
assert _perf_metrics["test_func"] is not None
|
|
||||||
assert _perf_metrics["test_func"]["count"] == 1
|
|
||||||
assert _perf_metrics["test_func"]["total_time"] >= 0.1
|
|
||||||
|
|
||||||
def test_perf_timer_report(self, mocker: MockerFixture):
|
|
||||||
mock_log = mocker.patch("logging.info")
|
|
||||||
|
|
||||||
@perf_timer(report=True, unit="ms", precision=3)
|
|
||||||
def test_func():
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
test_func()
|
|
||||||
|
|
||||||
assert _perf_metrics["test_func"] is not None
|
|
||||||
assert _perf_metrics["test_func"]["count"] == 1
|
|
||||||
assert _perf_metrics["test_func"]["total_time"] >= 0.1
|
|
||||||
|
|
||||||
assert mock_log.call_count == 1
|
|
||||||
|
|
||||||
def test_generate_report(self, mocker: MockerFixture, caplog: pytest.LogCaptureFixture):
|
|
||||||
mock_log = mocker.patch("logging.info")
|
|
||||||
|
|
||||||
from pyflowx.utils import _generate_report
|
|
||||||
|
|
||||||
@perf_timer(report=True, unit="ms", precision=3)
|
|
||||||
def test_func():
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
@perf_timer(report=True, unit="ms", precision=3)
|
|
||||||
def test_func2():
|
|
||||||
time.sleep(0.2)
|
|
||||||
|
|
||||||
test_func()
|
|
||||||
test_func2()
|
|
||||||
|
|
||||||
_generate_report("ms", 3)
|
|
||||||
|
|
||||||
assert mock_log.call_count == 3
|
|
||||||
assert _perf_metrics["test_func"]["count"] == 1
|
|
||||||
assert _perf_metrics["test_func"]["total_time"] >= 0.1
|
|
||||||
assert _perf_metrics["test_func2"]["count"] == 1
|
|
||||||
assert _perf_metrics["test_func2"]["total_time"] >= 0.2
|
|
||||||
@@ -5603,7 +5603,7 @@ pycountry = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyflowx"
|
name = "pyflowx"
|
||||||
version = "0.2.10"
|
version = "0.2.11"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "graphlib-backport", marker = "python_full_version < '3.9'" },
|
{ name = "graphlib-backport", marker = "python_full_version < '3.9'" },
|
||||||
|
|||||||
Reference in New Issue
Block a user