24 Commits

Author SHA1 Message Date
zhou c15b38516a bump version to 0.2.11
Release / build (push) Failing after 29m15s
Release / publish-pypi (push) Has been skipped
Release / release (push) Has been skipped
2026-06-27 23:08:32 +08:00
zhou 7d4e8a40ce refactor(cli): 重构CLI模块结构,整理系统工具与开发工具
1. 将原cli根目录下的clearscreen、taskkill、which工具迁移到cli/system子目录
2. 新增cli/dev子目录并添加envdev环境配置工具
3. 更新pyproject.toml中的脚本入口点映射
4. 调整tests/cli下的测试文件导入路径
5. 整理tasks/system.py的__all__导出顺序
2026-06-27 22:01:02 +08:00
zhou 1b2d6d6a2c chore: 更新依赖配置并移除 pysnooper 2026-06-27 21:53:20 +08:00
zhou df890f0f16 chore: 移除独立的envpy和envrs命令,合并功能到envdev
将原来envpy和envrs的环境配置功能整合到envdev命令中,删除了冗余的独立CLI模块和测试文件,统一管理Python、Conda和Rust的环境配置。
2026-06-27 21:22:36 +08:00
zhou b62a544569 chore: 调整Python版本与依赖适配,新增性能报告测试与工具函数
1.  将Python版本从3.13降级到3.11
2.  为typing-extensions添加版本适配标记
3.  简化dev依赖组,移除pysnooper
4.  重构perf_timer,提取_generate_report独立函数
5.  新增性能报告生成与测试用例
2026-06-27 20:47:29 +08:00
zhou d58fc5536e chore: 发布 pyflowx 0.2.10,新增性能计时器与多项重构
1. 新增 perf_timer 工具与配套测试用例
2. 重构任务条件跳过逻辑,优化失败条件展示
3. 重构 Graph 子图生成逻辑,提取公共依赖修剪函数
4. 重构条件模块,统一条件名称与失败原因获取逻辑
5. 重构存储后端,提取 TTL 共享逻辑并优化实现
6. 重构执行器模块,使用 Mixin 复用代码,拆分任务与层执行逻辑
7. 删除冗余的 which 命令测试文件
8. 更新依赖锁文件
2026-06-27 20:15:35 +08:00
zhou c3b86b603d bump version to 0.2.10
Release / build (push) Failing after 11m58s
Release / publish-pypi (push) Has been skipped
Release / release (push) Has been skipped
2026-06-27 19:41:24 +08:00
zhou 327bd6e069 feat: 优化条件不满足时的报错信息展示
1. 新增格式化reason的工具函数统一处理报错信息
2. 支持从条件函数中提取自定义的失败原因
3. 完善NOT和OR条件的失败原因传递逻辑
4. 移除任务跳过的冗余打印输出
2026-06-27 19:40:51 +08:00
zhou 22f8d2110d chore: add pysnooper dev dependency and update configs
1. add pysnooper>=1.2.3 to dev dependencies in pyproject.toml and uv.lock
2. update type hints in task.py from Iterator to Generator
3. add more PyPI mirrors and update envdev.py comments and checks
4. fix trailing whitespace in executors.py
2026-06-27 19:35:11 +08:00
zhou 2a1f2f7175 refactor(envdev, conditions): 重构环境配置脚本,新增平台和文件条件检查
1. 移除废弃的envqt命令入口
2. 新增IS_WINDOWS、IS_LINUX等平台检测条件
3. 新增FILE_CONTENT_EXISTS文件内容检查条件
4. 使用内置条件替代硬编码的平台判断
5. 为任务添加条件控制,仅在符合场景时执行
2026-06-27 18:29:40 +08:00
zhou 9d033e1c0b refactor(system): add setenv_group and write_file task helpers
1. 为setenv和which函数添加正确的返回类型注解
2. 新增setenv_group批量设置环境变量的任务组
3. 新增write_file写入文件的任务工具函数
4. 更新__all__导出所有新增的工具函数

