6 Commits

Author SHA1 Message Date
zhou c15b38516a bump version to 0.2.11
Release / build (push) Failing after 29m15s
Release / publish-pypi (push) Has been skipped
Release / release (push) Has been skipped
2026-06-27 23:08:32 +08:00
zhou 7d4e8a40ce refactor(cli): 重构CLI模块结构,整理系统工具与开发工具
1. 将原cli根目录下的clearscreen、taskkill、which工具迁移到cli/system子目录
2. 新增cli/dev子目录并添加envdev环境配置工具
3. 更新pyproject.toml中的脚本入口点映射
4. 调整tests/cli下的测试文件导入路径
5. 整理tasks/system.py的__all__导出顺序
2026-06-27 22:01:02 +08:00
zhou 1b2d6d6a2c chore: 更新依赖配置并移除 pysnooper 2026-06-27 21:53:20 +08:00
zhou df890f0f16 chore: 移除独立的envpy和envrs命令,合并功能到envdev
将原来envpy和envrs的环境配置功能整合到envdev命令中,删除了冗余的独立CLI模块和测试文件,统一管理Python、Conda和Rust的环境配置。
2026-06-27 21:22:36 +08:00
zhou b62a544569 chore: 调整Python版本与依赖适配,新增性能报告测试与工具函数
1.  将Python版本从3.13降级到3.11
2.  为typing-extensions添加版本适配标记
3.  简化dev依赖组,移除pysnooper
4.  重构perf_timer,提取_generate_report独立函数
5.  新增性能报告生成与测试用例
2026-06-27 20:47:29 +08:00
zhou d58fc5536e chore: 发布 pyflowx 0.2.10,新增性能计时器与多项重构
1. 新增 perf_timer 工具与配套测试用例
2. 重构任务条件跳过逻辑,优化失败条件展示
3. 重构 Graph 子图生成逻辑,提取公共依赖修剪函数
4. 重构条件模块,统一条件名称与失败原因获取逻辑
5. 重构存储后端,提取 TTL 共享逻辑并优化实现
6. 重构执行器模块,使用 Mixin 复用代码,拆分任务与层执行逻辑
7. 删除冗余的 which 命令测试文件
8. 更新依赖锁文件
2026-06-27 20:15:35 +08:00
25 changed files with 899 additions and 1244 deletions
+1 -1
View File
@@ -1 +1 @@
3.13
3.11
+9 -12
View File
@@ -13,7 +13,7 @@ classifiers = [
]
dependencies = [
"graphlib_backport >= 1.0.0; python_version < '3.9'",
"typing-extensions>=4.13.2",
"typing-extensions>=4.13.2; python_version < '3.10'",
]
description = "Lightweight, type-safe DAG task scheduler with multi-strategy execution."
keywords = ["async", "dag", "scheduler", "task", "workflow"]
@@ -21,16 +21,12 @@ license = { text = "MIT" }
name = "pyflowx"
readme = "README.md"
requires-python = ">=3.8"
version = "0.2.10"
version = "0.2.11"
[project.scripts]
autofmt = "pyflowx.cli.autofmt:main"
bumpversion = "pyflowx.cli.bumpversion:main"
clr = "pyflowx.cli.clearscreen:main"
emlman = "pyflowx.cli.emlmanager:main"
envdev = "pyflowx.cli.envdev:main"
envpy = "pyflowx.cli.envpy:main"
envrs = "pyflowx.cli.envrs:main"
filedate = "pyflowx.cli.filedate:main"
filelvl = "pyflowx.cli.filelevel:main"
foldback = "pyflowx.cli.folderback:main"
@@ -46,8 +42,12 @@ reseticon = "pyflowx.cli.reseticoncache:main"
scrcap = "pyflowx.cli.screenshot:main"
sglang = "pyflowx.cli.llm.sglang:main"
sshcopy = "pyflowx.cli.sshcopyid:main"
taskk = "pyflowx.cli.taskkill:main"
wch = "pyflowx.cli.which:main"
# dev
envdev = "pyflowx.cli.dev.envdev:main"
# system
clr = "pyflowx.cli.system.clearscreen:main"
taskk = "pyflowx.cli.system.taskkill:main"
wch = "pyflowx.cli.system.which:main"
[project.optional-dependencies]
dev = [
@@ -93,10 +93,7 @@ packages = ["src/pyflowx"]
pyflowx = { workspace = true }
[dependency-groups]
dev = [
"pyflowx[dev,office,llm]",
"pysnooper>=1.2.3",
]
dev = ["pyflowx[dev,office,llm]"]
[tool.coverage.run]
branch = true
+1 -1
View File
@@ -95,7 +95,7 @@ from .task import (
task_template,
)
__version__ = "0.3.4"
__version__ = "0.3.5"
__all__ = [
"IS_LINUX",
View File
@@ -127,6 +127,37 @@ CHINESE_FONTS: list[str] = [
"fonts-noto-color-emoji",
]
# ============================================================================
# Rust 配置
# ============================================================================
RustMirrorType = Literal["tsinghua", "ustc", "aliyun"]
RustVersionType = Literal["stable", "nightly", "beta"]
DEFAULT_RUST_VERSION: RustVersionType = "stable"
DEFAULT_MIRROR: RustMirrorType = "tsinghua"
RUSTUP_MIRRORS: dict[RustMirrorType, dict[str, str]] = {
"tsinghua": {
"RUSTUP_DIST_SERVER": "https://mirrors.tuna.tsinghua.edu.cn/rustup",
"RUSTUP_UPDATE_ROOT": "https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup",
"TOML_REGISTRY": "https://mirrors.tuna.tsinghua.edu.cn/crates.io-index/",
},
"aliyun": {
"RUSTUP_DIST_SERVER": "https://mirrors.aliyun.com/rustup",
"RUSTUP_UPDATE_ROOT": "https://mirrors.aliyun.com/rustup/rustup",
"TOML_REGISTRY": "https://mirrors.aliyun.com/crates.io-index/",
},
"ustc": {
"RUSTUP_DIST_SERVER": "https://mirrors.ustc.edu.cn/rust-static",
"RUSTUP_UPDATE_ROOT": "https://mirrors.ustc.edu.cn/rust-static/rustup",
"TOML_REGISTRY": "https://mirrors.ustc.edu.cn/crates.io-index/",
},
}
RUSTUP_DOWNLOAD_URL_LINUX = "https://mirrors.aliyun.com/repo/rust/rustup-init.sh"
RUSTUP_DOWNLOAD_URL_WINDOWS = "https://static.rust-lang.org/rustup/dist/x86_64-pc-windows-msvc/rustup-init.exe"
RUST_CONFIG_PATH = Path.home() / ".cargo" / "config.toml"
RUST_SCCACHE_DIR: Path = Path.home() / ".cargo" / "sccache"
RUST_SCCACHE_CACHE_SIZE: str = "20G"
def main() -> None:
"""主函数."""
@@ -147,14 +178,34 @@ def main() -> None:
choices=get_args(CondaMirrorType),
help="Conda 镜镜像源",
)
parser.add_argument(
"--rust-mirror",
nargs="?",
type=str,
default=DEFAULT_MIRROR,
choices=get_args(RustMirrorType),
help="Rust 镜像源",
)
parser.add_argument(
"--rust-version",
nargs="?",
type=str,
default=DEFAULT_RUST_VERSION,
choices=get_args(RustVersionType),
help=f"Rust 版本, 推荐: {get_args(RustVersionType)}",
)
args = parser.parse_args()
python_mirror = args.python_mirror
conda_mirror_urls = CONDA_MIRROR_URLS[args.conda_mirror]
rust_mirror = args.rust_mirror
rust_version = args.rust_version
# 确保配置文件目录存在
PIP_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
CONDA_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
RUST_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
RUST_SCCACHE_DIR.mkdir(parents=True, exist_ok=True)
# 使用 conditions 自动控制任务执行
graph = px.Graph.from_specs([
@@ -222,5 +273,59 @@ def main() -> None:
str(CONDA_CONFIG_PATH),
"show_channel_urls: true\nchannels:\n - " + "\n - ".join(conda_mirror_urls) + "\n - defaults",
),
# 设置 Rust 镜像源
*setenv_group({
"RUSTUP_DIST_SERVER": RUSTUP_MIRRORS[rust_mirror]["RUSTUP_DIST_SERVER"],
"RUSTUP_UPDATE_ROOT": RUSTUP_MIRRORS[rust_mirror]["RUSTUP_UPDATE_ROOT"],
"RUST_SCCACHE_DIR": str(RUST_SCCACHE_DIR),
"RUST_SCCACHE_CACHE_SIZE": RUST_SCCACHE_CACHE_SIZE,
}),
# 写入 Rust 配置(仅当未配置)
write_file(
str(RUST_CONFIG_PATH),
f"""
[source.crates-io]
replace-with = '{rust_mirror}'
[source.{rust_mirror}]
registry = "sparse+{RUSTUP_MIRRORS[rust_mirror]["TOML_REGISTRY"]}"
[registries.{rust_mirror}]
index = "sparse+{RUSTUP_MIRRORS[rust_mirror]["TOML_REGISTRY"]}"
""",
),
# 下载 Rustup 安装脚本
px.TaskSpec(
"download_rustup",
cmd=["curl", "-fsSL", RUSTUP_DOWNLOAD_URL_LINUX, "-o", "rustup-init.sh"],
conditions=(BuiltinConditions.IS_LINUX(), BuiltinConditions.NOT(BuiltinConditions.HAS_INSTALLED("rustup"))),
verbose=True,
),
px.TaskSpec(
"download_rustup_win",
cmd=[
"powershell",
"-Command",
"Invoke-WebRequest",
"-Uri",
RUSTUP_DOWNLOAD_URL_WINDOWS,
"-OutFile",
"rustup-init.exe",
],
conditions=(
BuiltinConditions.IS_WINDOWS(),
BuiltinConditions.NOT(BuiltinConditions.HAS_INSTALLED("rustup")),
),
verbose=True,
),
# 安装 Rust 工具链
px.TaskSpec(
"install_rust",
cmd=["rustup", "toolchain", "install", rust_version],
conditions=(BuiltinConditions.HAS_INSTALLED("rustup"),),
depends_on=("setenv_rustup_dist_server",),
allow_upstream_skip=True,
verbose=True,
),
])
px.run(graph, strategy="thread", verbose=True)
-122
View File
@@ -1,122 +0,0 @@
"""Python 环境配置工具.
用于设置 pip 镜像源, 支持清华和阿里云等国内镜像源,
同时配置 UV 和 Conda 的镜像源.
"""
from __future__ import annotations
import argparse
import os
from pathlib import Path
import pyflowx as px
from pyflowx.conditions import Constants
# ============================================================================
# 配置
# ============================================================================
PIP_INDEX_URLS: dict[str, str] = {
"tsinghua": "https://pypi.tuna.tsinghua.edu.cn/simple",
"aliyun": "https://mirrors.aliyun.com/pypi/simple/",
}
PIP_TRUSTED_HOSTS: dict[str, str] = {
"tsinghua": "pypi.tuna.tsinghua.edu.cn",
"aliyun": "mirrors.aliyun.com",
}
UV_INDEX_URL: str = "https://mirrors.aliyun.com/pypi/simple/"
UV_PYTHON_INSTALL_MIRROR: str = "https://registry.npmmirror.com/-/binary/python-build-standalone"
CONDA_MIRROR_URLS: dict[str, list[str]] = {
"tsinghua": [
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/",
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/",
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/",
],
"aliyun": [
"https://mirrors.aliyun.com/anaconda/pkgs/main/",
"https://mirrors.aliyun.com/anaconda/pkgs/free/",
"https://mirrors.aliyun.com/anaconda/cloud/conda-forge/",
],
}
# ============================================================================
# 辅助函数
# ============================================================================
def set_pip_mirror(mirror: str = "tsinghua", token: str | None = None) -> None:
"""设置 pip 镜像源.
Parameters
----------
mirror : str
镜像源名称: tsinghua, aliyun
token : str | None
PyPI token for publishing
"""
index_url = PIP_INDEX_URLS.get(mirror, PIP_INDEX_URLS["tsinghua"])
trusted_host = PIP_TRUSTED_HOSTS.get(mirror, "")
# 设置环境变量
os.environ["PIP_INDEX_URL"] = index_url
os.environ["UV_INDEX_URL"] = UV_INDEX_URL
os.environ["UV_DEFAULT_INDEX"] = UV_INDEX_URL
os.environ["UV_PYTHON_INSTALL_MIRROR"] = UV_PYTHON_INSTALL_MIRROR
# 写入 pip 配置文件
pip_dir = Path.home() / "pip"
pip_dir.mkdir(exist_ok=True)
pip_conf = pip_dir / ("pip.ini" if Constants.IS_WINDOWS else "pip.conf")
pip_conf.write_text(f"[global]\nindex-url = {index_url}\n[install]\ntrusted-host = {trusted_host}\n")
# 写入 conda 配置文件
condarc = Path.home() / ".condarc"
conda_urls = CONDA_MIRROR_URLS.get(mirror, CONDA_MIRROR_URLS["tsinghua"])
condarc.write_text(
"show_channel_urls: true\nchannels:\n" + "\n".join(f" - {url}" for url in conda_urls) + "\n - defaults\n"
)
# 写入 pypirc 配置文件 (如果有 token)
if token:
pypirc = Path.home() / ".pypirc"
pypirc.write_text(
f"[pypi]\nrepository: https://upload.pypi.org/legacy/\nusername: __token__\npassword: {token}\n"
)
print(f"已设置 pip 镜像源: {mirror} ({index_url})")
# ============================================================================
# CLI Runner
# ============================================================================
def main() -> None:
"""Python 环境配置工具主函数."""
parser = argparse.ArgumentParser(
description="EnvPy - Python 环境配置工具",
usage="envpy <command> [options]",
)
subparsers = parser.add_subparsers(dest="command", help="可用命令")
# 设置镜像源命令
mirror_parser = subparsers.add_parser("mirror", help="设置 pip 镜像源")
mirror_parser.add_argument("name", choices=["tsinghua", "aliyun"], help="镜像源名称")
mirror_parser.add_argument("--token", type=str, help="PyPI token for publishing")
args = parser.parse_args()
if args.command == "mirror":
graph = px.Graph.from_specs([
px.TaskSpec("set_pip_mirror", fn=set_pip_mirror, args=(args.name,), kwargs={"token": args.token})
])
else:
parser.print_help()
return
px.run(graph, strategy="thread")
-150
View File
@@ -1,150 +0,0 @@
"""Rust 环境配置工具.
配置 Rustup 和 Cargo 的国内镜像源,
加速 Rust 工具链和依赖包的下载.
"""
from __future__ import annotations
import argparse
import os
import subprocess
from pathlib import Path
from typing import Literal, get_args
import pyflowx as px
# ============================================================================
# 配置
# ============================================================================
RUSTUP_MIRRORS: dict[str, dict[str, str]] = {
"aliyun": {
"RUSTUP_DIST_SERVER": "https://mirrors.aliyun.com/rustup",
"RUSTUP_UPDATE_ROOT": "https://mirrors.aliyun.com/rustup/rustup",
"TOML_REGISTRY": "https://mirrors.aliyun.com/crates.io-index/",
},
"ustc": {
"RUSTUP_DIST_SERVER": "https://mirrors.ustc.edu.cn/rust-static",
"RUSTUP_UPDATE_ROOT": "https://mirrors.ustc.edu.cn/rust-static/rustup",
"TOML_REGISTRY": "https://mirrors.ustc.edu.cn/crates.io-index/",
},
"tsinghua": {
"RUSTUP_DIST_SERVER": "https://mirrors.tuna.tsinghua.edu.cn/rustup",
"RUSTUP_UPDATE_ROOT": "https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup",
"TOML_REGISTRY": "https://mirrors.tuna.tsinghua.edu.cn/crates.io-index/",
},
}
UsableRustVersion = Literal["stable", "nightly", "beta"]
UsableMirror = Literal["aliyun", "ustc", "tsinghua"]
DEFAULT_RUST_VERSION: UsableRustVersion = "stable"
DEFAULT_MIRROR: UsableMirror = "tsinghua"
# ============================================================================
# 辅助函数
# ============================================================================
def set_rust_mirror(mirror: UsableMirror = DEFAULT_MIRROR) -> None:
"""设置 Rust 镜像源.
Parameters
----------
mirror : str
镜像源名称: aliyun, ustc, tsinghua
"""
mirror_dict = RUSTUP_MIRRORS.get(mirror, RUSTUP_MIRRORS[DEFAULT_MIRROR])
server = mirror_dict["RUSTUP_DIST_SERVER"]
update_root = mirror_dict["RUSTUP_UPDATE_ROOT"]
toml_registry = mirror_dict["TOML_REGISTRY"]
# 设置环境变量
os.environ["RUSTUP_DIST_SERVER"] = server
os.environ["RUSTUP_UPDATE_ROOT"] = update_root
# 写入 cargo 配置
cargo_dir = Path.home() / ".cargo"
cargo_dir.mkdir(exist_ok=True)
cargo_config = cargo_dir / "config.toml"
cargo_config.write_text(
f"""[source.crates-io]
replace-with = '{mirror}'
[source.{mirror}]
registry = "sparse+{toml_registry}"
[registries.{mirror}]
index = "sparse+{toml_registry}"
"""
)
print(f"已设置 Rust 镜像源: {mirror}")
def install_rust(version: UsableRustVersion = DEFAULT_RUST_VERSION) -> None:
"""安装 Rust 工具链.
Parameters
----------
version : str
Rust 版本: stable, nightly, beta
"""
try:
subprocess.run(["rustup", "toolchain", "install", version], check=True)
print(f"已安装 Rust {version}")
except FileNotFoundError:
print("未找到 rustup,请先安装 Rust: https://rustup.rs")
raise
# ============================================================================
# CLI Runner
# ============================================================================
def main() -> None:
"""Rust 环境配置工具主函数."""
parser = argparse.ArgumentParser(
description="EnvRs - Rust 环境配置工具",
usage="envrs <command> [options]",
)
subparsers = parser.add_subparsers(dest="command", help="可用命令")
# 设置镜像源命令
mirror_parser = subparsers.add_parser("mirror", help="设置 Rust 镜像源")
mirror_parser.add_argument(
"name",
nargs="?",
default=DEFAULT_MIRROR,
choices=get_args(UsableMirror),
help=f"镜像源名称 ({get_args(UsableMirror)})",
)
# 安装 Rust 命令
install_parser = subparsers.add_parser("install", help="安装 Rust 工具链")
install_parser.add_argument(
"version",
nargs="?",
default=DEFAULT_RUST_VERSION,
choices=get_args(UsableRustVersion),
help=f"Rust 版本 ({get_args(UsableRustVersion)})",
)
args = parser.parse_args()
if args.command == "mirror":
graph = px.Graph.from_specs([
px.TaskSpec("set_rust_mirror", fn=set_rust_mirror, args=(args.name,), verbose=True)
])
elif args.command == "install":
graph = px.Graph.from_specs([
px.TaskSpec("install_rust", cmd=["rustup", "toolchain", "install", args.version], verbose=True)
])
else:
parser.print_help()
return
px.run(graph, strategy="thread", verbose=True)
View File
@@ -35,6 +35,6 @@ def main() -> None:
[
px.TaskSpec(f"kill_{proc_name}", cmd=[*cmd, f"{proc_name}*"], verbose=True)
for proc_name in args.process_names
]
],
)
px.run(graph, strategy="thread")
+27 -13
View File
@@ -42,6 +42,19 @@ def _static(predicate: Callable[[], bool], name: str) -> Condition:
return _cond
def _cond_reason(cond: Condition) -> str | list[str] | None:
"""获取条件的失败原因:优先返回 ``_reason``,否则返回 ``__name__``。"""
reason = getattr(cond, "_reason", None)
if reason is not None:
return reason
return getattr(cond, "__name__", repr(cond))
def _cond_name(cond: Condition) -> str:
"""获取条件的可读名称。"""
return getattr(cond, "__name__", repr(cond))
# ---------------------------------------------------------------------- #
# 模块级静态条件常量
# ---------------------------------------------------------------------- #
@@ -61,21 +74,25 @@ class BuiltinConditions:
# ------------------------------------------------------------------ #
# 静态条件
# ------------------------------------------------------------------ #
@staticmethod
def IS_WINDOWS() -> Condition:
"""检查是否为 Windows 平台."""
return _static(lambda: Constants.IS_WINDOWS, "IS_WINDOWS")
return IS_WINDOWS
@staticmethod
def IS_LINUX() -> Condition:
"""检查是否为 Linux 平台."""
return _static(lambda: Constants.IS_LINUX, "IS_LINUX")
return IS_LINUX
@staticmethod
def IS_MACOS() -> Condition:
"""检查是否为 macOS 平台."""
return _static(lambda: Constants.IS_MACOS, "IS_MACOS")
return IS_MACOS
@staticmethod
def IS_POSIX() -> Condition:
"""检查是否为 POSIX 平台."""
return _static(lambda: Constants.IS_POSIX, "IS_POSIX")
return IS_POSIX
@staticmethod
def PYTHON_VERSION(major: int, minor: int | None = None) -> Condition:
@@ -214,12 +231,12 @@ class BuiltinConditions:
result = condition(ctx)
if result:
# inner 为 True 时 NOT 会失败,记录 inner 的具体原因
inner_reason = getattr(condition, "_reason", None)
inner_reason = _cond_reason(condition)
if inner_reason is not None:
_cond._reason = inner_reason # type: ignore[attr-defined]
return not result
_cond.__name__ = f"NOT({getattr(condition, '__name__', repr(condition))})"
_cond.__name__ = f"NOT({_cond_name(condition)})"
return _cond
@staticmethod
@@ -229,8 +246,7 @@ class BuiltinConditions:
def _cond(ctx: Context) -> bool:
return all(c(ctx) for c in conditions)
names = [getattr(c, "__name__", repr(c)) for c in conditions]
_cond.__name__ = f"AND({', '.join(names)})"
_cond.__name__ = f"AND({', '.join(_cond_name(c) for c in conditions)})"
return _cond
@staticmethod
@@ -241,14 +257,12 @@ class BuiltinConditions:
matched: list[str] = []
for c in conditions:
if c(ctx):
matched.append(
getattr(c, "_reason", None) or getattr(c, "__name__", repr(c)),
)
reason = _cond_reason(c)
matched.append(reason if isinstance(reason, str) else str(reason))
if matched:
_cond._reason = matched # type: ignore[attr-defined]
return True
return False
names = [getattr(c, "__name__", repr(c)) for c in conditions]
_cond.__name__ = f"OR({', '.join(names)})"
_cond.__name__ = f"OR({', '.join(_cond_name(c) for c in conditions)})"
return _cond
+305 -350
View File
@@ -10,6 +10,17 @@
* ``dependency`` —— 依赖驱动调度:任务在其所有硬依赖完成后立即启动,
无需等待同层其他任务。最大化并行度。
架构
----
本模块通过 **Mixin** 组合消除同步/异步与各层执行器之间的重复代码:
* :class:`_TaskSkipMixin` —— 上游跳过 / 条件跳过的预检逻辑。
* :class:`_TaskRetryMixin` —— 重试决策、成功/失败后处理、finalize。
* :class:`_LayerMixin` —— 缓存过滤、优先级排序、信号量构建、结果存储。
* :class:`SyncTaskRunner` / :class:`AsyncTaskRunner` —— 任务级执行器,组合上述 Mixin。
* :class:`SequentialLayerRunner` / :class:`ThreadedLayerRunner` /
:class:`AsyncLayerRunner` / :class:`DependencyRunner` —— 层级执行器,组合 :class:`_LayerMixin`。
所有策略共享统一异步内核,支持:
* :class:`RetryPolicy`max_attempts/delay/backoff/jitter/retry_on
* 软依赖注入与默认值
@@ -30,6 +41,7 @@ import concurrent.futures
import inspect
import logging
import threading
import time
from datetime import datetime
from typing import Any, Awaitable, Callable, Literal, Mapping, cast
@@ -48,7 +60,7 @@ Strategy = Literal["sequential", "thread", "async", "dependency"]
# ---------------------------------------------------------------------- #
# 辅助
# 无状态公共辅助
# ---------------------------------------------------------------------- #
def _is_async_fn(spec: TaskSpec[Any]) -> bool:
"""判断 ``spec.effective_fn`` 是否为协程函数。"""
@@ -71,17 +83,6 @@ def _emit(on_event: EventCallback | None, result: TaskResult[Any]) -> None:
)
def _log_retry(spec: TaskSpec[Any], attempt: int, max_attempts: int, exc: BaseException) -> None:
"""记录重试日志。"""
logger.warning(
"task %r failed (attempt %d/%d): %r; retrying",
spec.name,
attempt,
max_attempts,
exc,
)
def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None:
"""安全调用钩子(异常仅记录,不影响任务状态)。"""
hook: Callable[..., None] | None = getattr(hooks, fn_name, None)
@@ -93,87 +94,6 @@ def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None:
logger.warning("hook %s raised: %r", fn_name, exc)
def _check_upstream_skipped(
spec: TaskSpec[Any],
report: RunReport | None,
) -> tuple[bool, str | None]:
"""检查硬依赖上游任务是否被 SKIPPED 或 FAILED。
软依赖不影响本检查——软依赖被跳过时注入默认值。
"""
if report is None: # pragma: no cover
return False, None # pragma: no cover
if spec.allow_upstream_skip: # pragma: no cover
return False, None # pragma: no cover
for dep in spec.depends_on:
if dep not in report.results: # pragma: no cover
continue # pragma: no cover
dep_status = report.results[dep].status
if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED):
return True, f"上游任务 '{dep}' 状态为 {dep_status.value}"
return False, None # pragma: no cover
def _format_reason(reason: Any) -> str:
"""将 _reason 格式化为可读字符串."""
if isinstance(reason, list):
return ", ".join(str(r) for r in reason)
return str(reason)
def _evaluate_conditions(spec: TaskSpec[Any], context: Mapping[str, Any]) -> str | None:
"""求值所有条件,返回跳过原因或 ``None``。
条件接收上下文映射(硬依赖 + 软依赖结果)。
"""
failed_conditions: list[str] = []
for condition in spec.conditions:
try:
ok = condition(context)
except Exception:
ok = False
name = getattr(condition, "__name__", None) or "匿名条件(执行错误)"
failed_conditions.append(name)
continue
if not ok:
reason = getattr(condition, "_reason", None)
if reason is not None:
failed_conditions.append(_format_reason(reason))
else:
failed_conditions.append(getattr(condition, "__name__", None) or "匿名条件")
if failed_conditions:
if len(failed_conditions) <= 2:
return f"条件不满足: {', '.join(failed_conditions)}"
return f"条件不满足: {', '.join(failed_conditions[:2])}{len(failed_conditions)}个条件"
if spec.skip_if_missing and not spec._is_cmd_available():
cmd_name = spec.cmd[0] if isinstance(spec.cmd, list) and spec.cmd else "unknown"
return f"命令不存在: {cmd_name}"
return None
def _make_skipped_result(
spec: TaskSpec[Any],
reason: str,
on_event: EventCallback | None,
) -> TaskResult[Any]:
"""构造 SKIPPED 的 TaskResult。"""
result: TaskResult[Any] = TaskResult(
spec=spec,
status=TaskStatus.SKIPPED,
finished_at=datetime.now(),
reason=reason,
)
_emit(on_event, result)
logger.info("task %r skipped (%s)", spec.name, reason)
return result
def _build_context(
spec: TaskSpec[Any],
global_context: Mapping[str, Any],
@@ -185,19 +105,16 @@ def _build_context(
软依赖:上游成功则注入其值;否则注入 ``spec.defaults`` 中的默认值(或 ``None``)。
"""
ctx: dict[str, Any] = {}
for dep in spec.depends_on:
if dep in global_context:
ctx[dep] = global_context[dep]
for dep in spec.soft_depends_on:
if dep in global_context:
ctx[dep] = global_context[dep]
elif dep in spec.defaults: # pragma: no cover
ctx[dep] = spec.defaults[dep] # pragma: no cover
elif dep in spec.defaults:
ctx[dep] = spec.defaults[dep]
else:
ctx[dep] = None
return ctx
@@ -222,33 +139,93 @@ def _apply_cached(
return True
def _prepare_for_execution(
def _sort_by_priority(layer: list[str], graph: Graph) -> list[str]:
"""按优先级降序排序(稳定排序)。"""
return sorted(layer, key=lambda n: -graph.resolved_spec(n).priority)
# ---------------------------------------------------------------------- #
# Mixin:任务级跳过 / 重试 / 成功处理
# ---------------------------------------------------------------------- #
class _TaskSkipMixin:
"""任务级跳过预检共享逻辑。
"上游被跳过/失败""条件不满足"两类跳过判断统一为单一入口,
被 :class:`SyncTaskRunner` 与 :class:`AsyncTaskRunner` 复用。
"""
@staticmethod
def _upstream_skip_reason(spec: TaskSpec[Any], report: RunReport | None) -> str | None:
"""硬依赖被 SKIPPED/FAILED 时返回原因字符串,否则 ``None``。
软依赖不影响本检查——软依赖被跳过时注入默认值。
"""
if report is None or spec.allow_upstream_skip:
return None
for dep in spec.depends_on:
if dep not in report.results:
continue
dep_status = report.results[dep].status
if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED):
return f"上游任务 '{dep}' 状态为 {dep_status.value}"
return None
@staticmethod
def _prepare_for_execution(
spec: TaskSpec[Any],
context: Mapping[str, Any],
report: RunReport | None,
on_event: EventCallback | None,
) -> TaskResult[Any] | None:
) -> TaskResult[Any] | None:
"""执行前预检:上游跳过 / 条件跳过。
返回 SKIPPED TaskResult 或 ``None``(继续执行)。
条件判断委托给 :meth:`TaskSpec.should_execute`,避免重复实现。
"""
should_skip, skip_reason = _check_upstream_skipped(spec, report)
if should_skip:
return _make_skipped_result(spec, skip_reason or "上游任务被跳过", on_event)
skip_reason = _evaluate_conditions(spec, context)
if skip_reason is not None:
return _make_skipped_result(spec, skip_reason, on_event)
# 1. 上游被跳过/失败
skip_reason = _TaskSkipMixin._upstream_skip_reason(spec, report)
# 2. 条件 / skip_if_missing(单一来源:TaskSpec.should_execute
if skip_reason is None:
should_run, cond_reason = spec.should_execute(context)
if not should_run:
skip_reason = cond_reason or "条件不满足"
if skip_reason is None:
return None
# 构造 SKIPPED 结果
result: TaskResult[Any] = TaskResult(
spec=spec,
status=TaskStatus.SKIPPED,
finished_at=datetime.now(),
reason=skip_reason,
)
_emit(on_event, result)
logger.info("task %r skipped (%s)", spec.name, skip_reason)
return result
def _finalize_failure(
class _TaskRetryMixin:
"""任务级重试决策与失败/成功后处理共享逻辑。"""
@staticmethod
def _should_retry(spec: TaskSpec[Any], attempts: int, exc: BaseException) -> bool:
"""是否应继续重试。"""
return attempts < spec.retry.max_attempts and spec.retry.should_retry(exc)
@staticmethod
def _mark_success(spec: TaskSpec[Any], result: TaskResult[Any], value: Any) -> None:
"""标记任务成功并触发 post_run 钩子。"""
result.value = value
result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now()
_run_hooks(spec.hooks, "post_run", spec, value)
@staticmethod
def _finalize_failure(
result: TaskResult[Any],
layer_idx: int | None,
on_event: EventCallback | None = None,
continue_on_error: bool = False,
) -> None:
on_event: EventCallback | None,
continue_on_error: bool,
) -> None:
"""标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。"""
result.status = TaskStatus.FAILED
result.finished_at = datetime.now()
@@ -266,41 +243,66 @@ def _finalize_failure(
layer=layer_idx,
)
@staticmethod
def _handle_failure(
spec: TaskSpec[Any],
result: TaskResult[Any],
exc: BaseException,
layer_idx: int | None,
on_event: EventCallback | None,
) -> bool:
"""统一处理失败:超时转换、重试决策、finalize。
def _sleep_for_retry(spec: TaskSpec[Any], attempt: int) -> None:
"""重试前的同步等待。"""
wait = spec.retry.wait_seconds(attempt)
if wait > 0:
import time
time.sleep(wait)
async def _async_sleep_for_retry(spec: TaskSpec[Any], attempt: int) -> None:
"""重试前的异步等待。"""
wait = spec.retry.wait_seconds(attempt)
if wait > 0:
await asyncio.sleep(wait)
Returns
-------
bool
``True`` 表示已 finalize(不再重试);``False`` 表示应继续重试。
"""
# asyncio.TimeoutError → TaskTimeoutError(统一异常类型)
if isinstance(exc, asyncio.TimeoutError):
exc = TaskTimeoutError(spec.name, spec.timeout or 0.0)
logger.warning(
"task %r timed out (attempt %d/%d); retrying",
spec.name,
result.attempts,
spec.retry.max_attempts,
)
else:
logger.warning(
"task %r failed (attempt %d/%d): %r; retrying",
spec.name,
result.attempts,
spec.retry.max_attempts,
exc,
)
result.error = exc
if _TaskRetryMixin._should_retry(spec, result.attempts, exc):
return False
_run_hooks(spec.hooks, "on_failure", spec, exc)
_TaskRetryMixin._finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return True
# ---------------------------------------------------------------------- #
# 同步执行内核
# 任务执行器:同步 / 异步(复用 _TaskSkipMixin + _TaskRetryMixin
# ---------------------------------------------------------------------- #
def _run_sync_with_retry(
class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
"""同步任务执行器:带重试与跳过预检。"""
@staticmethod
def run(
spec: TaskSpec[Any],
context: Mapping[str, Any],
layer_idx: int | None,
on_event: EventCallback | None = None,
report: RunReport | None = None,
) -> TaskResult[Any]:
"""执行同步任务并带重试;返回填充好的 TaskResult。"""
skipped = _prepare_for_execution(spec, context, report, on_event)
) -> TaskResult[Any]:
skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event)
if skipped is not None:
return skipped
result: TaskResult[Any] = TaskResult(spec=spec)
result.started_at = datetime.now()
max_attempts = spec.retry.max_attempts
args, kwargs = build_call_args(spec, context)
_run_hooks(spec.hooks, "pre_run", spec)
@@ -309,25 +311,60 @@ def _run_sync_with_retry(
result.attempts += 1
try:
with spec.env_context():
result.value = spec.effective_fn(*args, **kwargs)
result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now()
_run_hooks(spec.hooks, "post_run", spec, result.value)
value = spec.effective_fn(*args, **kwargs)
_TaskRetryMixin._mark_success(spec, result, value)
return result
except Exception as exc:
result.error = exc
if result.attempts >= max_attempts or not spec.retry.should_retry(exc):
_run_hooks(spec.hooks, "on_failure", spec, exc)
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event):
return result
_log_retry(spec, result.attempts, max_attempts, exc)
_sleep_for_retry(spec, result.attempts)
# pragma: no cover
wait = spec.retry.wait_seconds(result.attempts)
if wait > 0:
time.sleep(wait)
class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
"""异步任务执行器:在事件循环上运行同步或异步任务,带重试与跳过预检。"""
@staticmethod
async def run(
spec: TaskSpec[Any],
context: Mapping[str, Any],
layer_idx: int | None,
on_event: EventCallback | None = None,
report: RunReport | None = None,
semaphore: asyncio.Semaphore | None = None,
) -> TaskResult[Any]:
skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event)
if skipped is not None:
return skipped
async def _inner() -> TaskResult[Any]:
result: TaskResult[Any] = TaskResult(spec=spec)
result.started_at = datetime.now()
args, kwargs = build_call_args(spec, context)
loop = asyncio.get_event_loop()
_run_hooks(spec.hooks, "pre_run", spec)
while True:
result.attempts += 1
try:
value = await _execute_async_task(spec, args, kwargs, loop)
_TaskRetryMixin._mark_success(spec, result, value)
return result
except Exception as exc:
if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event):
return result
wait = spec.retry.wait_seconds(result.attempts)
if wait > 0:
await asyncio.sleep(wait)
if semaphore is not None:
async with semaphore:
return await _inner()
return await _inner()
# ---------------------------------------------------------------------- #
# 异步执行内核
# ---------------------------------------------------------------------- #
async def _execute_async_task(
spec: TaskSpec[Any],
args: tuple[Any, ...],
@@ -339,9 +376,7 @@ async def _execute_async_task(
coro = cast(Awaitable[Any], spec.effective_fn(*args, **kwargs))
if spec.timeout is not None:
return await asyncio.wait_for(coro, timeout=spec.timeout)
else:
return await coro
else:
def fn_call() -> Any:
with spec.env_context():
@@ -349,87 +384,89 @@ async def _execute_async_task(
if spec.timeout is not None:
return await asyncio.wait_for(loop.run_in_executor(None, fn_call), timeout=spec.timeout)
else:
return await loop.run_in_executor(None, fn_call)
async def _run_async_with_retry(
spec: TaskSpec[Any],
context: Mapping[str, Any],
layer_idx: int | None,
on_event: EventCallback | None = None,
report: RunReport | None = None,
semaphore: asyncio.Semaphore | None = None,
) -> TaskResult[Any]:
"""在事件循环上执行任务(同步或异步)并带重试。"""
skipped = _prepare_for_execution(spec, context, report, on_event)
if skipped is not None:
return skipped
# ---------------------------------------------------------------------- #
# Mixin:层执行共享逻辑
# ---------------------------------------------------------------------- #
class _LayerMixin:
"""层执行共享逻辑:缓存过滤、优先级排序、信号量构建、结果存储。
if semaphore is not None:
async with semaphore:
return await _run_async_inner(spec, context, layer_idx, on_event, report)
return await _run_async_inner(spec, context, layer_idx, on_event, report)
四个层执行器(sequential/threaded/async/dependency)通过组合此 Mixin
消除"过滤缓存→排序→运行→存结果"的样板代码。
"""
@staticmethod
def _filter_and_sort(
layer: list[str],
graph: Graph,
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
) -> list[str]:
"""过滤掉已命中缓存的任务,按优先级排序返回待运行列表。"""
to_run: list[str] = []
for name in layer:
spec = graph.resolved_spec(name)
if not _apply_cached(name, spec, context, report, backend, on_event):
to_run.append(name)
return _sort_by_priority(to_run, graph)
async def _run_async_inner(
spec: TaskSpec[Any],
context: Mapping[str, Any],
layer_idx: int | None,
on_event: EventCallback | None = None,
report: RunReport | None = None, # noqa: ARG001
) -> TaskResult[Any]:
"""异步执行内核的内部实现(已获取 semaphore 后)。"""
result: TaskResult[Any] = TaskResult(spec=spec)
result.started_at = datetime.now()
max_attempts = spec.retry.max_attempts
args, kwargs = build_call_args(spec, context)
loop = asyncio.get_event_loop()
@staticmethod
def _store_result(
name: str,
result: TaskResult[Any],
graph: Graph,
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
context_snapshot: Mapping[str, Any] | None = None,
) -> None:
"""存储任务结果到 context/report/backend 并触发事件。"""
context[name] = result.value
if result.status == TaskStatus.SUCCESS:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot if context_snapshot is not None else context, report)
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
_run_hooks(spec.hooks, "pre_run", spec)
@staticmethod
def _build_semaphores(
to_run: list[str],
graph: Graph,
sem_factory: Callable[[int], Any],
concurrency_limits: Mapping[str, int],
) -> dict[str, Any]:
"""为每个 ``concurrency_key`` 创建一个信号量。"""
semaphores: dict[str, Any] = {}
for name in to_run:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = sem_factory(limit)
return semaphores
while True:
result.attempts += 1
try:
result.value = await _execute_async_task(spec, args, kwargs, loop)
result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now()
_run_hooks(spec.hooks, "post_run", spec, result.value)
return result
except asyncio.TimeoutError:
exc: BaseException = TaskTimeoutError(spec.name, spec.timeout or 0.0)
result.error = exc
if result.attempts >= max_attempts or not spec.retry.should_retry(exc):
_run_hooks(spec.hooks, "on_failure", spec, exc)
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return result
logger.warning(
"task %r timed out (attempt %d/%d); retrying",
spec.name,
result.attempts,
max_attempts,
)
await _async_sleep_for_retry(spec, result.attempts)
except Exception as exc:
result.error = exc
if result.attempts >= max_attempts or not spec.retry.should_retry(exc):
_run_hooks(spec.hooks, "on_failure", spec, exc)
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return result
_log_retry(spec, result.attempts, max_attempts, exc)
await _async_sleep_for_retry(spec, result.attempts)
# pragma: no cover
@staticmethod
def _get_sem(semaphores: Mapping[str, Any], spec: TaskSpec[Any]) -> Any | None:
"""获取任务对应的信号量(无 concurrency_key 则返回 None)。"""
if spec.concurrency_key is None:
return None
return semaphores.get(spec.concurrency_key)
# ---------------------------------------------------------------------- #
# 层执行器
# ---------------------------------------------------------------------- #
def _sort_by_priority(layer: list[str], graph: Graph) -> list[str]:
"""按优先级降序排序(稳定排序)。"""
return sorted(layer, key=lambda n: -graph.resolved_spec(n).priority)
class SequentialLayerRunner(_LayerMixin):
"""逐个运行某层的任务(按优先级排序)。"""
def _execute_layer_sequential(
@staticmethod
def execute(
layer: list[str],
graph: Graph,
context: dict[str, Any],
@@ -437,22 +474,19 @@ def _execute_layer_sequential(
backend: StateBackend,
layer_idx: int,
on_event: EventCallback | None,
) -> None:
"""逐个运行某层的任务(按优先级排序)。"""
for name in _sort_by_priority(layer, graph):
) -> None:
for name in SequentialLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event):
spec = graph.resolved_spec(name)
if _apply_cached(name, spec, context, report, backend, on_event):
continue
task_ctx = _build_context(spec, context, report)
result = _run_sync_with_retry(spec, task_ctx, layer_idx, on_event, report)
context[name] = result.value
if result.status == TaskStatus.SUCCESS:
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
result = SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
SequentialLayerRunner._store_result(name, result, graph, context, report, backend, on_event)
def _execute_layer_threaded(
class ThreadedLayerRunner(_LayerMixin):
"""在线程池中并发运行某层的任务。"""
@staticmethod
def execute(
layer: list[str],
graph: Graph,
context: dict[str, Any],
@@ -462,70 +496,48 @@ def _execute_layer_threaded(
on_event: EventCallback | None,
max_workers: int,
concurrency_limits: Mapping[str, int],
) -> None:
"""在线程池中并发运行某层的任务。"""
to_run: list[str] = []
for name in layer:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context, report)
if _apply_cached(name, spec, context, report, backend, on_event):
continue
to_run.append(name)
) -> None:
to_run = ThreadedLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event)
if not to_run:
return
to_run = _sort_by_priority(to_run, graph)
# 为每个 concurrency_key 创建线程信号量
semaphores: dict[str, threading.Semaphore] = {}
for name in to_run:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = threading.Semaphore(limit)
semaphores = ThreadedLayerRunner._build_semaphores(to_run, graph, threading.Semaphore, concurrency_limits)
context_snapshot = dict(context)
lock = threading.Lock()
def _run_threaded_task(name: str) -> TaskResult[Any]:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report)
sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None
sem = ThreadedLayerRunner._get_sem(semaphores, spec)
if sem is not None:
sem.acquire()
try:
return _run_sync_with_retry(spec, task_ctx, layer_idx, on_event, report)
return SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
finally:
if sem is not None:
sem.release()
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
future_to_name: dict[concurrent.futures.Future[TaskResult[Any]], str] = {}
for name in to_run:
fut = pool.submit(_run_threaded_task, name)
future_to_name[fut] = name
future_to_name: dict[concurrent.futures.Future[TaskResult[Any]], str] = {
pool.submit(_run_threaded_task, name): name for name in to_run
}
completed: dict[str, TaskResult[Any]] = {}
try:
for fut in concurrent.futures.as_completed(future_to_name):
name = future_to_name[fut]
result = fut.result()
completed[name] = result
completed[name] = fut.result()
finally:
with lock:
for name, result in completed.items():
context[name] = result.value
if result.status == TaskStatus.SUCCESS:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report)
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
ThreadedLayerRunner._store_result(
name, result, graph, context, report, backend, on_event, context_snapshot
)
async def _execute_layer_async(
class AsyncLayerRunner(_LayerMixin):
"""在事件循环上并发运行某层的任务。"""
@staticmethod
async def execute(
layer: list[str],
graph: Graph,
context: dict[str, Any],
@@ -534,76 +546,41 @@ async def _execute_layer_async(
layer_idx: int,
on_event: EventCallback | None,
concurrency_limits: Mapping[str, int],
) -> None:
"""在事件循环上并发运行某层的任务。"""
to_run: list[str] = []
for name in layer:
spec = graph.resolved_spec(name)
if _apply_cached(name, spec, context, report, backend, on_event):
continue
to_run.append(name)
) -> None:
to_run = AsyncLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event)
if not to_run:
return
to_run = _sort_by_priority(to_run, graph)
# 为每个 concurrency_key 创建异步信号量
semaphores: dict[str, asyncio.Semaphore] = {}
for name in to_run:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = asyncio.Semaphore(limit)
semaphores = AsyncLayerRunner._build_semaphores(to_run, graph, asyncio.Semaphore, concurrency_limits)
context_snapshot = dict(context)
async def _run_async_task_wrapped(name: str) -> TaskResult[Any]:
async def _run_async_task(name: str) -> TaskResult[Any]:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report)
sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None
if sem is not None:
async with sem:
return await _run_async_with_retry(spec, task_ctx, layer_idx, on_event, report)
return await _run_async_with_retry(spec, task_ctx, layer_idx, on_event, report)
sem = AsyncLayerRunner._get_sem(semaphores, spec)
return await AsyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report, sem)
coros = [_run_async_task_wrapped(name) for name in to_run]
results = await asyncio.gather(*coros)
results = await asyncio.gather(*[_run_async_task(name) for name in to_run])
for name, result in zip(to_run, results):
context[name] = result.value
if result.status == TaskStatus.SUCCESS:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report)
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
AsyncLayerRunner._store_result(name, result, graph, context, report, backend, on_event, context_snapshot)
# ---------------------------------------------------------------------- #
# 依赖驱动调度
# ---------------------------------------------------------------------- #
async def _drive_dependency_async(
class DependencyRunner(_LayerMixin):
"""依赖驱动调度:任务在硬/软依赖完成后立即启动,无层屏障。
所有任务通过 asyncio 并发调度。同步任务卸载到线程池。
"""
@staticmethod
async def execute(
graph: Graph,
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
concurrency_limits: Mapping[str, int],
) -> None:
"""依赖驱动调度:任务在硬依赖完成后立即启动,无层屏障。
所有任务通过 asyncio 并发调度。同步任务卸载到线程池。
"""
all_names = set(graph.all_specs().keys())
semaphores: dict[str, asyncio.Semaphore] = {}
for name in all_names:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = asyncio.Semaphore(limit)
) -> None:
all_names = list(graph.all_specs().keys())
semaphores = DependencyRunner._build_semaphores(all_names, graph, asyncio.Semaphore, concurrency_limits)
futures: dict[str, asyncio.Future[TaskResult[Any]]] = {}
async def _run_task(name: str) -> TaskResult[Any]:
@@ -621,24 +598,14 @@ async def _drive_dependency_async(
if _apply_cached(name, spec, context, report, backend, on_event):
return report.results[name]
sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None
if sem is not None:
async with sem:
result = await _run_async_with_retry(spec, task_ctx, None, on_event, report)
else:
result = await _run_async_with_retry(spec, task_ctx, None, on_event, report)
context[name] = result.value
if result.status == TaskStatus.SUCCESS:
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
sem = DependencyRunner._get_sem(semaphores, spec)
result = await AsyncTaskRunner.run(spec, task_ctx, None, on_event, report, sem)
DependencyRunner._store_result(name, result, graph, context, report, backend, on_event)
return result
loop = asyncio.get_event_loop()
for name in all_names:
futures[name] = loop.create_task(_run_task(name))
await asyncio.gather(*futures.values())
@@ -729,9 +696,9 @@ def run(
elif strategy == "thread":
_drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers, limits)
elif strategy == "async":
_drive_async(graph, layers, context, report, backend, effective_callback, limits)
asyncio.run(_async_drive(graph, layers, context, report, backend, effective_callback, limits))
elif strategy == "dependency":
asyncio.run(_drive_dependency_async(graph, context, report, backend, effective_callback, limits))
asyncio.run(DependencyRunner.execute(graph, context, report, backend, effective_callback, limits))
else:
raise ValueError(f"Unknown strategy: {strategy!r}")
except TaskFailedError:
@@ -759,7 +726,7 @@ def _drive_sequential(
on_event: EventCallback | None,
) -> None:
for idx, layer in enumerate(layers, 1):
_execute_layer_sequential(layer, graph, context, report, backend, idx, on_event)
SequentialLayerRunner.execute(layer, graph, context, report, backend, idx, on_event)
def _drive_threaded(
@@ -774,19 +741,7 @@ def _drive_threaded(
) -> None:
for idx, layer in enumerate(layers, 1):
workers = max_workers or max(1, min(32, len(layer)))
_execute_layer_threaded(layer, graph, context, report, backend, idx, on_event, workers, concurrency_limits)
def _drive_async(
graph: Graph,
layers: list[list[str]],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
concurrency_limits: Mapping[str, int],
) -> None:
asyncio.run(_async_drive(graph, layers, context, report, backend, on_event, concurrency_limits))
ThreadedLayerRunner.execute(layer, graph, context, report, backend, idx, on_event, workers, concurrency_limits)
async def _async_drive(
@@ -799,4 +754,4 @@ async def _async_drive(
concurrency_limits: Mapping[str, int],
) -> None:
for idx, layer in enumerate(layers, 1):
await _execute_layer_async(layer, graph, context, report, backend, idx, on_event, concurrency_limits)
await AsyncLayerRunner.execute(layer, graph, context, report, backend, idx, on_event, concurrency_limits)
+25 -16
View File
@@ -12,6 +12,11 @@
from __future__ import annotations
__all__ = [
"Graph",
"GraphDefaults",
]
import sys
from dataclasses import dataclass, field, replace
from typing import Any, Callable, Iterable, Mapping, Sequence
@@ -49,6 +54,15 @@ class GraphDefaults:
verbose: bool = False
def _prune_deps(spec: TaskSpec[Any], keep: Callable[[str], bool]) -> TaskSpec[Any]:
"""返回新 spec,其 ``depends_on`` / ``soft_depends_on`` 仅保留 ``keep(dep)`` 为真的依赖。"""
return replace(
spec,
depends_on=tuple(d for d in spec.depends_on if keep(d)),
soft_depends_on=tuple(d for d in spec.soft_depends_on if keep(d)),
)
@dataclass
class Graph:
"""校验后的有向无环任务图。
@@ -64,6 +78,7 @@ class Graph:
specs: dict[str, TaskSpec[Any]] = field(default_factory=dict)
deps: dict[str, tuple[str, ...]] = field(default_factory=dict)
defaults: GraphDefaults = field(default_factory=GraphDefaults)
# 待解析的字符串引用列表(由 GraphComposer 消费);为空表示无引用。
_pending_refs: list[str] = field(default_factory=list)
@@ -225,16 +240,13 @@ class Graph:
def subgraph(self, tags: Iterable[str]) -> Graph:
"""返回仅包含匹配任意标签的任务的新图。依赖边被修剪。"""
wanted: set[str] = set(tags)
kept: list[TaskSpec[Any]] = []
for spec in self.specs.values():
if wanted & set(spec.tags):
pruned_deps = tuple(
d for d in spec.depends_on if d in self.specs and (wanted & set(self.specs[d].tags))
)
pruned_soft = tuple(
d for d in spec.soft_depends_on if d in self.specs and (wanted & set(self.specs[d].tags))
)
kept.append(replace(spec, depends_on=pruned_deps, soft_depends_on=pruned_soft))
def _dep_kept(dep: str) -> bool:
return dep in self.specs and bool(wanted & set(self.specs[dep].tags))
kept: list[TaskSpec[Any]] = [
_prune_deps(spec, _dep_kept) for spec in self.specs.values() if wanted & set(spec.tags)
]
return Graph.from_specs(kept, defaults=self.defaults)
def subgraph_by_names(self, names: Iterable[str]) -> Graph:
@@ -243,12 +255,9 @@ class Graph:
for n in wanted:
if n not in self.specs:
raise KeyError(f"Unknown task name: {n!r}")
kept: list[TaskSpec[Any]] = []
for spec in self.specs.values():
if spec.name in wanted:
pruned_deps = tuple(d for d in spec.depends_on if d in wanted)
pruned_soft = tuple(d for d in spec.soft_depends_on if d in wanted)
kept.append(replace(spec, depends_on=pruned_deps, soft_depends_on=pruned_soft))
kept: list[TaskSpec[Any]] = [
_prune_deps(spec, lambda d: d in wanted) for spec in self.specs.values() if spec.name in wanted
]
return Graph.from_specs(kept, defaults=self.defaults)
# ------------------------------------------------------------------ #
+110 -40
View File
@@ -17,6 +17,7 @@ import json
import sys
import time
from abc import ABC, abstractmethod
from collections.abc import Iterator
from pathlib import Path
from typing import Any, Mapping
@@ -55,7 +56,74 @@ class StateBackend(ABC):
"""清除所有存储状态。"""
class MemoryBackend(StateBackend):
class _TTLStateBackendMixin(StateBackend):
"""TTL 状态后端共享逻辑。
``has`` / ``get`` / ``load`` / ``save`` / ``clear`` 的统一实现
委托给四个原始存取原语:meth:`_get_raw`:meth:`_put_raw`
:meth:`_iter_raw`:meth:`_clear_raw`并基于 :meth:`_now`
``self._ttl`` 提供统一的过期判断 :meth:`_is_expired`
子类需设置 ``self._ttl`` 并实现上述四个原语如需自定义时间源
``time.monotonic``可覆盖 :meth:`_now`
"""
_ttl: float | None
# ---- 原语:由子类实现 ---- #
@abstractmethod
def _get_raw(self, key: str) -> tuple[Any, float] | None:
"""返回 ``(value, ts)``;键不存在时返回 ``None``。"""
@abstractmethod
def _put_raw(self, key: str, value: Any, ts: float) -> None:
"""写入一条记录。"""
@abstractmethod
def _iter_raw(self) -> Iterator[tuple[str, Any, float]]:
"""迭代所有记录(不做过期过滤),yield ``(key, value, ts)``。"""
@abstractmethod
def _clear_raw(self) -> None:
"""清空所有记录。"""
# ---- 共享实现 ---- #
def _now(self) -> float:
"""当前时间戳,默认为 wall-clock 秒。"""
return time.time()
def _is_expired(self, ts: float) -> bool:
"""时间戳 ``ts`` 是否已过期。"""
if self._ttl is None:
return False
return (self._now() - ts) > self._ttl
@override
def load(self) -> Mapping[str, Any]:
return {k: v for k, v, ts in self._iter_raw() if not self._is_expired(ts)}
@override
def save(self, key: str, value: Any) -> None:
self._put_raw(key, value, self._now())
@override
def has(self, key: str) -> bool:
entry = self._get_raw(key)
return entry is not None and not self._is_expired(entry[1])
@override
def get(self, key: str) -> Any:
entry = self._get_raw(key)
if entry is None or self._is_expired(entry[1]):
raise KeyError(key)
return entry[0]
@override
def clear(self) -> None:
self._clear_raw()
class MemoryBackend(_TTLStateBackendMixin):
"""进程内 dict 后端。进程退出即丢失。
Parameters
@@ -70,35 +138,35 @@ class MemoryBackend(StateBackend):
self._ttl = ttl
@override
def load(self) -> Mapping[str, Any]:
return {k: v for k, (v, _ts) in self._store.items() if not self._expired(k)}
def _now(self) -> float:
return time.monotonic()
@override
def save(self, key: str, value: Any) -> None:
self._store[key] = (value, time.monotonic())
def _get_raw(self, key: str) -> tuple[Any, float] | None:
return self._store.get(key)
@override
def has(self, key: str) -> bool:
return key in self._store and not self._expired(key)
def _put_raw(self, key: str, value: Any, ts: float) -> None:
self._store[key] = (value, ts)
@override
def get(self, key: str) -> Any:
if key not in self._store or self._expired(key):
raise KeyError(key)
return self._store[key][0]
def _iter_raw(self) -> Iterator[tuple[str, Any, float]]:
for k, (v, ts) in self._store.items():
yield k, v, ts
@override
def clear(self) -> None:
def _clear_raw(self) -> None:
self._store.clear()
def _expired(self, key: str) -> bool:
if self._ttl is None or key not in self._store:
"""键是否已过期(兼容旧测试 API)。"""
entry = self._get_raw(key)
if entry is None:
return False
_value, ts = self._store[key]
return (time.monotonic() - ts) > self._ttl
return self._is_expired(entry[1])
class JSONBackend(StateBackend):
class JSONBackend(_TTLStateBackendMixin):
"""基于文件的 JSON 存储,用于跨进程续跑。
存储格式``{key: {"value": v, "ts": epoch_seconds}}``
@@ -144,17 +212,30 @@ class JSONBackend(StateBackend):
except (OSError, TypeError) as exc:
raise StorageError(f"cannot write state file {self._path!r}", exc) from exc
def _now(self) -> float:
return time.time()
def _expired(self, entry: dict[str, Any]) -> bool:
if self._ttl is None:
return False
return (self._now() - float(entry.get("ts", 0))) > self._ttl
@override
def _get_raw(self, key: str) -> tuple[Any, float] | None:
entry = self._store.get(key)
if entry is None:
return None
return entry["value"], float(entry.get("ts", 0))
@override
def load(self) -> Mapping[str, Any]:
return {k: v["value"] for k, v in self._store.items() if not self._expired(v)}
def _put_raw(self, key: str, value: Any, ts: float) -> None:
self._store[key] = {"value": value, "ts": ts}
@override
def _iter_raw(self) -> Iterator[tuple[str, Any, float]]:
for k, entry in self._store.items():
yield k, entry["value"], float(entry.get("ts", 0))
@override
def _clear_raw(self) -> None:
self._store.clear()
@override
def clear(self) -> None:
super().clear()
self._flush()
@override
def save(self, key: str, value: Any) -> None:
@@ -162,23 +243,12 @@ class JSONBackend(StateBackend):
_ = json.dumps(value)
except (TypeError, ValueError) as exc:
raise StorageError(f"result of key {key!r} is not JSON-serialisable", exc) from exc
self._store[key] = {"value": value, "ts": self._now()}
super().save(key, value)
self._flush()
@override
def has(self, key: str) -> bool:
return key in self._store and not self._expired(self._store[key])
@override
def get(self, key: str) -> Any:
if key not in self._store or self._expired(self._store[key]):
raise KeyError(key)
return self._store[key]["value"]
@override
def clear(self) -> None:
self._store.clear()
self._flush()
def _expired(self, entry: Mapping[str, Any]) -> bool:
"""带元数据的条目是否已过期(兼容旧测试 API)。"""
return self._is_expired(float(entry.get("ts", 0)))
def resolve_backend(backend: StateBackend | None) -> StateBackend:
+10 -3
View File
@@ -74,6 +74,13 @@ Condition = Callable[[Context], bool]
CacheKeyFn = Callable[[Context], str]
def _format_skip_reason(failed_conditions: list[str]) -> str:
"""格式化跳过原因:≤2 个全展示,>2 个仅展示前 2 个并附总数。"""
if len(failed_conditions) <= 2:
return f"条件不满足: {', '.join(failed_conditions)}"
return f"条件不满足: {', '.join(failed_conditions[:2])}{len(failed_conditions)}个条件"
# ---------------------------------------------------------------------- #
# 重试策略
# ---------------------------------------------------------------------- #
@@ -315,6 +322,7 @@ class TaskSpec(Generic[T]):
-------
(should_run, skip_reason)
``should_run`` False ``skip_reason`` 描述跳过原因
失败条件超过 2 个时仅展示前 2 个并附总数
"""
# 逐个求值条件,记录失败项。
failed_conditions: list[str] = []
@@ -323,8 +331,7 @@ class TaskSpec(Generic[T]):
ok = condition(context)
except Exception:
ok = False
name = getattr(condition, "__name__", None) or "匿名条件(执行错误)"
failed_conditions.append(name)
failed_conditions.append("匿名条件(执行错误)")
continue
if not ok:
reason = getattr(condition, "_reason", None)
@@ -336,7 +343,7 @@ class TaskSpec(Generic[T]):
failed_conditions.append(getattr(condition, "__name__", None) or "匿名条件")
if failed_conditions:
return False, f"条件不满足: {', '.join(failed_conditions)}"
return False, _format_skip_reason(failed_conditions)
if self.skip_if_missing and not self._is_cmd_available():
cmd_name = self.cmd[0] if isinstance(self.cmd, list) and self.cmd else "unknown"
+9 -3
View File
@@ -6,6 +6,15 @@
from __future__ import annotations
__all__ = [
"clr",
"reset_icon_cache",
"setenv",
"setenv_group",
"which",
"write_file",
]
import os
import subprocess
from pathlib import Path
@@ -111,6 +120,3 @@ def write_file(path: str, content: str, encoding: str = "utf-8") -> px.TaskSpec:
print(f"写入文件 {path} 失败: {e}")
return px.TaskSpec(f"write_file_{path}", fn=write, verbose=True)
__all__ = ["clr", "reset_icon_cache", "setenv", "setenv_group", "which", "write_file"]
+107
View File
@@ -0,0 +1,107 @@
"""常用工具函数."""
from __future__ import annotations
__all__ = ["perf_timer"]
import functools
import logging
import time
from collections import defaultdict
from typing import Callable, TypedDict
try:
from typing_extensions import ParamSpec, TypeVar
except ImportError:
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
class _PerformanceMetrics(TypedDict):
"""性能指标."""
count: int
total_time: float
_perf_metrics: defaultdict[str, _PerformanceMetrics] = defaultdict(
lambda: _PerformanceMetrics(
count=0,
total_time=0.0,
)
)
def _generate_report(unit: str, precision: int) -> str:
"""生成性能指标报告,返回报告字符串."""
if not _perf_metrics:
return ""
lines: list[str] = []
lines.append("=" * 50)
lines.append("性能指标报告 (Performance Metrics Report)")
lines.append("-" * 50)
# 按总耗时排序,最耗时的函数排在前面
sorted_metrics = sorted(_perf_metrics.items(), key=lambda x: x[1]["total_time"], reverse=True)
for name, metrics in sorted_metrics:
avg_time = metrics["total_time"] / metrics["count"] if metrics["count"] > 0 else 0
lines.append(
f"{name}: "
f"调用次数={metrics['count']}, "
f"总耗时={metrics['total_time']:.{precision}f}{unit}, "
f"平均耗时={avg_time:.{precision}f}{unit}"
)
lines.append("=" * 50)
report_str = "\n".join(lines)
# 同时输出到日志
logging.info("\n".join(lines))
return report_str
def perf_timer(unit: str = "ms", precision: int = 4, report: bool = False):
"""性能计时器装饰器."""
scale: dict[str, float] = {
"s": 1.0,
"ms": 1000.0,
"us": 1000000.0,
}
def decorator(func: Callable[P, R]) -> Callable[P, R]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
_perf_metrics[func.__name__]["count"] += 1
_perf_metrics[func.__name__]["total_time"] += (end_time - start_time) * scale[unit]
if not report:
logging.info(
f"{func.__name__} {unit}: {_perf_metrics[func.__name__]['total_time']:.{precision}f}{unit}"
)
return result
return wrapper
if report:
import atexit
logging.basicConfig(level=logging.INFO)
logging.info(f"Performance metrics report enabled with unit {unit} and precision {precision}")
@atexit.register
def _report_at_exit() -> None:
"""在程序退出时报告性能指标."""
_generate_report(unit, precision)
# 将报告生成逻辑提取为独立函数,便于测试
return decorator
+1 -1
View File
@@ -5,7 +5,7 @@ from __future__ import annotations
from unittest.mock import patch
import pyflowx as px
from pyflowx.cli import clearscreen
from pyflowx.cli.system import clearscreen
# ---------------------------------------------------------------------- #
-110
View File
@@ -1,110 +0,0 @@
"""Tests for cli.envpy module."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import patch
import pytest
import pyflowx as px
from pyflowx.cli import envpy
# ---------------------------------------------------------------------- #
# set_pip_mirror
# ---------------------------------------------------------------------- #
class TestSetPipMirror:
"""Test set_pip_mirror function."""
def test_set_pip_mirror_tsinghua(self, tmp_path: Path) -> None:
"""Should set tsinghua mirror."""
with patch.object(Path, "home", return_value=tmp_path):
envpy.set_pip_mirror("tsinghua")
# Check pip config
pip_config = tmp_path / "pip" / "pip.ini"
if envpy.Constants.IS_WINDOWS:
assert pip_config.exists() or (tmp_path / "pip" / "pip.conf").exists()
def test_set_pip_mirror_aliyun(self, tmp_path: Path) -> None:
"""Should set aliyun mirror."""
with patch.object(Path, "home", return_value=tmp_path):
envpy.set_pip_mirror("aliyun")
# Check pip config
pip_dir = tmp_path / "pip"
assert pip_dir.exists()
def test_set_pip_mirror_with_token(self, tmp_path: Path) -> None:
"""Should set mirror with token."""
with patch.object(Path, "home", return_value=tmp_path):
envpy.set_pip_mirror("tsinghua", token="test_token")
# Check that token is set
def test_set_pip_mirror_creates_pip_dir(self, tmp_path: Path) -> None:
"""Should create pip directory if it doesn't exist."""
pip_dir = tmp_path / "pip"
with patch.object(Path, "home", return_value=tmp_path):
envpy.set_pip_mirror("tsinghua")
assert pip_dir.exists()
assert pip_dir.is_dir()
# ---------------------------------------------------------------------- #
# main function
# ---------------------------------------------------------------------- #
class TestMain:
"""Test main function."""
def test_main_mirror_tsinghua(self) -> None:
"""main() should handle mirror tsinghua command."""
with patch("sys.argv", ["envpy", "mirror", "tsinghua"]), patch.object(px, "run") as mock_run, patch.object(
envpy, "set_pip_mirror"
):
envpy.main()
assert mock_run.called
def test_main_mirror_aliyun(self) -> None:
"""main() should handle mirror aliyun command."""
with patch("sys.argv", ["envpy", "mirror", "aliyun"]), patch.object(px, "run") as mock_run, patch.object(
envpy, "set_pip_mirror"
):
envpy.main()
assert mock_run.called
def test_main_mirror_with_token(self) -> None:
"""main() should handle mirror with token."""
with patch("sys.argv", ["envpy", "mirror", "tsinghua", "--token", "test_token"]), patch.object(
px, "run"
) as mock_run, patch.object(envpy, "set_pip_mirror"):
envpy.main()
assert mock_run.called
def test_main_with_no_args_shows_help(self) -> None:
"""main() with no args should show help and return."""
with patch("sys.argv", ["envpy"]):
envpy.main()
# Should print help and return
def test_main_invalid_mirror_shows_error(self) -> None:
"""main() with invalid mirror should show error."""
with patch("sys.argv", ["envpy", "mirror", "invalid"]), pytest.raises(SystemExit) as exc_info:
envpy.main()
assert exc_info.value.code == 2
def test_main_creates_task_spec_with_correct_name(self) -> None:
"""main() should create TaskSpec with correct name."""
with patch("sys.argv", ["envpy", "mirror", "tsinghua"]), patch.object(px, "run") as mock_run, patch.object(
envpy, "set_pip_mirror"
):
envpy.main()
graph = mock_run.call_args[0][0]
task_names = list(graph.all_specs().keys())
assert "set_pip_mirror" in task_names
def test_main_uses_thread_strategy(self) -> None:
"""main() should use thread strategy."""
with patch("sys.argv", ["envpy", "mirror", "tsinghua"]), patch.object(px, "run") as mock_run, patch.object(
envpy, "set_pip_mirror"
):
envpy.main()
assert mock_run.call_args[1]["strategy"] == "thread"
-210
View File
@@ -1,210 +0,0 @@
"""Tests for cli.envrs module."""
from __future__ import annotations
import os
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
import pyflowx as px
from pyflowx.cli import envrs
# ---------------------------------------------------------------------- #
# set_rust_mirror
# ---------------------------------------------------------------------- #
class TestSetRustMirror:
"""Test set_rust_mirror function."""
def test_set_rust_mirror_aliyun(self, tmp_path: Path) -> None:
"""Should set aliyun mirror."""
with patch.object(Path, "home", return_value=tmp_path):
envrs.set_rust_mirror("aliyun")
# Check environment variables
assert os.environ.get("RUSTUP_DIST_SERVER") == "https://mirrors.aliyun.com/rustup"
assert os.environ.get("RUSTUP_UPDATE_ROOT") == "https://mirrors.aliyun.com/rustup/rustup"
# Check cargo config
cargo_config = tmp_path / ".cargo" / "config.toml"
assert cargo_config.exists()
content = cargo_config.read_text()
assert "aliyun" in content
def test_set_rust_mirror_ustc(self, tmp_path: Path) -> None:
"""Should set ustc mirror."""
with patch.object(Path, "home", return_value=tmp_path):
envrs.set_rust_mirror("ustc")
assert os.environ.get("RUSTUP_DIST_SERVER") == "https://mirrors.ustc.edu.cn/rust-static"
assert os.environ.get("RUSTUP_UPDATE_ROOT") == "https://mirrors.ustc.edu.cn/rust-static/rustup"
def test_set_rust_mirror_tsinghua(self, tmp_path: Path) -> None:
"""Should set tsinghua mirror."""
with patch.object(Path, "home", return_value=tmp_path):
envrs.set_rust_mirror("tsinghua")
assert os.environ.get("RUSTUP_DIST_SERVER") == "https://mirrors.tuna.tsinghua.edu.cn/rustup"
assert os.environ.get("RUSTUP_UPDATE_ROOT") == "https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup"
def test_set_rust_mirror_unknown_uses_default(self, tmp_path: Path) -> None:
"""Should use default mirror for unknown mirror name."""
with patch.object(Path, "home", return_value=tmp_path):
# pyrefly: ignore [bad-argument-type]
envrs.set_rust_mirror("unknown")
# Should use default mirror (tsinghua)
assert os.environ.get("RUSTUP_DIST_SERVER") == "https://mirrors.tuna.tsinghua.edu.cn/rustup"
def test_set_rust_mirror_creates_cargo_dir(self, tmp_path: Path) -> None:
"""Should create .cargo directory if it doesn't exist."""
cargo_dir = tmp_path / ".cargo"
with patch.object(Path, "home", return_value=tmp_path):
envrs.set_rust_mirror("aliyun")
assert cargo_dir.exists()
assert cargo_dir.is_dir()
def test_set_rust_mirror_prints_message(self, tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None:
"""Should print mirror name."""
with patch.object(Path, "home", return_value=tmp_path):
envrs.set_rust_mirror("aliyun")
captured = capsys.readouterr()
assert "已设置 Rust 镜像源: aliyun" in captured.out
# ---------------------------------------------------------------------- #
# install_rust
# ---------------------------------------------------------------------- #
class TestInstallRust:
"""Test install_rust function."""
def test_install_rust_stable(self) -> None:
"""Should install stable Rust."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
envrs.install_rust("stable")
mock_run.assert_called_once_with(["rustup", "toolchain", "install", "stable"], check=True)
def test_install_rust_nightly(self) -> None:
"""Should install nightly Rust."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
envrs.install_rust("nightly")
mock_run.assert_called_once_with(["rustup", "toolchain", "install", "nightly"], check=True)
def test_install_rust_beta(self) -> None:
"""Should install beta Rust."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
envrs.install_rust("beta")
mock_run.assert_called_once_with(["rustup", "toolchain", "install", "beta"], check=True)
def test_install_rust_file_not_found(self) -> None:
"""Should raise FileNotFoundError when rustup not found."""
with patch("subprocess.run", side_effect=FileNotFoundError), pytest.raises(FileNotFoundError):
envrs.install_rust("stable")
def test_install_rust_prints_message(self, capsys: pytest.CaptureFixture[str]) -> None:
"""Should print installation message."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
envrs.install_rust("stable")
captured = capsys.readouterr()
assert "已安装 Rust stable" in captured.out
# ---------------------------------------------------------------------- #
# main function
# ---------------------------------------------------------------------- #
class TestMain:
"""Test main function."""
def test_main_mirror_aliyun(self) -> None:
"""main() should handle mirror aliyun command."""
with patch("sys.argv", ["envrs", "mirror", "aliyun"]), patch.object(px, "run") as mock_run, patch.object(
envrs, "set_rust_mirror"
):
envrs.main()
assert mock_run.called
def test_main_mirror_ustc(self) -> None:
"""main() should handle mirror ustc command."""
with patch("sys.argv", ["envrs", "mirror", "ustc"]), patch.object(px, "run") as mock_run, patch.object(
envrs, "set_rust_mirror"
):
envrs.main()
assert mock_run.called
def test_main_mirror_tsinghua(self) -> None:
"""main() should handle mirror tsinghua command."""
with patch("sys.argv", ["envrs", "mirror", "tsinghua"]), patch.object(px, "run") as mock_run, patch.object(
envrs, "set_rust_mirror"
):
envrs.main()
assert mock_run.called
def test_main_mirror_default(self) -> None:
"""main() should use default mirror when not specified."""
with patch("sys.argv", ["envrs", "mirror"]), patch.object(px, "run") as mock_run, patch.object(
envrs, "set_rust_mirror"
):
envrs.main()
assert mock_run.called
def test_main_install_stable(self) -> None:
"""main() should handle install stable command."""
with patch("sys.argv", ["envrs", "install", "stable"]), patch.object(px, "run") as mock_run:
envrs.main()
assert mock_run.called
def test_main_install_nightly(self) -> None:
"""main() should handle install nightly command."""
with patch("sys.argv", ["envrs", "install", "nightly"]), patch.object(px, "run") as mock_run:
envrs.main()
assert mock_run.called
def test_main_install_beta(self) -> None:
"""main() should handle install beta command."""
with patch("sys.argv", ["envrs", "install", "beta"]), patch.object(px, "run") as mock_run:
envrs.main()
assert mock_run.called
def test_main_install_default(self) -> None:
"""main() should use default version when not specified."""
with patch("sys.argv", ["envrs", "install"]), patch.object(px, "run") as mock_run:
envrs.main()
assert mock_run.called
def test_main_with_no_args_shows_help(self) -> None:
"""main() with no args should show help and return."""
with patch("sys.argv", ["envrs"]):
envrs.main()
# Should print help and return
def test_main_invalid_version_shows_error(self) -> None:
"""main() with invalid version should show error."""
with patch("sys.argv", ["envrs", "install", "invalid"]), pytest.raises(SystemExit) as exc_info:
envrs.main()
assert exc_info.value.code == 2
def test_main_invalid_mirror_shows_error(self) -> None:
"""main() with invalid mirror should show error."""
with patch("sys.argv", ["envrs", "mirror", "invalid"]), pytest.raises(SystemExit) as exc_info:
envrs.main()
assert exc_info.value.code == 2
def test_main_creates_task_spec_with_verbose(self) -> None:
"""main() should create TaskSpec with verbose=True."""
with patch("sys.argv", ["envrs", "mirror", "aliyun"]), patch.object(px, "run") as mock_run, patch.object(
envrs, "set_rust_mirror"
):
envrs.main()
graph = mock_run.call_args[0][0]
specs = graph.all_specs()
for spec in specs.values():
assert spec.verbose is True
def test_main_uses_thread_strategy(self) -> None:
"""main() should use thread strategy."""
with patch("sys.argv", ["envrs", "mirror", "aliyun"]), patch.object(px, "run") as mock_run, patch.object(
envrs, "set_rust_mirror"
):
envrs.main()
assert mock_run.call_args[1]["strategy"] == "thread"
+1 -1
View File
@@ -7,7 +7,7 @@ from unittest.mock import patch
import pytest
import pyflowx as px
from pyflowx.cli import taskkill
from pyflowx.cli.system import taskkill
from pyflowx.conditions import Constants
-66
View File
@@ -1,66 +0,0 @@
"""Tests for cli.which module."""
from __future__ import annotations
import shutil
from unittest.mock import patch
import pytest
import pyflowx as px
from pyflowx.cli import which
# ---------------------------------------------------------------------- #
# main function
# ---------------------------------------------------------------------- #
class TestMain:
"""Test main function."""
def test_main_with_single_command(self) -> None:
"""main() should handle single command argument."""
with patch("sys.argv", ["which", "python"]), patch.object(
shutil, "which", return_value="/usr/bin/python"
), patch.object(px, "run") as mock_run:
which.main()
# Should create a graph with one task
assert mock_run.called
graph = mock_run.call_args[0][0]
assert isinstance(graph, px.Graph)
def test_main_with_multiple_commands(self) -> None:
"""main() should handle multiple command arguments."""
with patch("sys.argv", ["which", "python", "pip", "node"]), patch.object(
shutil, "which", return_value="/usr/bin/cmd"
), patch.object(px, "run") as mock_run:
which.main()
# Should create a graph with three tasks
assert mock_run.called
graph = mock_run.call_args[0][0]
assert isinstance(graph, px.Graph)
def test_main_with_no_args_shows_help(self) -> None:
"""main() with no args should show help and exit."""
with patch("sys.argv", ["which"]), pytest.raises(SystemExit) as exc_info:
which.main()
assert exc_info.value.code == 2
def test_main_creates_task_specs_with_correct_names(self) -> None:
"""main() should create TaskSpecs with correct names."""
with patch("sys.argv", ["which", "git", "npm"]), patch.object(
shutil, "which", return_value="/usr/bin/cmd"
), patch.object(px, "run") as mock_run:
which.main()
graph = mock_run.call_args[0][0]
# Check that task names are correct
task_names = list(graph.all_specs().keys())
assert "which_git" in task_names
assert "which_npm" in task_names
def test_main_uses_thread_strategy(self) -> None:
"""main() should use thread strategy."""
with patch("sys.argv", ["which", "python"]), patch.object(
shutil, "which", return_value="/usr/bin/python"
), patch.object(px, "run") as mock_run:
which.main()
assert mock_run.call_args[1]["strategy"] == "thread"
+65
View File
@@ -0,0 +1,65 @@
import time
import pytest
from pytest_mock import MockerFixture
from pyflowx.utils import _perf_metrics, perf_timer
@pytest.fixture(autouse=True)
def reset_perf_metrics():
"""重置性能指标."""
_perf_metrics.clear()
class TestPerformanceTimer:
def test_perf_timer(self):
@perf_timer()
def test_func():
time.sleep(0.1)
test_func()
assert _perf_metrics["test_func"] is not None
assert _perf_metrics["test_func"]["count"] == 1
assert _perf_metrics["test_func"]["total_time"] >= 0.1
def test_perf_timer_report(self, mocker: MockerFixture):
mock_log = mocker.patch("logging.info")
@perf_timer(report=True, unit="ms", precision=3)
def test_func():
time.sleep(0.1)
test_func()
assert _perf_metrics["test_func"] is not None
assert _perf_metrics["test_func"]["count"] == 1
assert _perf_metrics["test_func"]["total_time"] >= 0.1
assert mock_log.call_count == 1
def test_generate_report(self, mocker: MockerFixture, caplog: pytest.LogCaptureFixture):
mock_log = mocker.patch("logging.info")
from pyflowx.utils import _generate_report
@perf_timer(report=True, unit="ms", precision=3)
def test_func():
time.sleep(0.1)
@perf_timer(report=True, unit="ms", precision=3)
def test_func2():
time.sleep(0.2)
test_func()
test_func2()
_generate_report("ms", 3)
assert mock_log.call_count == 3
assert _perf_metrics["test_func"]["count"] == 1
assert _perf_metrics["test_func"]["total_time"] >= 0.1
assert _perf_metrics["test_func2"]["count"] == 1
assert _perf_metrics["test_func2"]["total_time"] >= 0.2
Generated
+4 -26
View File
@@ -5603,12 +5603,12 @@ pycountry = [
[[package]]
name = "pyflowx"
version = "0.2.9"
version = "0.2.10"
source = { editable = "." }
dependencies = [
{ name = "graphlib-backport", marker = "python_full_version < '3.9'" },
{ name = "typing-extensions", version = "4.13.2", source = { registry = "https://mirrors.aliyun.com/pypi/simple/" }, marker = "python_full_version < '3.9'" },
{ name = "typing-extensions", version = "4.15.0", source = { registry = "https://mirrors.aliyun.com/pypi/simple/" }, marker = "python_full_version >= '3.9'" },
{ name = "typing-extensions", version = "4.15.0", source = { registry = "https://mirrors.aliyun.com/pypi/simple/" }, marker = "python_full_version == '3.9.*'" },
]
[package.optional-dependencies]
@@ -5658,7 +5658,6 @@ office = [
[package.dev-dependencies]
dev = [
{ name = "pyflowx", extra = ["dev", "llm", "office"] },
{ name = "pysnooper" },
]
[package.metadata]
@@ -5682,15 +5681,12 @@ requires-dist = [
{ name = "sglang", extras = ["all"], marker = "python_full_version >= '3.10' and sys_platform == 'linux' and extra == 'llm'", specifier = "==0.5.10rc0" },
{ name = "tox", marker = "extra == 'dev'", specifier = ">=4.25.0" },
{ name = "tox-uv", marker = "extra == 'dev'", specifier = ">=1.13.1" },
{ name = "typing-extensions", specifier = ">=4.13.2" },
{ name = "typing-extensions", marker = "python_full_version < '3.10'", specifier = ">=4.13.2" },
]
provides-extras = ["dev", "llm", "office"]
[package.metadata.requires-dev]
dev = [
{ name = "pyflowx", extras = ["dev", "office", "llm"], editable = "." },
{ name = "pysnooper", specifier = ">=1.2.3" },
]
dev = [{ name = "pyflowx", extras = ["dev", "office", "llm"], editable = "." }]
[[package]]
name = "pygments"
@@ -6007,15 +6003,6 @@ wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/42/3d/4c6bcb3d456835f51445d3662a428f56c3ea5643ec798c577030ae34298c/pyrefly-1.1.1-py3-none-win_arm64.whl", hash = "sha256:83baf0db71e172665db1fca0ced50b8f7773f5192ca57e8ac6773a772b6d2fc5" },
]
[[package]]
name = "pysnooper"
version = "1.2.3"
source = { registry = "https://mirrors.aliyun.com/pypi/simple/" }
sdist = { url = "https://mirrors.aliyun.com/pypi/packages/d2/4a/be3c144f58de6b78911c417cc4a3b3fe5eb6d13cae4c12daf3ca17a8d473/pysnooper-1.2.3.tar.gz", hash = "sha256:1fa1425444a7af45108aaed860b5ca8b62b25bba25b0b037c059ba353d8f1e74" }
wheels = [
{ url = "https://mirrors.aliyun.com/pypi/packages/69/87/df62c8a998216e6749b67d548dae0967906036c61457510ef49667927c49/PySnooper-1.2.3-py2.py3-none-any.whl", hash = "sha256:546372f0e72da89f8d1b89e758b7c05a478d65288569a1ca2cc1620e7b1b1944" },
]
[[package]]
name = "pytesseract"
version = "0.3.13"
@@ -8439,22 +8426,13 @@ name = "typing-extensions"
version = "4.15.0"
source = { registry = "https://mirrors.aliyun.com/pypi/simple/" }
resolution-markers = [
"python_full_version >= '3.15' and sys_platform == 'darwin'",
"python_full_version >= '3.15' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"python_full_version >= '3.15' and sys_platform == 'win32'",
"python_full_version >= '3.15' and sys_platform == 'emscripten'",
"(python_full_version >= '3.15' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.15' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')",
"python_full_version == '3.14.*' and sys_platform == 'darwin'",
"python_full_version == '3.13.*' and sys_platform == 'darwin'",
"python_full_version == '3.12.*' and sys_platform == 'darwin'",
"python_full_version == '3.14.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"python_full_version == '3.14.*' and sys_platform == 'win32'",
"python_full_version == '3.14.*' and sys_platform == 'emscripten'",
"(python_full_version == '3.14.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.14.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')",
"python_full_version == '3.13.*' and sys_platform == 'win32'",
"python_full_version == '3.13.*' and sys_platform == 'emscripten'",
"(python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.13.*' and sys_platform != 'darwin' and sys_platform != 'emscripten' and sys_platform != 'linux' and sys_platform != 'win32')",
"python_full_version == '3.12.*' and sys_platform == 'win32'",
"python_full_version == '3.12.*' and sys_platform == 'emscripten'",