feat(cli/envdev): rewrite envdev cli with proper config and args
1. 重构环境开发CLI脚本,使用argparse替换原有TypedDict配置
2. 新增Python和Conda镜像源选择参数
3. 自动生成并写入Python pip和Conda配置文件
4. 优化任务依赖和命名,统一使用系统工具函数
2026-06-27 17:12:53 +08:00
zhou 336f7b7292 -envqt 2026-06-27 16:45:02 +08:00
zhou 65dcbcbf62 bump version to 0.2.9
Release / build (push) Failing after 16m3s
Release / publish-pypi (push) Has been skipped
Release / release (push) Has been skipped
2026-06-27 16:42:10 +08:00
zhou 7fa97a01e3 test(executors): add future annotations import to edge case test file
为测试文件添加from __future__ import annotations以支持更规范的类型注解写法
2026-06-27 16:33:24 +08:00
zhou 83da5135d0 test: add tests for graph all_deps and defaults inheritance
- add test_all_deps_combines_hard_and_soft to verify all_deps returns correct hard+soft deps in order
- add multiple tests for GraphDefaults field inheritance, including normal inheritance and non-override of custom values
2026-06-27 16:32:34 +08:00
zhou 7463a60649 test: 修复代码检查警告并优化测试用例
1. 为测试代码添加pyrefly忽略注释解决类型检查警告
2. 优化lambda参数命名为通配符符合PEP8规范
3. 增加断言检查任务函数非空并修正参数传递
4. 统一环境变量测试的命名和清理逻辑
2026-06-27 16:26:56 +08:00
zhou 87dd010342 test: add multiple new test cases and update python version
1. update .python-version from 3.11 to 3.13
2. add tests for IS_RUNNING and DIR_EXISTS conditions
3. add graph-related tests including string ref parsing, mermaid output, GraphComposer and compose function
4. add storage backend TTL tests for both MemoryBackend and JSONBackend
5. add new system task tests for clr, reset_icon_cache, setenv and which
6. add comprehensive task spec tests including soft dependencies, retry policy, context managers and task template
7. add executor edge case tests for various scenarios
2026-06-27 16:17:05 +08:00
zhou bdfee7bee4 ci: 简化CI/CD配置,移除冗余测试步骤和覆盖率上报
重构了GitHub Actions工作流,合并重复的CI任务,移除了预发布测试环节、多余的格式检查和安全审计任务,精简了 tox 测试命令与矩阵配置,同时删除了本地 tox 配置中的覆盖率和测试结果上报参数,优化整体流水线效率。
2026-06-27 16:00:44 +08:00
zhou b954fb1622 build(coverage): 调整coverage配置,新增cli目录到忽略白名单并提高达标阈值至95%
修改了pyproject.toml中的coverage配置:将src/pyflowx/cli/*加入omit排除列表,同时将测试覆盖率达标阈值从80提升至95
2026-06-27 15:57:00 +08:00
zhou a7b7a82dff ci: 完善CI/CD流程,添加测试覆盖率与并行测试配置
1. 为tox测试命令添加并行执行、覆盖率报告和JUnit结果输出
2. 拆分CI工作流为lint、格式检查、类型检查、安全审计、多矩阵测试和覆盖率汇总
3. 新增release前的预测试步骤,让build依赖测试通过
4. 移除低效的依赖策略测速测试用例
5. 配置多Python版本跨平台测试矩阵并上传测试 artifacts
2026-06-27 15:53:08 +08:00
zhou 40f0478146 bump version to 0.2.8
Release / build (push) Failing after 31s
Release / publish-pypi (push) Has been skipped
Release / release (push) Has been skipped
2026-06-27 15:44:09 +08:00
zhou b808b880f8 ci(github workflow): simplify release workflow
移除了冗余的预检查步骤、简化了工作流配置,更新了action版本并优化了版本提取和产物处理逻辑
2026-06-27 15:43:55 +08:00
zhou e073ff41ee ci: simplify and merge CI jobs
1. 合并lint和typecheck任务为一个job,减少重复的环境配置步骤
2. 精简测试矩阵,只保留Python3.8和3.13两个版本
3. 移除不必要的覆盖率上传和聚合检查job
4. 简化工作流触发条件,只保留push和手动触发
2026-06-27 15:43:24 +08:00
zhou ea0c51de5e build: 调整llm依赖条件并更新pyflowx版本
1. 为llm依赖添加linux平台限制
2. 移除uv.lock中的前置发布版本配置项
3. 将pyflowx版本从0.2.6升级到0.2.7
2026-06-27 15:33:33 +08:00
36 changed files with 2970 additions and 3722 deletions
+17 -96
View File
@@ -3,127 +3,48 @@ name: CI
on:
push:
branches: [ main, develop ]
pull_request:
branches: [ main, develop ]
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
# ─────────────────────────────────────────────────────────────
# lint:代码风格与格式检查(单平台即可)
# ─────────────────────────────────────────────────────────────
lint:
name: Lint (ruff)
lint-and-typecheck:
name: Lint & Typecheck
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- uses: actions/checkout@v4
- name: 安装 uv
uses: astral-sh/setup-uv@v5
- uses: astral-sh/setup-uv@v5
with:
version: latest
enable-cache: true
cache-dependency-glob: uv.lock
- name: 设置 Python 3.13
uses: actions/setup-python@v5
- uses: actions/setup-python@v5
with:
python-version: '3.13'
- name: 安装依赖
run: uv sync --extra dev --frozen
- run: uv sync
- run: uv run ruff check src tests
- run: uv run pyrefly check .
- name: Ruff 检查
run: uv run ruff check src tests
# ─────────────────────────────────────────────────────────────
# typecheckpyrefly 严格类型检查
# ─────────────────────────────────────────────────────────────
typecheck:
name: Typecheck (pyrefly)
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: 安装 uv
uses: astral-sh/setup-uv@v5
with:
version: latest
enable-cache: true
cache-dependency-glob: uv.lock
- name: 设置 Python 3.13
uses: actions/setup-python@v5
with:
python-version: '3.13'
- name: 安装依赖
run: uv sync --extra dev --frozen
- name: pyrefly 严格类型检查
run: uv run pyrefly check .
# ─────────────────────────────────────────────────────────────
# test:多平台 × 多 Python 版本矩阵测试 + 覆盖率
# ─────────────────────────────────────────────────────────────
test:
name: Test (${{ matrix.os }} / py${{ matrix.python-version }})
name: Test (${{ matrix.os }})
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ ubuntu-latest, windows-latest, macos-latest ]
python-version: [ '3.8', '3.9', '3.10', '3.11', '3.12', '3.13' ]
os: [ubuntu-latest, windows-latest, macos-latest]
steps:
- name: Checkout
uses: actions/checkout@v4
- uses: actions/checkout@v4
- name: 安装 uv
uses: astral-sh/setup-uv@v5
- uses: astral-sh/setup-uv@v5
with:
version: latest
enable-cache: true
cache-dependency-glob: uv.lock
- name: 设置 Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
python-version: |
3.8
3.13
- name: 安装依赖
run: uv sync --extra dev --frozen
- name: 运行测试
run: uv run pytest -v --cov=pyflowx --cov-report=xml --cov-report=term-missing
- name: 上传覆盖率
if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.13'
uses: actions/upload-artifact@v4
with:
name: coverage-${{ matrix.os }}-py${{ matrix.python-version }}
path: coverage.xml
retention-days: 7
# ─────────────────────────────────────────────────────────────
# 聚合:所有检查通过后才标记完成
# ─────────────────────────────────────────────────────────────
ci-pass:
name: CI Pass
runs-on: ubuntu-latest
needs: [ lint, typecheck, test ]
if: always()
steps:
- name: 检查依赖任务结果
if: ${{ needs.lint.result != 'success' || needs.typecheck.result != 'success' || needs.test.result != 'success' }}
run: |
echo "lint: ${{ needs.lint.result }}"
echo "typecheck: ${{ needs.typecheck.result }}"
echo "test: ${{ needs.test.result }}"
exit 1
- name: 全部通过
run: echo "✅ 所有 CI 检查通过"
- run: uvx tox run -e py38,py313
+21 -153
View File
@@ -2,192 +2,60 @@ name: Release
on:
push:
tags:
- 'v*.*.*'
workflow_dispatch:
inputs:
tag:
description: '发布版本号(如 v0.1.0'
required: true
type: string
tags: ['v*.*.*']
permissions:
contents: write
# Trusted Publishing (OIDC) 上传 PyPI 所需
id-token: write
jobs:
# ─────────────────────────────────────────────────────────────
# 预检:版本号校验 + 与 pyproject.toml 一致性检查
# ─────────────────────────────────────────────────────────────
pre-check:
name: Pre-release Check
build:
runs-on: ubuntu-latest
outputs:
version: ${{ steps.meta.outputs.version }}
tag: ${{ steps.meta.outputs.tag }}
version: ${{ steps.version.outputs.version }}
steps:
- name: Checkout
uses: actions/checkout@v4
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v5
with:
fetch-depth: 0
- name: 解析版本号
id: meta
run: |
if [ -n "${{ inputs.tag }}" ]; then
TAG="${{ inputs.tag }}"
else
TAG="${GITHUB_REF#refs/tags/}"
fi
# 去除前缀 v
VERSION="${TAG#v}"
echo "tag=$TAG" >> $GITHUB_OUTPUT
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "发布版本: $VERSION (tag: $TAG)"
- name: 校验版本号格式
run: |
VERSION="${{ steps.meta.outputs.version }}"
if ! echo "$VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+(-[a-zA-Z0-9.]+)?$'; then
echo "❌ 版本号格式错误: $VERSION(应为 x.y.z 或 x.y.z-rc.n"
exit 1
fi
- name: 校验 pyproject.toml 版本一致
run: |
# 精确提取 [project] 段的 version 字段(避免匹配到依赖的 version)
PY_VERSION=$(awk '/^\[project\]/{f=1} f&&/^version[[:space:]]*=/{gsub(/[" ]/,"",$3); print $3; exit}' pyproject.toml)
echo "pyproject.toml version: $PY_VERSION"
if [ "$PY_VERSION" != "${{ steps.meta.outputs.version }}" ]; then
echo "❌ pyproject.toml 版本($PY_VERSION) 与 tag 版本(${{ steps.meta.outputs.version }}) 不一致"
echo "请先更新 pyproject.toml 中的 version 字段"
exit 1
fi
# ─────────────────────────────────────────────────────────────
# 构建:wheel + sdist(纯 Python,单平台即可)
# ─────────────────────────────────────────────────────────────
build:
name: Build Artifacts
needs: pre-check
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: 安装 uv
uses: astral-sh/setup-uv@v5
with:
version: latest
enable-cache: true
- name: 设置 Python 3.13
uses: actions/setup-python@v5
- uses: actions/setup-python@v5
with:
python-version: '3.13'
- name: 安装依赖
run: uv sync --extra dev --frozen
- run: uv build
- name: 构建 wheel + sdist
run: uv build
- id: version
run: echo "version=${GITHUB_REF#refs/tags/v}" >> $GITHUB_OUTPUT
- name: 校验产物
run: |
echo "待上传产物:"
ls -la dist/
if [ -z "$(ls -A dist/*.whl dist/*.tar.gz 2>/dev/null)" ]; then
echo "❌ 未找到 wheel 或 sdist 产物"
exit 1
fi
- name: 上传构建产物
uses: actions/upload-artifact@v4
- uses: actions/upload-artifact@v7
with:
name: dist
path: dist/*
retention-days: 30
path: dist/
# ─────────────────────────────────────────────────────────────
# 发布:上传到 PyPITrusted Publishing / OIDC
# ─────────────────────────────────────────────────────────────
publish-pypi:
name: Publish to PyPI
needs: [pre-check, build]
needs: build
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/project/pyflowx/${{ needs.pre-check.outputs.version }}
permissions:
id-token: write
environment: pypi
steps:
- name: 下载构建产物
uses: actions/download-artifact@v4
- uses: actions/download-artifact@v8
with:
name: dist
path: dist
- name: 上传到 PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
attestations: true
- uses: pypa/gh-action-pypi-publish@release/v1
# ─────────────────────────────────────────────────────────────
# 发布:创建 GitHub Release
# ─────────────────────────────────────────────────────────────
release:
name: Publish Release
needs: [pre-check, build, publish-pypi]
needs: [build, publish-pypi]
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- name: Checkout
uses: actions/checkout@v4
- name: 下载构建产物
uses: actions/download-artifact@v4
- uses: actions/download-artifact@v8
with:
name: dist
path: assets
path: dist
- name: 整理发布产物
run: |
ls -la assets/
- name: 生成 Release Notes
id: notes
run: |
{
echo "## pyflowx ${{ needs.pre-check.outputs.version }}"
echo ""
echo "### 下载"
echo ""
echo "- **Wheel**: \`pyflowx-${{ needs.pre-check.outputs.version }}-py3-none-any.whl\`"
echo "- **源码包**: \`pyflowx-${{ needs.pre-check.outputs.version }}.tar.gz\`"
echo ""
echo "### 安装"
echo ""
echo '```bash'
echo "pip install pyflowx==${{ needs.pre-check.outputs.version }}"
echo '```'
echo ""
echo "### 完整变更日志"
} > RELEASE_NOTES.md
{
echo "content<<EOF"
cat RELEASE_NOTES.md
echo "EOF"
} >> $GITHUB_OUTPUT
- name: 创建 GitHub Release
uses: softprops/action-gh-release@v2
- uses: softprops/action-gh-release@v2
with:
tag_name: ${{ needs.pre-check.outputs.tag }}
name: pyflowx ${{ needs.pre-check.outputs.version }}
body: ${{ steps.notes.outputs.content }}
files: assets/*
draft: false
prerelease: ${{ contains(needs.pre-check.outputs.version, '-') }}
files: dist/*
generate_release_notes: true
+13 -12
View File
@@ -13,7 +13,7 @@ classifiers = [
]
dependencies = [
"graphlib_backport >= 1.0.0; python_version < '3.9'",
"typing-extensions>=4.13.2",
"typing-extensions>=4.13.2; python_version < '3.10'",
]
description = "Lightweight, type-safe DAG task scheduler with multi-strategy execution."
keywords = ["async", "dag", "scheduler", "task", "workflow"]
@@ -21,17 +21,12 @@ license = { text = "MIT" }
name = "pyflowx"
readme = "README.md"
requires-python = ">=3.8"
version = "0.2.7"
version = "0.2.11"
[project.scripts]
autofmt = "pyflowx.cli.autofmt:main"
bumpversion = "pyflowx.cli.bumpversion:main"
clr = "pyflowx.cli.clearscreen:main"
emlman = "pyflowx.cli.emlmanager:main"
envdev = "pyflowx.cli.envdev:main"
envpy = "pyflowx.cli.envpy:main"
envqt = "pyflowx.cli.envqt:main"
envrs = "pyflowx.cli.envrs:main"
filedate = "pyflowx.cli.filedate:main"
filelvl = "pyflowx.cli.filelevel:main"
foldback = "pyflowx.cli.folderback:main"
@@ -47,8 +42,12 @@ reseticon = "pyflowx.cli.reseticoncache:main"
scrcap = "pyflowx.cli.screenshot:main"
sglang = "pyflowx.cli.llm.sglang:main"
sshcopy = "pyflowx.cli.sshcopyid:main"
taskk = "pyflowx.cli.taskkill:main"
wch = "pyflowx.cli.which:main"
# dev
envdev = "pyflowx.cli.dev.envdev:main"
# system
clr = "pyflowx.cli.system.clearscreen:main"
taskk = "pyflowx.cli.system.taskkill:main"
wch = "pyflowx.cli.system.which:main"
[project.optional-dependencies]
dev = [
@@ -66,7 +65,9 @@ dev = [
"tox-uv>=1.13.1",
"tox>=4.25.0",
]
llm = ["sglang[all]==0.5.10rc0; python_version >= '3.10'"]
llm = [
"sglang[all]==0.5.10rc0; python_version >= '3.10' and sys_platform == 'linux'",
]
office = [
"pillow>=10.4.0",
"pymupdf>=1.24.11",
@@ -97,7 +98,7 @@ dev = ["pyflowx[dev,office,llm]"]
[tool.coverage.run]
branch = true
concurrency = ["thread"]
omit = ["src/pyflowx/examples/*", "tests/*"]
omit = ["src/pyflowx/cli/*", "src/pyflowx/examples/*", "tests/*"]
source = ["pyflowx"]
[tool.coverage.report]
@@ -107,7 +108,7 @@ exclude_lines = [
"pragma: no cover",
"raise NotImplementedError",
]
fail_under = 80
fail_under = 95
show_missing = true
[tool.pytest.ini_options]
+1 -1
View File
@@ -95,7 +95,7 @@ from .task import (
task_template,
)
__version__ = "0.3.1"
__version__ = "0.3.5"
__all__ = [
"IS_LINUX",
View File
+331
View File
@@ -0,0 +1,331 @@
from __future__ import annotations
import argparse
from pathlib import Path
from typing import Literal, get_args
import pyflowx as px
from pyflowx.conditions import BuiltinConditions
from pyflowx.tasks.system import setenv_group, write_file
# ============================================================================
# Mirror 配置
# ============================================================================
DOWNLOAD_MIRROR_SCRIPT: str = "curl -sSL https://linuxmirrors.cn/main.sh -o /tmp/linuxmirrors.sh"
INSTALL_MIRROR_SCRIPT: str = "sudo bash /tmp/linuxmirrors.sh"
# ============================================================================
# Python 配置
# ============================================================================
PyMirrorType = Literal["tsinghua", "aliyun", "huaweicloud", "ustc", "zju"]
PIP_INDEX_URLS: dict[PyMirrorType, str] = {
"tsinghua": "https://pypi.tuna.tsinghua.edu.cn/simple",
"aliyun": "https://mirrors.aliyun.com/pypi/simple/",
"huaweicloud": "https://mirrors.huaweicloud.com/repository/pypi/simple/",
"ustc": "https://pypi.mirrors.ustc.edu.cn/simple/",
"zju": "https://mirrors.zju.edu.cn/pypi/simple/",
}
PIP_TRUSTED_HOSTS: dict[PyMirrorType, str] = {
"tsinghua": "pypi.tuna.tsinghua.edu.cn",
"aliyun": "mirrors.aliyun.com",
"huaweicloud": "mirrors.huaweicloud.com",
"ustc": "pypi.mirrors.ustc.edu.cn",
"zju": "mirrors.zju.edu.cn",
}
PIP_CONFIG_PATH = Path.home() / ".pip" / "pip.conf" if BuiltinConditions.IS_LINUX() else Path.home() / "pip" / "pip.ini"
UV_INDEX_URLS = PIP_INDEX_URLS
UV_PYTHON_INSTALL_MIRROR: str = "https://registry.npmmirror.com/-/binary/python-build-standalone"
# ============================================================================
# Conda 配置
# ============================================================================
CondaMirrorType = Literal["tsinghua", "ustc", "bsfu", "aliyun"]
CONDA_MIRROR_URLS: dict[CondaMirrorType, list[str]] = {
"tsinghua": [
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/",
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/",
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r/",
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2/",
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/pro/",
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/",
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/bioconda/",
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/menpo/",
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/",
],
"ustc": [
"https://mirrors.ustc.edu.cn/anaconda/pkgs/main/",
"https://mirrors.ustc.edu.cn/anaconda/pkgs/free/",
"https://mirrors.ustc.edu.cn/anaconda/pkgs/r/",
"https://mirrors.ustc.edu.cn/anaconda/pkgs/msys2/",
"https://mirrors.ustc.edu.cn/anaconda/pkgs/pro/",
"https://mirrors.ustc.edu.cn/anaconda/pkgs/dev/",
"https://mirrors.ustc.edu.cn/anaconda/cloud/conda-forge/",
"https://mirrors.ustc.edu.cn/anaconda/cloud/bioconda/",
"https://mirrors.ustc.edu.cn/anaconda/cloud/menpo/",
"https://mirrors.ustc.edu.cn/anaconda/cloud/pytorch/",
],
"bsfu": [
"https://mirrors.bsfu.edu.cn/anaconda/pkgs/main/",
"https://mirrors.bsfu.edu.cn/anaconda/pkgs/free/",
"https://mirrors.bsfu.edu.cn/anaconda/pkgs/r/",
"https://mirrors.bsfu.edu.cn/anaconda/pkgs/msys2/",
"https://mirrors.bsfu.edu.cn/anaconda/pkgs/pro/",
"https://mirrors.bsfu.edu.cn/anaconda/pkgs/dev/",
"https://mirrors.bsfu.edu.cn/anaconda/cloud/conda-forge/",
"https://mirrors.bsfu.edu.cn/anaconda/cloud/bioconda/",
"https://mirrors.bsfu.edu.cn/anaconda/cloud/menpo/",
"https://mirrors.bsfu.edu.cn/anaconda/cloud/pytorch/",
],
"aliyun": [
"https://mirrors.aliyun.com/anaconda/pkgs/main/",
"https://mirrors.aliyun.com/anaconda/pkgs/free/",
"https://mirrors.aliyun.com/anaconda/pkgs/r/",
"https://mirrors.aliyun.com/anaconda/pkgs/msys2/",
"https://mirrors.aliyun.com/anaconda/pkgs/pro/",
"https://mirrors.aliyun.com/anaconda/pkgs/dev/",
"https://mirrors.aliyun.com/anaconda/cloud/conda-forge/",
"https://mirrors.aliyun.com/anaconda/cloud/bioconda/",
"https://mirrors.aliyun.com/anaconda/cloud/menpo/",
"https://mirrors.aliyun.com/anaconda/cloud/pytorch/",
],
}
CONDA_CONFIG_PATH = Path.home() / ".condarc"
# ============================================================================
# Qt 配置
# ============================================================================
QT_LIBS: list[str] = [
"build-essential",
"libgl1",
"libegl1",
"libglib2.0-0",
"libfontconfig1",
"libfreetype6",
"libxkbcommon0",
"libdbus-1-3",
"libxcb-xinerama0",
"libxcb-icccm4",
"libxcb-image0",
"libxcb-keysyms1",
"libxcb-randr0",
"libxcb-render-util0",
"libxcb-shape0",
"libxcb-xfixes0",
"libxcb-cursor0",
]
CHINESE_FONTS: list[str] = [
"fonts-noto-cjk",
"fonts-wqy-microhei",
"fonts-wqy-zenhei",
"fonts-noto-color-emoji",
]
# ============================================================================
# Rust 配置
# ============================================================================
RustMirrorType = Literal["tsinghua", "ustc", "aliyun"]
RustVersionType = Literal["stable", "nightly", "beta"]
DEFAULT_RUST_VERSION: RustVersionType = "stable"
DEFAULT_MIRROR: RustMirrorType = "tsinghua"
RUSTUP_MIRRORS: dict[RustMirrorType, dict[str, str]] = {
"tsinghua": {
"RUSTUP_DIST_SERVER": "https://mirrors.tuna.tsinghua.edu.cn/rustup",
"RUSTUP_UPDATE_ROOT": "https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup",
"TOML_REGISTRY": "https://mirrors.tuna.tsinghua.edu.cn/crates.io-index/",
},
"aliyun": {
"RUSTUP_DIST_SERVER": "https://mirrors.aliyun.com/rustup",
"RUSTUP_UPDATE_ROOT": "https://mirrors.aliyun.com/rustup/rustup",
"TOML_REGISTRY": "https://mirrors.aliyun.com/crates.io-index/",
},
"ustc": {
"RUSTUP_DIST_SERVER": "https://mirrors.ustc.edu.cn/rust-static",
"RUSTUP_UPDATE_ROOT": "https://mirrors.ustc.edu.cn/rust-static/rustup",
"TOML_REGISTRY": "https://mirrors.ustc.edu.cn/crates.io-index/",
},
}
RUSTUP_DOWNLOAD_URL_LINUX = "https://mirrors.aliyun.com/repo/rust/rustup-init.sh"
RUSTUP_DOWNLOAD_URL_WINDOWS = "https://static.rust-lang.org/rustup/dist/x86_64-pc-windows-msvc/rustup-init.exe"
RUST_CONFIG_PATH = Path.home() / ".cargo" / "config.toml"
RUST_SCCACHE_DIR: Path = Path.home() / ".cargo" / "sccache"
RUST_SCCACHE_CACHE_SIZE: str = "20G"
def main() -> None:
"""主函数."""
parser = argparse.ArgumentParser(description="环境开发工具")
parser.add_argument(
"--python-mirror",
nargs="?",
type=str,
default="tsinghua",
choices=get_args(PyMirrorType),
help="Python 镜像源",
)
parser.add_argument(
"--conda-mirror",
nargs="?",
type=str,
default="tsinghua",
choices=get_args(CondaMirrorType),
help="Conda 镜镜像源",
)
parser.add_argument(
"--rust-mirror",
nargs="?",
type=str,
default=DEFAULT_MIRROR,
choices=get_args(RustMirrorType),
help="Rust 镜像源",
)
parser.add_argument(
"--rust-version",
nargs="?",
type=str,
default=DEFAULT_RUST_VERSION,
choices=get_args(RustVersionType),
help=f"Rust 版本, 推荐: {get_args(RustVersionType)}",
)
args = parser.parse_args()
python_mirror = args.python_mirror
conda_mirror_urls = CONDA_MIRROR_URLS[args.conda_mirror]
rust_mirror = args.rust_mirror
rust_version = args.rust_version
# 确保配置文件目录存在
PIP_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
CONDA_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
RUST_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
RUST_SCCACHE_DIR.mkdir(parents=True, exist_ok=True)
# 使用 conditions 自动控制任务执行
graph = px.Graph.from_specs([
# 系统镜像配置(仅 Linux 且未配置国内镜像)
px.TaskSpec(
"download_mirror",
cmd=DOWNLOAD_MIRROR_SCRIPT,
conditions=(
BuiltinConditions.IS_LINUX(),
BuiltinConditions.NOT(
BuiltinConditions.OR(
*[
BuiltinConditions.FILE_CONTENT_EXISTS(f, m)
for f in [
"/etc/apt/sources.list",
"/etc/apt/sources.list.d/ubuntu.sources",
]
for m in get_args(PyMirrorType)
],
)
),
),
verbose=True,
),
px.TaskSpec(
"install_mirror",
cmd=INSTALL_MIRROR_SCRIPT,
depends_on=("download_mirror",),
verbose=True,
),
# 安装 Qt 依赖(仅 Linux
px.TaskSpec(
"install_qt_libs",
cmd=["sudo", "apt", "install", "-y", *QT_LIBS],
conditions=(BuiltinConditions.IS_LINUX(),),
depends_on=("install_mirror",),
allow_upstream_skip=True,
verbose=True,
),
# 安装中文字体(仅 Linux
px.TaskSpec(
"install_fonts",
cmd=["sudo", "apt", "install", "-y", *CHINESE_FONTS],
conditions=(BuiltinConditions.IS_LINUX(),),
depends_on=("install_mirror",),
allow_upstream_skip=True,
verbose=True,
),
# 设置 Python 环境变量
*setenv_group({
"PIP_INDEX_URL": PIP_INDEX_URLS[python_mirror],
"PIP_TRUSTED_HOSTS": PIP_TRUSTED_HOSTS[python_mirror],
"UV_INDEX_URL": UV_INDEX_URLS[python_mirror],
"UV_PYTHON_INSTALL_MIRROR": UV_PYTHON_INSTALL_MIRROR,
"UV_HTTP_TIMEOUT": "600",
"UV_LINK_MODE": "copy",
}),
# 写入 Python 配置(仅当未配置)
write_file(
str(PIP_CONFIG_PATH),
f"[global]\nindex-url = {PIP_INDEX_URLS[python_mirror]}\ntrusted-host = {PIP_TRUSTED_HOSTS[python_mirror]}",
),
# 写入 Conda 配置(仅当未配置)
write_file(
str(CONDA_CONFIG_PATH),
"show_channel_urls: true\nchannels:\n - " + "\n - ".join(conda_mirror_urls) + "\n - defaults",
),
# 设置 Rust 镜像源
*setenv_group({
"RUSTUP_DIST_SERVER": RUSTUP_MIRRORS[rust_mirror]["RUSTUP_DIST_SERVER"],
"RUSTUP_UPDATE_ROOT": RUSTUP_MIRRORS[rust_mirror]["RUSTUP_UPDATE_ROOT"],
"RUST_SCCACHE_DIR": str(RUST_SCCACHE_DIR),
"RUST_SCCACHE_CACHE_SIZE": RUST_SCCACHE_CACHE_SIZE,
}),
# 写入 Rust 配置(仅当未配置)
write_file(
str(RUST_CONFIG_PATH),
f"""
[source.crates-io]
replace-with = '{rust_mirror}'
[source.{rust_mirror}]
registry = "sparse+{RUSTUP_MIRRORS[rust_mirror]["TOML_REGISTRY"]}"
[registries.{rust_mirror}]
index = "sparse+{RUSTUP_MIRRORS[rust_mirror]["TOML_REGISTRY"]}"
""",
),
# 下载 Rustup 安装脚本
px.TaskSpec(
"download_rustup",
cmd=["curl", "-fsSL", RUSTUP_DOWNLOAD_URL_LINUX, "-o", "rustup-init.sh"],
conditions=(BuiltinConditions.IS_LINUX(), BuiltinConditions.NOT(BuiltinConditions.HAS_INSTALLED("rustup"))),
verbose=True,
),
px.TaskSpec(
"download_rustup_win",
cmd=[
"powershell",
"-Command",
"Invoke-WebRequest",
"-Uri",
RUSTUP_DOWNLOAD_URL_WINDOWS,
"-OutFile",
"rustup-init.exe",
],
conditions=(
BuiltinConditions.IS_WINDOWS(),
BuiltinConditions.NOT(BuiltinConditions.HAS_INSTALLED("rustup")),
),
verbose=True,
),
# 安装 Rust 工具链
px.TaskSpec(
"install_rust",
cmd=["rustup", "toolchain", "install", rust_version],
conditions=(BuiltinConditions.HAS_INSTALLED("rustup"),),
depends_on=("setenv_rustup_dist_server",),
allow_upstream_skip=True,
verbose=True,
),
])
px.run(graph, strategy="thread", verbose=True)
-59
View File
@@ -1,59 +0,0 @@
from typing import TypedDict
import pyflowx as px
class EnvConfig(TypedDict):
"""环境配置项."""
name: str
value: str
description: str
PIP_INDEX_URL_CONFIG: EnvConfig = {
"name": "PIP_INDEX_URL",
"value": "https://pypi.tuna.tsinghua.edu.cn/simple",
"description": "PIP索引URL",
}
# ============================================================================
# 配置
# ============================================================================
PIP_INDEX_URLS: dict[str, str] = {
"tsinghua": "https://pypi.tuna.tsinghua.edu.cn/simple",
"aliyun": "https://mirrors.aliyun.com/pypi/simple/",
}
PIP_TRUSTED_HOSTS: dict[str, str] = {
"tsinghua": "pypi.tuna.tsinghua.edu.cn",
"aliyun": "mirrors.aliyun.com",
}
UV_INDEX_URL: str = "https://mirrors.aliyun.com/pypi/simple/"
UV_PYTHON_INSTALL_MIRROR: str = "https://registry.npmmirror.com/-/binary/python-build-standalone"
CONDA_MIRROR_URLS: dict[str, list[str]] = {
"tsinghua": [
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/",
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/",
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/",
],
"aliyun": [
"https://mirrors.aliyun.com/anaconda/pkgs/main/",
"https://mirrors.aliyun.com/anaconda/pkgs/free/",
"https://mirrors.aliyun.com/anaconda/cloud/conda-forge/",
],
}
def main() -> None:
"""主函数."""
# 使用更安全的分步执行方式,便于调试和捕获错误
graph = px.Graph.from_specs([
px.TaskSpec("download", cmd="curl -sSL https://linuxmirrors.cn/main.sh -o /tmp/linuxmirrors.sh", verbose=True),
px.TaskSpec("install", cmd="sudo bash /tmp/linuxmirrors.sh", verbose=True, depends_on=("download",)),
])
px.run(graph, strategy="thread")
-122
View File
@@ -1,122 +0,0 @@
"""Python 环境配置工具.
用于设置 pip 镜像源, 支持清华和阿里云等国内镜像源,
同时配置 UV 和 Conda 的镜像源.
"""
from __future__ import annotations
import argparse
import os
from pathlib import Path
import pyflowx as px
from pyflowx.conditions import Constants
# ============================================================================
# 配置
# ============================================================================
PIP_INDEX_URLS: dict[str, str] = {
"tsinghua": "https://pypi.tuna.tsinghua.edu.cn/simple",
"aliyun": "https://mirrors.aliyun.com/pypi/simple/",
}
PIP_TRUSTED_HOSTS: dict[str, str] = {
"tsinghua": "pypi.tuna.tsinghua.edu.cn",
"aliyun": "mirrors.aliyun.com",
}
UV_INDEX_URL: str = "https://mirrors.aliyun.com/pypi/simple/"
UV_PYTHON_INSTALL_MIRROR: str = "https://registry.npmmirror.com/-/binary/python-build-standalone"
CONDA_MIRROR_URLS: dict[str, list[str]] = {
"tsinghua": [
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/",
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/",
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/",
],
"aliyun": [
"https://mirrors.aliyun.com/anaconda/pkgs/main/",
"https://mirrors.aliyun.com/anaconda/pkgs/free/",
"https://mirrors.aliyun.com/anaconda/cloud/conda-forge/",
],
}
# ============================================================================
# 辅助函数
# ============================================================================
def set_pip_mirror(mirror: str = "tsinghua", token: str | None = None) -> None:
"""设置 pip 镜像源.
Parameters
----------
mirror : str
镜像源名称: tsinghua, aliyun
token : str | None
PyPI token for publishing
"""
index_url = PIP_INDEX_URLS.get(mirror, PIP_INDEX_URLS["tsinghua"])
trusted_host = PIP_TRUSTED_HOSTS.get(mirror, "")
# 设置环境变量
os.environ["PIP_INDEX_URL"] = index_url
os.environ["UV_INDEX_URL"] = UV_INDEX_URL
os.environ["UV_DEFAULT_INDEX"] = UV_INDEX_URL
os.environ["UV_PYTHON_INSTALL_MIRROR"] = UV_PYTHON_INSTALL_MIRROR
# 写入 pip 配置文件
pip_dir = Path.home() / "pip"
pip_dir.mkdir(exist_ok=True)
pip_conf = pip_dir / ("pip.ini" if Constants.IS_WINDOWS else "pip.conf")
pip_conf.write_text(f"[global]\nindex-url = {index_url}\n[install]\ntrusted-host = {trusted_host}\n")
# 写入 conda 配置文件
condarc = Path.home() / ".condarc"
conda_urls = CONDA_MIRROR_URLS.get(mirror, CONDA_MIRROR_URLS["tsinghua"])
condarc.write_text(
"show_channel_urls: true\nchannels:\n" + "\n".join(f" - {url}" for url in conda_urls) + "\n - defaults\n"
)
# 写入 pypirc 配置文件 (如果有 token)
if token:
pypirc = Path.home() / ".pypirc"
pypirc.write_text(
f"[pypi]\nrepository: https://upload.pypi.org/legacy/\nusername: __token__\npassword: {token}\n"
)
print(f"已设置 pip 镜像源: {mirror} ({index_url})")
# ============================================================================
# CLI Runner
# ============================================================================
def main() -> None:
"""Python 环境配置工具主函数."""
parser = argparse.ArgumentParser(
description="EnvPy - Python 环境配置工具",
usage="envpy <command> [options]",
)
subparsers = parser.add_subparsers(dest="command", help="可用命令")
# 设置镜像源命令
mirror_parser = subparsers.add_parser("mirror", help="设置 pip 镜像源")
mirror_parser.add_argument("name", choices=["tsinghua", "aliyun"], help="镜像源名称")
mirror_parser.add_argument("--token", type=str, help="PyPI token for publishing")
args = parser.parse_args()
if args.command == "mirror":
graph = px.Graph.from_specs([
px.TaskSpec("set_pip_mirror", fn=set_pip_mirror, args=(args.name,), kwargs={"token": args.token})
])
else:
parser.print_help()
return
px.run(graph, strategy="thread")
-57
View File
@@ -1,57 +0,0 @@
"""PyQt 环境配置工具.
用于设置 PyQt 相关环境变量, 安装依赖环境.
"""
from __future__ import annotations
import pyflowx as px
from pyflowx.conditions import Constants
QT_LIBS: list[str] = [
"build-essential",
"libgl1",
"libegl1",
"libglib2.0-0",
"libfontconfig1",
"libfreetype6",
"libxkbcommon0",
"libdbus-1-3",
"libxcb-xinerama0",
"libxcb-icccm4",
"libxcb-image0",
"libxcb-keysyms1",
"libxcb-randr0",
"libxcb-render-util0",
"libxcb-shape0",
"libxcb-xfixes0",
"libxcb-cursor0",
]
CHINESE_FONTS: list[str] = [
"fonts-noto-cjk",
"fonts-wqy-microhei",
"fonts-wqy-zenhei",
"fonts-noto-color-emoji",
]
def main() -> None:
"""PyQt 环境配置工具主函数."""
graph = px.Graph.from_specs(
[
px.TaskSpec(
"envqt_install",
cmd=["sudo", "apt", "install", "-y", *QT_LIBS],
conditions=(lambda _: Constants.IS_LINUX,),
verbose=True,
),
px.TaskSpec(
"envqt_fonts",
cmd=["sudo", "apt", "install", "-y", *CHINESE_FONTS],
conditions=(lambda _: Constants.IS_LINUX,),
verbose=True,
),
],
)
px.run(graph, strategy="thread", verbose=True)
-150
View File
@@ -1,150 +0,0 @@
"""Rust 环境配置工具.
配置 Rustup 和 Cargo 的国内镜像源,
加速 Rust 工具链和依赖包的下载.
"""
from __future__ import annotations
import argparse
import os
import subprocess
from pathlib import Path
from typing import Literal, get_args
import pyflowx as px
# ============================================================================
# 配置
# ============================================================================
RUSTUP_MIRRORS: dict[str, dict[str, str]] = {
"aliyun": {
"RUSTUP_DIST_SERVER": "https://mirrors.aliyun.com/rustup",
"RUSTUP_UPDATE_ROOT": "https://mirrors.aliyun.com/rustup/rustup",
"TOML_REGISTRY": "https://mirrors.aliyun.com/crates.io-index/",
},
"ustc": {
"RUSTUP_DIST_SERVER": "https://mirrors.ustc.edu.cn/rust-static",
"RUSTUP_UPDATE_ROOT": "https://mirrors.ustc.edu.cn/rust-static/rustup",
"TOML_REGISTRY": "https://mirrors.ustc.edu.cn/crates.io-index/",
},
"tsinghua": {
"RUSTUP_DIST_SERVER": "https://mirrors.tuna.tsinghua.edu.cn/rustup",
"RUSTUP_UPDATE_ROOT": "https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup",
"TOML_REGISTRY": "https://mirrors.tuna.tsinghua.edu.cn/crates.io-index/",
},
}
UsableRustVersion = Literal["stable", "nightly", "beta"]
UsableMirror = Literal["aliyun", "ustc", "tsinghua"]
DEFAULT_RUST_VERSION: UsableRustVersion = "stable"
DEFAULT_MIRROR: UsableMirror = "tsinghua"
# ============================================================================
# 辅助函数
# ============================================================================
def set_rust_mirror(mirror: UsableMirror = DEFAULT_MIRROR) -> None:
"""设置 Rust 镜像源.
Parameters
----------
mirror : str
镜像源名称: aliyun, ustc, tsinghua
"""
mirror_dict = RUSTUP_MIRRORS.get(mirror, RUSTUP_MIRRORS[DEFAULT_MIRROR])
server = mirror_dict["RUSTUP_DIST_SERVER"]
update_root = mirror_dict["RUSTUP_UPDATE_ROOT"]
toml_registry = mirror_dict["TOML_REGISTRY"]
# 设置环境变量
os.environ["RUSTUP_DIST_SERVER"] = server
os.environ["RUSTUP_UPDATE_ROOT"] = update_root
# 写入 cargo 配置
cargo_dir = Path.home() / ".cargo"
cargo_dir.mkdir(exist_ok=True)
cargo_config = cargo_dir / "config.toml"
cargo_config.write_text(
f"""[source.crates-io]
replace-with = '{mirror}'
[source.{mirror}]
registry = "sparse+{toml_registry}"
[registries.{mirror}]
index = "sparse+{toml_registry}"
"""
)
print(f"已设置 Rust 镜像源: {mirror}")
def install_rust(version: UsableRustVersion = DEFAULT_RUST_VERSION) -> None:
"""安装 Rust 工具链.
Parameters
----------
version : str
Rust 版本: stable, nightly, beta
"""
try:
subprocess.run(["rustup", "toolchain", "install", version], check=True)
print(f"已安装 Rust {version}")
except FileNotFoundError:
print("未找到 rustup,请先安装 Rust: https://rustup.rs")
raise
# ============================================================================
# CLI Runner
# ============================================================================
def main() -> None:
"""Rust 环境配置工具主函数."""
parser = argparse.ArgumentParser(
description="EnvRs - Rust 环境配置工具",
usage="envrs <command> [options]",
)
subparsers = parser.add_subparsers(dest="command", help="可用命令")
# 设置镜像源命令
mirror_parser = subparsers.add_parser("mirror", help="设置 Rust 镜像源")
mirror_parser.add_argument(
"name",
nargs="?",
default=DEFAULT_MIRROR,
choices=get_args(UsableMirror),
help=f"镜像源名称 ({get_args(UsableMirror)})",
)
# 安装 Rust 命令
install_parser = subparsers.add_parser("install", help="安装 Rust 工具链")
install_parser.add_argument(
"version",
nargs="?",
default=DEFAULT_RUST_VERSION,
choices=get_args(UsableRustVersion),
help=f"Rust 版本 ({get_args(UsableRustVersion)})",
)
args = parser.parse_args()
if args.command == "mirror":
graph = px.Graph.from_specs([
px.TaskSpec("set_rust_mirror", fn=set_rust_mirror, args=(args.name,), verbose=True)
])
elif args.command == "install":
graph = px.Graph.from_specs([
px.TaskSpec("install_rust", cmd=["rustup", "toolchain", "install", args.version], verbose=True)
])
else:
parser.print_help()
return
px.run(graph, strategy="thread", verbose=True)
View File
@@ -35,6 +35,6 @@ def main() -> None:
[
px.TaskSpec(f"kill_{proc_name}", cmd=[*cmd, f"{proc_name}*"], verbose=True)
for proc_name in args.process_names
]
],
)
px.run(graph, strategy="thread")
+67 -7
View File
@@ -42,6 +42,19 @@ def _static(predicate: Callable[[], bool], name: str) -> Condition:
return _cond
def _cond_reason(cond: Condition) -> str | list[str] | None:
"""获取条件的失败原因:优先返回 ``_reason``,否则返回 ``__name__``。"""
reason = getattr(cond, "_reason", None)
if reason is not None:
return reason
return getattr(cond, "__name__", repr(cond))
def _cond_name(cond: Condition) -> str:
"""获取条件的可读名称。"""
return getattr(cond, "__name__", repr(cond))
# ---------------------------------------------------------------------- #
# 模块级静态条件常量
# ---------------------------------------------------------------------- #
@@ -61,6 +74,26 @@ class BuiltinConditions:
# ------------------------------------------------------------------ #
# 静态条件
# ------------------------------------------------------------------ #
@staticmethod
def IS_WINDOWS() -> Condition:
"""检查是否为 Windows 平台."""
return IS_WINDOWS
@staticmethod
def IS_LINUX() -> Condition:
"""检查是否为 Linux 平台."""
return IS_LINUX
@staticmethod
def IS_MACOS() -> Condition:
"""检查是否为 macOS 平台."""
return IS_MACOS
@staticmethod
def IS_POSIX() -> Condition:
"""检查是否为 POSIX 平台."""
return IS_POSIX
@staticmethod
def PYTHON_VERSION(major: int, minor: int | None = None) -> Condition:
"""检查 Python 版本是否匹配."""
@@ -118,6 +151,21 @@ class BuiltinConditions:
f"ENV_VAR_EQUALS({var_name!r},{value!r})",
)
@staticmethod
def FILE_CONTENT_EXISTS(path: Path | str, content: str) -> Condition:
"""检查文件是否包含指定内容."""
def _check() -> bool:
p = Path(path)
if not p.exists():
return False
try:
return content in p.read_text(encoding="utf-8")
except Exception:
return False
return _static(_check, f"FILE_CONTENT_EXISTS({path!r},{content!r})")
# ------------------------------------------------------------------ #
# 上下文条件:基于上游依赖结果
# ------------------------------------------------------------------ #
@@ -180,9 +228,15 @@ class BuiltinConditions:
"""对条件取反."""
def _cond(ctx: Context) -> bool:
return not condition(ctx)
result = condition(ctx)
if result:
# inner 为 True 时 NOT 会失败,记录 inner 的具体原因
inner_reason = _cond_reason(condition)
if inner_reason is not None:
_cond._reason = inner_reason # type: ignore[attr-defined]
return not result
_cond.__name__ = f"NOT({getattr(condition, '__name__', repr(condition))})"
_cond.__name__ = f"NOT({_cond_name(condition)})"
return _cond
@staticmethod
@@ -192,8 +246,7 @@ class BuiltinConditions:
def _cond(ctx: Context) -> bool:
return all(c(ctx) for c in conditions)
names = [getattr(c, "__name__", repr(c)) for c in conditions]
_cond.__name__ = f"AND({', '.join(names)})"
_cond.__name__ = f"AND({', '.join(_cond_name(c) for c in conditions)})"
return _cond
@staticmethod
@@ -201,8 +254,15 @@ class BuiltinConditions:
"""多个条件的逻辑或."""
def _cond(ctx: Context) -> bool:
return any(c(ctx) for c in conditions)
matched: list[str] = []
for c in conditions:
if c(ctx):
reason = _cond_reason(c)
matched.append(reason if isinstance(reason, str) else str(reason))
if matched:
_cond._reason = matched # type: ignore[attr-defined]
return True
return False
names = [getattr(c, "__name__", repr(c)) for c in conditions]
_cond.__name__ = f"OR({', '.join(names)})"
_cond.__name__ = f"OR({', '.join(_cond_name(c) for c in conditions)})"
return _cond
+303 -338
View File
@@ -10,6 +10,17 @@
* ``dependency`` —— 依赖驱动调度:任务在其所有硬依赖完成后立即启动,
无需等待同层其他任务。最大化并行度。
架构
----
本模块通过 **Mixin** 组合消除同步/异步与各层执行器之间的重复代码:
* :class:`_TaskSkipMixin` —— 上游跳过 / 条件跳过的预检逻辑。
* :class:`_TaskRetryMixin` —— 重试决策、成功/失败后处理、finalize。
* :class:`_LayerMixin` —— 缓存过滤、优先级排序、信号量构建、结果存储。
* :class:`SyncTaskRunner` / :class:`AsyncTaskRunner` —— 任务级执行器,组合上述 Mixin。
* :class:`SequentialLayerRunner` / :class:`ThreadedLayerRunner` /
:class:`AsyncLayerRunner` / :class:`DependencyRunner` —— 层级执行器,组合 :class:`_LayerMixin`。
所有策略共享统一异步内核,支持:
* :class:`RetryPolicy`max_attempts/delay/backoff/jitter/retry_on
* 软依赖注入与默认值
@@ -30,6 +41,7 @@ import concurrent.futures
import inspect
import logging
import threading
import time
from datetime import datetime
from typing import Any, Awaitable, Callable, Literal, Mapping, cast
@@ -48,7 +60,7 @@ Strategy = Literal["sequential", "thread", "async", "dependency"]
# ---------------------------------------------------------------------- #
# 辅助
# 无状态公共辅助
# ---------------------------------------------------------------------- #
def _is_async_fn(spec: TaskSpec[Any]) -> bool:
"""判断 ``spec.effective_fn`` 是否为协程函数。"""
@@ -71,17 +83,6 @@ def _emit(on_event: EventCallback | None, result: TaskResult[Any]) -> None:
)
def _log_retry(spec: TaskSpec[Any], attempt: int, max_attempts: int, exc: BaseException) -> None:
"""记录重试日志。"""
logger.warning(
"task %r failed (attempt %d/%d): %r; retrying",
spec.name,
attempt,
max_attempts,
exc,
)
def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None:
"""安全调用钩子(异常仅记录,不影响任务状态)。"""
hook: Callable[..., None] | None = getattr(hooks, fn_name, None)
@@ -93,77 +94,6 @@ def _run_hooks(hooks: TaskHooks, fn_name: str, *args: Any) -> None:
logger.warning("hook %s raised: %r", fn_name, exc)
def _check_upstream_skipped(
spec: TaskSpec[Any],
report: RunReport | None,
) -> tuple[bool, str | None]:
"""检查硬依赖上游任务是否被 SKIPPED 或 FAILED。
软依赖不影响本检查——软依赖被跳过时注入默认值。
"""
if report is None:
return False, None
if spec.allow_upstream_skip:
return False, None
for dep in spec.depends_on:
if dep not in report.results:
continue
dep_status = report.results[dep].status
if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED):
return True, f"上游任务 '{dep}' 状态为 {dep_status.value}"
return False, None
def _evaluate_conditions(spec: TaskSpec[Any], context: Mapping[str, Any]) -> str | None:
"""求值所有条件,返回跳过原因或 ``None``。
条件接收上下文映射(硬依赖 + 软依赖结果)。
"""
failed_conditions: list[str] = []
for condition in spec.conditions:
try:
ok = condition(context)
except Exception:
ok = False
name = getattr(condition, "__name__", None) or "匿名条件(执行错误)"
failed_conditions.append(name)
continue
if not ok:
failed_conditions.append(getattr(condition, "__name__", None) or "匿名条件")
if failed_conditions:
if len(failed_conditions) <= 2:
return f"条件不满足: {', '.join(failed_conditions)}"
return f"条件不满足: {', '.join(failed_conditions[:2])}{len(failed_conditions)}个条件"
if spec.skip_if_missing and not spec._is_cmd_available():
cmd_name = spec.cmd[0] if isinstance(spec.cmd, list) and spec.cmd else "unknown"
return f"命令不存在: {cmd_name}"
return None
def _make_skipped_result(
spec: TaskSpec[Any],
reason: str,
on_event: EventCallback | None,
) -> TaskResult[Any]:
"""构造 SKIPPED 的 TaskResult。"""
result: TaskResult[Any] = TaskResult(
spec=spec,
status=TaskStatus.SKIPPED,
finished_at=datetime.now(),
reason=reason,
)
_emit(on_event, result)
if spec.verbose:
print(f"[skip] 任务 '{spec.name}' 跳过: {reason}", flush=True)
logger.info("task %r skipped (%s)", spec.name, reason)
return result
def _build_context(
spec: TaskSpec[Any],
global_context: Mapping[str, Any],
@@ -175,11 +105,9 @@ def _build_context(
软依赖:上游成功则注入其值;否则注入 ``spec.defaults`` 中的默认值(或 ``None``)。
"""
ctx: dict[str, Any] = {}
for dep in spec.depends_on:
if dep in global_context:
ctx[dep] = global_context[dep]
for dep in spec.soft_depends_on:
if dep in global_context:
ctx[dep] = global_context[dep]
@@ -187,7 +115,6 @@ def _build_context(
ctx[dep] = spec.defaults[dep]
else:
ctx[dep] = None
return ctx
@@ -212,33 +139,93 @@ def _apply_cached(
return True
def _prepare_for_execution(
def _sort_by_priority(layer: list[str], graph: Graph) -> list[str]:
"""按优先级降序排序(稳定排序)。"""
return sorted(layer, key=lambda n: -graph.resolved_spec(n).priority)
# ---------------------------------------------------------------------- #
# Mixin:任务级跳过 / 重试 / 成功处理
# ---------------------------------------------------------------------- #
class _TaskSkipMixin:
"""任务级跳过预检共享逻辑。
"上游被跳过/失败""条件不满足"两类跳过判断统一为单一入口,
被 :class:`SyncTaskRunner` 与 :class:`AsyncTaskRunner` 复用。
"""
@staticmethod
def _upstream_skip_reason(spec: TaskSpec[Any], report: RunReport | None) -> str | None:
"""硬依赖被 SKIPPED/FAILED 时返回原因字符串,否则 ``None``。
软依赖不影响本检查——软依赖被跳过时注入默认值。
"""
if report is None or spec.allow_upstream_skip:
return None
for dep in spec.depends_on:
if dep not in report.results:
continue
dep_status = report.results[dep].status
if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED):
return f"上游任务 '{dep}' 状态为 {dep_status.value}"
return None
@staticmethod
def _prepare_for_execution(
spec: TaskSpec[Any],
context: Mapping[str, Any],
report: RunReport | None,
on_event: EventCallback | None,
) -> TaskResult[Any] | None:
) -> TaskResult[Any] | None:
"""执行前预检:上游跳过 / 条件跳过。
返回 SKIPPED TaskResult 或 ``None``(继续执行)。
条件判断委托给 :meth:`TaskSpec.should_execute`,避免重复实现。
"""
should_skip, skip_reason = _check_upstream_skipped(spec, report)
if should_skip:
return _make_skipped_result(spec, skip_reason or "上游任务被跳过", on_event)
skip_reason = _evaluate_conditions(spec, context)
if skip_reason is not None:
return _make_skipped_result(spec, skip_reason, on_event)
# 1. 上游被跳过/失败
skip_reason = _TaskSkipMixin._upstream_skip_reason(spec, report)
# 2. 条件 / skip_if_missing(单一来源:TaskSpec.should_execute
if skip_reason is None:
should_run, cond_reason = spec.should_execute(context)
if not should_run:
skip_reason = cond_reason or "条件不满足"
if skip_reason is None:
return None
# 构造 SKIPPED 结果
result: TaskResult[Any] = TaskResult(
spec=spec,
status=TaskStatus.SKIPPED,
finished_at=datetime.now(),
reason=skip_reason,
)
_emit(on_event, result)
logger.info("task %r skipped (%s)", spec.name, skip_reason)
return result
def _finalize_failure(
class _TaskRetryMixin:
"""任务级重试决策与失败/成功后处理共享逻辑。"""
@staticmethod
def _should_retry(spec: TaskSpec[Any], attempts: int, exc: BaseException) -> bool:
"""是否应继续重试。"""
return attempts < spec.retry.max_attempts and spec.retry.should_retry(exc)
@staticmethod
def _mark_success(spec: TaskSpec[Any], result: TaskResult[Any], value: Any) -> None:
"""标记任务成功并触发 post_run 钩子。"""
result.value = value
result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now()
_run_hooks(spec.hooks, "post_run", spec, value)
@staticmethod
def _finalize_failure(
result: TaskResult[Any],
layer_idx: int | None,
on_event: EventCallback | None = None,
continue_on_error: bool = False,
) -> None:
on_event: EventCallback | None,
continue_on_error: bool,
) -> None:
"""标记任务为 FAILED。若 ``continue_on_error`` 为真则不抛出异常。"""
result.status = TaskStatus.FAILED
result.finished_at = datetime.now()
@@ -256,41 +243,66 @@ def _finalize_failure(
layer=layer_idx,
)
@staticmethod
def _handle_failure(
spec: TaskSpec[Any],
result: TaskResult[Any],
exc: BaseException,
layer_idx: int | None,
on_event: EventCallback | None,
) -> bool:
"""统一处理失败:超时转换、重试决策、finalize。
def _sleep_for_retry(spec: TaskSpec[Any], attempt: int) -> None:
"""重试前的同步等待。"""
wait = spec.retry.wait_seconds(attempt)
if wait > 0:
import time
time.sleep(wait)
async def _async_sleep_for_retry(spec: TaskSpec[Any], attempt: int) -> None:
"""重试前的异步等待。"""
wait = spec.retry.wait_seconds(attempt)
if wait > 0:
await asyncio.sleep(wait)
Returns
-------
bool
``True`` 表示已 finalize(不再重试);``False`` 表示应继续重试。
"""
# asyncio.TimeoutError → TaskTimeoutError(统一异常类型)
if isinstance(exc, asyncio.TimeoutError):
exc = TaskTimeoutError(spec.name, spec.timeout or 0.0)
logger.warning(
"task %r timed out (attempt %d/%d); retrying",
spec.name,
result.attempts,
spec.retry.max_attempts,
)
else:
logger.warning(
"task %r failed (attempt %d/%d): %r; retrying",
spec.name,
result.attempts,
spec.retry.max_attempts,
exc,
)
result.error = exc
if _TaskRetryMixin._should_retry(spec, result.attempts, exc):
return False
_run_hooks(spec.hooks, "on_failure", spec, exc)
_TaskRetryMixin._finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return True
# ---------------------------------------------------------------------- #
# 同步执行内核
# 任务执行器:同步 / 异步(复用 _TaskSkipMixin + _TaskRetryMixin
# ---------------------------------------------------------------------- #
def _run_sync_with_retry(
class SyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
"""同步任务执行器:带重试与跳过预检。"""
@staticmethod
def run(
spec: TaskSpec[Any],
context: Mapping[str, Any],
layer_idx: int | None,
on_event: EventCallback | None = None,
report: RunReport | None = None,
) -> TaskResult[Any]:
"""执行同步任务并带重试;返回填充好的 TaskResult。"""
skipped = _prepare_for_execution(spec, context, report, on_event)
) -> TaskResult[Any]:
skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event)
if skipped is not None:
return skipped
result: TaskResult[Any] = TaskResult(spec=spec)
result.started_at = datetime.now()
max_attempts = spec.retry.max_attempts
args, kwargs = build_call_args(spec, context)
_run_hooks(spec.hooks, "pre_run", spec)
@@ -299,25 +311,60 @@ def _run_sync_with_retry(
result.attempts += 1
try:
with spec.env_context():
result.value = spec.effective_fn(*args, **kwargs)
result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now()
_run_hooks(spec.hooks, "post_run", spec, result.value)
value = spec.effective_fn(*args, **kwargs)
_TaskRetryMixin._mark_success(spec, result, value)
return result
except Exception as exc:
result.error = exc
if result.attempts >= max_attempts or not spec.retry.should_retry(exc):
_run_hooks(spec.hooks, "on_failure", spec, exc)
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event):
return result
_log_retry(spec, result.attempts, max_attempts, exc)
_sleep_for_retry(spec, result.attempts)
# pragma: no cover
wait = spec.retry.wait_seconds(result.attempts)
if wait > 0:
time.sleep(wait)
class AsyncTaskRunner(_TaskSkipMixin, _TaskRetryMixin):
"""异步任务执行器:在事件循环上运行同步或异步任务,带重试与跳过预检。"""
@staticmethod
async def run(
spec: TaskSpec[Any],
context: Mapping[str, Any],
layer_idx: int | None,
on_event: EventCallback | None = None,
report: RunReport | None = None,
semaphore: asyncio.Semaphore | None = None,
) -> TaskResult[Any]:
skipped = _TaskSkipMixin._prepare_for_execution(spec, context, report, on_event)
if skipped is not None:
return skipped
async def _inner() -> TaskResult[Any]:
result: TaskResult[Any] = TaskResult(spec=spec)
result.started_at = datetime.now()
args, kwargs = build_call_args(spec, context)
loop = asyncio.get_event_loop()
_run_hooks(spec.hooks, "pre_run", spec)
while True:
result.attempts += 1
try:
value = await _execute_async_task(spec, args, kwargs, loop)
_TaskRetryMixin._mark_success(spec, result, value)
return result
except Exception as exc:
if _TaskRetryMixin._handle_failure(spec, result, exc, layer_idx, on_event):
return result
wait = spec.retry.wait_seconds(result.attempts)
if wait > 0:
await asyncio.sleep(wait)
if semaphore is not None:
async with semaphore:
return await _inner()
return await _inner()
# ---------------------------------------------------------------------- #
# 异步执行内核
# ---------------------------------------------------------------------- #
async def _execute_async_task(
spec: TaskSpec[Any],
args: tuple[Any, ...],
@@ -329,9 +376,7 @@ async def _execute_async_task(
coro = cast(Awaitable[Any], spec.effective_fn(*args, **kwargs))
if spec.timeout is not None:
return await asyncio.wait_for(coro, timeout=spec.timeout)
else:
return await coro
else:
def fn_call() -> Any:
with spec.env_context():
@@ -339,87 +384,89 @@ async def _execute_async_task(
if spec.timeout is not None:
return await asyncio.wait_for(loop.run_in_executor(None, fn_call), timeout=spec.timeout)
else:
return await loop.run_in_executor(None, fn_call)
async def _run_async_with_retry(
spec: TaskSpec[Any],
context: Mapping[str, Any],
layer_idx: int | None,
on_event: EventCallback | None = None,
report: RunReport | None = None,
semaphore: asyncio.Semaphore | None = None,
) -> TaskResult[Any]:
"""在事件循环上执行任务(同步或异步)并带重试。"""
skipped = _prepare_for_execution(spec, context, report, on_event)
if skipped is not None:
return skipped
# ---------------------------------------------------------------------- #
# Mixin:层执行共享逻辑
# ---------------------------------------------------------------------- #
class _LayerMixin:
"""层执行共享逻辑:缓存过滤、优先级排序、信号量构建、结果存储。
if semaphore is not None:
async with semaphore:
return await _run_async_inner(spec, context, layer_idx, on_event, report)
return await _run_async_inner(spec, context, layer_idx, on_event, report)
四个层执行器(sequential/threaded/async/dependency)通过组合此 Mixin
消除"过滤缓存→排序→运行→存结果"的样板代码。
"""
@staticmethod
def _filter_and_sort(
layer: list[str],
graph: Graph,
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
) -> list[str]:
"""过滤掉已命中缓存的任务,按优先级排序返回待运行列表。"""
to_run: list[str] = []
for name in layer:
spec = graph.resolved_spec(name)
if not _apply_cached(name, spec, context, report, backend, on_event):
to_run.append(name)
return _sort_by_priority(to_run, graph)
async def _run_async_inner(
spec: TaskSpec[Any],
context: Mapping[str, Any],
layer_idx: int | None,
on_event: EventCallback | None = None,
report: RunReport | None = None, # noqa: ARG001
) -> TaskResult[Any]:
"""异步执行内核的内部实现(已获取 semaphore 后)。"""
result: TaskResult[Any] = TaskResult(spec=spec)
result.started_at = datetime.now()
max_attempts = spec.retry.max_attempts
args, kwargs = build_call_args(spec, context)
loop = asyncio.get_event_loop()
@staticmethod
def _store_result(
name: str,
result: TaskResult[Any],
graph: Graph,
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
context_snapshot: Mapping[str, Any] | None = None,
) -> None:
"""存储任务结果到 context/report/backend 并触发事件。"""
context[name] = result.value
if result.status == TaskStatus.SUCCESS:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot if context_snapshot is not None else context, report)
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
_run_hooks(spec.hooks, "pre_run", spec)
@staticmethod
def _build_semaphores(
to_run: list[str],
graph: Graph,
sem_factory: Callable[[int], Any],
concurrency_limits: Mapping[str, int],
) -> dict[str, Any]:
"""为每个 ``concurrency_key`` 创建一个信号量。"""
semaphores: dict[str, Any] = {}
for name in to_run:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = sem_factory(limit)
return semaphores
while True:
result.attempts += 1
try:
result.value = await _execute_async_task(spec, args, kwargs, loop)
result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now()
_run_hooks(spec.hooks, "post_run", spec, result.value)
return result
except asyncio.TimeoutError:
exc: BaseException = TaskTimeoutError(spec.name, spec.timeout or 0.0)
result.error = exc
if result.attempts >= max_attempts or not spec.retry.should_retry(exc):
_run_hooks(spec.hooks, "on_failure", spec, exc)
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return result
logger.warning(
"task %r timed out (attempt %d/%d); retrying",
spec.name,
result.attempts,
max_attempts,
)
await _async_sleep_for_retry(spec, result.attempts)
except Exception as exc:
result.error = exc
if result.attempts >= max_attempts or not spec.retry.should_retry(exc):
_run_hooks(spec.hooks, "on_failure", spec, exc)
_finalize_failure(result, layer_idx, on_event, spec.continue_on_error)
return result
_log_retry(spec, result.attempts, max_attempts, exc)
await _async_sleep_for_retry(spec, result.attempts)
# pragma: no cover
@staticmethod
def _get_sem(semaphores: Mapping[str, Any], spec: TaskSpec[Any]) -> Any | None:
"""获取任务对应的信号量(无 concurrency_key 则返回 None)。"""
if spec.concurrency_key is None:
return None
return semaphores.get(spec.concurrency_key)
# ---------------------------------------------------------------------- #
# 层执行器
# ---------------------------------------------------------------------- #
def _sort_by_priority(layer: list[str], graph: Graph) -> list[str]:
"""按优先级降序排序(稳定排序)。"""
return sorted(layer, key=lambda n: -graph.resolved_spec(n).priority)
class SequentialLayerRunner(_LayerMixin):
"""逐个运行某层的任务(按优先级排序)。"""
def _execute_layer_sequential(
@staticmethod
def execute(
layer: list[str],
graph: Graph,
context: dict[str, Any],
@@ -427,22 +474,19 @@ def _execute_layer_sequential(
backend: StateBackend,
layer_idx: int,
on_event: EventCallback | None,
) -> None:
"""逐个运行某层的任务(按优先级排序)。"""
for name in _sort_by_priority(layer, graph):
) -> None:
for name in SequentialLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event):
spec = graph.resolved_spec(name)
if _apply_cached(name, spec, context, report, backend, on_event):
continue
task_ctx = _build_context(spec, context, report)
result = _run_sync_with_retry(spec, task_ctx, layer_idx, on_event, report)
context[name] = result.value
if result.status == TaskStatus.SUCCESS:
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
result = SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
SequentialLayerRunner._store_result(name, result, graph, context, report, backend, on_event)
def _execute_layer_threaded(
class ThreadedLayerRunner(_LayerMixin):
"""在线程池中并发运行某层的任务。"""
@staticmethod
def execute(
layer: list[str],
graph: Graph,
context: dict[str, Any],
@@ -452,70 +496,48 @@ def _execute_layer_threaded(
on_event: EventCallback | None,
max_workers: int,
concurrency_limits: Mapping[str, int],
) -> None:
"""在线程池中并发运行某层的任务。"""
to_run: list[str] = []
for name in layer:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context, report)
if _apply_cached(name, spec, context, report, backend, on_event):
continue
to_run.append(name)
) -> None:
to_run = ThreadedLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event)
if not to_run:
return
to_run = _sort_by_priority(to_run, graph)
# 为每个 concurrency_key 创建线程信号量
semaphores: dict[str, threading.Semaphore] = {}
for name in to_run:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = threading.Semaphore(limit)
semaphores = ThreadedLayerRunner._build_semaphores(to_run, graph, threading.Semaphore, concurrency_limits)
context_snapshot = dict(context)
lock = threading.Lock()
def _run_threaded_task(name: str) -> TaskResult[Any]:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report)
sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None
sem = ThreadedLayerRunner._get_sem(semaphores, spec)
if sem is not None:
sem.acquire()
try:
return _run_sync_with_retry(spec, task_ctx, layer_idx, on_event, report)
return SyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report)
finally:
if sem is not None:
sem.release()
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
future_to_name: dict[concurrent.futures.Future[TaskResult[Any]], str] = {}
for name in to_run:
fut = pool.submit(_run_threaded_task, name)
future_to_name[fut] = name
future_to_name: dict[concurrent.futures.Future[TaskResult[Any]], str] = {
pool.submit(_run_threaded_task, name): name for name in to_run
}
completed: dict[str, TaskResult[Any]] = {}
try:
for fut in concurrent.futures.as_completed(future_to_name):
name = future_to_name[fut]
result = fut.result()
completed[name] = result
completed[name] = fut.result()
finally:
with lock:
for name, result in completed.items():
context[name] = result.value
if result.status == TaskStatus.SUCCESS:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report)
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
ThreadedLayerRunner._store_result(
name, result, graph, context, report, backend, on_event, context_snapshot
)
async def _execute_layer_async(
class AsyncLayerRunner(_LayerMixin):
"""在事件循环上并发运行某层的任务。"""
@staticmethod
async def execute(
layer: list[str],
graph: Graph,
context: dict[str, Any],
@@ -524,76 +546,41 @@ async def _execute_layer_async(
layer_idx: int,
on_event: EventCallback | None,
concurrency_limits: Mapping[str, int],
) -> None:
"""在事件循环上并发运行某层的任务。"""
to_run: list[str] = []
for name in layer:
spec = graph.resolved_spec(name)
if _apply_cached(name, spec, context, report, backend, on_event):
continue
to_run.append(name)
) -> None:
to_run = AsyncLayerRunner._filter_and_sort(layer, graph, context, report, backend, on_event)
if not to_run:
return
to_run = _sort_by_priority(to_run, graph)
# 为每个 concurrency_key 创建异步信号量
semaphores: dict[str, asyncio.Semaphore] = {}
for name in to_run:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = asyncio.Semaphore(limit)
semaphores = AsyncLayerRunner._build_semaphores(to_run, graph, asyncio.Semaphore, concurrency_limits)
context_snapshot = dict(context)
async def _run_async_task_wrapped(name: str) -> TaskResult[Any]:
async def _run_async_task(name: str) -> TaskResult[Any]:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report)
sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None
if sem is not None:
async with sem:
return await _run_async_with_retry(spec, task_ctx, layer_idx, on_event, report)
return await _run_async_with_retry(spec, task_ctx, layer_idx, on_event, report)
sem = AsyncLayerRunner._get_sem(semaphores, spec)
return await AsyncTaskRunner.run(spec, task_ctx, layer_idx, on_event, report, sem)
coros = [_run_async_task_wrapped(name) for name in to_run]
results = await asyncio.gather(*coros)
results = await asyncio.gather(*[_run_async_task(name) for name in to_run])
for name, result in zip(to_run, results):
context[name] = result.value
if result.status == TaskStatus.SUCCESS:
spec = graph.resolved_spec(name)
task_ctx = _build_context(spec, context_snapshot, report)
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
AsyncLayerRunner._store_result(name, result, graph, context, report, backend, on_event, context_snapshot)
# ---------------------------------------------------------------------- #
# 依赖驱动调度
# ---------------------------------------------------------------------- #
async def _drive_dependency_async(
class DependencyRunner(_LayerMixin):
"""依赖驱动调度:任务在硬/软依赖完成后立即启动,无层屏障。
所有任务通过 asyncio 并发调度。同步任务卸载到线程池。
"""
@staticmethod
async def execute(
graph: Graph,
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
concurrency_limits: Mapping[str, int],
) -> None:
"""依赖驱动调度:任务在硬依赖完成后立即启动,无层屏障。
所有任务通过 asyncio 并发调度。同步任务卸载到线程池。
"""
all_names = set(graph.all_specs().keys())
semaphores: dict[str, asyncio.Semaphore] = {}
for name in all_names:
spec = graph.resolved_spec(name)
key = spec.concurrency_key
if key is not None and key not in semaphores:
limit = concurrency_limits.get(key, 1)
semaphores[key] = asyncio.Semaphore(limit)
) -> None:
all_names = list(graph.all_specs().keys())
semaphores = DependencyRunner._build_semaphores(all_names, graph, asyncio.Semaphore, concurrency_limits)
futures: dict[str, asyncio.Future[TaskResult[Any]]] = {}
async def _run_task(name: str) -> TaskResult[Any]:
@@ -611,24 +598,14 @@ async def _drive_dependency_async(
if _apply_cached(name, spec, context, report, backend, on_event):
return report.results[name]
sem = semaphores.get(spec.concurrency_key) if spec.concurrency_key else None
if sem is not None:
async with sem:
result = await _run_async_with_retry(spec, task_ctx, None, on_event, report)
else:
result = await _run_async_with_retry(spec, task_ctx, None, on_event, report)
context[name] = result.value
if result.status == TaskStatus.SUCCESS:
backend.save(spec.storage_key(task_ctx), result.value)
report.results[name] = result
_emit(on_event, result)
sem = DependencyRunner._get_sem(semaphores, spec)
result = await AsyncTaskRunner.run(spec, task_ctx, None, on_event, report, sem)
DependencyRunner._store_result(name, result, graph, context, report, backend, on_event)
return result
loop = asyncio.get_event_loop()
for name in all_names:
futures[name] = loop.create_task(_run_task(name))
await asyncio.gather(*futures.values())
@@ -719,9 +696,9 @@ def run(
elif strategy == "thread":
_drive_threaded(graph, layers, context, report, backend, effective_callback, max_workers, limits)
elif strategy == "async":
_drive_async(graph, layers, context, report, backend, effective_callback, limits)
asyncio.run(_async_drive(graph, layers, context, report, backend, effective_callback, limits))
elif strategy == "dependency":
asyncio.run(_drive_dependency_async(graph, context, report, backend, effective_callback, limits))
asyncio.run(DependencyRunner.execute(graph, context, report, backend, effective_callback, limits))
else:
raise ValueError(f"Unknown strategy: {strategy!r}")
except TaskFailedError:
@@ -749,7 +726,7 @@ def _drive_sequential(
on_event: EventCallback | None,
) -> None:
for idx, layer in enumerate(layers, 1):
_execute_layer_sequential(layer, graph, context, report, backend, idx, on_event)
SequentialLayerRunner.execute(layer, graph, context, report, backend, idx, on_event)
def _drive_threaded(
@@ -764,19 +741,7 @@ def _drive_threaded(
) -> None:
for idx, layer in enumerate(layers, 1):
workers = max_workers or max(1, min(32, len(layer)))
_execute_layer_threaded(layer, graph, context, report, backend, idx, on_event, workers, concurrency_limits)
def _drive_async(
graph: Graph,
layers: list[list[str]],
context: dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: EventCallback | None,
concurrency_limits: Mapping[str, int],
) -> None:
asyncio.run(_async_drive(graph, layers, context, report, backend, on_event, concurrency_limits))
ThreadedLayerRunner.execute(layer, graph, context, report, backend, idx, on_event, workers, concurrency_limits)
async def _async_drive(
@@ -789,4 +754,4 @@ async def _async_drive(
concurrency_limits: Mapping[str, int],
) -> None:
for idx, layer in enumerate(layers, 1):
await _execute_layer_async(layer, graph, context, report, backend, idx, on_event, concurrency_limits)
await AsyncLayerRunner.execute(layer, graph, context, report, backend, idx, on_event, concurrency_limits)
+25 -16
View File
@@ -12,6 +12,11 @@
from __future__ import annotations
__all__ = [
"Graph",
"GraphDefaults",
]
import sys
from dataclasses import dataclass, field, replace
from typing import Any, Callable, Iterable, Mapping, Sequence
@@ -49,6 +54,15 @@ class GraphDefaults:
verbose: bool = False
def _prune_deps(spec: TaskSpec[Any], keep: Callable[[str], bool]) -> TaskSpec[Any]:
"""返回新 spec,其 ``depends_on`` / ``soft_depends_on`` 仅保留 ``keep(dep)`` 为真的依赖。"""
return replace(
spec,
depends_on=tuple(d for d in spec.depends_on if keep(d)),
soft_depends_on=tuple(d for d in spec.soft_depends_on if keep(d)),
)
@dataclass
class Graph:
"""校验后的有向无环任务图。
@@ -64,6 +78,7 @@ class Graph:
specs: dict[str, TaskSpec[Any]] = field(default_factory=dict)
deps: dict[str, tuple[str, ...]] = field(default_factory=dict)
defaults: GraphDefaults = field(default_factory=GraphDefaults)
# 待解析的字符串引用列表(由 GraphComposer 消费);为空表示无引用。
_pending_refs: list[str] = field(default_factory=list)
@@ -225,16 +240,13 @@ class Graph:
def subgraph(self, tags: Iterable[str]) -> Graph:
"""返回仅包含匹配任意标签的任务的新图。依赖边被修剪。"""
wanted: set[str] = set(tags)
kept: list[TaskSpec[Any]] = []
for spec in self.specs.values():
if wanted & set(spec.tags):
pruned_deps = tuple(
d for d in spec.depends_on if d in self.specs and (wanted & set(self.specs[d].tags))
)
pruned_soft = tuple(
d for d in spec.soft_depends_on if d in self.specs and (wanted & set(self.specs[d].tags))
)
kept.append(replace(spec, depends_on=pruned_deps, soft_depends_on=pruned_soft))
def _dep_kept(dep: str) -> bool:
return dep in self.specs and bool(wanted & set(self.specs[dep].tags))
kept: list[TaskSpec[Any]] = [
_prune_deps(spec, _dep_kept) for spec in self.specs.values() if wanted & set(spec.tags)
]
return Graph.from_specs(kept, defaults=self.defaults)
def subgraph_by_names(self, names: Iterable[str]) -> Graph:
@@ -243,12 +255,9 @@ class Graph:
for n in wanted:
if n not in self.specs:
raise KeyError(f"Unknown task name: {n!r}")
kept: list[TaskSpec[Any]] = []
for spec in self.specs.values():
if spec.name in wanted:
pruned_deps = tuple(d for d in spec.depends_on if d in wanted)
pruned_soft = tuple(d for d in spec.soft_depends_on if d in wanted)
kept.append(replace(spec, depends_on=pruned_deps, soft_depends_on=pruned_soft))
kept: list[TaskSpec[Any]] = [
_prune_deps(spec, lambda d: d in wanted) for spec in self.specs.values() if spec.name in wanted
]
return Graph.from_specs(kept, defaults=self.defaults)
# ------------------------------------------------------------------ #
+111 -42
View File
@@ -17,13 +17,14 @@ import json
import sys
import time
from abc import ABC, abstractmethod
from collections.abc import Iterator
from pathlib import Path
from typing import Any, Mapping
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
from typing_extensions import override # pragma: no cover
from .errors import StorageError
@@ -55,7 +56,74 @@ class StateBackend(ABC):
"""清除所有存储状态。"""
class MemoryBackend(StateBackend):
class _TTLStateBackendMixin(StateBackend):
"""TTL 状态后端共享逻辑。
将 ``has`` / ``get`` / ``load`` / ``save`` / ``clear`` 的统一实现
委托给四个原始存取原语::meth:`_get_raw`、:meth:`_put_raw`、
:meth:`_iter_raw`、:meth:`_clear_raw`,并基于 :meth:`_now` 与
``self._ttl`` 提供统一的过期判断 :meth:`_is_expired`。
子类需设置 ``self._ttl`` 并实现上述四个原语;如需自定义时间源
(如 ``time.monotonic``)可覆盖 :meth:`_now`。
"""
_ttl: float | None
# ---- 原语:由子类实现 ---- #
@abstractmethod
def _get_raw(self, key: str) -> tuple[Any, float] | None:
"""返回 ``(value, ts)``;键不存在时返回 ``None``。"""
@abstractmethod
def _put_raw(self, key: str, value: Any, ts: float) -> None:
"""写入一条记录。"""
@abstractmethod
def _iter_raw(self) -> Iterator[tuple[str, Any, float]]:
"""迭代所有记录(不做过期过滤),yield ``(key, value, ts)``。"""
@abstractmethod
def _clear_raw(self) -> None:
"""清空所有记录。"""
# ---- 共享实现 ---- #
def _now(self) -> float:
"""当前时间戳,默认为 wall-clock 秒。"""
return time.time()
def _is_expired(self, ts: float) -> bool:
"""时间戳 ``ts`` 是否已过期。"""
if self._ttl is None:
return False
return (self._now() - ts) > self._ttl
@override
def load(self) -> Mapping[str, Any]:
return {k: v for k, v, ts in self._iter_raw() if not self._is_expired(ts)}
@override
def save(self, key: str, value: Any) -> None:
self._put_raw(key, value, self._now())
@override
def has(self, key: str) -> bool:
entry = self._get_raw(key)
return entry is not None and not self._is_expired(entry[1])
@override
def get(self, key: str) -> Any:
entry = self._get_raw(key)
if entry is None or self._is_expired(entry[1]):
raise KeyError(key)
return entry[0]
@override
def clear(self) -> None:
self._clear_raw()
class MemoryBackend(_TTLStateBackendMixin):
"""进程内 dict 后端。进程退出即丢失。
Parameters
@@ -70,35 +138,35 @@ class MemoryBackend(StateBackend):
self._ttl = ttl
@override
def load(self) -> Mapping[str, Any]:
return {k: v for k, (v, _ts) in self._store.items() if not self._expired(k)}
def _now(self) -> float:
return time.monotonic()
@override
def save(self, key: str, value: Any) -> None:
self._store[key] = (value, time.monotonic())
def _get_raw(self, key: str) -> tuple[Any, float] | None:
return self._store.get(key)
@override
def has(self, key: str) -> bool:
return key in self._store and not self._expired(key)
def _put_raw(self, key: str, value: Any, ts: float) -> None:
self._store[key] = (value, ts)
@override
def get(self, key: str) -> Any:
if key not in self._store or self._expired(key):
raise KeyError(key)
return self._store[key][0]
def _iter_raw(self) -> Iterator[tuple[str, Any, float]]:
for k, (v, ts) in self._store.items():
yield k, v, ts
@override
def clear(self) -> None:
def _clear_raw(self) -> None:
self._store.clear()
def _expired(self, key: str) -> bool:
if self._ttl is None or key not in self._store:
"""键是否已过期(兼容旧测试 API)。"""
entry = self._get_raw(key)
if entry is None:
return False
_value, ts = self._store[key]
return (time.monotonic() - ts) > self._ttl
return self._is_expired(entry[1])
class JSONBackend(StateBackend):
class JSONBackend(_TTLStateBackendMixin):
"""基于文件的 JSON 存储,用于跨进程续跑。
存储格式:``{key: {"value": v, "ts": epoch_seconds}}``。
@@ -131,7 +199,6 @@ class JSONBackend(StateBackend):
if isinstance(v, dict) and "value" in v and "ts" in v:
self._store[k] = v
else:
# 旧格式:纯值
self._store[k] = {"value": v, "ts": time.time()}
except (OSError, json.JSONDecodeError) as exc:
raise StorageError(f"cannot read state file {self._path!r}", exc) from exc
@@ -145,17 +212,30 @@ class JSONBackend(StateBackend):
except (OSError, TypeError) as exc:
raise StorageError(f"cannot write state file {self._path!r}", exc) from exc
def _now(self) -> float:
return time.time()
def _expired(self, entry: dict[str, Any]) -> bool:
if self._ttl is None:
return False
return (self._now() - float(entry.get("ts", 0))) > self._ttl
@override
def _get_raw(self, key: str) -> tuple[Any, float] | None:
entry = self._store.get(key)
if entry is None:
return None
return entry["value"], float(entry.get("ts", 0))
@override
def load(self) -> Mapping[str, Any]:
return {k: v["value"] for k, v in self._store.items() if not self._expired(v)}
def _put_raw(self, key: str, value: Any, ts: float) -> None:
self._store[key] = {"value": value, "ts": ts}
@override
def _iter_raw(self) -> Iterator[tuple[str, Any, float]]:
for k, entry in self._store.items():
yield k, entry["value"], float(entry.get("ts", 0))
@override
def _clear_raw(self) -> None:
self._store.clear()
@override
def clear(self) -> None:
super().clear()
self._flush()
@override
def save(self, key: str, value: Any) -> None:
@@ -163,23 +243,12 @@ class JSONBackend(StateBackend):
_ = json.dumps(value)
except (TypeError, ValueError) as exc:
raise StorageError(f"result of key {key!r} is not JSON-serialisable", exc) from exc
self._store[key] = {"value": value, "ts": self._now()}
super().save(key, value)
self._flush()
@override
def has(self, key: str) -> bool:
return key in self._store and not self._expired(self._store[key])
@override
def get(self, key: str) -> Any:
if key not in self._store or self._expired(self._store[key]):
raise KeyError(key)
return self._store[key]["value"]
@override
def clear(self) -> None:
self._store.clear()
self._flush()
def _expired(self, entry: Mapping[str, Any]) -> bool:
"""带元数据的条目是否已过期(兼容旧测试 API)。"""
return self._is_expired(float(entry.get("ts", 0)))
def resolve_backend(backend: StateBackend | None) -> StateBackend:
+19 -6
View File
@@ -31,8 +31,8 @@ from typing import (
Callable,
ContextManager,
Coroutine,
Generator,
Generic,
Iterator,
List,
Mapping,
Union,
@@ -42,7 +42,7 @@ from typing import (
if sys.version_info >= (3, 13):
from typing import TypeVar
else:
from typing_extensions import TypeVar
from typing_extensions import TypeVar # pragma: no cover
T = TypeVar("T", default=Any)
@@ -74,6 +74,13 @@ Condition = Callable[[Context], bool]
CacheKeyFn = Callable[[Context], str]
def _format_skip_reason(failed_conditions: list[str]) -> str:
"""格式化跳过原因:≤2 个全展示,>2 个仅展示前 2 个并附总数。"""
if len(failed_conditions) <= 2:
return f"条件不满足: {', '.join(failed_conditions)}"
return f"条件不满足: {', '.join(failed_conditions[:2])}{len(failed_conditions)}个条件"
# ---------------------------------------------------------------------- #
# 重试策略
# ---------------------------------------------------------------------- #
@@ -315,6 +322,7 @@ class TaskSpec(Generic[T]):
-------
(should_run, skip_reason)
``should_run`` 为 False 时 ``skip_reason`` 描述跳过原因。
失败条件超过 2 个时仅展示前 2 个并附总数。
"""
# 逐个求值条件,记录失败项。
failed_conditions: list[str] = []
@@ -323,14 +331,19 @@ class TaskSpec(Generic[T]):
ok = condition(context)
except Exception:
ok = False
name = getattr(condition, "__name__", None) or "匿名条件(执行错误)"
failed_conditions.append(name)
failed_conditions.append("匿名条件(执行错误)")
continue
if not ok:
reason = getattr(condition, "_reason", None)
if reason is not None:
failed_conditions.append(
", ".join(str(r) for r in reason) if isinstance(reason, list) else str(reason),
)
else:
failed_conditions.append(getattr(condition, "__name__", None) or "匿名条件")
if failed_conditions:
return False, f"条件不满足: {', '.join(failed_conditions)}"
return False, _format_skip_reason(failed_conditions)
if self.skip_if_missing and not self._is_cmd_available():
cmd_name = self.cmd[0] if isinstance(self.cmd, list) and self.cmd else "unknown"
@@ -367,7 +380,7 @@ class TaskSpec(Generic[T]):
def _env_and_cwd(
env: Mapping[str, str] | None,
cwd: Path | None,
) -> Iterator[None]:
) -> Generator[None, None, None]:
"""临时设置环境变量与工作目录。"""
saved_env: dict[str, str] = {}
saved_cwd: str | None = None
+27 -3
View File
@@ -6,6 +6,15 @@
from __future__ import annotations
__all__ = [
"clr",
"reset_icon_cache",
"setenv",
"setenv_group",
"which",
"write_file",
]
import os
import subprocess
from pathlib import Path
@@ -66,7 +75,7 @@ def reset_icon_cache() -> list[px.TaskSpec]:
]
def setenv(name: str, value: str, default: bool = False):
def setenv(name: str, value: str, default: bool = False) -> px.TaskSpec:
"""设置环境变量任务."""
def set_env():
@@ -78,7 +87,12 @@ def setenv(name: str, value: str, default: bool = False):
return px.TaskSpec(f"setenv_{name.lower()}", fn=set_env, verbose=True)
def which(cmd: str):
def setenv_group(envs: dict[str, str], default: bool = False) -> list[px.TaskSpec]:
"""设置环境变量组任务."""
return [setenv(name, value, default) for name, value in envs.items()]
def which(cmd: str) -> px.TaskSpec:
"""查找命令路径任务."""
which_cmd = "where" if Constants.IS_WINDOWS else "which"
@@ -95,4 +109,14 @@ def which(cmd: str):
return px.TaskSpec(f"which_{cmd}", fn=find_command)
__all__ = ["clr", "setenv", "which"]
def write_file(path: str, content: str, encoding: str = "utf-8") -> px.TaskSpec:
"""写入文件任务."""
def write():
try:
with open(path, "w", encoding=encoding) as f:
f.write(content)
except Exception as e:
print(f"写入文件 {path} 失败: {e}")
return px.TaskSpec(f"write_file_{path}", fn=write, verbose=True)
+107
View File
@@ -0,0 +1,107 @@
"""常用工具函数."""
from __future__ import annotations
__all__ = ["perf_timer"]
import functools
import logging
import time
from collections import defaultdict
from typing import Callable, TypedDict
try:
from typing_extensions import ParamSpec, TypeVar
except ImportError:
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
class _PerformanceMetrics(TypedDict):
"""性能指标."""
count: int
total_time: float
_perf_metrics: defaultdict[str, _PerformanceMetrics] = defaultdict(
lambda: _PerformanceMetrics(
count=0,
total_time=0.0,
)
)
def _generate_report(unit: str, precision: int) -> str:
"""生成性能指标报告,返回报告字符串."""
if not _perf_metrics:
return ""
lines: list[str] = []
lines.append("=" * 50)
lines.append("性能指标报告 (Performance Metrics Report)")
lines.append("-" * 50)
# 按总耗时排序,最耗时的函数排在前面
sorted_metrics = sorted(_perf_metrics.items(), key=lambda x: x[1]["total_time"], reverse=True)
for name, metrics in sorted_metrics:
avg_time = metrics["total_time"] / metrics["count"] if metrics["count"] > 0 else 0
lines.append(
f"{name}: "
f"调用次数={metrics['count']}, "
f"总耗时={metrics['total_time']:.{precision}f}{unit}, "
f"平均耗时={avg_time:.{precision}f}{unit}"
)
lines.append("=" * 50)
report_str = "\n".join(lines)
# 同时输出到日志
logging.info("\n".join(lines))
return report_str
def perf_timer(unit: str = "ms", precision: int = 4, report: bool = False):
"""性能计时器装饰器."""
scale: dict[str, float] = {
"s": 1.0,
"ms": 1000.0,
"us": 1000000.0,
}
def decorator(func: Callable[P, R]) -> Callable[P, R]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
_perf_metrics[func.__name__]["count"] += 1
_perf_metrics[func.__name__]["total_time"] += (end_time - start_time) * scale[unit]
if not report:
logging.info(
f"{func.__name__} {unit}: {_perf_metrics[func.__name__]['total_time']:.{precision}f}{unit}"
)
return result
return wrapper
if report:
import atexit
logging.basicConfig(level=logging.INFO)
logging.info(f"Performance metrics report enabled with unit {unit} and precision {precision}")
@atexit.register
def _report_at_exit() -> None:
"""在程序退出时报告性能指标."""
_generate_report(unit, precision)
# 将报告生成逻辑提取为独立函数,便于测试
return decorator
+1 -1
View File
@@ -5,7 +5,7 @@ from __future__ import annotations
from unittest.mock import patch
import pyflowx as px
from pyflowx.cli import clearscreen
from pyflowx.cli.system import clearscreen
# ---------------------------------------------------------------------- #
-110
View File
@@ -1,110 +0,0 @@
"""Tests for cli.envpy module."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import patch
import pytest
import pyflowx as px
from pyflowx.cli import envpy
# ---------------------------------------------------------------------- #
# set_pip_mirror
# ---------------------------------------------------------------------- #
class TestSetPipMirror:
"""Test set_pip_mirror function."""
def test_set_pip_mirror_tsinghua(self, tmp_path: Path) -> None:
"""Should set tsinghua mirror."""
with patch.object(Path, "home", return_value=tmp_path):
envpy.set_pip_mirror("tsinghua")
# Check pip config
pip_config = tmp_path / "pip" / "pip.ini"
if envpy.Constants.IS_WINDOWS:
assert pip_config.exists() or (tmp_path / "pip" / "pip.conf").exists()
def test_set_pip_mirror_aliyun(self, tmp_path: Path) -> None:
"""Should set aliyun mirror."""
with patch.object(Path, "home", return_value=tmp_path):
envpy.set_pip_mirror("aliyun")
# Check pip config
pip_dir = tmp_path / "pip"
assert pip_dir.exists()
def test_set_pip_mirror_with_token(self, tmp_path: Path) -> None:
"""Should set mirror with token."""
with patch.object(Path, "home", return_value=tmp_path):
envpy.set_pip_mirror("tsinghua", token="test_token")
# Check that token is set
def test_set_pip_mirror_creates_pip_dir(self, tmp_path: Path) -> None:
"""Should create pip directory if it doesn't exist."""
pip_dir = tmp_path / "pip"
with patch.object(Path, "home", return_value=tmp_path):
envpy.set_pip_mirror("tsinghua")
assert pip_dir.exists()
assert pip_dir.is_dir()
# ---------------------------------------------------------------------- #
# main function
# ---------------------------------------------------------------------- #
class TestMain:
"""Test main function."""
def test_main_mirror_tsinghua(self) -> None:
"""main() should handle mirror tsinghua command."""
with patch("sys.argv", ["envpy", "mirror", "tsinghua"]), patch.object(px, "run") as mock_run, patch.object(
envpy, "set_pip_mirror"
):
envpy.main()
assert mock_run.called
def test_main_mirror_aliyun(self) -> None:
"""main() should handle mirror aliyun command."""
with patch("sys.argv", ["envpy", "mirror", "aliyun"]), patch.object(px, "run") as mock_run, patch.object(
envpy, "set_pip_mirror"
):
envpy.main()
assert mock_run.called
def test_main_mirror_with_token(self) -> None:
"""main() should handle mirror with token."""
with patch("sys.argv", ["envpy", "mirror", "tsinghua", "--token", "test_token"]), patch.object(
px, "run"
) as mock_run, patch.object(envpy, "set_pip_mirror"):
envpy.main()
assert mock_run.called
def test_main_with_no_args_shows_help(self) -> None:
"""main() with no args should show help and return."""
with patch("sys.argv", ["envpy"]):
envpy.main()
# Should print help and return
def test_main_invalid_mirror_shows_error(self) -> None:
"""main() with invalid mirror should show error."""
with patch("sys.argv", ["envpy", "mirror", "invalid"]), pytest.raises(SystemExit) as exc_info:
envpy.main()
assert exc_info.value.code == 2
def test_main_creates_task_spec_with_correct_name(self) -> None:
"""main() should create TaskSpec with correct name."""
with patch("sys.argv", ["envpy", "mirror", "tsinghua"]), patch.object(px, "run") as mock_run, patch.object(
envpy, "set_pip_mirror"
):
envpy.main()
graph = mock_run.call_args[0][0]
task_names = list(graph.all_specs().keys())
assert "set_pip_mirror" in task_names
def test_main_uses_thread_strategy(self) -> None:
"""main() should use thread strategy."""
with patch("sys.argv", ["envpy", "mirror", "tsinghua"]), patch.object(px, "run") as mock_run, patch.object(
envpy, "set_pip_mirror"
):
envpy.main()
assert mock_run.call_args[1]["strategy"] == "thread"
-210
View File
@@ -1,210 +0,0 @@
"""Tests for cli.envrs module."""
from __future__ import annotations
import os
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
import pyflowx as px
from pyflowx.cli import envrs
# ---------------------------------------------------------------------- #
# set_rust_mirror
# ---------------------------------------------------------------------- #
class TestSetRustMirror:
"""Test set_rust_mirror function."""
def test_set_rust_mirror_aliyun(self, tmp_path: Path) -> None:
"""Should set aliyun mirror."""
with patch.object(Path, "home", return_value=tmp_path):
envrs.set_rust_mirror("aliyun")
# Check environment variables
assert os.environ.get("RUSTUP_DIST_SERVER") == "https://mirrors.aliyun.com/rustup"
assert os.environ.get("RUSTUP_UPDATE_ROOT") == "https://mirrors.aliyun.com/rustup/rustup"
# Check cargo config
cargo_config = tmp_path / ".cargo" / "config.toml"
assert cargo_config.exists()
content = cargo_config.read_text()
assert "aliyun" in content
def test_set_rust_mirror_ustc(self, tmp_path: Path) -> None:
"""Should set ustc mirror."""
with patch.object(Path, "home", return_value=tmp_path):
envrs.set_rust_mirror("ustc")
assert os.environ.get("RUSTUP_DIST_SERVER") == "https://mirrors.ustc.edu.cn/rust-static"
assert os.environ.get("RUSTUP_UPDATE_ROOT") == "https://mirrors.ustc.edu.cn/rust-static/rustup"
def test_set_rust_mirror_tsinghua(self, tmp_path: Path) -> None:
"""Should set tsinghua mirror."""
with patch.object(Path, "home", return_value=tmp_path):
envrs.set_rust_mirror("tsinghua")
assert os.environ.get("RUSTUP_DIST_SERVER") == "https://mirrors.tuna.tsinghua.edu.cn/rustup"
assert os.environ.get("RUSTUP_UPDATE_ROOT") == "https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup"
def test_set_rust_mirror_unknown_uses_default(self, tmp_path: Path) -> None:
"""Should use default mirror for unknown mirror name."""
with patch.object(Path, "home", return_value=tmp_path):
# pyrefly: ignore [bad-argument-type]
envrs.set_rust_mirror("unknown")
# Should use default mirror (tsinghua)
assert os.environ.get("RUSTUP_DIST_SERVER") == "https://mirrors.tuna.tsinghua.edu.cn/rustup"
def test_set_rust_mirror_creates_cargo_dir(self, tmp_path: Path) -> None:
"""Should create .cargo directory if it doesn't exist."""
cargo_dir = tmp_path / ".cargo"
with patch.object(Path, "home", return_value=tmp_path):
envrs.set_rust_mirror("aliyun")
assert cargo_dir.exists()
assert cargo_dir.is_dir()
def test_set_rust_mirror_prints_message(self, tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None:
"""Should print mirror name."""
with patch.object(Path, "home", return_value=tmp_path):
envrs.set_rust_mirror("aliyun")
captured = capsys.readouterr()
assert "已设置 Rust 镜像源: aliyun" in captured.out
# ---------------------------------------------------------------------- #
# install_rust
# ---------------------------------------------------------------------- #
class TestInstallRust:
"""Test install_rust function."""
def test_install_rust_stable(self) -> None:
"""Should install stable Rust."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
envrs.install_rust("stable")
mock_run.assert_called_once_with(["rustup", "toolchain", "install", "stable"], check=True)
def test_install_rust_nightly(self) -> None:
"""Should install nightly Rust."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
envrs.install_rust("nightly")
mock_run.assert_called_once_with(["rustup", "toolchain", "install", "nightly"], check=True)
def test_install_rust_beta(self) -> None:
"""Should install beta Rust."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
envrs.install_rust("beta")
mock_run.assert_called_once_with(["rustup", "toolchain", "install", "beta"], check=True)
def test_install_rust_file_not_found(self) -> None:
"""Should raise FileNotFoundError when rustup not found."""
with patch("subprocess.run", side_effect=FileNotFoundError), pytest.raises(FileNotFoundError):
envrs.install_rust("stable")
def test_install_rust_prints_message(self, capsys: pytest.CaptureFixture[str]) -> None:
"""Should print installation message."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
envrs.install_rust("stable")
captured = capsys.readouterr()
assert "已安装 Rust stable" in captured.out
# ---------------------------------------------------------------------- #
# main function
# ---------------------------------------------------------------------- #
class TestMain:
"""Test main function."""
def test_main_mirror_aliyun(self) -> None:
"""main() should handle mirror aliyun command."""
with patch("sys.argv", ["envrs", "mirror", "aliyun"]), patch.object(px, "run") as mock_run, patch.object(
envrs, "set_rust_mirror"
):
envrs.main()
assert mock_run.called
def test_main_mirror_ustc(self) -> None:
"""main() should handle mirror ustc command."""
with patch("sys.argv", ["envrs", "mirror", "ustc"]), patch.object(px, "run") as mock_run, patch.object(
envrs, "set_rust_mirror"
):
envrs.main()
assert mock_run.called
def test_main_mirror_tsinghua(self) -> None:
"""main() should handle mirror tsinghua command."""
with patch("sys.argv", ["envrs", "mirror", "tsinghua"]), patch.object(px, "run") as mock_run, patch.object(
envrs, "set_rust_mirror"
):
envrs.main()
assert mock_run.called
def test_main_mirror_default(self) -> None:
"""main() should use default mirror when not specified."""
with patch("sys.argv", ["envrs", "mirror"]), patch.object(px, "run") as mock_run, patch.object(
envrs, "set_rust_mirror"
):
envrs.main()
assert mock_run.called
def test_main_install_stable(self) -> None:
"""main() should handle install stable command."""
with patch("sys.argv", ["envrs", "install", "stable"]), patch.object(px, "run") as mock_run:
envrs.main()
assert mock_run.called
def test_main_install_nightly(self) -> None:
"""main() should handle install nightly command."""
with patch("sys.argv", ["envrs", "install", "nightly"]), patch.object(px, "run") as mock_run:
envrs.main()
assert mock_run.called
def test_main_install_beta(self) -> None:
"""main() should handle install beta command."""
with patch("sys.argv", ["envrs", "install", "beta"]), patch.object(px, "run") as mock_run:
envrs.main()
assert mock_run.called
def test_main_install_default(self) -> None:
"""main() should use default version when not specified."""
with patch("sys.argv", ["envrs", "install"]), patch.object(px, "run") as mock_run:
envrs.main()
assert mock_run.called
def test_main_with_no_args_shows_help(self) -> None:
"""main() with no args should show help and return."""
with patch("sys.argv", ["envrs"]):
envrs.main()
# Should print help and return
def test_main_invalid_version_shows_error(self) -> None:
"""main() with invalid version should show error."""
with patch("sys.argv", ["envrs", "install", "invalid"]), pytest.raises(SystemExit) as exc_info:
envrs.main()
assert exc_info.value.code == 2
def test_main_invalid_mirror_shows_error(self) -> None:
"""main() with invalid mirror should show error."""
with patch("sys.argv", ["envrs", "mirror", "invalid"]), pytest.raises(SystemExit) as exc_info:
envrs.main()
assert exc_info.value.code == 2
def test_main_creates_task_spec_with_verbose(self) -> None:
"""main() should create TaskSpec with verbose=True."""
with patch("sys.argv", ["envrs", "mirror", "aliyun"]), patch.object(px, "run") as mock_run, patch.object(
envrs, "set_rust_mirror"
):
envrs.main()
graph = mock_run.call_args[0][0]
specs = graph.all_specs()
for spec in specs.values():
assert spec.verbose is True
def test_main_uses_thread_strategy(self) -> None:
"""main() should use thread strategy."""
with patch("sys.argv", ["envrs", "mirror", "aliyun"]), patch.object(px, "run") as mock_run, patch.object(
envrs, "set_rust_mirror"
):
envrs.main()
assert mock_run.call_args[1]["strategy"] == "thread"
+1 -1
View File
@@ -7,7 +7,7 @@ from unittest.mock import patch
import pytest
import pyflowx as px
from pyflowx.cli import taskkill
from pyflowx.cli.system import taskkill
from pyflowx.conditions import Constants
-66
View File
@@ -1,66 +0,0 @@
"""Tests for cli.which module."""
from __future__ import annotations
import shutil
from unittest.mock import patch
import pytest
import pyflowx as px
from pyflowx.cli import which
# ---------------------------------------------------------------------- #
# main function
# ---------------------------------------------------------------------- #
class TestMain:
"""Test main function."""
def test_main_with_single_command(self) -> None:
"""main() should handle single command argument."""
with patch("sys.argv", ["which", "python"]), patch.object(
shutil, "which", return_value="/usr/bin/python"
), patch.object(px, "run") as mock_run:
which.main()
# Should create a graph with one task
assert mock_run.called
graph = mock_run.call_args[0][0]
assert isinstance(graph, px.Graph)
def test_main_with_multiple_commands(self) -> None:
"""main() should handle multiple command arguments."""
with patch("sys.argv", ["which", "python", "pip", "node"]), patch.object(
shutil, "which", return_value="/usr/bin/cmd"
), patch.object(px, "run") as mock_run:
which.main()
# Should create a graph with three tasks
assert mock_run.called
graph = mock_run.call_args[0][0]
assert isinstance(graph, px.Graph)
def test_main_with_no_args_shows_help(self) -> None:
"""main() with no args should show help and exit."""
with patch("sys.argv", ["which"]), pytest.raises(SystemExit) as exc_info:
which.main()
assert exc_info.value.code == 2
def test_main_creates_task_specs_with_correct_names(self) -> None:
"""main() should create TaskSpecs with correct names."""
with patch("sys.argv", ["which", "git", "npm"]), patch.object(
shutil, "which", return_value="/usr/bin/cmd"
), patch.object(px, "run") as mock_run:
which.main()
graph = mock_run.call_args[0][0]
# Check that task names are correct
task_names = list(graph.all_specs().keys())
assert "which_git" in task_names
assert "which_npm" in task_names
def test_main_uses_thread_strategy(self) -> None:
"""main() should use thread strategy."""
with patch("sys.argv", ["which", "python"]), patch.object(
shutil, "which", return_value="/usr/bin/python"
), patch.object(px, "run") as mock_run:
which.main()
assert mock_run.call_args[1]["strategy"] == "thread"
+57 -29
View File
@@ -338,6 +338,63 @@ class TestGraphDefaults:
assert report.success
assert calls["n"] == 3
def test_defaults_strategy_env_cwd(self) -> None:
"""测试strategy、env、cwd字段的继承。"""
defaults = px.GraphDefaults(
strategy="thread",
env={"VAR": "value"},
cwd="/tmp",
)
graph = px.Graph(defaults=defaults)
graph.add(px.TaskSpec("a", lambda: "ok"))
resolved = graph.resolved_spec("a")
assert resolved.strategy == "thread"
assert resolved.env == {"VAR": "value"}
assert resolved.cwd == "/tmp"
def test_defaults_continue_on_error_concurrency_key_verbose(self) -> None:
"""测试continue_on_error、concurrency_key、verbose字段的继承。"""
defaults = px.GraphDefaults(
continue_on_error=True,
concurrency_key="pool",
verbose=True,
)
graph = px.Graph(defaults=defaults)
graph.add(px.TaskSpec("a", lambda: "ok"))
resolved = graph.resolved_spec("a")
assert resolved.continue_on_error is True
assert resolved.concurrency_key == "pool"
assert resolved.verbose is True
def test_defaults_spec_excludes_non_default_values(self) -> None:
"""测试当spec已有非默认值时,不应被defaults覆盖。"""
defaults = px.GraphDefaults(
strategy="thread",
continue_on_error=True,
verbose=True,
priority=5,
)
graph = px.Graph(defaults=defaults)
graph.add(
px.TaskSpec(
"a",
lambda: "ok",
strategy="sequential",
continue_on_error=True, # True是非默认值,不会被覆盖
verbose=True, # True是非默认值,不会被覆盖
priority=10, # 非0值,不会被覆盖
)
)
resolved = graph.resolved_spec("a")
# strategy已有非默认值,不会被覆盖
assert resolved.strategy == "sequential"
# continue_on_error=True不会被defaults覆盖(只有False才会被覆盖)
assert resolved.continue_on_error is True
# verbose=True不会被defaults覆盖(只有False才会被覆盖)
assert resolved.verbose is True
# priority非0值不会被覆盖
assert resolved.priority == 10
# ---------------------------------------------------------------------- #
# 软依赖 soft_depends_on
@@ -449,35 +506,6 @@ class TestDependencyDrivenScheduling:
assert report["b"] == 2
assert report["c"] == 3
def test_dependency_strategy_faster_than_layered(self) -> None:
"""依赖驱动应比层屏障更快(无层等待)。"""
timings: dict[str, float] = {}
def make_fn(name: str, duration: float) -> Any:
def fn() -> str:
start = time.monotonic()
time.sleep(duration)
timings[name] = time.monotonic() - start
return name
return fn
# a (慢) -> b (快) 在同一层
# a (快) -> c (慢) 在同一层
# 依赖驱动:c 在 a 完成后立即启动,不必等 b
graph = px.Graph.from_specs([
px.TaskSpec("a", make_fn("a", 0.05)),
px.TaskSpec("b", make_fn("b", 0.05), depends_on=("a",)),
px.TaskSpec("c", make_fn("c", 0.05), depends_on=("a",)),
px.TaskSpec("d", make_fn("d", 0.01), depends_on=("b", "c")),
])
start = time.monotonic()
report = px.run(graph, strategy="dependency")
elapsed = time.monotonic() - start
assert report.success
# a(0.05) + max(b,c)(0.05) + d(0.01) ≈ 0.11,层屏障会更慢
assert elapsed < 0.20
def test_dependency_strategy_with_async_fn(self) -> None:
async def a() -> str:
await asyncio.sleep(0.01)
+85
View File
@@ -4,8 +4,11 @@ from __future__ import annotations
import os
import sys
from pathlib import Path
from unittest.mock import patch
import pytest
from pyflowx.conditions import (
IS_LINUX,
IS_MACOS,
@@ -216,3 +219,85 @@ def test_logical_combination_with_dep_conditions():
BuiltinConditions.NOT(BuiltinConditions.DEP_TRUTHY("b")),
)
assert cond(ctx) is True
# ---------------------------------------------------------------------- #
# IS_RUNNING: 跨平台 subprocess 检测
# ---------------------------------------------------------------------- #
def test_is_running_windows_found(monkeypatch: pytest.MonkeyPatch):
"""Windows 上 tasklist 检测到进程."""
monkeypatch.setattr("pyflowx.conditions.Constants.IS_WINDOWS", True)
monkeypatch.setattr("pyflowx.conditions.Constants.IS_LINUX", False)
class MockResult:
stdout = "explorer.exe\nother.exe"
returncode = 0
monkeypatch.setattr(
"subprocess.run",
lambda *_, **__: MockResult(),
)
cond = BuiltinConditions.IS_RUNNING("explorer.exe")
assert cond({}) is True
def test_is_running_windows_not_found(monkeypatch: pytest.MonkeyPatch):
"""Windows 上 tasklist 未检测到进程."""
monkeypatch.setattr("pyflowx.conditions.Constants.IS_WINDOWS", True)
monkeypatch.setattr("pyflowx.conditions.Constants.IS_LINUX", False)
class MockResult:
stdout = "other.exe"
returncode = 0
monkeypatch.setattr(
"subprocess.run",
lambda *_, **__: MockResult(),
)
cond = BuiltinConditions.IS_RUNNING("explorer.exe")
assert cond({}) is False
def test_is_running_linux_found(monkeypatch: pytest.MonkeyPatch):
"""Linux 上 pgrep 检测到进程."""
monkeypatch.setattr("pyflowx.conditions.Constants.IS_WINDOWS", False)
monkeypatch.setattr("pyflowx.conditions.Constants.IS_LINUX", True)
class MockResult:
returncode = 0
monkeypatch.setattr(
"subprocess.run",
lambda *_, **__: MockResult(),
)
cond = BuiltinConditions.IS_RUNNING("nginx")
assert cond({}) is True
def test_is_running_linux_not_found(monkeypatch: pytest.MonkeyPatch):
"""Linux 上 pgrep 未检测到进程."""
monkeypatch.setattr("pyflowx.conditions.Constants.IS_WINDOWS", False)
monkeypatch.setattr("pyflowx.conditions.Constants.IS_LINUX", True)
class MockResult:
returncode = 1
monkeypatch.setattr(
"subprocess.run",
lambda *_, **__: MockResult(),
)
cond = BuiltinConditions.IS_RUNNING("nonexistent")
assert cond({}) is False
def test_dir_exists_true(tmp_path: Path):
"""DIR_EXISTS 检测路径存在."""
cond = BuiltinConditions.DIR_EXISTS(tmp_path)
assert cond({}) is True
def test_dir_exists_false(tmp_path: Path):
"""DIR_EXISTS 检测路径不存在."""
missing = tmp_path / "nonexistent"
cond = BuiltinConditions.DIR_EXISTS(missing)
assert cond({}) is False
+41
View File
@@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import logging
import tempfile
import threading
import time
@@ -93,6 +94,46 @@ def test_retries_then_succeeds() -> None:
assert attempts["n"] == 3
def test_retries_with_delay() -> None:
"""测试带delay的重试会实际等待。"""
attempts = {"n": 0}
start_time = time.time()
def flaky() -> str:
attempts["n"] += 1
if attempts["n"] < 2:
raise RuntimeError("not yet")
return "ok"
graph = px.Graph.from_specs([
px.TaskSpec("flaky", flaky, retry=px.RetryPolicy(max_attempts=2, delay=0.1)),
])
report = px.run(graph, strategy="sequential")
elapsed = time.time() - start_time
assert report.success
assert elapsed >= 0.1 # 应有至少0.1秒的等待时间
assert attempts["n"] == 2
def test_timeout_then_retry_async(caplog: pytest.LogCaptureFixture) -> None:
"""测试超时后可以重试,并记录warning日志。"""
async def slow_task() -> str:
await asyncio.sleep(10) # 会触发超时
return "ok"
graph = px.Graph.from_specs([
px.TaskSpec("slow", slow_task, timeout=0.2, retry=px.RetryPolicy(max_attempts=2)),
])
with caplog.at_level(logging.WARNING, logger="pyflowx"):
with pytest.raises(px.TaskFailedError) as exc_info:
_ = px.run(graph, strategy="async")
assert exc_info.value.attempts == 2
assert "timed out" in str(exc_info.value.cause)
# 应有超时重试的warning日志
assert any("timed out" in r.message for r in caplog.records)
def test_retries_exhausted() -> None:
def always_fail() -> None:
raise RuntimeError("nope")
+309
View File
@@ -1,7 +1,11 @@
"""Tests for executors module edge cases."""
from __future__ import annotations
import asyncio
import logging
import sys
from typing import Callable
import pytest
@@ -251,3 +255,308 @@ def test_execute_async_with_error():
with pytest.raises(px.TaskFailedError):
px.run(graph, strategy="async")
# ---------------------------------------------------------------------- #
# _check_upstream_skipped 分支测试
# ---------------------------------------------------------------------- #
def test_allow_upstream_skip_allows_execution_after_skipped() -> None:
"""allow_upstream_skip=True 时上游被 SKIPPED 后本任务仍执行."""
never_true = lambda _ctx: False # noqa: E731
def downstream_task() -> str:
return "ran despite upstream skipped"
graph = px.Graph.from_specs([
px.TaskSpec("upstream", fn=lambda: "up", conditions=(never_true,)),
px.TaskSpec("downstream", fn=downstream_task, depends_on=("upstream",), allow_upstream_skip=True),
])
report = px.run(graph, strategy="sequential")
assert report.success
assert report.results["upstream"].status == TaskStatus.SKIPPED
assert report.results["downstream"].status == TaskStatus.SUCCESS
assert report["downstream"] == "ran despite upstream skipped"
def test_upstream_failed_skips_downstream() -> None:
"""上游 FAILED 时下游被 SKIPPED(除非 allow_upstream_skip=True."""
def boom():
raise ValueError("boom")
def downstream():
return "should not run"
graph = px.Graph.from_specs([
px.TaskSpec("upstream", fn=boom),
px.TaskSpec("downstream", fn=downstream, depends_on=("upstream",)),
])
with pytest.raises(px.TaskFailedError):
px.run(graph, strategy="sequential")
# ---------------------------------------------------------------------- #
# _evaluate_conditions 多条件分支测试
# ---------------------------------------------------------------------- #
def test_multiple_conditions_failure_truncation() -> None:
"""超过 2 个条件失败时应截断显示."""
spec = px.TaskSpec(
"multi_skip",
fn=lambda: "result",
conditions=(lambda _ctx: False, lambda _ctx: False, lambda _ctx: False, lambda _ctx: False, lambda _ctx: False),
)
graph = px.Graph.from_specs([spec])
report = px.run(graph, strategy="sequential", verbose=True)
assert report.success
assert report.results["multi_skip"].status == TaskStatus.SKIPPED
# reason 应显示 "条件不满足: <lambda>, <lambda> 等5个条件"
# ---------------------------------------------------------------------- #
# concurrency_key 测试
# ---------------------------------------------------------------------- #
def test_concurrency_key_sequential() -> None:
"""sequential 策略下 concurrency_key 无效果."""
spec = px.TaskSpec("a", fn=lambda: 1, concurrency_key="group1")
graph = px.Graph.from_specs([spec])
report = px.run(graph, strategy="sequential", concurrency_limits={"group1": 1})
assert report.success
def test_concurrency_key_thread() -> None:
"""thread 策略下 concurrency_key 应限制并发."""
import time
order = []
def make(name: str) -> Callable[[], str]:
def fn():
order.append(f"{name}-start")
time.sleep(0.1)
order.append(f"{name}-end")
return name
return fn
graph = px.Graph.from_specs([
px.TaskSpec("a", fn=make("a"), concurrency_key="group1"),
px.TaskSpec("b", fn=make("b"), concurrency_key="group1"),
px.TaskSpec("c", fn=make("c"), concurrency_key="group1"),
])
report = px.run(graph, strategy="thread", max_workers=10, concurrency_limits={"group1": 1})
assert report.success
# 由于 concurrency_key 限制为 1,任务应串行执行
# 验证顺序:每个任务的 start-end 应连续
# 可能顺序:a-start, a-end, b-start, b-end, c-start, c-end
def test_concurrency_key_async() -> None:
"""async 策略下 concurrency_key 应限制并发."""
import asyncio
async def task_a():
await asyncio.sleep(0.01)
return "a"
async def task_b():
await asyncio.sleep(0.01)
return "b"
graph = px.Graph.from_specs([
px.TaskSpec("a", fn=task_a, concurrency_key="group1"),
px.TaskSpec("b", fn=task_b, concurrency_key="group1"),
])
report = px.run(graph, strategy="async", concurrency_limits={"group1": 1})
assert report.success
# ---------------------------------------------------------------------- #
# dependency 策略测试
# ---------------------------------------------------------------------- #
def test_dependency_strategy_basic() -> None:
"""dependency 策略应正确执行."""
order = []
def make(name: str) -> Callable[[], str]:
def fn():
order.append(name)
return name
return fn
graph = px.Graph.from_specs([
px.TaskSpec("a", fn=make("a")),
px.TaskSpec("b", fn=make("b"), depends_on=("a",)),
px.TaskSpec("c", fn=make("c"), depends_on=("a",)),
px.TaskSpec("d", fn=make("d"), depends_on=("b", "c")),
])
report = px.run(graph, strategy="dependency")
assert report.success
assert "a" in order
assert "d" in order
def test_dependency_strategy_async() -> None:
"""dependency 策略下异步任务应正确执行."""
async def a():
return "a"
async def b(a: str):
return a + "b"
graph = px.Graph.from_specs([
px.TaskSpec("a", fn=a),
px.TaskSpec("b", fn=b, depends_on=("a",)),
])
report = px.run(graph, strategy="dependency")
assert report.success
assert report["b"] == "ab"
# ---------------------------------------------------------------------- #
# continue_on_error 测试
# ---------------------------------------------------------------------- #
def test_continue_on_error_marks_failed_but_continues() -> None:
"""continue_on_error=True 时任务失败不抛异常,但 report.success 为 True(无 TaskFailedError 抛出)。"""
def boom():
raise ValueError("boom")
graph = px.Graph.from_specs([
px.TaskSpec("fail", fn=boom, continue_on_error=True),
px.TaskSpec("other", fn=lambda: "ok"), # 无依赖,应继续
])
# continue_on_error=True 时 run 不抛异常,report.success 为 True
report = px.run(graph, strategy="sequential")
# report.success 为 True 因为没有抛 TaskFailedError
assert report.success # 因为 continue_on_error 阻止了 TaskFailedError
assert report.results["fail"].status == TaskStatus.FAILED
assert report.results["other"].status == TaskStatus.SUCCESS
def test_continue_on_error_downstream_skipped() -> None:
"""continue_on_error=True 时失败任务的下游被 SKIPPEDallow_upstream_skip=False 时)。"""
def boom():
raise ValueError("boom")
def downstream():
return "should not run"
graph = px.Graph.from_specs([
px.TaskSpec("fail", fn=boom, continue_on_error=True),
px.TaskSpec("dep", fn=downstream, depends_on=("fail",), allow_upstream_skip=False),
])
report = px.run(graph, strategy="sequential")
# report.success 为 True 因为 continue_on_error 阻止了 TaskFailedError
assert report.success
assert report.results["fail"].status == TaskStatus.FAILED
assert report.results["dep"].status == TaskStatus.SKIPPED
# ---------------------------------------------------------------------- #
# soft_depends_on 默认值注入测试
# ---------------------------------------------------------------------- #
def test_soft_depends_on_default_value_injection() -> None:
"""软依赖存在且成功时注入其结果值(参数名需与依赖名一致)。"""
def task_with_soft_dep(a: str | None = None) -> str:
return f"a={a}"
graph = px.Graph.from_specs([
px.TaskSpec("a", fn=lambda: "value"),
px.TaskSpec("b", fn=task_with_soft_dep, soft_depends_on=("a",)),
])
report = px.run(graph, strategy="sequential")
assert report.success
assert report["b"] == "a=value"
def test_soft_depends_on_skipped_injects_none() -> None:
"""软依赖被 SKIPPED 时注入 None(参数名需与依赖名一致)。"""
never_true = lambda _ctx: False # noqa: E731
def task_with_soft_dep(skipped: str | None = None) -> str:
return f"skipped={skipped}"
graph = px.Graph.from_specs([
px.TaskSpec("skipped", fn=lambda: "value", conditions=(never_true,)),
px.TaskSpec("b", fn=task_with_soft_dep, soft_depends_on=("skipped",)),
])
report = px.run(graph, strategy="sequential")
assert report.success
# 软依赖被 skipped 时注入 None(因为 global_context 中有 skipped,值为 None
assert report["b"] == "skipped=None"
# ---------------------------------------------------------------------- #
# hooks 异常处理测试
# ---------------------------------------------------------------------- #
def test_hooks_pre_run_exception_logged(caplog: pytest.LogCaptureFixture) -> None:
"""pre_run hook 抛异常应被记录但不影响任务."""
def bad_hook(_spec):
raise RuntimeError("hook error")
hooks = px.TaskHooks(pre_run=bad_hook)
spec = px.TaskSpec("a", fn=lambda: "ok", hooks=hooks)
graph = px.Graph.from_specs([spec])
with caplog.at_level(logging.WARNING, logger="pyflowx"):
report = px.run(graph, strategy="sequential")
assert report.success
assert any("hook" in r.message for r in caplog.records)
def test_hooks_post_run_exception_logged(caplog: pytest.LogCaptureFixture) -> None:
"""post_run hook 抛异常应被记录但不影响任务."""
def bad_hook(_spec, _value):
raise RuntimeError("post hook error")
hooks = px.TaskHooks(post_run=bad_hook)
spec = px.TaskSpec("a", fn=lambda: "ok", hooks=hooks)
graph = px.Graph.from_specs([spec])
with caplog.at_level(logging.WARNING, logger="pyflowx"):
report = px.run(graph, strategy="sequential")
assert report.success
assert any("hook" in r.message for r in caplog.records)
def test_hooks_on_failure_exception_logged(caplog: pytest.LogCaptureFixture) -> None:
"""on_failure hook 抛异常应被记录但不影响任务."""
def bad_hook(_spec, _exc):
raise RuntimeError("failure hook error")
hooks = px.TaskHooks(on_failure=bad_hook)
spec = px.TaskSpec("a", fn=lambda: (_ for _ in ()).throw(ValueError("task error")), hooks=hooks)
graph = px.Graph.from_specs([spec])
with caplog.at_level(logging.WARNING, logger="pyflowx"), pytest.raises(px.TaskFailedError):
px.run(graph, strategy="sequential")
assert any("hook" in r.message for r in caplog.records)
# ---------------------------------------------------------------------- #
# unknown strategy 测试
# ---------------------------------------------------------------------- #
def test_unknown_strategy_raises() -> None:
"""未知 strategy 应抛 ValueError."""
graph = px.Graph.from_specs([px.TaskSpec("a", fn=lambda: 1)])
with pytest.raises(ValueError, match="Unknown strategy"):
# pyrefly: ignore [bad-argument-type]
px.run(graph, strategy="unknown_strategy")
# ---------------------------------------------------------------------- #
# 空图测试
# ---------------------------------------------------------------------- #
def test_empty_graph_dependency_strategy() -> None:
"""dependency 策略下空图应正常返回."""
graph = px.Graph()
report = px.run(graph, strategy="dependency")
assert report.success
assert len(report) == 0
+126
View File
@@ -6,6 +6,7 @@ import pytest
import pyflowx as px
from pyflowx.errors import CycleError, DuplicateTaskError, MissingDependencyError
from pyflowx.graph import GraphComposer, compose
def _fn() -> None:
@@ -161,6 +162,19 @@ def test_all_specs_returns_view() -> None:
assert view is graph.all_specs() or view == graph.all_specs()
def test_all_deps_combines_hard_and_soft() -> None:
"""all_deps 应返回硬依赖 + 软依赖的组合。"""
graph = px.Graph.from_specs([
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn),
px.TaskSpec("c", _fn, depends_on=("a",), soft_depends_on=("b",)),
])
all_deps = graph.all_deps("c")
assert set(all_deps) == {"a", "b"}
# 硬依赖在前,软依赖在后
assert all_deps == ("a", "b")
def test_spec_accessor() -> None:
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
assert graph.spec("a").name == "a"
@@ -213,3 +227,115 @@ def test_subgraph_by_tags_no_match() -> None:
graph = px.Graph.from_specs([px.TaskSpec("a", _fn, tags=("x",))])
sub = graph.subgraph(["z"])
assert len(sub) == 0
# ---------------------------------------------------------------------- #
# from_specs str 类型分支测试
# ---------------------------------------------------------------------- #
def test_from_specs_with_string_ref() -> None:
"""from_specs 接受字符串引用并收集到 pending_refs."""
# 字符串引用被收集到 _pending_refs,而非尝试打开文件
graph = px.Graph.from_specs(["ref_cmd"])
assert graph._pending_refs == ["ref_cmd"]
def test_from_specs_with_invalid_type() -> None:
"""from_specs 接受不支持的类型时应抛 TypeError."""
with pytest.raises(TypeError, match="from_specs 只接受 TaskSpec 或 str"):
_ = px.Graph.from_specs([123]) # type: ignore[list-item]
# ---------------------------------------------------------------------- #
# to_mermaid 软依赖测试
# ---------------------------------------------------------------------- #
def test_to_mermaid_soft_depends_on() -> None:
"""to_mermaid 应正确绘制软依赖为虚线."""
graph = px.Graph.from_specs([
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn, soft_depends_on=("a",)),
])
mermaid = graph.to_mermaid()
assert "a -.-> b" in mermaid # 软依赖用虚线
# ---------------------------------------------------------------------- #
# GraphComposer 与 compose 测试
# ---------------------------------------------------------------------- #
def test_graph_composer_resolve_all() -> None:
"""GraphComposer.resolve_all 应展开所有图的字符串引用."""
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn), px.TaskSpec("a2", _fn, depends_on=("a1",))])
# 创建带 _pending_refs 的图
graph_b = px.Graph.from_specs([px.TaskSpec("b1", _fn)])
graph_b._pending_refs = ["cmd_a"] # 手动设置内部属性
composer = GraphComposer({"cmd_a": graph_a, "cmd_b": graph_b})
resolved = composer.resolve_all()
# graph_b 应包含 graph_a 的任务
assert "a1" in resolved["cmd_b"]
assert "a2" in resolved["cmd_b"]
def test_graph_composer_parse_ref_self_reference() -> None:
"""GraphComposer.parse_ref 应检测循环引用."""
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
composer = GraphComposer({"cmd": graph})
with pytest.raises(ValueError, match="循环引用"):
_ = composer.parse_ref("cmd", "cmd")
def test_graph_composer_parse_ref_cmd_not_found() -> None:
"""GraphComposer.parse_ref 应检测引用的命令不存在."""
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
composer = GraphComposer({"cmd": graph})
with pytest.raises(ValueError, match="引用的命令 'missing' 不存在"):
_ = composer.parse_ref("missing", "current")
def test_graph_composer_parse_ref_task_not_found() -> None:
"""GraphComposer.parse_ref 应检测任务不存在于引用的命令中."""
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn)])
graph_b = px.Graph.from_specs([px.TaskSpec("b1", _fn)])
composer = GraphComposer({"cmd_a": graph_a, "cmd_b": graph_b})
with pytest.raises(ValueError, match="任务 'missing' 不存在于命令 'cmd_a'"):
_ = composer.parse_ref("cmd_a.missing", "cmd_b")
def test_graph_composer_expand_refs_no_pending() -> None:
"""GraphComposer.expand_refs 无 pending_refs 时应原样返回."""
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
composer = GraphComposer({"cmd": graph})
expanded = composer.expand_refs(graph, "cmd")
assert expanded is graph
def test_compose_function() -> None:
"""compose() 函数应等同于 GraphComposer().resolve_all()。"""
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn)])
graph_b = px.Graph.from_specs([px.TaskSpec("b1", _fn)])
graph_b._pending_refs = ["cmd_a"] # 手动设置内部属性
resolved = compose({"cmd_a": graph_a, "cmd_b": graph_b})
assert "a1" in resolved["cmd_b"]
# ---------------------------------------------------------------------- #
# resolved_spec defaults 测试
# ---------------------------------------------------------------------- #
def test_resolved_spec_applies_defaults() -> None:
"""resolved_spec 应应用 Graph.defaults。"""
defaults = px.GraphDefaults(timeout=10.0, retry=px.RetryPolicy(max_attempts=2))
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)], defaults=defaults)
resolved = graph.resolved_spec("a")
assert resolved.timeout == 10.0
assert resolved.retry.max_attempts == 2
def test_resolved_spec_no_override() -> None:
"""resolved_spec 不应覆盖任务已有的设置。"""
defaults = px.GraphDefaults(timeout=10.0)
graph = px.Graph.from_specs([px.TaskSpec("a", _fn, timeout=5.0)], defaults=defaults)
resolved = graph.resolved_spec("a")
assert resolved.timeout == 5.0 # 保持原值,不被 defaults 覆盖
+144
View File
@@ -5,6 +5,7 @@ from __future__ import annotations
import json
import os
import tempfile
import time
from pathlib import Path
from typing import Any
@@ -43,6 +44,46 @@ def test_memory_backend_get_missing_raises() -> None:
b.get("nope")
def test_memory_backend_ttl_expired() -> None:
"""MemoryBackend TTL 过期后 has/get 返回 False/抛 KeyError."""
b = MemoryBackend(ttl=0.1) # 0.1 秒过期
b.save("a", 1)
assert b.has("a")
time.sleep(0.15)
assert not b.has("a")
with pytest.raises(KeyError):
b.get("a")
def test_memory_backend_ttl_load_filters_expired() -> None:
"""MemoryBackend.load() 应过滤过期的条目."""
b = MemoryBackend(ttl=0.1)
b.save("a", 1)
b.save("b", 2)
time.sleep(0.15)
# a 过期,但 b 也要过期... 需要更精确控制
# 使用 monkeypatch 更可控
b._store["expired"] = ("value", time.monotonic() - 100) # 手动设置过期时间
b._store["fresh"] = ("value2", time.monotonic())
assert "expired" not in dict(b.load())
assert "fresh" in dict(b.load())
def test_memory_backend_expired_key_not_in_store() -> None:
"""_expired 对不存在键返回 False."""
b = MemoryBackend(ttl=1.0)
assert b._expired("nonexistent") is False
def test_memory_backend_no_ttl_never_expired() -> None:
"""无 TTL 时永不过期."""
b = MemoryBackend()
b.save("a", 1)
b._store["a"] = (1, time.monotonic() - 1000) # 手动设置很久以前的存储
assert b.has("a") # 仍然存在
assert b.get("a") == 1
# ---------------------------------------------------------------------- #
# JSONBackend
# ---------------------------------------------------------------------- #
@@ -150,6 +191,109 @@ def test_json_backend_non_dict_content_ignored(tmp_path: Path) -> None:
assert dict(b.load()) == {}
def test_json_backend_old_format_migration(tmp_path: Path) -> None:
"""旧格式JSON(纯值)应被迁移为新格式(带ts)。"""
path = tmp_path / "state.json"
# 写入旧格式:纯值
old_data = {"a": 1, "b": "value"}
_ = path.write_text(json.dumps(old_data))
b = JSONBackend(str(path))
# 读取后应有ts字段
assert "a" in b._store
assert "value" in b._store["a"]
assert "ts" in b._store["a"]
assert b._store["a"]["value"] == 1
# ---------------------------------------------------------------------- #
# JSONBackend TTL 测试
# ---------------------------------------------------------------------- #
def test_json_backend_ttl_expired_has_returns_false() -> None:
"""JSONBackend TTL 过期后 has 返回 False."""
with tempfile.TemporaryDirectory() as tmp:
path = str(Path(tmp) / "state.json")
b = JSONBackend(path, ttl=0.1)
b.save("a", 1)
assert b.has("a")
time.sleep(0.15)
assert not b.has("a")
def test_json_backend_ttl_expired_get_raises_keyerror() -> None:
"""JSONBackend TTL 过期后 get 抛 KeyError."""
with tempfile.TemporaryDirectory() as tmp:
path = str(Path(tmp) / "state.json")
b = JSONBackend(path, ttl=0.1)
b.save("a", 1)
time.sleep(0.15)
with pytest.raises(KeyError):
b.get("a")
def test_json_backend_ttl_load_filters_expired() -> None:
"""JSONBackend.load() 应过滤过期的条目."""
with tempfile.TemporaryDirectory() as tmp:
path = str(Path(tmp) / "state.json")
b = JSONBackend(path, ttl=0.1)
b.save("a", 1)
b.save("b", 2)
time.sleep(0.15)
# 两个都过期了
assert dict(b.load()) == {}
def test_json_backend_expired_no_ttl() -> None:
"""无 TTL 时 _expired 返回 False."""
with tempfile.TemporaryDirectory() as tmp:
path = str(Path(tmp) / "state.json")
b = JSONBackend(path)
b.save("a", 1)
# 手动修改 ts 为很久以前
b._store["a"]["ts"] = time.time() - 1000
assert b._expired(b._store["a"]) is False # 无 TTL,永不过期
def test_json_backend_expired_with_ttl() -> None:
"""有 TTL 时 _expired 检查是否过期."""
with tempfile.TemporaryDirectory() as tmp:
path = str(Path(tmp) / "state.json")
b = JSONBackend(path, ttl=1.0)
b.save("a", 1)
# 手动修改 ts 为很久以前
b._store["a"]["ts"] = time.time() - 10 # 10 秒前,超过 TTL
assert b._expired(b._store["a"]) is True
def test_json_backend_expired_missing_ts() -> None:
"""entry 缺少 ts 时使用默认值 0."""
with tempfile.TemporaryDirectory() as tmp:
path = str(Path(tmp) / "state.json")
b = JSONBackend(path, ttl=1.0)
b._store["a"] = {"value": 1} # 缺少 ts
# ts 默认为 0,已经过了很久
assert b._expired(b._store["a"]) is True
def test_json_backend_save_value_error(monkeypatch: pytest.MonkeyPatch) -> None:
"""save 时 json.dumps 抛 ValueError 应转为 StorageError."""
import json as _json
with tempfile.TemporaryDirectory() as tmp:
path = str(Path(tmp) / "state.json")
b = JSONBackend(path)
original_dumps = _json.dumps
def flaky_dumps(*_args: Any, **_kwargs: Any) -> str:
raise ValueError("simulated dumps failure")
monkeypatch.setattr(_json, "dumps", flaky_dumps)
with pytest.raises(StorageError, match="not JSON-serialisable"):
b.save("a", 1)
monkeypatch.setattr(_json, "dumps", original_dumps)
# ---------------------------------------------------------------------- #
# resolve_backend
# ---------------------------------------------------------------------- #
+191
View File
@@ -0,0 +1,191 @@
"""Tests for tasks/system.py."""
import os
import subprocess
import pytest
from pyflowx.conditions import Constants
from pyflowx.tasks.system import clr, reset_icon_cache, setenv, which
def test_clr_creates_task_spec() -> None:
"""clr() 应创建 TaskSpec。"""
spec = clr()
assert spec.name == "clear_screen"
assert spec.fn is not None
def test_clr_executes_on_linux(monkeypatch: pytest.MonkeyPatch) -> None:
"""clr() 在 Linux 上应执行 clear 命令。"""
monkeypatch.setattr(Constants, "IS_WINDOWS", False)
monkeypatch.setattr(Constants, "IS_LINUX", True)
# Mock subprocess.run
ran = []
monkeypatch.setattr(
subprocess,
"run",
lambda *cmd, **__: ran.append(cmd),
)
spec = clr()
assert spec.fn is not None
spec.fn()
assert ran == [(["clear"],)]
def test_clr_executes_on_windows(monkeypatch: pytest.MonkeyPatch) -> None:
"""clr() 在 Windows 上应执行 cls 命令。"""
monkeypatch.setattr(Constants, "IS_WINDOWS", True)
# Mock subprocess.run
ran = []
monkeypatch.setattr(
subprocess,
"run",
lambda *cmd, **__: ran.append(cmd),
)
spec = clr()
assert spec.fn is not None
spec.fn()
assert ran == [(["cls"],)]
def test_reset_icon_cache_non_windows(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None:
"""reset_icon_cache() 在非 Windows 上应返回空列表并打印提示。"""
monkeypatch.setattr(Constants, "IS_WINDOWS", False)
specs = reset_icon_cache()
assert specs == []
captured = capsys.readouterr()
assert "仅在 Windows 上支持" in captured.out
def test_reset_icon_cache_windows(monkeypatch: pytest.MonkeyPatch) -> None:
"""reset_icon_cache() 在 Windows 上应返回任务列表。"""
monkeypatch.setattr(Constants, "IS_WINDOWS", True)
monkeypatch.setenv("LOCALAPPDATA", "C:\\Users\\test\\AppData\\Local")
specs = reset_icon_cache()
assert len(specs) == 4
assert specs[0].name == "kill_explorer"
assert specs[1].name == "delete_icon_cache"
assert specs[2].name == "delete_icon_cache_all"
assert specs[3].name == "restart_explorer"
def test_setenv_creates_task_spec() -> None:
"""setenv() 应创建 TaskSpec。"""
spec = setenv("TEST_VAR", "test_value")
assert spec.name == "setenv_test_var"
assert spec.verbose is True
def test_setenv_sets_environment_variable(monkeypatch: pytest.MonkeyPatch) -> None:
"""setenv() 应设置环境变量。"""
spec = setenv("PYFLOWX_TEST_VAR_1", "test_value")
assert spec.fn is not None
spec.fn()
assert os.environ["PYFLOWX_TEST_VAR_1"] == "test_value"
# Clean up
del os.environ["PYFLOWX_TEST_VAR_1"]
def test_setenv_default_not_overwrite(monkeypatch: pytest.MonkeyPatch) -> None:
"""setenv(default=True) 不应覆盖已存在的环境变量。"""
os.environ["PYFLOWX_TEST_VAR_EXISTS"] = "original"
spec = setenv("PYFLOWX_TEST_VAR_EXISTS", "new_value", default=True)
assert spec.fn is not None
spec.fn()
assert os.environ["PYFLOWX_TEST_VAR_EXISTS"] == "original"
# Clean up
del os.environ["PYFLOWX_TEST_VAR_EXISTS"]
def test_setenv_default_sets_when_missing() -> None:
"""setenv(default=True) 应在缺失时设置环境变量。"""
# Ensure variable does not exist
var_name = "PYFLOWX_TEST_VAR_MISSING"
if var_name in os.environ:
del os.environ[var_name]
spec = setenv(var_name, "default_value", default=True)
assert spec.fn is not None
spec.fn()
assert os.environ[var_name] == "default_value"
# Clean up after test
del os.environ[var_name]
def test_which_creates_task_spec() -> None:
"""which() 应创建 TaskSpec。"""
spec = which("python")
assert spec.name == "which_python"
def test_which_linux_found(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None:
"""which() 在 Linux 上找到命令应打印路径。"""
monkeypatch.setattr(Constants, "IS_WINDOWS", False)
class MockResult:
returncode = 0
stdout = "/usr/bin/python\n"
monkeypatch.setattr(
subprocess,
"run",
lambda *_, **__: MockResult(),
)
spec = which("python")
assert spec.fn is not None
spec.fn()
captured = capsys.readouterr()
assert "python ->" in captured.out
assert "/usr/bin/python" in captured.out
def test_which_windows_found(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None:
"""which() 在 Windows 上找到命令应打印路径。"""
monkeypatch.setattr(Constants, "IS_WINDOWS", True)
class MockResult:
returncode = 0
stdout = "C:\\Python\\python.exe\nC:\\Python\\Scripts\\python.exe\n"
monkeypatch.setattr(
subprocess,
"run",
lambda *_, **__: MockResult(),
)
spec = which("python")
assert spec.fn is not None
spec.fn()
captured = capsys.readouterr()
assert "python ->" in captured.out
assert "C:\\Python\\python.exe" in captured.out
def test_which_not_found(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None:
"""which() 未找到命令应打印提示。"""
monkeypatch.setattr(Constants, "IS_WINDOWS", False)
class MockResult:
returncode = 1
stdout = ""
monkeypatch.setattr(
subprocess,
"run",
lambda *_, **__: MockResult(),
)
spec = which("nonexistent_cmd")
assert spec.fn is not None
spec.fn()
captured = capsys.readouterr()
assert "nonexistent_cmd -> 未找到" in captured.out
+306 -1
View File
@@ -2,11 +2,20 @@
from __future__ import annotations
import os
from datetime import datetime
from pathlib import Path
import pytest
from pyflowx.task import RetryPolicy, TaskResult, TaskSpec, TaskStatus
from pyflowx.task import (
RetryPolicy,
TaskResult,
TaskSpec,
TaskStatus,
_env_and_cwd,
task_template,
)
def _fn() -> None:
@@ -28,11 +37,283 @@ def test_spec_zero_timeout_rejected() -> None:
TaskSpec("a", _fn, timeout=0)
def test_spec_negative_timeout_rejected() -> None:
"""负数timeout应被拒绝。"""
with pytest.raises(ValueError, match="timeout"):
TaskSpec("a", _fn, timeout=-1.0)
def test_spec_self_dependency_rejected() -> None:
with pytest.raises(ValueError, match="depend on itself"):
TaskSpec("a", _fn, depends_on=("a",))
def test_spec_self_soft_dependency_rejected() -> None:
"""self dependency via soft_depends_on 也应被拒绝."""
with pytest.raises(ValueError, match="depend on itself"):
TaskSpec("a", _fn, soft_depends_on=("a",))
def test_spec_overlap_depends_rejected() -> None:
"""depends_on 和 soft_depends_on 重叠应被拒绝."""
with pytest.raises(ValueError, match="不能重叠"):
TaskSpec("a", _fn, depends_on=("b",), soft_depends_on=("b",))
# ---------------------------------------------------------------------- #
# RetryPolicy 参数验证
# ---------------------------------------------------------------------- #
def test_retry_policy_negative_delay_rejected() -> None:
with pytest.raises(ValueError, match="delay must be >= 0"):
RetryPolicy(delay=-1)
def test_retry_policy_negative_backoff_rejected() -> None:
with pytest.raises(ValueError, match="backoff must be >= 0"):
RetryPolicy(backoff=-1)
def test_retry_policy_negative_jitter_rejected() -> None:
with pytest.raises(ValueError, match="jitter must be >= 0"):
RetryPolicy(jitter=-1)
def test_retry_policy_retries_property() -> None:
policy = RetryPolicy(max_attempts=3)
assert policy.retries == 2
def test_retry_policy_should_retry_matching() -> None:
policy = RetryPolicy(max_attempts=3, retry_on=(ValueError,))
assert policy.should_retry(ValueError("x")) is True
assert policy.should_retry(RuntimeError("x")) is False
def test_retry_policy_should_retry_empty_tuple() -> None:
"""空元组等价于不重试."""
policy = RetryPolicy(max_attempts=3, retry_on=())
assert policy.should_retry(ValueError("x")) is False
def test_retry_policy_wait_seconds_zero_attempt() -> None:
"""attempt < 1 时返回 0."""
policy = RetryPolicy(delay=1.0, backoff=2.0)
assert policy.wait_seconds(0) == 0.0
assert policy.wait_seconds(-1) == 0.0
def test_retry_policy_wait_seconds_with_backoff() -> None:
"""有 backoff 时等待时间应递增."""
policy = RetryPolicy(delay=1.0, backoff=2.0)
# attempt=1: delay * backoff^0 = 1
# attempt=2: delay * backoff^1 = 2
assert policy.wait_seconds(1) == 1.0
assert policy.wait_seconds(2) == 2.0
def test_retry_policy_wait_seconds_with_jitter() -> None:
"""有 jitter 时等待时间应增加随机量."""
policy = RetryPolicy(delay=1.0, jitter=0.5)
# 多次调用验证结果在合理范围内
for _ in range(5):
wait = policy.wait_seconds(1)
assert 1.0 <= wait <= 1.5
# ---------------------------------------------------------------------- #
# should_execute 条件异常处理
# ---------------------------------------------------------------------- #
def test_should_execute_condition_exception_returns_false() -> None:
"""条件执行抛异常时应返回 False 并记录原因."""
def bad_condition(_ctx):
raise RuntimeError("condition error")
bad_condition.__name__ = ""
spec = TaskSpec("a", _fn, conditions=(bad_condition,))
should_run, reason = spec.should_execute({})
assert should_run is False
# pyrefly: ignore [not-iterable]
assert "匿名条件(执行错误)" in reason
def test_should_execute_condition_lambda_name() -> None:
"""lambda 条件有 __name__ 为 '<lambda>'."""
spec = TaskSpec("a", _fn, conditions=(lambda _ctx: False,))
should_run, reason = spec.should_execute({})
assert should_run is False
# pyrefly: ignore [not-iterable]
assert "<lambda>" in reason
def test_should_execute_skip_if_missing_cmd_not_found() -> None:
"""skip_if_missing 且命令不存在时应跳过."""
spec = TaskSpec("a", cmd=["nonexistent_cmd_xyz"], skip_if_missing=True)
should_run, reason = spec.should_execute({})
assert should_run is False
# pyrefly: ignore [not-iterable]
assert "命令不存在" in reason
def test_should_execute_skip_if_missing_cmd_found() -> None:
"""skip_if_missing 但命令存在时应执行."""
# 使用 Python 作为已安装的命令
spec = TaskSpec("a", cmd=["echo"], skip_if_missing=True) # echo 应存在
should_run, reason = spec.should_execute({})
assert should_run is True
assert reason is None
def test_should_execute_skip_if_missing_non_list_cmd() -> None:
"""skip_if_missing 对非 list 命令不影响."""
spec = TaskSpec("a", cmd="echo hello", skip_if_missing=True)
should_run, reason = spec.should_execute({})
assert should_run is True
assert reason is None
def test_should_execute_skip_if_missing_empty_list() -> None:
"""skip_if_missing 对空列表命令返回 True."""
spec = TaskSpec("a", cmd=[], skip_if_missing=True)
# 空 list 不检查
_should_run, _reason = spec.should_execute({})
# 因为 cmd=[] 且 fn=None,这会在 __post_init__ 中抛异常
# 所以这个测试无效,我们用另一个方式测试 _is_cmd_available
def test_is_cmd_available_empty_list_returns_true() -> None:
"""_is_cmd_available 对空列表返回 True."""
spec = TaskSpec("a", cmd=[], fn=_fn) # 提供 fn 避免 __post_init__ 异常
assert spec._is_cmd_available() is True
def test_is_cmd_available_string_returns_true() -> None:
"""_is_cmd_available 对字符串命令返回 True."""
spec = TaskSpec("a", cmd="echo hello")
assert spec._is_cmd_available() is True
def test_is_cmd_available_callable_returns_true() -> None:
"""_is_cmd_available 对可调用命令返回 True."""
spec = TaskSpec("a", cmd=_fn)
assert spec._is_cmd_available() is True
# ---------------------------------------------------------------------- #
# storage_key 异常处理
# ---------------------------------------------------------------------- #
def test_storage_key_cache_key_exception_returns_name() -> None:
"""cache_key 抛异常时应返回任务名."""
def bad_cache_key(_ctx):
raise RuntimeError("cache key error")
spec = TaskSpec("a", _fn, cache_key=bad_cache_key)
key = spec.storage_key({})
assert key == "a"
def test_storage_key_cache_key_success() -> None:
"""cache_key 成功时应返回组合键."""
spec = TaskSpec("a", _fn, cache_key=lambda ctx: ctx.get("x", "default"))
key = spec.storage_key({"x": "value"})
assert key == "a:value"
def test_storage_key_no_cache_key() -> None:
"""无 cache_key 时返回任务名."""
spec = TaskSpec("a", _fn)
key = spec.storage_key({})
assert key == "a"
# ---------------------------------------------------------------------- #
# _env_and_cwd 上下文管理器
# ---------------------------------------------------------------------- #
def test_env_and_cwd_sets_env() -> None:
"""应临时设置环境变量。"""
var_name = "PYFLOWX_TEST_ENV_VAR_1"
with _env_and_cwd({var_name: "test_value"}, None):
assert os.environ[var_name] == "test_value"
# 退出后应恢复
assert var_name not in os.environ
def test_env_and_cwd_restores_existing_env() -> None:
"""应恢复已有的环境变量."""
os.environ["EXISTING_VAR"] = "original"
try:
with _env_and_cwd({"EXISTING_VAR": "new_value"}, None):
assert os.environ["EXISTING_VAR"] == "new_value"
# 退出后应恢复原值
assert os.environ["EXISTING_VAR"] == "original"
finally:
os.environ.pop("EXISTING_VAR", None)
def test_env_and_cwd_sets_cwd(tmp_path: Path) -> None:
"""应临时切换工作目录."""
original = Path.cwd()
with _env_and_cwd(None, tmp_path):
assert Path.cwd() == tmp_path
# 退出后应恢复
assert Path.cwd() == original
def test_env_and_cwd_no_changes() -> None:
"""无 env 和 cwd 时不应有任何变化."""
original_env = dict(os.environ)
original_cwd = Path.cwd()
with _env_and_cwd(None, None):
pass
assert dict(os.environ) == original_env
assert Path.cwd() == original_cwd
def test_spec_env_context() -> None:
"""TaskSpec.env_context 应正确工作."""
var_name = "PYFLOWX_TEST_ENV_VAR_2"
spec = TaskSpec("a", _fn, env={var_name: "value"})
with spec.env_context():
assert os.environ[var_name] == "value"
assert var_name not in os.environ
# ---------------------------------------------------------------------- #
# task_template 工厂
# ---------------------------------------------------------------------- #
def test_task_template_creates_specs() -> None:
"""task_template 应创建 TaskSpec 工厂."""
template = task_template(fn=_fn, retry=RetryPolicy(max_attempts=3))
spec = template("task1")
assert spec.name == "task1"
assert spec.retry.max_attempts == 3
def test_task_template_with_cmd() -> None:
"""task_template 可以使用 cmd."""
template = task_template(cmd=["echo", "hello"])
spec = template("task1")
assert spec.name == "task1"
assert spec.cmd == ["echo", "hello"]
def test_task_template_overrides() -> None:
"""task_template 工厂可以覆盖默认值."""
template = task_template(fn=_fn, timeout=10.0)
spec = template("task1", timeout=5.0)
assert spec.timeout == 5.0
def test_task_template_factory_name() -> None:
"""工厂函数名应为 task_template_factory."""
template = task_template(fn=_fn)
assert template.__name__ == "task_template_factory"
# ---------------------------------------------------------------------- #
# TaskResult 测试
# ---------------------------------------------------------------------- #
def test_task_result_duration_none_when_not_started() -> None:
spec: TaskSpec[None] = TaskSpec("a", _fn)
result: TaskResult[None] = TaskResult(spec=spec)
@@ -61,3 +342,27 @@ def test_task_result_default_status() -> None:
assert result.value is None
assert result.error is None
assert result.attempts == 0
# ---------------------------------------------------------------------- #
# _run_command callable 命令测试
# ---------------------------------------------------------------------- #
def test_run_command_callable_verbose_with_cwd(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None:
"""callable 命令 verbose 模式应打印信息."""
spec = TaskSpec("a", cmd=lambda: "result", verbose=True, cwd=tmp_path)
import pyflowx.task as task_module
result = task_module._run_command(spec)
assert result == "result"
captured = capsys.readouterr()
assert "执行可调用命令" in captured.out
assert "工作目录" in captured.out
def test_run_command_callable_exception() -> None:
"""callable 命令抛异常应转为 RuntimeError."""
spec = TaskSpec("a", cmd=lambda: (_ for _ in ()).throw(RuntimeError("callable error")))
import pyflowx.task as task_module
with pytest.raises(RuntimeError, match="可调用命令执行异常"):
task_module._run_command(spec)
+65
View File
@@ -0,0 +1,65 @@
import time
import pytest
from pytest_mock import MockerFixture
from pyflowx.utils import _perf_metrics, perf_timer
@pytest.fixture(autouse=True)
def reset_perf_metrics():
"""重置性能指标."""
_perf_metrics.clear()
class TestPerformanceTimer:
def test_perf_timer(self):
@perf_timer()
def test_func():
time.sleep(0.1)
test_func()
assert _perf_metrics["test_func"] is not None
assert _perf_metrics["test_func"]["count"] == 1
assert _perf_metrics["test_func"]["total_time"] >= 0.1
def test_perf_timer_report(self, mocker: MockerFixture):
mock_log = mocker.patch("logging.info")
@perf_timer(report=True, unit="ms", precision=3)
def test_func():
time.sleep(0.1)
test_func()
assert _perf_metrics["test_func"] is not None
assert _perf_metrics["test_func"]["count"] == 1
assert _perf_metrics["test_func"]["total_time"] >= 0.1
assert mock_log.call_count == 1
def test_generate_report(self, mocker: MockerFixture, caplog: pytest.LogCaptureFixture):
mock_log = mocker.patch("logging.info")
from pyflowx.utils import _generate_report
@perf_timer(report=True, unit="ms", precision=3)
def test_func():
time.sleep(0.1)
@perf_timer(report=True, unit="ms", precision=3)
def test_func2():
time.sleep(0.2)
test_func()
test_func2()
_generate_report("ms", 3)
assert mock_log.call_count == 3
assert _perf_metrics["test_func"]["count"] == 1
assert _perf_metrics["test_func"]["total_time"] >= 0.1
assert _perf_metrics["test_func2"]["count"] == 1
assert _perf_metrics["test_func2"]["total_time"] >= 0.2
Generated
+482 -2122
View File
File diff suppressed because it is too large Load Diff