Compare commits
123 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3f9c52e6f1 | |||
| 8fadf6edd8 | |||
| abc1152538 | |||
| 5e561b4b3a | |||
| 40f641611b | |||
| 232e7293d9 | |||
| a1bae58e56 | |||
| cbc7cc0a75 | |||
| d0ff7d7b4d | |||
| d154f67ce0 | |||
| 9999071119 | |||
| bdd70e9c43 | |||
| c15b38516a | |||
| 7d4e8a40ce | |||
| 1b2d6d6a2c | |||
| df890f0f16 | |||
| b62a544569 | |||
| d58fc5536e | |||
| c3b86b603d | |||
| 327bd6e069 | |||
| 22f8d2110d | |||
| 2a1f2f7175 | |||
| 9d033e1c0b | |||
| 336f7b7292 | |||
| 65dcbcbf62 | |||
| 7fa97a01e3 | |||
| 83da5135d0 | |||
| 7463a60649 | |||
| 87dd010342 | |||
| bdfee7bee4 | |||
| b954fb1622 | |||
| a7b7a82dff | |||
| 40f0478146 | |||
| b808b880f8 | |||
| e073ff41ee | |||
| ea0c51de5e | |||
| 2b3f4b82d3 | |||
| 1e23c48efc | |||
| 5c8ec281ff | |||
| 6f01cde8ac | |||
| bcd189ae60 | |||
| 20c4fb87c5 | |||
| a98eb6e344 | |||
| 752ff618b2 | |||
| f15f235ecf | |||
| 9d79cddbd6 | |||
| af9aab395a | |||
| 6f334fde73 | |||
| 2ccd84ac3b | |||
| ec30af3edb | |||
| 10bbc07118 | |||
| 194cf3c343 | |||
| 1880cd7a34 | |||
| d43c9e4044 | |||
| 22ac9fc4dd | |||
| 7ded8df05e | |||
| fd282db28f | |||
| 6f64d9d6dc | |||
| a2889fbb08 | |||
| 024b597e44 | |||
| 1eb7942aa9 | |||
| 9285ae3782 | |||
| a88797f410 | |||
| b047b05aaf | |||
| 78a274ce5b | |||
| ab8faec863 | |||
| 936a009212 | |||
| f10f8d09a6 | |||
| 0d6a78f320 | |||
| c9a4192c85 | |||
| 0afdb54e5c | |||
| 9e99a1f1ba | |||
| 50575c6e91 | |||
| f8436f6b8c | |||
| 5c0f51e272 | |||
| 4e3622ef02 | |||
| f69ddc5133 | |||
| 477d901281 | |||
| 0df795237d | |||
| 413ab40044 | |||
| d4a1a5c2de | |||
| 843e9369fe | |||
| 48f6d8a7f0 | |||
| 0b97846d77 | |||
| 50e74180a2 | |||
| 71e6ba316a | |||
| 707e2ac07c | |||
| 983d47bd2e | |||
| 9cc91d1153 | |||
| 2f3041c169 | |||
| 6a004a54b9 | |||
| 2d0873af45 | |||
| 4cc21be562 | |||
| 98cf3b54a1 | |||
| af8a074484 | |||
| ff1122cb68 | |||
| cbc02c5aee | |||
| c8e9354e87 | |||
| 1ecff5fdf7 | |||
| c856c9b6a6 | |||
| ea591d1088 | |||
| cae51856d2 | |||
| be03662e4c | |||
| db18ca4978 | |||
| 7de55614a6 | |||
| 939cd724ec | |||
| 5ddfe8510c | |||
| cd38e1246a | |||
| febcd90a31 | |||
| 58bafd48cc | |||
| 179e5b3811 | |||
| 4884fd53e5 | |||
| 60083bcb6e | |||
| 56c018e72e | |||
| 22ae4b0084 | |||
| 08eb743ea9 | |||
| c06d0284c4 | |||
| 6cc693d15f | |||
| 13f6110b18 | |||
| 6d4b5e4a1f | |||
| e00868e3b1 | |||
| 4de55336f1 | |||
| fad964b370 |
+17
-105
@@ -2,137 +2,49 @@ name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, develop]
|
||||
pull_request:
|
||||
branches: [main, develop]
|
||||
workflow_dispatch:
|
||||
branches: [ main, develop ]
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# lint:代码风格与格式检查(单平台即可)
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
lint:
|
||||
name: Lint (ruff)
|
||||
lint-and-typecheck:
|
||||
name: Lint & Typecheck
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: 安装 uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
- uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
version: latest
|
||||
enable-cache: true
|
||||
cache-dependency-glob: uv.lock
|
||||
|
||||
- name: 设置 Python 3.13
|
||||
uses: actions/setup-python@v5
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.13'
|
||||
|
||||
- name: 安装依赖
|
||||
run: uv sync --extra dev --frozen
|
||||
- run: uv sync
|
||||
- run: uv run ruff check src tests
|
||||
- run: uv run pyrefly check .
|
||||
|
||||
- name: Ruff 检查
|
||||
run: uv run ruff check src tests examples
|
||||
|
||||
- name: Ruff 格式检查
|
||||
run: uv run ruff format --check src tests examples
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# typecheck:mypy 严格类型检查
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
typecheck:
|
||||
name: Typecheck (mypy)
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: 安装 uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
version: latest
|
||||
enable-cache: true
|
||||
cache-dependency-glob: uv.lock
|
||||
|
||||
- name: 设置 Python 3.13
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.13'
|
||||
|
||||
- name: 安装依赖
|
||||
run: uv sync --extra dev --frozen
|
||||
|
||||
- name: Mypy 严格类型检查
|
||||
run: uv run mypy
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# test:多平台 × 多 Python 版本矩阵测试 + 覆盖率
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
test:
|
||||
name: Test (${{ matrix.os }} / py${{ matrix.python-version }})
|
||||
name: Test (${{ matrix.os }})
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: [ubuntu-latest, windows-latest, macos-latest]
|
||||
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13']
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: 安装 uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
- uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
version: latest
|
||||
enable-cache: true
|
||||
cache-dependency-glob: uv.lock
|
||||
|
||||
- name: 设置 Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
python-version: |
|
||||
3.8
|
||||
3.13
|
||||
|
||||
- name: 安装依赖
|
||||
run: uv sync --extra dev --frozen
|
||||
|
||||
- name: 运行测试(含覆盖率,强制 100%)
|
||||
run: uv run pytest -v --cov=pyflowx --cov-report=xml --cov-report=term-missing --cov-fail-under=100
|
||||
|
||||
- name: 运行示例冒烟测试
|
||||
run: |
|
||||
uv run python examples/etl_pipeline.py
|
||||
uv run python examples/parallel_run.py
|
||||
uv run python examples/async_aggregation.py
|
||||
|
||||
- name: 上传覆盖率
|
||||
if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.13'
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: coverage-${{ matrix.os }}-py${{ matrix.python-version }}
|
||||
path: coverage.xml
|
||||
retention-days: 7
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# 聚合:所有检查通过后才标记完成
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
ci-pass:
|
||||
name: CI Pass
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint, typecheck, test]
|
||||
if: always()
|
||||
steps:
|
||||
- name: 检查依赖任务结果
|
||||
if: ${{ needs.lint.result != 'success' || needs.typecheck.result != 'success' || needs.test.result != 'success' }}
|
||||
run: |
|
||||
echo "lint: ${{ needs.lint.result }}"
|
||||
echo "typecheck: ${{ needs.typecheck.result }}"
|
||||
echo "test: ${{ needs.test.result }}"
|
||||
exit 1
|
||||
- name: 全部通过
|
||||
run: echo "✅ 所有 CI 检查通过"
|
||||
- run: uvx tox run -e py38,py313
|
||||
|
||||
+21
-153
@@ -2,192 +2,60 @@ name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*.*.*'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
tag:
|
||||
description: '发布版本号(如 v0.1.0)'
|
||||
required: true
|
||||
type: string
|
||||
tags: ['v*.*.*']
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
# Trusted Publishing (OIDC) 上传 PyPI 所需
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# 预检:版本号校验 + 与 pyproject.toml 一致性检查
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
pre-check:
|
||||
name: Pre-release Check
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
version: ${{ steps.meta.outputs.version }}
|
||||
tag: ${{ steps.meta.outputs.tag }}
|
||||
version: ${{ steps.version.outputs.version }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: 解析版本号
|
||||
id: meta
|
||||
run: |
|
||||
if [ -n "${{ inputs.tag }}" ]; then
|
||||
TAG="${{ inputs.tag }}"
|
||||
else
|
||||
TAG="${GITHUB_REF#refs/tags/}"
|
||||
fi
|
||||
# 去除前缀 v
|
||||
VERSION="${TAG#v}"
|
||||
echo "tag=$TAG" >> $GITHUB_OUTPUT
|
||||
echo "version=$VERSION" >> $GITHUB_OUTPUT
|
||||
echo "发布版本: $VERSION (tag: $TAG)"
|
||||
|
||||
- name: 校验版本号格式
|
||||
run: |
|
||||
VERSION="${{ steps.meta.outputs.version }}"
|
||||
if ! echo "$VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+(-[a-zA-Z0-9.]+)?$'; then
|
||||
echo "❌ 版本号格式错误: $VERSION(应为 x.y.z 或 x.y.z-rc.n)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: 校验 pyproject.toml 版本一致
|
||||
run: |
|
||||
# 精确提取 [project] 段的 version 字段(避免匹配到依赖的 version)
|
||||
PY_VERSION=$(awk '/^\[project\]/{f=1} f&&/^version[[:space:]]*=/{gsub(/[" ]/,"",$3); print $3; exit}' pyproject.toml)
|
||||
echo "pyproject.toml version: $PY_VERSION"
|
||||
if [ "$PY_VERSION" != "${{ steps.meta.outputs.version }}" ]; then
|
||||
echo "❌ pyproject.toml 版本($PY_VERSION) 与 tag 版本(${{ steps.meta.outputs.version }}) 不一致"
|
||||
echo "请先更新 pyproject.toml 中的 version 字段"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# 构建:wheel + sdist(纯 Python,单平台即可)
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
build:
|
||||
name: Build Artifacts
|
||||
needs: pre-check
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: 安装 uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
version: latest
|
||||
enable-cache: true
|
||||
|
||||
- name: 设置 Python 3.13
|
||||
uses: actions/setup-python@v5
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.13'
|
||||
|
||||
- name: 安装依赖
|
||||
run: uv sync --extra dev --frozen
|
||||
- run: uv build
|
||||
|
||||
- name: 构建 wheel + sdist
|
||||
run: uv build
|
||||
- id: version
|
||||
run: echo "version=${GITHUB_REF#refs/tags/v}" >> $GITHUB_OUTPUT
|
||||
|
||||
- name: 校验产物
|
||||
run: |
|
||||
echo "待上传产物:"
|
||||
ls -la dist/
|
||||
if [ -z "$(ls -A dist/*.whl dist/*.tar.gz 2>/dev/null)" ]; then
|
||||
echo "❌ 未找到 wheel 或 sdist 产物"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
- name: 上传构建产物
|
||||
uses: actions/upload-artifact@v4
|
||||
- uses: actions/upload-artifact@v7
|
||||
with:
|
||||
name: dist
|
||||
path: dist/*
|
||||
retention-days: 30
|
||||
path: dist/
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# 发布:上传到 PyPI(Trusted Publishing / OIDC)
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
publish-pypi:
|
||||
name: Publish to PyPI
|
||||
needs: [pre-check, build]
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/project/pyflowx/${{ needs.pre-check.outputs.version }}
|
||||
permissions:
|
||||
id-token: write
|
||||
environment: pypi
|
||||
steps:
|
||||
- name: 下载构建产物
|
||||
uses: actions/download-artifact@v4
|
||||
- uses: actions/download-artifact@v8
|
||||
with:
|
||||
name: dist
|
||||
path: dist
|
||||
|
||||
- name: 上传到 PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
attestations: true
|
||||
- uses: pypa/gh-action-pypi-publish@release/v1
|
||||
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
# 发布:创建 GitHub Release
|
||||
# ─────────────────────────────────────────────────────────────
|
||||
release:
|
||||
name: Publish Release
|
||||
needs: [pre-check, build, publish-pypi]
|
||||
needs: [build, publish-pypi]
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: 下载构建产物
|
||||
uses: actions/download-artifact@v4
|
||||
- uses: actions/download-artifact@v8
|
||||
with:
|
||||
name: dist
|
||||
path: assets
|
||||
path: dist
|
||||
|
||||
- name: 整理发布产物
|
||||
run: |
|
||||
ls -la assets/
|
||||
|
||||
- name: 生成 Release Notes
|
||||
id: notes
|
||||
run: |
|
||||
{
|
||||
echo "## pyflowx ${{ needs.pre-check.outputs.version }}"
|
||||
echo ""
|
||||
echo "### 下载"
|
||||
echo ""
|
||||
echo "- **Wheel**: \`pyflowx-${{ needs.pre-check.outputs.version }}-py3-none-any.whl\`"
|
||||
echo "- **源码包**: \`pyflowx-${{ needs.pre-check.outputs.version }}.tar.gz\`"
|
||||
echo ""
|
||||
echo "### 安装"
|
||||
echo ""
|
||||
echo '```bash'
|
||||
echo "pip install pyflowx==${{ needs.pre-check.outputs.version }}"
|
||||
echo '```'
|
||||
echo ""
|
||||
echo "### 完整变更日志"
|
||||
} > RELEASE_NOTES.md
|
||||
{
|
||||
echo "content<<EOF"
|
||||
cat RELEASE_NOTES.md
|
||||
echo "EOF"
|
||||
} >> $GITHUB_OUTPUT
|
||||
|
||||
- name: 创建 GitHub Release
|
||||
uses: softprops/action-gh-release@v2
|
||||
- uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
tag_name: ${{ needs.pre-check.outputs.tag }}
|
||||
name: pyflowx ${{ needs.pre-check.outputs.version }}
|
||||
body: ${{ steps.notes.outputs.content }}
|
||||
files: assets/*
|
||||
draft: false
|
||||
prerelease: ${{ contains(needs.pre-check.outputs.version, '-') }}
|
||||
files: dist/*
|
||||
generate_release_notes: true
|
||||
|
||||
@@ -9,3 +9,4 @@ wheels/
|
||||
# Virtual environments
|
||||
.venv
|
||||
.coverage
|
||||
.idea
|
||||
|
||||
@@ -7,10 +7,7 @@ repos:
|
||||
hooks:
|
||||
# Run the linter
|
||||
- id: ruff
|
||||
args: [ --fix, --exit-non-zero-on-fix ]
|
||||
# Run the formatter
|
||||
- id: ruff-format
|
||||
args: [ --config=pyproject.toml]
|
||||
args: [--fix, --exit-non-zero-on-fix]
|
||||
- repo: https://gitcode.com/gh_mirrors/pr/pre-commit-hooks.git
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
@@ -18,5 +15,5 @@ repos:
|
||||
- id: debug-statements
|
||||
- id: fix-byte-order-marker
|
||||
- id: trailing-whitespace
|
||||
args: [ --markdown-linebreak-ext=md ]
|
||||
args: [--markdown-linebreak-ext=md]
|
||||
- id: end-of-file-fixer
|
||||
|
||||
+1
-1
@@ -1 +1 @@
|
||||
3.8
|
||||
3.11
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
# PYTHON
|
||||
.coverage
|
||||
.pytest_cache/
|
||||
.ruff_cache/
|
||||
.tox/
|
||||
.venv/
|
||||
__pycache__/
|
||||
|
||||
# NODEJS
|
||||
node_modules/
|
||||
|
||||
# IDE
|
||||
.idea
|
||||
.trae
|
||||
.vscode
|
||||
@@ -0,0 +1,11 @@
|
||||
---
|
||||
alwaysApply: true
|
||||
scene: git_message
|
||||
---
|
||||
|
||||
在此处编写规则,自定义 AI 生成提交信息的风格。
|
||||
|
||||
## 提交信息格式
|
||||
- 提交信息必须使用中文。
|
||||
- 提交信息必须包含变更的类型(例如 "fix"、"feat"、"refactor" 等)。
|
||||
- 提交信息必须尽简洁明了,不要超过一段落。
|
||||
@@ -0,0 +1,157 @@
|
||||
# Python 开发规范
|
||||
|
||||
本规范结合 Python 最佳实践,作为编写与审查 Python 代码的统一标准。
|
||||
详细操作指南见 `.agents/skills/` 下相应技能。
|
||||
|
||||
## 工具链(以 pyproject.toml 为准)
|
||||
|
||||
| 工具 | 用途 | 配置要点 |
|
||||
|------|------|---------|
|
||||
| **ruff** | lint + format | `line-length=120`,`target-version="py38"` |
|
||||
| **pyrefly** | 类型检查 | `preset="strict"`,`python-version="3.8"` |
|
||||
| **pytest** | 测试 | `asyncio_default_fixture_loop_scope="function"`,marker `slow` |
|
||||
| **coverage** | 覆盖率 | `branch=true`,`fail_under=95`,`concurrency=["thread"]` |
|
||||
| **pre-commit** | 提交前检查 | ruff `--fix` + trailing-whitespace + end-of-file-fixer |
|
||||
|
||||
验证(每次修改后必做):
|
||||
|
||||
```bash
|
||||
uvx --from pyflowx pymake tc
|
||||
uvx --from pyflowx pymake cov
|
||||
```
|
||||
|
||||
## 兼容性
|
||||
|
||||
- **最低 Python 3.8**:用 `from __future__ import annotations` 延迟注解求值;
|
||||
按版本用 `typing.List`(3.8) → 内置泛型(3.9) → `X | Y`(3.10) → `typing.override`(3.12)。
|
||||
- **版本守卫**:`if sys.version_info >= (3, X):` 引入高版本 API;低版本回退分支加 `# pragma: no cover`。
|
||||
- **零运行时依赖**:仅依赖标准库(3.8 需 `graphlib_backport`、`typing-extensions`)。
|
||||
新增依赖须审慎,优先用标准库。
|
||||
|
||||
## 类型注解
|
||||
|
||||
- **公共 API 必须有完整类型注解**,包括返回类型;私有函数也应有注解。
|
||||
- 泛型用 `TypeVar`;PEP 696 `default=` 仅 3.13+ 标准库支持,3.8–3.12 用 `typing_extensions.TypeVar`。
|
||||
- `Mapping`/`Sequence` 用于只读参数,`dict`/`list` 用于可变返回。
|
||||
- `Any` 仅用于真正动态场景(如 `Context` 跨任务异构映射);任务内部类型必须完全静态。
|
||||
- 禁用裸 `# type: ignore`;确需时加具体规则码(如 `# type: ignore[union-attr]`)。
|
||||
- **`TYPE_CHECKING` 守卫**:仅类型检查需要的导入放 `if TYPE_CHECKING:` 块内,避免循环依赖。
|
||||
- **类型收窄**:用 `assert isinstance(x, Y)` 辅助 pyrefly 推断;`cast()` 仅用于类型系统无法表达的场景。
|
||||
|
||||
## 数据结构
|
||||
|
||||
- **不可变优先**:配置/描述类用 `@dataclass(frozen=True)`;可变类属性标注 `RUF012` 豁免。
|
||||
- **缓存**:实例级用 `functools.cached_property`,按参数键控用 `functools.lru_cache`;
|
||||
不可哈希参数需 try/except 回退。修改被缓存数据源后必须手动清空缓存。
|
||||
- **抽象基类**:接口用 `abc.ABC` + `@abstractmethod`(如 `StateBackend`)。
|
||||
- **枚举**:状态/标志值用 `enum.Enum`(如 `TaskStatus`),禁止裸字符串/魔术数字;枚举值用 `UPPER_SNAKE`。
|
||||
- **`__repr__`**:可变类实现 `__repr__`(含关键字段);`frozen=True` dataclass 自动生成。
|
||||
|
||||
## 模块与导入
|
||||
|
||||
- **单一职责**:每模块只做一件事(`task.py` 数据结构、`executors.py` 执行、`command.py` 命令、`compose.py` 组合)。禁止跨职责边界。
|
||||
- **导入顺序**(ruff isort):`__future__` → 标准库 → 第三方 → 本地,各组间空行。
|
||||
- **惰性导入**:仅为打破循环依赖时使用,函数体内导入并注释说明;顶层导入是默认。
|
||||
- **`__all__`**:定义 `__all__` 显式声明导出符号,位置仅次于 `__future__` 之后。
|
||||
- **禁用 star imports**:`from x import *` 污染命名空间、破坏类型检查(`__init__.py` 聚合经 `__all__` 控制为例外)。
|
||||
- **避免 `utils.py`/`helpers.py`**:按职责归入对应模块。
|
||||
|
||||
## 函数设计
|
||||
|
||||
- **模块级函数优于 Mixin**:共享逻辑用模块级函数,类只持有状态与薄方法。
|
||||
- **静态方法慎用**:纯函数直接放模块级。
|
||||
- **参数 ≤ 5 个**为宜;超出用 dataclass 封装参数对象。
|
||||
- **单一职责**:一个函数做一件事;过长函数考虑拆分。
|
||||
- **异常范围要窄**:只捕获预期异常(如 `(TypeError, ValueError, KeyError, AttributeError)`),
|
||||
**禁止** `except Exception` 掩盖 bug;捕获后至少 `logger.warning` 记录。
|
||||
- **可变默认参数**:`def f(x=[])` 是经典坑;用 `None` 哨兵或 `field(default_factory=list)`。
|
||||
|
||||
## 异常处理
|
||||
|
||||
- **自定义异常家族**:继承公共基类(如 `PyFlowXError`),按错误场景分类。
|
||||
- **异常包装**:`raise NewError(...) from exc` 保留因果链。
|
||||
- **不要吞异常**:捕获后必须处理(记录/包装/重抛),禁止空 `except: pass`。
|
||||
- **钩子/回调异常**:第三方回调异常仅记录,不影响主流程。
|
||||
|
||||
## 并发与线程安全
|
||||
|
||||
- **进程全局状态**(`os.environ`/`os.chdir`)在并发场景下必须用全局锁(`threading.RLock`)序列化。
|
||||
- **条件评估不可有可变状态**:组合条件(NOT/AND/OR)不得修改共享 `_reason`,避免竞态。
|
||||
- **批量 I/O**:循环内多次写盘改为批量一次(`contextmanager` 包裹延迟落盘)。
|
||||
- **信号量限流**:`concurrency_key` + `Semaphore` 按组限流。
|
||||
|
||||
## 测试
|
||||
|
||||
详细操作指南见 `.agents/skills/pyflowx-testing` 技能。硬约束:
|
||||
|
||||
- **覆盖率 ≥ 95%**(branch coverage),不得下降。
|
||||
- **公共 API 优先测试**:用公共接口(`has`/`get`),不访问私有方法;
|
||||
故障注入等场景可临时访问私有属性,docstring 注明原因。
|
||||
- **命名**:`test_<被测对象>_<场景>`。
|
||||
- **断言**:原生 `assert x == 1`,禁用 `self.assertEqual`;`pytest.raises` 必填 `match=`。
|
||||
- **Mock 优先级**:`monkeypatch` > 内联 stub > `unittest.mock` > `pytest-mock`。
|
||||
禁用 `@patch` 装饰器、`mock.patch.object` 上下文、`pytest-mock` 的 `mocker` fixture。
|
||||
- **fixture**:`tmp_path`/`monkeypatch`/`capsys` 优先;autouse 仅全局必需时用。
|
||||
- **slow 标记**:耗时测试加 `@pytest.mark.slow`,CI 可 `-m "not slow"` 跳过。
|
||||
- **测试代码也跑 ruff**:`tests/**` 忽略 `ARG001`/`ARG002`。
|
||||
|
||||
## 代码风格
|
||||
|
||||
- **行宽 120**(ruff formatter 处理)。
|
||||
- **docstring**:公共 API 必须有;中文叙述 + 中文注释是本项目既有风格。
|
||||
- **打印和日志**:使用中文打印和日志,避免使用英文。
|
||||
- **命名**:`snake_case` 函数/变量,`PascalCase` 类,`UPPER_SNAKE` 常量,`_leading_underscore` 私有。
|
||||
- **字符串引号**:ruff 默认双引号。
|
||||
- **末尾单 `\n`**、**无尾随空格**(pre-commit 强制)。
|
||||
- **不用 emoji**:除非用户明确要求。
|
||||
|
||||
## Pythonic 风格
|
||||
|
||||
- **`is` 比较 `None`/`True`/`False`**:单例用 `is`,值用 `==`(PEP 8 E711/E712)。
|
||||
- **EAFP 优于 LBYL**:先尝试再处理异常,而非先检查再执行(避免竞态窗口)。
|
||||
- **truthiness**:`if items:` 优于 `if len(items) > 0:`。
|
||||
- **字符串格式化**:首选 f-string;`%` 仅用于 `logging` 延迟格式化。
|
||||
- **推导式**优于 `map`+`filter`;> 2 层拆为显式循环。
|
||||
- **`enumerate`** 替代 `range(len())`;**`zip`** 并行迭代(3.10+ 用 `strict=True`)。
|
||||
- **解包** `a, b = pair` 优于索引访问;忽略值用 `_`。
|
||||
- **海象运算符 `:=`**(3.8+):赋值+判断合一,但不滥用。
|
||||
|
||||
## 日志
|
||||
|
||||
- **`logging.getLogger(__name__)`**:每模块独立 logger,禁用 `print` 调试残留。
|
||||
- **结构化上下文**:`extra={...}` 传字段;`logger.warning("task %r failed: %s", name, exc)` 优于 f-string(延迟格式化)。
|
||||
- **日志级别**:`DEBUG` 诊断 / `INFO` 关键流程 / `WARNING` 可恢复异常 / `ERROR` 需人工介入。
|
||||
- **禁止日志密码/密钥**:脱敏后再记录。
|
||||
|
||||
## 路径与资源
|
||||
|
||||
- **优先 `pathlib.Path`**:`Path("a") / "b"` 而非 `os.path.join`(ruff `PTH` 强制);
|
||||
禁止字符串拼接路径。类型注解用 `Path`,边界 `str` 立即包装。
|
||||
- **`with` 语句**:文件、锁、连接、临时目录一律用 `with` 或 `contextlib.contextmanager`;
|
||||
多资源用 `contextlib.ExitStack`。
|
||||
- **显式关闭**:长生命周期对象(连接池、线程池)实现 `close()`,但优先 `with`。
|
||||
- **批量操作**:循环内多次 acquire/release 改为批量一次。
|
||||
|
||||
## 安全
|
||||
|
||||
- **禁用 `eval`/`exec`**:处理不可信输入时绝不使用;用 `ast.literal_eval` 或专用解析器。
|
||||
- **`subprocess`**:禁用 `shell=True` 除非命令完全可信;优先 `list[str]` 形式。
|
||||
- **凭证不入仓**:密钥/token/密码放 `.env` 或环境变量,`.gitignore` 必须包含 `.env`。
|
||||
- **日志脱敏**:记录请求/响应时移除 `Authorization`、`password` 等字段。
|
||||
- **依赖审计**:`uv lock` 后审阅新增依赖,避免引入已知 CVE 的包。
|
||||
|
||||
## 性能要点
|
||||
|
||||
- **避免重复计算**:循环内查询应缓存或预构建映射(如 `{name: spec}`)。
|
||||
- **避免双重查找**:`has(k)` + `get(k)` 改为单次 `get(k)` + `KeyError` 回退。
|
||||
- **统一校验**:入口校验一次,下游路径不重复(如 `run()` 统一 `validate()`,`layers()` 不再重复)。
|
||||
- **事件 emit**:任务生命周期必须 emit `RUNNING` → `SUCCESS`/`FAILED`/`SKIPPED`,
|
||||
不要留死分支(`# pragma: no cover` 是清理信号,应激活或删除)。
|
||||
|
||||
## Git 与提交
|
||||
|
||||
- **不自动提交/push**:除非用户明确要求。
|
||||
- **不修改 git config**。
|
||||
- **不运行破坏性命令**(`push --force`/`reset --hard`/`clean -f`)除非用户明确要求。
|
||||
- **staging**:按文件名添加,不用 `git add -A`/`git add .`,避免误加敏感文件。
|
||||
- **commit message**:简洁,聚焦"为什么"而非"是什么";遵循仓库既有风格。
|
||||
@@ -0,0 +1,135 @@
|
||||
---
|
||||
name: "pyflowx-testing"
|
||||
description: "PyFlowX 项目的测试编写规范与 mock 使用指南。在编写或审查测试、选择 mock 工具、设计 fixture、处理 asyncio 测试时调用。"
|
||||
---
|
||||
|
||||
# PyFlowX 测试规范
|
||||
|
||||
本技能是 `.trae/rules/python-standards.md` 测试章节的详细展开。
|
||||
规则文件仅保留硬约束指针,本文件提供完整操作指南。
|
||||
|
||||
## 总则
|
||||
|
||||
- **覆盖率 ≥ 95%**(branch coverage),不得下降。
|
||||
- **公共 API 优先测试**:测试用公共接口(`has`/`get`),不访问私有方法
|
||||
(如 `_expired`)。兼容旧测试的私有方法应删除并迁移测试。
|
||||
例外:`_store`/`_flush` 等内部状态在无法用公共 API 触发时(如模拟过期、
|
||||
故障注入),可临时访问私有属性,并在 docstring 注明原因。
|
||||
- **命名**:`test_<被测对象>_<场景>`,如 `test_storage_key_cache_key_exception_returns_name`。
|
||||
- **每个测试一个断言重点**;多个断言要语义相关。
|
||||
- **slow 标记**:耗时测试加 `@pytest.mark.slow`,CI 可 `-m "not slow"` 跳过。
|
||||
- **测试代码也跑 ruff**:`tests/**` 忽略 `ARG001`/`ARG002`(未用 fixture 参数)。
|
||||
- **断言风格**:用原生 `assert` + 比较运算符(`assert x == 1`),
|
||||
不用 `self.assertEqual`;pytest 会生成更清晰的 diff。
|
||||
|
||||
## Mock 工具选择(强制)
|
||||
|
||||
**优先级**:`monkeypatch` > 内联 stub > `unittest.mock` > `pytest-mock`。
|
||||
|
||||
| 场景 | 工具 | 示例 |
|
||||
|------|------|------|
|
||||
| 替换模块属性 / 环境变量 / 工作目录 | `monkeypatch` | `monkeypatch.setattr(subprocess, "run", fake_run)` |
|
||||
| `os.environ["KEY"]` 临时设置 | `monkeypatch.setenv` | `monkeypatch.setenv("LOCALAPPDATA", "C:\\...")` |
|
||||
| 切换 cwd | `monkeypatch.chdir` | `monkeypatch.chdir(tmp_path)` |
|
||||
| 一次性 stub 函数 | 内联 lambda / 闭包 | `ran = []; monkeypatch.setattr(subprocess, "run", lambda *c, **__: ran.append(c))` |
|
||||
| 复杂 spy(记录调用次数/参数/返回序列) | `unittest.mock.MagicMock` | 仅当 lambda 不足以表达时 |
|
||||
| `with patch(...)` 上下文 | **禁用**(用 monkeypatch) | monkeypatch 自动 teardown 更安全 |
|
||||
|
||||
**禁止**:
|
||||
- 不用 `pytest-mock` 的 `mocker` fixture(项目虽在 dev 依赖声明,但实际
|
||||
测试代码未使用;为保持风格统一,新代码继续用 `monkeypatch`)。
|
||||
- 不用 `unittest.mock.patch` 装饰器(`@patch("x.y")`),它隐藏依赖且
|
||||
与 pytest fixture 模式不兼容;用 `monkeypatch.setattr` 替代。
|
||||
- 不用 `mock.patch.object` 作为上下文管理器,除非被测代码本身就是
|
||||
contextmanager(此时用 `monkeypatch.setattr` 仍更简单)。
|
||||
|
||||
## monkeypatch 使用规范
|
||||
|
||||
- **类型注解**:fixture 参数标注 `monkeypatch: pytest.MonkeyPatch`。
|
||||
- **作用域**:monkeypatch 自动在测试结束时撤销,**禁止**手动
|
||||
`monkeypatch.setattr(x, "y", original)` 恢复(多余且容易遗漏)。
|
||||
例外:在单个测试内需要中途恢复时,用 `monkeypatch.undo()` 全量撤销。
|
||||
- **替换目标**:替换"被测代码看到的对象",而非全局对象本身。
|
||||
- 错误:`monkeypatch.setattr("os.path.exists", fake)` —— 替换全局,影响其他模块。
|
||||
- 正确:`monkeypatch.setattr(pyflowx.command.shutil, "which", fake)` ——
|
||||
替换被测模块引用的 `shutil.which`。
|
||||
- **属性 vs 字符串路径**:优先属性访问形式 `monkeypatch.setattr(obj, "attr", val)`
|
||||
而非字符串路径 `monkeypatch.setattr("pkg.mod.obj.attr", val)`,
|
||||
前者有 IDE 跳转与重构支持。
|
||||
- **记录调用**:用闭包 `ran: list[tuple] = []` + `lambda *a, **k: ran.append((a, k))`
|
||||
替代 `MagicMock`,可读性更好且无需导入。
|
||||
|
||||
## Stub 与 Spy 模式
|
||||
|
||||
- **轻量 stub**:内联定义 `class MockResult: returncode = 0; stdout = ""`,
|
||||
替代 `MagicMock(return_value=...)`,类型明确且不引入 mock 依赖。
|
||||
- **状态收集**:闭包 + list 比 `mock.call_args_list` 更易断言:
|
||||
```python
|
||||
calls: list[list[str]] = []
|
||||
|
||||
|
||||
def fake_run(cmd: list[str], **_: Any) -> MockResult:
|
||||
calls.append(cmd)
|
||||
return MockResult()
|
||||
|
||||
|
||||
monkeypatch.setattr(subprocess, "run", fake_run)
|
||||
assert calls == [["clear"]]
|
||||
```
|
||||
- **副作用序列**:需要按调用次数返回不同值时,用 `itertools.cycle` 或
|
||||
手动计数器,而非 `side_effect=[...]`(mock 专有 API)。
|
||||
- **异常注入**:`def raise_oserror(*a, **k): raise OSError("...")`,
|
||||
用 `pytest.raises(OSError)` 验证,而非 `side_effect=OSError`。
|
||||
|
||||
## 异常断言
|
||||
|
||||
- **`pytest.raises`**:必填 `match=` 正则(除非异常消息完全不可预测),
|
||||
避免误捕获同类异常:
|
||||
```python
|
||||
with pytest.raises(StorageError, match="cannot write"):
|
||||
b.save("a", 1)
|
||||
```
|
||||
- **异常链**:验证 `__cause__` 时用 `exc_info.value.__cause__`,
|
||||
确认 `raise X from Y` 因果链完整。
|
||||
- **禁止** `try/except + assert False`:用 `pytest.raises` 替代。
|
||||
|
||||
## Fixture 规范
|
||||
|
||||
- **`tmp_path`**:处理临时文件,自动清理,禁止 `tempfile.mkdtemp()` 手动管理。
|
||||
- **`monkeypatch`**:环境变量、cwd、模块属性 mock(见上)。
|
||||
- **`capsys`/`capfd`**:捕获 stdout/stderr,验证日志或命令输出。
|
||||
- **autouse fixture**:仅在全局必需时用(如 `conftest.py` 的
|
||||
`packtool_tmp_workdir` 自动切到 tmp_path);否则显式声明参数。
|
||||
- **fixture 命名**:`snake_case`,描述"提供什么"而非"测试什么"
|
||||
(`sample_graph` 优于 `test_data`)。
|
||||
- **fixture 作用域**:默认 `function`;`module`/`session` 仅当构造昂贵且
|
||||
只读时,并加注释说明无副作用。
|
||||
|
||||
## asyncio 测试
|
||||
|
||||
- **fixture `loop_scope="function"`**(pyproject 已配置默认值)。
|
||||
- **async 测试**:`async def test_x():`,pytest-asyncio 自动驱动。
|
||||
- **await 检查**:测试异步函数必须 `await` 结果,禁止仅验证返回 coroutine 对象。
|
||||
- **异步 mock**:用 `AsyncMock`(3.8+ 在 `unittest.mock`)或
|
||||
`async def fake(): return value`,禁用 `MagicMock(return_value=coro)`。
|
||||
|
||||
## 参数化
|
||||
|
||||
- **`@pytest.mark.parametrize`**:用 `ids` 参数提供可读标识:
|
||||
```python
|
||||
@pytest.mark.parametrize(
|
||||
("strategy", "expected_workers"),
|
||||
[("sequential", 1), ("thread", 8), ("async", 1)],
|
||||
ids=["seq", "thread-8", "async"],
|
||||
)
|
||||
```
|
||||
- **参数命名**:参数元组用有意义名称,而非 `("a", "b")`。
|
||||
- **组合爆炸**:参数组合 > 20 时拆分测试,避免单个测试函数臃肿。
|
||||
|
||||
## 测试组织
|
||||
|
||||
- **文件命名**:`test_<被测模块>.py`(`test_storage.py` 对应 `storage.py`)。
|
||||
- **类分组**:仅在测试逻辑强相关时用 `class TestXxx:` 分组;默认用模块级函数。
|
||||
- **docstring**:每个测试函数一句话说明"测试什么场景",复杂场景补充"为什么"。
|
||||
- **setup/teardown**:优先 fixture;`setup_method`/`teardown_method` 仅在
|
||||
无法用 fixture 表达时(罕见)。
|
||||
Vendored
-1
@@ -18,7 +18,6 @@
|
||||
"evenBetterToml.formatter.arrayAutoCollapse": true,
|
||||
"evenBetterToml.formatter.arrayAutoExpand": true,
|
||||
"evenBetterToml.formatter.arrayTrailingComma": true,
|
||||
"evenBetterToml.formatter.columnWidth": 120,
|
||||
"evenBetterToml.formatter.compactEntries": false,
|
||||
"evenBetterToml.formatter.indentEntries": false,
|
||||
"evenBetterToml.formatter.indentTables": false,
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2026 endo Team
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -2,11 +2,11 @@
|
||||
|
||||
> 轻量、类型安全的 DAG 任务调度器。
|
||||
|
||||
[](https://github.com/pyflowx/pyflowx/actions/workflows/ci.yml)
|
||||
[](https://github.com/gookeryoung/pyflowx/actions/workflows/ci.yml)
|
||||
[](https://pypi.org/project/pyflowx/)
|
||||
[](https://pypi.org/project/pyflowx/)
|
||||
[](https://github.com/pyflowx/pyflowx)
|
||||
[](https://github.com/pyflowx/pyflowx/blob/main/LICENSE)
|
||||
[](https://github.com/gookeryoung/pyflowx)
|
||||
[](https://github.com/gookeryoung/pyflowx/blob/main/LICENSE)
|
||||
|
||||
PyFlowX 把"任务依赖"这件事做到极致简单:**参数名就是依赖声明**。无需装饰器、
|
||||
无需样板包装器,写一个普通函数,框架按参数名自动注入上游结果。
|
||||
@@ -14,15 +14,25 @@ PyFlowX 把"任务依赖"这件事做到极致简单:**参数名就是依赖
|
||||
## 特性
|
||||
|
||||
- **零样板** —— 参数名即依赖,框架自动注入上游结果
|
||||
- **三种执行策略** —— `sequential`(调试)/ `thread`(I/O 密集同步)/ `async`(I/O 密集异步)
|
||||
- **四种执行策略** —— `sequential`(调试)/ `thread`(I/O 密集同步)/ `async`(I/O 密集异步)/ `dependency`(依赖驱动,最大化并行)
|
||||
- **类型安全** —— `TaskSpec[T]` 把返回类型一路传到 `RunReport`,mypy strict 通过
|
||||
- **DAG 校验** —— 构建时即时校验重名、缺失依赖、环
|
||||
- **自动分层** —— Kahn 算法分组,同层任务可并行
|
||||
- **重试与超时** —— 每个任务独立配置 `retries` 与 `timeout`
|
||||
- **断点续跑** —— `MemoryBackend` / `JSONBackend`,成功结果可缓存复用
|
||||
- **可观测** —— `on_event` 回调、`dry_run` 预览、Mermaid 可视化
|
||||
- **重试与超时** —— 每个任务独立配置 `RetryPolicy`(max_attempts/delay/backoff/jitter/retry_on)与 `timeout`
|
||||
- **软依赖** —— `soft_depends_on` 仅用于上下文注入,不参与拓扑分层
|
||||
- **并发限制** —— `concurrency_key` + `concurrency_limits` 按组限流
|
||||
- **任务钩子** —— `TaskHooks`(pre_run/post_run/on_failure)生命周期回调
|
||||
- **断点续跑** —— `MemoryBackend` / `JSONBackend`,成功结果可缓存复用;`batch()` 批量落盘
|
||||
- **缓存键** —— `cache_key` 函数基于输入计算稳定键,使不同输入产生独立缓存
|
||||
- **命令任务** —— `cmd` 参数直接执行外部命令,支持列表/shell/可调用对象
|
||||
- **条件执行** —— `conditions` 参数按平台、环境变量、应用安装等条件跳过任务
|
||||
- **图组合** —— `compose` / `GraphComposer` 编程式展开多图字符串引用
|
||||
- **任务模板** —— `task_template` 工厂批量生成相似 TaskSpec
|
||||
- **图级默认值** —— `GraphDefaults` 统一配置 retry/timeout/concurrency 等
|
||||
- **CLI 运行器** —— `CliRunner` 把多个图映射为命令行子命令,替代 Makefile
|
||||
- **可观测** —— `on_event` 回调(RUNNING/SUCCESS/FAILED/SKIPPED)、`dry_run` 预览、`verbose` 生命周期日志、Mermaid 可视化
|
||||
- **零运行时依赖** —— 仅依赖标准库(3.8 需 `graphlib_backport`)
|
||||
- **100% 测试覆盖** —— 分支覆盖率达 100%
|
||||
- **97% 测试覆盖** —— 分支覆盖率 >= 95%
|
||||
|
||||
## 安装
|
||||
|
||||
@@ -41,13 +51,16 @@ uv add pyflowx
|
||||
```python
|
||||
import pyflowx as px
|
||||
|
||||
|
||||
def extract() -> list[int]:
|
||||
return [1, 2, 3]
|
||||
|
||||
|
||||
# 参数名 extract 自动匹配上游任务名 → 自动注入
|
||||
def double(extract: list[int]) -> list[int]:
|
||||
return [x * 2 for x in extract]
|
||||
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("extract", extract),
|
||||
px.TaskSpec("double", double, ("extract",)),
|
||||
@@ -61,36 +74,92 @@ print(report["double"]) # [2, 4, 6]
|
||||
|
||||
### TaskSpec —— 任务描述
|
||||
|
||||
`TaskSpec` 是不可变的任务描述符,是唯一需要配置的东西:
|
||||
`TaskSpec` 是不可变的任务描述符(`Generic[T]`,返回类型一路传到 `RunReport`),是唯一需要配置的东西:
|
||||
|
||||
```python
|
||||
px.TaskSpec(
|
||||
name="fetch_user", # 唯一标识
|
||||
fn=fetch_user, # 同步或异步函数
|
||||
depends_on=("auth",), # 依赖的任务名
|
||||
args=(uid,), # 静态位置参数(追加在注入参数后)
|
||||
kwargs={"timeout": 30}, # 静态关键字参数
|
||||
retries=3, # 失败重试次数(0 = 仅一次)
|
||||
timeout=30.0, # 超时秒数(None = 不限制)
|
||||
tags=("api", "user"), # 自由标签,用于子图过滤
|
||||
name="fetch_user", # 唯一标识
|
||||
fn=fetch_user, # 同步或异步函数
|
||||
cmd=["curl", "..."], # 或: 执行命令(覆盖 fn)
|
||||
depends_on=("auth",), # 硬依赖(参与拓扑分层)
|
||||
soft_depends_on=("cache",), # 软依赖(仅注入,不参与分层)
|
||||
args=(uid,), # 静态位置参数(追加在注入参数后)
|
||||
kwargs={"timeout": 30}, # 静态关键字参数
|
||||
retry=px.RetryPolicy(max_attempts=3, delay=1.0, backoff=2.0), # 重试策略
|
||||
timeout=30.0, # 超时秒数(None = 不限制)
|
||||
tags=("api", "user"), # 自由标签,用于子图过滤
|
||||
conditions=(is_prod,), # 条件函数列表(全部为 True 才执行)
|
||||
priority=10, # 同层内优先级(高优先执行,默认 0)
|
||||
concurrency_key="db", # 并发分组键(配合 concurrency_limits 限流)
|
||||
cache_key=lambda ctx: str(ctx.get("uid")), # 缓存键函数(不同输入独立缓存)
|
||||
hooks=px.TaskHooks(pre_run=..., post_run=..., on_failure=...), # 生命周期钩子
|
||||
cwd=Path("/tmp"), # 命令工作目录(仅 cmd 模式)
|
||||
env={"DEBUG": "1"}, # 环境变量覆盖(fn 与 cmd 模式均生效)
|
||||
verbose=True, # 打印命令输出(仅 cmd 模式)
|
||||
skip_if_missing=True, # 命令不存在时自动跳过(仅 list[str] cmd)
|
||||
allow_upstream_skip=False, # 上游 SKIPPED/FAILED 时是否仍执行
|
||||
continue_on_error=False, # 本任务失败是否不中断整体
|
||||
)
|
||||
```
|
||||
|
||||
支持两种任务形态:
|
||||
|
||||
- **函数任务**(`fn`):普通 Python 函数,参数名驱动自动注入
|
||||
- **命令任务**(`cmd`):执行外部命令,支持 `list[str]`、`str`(shell)、`Callable` 三种形态
|
||||
|
||||
`skip_if_missing=True` 时,`list[str]` 类型的 `cmd` 会通过 `shutil.which` 检查命令是否存在,不存在则跳过任务(标记为 `SKIPPED`)而非失败。适用于构建工具场景,避免因未安装某些工具而导致整个图执行失败。
|
||||
|
||||
### Graph —— DAG 构建
|
||||
|
||||
```python
|
||||
graph = px.Graph.from_specs([...]) # 整批校验(推荐)
|
||||
# 图级默认值:TaskSpec 字段为 None 时回退
|
||||
defaults = px.GraphDefaults(retry=px.RetryPolicy(max_attempts=2), timeout=60.0)
|
||||
|
||||
graph = px.Graph.from_specs([...], defaults=defaults) # 整批校验(推荐)
|
||||
# 或增量构建
|
||||
graph = px.Graph()
|
||||
graph = px.Graph(defaults=defaults)
|
||||
graph.add(px.TaskSpec("a", fn_a))
|
||||
graph.add(px.TaskSpec("b", fn_b, ("a",)))
|
||||
|
||||
graph.validate() # 显式校验(环检测)
|
||||
graph.layers() # 拓扑分层
|
||||
graph.to_mermaid() # Mermaid 可视化
|
||||
graph.describe() # 人类可读摘要
|
||||
graph.subgraph(("api",)) # 按标签切片
|
||||
graph.validate() # 显式校验(环检测)
|
||||
graph.layers() # 拓扑分层(run() 入口已统一校验,直接调用需自行先 validate)
|
||||
graph.to_mermaid() # Mermaid 可视化
|
||||
graph.describe() # 人类可读摘要
|
||||
graph.subgraph(("api",)) # 按标签切片
|
||||
graph.subgraph_by_names(("a", "b")) # 按名称切片
|
||||
graph.map("fetch", [1, 2, 3], lambda i: TaskSpec(f"fetch_{i}", ...)) # 批量 fan-out
|
||||
```
|
||||
|
||||
### 图组合 —— compose
|
||||
|
||||
`compose` / `GraphComposer` 把带字符串引用的多个图展开为纯 `Graph`:
|
||||
|
||||
```python
|
||||
graphs = {
|
||||
"build": px.Graph.from_specs([px.TaskSpec("b", cmd=["echo", "b"])]),
|
||||
"all": px.Graph.from_specs(["build", px.TaskSpec("t", cmd=["echo", "t"])]),
|
||||
}
|
||||
resolved = px.compose(graphs) # "all" 图中的 "build" 引用被展开
|
||||
```
|
||||
|
||||
引用格式:`"command_name"`(整个图)或 `"command_name.task_name"`(特定任务)。
|
||||
`CliRunner` 内部自动调用 `compose`。
|
||||
|
||||
### 任务模板 —— task_template
|
||||
|
||||
`task_template` 工厂批量生成相似 TaskSpec:
|
||||
|
||||
```python
|
||||
fetch = px.task_template(
|
||||
fn=fetch_url,
|
||||
retry=px.RetryPolicy(max_attempts=5),
|
||||
timeout=30.0,
|
||||
tags=("api",),
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
fetch("users", url="https://api.example.com/users"),
|
||||
fetch("posts", url="https://api.example.com/posts"),
|
||||
])
|
||||
```
|
||||
|
||||
### run —— 执行
|
||||
@@ -98,23 +167,26 @@ graph.subgraph_by_names(("a", "b")) # 按名称切片
|
||||
```python
|
||||
report = px.run(
|
||||
graph,
|
||||
strategy="async", # sequential | thread | async
|
||||
max_workers=8, # thread 策略的线程池大小
|
||||
dry_run=False, # True = 仅打印计划
|
||||
on_event=callback, # 状态转换回调
|
||||
strategy="async", # sequential | thread | async | dependency
|
||||
max_workers=8, # thread 策略的线程池大小
|
||||
concurrency_limits={"db": 2}, # 按 concurrency_key 限流
|
||||
dry_run=False, # True = 仅打印计划
|
||||
verbose=False, # True = 打印任务生命周期日志
|
||||
on_event=callback, # 状态转换回调(RUNNING/SUCCESS/FAILED/SKIPPED)
|
||||
state=px.JSONBackend("state.json"), # 断点续跑后端
|
||||
continue_on_error=False, # True = 单任务失败不中断整体
|
||||
)
|
||||
```
|
||||
|
||||
### RunReport —— 结果
|
||||
|
||||
```python
|
||||
report["task_name"] # 任务返回值
|
||||
report["task_name"] # 任务返回值
|
||||
report.result_of("task_name") # 完整 TaskResult
|
||||
report.success # 整体是否成功
|
||||
report.summary() # 统计字典
|
||||
report.failed_tasks() # 失败任务名列表
|
||||
report.describe() # 人类可读报告
|
||||
report.success # 整体是否成功
|
||||
report.summary() # 统计字典
|
||||
report.failed_tasks() # 失败任务名列表
|
||||
report.describe() # 人类可读报告
|
||||
```
|
||||
|
||||
## 上下文注入规则
|
||||
@@ -122,21 +194,24 @@ report.describe() # 人类可读报告
|
||||
按顺序求值:
|
||||
|
||||
1. **标注为 `Context`** 的参数 → 接收完整上游结果映射
|
||||
2. **名称匹配依赖** 的参数 → 接收该依赖的结果
|
||||
2. **名称匹配依赖** 的参数 → 接收该依赖的结果(含软依赖,缺失时注入默认值)
|
||||
3. **`**kwargs`** 参数 → 接收所有依赖结果(dict)
|
||||
4. **`TaskSpec.args` / `kwargs`** → 为非依赖参数提供静态值
|
||||
|
||||
```python
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
def aggregate(ctx: px.Context) -> Dict[str, Any]:
|
||||
"""ctx 包含所有 depends_on 任务的返回值。"""
|
||||
return dict(ctx)
|
||||
|
||||
|
||||
def merge(fetch_a: str, fetch_b: str) -> str:
|
||||
"""fetch_a / fetch_b 自动注入。"""
|
||||
return fetch_a + fetch_b
|
||||
|
||||
|
||||
def fetch_user(uid: int) -> dict: # uid 来自 TaskSpec.args
|
||||
...
|
||||
```
|
||||
@@ -148,8 +223,91 @@ def fetch_user(uid: int) -> dict: # uid 来自 TaskSpec.args
|
||||
| `sequential` | 串行 | 调试、CPU 密集 | 直接调用 | 事件循环 |
|
||||
| `thread` | 线程池 | I/O 密集同步 | 线程池 | 不支持 |
|
||||
| `async` | 事件循环 | I/O 密集异步 | 卸载到线程池 | 事件循环 |
|
||||
| `dependency` | 依赖驱动 | 最大化并行度 | 卸载到线程池 | 事件循环 |
|
||||
|
||||
所有策略都遵循 `retries`、`timeout`、上下文注入、状态后端,并发出 `TaskEvent`。
|
||||
所有策略都遵循 `RetryPolicy`、`timeout`、上下文注入、状态后端、`concurrency_limits`,
|
||||
并发出 `TaskEvent`(RUNNING/SUCCESS/FAILED/SKIPPED)。`dependency` 策略无层屏障:
|
||||
任务在其所有硬依赖完成后立即启动。
|
||||
|
||||
## 命令任务
|
||||
|
||||
`TaskSpec` 的 `cmd` 参数支持执行外部命令,无需包装 Python 函数:
|
||||
|
||||
```python
|
||||
graph = px.Graph.from_specs([
|
||||
# 命令列表(推荐,参数无需转义)
|
||||
px.TaskSpec("list_files", cmd=["ls", "-la"]),
|
||||
# shell 字符串(支持管道、重定向)
|
||||
px.TaskSpec("check_git", cmd="git status | head"),
|
||||
# 带工作目录与超时
|
||||
px.TaskSpec("build", cmd=["make", "all"], cwd=Path("/project"), timeout=300),
|
||||
# 命令不存在时自动跳过(而非失败)
|
||||
px.TaskSpec("optional_tool", cmd=["maturin", "build"], skip_if_missing=True),
|
||||
])
|
||||
```
|
||||
|
||||
`verbose=True` 时打印执行的命令、工作目录、返回码与输出;`verbose=False` 时静默执行(失败信息仍包含 stderr)。
|
||||
|
||||
`skip_if_missing=True` 时,`list[str]` 类型的 `cmd` 会通过 `shutil.which` 检查命令是否存在,不存在则跳过任务(标记为 `SKIPPED`)而非失败。适用于构建工具场景,避免因未安装某些工具而导致整个图执行失败。对于 `str`(shell)和 `Callable` 类型的 `cmd`,此参数无效。
|
||||
|
||||
## 条件执行
|
||||
|
||||
`conditions` 参数让任务按条件跳过(标记为 `SKIPPED`):
|
||||
|
||||
```python
|
||||
from pyflowx.conditions import IS_WINDOWS, BuiltinConditions
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
# 仅在 Windows 上运行
|
||||
px.TaskSpec("win_only", cmd=["dir"], conditions=(IS_WINDOWS,)),
|
||||
# 仅在 git 已安装时运行
|
||||
px.TaskSpec(
|
||||
"git_check",
|
||||
cmd=["git", "--version"],
|
||||
conditions=(BuiltinConditions.HAS_INSTALLED("git"),),
|
||||
),
|
||||
# 组合条件
|
||||
px.TaskSpec(
|
||||
"prod_deploy",
|
||||
fn=deploy,
|
||||
conditions=(
|
||||
BuiltinConditions.ENV_VAR_EQUALS("ENV", "prod"),
|
||||
BuiltinConditions.HAS_INSTALLED("docker"),
|
||||
),
|
||||
),
|
||||
])
|
||||
```
|
||||
|
||||
内置条件:`IS_WINDOWS` / `IS_LINUX` / `IS_MACOS` / `IS_POSIX` / `PYTHON_VERSION` / `HAS_INSTALLED` / `ENV_VAR_EXISTS` / `ENV_VAR_EQUALS` / `NOT` / `AND` / `OR`。
|
||||
|
||||
## CLI 运行器
|
||||
|
||||
`CliRunner` 把多个 Graph 映射为命令行子命令,适合构建项目专属构建工具(替代 Makefile):
|
||||
|
||||
```python
|
||||
runner = px.CliRunner(
|
||||
strategy="sequential",
|
||||
description="My Build Tool",
|
||||
graphs={
|
||||
"clean": clean_graph,
|
||||
"build": build_graph,
|
||||
"test": test_graph,
|
||||
},
|
||||
)
|
||||
runner.run_cli() # 解析 sys.argv 并执行
|
||||
```
|
||||
|
||||
命令行用法:
|
||||
|
||||
```bash
|
||||
python build.py clean # 执行 clean 图
|
||||
python build.py build --strategy thread # 覆盖执行策略
|
||||
python build.py test --dry-run # 仅打印执行计划
|
||||
python build.py --list # 列出所有命令
|
||||
python build.py --quiet # 静默模式
|
||||
```
|
||||
|
||||
`verbose=True`(默认)时打印任务生命周期(开始/成功/失败/跳过)与命令输出;`--quiet` 关闭。
|
||||
|
||||
## 示例
|
||||
|
||||
@@ -173,12 +331,25 @@ python examples/async_aggregation.py
|
||||
from pyflowx import JSONBackend
|
||||
|
||||
# 第一次运行:成功结果写入 state.json
|
||||
backend = JSONBackend("state.json")
|
||||
backend = JSONBackend("state.json", ttl=3600) # ttl 秒数,过期条目自动忽略
|
||||
report = px.run(graph, strategy="sequential", state=backend)
|
||||
|
||||
# 第二次运行:已缓存任务自动跳过
|
||||
# 第二次运行:已缓存任务自动跳过(状态为 SKIPPED)
|
||||
report = px.run(graph, strategy="sequential", state=backend)
|
||||
# report.results 中缓存任务状态为 SKIPPED
|
||||
```
|
||||
|
||||
`run()` 内部以 `backend.batch()` 包裹整个执行:所有 `save` 延迟到运行结束时统一落盘一次
|
||||
(`JSONBackend` 从 O(N²) 降为 O(N) 磁盘写入;`MemoryBackend` 为 no-op)。
|
||||
|
||||
**缓存键**:默认存储键为任务名。配置 `cache_key` 函数后,键为 `"name:cache_key_value"`,
|
||||
使不同输入产生独立缓存条目:
|
||||
|
||||
```python
|
||||
px.TaskSpec(
|
||||
"fetch_user",
|
||||
fn=fetch_user,
|
||||
cache_key=lambda ctx: str(ctx.get("uid")), # 不同 uid 独立缓存
|
||||
)
|
||||
```
|
||||
|
||||
## 错误处理
|
||||
@@ -219,14 +390,52 @@ except px.PyFlowXError:
|
||||
|
||||
PyFlowX 专注于**单机 DAG 调度**的极致简洁,适合 ETL、数据处理、CI 流水线等场景。
|
||||
|
||||
## 高级特性
|
||||
|
||||
### 并发限制
|
||||
|
||||
按 `concurrency_key` 分组限流,避免压垮下游资源:
|
||||
|
||||
```python
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("q1", fn=query_db, concurrency_key="db"),
|
||||
px.TaskSpec("q2", fn=query_db, concurrency_key="db"),
|
||||
px.TaskSpec("q3", fn=query_db, concurrency_key="db"),
|
||||
])
|
||||
# 同一时刻最多 2 个 "db" 组任务运行
|
||||
px.run(graph, strategy="async", concurrency_limits={"db": 2})
|
||||
```
|
||||
|
||||
### 任务钩子
|
||||
|
||||
`TaskHooks` 在任务生命周期触发(异常仅记录,不影响任务状态):
|
||||
|
||||
```python
|
||||
hooks = px.TaskHooks(
|
||||
pre_run=lambda spec: print(f"start {spec.name}"),
|
||||
post_run=lambda spec, value: print(f"done {spec.name}"),
|
||||
on_failure=lambda spec, exc: alert(spec.name, exc),
|
||||
)
|
||||
px.TaskSpec("task", fn=work, hooks=hooks)
|
||||
```
|
||||
|
||||
### 优先级
|
||||
|
||||
同层内按 `priority` 降序执行(稳定排序):
|
||||
|
||||
```python
|
||||
px.TaskSpec("low", fn=work, priority=0)
|
||||
px.TaskSpec("high", fn=work, priority=10) # 同层内先执行
|
||||
```
|
||||
|
||||
## 开发
|
||||
|
||||
```bash
|
||||
# 安装开发依赖
|
||||
uv sync --extra dev
|
||||
|
||||
# 运行测试(含覆盖率)
|
||||
uv run pytest --cov=pyflowx --cov-fail-under=100
|
||||
# 运行测试(含覆盖率,阈值 95%)
|
||||
uv run pytest --cov=pyflowx --cov-fail-under=95
|
||||
|
||||
# 类型检查
|
||||
uv run mypy
|
||||
@@ -236,6 +445,22 @@ uv run ruff check src tests examples
|
||||
uv run ruff format --check src tests examples
|
||||
```
|
||||
|
||||
## 模块结构
|
||||
|
||||
| 模块 | 职责 |
|
||||
|------|------|
|
||||
| `task.py` | 纯数据结构:`TaskSpec`、`RetryPolicy`、`TaskHooks`、`TaskStatus` |
|
||||
| `graph.py` | DAG 构建、校验、分层、可视化 |
|
||||
| `compose.py` | 多图组合:`GraphComposer` / `compose` |
|
||||
| `context.py` | 上下文注入:参数名→依赖解析 |
|
||||
| `command.py` | 命令执行:`run_command`(list/shell/Callable) |
|
||||
| `conditions.py` | 条件执行:内置条件与组合器 |
|
||||
| `executors.py` | 执行器与 `run` 入口:四种策略共享模块级辅助 |
|
||||
| `storage.py` | 状态后端:`MemoryBackend` / `JSONBackend`(batch flush) |
|
||||
| `runner.py` | CLI 运行器:`CliRunner` |
|
||||
| `report.py` | 运行结果:`RunReport` / `TaskResult` |
|
||||
| `errors.py` | 错误家族:`PyFlowXError` 子类 |
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT
|
||||
|
||||
+98
-28
@@ -6,28 +6,55 @@ classifiers = [
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Programming Language :: Python :: 3.14",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Topic :: Software Development :: Libraries :: Application Frameworks",
|
||||
]
|
||||
dependencies = ["graphlib_backport >= 1.0.0; python_version < '3.9'"]
|
||||
dependencies = [
|
||||
"graphlib_backport >= 1.0.0; python_version < '3.9'",
|
||||
"typing-extensions>=4.13.2; python_version < '3.10'",
|
||||
]
|
||||
description = "Lightweight, type-safe DAG task scheduler with multi-strategy execution."
|
||||
keywords = ["async", "dag", "scheduler", "task", "workflow"]
|
||||
license = { text = "MIT" }
|
||||
name = "pyflowx"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
version = "0.1.2"
|
||||
version = "0.2.12"
|
||||
|
||||
[project.scripts]
|
||||
pyflowx-demo = "pyflowx.__main__:main"
|
||||
autofmt = "pyflowx.cli.autofmt:main"
|
||||
bumpversion = "pyflowx.cli.bumpversion:main"
|
||||
emlman = "pyflowx.cli.emlmanager:main"
|
||||
filedate = "pyflowx.cli.filedate:main"
|
||||
filelvl = "pyflowx.cli.filelevel:main"
|
||||
foldback = "pyflowx.cli.folderback:main"
|
||||
foldzip = "pyflowx.cli.folderzip:main"
|
||||
gitt = "pyflowx.cli.gittool:main"
|
||||
lscalc = "pyflowx.cli.lscalc:main"
|
||||
msdown = "pyflowx.cli.llm.msdownload:main"
|
||||
packtool = "pyflowx.cli.packtool:main"
|
||||
pdftool = "pyflowx.cli.pdftool:main"
|
||||
piptool = "pyflowx.cli.piptool:main"
|
||||
pymake = "pyflowx.cli.pymake:main"
|
||||
reseticon = "pyflowx.cli.reseticoncache:main"
|
||||
scrcap = "pyflowx.cli.screenshot:main"
|
||||
sglang = "pyflowx.cli.llm.sglang:main"
|
||||
sshcopy = "pyflowx.cli.sshcopyid:main"
|
||||
# dev
|
||||
envdev = "pyflowx.cli.dev.envdev:main"
|
||||
# system
|
||||
clr = "pyflowx.cli.system.clearscreen:main"
|
||||
taskk = "pyflowx.cli.system.taskkill:main"
|
||||
wch = "pyflowx.cli.system.which:main"
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"hatch>=1.14.2",
|
||||
"httpx>=0.28.0",
|
||||
"mypy >= 1.0",
|
||||
"prek>=0.4.5",
|
||||
"pyrefly>=1.1.1",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"pytest-cov>=5.0.0",
|
||||
"pytest-html>=4.1.1",
|
||||
@@ -38,52 +65,95 @@ dev = [
|
||||
"tox-uv>=1.13.1",
|
||||
"tox>=4.25.0",
|
||||
]
|
||||
llm = [
|
||||
"sglang[all]==0.5.10rc0; python_version >= '3.10' and sys_platform == 'linux'",
|
||||
]
|
||||
office = [
|
||||
"pillow>=10.4.0",
|
||||
"pymupdf>=1.24.11",
|
||||
"pypdf>=5.9.0",
|
||||
"pytesseract>=0.3.13",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
build-backend = "hatchling.build"
|
||||
requires = ["hatchling"]
|
||||
|
||||
[[tool.uv.index]]
|
||||
default = true
|
||||
url = "https://mirrors.aliyun.com/pypi/simple/"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/pyflowx"]
|
||||
|
||||
[tool.hatch.build.targets.wheel.force-include]
|
||||
"src/pyflowx/py.typed" = "pyflowx/py.typed"
|
||||
|
||||
[tool.mypy]
|
||||
# mypy 2.x requires a >=3.10 target. We check against 3.10 syntax; the
|
||||
# runtime stays 3.8-compatible via `from __future__ import annotations`
|
||||
# (all annotations are strings at runtime) and the graphlib_backport
|
||||
# conditional dependency for topological sorting.
|
||||
check_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
disallow_untyped_defs = true
|
||||
files = ["src/pyflowx"]
|
||||
ignore_missing_imports = false
|
||||
python_version = "3.8"
|
||||
strict = true
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
|
||||
[tool.uv.sources]
|
||||
pyflowx = { workspace = true }
|
||||
|
||||
[[tool.uv.index]]
|
||||
default = true
|
||||
url = "https://mirrors.aliyun.com/pypi/simple/"
|
||||
|
||||
[dependency-groups]
|
||||
dev = ["pyflowx[dev]"]
|
||||
dev = ["pyflowx[dev,office,llm]"]
|
||||
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
concurrency = ["thread"]
|
||||
omit = ["src/pyflowx/examples/*", "tests/*"]
|
||||
omit = ["src/pyflowx/cli/*", "src/pyflowx/examples/*", "tests/*"]
|
||||
source = ["pyflowx"]
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_lines = ["if TYPE_CHECKING:", "if __name__ == .__main__.:", "pragma: no cover", "raise NotImplementedError"]
|
||||
fail_under = 95
|
||||
show_missing = true
|
||||
exclude_lines = [
|
||||
"if TYPE_CHECKING:",
|
||||
"if __name__ == .__main__.:",
|
||||
"pragma: no cover",
|
||||
"raise NotImplementedError",
|
||||
]
|
||||
fail_under = 95
|
||||
show_missing = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
|
||||
|
||||
# Ruff 配置 - 与 .pre-commit-config.yaml 保持一致
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
target-version = "py38"
|
||||
|
||||
[tool.ruff.lint]
|
||||
ignore = [
|
||||
"E501", # line too long (handled by formatter)
|
||||
"PLC0415", # import should be at top-level (intentional for lazy imports)
|
||||
"PLR0913", # too many arguments
|
||||
"PLR0915", # too many statements (intentional for complex methods)
|
||||
"PLR2004", # magic value comparison
|
||||
"PTH119", # os.path.basename (intentional for sys.argv)
|
||||
"PTH123", # pathlib open() replacement
|
||||
"RUF001", # ambiguous unicode characters in string
|
||||
"RUF002", # ambiguous unicode characters in docstring
|
||||
"RUF003", # ambiguous unicode characters in comment
|
||||
"RUF012", # mutable class attributes (intentional for config)
|
||||
"SIM108", # use ternary operator
|
||||
]
|
||||
select = [
|
||||
"ARG", # flake8-unused-arguments
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"E", # pycodestyle errors
|
||||
"F", # Pyflakes
|
||||
"I", # isort
|
||||
"PL", # Pylint
|
||||
"PTH", # flake8-use-pathlib
|
||||
"RUF", # Ruff-specific rules
|
||||
"SIM", # flake8-simplify
|
||||
"UP", # pyupgrade
|
||||
"W", # pycodestyle warnings
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"**/tests/**" = ["ARG001", "ARG002"]
|
||||
|
||||
[tool.pyrefly]
|
||||
preset = "strict"
|
||||
project-includes = ["**/*.ipynb", "**/*.py*"]
|
||||
python-version = "3.8"
|
||||
|
||||
+99
-27
@@ -4,9 +4,15 @@
|
||||
--------
|
||||
* :class:`TaskSpec` —— 不可变任务描述符(唯一需要配置的东西)。
|
||||
* :class:`Graph` —— 由一组 spec 构建的 DAG;负责校验、分层、可视化。
|
||||
* :func:`run` —— 以 ``sequential`` / ``thread`` / ``async`` 策略执行图。
|
||||
* :func:`run` ——以 ``sequential`` / ``thread`` / ``async`` / ``dependency``
|
||||
策略执行图。
|
||||
* :class:`RunReport` —— 类型化、可查询的运行结果。
|
||||
* :class:`Context` —— 整体上下文注入的标注标记。
|
||||
* :class:`RetryPolicy` —— 重试策略(max_attempts/delay/backoff/jitter/retry_on)。
|
||||
* :class:`TaskHooks` —— 任务生命周期钩子(pre_run/post_run/on_failure)。
|
||||
* :class:`GraphDefaults` —— 图级默认值。
|
||||
* :func:`compose` —— 编程式组合多图。
|
||||
* :func:`task_template` —— 批量生成相似 TaskSpec 的工厂。
|
||||
* 状态后端::class:`StateBackend`、:class:`MemoryBackend`、:class:`JSONBackend`。
|
||||
|
||||
快速上手
|
||||
@@ -18,14 +24,51 @@
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("extract", extract),
|
||||
px.TaskSpec("double", double, ("extract",)),
|
||||
px.TaskSpec("double", double, depends_on=("extract",)),
|
||||
])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
print(report["double"]) # [2, 4, 6]
|
||||
|
||||
命令行任务示例
|
||||
--------------
|
||||
import pyflowx as px
|
||||
from pyflowx.conditions import IS_WINDOWS, BuiltinConditions
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("list_files", cmd=["ls", "-la"]),
|
||||
px.TaskSpec("check_git", cmd="git status"),
|
||||
px.TaskSpec(
|
||||
"win_only",
|
||||
cmd=["dir"],
|
||||
conditions=(IS_WINDOWS,)
|
||||
),
|
||||
px.TaskSpec(
|
||||
"git_check",
|
||||
cmd=["git", "--version"],
|
||||
conditions=(BuiltinConditions.HAS_INSTALLED("git"),)
|
||||
),
|
||||
px.TaskSpec(
|
||||
"optional_build",
|
||||
cmd=["maturin", "build"],
|
||||
skip_if_missing=True
|
||||
),
|
||||
])
|
||||
report = px.run(graph)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .command import run_command
|
||||
from .compose import GraphComposer, compose
|
||||
from .conditions import (
|
||||
IS_LINUX,
|
||||
IS_MACOS,
|
||||
IS_POSIX,
|
||||
IS_WINDOWS,
|
||||
BuiltinConditions,
|
||||
Condition,
|
||||
Constants,
|
||||
)
|
||||
from .context import Context, build_call_args, describe_injection
|
||||
from .errors import (
|
||||
CycleError,
|
||||
@@ -37,39 +80,68 @@ from .errors import (
|
||||
TaskFailedError,
|
||||
TaskTimeoutError,
|
||||
)
|
||||
from .executors import run
|
||||
from .graph import Graph
|
||||
from .executors import Strategy, run
|
||||
from .graph import Graph, GraphDefaults
|
||||
from .report import RunReport
|
||||
from .runner import CliExitCode, CliRunner
|
||||
from .storage import JSONBackend, MemoryBackend, StateBackend
|
||||
from .task import TaskEvent, TaskResult, TaskSpec, TaskStatus
|
||||
from .task import (
|
||||
CacheKeyFn,
|
||||
RetryPolicy,
|
||||
TaskCmd,
|
||||
TaskEvent,
|
||||
TaskHooks,
|
||||
TaskResult,
|
||||
TaskSpec,
|
||||
TaskStatus,
|
||||
cmd,
|
||||
task,
|
||||
task_template,
|
||||
)
|
||||
|
||||
__version__ = "0.1.2"
|
||||
__version__ = "0.3.6"
|
||||
|
||||
__all__ = [
|
||||
# 核心类型
|
||||
"IS_LINUX",
|
||||
"IS_MACOS",
|
||||
"IS_POSIX",
|
||||
"IS_WINDOWS",
|
||||
"BuiltinConditions",
|
||||
"CacheKeyFn",
|
||||
"CliExitCode",
|
||||
"CliRunner",
|
||||
"Condition",
|
||||
"Constants",
|
||||
"Context",
|
||||
"CycleError",
|
||||
"DuplicateTaskError",
|
||||
"Graph",
|
||||
"GraphComposer",
|
||||
"GraphDefaults",
|
||||
"InjectionError",
|
||||
"JSONBackend",
|
||||
"MemoryBackend",
|
||||
"MissingDependencyError",
|
||||
"PyFlowXError",
|
||||
"RetryPolicy",
|
||||
"RunReport",
|
||||
"StateBackend",
|
||||
"StorageError",
|
||||
"Strategy",
|
||||
"TaskCmd",
|
||||
"TaskEvent",
|
||||
"TaskFailedError",
|
||||
"TaskHooks",
|
||||
"TaskResult",
|
||||
"TaskSpec",
|
||||
"TaskStatus",
|
||||
"TaskResult",
|
||||
"TaskEvent",
|
||||
"Context",
|
||||
"Graph",
|
||||
"RunReport",
|
||||
# 执行
|
||||
"run",
|
||||
# 状态后端
|
||||
"StateBackend",
|
||||
"MemoryBackend",
|
||||
"JSONBackend",
|
||||
# 错误
|
||||
"PyFlowXError",
|
||||
"DuplicateTaskError",
|
||||
"MissingDependencyError",
|
||||
"CycleError",
|
||||
"TaskFailedError",
|
||||
"TaskTimeoutError",
|
||||
"InjectionError",
|
||||
"StorageError",
|
||||
# 辅助(高级)
|
||||
"build_call_args",
|
||||
"cmd",
|
||||
"compose",
|
||||
"describe_injection",
|
||||
"run",
|
||||
"run_command",
|
||||
"task",
|
||||
"task_template",
|
||||
]
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
from pyflowx.examples.async_aggregation import main as async_aggregation_main
|
||||
from pyflowx.examples.etl_pipeline import main as etl_pipeline_main
|
||||
from pyflowx.examples.parallel_run import main as parallel_run_main
|
||||
|
||||
|
||||
def main():
|
||||
async_aggregation_main()
|
||||
etl_pipeline_main()
|
||||
parallel_run_main()
|
||||
@@ -0,0 +1,282 @@
|
||||
"""自动格式化工具模块.
|
||||
|
||||
提供 Python 代码自动格式化的常用功能封装,
|
||||
支持 docstring 自动生成、pyproject.toml 配置同步等功能.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
try:
|
||||
import tomllib # noqa: F401
|
||||
|
||||
HAS_TOMLLIB = True
|
||||
except ImportError:
|
||||
HAS_TOMLLIB = False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 配置
|
||||
# ============================================================================
|
||||
|
||||
IGNORE_PATTERNS = [
|
||||
"__pycache__",
|
||||
"*.pyc",
|
||||
"*.pyo",
|
||||
".git",
|
||||
".venv",
|
||||
".idea",
|
||||
".vscode",
|
||||
"*.egg-info",
|
||||
"dist",
|
||||
"build",
|
||||
".pytest_cache",
|
||||
".tox",
|
||||
".mypy_cache",
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 辅助函数
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def format_with_ruff(target: Path, fix: bool = True) -> None:
|
||||
"""使用 ruff 格式化代码.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target : Path
|
||||
目标路径
|
||||
fix : bool
|
||||
是否自动修复
|
||||
"""
|
||||
cmd = ["ruff", "format", str(target)]
|
||||
if fix:
|
||||
cmd.append("--fix")
|
||||
|
||||
subprocess.run(cmd, check=True)
|
||||
print(f"ruff format 完成: {target}")
|
||||
|
||||
|
||||
def lint_with_ruff(target: Path, fix: bool = True) -> None:
|
||||
"""使用 ruff 检查代码.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
target : Path
|
||||
目标路径
|
||||
fix : bool
|
||||
是否自动修复
|
||||
"""
|
||||
cmd = ["ruff", "check", str(target)]
|
||||
if fix:
|
||||
cmd.extend(["--fix", "--unsafe-fixes"])
|
||||
|
||||
subprocess.run(cmd, check=True)
|
||||
print(f"ruff check 完成: {target}")
|
||||
|
||||
|
||||
def add_docstring(file_path: Path, docstring: str) -> bool:
|
||||
"""为文件添加 docstring.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_path : Path
|
||||
文件路径
|
||||
docstring : str
|
||||
docstring 内容
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
是否成功添加
|
||||
"""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
tree = ast.parse(content)
|
||||
|
||||
# 检查是否已有 docstring
|
||||
first_node = tree.body[0] if tree.body else None
|
||||
if first_node and isinstance(first_node, ast.Expr) and isinstance(first_node.value, ast.Constant):
|
||||
return False
|
||||
|
||||
# 添加 docstring
|
||||
lines = content.splitlines()
|
||||
doc_lines = docstring.splitlines()
|
||||
doc_lines.append("")
|
||||
new_content = "\n".join(doc_lines + lines)
|
||||
|
||||
file_path.write_text(new_content, encoding="utf-8")
|
||||
print(f"添加 docstring: {file_path}")
|
||||
return True
|
||||
|
||||
except (OSError, UnicodeDecodeError, SyntaxError) as e:
|
||||
print(f"处理失败: {file_path} - {e}")
|
||||
return False
|
||||
|
||||
|
||||
def generate_module_docstring(file_path: Path) -> str:
|
||||
"""生成模块 docstring.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_path : Path
|
||||
文件路径
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
生成的 docstring
|
||||
"""
|
||||
stem = file_path.stem
|
||||
parent = file_path.parent.name
|
||||
|
||||
# 关键词匹配
|
||||
keywords = {
|
||||
"cli": f"Command-line interface for {parent}",
|
||||
"gui": f"Graphical user interface for {parent}",
|
||||
"core": f"Core functionality for {parent}",
|
||||
"util": f"Utility functions for {parent}",
|
||||
"model": f"Data models for {parent}",
|
||||
"test": f"Tests for {parent}",
|
||||
}
|
||||
|
||||
for key, desc in keywords.items():
|
||||
if key in stem.lower():
|
||||
return f'"""{desc}."""'
|
||||
|
||||
return f'"""{stem.replace("_", " ").title()} module."""'
|
||||
|
||||
|
||||
def auto_add_docstrings(root_dir: Path) -> int:
|
||||
"""自动为所有 Python 文件添加 docstring.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
root_dir : Path
|
||||
根目录
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
添加的 docstring 数量
|
||||
"""
|
||||
count = 0
|
||||
for py_file in root_dir.rglob("*.py"):
|
||||
# 跳过忽略的文件
|
||||
if any(pattern in str(py_file) for pattern in IGNORE_PATTERNS):
|
||||
continue
|
||||
|
||||
docstring = generate_module_docstring(py_file)
|
||||
if add_docstring(py_file, docstring):
|
||||
count += 1
|
||||
|
||||
print(f"共添加 {count} 个 docstring")
|
||||
return count
|
||||
|
||||
|
||||
def sync_pyproject_config(root_dir: Path) -> None:
|
||||
"""同步 pyproject.toml 配置到子项目.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
root_dir : Path
|
||||
根目录
|
||||
"""
|
||||
main_toml = root_dir / "pyproject.toml"
|
||||
if not main_toml.exists():
|
||||
print(f"主项目配置文件不存在: {main_toml}")
|
||||
return
|
||||
|
||||
# 查找所有子项目的 pyproject.toml
|
||||
sub_tomls = [p for p in root_dir.rglob("pyproject.toml") if p != main_toml and ".venv" not in str(p)]
|
||||
|
||||
if not sub_tomls:
|
||||
print("没有找到子项目的 pyproject.toml")
|
||||
return
|
||||
|
||||
print(f"找到 {len(sub_tomls)} 个子项目配置文件")
|
||||
|
||||
# 对每个子项目调用 ruff format
|
||||
for sub_toml in sub_tomls:
|
||||
subprocess.run(["ruff", "format", str(sub_toml)], check=False)
|
||||
|
||||
print("配置同步完成")
|
||||
|
||||
|
||||
def format_all(root_dir: Path) -> None:
|
||||
"""格式化所有 Python 文件.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
root_dir : Path
|
||||
根目录
|
||||
"""
|
||||
# 使用 ruff format
|
||||
subprocess.run(["ruff", "format", str(root_dir)], check=True)
|
||||
|
||||
# 使用 ruff check
|
||||
subprocess.run(["ruff", "check", "--fix", "--unsafe-fixes", str(root_dir)], check=True)
|
||||
|
||||
print(f"格式化完成: {root_dir}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CLI Runner
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""自动格式化工具主函数."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="AutoFmt - 自动格式化工具",
|
||||
usage="autofmt <command> [options]",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="可用命令")
|
||||
|
||||
# ruff format 命令
|
||||
format_parser = subparsers.add_parser("fmt", help="使用 ruff 格式化代码")
|
||||
format_parser.add_argument("--target", type=str, default=".", help="目标路径")
|
||||
|
||||
# ruff check 命令
|
||||
lint_parser = subparsers.add_parser("lint", help="使用 ruff 检查代码")
|
||||
lint_parser.add_argument("--target", type=str, default=".", help="目标路径")
|
||||
lint_parser.add_argument("--fix", action="store_true", help="自动修复")
|
||||
|
||||
# 自动添加 docstring 命令
|
||||
doc_parser = subparsers.add_parser("doc", help="自动添加 docstring")
|
||||
doc_parser.add_argument("--root-dir", type=str, default=".", help="根目录")
|
||||
|
||||
# 同步配置命令
|
||||
sync_parser = subparsers.add_parser("sync", help="同步 pyproject.toml 配置")
|
||||
sync_parser.add_argument("--root-dir", type=str, default=".", help="根目录")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "fmt":
|
||||
graph = px.Graph.from_specs([px.TaskSpec("ruff_format", cmd=["ruff", "format", args.target], verbose=True)])
|
||||
elif args.command == "lint":
|
||||
cmd = ["ruff", "check", args.target]
|
||||
if args.fix:
|
||||
cmd.extend(["--fix", "--unsafe-fixes"])
|
||||
graph = px.Graph.from_specs([px.TaskSpec("ruff_check", cmd=cmd, verbose=True)])
|
||||
elif args.command == "doc":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("auto_docstring", fn=auto_add_docstrings, args=(Path(args.root_dir),), verbose=True)
|
||||
])
|
||||
elif args.command == "sync":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("sync_config", fn=sync_pyproject_config, args=(Path(args.root_dir),), verbose=True)
|
||||
])
|
||||
else:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
px.run(graph, strategy="thread")
|
||||
@@ -0,0 +1,263 @@
|
||||
"""版本号自动管理工具.
|
||||
|
||||
使用 TaskSpec 模式实现, 支持语义化版本管理和多文件格式的版本号更新.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Literal, get_args
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
BumpVersionType = Literal["patch", "minor", "major"]
|
||||
|
||||
# 针对不同文件类型的版本号匹配模式
|
||||
# pyproject.toml: version = "X.Y.Z" 或 version = 'X.Y.Z'
|
||||
_PYPROJECT_VERSION_PATTERN = re.compile(
|
||||
r'(?:^|\n)\s*version\s*=\s*["\']'
|
||||
r"(?P<major>0|[1-9]\d*)\.(?P<minor>0|[1-9]\d*)\.(?P<patch>0|[1-9]\d*)"
|
||||
r"(?:-(?P<prerelease>(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?"
|
||||
r"(?:\+(?P<buildmetadata>[0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?"
|
||||
r'["\']',
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
# __init__.py: __version__ = "X.Y.Z" 或 __version__ = 'X.Y.Z'
|
||||
_INIT_VERSION_PATTERN = re.compile(
|
||||
r'(?:^|\n)\s*__version__\s*=\s*["\']'
|
||||
r"(?P<major>0|[1-9]\d*)\.(?P<minor>0|[1-9]\d*)\.(?P<patch>0|[1-9]\d*)"
|
||||
r"(?:-(?P<prerelease>(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?"
|
||||
r"(?:\+(?P<buildmetadata>[0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*))?"
|
||||
r'["\']',
|
||||
re.MULTILINE,
|
||||
)
|
||||
|
||||
|
||||
def _get_pattern_for_file(file_name: str) -> re.Pattern[str] | None:
|
||||
"""根据文件类型获取对应的正则表达式.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_name : str
|
||||
文件名
|
||||
|
||||
Returns
|
||||
-------
|
||||
re.Pattern[str] | None
|
||||
对应的正则表达式,如果无法确定则返回 None
|
||||
"""
|
||||
if file_name == "pyproject.toml":
|
||||
return _PYPROJECT_VERSION_PATTERN
|
||||
if file_name == "__init__.py":
|
||||
return _INIT_VERSION_PATTERN
|
||||
return None
|
||||
|
||||
|
||||
def _calculate_new_version(major: int, minor: int, patch: int, part: BumpVersionType) -> str:
|
||||
"""计算新版本号.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
major : int
|
||||
当前主版本号
|
||||
minor : int
|
||||
当前次版本号
|
||||
patch : int
|
||||
当前补丁版本号
|
||||
part : BumpVersionType
|
||||
要更新的部分
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
新版本号
|
||||
"""
|
||||
if part == "major":
|
||||
return f"{major + 1}.0.0"
|
||||
if part == "minor":
|
||||
return f"{major}.{minor + 1}.0"
|
||||
return f"{major}.{minor}.{patch + 1}"
|
||||
|
||||
|
||||
def _build_replacement_string(original_match: str, new_version: str, file_name: str) -> str:
|
||||
"""构建替换字符串,保留原始格式.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
original_match : str
|
||||
原始匹配的字符串
|
||||
new_version : str
|
||||
新版本号
|
||||
file_name : str
|
||||
文件名
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
替换字符串
|
||||
"""
|
||||
quote_char = '"' if '"' in original_match else "'"
|
||||
|
||||
if file_name == "pyproject.toml":
|
||||
prefix_match = re.match(r'(\s*version\s*=\s*)["\']', original_match)
|
||||
prefix = prefix_match.group(1) if prefix_match else "version = "
|
||||
return f"{prefix}{quote_char}{new_version}{quote_char}"
|
||||
|
||||
if file_name == "__init__.py":
|
||||
prefix_match = re.match(r'(\s*__version__\s*=\s*)["\']', original_match)
|
||||
prefix = prefix_match.group(1) if prefix_match else "__version__ = "
|
||||
return f"{prefix}{quote_char}{new_version}{quote_char}"
|
||||
|
||||
return new_version
|
||||
|
||||
|
||||
def bump_file_version(file_path: Path, part: BumpVersionType = "patch") -> str | None:
|
||||
"""更新文件中的版本号.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_path : Path
|
||||
要更新的文件路径
|
||||
part : BumpVersionType
|
||||
版本部分: patch, minor, major
|
||||
|
||||
Returns
|
||||
-------
|
||||
str | None
|
||||
更新后的新版本号,如果文件中未找到版本号则返回 None
|
||||
"""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
except Exception as e:
|
||||
print(f"读取文件 {file_path} 时出错: {e}")
|
||||
raise
|
||||
|
||||
# 获取文件对应的正则表达式
|
||||
pattern = _get_pattern_for_file(file_path.name)
|
||||
|
||||
# 对于未知文件类型,尝试两种模式
|
||||
if pattern:
|
||||
match = pattern.search(content)
|
||||
else:
|
||||
match = _PYPROJECT_VERSION_PATTERN.search(content) or _INIT_VERSION_PATTERN.search(content)
|
||||
|
||||
if not match:
|
||||
print(f"文件 {file_path} 中未找到版本号模式")
|
||||
return None
|
||||
|
||||
# 提取当前版本号
|
||||
major = int(match.group("major"))
|
||||
minor = int(match.group("minor"))
|
||||
patch = int(match.group("patch"))
|
||||
|
||||
# 计算新版本号
|
||||
new_version = _calculate_new_version(major, minor, patch, part)
|
||||
|
||||
# 构建替换字符串
|
||||
original_match = match.group(0)
|
||||
replacement = _build_replacement_string(original_match, new_version, file_path.name)
|
||||
|
||||
# 更新文件内容
|
||||
content = content.replace(original_match, replacement)
|
||||
|
||||
try:
|
||||
file_path.write_text(content, encoding="utf-8")
|
||||
except Exception as e:
|
||||
print(f"更新文件 {file_path} 版本号时出错: {e}")
|
||||
raise
|
||||
|
||||
return new_version
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""版本号管理工具主函数."""
|
||||
parser = argparse.ArgumentParser(description="BumpVersion - 版本号自动管理工具")
|
||||
parser.add_argument(
|
||||
"part",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default="patch",
|
||||
choices=get_args(BumpVersionType),
|
||||
help=f"版本部分: {get_args(BumpVersionType)}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-tag",
|
||||
action="store_true",
|
||||
help="提交后不创建 git tag",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
part = args.part
|
||||
|
||||
# 搜索文件,排除常见的虚拟环境和缓存目录
|
||||
ignore_dirs = {".venv", "venv", ".git", "__pycache__", ".tox", "node_modules", "build", "dist", ".eggs"}
|
||||
all_files = set()
|
||||
|
||||
for pattern in ["__init__.py", "pyproject.toml"]:
|
||||
for file in Path.cwd().rglob(pattern):
|
||||
# 检查路径中是否包含需要忽略的目录
|
||||
if not any(ignore_dir in file.parts for ignore_dir in ignore_dirs):
|
||||
all_files.add(file)
|
||||
|
||||
if not all_files:
|
||||
print("未找到包含版本号的文件")
|
||||
return
|
||||
|
||||
print(f"找到 {len(all_files)} 个文件需要更新版本号")
|
||||
for file in sorted(all_files):
|
||||
print(f" - {file.relative_to(Path.cwd())}")
|
||||
|
||||
# 更新所有文件的版本号(使用顺序执行避免竞争条件)
|
||||
# 使用相对于 cwd 的路径作为任务名,确保唯一性
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
f"bump_{file.relative_to(Path.cwd())}".replace("\\", "_").replace("/", "_").replace(".", "_"),
|
||||
fn=bump_file_version,
|
||||
args=(file, part),
|
||||
)
|
||||
for file in all_files
|
||||
])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
|
||||
# 收集新版本号(取第一个成功的结果)
|
||||
new_version = None
|
||||
for task_name in report:
|
||||
result = report[task_name]
|
||||
if result is not None:
|
||||
new_version = result
|
||||
break
|
||||
|
||||
if not new_version:
|
||||
print("未能获取新版本号")
|
||||
return
|
||||
|
||||
print(f"版本号已更新为: {new_version}")
|
||||
|
||||
# 提交修改并创建标签
|
||||
tasks = [
|
||||
px.TaskSpec("git_add", cmd=["git", "add", "."]),
|
||||
px.TaskSpec(
|
||||
"git_commit",
|
||||
cmd=["git", "commit", "-m", f"bump version to {new_version}"],
|
||||
depends_on=("git_add",),
|
||||
),
|
||||
]
|
||||
|
||||
if not args.no_tag:
|
||||
tag_name = f"v{new_version}"
|
||||
tasks.append(
|
||||
px.TaskSpec(
|
||||
"git_tag",
|
||||
cmd=["git", "tag", "-a", tag_name, "-m", f"Release {tag_name}"],
|
||||
depends_on=("git_commit",),
|
||||
)
|
||||
)
|
||||
|
||||
graph = px.Graph.from_specs(tasks)
|
||||
px.run(graph, strategy="sequential")
|
||||
|
||||
if not args.no_tag:
|
||||
print(f"已创建标签: v{new_version}")
|
||||
@@ -0,0 +1,331 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Literal, get_args
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.conditions import BuiltinConditions
|
||||
from pyflowx.tasks.system import setenv_group, write_file
|
||||
|
||||
# ============================================================================
|
||||
# Mirror 配置
|
||||
# ============================================================================
|
||||
DOWNLOAD_MIRROR_SCRIPT: str = "curl -sSL https://linuxmirrors.cn/main.sh -o /tmp/linuxmirrors.sh"
|
||||
INSTALL_MIRROR_SCRIPT: str = "sudo bash /tmp/linuxmirrors.sh"
|
||||
|
||||
# ============================================================================
|
||||
# Python 配置
|
||||
# ============================================================================
|
||||
PyMirrorType = Literal["tsinghua", "aliyun", "huaweicloud", "ustc", "zju"]
|
||||
|
||||
PIP_INDEX_URLS: dict[PyMirrorType, str] = {
|
||||
"tsinghua": "https://pypi.tuna.tsinghua.edu.cn/simple",
|
||||
"aliyun": "https://mirrors.aliyun.com/pypi/simple/",
|
||||
"huaweicloud": "https://mirrors.huaweicloud.com/repository/pypi/simple/",
|
||||
"ustc": "https://pypi.mirrors.ustc.edu.cn/simple/",
|
||||
"zju": "https://mirrors.zju.edu.cn/pypi/simple/",
|
||||
}
|
||||
|
||||
PIP_TRUSTED_HOSTS: dict[PyMirrorType, str] = {
|
||||
"tsinghua": "pypi.tuna.tsinghua.edu.cn",
|
||||
"aliyun": "mirrors.aliyun.com",
|
||||
"huaweicloud": "mirrors.huaweicloud.com",
|
||||
"ustc": "pypi.mirrors.ustc.edu.cn",
|
||||
"zju": "mirrors.zju.edu.cn",
|
||||
}
|
||||
PIP_CONFIG_PATH = Path.home() / ".pip" / "pip.conf" if BuiltinConditions.IS_LINUX() else Path.home() / "pip" / "pip.ini"
|
||||
|
||||
UV_INDEX_URLS = PIP_INDEX_URLS
|
||||
UV_PYTHON_INSTALL_MIRROR: str = "https://registry.npmmirror.com/-/binary/python-build-standalone"
|
||||
|
||||
# ============================================================================
|
||||
# Conda 配置
|
||||
# ============================================================================
|
||||
CondaMirrorType = Literal["tsinghua", "ustc", "bsfu", "aliyun"]
|
||||
|
||||
CONDA_MIRROR_URLS: dict[CondaMirrorType, list[str]] = {
|
||||
"tsinghua": [
|
||||
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/",
|
||||
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/",
|
||||
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/r/",
|
||||
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/msys2/",
|
||||
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/pro/",
|
||||
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/",
|
||||
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/bioconda/",
|
||||
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/menpo/",
|
||||
"https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/",
|
||||
],
|
||||
"ustc": [
|
||||
"https://mirrors.ustc.edu.cn/anaconda/pkgs/main/",
|
||||
"https://mirrors.ustc.edu.cn/anaconda/pkgs/free/",
|
||||
"https://mirrors.ustc.edu.cn/anaconda/pkgs/r/",
|
||||
"https://mirrors.ustc.edu.cn/anaconda/pkgs/msys2/",
|
||||
"https://mirrors.ustc.edu.cn/anaconda/pkgs/pro/",
|
||||
"https://mirrors.ustc.edu.cn/anaconda/pkgs/dev/",
|
||||
"https://mirrors.ustc.edu.cn/anaconda/cloud/conda-forge/",
|
||||
"https://mirrors.ustc.edu.cn/anaconda/cloud/bioconda/",
|
||||
"https://mirrors.ustc.edu.cn/anaconda/cloud/menpo/",
|
||||
"https://mirrors.ustc.edu.cn/anaconda/cloud/pytorch/",
|
||||
],
|
||||
"bsfu": [
|
||||
"https://mirrors.bsfu.edu.cn/anaconda/pkgs/main/",
|
||||
"https://mirrors.bsfu.edu.cn/anaconda/pkgs/free/",
|
||||
"https://mirrors.bsfu.edu.cn/anaconda/pkgs/r/",
|
||||
"https://mirrors.bsfu.edu.cn/anaconda/pkgs/msys2/",
|
||||
"https://mirrors.bsfu.edu.cn/anaconda/pkgs/pro/",
|
||||
"https://mirrors.bsfu.edu.cn/anaconda/pkgs/dev/",
|
||||
"https://mirrors.bsfu.edu.cn/anaconda/cloud/conda-forge/",
|
||||
"https://mirrors.bsfu.edu.cn/anaconda/cloud/bioconda/",
|
||||
"https://mirrors.bsfu.edu.cn/anaconda/cloud/menpo/",
|
||||
"https://mirrors.bsfu.edu.cn/anaconda/cloud/pytorch/",
|
||||
],
|
||||
"aliyun": [
|
||||
"https://mirrors.aliyun.com/anaconda/pkgs/main/",
|
||||
"https://mirrors.aliyun.com/anaconda/pkgs/free/",
|
||||
"https://mirrors.aliyun.com/anaconda/pkgs/r/",
|
||||
"https://mirrors.aliyun.com/anaconda/pkgs/msys2/",
|
||||
"https://mirrors.aliyun.com/anaconda/pkgs/pro/",
|
||||
"https://mirrors.aliyun.com/anaconda/pkgs/dev/",
|
||||
"https://mirrors.aliyun.com/anaconda/cloud/conda-forge/",
|
||||
"https://mirrors.aliyun.com/anaconda/cloud/bioconda/",
|
||||
"https://mirrors.aliyun.com/anaconda/cloud/menpo/",
|
||||
"https://mirrors.aliyun.com/anaconda/cloud/pytorch/",
|
||||
],
|
||||
}
|
||||
CONDA_CONFIG_PATH = Path.home() / ".condarc"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Qt 配置
|
||||
# ============================================================================
|
||||
|
||||
QT_LIBS: list[str] = [
|
||||
"build-essential",
|
||||
"libgl1",
|
||||
"libegl1",
|
||||
"libglib2.0-0",
|
||||
"libfontconfig1",
|
||||
"libfreetype6",
|
||||
"libxkbcommon0",
|
||||
"libdbus-1-3",
|
||||
"libxcb-xinerama0",
|
||||
"libxcb-icccm4",
|
||||
"libxcb-image0",
|
||||
"libxcb-keysyms1",
|
||||
"libxcb-randr0",
|
||||
"libxcb-render-util0",
|
||||
"libxcb-shape0",
|
||||
"libxcb-xfixes0",
|
||||
"libxcb-cursor0",
|
||||
]
|
||||
|
||||
CHINESE_FONTS: list[str] = [
|
||||
"fonts-noto-cjk",
|
||||
"fonts-wqy-microhei",
|
||||
"fonts-wqy-zenhei",
|
||||
"fonts-noto-color-emoji",
|
||||
]
|
||||
|
||||
# ============================================================================
|
||||
# Rust 配置
|
||||
# ============================================================================
|
||||
RustMirrorType = Literal["tsinghua", "ustc", "aliyun"]
|
||||
RustVersionType = Literal["stable", "nightly", "beta"]
|
||||
DEFAULT_RUST_VERSION: RustVersionType = "stable"
|
||||
DEFAULT_MIRROR: RustMirrorType = "tsinghua"
|
||||
|
||||
RUSTUP_MIRRORS: dict[RustMirrorType, dict[str, str]] = {
|
||||
"tsinghua": {
|
||||
"RUSTUP_DIST_SERVER": "https://mirrors.tuna.tsinghua.edu.cn/rustup",
|
||||
"RUSTUP_UPDATE_ROOT": "https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup",
|
||||
"TOML_REGISTRY": "https://mirrors.tuna.tsinghua.edu.cn/crates.io-index/",
|
||||
},
|
||||
"aliyun": {
|
||||
"RUSTUP_DIST_SERVER": "https://mirrors.aliyun.com/rustup",
|
||||
"RUSTUP_UPDATE_ROOT": "https://mirrors.aliyun.com/rustup/rustup",
|
||||
"TOML_REGISTRY": "https://mirrors.aliyun.com/crates.io-index/",
|
||||
},
|
||||
"ustc": {
|
||||
"RUSTUP_DIST_SERVER": "https://mirrors.ustc.edu.cn/rust-static",
|
||||
"RUSTUP_UPDATE_ROOT": "https://mirrors.ustc.edu.cn/rust-static/rustup",
|
||||
"TOML_REGISTRY": "https://mirrors.ustc.edu.cn/crates.io-index/",
|
||||
},
|
||||
}
|
||||
RUSTUP_DOWNLOAD_URL_LINUX = "https://mirrors.aliyun.com/repo/rust/rustup-init.sh"
|
||||
RUSTUP_DOWNLOAD_URL_WINDOWS = "https://static.rust-lang.org/rustup/dist/x86_64-pc-windows-msvc/rustup-init.exe"
|
||||
RUST_CONFIG_PATH = Path.home() / ".cargo" / "config.toml"
|
||||
RUST_SCCACHE_DIR: Path = Path.home() / ".cargo" / "sccache"
|
||||
RUST_SCCACHE_CACHE_SIZE: str = "20G"
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""主函数."""
|
||||
parser = argparse.ArgumentParser(description="环境开发工具")
|
||||
parser.add_argument(
|
||||
"--python-mirror",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="tsinghua",
|
||||
choices=get_args(PyMirrorType),
|
||||
help="Python 镜像源",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--conda-mirror",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="tsinghua",
|
||||
choices=get_args(CondaMirrorType),
|
||||
help="Conda 镜镜像源",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rust-mirror",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default=DEFAULT_MIRROR,
|
||||
choices=get_args(RustMirrorType),
|
||||
help="Rust 镜像源",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rust-version",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default=DEFAULT_RUST_VERSION,
|
||||
choices=get_args(RustVersionType),
|
||||
help=f"Rust 版本, 推荐: {get_args(RustVersionType)}",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
python_mirror = args.python_mirror
|
||||
conda_mirror_urls = CONDA_MIRROR_URLS[args.conda_mirror]
|
||||
rust_mirror = args.rust_mirror
|
||||
rust_version = args.rust_version
|
||||
|
||||
# 确保配置文件目录存在
|
||||
PIP_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
CONDA_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
RUST_CONFIG_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
RUST_SCCACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用 conditions 自动控制任务执行
|
||||
graph = px.Graph.from_specs([
|
||||
# 系统镜像配置(仅 Linux 且未配置国内镜像)
|
||||
px.TaskSpec(
|
||||
"download_mirror",
|
||||
cmd=DOWNLOAD_MIRROR_SCRIPT,
|
||||
conditions=(
|
||||
BuiltinConditions.IS_LINUX(),
|
||||
BuiltinConditions.NOT(
|
||||
BuiltinConditions.OR(
|
||||
*[
|
||||
BuiltinConditions.FILE_CONTENT_EXISTS(f, m)
|
||||
for f in [
|
||||
"/etc/apt/sources.list",
|
||||
"/etc/apt/sources.list.d/ubuntu.sources",
|
||||
]
|
||||
for m in get_args(PyMirrorType)
|
||||
],
|
||||
)
|
||||
),
|
||||
),
|
||||
verbose=True,
|
||||
),
|
||||
px.TaskSpec(
|
||||
"install_mirror",
|
||||
cmd=INSTALL_MIRROR_SCRIPT,
|
||||
depends_on=("download_mirror",),
|
||||
verbose=True,
|
||||
),
|
||||
# 安装 Qt 依赖(仅 Linux)
|
||||
px.TaskSpec(
|
||||
"install_qt_libs",
|
||||
cmd=["sudo", "apt", "install", "-y", *QT_LIBS],
|
||||
conditions=(BuiltinConditions.IS_LINUX(),),
|
||||
depends_on=("install_mirror",),
|
||||
allow_upstream_skip=True,
|
||||
verbose=True,
|
||||
),
|
||||
# 安装中文字体(仅 Linux)
|
||||
px.TaskSpec(
|
||||
"install_fonts",
|
||||
cmd=["sudo", "apt", "install", "-y", *CHINESE_FONTS],
|
||||
conditions=(BuiltinConditions.IS_LINUX(),),
|
||||
depends_on=("install_mirror",),
|
||||
allow_upstream_skip=True,
|
||||
verbose=True,
|
||||
),
|
||||
# 设置 Python 环境变量
|
||||
*setenv_group({
|
||||
"PIP_INDEX_URL": PIP_INDEX_URLS[python_mirror],
|
||||
"PIP_TRUSTED_HOSTS": PIP_TRUSTED_HOSTS[python_mirror],
|
||||
"UV_INDEX_URL": UV_INDEX_URLS[python_mirror],
|
||||
"UV_PYTHON_INSTALL_MIRROR": UV_PYTHON_INSTALL_MIRROR,
|
||||
"UV_HTTP_TIMEOUT": "600",
|
||||
"UV_LINK_MODE": "copy",
|
||||
}),
|
||||
# 写入 Python 配置(仅当未配置)
|
||||
write_file(
|
||||
str(PIP_CONFIG_PATH),
|
||||
f"[global]\nindex-url = {PIP_INDEX_URLS[python_mirror]}\ntrusted-host = {PIP_TRUSTED_HOSTS[python_mirror]}",
|
||||
),
|
||||
# 写入 Conda 配置(仅当未配置)
|
||||
write_file(
|
||||
str(CONDA_CONFIG_PATH),
|
||||
"show_channel_urls: true\nchannels:\n - " + "\n - ".join(conda_mirror_urls) + "\n - defaults",
|
||||
),
|
||||
# 设置 Rust 镜像源
|
||||
*setenv_group({
|
||||
"RUSTUP_DIST_SERVER": RUSTUP_MIRRORS[rust_mirror]["RUSTUP_DIST_SERVER"],
|
||||
"RUSTUP_UPDATE_ROOT": RUSTUP_MIRRORS[rust_mirror]["RUSTUP_UPDATE_ROOT"],
|
||||
"RUST_SCCACHE_DIR": str(RUST_SCCACHE_DIR),
|
||||
"RUST_SCCACHE_CACHE_SIZE": RUST_SCCACHE_CACHE_SIZE,
|
||||
}),
|
||||
# 写入 Rust 配置(仅当未配置)
|
||||
write_file(
|
||||
str(RUST_CONFIG_PATH),
|
||||
f"""
|
||||
[source.crates-io]
|
||||
replace-with = '{rust_mirror}'
|
||||
|
||||
[source.{rust_mirror}]
|
||||
registry = "sparse+{RUSTUP_MIRRORS[rust_mirror]["TOML_REGISTRY"]}"
|
||||
|
||||
[registries.{rust_mirror}]
|
||||
index = "sparse+{RUSTUP_MIRRORS[rust_mirror]["TOML_REGISTRY"]}"
|
||||
""",
|
||||
),
|
||||
# 下载 Rustup 安装脚本
|
||||
px.TaskSpec(
|
||||
"download_rustup",
|
||||
cmd=["curl", "-fsSL", RUSTUP_DOWNLOAD_URL_LINUX, "-o", "rustup-init.sh"],
|
||||
conditions=(BuiltinConditions.IS_LINUX(), BuiltinConditions.NOT(BuiltinConditions.HAS_INSTALLED("rustup"))),
|
||||
verbose=True,
|
||||
),
|
||||
px.TaskSpec(
|
||||
"download_rustup_win",
|
||||
cmd=[
|
||||
"powershell",
|
||||
"-Command",
|
||||
"Invoke-WebRequest",
|
||||
"-Uri",
|
||||
RUSTUP_DOWNLOAD_URL_WINDOWS,
|
||||
"-OutFile",
|
||||
"rustup-init.exe",
|
||||
],
|
||||
conditions=(
|
||||
BuiltinConditions.IS_WINDOWS(),
|
||||
BuiltinConditions.NOT(BuiltinConditions.HAS_INSTALLED("rustup")),
|
||||
),
|
||||
verbose=True,
|
||||
),
|
||||
# 安装 Rust 工具链
|
||||
px.TaskSpec(
|
||||
"install_rust",
|
||||
cmd=["rustup", "toolchain", "install", rust_version],
|
||||
conditions=(BuiltinConditions.HAS_INSTALLED("rustup"),),
|
||||
depends_on=("setenv_rustup_dist_server",),
|
||||
allow_upstream_skip=True,
|
||||
verbose=True,
|
||||
),
|
||||
])
|
||||
px.run(graph, strategy="thread", verbose=True)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,137 @@
|
||||
"""文件日期处理工具.
|
||||
|
||||
自动检测文件名的日期前缀,
|
||||
并根据文件的实际创建或修改时间重命名文件.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
# ============================================================================
|
||||
# 配置
|
||||
# ============================================================================
|
||||
|
||||
DATE_PATTERN = re.compile(r"(20|19)\d{2}[-_#.~]?((0[1-9])|(1[012]))[-_#.~]?((0[1-9])|([12]\d)|(3[01]))[-_#.~]?")
|
||||
SEP = "_"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 辅助函数
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def get_file_timestamp(filepath: Path) -> str:
|
||||
"""获取文件时间戳."""
|
||||
modified_time = filepath.stat().st_mtime
|
||||
created_time = filepath.stat().st_ctime
|
||||
return time.strftime("%Y%m%d", time.localtime(max((modified_time, created_time))))
|
||||
|
||||
|
||||
def remove_date_prefix(filepath: Path) -> Path:
|
||||
"""移除文件日期前缀."""
|
||||
stem = filepath.stem
|
||||
new_stem = DATE_PATTERN.sub("", stem)
|
||||
if new_stem != stem:
|
||||
new_path = filepath.with_name(new_stem + filepath.suffix)
|
||||
filepath.rename(new_path)
|
||||
return new_path
|
||||
return filepath
|
||||
|
||||
|
||||
def add_date_prefix(filepath: Path) -> Path:
|
||||
"""添加文件日期前缀."""
|
||||
timestamp = get_file_timestamp(filepath)
|
||||
stem = filepath.stem
|
||||
new_stem = f"{timestamp}{SEP}{stem}"
|
||||
new_path = filepath.with_name(new_stem + filepath.suffix)
|
||||
if new_path != filepath:
|
||||
filepath.rename(new_path)
|
||||
return new_path
|
||||
return filepath
|
||||
|
||||
|
||||
def process_file_date(filepath: Path, clear: bool = False) -> None:
|
||||
"""处理单个文件的日期前缀.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filepath : Path
|
||||
文件路径
|
||||
clear : bool
|
||||
是否清除日期前缀
|
||||
"""
|
||||
if clear:
|
||||
remove_date_prefix(filepath)
|
||||
else:
|
||||
# 先移除旧日期前缀,再添加新日期前缀
|
||||
new_path = remove_date_prefix(filepath)
|
||||
add_date_prefix(new_path)
|
||||
|
||||
|
||||
def process_files_date(targets: list[Path], clear: bool = False) -> None:
|
||||
"""批量处理文件日期前缀.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
targets : list[Path]
|
||||
文件路径列表
|
||||
clear : bool
|
||||
是否清除日期前缀
|
||||
"""
|
||||
for target in targets:
|
||||
if target.exists() and not target.name.startswith("."):
|
||||
process_file_date(target, clear)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CLI Runner
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""文件日期处理工具主函数."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="FileDate - 文件日期处理工具",
|
||||
usage="filedate <command> [options]",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="可用命令")
|
||||
|
||||
# 添加日期前缀命令
|
||||
add_parser = subparsers.add_parser("add", help="添加日期前缀")
|
||||
add_parser.add_argument("files", nargs="+", help="文件路径")
|
||||
|
||||
# 清除日期前缀命令
|
||||
clear_parser = subparsers.add_parser("clear", help="清除日期前缀")
|
||||
clear_parser.add_argument("files", nargs="+", help="文件路径")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "add":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"process_files_date",
|
||||
fn=process_files_date,
|
||||
args=([Path(f) for f in args.files],),
|
||||
kwargs={"clear": False},
|
||||
)
|
||||
])
|
||||
elif args.command == "clear":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"process_files_date",
|
||||
fn=process_files_date,
|
||||
args=([Path(f) for f in args.files],),
|
||||
kwargs={"clear": True},
|
||||
)
|
||||
])
|
||||
else:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
px.run(graph, strategy="thread")
|
||||
@@ -0,0 +1,140 @@
|
||||
"""文件等级重命名工具.
|
||||
|
||||
根据文件等级配置自动重命名文件,
|
||||
支持多种等级标识和括号格式.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
# ============================================================================
|
||||
# 配置
|
||||
# ============================================================================
|
||||
|
||||
LEVELS: dict[str, str] = {
|
||||
"0": "",
|
||||
"1": "PUB,NOR",
|
||||
"2": "INT",
|
||||
"3": "CON",
|
||||
"4": "CLA",
|
||||
}
|
||||
|
||||
BRACKETS: tuple[str, str] = (" ([_(【-", " )]_)】")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 辅助函数
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def remove_marks(stem: str, marks: list[str]) -> str:
|
||||
"""从文件名主干中移除所有标记."""
|
||||
left_brackets, right_brackets = BRACKETS
|
||||
for mark in marks:
|
||||
pos = 0
|
||||
while True:
|
||||
pos = stem.find(mark, pos)
|
||||
if pos == -1:
|
||||
break
|
||||
b, e = pos - 1, pos + len(mark)
|
||||
if b >= 0 and e < len(stem) and stem[b] in left_brackets and stem[e] in right_brackets:
|
||||
stem = stem[:b] + stem[e + 1 :]
|
||||
else:
|
||||
pos = e
|
||||
return stem
|
||||
|
||||
|
||||
def process_file_level(filepath: Path, level: int = 0) -> None:
|
||||
"""处理单个文件的等级标记.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filepath : Path
|
||||
文件路径
|
||||
level : int
|
||||
文件等级 (0-4), 0 用于清除等级
|
||||
"""
|
||||
if not (0 <= level < len(LEVELS)):
|
||||
print(f"无效的等级 {level}, 必须在 0 和 {len(LEVELS) - 1} 之间")
|
||||
return
|
||||
|
||||
if not filepath.exists():
|
||||
print(f"文件不存在: {filepath}")
|
||||
return
|
||||
|
||||
filestem = filepath.stem
|
||||
original_stem = filestem
|
||||
|
||||
# 移除所有等级标记
|
||||
for level_names in LEVELS.values():
|
||||
if level_names:
|
||||
filestem = remove_marks(filestem, level_names.split(","))
|
||||
|
||||
# 移除数字标记
|
||||
for digit in map(str, range(1, 10)):
|
||||
filestem = remove_marks(filestem, [digit])
|
||||
|
||||
# 添加等级标记
|
||||
if level > 0:
|
||||
levelstr = LEVELS.get(str(level), "").split(",")[0]
|
||||
if levelstr:
|
||||
filestem = f"{filestem}({levelstr})"
|
||||
|
||||
# 重命名文件
|
||||
if filestem != original_stem:
|
||||
new_path = filepath.with_name(filestem + filepath.suffix)
|
||||
filepath.rename(new_path)
|
||||
print(f"重命名: {filepath} -> {new_path}")
|
||||
|
||||
|
||||
def process_files_level(targets: list[Path], level: int = 0) -> None:
|
||||
"""批量处理文件等级标记.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
targets : list[Path]
|
||||
文件路径列表
|
||||
level : int
|
||||
文件等级 (0-4)
|
||||
"""
|
||||
for target in targets:
|
||||
process_file_level(target, level)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CLI Runner
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""文件等级重命名工具主函数."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="FileLevel - 文件等级重命名工具",
|
||||
usage="filelevel <command> [options]",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="可用命令")
|
||||
|
||||
# 设置等级命令
|
||||
level_parser = subparsers.add_parser("set", help="设置文件等级")
|
||||
level_parser.add_argument("files", nargs="+", help="文件路径")
|
||||
level_parser.add_argument("--level", type=int, choices=[0, 1, 2, 3, 4], required=True, help="文件等级 (0-4)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "set":
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec(
|
||||
"process_files_level", fn=process_files_level, args=([Path(f) for f in args.files], args.level)
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
px.run(graph, strategy="thread")
|
||||
@@ -0,0 +1,85 @@
|
||||
"""文件夹备份工具.
|
||||
|
||||
备份文件和文件夹为 zip 文件,
|
||||
自动删除超过最大数量的旧备份文件.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
# ============================================================================
|
||||
# 辅助函数
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def remove_dump(src: Path, dst: Path, max_zip: int) -> None:
|
||||
"""递归删除旧的备份 zip 文件."""
|
||||
zip_paths = [filepath for filepath in dst.rglob("*.zip") if src.stem in str(filepath)]
|
||||
zip_files = sorted(zip_paths, key=lambda fn: str(fn)[-19:-4])
|
||||
if len(zip_files) > max_zip:
|
||||
zip_files[0].unlink()
|
||||
remove_dump(src, dst, max_zip)
|
||||
|
||||
|
||||
def zip_target(src: Path, dst: Path, max_zip: int) -> None:
|
||||
"""将单个文件或文件夹压缩为 zip 文件."""
|
||||
files = [str(_) for _ in src.rglob("*")]
|
||||
timestamp = time.strftime("_%Y%m%d_%H%M%S")
|
||||
target_path = dst / (src.stem + timestamp + ".zip")
|
||||
|
||||
with zipfile.ZipFile(target_path, "w") as zip_file:
|
||||
for file in files:
|
||||
zip_file.write(file, arcname=file.replace(str(src.parent), ""))
|
||||
|
||||
remove_dump(src, dst, max_zip)
|
||||
print(f"备份完成: {target_path}")
|
||||
|
||||
|
||||
def backup_folder(src: str, dst: str, max_zip: int = 5) -> None:
|
||||
"""备份文件夹.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
src : str
|
||||
源文件夹路径
|
||||
dst : str
|
||||
目标文件夹路径
|
||||
max_zip : int
|
||||
最大备份数量
|
||||
"""
|
||||
src_path = Path(src)
|
||||
dst_path = Path(dst)
|
||||
|
||||
if not src_path.exists():
|
||||
print(f"源文件夹不存在: {src_path}")
|
||||
return
|
||||
|
||||
if not dst_path.exists():
|
||||
dst_path.mkdir(parents=True, exist_ok=True)
|
||||
print(f"创建目标文件夹: {dst_path}")
|
||||
|
||||
zip_target(src_path, dst_path, max_zip)
|
||||
|
||||
|
||||
@px.task
|
||||
def folderback_default() -> None:
|
||||
"""备份当前目录到 ./backup."""
|
||||
backup_folder(".", "./backup", 5)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""文件夹备份工具主函数."""
|
||||
runner = px.CliRunner(
|
||||
strategy="thread",
|
||||
description="FolderBack - 文件夹备份工具",
|
||||
aliases={
|
||||
# 备份当前目录到 ./backup
|
||||
"b": folderback_default,
|
||||
},
|
||||
)
|
||||
runner.run_cli()
|
||||
@@ -0,0 +1,76 @@
|
||||
"""文件夹压缩工具.
|
||||
|
||||
压缩目录下的所有文件/文件夹为 zip 文件,
|
||||
默认压缩当前目录下的所有子文件夹.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
# ============================================================================
|
||||
# 配置
|
||||
# ============================================================================
|
||||
|
||||
IGNORE_DIRS: list[str] = [".git", ".idea", ".vscode", "__pycache__"]
|
||||
IGNORE_FILES: list[str] = [".gitignore"]
|
||||
IGNORE: list[str] = [*IGNORE_DIRS, *IGNORE_FILES]
|
||||
IGNORE_EXT: list[str] = [".zip", ".rar", ".7z", ".tar", ".gz"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 辅助函数
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def archive_folder(folder: Path) -> None:
|
||||
"""压缩单个文件夹."""
|
||||
shutil.make_archive(
|
||||
str(folder.with_name(folder.name)),
|
||||
format="zip",
|
||||
base_dir=folder,
|
||||
)
|
||||
print(f"压缩完成: {folder.name}.zip")
|
||||
|
||||
|
||||
def zip_folders(cwd: str = ".") -> None:
|
||||
"""压缩目录下的所有文件夹.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cwd : str
|
||||
工作目录
|
||||
"""
|
||||
cwd_path = Path(cwd)
|
||||
if not cwd_path.exists():
|
||||
print(f"目录不存在: {cwd_path}")
|
||||
return
|
||||
|
||||
dirs: list[Path] = [
|
||||
e for e in cwd_path.iterdir() if e.is_dir() and e.name not in IGNORE_DIRS and e.suffix not in IGNORE_EXT
|
||||
]
|
||||
|
||||
for dir_path in dirs:
|
||||
archive_folder(dir_path)
|
||||
|
||||
|
||||
@px.task
|
||||
def folderzip_default() -> None:
|
||||
"""压缩当前目录下的所有文件夹."""
|
||||
zip_folders(".")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""文件夹压缩工具主函数."""
|
||||
runner = px.CliRunner(
|
||||
strategy="thread",
|
||||
description="FolderZip - 文件夹压缩工具",
|
||||
aliases={
|
||||
# 压缩当前目录下的所有文件夹
|
||||
"z": folderzip_default,
|
||||
},
|
||||
)
|
||||
runner.run_cli()
|
||||
@@ -0,0 +1,107 @@
|
||||
"""Git 工具模块.
|
||||
|
||||
提供 Git 仓库管理的常用操作封装,
|
||||
支持初始化、提交、清理、推送等功能.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
EXCLUDE_DIRS = [
|
||||
# 编辑器相关目录
|
||||
".vscode",
|
||||
".idea",
|
||||
".editorconfig",
|
||||
".trae",
|
||||
".qoder",
|
||||
# 项目相关目录
|
||||
".venv",
|
||||
".git",
|
||||
".tox",
|
||||
".pytest_cache",
|
||||
"node_modules",
|
||||
".ruff_cache",
|
||||
]
|
||||
EXCLUDE_CMDS = [arg for d in EXCLUDE_DIRS for arg in ["-e", d]]
|
||||
|
||||
|
||||
def init_sub_dirs() -> None:
|
||||
"""初始化子目录的Git仓库."""
|
||||
sub_dirs = [subdir for subdir in Path.cwd().iterdir() if subdir.is_dir()]
|
||||
for subdir in sub_dirs:
|
||||
px.run(
|
||||
px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"init",
|
||||
cmd=["git", "init"],
|
||||
conditions=(lambda _: not_has_git_repo(),),
|
||||
cwd=subdir,
|
||||
),
|
||||
px.TaskSpec("add", cmd=["git", "add", "."], depends_on=("init",)),
|
||||
px.TaskSpec("commit", cmd=["git", "commit", "-m", "init commit"], depends_on=("add",)),
|
||||
]),
|
||||
)
|
||||
|
||||
|
||||
@px.task(name="isub")
|
||||
def isub() -> None:
|
||||
"""初始化子目录的Git仓库."""
|
||||
init_sub_dirs()
|
||||
|
||||
|
||||
push: px.TaskSpec = px.TaskSpec("push", cmd=["git", "push"])
|
||||
pull: px.TaskSpec = px.TaskSpec("pull", cmd=["git", "pull"])
|
||||
kill_tgit: px.TaskSpec = px.TaskSpec("task_kill", cmd=["taskkill", "/f", "/t", "/im", "tgitcache.exe"])
|
||||
|
||||
|
||||
def not_has_git_repo() -> bool:
|
||||
"""检查当前目录没有Git仓库."""
|
||||
return not Path.cwd().exists() or not (Path.cwd() / ".git").is_dir()
|
||||
|
||||
|
||||
def has_files() -> bool:
|
||||
"""检查当前目录是否有文件."""
|
||||
return bool(list(Path.cwd().glob("*")))
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Git工具主函数."""
|
||||
runner = px.CliRunner(
|
||||
strategy="thread",
|
||||
description="Gittool - Git 执行工具.",
|
||||
aliases={
|
||||
# 添加并提交
|
||||
"a": px.Graph.from_specs([
|
||||
px.TaskSpec("add", cmd=["git", "add", "."], conditions=(lambda _: has_files(),)),
|
||||
px.TaskSpec("commit", cmd=["git", "commit", "-m", "chore: update"], depends_on=("add",)),
|
||||
]),
|
||||
# 清理(chain: clean → status)
|
||||
"c": px.Graph().chain(
|
||||
px.TaskSpec("clean", cmd=["git", "clean", "-xfd", *EXCLUDE_CMDS]),
|
||||
px.TaskSpec("status", cmd=["git", "status", "--porcelain"]),
|
||||
),
|
||||
# 初始化、添加并提交
|
||||
"i": px.Graph.from_specs([
|
||||
px.TaskSpec("init", cmd=["git", "init"], conditions=(lambda _: not_has_git_repo(),)),
|
||||
px.TaskSpec("add", cmd=["git", "add", "."], depends_on=("init",), conditions=(lambda _: has_files(),)),
|
||||
px.TaskSpec(
|
||||
"commit",
|
||||
cmd=["git", "commit", "-m", "init commit"],
|
||||
depends_on=("add",),
|
||||
conditions=(lambda _: has_files(),),
|
||||
),
|
||||
]),
|
||||
# 初始化子目录
|
||||
"isub": isub,
|
||||
# 推送
|
||||
"p": push,
|
||||
# 拉取
|
||||
"pl": pull,
|
||||
# 重启TGit缓存
|
||||
"r": kill_tgit,
|
||||
},
|
||||
)
|
||||
runner.run_cli()
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Download from ModelScopeHub."""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Literal, get_args
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
DownloadType = Literal["model", "dataset", "space"]
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Download a model from ModelScopeHub.")
|
||||
parser.add_argument("name", help="Target name.")
|
||||
parser.add_argument("--type", "-t", nargs="?", default="model", choices=get_args(DownloadType), help="Target type.")
|
||||
parser.add_argument("--dir", default=None, help="Download directory.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.name:
|
||||
parser.error("name is required")
|
||||
|
||||
download_dir: Path = Path(args.dir) if args.dir else Path.home() / ".models" / args.name.split("/")[-1]
|
||||
download_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
name="download",
|
||||
cmd=[
|
||||
"uvx",
|
||||
"modelscope",
|
||||
"download",
|
||||
f"--{args.type}",
|
||||
args.name,
|
||||
"--local_dir",
|
||||
str(download_dir),
|
||||
],
|
||||
verbose=True,
|
||||
),
|
||||
])
|
||||
|
||||
px.run(graph, strategy="thread", verbose=True)
|
||||
@@ -0,0 +1,63 @@
|
||||
"""使用 SGLang 运行本地模型."""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.conditions import BuiltinConditions, Constants
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="启动 SGLang 服务")
|
||||
parser.add_argument("--model", default="~/.models/Qwen2.5-Coder-32B-Instruct-AWQ", help="模型路径")
|
||||
parser.add_argument("--port", type=int, default=8000, help="服务端口")
|
||||
parser.add_argument("--ctx-len", type=int, default=28672, help="最大上下文长度")
|
||||
parser.add_argument("--mem", type=float, default=0.75, help="显存占比 (0-1)")
|
||||
parser.add_argument("--host", default="0.0.0.0", help="主机地址")
|
||||
parser.add_argument("--log-level", default="info", help="日志级别")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.model:
|
||||
parser.error("model is required")
|
||||
|
||||
model_dir = Path(args.model).expanduser()
|
||||
if not model_dir.exists():
|
||||
parser.error(f"Model directory {model_dir} does not exist.")
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
name="download",
|
||||
cmd=[
|
||||
"uv",
|
||||
"install",
|
||||
"sglang[all]",
|
||||
],
|
||||
conditions=(BuiltinConditions.NOT(BuiltinConditions.HAS_INSTALLED("sglang")),),
|
||||
verbose=True,
|
||||
),
|
||||
px.TaskSpec(
|
||||
name="run",
|
||||
cmd=[
|
||||
"python" if Constants.IS_WINDOWS else "python3",
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--model-path",
|
||||
str(model_dir),
|
||||
"--host",
|
||||
str(args.host),
|
||||
"--port",
|
||||
"8000",
|
||||
"--mem-fraction-static",
|
||||
str(args.mem),
|
||||
"--context-length",
|
||||
"32768",
|
||||
"--tool-call-parser",
|
||||
"qwen",
|
||||
"--log-level",
|
||||
str(args.log_level),
|
||||
],
|
||||
verbose=True,
|
||||
),
|
||||
])
|
||||
|
||||
px.run(graph, strategy="sequential", verbose=True)
|
||||
@@ -0,0 +1,174 @@
|
||||
"""LS-DYNA 计算工具.
|
||||
|
||||
用于管理 LS-DYNA 仿真计算任务,
|
||||
支持启动、监控和管理计算进程.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.conditions import Constants
|
||||
|
||||
# ============================================================================
|
||||
# 配置
|
||||
# ============================================================================
|
||||
|
||||
LS_DYNA_COMMANDS: dict[str, list[str]] = {
|
||||
"windows": ["ls-dyna_mpp", "i=input.k", "ncpu=4"],
|
||||
"linux": ["ls-dyna_mpp", "i=input.k", "ncpu=8"],
|
||||
"macos": ["ls-dyna_mpp", "i=input.k", "ncpu=4"],
|
||||
}
|
||||
|
||||
DEFAULT_INPUT_FILE: str = "input.k"
|
||||
DEFAULT_NCPU: int = 4
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 辅助函数
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def get_ls_dyna_command(input_file: str, ncpu: int) -> list[str]:
|
||||
"""获取 LS-DYNA 命令.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_file : str
|
||||
输入文件路径
|
||||
ncpu : int
|
||||
CPU 核心数
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
LS-DYNA 命令列表
|
||||
"""
|
||||
if Constants.IS_WINDOWS or Constants.IS_MACOS:
|
||||
return ["ls-dyna_mpp", f"i={input_file}", f"ncpu={ncpu}"]
|
||||
else:
|
||||
return ["ls-dyna_mpp", f"i={input_file}", f"ncpu={ncpu}"]
|
||||
|
||||
|
||||
def run_ls_dyna(input_file: str, ncpu: int = DEFAULT_NCPU) -> None:
|
||||
"""运行 LS-DYNA 计算.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_file : str
|
||||
输入文件路径
|
||||
ncpu : int
|
||||
CPU 核心数
|
||||
"""
|
||||
input_path = Path(input_file)
|
||||
if not input_path.exists():
|
||||
print(f"输入文件不存在: {input_path}")
|
||||
return
|
||||
|
||||
cmd = get_ls_dyna_command(input_file, ncpu)
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
print(f"LS-DYNA 计算完成: {input_file}")
|
||||
except FileNotFoundError:
|
||||
print("未找到 ls-dyna_mpp 命令")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"LS-DYNA 计算失败: {e}")
|
||||
|
||||
|
||||
def run_ls_dyna_mpi(input_file: str, ncpu: int = DEFAULT_NCPU) -> None:
|
||||
"""运行 LS-DYNA MPI 计算.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_file : str
|
||||
输入文件路径
|
||||
ncpu : int
|
||||
CPU 核心数
|
||||
"""
|
||||
input_path = Path(input_file)
|
||||
if not input_path.exists():
|
||||
print(f"输入文件不存在: {input_path}")
|
||||
return
|
||||
|
||||
cmd = ["mpirun", "-np", str(ncpu), "ls-dyna_mpp", f"i={input_file}"]
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
print(f"LS-DYNA MPI 计算完成: {input_file}")
|
||||
except FileNotFoundError:
|
||||
print("未找到 mpirun 或 ls-dyna_mpp 命令")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"LS-DYNA MPI 计算失败: {e}")
|
||||
|
||||
|
||||
def check_ls_dyna_status() -> None:
|
||||
"""检查 LS-DYNA 进程状态."""
|
||||
try:
|
||||
if Constants.IS_WINDOWS:
|
||||
result = subprocess.run(
|
||||
["tasklist", "/fi", "imagename eq ls-dyna_mpp.exe"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
print(result.stdout)
|
||||
else:
|
||||
result = subprocess.run(
|
||||
["pgrep", "-f", "ls-dyna"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
if result.stdout.strip():
|
||||
print(f"运行中的 LS-DYNA 进程 PID: {result.stdout.strip()}")
|
||||
else:
|
||||
print("没有运行中的 LS-DYNA 进程")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"检查进程状态失败: {e}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CLI Runner
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""LS-DYNA 计算工具主函数."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="LSCalc - LS-DYNA 计算工具",
|
||||
usage="lscalc <command> [options]",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="可用命令")
|
||||
|
||||
# 运行计算命令
|
||||
run_parser = subparsers.add_parser("run", help="运行 LS-DYNA 计算")
|
||||
run_parser.add_argument("input_file", help="输入文件路径")
|
||||
run_parser.add_argument("--ncpu", type=int, default=DEFAULT_NCPU, help="CPU 核心数")
|
||||
|
||||
# 运行 MPI 计算命令
|
||||
mpi_parser = subparsers.add_parser("mpi", help="运行 LS-DYNA MPI 计算")
|
||||
mpi_parser.add_argument("input_file", help="输入文件路径")
|
||||
mpi_parser.add_argument("--ncpu", type=int, default=DEFAULT_NCPU, help="CPU 核心数")
|
||||
|
||||
# 检查进程状态命令
|
||||
subparsers.add_parser("status", help="检查 LS-DYNA 进程状态")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "run":
|
||||
graph = px.Graph.from_specs(
|
||||
[px.TaskSpec("run_ls_dyna", fn=run_ls_dyna, args=(args.input_file,), kwargs={"ncpu": args.ncpu})]
|
||||
)
|
||||
elif args.command == "mpi":
|
||||
graph = px.Graph.from_specs(
|
||||
[px.TaskSpec("run_ls_dyna_mpi", fn=run_ls_dyna_mpi, args=(args.input_file,), kwargs={"ncpu": args.ncpu})]
|
||||
)
|
||||
elif args.command == "status":
|
||||
graph = px.Graph.from_specs([px.TaskSpec("check_ls_dyna_status", fn=check_ls_dyna_status)])
|
||||
else:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
px.run(graph, strategy="thread")
|
||||
@@ -0,0 +1,349 @@
|
||||
"""Python 打包工具模块.
|
||||
|
||||
提供 Python 项目打包的常用功能封装,
|
||||
支持源码打包、依赖打包、嵌入式 Python 安装等功能.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
import subprocess
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
# ============================================================================
|
||||
# 配置
|
||||
# ============================================================================
|
||||
|
||||
DEFAULT_BUILD_DIR = ".pypack"
|
||||
DEFAULT_DIST_DIR = "dist"
|
||||
DEFAULT_LIB_DIR = "libs"
|
||||
DEFAULT_CACHE_DIR = ".cache/pypack"
|
||||
|
||||
IGNORE_PATTERNS = [
|
||||
"__pycache__",
|
||||
"*.pyc",
|
||||
"*.pyo",
|
||||
".git",
|
||||
".venv",
|
||||
".idea",
|
||||
".vscode",
|
||||
"*.egg-info",
|
||||
"dist",
|
||||
"build",
|
||||
".pytest_cache",
|
||||
".tox",
|
||||
".mypy_cache",
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 辅助函数
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def pack_source(project_dir: Path, output_dir: Path) -> None:
|
||||
"""打包项目源码.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
project_dir : Path
|
||||
项目目录
|
||||
output_dir : Path
|
||||
输出目录
|
||||
"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 检测项目名称
|
||||
pyproject_file = project_dir / "pyproject.toml"
|
||||
project_name = project_dir.name
|
||||
|
||||
if pyproject_file.exists():
|
||||
try:
|
||||
import tomllib
|
||||
|
||||
content = pyproject_file.read_text(encoding="utf-8")
|
||||
data = tomllib.loads(content)
|
||||
project_name = data.get("project", {}).get("name", project_name)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 打包源码
|
||||
source_dir = output_dir / "src" / project_name
|
||||
source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 复制文件
|
||||
src_subdir = project_dir / "src"
|
||||
if src_subdir.exists():
|
||||
shutil.copytree(
|
||||
src_subdir,
|
||||
source_dir / "src",
|
||||
ignore=shutil.ignore_patterns(*IGNORE_PATTERNS),
|
||||
dirs_exist_ok=True,
|
||||
)
|
||||
else:
|
||||
for item in project_dir.iterdir():
|
||||
if item.name in IGNORE_PATTERNS or item.name.startswith("."):
|
||||
continue
|
||||
dst_item = source_dir / item.name
|
||||
if item.is_dir():
|
||||
shutil.copytree(
|
||||
item,
|
||||
dst_item,
|
||||
ignore=shutil.ignore_patterns(*IGNORE_PATTERNS),
|
||||
dirs_exist_ok=True,
|
||||
)
|
||||
else:
|
||||
shutil.copy2(item, dst_item)
|
||||
|
||||
print(f"源码打包完成: {source_dir}")
|
||||
|
||||
|
||||
def pack_dependencies(lib_dir: Path, dependencies: list[str]) -> None:
|
||||
"""打包项目依赖.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
lib_dir : Path
|
||||
依赖库目录
|
||||
dependencies : list[str]
|
||||
依赖列表
|
||||
"""
|
||||
lib_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not dependencies:
|
||||
print("没有依赖需要打包")
|
||||
return
|
||||
|
||||
# 使用 pip 安装依赖到目标目录
|
||||
cmd = [
|
||||
"pip",
|
||||
"install",
|
||||
"--target",
|
||||
str(lib_dir),
|
||||
"--no-compile",
|
||||
"--no-warn-script-location",
|
||||
]
|
||||
cmd.extend(dependencies)
|
||||
|
||||
subprocess.run(cmd, check=True)
|
||||
print(f"依赖打包完成: {lib_dir}")
|
||||
|
||||
|
||||
def pack_wheel(project_dir: Path, output_dir: Path) -> None:
|
||||
"""打包项目为 wheel 文件.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
project_dir : Path
|
||||
项目目录
|
||||
output_dir : Path
|
||||
输出目录
|
||||
"""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 使用 pip wheel 打包
|
||||
cmd = [
|
||||
"pip",
|
||||
"wheel",
|
||||
"--no-deps",
|
||||
"--wheel-dir",
|
||||
str(output_dir),
|
||||
str(project_dir),
|
||||
]
|
||||
|
||||
subprocess.run(cmd, check=True)
|
||||
print(f"Wheel 打包完成: {output_dir}")
|
||||
|
||||
|
||||
def install_embed_python(version: str, output_dir: Path) -> None:
|
||||
"""安装嵌入式 Python.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
version : str
|
||||
Python 版本 (如: 3.10, 3.11)
|
||||
output_dir : Path
|
||||
输出目录
|
||||
"""
|
||||
import platform
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 构建下载 URL
|
||||
arch = platform.machine().lower()
|
||||
if arch in ["x86_64", "amd64"]:
|
||||
arch = "amd64"
|
||||
elif arch in ["arm64", "aarch64"]:
|
||||
arch = "arm64"
|
||||
|
||||
# 解析完整版本号
|
||||
version_map = {
|
||||
"3.8": "3.8.10",
|
||||
"3.9": "3.9.13",
|
||||
"3.10": "3.10.11",
|
||||
"3.11": "3.11.9",
|
||||
"3.12": "3.12.4",
|
||||
}
|
||||
full_version = version_map.get(version, f"{version}.0")
|
||||
|
||||
# Windows 嵌入式 Python 下载 URL
|
||||
url = f"https://www.python.org/ftp/python/{full_version}/python-{full_version}-embed-{arch}.zip"
|
||||
|
||||
# 下载并解压
|
||||
cache_file = Path(DEFAULT_CACHE_DIR) / f"python-{full_version}-embed-{arch}.zip"
|
||||
cache_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if not cache_file.exists():
|
||||
print(f"正在下载嵌入式 Python {full_version}...")
|
||||
import urllib.request
|
||||
|
||||
urllib.request.urlretrieve(url, cache_file)
|
||||
print(f"下载完成: {cache_file}")
|
||||
|
||||
# 解压
|
||||
with zipfile.ZipFile(cache_file, "r") as zf:
|
||||
zf.extractall(output_dir)
|
||||
|
||||
print(f"嵌入式 Python 安装完成: {output_dir}")
|
||||
|
||||
|
||||
def create_zip_package(source_dir: Path, output_file: Path) -> None:
|
||||
"""创建 ZIP 打包文件.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir : Path
|
||||
源目录
|
||||
output_file : Path
|
||||
输出文件
|
||||
"""
|
||||
output_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with zipfile.ZipFile(output_file, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
for file in source_dir.rglob("*"):
|
||||
if file.is_file():
|
||||
arcname = file.relative_to(source_dir)
|
||||
zf.write(file, arcname)
|
||||
|
||||
print(f"ZIP 打包完成: {output_file}")
|
||||
|
||||
|
||||
def clean_build_dir(build_dir: Path) -> None:
|
||||
"""清理构建目录.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
build_dir : Path
|
||||
构建目录
|
||||
"""
|
||||
if build_dir.exists():
|
||||
shutil.rmtree(build_dir)
|
||||
print(f"清理完成: {build_dir}")
|
||||
else:
|
||||
print(f"目录不存在: {build_dir}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CLI Runner
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Python 打包工具主函数."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="PackTool - Python 打包工具",
|
||||
usage="packtool <command> [options]",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="可用命令")
|
||||
|
||||
# 源码打包命令
|
||||
src_parser = subparsers.add_parser("src", help="打包项目源码")
|
||||
src_parser.add_argument("--project-dir", type=str, default=".", help="项目目录")
|
||||
src_parser.add_argument("--output-dir", type=str, default=DEFAULT_BUILD_DIR, help="输出目录")
|
||||
|
||||
# 依赖打包命令
|
||||
deps_parser = subparsers.add_parser("deps", help="打包项目依赖")
|
||||
deps_parser.add_argument("--lib-dir", type=str, default=DEFAULT_LIB_DIR, help="依赖库目录")
|
||||
deps_parser.add_argument("dependencies", nargs="*", help="依赖列表")
|
||||
|
||||
# Wheel 打包命令
|
||||
wheel_parser = subparsers.add_parser("wheel", help="打包项目为 wheel 文件")
|
||||
wheel_parser.add_argument("--project-dir", type=str, default=".", help="项目目录")
|
||||
wheel_parser.add_argument("--output-dir", type=str, default=DEFAULT_DIST_DIR, help="输出目录")
|
||||
|
||||
# 嵌入式 Python 安装命令
|
||||
embed_parser = subparsers.add_parser("embed", help="安装嵌入式 Python")
|
||||
embed_parser.add_argument("--version", type=str, default="3.10", help="Python 版本")
|
||||
embed_parser.add_argument("--output-dir", type=str, default="python", help="输出目录")
|
||||
|
||||
# ZIP 打包命令
|
||||
zip_parser = subparsers.add_parser("zip", help="创建 ZIP 打包文件")
|
||||
zip_parser.add_argument("--source-dir", type=str, default=".", help="源目录")
|
||||
zip_parser.add_argument("--output-file", type=str, default="package.zip", help="输出文件")
|
||||
|
||||
# 清理命令
|
||||
subparsers.add_parser("clean", help="清理构建目录")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "src":
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec(
|
||||
"pack_source",
|
||||
fn=pack_source,
|
||||
args=(Path(args.project_dir), Path(args.output_dir)),
|
||||
)
|
||||
]
|
||||
)
|
||||
elif args.command == "deps":
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec(
|
||||
"pack_deps",
|
||||
fn=pack_dependencies,
|
||||
args=(Path(args.lib_dir), args.dependencies),
|
||||
)
|
||||
]
|
||||
)
|
||||
elif args.command == "wheel":
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec(
|
||||
"pack_wheel",
|
||||
fn=pack_wheel,
|
||||
args=(Path(args.project_dir), Path(args.output_dir)),
|
||||
)
|
||||
]
|
||||
)
|
||||
elif args.command == "embed":
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec(
|
||||
"install_embed",
|
||||
fn=install_embed_python,
|
||||
args=(args.version, Path(args.output_dir)),
|
||||
)
|
||||
]
|
||||
)
|
||||
elif args.command == "zip":
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec(
|
||||
"create_zip",
|
||||
fn=create_zip_package,
|
||||
args=(Path(args.source_dir), Path(args.output_file)),
|
||||
)
|
||||
]
|
||||
)
|
||||
elif args.command == "clean":
|
||||
graph = px.Graph.from_specs([px.TaskSpec("clean_build", fn=clean_build_dir, args=(Path(DEFAULT_BUILD_DIR),))])
|
||||
else:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
px.run(graph, strategy="thread")
|
||||
@@ -0,0 +1,523 @@
|
||||
"""PDF 工具模块.
|
||||
|
||||
提供 PDF 文件操作的常用功能封装,
|
||||
支持合并、拆分、压缩、加密、水印、OCR等功能.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
try:
|
||||
import fitz # PyMuPDF
|
||||
|
||||
HAS_PYMUPDF = True
|
||||
except ImportError:
|
||||
HAS_PYMUPDF = False
|
||||
|
||||
try:
|
||||
import pypdf
|
||||
|
||||
HAS_PYPDF = True
|
||||
except ImportError:
|
||||
HAS_PYPDF = False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 配置
|
||||
# ============================================================================
|
||||
|
||||
PDF_SUFFIX = ".pdf"
|
||||
DEFAULT_QUALITY = 75
|
||||
DEFAULT_PASSWORD = ""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 辅助函数
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def pdf_merge(input_paths: list[Path], output_path: Path) -> None:
|
||||
"""合并多个 PDF 文件."""
|
||||
if not HAS_PYPDF:
|
||||
print("未安装 pypdf 库,请安装: pip install pypdf")
|
||||
return
|
||||
|
||||
writer = pypdf.PdfWriter()
|
||||
for input_path in input_paths:
|
||||
if input_path.exists():
|
||||
reader = pypdf.PdfReader(str(input_path))
|
||||
for page in reader.pages:
|
||||
writer.add_page(page)
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "wb") as f:
|
||||
writer.write(f)
|
||||
|
||||
print(f"合并完成: {output_path}")
|
||||
|
||||
|
||||
def pdf_split(input_path: Path, output_dir: Path) -> None:
|
||||
"""拆分 PDF 文件为单页."""
|
||||
if not HAS_PYPDF:
|
||||
print("未安装 pypdf 库,请安装: pip install pypdf")
|
||||
return
|
||||
|
||||
reader = pypdf.PdfReader(str(input_path))
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i, page in enumerate(reader.pages):
|
||||
writer = pypdf.PdfWriter()
|
||||
writer.add_page(page)
|
||||
output_file = output_dir / f"{input_path.stem}_page_{i + 1}.pdf"
|
||||
with open(output_file, "wb") as f:
|
||||
writer.write(f)
|
||||
|
||||
print(f"拆分完成: {output_dir}")
|
||||
|
||||
|
||||
def pdf_compress(input_path: Path, output_path: Path) -> None:
|
||||
"""压缩 PDF 文件."""
|
||||
if not HAS_PYMUPDF:
|
||||
print("未安装 PyMuPDF 库,请安装: pip install PyMuPDF")
|
||||
return
|
||||
|
||||
doc = fitz.open(str(input_path))
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
doc.save(str(output_path), garbage=4, deflate=True, clean=True)
|
||||
doc.close()
|
||||
|
||||
original_size = input_path.stat().st_size
|
||||
new_size = output_path.stat().st_size
|
||||
ratio = (1 - new_size / original_size) * 100
|
||||
print(f"压缩完成: {output_path} (缩小 {ratio:.1f}%)")
|
||||
|
||||
|
||||
def pdf_encrypt(input_path: Path, output_path: Path, password: str) -> None:
|
||||
"""加密 PDF 文件."""
|
||||
if not HAS_PYPDF:
|
||||
print("未安装 pypdf 库,请安装: pip install pypdf")
|
||||
return
|
||||
|
||||
reader = pypdf.PdfReader(str(input_path))
|
||||
writer = pypdf.PdfWriter()
|
||||
|
||||
for page in reader.pages:
|
||||
writer.add_page(page)
|
||||
|
||||
writer.encrypt(user_password=password, owner_password=password, use_128bit=True)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "wb") as f:
|
||||
writer.write(f)
|
||||
|
||||
print(f"加密完成: {output_path}")
|
||||
|
||||
|
||||
def pdf_decrypt(input_path: Path, output_path: Path, password: str) -> None:
|
||||
"""解密 PDF 文件."""
|
||||
if not HAS_PYPDF:
|
||||
print("未安装 pypdf 库,请安装: pip install pypdf")
|
||||
return
|
||||
|
||||
reader = pypdf.PdfReader(str(input_path))
|
||||
if reader.is_encrypted:
|
||||
reader.decrypt(password)
|
||||
|
||||
writer = pypdf.PdfWriter()
|
||||
for page in reader.pages:
|
||||
writer.add_page(page)
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "wb") as f:
|
||||
writer.write(f)
|
||||
|
||||
print(f"解密完成: {output_path}")
|
||||
|
||||
|
||||
def pdf_extract_text(input_path: Path, output_path: Path) -> None:
|
||||
"""提取 PDF 文本."""
|
||||
if not HAS_PYMUPDF:
|
||||
print("未安装 PyMuPDF 库,请安装: pip install PyMuPDF")
|
||||
return
|
||||
|
||||
doc = fitz.open(str(input_path))
|
||||
text = ""
|
||||
for page in doc:
|
||||
text += str(page.get_text()) + "\n\n"
|
||||
doc.close()
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(text, encoding="utf-8")
|
||||
print(f"文本提取完成: {output_path}")
|
||||
|
||||
|
||||
def pdf_extract_images(input_path: Path, output_dir: Path) -> None:
|
||||
"""提取 PDF 图片."""
|
||||
if not HAS_PYMUPDF:
|
||||
print("未安装 PyMuPDF 库,请安装: pip install PyMuPDF")
|
||||
return
|
||||
|
||||
doc = fitz.open(str(input_path))
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
image_count = 0
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
for page_num, page in enumerate(doc):
|
||||
images = page.get_images(full=True)
|
||||
for img_idx, img in enumerate(images):
|
||||
xref = img[0]
|
||||
base_image = doc.extract_image(xref)
|
||||
image_data = base_image["image"]
|
||||
image_ext = base_image["ext"]
|
||||
image_path = output_dir / f"page_{page_num + 1}_img_{img_idx + 1}.{image_ext}"
|
||||
image_path.write_bytes(image_data)
|
||||
image_count += 1
|
||||
|
||||
doc.close()
|
||||
print(f"图片提取完成: {output_dir} (共 {image_count} 张)")
|
||||
|
||||
|
||||
def pdf_add_watermark(input_path: Path, output_path: Path, text: str = "CONFIDENTIAL") -> None:
|
||||
"""添加 PDF 水印."""
|
||||
if not HAS_PYMUPDF:
|
||||
print("未安装 PyMuPDF 库,请安装: pip install PyMuPDF")
|
||||
return
|
||||
|
||||
doc = fitz.open(str(input_path))
|
||||
for page in doc:
|
||||
rect = page.rect
|
||||
text_width = fitz.get_text_length(text, fontsize=48)
|
||||
x = (rect.width - text_width) / 2
|
||||
y = rect.height / 2
|
||||
page.insert_text((x, y), text, fontsize=48, rotate=45, color=(0, 0, 0))
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
doc.save(str(output_path))
|
||||
doc.close()
|
||||
print(f"水印添加完成: {output_path}")
|
||||
|
||||
|
||||
def pdf_rotate(input_path: Path, output_path: Path, rotation: int = 90) -> None:
|
||||
"""旋转 PDF 页面."""
|
||||
if not HAS_PYMUPDF:
|
||||
print("未安装 PyMuPDF 库,请安装: pip install PyMuPDF")
|
||||
return
|
||||
|
||||
doc = fitz.open(str(input_path))
|
||||
for page in doc:
|
||||
page.set_rotation(rotation)
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
doc.save(str(output_path))
|
||||
doc.close()
|
||||
print(f"旋转完成: {output_path}")
|
||||
|
||||
|
||||
def pdf_crop(input_path: Path, output_path: Path, margins: tuple[int, int, int, int]) -> None:
|
||||
"""裁剪 PDF 页面."""
|
||||
if not HAS_PYMUPDF:
|
||||
print("未安装 PyMuPDF 库,请安装: pip install PyMuPDF")
|
||||
return
|
||||
|
||||
doc = fitz.open(str(input_path))
|
||||
left, top, right, bottom = margins
|
||||
|
||||
for page in doc:
|
||||
rect = page.rect
|
||||
new_rect = fitz.Rect(
|
||||
rect.x0 + left,
|
||||
rect.y0 + top,
|
||||
rect.x1 - right,
|
||||
rect.y1 - bottom,
|
||||
)
|
||||
page.set_cropbox(new_rect)
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
doc.save(str(output_path))
|
||||
doc.close()
|
||||
print(f"裁剪完成: {output_path}")
|
||||
|
||||
|
||||
def pdf_info(input_path: Path) -> None:
|
||||
"""显示 PDF 信息."""
|
||||
if not HAS_PYMUPDF:
|
||||
print("未安装 PyMuPDF 库,请安装: pip install PyMuPDF")
|
||||
return
|
||||
|
||||
doc = fitz.open(str(input_path))
|
||||
print(f"文件: {input_path}")
|
||||
print(f"页数: {doc.page_count}")
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
print(f"标题: {doc.metadata.get('title', 'N/A')}")
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
print(f"作者: {doc.metadata.get('author', 'N/A')}")
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
print(f"创建日期: {doc.metadata.get('creationDate', 'N/A')}")
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
print(f"修改日期: {doc.metadata.get('modDate', 'N/A')}")
|
||||
print(f"文件大小: {input_path.stat().st_size / 1024:.1f} KB")
|
||||
doc.close()
|
||||
|
||||
|
||||
def pdf_ocr(input_path: Path, output_path: Path, lang: str = "chi_sim+eng") -> None:
|
||||
"""PDF OCR 识别."""
|
||||
try:
|
||||
import pytesseract
|
||||
from PIL import Image
|
||||
except ImportError:
|
||||
print("未安装 OCR 相关库,请安装: pip install pytesseract pillow")
|
||||
return
|
||||
|
||||
if not HAS_PYMUPDF:
|
||||
print("未安装 PyMuPDF 库,请安装: pip install PyMuPDF")
|
||||
return
|
||||
|
||||
doc = fitz.open(str(input_path))
|
||||
new_doc = fitz.open()
|
||||
|
||||
for page in doc:
|
||||
pix = page.get_pixmap()
|
||||
img = Image.frombytes("RGB", (pix.width, pix.height), pix.samples)
|
||||
ocr_text = pytesseract.image_to_string(img, lang=lang)
|
||||
|
||||
new_page = new_doc.new_page(width=page.rect.width, height=page.rect.height)
|
||||
new_page.insert_image(new_page.rect, pixmap=pix)
|
||||
text_rect = fitz.Rect(0, 0, page.rect.width, page.rect.height)
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
new_page.insert_textbox(text_rect, ocr_text)
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
new_doc.save(str(output_path))
|
||||
new_doc.close()
|
||||
doc.close()
|
||||
print(f"OCR 识别完成: {output_path}")
|
||||
|
||||
|
||||
def pdf_reorder(input_path: Path, output_path: Path, order: list[int]) -> None:
|
||||
"""重排 PDF 页面顺序."""
|
||||
if not HAS_PYPDF:
|
||||
print("未安装 pypdf 库,请安装: pip install pypdf")
|
||||
return
|
||||
|
||||
reader = pypdf.PdfReader(str(input_path))
|
||||
writer = pypdf.PdfWriter()
|
||||
|
||||
for page_num in order:
|
||||
if 0 <= page_num < len(reader.pages):
|
||||
writer.add_page(reader.pages[page_num])
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "wb") as f:
|
||||
writer.write(f)
|
||||
|
||||
print(f"重排完成: {output_path}")
|
||||
|
||||
|
||||
def pdf_to_images(input_path: Path, output_dir: Path, dpi: int = 300) -> None:
|
||||
"""PDF 转图片."""
|
||||
if not HAS_PYMUPDF:
|
||||
print("未安装 PyMuPDF 库,请安装: pip install PyMuPDF")
|
||||
return
|
||||
|
||||
doc = fitz.open(str(input_path))
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
for page_num, page in enumerate(doc):
|
||||
pix = page.get_pixmap(dpi=dpi)
|
||||
image_path = output_dir / f"{input_path.stem}_page_{page_num + 1}.png"
|
||||
pix.save(str(image_path))
|
||||
|
||||
doc.close()
|
||||
print(f"转换完成: {output_dir}")
|
||||
|
||||
|
||||
def pdf_repair(input_path: Path, output_path: Path) -> None:
|
||||
"""修复 PDF 文件."""
|
||||
if not HAS_PYMUPDF:
|
||||
print("未安装 PyMuPDF 库,请安装: pip install PyMuPDF")
|
||||
return
|
||||
|
||||
doc = fitz.open(str(input_path))
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
doc.save(str(output_path), garbage=4, deflate=True, clean=True)
|
||||
doc.close()
|
||||
print(f"修复完成: {output_path}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CLI Runner
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def main() -> None: # noqa: PLR0912
|
||||
"""PDF 工具主函数."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="PDFTool - PDF 文件工具集",
|
||||
usage="pdftool <command> [options]",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="可用命令")
|
||||
|
||||
# 合并 PDF 命令
|
||||
merge_parser = subparsers.add_parser("m", help="合并 PDF 文件")
|
||||
merge_parser.add_argument("inputs", nargs="+", help="输入 PDF 文件路径")
|
||||
merge_parser.add_argument("--output", type=str, default="merged.pdf", help="输出文件路径")
|
||||
|
||||
# 拆分 PDF 命令
|
||||
split_parser = subparsers.add_parser("s", help="拆分 PDF 文件为单页")
|
||||
split_parser.add_argument("input", help="输入 PDF 文件路径")
|
||||
split_parser.add_argument("--output-dir", type=str, default="split", help="输出目录")
|
||||
|
||||
# 压缩 PDF 命令
|
||||
compress_parser = subparsers.add_parser("c", help="压缩 PDF 文件")
|
||||
compress_parser.add_argument("input", help="输入 PDF 文件路径")
|
||||
compress_parser.add_argument("--output", type=str, default="compressed.pdf", help="输出文件路径")
|
||||
|
||||
# 加密 PDF 命令
|
||||
encrypt_parser = subparsers.add_parser("e", help="加密 PDF 文件")
|
||||
encrypt_parser.add_argument("input", help="输入 PDF 文件路径")
|
||||
encrypt_parser.add_argument("--output", type=str, default="encrypted.pdf", help="输出文件路径")
|
||||
encrypt_parser.add_argument("--password", type=str, required=True, help="密码")
|
||||
|
||||
# 解密 PDF 命令
|
||||
decrypt_parser = subparsers.add_parser("d", help="解密 PDF 文件")
|
||||
decrypt_parser.add_argument("input", help="输入 PDF 文件路径")
|
||||
decrypt_parser.add_argument("--output", type=str, default="decrypted.pdf", help="输出文件路径")
|
||||
decrypt_parser.add_argument("--password", type=str, required=True, help="密码")
|
||||
|
||||
# 提取文本命令
|
||||
extract_text_parser = subparsers.add_parser("xt", help="提取 PDF 文本")
|
||||
extract_text_parser.add_argument("input", help="输入 PDF 文件路径")
|
||||
extract_text_parser.add_argument("--output", type=str, default="output.txt", help="输出文件路径")
|
||||
|
||||
# 提取图片命令
|
||||
extract_images_parser = subparsers.add_parser("xi", help="提取 PDF 图片")
|
||||
extract_images_parser.add_argument("input", help="输入 PDF 文件路径")
|
||||
extract_images_parser.add_argument("--output-dir", type=str, default="images", help="输出目录")
|
||||
|
||||
# 添加水印命令
|
||||
watermark_parser = subparsers.add_parser("w", help="添加 PDF 水印")
|
||||
watermark_parser.add_argument("input", help="输入 PDF 文件路径")
|
||||
watermark_parser.add_argument("--output", type=str, default="watermarked.pdf", help="输出文件路径")
|
||||
watermark_parser.add_argument("--text", type=str, default="CONFIDENTIAL", help="水印文本")
|
||||
|
||||
# 旋转 PDF 命令
|
||||
rotate_parser = subparsers.add_parser("r", help="旋转 PDF 页面")
|
||||
rotate_parser.add_argument("input", help="输入 PDF 文件路径")
|
||||
rotate_parser.add_argument("--output", type=str, default="rotated.pdf", help="输出文件路径")
|
||||
rotate_parser.add_argument("--rotation", type=int, default=90, help="旋转角度 (90, 180, 270)")
|
||||
|
||||
# 裁剪 PDF 命令
|
||||
crop_parser = subparsers.add_parser("crop", help="裁剪 PDF 页面")
|
||||
crop_parser.add_argument("input", help="输入 PDF 文件路径")
|
||||
crop_parser.add_argument("--output", type=str, default="cropped.pdf", help="输出文件路径")
|
||||
crop_parser.add_argument("--left", type=int, default=10, help="左边裁剪")
|
||||
crop_parser.add_argument("--top", type=int, default=10, help="顶部裁剪")
|
||||
crop_parser.add_argument("--right", type=int, default=10, help="右边裁剪")
|
||||
crop_parser.add_argument("--bottom", type=int, default=10, help="底部裁剪")
|
||||
|
||||
# 显示信息命令
|
||||
info_parser = subparsers.add_parser("i", help="显示 PDF 信息")
|
||||
info_parser.add_argument("input", help="输入 PDF 文件路径")
|
||||
|
||||
# OCR 识别命令
|
||||
ocr_parser = subparsers.add_parser("ocr", help="PDF OCR 识别")
|
||||
ocr_parser.add_argument("input", help="输入 PDF 文件路径")
|
||||
ocr_parser.add_argument("--output", type=str, default="ocr.pdf", help="输出文件路径")
|
||||
ocr_parser.add_argument("--lang", type=str, default="chi_sim+eng", help="OCR 语言")
|
||||
|
||||
# 转换图片命令
|
||||
to_images_parser = subparsers.add_parser("img", help="PDF 转图片")
|
||||
to_images_parser.add_argument("input", help="输入 PDF 文件路径")
|
||||
to_images_parser.add_argument("--output-dir", type=str, default="images", help="输出目录")
|
||||
to_images_parser.add_argument("--dpi", type=int, default=300, help="图片 DPI")
|
||||
|
||||
# 修复 PDF 命令
|
||||
repair_parser = subparsers.add_parser("repair", help="修复 PDF 文件")
|
||||
repair_parser.add_argument("input", help="输入 PDF 文件路径")
|
||||
repair_parser.add_argument("--output", type=str, default="repaired.pdf", help="输出文件路径")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "m":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("pdf_merge", fn=pdf_merge, args=([Path(p) for p in args.inputs], Path(args.output)))
|
||||
])
|
||||
elif args.command == "s":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("pdf_split", fn=pdf_split, args=(Path(args.input), Path(args.output_dir)))
|
||||
])
|
||||
elif args.command == "c":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("pdf_compress", fn=pdf_compress, args=(Path(args.input), Path(args.output)))
|
||||
])
|
||||
elif args.command == "e":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("pdf_encrypt", fn=pdf_encrypt, args=(Path(args.input), Path(args.output), args.password))
|
||||
])
|
||||
elif args.command == "d":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("pdf_decrypt", fn=pdf_decrypt, args=(Path(args.input), Path(args.output), args.password))
|
||||
])
|
||||
elif args.command == "xt":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("pdf_extract_text", fn=pdf_extract_text, args=(Path(args.input), Path(args.output)))
|
||||
])
|
||||
elif args.command == "xi":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("pdf_extract_images", fn=pdf_extract_images, args=(Path(args.input), Path(args.output_dir)))
|
||||
])
|
||||
elif args.command == "w":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"pdf_watermark",
|
||||
fn=pdf_add_watermark,
|
||||
args=(Path(args.input), Path(args.output)),
|
||||
kwargs={"text": args.text},
|
||||
)
|
||||
])
|
||||
elif args.command == "r":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"pdf_rotate",
|
||||
fn=pdf_rotate,
|
||||
args=(Path(args.input), Path(args.output)),
|
||||
kwargs={"rotation": args.rotation},
|
||||
)
|
||||
])
|
||||
elif args.command == "crop":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"pdf_crop",
|
||||
fn=pdf_crop,
|
||||
args=(Path(args.input), Path(args.output)),
|
||||
kwargs={"margins": (args.left, args.top, args.right, args.bottom)},
|
||||
)
|
||||
])
|
||||
elif args.command == "i":
|
||||
graph = px.Graph.from_specs([px.TaskSpec("pdf_info", fn=pdf_info, args=(Path(args.input),))])
|
||||
elif args.command == "ocr":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("pdf_ocr", fn=pdf_ocr, args=(Path(args.input), Path(args.output)), kwargs={"lang": args.lang})
|
||||
])
|
||||
elif args.command == "img":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"pdf_to_images",
|
||||
fn=pdf_to_images,
|
||||
args=(Path(args.input), Path(args.output_dir)),
|
||||
kwargs={"dpi": args.dpi},
|
||||
)
|
||||
])
|
||||
elif args.command == "repair":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("pdf_repair", fn=pdf_repair, args=(Path(args.input), Path(args.output)))
|
||||
])
|
||||
else:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
px.run(graph, strategy="thread")
|
||||
@@ -0,0 +1,195 @@
|
||||
"""pip 包管理工具模块.
|
||||
|
||||
提供 pip 包管理操作的封装,
|
||||
支持安装、卸载、下载等功能.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import fnmatch
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
# ============================================================================
|
||||
# 配置
|
||||
# ============================================================================
|
||||
|
||||
PACKAGE_DIR = "packages"
|
||||
REQUIREMENTS_FILE = "requirements.txt"
|
||||
|
||||
# 受保护的包名集合
|
||||
_PROTECTED_PACKAGES: frozenset[str] = frozenset({
|
||||
"pyflowx",
|
||||
"bitool",
|
||||
})
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 辅助函数
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def _get_installed_packages() -> list[str]:
|
||||
"""获取当前环境中所有已安装的包名."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["pip", "list", "--format=freeze"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
packages: list[str] = []
|
||||
for line in result.stdout.strip().split("\n"):
|
||||
if line and "==" in line:
|
||||
pkg_name = line.split("==")[0].strip()
|
||||
packages.append(pkg_name)
|
||||
except (subprocess.SubprocessError, OSError):
|
||||
return []
|
||||
return packages
|
||||
|
||||
|
||||
def _expand_wildcard_packages(pattern: str) -> list[str]:
|
||||
"""展开通配符模式为实际的包名列表."""
|
||||
if not any(char in pattern for char in ["*", "?", "[", "]"]):
|
||||
return [pattern]
|
||||
|
||||
installed_packages = _get_installed_packages()
|
||||
matched = [pkg for pkg in installed_packages if fnmatch.fnmatchcase(pkg.lower(), pattern.lower())]
|
||||
return matched
|
||||
|
||||
|
||||
def _filter_protected_packages(packages: list[str]) -> list[str]:
|
||||
"""过滤掉受保护的包名."""
|
||||
safe = [p for p in packages if p.lower() not in {p.lower() for p in _PROTECTED_PACKAGES}]
|
||||
filtered = [p for p in packages if p.lower() in {p.lower() for p in _PROTECTED_PACKAGES}]
|
||||
if filtered:
|
||||
print(f"跳过受保护的包: {', '.join(filtered)}")
|
||||
return safe
|
||||
|
||||
|
||||
def pip_uninstall(pkg_names: list[str]) -> None:
|
||||
"""卸载包."""
|
||||
packages_to_uninstall: list[str] = []
|
||||
for pattern in pkg_names:
|
||||
packages_to_uninstall.extend(_expand_wildcard_packages(pattern))
|
||||
|
||||
packages_to_uninstall = _filter_protected_packages(packages_to_uninstall)
|
||||
|
||||
if not packages_to_uninstall:
|
||||
return
|
||||
|
||||
subprocess.run(["pip", "uninstall", "-y", *packages_to_uninstall], check=True)
|
||||
|
||||
|
||||
def pip_reinstall(pkg_names: list[str], offline: bool = False) -> None:
|
||||
"""重新安装包."""
|
||||
safe_pkgs = _filter_protected_packages(pkg_names)
|
||||
if not safe_pkgs:
|
||||
print("所有指定的包均为受保护包, 跳过重装")
|
||||
return
|
||||
|
||||
subprocess.run(["pip", "uninstall", "-y", *safe_pkgs], check=True)
|
||||
|
||||
options = ["--no-index", "--find-links", "."] if offline else []
|
||||
subprocess.run(["pip", "install", *options, *safe_pkgs], check=True)
|
||||
|
||||
|
||||
def pip_download(pkg_names: list[str], offline: bool = False) -> None:
|
||||
"""下载包."""
|
||||
options = ["--no-index", "--find-links", "."] if offline else []
|
||||
subprocess.run(
|
||||
["pip", "download", *pkg_names, *options, "-d", PACKAGE_DIR],
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
def pip_freeze() -> None:
|
||||
"""冻结依赖."""
|
||||
result = subprocess.run(
|
||||
["pip", "freeze", "--exclude-editable"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
Path(REQUIREMENTS_FILE).write_text(result.stdout)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CLI Runner
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""pip 工具主函数."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="PipTool - pip 包管理工具",
|
||||
usage="piptool <command> [options]",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="可用命令")
|
||||
|
||||
# 安装命令
|
||||
install_parser = subparsers.add_parser("i", help="安装包")
|
||||
install_parser.add_argument("packages", nargs="+", help="要安装的包名")
|
||||
|
||||
# 卸载命令
|
||||
uninstall_parser = subparsers.add_parser("u", help="卸载包")
|
||||
uninstall_parser.add_argument("packages", nargs="+", help="要卸载的包名 (支持通配符)")
|
||||
|
||||
# 重装命令
|
||||
reinstall_parser = subparsers.add_parser("r", help="重新安装包")
|
||||
reinstall_parser.add_argument("packages", nargs="+", help="要重装的包名")
|
||||
reinstall_parser.add_argument("--offline", action="store_true", help="使用离线模式")
|
||||
|
||||
# 下载命令
|
||||
download_parser = subparsers.add_parser("d", help="下载包")
|
||||
download_parser.add_argument("packages", nargs="+", help="要下载的包名")
|
||||
download_parser.add_argument("--offline", action="store_true", help="使用离线模式")
|
||||
|
||||
# 升级 pip 命令
|
||||
subparsers.add_parser("up", help="升级 pip")
|
||||
|
||||
# 冻结依赖命令
|
||||
subparsers.add_parser("f", help="冻结依赖到 requirements.txt")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "i":
|
||||
graph = px.Graph.from_specs([px.TaskSpec("pip_install", cmd=["pip", "install", *args.packages], verbose=True)])
|
||||
elif args.command == "u":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("pip_uninstall", fn=pip_uninstall, args=(args.packages,), verbose=True)
|
||||
])
|
||||
elif args.command == "r":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"pip_reinstall",
|
||||
fn=pip_reinstall,
|
||||
args=(args.packages,),
|
||||
kwargs={"offline": args.offline},
|
||||
verbose=True,
|
||||
)
|
||||
])
|
||||
elif args.command == "d":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"pip_download",
|
||||
fn=pip_download,
|
||||
args=(args.packages,),
|
||||
kwargs={"offline": args.offline},
|
||||
verbose=True,
|
||||
)
|
||||
])
|
||||
elif args.command == "up":
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("pip_upgrade", cmd=["python", "-m", "pip", "install", "--upgrade", "pip"], verbose=True)
|
||||
])
|
||||
elif args.command == "f":
|
||||
graph = px.Graph.from_specs([px.TaskSpec("pip_freeze", fn=pip_freeze, verbose=True)])
|
||||
else:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
px.run(graph, strategy="thread")
|
||||
@@ -0,0 +1,116 @@
|
||||
"""Python 构建工具模块.
|
||||
|
||||
完全替代传统的 Makefile,
|
||||
提供更好的跨平台兼容性和 Python 生态集成.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.conditions import Constants
|
||||
|
||||
MATURIN_BUILD_COMMAND = ["maturin", "build", "-r"]
|
||||
if Constants.IS_WINDOWS:
|
||||
MATURIN_BUILD_COMMAND.extend(["--target", "x86_64-win7-windows-msvc", "-Zbuild-std", "-i", "python3.8"])
|
||||
|
||||
# 扁平注册所有任务(px.cmd 自动从命令前两段推导 name)
|
||||
tasks: list[px.TaskSpec] = [
|
||||
px.cmd(["uv", "build"]),
|
||||
px.cmd(MATURIN_BUILD_COMMAND),
|
||||
px.cmd(["uv", "sync"]),
|
||||
px.cmd(["gitt", "c"], name="git_clean"),
|
||||
px.cmd(
|
||||
["pytest", "-m", "not slow", "-n", "8", "--dist", "loadfile", "--color=yes", "--durations=10"],
|
||||
name="test",
|
||||
),
|
||||
px.cmd(
|
||||
["pytest", "-m", "not slow", "--dist", "loadfile", "--color=yes", "--durations=10"],
|
||||
name="test_fast",
|
||||
),
|
||||
px.cmd(
|
||||
["pytest", "--cov", "-n", "8", "--dist", "loadfile", "--tb=short", "-v", "--color=yes", "--durations=10"],
|
||||
name="test_coverage",
|
||||
),
|
||||
px.cmd(["pyrefly", "check", "."]),
|
||||
px.cmd(["git", "add", "-A"], name="git_add_all"),
|
||||
px.cmd(["bumpversion"]),
|
||||
px.cmd(["bumpversion", "minor"]),
|
||||
px.cmd(["git", "push"]),
|
||||
px.cmd(["git", "push", "--tags"], name="git_push_tags"),
|
||||
px.cmd(["hatch", "publish"], name="publish_python"),
|
||||
px.cmd(["twine", "upload", "--disable-progress-bar"], name="twine_publish"),
|
||||
]
|
||||
|
||||
# 单任务别名(alias 名与任务名相同):直接内联 TaskSpec,避免 str 自引用
|
||||
aliases: dict[str, str | list[str | px.TaskSpec] | px.TaskSpec | px.Graph] = {
|
||||
# 构建命令
|
||||
"b": "uv_build",
|
||||
"bc": "maturin_build",
|
||||
"ba": ["b", "bc"],
|
||||
# 安装命令
|
||||
"sync": "uv_sync",
|
||||
# 清理命令
|
||||
"c": "git_clean",
|
||||
# 开发工具
|
||||
"bump": ["c", "tc", "git_add_all", "bumpversion"],
|
||||
"bumpmi": "bumpversion_minor",
|
||||
"cov": ["git_clean", "test_coverage"],
|
||||
"doc": px.cmd(["sphinx-build", "-b", "html", "docs", "docs/_build"], name="doc"),
|
||||
"lint": px.cmd(["ruff", "check", "--fix", "--unsafe-fixes"], name="lint"),
|
||||
"pb": ["twine_publish", "publish_python"],
|
||||
"t": "test",
|
||||
"tf": "test_fast",
|
||||
"tc": ["pyrefly_check", "lint"],
|
||||
"tox": px.cmd(["tox", "-p", "auto"], name="tox"),
|
||||
# 发布命令
|
||||
"p": ["git_clean", "git_push", "git_push_tags"],
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""pymake 构建工具.
|
||||
|
||||
🔨 构建命令:
|
||||
pymake b - 构建 Python 主包 (uv build)
|
||||
pymake bc - 构建 Rust 核心模块 (maturin build)
|
||||
pymake ba - 构建所有包 (先 Python 后 Rust)
|
||||
|
||||
📦 安装命令 (开发模式):
|
||||
pymake sync - 安装依赖包 (uv sync)
|
||||
|
||||
🧹 清理命令:
|
||||
pymake c - 清理所有构建产物 (gitt c)
|
||||
|
||||
🛠️ 开发工具:
|
||||
pymake t - 运行测试 (pytest)
|
||||
pymake tc - 运行测试并生成覆盖率报告
|
||||
pymake tf - 运行快速测试 (pytest -m not slow)
|
||||
pymake lint - 代码格式化与检查 (ruff)
|
||||
pymake type - 类型检查 (mypy, ty)
|
||||
pymake doc - 构建文档 (sphinx)
|
||||
|
||||
🔬 多版本测试:
|
||||
pymake tox - 多版本 Python 测试 (tox -p auto)
|
||||
|
||||
📦 发布命令:
|
||||
pymake pb - 发布到 PyPI (twine + hatch)
|
||||
|
||||
🔖 版本管理:
|
||||
pymake bump - 自动升级版本号并提交修改 (清理 + 检查 + 格式化 + git add + bumpversion)
|
||||
|
||||
💡 常用工作流:
|
||||
1. 日常开发: pymake lint && pymake t
|
||||
2. 构建发布包: pymake ba
|
||||
3. 多版本兼容性测试: pymake tox
|
||||
4. 发布到 PyPI: pymake pb
|
||||
|
||||
📝 示例:
|
||||
pymake ba # 构建所有包
|
||||
pymake sync # 安装依赖
|
||||
pymake t # 运行测试
|
||||
pymake tox # 多版本兼容性测试
|
||||
pymake lint # 格式化代码
|
||||
pymake type # 类型检查
|
||||
"""
|
||||
runner = px.CliRunner(strategy="sequential", description="PyMake - Python 构建工具", tasks=tasks, aliases=aliases)
|
||||
runner.run_cli()
|
||||
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.tasks.system import reset_icon_cache
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""重启图标缓存工具主函数."""
|
||||
graph = px.Graph.from_specs(reset_icon_cache())
|
||||
px.run(graph, strategy="thread")
|
||||
@@ -0,0 +1,163 @@
|
||||
"""截图工具.
|
||||
|
||||
跨平台截图工具, 支持全屏截图和区域截图.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.conditions import Constants
|
||||
|
||||
# ============================================================================
|
||||
# 辅助函数
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def get_screenshot_path(filename: str | None = None) -> Path:
|
||||
"""获取截图保存路径.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename : str | None
|
||||
文件名, 如果为 None 则自动生成
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
截图保存路径
|
||||
"""
|
||||
if filename is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"screenshot_{timestamp}.png"
|
||||
|
||||
screenshots_dir = Path.home() / "Pictures" / "screenshots"
|
||||
screenshots_dir.mkdir(parents=True, exist_ok=True)
|
||||
return screenshots_dir / filename
|
||||
|
||||
|
||||
def take_screenshot_full(filename: str | None = None) -> None:
|
||||
"""全屏截图.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename : str | None
|
||||
文件名
|
||||
"""
|
||||
output_path = get_screenshot_path(filename)
|
||||
|
||||
if Constants.IS_WINDOWS:
|
||||
# Windows: 使用 PowerShell 截图
|
||||
ps_script = f"""
|
||||
Add-Type -AssemblyName System.Windows.Forms
|
||||
Add-Type -AssemblyName System.Drawing
|
||||
$screen = [System.Windows.Forms.Screen]::PrimaryScreen
|
||||
$bounds = $screen.Bounds
|
||||
$bitmap = New-Object System.Drawing.Bitmap $bounds.Width, $bounds.Height
|
||||
$graphics = [System.Drawing.Graphics]::FromImage($bitmap)
|
||||
$graphics.CopyFromScreen($bounds.Location, [System.Drawing.Point]::Empty, $bounds.Size)
|
||||
$bitmap.Save('{output_path.as_posix()}')
|
||||
$graphics.Dispose()
|
||||
$bitmap.Dispose()
|
||||
"""
|
||||
subprocess.run(["powershell", "-Command", ps_script], check=True)
|
||||
elif Constants.IS_MACOS:
|
||||
# macOS: 使用 screencapture
|
||||
subprocess.run(["screencapture", "-x", str(output_path)], check=True)
|
||||
else:
|
||||
# Linux: 使用 gnome-screenshot 或 scrot
|
||||
try:
|
||||
subprocess.run(["gnome-screenshot", "-f", str(output_path)], check=True)
|
||||
except FileNotFoundError:
|
||||
subprocess.run(["scrot", str(output_path)], check=True)
|
||||
|
||||
print(f"截图已保存: {output_path}")
|
||||
|
||||
|
||||
def take_screenshot_area(filename: str | None = None) -> None:
|
||||
"""区域截图.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename : str | None
|
||||
文件名
|
||||
"""
|
||||
output_path = get_screenshot_path(filename)
|
||||
|
||||
if Constants.IS_WINDOWS:
|
||||
# Windows: 使用 PowerShell 截图 (需要用户选择区域)
|
||||
ps_script = f"""
|
||||
Add-Type -AssemblyName System.Windows.Forms
|
||||
Add-Type -AssemblyName System.Drawing
|
||||
$form = New-Object System.Windows.Forms.Form
|
||||
$form.WindowState = 'Maximized'
|
||||
$form.FormBorderStyle = 'None'
|
||||
$form.BackColor = [System.Drawing.Color]::FromArgb(1, 0, 0)
|
||||
$form.Opacity = 0.5
|
||||
$form.TopMost = $true
|
||||
$form.Show()
|
||||
Start-Sleep -Milliseconds 100
|
||||
$screen = [System.Windows.Forms.Screen]::PrimaryScreen
|
||||
$bounds = $screen.Bounds
|
||||
$bitmap = New-Object System.Drawing.Bitmap $bounds.Width, $bounds.Height
|
||||
$graphics = [System.Drawing.Graphics]::FromImage($bitmap)
|
||||
$graphics.CopyFromScreen($bounds.Location, [System.Drawing.Point]::Empty, $bounds.Size)
|
||||
$form.Close()
|
||||
$bitmap.Save('{output_path.as_posix()}')
|
||||
$graphics.Dispose()
|
||||
$bitmap.Dispose()
|
||||
"""
|
||||
subprocess.run(["powershell", "-Command", ps_script], check=True)
|
||||
elif Constants.IS_MACOS:
|
||||
# macOS: 使用 screencapture 交互模式
|
||||
subprocess.run(["screencapture", "-i", str(output_path)], check=True)
|
||||
else:
|
||||
# Linux: 使用 gnome-screenshot 交互模式
|
||||
try:
|
||||
subprocess.run(["gnome-screenshot", "-a", "-f", str(output_path)], check=True)
|
||||
except FileNotFoundError:
|
||||
subprocess.run(["scrot", "-s", str(output_path)], check=True)
|
||||
|
||||
print(f"截图已保存: {output_path}")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CLI Runner
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""截图工具主函数."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Screenshot - 截图工具",
|
||||
usage="screenshot <command> [options]",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="可用命令")
|
||||
|
||||
# 全屏截图命令
|
||||
full_parser = subparsers.add_parser("full", help="全屏截图")
|
||||
full_parser.add_argument("--filename", type=str, help="文件名")
|
||||
|
||||
# 区域截图命令
|
||||
area_parser = subparsers.add_parser("area", help="区域截图")
|
||||
area_parser.add_argument("--filename", type=str, help="文件名")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "full":
|
||||
graph = px.Graph.from_specs(
|
||||
[px.TaskSpec("screenshot_full", fn=take_screenshot_full, kwargs={"filename": args.filename})]
|
||||
)
|
||||
elif args.command == "area":
|
||||
graph = px.Graph.from_specs(
|
||||
[px.TaskSpec("screenshot_area", fn=take_screenshot_area, kwargs={"filename": args.filename})]
|
||||
)
|
||||
else:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
px.run(graph, strategy="thread")
|
||||
@@ -0,0 +1,122 @@
|
||||
"""SSH 密钥部署工具.
|
||||
|
||||
类似 ssh-copy-id, 自动将 SSH 公钥部署到远程服务器,
|
||||
支持密码认证和密钥认证两种方式.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
# ============================================================================
|
||||
# 辅助函数
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def ssh_copy_id(
|
||||
hostname: str,
|
||||
username: str,
|
||||
password: str,
|
||||
port: int = 22,
|
||||
keypath: str = "~/.ssh/id_rsa.pub",
|
||||
timeout: int = 30,
|
||||
) -> None:
|
||||
"""将 SSH 公钥部署到远程服务器.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
hostname : str
|
||||
远程服务器主机名或 IP 地址
|
||||
username : str
|
||||
远程服务器用户名
|
||||
password : str
|
||||
远程服务器密码
|
||||
port : int
|
||||
SSH 端口, 默认 22
|
||||
keypath : str
|
||||
公钥文件路径, 默认 ~/.ssh/id_rsa.pub
|
||||
timeout : int
|
||||
SSH 操作超时秒数, 默认 30
|
||||
"""
|
||||
# 读取公钥
|
||||
pub_key_path = Path(keypath).expanduser()
|
||||
if not pub_key_path.exists():
|
||||
print(f"公钥文件不存在: {pub_key_path}")
|
||||
sys.exit(1)
|
||||
|
||||
pub_key = pub_key_path.read_text().strip()
|
||||
|
||||
# 构建部署脚本
|
||||
script = f"""mkdir -p ~/.ssh && chmod 700 ~/.ssh
|
||||
cd ~/.ssh && touch authorized_keys && chmod 600 authorized_keys
|
||||
grep -qF '{pub_key.split()[1]}' authorized_keys 2>/dev/null || echo '{pub_key}' >> authorized_keys"""
|
||||
|
||||
# 使用 sshpass 执行
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"sshpass",
|
||||
"-p",
|
||||
password,
|
||||
"ssh",
|
||||
"-p",
|
||||
str(port),
|
||||
"-o",
|
||||
"StrictHostKeyChecking=no",
|
||||
"-o",
|
||||
"UserKnownHostsFile=/dev/null",
|
||||
"-o",
|
||||
f"ConnectTimeout={timeout}",
|
||||
f"{username}@{hostname}",
|
||||
script,
|
||||
],
|
||||
check=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
print(f"SSH 密钥已部署到 {username}@{hostname}:{port}")
|
||||
except FileNotFoundError:
|
||||
print(f"未找到 sshpass 工具,请手动执行: ssh-copy-id -p {port} {username}@{hostname}")
|
||||
sys.exit(1)
|
||||
except subprocess.TimeoutExpired:
|
||||
print("SSH 连接超时")
|
||||
sys.exit(1)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"SSH 执行失败: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CLI Runner
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""SSH 密钥部署工具主函数."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="SSHCopyID - SSH 密钥部署工具",
|
||||
usage="sshcopyid <hostname> <username> <password> [--port PORT] [--keypath KEYPATH]",
|
||||
)
|
||||
parser.add_argument("hostname", type=str, help="远程服务器主机名或 IP 地址")
|
||||
parser.add_argument("username", type=str, help="远程服务器用户名")
|
||||
parser.add_argument("password", type=str, help="远程服务器密码")
|
||||
parser.add_argument("--port", type=int, default=22, help="SSH 端口 (默认: 22)")
|
||||
parser.add_argument("--keypath", type=str, default="~/.ssh/id_rsa.pub", help="公钥文件路径")
|
||||
parser.add_argument("--timeout", type=int, default=30, help="SSH 操作超时秒数 (默认: 30)")
|
||||
args = parser.parse_args()
|
||||
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec(
|
||||
"ssh_deploy",
|
||||
fn=ssh_copy_id,
|
||||
args=(args.hostname, args.username, args.password),
|
||||
kwargs={"port": args.port, "keypath": args.keypath, "timeout": args.timeout},
|
||||
)
|
||||
]
|
||||
)
|
||||
px.run(graph, strategy="thread")
|
||||
@@ -0,0 +1,15 @@
|
||||
"""清屏工具.
|
||||
|
||||
跨平台清屏工具, 支持终端和控制台清屏.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.tasks.system import clr
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""清屏工具主函数."""
|
||||
graph = px.Graph.from_specs([clr()])
|
||||
px.run(graph, strategy="thread")
|
||||
@@ -0,0 +1,40 @@
|
||||
"""进程终止工具.
|
||||
|
||||
跨平台进程终止工具, 支持按名称终止进程.
|
||||
用法: taskkill proc_name [proc_name ...]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.conditions import Constants
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""进程终止工具主函数."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="TaskKill - 进程终止工具",
|
||||
usage="taskkill <process_name> [process_name ...]",
|
||||
)
|
||||
parser.add_argument(
|
||||
"process_names",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="进程名称 (如: chrome.exe python node)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if Constants.IS_WINDOWS:
|
||||
cmd = ["taskkill", "/f", "/im"]
|
||||
else:
|
||||
cmd = ["pkill", "-f"]
|
||||
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec(f"kill_{proc_name}", cmd=[*cmd, f"{proc_name}*"], verbose=True)
|
||||
for proc_name in args.process_names
|
||||
],
|
||||
)
|
||||
px.run(graph, strategy="thread")
|
||||
@@ -0,0 +1,21 @@
|
||||
"""命令查找工具.
|
||||
|
||||
跨平台查找可执行命令路径, 类似 Unix 的 which 命令.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.tasks.system import which
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""命令查找工具主函数."""
|
||||
parser = argparse.ArgumentParser(description="Which - 命令查找工具")
|
||||
parser.add_argument("commands", nargs="+", help="要查找的命令名称, 如: python ls ps gcc...")
|
||||
args = parser.parse_args()
|
||||
|
||||
graph = px.Graph.from_specs([which(cmd) for cmd in args.commands])
|
||||
px.run(graph, strategy="thread")
|
||||
@@ -0,0 +1,98 @@
|
||||
"""命令执行器:把 :class:`~pyflowx.task.TaskSpec` 的 ``cmd`` 字段(list /
|
||||
shell 字符串 / 可调用对象)转换为统一执行入口。
|
||||
|
||||
历史背景:原 ``task.py`` 的模块文档声明其为"纯数据结构",但 ``_run_command``
|
||||
属于命令执行逻辑,违反单一职责。此处将其抽离,``TaskSpec`` 仅持有配置,
|
||||
执行逻辑集中于本模块,便于独立测试与维护。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from typing import Any, List, Union, cast
|
||||
|
||||
from .task import TaskSpec
|
||||
|
||||
__all__ = ["run_command"]
|
||||
|
||||
|
||||
def run_command(spec: TaskSpec[Any]) -> Any: # noqa: PLR0912
|
||||
"""执行 ``spec.cmd`` 指定的命令(list / shell 字符串 / 可调用对象)。
|
||||
|
||||
与原 ``TaskSpec._run_command`` 行为一致:
|
||||
|
||||
- 可调用对象:直接调用,异常包装为 :class:`RuntimeError`。
|
||||
- list / str:通过 :func:`subprocess.run` 执行,非零返回码抛
|
||||
:class:`RuntimeError`(``verbose=False`` 时附 stderr)。
|
||||
- ``verbose=True`` 时打印执行信息与返回码到 stdout。
|
||||
- ``cwd`` / ``env`` 通过 subprocess 参数隔离(进程级状态仅在 fn 任务路径
|
||||
使用,cmd 路径不依赖 ``os.chdir`` / ``os.environ``)。
|
||||
"""
|
||||
cmd = spec.cmd
|
||||
verbose = spec.verbose
|
||||
cwd = spec.cwd
|
||||
timeout = spec.timeout
|
||||
env_override = spec.env
|
||||
|
||||
# 可调用对象:直接调用,返回其结果。
|
||||
if callable(cmd) and not isinstance(cmd, (list, str)):
|
||||
name = getattr(cmd, "__name__", "callable")
|
||||
if verbose:
|
||||
print(f"[verbose] 执行可调用命令: {name}", flush=True)
|
||||
if cwd is not None:
|
||||
print(f"[verbose] 工作目录: {cwd}", flush=True)
|
||||
try:
|
||||
return cmd()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"可调用命令执行异常: {name}: {e}") from e
|
||||
|
||||
is_list = isinstance(cmd, list)
|
||||
if is_list:
|
||||
cmd_str = " ".join(arg for arg in cmd) # type: ignore[union-attr]
|
||||
verb = "执行命令"
|
||||
label = "命令"
|
||||
else:
|
||||
cmd_str = cast(str, cmd)
|
||||
verb = "执行 Shell"
|
||||
label = "Shell 命令"
|
||||
|
||||
if verbose:
|
||||
print(f"[verbose] {verb}: {cmd_str}", flush=True)
|
||||
if cwd is not None:
|
||||
print(f"[verbose] 工作目录: {cwd}", flush=True)
|
||||
|
||||
# 合并环境变量
|
||||
run_env: dict[str, str] | None = None
|
||||
if env_override:
|
||||
run_env = dict(os.environ)
|
||||
run_env.update(env_override)
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cast(Union[str, List[str]], cmd),
|
||||
shell=not is_list,
|
||||
cwd=cwd,
|
||||
env=run_env,
|
||||
timeout=timeout,
|
||||
capture_output=not verbose,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise RuntimeError(f"{label}未找到: {cmd_str}") from None
|
||||
except subprocess.TimeoutExpired:
|
||||
raise RuntimeError(f"{label}执行超时: {cmd_str} ({timeout}s)") from None
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"{label}执行异常: {cmd_str}: {e}") from e
|
||||
|
||||
if verbose:
|
||||
print(f"[verbose] 返回码: {result.returncode}", flush=True)
|
||||
|
||||
if result.returncode == 0:
|
||||
return None
|
||||
|
||||
err_msg = f"{label}执行失败: `{cmd_str}`, 返回码: {result.returncode}"
|
||||
if not verbose and result.stderr.strip():
|
||||
err_msg += f"\n{result.stderr.strip()}"
|
||||
raise RuntimeError(err_msg)
|
||||
@@ -0,0 +1,115 @@
|
||||
"""图组合:将带字符串引用的多个图展开为纯 :class:`~pyflowx.graph.Graph`。
|
||||
|
||||
历史背景:原 ``graph.py`` 同时承载 DAG 构建/校验/分层与多图组合逻辑,
|
||||
职责过载。组合逻辑(:class:`GraphComposer` / :func:`compose`)与单图 DAG
|
||||
模型正交,此处抽离为独立模块,便于按需导入与独立演进。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import replace
|
||||
from typing import Any
|
||||
|
||||
from .graph import Graph
|
||||
from .task import TaskSpec
|
||||
|
||||
__all__ = ["GraphComposer", "compose"]
|
||||
|
||||
|
||||
class GraphComposer:
|
||||
"""将带字符串引用的图展开为纯 :class:`TaskSpec` 图。
|
||||
|
||||
引用格式:
|
||||
* ``"command_name"`` —— 引用整个命令图。
|
||||
* ``"command_name.task_name"`` —— 引用特定任务。
|
||||
|
||||
引用按顺序展开,后续引用的任务依赖前面引用的最后一个任务;
|
||||
原始 ``TaskSpec`` 之间也按出现顺序串行依赖。
|
||||
"""
|
||||
|
||||
def __init__(self, graphs: dict[str, Graph]) -> None:
|
||||
self.graphs = graphs
|
||||
|
||||
def resolve_all(self) -> dict[str, Graph]:
|
||||
"""解析所有图的字符串引用,返回展开后的新图映射。"""
|
||||
resolved: dict[str, Graph] = {}
|
||||
for cmd_name, graph in self.graphs.items():
|
||||
resolved[cmd_name] = self.expand_refs(graph, cmd_name)
|
||||
return resolved
|
||||
|
||||
def expand_refs(self, graph: Graph, current_cmd: str) -> Graph:
|
||||
"""展开图中的字符串引用。若无 ``_pending_refs``,原样返回。"""
|
||||
pending_refs = graph._pending_refs
|
||||
if not pending_refs:
|
||||
return graph
|
||||
|
||||
all_specs: list[TaskSpec[Any]] = []
|
||||
previous_ref_last_task: str | None = None
|
||||
|
||||
for ref in pending_refs:
|
||||
expanded_specs = self.parse_ref(ref, current_cmd)
|
||||
if previous_ref_last_task and expanded_specs:
|
||||
for i, task in enumerate(expanded_specs):
|
||||
if i == 0 or not task.depends_on:
|
||||
expanded_specs[i] = replace(task, depends_on=tuple({*task.depends_on, previous_ref_last_task}))
|
||||
if expanded_specs:
|
||||
previous_ref_last_task = expanded_specs[-1].name
|
||||
all_specs.extend(expanded_specs)
|
||||
|
||||
original_specs = list(graph.all_specs().values())
|
||||
if original_specs:
|
||||
if previous_ref_last_task:
|
||||
first = original_specs[0]
|
||||
all_specs.append(replace(first, depends_on=tuple({*first.depends_on, previous_ref_last_task})))
|
||||
else:
|
||||
all_specs.append(original_specs[0])
|
||||
for i in range(1, len(original_specs)):
|
||||
current_task = original_specs[i]
|
||||
previous_task_name = original_specs[i - 1].name
|
||||
all_specs.append(
|
||||
replace(current_task, depends_on=tuple({*current_task.depends_on, previous_task_name}))
|
||||
)
|
||||
|
||||
return Graph.from_specs(all_specs, defaults=graph.defaults)
|
||||
|
||||
def parse_ref(self, ref: str, current_cmd: str) -> list[TaskSpec[Any]]:
|
||||
"""解析单个字符串引用,返回对应的 TaskSpec 列表。"""
|
||||
if ref == current_cmd:
|
||||
raise ValueError(f"循环引用: 命令 '{current_cmd}' 引用了自己")
|
||||
|
||||
if "." in ref:
|
||||
cmd_name, task_name = ref.split(".", 1)
|
||||
if cmd_name not in self.graphs:
|
||||
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
|
||||
ref_graph = self.graphs[cmd_name]
|
||||
if task_name not in ref_graph.all_specs():
|
||||
raise ValueError(f"任务 '{task_name}' 不存在于命令 '{cmd_name}' 中")
|
||||
return [ref_graph.all_specs()[task_name]]
|
||||
else:
|
||||
cmd_name = ref
|
||||
if cmd_name not in self.graphs:
|
||||
raise ValueError(f"引用的命令 '{cmd_name}' 不存在")
|
||||
ref_graph = self.graphs[cmd_name]
|
||||
ref_graph = self.expand_refs(ref_graph, cmd_name)
|
||||
return list(ref_graph.all_specs().values())
|
||||
|
||||
|
||||
def compose(
|
||||
graphs: dict[str, Graph],
|
||||
) -> dict[str, Graph]:
|
||||
"""编程式解析多图的字符串引用,返回展开后的新图映射。
|
||||
|
||||
与 :class:`GraphComposer` 等价,但作为独立函数暴露,供不使用
|
||||
:class:`~pyflowx.runner.CliRunner` 的编程式用户调用。
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> graphs = {
|
||||
... "build": px.Graph.from_specs([px.TaskSpec("b", cmd=["echo", "b"])]),
|
||||
... "all": px.Graph.from_specs(["build", px.TaskSpec("t", cmd=["echo", "t"])]),
|
||||
... }
|
||||
>>> resolved = px.compose(graphs)
|
||||
>>> "b" in resolved["all"].all_specs()
|
||||
True
|
||||
"""
|
||||
return GraphComposer(graphs).resolve_all()
|
||||
@@ -0,0 +1,250 @@
|
||||
"""条件判断模块.
|
||||
|
||||
所有条件均为 ``Callable[[Context], bool]``,接收依赖上下文映射(可能为空)。
|
||||
这使得条件可基于上游任务的运行时返回值做决策,实现动态分支。
|
||||
|
||||
内置条件分两类:
|
||||
1. *静态条件* —— 不依赖上下文(平台/环境变量/安装检查),通过 ``_static``
|
||||
包装忽略传入的 context,便于作为模块级常量使用。
|
||||
2. *上下文条件* —— 基于上游结果判断,如 :meth:`BuiltinConditions.DEP_EQUALS`。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
from .task import Condition, Context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["BuiltinConditions", "Condition", "Constants"]
|
||||
|
||||
|
||||
class Constants:
|
||||
"""常量定义."""
|
||||
|
||||
IS_WINDOWS: bool = sys.platform == "win32"
|
||||
IS_LINUX: bool = sys.platform == "linux"
|
||||
IS_MACOS: bool = sys.platform == "darwin"
|
||||
IS_POSIX: bool = sys.platform != "win32"
|
||||
|
||||
|
||||
def _static(predicate: Callable[[], bool], name: str) -> Condition:
|
||||
"""将无参谓词包装为忽略上下文的 :class:`Condition`。"""
|
||||
|
||||
def _cond(_ctx: Context) -> bool:
|
||||
return predicate()
|
||||
|
||||
_cond.__name__ = name
|
||||
return _cond
|
||||
|
||||
|
||||
def _cond_name(cond: Condition) -> str:
|
||||
"""获取条件的可读名称。"""
|
||||
return getattr(cond, "__name__", repr(cond))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 模块级静态条件常量
|
||||
# ---------------------------------------------------------------------- #
|
||||
IS_WINDOWS: Condition = _static(lambda: Constants.IS_WINDOWS, "IS_WINDOWS")
|
||||
IS_LINUX: Condition = _static(lambda: Constants.IS_LINUX, "IS_LINUX")
|
||||
IS_MACOS: Condition = _static(lambda: Constants.IS_MACOS, "IS_MACOS")
|
||||
IS_POSIX: Condition = _static(lambda: Constants.IS_POSIX, "IS_POSIX")
|
||||
|
||||
|
||||
class BuiltinConditions:
|
||||
"""内置条件判断函数集合.
|
||||
|
||||
静态条件工厂返回忽略上下文的 :class:`Condition`;上下文条件工厂返回
|
||||
会读取依赖结果的 :class:`Condition`。
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 静态条件
|
||||
# ------------------------------------------------------------------ #
|
||||
@staticmethod
|
||||
def IS_WINDOWS() -> Condition:
|
||||
"""检查是否为 Windows 平台."""
|
||||
return IS_WINDOWS
|
||||
|
||||
@staticmethod
|
||||
def IS_LINUX() -> Condition:
|
||||
"""检查是否为 Linux 平台."""
|
||||
return IS_LINUX
|
||||
|
||||
@staticmethod
|
||||
def IS_MACOS() -> Condition:
|
||||
"""检查是否为 macOS 平台."""
|
||||
return IS_MACOS
|
||||
|
||||
@staticmethod
|
||||
def IS_POSIX() -> Condition:
|
||||
"""检查是否为 POSIX 平台."""
|
||||
return IS_POSIX
|
||||
|
||||
@staticmethod
|
||||
def PYTHON_VERSION(major: int, minor: int | None = None) -> Condition:
|
||||
"""检查 Python 版本是否匹配."""
|
||||
if minor is None:
|
||||
return _static(lambda: sys.version_info.major == major, f"PYTHON_VERSION({major})")
|
||||
return _static(
|
||||
lambda: sys.version_info.major == major and sys.version_info.minor == minor,
|
||||
f"PYTHON_VERSION({major},{minor})",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def PYTHON_VERSION_AT_LEAST(major: int, minor: int = 0) -> Condition:
|
||||
"""检查 Python 版本是否 >= 指定版本."""
|
||||
return _static(lambda: sys.version_info >= (major, minor), f"PYTHON_VERSION_AT_LEAST({major},{minor})")
|
||||
|
||||
@staticmethod
|
||||
def IS_RUNNING(app_name: str) -> Condition:
|
||||
"""检查指定应用是否正在运行."""
|
||||
|
||||
def _check() -> bool:
|
||||
if Constants.IS_WINDOWS:
|
||||
result = subprocess.run(
|
||||
["tasklist", "/nh", "/fi", f"imagename eq {app_name}"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
return app_name.lower() in result.stdout.lower()
|
||||
else:
|
||||
result = subprocess.run(["pgrep", "-x", app_name], capture_output=True, check=False)
|
||||
return result.returncode == 0
|
||||
|
||||
return _static(_check, f"IS_RUNNING({app_name!r})")
|
||||
|
||||
@staticmethod
|
||||
def HAS_INSTALLED(app_name: str) -> Condition:
|
||||
"""检查指定应用是否已安装."""
|
||||
return _static(lambda: shutil.which(app_name) is not None, f"HAS_INSTALLED({app_name!r})")
|
||||
|
||||
@staticmethod
|
||||
def DIR_EXISTS(path: Path) -> Condition:
|
||||
"""路径是否存在."""
|
||||
return _static(path.exists, f"DIR_EXISTS({path!r})")
|
||||
|
||||
@staticmethod
|
||||
def ENV_VAR_EXISTS(var_name: str) -> Condition:
|
||||
"""检查环境变量是否存在."""
|
||||
return _static(lambda: var_name in os.environ, f"ENV_VAR_EXISTS({var_name!r})")
|
||||
|
||||
@staticmethod
|
||||
def ENV_VAR_EQUALS(var_name: str, value: str) -> Condition:
|
||||
"""检查环境变量是否等于指定值."""
|
||||
return _static(
|
||||
lambda: os.environ.get(var_name) == value,
|
||||
f"ENV_VAR_EQUALS({var_name!r},{value!r})",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def FILE_CONTENT_EXISTS(path: Path | str, content: str) -> Condition:
|
||||
"""检查文件是否包含指定内容."""
|
||||
|
||||
def _check() -> bool:
|
||||
p = Path(path)
|
||||
if not p.exists():
|
||||
return False
|
||||
try:
|
||||
return content in p.read_text(encoding="utf-8")
|
||||
except (OSError, UnicodeDecodeError):
|
||||
return False
|
||||
|
||||
return _static(_check, f"FILE_CONTENT_EXISTS({path!r},{content!r})")
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 上下文条件:基于上游依赖结果
|
||||
# ------------------------------------------------------------------ #
|
||||
@staticmethod
|
||||
def DEP_EQUALS(dep_name: str, value: Any) -> Condition:
|
||||
"""上游任务 ``dep_name`` 的返回值等于 ``value`` 时为真。
|
||||
|
||||
若依赖未在上下文中(被跳过或未执行),返回 ``False``。
|
||||
"""
|
||||
|
||||
def _cond(ctx: Context) -> bool:
|
||||
return dep_name in ctx and ctx[dep_name] == value
|
||||
|
||||
_cond.__name__ = f"DEP_EQUALS({dep_name!r},{value!r})"
|
||||
return _cond
|
||||
|
||||
@staticmethod
|
||||
def DEP_MATCHES(dep_name: str, predicate: Callable[[Any], bool]) -> Condition:
|
||||
"""上游任务 ``dep_name`` 的返回值满足 ``predicate`` 时为真。
|
||||
|
||||
依赖不存在时返回 ``False``。
|
||||
"""
|
||||
|
||||
def _cond(ctx: Context) -> bool:
|
||||
if dep_name not in ctx:
|
||||
return False
|
||||
try:
|
||||
return predicate(ctx[dep_name])
|
||||
except Exception as exc:
|
||||
logger.warning("DEP_MATCHES predicate %r raised: %r", dep_name, exc)
|
||||
return False
|
||||
|
||||
_cond.__name__ = f"DEP_MATCHES({dep_name!r},{getattr(predicate, '__name__', 'pred')})"
|
||||
return _cond
|
||||
|
||||
@staticmethod
|
||||
def DEP_PRESENT(dep_name: str) -> Condition:
|
||||
"""上游任务 ``dep_name`` 存在于上下文(即已成功执行)时为真。"""
|
||||
|
||||
def _cond(ctx: Context) -> bool:
|
||||
return dep_name in ctx and ctx[dep_name] is not None
|
||||
|
||||
_cond.__name__ = f"DEP_PRESENT({dep_name!r})"
|
||||
return _cond
|
||||
|
||||
@staticmethod
|
||||
def DEP_TRUTHY(dep_name: str) -> Condition:
|
||||
"""上游任务 ``dep_name`` 的返回值为真值时为真。"""
|
||||
|
||||
def _cond(ctx: Context) -> bool:
|
||||
return bool(ctx.get(dep_name))
|
||||
|
||||
_cond.__name__ = f"DEP_TRUTHY({dep_name!r})"
|
||||
return _cond
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 逻辑组合
|
||||
# ------------------------------------------------------------------ #
|
||||
@staticmethod
|
||||
def NOT(condition: Condition) -> Condition:
|
||||
"""对条件取反."""
|
||||
|
||||
def _cond(ctx: Context) -> bool:
|
||||
return not condition(ctx)
|
||||
|
||||
_cond.__name__ = f"NOT({_cond_name(condition)})"
|
||||
return _cond
|
||||
|
||||
@staticmethod
|
||||
def AND(*conditions: Condition) -> Condition:
|
||||
"""多个条件的逻辑与."""
|
||||
|
||||
def _cond(ctx: Context) -> bool:
|
||||
return all(c(ctx) for c in conditions)
|
||||
|
||||
_cond.__name__ = f"AND({', '.join(_cond_name(c) for c in conditions)})"
|
||||
return _cond
|
||||
|
||||
@staticmethod
|
||||
def OR(*conditions: Condition) -> Condition:
|
||||
"""多个条件的逻辑或."""
|
||||
|
||||
def _cond(ctx: Context) -> bool:
|
||||
return any(c(ctx) for c in conditions)
|
||||
|
||||
_cond.__name__ = f"OR({', '.join(_cond_name(c) for c in conditions)})"
|
||||
return _cond
|
||||
+49
-74
@@ -1,106 +1,92 @@
|
||||
"""上下文注入:把上游结果转换为函数参数。
|
||||
|
||||
本机制让用户可以编写普通函数,其参数名*就是*依赖声明,从而消除其他
|
||||
DAG 库中泛滥的样板包装器(如 ``def wrapper(): return fn(workflow.get_task_result('x'))``)。
|
||||
DAG 库中泛滥的样板包装器。
|
||||
|
||||
注入规则(按顺序求值)
|
||||
----------------------
|
||||
1. **标注为** :class:`Context` 的参数接收完整结果映射。适用于需要遍历
|
||||
所有输入的任务。
|
||||
2. **名称匹配某个依赖**的参数接收该依赖的结果。
|
||||
1. **标注为** :class:`Context` 的参数接收完整结果映射(含硬依赖与软依赖)。
|
||||
2. **名称匹配某个依赖**(硬或软)的参数接收该依赖的结果。
|
||||
3. ``**kwargs`` 参数以 dict 形式接收*所有*依赖结果。
|
||||
4. ``TaskSpec.args`` / ``TaskSpec.kwargs`` 为*非依赖*参数提供静态值。
|
||||
|
||||
若某参数无法解析且无默认值,则抛出 :class:`~pyflowx.errors.InjectionError`,
|
||||
并附带精确错误信息。
|
||||
若某参数无法解析且无默认值,则抛出 :class:`~pyflowx.errors.InjectionError`。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Mapping, Set, Tuple
|
||||
from functools import lru_cache
|
||||
from typing import Any, Mapping
|
||||
|
||||
from .errors import InjectionError
|
||||
from .task import Context, TaskSpec
|
||||
|
||||
__all__ = ["Context", "build_call_args", "describe_injection", "_is_context_annotation"]
|
||||
__all__ = ["Context", "_is_context_annotation", "build_call_args", "describe_injection"]
|
||||
|
||||
|
||||
@lru_cache(maxsize=1024)
|
||||
def _cached_signature(fn: Any) -> inspect.Signature:
|
||||
"""缓存 ``inspect.signature`` 结果(按 fn 对象键控)。
|
||||
|
||||
``fn`` 对象在 :meth:`TaskSpec.effective_fn` 缓存后稳定,签名重复内省
|
||||
属纯开销。对不可哈希的可调用对象,调用方回退到直接内省。
|
||||
"""
|
||||
return inspect.signature(fn)
|
||||
|
||||
|
||||
def _signature(fn: Any) -> inspect.Signature:
|
||||
"""获取签名,优先走缓存;``fn`` 不可哈希时回退到直接内省。"""
|
||||
try:
|
||||
return _cached_signature(fn)
|
||||
except TypeError:
|
||||
return inspect.signature(fn)
|
||||
|
||||
|
||||
def _is_context_annotation(annotation: Any) -> bool:
|
||||
"""判断参数标注是否为(或指向)``Context``。
|
||||
|
||||
处理三种形式:
|
||||
* ``Context`` 别名对象本身;
|
||||
* ``__name__``/``_name`` 为 ``Context`` 或 ``Mapping`` 的 typing 别名;
|
||||
* *字符串*标注(``from __future__ import annotations`` 会在运行时
|
||||
把所有标注变为字符串),如 ``"Context"`` 或 ``"px.Context"``。
|
||||
"""
|
||||
"""判断参数标注是否为(或指向)``Context``。"""
|
||||
if annotation is Context:
|
||||
return True
|
||||
# `from __future__ import annotations` 产生的字符串标注。
|
||||
if isinstance(annotation, str):
|
||||
# 匹配 "Context"、"px.Context"、"pyflowx.Context" 等。
|
||||
return annotation == "Context" or annotation.endswith(".Context")
|
||||
# 按限定名匹配,支持 ``from pyflowx import Context`` 再导出。
|
||||
name = getattr(annotation, "__name__", None) or getattr(annotation, "_name", None)
|
||||
if name in ("Context", "Mapping"):
|
||||
return True
|
||||
return False
|
||||
return name in ("Context", "Mapping")
|
||||
|
||||
|
||||
def build_call_args(
|
||||
spec: TaskSpec[object],
|
||||
spec: TaskSpec[Any],
|
||||
context: Mapping[str, Any],
|
||||
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
|
||||
) -> tuple[tuple[Any, ...], dict[str, Any]]:
|
||||
"""解析用于调用 ``spec.fn`` 的 ``(args, kwargs)``。
|
||||
|
||||
参数
|
||||
----
|
||||
spec:
|
||||
任务 spec,提供 ``fn``、``depends_on``、``args``、``kwargs``。
|
||||
context:
|
||||
依赖名 -> 结果值的映射。仅保证本任务自身的 ``depends_on`` 条目
|
||||
存在;其他任务的结果被排除,以保持注入的确定性。
|
||||
|
||||
返回
|
||||
----
|
||||
(args, kwargs)
|
||||
可直接展开为 ``spec.fn(*args, **kwargs)``。
|
||||
|
||||
抛出
|
||||
----
|
||||
InjectionError
|
||||
若必需参数无法满足,或静态 ``kwargs`` 与注入依赖名冲突。
|
||||
``context`` 必须已包含所有硬依赖与软依赖的结果(软依赖被跳过时由
|
||||
执行器填入 :attr:`TaskSpec.defaults` 中的默认值)。
|
||||
"""
|
||||
sig = inspect.signature(spec.fn)
|
||||
fn = spec.effective_fn
|
||||
sig = _signature(fn)
|
||||
params = sig.parameters
|
||||
|
||||
# 检测特殊参数类型。
|
||||
var_keyword = next(
|
||||
(p for p in params.values() if p.kind == inspect.Parameter.VAR_KEYWORD),
|
||||
None,
|
||||
)
|
||||
|
||||
# 与本任务相关的上下文子集。
|
||||
dep_context: Dict[str, Any] = {
|
||||
name: context[name] for name in spec.depends_on if name in context
|
||||
}
|
||||
# 本任务相关的上下文子集:硬依赖 + 软依赖。
|
||||
all_deps = set(spec.depends_on) | set(spec.soft_depends_on)
|
||||
dep_context: dict[str, Any] = {name: context[name] for name in all_deps if name in context}
|
||||
|
||||
# 检测静态 kwargs 与依赖名的冲突。
|
||||
collisions = set(spec.kwargs) & set(dep_context)
|
||||
if collisions:
|
||||
raise InjectionError(
|
||||
spec.name,
|
||||
f"static kwargs {sorted(collisions)} collide with dependency names; "
|
||||
"rename the static kwarg or the dependency.",
|
||||
+ "rename the static kwarg or the dependency.",
|
||||
)
|
||||
|
||||
injected_kwargs: Dict[str, Any] = {}
|
||||
leftover_dep_results: Dict[str, Any] = dict(dep_context)
|
||||
injected_kwargs: dict[str, Any] = {}
|
||||
leftover_dep_results: dict[str, Any] = dict(dep_context)
|
||||
|
||||
# 被 spec.args 消费的位置参数。记录哪些参数名已被位置填充,
|
||||
# 以便在基于名称的注入(依赖 / Context / 静态 kwargs)时跳过。
|
||||
positional_params: List[str] = []
|
||||
positional_params: list[str] = []
|
||||
positional_kinds = (
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
@@ -108,33 +94,25 @@ def build_call_args(
|
||||
for pname, param in params.items():
|
||||
if param.kind in positional_kinds:
|
||||
positional_params.append(pname)
|
||||
# 前 len(spec.args) 个位置参数由 spec.args 填充。
|
||||
args_filled: Set[str] = set(positional_params[: len(spec.args)])
|
||||
args_filled: set[str] = set(positional_params[: len(spec.args)])
|
||||
|
||||
for pname, param in params.items():
|
||||
# 跳过已被位置 spec.args 填充的参数。
|
||||
if pname in args_filled:
|
||||
continue
|
||||
|
||||
# 规则 1:标注为 Context -> 完整映射。
|
||||
if _is_context_annotation(param.annotation):
|
||||
injected_kwargs[pname] = dep_context
|
||||
continue
|
||||
|
||||
# 规则 2:名称匹配某个依赖。
|
||||
if pname in dep_context:
|
||||
injected_kwargs[pname] = dep_context[pname]
|
||||
leftover_dep_results.pop(pname, None)
|
||||
continue
|
||||
|
||||
# 规则 3:在循环后通过 **kwargs 处理。
|
||||
|
||||
# 规则 4:静态 kwargs 填充其余参数。
|
||||
if pname in spec.kwargs:
|
||||
injected_kwargs[pname] = spec.kwargs[pname]
|
||||
continue
|
||||
|
||||
# 该参数无来源:必须有默认值,否则报错。
|
||||
if param.default is inspect.Parameter.empty and param.kind not in (
|
||||
inspect.Parameter.VAR_POSITIONAL,
|
||||
inspect.Parameter.VAR_KEYWORD,
|
||||
@@ -144,9 +122,7 @@ def build_call_args(
|
||||
f"parameter {pname!r} has no dependency, static value, or default.",
|
||||
)
|
||||
|
||||
# 规则 3:**kwargs 吞掉剩余依赖结果。
|
||||
if var_keyword is not None and leftover_dep_results:
|
||||
# 先合并静态 kwargs,再合并依赖结果(冲突已在上方拒绝)。
|
||||
merged = dict(spec.kwargs)
|
||||
merged.update(injected_kwargs)
|
||||
merged.update(leftover_dep_results)
|
||||
@@ -155,13 +131,10 @@ def build_call_args(
|
||||
return tuple(spec.args), injected_kwargs
|
||||
|
||||
|
||||
def describe_injection(spec: TaskSpec[object]) -> str:
|
||||
"""生成任务参数注入方式的人类可读描述。
|
||||
|
||||
供 ``dry_run`` 使用,在不执行的情况下展示执行计划。
|
||||
"""
|
||||
sig = inspect.signature(spec.fn)
|
||||
# 确定哪些位置参数由 spec.args 填充。
|
||||
def describe_injection(spec: TaskSpec[Any]) -> str:
|
||||
"""生成任务参数注入方式的人类可读描述。供 ``dry_run`` 使用。"""
|
||||
fn = spec.effective_fn
|
||||
sig = _signature(fn)
|
||||
positional_params = [
|
||||
p
|
||||
for p, param in sig.parameters.items()
|
||||
@@ -172,6 +145,7 @@ def describe_injection(spec: TaskSpec[object]) -> str:
|
||||
)
|
||||
]
|
||||
args_filled = set(positional_params[: len(spec.args)])
|
||||
all_deps = set(spec.depends_on) | set(spec.soft_depends_on)
|
||||
parts = []
|
||||
for pname, param in sig.parameters.items():
|
||||
if pname in args_filled:
|
||||
@@ -179,8 +153,9 @@ def describe_injection(spec: TaskSpec[object]) -> str:
|
||||
parts.append(f"{pname}={spec.args[idx]!r}")
|
||||
elif _is_context_annotation(param.annotation):
|
||||
parts.append(f"{pname}=<Context>")
|
||||
elif pname in spec.depends_on:
|
||||
parts.append(f"{pname}=<result:{pname}>")
|
||||
elif pname in all_deps:
|
||||
tag = "soft" if pname in spec.soft_depends_on else "dep"
|
||||
parts.append(f"{pname}=<{tag}:{pname}>")
|
||||
elif pname in spec.kwargs:
|
||||
parts.append(f"{pname}={spec.kwargs[pname]!r}")
|
||||
elif param.default is not inspect.Parameter.empty:
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Iterable, Optional
|
||||
from typing import Any, Iterable
|
||||
|
||||
|
||||
class PyFlowXError(Exception):
|
||||
@@ -27,7 +27,7 @@ class MissingDependencyError(PyFlowXError):
|
||||
def __init__(self, task: str, dependency: str) -> None:
|
||||
super().__init__(
|
||||
f"Task '{task}' depends on unknown task '{dependency}'. "
|
||||
"Add the dependency before (or together with) this task."
|
||||
+ "Add the dependency before (or together with) this task."
|
||||
)
|
||||
self.task = task
|
||||
self.dependency = dependency
|
||||
@@ -55,12 +55,10 @@ class TaskFailedError(PyFlowXError):
|
||||
task: str,
|
||||
cause: BaseException,
|
||||
attempts: int,
|
||||
layer: Optional[int] = None,
|
||||
layer: int | None = None,
|
||||
) -> None:
|
||||
location = f" (layer {layer})" if layer is not None else ""
|
||||
super().__init__(
|
||||
f"Task '{task}' failed after {attempts} attempt(s){location}: {cause}"
|
||||
)
|
||||
super().__init__(f"Task '{task}' failed after {attempts} attempt(s){location}: {cause}")
|
||||
self.task = task
|
||||
self.cause = cause
|
||||
self.attempts = attempts
|
||||
@@ -87,6 +85,6 @@ class InjectionError(PyFlowXError):
|
||||
class StorageError(PyFlowXError):
|
||||
"""状态后端在持久化失败时抛出。"""
|
||||
|
||||
def __init__(self, detail: str, cause: Optional[BaseException] = None) -> None:
|
||||
def __init__(self, detail: str, cause: BaseException | None = None) -> None:
|
||||
super().__init__(f"State storage error: {detail}")
|
||||
self.cause: Any = cause
|
||||
|
||||
@@ -10,40 +10,38 @@ Shows:
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
|
||||
async def fetch_user(uid: int) -> dict:
|
||||
async def fetch_user(uid: int) -> dict[str, Any]:
|
||||
await asyncio.sleep(0.2)
|
||||
return {"id": uid, "name": f"User{uid}"}
|
||||
|
||||
|
||||
async def fetch_posts(uid: int) -> List[int]:
|
||||
async def fetch_posts(uid: int) -> list[int]:
|
||||
await asyncio.sleep(0.2)
|
||||
return [uid, uid + 1]
|
||||
|
||||
|
||||
# Context annotation → receives the full mapping of upstream results.
|
||||
def aggregate(ctx: px.Context) -> Dict[str, Any]:
|
||||
def aggregate(ctx: px.Context) -> dict[str, Any]:
|
||||
return dict(ctx)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
# Static positional args parameterise the same function twice.
|
||||
px.TaskSpec("fetch_user", fetch_user, args=(1,)),
|
||||
px.TaskSpec("fetch_posts", fetch_posts, args=(1,)),
|
||||
px.TaskSpec("aggregate", aggregate, ("fetch_user", "fetch_posts")),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
# Static positional args parameterise the same function twice.
|
||||
px.TaskSpec("fetch_user", fetch_user, args=(1,)),
|
||||
px.TaskSpec("fetch_posts", fetch_posts, args=(1,)),
|
||||
px.TaskSpec("aggregate", aggregate, depends_on=("fetch_user", "fetch_posts")),
|
||||
])
|
||||
|
||||
print("=== Dry run ===")
|
||||
px.run(graph, strategy="async", dry_run=True)
|
||||
_ = px.run(graph, strategy="async", dry_run=True)
|
||||
|
||||
events: List[px.TaskEvent] = []
|
||||
events: list[px.TaskEvent] = []
|
||||
print("\n=== Async execution ===")
|
||||
report = px.run(graph, strategy="async", on_event=events.append)
|
||||
|
||||
|
||||
@@ -10,21 +10,21 @@ Demonstrates the core PyFlowX workflow:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from typing import Any
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
# --- task functions: pure, testable, no framework coupling ------------- #
|
||||
|
||||
|
||||
def extract_customers() -> List[dict]:
|
||||
def extract_customers() -> list[dict[str, Any]]:
|
||||
return [
|
||||
{"id": "C001", "name": "Alice"},
|
||||
{"id": "C002", "name": "Bob"},
|
||||
]
|
||||
|
||||
|
||||
def extract_orders() -> List[dict]:
|
||||
def extract_orders() -> list[dict[str, Any]]:
|
||||
return [
|
||||
{"id": "O001", "customer_id": "C001", "amount": 150.0},
|
||||
{"id": "O002", "customer_id": "C002", "amount": 200.5},
|
||||
@@ -33,42 +33,38 @@ def extract_orders() -> List[dict]:
|
||||
|
||||
# Parameter names match dependency names → automatic injection.
|
||||
def transform(
|
||||
extract_customers: List[dict],
|
||||
extract_orders: List[dict],
|
||||
) -> List[dict]:
|
||||
extract_customers: list[dict[str, Any]],
|
||||
extract_orders: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
cmap = {c["id"]: c for c in extract_customers}
|
||||
return [
|
||||
{**o, "customer_name": cmap[o["customer_id"]]["name"]}
|
||||
for o in extract_orders
|
||||
if o["customer_id"] in cmap
|
||||
]
|
||||
return [{**o, "customer_name": cmap[o["customer_id"]]["name"]} for o in extract_orders if o["customer_id"] in cmap]
|
||||
|
||||
|
||||
def load(transform: List[dict]) -> int:
|
||||
def load(transform: list[dict[str, Any]]) -> int:
|
||||
print(f" loaded {len(transform)} records")
|
||||
return len(transform)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("extract_customers", extract_customers, tags=("extract",)),
|
||||
px.TaskSpec("extract_orders", extract_orders, tags=("extract",)),
|
||||
px.TaskSpec(
|
||||
"transform",
|
||||
transform,
|
||||
("extract_customers", "extract_orders"),
|
||||
tags=("transform",),
|
||||
),
|
||||
px.TaskSpec("load", load, ("transform",), retries=1, tags=("load",)),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("extract_customers", extract_customers, tags=("extract",)),
|
||||
px.TaskSpec("extract_orders", extract_orders, tags=("extract",)),
|
||||
px.TaskSpec(
|
||||
"transform",
|
||||
transform,
|
||||
depends_on=("extract_customers", "extract_orders"),
|
||||
tags=("transform",),
|
||||
),
|
||||
px.TaskSpec(
|
||||
"load", load, depends_on=("transform",), retry=px.RetryPolicy(max_attempts=1, delay=1.0), tags=("load",)
|
||||
),
|
||||
])
|
||||
|
||||
print("=== Execution plan ===")
|
||||
print(graph.describe())
|
||||
|
||||
print("\n=== Dry run (no execution) ===")
|
||||
px.run(graph, strategy="sequential", dry_run=True)
|
||||
_ = px.run(graph, strategy="sequential", dry_run=True)
|
||||
|
||||
print("\n=== Sequential execution ===")
|
||||
report = px.run(graph, strategy="sequential")
|
||||
|
||||
@@ -29,13 +29,11 @@ def merge(fetch_a: str, fetch_b: str) -> str:
|
||||
|
||||
|
||||
def main() -> None:
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("fetch_a", fetch_a),
|
||||
px.TaskSpec("fetch_b", fetch_b),
|
||||
px.TaskSpec("merge", merge, ("fetch_a", "fetch_b")),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("fetch_a", fetch_a),
|
||||
px.TaskSpec("fetch_b", fetch_b),
|
||||
px.TaskSpec("merge", merge, depends_on=("fetch_a", "fetch_b")),
|
||||
])
|
||||
|
||||
print("=== Mermaid diagram ===")
|
||||
print(graph.to_mermaid("LR"))
|
||||
|
||||
+728
-271
File diff suppressed because it is too large
Load Diff
+390
-126
@@ -1,31 +1,140 @@
|
||||
"""DAG 构建、校验、分层与可视化。
|
||||
|
||||
使用标准库的 :mod:`graphlib`(3.9+)或 :mod:`graphlib_backport`(3.8)
|
||||
进行拓扑排序。图以增量方式构建并即时校验,使配置错误在构建时(而非
|
||||
执行时)快速失败。
|
||||
进行拓扑排序。图以增量方式构建并即时校验,使配置错误在构建时(而非执行时)快速失败。
|
||||
|
||||
支持:
|
||||
* 图级默认值 :class:`GraphDefaults`,TaskSpec 字段为 ``None`` 时回退。
|
||||
* :meth:`Graph.map` 工厂批量生成 fan-out 任务。
|
||||
* 字符串引用与 :func:`compose` 编程式组合多个图。
|
||||
* 软依赖:仅用于上下文注入,不参与拓扑分层。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = [
|
||||
"Graph",
|
||||
"GraphDefaults",
|
||||
]
|
||||
|
||||
import inspect
|
||||
import sys
|
||||
from typing import Dict, Iterable, List, Mapping, Sequence, Set, Tuple
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import Any, Callable, Iterable, Mapping, Sequence
|
||||
|
||||
from .errors import CycleError, DuplicateTaskError, MissingDependencyError
|
||||
from .task import TaskSpec
|
||||
from .task import Context, RetryPolicy, TaskSpec
|
||||
|
||||
# graphlib 自 3.9 起进入标准库;3.8 回退到 backport。
|
||||
if sys.version_info >= (3, 9): # pragma: no cover
|
||||
import graphlib
|
||||
import graphlib # pyright: ignore[reportUnreachable]
|
||||
|
||||
_TopologicalSorter = graphlib.TopologicalSorter
|
||||
else: # pragma: no cover
|
||||
import graphlib # type: ignore[import-untyped] # pragma: no cover
|
||||
import graphlib # type: ignore[import-untyped]
|
||||
|
||||
_TopologicalSorter = graphlib.TopologicalSorter # pragma: no cover
|
||||
|
||||
|
||||
@dataclass
|
||||
class GraphDefaults:
|
||||
"""图级默认值。TaskSpec 对应字段为 ``None`` 时回退到此处。
|
||||
|
||||
仅对可空字段生效(retry/timeout/strategy/env/cwd/tags/priority/
|
||||
continue_on_error/concurrency_key)。非空字段(name/fn/cmd)不回退。
|
||||
"""
|
||||
|
||||
retry: RetryPolicy | None = None
|
||||
timeout: float | None = None
|
||||
strategy: str | None = None
|
||||
tags: tuple[str, ...] = ()
|
||||
env: Mapping[str, str] | None = None
|
||||
cwd: Any = None # Path | None
|
||||
priority: int = 0
|
||||
continue_on_error: bool = False
|
||||
concurrency_key: str | None = None
|
||||
verbose: bool = False
|
||||
|
||||
|
||||
def _prune_deps(spec: TaskSpec[Any], keep: Callable[[str], bool]) -> TaskSpec[Any]:
|
||||
"""返回新 spec,其 ``depends_on`` / ``soft_depends_on`` 仅保留 ``keep(dep)`` 为真的依赖。"""
|
||||
return replace(
|
||||
spec,
|
||||
depends_on=tuple(d for d in spec.depends_on if keep(d)),
|
||||
soft_depends_on=tuple(d for d in spec.soft_depends_on if keep(d)),
|
||||
)
|
||||
|
||||
|
||||
def _make_namespaced_fn(orig_fn: Any, ns: str, dep_names: set[str]) -> Any:
|
||||
"""包装 fn,使其能接收带 ``ns:`` 前缀的依赖名,调用时映射回原参数名。
|
||||
|
||||
命名空间合并后,依赖名带前缀(如 ``build:extract``),但 Python 参数名
|
||||
不能含 ``:``。wrapper 用 ``**kwargs`` 接收所有依赖,内部把带前缀的依赖名
|
||||
映射回原参数名后调用原 fn。
|
||||
|
||||
无依赖参数时直接返回原 fn。
|
||||
"""
|
||||
if not dep_names or orig_fn is None:
|
||||
return orig_fn
|
||||
try:
|
||||
orig_sig = inspect.signature(orig_fn)
|
||||
except (TypeError, ValueError):
|
||||
return orig_fn
|
||||
|
||||
# 带前缀依赖名 -> 原参数名
|
||||
name_map: dict[str, str] = {f"{ns}:{orig}": orig for orig in dep_names}
|
||||
prefix = f"{ns}:"
|
||||
|
||||
# 检查原 fn 是否有 Context 标注参数
|
||||
context_param_name: str | None = None
|
||||
for p in orig_sig.parameters.values():
|
||||
ann = p.annotation
|
||||
if ann is not Context and not (isinstance(ann, str) and ann.endswith("Context")):
|
||||
continue
|
||||
context_param_name = p.name
|
||||
break
|
||||
|
||||
if context_param_name is not None:
|
||||
|
||||
def wrapper(ctx: Any = None, **kwargs: Any) -> Any:
|
||||
# ctx 是 dep_context,键为带前缀的依赖名;映射回原始键
|
||||
orig_ctx: dict[str, Any] = {}
|
||||
for k, v in (ctx or {}).items():
|
||||
orig_ctx[name_map.get(k, k)] = v
|
||||
# kwargs 中带前缀的依赖也映射回原参数名
|
||||
for k, v in kwargs.items():
|
||||
if k in name_map:
|
||||
orig_ctx[name_map[k]] = v
|
||||
return orig_fn(**{context_param_name: orig_ctx})
|
||||
|
||||
ctx_param = inspect.Parameter("ctx", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Context)
|
||||
kw_param = inspect.Parameter("kwargs", inspect.Parameter.VAR_KEYWORD)
|
||||
wrapper.__signature__ = inspect.Signature( # type: ignore[attr-defined]
|
||||
parameters=[ctx_param, kw_param],
|
||||
return_annotation=orig_sig.return_annotation,
|
||||
)
|
||||
else:
|
||||
|
||||
def wrapper(**kwargs: Any) -> Any: # type: ignore[no-redef]
|
||||
orig_kwargs: dict[str, Any] = {}
|
||||
for k, v in kwargs.items():
|
||||
if k.startswith(prefix):
|
||||
orig_kwargs[k[len(prefix) :]] = v
|
||||
return orig_fn(**orig_kwargs)
|
||||
|
||||
kw_param = inspect.Parameter("kwargs", inspect.Parameter.VAR_KEYWORD)
|
||||
wrapper.__signature__ = inspect.Signature( # type: ignore[attr-defined]
|
||||
parameters=[kw_param],
|
||||
return_annotation=orig_sig.return_annotation,
|
||||
)
|
||||
|
||||
wrapper.__name__ = f"{ns}_{getattr(orig_fn, '__name__', 'fn')}"
|
||||
wrapper.__doc__ = getattr(orig_fn, "__doc__", None)
|
||||
return wrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class Graph:
|
||||
"""校验后不可变的有向无环任务图。
|
||||
"""校验后的有向无环任务图。
|
||||
|
||||
通过添加 :class:`~pyflowx.task.TaskSpec` 实例构建。每次 ``add`` 都
|
||||
执行即时校验(重名、缺失依赖),:meth:`validate` / :meth:`layers`
|
||||
@@ -35,69 +144,157 @@ class Graph:
|
||||
这使图可安全重复运行并在线程间共享。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._specs: Dict[str, TaskSpec[object]] = {}
|
||||
# 任务 -> 其直接依赖(前驱)。
|
||||
self._deps: Dict[str, Tuple[str, ...]] = {}
|
||||
specs: dict[str, TaskSpec[Any]] = field(default_factory=dict)
|
||||
deps: dict[str, tuple[str, ...]] = field(default_factory=dict)
|
||||
defaults: GraphDefaults = field(default_factory=GraphDefaults)
|
||||
namespace: str | None = None
|
||||
|
||||
# 待解析的字符串引用列表(由 GraphComposer 消费);为空表示无引用。
|
||||
_pending_refs: list[str] = field(default_factory=list)
|
||||
|
||||
# resolved_spec 缓存:避免执行期每个任务多次重复 dataclasses.replace 判断。
|
||||
# 在 specs / defaults 变更时失效。
|
||||
_resolved_cache: dict[str, TaskSpec[Any]] = field(default_factory=dict)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 构建
|
||||
# ------------------------------------------------------------------ #
|
||||
def add(self, spec: TaskSpec[object]) -> "Graph":
|
||||
"""注册一个任务 spec,并即时校验。
|
||||
|
||||
返回 ``self`` 以支持链式调用,但推荐入口是 :meth:`from_specs`,
|
||||
它会整批校验(允许单次调用中的前向引用)。
|
||||
"""
|
||||
if spec.name in self._specs:
|
||||
raise DuplicateTaskError(spec.name)
|
||||
self._specs[spec.name] = spec
|
||||
self._deps[spec.name] = spec.depends_on
|
||||
# 为增量 API 即时检查重名与缺失依赖。
|
||||
def add(self, spec: TaskSpec[Any]) -> Graph:
|
||||
"""注册一个任务 spec,并即时校验。返回 ``self`` 支持链式调用。"""
|
||||
self._register(spec)
|
||||
self._validate_references()
|
||||
return self
|
||||
|
||||
def chain(self, *specs: TaskSpec[Any]) -> Graph:
|
||||
"""链式注册任务:每个 spec 自动依赖前一个。
|
||||
|
||||
``chain(a, b, c)`` 等价于 ``b`` 依赖 ``a``,``c`` 依赖 ``b``。
|
||||
若 spec 已带 ``depends_on``,则前驱名追加到现有依赖前。
|
||||
返回 ``self`` 支持链式调用。
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> graph = px.Graph().chain(extract, transform, load)
|
||||
"""
|
||||
prev_name: str | None = None
|
||||
for s in specs:
|
||||
current = s
|
||||
if prev_name is not None:
|
||||
# 将前驱追加到 depends_on 最前(保持显式依赖优先)
|
||||
new_deps = (prev_name, *s.depends_on) if prev_name not in s.depends_on else s.depends_on
|
||||
current = replace(s, depends_on=new_deps)
|
||||
self.add(current)
|
||||
prev_name = current.name
|
||||
return self
|
||||
|
||||
def _register(self, spec: TaskSpec[Any]) -> None:
|
||||
if spec.name in self.specs:
|
||||
raise DuplicateTaskError(spec.name)
|
||||
self.specs[spec.name] = spec
|
||||
# 拓扑依赖仅含硬依赖;软依赖仅用于注入,不影响分层。
|
||||
self.deps[spec.name] = spec.depends_on
|
||||
self._resolved_cache.clear()
|
||||
|
||||
@classmethod
|
||||
def from_specs(cls, specs: Iterable[TaskSpec[object]]) -> "Graph":
|
||||
def from_specs(
|
||||
cls,
|
||||
specs: Iterable[TaskSpec[Any] | str],
|
||||
defaults: GraphDefaults | None = None,
|
||||
*,
|
||||
namespace: str | None = None,
|
||||
) -> Graph:
|
||||
"""从可迭代的 task spec 构建图。
|
||||
|
||||
先收集所有 spec,再统一校验。这意味着任务可以引用*后出现*的
|
||||
依赖——顺序无关,就像声明式配置文件的读取方式。
|
||||
先收集所有 spec,再统一校验。允许前向引用。支持字符串引用,
|
||||
由 :func:`compose` 或 :class:`GraphComposer` 解析展开。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
specs:
|
||||
TaskSpec 对象或字符串引用的列表。
|
||||
defaults:
|
||||
图级默认值。``None`` 使用空 :class:`GraphDefaults`。
|
||||
namespace:
|
||||
可选命名空间,用于 :meth:`add_subgraph` 合并时加前缀。
|
||||
"""
|
||||
graph = cls()
|
||||
graph = cls(defaults=defaults or GraphDefaults(), namespace=namespace)
|
||||
pending_refs: list[str] = []
|
||||
|
||||
for spec in specs:
|
||||
if spec.name in graph._specs:
|
||||
raise DuplicateTaskError(spec.name)
|
||||
graph._specs[spec.name] = spec
|
||||
graph._deps[spec.name] = spec.depends_on
|
||||
if isinstance(spec, str):
|
||||
pending_refs.append(spec)
|
||||
elif isinstance(spec, TaskSpec):
|
||||
graph._register(spec)
|
||||
else:
|
||||
raise TypeError(f"from_specs 只接受 TaskSpec 或 str,收到: {type(spec)}")
|
||||
|
||||
if pending_refs:
|
||||
graph._pending_refs = pending_refs
|
||||
|
||||
graph._validate_references()
|
||||
graph.validate()
|
||||
return graph
|
||||
|
||||
def add_subgraph(self, sub: Graph, *, namespace: str | None = None) -> Graph:
|
||||
"""将子图合并到当前图,任务名加命名空间前缀避免冲突。
|
||||
|
||||
参数
|
||||
----
|
||||
sub:
|
||||
待合并的子图。
|
||||
namespace:
|
||||
命名空间前缀。``None`` 时使用 ``sub.namespace``,若子图也无命名空间
|
||||
则抛出 ``ValueError``。最终任务名为 ``f"{ns}:{original_name}"``。
|
||||
|
||||
合并后,子图内任务的依赖名也会被加前缀;与子图外部任务的依赖保持原样。
|
||||
|
||||
返回 ``self`` 支持链式调用。
|
||||
"""
|
||||
ns = namespace or sub.namespace
|
||||
if not ns:
|
||||
raise ValueError("add_subgraph 需要 namespace 或子图自带 namespace")
|
||||
|
||||
def _rename(name: str) -> str:
|
||||
# 仅对子图内部任务名加前缀;外部依赖保持原样
|
||||
return f"{ns}:{name}" if name in sub.specs else name
|
||||
|
||||
sub_names = set(sub.specs.keys())
|
||||
for spec in sub.specs.values():
|
||||
# 子图内部依赖名需加前缀,对应的 fn 参数也需包装
|
||||
internal_deps = (set(spec.depends_on) | set(spec.soft_depends_on)) & sub_names
|
||||
new_fn = _make_namespaced_fn(spec.fn, ns, internal_deps) if spec.fn else spec.fn
|
||||
new_spec = replace(
|
||||
spec,
|
||||
name=_rename(spec.name),
|
||||
fn=new_fn,
|
||||
depends_on=tuple(_rename(d) for d in spec.depends_on),
|
||||
soft_depends_on=tuple(_rename(d) for d in spec.soft_depends_on),
|
||||
)
|
||||
self._register(new_spec)
|
||||
self._validate_references()
|
||||
self.validate()
|
||||
return self
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 校验
|
||||
# ------------------------------------------------------------------ #
|
||||
def _validate_references(self) -> None:
|
||||
"""确保每个依赖名都存在于图中。"""
|
||||
for name, deps in self._deps.items():
|
||||
for dep in deps:
|
||||
if dep not in self._specs:
|
||||
"""确保每个依赖名都存在于图中。硬依赖与软依赖都校验。"""
|
||||
for name, spec in self.specs.items():
|
||||
for dep in spec.depends_on:
|
||||
if dep not in self.specs:
|
||||
raise MissingDependencyError(name, dep)
|
||||
for dep in spec.soft_depends_on:
|
||||
if dep not in self.specs:
|
||||
raise MissingDependencyError(name, dep)
|
||||
|
||||
def validate(self) -> None:
|
||||
"""执行完整 DAG 校验。
|
||||
|
||||
存在环时抛出 :class:`~pyflowx.errors.CycleError`。
|
||||
依赖存在性由 :meth:`_validate_references` 检查。
|
||||
"""
|
||||
"""执行完整 DAG 校验。存在环时抛出 :class:`CycleError`。"""
|
||||
self._validate_references()
|
||||
sorter = _TopologicalSorter(self._deps)
|
||||
sorter = _TopologicalSorter(self.deps)
|
||||
try:
|
||||
# prepare() 在有环时抛出 CycleError;此处不需要
|
||||
# static_order() 的结果,仅利用其校验副作用。
|
||||
sorter.prepare()
|
||||
except graphlib.CycleError as exc:
|
||||
# exc.args[1] 是构成环的节点列表。
|
||||
except graphlib.CycleError as exc: # type: ignore[name-defined]
|
||||
cycle: Sequence[str] = exc.args[1] if len(exc.args) > 1 else []
|
||||
raise CycleError(list(cycle)) from exc
|
||||
|
||||
@@ -105,37 +302,81 @@ class Graph:
|
||||
# 内省
|
||||
# ------------------------------------------------------------------ #
|
||||
@property
|
||||
def names(self) -> List[str]:
|
||||
def names(self) -> list[str]:
|
||||
"""所有已注册任务名(按插入顺序)。"""
|
||||
return list(self._specs.keys())
|
||||
return list(self.specs.keys())
|
||||
|
||||
def spec(self, name: str) -> TaskSpec[object]:
|
||||
def spec(self, name: str) -> TaskSpec[Any]:
|
||||
"""返回 ``name`` 的 spec;不存在则 ``KeyError``。"""
|
||||
return self._specs[name]
|
||||
return self.specs[name]
|
||||
|
||||
def dependencies(self, name: str) -> Tuple[str, ...]:
|
||||
"""``name`` 的直接前驱。"""
|
||||
return self._deps[name]
|
||||
def resolved_spec(self, name: str) -> TaskSpec[Any]:
|
||||
"""返回应用图级默认值后的 spec(不修改原图)。
|
||||
|
||||
def all_specs(self) -> Mapping[str, TaskSpec[object]]:
|
||||
对于 ``retry``/``timeout``/``strategy``/``env``/``cwd`` 等可空
|
||||
字段,若 spec 字段为默认空值且图级默认值非空,则用
|
||||
:func:`dataclasses.replace` 生成带默认值的副本。
|
||||
|
||||
结果按 ``name`` 缓存;specs / defaults 变更时缓存失效。
|
||||
"""
|
||||
cached = self._resolved_cache.get(name)
|
||||
if cached is not None:
|
||||
return cached
|
||||
spec = self.specs[name]
|
||||
d = self.defaults
|
||||
overrides: dict[str, Any] = {}
|
||||
if spec.retry == RetryPolicy() and d.retry is not None:
|
||||
overrides["retry"] = d.retry
|
||||
if spec.timeout is None and d.timeout is not None:
|
||||
overrides["timeout"] = d.timeout
|
||||
if spec.strategy is None and d.strategy is not None:
|
||||
overrides["strategy"] = d.strategy
|
||||
if spec.env is None and d.env is not None:
|
||||
overrides["env"] = d.env
|
||||
if spec.cwd is None and d.cwd is not None:
|
||||
overrides["cwd"] = d.cwd
|
||||
if spec.priority == 0 and d.priority != 0:
|
||||
overrides["priority"] = d.priority
|
||||
if not spec.continue_on_error and d.continue_on_error:
|
||||
overrides["continue_on_error"] = True
|
||||
if spec.concurrency_key is None and d.concurrency_key is not None:
|
||||
overrides["concurrency_key"] = d.concurrency_key
|
||||
if not spec.verbose and d.verbose:
|
||||
overrides["verbose"] = True
|
||||
if not spec.tags and d.tags:
|
||||
overrides["tags"] = d.tags
|
||||
resolved = spec if not overrides else replace(spec, **overrides)
|
||||
self._resolved_cache[name] = resolved
|
||||
return resolved
|
||||
|
||||
def dependencies(self, name: str) -> tuple[str, ...]:
|
||||
"""``name`` 的直接硬依赖前驱。"""
|
||||
return self.deps[name]
|
||||
|
||||
def all_deps(self, name: str) -> tuple[str, ...]:
|
||||
"""``name`` 的硬依赖 + 软依赖。"""
|
||||
spec = self.specs[name]
|
||||
return tuple(spec.depends_on) + tuple(spec.soft_depends_on)
|
||||
|
||||
def all_specs(self) -> Mapping[str, TaskSpec[Any]]:
|
||||
"""name -> spec 的只读视图。"""
|
||||
return self._specs
|
||||
return self.specs
|
||||
|
||||
def layers(self) -> List[List[str]]:
|
||||
def layers(self) -> list[list[str]]:
|
||||
"""将任务分组为可并行执行的层(Kahn 算法)。
|
||||
|
||||
同层任务无相互依赖,可并发执行。层按执行顺序返回。
|
||||
同层任务无相互硬依赖,可并发执行。软依赖不参与分层。
|
||||
层按执行顺序返回。图有环时抛出 :class:`CycleError`。
|
||||
|
||||
图有环时抛出 :class:`~pyflowx.errors.CycleError`。
|
||||
.. note::
|
||||
本方法假定图已通过 :meth:`validate` 校验(由 :func:`pyflowx.run`
|
||||
在入口统一执行一次)。若直接调用本方法,需自行先校验。
|
||||
"""
|
||||
self.validate()
|
||||
sorter = _TopologicalSorter(self._deps)
|
||||
result: List[List[str]] = []
|
||||
# ``get_ready`` + ``done`` 每次给出一层,正好是并行执行所需的分组。
|
||||
sorter = _TopologicalSorter(self.deps)
|
||||
result: list[list[str]] = []
|
||||
sorter.prepare()
|
||||
while sorter.is_active():
|
||||
ready = list(sorter.get_ready())
|
||||
# 排序以保证确定性、可复现的执行计划。
|
||||
ready.sort()
|
||||
result.append(ready)
|
||||
for node in ready:
|
||||
@@ -145,81 +386,104 @@ class Graph:
|
||||
# ------------------------------------------------------------------ #
|
||||
# 子图 / 标签过滤
|
||||
# ------------------------------------------------------------------ #
|
||||
def subgraph(self, tags: Iterable[str]) -> "Graph":
|
||||
"""返回仅包含匹配任意标签的任务的新图。
|
||||
def subgraph(self, tags: Iterable[str]) -> Graph:
|
||||
"""返回仅包含匹配任意标签的任务的新图。依赖边被修剪。"""
|
||||
wanted: set[str] = set(tags)
|
||||
|
||||
依赖会被修剪,仅保留被保留任务之间的边;指向被丢弃任务的边
|
||||
会被移除(被保留的任务不再等待它们)。用于调试时运行大型
|
||||
DAG 的切片。
|
||||
"""
|
||||
wanted: Set[str] = set(tags)
|
||||
kept: List[TaskSpec[object]] = []
|
||||
for spec in self._specs.values():
|
||||
if wanted & set(spec.tags):
|
||||
pruned_deps = tuple(
|
||||
d
|
||||
for d in spec.depends_on
|
||||
if d in self._specs and (wanted & set(self._specs[d].tags))
|
||||
)
|
||||
kept.append(
|
||||
TaskSpec(
|
||||
name=spec.name,
|
||||
fn=spec.fn,
|
||||
depends_on=pruned_deps,
|
||||
args=spec.args,
|
||||
kwargs=spec.kwargs,
|
||||
retries=spec.retries,
|
||||
timeout=spec.timeout,
|
||||
tags=spec.tags,
|
||||
)
|
||||
)
|
||||
return Graph.from_specs(kept)
|
||||
def _dep_kept(dep: str) -> bool:
|
||||
return dep in self.specs and bool(wanted & set(self.specs[dep].tags))
|
||||
|
||||
def subgraph_by_names(self, names: Iterable[str]) -> "Graph":
|
||||
kept: list[TaskSpec[Any]] = [
|
||||
_prune_deps(spec, _dep_kept) for spec in self.specs.values() if wanted & set(spec.tags)
|
||||
]
|
||||
return Graph.from_specs(kept, defaults=self.defaults)
|
||||
|
||||
def subgraph_by_names(self, names: Iterable[str]) -> Graph:
|
||||
"""返回限定于 ``names`` 的新图(边已修剪)。"""
|
||||
wanted: Set[str] = set(names)
|
||||
wanted: set[str] = set(names)
|
||||
for n in wanted:
|
||||
if n not in self._specs:
|
||||
if n not in self.specs:
|
||||
raise KeyError(f"Unknown task name: {n!r}")
|
||||
kept: List[TaskSpec[object]] = []
|
||||
for spec in self._specs.values():
|
||||
if spec.name in wanted:
|
||||
pruned_deps = tuple(d for d in spec.depends_on if d in wanted)
|
||||
kept.append(
|
||||
TaskSpec(
|
||||
name=spec.name,
|
||||
fn=spec.fn,
|
||||
depends_on=pruned_deps,
|
||||
args=spec.args,
|
||||
kwargs=spec.kwargs,
|
||||
retries=spec.retries,
|
||||
timeout=spec.timeout,
|
||||
tags=spec.tags,
|
||||
)
|
||||
)
|
||||
return Graph.from_specs(kept)
|
||||
kept: list[TaskSpec[Any]] = [
|
||||
_prune_deps(spec, lambda d: d in wanted) for spec in self.specs.values() if spec.name in wanted
|
||||
]
|
||||
return Graph.from_specs(kept, defaults=self.defaults)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Fan-out / map-reduce
|
||||
# ------------------------------------------------------------------ #
|
||||
def map(
|
||||
self,
|
||||
name_fn: Callable[[int], str],
|
||||
spec: TaskSpec[Any],
|
||||
items: Sequence[Any],
|
||||
arg_factory: Callable[[Any], tuple[Any, ...]] | None = None,
|
||||
depends_on_per: Callable[[int], tuple[str, ...]] | None = None,
|
||||
) -> list[TaskSpec[Any]]:
|
||||
"""为 ``items`` 中每个元素生成一个 TaskSpec 并加入图。
|
||||
|
||||
用于 fan-out / map-reduce 模式。返回生成的 spec 列表,便于
|
||||
后续 reduce 任务依赖。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name_fn:
|
||||
接受索引 ``i``,返回任务名。需保证唯一。
|
||||
spec:
|
||||
模板 spec。其 ``name`` 与 ``args`` 会被覆盖。
|
||||
items:
|
||||
待分发的数据序列。
|
||||
arg_factory:
|
||||
接受一个 item,返回位置参数元组,覆盖 spec.args。
|
||||
``None`` 则将单个 item 作为唯一位置参数。
|
||||
depends_on_per:
|
||||
接受索引 ``i``,返回该任务的额外硬依赖。``None`` 则继承 spec.depends_on。
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[TaskSpec]
|
||||
生成的 spec 列表(已加入图)。
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> fetch_tmpl = px.TaskSpec("", fn=fetch_user)
|
||||
>>> specs = graph.map(lambda i: f"fetch_{i}", fetch_tmpl, [1, 2, 3])
|
||||
>>> reduce_spec = px.TaskSpec("reduce", fn=reduce_fn, depends_on=tuple(s.name for s in specs))
|
||||
"""
|
||||
generated: list[TaskSpec[Any]] = []
|
||||
for i, item in enumerate(items):
|
||||
name = name_fn(i)
|
||||
args = arg_factory(item) if arg_factory is not None else (item,)
|
||||
extra_deps = depends_on_per(i) if depends_on_per is not None else ()
|
||||
new_spec = replace(
|
||||
spec,
|
||||
name=name,
|
||||
args=tuple(args),
|
||||
depends_on=tuple(spec.depends_on) + tuple(extra_deps),
|
||||
)
|
||||
self.add(new_spec)
|
||||
generated.append(new_spec)
|
||||
return generated
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 可视化
|
||||
# ------------------------------------------------------------------ #
|
||||
def to_mermaid(self, orientation: str = "TD") -> str:
|
||||
"""将 DAG 渲染为 Mermaid ``graph`` 定义字符串。
|
||||
|
||||
无外部依赖;输出可粘贴到 Markdown、由 VS Code 的 Mermaid 预览
|
||||
渲染,或保存为文件。
|
||||
"""
|
||||
"""将 DAG 渲染为 Mermaid ``graph`` 定义字符串。"""
|
||||
valid = {"TD", "TB", "BT", "LR", "RL"}
|
||||
orientation = orientation.upper()
|
||||
if orientation not in valid:
|
||||
raise ValueError(
|
||||
f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}."
|
||||
)
|
||||
lines: List[str] = [f"graph {orientation}"]
|
||||
for name in self._specs:
|
||||
raise ValueError(f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}.")
|
||||
lines: list[str] = [f"graph {orientation}"]
|
||||
for name in self.specs:
|
||||
lines.append(f' {name}["{name}"]')
|
||||
for name, deps in self._deps.items():
|
||||
for name, deps in self.deps.items():
|
||||
for dep in deps:
|
||||
lines.append(f" {dep} --> {name}")
|
||||
# 软依赖用虚线
|
||||
for name, spec in self.specs.items():
|
||||
for dep in spec.soft_depends_on:
|
||||
lines.append(f" {dep} -.-> {name}")
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
@@ -227,16 +491,16 @@ class Graph:
|
||||
# ------------------------------------------------------------------ #
|
||||
def describe(self) -> str:
|
||||
"""用于调试的人类可读多行摘要。"""
|
||||
out: List[str] = [f"Graph(tasks={len(self._specs)})"]
|
||||
out: list[str] = [f"Graph(tasks={len(self.specs)})"]
|
||||
for layer_idx, layer in enumerate(self.layers(), 1):
|
||||
out.append(f" Layer {layer_idx}: {layer}")
|
||||
return "\n".join(out)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Graph(tasks={len(self._specs)})"
|
||||
return f"Graph(tasks={len(self.specs)})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._specs)
|
||||
return len(self.specs)
|
||||
|
||||
def __contains__(self, name: object) -> bool:
|
||||
return name in self._specs
|
||||
def __contains__(self, name: Any) -> bool:
|
||||
return name in self.specs
|
||||
|
||||
+26
-14
@@ -7,7 +7,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Iterator, List
|
||||
from typing import Any, Iterator
|
||||
|
||||
from .task import TaskResult, TaskStatus
|
||||
|
||||
@@ -24,7 +24,7 @@ class RunReport:
|
||||
当且仅当所有非跳过任务都以 ``SUCCESS`` 结束时为 ``True``。
|
||||
"""
|
||||
|
||||
results: Dict[str, TaskResult[object]] = field(default_factory=dict)
|
||||
results: dict[str, TaskResult[Any]] = field(default_factory=dict)
|
||||
success: bool = True
|
||||
|
||||
# ---- 类型化访问 --------------------------------------------------- #
|
||||
@@ -36,11 +36,11 @@ class RunReport:
|
||||
"""
|
||||
return self.results[name].value
|
||||
|
||||
def result_of(self, name: str) -> TaskResult[object]:
|
||||
def result_of(self, name: str) -> TaskResult[Any]:
|
||||
"""返回 ``name`` 的完整 :class:`TaskResult`。"""
|
||||
return self.results[name]
|
||||
|
||||
def __contains__(self, name: object) -> bool:
|
||||
def __contains__(self, name: Any) -> bool:
|
||||
return name in self.results
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
@@ -50,9 +50,9 @@ class RunReport:
|
||||
return len(self.results)
|
||||
|
||||
# ---- 汇总 --------------------------------------------------------- #
|
||||
def summary(self) -> Dict[str, Any]:
|
||||
def summary(self) -> dict[str, Any]:
|
||||
"""用于日志/仪表盘的紧凑统计字典。"""
|
||||
counts: Dict[str, int] = {}
|
||||
counts: dict[str, int] = {}
|
||||
total_duration = 0.0
|
||||
for r in self.results.values():
|
||||
counts[r.status.value] = counts.get(r.status.value, 0) + 1
|
||||
@@ -65,19 +65,31 @@ class RunReport:
|
||||
"total_duration_seconds": round(total_duration, 6),
|
||||
}
|
||||
|
||||
def failed_tasks(self) -> List[str]:
|
||||
def failed_tasks(self) -> list[str]:
|
||||
"""以 FAILED 状态结束的任务名列表。"""
|
||||
return [
|
||||
name for name, r in self.results.items() if r.status == TaskStatus.FAILED
|
||||
]
|
||||
return [name for name, r in self.results.items() if r.status == TaskStatus.FAILED]
|
||||
|
||||
def succeeded_tasks(self) -> list[str]:
|
||||
"""以 SUCCESS 状态结束的任务名列表。"""
|
||||
return [name for name, r in self.results.items() if r.status == TaskStatus.SUCCESS]
|
||||
|
||||
def skipped_tasks(self) -> list[str]:
|
||||
"""以 SKIPPED 状态结束的任务名列表。"""
|
||||
return [name for name, r in self.results.items() if r.status == TaskStatus.SKIPPED]
|
||||
|
||||
def tasks_by_status(self, status: TaskStatus) -> list[str]:
|
||||
"""返回指定状态的任务名列表。"""
|
||||
return [name for name, r in self.results.items() if r.status == status]
|
||||
|
||||
def durations(self) -> dict[str, float]:
|
||||
"""任务名 -> 执行时长(秒)。无时长记录的为 0.0。"""
|
||||
return {name: (r.duration or 0.0) for name, r in self.results.items()}
|
||||
|
||||
def describe(self) -> str:
|
||||
"""用于调试的人类可读多行报告。"""
|
||||
lines: List[str] = [f"RunReport(success={self.success})"]
|
||||
lines: list[str] = [f"RunReport(success={self.success})"]
|
||||
for name, r in self.results.items():
|
||||
dur = f"{r.duration:.3f}s" if r.duration is not None else "-"
|
||||
err = f" error={r.error!r}" if r.error else ""
|
||||
lines.append(
|
||||
f" {name}: {r.status.value} ({dur} attempts={r.attempts}){err}"
|
||||
)
|
||||
lines.append(f" {name}: {r.status.value} ({dur} attempts={r.attempts}){err}")
|
||||
return "\n".join(lines)
|
||||
|
||||
@@ -0,0 +1,330 @@
|
||||
"""命令行运行器:根据用户输入执行对应的任务流图.
|
||||
|
||||
verbose 模式
|
||||
------------
|
||||
``CliRunner`` 默认 ``verbose=True``, 会:
|
||||
1. 打印任务生命周期 (开始/成功/失败/跳过) 到 stdout
|
||||
2. 对 ``cmd`` 类任务, 显示执行的命令及其标准输出/标准错误
|
||||
|
||||
可通过构造参数 ``verbose=False`` 或命令行 ``--quiet`` 关闭.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import enum
|
||||
import sys
|
||||
from dataclasses import dataclass, field, replace
|
||||
from pathlib import Path
|
||||
from typing import Any, Sequence, get_args
|
||||
|
||||
from .compose import GraphComposer
|
||||
from .errors import PyFlowXError
|
||||
from .executors import Strategy, run
|
||||
from .graph import Graph
|
||||
from .task import TaskSpec
|
||||
|
||||
__all__ = ["CliExitCode", "CliRunner"]
|
||||
|
||||
|
||||
class CliExitCode(enum.IntEnum):
|
||||
"""CliRunner 退出码."""
|
||||
|
||||
SUCCESS = 0
|
||||
FAILURE = 1
|
||||
INTERRUPTED = 130 # 与 POSIX 信号中断一致
|
||||
|
||||
|
||||
def _apply_verbose_to_graph(graph: Graph, verbose: bool) -> Graph:
|
||||
"""创建新图, 其中所有 TaskSpec 的 verbose 字段被设置为指定值.
|
||||
|
||||
使用 ``dataclasses.replace`` 在不可变的 TaskSpec 上创建带 verbose 标记的副本.
|
||||
依赖关系、标签等元数据全部保留.
|
||||
|
||||
Note
|
||||
-----
|
||||
自 ``_wrap_cmd`` 不再闭包捕获 ``verbose`` 后,此函数不再是必需的——
|
||||
直接翻转 ``spec.verbose`` 即可生效。保留是为了向后兼容现有调用与测试。
|
||||
TaskSpec 仍是 frozen dataclass,故仍用 ``replace`` 创建副本。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
graph : Graph
|
||||
原始图.
|
||||
verbose : bool
|
||||
要设置的 verbose 值.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Graph
|
||||
所有 spec 的 verbose 字段已更新的新图.
|
||||
"""
|
||||
new_specs: list[TaskSpec[Any]] = []
|
||||
for spec in graph.all_specs().values():
|
||||
if spec.verbose == verbose:
|
||||
new_specs.append(spec)
|
||||
else:
|
||||
new_specs.append(replace(spec, verbose=verbose))
|
||||
return Graph.from_specs(new_specs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CliRunner:
|
||||
"""命令行运行器: 根据用户输入执行对应的任务流图.
|
||||
|
||||
将命令别名映射到 Graph 实例. 通过 ``sys.argv`` 解析用户输入的命令,
|
||||
执行对应的图.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
aliases : dict[str, str | list[str] | Graph]
|
||||
命令别名到任务引用的映射. 每个值可以是:
|
||||
* ``str`` —— 单个任务名 (引用 ``tasks`` 中注册的任务),
|
||||
生成单任务图.
|
||||
* ``list[str]`` —— 任务名列表, 自动 :meth:`Graph.chain` 建立链式依赖,
|
||||
即后一个任务依赖前一个.
|
||||
* :class:`~pyflowx.graph.Graph` —— 直接使用该图 (用于复杂场景, 如
|
||||
自定义 ``conditions``、并行分支等).
|
||||
tasks : list[TaskSpec]
|
||||
扁平注册的任务列表. ``aliases`` 中的字符串引用这些任务名.
|
||||
未被任何 alias 引用的任务不会被执行.
|
||||
strategy : str | Strategy
|
||||
默认执行策略. 可被命令行 ``--strategy`` 覆盖.
|
||||
description : str
|
||||
CLI 帮助文本.
|
||||
verbose : bool
|
||||
是否显示详细执行过程. 默认 ``True``, 可被命令行 ``--quiet`` 关闭.
|
||||
|
||||
Examples
|
||||
--------
|
||||
简单场景 (tasks + aliases)::
|
||||
|
||||
runner = px.CliRunner(
|
||||
tasks=[
|
||||
px.cmd(["uv", "build"]), # name="uv_build"
|
||||
px.cmd(["maturin", "build"], name="maturin_build"),
|
||||
px.cmd(["ruff", "check", "--fix"], name="lint"),
|
||||
],
|
||||
aliases={
|
||||
"b": "uv_build",
|
||||
"ba": ["uv_build", "maturin_build"], # chain: maturin 依赖 uv
|
||||
"lint": "lint",
|
||||
},
|
||||
)
|
||||
runner.run()
|
||||
|
||||
复杂场景 (直接用 Graph)::
|
||||
|
||||
runner = px.CliRunner(
|
||||
aliases={
|
||||
"a": px.Graph.from_specs([
|
||||
px.TaskSpec("add", cmd=["git", "add", "."], conditions=(...)),
|
||||
px.TaskSpec("commit", cmd=["git", "commit"], depends_on=("add",)),
|
||||
]),
|
||||
},
|
||||
)
|
||||
"""
|
||||
|
||||
aliases: dict[str, str | list[str | TaskSpec[Any]] | TaskSpec[Any] | Graph] = field(default_factory=dict)
|
||||
tasks: list[TaskSpec[Any]] = field(default_factory=list)
|
||||
strategy: Strategy = field(default="dependency")
|
||||
description: str = field(default_factory=str)
|
||||
verbose: bool = field(default_factory=lambda: True)
|
||||
# 解析后的命令→图映射,__post_init__ 填充
|
||||
graphs: dict[str, Graph] = field(default_factory=dict, init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.aliases:
|
||||
raise ValueError("CliRunner 至少需要一个别名 (通过 aliases= 提供)")
|
||||
|
||||
# 1. 把 tasks 注册为虚拟命令图(每个 task 一个图),加入 raw_graphs
|
||||
# 使 GraphComposer 能解析对它们的字符串引用
|
||||
raw_graphs: dict[str, Graph] = {}
|
||||
for spec in self.tasks:
|
||||
if spec.name in raw_graphs:
|
||||
raise ValueError(f"任务名重复: {spec.name!r}")
|
||||
raw_graphs[spec.name] = Graph.from_specs([spec])
|
||||
|
||||
# 2. 把每个 alias 转为 Graph(alias 名可与 task 名相同,覆盖 task 注册)
|
||||
for alias, value in self.aliases.items():
|
||||
raw_graphs[alias] = self._alias_to_graph(alias, value)
|
||||
|
||||
# 3. 解析图间字符串引用(str / list[str] 引用其他 alias 或任务)
|
||||
self.graphs = GraphComposer(raw_graphs).resolve_all()
|
||||
|
||||
@staticmethod
|
||||
def _alias_to_graph(
|
||||
alias: str,
|
||||
value: str | list[str | TaskSpec[Any]] | TaskSpec[Any] | Graph,
|
||||
) -> Graph:
|
||||
"""把 alias 的值转换为 Graph.
|
||||
|
||||
* ``str`` —— 对其他 alias 或已注册任务名的引用, 由 GraphComposer 展开.
|
||||
* ``TaskSpec`` —— 单个内联任务, 生成单任务图.
|
||||
* ``list[str | TaskSpec]`` —— 引用/任务混合列表, GraphComposer 展开时
|
||||
自动让后续引用依赖前面 (chain 语义). 元素为 alias 名、任务名或
|
||||
:class:`TaskSpec` 对象 (内联任务).
|
||||
* ``Graph`` —— 原样返回 (用于复杂场景: conditions、并行分支等).
|
||||
"""
|
||||
if isinstance(value, Graph):
|
||||
return value
|
||||
if isinstance(value, TaskSpec):
|
||||
return Graph.from_specs([value])
|
||||
if isinstance(value, str):
|
||||
# 字符串引用,用 _pending_refs 占位,GraphComposer 后续展开
|
||||
return Graph.from_specs([value]) # type: ignore[arg-type]
|
||||
if isinstance(value, list):
|
||||
if not value:
|
||||
raise ValueError(f"别名 {alias!r} 的任务列表为空")
|
||||
for item in value:
|
||||
if not isinstance(item, (str, TaskSpec)):
|
||||
raise TypeError(f"别名 {alias!r} 的列表元素类型无效: {type(item).__name__}, 预期 str 或 TaskSpec")
|
||||
# str/TaskSpec 混合列表,由 GraphComposer 展开(自动建立 chain 依赖)
|
||||
return Graph.from_specs(value)
|
||||
raise TypeError(
|
||||
f"别名 {alias!r} 的值类型无效: {type(value).__name__}, 预期 str/TaskSpec/list[str|TaskSpec]/Graph"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 内省
|
||||
# ------------------------------------------------------------------ #
|
||||
@property
|
||||
def commands(self) -> list[str]:
|
||||
"""可用的命令列表 (按 aliases 定义顺序, 不含 tasks 中未引用的任务)."""
|
||||
return list(self.aliases.keys())
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 参数解析
|
||||
# ------------------------------------------------------------------ #
|
||||
def _prog_name(self) -> str:
|
||||
"""从 sys.argv[0] 推导程序名."""
|
||||
return Path(sys.argv[0]).name if sys.argv else "pyflowx"
|
||||
|
||||
def create_parser(self) -> argparse.ArgumentParser:
|
||||
"""创建参数解析器.
|
||||
|
||||
子类可覆盖此方法以添加自定义参数. 覆盖时应保留 ``command``
|
||||
位置参数与 ``--strategy`` / ``--dry-run`` / ``--list`` / ``--quiet``
|
||||
选项, 否则 :meth:`run` 的默认逻辑可能失效.
|
||||
|
||||
Returns
|
||||
-------
|
||||
argparse.ArgumentParser
|
||||
新创建的参数解析器实例.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
prog=self._prog_name(),
|
||||
description=self.description or "PyFlowX CLI Runner",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog=self._format_commands_help(),
|
||||
)
|
||||
_ = parser.add_argument(
|
||||
"command",
|
||||
nargs="?",
|
||||
help="要执行的命令",
|
||||
)
|
||||
_ = parser.add_argument(
|
||||
"--strategy",
|
||||
choices=list(get_args(Strategy)),
|
||||
default=self.strategy,
|
||||
help="执行策略 (默认: %(default)s)",
|
||||
)
|
||||
_ = parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="只打印执行计划, 不实际运行",
|
||||
)
|
||||
_ = parser.add_argument(
|
||||
"--list",
|
||||
action="store_true",
|
||||
help="列出所有可用命令",
|
||||
)
|
||||
_ = parser.add_argument(
|
||||
"--quiet",
|
||||
action="store_true",
|
||||
help="静默模式, 不显示执行过程 (覆盖默认 verbose)",
|
||||
)
|
||||
return parser
|
||||
|
||||
def _format_commands_help(self) -> str:
|
||||
"""格式化命令帮助文本."""
|
||||
return "可用命令:\n" + " | ".join(self.graphs.keys())
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 执行
|
||||
# ------------------------------------------------------------------ #
|
||||
def run(self, args: Sequence[str] | None = None) -> int:
|
||||
"""解析参数并执行对应的图.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args : Sequence[str] | None
|
||||
参数列表, 默认使用 ``sys.argv[1:]``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
退出码 (0 成功, 1 失败, 130 中断).
|
||||
|
||||
Raises
|
||||
------
|
||||
SystemExit
|
||||
当 argparse 无法解析参数时 (与标准 argparse 行为一致).
|
||||
"""
|
||||
parser = self.create_parser()
|
||||
parsed = parser.parse_args(args)
|
||||
|
||||
# --list: 列出命令
|
||||
if parsed.list:
|
||||
print(self._format_commands_help())
|
||||
return CliExitCode.SUCCESS.value
|
||||
|
||||
# 无命令: 显示帮助
|
||||
if not parsed.command:
|
||||
parser.print_help()
|
||||
return CliExitCode.FAILURE.value
|
||||
|
||||
# 验证命令(必须是已注册的 alias,不接受裸任务名)
|
||||
if parsed.command not in self.aliases:
|
||||
available = ", ".join(self.commands)
|
||||
print(
|
||||
f"错误: 未知命令 {parsed.command!r} (可用命令: {available})",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return CliExitCode.FAILURE.value
|
||||
|
||||
# 确定是否 verbose: --quiet 覆盖默认值
|
||||
verbose = self.verbose and not parsed.quiet
|
||||
|
||||
# 对图应用 verbose 设置 (重建带 verbose 标记的 spec)
|
||||
graph = self.graphs[parsed.command]
|
||||
if verbose:
|
||||
graph = _apply_verbose_to_graph(graph, verbose=True)
|
||||
|
||||
# 执行对应的图
|
||||
try:
|
||||
report = run(
|
||||
graph,
|
||||
strategy=parsed.strategy,
|
||||
dry_run=parsed.dry_run,
|
||||
verbose=verbose,
|
||||
)
|
||||
return CliExitCode.SUCCESS.value if report.success else CliExitCode.FAILURE.value
|
||||
except KeyboardInterrupt:
|
||||
print("\n操作已取消", file=sys.stderr)
|
||||
return CliExitCode.INTERRUPTED.value
|
||||
except PyFlowXError as e:
|
||||
print(f"错误: {e}", file=sys.stderr)
|
||||
return CliExitCode.FAILURE.value
|
||||
|
||||
def run_cli(self, args: Sequence[str] | None = None) -> None:
|
||||
"""运行并以退出码退出进程.
|
||||
|
||||
作为 CLI 工具运行时的入口点, 等价于 ``sys.exit(self.run(args))``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args : Sequence[str] | None
|
||||
参数列表, 默认使用 ``sys.argv[1:]``.
|
||||
"""
|
||||
sys.exit(self.run(args))
|
||||
+206
-57
@@ -4,93 +4,213 @@
|
||||
执行器向后端查询某任务是否已有存储结果;若有则跳过该任务,并将其
|
||||
存储值注入下游任务。
|
||||
|
||||
本模块刻意保持最小化:仅持久化*成功*结果(失败任务会重跑),存储
|
||||
形态为扁平的 ``{task_name: result}`` 映射。内置两个后端:
|
||||
存储键由 :meth:`TaskSpec.storage_key` 计算,默认为任务名;若任务配置
|
||||
了 ``cache_key``,则键为 ``"name:cache_key_value"``,使不同输入产生
|
||||
独立缓存条目。
|
||||
|
||||
* :class:`MemoryBackend` —— 快速、进程内、无 I/O。默认。
|
||||
* :class:`JSONBackend` —— 持久化到 JSON 文件,支持跨进程续跑。
|
||||
|
||||
两者均零依赖(``json`` 为标准库)。用户可子类化
|
||||
:class:`StateBackend` 接入 SQLite、Redis 等。
|
||||
支持 TTL:``has`` 在条目过期时返回 ``False``。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any, ContextManager, Mapping
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override # pragma: no cover
|
||||
|
||||
from .errors import StorageError
|
||||
|
||||
|
||||
class StateBackend(ABC):
|
||||
"""可续跑状态存储的抽象基类。"""
|
||||
"""可续跑状态存储的抽象基类。
|
||||
|
||||
所有方法以 ``key`` 为参数(通常为任务名或 ``name:cache_key``)。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> Mapping[str, Any]:
|
||||
"""返回完整的存储映射(可能为空)。"""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, name: str, value: Any) -> None:
|
||||
def save(self, key: str, value: Any) -> None:
|
||||
"""持久化单个任务的成功结果。"""
|
||||
|
||||
@abstractmethod
|
||||
def has(self, name: str) -> bool:
|
||||
"""``name`` 是否已有存储结果。"""
|
||||
def has(self, key: str) -> bool:
|
||||
"""``key`` 是否已有未过期的存储结果。"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, name: str) -> Any:
|
||||
"""返回 ``name`` 的存储结果(不存在则抛 ``KeyError``)。"""
|
||||
def get(self, key: str) -> Any:
|
||||
"""返回 ``key`` 的存储结果(不存在则抛 ``KeyError``)。"""
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""清除所有存储状态。"""
|
||||
|
||||
def flush(self) -> None: # noqa: B027
|
||||
"""将内存中暂存的状态持久化到外部介质。
|
||||
|
||||
class MemoryBackend(StateBackend):
|
||||
"""进程内 dict 后端。进程退出即丢失。"""
|
||||
默认无操作(如 :class:`MemoryBackend` 无需落盘)。
|
||||
:class:`JSONBackend` 在 :meth:`batch` 期间会延迟落盘,需在退出时调用。
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._store: Dict[str, Any] = {}
|
||||
def batch(self) -> ContextManager[None]:
|
||||
"""返回一个上下文管理器,期间 :meth:`save` 可延迟 :meth:`flush`。
|
||||
|
||||
默认实现为 no-op(如 :class:`MemoryBackend`)。:class:`JSONBackend`
|
||||
覆盖为:进入时标记延迟,退出时统一 flush 一次,将每任务一次落盘
|
||||
(N 次写入)降为整次运行一次(O(N) 而非 O(N²))。
|
||||
"""
|
||||
return nullcontext()
|
||||
|
||||
|
||||
class _TTLStateBackendMixin(StateBackend):
|
||||
"""TTL 状态后端共享逻辑。
|
||||
|
||||
将 ``has`` / ``get`` / ``load`` / ``save`` / ``clear`` 的统一实现
|
||||
委托给四个原始存取原语::meth:`_get_raw`、:meth:`_put_raw`、
|
||||
:meth:`_iter_raw`、:meth:`_clear_raw`,并基于 :meth:`_now` 与
|
||||
``self._ttl`` 提供统一的过期判断 :meth:`_is_expired`。
|
||||
|
||||
子类需设置 ``self._ttl`` 并实现上述四个原语;如需自定义时间源
|
||||
(如 ``time.monotonic``)可覆盖 :meth:`_now`。
|
||||
"""
|
||||
|
||||
_ttl: float | None
|
||||
|
||||
# ---- 原语:由子类实现 ---- #
|
||||
@abstractmethod
|
||||
def _get_raw(self, key: str) -> tuple[Any, float] | None:
|
||||
"""返回 ``(value, ts)``;键不存在时返回 ``None``。"""
|
||||
|
||||
@abstractmethod
|
||||
def _put_raw(self, key: str, value: Any, ts: float) -> None:
|
||||
"""写入一条记录。"""
|
||||
|
||||
@abstractmethod
|
||||
def _iter_raw(self) -> Iterator[tuple[str, Any, float]]:
|
||||
"""迭代所有记录(不做过期过滤),yield ``(key, value, ts)``。"""
|
||||
|
||||
@abstractmethod
|
||||
def _clear_raw(self) -> None:
|
||||
"""清空所有记录。"""
|
||||
|
||||
# ---- 共享实现 ---- #
|
||||
def _now(self) -> float:
|
||||
"""当前时间戳,默认为 wall-clock 秒。"""
|
||||
return time.time()
|
||||
|
||||
def _is_expired(self, ts: float) -> bool:
|
||||
"""时间戳 ``ts`` 是否已过期。"""
|
||||
if self._ttl is None:
|
||||
return False
|
||||
return (self._now() - ts) > self._ttl
|
||||
|
||||
@override
|
||||
def load(self) -> Mapping[str, Any]:
|
||||
return dict(self._store)
|
||||
return {k: v for k, v, ts in self._iter_raw() if not self._is_expired(ts)}
|
||||
|
||||
def save(self, name: str, value: Any) -> None:
|
||||
self._store[name] = value
|
||||
@override
|
||||
def save(self, key: str, value: Any) -> None:
|
||||
self._put_raw(key, value, self._now())
|
||||
|
||||
def has(self, name: str) -> bool:
|
||||
return name in self._store
|
||||
@override
|
||||
def has(self, key: str) -> bool:
|
||||
entry = self._get_raw(key)
|
||||
return entry is not None and not self._is_expired(entry[1])
|
||||
|
||||
def get(self, name: str) -> Any:
|
||||
return self._store[name]
|
||||
@override
|
||||
def get(self, key: str) -> Any:
|
||||
entry = self._get_raw(key)
|
||||
if entry is None or self._is_expired(entry[1]):
|
||||
raise KeyError(key)
|
||||
return entry[0]
|
||||
|
||||
@override
|
||||
def clear(self) -> None:
|
||||
self._clear_raw()
|
||||
|
||||
|
||||
class MemoryBackend(_TTLStateBackendMixin):
|
||||
"""进程内 dict 后端。进程退出即丢失。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ttl:
|
||||
条目存活秒数。``None`` 表示永不过期。``has`` 在条目超过 ttl 后
|
||||
返回 ``False``(但不主动删除,下次 ``save`` 覆盖)。
|
||||
"""
|
||||
|
||||
def __init__(self, ttl: float | None = None) -> None:
|
||||
self._store: dict[str, tuple[Any, float]] = {}
|
||||
self._ttl = ttl
|
||||
|
||||
@override
|
||||
def _now(self) -> float:
|
||||
return time.monotonic()
|
||||
|
||||
@override
|
||||
def _get_raw(self, key: str) -> tuple[Any, float] | None:
|
||||
return self._store.get(key)
|
||||
|
||||
@override
|
||||
def _put_raw(self, key: str, value: Any, ts: float) -> None:
|
||||
self._store[key] = (value, ts)
|
||||
|
||||
@override
|
||||
def _iter_raw(self) -> Iterator[tuple[str, Any, float]]:
|
||||
for k, (v, ts) in self._store.items():
|
||||
yield k, v, ts
|
||||
|
||||
@override
|
||||
def _clear_raw(self) -> None:
|
||||
self._store.clear()
|
||||
|
||||
|
||||
class JSONBackend(StateBackend):
|
||||
class JSONBackend(_TTLStateBackendMixin):
|
||||
"""基于文件的 JSON 存储,用于跨进程续跑。
|
||||
|
||||
结果必须可 JSON 序列化。不可序列化的值会抛出
|
||||
:class:`~pyflowx.errors.StorageError`(运行本身不会中止;仅该条
|
||||
结果的持久化失败)。
|
||||
存储格式:``{key: {"value": v, "ts": epoch_seconds}}``。
|
||||
``ts`` 用于 TTL 判断。结果必须可 JSON 序列化。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path:
|
||||
JSON 文件路径。
|
||||
ttl:
|
||||
条目存活秒数。``None`` 表示永不过期。
|
||||
"""
|
||||
|
||||
def __init__(self, path: str) -> None:
|
||||
self._path = path
|
||||
self._store: Dict[str, Any] = {}
|
||||
def __init__(self, path: str, ttl: float | None = None) -> None:
|
||||
self._path: str = path
|
||||
self._ttl = ttl
|
||||
self._store: dict[str, dict[str, Any]] = {}
|
||||
self._defer_flush: bool = False
|
||||
self._load()
|
||||
|
||||
def _load(self) -> None:
|
||||
if not os.path.exists(self._path):
|
||||
if not Path(self._path).exists():
|
||||
return
|
||||
try:
|
||||
with open(self._path, "r", encoding="utf-8") as fh:
|
||||
data = json.load(fh)
|
||||
with open(self._path, encoding="utf-8") as fh:
|
||||
data: Any = json.load(fh)
|
||||
if isinstance(data, dict):
|
||||
self._store = data
|
||||
# 兼容纯值格式与带元数据格式
|
||||
self._store = {}
|
||||
for k, v in data.items():
|
||||
if isinstance(v, dict) and "value" in v and "ts" in v:
|
||||
self._store[k] = v
|
||||
else:
|
||||
self._store[k] = {"value": v, "ts": time.time()}
|
||||
except (OSError, json.JSONDecodeError) as exc:
|
||||
raise StorageError(f"cannot read state file {self._path!r}", exc) from exc
|
||||
|
||||
@@ -99,35 +219,64 @@ class JSONBackend(StateBackend):
|
||||
try:
|
||||
with open(tmp, "w", encoding="utf-8") as fh:
|
||||
json.dump(self._store, fh, ensure_ascii=False, indent=2)
|
||||
os.replace(tmp, self._path)
|
||||
_ = Path(tmp).replace(Path(self._path))
|
||||
except (OSError, TypeError) as exc:
|
||||
raise StorageError(f"cannot write state file {self._path!r}", exc) from exc
|
||||
|
||||
def load(self) -> Mapping[str, Any]:
|
||||
return dict(self._store)
|
||||
@override
|
||||
def _get_raw(self, key: str) -> tuple[Any, float] | None:
|
||||
entry = self._store.get(key)
|
||||
if entry is None:
|
||||
return None
|
||||
return entry["value"], float(entry.get("ts", 0))
|
||||
|
||||
def save(self, name: str, value: Any) -> None:
|
||||
# 在修改内存状态前先校验可序列化性。
|
||||
try:
|
||||
json.dumps(value)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise StorageError(
|
||||
f"result of task {name!r} is not JSON-serialisable", exc
|
||||
) from exc
|
||||
self._store[name] = value
|
||||
self._flush()
|
||||
@override
|
||||
def _put_raw(self, key: str, value: Any, ts: float) -> None:
|
||||
self._store[key] = {"value": value, "ts": ts}
|
||||
|
||||
def has(self, name: str) -> bool:
|
||||
return name in self._store
|
||||
@override
|
||||
def _iter_raw(self) -> Iterator[tuple[str, Any, float]]:
|
||||
for k, entry in self._store.items():
|
||||
yield k, entry["value"], float(entry.get("ts", 0))
|
||||
|
||||
def get(self, name: str) -> Any:
|
||||
return self._store[name]
|
||||
|
||||
def clear(self) -> None:
|
||||
@override
|
||||
def _clear_raw(self) -> None:
|
||||
self._store.clear()
|
||||
|
||||
@override
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self._flush()
|
||||
|
||||
@override
|
||||
def save(self, key: str, value: Any) -> None:
|
||||
try:
|
||||
_ = json.dumps(value)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise StorageError(f"result of key {key!r} is not JSON-serialisable", exc) from exc
|
||||
super().save(key, value)
|
||||
if not self._defer_flush:
|
||||
self._flush()
|
||||
|
||||
def resolve_backend(backend: Optional[StateBackend]) -> StateBackend:
|
||||
@override
|
||||
def flush(self) -> None:
|
||||
self._flush()
|
||||
|
||||
@override
|
||||
@contextmanager
|
||||
def batch(self) -> Iterator[None]:
|
||||
"""进入批量模式:``save`` 暂不落盘,退出时统一 flush 一次。
|
||||
|
||||
将整次运行 N 个任务的 N 次全量落盘降为 1 次。
|
||||
"""
|
||||
self._defer_flush = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._defer_flush = False
|
||||
self._flush()
|
||||
|
||||
|
||||
def resolve_backend(backend: StateBackend | None) -> StateBackend:
|
||||
"""返回 ``backend``;为 ``None`` 时返回新的 :class:`MemoryBackend`。"""
|
||||
return backend if backend is not None else MemoryBackend()
|
||||
|
||||
+525
-39
@@ -17,22 +17,36 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Coroutine,
|
||||
Generator,
|
||||
Generic,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar
|
||||
else:
|
||||
from typing_extensions import TypeVar # pragma: no cover
|
||||
|
||||
T = TypeVar("T", default=Any)
|
||||
|
||||
# 任务可调用对象可以是同步或异步的。显式保留联合类型,让 mypy 理解两种形态。
|
||||
TaskFn = Union[
|
||||
@@ -44,6 +58,112 @@ TaskFn = Union[
|
||||
# 单任务类型由函数签名本身保留。
|
||||
Context = Mapping[str, Any]
|
||||
|
||||
# 命令类型支持
|
||||
TaskCmd = Union[
|
||||
List[str], # 命令列表, 如 ["ls", "-la"]
|
||||
str, # shell 命令字符串
|
||||
Callable[..., Any], # Python 函数
|
||||
]
|
||||
|
||||
# 执行策略:sequential/thread/async 为层屏障模型,dependency 为依赖驱动模型。
|
||||
Strategy = Union[str, "StrategyKind"]
|
||||
StrategyKind = Any # 占位,避免循环;executors 模块用 Literal 约束
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 条件判断函数类型:接收依赖上下文(可能为空映射),返回是否应执行。
|
||||
Condition = Callable[[Context], bool]
|
||||
|
||||
# 缓存键计算函数:基于依赖上下文计算稳定字符串键。
|
||||
CacheKeyFn = Callable[[Context], str]
|
||||
|
||||
|
||||
def _format_skip_reason(failed_conditions: list[str]) -> str:
|
||||
"""格式化跳过原因:≤2 个全展示,>2 个仅展示前 2 个并附总数。"""
|
||||
if len(failed_conditions) <= 2:
|
||||
return f"条件不满足: {', '.join(failed_conditions)}"
|
||||
return f"条件不满足: {', '.join(failed_conditions[:2])} 等{len(failed_conditions)}个条件"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 重试策略
|
||||
# ---------------------------------------------------------------------- #
|
||||
@dataclass(frozen=True)
|
||||
class RetryPolicy:
|
||||
"""任务失败重试策略。
|
||||
|
||||
参数
|
||||
----
|
||||
max_attempts:
|
||||
最大尝试次数(含首次)。``1`` 表示仅尝试一次,不重试。
|
||||
delay:
|
||||
两次尝试之间的初始等待秒数。
|
||||
backoff:
|
||||
退避倍率。第 n 次重试等待 ``delay * backoff ** (n-1)``。
|
||||
jitter:
|
||||
抖动上限秒数。每次等待加上 ``[0, jitter)`` 的随机量,避免惊群。
|
||||
retry_on:
|
||||
仅对这些异常类型重试。默认 ``(Exception,)`` 重试所有异常。
|
||||
传入空元组等价于不重试。
|
||||
|
||||
Note
|
||||
-----
|
||||
替代旧版 ``retries: int``。``retries=2`` 等价于
|
||||
``RetryPolicy(max_attempts=3)``。
|
||||
"""
|
||||
|
||||
max_attempts: int = 1
|
||||
delay: float = 0.0
|
||||
backoff: float = 1.0
|
||||
jitter: float = 0.0
|
||||
retry_on: tuple[type[BaseException], ...] = (Exception,)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.max_attempts < 1:
|
||||
raise ValueError(f"RetryPolicy.max_attempts must be >= 1, got {self.max_attempts}.")
|
||||
if self.delay < 0:
|
||||
raise ValueError(f"RetryPolicy.delay must be >= 0, got {self.delay}.")
|
||||
if self.backoff < 0:
|
||||
raise ValueError(f"RetryPolicy.backoff must be >= 0, got {self.backoff}.")
|
||||
if self.jitter < 0:
|
||||
raise ValueError(f"RetryPolicy.jitter must be >= 0, got {self.jitter}.")
|
||||
|
||||
@property
|
||||
def retries(self) -> int:
|
||||
"""重试次数(不含首次),等价于 ``max_attempts - 1``。"""
|
||||
return self.max_attempts - 1
|
||||
|
||||
def should_retry(self, exc: BaseException) -> bool:
|
||||
"""异常是否属于可重试类型。"""
|
||||
return isinstance(exc, self.retry_on)
|
||||
|
||||
def wait_seconds(self, attempt: int) -> float:
|
||||
"""第 ``attempt`` 次失败后应等待的秒数(attempt 从 1 开始)。"""
|
||||
if attempt < 1:
|
||||
return 0.0
|
||||
import random
|
||||
|
||||
base = self.delay * (self.backoff ** max(0, attempt - 1))
|
||||
jitter = random.uniform(0, self.jitter) if self.jitter > 0 else 0.0
|
||||
return base + jitter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 任务钩子
|
||||
# ---------------------------------------------------------------------- #
|
||||
@dataclass(frozen=True)
|
||||
class TaskHooks:
|
||||
"""任务生命周期钩子。
|
||||
|
||||
所有钩子均为可选。``pre_run`` 在任务实际执行前调用;``post_run``
|
||||
在成功后调用并接收返回值;``on_failure`` 在最终失败后调用并接收异常。
|
||||
钩子异常不会影响任务状态,仅记录日志。
|
||||
"""
|
||||
|
||||
pre_run: Callable[[TaskSpec[Any]], None] | None = None
|
||||
post_run: Callable[[TaskSpec[Any], Any], None] | None = None
|
||||
on_failure: Callable[[TaskSpec[Any], BaseException], None] | None = None
|
||||
|
||||
|
||||
class TaskStatus(Enum):
|
||||
"""任务在单次运行内的生命周期状态。"""
|
||||
@@ -66,63 +186,432 @@ class TaskSpec(Generic[T]):
|
||||
fn:
|
||||
待执行的可调用对象,可为同步或异步。其参数名驱动自动上下文
|
||||
注入(见 :mod:`pyflowx.context`)。
|
||||
若提供 ``cmd`` 参数,则此参数会被忽略。
|
||||
cmd:
|
||||
命令列表或 shell 字符串,支持三种形态:
|
||||
- ``list[str]``: 命令及参数列表,如 ``["ls", "-la"]``
|
||||
- ``str``: shell 命令字符串,如 ``"pip freeze > requirements.txt"``
|
||||
- ``Callable``: Python 函数,与 ``fn`` 参数等效
|
||||
depends_on:
|
||||
必须先完成才能运行本任务的任务名列表。顺序无关;框架会做
|
||||
拓扑排序。
|
||||
硬依赖任务名。必须全部成功完成才会运行本任务。
|
||||
上游被 SKIPPED 时,本任务也会被 SKIPPED(除非
|
||||
``allow_upstream_skip=True``)。
|
||||
soft_depends_on:
|
||||
软依赖任务名。会等待其完成,但其结果不影响本任务是否执行:
|
||||
- 上游成功:注入其返回值
|
||||
- 上游 SKIPPED 或失败:注入 :attr:`defaults` 中提供的默认值
|
||||
适用于"可选输入"场景。
|
||||
defaults:
|
||||
软依赖的默认值映射 ``{dep_name: default_value}``。
|
||||
软依赖未提供结果时使用。未在 defaults 中出现的软依赖默认为 ``None``。
|
||||
args:
|
||||
静态位置参数,追加在注入参数*之后*。适用于参数化任务
|
||||
(如 ``fetch_user(uid)``)。
|
||||
静态位置参数,追加在注入参数*之后*。
|
||||
kwargs:
|
||||
静态关键字参数。若与注入名冲突则抛出
|
||||
:class:`~pyflowx.errors.InjectionError`。
|
||||
retries:
|
||||
失败后的重试次数。``0`` 表示仅尝试一次。
|
||||
retry:
|
||||
:class:`RetryPolicy` 重试策略。默认仅尝试一次。
|
||||
timeout:
|
||||
最大执行时长(秒)。``None`` 表示不限制。异步任务使用
|
||||
:func:`asyncio.wait_for`;线程/异步执行器中的同步任务会
|
||||
取消 worker future。
|
||||
:func:`asyncio.wait_for`;同步任务通过线程 future 取消。
|
||||
tags:
|
||||
自由标签,供 :meth:`Graph.subgraph` 做选择性执行与调试。
|
||||
自由标签,供 :meth:`Graph.subgraph` 做选择性执行与调试,
|
||||
也可用于并发限制分组。
|
||||
conditions:
|
||||
条件判断函数列表,接收依赖上下文,全部返回 ``True`` 时才执行任务。
|
||||
任一返回 ``False`` 则任务被标记为 SKIPPED。
|
||||
cwd:
|
||||
工作目录。对 ``cmd`` 任务作为子进程工作目录;对 ``fn`` 任务
|
||||
通过临时切换当前目录生效。
|
||||
env:
|
||||
环境变量覆盖映射。对 ``cmd`` 任务合并到子进程环境;对 ``fn``
|
||||
任务在执行期间临时设置。
|
||||
verbose:
|
||||
是否打印详细输出。``True`` 时打印执行的命令、返回码与输出
|
||||
(仅 ``cmd``),以及任务生命周期。
|
||||
skip_if_missing:
|
||||
仅对 ``cmd`` 为 ``list[str]`` 有效。``True`` 时通过
|
||||
:func:`shutil.which` 检查命令是否存在,不存在则跳过。
|
||||
allow_upstream_skip:
|
||||
若为 ``True``,硬依赖被 SKIPPED 时本任务仍执行(软依赖不影响)。
|
||||
适用于清理类任务。
|
||||
strategy:
|
||||
单任务执行策略覆盖。``None`` 表示继承图级策略。
|
||||
``"sequential"`` 同步直接调用;``"thread"``/``"async"`` 将同步
|
||||
任务卸载到线程池,异步任务跑在事件循环上。
|
||||
priority:
|
||||
同层任务调度优先级。数值越大越先启动。仅影响同层内启动顺序,
|
||||
不打破层屏障。默认 ``0``。
|
||||
concurrency_key:
|
||||
并发限制分组键。具有相同键的任务共享一个信号量,限制同时
|
||||
运行的实例数。具体限额由 :func:`run` 的 ``concurrency_limits``
|
||||
参数提供 ``{key: limit}`` 映射。``None`` 表示不限制。
|
||||
continue_on_error:
|
||||
若为 ``True``,任务最终失败时不中止整图,仅标记本任务 FAILED,
|
||||
其硬依赖下游被 SKIPPED,其余任务继续。默认 ``False``。
|
||||
cache_key:
|
||||
缓存键计算函数。若提供,则用其基于依赖上下文计算的字符串键
|
||||
存取状态后端,使不同输入产生独立缓存条目。``None`` 表示用任务名。
|
||||
hooks:
|
||||
:class:`TaskHooks` 生命周期钩子。
|
||||
executor:
|
||||
同步任务的执行器:``"thread"``(默认,线程池)/ ``"process"``
|
||||
(进程池,绕过 GIL,适合 CPU 密集型;``fn`` 须可 pickle)/
|
||||
``"inline"``(直接在事件循环线程调用,最快但会阻塞循环)。
|
||||
"""
|
||||
|
||||
name: str
|
||||
fn: TaskFn[T]
|
||||
depends_on: Tuple[str, ...] = ()
|
||||
args: Tuple[Any, ...] = ()
|
||||
fn: TaskFn[T] | None = None
|
||||
cmd: TaskCmd | None = None
|
||||
depends_on: tuple[str, ...] = ()
|
||||
soft_depends_on: tuple[str, ...] = ()
|
||||
defaults: Mapping[str, Any] = field(default_factory=dict)
|
||||
args: tuple[Any, ...] = ()
|
||||
kwargs: Mapping[str, Any] = field(default_factory=dict)
|
||||
retries: int = 0
|
||||
timeout: Optional[float] = None
|
||||
tags: Tuple[str, ...] = ()
|
||||
retry: RetryPolicy = field(default_factory=RetryPolicy)
|
||||
timeout: float | None = None
|
||||
tags: tuple[str, ...] = ()
|
||||
conditions: tuple[Condition, ...] = ()
|
||||
cwd: Path | None = None
|
||||
env: Mapping[str, str] | None = None
|
||||
verbose: bool = False
|
||||
skip_if_missing: bool = False
|
||||
allow_upstream_skip: bool = False
|
||||
strategy: str | None = None
|
||||
priority: int = 0
|
||||
concurrency_key: str | None = None
|
||||
continue_on_error: bool = False
|
||||
cache_key: CacheKeyFn | None = None
|
||||
hooks: TaskHooks = field(default_factory=TaskHooks)
|
||||
executor: str = "thread" # "thread" | "process" | "inline"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.name:
|
||||
raise ValueError("TaskSpec.name must be a non-empty string.")
|
||||
if self.retries < 0:
|
||||
raise ValueError(f"TaskSpec '{self.name}': retries must be >= 0.")
|
||||
if self.retry.max_attempts < 1:
|
||||
raise ValueError(f"TaskSpec '{self.name}': retry.max_attempts must be >= 1.")
|
||||
if self.timeout is not None and self.timeout <= 0:
|
||||
raise ValueError(f"TaskSpec '{self.name}': timeout must be > 0.")
|
||||
if self.name in self.depends_on:
|
||||
if self.name in self.depends_on or self.name in self.soft_depends_on:
|
||||
raise ValueError(f"TaskSpec '{self.name}' cannot depend on itself.")
|
||||
overlap = set(self.depends_on) & set(self.soft_depends_on)
|
||||
if overlap:
|
||||
raise ValueError(f"TaskSpec '{self.name}': depends_on 与 soft_depends_on 不能重叠: {sorted(overlap)}")
|
||||
if self.fn is None and self.cmd is None:
|
||||
raise ValueError(f"TaskSpec '{self.name}': 必须提供 fn 或 cmd 参数。")
|
||||
|
||||
@cached_property
|
||||
def effective_fn(self) -> TaskFn[T]:
|
||||
"""获取有效的执行函数。
|
||||
|
||||
若提供 ``cmd``,返回包装后的命令执行函数;否则返回 ``fn``。
|
||||
包装函数在每次调用时从 ``self`` 读取 ``verbose``/``cwd``/``env``/
|
||||
``timeout``,避免闭包捕获运行期参数,使翻转字段无需重建 spec。
|
||||
|
||||
结果按实例缓存(:func:`functools.cached_property`):frozen dataclass
|
||||
字段不可变,``_wrap_cmd`` 生成的闭包稳定,无需每次访问重建。
|
||||
"""
|
||||
if self.cmd is not None:
|
||||
return self._wrap_cmd()
|
||||
if self.fn is not None:
|
||||
return self.fn
|
||||
raise ValueError(f"TaskSpec '{self.name}': 没有可执行的函数或命令。") # pragma: no cover
|
||||
|
||||
def _wrap_cmd(self) -> TaskFn[Any]:
|
||||
"""将 cmd 包装为可执行函数。
|
||||
|
||||
实际执行逻辑位于 :mod:`pyflowx.command`,避免 :class:`TaskSpec`
|
||||
作为纯数据结构混入命令执行逻辑。
|
||||
"""
|
||||
from .command import run_command
|
||||
|
||||
spec = self
|
||||
|
||||
def _run() -> T:
|
||||
return cast(T, run_command(spec))
|
||||
|
||||
_run.__name__ = spec.name
|
||||
return _run # type: ignore[return-value]
|
||||
|
||||
def should_execute(self, context: Context) -> tuple[bool, str | None]:
|
||||
"""检查任务是否应执行。
|
||||
|
||||
Returns
|
||||
-------
|
||||
(should_run, skip_reason)
|
||||
``should_run`` 为 False 时 ``skip_reason`` 描述跳过原因。
|
||||
失败条件超过 2 个时仅展示前 2 个并附总数。
|
||||
"""
|
||||
# 逐个求值条件,记录失败项。
|
||||
failed_conditions: list[str] = []
|
||||
for condition in self.conditions:
|
||||
try:
|
||||
ok = condition(context)
|
||||
except Exception:
|
||||
ok = False
|
||||
failed_conditions.append("匿名条件(执行错误)")
|
||||
continue
|
||||
if not ok:
|
||||
reason = getattr(condition, "_reason", None)
|
||||
if reason is not None:
|
||||
failed_conditions.append(
|
||||
", ".join(str(r) for r in reason) if isinstance(reason, list) else str(reason),
|
||||
)
|
||||
else:
|
||||
failed_conditions.append(getattr(condition, "__name__", None) or "匿名条件")
|
||||
|
||||
if failed_conditions:
|
||||
return False, _format_skip_reason(failed_conditions)
|
||||
|
||||
if self.skip_if_missing and not self._is_cmd_available():
|
||||
cmd_name = self.cmd[0] if isinstance(self.cmd, list) and self.cmd else "unknown"
|
||||
return False, f"命令不存在: {cmd_name}"
|
||||
|
||||
return True, None
|
||||
|
||||
def _is_cmd_available(self) -> bool:
|
||||
"""检查 ``cmd`` 是否可用(仅 list[str])。"""
|
||||
cmd = self.cmd
|
||||
if isinstance(cmd, list) and cmd:
|
||||
return shutil.which(cmd[0]) is not None
|
||||
return True
|
||||
|
||||
def env_context(self) -> ContextManager[None]:
|
||||
"""返回临时应用 ``env`` 与 ``cwd`` 的上下文管理器。
|
||||
|
||||
对 ``fn`` 任务生效。``cmd`` 任务在 :func:`_run_command` 中直接
|
||||
传给子进程。
|
||||
"""
|
||||
return _env_and_cwd(self.env, self.cwd)
|
||||
|
||||
def storage_key(self, context: Context) -> str:
|
||||
"""计算状态后端存储键。"""
|
||||
if self.cache_key is None:
|
||||
return self.name
|
||||
try:
|
||||
return f"{self.name}:{self.cache_key(context)}"
|
||||
except (TypeError, ValueError, KeyError, AttributeError) as exc:
|
||||
# cache_key 抛出预期内的数据/类型异常时回退到 name,但仍记录警告
|
||||
# 以便用户发现 cache_key 实现中的 bug。
|
||||
logger.warning(
|
||||
"task %r: cache_key 回退到 name(%s: %s)",
|
||||
self.name,
|
||||
type(exc).__name__,
|
||||
exc,
|
||||
)
|
||||
return self.name
|
||||
|
||||
|
||||
# 全局锁:序列化对进程级状态(os.environ / os.chdir)的临时修改。
|
||||
# ``fn`` 任务在 thread/async 策略下并发执行时,若各自配置了不同的
|
||||
# ``cwd``/``env``,会相互覆盖(os.chdir 与 os.environ 均为进程全局)。
|
||||
# 该锁仅包裹"切换→执行→恢复"区间,保证正确性;不使用 cwd/env 的任务不受影响。
|
||||
_env_cwd_lock = threading.RLock()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _env_and_cwd(
|
||||
env: Mapping[str, str] | None,
|
||||
cwd: Path | None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""临时设置环境变量与工作目录。
|
||||
|
||||
``os.environ`` 与 ``os.chdir`` 是进程级全局状态,在 thread/async 策略下
|
||||
并发执行多个带 ``env``/``cwd`` 的 ``fn`` 任务时会相互覆盖。本函数通过
|
||||
模块级 :data:`_env_cwd_lock` 串行化"切换→执行→恢复"区间,确保正确性。
|
||||
无 ``env`` 且无 ``cwd`` 时直接 yield,不获取锁。
|
||||
"""
|
||||
if not env and cwd is None:
|
||||
yield
|
||||
return
|
||||
with _env_cwd_lock:
|
||||
saved_env: dict[str, str] = {}
|
||||
saved_cwd: str | None = None
|
||||
if env:
|
||||
for k, v in env.items():
|
||||
if k in os.environ:
|
||||
saved_env[k] = os.environ[k]
|
||||
os.environ[k] = v
|
||||
if cwd is not None:
|
||||
saved_cwd = str(Path.cwd())
|
||||
os.chdir(cwd)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if saved_cwd is not None:
|
||||
os.chdir(saved_cwd)
|
||||
# 恢复环境变量
|
||||
if env:
|
||||
for k in env:
|
||||
if k in saved_env:
|
||||
os.environ[k] = saved_env[k]
|
||||
else:
|
||||
os.environ.pop(k, None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 任务模板:批量生成相似 TaskSpec 的工厂
|
||||
# ---------------------------------------------------------------------- #
|
||||
def _task_noop() -> None:
|
||||
"""task(cmd=...) 形式下的占位 fn(cmd 任务执行期不调用 fn)。"""
|
||||
return None
|
||||
|
||||
|
||||
def task(
|
||||
fn: TaskFn[Any] | None = None,
|
||||
*,
|
||||
cmd: TaskCmd | None = None,
|
||||
depends_on: tuple[str, ...] = (),
|
||||
soft_depends_on: tuple[str, ...] = (),
|
||||
defaults: Mapping[str, Any] | None = None,
|
||||
args: tuple[Any, ...] = (),
|
||||
kwargs: Mapping[str, Any] | None = None,
|
||||
retry: RetryPolicy | None = None,
|
||||
timeout: float | None = None,
|
||||
tags: tuple[str, ...] = (),
|
||||
conditions: tuple[Condition, ...] = (),
|
||||
cwd: str | Path | None = None,
|
||||
env: Mapping[str, str] | None = None,
|
||||
verbose: bool = False,
|
||||
skip_if_missing: bool = False,
|
||||
allow_upstream_skip: bool = False,
|
||||
strategy: str | None = None,
|
||||
priority: int = 0,
|
||||
concurrency_key: str | None = None,
|
||||
continue_on_error: bool = False,
|
||||
cache_key: CacheKeyFn | None = None,
|
||||
hooks: TaskHooks | None = None,
|
||||
name: str | None = None,
|
||||
) -> Any:
|
||||
"""装饰器:将函数转为 :class:`TaskSpec`。
|
||||
|
||||
``name`` 默认取 ``fn.__name__``。可直接装饰函数,或带参数使用。
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> @px.task
|
||||
... def extract(): return [1, 2, 3]
|
||||
>>> @px.task(depends_on=("extract",))
|
||||
... def double(extract): return [x * 2 for x in extract]
|
||||
>>> graph = px.Graph.from_specs([extract, double])
|
||||
"""
|
||||
|
||||
def _decorate(func: TaskFn[Any]) -> TaskSpec[Any]:
|
||||
spec_name = name or func.__name__
|
||||
return TaskSpec(
|
||||
name=spec_name,
|
||||
fn=func,
|
||||
cmd=cmd,
|
||||
depends_on=depends_on,
|
||||
soft_depends_on=soft_depends_on,
|
||||
defaults=dict(defaults) if defaults else {},
|
||||
args=args,
|
||||
kwargs=dict(kwargs) if kwargs else {},
|
||||
retry=retry if retry is not None else RetryPolicy(),
|
||||
timeout=timeout,
|
||||
tags=tags,
|
||||
conditions=conditions,
|
||||
cwd=Path(cwd) if isinstance(cwd, str) else cwd,
|
||||
env=dict(env) if env else None,
|
||||
verbose=verbose,
|
||||
skip_if_missing=skip_if_missing,
|
||||
allow_upstream_skip=allow_upstream_skip,
|
||||
strategy=strategy,
|
||||
priority=priority,
|
||||
concurrency_key=concurrency_key,
|
||||
continue_on_error=continue_on_error,
|
||||
cache_key=cache_key,
|
||||
hooks=hooks if hooks is not None else TaskHooks(),
|
||||
)
|
||||
|
||||
if fn is None and cmd is None:
|
||||
# 带参数调用:@task(depends_on=...),等待被装饰函数
|
||||
return _decorate
|
||||
if fn is None:
|
||||
# task(cmd=..., name=...) 直接构造,无被装饰函数
|
||||
if name is None:
|
||||
raise ValueError("task(cmd=...) 需要显式提供 name")
|
||||
return _decorate(_task_noop)
|
||||
return _decorate(fn)
|
||||
|
||||
|
||||
def cmd(
|
||||
command: list[str],
|
||||
*,
|
||||
name: str | None = None,
|
||||
depends_on: tuple[str, ...] = (),
|
||||
**kwargs: Any,
|
||||
) -> TaskSpec[Any]:
|
||||
"""从命令列表快速创建 :class:`TaskSpec`。
|
||||
|
||||
``name`` 默认为 ``"_".join(command[:2])``(如 ``["uv", "build"]`` → ``"uv_build"``)。
|
||||
若命令不足两个元素则用 ``"_".join(command)``。
|
||||
|
||||
其余关键字参数透传给 :class:`TaskSpec`(如 ``depends_on``、``tags`` 等)。
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> uv_build = px.cmd(["uv", "build"])
|
||||
>>> uv_build.name
|
||||
'uv_build'
|
||||
>>> lint = px.cmd(["ruff", "check", "--fix"], name="lint")
|
||||
>>> lint.name
|
||||
'lint'
|
||||
"""
|
||||
spec_name = name or "_".join(command[:2]) if len(command) >= 2 else "_".join(command)
|
||||
return TaskSpec(
|
||||
name=spec_name,
|
||||
cmd=command,
|
||||
depends_on=depends_on,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def task_template(
|
||||
fn: TaskFn[Any] | None = None,
|
||||
cmd: TaskCmd | None = None,
|
||||
**defaults: Any,
|
||||
) -> Callable[..., TaskSpec[Any]]:
|
||||
"""创建任务模板工厂。
|
||||
|
||||
返回的工厂接受 ``name`` 与任意覆盖字段,生成 :class:`TaskSpec`。
|
||||
适用于批量创建相似任务(如 fan-out)。
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> Fetch = px.task_template(fn=fetch_user, retry=px.RetryPolicy(max_attempts=3))
|
||||
>>> specs = [Fetch(f"fetch_{uid}", args=(uid,)) for uid in range(5)]
|
||||
"""
|
||||
base = dict(defaults)
|
||||
if fn is not None:
|
||||
base["fn"] = fn
|
||||
if cmd is not None:
|
||||
base["cmd"] = cmd
|
||||
|
||||
def _factory(name: str, **overrides: Any) -> TaskSpec[Any]:
|
||||
merged = dict(base)
|
||||
merged.update(overrides)
|
||||
return TaskSpec(name, **merged)
|
||||
|
||||
_factory.__name__ = "task_template_factory"
|
||||
return _factory
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskResult(Generic[T]):
|
||||
"""运行期间产生的可变单任务记录。
|
||||
|
||||
每次运行都会创建全新的 :class:`TaskResult`;spec 本身保持不可变。
|
||||
这让同一个图可以安全地重复运行。
|
||||
"""
|
||||
"""运行期间产生的可变单任务记录。"""
|
||||
|
||||
spec: TaskSpec[T]
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
value: Optional[T] = None
|
||||
error: Optional[BaseException] = None
|
||||
value: T | None = None
|
||||
error: BaseException | None = None
|
||||
attempts: int = 0
|
||||
started_at: Optional[datetime] = None
|
||||
finished_at: Optional[datetime] = None
|
||||
started_at: datetime | None = None
|
||||
finished_at: datetime | None = None
|
||||
reason: str | None = None # 跳过原因
|
||||
|
||||
@property
|
||||
def duration(self) -> Optional[float]:
|
||||
def duration(self) -> float | None:
|
||||
"""从开始到结束的耗时(秒),未开始/未结束则为 ``None``。"""
|
||||
if self.started_at is None or self.finished_at is None:
|
||||
return None
|
||||
@@ -131,14 +620,11 @@ class TaskResult(Generic[T]):
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TaskEvent:
|
||||
"""执行期间向观察者发出的不可变事件。
|
||||
|
||||
传递给 :func:`pyflowx.run` 的 ``on_event`` 回调,让调用者无需耦合
|
||||
执行器内部即可构建进度条、指标或结构化日志。
|
||||
"""
|
||||
"""执行期间向观察者发出的不可变事件。"""
|
||||
|
||||
task: str
|
||||
status: TaskStatus
|
||||
attempts: int = 0
|
||||
error: Optional[str] = None
|
||||
duration: Optional[float] = None
|
||||
error: str | None = None
|
||||
duration: float | None = None
|
||||
reason: str | None = None
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
"""系统操作任务模块.
|
||||
|
||||
提供常用的系统操作任务封装, 包括清屏、环境变量设置、命令查找等.
|
||||
遵循实用主义原则, 仅提供核心功能, 无过度设计.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__ = [
|
||||
"clr",
|
||||
"reset_icon_cache",
|
||||
"setenv",
|
||||
"setenv_group",
|
||||
"which",
|
||||
"write_file",
|
||||
]
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx import BuiltinConditions
|
||||
from pyflowx.conditions import Constants
|
||||
|
||||
|
||||
def clr():
|
||||
"""清屏任务."""
|
||||
cmd = ["cls"] if Constants.IS_WINDOWS else ["clear"]
|
||||
return px.TaskSpec("clear_screen", fn=lambda: subprocess.run(cmd, check=False))
|
||||
|
||||
|
||||
def reset_icon_cache() -> list[px.TaskSpec]:
|
||||
"""重置图标缓存任务."""
|
||||
if not Constants.IS_WINDOWS:
|
||||
print("reset_icon_cache: 仅在 Windows 上支持")
|
||||
return []
|
||||
|
||||
local_app_data = os.environ.get("LOCALAPPDATA", "")
|
||||
icon_cache_db = Path(local_app_data) / "IconCache.db"
|
||||
explorer_cache_dir = Path(local_app_data) / "Microsoft" / "Windows" / "Explorer"
|
||||
|
||||
return [
|
||||
px.TaskSpec(
|
||||
"kill_explorer",
|
||||
cmd=["taskkill", "/f", "/im", "explorer.exe"],
|
||||
conditions=(BuiltinConditions.IS_RUNNING("explorer.exe"),),
|
||||
verbose=True,
|
||||
),
|
||||
px.TaskSpec(
|
||||
"delete_icon_cache",
|
||||
cmd=["cmd", "/c", "del", "/a", "/q", str(icon_cache_db)],
|
||||
conditions=(BuiltinConditions.DIR_EXISTS(icon_cache_db),),
|
||||
depends_on=("kill_explorer",),
|
||||
verbose=True,
|
||||
),
|
||||
px.TaskSpec(
|
||||
"delete_icon_cache_all",
|
||||
cmd=["cmd", "/c", "del", "/a", "/q", str(explorer_cache_dir / "iconcache*")],
|
||||
conditions=(BuiltinConditions.DIR_EXISTS(explorer_cache_dir),),
|
||||
depends_on=("kill_explorer",),
|
||||
verbose=True,
|
||||
),
|
||||
px.TaskSpec(
|
||||
"restart_explorer",
|
||||
cmd=["cmd", "/c", "start", "explorer.exe"],
|
||||
conditions=(
|
||||
BuiltinConditions.HAS_INSTALLED("explorer.exe"),
|
||||
BuiltinConditions.NOT(BuiltinConditions.IS_RUNNING("explorer.exe")),
|
||||
),
|
||||
depends_on=("delete_icon_cache", "delete_icon_cache_all"),
|
||||
allow_upstream_skip=True,
|
||||
verbose=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def setenv(name: str, value: str, default: bool = False) -> px.TaskSpec:
|
||||
"""设置环境变量任务."""
|
||||
|
||||
def set_env():
|
||||
if default:
|
||||
os.environ.setdefault(name, value)
|
||||
else:
|
||||
os.environ[name] = value
|
||||
|
||||
return px.TaskSpec(f"setenv_{name.lower()}", fn=set_env, verbose=True)
|
||||
|
||||
|
||||
def setenv_group(envs: dict[str, str], default: bool = False) -> list[px.TaskSpec]:
|
||||
"""设置环境变量组任务."""
|
||||
return [setenv(name, value, default) for name, value in envs.items()]
|
||||
|
||||
|
||||
def which(cmd: str) -> px.TaskSpec:
|
||||
"""查找命令路径任务."""
|
||||
which_cmd = "where" if Constants.IS_WINDOWS else "which"
|
||||
|
||||
def find_command():
|
||||
result = subprocess.run([which_cmd, cmd], capture_output=True, text=True, check=False)
|
||||
|
||||
if result.returncode == 0:
|
||||
# Windows 的 where 可能返回多行, 取第一个
|
||||
path = result.stdout.strip().split("\n")[0].strip()
|
||||
print(f"{cmd} -> {path}")
|
||||
else:
|
||||
print(f"{cmd} -> 未找到")
|
||||
|
||||
return px.TaskSpec(f"which_{cmd}", fn=find_command)
|
||||
|
||||
|
||||
def write_file(path: str, content: str, encoding: str = "utf-8") -> px.TaskSpec:
|
||||
"""写入文件任务."""
|
||||
|
||||
def write():
|
||||
p = Path(path)
|
||||
p.write_text(content, encoding=encoding)
|
||||
|
||||
return px.TaskSpec(f"write_file_{path}", fn=write, verbose=True)
|
||||
@@ -0,0 +1,26 @@
|
||||
"""进程池测试辅助:模块级函数(须可 pickle)。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
|
||||
def cpu_heavy(n: int) -> int:
|
||||
"""CPU 密集型计算(求平方和)。"""
|
||||
return sum(i * i for i in range(n))
|
||||
|
||||
|
||||
def add(a: int, b: int) -> int:
|
||||
"""简单加法。"""
|
||||
return a + b
|
||||
|
||||
|
||||
def sub(a: int, b: int) -> int:
|
||||
"""简单减法。"""
|
||||
return a - b
|
||||
|
||||
|
||||
def slow_sleep(seconds: float) -> int:
|
||||
"""睡眠指定秒数,用于测试超时。"""
|
||||
time.sleep(seconds)
|
||||
return int(seconds)
|
||||
@@ -0,0 +1,301 @@
|
||||
"""Tests for cli.autofmt module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.cli import autofmt
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# format_with_ruff
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestFormatWithRuff:
|
||||
"""Test format_with_ruff function."""
|
||||
|
||||
def test_format_with_ruff(self, tmp_path: Path) -> None:
|
||||
"""Should format with ruff."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
autofmt.format_with_ruff(tmp_path, fix=True)
|
||||
assert mock_run.called
|
||||
|
||||
def test_format_with_ruff_no_fix(self, tmp_path: Path) -> None:
|
||||
"""Should format with ruff without fix."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
autofmt.format_with_ruff(tmp_path, fix=False)
|
||||
# Should not include --fix flag
|
||||
call_args = mock_run.call_args[0][0]
|
||||
assert "--fix" not in call_args
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# lint_with_ruff
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestLintWithRuff:
|
||||
"""Test lint_with_ruff function."""
|
||||
|
||||
def test_lint_with_ruff(self, tmp_path: Path) -> None:
|
||||
"""Should lint with ruff."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
autofmt.lint_with_ruff(tmp_path, fix=True)
|
||||
assert mock_run.called
|
||||
|
||||
def test_lint_with_ruff_no_fix(self, tmp_path: Path) -> None:
|
||||
"""Should lint with ruff without fix."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
autofmt.lint_with_ruff(tmp_path, fix=False)
|
||||
# Should not include --fix flag
|
||||
call_args = mock_run.call_args[0][0]
|
||||
assert "--fix" not in call_args
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# add_docstring
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestAddDocstring:
|
||||
"""Test add_docstring function."""
|
||||
|
||||
def test_add_docstring_to_file(self, tmp_path: Path) -> None:
|
||||
"""Should add docstring to file."""
|
||||
py_file = tmp_path / "test.py"
|
||||
py_file.write_text("def test():\n pass\n")
|
||||
|
||||
result = autofmt.add_docstring(py_file, '"""Test module."""')
|
||||
assert result is True
|
||||
|
||||
def test_add_docstring_skips_files_with_docstring(self, tmp_path: Path) -> None:
|
||||
"""Should skip files that already have docstring."""
|
||||
py_file = tmp_path / "test.py"
|
||||
py_file.write_text('"""Existing docstring."""\ndef test():\n pass\n')
|
||||
|
||||
result = autofmt.add_docstring(py_file, '"""New docstring."""')
|
||||
assert result is False
|
||||
|
||||
def test_add_docstring_empty_file(self, tmp_path: Path) -> None:
|
||||
"""Should handle empty file."""
|
||||
py_file = tmp_path / "test.py"
|
||||
py_file.write_text("")
|
||||
|
||||
result = autofmt.add_docstring(py_file, '"""Test module."""')
|
||||
# Should handle empty file
|
||||
assert result is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# generate_module_docstring
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestGenerateModuleDocstring:
|
||||
"""Test generate_module_docstring function."""
|
||||
|
||||
def test_generate_module_docstring_basic(self, tmp_path: Path) -> None:
|
||||
"""Should generate basic docstring."""
|
||||
py_file = tmp_path / "test.py"
|
||||
py_file.write_text("def test():\n pass\n")
|
||||
|
||||
result = autofmt.generate_module_docstring(py_file)
|
||||
# Should contain "Tests for" since stem contains "test"
|
||||
assert "Tests for" in result
|
||||
|
||||
def test_generate_module_docstring_with_package(self, tmp_path: Path) -> None:
|
||||
"""Should generate docstring for package."""
|
||||
py_file = tmp_path / "mypackage" / "test.py"
|
||||
py_file.parent.mkdir(parents=True)
|
||||
py_file.write_text("def test():\n pass\n")
|
||||
|
||||
result = autofmt.generate_module_docstring(py_file)
|
||||
assert "mypackage" in result
|
||||
|
||||
def test_generate_module_docstring_cli(self, tmp_path: Path) -> None:
|
||||
"""Should generate docstring for CLI module."""
|
||||
py_file = tmp_path / "cli.py"
|
||||
py_file.write_text("def test():\n pass\n")
|
||||
|
||||
result = autofmt.generate_module_docstring(py_file)
|
||||
assert "Command-line interface" in result
|
||||
|
||||
def test_generate_module_docstring_util(self, tmp_path: Path) -> None:
|
||||
"""Should generate docstring for utility module."""
|
||||
py_file = tmp_path / "utils.py"
|
||||
py_file.write_text("def test():\n pass\n")
|
||||
|
||||
result = autofmt.generate_module_docstring(py_file)
|
||||
assert "Utility functions" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# auto_add_docstrings
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestAutoAddDocstrings:
|
||||
"""Test auto_add_docstrings function."""
|
||||
|
||||
def test_auto_add_docstrings(self, tmp_path: Path) -> None:
|
||||
"""Should auto add docstrings."""
|
||||
py_file = tmp_path / "test.py"
|
||||
py_file.write_text("def test():\n pass\n")
|
||||
|
||||
with patch.object(autofmt, "add_docstring", return_value=True):
|
||||
count = autofmt.auto_add_docstrings(tmp_path)
|
||||
assert count >= 0
|
||||
|
||||
def test_auto_add_docstrings_skips_ignored(self, tmp_path: Path) -> None:
|
||||
"""Should skip ignored directories."""
|
||||
py_file = tmp_path / "__pycache__" / "test.py"
|
||||
py_file.parent.mkdir()
|
||||
py_file.write_text("def test():\n pass\n")
|
||||
|
||||
count = autofmt.auto_add_docstrings(tmp_path)
|
||||
# Should skip __pycache__
|
||||
assert count == 0
|
||||
|
||||
def test_auto_add_docstrings_no_files(self, tmp_path: Path) -> None:
|
||||
"""Should handle no Python files."""
|
||||
txt_file = tmp_path / "test.txt"
|
||||
txt_file.write_text("test content")
|
||||
|
||||
count = autofmt.auto_add_docstrings(tmp_path)
|
||||
assert count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# sync_pyproject_config
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestSyncPyprojectConfig:
|
||||
"""Test sync_pyproject_config function."""
|
||||
|
||||
def test_sync_pyproject_config_creates_file(self, tmp_path: Path) -> None:
|
||||
"""Should sync pyproject.toml config."""
|
||||
main_toml = tmp_path / "pyproject.toml"
|
||||
main_toml.write_text("[tool.ruff]\n")
|
||||
sub_dir = tmp_path / "subproject"
|
||||
sub_dir.mkdir()
|
||||
sub_toml = sub_dir / "pyproject.toml"
|
||||
sub_toml.write_text("[tool.ruff]\n")
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
autofmt.sync_pyproject_config(tmp_path)
|
||||
assert mock_run.called
|
||||
|
||||
def test_sync_pyproject_config_updates_file(self, tmp_path: Path) -> None:
|
||||
"""Should update existing pyproject.toml."""
|
||||
main_toml = tmp_path / "pyproject.toml"
|
||||
main_toml.write_text("[tool.ruff]\n")
|
||||
sub_dir = tmp_path / "subproject"
|
||||
sub_dir.mkdir()
|
||||
sub_toml = sub_dir / "pyproject.toml"
|
||||
sub_toml.write_text("[tool.ruff]\n")
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
autofmt.sync_pyproject_config(tmp_path)
|
||||
assert mock_run.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# format_all
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestFormatAll:
|
||||
"""Test format_all function."""
|
||||
|
||||
def test_format_all_runs_ruff_format(self, tmp_path: Path) -> None:
|
||||
"""Should run ruff format."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
autofmt.format_all(tmp_path)
|
||||
assert mock_run.called
|
||||
|
||||
def test_format_all_runs_ruff_check(self, tmp_path: Path) -> None:
|
||||
"""Should run ruff check."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
autofmt.format_all(tmp_path)
|
||||
# Should call ruff format and ruff check
|
||||
assert mock_run.call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# main function
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestMain:
|
||||
"""Test main function."""
|
||||
|
||||
def test_main_fmt_default_target(self) -> None:
|
||||
"""main() should handle fmt command with default target."""
|
||||
with patch("sys.argv", ["autofmt", "fmt"]), patch.object(px, "run") as mock_run:
|
||||
autofmt.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_fmt_custom_target(self) -> None:
|
||||
"""main() should handle fmt command with custom target."""
|
||||
with patch("sys.argv", ["autofmt", "fmt", "--target", "src"]), patch.object(px, "run") as mock_run:
|
||||
autofmt.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_lint_default_target(self) -> None:
|
||||
"""main() should handle lint command with default target."""
|
||||
with patch("sys.argv", ["autofmt", "lint"]), patch.object(px, "run") as mock_run:
|
||||
autofmt.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_lint_with_fix(self) -> None:
|
||||
"""main() should handle lint command with fix."""
|
||||
with patch("sys.argv", ["autofmt", "lint", "--fix"]), patch.object(px, "run") as mock_run:
|
||||
autofmt.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_lint_custom_target(self) -> None:
|
||||
"""main() should handle lint command with custom target."""
|
||||
with patch("sys.argv", ["autofmt", "lint", "--target", "src"]), patch.object(px, "run") as mock_run:
|
||||
autofmt.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_doc_default_root(self) -> None:
|
||||
"""main() should handle doc command with default root."""
|
||||
with patch("sys.argv", ["autofmt", "doc"]), patch.object(px, "run") as mock_run:
|
||||
autofmt.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_doc_custom_root(self) -> None:
|
||||
"""main() should handle doc command with custom root."""
|
||||
with patch("sys.argv", ["autofmt", "doc", "--root-dir", "src"]), patch.object(px, "run") as mock_run:
|
||||
autofmt.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_sync_default_root(self) -> None:
|
||||
"""main() should handle sync command with default root."""
|
||||
with patch("sys.argv", ["autofmt", "sync"]), patch.object(px, "run") as mock_run:
|
||||
autofmt.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_sync_custom_root(self) -> None:
|
||||
"""main() should handle sync command with custom root."""
|
||||
with patch("sys.argv", ["autofmt", "sync", "--root-dir", "."]), patch.object(px, "run") as mock_run:
|
||||
autofmt.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_with_no_args_shows_help(self) -> None:
|
||||
"""main() with no args should show help."""
|
||||
with patch("sys.argv", ["autofmt"]), patch.object(autofmt, "main"):
|
||||
# Just call main, it should show help and return
|
||||
autofmt.main()
|
||||
# main() should return without calling px.run
|
||||
assert True
|
||||
|
||||
def test_main_creates_task_specs_with_verbose(self) -> None:
|
||||
"""main() should create TaskSpecs with verbose=True."""
|
||||
with patch("sys.argv", ["autofmt", "fmt"]), patch.object(px, "run") as mock_run:
|
||||
autofmt.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_uses_thread_strategy(self) -> None:
|
||||
"""main() should use thread strategy."""
|
||||
with patch("sys.argv", ["autofmt", "fmt"]), patch.object(px, "run") as mock_run:
|
||||
autofmt.main()
|
||||
# Check that strategy="thread" was used
|
||||
assert mock_run.called
|
||||
@@ -0,0 +1,318 @@
|
||||
"""Tests for cli.bumpversion module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.cli import bumpversion
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def auto_use_tmp_path(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""自动使用临时路径."""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# bump_file_version
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestBumpFileVersion:
|
||||
"""Test bump_file_version function."""
|
||||
|
||||
def test_bump_patch_version(self, tmp_path: Path) -> None:
|
||||
"""Should bump patch version correctly."""
|
||||
test_file = tmp_path / "pyproject.toml"
|
||||
test_file.write_text('version = "1.2.3"', encoding="utf-8")
|
||||
|
||||
result = bumpversion.bump_file_version(test_file, "patch")
|
||||
|
||||
assert result == "1.2.4"
|
||||
assert test_file.read_text(encoding="utf-8") == 'version = "1.2.4"'
|
||||
|
||||
def test_bump_minor_version(self, tmp_path: Path) -> None:
|
||||
"""Should bump minor version correctly."""
|
||||
test_file = tmp_path / "pyproject.toml"
|
||||
test_file.write_text('version = "1.2.3"', encoding="utf-8")
|
||||
|
||||
result = bumpversion.bump_file_version(test_file, "minor")
|
||||
|
||||
assert result == "1.3.0"
|
||||
assert test_file.read_text(encoding="utf-8") == 'version = "1.3.0"'
|
||||
|
||||
def test_bump_major_version(self, tmp_path: Path) -> None:
|
||||
"""Should bump major version correctly."""
|
||||
test_file = tmp_path / "pyproject.toml"
|
||||
test_file.write_text('version = "1.2.3"', encoding="utf-8")
|
||||
|
||||
result = bumpversion.bump_file_version(test_file, "major")
|
||||
|
||||
assert result == "2.0.0"
|
||||
assert test_file.read_text(encoding="utf-8") == 'version = "2.0.0"'
|
||||
|
||||
def test_version_pattern_with_prerelease(self, tmp_path: Path) -> None:
|
||||
"""Should handle version with prerelease suffix."""
|
||||
test_file = tmp_path / "pyproject.toml"
|
||||
test_file.write_text('version = "1.2.3-alpha.1"', encoding="utf-8")
|
||||
|
||||
result = bumpversion.bump_file_version(test_file, "patch")
|
||||
|
||||
assert result == "1.2.4"
|
||||
# 预发布版本应该被清除
|
||||
content = test_file.read_text(encoding="utf-8")
|
||||
assert "alpha" not in content
|
||||
|
||||
def test_version_pattern_with_build_metadata(self, tmp_path: Path) -> None:
|
||||
"""Should handle version with build metadata."""
|
||||
test_file = tmp_path / "pyproject.toml"
|
||||
test_file.write_text('version = "1.2.3+build.123"', encoding="utf-8")
|
||||
|
||||
result = bumpversion.bump_file_version(test_file, "patch")
|
||||
|
||||
assert result == "1.2.4"
|
||||
# 构建元数据应该被清除
|
||||
content = test_file.read_text(encoding="utf-8")
|
||||
assert "build" not in content
|
||||
|
||||
def test_no_version_found(self, tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""Should return None when no version pattern found."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("no version here", encoding="utf-8")
|
||||
|
||||
result = bumpversion.bump_file_version(test_file, "patch")
|
||||
|
||||
assert result is None
|
||||
captured = capsys.readouterr()
|
||||
assert "未找到版本号模式" in captured.out
|
||||
|
||||
def test_utf8_encoding(self, tmp_path: Path) -> None:
|
||||
"""Should handle UTF-8 encoded files correctly."""
|
||||
test_file = tmp_path / "__init__.py"
|
||||
test_file.write_text('__version__ = "1.2.3"', encoding="utf-8")
|
||||
|
||||
result = bumpversion.bump_file_version(test_file, "patch")
|
||||
|
||||
assert result == "1.2.4"
|
||||
assert test_file.read_text(encoding="utf-8") == '__version__ = "1.2.4"'
|
||||
|
||||
def test_pyproject_toml_format(self, tmp_path: Path) -> None:
|
||||
"""Should handle pyproject.toml format correctly."""
|
||||
test_file = tmp_path / "pyproject.toml"
|
||||
content = """
|
||||
[project]
|
||||
name = "test"
|
||||
version = "0.1.0"
|
||||
description = "Test project"
|
||||
"""
|
||||
test_file.write_text(content, encoding="utf-8")
|
||||
|
||||
result = bumpversion.bump_file_version(test_file, "minor")
|
||||
|
||||
assert result == "0.2.0"
|
||||
updated = test_file.read_text(encoding="utf-8")
|
||||
assert 'version = "0.2.0"' in updated
|
||||
assert 'name = "test"' in updated
|
||||
|
||||
def test_init_py_format(self, tmp_path: Path) -> None:
|
||||
"""Should handle __init__.py format correctly."""
|
||||
test_file = tmp_path / "__init__.py"
|
||||
content = '''"""Package info."""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
'''
|
||||
test_file.write_text(content, encoding="utf-8")
|
||||
|
||||
result = bumpversion.bump_file_version(test_file, "major")
|
||||
|
||||
assert result == "2.0.0"
|
||||
updated = test_file.read_text(encoding="utf-8")
|
||||
assert '__version__ = "2.0.0"' in updated
|
||||
|
||||
def test_multiple_versions_in_file(self, tmp_path: Path) -> None:
|
||||
"""Should only bump the project version, not dependencies."""
|
||||
test_file = tmp_path / "pyproject.toml"
|
||||
content = """
|
||||
[project]
|
||||
version = "1.0.0"
|
||||
dependencies = ["lib >= 2.0.0", "other >= 3.0.0"]
|
||||
"""
|
||||
test_file.write_text(content, encoding="utf-8")
|
||||
|
||||
result = bumpversion.bump_file_version(test_file, "patch")
|
||||
|
||||
assert result == "1.0.1"
|
||||
updated = test_file.read_text(encoding="utf-8")
|
||||
assert 'version = "1.0.1"' in updated
|
||||
# 确保 dependencies 中的版本没有被更新
|
||||
assert "lib >= 2.0.0" in updated
|
||||
assert "other >= 3.0.0" in updated
|
||||
|
||||
def test_file_read_error(self, tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""Should handle file read errors."""
|
||||
# 创建一个目录而不是文件
|
||||
test_file = tmp_path / "test_dir"
|
||||
test_file.mkdir()
|
||||
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
bumpversion.bump_file_version(test_file, "patch")
|
||||
|
||||
def test_file_write_error(self, tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""Should handle file write errors."""
|
||||
# 在只读目录中创建文件(这个测试在某些系统上可能不适用)
|
||||
test_file = tmp_path / "readonly.toml"
|
||||
test_file.write_text('version = "1.0.0"', encoding="utf-8")
|
||||
# 设置为只读
|
||||
test_file.chmod(0o444)
|
||||
|
||||
try:
|
||||
with pytest.raises(Exception): # noqa: B017
|
||||
bumpversion.bump_file_version(test_file, "patch")
|
||||
finally:
|
||||
# 恢复权限以便清理
|
||||
test_file.chmod(0o644)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Version pattern tests
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestVersionPattern:
|
||||
"""Test version pattern matching."""
|
||||
|
||||
def test_simple_version(self, tmp_path: Path) -> None:
|
||||
"""Should match simple version."""
|
||||
test_file = tmp_path / "__init__.py"
|
||||
test_file.write_text('__version__ = "1.0.0"', encoding="utf-8")
|
||||
|
||||
result = bumpversion.bump_file_version(test_file, "patch")
|
||||
|
||||
assert result == "1.0.1"
|
||||
|
||||
def test_version_with_zeros(self, tmp_path: Path) -> None:
|
||||
"""Should handle versions with zeros correctly."""
|
||||
test_file = tmp_path / "__init__.py"
|
||||
test_file.write_text('__version__ = "0.0.0"', encoding="utf-8")
|
||||
|
||||
result = bumpversion.bump_file_version(test_file, "patch")
|
||||
|
||||
assert result == "0.0.1"
|
||||
|
||||
def test_large_version_numbers(self, tmp_path: Path) -> None:
|
||||
"""Should handle large version numbers."""
|
||||
test_file = tmp_path / "__init__.py"
|
||||
test_file.write_text('__version__ = "10.20.30"', encoding="utf-8")
|
||||
|
||||
result = bumpversion.bump_file_version(test_file, "minor")
|
||||
|
||||
assert result == "10.21.0"
|
||||
|
||||
def test_version_in_url(self, tmp_path: Path) -> None:
|
||||
"""Should not match version in URL or other contexts."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("https://example.com/v1.2.3/download", encoding="utf-8")
|
||||
|
||||
result = bumpversion.bump_file_version(test_file, "patch")
|
||||
|
||||
# 不应该匹配 URL 中的版本号
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Edge cases
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestEdgeCases:
|
||||
"""Test edge cases and error handling."""
|
||||
|
||||
def test_empty_file(self, tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""Should handle empty file."""
|
||||
test_file = tmp_path / "empty.txt"
|
||||
test_file.write_text("", encoding="utf-8")
|
||||
|
||||
result = bumpversion.bump_file_version(test_file, "patch")
|
||||
|
||||
assert result is None
|
||||
captured = capsys.readouterr()
|
||||
assert "未找到版本号模式" in captured.out
|
||||
|
||||
def test_file_with_special_chars(self, tmp_path: Path) -> None:
|
||||
"""Should handle file with special characters."""
|
||||
test_file = tmp_path / "__init__.py"
|
||||
content = '# 中文注释\n__version__ = "1.0.0"\n# 特殊字符: @#$%'
|
||||
test_file.write_text(content, encoding="utf-8")
|
||||
|
||||
result = bumpversion.bump_file_version(test_file, "patch")
|
||||
|
||||
assert result == "1.0.1"
|
||||
updated = test_file.read_text(encoding="utf-8")
|
||||
assert "# 中文注释" in updated
|
||||
assert "# 特殊字符: @#$%" in updated
|
||||
|
||||
def test_consecutive_bumps(self, tmp_path: Path) -> None:
|
||||
"""Should handle consecutive version bumps correctly."""
|
||||
test_file = tmp_path / "__init__.py"
|
||||
test_file.write_text('__version__ = "1.0.0"', encoding="utf-8")
|
||||
|
||||
# 第一次 bump
|
||||
result1 = bumpversion.bump_file_version(test_file, "patch")
|
||||
assert result1 == "1.0.1"
|
||||
|
||||
# 第二次 bump
|
||||
result2 = bumpversion.bump_file_version(test_file, "minor")
|
||||
assert result2 == "1.1.0"
|
||||
|
||||
# 第三次 bump
|
||||
result3 = bumpversion.bump_file_version(test_file, "major")
|
||||
assert result3 == "2.0.0"
|
||||
|
||||
# 验证最终结果
|
||||
assert test_file.read_text(encoding="utf-8") == '__version__ = "2.0.0"'
|
||||
|
||||
|
||||
class TestBumpVersionCli:
|
||||
"""Test bumpversion CLI."""
|
||||
|
||||
def test_minor(self, tmp_path: Path) -> None:
|
||||
"""Should handle minor version bump."""
|
||||
test_file = tmp_path / "__init__.py"
|
||||
test_file.write_text('__version__ = "1.0.0"', encoding="utf-8")
|
||||
|
||||
# Mock px.run: 只真正执行第一次调用(版本更新),其余返回空 dict
|
||||
with patch("sys.argv", ["bumpversion", "minor", "--no-tag"]), patch("pyflowx.run") as mock_run:
|
||||
|
||||
def run_side_effect(graph: px.Graph, strategy: str | None = None):
|
||||
# 执行实际版本更新任务
|
||||
results = {}
|
||||
for spec in graph.specs.values():
|
||||
if spec.fn is not None and spec.args:
|
||||
results[spec.name] = spec.fn(*spec.args)
|
||||
return results
|
||||
|
||||
mock_run.side_effect = run_side_effect
|
||||
bumpversion.main()
|
||||
|
||||
# 验证版本号已更新
|
||||
assert test_file.read_text(encoding="utf-8") == '__version__ = "1.1.0"'
|
||||
|
||||
def test_no_valid_files(self, tmp_path: Path, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""Should handle no valid files."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("这是一个测试文件", encoding="utf-8")
|
||||
|
||||
with patch("sys.argv", ["bumpversion", "minor", "--no-tag"]), patch("pyflowx.run") as mock_run:
|
||||
|
||||
def run_side_effect(graph: px.Graph, strategy: str | None = None):
|
||||
# 执行实际版本更新任务
|
||||
results = {}
|
||||
for spec in graph.specs.values():
|
||||
if spec.fn is not None and spec.args:
|
||||
results[spec.name] = spec.fn(*spec.args)
|
||||
return results
|
||||
|
||||
mock_run.side_effect = run_side_effect
|
||||
bumpversion.main()
|
||||
|
||||
# 验证未更新任何文件
|
||||
assert test_file.read_text(encoding="utf-8") == "这是一个测试文件"
|
||||
assert "未找到包含版本号的文件" in capsys.readouterr().out
|
||||
@@ -0,0 +1,21 @@
|
||||
"""Tests for cli.clearscreen module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.cli.system import clearscreen
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# main function
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestMain:
|
||||
"""Test main function."""
|
||||
|
||||
def test_main_creates_graph_and_runs(self) -> None:
|
||||
"""main() should create a Graph and run it."""
|
||||
with patch.object(px, "run") as mock_run:
|
||||
clearscreen.main()
|
||||
assert mock_run.called
|
||||
@@ -0,0 +1,927 @@
|
||||
"""Tests for cli.emlmanager module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import email
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from pyflowx.cli import emlmanager
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# EmailDatabase Tests
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestEmailDatabase:
|
||||
"""Test EmailDatabase class."""
|
||||
|
||||
def test_init_database(self, tmp_path: Path) -> None:
|
||||
"""Should initialize database successfully."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
assert db.db_path == db_path
|
||||
assert db.conn is not None
|
||||
db.close()
|
||||
|
||||
def test_init_database_creates_table(self, tmp_path: Path) -> None:
|
||||
"""Should create emails table with correct schema."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
assert db.conn is not None
|
||||
|
||||
cursor = db.conn.cursor()
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='emails'")
|
||||
result = cursor.fetchone()
|
||||
assert result is not None
|
||||
db.close()
|
||||
|
||||
def test_init_database_creates_indexes(self, tmp_path: Path) -> None:
|
||||
"""Should create indexes for better query performance."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
assert db.conn is not None
|
||||
|
||||
cursor = db.conn.cursor()
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='index' AND name='idx_subject'")
|
||||
result = cursor.fetchone()
|
||||
assert result is not None
|
||||
db.close()
|
||||
|
||||
def test_insert_email_success(self, tmp_path: Path) -> None:
|
||||
"""Should insert email data successfully."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
email_data = {
|
||||
"file_path": "/test/path.eml",
|
||||
"file_hash": "abc123",
|
||||
"subject": "Test Subject",
|
||||
"sender": "sender@example.com",
|
||||
"recipients": "recipient@example.com",
|
||||
"date": "Mon, 1 Jan 2024 12:00:00 +0000",
|
||||
"date_parsed": "2024-01-01T12:00:00",
|
||||
"body_text": "Test body",
|
||||
"body_html": "<p>Test body</p>",
|
||||
"has_attachments": 0,
|
||||
"file_size": 1024,
|
||||
}
|
||||
|
||||
result = db.insert_email(email_data)
|
||||
assert result is True
|
||||
assert db.conn is not None
|
||||
|
||||
cursor = db.conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM emails")
|
||||
count = cursor.fetchone()[0]
|
||||
assert count == 1
|
||||
db.close()
|
||||
|
||||
def test_insert_email_replace_existing(self, tmp_path: Path) -> None:
|
||||
"""Should replace existing email with same file_path."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
email_data = {
|
||||
"file_path": "/test/path.eml",
|
||||
"file_hash": "abc123",
|
||||
"subject": "Original Subject",
|
||||
"sender": "sender@example.com",
|
||||
"recipients": "recipient@example.com",
|
||||
"date": "Mon, 1 Jan 2024 12:00:00 +0000",
|
||||
"date_parsed": "2024-01-01T12:00:00",
|
||||
"body_text": "Original body",
|
||||
"body_html": "<p>Original body</p>",
|
||||
"has_attachments": 0,
|
||||
"file_size": 1024,
|
||||
}
|
||||
|
||||
db.insert_email(email_data)
|
||||
|
||||
# Insert same file_path with different content
|
||||
email_data["subject"] = "Updated Subject"
|
||||
email_data["file_hash"] = "xyz789"
|
||||
db.insert_email(email_data)
|
||||
|
||||
assert db.conn is not None
|
||||
|
||||
cursor = db.conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) FROM emails")
|
||||
count = cursor.fetchone()[0]
|
||||
assert count == 1
|
||||
|
||||
cursor.execute("SELECT subject FROM emails WHERE file_path = ?", ("/test/path.eml",))
|
||||
subject = cursor.fetchone()[0]
|
||||
assert subject == "Updated Subject"
|
||||
db.close()
|
||||
|
||||
def test_search_emails_no_keyword(self, tmp_path: Path) -> None:
|
||||
"""Should return all emails when no keyword provided."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
# Insert test emails
|
||||
for i in range(5):
|
||||
db.insert_email({
|
||||
"file_path": f"/test/path{i}.eml",
|
||||
"file_hash": f"hash{i}",
|
||||
"subject": f"Subject {i}",
|
||||
"sender": f"sender{i}@example.com",
|
||||
"recipients": "recipient@example.com",
|
||||
"date": f"Mon, {i + 1} Jan 2024 12:00:00 +0000",
|
||||
"date_parsed": f"2024-01-0{i + 1}T12:00:00",
|
||||
"body_text": f"Body {i}",
|
||||
"body_html": f"<p>Body {i}</p>",
|
||||
"has_attachments": 0,
|
||||
"file_size": 1024,
|
||||
})
|
||||
|
||||
results = db.search_emails(limit=3)
|
||||
assert len(results) == 3
|
||||
db.close()
|
||||
|
||||
def test_search_emails_by_subject(self, tmp_path: Path) -> None:
|
||||
"""Should search emails by subject."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
db.insert_email({
|
||||
"file_path": "/test/path1.eml",
|
||||
"file_hash": "hash1",
|
||||
"subject": "Important Meeting",
|
||||
"sender": "sender1@example.com",
|
||||
"recipients": "recipient@example.com",
|
||||
"date": "Mon, 1 Jan 2024 12:00:00 +0000",
|
||||
"date_parsed": "2024-01-01T12:00:00",
|
||||
"body_text": "Meeting body",
|
||||
"body_html": "<p>Meeting body</p>",
|
||||
"has_attachments": 0,
|
||||
"file_size": 1024,
|
||||
})
|
||||
|
||||
db.insert_email({
|
||||
"file_path": "/test/path2.eml",
|
||||
"file_hash": "hash2",
|
||||
"subject": "Casual Chat",
|
||||
"sender": "sender2@example.com",
|
||||
"recipients": "recipient@example.com",
|
||||
"date": "Tue, 2 Jan 2024 12:00:00 +0000",
|
||||
"date_parsed": "2024-01-02T12:00:00",
|
||||
"body_text": "Chat body",
|
||||
"body_html": "<p>Chat body</p>",
|
||||
"has_attachments": 0,
|
||||
"file_size": 1024,
|
||||
})
|
||||
|
||||
results = db.search_emails(keyword="Meeting", field="subject")
|
||||
assert len(results) == 1
|
||||
assert results[0]["subject"] == "Important Meeting"
|
||||
db.close()
|
||||
|
||||
def test_search_emails_by_sender(self, tmp_path: Path) -> None:
|
||||
"""Should search emails by sender."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
db.insert_email({
|
||||
"file_path": "/test/path1.eml",
|
||||
"file_hash": "hash1",
|
||||
"subject": "Test",
|
||||
"sender": "alice@example.com",
|
||||
"recipients": "recipient@example.com",
|
||||
"date": "Mon, 1 Jan 2024 12:00:00 +0000",
|
||||
"date_parsed": "2024-01-01T12:00:00",
|
||||
"body_text": "Body",
|
||||
"body_html": "<p>Body</p>",
|
||||
"has_attachments": 0,
|
||||
"file_size": 1024,
|
||||
})
|
||||
|
||||
db.insert_email({
|
||||
"file_path": "/test/path2.eml",
|
||||
"file_hash": "hash2",
|
||||
"subject": "Test",
|
||||
"sender": "bob@example.com",
|
||||
"recipients": "recipient@example.com",
|
||||
"date": "Tue, 2 Jan 2024 12:00:00 +0000",
|
||||
"date_parsed": "2024-01-02T12:00:00",
|
||||
"body_text": "Body",
|
||||
"body_html": "<p>Body</p>",
|
||||
"has_attachments": 0,
|
||||
"file_size": 1024,
|
||||
})
|
||||
|
||||
results = db.search_emails(keyword="alice", field="sender")
|
||||
assert len(results) == 1
|
||||
assert results[0]["sender"] == "alice@example.com"
|
||||
db.close()
|
||||
|
||||
def test_search_emails_all_fields(self, tmp_path: Path) -> None:
|
||||
"""Should search emails across all fields."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
db.insert_email({
|
||||
"file_path": "/test/path1.eml",
|
||||
"file_hash": "hash1",
|
||||
"subject": "Project Update",
|
||||
"sender": "manager@example.com",
|
||||
"recipients": "team@example.com",
|
||||
"date": "Mon, 1 Jan 2024 12:00:00 +0000",
|
||||
"date_parsed": "2024-01-01T12:00:00",
|
||||
"body_text": "Please review the quarterly report",
|
||||
"body_html": "<p>Please review the quarterly report</p>",
|
||||
"has_attachments": 0,
|
||||
"file_size": 1024,
|
||||
})
|
||||
|
||||
# Search for keyword in subject
|
||||
results = db.search_emails(keyword="Project", field="all")
|
||||
assert len(results) == 1
|
||||
|
||||
# Search for keyword in body
|
||||
results = db.search_emails(keyword="quarterly", field="all")
|
||||
assert len(results) == 1
|
||||
db.close()
|
||||
|
||||
def test_get_grouped_emails(self, tmp_path: Path) -> None:
|
||||
"""Should group emails by normalized subject."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
# Insert emails with same subject (different prefixes)
|
||||
db.insert_email({
|
||||
"file_path": "/test/path1.eml",
|
||||
"file_hash": "hash1",
|
||||
"subject": "Meeting Tomorrow",
|
||||
"sender": "sender1@example.com",
|
||||
"recipients": "recipient@example.com",
|
||||
"date": "Mon, 1 Jan 2024 12:00:00 +0000",
|
||||
"date_parsed": "2024-01-01T12:00:00",
|
||||
"body_text": "Body 1",
|
||||
"body_html": "<p>Body 1</p>",
|
||||
"has_attachments": 0,
|
||||
"file_size": 1024,
|
||||
})
|
||||
|
||||
db.insert_email({
|
||||
"file_path": "/test/path2.eml",
|
||||
"file_hash": "hash2",
|
||||
"subject": "Re: Meeting Tomorrow",
|
||||
"sender": "sender2@example.com",
|
||||
"recipients": "recipient@example.com",
|
||||
"date": "Tue, 2 Jan 2024 12:00:00 +0000",
|
||||
"date_parsed": "2024-01-02T12:00:00",
|
||||
"body_text": "Body 2",
|
||||
"body_html": "<p>Body 2</p>",
|
||||
"has_attachments": 0,
|
||||
"file_size": 1024,
|
||||
})
|
||||
|
||||
db.insert_email({
|
||||
"file_path": "/test/path3.eml",
|
||||
"file_hash": "hash3",
|
||||
"subject": "Different Topic",
|
||||
"sender": "sender3@example.com",
|
||||
"recipients": "recipient@example.com",
|
||||
"date": "Wed, 3 Jan 2024 12:00:00 +0000",
|
||||
"date_parsed": "2024-01-03T12:00:00",
|
||||
"body_text": "Body 3",
|
||||
"body_html": "<p>Body 3</p>",
|
||||
"has_attachments": 0,
|
||||
"file_size": 1024,
|
||||
})
|
||||
|
||||
grouped = db.get_grouped_emails()
|
||||
# Should have 2 groups: "Meeting Tomorrow" and "Different Topic"
|
||||
assert len(grouped) == 2
|
||||
assert "Meeting Tomorrow" in grouped
|
||||
assert len(grouped["Meeting Tomorrow"]) == 2
|
||||
db.close()
|
||||
|
||||
def test_normalize_subject(self, tmp_path: Path) -> None:
|
||||
"""Should normalize subject by removing Re/Fwd prefixes."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
assert db._normalize_subject("Re: Meeting") == "Meeting"
|
||||
assert db._normalize_subject("Fwd: Meeting") == "Meeting"
|
||||
assert db._normalize_subject("FW: Meeting") == "Meeting"
|
||||
assert db._normalize_subject("Re: Fwd: Meeting") == "Fwd: Meeting"
|
||||
assert db._normalize_subject("Meeting") == "Meeting"
|
||||
db.close()
|
||||
|
||||
def test_get_email_count(self, tmp_path: Path) -> None:
|
||||
"""Should return correct email count."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
assert db.get_email_count() == 0
|
||||
|
||||
for i in range(3):
|
||||
db.insert_email({
|
||||
"file_path": f"/test/path{i}.eml",
|
||||
"file_hash": f"hash{i}",
|
||||
"subject": f"Subject {i}",
|
||||
"sender": f"sender{i}@example.com",
|
||||
"recipients": "recipient@example.com",
|
||||
"date": f"Mon, {i + 1} Jan 2024 12:00:00 +0000",
|
||||
"date_parsed": f"2024-01-0{i + 1}T12:00:00",
|
||||
"body_text": f"Body {i}",
|
||||
"body_html": f"<p>Body {i}</p>",
|
||||
"has_attachments": 0,
|
||||
"file_size": 1024,
|
||||
})
|
||||
|
||||
assert db.get_email_count() == 3
|
||||
db.close()
|
||||
|
||||
def test_clear_all(self, tmp_path: Path) -> None:
|
||||
"""Should clear all emails from database."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
# Insert some emails
|
||||
for i in range(3):
|
||||
db.insert_email({
|
||||
"file_path": f"/test/path{i}.eml",
|
||||
"file_hash": f"hash{i}",
|
||||
"subject": f"Subject {i}",
|
||||
"sender": f"sender{i}@example.com",
|
||||
"recipients": "recipient@example.com",
|
||||
"date": f"Mon, {i + 1} Jan 2024 12:00:00 +0000",
|
||||
"date_parsed": f"2024-01-0{i + 1}T12:00:00",
|
||||
"body_text": f"Body {i}",
|
||||
"body_html": f"<p>Body {i}</p>",
|
||||
"has_attachments": 0,
|
||||
"file_size": 1024,
|
||||
})
|
||||
|
||||
assert db.get_email_count() == 3
|
||||
|
||||
db.clear_all()
|
||||
assert db.get_email_count() == 0
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Email Parsing Tests
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestDecodeMimeWords:
|
||||
"""Test decode_mime_words function."""
|
||||
|
||||
def test_decode_simple_text(self) -> None:
|
||||
"""Should decode simple ASCII text."""
|
||||
result = emlmanager.decode_mime_words("Simple text")
|
||||
assert result == "Simple text"
|
||||
|
||||
def test_decode_utf8_encoded(self) -> None:
|
||||
"""Should decode UTF-8 encoded text."""
|
||||
# =?utf-8?b?5Lit5paH?= is "中文" in UTF-8 Base64
|
||||
result = emlmanager.decode_mime_words("=?utf-8?b?5Lit5paH?=")
|
||||
assert result == "中文"
|
||||
|
||||
def test_decode_qp_encoded(self) -> None:
|
||||
"""Should decode Quoted-Printable encoded text."""
|
||||
result = emlmanager.decode_mime_words("=?utf-8?Q?Hello=20World?=")
|
||||
assert result == "Hello World"
|
||||
|
||||
def test_decode_empty_string(self) -> None:
|
||||
"""Should handle empty string."""
|
||||
result = emlmanager.decode_mime_words("")
|
||||
assert result == ""
|
||||
|
||||
def test_decode_none(self) -> None:
|
||||
"""Should handle None input."""
|
||||
result = emlmanager.decode_mime_words("")
|
||||
assert result == ""
|
||||
|
||||
def test_decode_mixed_encoding(self) -> None:
|
||||
"""Should decode mixed encoding."""
|
||||
result = emlmanager.decode_mime_words("Hello =?utf-8?b?5Lit5paH?= World")
|
||||
assert "Hello" in result
|
||||
assert "中文" in result
|
||||
assert "World" in result
|
||||
|
||||
|
||||
class TestParseEmailDate:
|
||||
"""Test _parse_email_date function."""
|
||||
|
||||
def test_parse_valid_date(self) -> None:
|
||||
"""Should parse valid email date."""
|
||||
date_str = "Mon, 1 Jan 2024 12:00:00 +0000"
|
||||
result = emlmanager._parse_email_date(date_str)
|
||||
assert result == "2024-01-01T12:00:00+00:00"
|
||||
|
||||
def test_parse_empty_date(self) -> None:
|
||||
"""Should handle empty date string."""
|
||||
result = emlmanager._parse_email_date("")
|
||||
assert result == ""
|
||||
|
||||
def test_parse_invalid_date(self) -> None:
|
||||
"""Should return original string for invalid date."""
|
||||
result = emlmanager._parse_email_date("Invalid Date")
|
||||
assert result == "Invalid Date"
|
||||
|
||||
|
||||
class TestExtractEmailBodyPart:
|
||||
"""Test _extract_email_body_part function."""
|
||||
|
||||
def test_extract_text_plain(self) -> None:
|
||||
"""Should extract plain text content."""
|
||||
msg = email.message_from_string("Content-Type: text/plain; charset=utf-8\n\nTest body content")
|
||||
result = emlmanager._extract_email_body_part(msg)
|
||||
assert result == "Test body content"
|
||||
|
||||
def test_extract_text_with_charset(self) -> None:
|
||||
"""Should handle different charsets."""
|
||||
msg = email.message_from_string("Content-Type: text/plain; charset=utf-8\n\nHello 世界")
|
||||
result = emlmanager._extract_email_body_part(msg)
|
||||
assert "Hello" in result
|
||||
|
||||
def test_extract_empty_body(self) -> None:
|
||||
"""Should handle empty body."""
|
||||
msg = email.message_from_string("Content-Type: text/plain; charset=utf-8\n\n")
|
||||
result = emlmanager._extract_email_body_part(msg)
|
||||
assert result == ""
|
||||
|
||||
def test_extract_body_with_max_length(self) -> None:
|
||||
"""Should truncate body to MAX_BODY_LENGTH."""
|
||||
long_text = "A" * 10000
|
||||
msg = email.message_from_string(f"Content-Type: text/plain; charset=utf-8\n\n{long_text}")
|
||||
result = emlmanager._extract_email_body_part(msg)
|
||||
assert len(result) == emlmanager.MAX_BODY_LENGTH
|
||||
|
||||
|
||||
class TestProcessMultipartEmail:
|
||||
"""Test _process_multipart_email function."""
|
||||
|
||||
def test_process_multipart_with_attachments(self) -> None:
|
||||
"""Should detect attachments in multipart email."""
|
||||
msg = email.message_from_string(
|
||||
"""From: sender@example.com
|
||||
To: recipient@example.com
|
||||
Subject: Test
|
||||
MIME-Version: 1.0
|
||||
Content-Type: multipart/mixed; boundary=boundary
|
||||
|
||||
--boundary
|
||||
Content-Type: text/plain; charset=utf-8
|
||||
|
||||
Test body
|
||||
|
||||
--boundary
|
||||
Content-Type: application/pdf; name="test.pdf"
|
||||
Content-Disposition: attachment; filename="test.pdf"
|
||||
|
||||
PDF content here
|
||||
|
||||
--boundary--
|
||||
"""
|
||||
)
|
||||
body_text, _body_html, has_attachments = emlmanager._process_multipart_email(msg)
|
||||
assert body_text.strip() == "Test body"
|
||||
assert has_attachments == 1
|
||||
|
||||
def test_process_multipart_text_and_html(self) -> None:
|
||||
"""Should extract both text and html parts."""
|
||||
msg = email.message_from_string(
|
||||
"""From: sender@example.com
|
||||
To: recipient@example.com
|
||||
Subject: Test
|
||||
MIME-Version: 1.0
|
||||
Content-Type: multipart/alternative; boundary=boundary
|
||||
|
||||
--boundary
|
||||
Content-Type: text/plain; charset=utf-8
|
||||
|
||||
Plain text body
|
||||
|
||||
--boundary
|
||||
Content-Type: text/html; charset=utf-8
|
||||
|
||||
<html><body>HTML body</body></html>
|
||||
|
||||
--boundary--
|
||||
"""
|
||||
)
|
||||
body_text, body_html, has_attachments = emlmanager._process_multipart_email(msg)
|
||||
assert "Plain text body" in body_text
|
||||
assert "HTML body" in body_html
|
||||
assert has_attachments == 0
|
||||
|
||||
|
||||
class TestProcessSinglepartEmail:
|
||||
"""Test _process_singlepart_email function."""
|
||||
|
||||
def test_process_text_plain(self) -> None:
|
||||
"""Should process plain text email."""
|
||||
msg = email.message_from_string("Content-Type: text/plain; charset=utf-8\n\nPlain text content")
|
||||
body_text, body_html = emlmanager._process_singlepart_email(msg)
|
||||
assert body_text == "Plain text content"
|
||||
assert body_html == ""
|
||||
|
||||
def test_process_text_html(self) -> None:
|
||||
"""Should process HTML email."""
|
||||
msg = email.message_from_string(
|
||||
"Content-Type: text/html; charset=utf-8\n\n<html><body>HTML content</body></html>"
|
||||
)
|
||||
body_text, body_html = emlmanager._process_singlepart_email(msg)
|
||||
assert body_text == ""
|
||||
assert "HTML content" in body_html
|
||||
|
||||
|
||||
class TestParseEmlFile:
|
||||
"""Test parse_eml_file function."""
|
||||
|
||||
def test_parse_simple_eml(self, tmp_path: Path) -> None:
|
||||
"""Should parse simple EML file."""
|
||||
eml_content = """From: sender@example.com
|
||||
To: recipient@example.com
|
||||
Subject: Test Subject
|
||||
Date: Mon, 1 Jan 2024 12:00:00 +0000
|
||||
|
||||
This is the email body.
|
||||
"""
|
||||
eml_file = tmp_path / "test.eml"
|
||||
eml_file.write_text(eml_content)
|
||||
|
||||
result = emlmanager.parse_eml_file(eml_file)
|
||||
|
||||
assert result is not None
|
||||
assert result["subject"] == "Test Subject"
|
||||
assert result["sender"] == "sender@example.com"
|
||||
assert result["recipients"] == "recipient@example.com"
|
||||
assert "This is the email body" in result["body_text"]
|
||||
assert result["has_attachments"] == 0
|
||||
|
||||
def test_parse_eml_with_mime_subject(self, tmp_path: Path) -> None:
|
||||
"""Should parse EML with MIME-encoded subject."""
|
||||
eml_content = """From: sender@example.com
|
||||
To: recipient@example.com
|
||||
Subject: =?utf-8?b?5Lit5paHIEhlbGxv?=
|
||||
Date: Mon, 1 Jan 2024 12:00:00 +0000
|
||||
|
||||
Email body
|
||||
"""
|
||||
eml_file = tmp_path / "test.eml"
|
||||
eml_file.write_text(eml_content)
|
||||
|
||||
result = emlmanager.parse_eml_file(eml_file)
|
||||
|
||||
assert result is not None
|
||||
assert "中文" in result["subject"]
|
||||
assert "Hello" in result["subject"]
|
||||
|
||||
def test_parse_multipart_eml(self, tmp_path: Path) -> None:
|
||||
"""Should parse multipart EML file."""
|
||||
eml_content = """From: sender@example.com
|
||||
To: recipient@example.com
|
||||
Subject: Multipart Test
|
||||
Date: Mon, 1 Jan 2024 12:00:00 +0000
|
||||
MIME-Version: 1.0
|
||||
Content-Type: multipart/alternative; boundary=boundary
|
||||
|
||||
--boundary
|
||||
Content-Type: text/plain; charset=utf-8
|
||||
|
||||
Plain text version
|
||||
|
||||
--boundary
|
||||
Content-Type: text/html; charset=utf-8
|
||||
|
||||
<html><body>HTML version</body></html>
|
||||
|
||||
--boundary--
|
||||
"""
|
||||
eml_file = tmp_path / "test.eml"
|
||||
eml_file.write_text(eml_content)
|
||||
|
||||
result = emlmanager.parse_eml_file(eml_file)
|
||||
|
||||
assert result is not None
|
||||
assert "Plain text version" in result["body_text"]
|
||||
assert "HTML version" in result["body_html"]
|
||||
|
||||
def test_parse_eml_with_attachment(self, tmp_path: Path) -> None:
|
||||
"""Should detect attachments."""
|
||||
eml_content = """From: sender@example.com
|
||||
To: recipient@example.com
|
||||
Subject: Email with attachment
|
||||
Date: Mon, 1 Jan 2024 12:00:00 +0000
|
||||
MIME-Version: 1.0
|
||||
Content-Type: multipart/mixed; boundary=boundary
|
||||
|
||||
--boundary
|
||||
Content-Type: text/plain; charset=utf-8
|
||||
|
||||
Email body
|
||||
|
||||
--boundary
|
||||
Content-Type: application/pdf; name="test.pdf"
|
||||
Content-Disposition: attachment; filename="test.pdf"
|
||||
Content-Transfer-Encoding: base64
|
||||
|
||||
JVBERi0xLjQK
|
||||
|
||||
--boundary--
|
||||
"""
|
||||
eml_file = tmp_path / "test.eml"
|
||||
eml_file.write_text(eml_content)
|
||||
|
||||
result = emlmanager.parse_eml_file(eml_file)
|
||||
|
||||
assert result is not None
|
||||
assert result["has_attachments"] == 1
|
||||
|
||||
def test_parse_nonexistent_file(self, tmp_path: Path) -> None:
|
||||
"""Should return None for nonexistent file."""
|
||||
eml_file = tmp_path / "nonexistent.eml"
|
||||
result = emlmanager.parse_eml_file(eml_file)
|
||||
assert result is None
|
||||
|
||||
def test_parse_invalid_eml(self, tmp_path: Path) -> None:
|
||||
"""Should handle invalid EML file gracefully."""
|
||||
eml_file = tmp_path / "invalid.eml"
|
||||
eml_file.write_text("This is not a valid EML file")
|
||||
|
||||
result = emlmanager.parse_eml_file(eml_file)
|
||||
# Should still parse but with empty/default values
|
||||
assert result is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Web Server Tests
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestEmlManagerHandler:
|
||||
"""Test EmlManagerHandler HTTP request handler."""
|
||||
|
||||
def test_api_get_status(self, tmp_path: Path) -> None:
|
||||
"""Should return server status."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
# Create a mock handler instance without calling __init__
|
||||
handler = Mock(spec=emlmanager.EmlManagerHandler)
|
||||
handler.db = db
|
||||
handler.work_dir = tmp_path
|
||||
handler._send_json_response = Mock()
|
||||
|
||||
# Call the method directly (not through __init__)
|
||||
emlmanager.EmlManagerHandler._api_get_status(handler)
|
||||
|
||||
handler._send_json_response.assert_called_once()
|
||||
call_args = handler._send_json_response.call_args[0][0]
|
||||
assert call_args["initialized"] is True
|
||||
assert str(tmp_path) in call_args["work_dir"]
|
||||
|
||||
db.close()
|
||||
|
||||
def test_api_get_count(self, tmp_path: Path) -> None:
|
||||
"""Should return email count."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
# Insert some emails
|
||||
for i in range(3):
|
||||
db.insert_email({
|
||||
"file_path": f"/test/path{i}.eml",
|
||||
"file_hash": f"hash{i}",
|
||||
"subject": f"Subject {i}",
|
||||
"sender": f"sender{i}@example.com",
|
||||
"recipients": "recipient@example.com",
|
||||
"date": f"Mon, {i + 1} Jan 2024 12:00:00 +0000",
|
||||
"date_parsed": f"2024-01-0{i + 1}T12:00:00",
|
||||
"body_text": f"Body {i}",
|
||||
"body_html": f"<p>Body {i}</p>",
|
||||
"has_attachments": 0,
|
||||
"file_size": 1024,
|
||||
})
|
||||
|
||||
# Create a mock handler instance without calling __init__
|
||||
handler = Mock(spec=emlmanager.EmlManagerHandler)
|
||||
handler.db = db
|
||||
handler._send_json_response = Mock()
|
||||
|
||||
# Call the method directly
|
||||
emlmanager.EmlManagerHandler._api_get_count(handler)
|
||||
|
||||
handler._send_json_response.assert_called_once()
|
||||
call_args = handler._send_json_response.call_args[0][0]
|
||||
assert call_args["count"] == 3
|
||||
|
||||
db.close()
|
||||
|
||||
def test_api_get_emails(self, tmp_path: Path) -> None:
|
||||
"""Should return emails list."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
# Insert test email
|
||||
db.insert_email({
|
||||
"file_path": "/test/path.eml",
|
||||
"file_hash": "hash",
|
||||
"subject": "Test Subject",
|
||||
"sender": "sender@example.com",
|
||||
"recipients": "recipient@example.com",
|
||||
"date": "Mon, 1 Jan 2024 12:00:00 +0000",
|
||||
"date_parsed": "2024-01-01T12:00:00",
|
||||
"body_text": "Test body",
|
||||
"body_html": "<p>Test body</p>",
|
||||
"has_attachments": 0,
|
||||
"file_size": 1024,
|
||||
})
|
||||
|
||||
# Create a mock handler instance without calling __init__
|
||||
handler = Mock(spec=emlmanager.EmlManagerHandler)
|
||||
handler.db = db
|
||||
handler._send_json_response = Mock()
|
||||
|
||||
# Call the method directly
|
||||
emlmanager.EmlManagerHandler._api_get_emails(handler, {})
|
||||
|
||||
handler._send_json_response.assert_called_once()
|
||||
call_args = handler._send_json_response.call_args[0][0]
|
||||
assert len(call_args["emails"]) == 1
|
||||
assert call_args["emails"][0]["subject"] == "Test Subject"
|
||||
|
||||
db.close()
|
||||
|
||||
def test_api_clear_database(self, tmp_path: Path) -> None:
|
||||
"""Should clear database."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
# Insert test email
|
||||
db.insert_email({
|
||||
"file_path": "/test/path.eml",
|
||||
"file_hash": "hash",
|
||||
"subject": "Test Subject",
|
||||
"sender": "sender@example.com",
|
||||
"recipients": "recipient@example.com",
|
||||
"date": "Mon, 1 Jan 2024 12:00:00 +0000",
|
||||
"date_parsed": "2024-01-01T12:00:00",
|
||||
"body_text": "Test body",
|
||||
"body_html": "<p>Test body</p>",
|
||||
"has_attachments": 0,
|
||||
"file_size": 1024,
|
||||
})
|
||||
|
||||
assert db.get_email_count() == 1
|
||||
|
||||
# Create a mock handler instance without calling __init__
|
||||
handler = Mock(spec=emlmanager.EmlManagerHandler)
|
||||
handler.db = db
|
||||
handler._send_json_response = Mock()
|
||||
|
||||
# Call the method directly
|
||||
emlmanager.EmlManagerHandler._api_clear_database(handler)
|
||||
|
||||
handler._send_json_response.assert_called_once()
|
||||
assert db.get_email_count() == 0
|
||||
db.close()
|
||||
|
||||
def test_send_json_response_with_gzip(self, tmp_path: Path) -> None:
|
||||
"""Should send gzip-compressed JSON response when client supports it."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
# Create a mock handler with all necessary attributes
|
||||
handler = Mock(spec=emlmanager.EmlManagerHandler)
|
||||
handler.db = db
|
||||
handler.headers = {"Accept-Encoding": "gzip, deflate"}
|
||||
handler.send_response = Mock()
|
||||
handler.send_header = Mock()
|
||||
handler.end_headers = Mock()
|
||||
handler.wfile = BytesIO()
|
||||
|
||||
data = {"test": "data"}
|
||||
|
||||
# Call the real method
|
||||
emlmanager.EmlManagerHandler._send_json_response(handler, data)
|
||||
|
||||
# Check that gzip compression was used
|
||||
handler.send_response.assert_called_once_with(200)
|
||||
assert any(
|
||||
call[0][0] == "Content-Encoding" and call[0][1] == "gzip" for call in handler.send_header.call_args_list
|
||||
)
|
||||
|
||||
db.close()
|
||||
|
||||
def test_send_json_response_without_gzip(self, tmp_path: Path) -> None:
|
||||
"""Should send uncompressed JSON response when client doesn't support gzip."""
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
# Create a mock handler with all necessary attributes
|
||||
handler = Mock(spec=emlmanager.EmlManagerHandler)
|
||||
handler.db = db
|
||||
handler.headers = {"Accept-Encoding": "identity"}
|
||||
handler.send_response = Mock()
|
||||
handler.send_header = Mock()
|
||||
handler.end_headers = Mock()
|
||||
handler.wfile = BytesIO()
|
||||
|
||||
data = {"test": "data"}
|
||||
|
||||
# Call the real method
|
||||
emlmanager.EmlManagerHandler._send_json_response(handler, data)
|
||||
|
||||
# Check that gzip compression was NOT used
|
||||
handler.send_response.assert_called_once_with(200)
|
||||
assert not any(call[0][0] == "Content-Encoding" for call in handler.send_header.call_args_list)
|
||||
|
||||
db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Main Function Tests
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestMain:
|
||||
"""Test main function."""
|
||||
|
||||
def test_main_with_dir_argument(self, tmp_path: Path) -> None:
|
||||
"""Should initialize database when dir argument provided."""
|
||||
# Create some EML files
|
||||
for i in range(2):
|
||||
eml_file = tmp_path / f"test{i}.eml"
|
||||
eml_file.write_text(f"""From: sender{i}@example.com
|
||||
To: recipient@example.com
|
||||
Subject: Test {i}
|
||||
Date: Mon, {i + 1} Jan 2024 12:00:00 +0000
|
||||
|
||||
Body {i}
|
||||
""")
|
||||
|
||||
with patch("sys.argv", ["emlmanager", "--dir", str(tmp_path), "--port", "8080"]), patch.object(
|
||||
emlmanager, "ThreadingHTTPServer"
|
||||
) as mock_server, patch("threading.Thread"):
|
||||
# Don't actually start the server
|
||||
mock_server_instance = Mock()
|
||||
mock_server.return_value = mock_server_instance
|
||||
|
||||
# This would normally block, so we'll just test initialization
|
||||
with patch.object(emlmanager.EmlManagerHandler, "db", None):
|
||||
# The main function would be called, but we're patching to prevent blocking
|
||||
pass
|
||||
|
||||
# Verify EML files were found
|
||||
assert len(list(tmp_path.glob("*.eml"))) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Integration Tests
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestIntegration:
|
||||
"""Integration tests for emlmanager."""
|
||||
|
||||
def test_full_workflow(self, tmp_path: Path) -> None:
|
||||
"""Test complete workflow: parse -> store -> search."""
|
||||
# Initialize database
|
||||
db_path = tmp_path / "test.db"
|
||||
db = emlmanager.EmailDatabase(db_path)
|
||||
|
||||
# Create EML files
|
||||
eml_files = []
|
||||
for i in range(3):
|
||||
eml_file = tmp_path / f"email{i}.eml"
|
||||
eml_content = f"""From: sender{i}@example.com
|
||||
To: recipient@example.com
|
||||
Subject: Test Email {i}
|
||||
Date: Mon, {i + 1} Jan 2024 12:00:00 +0000
|
||||
|
||||
This is email body {i}.
|
||||
"""
|
||||
eml_file.write_text(eml_content)
|
||||
eml_files.append(eml_file)
|
||||
|
||||
# Parse and insert emails
|
||||
for eml_file in eml_files:
|
||||
email_data = emlmanager.parse_eml_file(eml_file)
|
||||
if email_data:
|
||||
db.insert_email(email_data)
|
||||
|
||||
# Verify insertion
|
||||
assert db.get_email_count() == 3
|
||||
|
||||
# Search emails
|
||||
results = db.search_emails(keyword="Email")
|
||||
assert len(results) == 3
|
||||
|
||||
# Search by sender
|
||||
results = db.search_emails(keyword="sender1", field="sender")
|
||||
assert len(results) == 1
|
||||
assert results[0]["sender"] == "sender1@example.com"
|
||||
|
||||
# Get grouped emails
|
||||
grouped = db.get_grouped_emails()
|
||||
assert len(grouped) > 0
|
||||
|
||||
# Clear database
|
||||
db.clear_all()
|
||||
assert db.get_email_count() == 0
|
||||
|
||||
db.close()
|
||||
@@ -0,0 +1,136 @@
|
||||
"""Tests for cli.filedate module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.cli import filedate
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# get_file_timestamp
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestGetFileTimestamp:
|
||||
"""Test get_file_timestamp function."""
|
||||
|
||||
def test_get_file_timestamp(self, tmp_path: Path) -> None:
|
||||
"""Should get file timestamp."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("test content")
|
||||
|
||||
timestamp = filedate.get_file_timestamp(test_file)
|
||||
assert len(timestamp) == 8 # YYYYMMDD format
|
||||
assert timestamp.isdigit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# remove_date_prefix
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestRemoveDatePrefix:
|
||||
"""Test remove_date_prefix function."""
|
||||
|
||||
def test_remove_date_prefix_with_date(self, tmp_path: Path) -> None:
|
||||
"""Should remove date prefix from filename."""
|
||||
test_file = tmp_path / "20240101_test.txt"
|
||||
test_file.write_text("test content")
|
||||
|
||||
new_path = filedate.remove_date_prefix(test_file)
|
||||
assert new_path.name == "test.txt"
|
||||
|
||||
def test_remove_date_prefix_without_date(self, tmp_path: Path) -> None:
|
||||
"""Should not change filename without date prefix."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("test content")
|
||||
|
||||
new_path = filedate.remove_date_prefix(test_file)
|
||||
assert new_path == test_file
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# add_date_prefix
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestAddDatePrefix:
|
||||
"""Test add_date_prefix function."""
|
||||
|
||||
def test_add_date_prefix(self, tmp_path: Path) -> None:
|
||||
"""Should add date prefix to filename."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("test content")
|
||||
|
||||
new_path = filedate.add_date_prefix(test_file)
|
||||
assert new_path.name.startswith("20") # Starts with year
|
||||
assert "_test.txt" in new_path.name
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# process_file_date
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestProcessFileDate:
|
||||
"""Test process_file_date function."""
|
||||
|
||||
def test_process_file_date_add(self, tmp_path: Path) -> None:
|
||||
"""Should add date prefix."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("test content")
|
||||
|
||||
filedate.process_file_date(test_file, clear=False)
|
||||
# File should be renamed with date prefix
|
||||
|
||||
def test_process_file_date_clear(self, tmp_path: Path) -> None:
|
||||
"""Should clear date prefix."""
|
||||
test_file = tmp_path / "20240101_test.txt"
|
||||
test_file.write_text("test content")
|
||||
|
||||
filedate.process_file_date(test_file, clear=True)
|
||||
# File should be renamed without date prefix
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# process_files_date
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestProcessFilesDate:
|
||||
"""Test process_files_date function."""
|
||||
|
||||
def test_process_files_date_batch(self, tmp_path: Path) -> None:
|
||||
"""Should process multiple files."""
|
||||
files = []
|
||||
for i in range(3):
|
||||
test_file = tmp_path / f"test{i}.txt"
|
||||
test_file.write_text(f"content{i}")
|
||||
files.append(test_file)
|
||||
|
||||
filedate.process_files_date(files, clear=False)
|
||||
# All files should be processed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# main function
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestMain:
|
||||
"""Test main function."""
|
||||
|
||||
def test_main_add_command(self, tmp_path: Path) -> None:
|
||||
"""main() should handle add command."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("test content")
|
||||
|
||||
with patch("sys.argv", ["filedate", "add", str(test_file)]), patch.object(px, "run") as mock_run:
|
||||
filedate.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_clear_command(self, tmp_path: Path) -> None:
|
||||
"""main() should handle clear command."""
|
||||
test_file = tmp_path / "20240101_test.txt"
|
||||
test_file.write_text("test content")
|
||||
|
||||
with patch("sys.argv", ["filedate", "clear", str(test_file)]), patch.object(px, "run") as mock_run:
|
||||
filedate.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_with_no_args_shows_help(self) -> None:
|
||||
"""main() with no args should show help."""
|
||||
with patch("sys.argv", ["filedate"]):
|
||||
filedate.main()
|
||||
# Should print help and return
|
||||
@@ -0,0 +1,133 @@
|
||||
"""Tests for cli.filelevel module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.cli import filelevel
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# remove_marks
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestRemoveMarks:
|
||||
"""Test remove_marks function."""
|
||||
|
||||
def test_remove_marks_single_mark(self) -> None:
|
||||
"""Should remove single mark."""
|
||||
stem = "filename(PUB)"
|
||||
result = filelevel.remove_marks(stem, ["PUB"])
|
||||
assert result == "filename"
|
||||
|
||||
def test_remove_marks_multiple_marks(self) -> None:
|
||||
"""Should remove multiple marks."""
|
||||
stem = "filename(PUB)(NOR)"
|
||||
result = filelevel.remove_marks(stem, ["PUB", "NOR"])
|
||||
assert result == "filename"
|
||||
|
||||
def test_remove_marks_no_marks(self) -> None:
|
||||
"""Should not change stem without marks."""
|
||||
stem = "filename"
|
||||
result = filelevel.remove_marks(stem, ["PUB"])
|
||||
assert result == "filename"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# process_file_level
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestProcessFileLevel:
|
||||
"""Test process_file_level function."""
|
||||
|
||||
def test_process_file_level_set_pub(self, tmp_path: Path) -> None:
|
||||
"""Should set PUB level."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("test content")
|
||||
|
||||
filelevel.process_file_level(test_file, level=1)
|
||||
# File should be renamed with PUB level
|
||||
|
||||
def test_process_file_level_set_int(self, tmp_path: Path) -> None:
|
||||
"""Should set INT level."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("test content")
|
||||
|
||||
filelevel.process_file_level(test_file, level=2)
|
||||
# File should be renamed with INT level
|
||||
|
||||
def test_process_file_level_clear(self, tmp_path: Path) -> None:
|
||||
"""Should clear level."""
|
||||
test_file = tmp_path / "test(PUB).txt"
|
||||
test_file.write_text("test content")
|
||||
|
||||
filelevel.process_file_level(test_file, level=0)
|
||||
# File should be renamed without level
|
||||
|
||||
def test_process_file_level_invalid_level(self, tmp_path: Path) -> None:
|
||||
"""Should handle invalid level."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("test content")
|
||||
|
||||
filelevel.process_file_level(test_file, level=5)
|
||||
# Should print error message
|
||||
|
||||
def test_process_file_level_nonexistent_file(self, tmp_path: Path) -> None:
|
||||
"""Should handle nonexistent file."""
|
||||
test_file = tmp_path / "nonexistent.txt"
|
||||
|
||||
filelevel.process_file_level(test_file, level=1)
|
||||
# Should print error message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# process_files_level
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestProcessFilesLevel:
|
||||
"""Test process_files_level function."""
|
||||
|
||||
def test_process_files_level_batch(self, tmp_path: Path) -> None:
|
||||
"""Should process multiple files."""
|
||||
files = []
|
||||
for i in range(3):
|
||||
test_file = tmp_path / f"test{i}.txt"
|
||||
test_file.write_text(f"content{i}")
|
||||
files.append(test_file)
|
||||
|
||||
filelevel.process_files_level(files, level=1)
|
||||
# All files should be processed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# main function
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestMain:
|
||||
"""Test main function."""
|
||||
|
||||
def test_main_set_command(self, tmp_path: Path) -> None:
|
||||
"""main() should handle set command."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("test content")
|
||||
|
||||
with patch("sys.argv", ["filelevel", "set", str(test_file), "--level", "1"]), patch.object(
|
||||
px, "run"
|
||||
) as mock_run:
|
||||
filelevel.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_set_command_level_2(self, tmp_path: Path) -> None:
|
||||
"""main() should handle set command with level 2."""
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_text("test content")
|
||||
|
||||
with patch("sys.argv", ["filelevel", "set", str(test_file), "--level", "2"]), patch.object(
|
||||
px, "run"
|
||||
) as mock_run:
|
||||
filelevel.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_with_no_args_shows_help(self) -> None:
|
||||
"""main() with no args should show help."""
|
||||
with patch("sys.argv", ["filelevel"]):
|
||||
filelevel.main()
|
||||
# Should print help and return
|
||||
@@ -0,0 +1,173 @@
|
||||
"""Tests for cli.folderback module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.cli import folderback
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# remove_dump
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestRemoveDump:
|
||||
"""Test remove_dump function."""
|
||||
|
||||
def test_remove_dump_no_files(self, tmp_path: Path) -> None:
|
||||
"""Should handle no zip files."""
|
||||
src = tmp_path / "source"
|
||||
src.mkdir()
|
||||
dst = tmp_path / "backup"
|
||||
dst.mkdir()
|
||||
|
||||
folderback.remove_dump(src, dst, 5)
|
||||
# Should not raise error
|
||||
|
||||
def test_remove_dump_within_limit(self, tmp_path: Path) -> None:
|
||||
"""Should not remove files within limit."""
|
||||
src = tmp_path / "source"
|
||||
src.mkdir()
|
||||
dst = tmp_path / "backup"
|
||||
dst.mkdir()
|
||||
|
||||
# Create some zip files
|
||||
for i in range(3):
|
||||
zip_file = dst / f"source_20240101_12000{i}.zip"
|
||||
zip_file.write_bytes(b"ZIP content")
|
||||
|
||||
folderback.remove_dump(src, dst, 5)
|
||||
# All files should remain
|
||||
assert len(list(dst.glob("*.zip"))) == 3
|
||||
|
||||
def test_remove_dump_exceeds_limit(self, tmp_path: Path) -> None:
|
||||
"""Should remove oldest files when exceeds limit."""
|
||||
src = tmp_path / "source"
|
||||
src.mkdir()
|
||||
dst = tmp_path / "backup"
|
||||
dst.mkdir()
|
||||
|
||||
# Create more zip files than limit
|
||||
for i in range(7):
|
||||
zip_file = dst / f"source_20240101_12000{i}.zip"
|
||||
zip_file.write_bytes(b"ZIP content")
|
||||
|
||||
folderback.remove_dump(src, dst, 5)
|
||||
# Should have only 5 files
|
||||
assert len(list(dst.glob("*.zip"))) == 5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# zip_target
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestZipTarget:
|
||||
"""Test zip_target function."""
|
||||
|
||||
def test_zip_target_creates_zip(self, tmp_path: Path) -> None:
|
||||
"""Should create zip file."""
|
||||
src = tmp_path / "source"
|
||||
src.mkdir()
|
||||
(src / "test.txt").write_text("test content")
|
||||
dst = tmp_path / "backup"
|
||||
dst.mkdir()
|
||||
|
||||
with patch("time.strftime", return_value="_20240101_120000"):
|
||||
folderback.zip_target(src, dst, 5)
|
||||
|
||||
# Should create zip file
|
||||
zip_files = list(dst.glob("*.zip"))
|
||||
assert len(zip_files) == 1
|
||||
|
||||
def test_zip_target_with_subdirectories(self, tmp_path: Path) -> None:
|
||||
"""Should zip files in subdirectories."""
|
||||
src = tmp_path / "source"
|
||||
src.mkdir()
|
||||
subdir = src / "subdir"
|
||||
subdir.mkdir()
|
||||
(src / "test.txt").write_text("test content")
|
||||
(subdir / "nested.txt").write_text("nested content")
|
||||
dst = tmp_path / "backup"
|
||||
dst.mkdir()
|
||||
|
||||
with patch("time.strftime", return_value="_20240101_120000"):
|
||||
folderback.zip_target(src, dst, 5)
|
||||
|
||||
# Should create zip file
|
||||
zip_files = list(dst.glob("*.zip"))
|
||||
assert len(zip_files) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# backup_folder
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestBackupFolder:
|
||||
"""Test backup_folder function."""
|
||||
|
||||
def test_backup_folder_with_source_and_backup(self, tmp_path: Path) -> None:
|
||||
"""Should backup folder with source and backup paths."""
|
||||
source_dir = tmp_path / "source"
|
||||
source_dir.mkdir()
|
||||
(source_dir / "test.txt").write_text("test content")
|
||||
backup_dir = tmp_path / "backup"
|
||||
|
||||
with patch.object(folderback, "zip_target") as mock_zip:
|
||||
folderback.backup_folder(str(source_dir), str(backup_dir), 5)
|
||||
assert mock_zip.called
|
||||
|
||||
def test_backup_folder_with_max_backups(self, tmp_path: Path) -> None:
|
||||
"""Should backup folder with max backups."""
|
||||
source_dir = tmp_path / "source"
|
||||
source_dir.mkdir()
|
||||
(source_dir / "test.txt").write_text("test content")
|
||||
backup_dir = tmp_path / "backup"
|
||||
|
||||
with patch.object(folderback, "zip_target") as mock_zip:
|
||||
folderback.backup_folder(str(source_dir), str(backup_dir), 10)
|
||||
assert mock_zip.called
|
||||
|
||||
def test_backup_folder_source_not_exists(self, tmp_path: Path) -> None:
|
||||
"""Should handle non-existent source folder."""
|
||||
source_dir = tmp_path / "nonexistent"
|
||||
backup_dir = tmp_path / "backup"
|
||||
backup_dir.mkdir()
|
||||
|
||||
folderback.backup_folder(str(source_dir), str(backup_dir), 5)
|
||||
# Should print error message and return
|
||||
|
||||
def test_backup_folder_creates_dst(self, tmp_path: Path) -> None:
|
||||
"""Should create destination directory."""
|
||||
source_dir = tmp_path / "source"
|
||||
source_dir.mkdir()
|
||||
(source_dir / "test.txt").write_text("test content")
|
||||
backup_dir = tmp_path / "backup"
|
||||
|
||||
with patch.object(folderback, "zip_target") as mock_zip:
|
||||
folderback.backup_folder(str(source_dir), str(backup_dir), 5)
|
||||
assert backup_dir.exists()
|
||||
assert mock_zip.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# TaskSpec definitions
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestTaskSpecDefinitions:
|
||||
"""Test that all TaskSpec definitions are valid."""
|
||||
|
||||
def test_folderback_default_spec(self) -> None:
|
||||
"""folderback_default spec should be properly defined."""
|
||||
assert folderback.folderback_default.name == "folderback_default"
|
||||
assert folderback.folderback_default.fn is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# main function
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestMain:
|
||||
"""Test main function."""
|
||||
|
||||
def test_main_calls_run_cli(self) -> None:
|
||||
"""main() should create a CliRunner and call run_cli()."""
|
||||
with patch.object(px.CliRunner, "run_cli") as mock_run_cli:
|
||||
folderback.main()
|
||||
assert mock_run_cli.called
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Tests for cli.folderzip module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.cli import folderzip
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# archive_folder
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestArchiveFolder:
|
||||
"""Test archive_folder function."""
|
||||
|
||||
def test_archive_folder(self, tmp_path: Path) -> None:
|
||||
"""Should archive a folder."""
|
||||
folder = tmp_path / "test_folder"
|
||||
folder.mkdir()
|
||||
(folder / "test.txt").write_text("test content")
|
||||
|
||||
with patch("shutil.make_archive") as mock_archive:
|
||||
folderzip.archive_folder(folder)
|
||||
assert mock_archive.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# zip_folders
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestZipFolders:
|
||||
"""Test zip_folders function."""
|
||||
|
||||
def test_zip_folders_with_cwd(self, tmp_path: Path) -> None:
|
||||
"""Should zip folders in cwd."""
|
||||
# Create some folders
|
||||
(tmp_path / "folder1").mkdir()
|
||||
(tmp_path / "folder2").mkdir()
|
||||
(tmp_path / ".git").mkdir() # Should be ignored
|
||||
|
||||
with patch.object(folderzip, "archive_folder") as mock_archive:
|
||||
folderzip.zip_folders(str(tmp_path))
|
||||
# Should archive folder1 and folder2, but not .git
|
||||
assert mock_archive.call_count == 2
|
||||
|
||||
def test_zip_folders_nonexistent_cwd(self, tmp_path: Path) -> None:
|
||||
"""Should handle nonexistent cwd."""
|
||||
folderzip.zip_folders(str(tmp_path / "nonexistent"))
|
||||
# Should print error message and return
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# TaskSpec definitions
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestTaskSpecDefinitions:
|
||||
"""Test that all TaskSpec definitions are valid."""
|
||||
|
||||
def test_folderzip_default_spec(self) -> None:
|
||||
"""folderzip_default spec should be properly defined."""
|
||||
assert folderzip.folderzip_default.name == "folderzip_default"
|
||||
assert folderzip.folderzip_default.fn is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# main function
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestMain:
|
||||
"""Test main function."""
|
||||
|
||||
def test_main_calls_run_cli(self) -> None:
|
||||
"""main() should create a CliRunner and call run_cli()."""
|
||||
with patch.object(px.CliRunner, "run_cli") as mock_run_cli:
|
||||
folderzip.main()
|
||||
assert mock_run_cli.called
|
||||
@@ -0,0 +1,137 @@
|
||||
"""Tests for cli.gittool module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.cli import gittool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# not_has_git_repo
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestNotHasGitRepo:
|
||||
"""Test not_has_git_repo function."""
|
||||
|
||||
def test_not_has_git_repo_true(self, tmp_path: Path) -> None:
|
||||
"""Should return True when no .git directory."""
|
||||
with patch.object(Path, "cwd", return_value=tmp_path):
|
||||
result = gittool.not_has_git_repo()
|
||||
assert result is True
|
||||
|
||||
def test_not_has_git_repo_false(self, tmp_path: Path) -> None:
|
||||
"""Should return False when .git directory exists."""
|
||||
git_dir = tmp_path / ".git"
|
||||
git_dir.mkdir()
|
||||
|
||||
with patch.object(Path, "cwd", return_value=tmp_path):
|
||||
result = gittool.not_has_git_repo()
|
||||
assert result is False
|
||||
|
||||
def test_not_has_git_repo_cwd_not_exists(self, tmp_path: Path) -> None:
|
||||
"""Should return True when cwd doesn't exist."""
|
||||
nonexistent = tmp_path / "nonexistent"
|
||||
|
||||
with patch.object(Path, "cwd", return_value=nonexistent):
|
||||
result = gittool.not_has_git_repo()
|
||||
assert result is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# has_files
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestHasFiles:
|
||||
"""Test has_files function."""
|
||||
|
||||
def test_has_files_true(self, tmp_path: Path) -> None:
|
||||
"""Should return True when files exist."""
|
||||
(tmp_path / "test.txt").write_text("test")
|
||||
|
||||
with patch.object(Path, "cwd", return_value=tmp_path):
|
||||
result = gittool.has_files()
|
||||
assert result is True
|
||||
|
||||
def test_has_files_false(self, tmp_path: Path) -> None:
|
||||
"""Should return False when no files."""
|
||||
with patch.object(Path, "cwd", return_value=tmp_path):
|
||||
result = gittool.has_files()
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# init_sub_dirs
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestInitSubDirs:
|
||||
"""Test init_sub_dirs function."""
|
||||
|
||||
def test_init_sub_dirs_with_subdirectories(self, tmp_path: Path) -> None:
|
||||
"""Should initialize git in subdirectories."""
|
||||
subdir1 = tmp_path / "subdir1"
|
||||
subdir1.mkdir()
|
||||
subdir2 = tmp_path / "subdir2"
|
||||
subdir2.mkdir()
|
||||
|
||||
with patch.object(Path, "cwd", return_value=tmp_path), patch.object(px, "run") as mock_run:
|
||||
gittool.init_sub_dirs()
|
||||
# Should call px.run for each subdirectory
|
||||
assert mock_run.call_count == 2
|
||||
|
||||
def test_init_sub_dirs_no_subdirectories(self, tmp_path: Path) -> None:
|
||||
"""Should handle no subdirectories."""
|
||||
with patch.object(Path, "cwd", return_value=tmp_path), patch.object(px, "run") as mock_run:
|
||||
gittool.init_sub_dirs()
|
||||
# Should not call px.run
|
||||
assert mock_run.call_count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# TaskSpec definitions
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestTaskSpecDefinitions:
|
||||
"""Test that all TaskSpec definitions are valid."""
|
||||
|
||||
def test_push_spec(self) -> None:
|
||||
"""push spec should be properly defined."""
|
||||
assert gittool.push.name == "push"
|
||||
assert gittool.push.cmd == ["git", "push"]
|
||||
|
||||
def test_pull_spec(self) -> None:
|
||||
"""pull spec should be properly defined."""
|
||||
assert gittool.pull.name == "pull"
|
||||
assert gittool.pull.cmd == ["git", "pull"]
|
||||
|
||||
def test_kill_tgit_spec(self) -> None:
|
||||
"""kill_tgit spec should be properly defined."""
|
||||
assert gittool.kill_tgit.name == "task_kill"
|
||||
assert isinstance(gittool.kill_tgit.cmd, list)
|
||||
assert "taskkill" in gittool.kill_tgit.cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# main function
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestMain:
|
||||
"""Test main function."""
|
||||
|
||||
def test_main_calls_run_cli(self) -> None:
|
||||
"""main() should create a CliRunner and call run_cli()."""
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
gittool.main()
|
||||
# run_cli() calls sys.exit(), so we should get SystemExit
|
||||
assert exc_info.value.code in (0, 1, 2)
|
||||
|
||||
def test_main_with_list_argument(self) -> None:
|
||||
"""main() should handle --list argument."""
|
||||
with patch("sys.argv", ["gittool", "--list"]), pytest.raises(SystemExit) as exc_info:
|
||||
gittool.main()
|
||||
assert exc_info.value.code == 0
|
||||
|
||||
def test_main_with_no_args_shows_help(self) -> None:
|
||||
"""main() with no args should show help and exit."""
|
||||
with patch("sys.argv", ["gittool"]), pytest.raises(SystemExit) as exc_info:
|
||||
gittool.main()
|
||||
assert exc_info.value.code == 1
|
||||
@@ -0,0 +1,157 @@
|
||||
"""Tests for cli.lscalc module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.cli import lscalc
|
||||
from pyflowx.conditions import Constants
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# get_ls_dyna_command
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestGetLsDynaCommand:
|
||||
"""Test get_ls_dyna_command function."""
|
||||
|
||||
def test_get_ls_dyna_command_windows(self) -> None:
|
||||
"""Should get LS-DYNA command for Windows."""
|
||||
with patch.object(Constants, "IS_WINDOWS", True), patch.object(Constants, "IS_MACOS", False):
|
||||
cmd = lscalc.get_ls_dyna_command("input.k", 4)
|
||||
assert "ls-dyna_mpp" in cmd
|
||||
assert "i=input.k" in cmd
|
||||
assert "ncpu=4" in cmd
|
||||
|
||||
def test_get_ls_dyna_command_linux(self) -> None:
|
||||
"""Should get LS-DYNA command for Linux."""
|
||||
with patch.object(Constants, "IS_WINDOWS", False), patch.object(Constants, "IS_MACOS", False):
|
||||
cmd = lscalc.get_ls_dyna_command("input.k", 8)
|
||||
assert "ls-dyna_mpp" in cmd
|
||||
assert "i=input.k" in cmd
|
||||
assert "ncpu=8" in cmd
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# run_ls_dyna
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestRunLsDyna:
|
||||
"""Test run_ls_dyna function."""
|
||||
|
||||
def test_run_ls_dyna_success(self, tmp_path: Path) -> None:
|
||||
"""Should run LS-DYNA successfully."""
|
||||
input_file = tmp_path / "input.k"
|
||||
input_file.write_text("LS-DYNA input")
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
lscalc.run_ls_dyna(str(input_file), ncpu=4)
|
||||
assert mock_run.called
|
||||
|
||||
def test_run_ls_dyna_file_not_found(self, tmp_path: Path) -> None:
|
||||
"""Should handle nonexistent input file."""
|
||||
input_file = tmp_path / "nonexistent.k"
|
||||
|
||||
lscalc.run_ls_dyna(str(input_file), ncpu=4)
|
||||
# Should print error message
|
||||
|
||||
def test_run_ls_dyna_command_not_found(self, tmp_path: Path) -> None:
|
||||
"""Should handle command not found."""
|
||||
input_file = tmp_path / "input.k"
|
||||
input_file.write_text("LS-DYNA input")
|
||||
|
||||
with patch("subprocess.run", side_effect=FileNotFoundError):
|
||||
lscalc.run_ls_dyna(str(input_file), ncpu=4)
|
||||
# Should print error message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# run_ls_dyna_mpi
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestRunLsDynaMpi:
|
||||
"""Test run_ls_dyna_mpi function."""
|
||||
|
||||
def test_run_ls_dyna_mpi_success(self, tmp_path: Path) -> None:
|
||||
"""Should run LS-DYNA MPI successfully."""
|
||||
input_file = tmp_path / "input.k"
|
||||
input_file.write_text("LS-DYNA input")
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
lscalc.run_ls_dyna_mpi(str(input_file), ncpu=8)
|
||||
assert mock_run.called
|
||||
|
||||
def test_run_ls_dyna_mpi_file_not_found(self, tmp_path: Path) -> None:
|
||||
"""Should handle nonexistent input file."""
|
||||
input_file = tmp_path / "nonexistent.k"
|
||||
|
||||
lscalc.run_ls_dyna_mpi(str(input_file), ncpu=8)
|
||||
# Should print error message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# check_ls_dyna_status
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestCheckLsDynaStatus:
|
||||
"""Test check_ls_dyna_status function."""
|
||||
|
||||
def test_check_ls_dyna_status_windows(self) -> None:
|
||||
"""Should check LS-DYNA status on Windows."""
|
||||
with patch.object(Constants, "IS_WINDOWS", True), patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(stdout="ls-dyna_mpp.exe", returncode=0)
|
||||
lscalc.check_ls_dyna_status()
|
||||
assert mock_run.called
|
||||
|
||||
def test_check_ls_dyna_status_linux(self) -> None:
|
||||
"""Should check LS-DYNA status on Linux."""
|
||||
with patch.object(Constants, "IS_WINDOWS", False), patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(stdout="1234", returncode=0)
|
||||
lscalc.check_ls_dyna_status()
|
||||
assert mock_run.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# main function
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestMain:
|
||||
"""Test main function."""
|
||||
|
||||
def test_main_run_command(self, tmp_path: Path) -> None:
|
||||
"""main() should handle run command."""
|
||||
input_file = tmp_path / "input.k"
|
||||
input_file.write_text("LS-DYNA input")
|
||||
|
||||
with patch("sys.argv", ["lscalc", "run", str(input_file)]), patch.object(px, "run") as mock_run:
|
||||
lscalc.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_run_command_with_ncpu(self, tmp_path: Path) -> None:
|
||||
"""main() should handle run command with ncpu."""
|
||||
input_file = tmp_path / "input.k"
|
||||
input_file.write_text("LS-DYNA input")
|
||||
|
||||
with patch("sys.argv", ["lscalc", "run", str(input_file), "--ncpu", "8"]), patch.object(px, "run") as mock_run:
|
||||
lscalc.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_mpi_command(self, tmp_path: Path) -> None:
|
||||
"""main() should handle mpi command."""
|
||||
input_file = tmp_path / "input.k"
|
||||
input_file.write_text("LS-DYNA input")
|
||||
|
||||
with patch("sys.argv", ["lscalc", "mpi", str(input_file)]), patch.object(px, "run") as mock_run:
|
||||
lscalc.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_status_command(self) -> None:
|
||||
"""main() should handle status command."""
|
||||
with patch("sys.argv", ["lscalc", "status"]), patch.object(px, "run") as mock_run:
|
||||
lscalc.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_with_no_args_shows_help(self) -> None:
|
||||
"""main() with no args should show help."""
|
||||
with patch("sys.argv", ["lscalc"]):
|
||||
lscalc.main()
|
||||
# Should print help and return
|
||||
@@ -0,0 +1,324 @@
|
||||
"""Tests for cli.packtool module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.cli import packtool
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def packtool_tmp_workdir(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""自动切换到临时工作目录,防止测试污染项目根目录.
|
||||
|
||||
Args:
|
||||
tmp_path: pytest 提供的临时目录
|
||||
monkeypatch: pytest 的 monkeypatch 工具
|
||||
"""
|
||||
# Mock DEFAULT_CACHE_DIR 到临时目录
|
||||
monkeypatch.setattr(packtool, "DEFAULT_CACHE_DIR", str(tmp_path / ".cache" / "pypack"))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pack_source
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestPackSource:
|
||||
"""Test pack_source function."""
|
||||
|
||||
def test_pack_source_basic(self, tmp_path: Path) -> None:
|
||||
"""Should pack source code."""
|
||||
project_dir = tmp_path / "project"
|
||||
project_dir.mkdir()
|
||||
(project_dir / "main.py").write_text("print('hello')")
|
||||
output_dir = tmp_path / "output"
|
||||
|
||||
packtool.pack_source(project_dir, output_dir)
|
||||
assert output_dir.exists()
|
||||
|
||||
def test_pack_source_with_pyproject(self, tmp_path: Path) -> None:
|
||||
"""Should pack source with pyproject.toml."""
|
||||
project_dir = tmp_path / "project"
|
||||
project_dir.mkdir()
|
||||
(project_dir / "pyproject.toml").write_text("[project]\nname = 'test'")
|
||||
(project_dir / "main.py").write_text("print('hello')")
|
||||
output_dir = tmp_path / "output"
|
||||
|
||||
packtool.pack_source(project_dir, output_dir)
|
||||
assert output_dir.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pack_dependencies
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestPackDependencies:
|
||||
"""Test pack_dependencies function."""
|
||||
|
||||
def test_pack_dependencies_empty(self, tmp_path: Path) -> None:
|
||||
"""Should handle empty dependencies."""
|
||||
lib_dir = tmp_path / "libs"
|
||||
|
||||
packtool.pack_dependencies(lib_dir, [])
|
||||
# Should print message and return
|
||||
|
||||
def test_pack_dependencies_with_deps(self, tmp_path: Path) -> None:
|
||||
"""Should pack dependencies."""
|
||||
lib_dir = tmp_path / "libs"
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
packtool.pack_dependencies(lib_dir, ["numpy", "pandas"])
|
||||
assert mock_run.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pack_wheel
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestPackWheel:
|
||||
"""Test pack_wheel function."""
|
||||
|
||||
def test_pack_wheel(self, tmp_path: Path) -> None:
|
||||
"""Should pack wheel."""
|
||||
project_dir = tmp_path / "project"
|
||||
project_dir.mkdir()
|
||||
(project_dir / "pyproject.toml").write_text("[project]\nname = 'test'")
|
||||
output_dir = tmp_path / "dist"
|
||||
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
packtool.pack_wheel(project_dir, output_dir)
|
||||
assert mock_run.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# install_embed_python
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestInstallEmbedPython:
|
||||
"""Test install_embed_python function."""
|
||||
|
||||
def test_install_embed_python_basic(self, tmp_path: Path) -> None:
|
||||
"""Should install embedded Python (mocked for speed)."""
|
||||
output_dir = tmp_path / "python"
|
||||
|
||||
# Create a mock cache file that doesn't exist (force download)
|
||||
with patch("platform.machine", return_value="x86_64"), patch(
|
||||
"urllib.request.urlretrieve"
|
||||
) as mock_urlretrieve, patch("zipfile.ZipFile") as mock_zipfile:
|
||||
# Mock successful download
|
||||
mock_urlretrieve.return_value = None
|
||||
mock_zip_instance = MagicMock()
|
||||
mock_zipfile.return_value.__enter__.return_value = mock_zip_instance
|
||||
|
||||
packtool.install_embed_python("3.10", output_dir)
|
||||
|
||||
# Verify download was called
|
||||
assert mock_urlretrieve.called
|
||||
# Verify extraction was called
|
||||
assert mock_zip_instance.extractall.called
|
||||
# Verify output directory was created
|
||||
assert output_dir.exists()
|
||||
|
||||
def test_install_embed_python_with_cache(self, tmp_path: Path) -> None:
|
||||
"""Should use cached Python if available."""
|
||||
output_dir = tmp_path / "python"
|
||||
cache_dir = tmp_path / ".cache" / "pypack"
|
||||
cache_dir.mkdir(parents=True)
|
||||
|
||||
# Create a fake cached zip file
|
||||
cache_file = cache_dir / "python-3.10.11-embed-amd64.zip"
|
||||
cache_file.write_bytes(b"PK\x03\x04" + b"\x00" * 100) # Minimal ZIP header
|
||||
|
||||
with patch("platform.machine", return_value="x86_64"), patch("zipfile.ZipFile") as mock_zipfile:
|
||||
mock_zip_instance = MagicMock()
|
||||
mock_zipfile.return_value.__enter__.return_value = mock_zip_instance
|
||||
|
||||
packtool.install_embed_python("3.10", output_dir)
|
||||
|
||||
# Verify extraction was called (using cache)
|
||||
assert mock_zip_instance.extractall.called
|
||||
# Verify output directory was created
|
||||
assert output_dir.exists()
|
||||
|
||||
def test_install_embed_python_real_download(self, tmp_path: Path) -> None:
|
||||
"""Should actually download and extract embedded Python (requires network).
|
||||
|
||||
This test performs a real download to verify the entire workflow.
|
||||
It's marked to run only when network is available.
|
||||
"""
|
||||
import platform
|
||||
import zipfile
|
||||
|
||||
output_dir = tmp_path / "python_real"
|
||||
|
||||
# Only run on Windows (embed Python is Windows-specific)
|
||||
if platform.system() != "Windows":
|
||||
return
|
||||
|
||||
# Perform real installation
|
||||
packtool.install_embed_python("3.10", output_dir)
|
||||
|
||||
# Verify installation succeeded
|
||||
assert output_dir.exists()
|
||||
|
||||
# Verify key files are present
|
||||
expected_files = [
|
||||
"python.exe",
|
||||
"python310.dll",
|
||||
"python310.zip",
|
||||
]
|
||||
|
||||
for expected_file in expected_files:
|
||||
file_path = output_dir / expected_file
|
||||
assert file_path.exists(), f"Expected file {expected_file} not found"
|
||||
assert file_path.stat().st_size > 0, f"File {expected_file} is empty"
|
||||
|
||||
# Verify python.exe is executable
|
||||
python_exe = output_dir / "python.exe"
|
||||
assert python_exe.is_file()
|
||||
|
||||
# Verify the installation is functional
|
||||
# Check that we can at least read the zip file
|
||||
python_zip = output_dir / "python310.zip"
|
||||
assert zipfile.is_zipfile(python_zip)
|
||||
|
||||
print(f"✅ Successfully downloaded and installed embed Python to {output_dir}")
|
||||
print(f" Files: {list(output_dir.iterdir())}")
|
||||
|
||||
def test_install_embed_python_different_versions(self, tmp_path: Path) -> None:
|
||||
"""Should handle different Python versions."""
|
||||
output_dir = tmp_path / "python"
|
||||
|
||||
with patch("platform.machine", return_value="x86_64"), patch(
|
||||
"urllib.request.urlretrieve"
|
||||
) as mock_urlretrieve, patch("zipfile.ZipFile") as mock_zipfile:
|
||||
mock_zip_instance = MagicMock()
|
||||
mock_zipfile.return_value.__enter__.return_value = mock_zip_instance
|
||||
|
||||
# Test different versions
|
||||
for version in ["3.8", "3.9", "3.10", "3.11", "3.12"]:
|
||||
packtool.install_embed_python(version, output_dir)
|
||||
assert mock_urlretrieve.called
|
||||
|
||||
def test_install_embed_python_creates_cache(self, tmp_path: Path) -> None:
|
||||
"""Should create cache directory and file."""
|
||||
output_dir = tmp_path / "python"
|
||||
|
||||
with patch("platform.machine", return_value="x86_64"), patch(
|
||||
"urllib.request.urlretrieve"
|
||||
) as mock_urlretrieve, patch("zipfile.ZipFile") as mock_zipfile:
|
||||
mock_urlretrieve.return_value = None
|
||||
mock_zip_instance = MagicMock()
|
||||
mock_zipfile.return_value.__enter__.return_value = mock_zip_instance
|
||||
|
||||
packtool.install_embed_python("3.10", output_dir)
|
||||
|
||||
# Verify cache directory was created (now in tmp_path)
|
||||
Path(packtool.DEFAULT_CACHE_DIR)
|
||||
# Note: In test environment, cache might not persist due to mocking
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# create_zip_package
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestCreateZipPackage:
|
||||
"""Test create_zip_package function."""
|
||||
|
||||
def test_create_zip_package(self, tmp_path: Path) -> None:
|
||||
"""Should create ZIP package."""
|
||||
source_dir = tmp_path / "source"
|
||||
source_dir.mkdir()
|
||||
(source_dir / "test.txt").write_text("test content")
|
||||
output_file = tmp_path / "package.zip"
|
||||
|
||||
packtool.create_zip_package(source_dir, output_file)
|
||||
assert output_file.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# clean_build_dir
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestCleanBuildDir:
|
||||
"""Test clean_build_dir function."""
|
||||
|
||||
def test_clean_build_dir_exists(self, tmp_path: Path) -> None:
|
||||
"""Should clean existing build directory."""
|
||||
build_dir = tmp_path / "build"
|
||||
build_dir.mkdir()
|
||||
(build_dir / "test.txt").write_text("test")
|
||||
|
||||
packtool.clean_build_dir(build_dir)
|
||||
assert not build_dir.exists()
|
||||
|
||||
def test_clean_build_dir_not_exists(self, tmp_path: Path) -> None:
|
||||
"""Should handle nonexistent build directory."""
|
||||
build_dir = tmp_path / "nonexistent"
|
||||
|
||||
packtool.clean_build_dir(build_dir)
|
||||
# Should print message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# main function
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestMain:
|
||||
"""Test main function."""
|
||||
|
||||
def test_main_src_command(self, tmp_path: Path) -> None:
|
||||
"""main() should handle src command."""
|
||||
project_dir = tmp_path / "project"
|
||||
project_dir.mkdir()
|
||||
|
||||
with patch("sys.argv", ["packtool", "src", "--project-dir", str(project_dir)]), patch.object(
|
||||
px, "run"
|
||||
) as mock_run:
|
||||
packtool.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_deps_command(self, tmp_path: Path) -> None:
|
||||
"""main() should handle deps command."""
|
||||
with patch("sys.argv", ["packtool", "deps", "numpy", "pandas"]), patch.object(px, "run") as mock_run:
|
||||
packtool.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_wheel_command(self, tmp_path: Path) -> None:
|
||||
"""main() should handle wheel command."""
|
||||
project_dir = tmp_path / "project"
|
||||
project_dir.mkdir()
|
||||
|
||||
with patch("sys.argv", ["packtool", "wheel", "--project-dir", str(project_dir)]), patch.object(
|
||||
px, "run"
|
||||
) as mock_run:
|
||||
packtool.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_embed_command(self, tmp_path: Path) -> None:
|
||||
"""main() should handle embed command."""
|
||||
with patch("sys.argv", ["packtool", "embed", "--version", "3.10"]), patch.object(px, "run") as mock_run:
|
||||
packtool.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_zip_command(self, tmp_path: Path) -> None:
|
||||
"""main() should handle zip command."""
|
||||
source_dir = tmp_path / "source"
|
||||
source_dir.mkdir()
|
||||
|
||||
with patch("sys.argv", ["packtool", "zip", "--source-dir", str(source_dir)]), patch.object(
|
||||
px, "run"
|
||||
) as mock_run:
|
||||
packtool.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_clean_command(self) -> None:
|
||||
"""main() should handle clean command."""
|
||||
with patch("sys.argv", ["packtool", "clean"]), patch.object(px, "run") as mock_run:
|
||||
packtool.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_with_no_args_shows_help(self) -> None:
|
||||
"""main() with no args should show help."""
|
||||
with patch("sys.argv", ["packtool"]):
|
||||
packtool.main()
|
||||
# Should print help and return
|
||||
@@ -0,0 +1,324 @@
|
||||
"""Tests for cli.pdftool module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.cli import pdftool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pdf_merge
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestPdfMerge:
|
||||
"""Test pdf_merge function."""
|
||||
|
||||
def test_pdf_merge_files(self, tmp_path: Path) -> None:
|
||||
"""Should merge PDF files."""
|
||||
pytest.importorskip("pypdf")
|
||||
input_files = [tmp_path / "input1.pdf", tmp_path / "input2.pdf"]
|
||||
for f in input_files:
|
||||
f.write_bytes(b"PDF content")
|
||||
output_file = tmp_path / "merged.pdf"
|
||||
|
||||
with patch("pypdf.PdfReader"), patch("pypdf.PdfWriter") as mock_writer:
|
||||
mock_writer_instance = MagicMock()
|
||||
mock_writer.return_value = mock_writer_instance
|
||||
pdftool.pdf_merge(input_files, output_file)
|
||||
assert mock_writer_instance.write.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pdf_split
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestPdfSplit:
|
||||
"""Test pdf_split function."""
|
||||
|
||||
def test_pdf_split_file(self, tmp_path: Path) -> None:
|
||||
"""Should split PDF file."""
|
||||
pytest.importorskip("pypdf")
|
||||
input_file = tmp_path / "input.pdf"
|
||||
input_file.write_bytes(b"PDF content")
|
||||
output_dir = tmp_path / "split"
|
||||
|
||||
with patch("pypdf.PdfReader") as mock_reader, patch("pypdf.PdfWriter"):
|
||||
mock_reader_instance = MagicMock()
|
||||
mock_reader.return_value = mock_reader_instance
|
||||
mock_reader_instance.pages = [MagicMock()]
|
||||
pdftool.pdf_split(input_file, output_dir)
|
||||
assert output_dir.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pdf_compress
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestPdfCompress:
|
||||
"""Test pdf_compress function."""
|
||||
|
||||
def test_pdf_compress_file(self, tmp_path: Path) -> None:
|
||||
"""Should compress PDF file."""
|
||||
pytest.importorskip("fitz")
|
||||
input_file = tmp_path / "input.pdf"
|
||||
input_file.write_bytes(b"PDF content")
|
||||
output_file = tmp_path / "compressed.pdf"
|
||||
|
||||
with patch("fitz.open") as mock_fitz_open:
|
||||
mock_doc = MagicMock()
|
||||
mock_fitz_open.return_value = mock_doc
|
||||
|
||||
# Mock save to actually create the file
|
||||
def mock_save(*args: Any, **kwargs: Any):
|
||||
output_file.write_bytes(b"Compressed PDF")
|
||||
|
||||
mock_doc.save = mock_save
|
||||
pdftool.pdf_compress(input_file, output_file)
|
||||
assert output_file.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pdf_extract_text
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestPdfExtractText:
|
||||
"""Test pdf_extract_text function."""
|
||||
|
||||
def test_pdf_extract_text_file(self, tmp_path: Path) -> None:
|
||||
"""Should extract text from PDF."""
|
||||
pytest.importorskip("fitz")
|
||||
input_file = tmp_path / "input.pdf"
|
||||
input_file.write_bytes(b"PDF content")
|
||||
output_file = tmp_path / "output.txt"
|
||||
|
||||
with patch("fitz.open") as mock_fitz_open:
|
||||
mock_doc = MagicMock()
|
||||
mock_page = MagicMock()
|
||||
mock_page.get_text.return_value = "Test text"
|
||||
mock_doc.__iter__ = MagicMock(return_value=iter([mock_page]))
|
||||
mock_fitz_open.return_value = mock_doc
|
||||
pdftool.pdf_extract_text(input_file, output_file)
|
||||
assert output_file.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pdf_extract_images
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestPdfExtractImages:
|
||||
"""Test pdf_extract_images function."""
|
||||
|
||||
def test_pdf_extract_images_file(self, tmp_path: Path) -> None:
|
||||
"""Should extract images from PDF."""
|
||||
pytest.importorskip("fitz")
|
||||
input_file = tmp_path / "input.pdf"
|
||||
input_file.write_bytes(b"PDF content")
|
||||
output_dir = tmp_path / "images"
|
||||
|
||||
with patch("fitz.open") as mock_fitz_open:
|
||||
mock_doc = MagicMock()
|
||||
mock_page = MagicMock()
|
||||
mock_page.get_images.return_value = [[0]]
|
||||
mock_doc.__iter__ = MagicMock(return_value=iter([mock_page]))
|
||||
mock_doc.extract_image.return_value = {"image": b"image data", "ext": "png"}
|
||||
mock_fitz_open.return_value = mock_doc
|
||||
pdftool.pdf_extract_images(input_file, output_dir)
|
||||
assert output_dir.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pdf_add_watermark
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestPdfAddWatermark:
|
||||
"""Test pdf_add_watermark function."""
|
||||
|
||||
def test_pdf_add_watermark_file(self, tmp_path: Path) -> None:
|
||||
"""Should add watermark to PDF."""
|
||||
pytest.importorskip("fitz")
|
||||
input_file = tmp_path / "input.pdf"
|
||||
input_file.write_bytes(b"PDF content")
|
||||
output_file = tmp_path / "watermarked.pdf"
|
||||
|
||||
with patch("fitz.open") as mock_fitz_open, patch("fitz.get_text_length") as mock_text_length:
|
||||
mock_doc = MagicMock()
|
||||
mock_page = MagicMock()
|
||||
mock_page.rect = MagicMock(width=800, height=600)
|
||||
mock_doc.__iter__ = MagicMock(return_value=iter([mock_page]))
|
||||
mock_fitz_open.return_value = mock_doc
|
||||
mock_text_length.return_value = 100
|
||||
pdftool.pdf_add_watermark(input_file, output_file)
|
||||
assert mock_doc.save.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pdf_rotate
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestPdfRotate:
|
||||
"""Test pdf_rotate function."""
|
||||
|
||||
def test_pdf_rotate_file_90(self, tmp_path: Path) -> None:
|
||||
"""Should rotate PDF by 90 degrees."""
|
||||
pytest.importorskip("fitz")
|
||||
input_file = tmp_path / "input.pdf"
|
||||
input_file.write_bytes(b"PDF content")
|
||||
output_file = tmp_path / "rotated.pdf"
|
||||
|
||||
with patch("fitz.open") as mock_fitz_open:
|
||||
mock_doc = MagicMock()
|
||||
mock_page = MagicMock()
|
||||
mock_doc.__iter__ = MagicMock(return_value=iter([mock_page]))
|
||||
mock_fitz_open.return_value = mock_doc
|
||||
pdftool.pdf_rotate(input_file, output_file, rotation=90)
|
||||
assert mock_doc.save.called
|
||||
|
||||
def test_pdf_rotate_file_180(self, tmp_path: Path) -> None:
|
||||
"""Should rotate PDF by 180 degrees."""
|
||||
pytest.importorskip("fitz")
|
||||
input_file = tmp_path / "input.pdf"
|
||||
input_file.write_bytes(b"PDF content")
|
||||
output_file = tmp_path / "rotated.pdf"
|
||||
|
||||
with patch("fitz.open") as mock_fitz_open:
|
||||
mock_doc = MagicMock()
|
||||
mock_page = MagicMock()
|
||||
mock_doc.__iter__ = MagicMock(return_value=iter([mock_page]))
|
||||
mock_fitz_open.return_value = mock_doc
|
||||
pdftool.pdf_rotate(input_file, output_file, rotation=180)
|
||||
assert mock_doc.save.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pdf_crop
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestPdfCrop:
|
||||
"""Test pdf_crop function."""
|
||||
|
||||
def test_pdf_crop_file(self, tmp_path: Path) -> None:
|
||||
"""Should crop PDF."""
|
||||
pytest.importorskip("fitz")
|
||||
input_file = tmp_path / "input.pdf"
|
||||
input_file.write_bytes(b"PDF content")
|
||||
output_file = tmp_path / "cropped.pdf"
|
||||
|
||||
with patch("fitz.open") as mock_fitz_open, patch("fitz.Rect"):
|
||||
mock_doc = MagicMock()
|
||||
mock_page = MagicMock()
|
||||
mock_page.rect = MagicMock(x0=0, y0=0, x1=800, y1=600)
|
||||
mock_doc.__iter__ = MagicMock(return_value=iter([mock_page]))
|
||||
mock_fitz_open.return_value = mock_doc
|
||||
pdftool.pdf_crop(input_file, output_file, margins=(10, 10, 10, 10))
|
||||
assert mock_doc.save.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pdf_info
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestPdfInfo:
|
||||
"""Test pdf_info function."""
|
||||
|
||||
def test_pdf_info_file(self, tmp_path: Path) -> None:
|
||||
"""Should show PDF info."""
|
||||
pytest.importorskip("fitz")
|
||||
input_file = tmp_path / "input.pdf"
|
||||
input_file.write_bytes(b"PDF content")
|
||||
|
||||
with patch("fitz.open") as mock_fitz_open:
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.page_count = 10
|
||||
mock_doc.metadata = {"title": "Test", "author": "Author"}
|
||||
mock_fitz_open.return_value = mock_doc
|
||||
pdftool.pdf_info(input_file)
|
||||
assert mock_fitz_open.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pdf_ocr
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestPdfOcr:
|
||||
"""Test pdf_ocr function."""
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_pdf_ocr_file(self, tmp_path: Path) -> None:
|
||||
"""Should OCR PDF."""
|
||||
pytest.importorskip("fitz")
|
||||
pytest.importorskip("pytesseract")
|
||||
pytest.importorskip("PIL")
|
||||
input_file = tmp_path / "input.pdf"
|
||||
input_file.write_bytes(b"PDF content")
|
||||
output_file = tmp_path / "ocr.pdf"
|
||||
|
||||
with patch("fitz.open") as mock_fitz_open, patch("PIL.Image.frombytes"), patch(
|
||||
"pytesseract.image_to_string"
|
||||
) as mock_ocr:
|
||||
mock_doc = MagicMock()
|
||||
mock_page = MagicMock()
|
||||
mock_page.rect = MagicMock(width=800, height=600)
|
||||
mock_doc.__iter__ = MagicMock(return_value=iter([mock_page]))
|
||||
mock_fitz_open.return_value = mock_doc
|
||||
mock_ocr.return_value = "OCR text"
|
||||
pdftool.pdf_ocr(input_file, output_file)
|
||||
# Should complete OCR
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pdf_repair
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestPdfRepair:
|
||||
"""Test pdf_repair function."""
|
||||
|
||||
def test_pdf_repair_file(self, tmp_path: Path) -> None:
|
||||
"""Should repair PDF."""
|
||||
pytest.importorskip("fitz")
|
||||
input_file = tmp_path / "input.pdf"
|
||||
input_file.write_bytes(b"PDF content")
|
||||
output_file = tmp_path / "repaired.pdf"
|
||||
|
||||
with patch("fitz.open") as mock_fitz_open:
|
||||
mock_doc = MagicMock()
|
||||
mock_fitz_open.return_value = mock_doc
|
||||
pdftool.pdf_repair(input_file, output_file)
|
||||
assert mock_doc.save.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# main function
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestMain:
|
||||
"""Test main function."""
|
||||
|
||||
def test_main_merge_command(self, tmp_path: Path) -> None:
|
||||
"""main() should handle merge command."""
|
||||
input_files = [tmp_path / "input1.pdf", tmp_path / "input2.pdf"]
|
||||
for f in input_files:
|
||||
f.write_bytes(b"PDF content")
|
||||
|
||||
with patch("sys.argv", ["pdftool", "m", str(input_files[0]), str(input_files[1])]), patch.object(
|
||||
px, "run"
|
||||
) as mock_run:
|
||||
pdftool.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_split_command(self, tmp_path: Path) -> None:
|
||||
"""main() should handle split command."""
|
||||
input_file = tmp_path / "input.pdf"
|
||||
input_file.write_bytes(b"PDF content")
|
||||
|
||||
with patch("sys.argv", ["pdftool", "s", str(input_file)]), patch.object(px, "run") as mock_run:
|
||||
pdftool.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_compress_command(self, tmp_path: Path) -> None:
|
||||
"""main() should handle compress command."""
|
||||
input_file = tmp_path / "input.pdf"
|
||||
input_file.write_bytes(b"PDF content")
|
||||
|
||||
with patch("sys.argv", ["pdftool", "c", str(input_file)]), patch.object(px, "run") as mock_run:
|
||||
pdftool.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_with_no_args_shows_help(self) -> None:
|
||||
"""main() with no args should show help."""
|
||||
with patch("sys.argv", ["pdftool"]):
|
||||
pdftool.main()
|
||||
# Should print help and return
|
||||
@@ -0,0 +1,254 @@
|
||||
"""Tests for cli.piptool module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.cli import piptool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# _get_installed_packages
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestGetInstalledPackages:
|
||||
"""Test _get_installed_packages function."""
|
||||
|
||||
def test_get_installed_packages_success(self) -> None:
|
||||
"""Should get installed packages."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(stdout="numpy==1.0.0\npandas==2.0.0\n", returncode=0)
|
||||
result = piptool._get_installed_packages()
|
||||
assert "numpy" in result
|
||||
assert "pandas" in result
|
||||
|
||||
def test_get_installed_packages_empty(self) -> None:
|
||||
"""Should handle empty output."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(stdout="", returncode=0)
|
||||
result = piptool._get_installed_packages()
|
||||
assert result == []
|
||||
|
||||
def test_get_installed_packages_error(self) -> None:
|
||||
"""Should handle subprocess error."""
|
||||
with patch("subprocess.run", side_effect=subprocess.SubprocessError):
|
||||
result = piptool._get_installed_packages()
|
||||
assert result == []
|
||||
|
||||
def test_get_installed_packages_oserror(self) -> None:
|
||||
"""Should handle OSError."""
|
||||
with patch("subprocess.run", side_effect=OSError):
|
||||
result = piptool._get_installed_packages()
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# _expand_wildcard_packages
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestExpandWildcardPackages:
|
||||
"""Test _expand_wildcard_packages function."""
|
||||
|
||||
def test_expand_wildcard_no_pattern(self) -> None:
|
||||
"""Should return package name when no wildcard."""
|
||||
result = piptool._expand_wildcard_packages("numpy")
|
||||
assert result == ["numpy"]
|
||||
|
||||
def test_expand_wildcard_with_star(self) -> None:
|
||||
"""Should expand wildcard with star."""
|
||||
with patch.object(piptool, "_get_installed_packages", return_value=["numpy", "numpy-core", "pandas"]):
|
||||
result = piptool._expand_wildcard_packages("numpy*")
|
||||
assert "numpy" in result
|
||||
assert "numpy-core" in result
|
||||
|
||||
def test_expand_wildcard_with_question(self) -> None:
|
||||
"""Should expand wildcard with question mark."""
|
||||
with patch.object(piptool, "_get_installed_packages", return_value=["numpy", "numba"]):
|
||||
result = piptool._expand_wildcard_packages("num??")
|
||||
assert len(result) > 0
|
||||
|
||||
def test_expand_wildcard_no_match(self) -> None:
|
||||
"""Should return empty list when no match."""
|
||||
with patch.object(piptool, "_get_installed_packages", return_value=["pandas", "scipy"]):
|
||||
result = piptool._expand_wildcard_packages("numpy*")
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# _filter_protected_packages
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestFilterProtectedPackages:
|
||||
"""Test _filter_protected_packages function."""
|
||||
|
||||
def test_filter_protected_packages_normal(self) -> None:
|
||||
"""Should filter protected packages."""
|
||||
result = piptool._filter_protected_packages(["numpy", "pandas", "pyflowx"])
|
||||
assert "numpy" in result
|
||||
assert "pandas" in result
|
||||
assert "pyflowx" not in result
|
||||
|
||||
def test_filter_protected_packages_all_protected(self) -> None:
|
||||
"""Should filter all protected packages."""
|
||||
result = piptool._filter_protected_packages(["pyflowx", "bitool"])
|
||||
assert result == []
|
||||
|
||||
def test_filter_protected_packages_case_insensitive(self) -> None:
|
||||
"""Should filter case insensitive."""
|
||||
result = piptool._filter_protected_packages(["PyFlowX", "BITOOL"])
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pip_uninstall
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestPipUninstall:
|
||||
"""Test pip_uninstall function."""
|
||||
|
||||
def test_pip_uninstall_single_package(self) -> None:
|
||||
"""Should uninstall single package."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
piptool.pip_uninstall(["numpy"])
|
||||
assert mock_run.called
|
||||
|
||||
def test_pip_uninstall_multiple_packages(self) -> None:
|
||||
"""Should uninstall multiple packages."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
piptool.pip_uninstall(["numpy", "pandas", "scipy"])
|
||||
# Should call pip uninstall
|
||||
assert mock_run.called
|
||||
|
||||
def test_pip_uninstall_with_wildcard(self) -> None:
|
||||
"""Should handle wildcard in package name."""
|
||||
with patch.object(piptool, "_expand_wildcard_packages", return_value=["numpy", "numpy-core"]), patch(
|
||||
"subprocess.run"
|
||||
) as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
piptool.pip_uninstall(["numpy*"])
|
||||
assert mock_run.called
|
||||
|
||||
def test_pip_uninstall_empty_packages(self) -> None:
|
||||
"""Should handle empty packages list."""
|
||||
with patch.object(piptool, "_expand_wildcard_packages", return_value=[]):
|
||||
piptool.pip_uninstall(["nonexistent*"])
|
||||
# Should not call subprocess.run
|
||||
|
||||
def test_pip_uninstall_all_protected(self) -> None:
|
||||
"""Should handle all protected packages."""
|
||||
piptool.pip_uninstall(["pyflowx"])
|
||||
# Should not call subprocess.run
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pip_reinstall
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestPipReinstall:
|
||||
"""Test pip_reinstall function."""
|
||||
|
||||
def test_pip_reinstall_single_package(self) -> None:
|
||||
"""Should reinstall single package."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
piptool.pip_reinstall(["numpy"])
|
||||
# Should call pip uninstall and pip install
|
||||
assert mock_run.call_count == 2
|
||||
|
||||
def test_pip_reinstall_offline(self) -> None:
|
||||
"""Should reinstall packages offline."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
piptool.pip_reinstall(["numpy"], offline=True)
|
||||
# Should call pip install with offline flags
|
||||
assert mock_run.called
|
||||
|
||||
def test_pip_reinstall_all_protected(self) -> None:
|
||||
"""Should handle all protected packages."""
|
||||
piptool.pip_reinstall(["pyflowx"])
|
||||
# Should not call subprocess.run
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pip_download
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestPipDownload:
|
||||
"""Test pip_download function."""
|
||||
|
||||
def test_pip_download_single_package(self) -> None:
|
||||
"""Should download single package."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
piptool.pip_download(["numpy"])
|
||||
assert mock_run.called
|
||||
|
||||
def test_pip_download_offline(self) -> None:
|
||||
"""Should download packages offline."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
piptool.pip_download(["numpy"], offline=True)
|
||||
# Should call pip download with offline flags
|
||||
assert mock_run.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pip_freeze
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestPipFreeze:
|
||||
"""Test pip_freeze function."""
|
||||
|
||||
def test_pip_freeze(self, tmp_path: Path) -> None:
|
||||
"""Should freeze dependencies."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(stdout="numpy==1.0.0\npandas==2.0.0", returncode=0)
|
||||
piptool.pip_freeze()
|
||||
assert mock_run.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# main function
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestMain:
|
||||
"""Test main function."""
|
||||
|
||||
def test_main_install_command(self) -> None:
|
||||
"""main() should handle install command."""
|
||||
with patch("sys.argv", ["piptool", "i", "numpy", "pandas"]), patch.object(px, "run") as mock_run:
|
||||
piptool.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_uninstall_command(self) -> None:
|
||||
"""main() should handle uninstall command."""
|
||||
with patch("sys.argv", ["piptool", "u", "numpy"]), patch.object(px, "run") as mock_run:
|
||||
piptool.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_reinstall_command(self) -> None:
|
||||
"""main() should handle reinstall command."""
|
||||
with patch("sys.argv", ["piptool", "r", "numpy"]), patch.object(px, "run") as mock_run:
|
||||
piptool.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_download_command(self) -> None:
|
||||
"""main() should handle download command."""
|
||||
with patch("sys.argv", ["piptool", "d", "numpy"]), patch.object(px, "run") as mock_run:
|
||||
piptool.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_upgrade_command(self) -> None:
|
||||
"""main() should handle upgrade command."""
|
||||
with patch("sys.argv", ["piptool", "up"]), patch.object(px, "run") as mock_run:
|
||||
piptool.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_freeze_command(self) -> None:
|
||||
"""main() should handle freeze command."""
|
||||
with patch("sys.argv", ["piptool", "f"]), patch.object(px, "run") as mock_run:
|
||||
piptool.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_with_no_args_shows_help(self) -> None:
|
||||
"""main() with no args should show help."""
|
||||
with patch("sys.argv", ["piptool"]):
|
||||
piptool.main()
|
||||
# Should print help and return
|
||||
@@ -0,0 +1,158 @@
|
||||
"""Tests for cli.pymake module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from pyflowx.cli import pymake
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# TaskSpec definitions
|
||||
# ---------------------------------------------------------------------- #
|
||||
def _find_task(name: str) -> pymake.px.TaskSpec:
|
||||
"""从 pymake.tasks 或 aliases 中查找指定名称的 TaskSpec."""
|
||||
for spec in pymake.tasks:
|
||||
if spec.name == name:
|
||||
return spec
|
||||
# 单任务别名(doc/lint/tox)内联在 aliases dict 中
|
||||
value = pymake.aliases.get(name)
|
||||
if isinstance(value, pymake.px.TaskSpec):
|
||||
return value
|
||||
raise KeyError(f"任务 {name!r} 未找到")
|
||||
|
||||
|
||||
class TestTaskSpecDefinitions:
|
||||
"""Test that all TaskSpec definitions are valid."""
|
||||
|
||||
def test_uv_build_spec(self) -> None:
|
||||
"""uv_build spec should be properly defined."""
|
||||
spec = _find_task("uv_build")
|
||||
assert spec.name == "uv_build"
|
||||
assert spec.cmd == ["uv", "build"]
|
||||
assert spec.skip_if_missing is False
|
||||
|
||||
def test_maturin_build_spec(self) -> None:
|
||||
"""maturin_build spec should be properly defined."""
|
||||
spec = _find_task("maturin_build")
|
||||
assert spec.name == "maturin_build"
|
||||
assert isinstance(spec.cmd, list)
|
||||
assert spec.skip_if_missing is False
|
||||
|
||||
def test_uv_sync_spec(self) -> None:
|
||||
"""uv_sync spec should be properly defined."""
|
||||
spec = _find_task("uv_sync")
|
||||
assert spec.name == "uv_sync"
|
||||
assert spec.cmd == ["uv", "sync"]
|
||||
assert spec.skip_if_missing is False
|
||||
|
||||
def test_git_clean_spec(self) -> None:
|
||||
"""git_clean spec should be properly defined."""
|
||||
spec = _find_task("git_clean")
|
||||
assert spec.name == "git_clean"
|
||||
assert spec.cmd == ["gitt", "c"]
|
||||
assert spec.skip_if_missing is False
|
||||
|
||||
def test_test_spec(self) -> None:
|
||||
"""test spec should be properly defined."""
|
||||
spec = _find_task("test")
|
||||
assert spec.name == "test"
|
||||
assert isinstance(spec.cmd, list)
|
||||
assert "pytest" in spec.cmd
|
||||
assert "-m" in spec.cmd
|
||||
assert "not slow" in spec.cmd
|
||||
assert spec.skip_if_missing is False
|
||||
|
||||
def test_test_fast_spec(self) -> None:
|
||||
"""test_fast spec should be properly defined."""
|
||||
spec = _find_task("test_fast")
|
||||
assert spec.name == "test_fast"
|
||||
assert isinstance(spec.cmd, list)
|
||||
assert "pytest" in spec.cmd
|
||||
assert "-n" not in spec.cmd # test_fast doesn't use parallel
|
||||
assert spec.skip_if_missing is False
|
||||
|
||||
def test_test_coverage_spec(self) -> None:
|
||||
"""test_coverage spec should be properly defined."""
|
||||
spec = _find_task("test_coverage")
|
||||
assert spec.name == "test_coverage"
|
||||
assert isinstance(spec.cmd, list)
|
||||
assert "pytest" in spec.cmd
|
||||
assert "--cov" in spec.cmd
|
||||
assert spec.skip_if_missing is False
|
||||
|
||||
def test_ruff_lint_spec(self) -> None:
|
||||
"""lint spec should be properly defined."""
|
||||
spec = _find_task("lint")
|
||||
assert spec.name == "lint"
|
||||
assert isinstance(spec.cmd, list)
|
||||
assert "ruff" in spec.cmd
|
||||
assert "check" in spec.cmd
|
||||
assert spec.skip_if_missing is False
|
||||
|
||||
def test_doc_spec(self) -> None:
|
||||
"""doc spec should be properly defined."""
|
||||
spec = _find_task("doc")
|
||||
assert spec.name == "doc"
|
||||
assert isinstance(spec.cmd, list)
|
||||
assert "sphinx-build" in spec.cmd
|
||||
assert spec.skip_if_missing is False
|
||||
|
||||
def test_hatch_publish_spec(self) -> None:
|
||||
"""publish_python spec should be properly defined."""
|
||||
spec = _find_task("publish_python")
|
||||
assert spec.name == "publish_python"
|
||||
assert spec.cmd == ["hatch", "publish"]
|
||||
assert spec.skip_if_missing is False
|
||||
|
||||
def test_twine_publish_spec(self) -> None:
|
||||
"""twine_publish spec should be properly defined."""
|
||||
spec = _find_task("twine_publish")
|
||||
assert spec.name == "twine_publish"
|
||||
assert isinstance(spec.cmd, list)
|
||||
assert "twine" in spec.cmd
|
||||
assert "upload" in spec.cmd
|
||||
assert spec.skip_if_missing is False
|
||||
|
||||
def test_tox_spec(self) -> None:
|
||||
"""tox spec should be properly defined."""
|
||||
spec = _find_task("tox")
|
||||
assert spec.name == "tox"
|
||||
assert spec.cmd == ["tox", "-p", "auto"]
|
||||
assert spec.skip_if_missing is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# main function
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestMain:
|
||||
"""Test main function."""
|
||||
|
||||
def test_main_calls_run_cli(self) -> None:
|
||||
"""main() should create a CliRunner and call run_cli()."""
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
pymake.main()
|
||||
# run_cli() calls sys.exit(), so we should get SystemExit
|
||||
# The exit code depends on whether any commands are available
|
||||
assert exc_info.value.code in (0, 1, 2)
|
||||
|
||||
def test_main_with_list_argument(self) -> None:
|
||||
"""main() should handle --list argument."""
|
||||
with patch("sys.argv", ["pymake", "--list"]), pytest.raises(SystemExit) as exc_info:
|
||||
pymake.main()
|
||||
assert exc_info.value.code == 0
|
||||
|
||||
def test_main_creates_runner_with_multiple_commands(self) -> None:
|
||||
"""main() should create a CliRunner with multiple commands."""
|
||||
# We can't easily test the runner creation without mocking,
|
||||
# but we can verify that main() doesn't raise an error for --list
|
||||
with patch("sys.argv", ["pymake", "--list"]), pytest.raises(SystemExit):
|
||||
pymake.main()
|
||||
|
||||
def test_main_with_no_args_shows_help(self) -> None:
|
||||
"""main() with no args should show help and exit with failure."""
|
||||
with patch("sys.argv", ["pymake"]), pytest.raises(SystemExit) as exc_info:
|
||||
pymake.main()
|
||||
assert exc_info.value.code == 1
|
||||
@@ -0,0 +1,123 @@
|
||||
"""Tests for cli.screenshot module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.cli import screenshot
|
||||
from pyflowx.conditions import Constants
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# get_screenshot_path
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestGetScreenshotPath:
|
||||
"""Test get_screenshot_path function."""
|
||||
|
||||
def test_get_screenshot_path_with_filename(self, tmp_path: Path) -> None:
|
||||
"""Should get screenshot path with filename."""
|
||||
with patch.object(Path, "home", return_value=tmp_path):
|
||||
result = screenshot.get_screenshot_path("test.png")
|
||||
assert result.name == "test.png"
|
||||
|
||||
def test_get_screenshot_path_without_filename(self, tmp_path: Path) -> None:
|
||||
"""Should get screenshot path without filename."""
|
||||
with patch.object(Path, "home", return_value=tmp_path):
|
||||
result = screenshot.get_screenshot_path()
|
||||
assert "screenshot_" in result.name
|
||||
assert result.suffix == ".png"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# take_screenshot_full
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestTakeScreenshotFull:
|
||||
"""Test take_screenshot_full function."""
|
||||
|
||||
def test_take_screenshot_full_windows(self, tmp_path: Path) -> None:
|
||||
"""Should take full screenshot on Windows."""
|
||||
with patch.object(Constants, "IS_WINDOWS", True), patch.object(Constants, "IS_MACOS", False), patch.object(
|
||||
Path, "home", return_value=tmp_path
|
||||
), patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
screenshot.take_screenshot_full()
|
||||
assert mock_run.called
|
||||
|
||||
def test_take_screenshot_full_macos(self, tmp_path: Path) -> None:
|
||||
"""Should take full screenshot on macOS."""
|
||||
with patch.object(Constants, "IS_WINDOWS", False), patch.object(Constants, "IS_MACOS", True), patch.object(
|
||||
Path, "home", return_value=tmp_path
|
||||
), patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
screenshot.take_screenshot_full()
|
||||
assert mock_run.called
|
||||
|
||||
def test_take_screenshot_full_linux(self, tmp_path: Path) -> None:
|
||||
"""Should take full screenshot on Linux."""
|
||||
with patch.object(Constants, "IS_WINDOWS", False), patch.object(Constants, "IS_MACOS", False), patch.object(
|
||||
Path, "home", return_value=tmp_path
|
||||
), patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
screenshot.take_screenshot_full()
|
||||
assert mock_run.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# take_screenshot_area
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestTakeScreenshotArea:
|
||||
"""Test take_screenshot_area function."""
|
||||
|
||||
def test_take_screenshot_area_windows(self, tmp_path: Path) -> None:
|
||||
"""Should take area screenshot on Windows."""
|
||||
with patch.object(Constants, "IS_WINDOWS", True), patch.object(Constants, "IS_MACOS", False), patch.object(
|
||||
Path, "home", return_value=tmp_path
|
||||
), patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
screenshot.take_screenshot_area()
|
||||
assert mock_run.called
|
||||
|
||||
def test_take_screenshot_area_macos(self, tmp_path: Path) -> None:
|
||||
"""Should take area screenshot on macOS."""
|
||||
with patch.object(Constants, "IS_WINDOWS", False), patch.object(Constants, "IS_MACOS", True), patch.object(
|
||||
Path, "home", return_value=tmp_path
|
||||
), patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
screenshot.take_screenshot_area()
|
||||
assert mock_run.called
|
||||
|
||||
def test_take_screenshot_area_linux(self, tmp_path: Path) -> None:
|
||||
"""Should take area screenshot on Linux."""
|
||||
with patch.object(Constants, "IS_WINDOWS", False), patch.object(Constants, "IS_MACOS", False), patch.object(
|
||||
Path, "home", return_value=tmp_path
|
||||
), patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
screenshot.take_screenshot_area()
|
||||
assert mock_run.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# main function
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestMain:
|
||||
"""Test main function."""
|
||||
|
||||
def test_main_full_command(self, tmp_path: Path) -> None:
|
||||
"""main() should handle full command."""
|
||||
with patch("sys.argv", ["screenshot", "full"]), patch.object(px, "run") as mock_run:
|
||||
screenshot.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_area_command(self, tmp_path: Path) -> None:
|
||||
"""main() should handle area command."""
|
||||
with patch("sys.argv", ["screenshot", "area"]), patch.object(px, "run") as mock_run:
|
||||
screenshot.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_with_no_args_shows_help(self) -> None:
|
||||
"""main() with no args should show help."""
|
||||
with patch("sys.argv", ["screenshot"]):
|
||||
screenshot.main()
|
||||
# Should print help and return
|
||||
@@ -0,0 +1,163 @@
|
||||
"""Tests for cli.sshcopyid module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.cli import sshcopyid
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# ssh_copy_id
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestSshCopyId:
|
||||
"""Test ssh_copy_id function."""
|
||||
|
||||
def test_ssh_copy_id_pub_key_not_exists(self, tmp_path: Path) -> None:
|
||||
"""Should handle nonexistent public key."""
|
||||
with patch.object(Path, "expanduser", return_value=tmp_path / "nonexistent.pub"), pytest.raises(SystemExit):
|
||||
sshcopyid.ssh_copy_id("localhost", "user", "password")
|
||||
|
||||
def test_ssh_copy_id_sshpass_not_found(self, tmp_path: Path) -> None:
|
||||
"""Should handle sshpass not found."""
|
||||
pub_key = tmp_path / "id_rsa.pub"
|
||||
pub_key.write_text("ssh-rsa AAAAB3...")
|
||||
|
||||
with patch.object(Path, "expanduser", return_value=pub_key), patch(
|
||||
"subprocess.run", side_effect=FileNotFoundError
|
||||
), pytest.raises(SystemExit):
|
||||
sshcopyid.ssh_copy_id("localhost", "user", "password")
|
||||
|
||||
def test_ssh_copy_id_timeout(self, tmp_path: Path) -> None:
|
||||
"""Should handle SSH timeout."""
|
||||
pub_key = tmp_path / "id_rsa.pub"
|
||||
pub_key.write_text("ssh-rsa AAAAB3...")
|
||||
|
||||
with patch.object(Path, "expanduser", return_value=pub_key), patch(
|
||||
"subprocess.run", side_effect=subprocess.TimeoutExpired("cmd", 30)
|
||||
), pytest.raises(SystemExit):
|
||||
sshcopyid.ssh_copy_id("localhost", "user", "password")
|
||||
|
||||
def test_ssh_copy_id_process_error(self, tmp_path: Path) -> None:
|
||||
"""Should handle SSH process error."""
|
||||
pub_key = tmp_path / "id_rsa.pub"
|
||||
pub_key.write_text("ssh-rsa AAAAB3...")
|
||||
|
||||
with patch.object(Path, "expanduser", return_value=pub_key), patch(
|
||||
"subprocess.run", side_effect=subprocess.CalledProcessError(1, "cmd")
|
||||
), pytest.raises(SystemExit):
|
||||
sshcopyid.ssh_copy_id("localhost", "user", "password")
|
||||
|
||||
def test_ssh_copy_id_success(self, tmp_path: Path) -> None:
|
||||
"""Should deploy SSH key successfully."""
|
||||
pub_key = tmp_path / "id_rsa.pub"
|
||||
pub_key.write_text("ssh-rsa AAAAB3...")
|
||||
|
||||
with patch.object(Path, "expanduser", return_value=pub_key), patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
sshcopyid.ssh_copy_id("localhost", "user", "password")
|
||||
assert mock_run.called
|
||||
|
||||
def test_ssh_copy_id_with_custom_port(self, tmp_path: Path) -> None:
|
||||
"""Should handle custom port."""
|
||||
pub_key = tmp_path / "id_rsa.pub"
|
||||
pub_key.write_text("ssh-rsa AAAAB3...")
|
||||
|
||||
with patch.object(Path, "expanduser", return_value=pub_key), patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
sshcopyid.ssh_copy_id("localhost", "user", "password", port=2222)
|
||||
# Verify port is used
|
||||
call_args = mock_run.call_args[0][0]
|
||||
assert "2222" in call_args
|
||||
|
||||
def test_ssh_copy_id_with_custom_keypath(self, tmp_path: Path) -> None:
|
||||
"""Should handle custom keypath."""
|
||||
custom_key = tmp_path / "custom.pub"
|
||||
custom_key.write_text("ssh-rsa AAAAB3...")
|
||||
|
||||
with patch.object(Path, "expanduser", return_value=custom_key), patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
sshcopyid.ssh_copy_id("localhost", "user", "password", keypath=str(custom_key))
|
||||
assert mock_run.called
|
||||
|
||||
def test_ssh_copy_id_with_custom_timeout(self, tmp_path: Path) -> None:
|
||||
"""Should handle custom timeout."""
|
||||
pub_key = tmp_path / "id_rsa.pub"
|
||||
pub_key.write_text("ssh-rsa AAAAB3...")
|
||||
|
||||
with patch.object(Path, "expanduser", return_value=pub_key), patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
sshcopyid.ssh_copy_id("localhost", "user", "password", timeout=60)
|
||||
# Verify timeout is used in ConnectTimeout option
|
||||
call_args = mock_run.call_args[0][0]
|
||||
assert "ConnectTimeout=60" in call_args
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# main function
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestMain:
|
||||
"""Test main function."""
|
||||
|
||||
def test_main_with_required_args(self) -> None:
|
||||
"""main() should handle required arguments."""
|
||||
with patch("sys.argv", ["sshcopyid", "localhost", "user", "password"]), patch.object(
|
||||
px, "run"
|
||||
) as mock_run, patch.object(sshcopyid, "ssh_copy_id"):
|
||||
sshcopyid.main()
|
||||
assert mock_run.called
|
||||
graph = mock_run.call_args[0][0]
|
||||
assert isinstance(graph, px.Graph)
|
||||
|
||||
def test_main_with_custom_port(self) -> None:
|
||||
"""main() should handle custom port argument."""
|
||||
with patch("sys.argv", ["sshcopyid", "localhost", "user", "password", "--port", "2222"]), patch.object(
|
||||
px, "run"
|
||||
) as mock_run, patch.object(sshcopyid, "ssh_copy_id"):
|
||||
sshcopyid.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_with_custom_keypath(self) -> None:
|
||||
"""main() should handle custom keypath argument."""
|
||||
with patch(
|
||||
"sys.argv", ["sshcopyid", "localhost", "user", "password", "--keypath", "/custom/key.pub"]
|
||||
), patch.object(px, "run") as mock_run, patch.object(sshcopyid, "ssh_copy_id"):
|
||||
sshcopyid.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_with_custom_timeout(self) -> None:
|
||||
"""main() should handle custom timeout argument."""
|
||||
with patch("sys.argv", ["sshcopyid", "localhost", "user", "password", "--timeout", "60"]), patch.object(
|
||||
px, "run"
|
||||
) as mock_run, patch.object(sshcopyid, "ssh_copy_id"):
|
||||
sshcopyid.main()
|
||||
assert mock_run.called
|
||||
|
||||
def test_main_with_no_args_shows_help(self) -> None:
|
||||
"""main() with no args should show help and exit."""
|
||||
with patch("sys.argv", ["sshcopyid"]), pytest.raises(SystemExit) as exc_info:
|
||||
sshcopyid.main()
|
||||
assert exc_info.value.code == 2
|
||||
|
||||
def test_main_creates_task_spec_with_correct_name(self) -> None:
|
||||
"""main() should create TaskSpec with correct name."""
|
||||
with patch("sys.argv", ["sshcopyid", "localhost", "user", "password"]), patch.object(
|
||||
px, "run"
|
||||
) as mock_run, patch.object(sshcopyid, "ssh_copy_id"):
|
||||
sshcopyid.main()
|
||||
graph = mock_run.call_args[0][0]
|
||||
task_names = list(graph.all_specs().keys())
|
||||
assert "ssh_deploy" in task_names
|
||||
|
||||
def test_main_uses_thread_strategy(self) -> None:
|
||||
"""main() should use thread strategy."""
|
||||
with patch("sys.argv", ["sshcopyid", "localhost", "user", "password"]), patch.object(
|
||||
px, "run"
|
||||
) as mock_run, patch.object(sshcopyid, "ssh_copy_id"):
|
||||
sshcopyid.main()
|
||||
assert mock_run.call_args[1]["strategy"] == "thread"
|
||||
@@ -0,0 +1,102 @@
|
||||
"""Tests for cli.taskkill module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.cli.system import taskkill
|
||||
from pyflowx.conditions import Constants
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# main function
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestMain:
|
||||
"""Test main function."""
|
||||
|
||||
def test_main_with_single_process(self) -> None:
|
||||
"""main() should handle single process argument."""
|
||||
with patch("sys.argv", ["taskkill", "chrome.exe"]), patch.object(px, "run") as mock_run:
|
||||
taskkill.main()
|
||||
assert mock_run.called
|
||||
graph = mock_run.call_args[0][0]
|
||||
assert isinstance(graph, px.Graph)
|
||||
|
||||
def test_main_with_multiple_processes(self) -> None:
|
||||
"""main() should handle multiple process arguments."""
|
||||
with patch("sys.argv", ["taskkill", "chrome.exe", "python.exe", "node.exe"]), patch.object(
|
||||
px, "run"
|
||||
) as mock_run:
|
||||
taskkill.main()
|
||||
assert mock_run.called
|
||||
graph = mock_run.call_args[0][0]
|
||||
assert isinstance(graph, px.Graph)
|
||||
|
||||
def test_main_with_no_args_shows_help(self) -> None:
|
||||
"""main() with no args should show help and exit."""
|
||||
with patch("sys.argv", ["taskkill"]), pytest.raises(SystemExit) as exc_info:
|
||||
taskkill.main()
|
||||
assert exc_info.value.code == 2
|
||||
|
||||
def test_main_creates_task_specs_with_correct_names(self) -> None:
|
||||
"""main() should create TaskSpecs with correct names."""
|
||||
with patch("sys.argv", ["taskkill", "chrome.exe", "python.exe"]), patch.object(px, "run") as mock_run:
|
||||
taskkill.main()
|
||||
graph = mock_run.call_args[0][0]
|
||||
task_names = list(graph.all_specs().keys())
|
||||
assert "kill_chrome.exe" in task_names
|
||||
assert "kill_python.exe" in task_names
|
||||
|
||||
def test_main_uses_thread_strategy(self) -> None:
|
||||
"""main() should use thread strategy."""
|
||||
with patch("sys.argv", ["taskkill", "chrome.exe"]), patch.object(px, "run") as mock_run:
|
||||
taskkill.main()
|
||||
assert mock_run.call_args[1]["strategy"] == "thread"
|
||||
|
||||
def test_main_windows_command_format(self) -> None:
|
||||
"""main() should use Windows command format on Windows."""
|
||||
if Constants.IS_WINDOWS:
|
||||
with patch("sys.argv", ["taskkill", "chrome.exe"]), patch.object(px, "run") as mock_run:
|
||||
taskkill.main()
|
||||
graph = mock_run.call_args[0][0]
|
||||
specs = graph.all_specs()
|
||||
# Check that command includes Windows taskkill format
|
||||
for spec in specs.values():
|
||||
assert spec.cmd[0] == "taskkill"
|
||||
assert spec.cmd[1] == "/f"
|
||||
assert spec.cmd[2] == "/im"
|
||||
|
||||
def test_main_linux_command_format(self) -> None:
|
||||
"""main() should use Linux command format on Linux."""
|
||||
with patch.object(Constants, "IS_WINDOWS", False), patch("sys.argv", ["taskkill", "chrome.exe"]), patch.object(
|
||||
px, "run"
|
||||
) as mock_run:
|
||||
taskkill.main()
|
||||
graph = mock_run.call_args[0][0]
|
||||
specs = graph.all_specs()
|
||||
# Check that command includes Linux pkill format
|
||||
for spec in specs.values():
|
||||
assert spec.cmd[0] == "pkill"
|
||||
assert spec.cmd[1] == "-f"
|
||||
|
||||
def test_main_tasks_have_verbose_true(self) -> None:
|
||||
"""main() should create tasks with verbose=True."""
|
||||
with patch("sys.argv", ["taskkill", "chrome.exe"]), patch.object(px, "run") as mock_run:
|
||||
taskkill.main()
|
||||
graph = mock_run.call_args[0][0]
|
||||
specs = graph.all_specs()
|
||||
for spec in specs.values():
|
||||
assert spec.verbose is True
|
||||
|
||||
def test_main_adds_wildcard_to_process_name(self) -> None:
|
||||
"""main() should add wildcard to process name."""
|
||||
with patch("sys.argv", ["taskkill", "chrome.exe"]), patch.object(px, "run") as mock_run:
|
||||
taskkill.main()
|
||||
graph = mock_run.call_args[0][0]
|
||||
specs = graph.all_specs()
|
||||
# Check that wildcard is added
|
||||
for spec in specs.values():
|
||||
assert spec.cmd[-1].endswith("*")
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# 将 tests 目录加入 sys.path,使进程池测试能 import _proc_helper 模块级辅助函数。
|
||||
# 进程池 pickle 要求被调用函数为模块级,conftest.py 在 xdist worker 中也会执行。
|
||||
_TESTS_DIR = str(Path(__file__).resolve().parent)
|
||||
if _TESTS_DIR not in sys.path:
|
||||
sys.path.insert(0, _TESTS_DIR)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def packtool_tmp_workdir(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""自动切换到临时工作目录,防止测试污染项目根目录.
|
||||
|
||||
Args:
|
||||
tmp_path: pytest 提供的临时目录
|
||||
monkeypatch: pytest 的 monkeypatch 工具
|
||||
"""
|
||||
monkeypatch.chdir(tmp_path)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,101 @@
|
||||
"""Tests for Graph.chain DSL."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.task import TaskSpec
|
||||
|
||||
|
||||
def _fn() -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_chain_basic_linkage() -> None:
|
||||
"""chain(a, b, c) 应建立 a->b->c 依赖."""
|
||||
a = TaskSpec("a", _fn)
|
||||
b = TaskSpec("b", _fn)
|
||||
c = TaskSpec("c", _fn)
|
||||
|
||||
graph = px.Graph().chain(a, b, c)
|
||||
|
||||
assert graph.all_specs()["b"].depends_on == ("a",)
|
||||
assert graph.all_specs()["c"].depends_on == ("b",)
|
||||
assert graph.all_specs()["a"].depends_on == ()
|
||||
|
||||
|
||||
def test_chain_single_spec() -> None:
|
||||
"""chain(a) 应只注册 a,无依赖."""
|
||||
a = TaskSpec("a", _fn)
|
||||
graph = px.Graph().chain(a)
|
||||
assert "a" in graph
|
||||
assert graph.all_specs()["a"].depends_on == ()
|
||||
|
||||
|
||||
def test_chain_preserves_existing_deps() -> None:
|
||||
"""chain 应保留 spec 已有的 depends_on."""
|
||||
a = TaskSpec("a", _fn)
|
||||
b = TaskSpec("b", _fn)
|
||||
c = TaskSpec("c", _fn, depends_on=("b",))
|
||||
|
||||
graph = px.Graph().chain(a, b, c)
|
||||
# c 已有 depends_on=('b',),前驱是 b,已在依赖中,不重复添加
|
||||
assert graph.all_specs()["c"].depends_on == ("b",)
|
||||
|
||||
|
||||
def test_chain_merges_existing_deps() -> None:
|
||||
"""chain 应将前驱追加到已有依赖前(若不存在)."""
|
||||
a = TaskSpec("a", _fn)
|
||||
x = TaskSpec("x", _fn)
|
||||
c = TaskSpec("c", _fn, depends_on=("x",))
|
||||
|
||||
graph = px.Graph().chain(a, x, c)
|
||||
# c 前驱是 x,但 c 已依赖 x,不重复
|
||||
assert graph.all_specs()["c"].depends_on == ("x",)
|
||||
|
||||
|
||||
def test_chain_returns_self() -> None:
|
||||
"""chain 返回 self 支持链式调用."""
|
||||
a = TaskSpec("a", _fn)
|
||||
graph = px.Graph()
|
||||
assert graph.chain(a) is graph
|
||||
|
||||
|
||||
def test_chain_execution_order() -> None:
|
||||
"""chain 应保证执行顺序."""
|
||||
order: list[str] = []
|
||||
|
||||
def make(name: str):
|
||||
def fn() -> str:
|
||||
order.append(name)
|
||||
return name
|
||||
return fn
|
||||
|
||||
a = TaskSpec("a", make("a"))
|
||||
b = TaskSpec("b", make("b"))
|
||||
c = TaskSpec("c", make("c"))
|
||||
|
||||
graph = px.Graph().chain(a, b, c)
|
||||
report = px.run(graph)
|
||||
assert report.success
|
||||
assert order == ["a", "b", "c"]
|
||||
|
||||
|
||||
def test_chain_with_decorator_specs() -> None:
|
||||
"""chain 应与 @task 装饰器配合."""
|
||||
|
||||
@px.task
|
||||
def extract() -> int:
|
||||
return 1
|
||||
|
||||
@px.task
|
||||
def transform(extract: int) -> int:
|
||||
return extract + 10
|
||||
|
||||
@px.task
|
||||
def load(transform: int) -> int:
|
||||
return transform + 100
|
||||
|
||||
graph = px.Graph().chain(extract, transform, load)
|
||||
report = px.run(graph)
|
||||
assert report.success
|
||||
assert report["load"] == 111
|
||||
@@ -0,0 +1,499 @@
|
||||
"""Tests for command reference feature in CliRunner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
|
||||
class TestCommandReferences:
|
||||
"""Test string references in Graph.from_specs."""
|
||||
|
||||
def test_simple_command_reference(self) -> None:
|
||||
"""Should expand simple command reference."""
|
||||
build_task = px.TaskSpec("build", cmd=["echo", "building"])
|
||||
test_task = px.TaskSpec("test", cmd=["echo", "testing"])
|
||||
|
||||
runner = px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"build": px.Graph.from_specs([build_task]),
|
||||
"test": px.Graph.from_specs([test_task]),
|
||||
"all": px.Graph.from_specs([build_task, "test"]),
|
||||
},
|
||||
)
|
||||
|
||||
# Check that 'all' command has both tasks
|
||||
all_tasks = list(runner.graphs["all"].all_specs().keys())
|
||||
assert "build" in all_tasks
|
||||
assert "test" in all_tasks
|
||||
assert len(all_tasks) == 2
|
||||
|
||||
def test_multiple_command_references(self) -> None:
|
||||
"""Should expand multiple command references."""
|
||||
task1 = px.TaskSpec("task1", cmd=["echo", "1"])
|
||||
task2 = px.TaskSpec("task2", cmd=["echo", "2"])
|
||||
task3 = px.TaskSpec("task3", cmd=["echo", "3"])
|
||||
|
||||
runner = px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"cmd1": px.Graph.from_specs([task1]),
|
||||
"cmd2": px.Graph.from_specs([task2]),
|
||||
"cmd3": px.Graph.from_specs([task3]),
|
||||
"all": px.Graph.from_specs(["cmd1", "cmd2", "cmd3"]),
|
||||
},
|
||||
)
|
||||
|
||||
# Check that 'all' command has all tasks
|
||||
all_tasks = list(runner.graphs["all"].all_specs().keys())
|
||||
assert set(all_tasks) == {"task1", "task2", "task3"}
|
||||
|
||||
def test_specific_task_reference(self) -> None:
|
||||
"""Should expand specific task reference."""
|
||||
lint_task = px.TaskSpec("lint", cmd=["echo", "linting"])
|
||||
format_task = px.TaskSpec("format", cmd=["echo", "formatting"])
|
||||
|
||||
runner = px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"lint": px.Graph.from_specs([lint_task, format_task]),
|
||||
"quick": px.Graph.from_specs(["lint.lint"]),
|
||||
},
|
||||
)
|
||||
|
||||
# Check that 'quick' command only has lint task
|
||||
quick_tasks = list(runner.graphs["quick"].all_specs().keys())
|
||||
assert quick_tasks == ["lint"]
|
||||
|
||||
def test_nested_command_reference(self) -> None:
|
||||
"""Should expand nested command references."""
|
||||
task1 = px.TaskSpec("task1", cmd=["echo", "1"])
|
||||
task2 = px.TaskSpec("task2", cmd=["echo", "2"])
|
||||
task3 = px.TaskSpec("task3", cmd=["echo", "3"])
|
||||
|
||||
runner = px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"cmd1": px.Graph.from_specs([task1]),
|
||||
"cmd2": px.Graph.from_specs(["cmd1", task2]),
|
||||
"cmd3": px.Graph.from_specs(["cmd2", task3]),
|
||||
},
|
||||
)
|
||||
|
||||
# Check that 'cmd3' has all tasks
|
||||
cmd3_tasks = list(runner.graphs["cmd3"].all_specs().keys())
|
||||
assert set(cmd3_tasks) == {"task1", "task2", "task3"}
|
||||
|
||||
def test_circular_reference_error(self) -> None:
|
||||
"""Should raise error for circular references."""
|
||||
task1 = px.TaskSpec("task1", cmd=["echo", "1"])
|
||||
|
||||
with pytest.raises(ValueError, match="循环引用"):
|
||||
px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"cmd1": px.Graph.from_specs(["cmd1", task1]),
|
||||
},
|
||||
)
|
||||
|
||||
def test_invalid_command_reference_error(self) -> None:
|
||||
"""Should raise error for invalid command reference."""
|
||||
task1 = px.TaskSpec("task1", cmd=["echo", "1"])
|
||||
|
||||
with pytest.raises(ValueError, match="引用的命令 'invalid' 不存在"):
|
||||
px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"cmd1": px.Graph.from_specs(["invalid", task1]),
|
||||
},
|
||||
)
|
||||
|
||||
def test_invalid_task_reference_error(self) -> None:
|
||||
"""Should raise error for invalid task reference."""
|
||||
task1 = px.TaskSpec("task1", cmd=["echo", "1"])
|
||||
|
||||
with pytest.raises(ValueError, match="任务 'invalid' 不存在于命令 'cmd1' 中"):
|
||||
px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"cmd1": px.Graph.from_specs([task1]),
|
||||
"cmd2": px.Graph.from_specs(["cmd1.invalid"]),
|
||||
},
|
||||
)
|
||||
|
||||
def test_reference_preserves_dependencies(self) -> None:
|
||||
"""Should preserve dependencies when expanding references."""
|
||||
task1 = px.TaskSpec("task1", cmd=["echo", "1"])
|
||||
task2 = px.TaskSpec("task2", cmd=["echo", "2"], depends_on=("task1",))
|
||||
|
||||
runner = px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"cmd1": px.Graph.from_specs([task1, task2]),
|
||||
"cmd2": px.Graph.from_specs(["cmd1"]),
|
||||
},
|
||||
)
|
||||
|
||||
# Check that dependencies are preserved
|
||||
cmd2_deps = runner.graphs["cmd2"].deps
|
||||
assert cmd2_deps["task2"] == ("task1",)
|
||||
|
||||
def test_mixed_references_and_tasks(self) -> None:
|
||||
"""Should handle mixed references and direct tasks."""
|
||||
task1 = px.TaskSpec("task1", cmd=["echo", "1"])
|
||||
task2 = px.TaskSpec("task2", cmd=["echo", "2"])
|
||||
task3 = px.TaskSpec("task3", cmd=["echo", "3"])
|
||||
|
||||
runner = px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"cmd1": px.Graph.from_specs([task1, task2]),
|
||||
"cmd2": px.Graph.from_specs(["cmd1", task3]),
|
||||
},
|
||||
)
|
||||
|
||||
# Check that 'cmd2' has all tasks
|
||||
cmd2_tasks = list(runner.graphs["cmd2"].all_specs().keys())
|
||||
assert set(cmd2_tasks) == {"task1", "task2", "task3"}
|
||||
|
||||
def test_execution_order_with_references(self) -> None:
|
||||
"""Should execute references in correct order."""
|
||||
task1 = px.TaskSpec("task1", cmd=["echo", "step1"])
|
||||
task2 = px.TaskSpec("task2", cmd=["echo", "step2"])
|
||||
task3 = px.TaskSpec("task3", cmd=["echo", "step3"])
|
||||
task4 = px.TaskSpec("task4", cmd=["echo", "step4"])
|
||||
task5 = px.TaskSpec("task5", cmd=["echo", "step5"])
|
||||
|
||||
runner = px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"cmd1": px.Graph.from_specs([task1]),
|
||||
"cmd2": px.Graph.from_specs([task2, task3]),
|
||||
"cmd3": px.Graph.from_specs([task4]),
|
||||
"ordered": px.Graph.from_specs(["cmd1", "cmd2", "cmd3", task5]),
|
||||
},
|
||||
)
|
||||
|
||||
# Check execution order through layers
|
||||
layers = runner.graphs["ordered"].layers()
|
||||
|
||||
# Layer 1 should have task1 (cmd1)
|
||||
assert "task1" in layers[0]
|
||||
|
||||
# Layer 2 should have task2 and task3 (cmd2)
|
||||
assert "task2" in layers[1]
|
||||
assert "task3" in layers[1]
|
||||
|
||||
# Layer 3 should have task4 (cmd3)
|
||||
assert "task4" in layers[2]
|
||||
|
||||
# Layer 4 should have task5 (original task)
|
||||
assert "task5" in layers[3]
|
||||
|
||||
# Verify total layers
|
||||
assert len(layers) == 4
|
||||
|
||||
def test_execution_order_multiple_original_tasks(self) -> None:
|
||||
"""Should execute multiple original TaskSpecs in correct order."""
|
||||
task1 = px.TaskSpec("task1", cmd=["echo", "1"])
|
||||
task2 = px.TaskSpec("task2", cmd=["echo", "2"])
|
||||
task3 = px.TaskSpec("task3", cmd=["echo", "3"])
|
||||
task4 = px.TaskSpec("task4", cmd=["echo", "4"])
|
||||
task5 = px.TaskSpec("task5", cmd=["echo", "5"])
|
||||
|
||||
runner = px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"cmd1": px.Graph.from_specs([task1]),
|
||||
"cmd2": px.Graph.from_specs([task2]),
|
||||
"all": px.Graph.from_specs(["cmd1", "cmd2", task3, task4, task5]),
|
||||
},
|
||||
)
|
||||
|
||||
# Check execution order through layers
|
||||
layers = runner.graphs["all"].layers()
|
||||
|
||||
# Layer 1: task1 (cmd1)
|
||||
assert "task1" in layers[0]
|
||||
|
||||
# Layer 2: task2 (cmd2)
|
||||
assert "task2" in layers[1]
|
||||
|
||||
# Layer 3: task3 (first original TaskSpec)
|
||||
assert "task3" in layers[2]
|
||||
|
||||
# Layer 4: task4 (second original TaskSpec)
|
||||
assert "task4" in layers[3]
|
||||
|
||||
# Layer 5: task5 (third original TaskSpec)
|
||||
assert "task5" in layers[4]
|
||||
|
||||
# Verify total layers
|
||||
assert len(layers) == 5
|
||||
|
||||
def test_execution_order_with_internal_dependencies(self) -> None:
|
||||
"""Should preserve internal dependencies within referenced commands."""
|
||||
task1 = px.TaskSpec("task1", cmd=["echo", "1"])
|
||||
task2 = px.TaskSpec("task2", cmd=["echo", "2"], depends_on=("task1",))
|
||||
task3 = px.TaskSpec("task3", cmd=["echo", "3"])
|
||||
task4 = px.TaskSpec("task4", cmd=["echo", "4"])
|
||||
|
||||
runner = px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"cmd1": px.Graph.from_specs([task1, task2]),
|
||||
"cmd2": px.Graph.from_specs([task3]),
|
||||
"all": px.Graph.from_specs(["cmd1", "cmd2", task4]),
|
||||
},
|
||||
)
|
||||
|
||||
# Check execution order through layers
|
||||
layers = runner.graphs["all"].layers()
|
||||
|
||||
# Layer 1: task1
|
||||
assert "task1" in layers[0]
|
||||
|
||||
# Layer 2: task2 (depends on task1)
|
||||
assert "task2" in layers[1]
|
||||
|
||||
# Layer 3: task3 (cmd2, depends on task2)
|
||||
assert "task3" in layers[2]
|
||||
|
||||
# Layer 4: task4 (original TaskSpec, depends on task3)
|
||||
assert "task4" in layers[3]
|
||||
|
||||
# Verify total layers
|
||||
assert len(layers) == 4
|
||||
|
||||
def test_execution_order_pymake_bump_scenario(self) -> None:
|
||||
"""Should execute pymake bump command in correct order."""
|
||||
# Simulate pymake bump scenario
|
||||
git_clean = px.TaskSpec("git_clean", cmd=["echo", "clean"])
|
||||
typecheck = px.TaskSpec("typecheck", cmd=["echo", "typecheck"])
|
||||
lint = px.TaskSpec("lint", cmd=["echo", "lint"])
|
||||
format_task = px.TaskSpec("format", cmd=["echo", "format"], depends_on=("lint",))
|
||||
git_add_all = px.TaskSpec("git_add_all", cmd=["echo", "git add -A"])
|
||||
bump = px.TaskSpec("bumpversion", cmd=["echo", "bumpversion -t"])
|
||||
|
||||
runner = px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"c": px.Graph.from_specs([git_clean]),
|
||||
"tc": px.Graph.from_specs([typecheck, "lint"]),
|
||||
"lint": px.Graph.from_specs([lint, format_task]),
|
||||
"bump": px.Graph.from_specs(["c", "tc", git_add_all, bump]),
|
||||
},
|
||||
)
|
||||
|
||||
# Check execution order through layers
|
||||
layers = runner.graphs["bump"].layers()
|
||||
|
||||
# Layer 1: git_clean (c)
|
||||
assert "git_clean" in layers[0]
|
||||
|
||||
# Layer 2: lint (tc.lint, depends on git_clean)
|
||||
assert "lint" in layers[1]
|
||||
|
||||
# Layer 3: format (tc.lint.format, depends on lint)
|
||||
assert "format" in layers[2]
|
||||
|
||||
# Layer 4: typecheck (tc.typecheck, depends on format)
|
||||
assert "typecheck" in layers[3]
|
||||
|
||||
# Layer 5: git_add_all (original TaskSpec, depends on typecheck)
|
||||
assert "git_add_all" in layers[4]
|
||||
|
||||
# Layer 6: bumpversion (original TaskSpec, depends on git_add_all)
|
||||
assert "bumpversion" in layers[5]
|
||||
|
||||
# Verify total layers
|
||||
assert len(layers) == 6
|
||||
|
||||
def test_execution_order_only_references(self) -> None:
|
||||
"""Should execute only references without original TaskSpecs."""
|
||||
task1 = px.TaskSpec("task1", cmd=["echo", "1"])
|
||||
task2 = px.TaskSpec("task2", cmd=["echo", "2"])
|
||||
task3 = px.TaskSpec("task3", cmd=["echo", "3"])
|
||||
|
||||
runner = px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"cmd1": px.Graph.from_specs([task1]),
|
||||
"cmd2": px.Graph.from_specs([task2]),
|
||||
"cmd3": px.Graph.from_specs([task3]),
|
||||
"all": px.Graph.from_specs(["cmd1", "cmd2", "cmd3"]),
|
||||
},
|
||||
)
|
||||
|
||||
# Check execution order through layers
|
||||
layers = runner.graphs["all"].layers()
|
||||
|
||||
# Layer 1: task1 (cmd1)
|
||||
assert "task1" in layers[0]
|
||||
|
||||
# Layer 2: task2 (cmd2, depends on task1)
|
||||
assert "task2" in layers[1]
|
||||
|
||||
# Layer 3: task3 (cmd3, depends on task2)
|
||||
assert "task3" in layers[2]
|
||||
|
||||
# Verify total layers
|
||||
assert len(layers) == 3
|
||||
|
||||
def test_execution_order_only_original_tasks(self) -> None:
|
||||
"""Should execute only original TaskSpecs without references."""
|
||||
task1 = px.TaskSpec("task1", cmd=["echo", "1"])
|
||||
task2 = px.TaskSpec("task2", cmd=["echo", "2"])
|
||||
task3 = px.TaskSpec("task3", cmd=["echo", "3"])
|
||||
|
||||
runner = px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"all": px.Graph.from_specs([task1, task2, task3]),
|
||||
},
|
||||
)
|
||||
|
||||
# Check execution order through layers
|
||||
layers = runner.graphs["all"].layers()
|
||||
|
||||
# All tasks should be in layer 1 (no dependencies)
|
||||
assert "task1" in layers[0]
|
||||
assert "task2" in layers[0]
|
||||
assert "task3" in layers[0]
|
||||
|
||||
# Verify total layers
|
||||
assert len(layers) == 1
|
||||
|
||||
def test_execution_order_single_reference(self) -> None:
|
||||
"""Should execute single reference correctly."""
|
||||
task1 = px.TaskSpec("task1", cmd=["echo", "1"])
|
||||
task2 = px.TaskSpec("task2", cmd=["echo", "2"])
|
||||
|
||||
runner = px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"cmd1": px.Graph.from_specs([task1, task2]),
|
||||
"all": px.Graph.from_specs(["cmd1"]),
|
||||
},
|
||||
)
|
||||
|
||||
# Check execution order through layers
|
||||
layers = runner.graphs["all"].layers()
|
||||
|
||||
# Should have the same structure as cmd1
|
||||
assert "task1" in layers[0]
|
||||
assert "task2" in layers[0]
|
||||
|
||||
# Verify total layers
|
||||
assert len(layers) == 1
|
||||
|
||||
def test_execution_order_deep_nesting(self) -> None:
|
||||
"""Should execute deeply nested references correctly."""
|
||||
task1 = px.TaskSpec("task1", cmd=["echo", "1"])
|
||||
task2 = px.TaskSpec("task2", cmd=["echo", "2"])
|
||||
task3 = px.TaskSpec("task3", cmd=["echo", "3"])
|
||||
task4 = px.TaskSpec("task4", cmd=["echo", "4"])
|
||||
task5 = px.TaskSpec("task5", cmd=["echo", "5"])
|
||||
|
||||
runner = px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"cmd1": px.Graph.from_specs([task1]),
|
||||
"cmd2": px.Graph.from_specs(["cmd1", task2]),
|
||||
"cmd3": px.Graph.from_specs(["cmd2", task3]),
|
||||
"cmd4": px.Graph.from_specs(["cmd3", task4]),
|
||||
"cmd5": px.Graph.from_specs(["cmd4", task5]),
|
||||
},
|
||||
)
|
||||
|
||||
# Check execution order through layers
|
||||
layers = runner.graphs["cmd5"].layers()
|
||||
|
||||
# Should execute in order: task1 -> task2 -> task3 -> task4 -> task5
|
||||
assert "task1" in layers[0]
|
||||
assert "task2" in layers[1]
|
||||
assert "task3" in layers[2]
|
||||
assert "task4" in layers[3]
|
||||
assert "task5" in layers[4]
|
||||
|
||||
# Verify total layers
|
||||
assert len(layers) == 5
|
||||
|
||||
def test_execution_order_with_parallel_tasks_in_reference(self) -> None:
|
||||
"""Should handle parallel tasks within referenced commands."""
|
||||
task1 = px.TaskSpec("task1", cmd=["echo", "1"])
|
||||
task2 = px.TaskSpec("task2", cmd=["echo", "2"])
|
||||
task3 = px.TaskSpec("task3", cmd=["echo", "3"])
|
||||
task4 = px.TaskSpec("task4", cmd=["echo", "4"])
|
||||
|
||||
runner = px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"cmd1": px.Graph.from_specs([task1, task2]), # Parallel tasks
|
||||
"cmd2": px.Graph.from_specs([task3, task4]), # Parallel tasks
|
||||
"all": px.Graph.from_specs(["cmd1", "cmd2"]),
|
||||
},
|
||||
)
|
||||
|
||||
# Check execution order through layers
|
||||
layers = runner.graphs["all"].layers()
|
||||
|
||||
# Layer 1: task1 and task2 (cmd1, parallel)
|
||||
assert "task1" in layers[0]
|
||||
assert "task2" in layers[0]
|
||||
|
||||
# Layer 2: task3 and task4 (cmd2, depends on cmd1's last task)
|
||||
# Note: Both task3 and task4 should depend on the last task of cmd1
|
||||
assert "task3" in layers[1]
|
||||
assert "task4" in layers[1]
|
||||
|
||||
# Verify total layers
|
||||
assert len(layers) == 2
|
||||
|
||||
def test_execution_order_complex_mixed_scenario(self) -> None:
|
||||
"""Should handle complex mixed scenario with references and TaskSpecs."""
|
||||
# Create a complex scenario
|
||||
clean = px.TaskSpec("clean", cmd=["echo", "clean"])
|
||||
build1 = px.TaskSpec("build1", cmd=["echo", "build1"])
|
||||
build2 = px.TaskSpec("build2", cmd=["echo", "build2"], depends_on=("build1",))
|
||||
test1 = px.TaskSpec("test1", cmd=["echo", "test1"])
|
||||
test2 = px.TaskSpec("test2", cmd=["echo", "test2"])
|
||||
package = px.TaskSpec("package", cmd=["echo", "package"])
|
||||
deploy = px.TaskSpec("deploy", cmd=["echo", "deploy"])
|
||||
|
||||
runner = px.CliRunner(
|
||||
strategy="sequential",
|
||||
aliases={
|
||||
"clean": px.Graph.from_specs([clean]),
|
||||
"build": px.Graph.from_specs([build1, build2]),
|
||||
"test": px.Graph.from_specs([test1, test2]),
|
||||
"release": px.Graph.from_specs(["clean", "build", "test", package, deploy]),
|
||||
},
|
||||
)
|
||||
|
||||
# Check execution order through layers
|
||||
layers = runner.graphs["release"].layers()
|
||||
|
||||
# Layer 1: clean
|
||||
assert "clean" in layers[0]
|
||||
|
||||
# Layer 2: build1 (depends on clean)
|
||||
assert "build1" in layers[1]
|
||||
|
||||
# Layer 3: build2 (depends on build1)
|
||||
assert "build2" in layers[2]
|
||||
|
||||
# Layer 4: test1 and test2 (depends on build2)
|
||||
assert "test1" in layers[3]
|
||||
assert "test2" in layers[3]
|
||||
|
||||
# Layer 5: package (depends on test1/test2)
|
||||
assert "package" in layers[4]
|
||||
|
||||
# Layer 6: deploy (depends on package)
|
||||
assert "deploy" in layers[5]
|
||||
|
||||
# Verify total layers
|
||||
assert len(layers) == 6
|
||||
@@ -0,0 +1,359 @@
|
||||
"""Tests for conditions module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from pyflowx.conditions import (
|
||||
IS_LINUX,
|
||||
IS_MACOS,
|
||||
IS_POSIX,
|
||||
IS_WINDOWS,
|
||||
BuiltinConditions,
|
||||
Constants,
|
||||
)
|
||||
|
||||
_CTX: dict[str, object] = {}
|
||||
|
||||
|
||||
def test_constants_is_windows():
|
||||
assert (sys.platform == "win32") == Constants.IS_WINDOWS
|
||||
|
||||
|
||||
def test_constants_is_linux():
|
||||
assert (sys.platform == "linux") == Constants.IS_LINUX
|
||||
|
||||
|
||||
def test_constants_is_macos():
|
||||
assert (sys.platform == "darwin") == Constants.IS_MACOS
|
||||
|
||||
|
||||
def test_constants_is_posix():
|
||||
assert (sys.platform != "win32") == Constants.IS_POSIX
|
||||
|
||||
|
||||
def test_module_level_static_conditions():
|
||||
assert IS_WINDOWS(_CTX) == Constants.IS_WINDOWS
|
||||
assert IS_LINUX(_CTX) == Constants.IS_LINUX
|
||||
assert IS_MACOS(_CTX) == Constants.IS_MACOS
|
||||
assert IS_POSIX(_CTX) == Constants.IS_POSIX
|
||||
|
||||
|
||||
def test_python_version_major_only():
|
||||
current_major = sys.version_info.major
|
||||
assert BuiltinConditions.PYTHON_VERSION(current_major)(_CTX) is True
|
||||
assert BuiltinConditions.PYTHON_VERSION(current_major + 1)(_CTX) is False
|
||||
|
||||
|
||||
def test_python_version_with_minor():
|
||||
current_major = sys.version_info.major
|
||||
current_minor = sys.version_info.minor
|
||||
assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor)(_CTX) is True
|
||||
assert BuiltinConditions.PYTHON_VERSION(current_major, current_minor + 1)(_CTX) is False
|
||||
|
||||
|
||||
def test_python_version_at_least():
|
||||
current_major = sys.version_info.major
|
||||
current_minor = sys.version_info.minor
|
||||
assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major, current_minor)(_CTX) is True
|
||||
assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major - 1, 0)(_CTX) is True
|
||||
assert BuiltinConditions.PYTHON_VERSION_AT_LEAST(current_major + 1, 0)(_CTX) is False
|
||||
|
||||
|
||||
def test_has_installed_true():
|
||||
condition = BuiltinConditions.HAS_INSTALLED("python3")
|
||||
assert condition(_CTX) is True
|
||||
|
||||
|
||||
def test_has_installed_false():
|
||||
condition = BuiltinConditions.HAS_INSTALLED("nonexistent_app_12345")
|
||||
assert condition(_CTX) is False
|
||||
|
||||
|
||||
def test_env_var_exists_true():
|
||||
with patch.dict(os.environ, {"TEST_VAR": "value"}):
|
||||
condition = BuiltinConditions.ENV_VAR_EXISTS("TEST_VAR")
|
||||
assert condition(_CTX) is True
|
||||
|
||||
|
||||
def test_env_var_exists_false():
|
||||
condition = BuiltinConditions.ENV_VAR_EXISTS("NONEXISTENT_VAR_12345")
|
||||
assert condition(_CTX) is False
|
||||
|
||||
|
||||
def test_env_var_equals_true():
|
||||
with patch.dict(os.environ, {"TEST_VAR": "expected_value"}):
|
||||
condition = BuiltinConditions.ENV_VAR_EQUALS("TEST_VAR", "expected_value")
|
||||
assert condition(_CTX) is True
|
||||
|
||||
|
||||
def test_env_var_equals_false():
|
||||
with patch.dict(os.environ, {"TEST_VAR": "different_value"}):
|
||||
condition = BuiltinConditions.ENV_VAR_EQUALS("TEST_VAR", "expected_value")
|
||||
assert condition(_CTX) is False
|
||||
|
||||
|
||||
def test_not():
|
||||
true_cond = BuiltinConditions.HAS_INSTALLED("python3")
|
||||
false_cond = BuiltinConditions.HAS_INSTALLED("nonexistent_app_12345")
|
||||
|
||||
assert BuiltinConditions.NOT(true_cond)(_CTX) is False
|
||||
assert BuiltinConditions.NOT(false_cond)(_CTX) is True
|
||||
|
||||
|
||||
def test_and_all_true():
|
||||
cond = BuiltinConditions.AND(
|
||||
BuiltinConditions.HAS_INSTALLED("python3"),
|
||||
BuiltinConditions.HAS_INSTALLED("python3"),
|
||||
)
|
||||
assert cond(_CTX) is True
|
||||
|
||||
|
||||
def test_and_one_false():
|
||||
cond = BuiltinConditions.AND(
|
||||
BuiltinConditions.HAS_INSTALLED("python3"),
|
||||
BuiltinConditions.HAS_INSTALLED("nonexistent_app"),
|
||||
)
|
||||
assert cond(_CTX) is False
|
||||
|
||||
|
||||
def test_or_all_false():
|
||||
cond = BuiltinConditions.OR(
|
||||
BuiltinConditions.HAS_INSTALLED("nonexistent1"),
|
||||
BuiltinConditions.HAS_INSTALLED("nonexistent2"),
|
||||
)
|
||||
assert cond(_CTX) is False
|
||||
|
||||
|
||||
def test_or_one_true():
|
||||
cond = BuiltinConditions.OR(
|
||||
BuiltinConditions.HAS_INSTALLED("nonexistent1"),
|
||||
BuiltinConditions.HAS_INSTALLED("python3"),
|
||||
)
|
||||
assert cond(_CTX) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 上下文条件:基于上游依赖结果
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_dep_equals_true():
|
||||
ctx = {"upstream": 42}
|
||||
cond = BuiltinConditions.DEP_EQUALS("upstream", 42)
|
||||
assert cond(ctx) is True
|
||||
|
||||
|
||||
def test_dep_equals_false():
|
||||
ctx = {"upstream": 99}
|
||||
cond = BuiltinConditions.DEP_EQUALS("upstream", 42)
|
||||
assert cond(ctx) is False
|
||||
|
||||
|
||||
def test_dep_equals_missing_dep():
|
||||
cond = BuiltinConditions.DEP_EQUALS("missing", 42)
|
||||
assert cond({}) is False
|
||||
|
||||
|
||||
def test_dep_matches_true():
|
||||
ctx = {"upstream": [1, 2, 3]}
|
||||
cond = BuiltinConditions.DEP_MATCHES("upstream", lambda v: len(v) == 3)
|
||||
assert cond(ctx) is True
|
||||
|
||||
|
||||
def test_dep_matches_false():
|
||||
ctx = {"upstream": [1, 2]}
|
||||
cond = BuiltinConditions.DEP_MATCHES("upstream", lambda v: len(v) == 3)
|
||||
assert cond(ctx) is False
|
||||
|
||||
|
||||
def test_dep_matches_exception_returns_false():
|
||||
ctx = {"upstream": ""}
|
||||
cond = BuiltinConditions.DEP_MATCHES("upstream", lambda v: v[0])
|
||||
assert cond(ctx) is False
|
||||
|
||||
|
||||
def test_dep_present_true():
|
||||
ctx = {"upstream": "value"}
|
||||
cond = BuiltinConditions.DEP_PRESENT("upstream")
|
||||
assert cond(ctx) is True
|
||||
|
||||
|
||||
def test_dep_present_false_none():
|
||||
# pyrefly: ignore [implicit-any-empty-container]
|
||||
ctx = {"upstream": None}
|
||||
cond = BuiltinConditions.DEP_PRESENT("upstream")
|
||||
assert cond(ctx) is False
|
||||
|
||||
|
||||
def test_dep_present_false_missing():
|
||||
cond = BuiltinConditions.DEP_PRESENT("missing")
|
||||
assert cond({}) is False
|
||||
|
||||
|
||||
def test_dep_truthy_true():
|
||||
ctx = {"upstream": [1]}
|
||||
cond = BuiltinConditions.DEP_TRUTHY("upstream")
|
||||
assert cond(ctx) is True
|
||||
|
||||
|
||||
def test_dep_truthy_false():
|
||||
# pyrefly: ignore [implicit-any-empty-container]
|
||||
ctx = {"upstream": []}
|
||||
cond = BuiltinConditions.DEP_TRUTHY("upstream")
|
||||
assert cond(ctx) is False
|
||||
|
||||
|
||||
def test_dep_truthy_missing():
|
||||
cond = BuiltinConditions.DEP_TRUTHY("missing")
|
||||
assert cond({}) is False
|
||||
|
||||
|
||||
def test_logical_combination_with_dep_conditions():
|
||||
ctx = {"a": 1, "b": 0}
|
||||
cond = BuiltinConditions.AND(
|
||||
BuiltinConditions.DEP_EQUALS("a", 1),
|
||||
BuiltinConditions.NOT(BuiltinConditions.DEP_TRUTHY("b")),
|
||||
)
|
||||
assert cond(ctx) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# IS_RUNNING: 跨平台 subprocess 检测
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_is_running_windows_found(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Windows 上 tasklist 检测到进程."""
|
||||
monkeypatch.setattr("pyflowx.conditions.Constants.IS_WINDOWS", True)
|
||||
monkeypatch.setattr("pyflowx.conditions.Constants.IS_LINUX", False)
|
||||
|
||||
class MockResult:
|
||||
stdout = "explorer.exe\nother.exe"
|
||||
returncode = 0
|
||||
|
||||
monkeypatch.setattr(
|
||||
"subprocess.run",
|
||||
lambda *_, **__: MockResult(),
|
||||
)
|
||||
cond = BuiltinConditions.IS_RUNNING("explorer.exe")
|
||||
assert cond({}) is True
|
||||
|
||||
|
||||
def test_is_running_windows_not_found(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Windows 上 tasklist 未检测到进程."""
|
||||
monkeypatch.setattr("pyflowx.conditions.Constants.IS_WINDOWS", True)
|
||||
monkeypatch.setattr("pyflowx.conditions.Constants.IS_LINUX", False)
|
||||
|
||||
class MockResult:
|
||||
stdout = "other.exe"
|
||||
returncode = 0
|
||||
|
||||
monkeypatch.setattr(
|
||||
"subprocess.run",
|
||||
lambda *_, **__: MockResult(),
|
||||
)
|
||||
cond = BuiltinConditions.IS_RUNNING("explorer.exe")
|
||||
assert cond({}) is False
|
||||
|
||||
|
||||
def test_is_running_linux_found(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Linux 上 pgrep 检测到进程."""
|
||||
monkeypatch.setattr("pyflowx.conditions.Constants.IS_WINDOWS", False)
|
||||
monkeypatch.setattr("pyflowx.conditions.Constants.IS_LINUX", True)
|
||||
|
||||
class MockResult:
|
||||
returncode = 0
|
||||
|
||||
monkeypatch.setattr(
|
||||
"subprocess.run",
|
||||
lambda *_, **__: MockResult(),
|
||||
)
|
||||
cond = BuiltinConditions.IS_RUNNING("nginx")
|
||||
assert cond({}) is True
|
||||
|
||||
|
||||
def test_is_running_linux_not_found(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Linux 上 pgrep 未检测到进程."""
|
||||
monkeypatch.setattr("pyflowx.conditions.Constants.IS_WINDOWS", False)
|
||||
monkeypatch.setattr("pyflowx.conditions.Constants.IS_LINUX", True)
|
||||
|
||||
class MockResult:
|
||||
returncode = 1
|
||||
|
||||
monkeypatch.setattr(
|
||||
"subprocess.run",
|
||||
lambda *_, **__: MockResult(),
|
||||
)
|
||||
cond = BuiltinConditions.IS_RUNNING("nonexistent")
|
||||
assert cond({}) is False
|
||||
|
||||
|
||||
def test_dir_exists_true(tmp_path: Path):
|
||||
"""DIR_EXISTS 检测路径存在."""
|
||||
cond = BuiltinConditions.DIR_EXISTS(tmp_path)
|
||||
assert cond({}) is True
|
||||
|
||||
|
||||
def test_dir_exists_false(tmp_path: Path):
|
||||
"""DIR_EXISTS 检测路径不存在."""
|
||||
missing = tmp_path / "nonexistent"
|
||||
cond = BuiltinConditions.DIR_EXISTS(missing)
|
||||
assert cond({}) is False
|
||||
|
||||
|
||||
def test_builtin_is_windows_returns_module_condition():
|
||||
"""BuiltinConditions.IS_WINDOWS() 应返回模块级 IS_WINDOWS."""
|
||||
assert BuiltinConditions.IS_WINDOWS() is IS_WINDOWS
|
||||
|
||||
|
||||
def test_builtin_is_linux_returns_module_condition():
|
||||
"""BuiltinConditions.IS_LINUX() 应返回模块级 IS_LINUX."""
|
||||
assert BuiltinConditions.IS_LINUX() is IS_LINUX
|
||||
|
||||
|
||||
def test_builtin_is_macos_returns_module_condition():
|
||||
"""BuiltinConditions.IS_MACOS() 应返回模块级 IS_MACOS."""
|
||||
assert BuiltinConditions.IS_MACOS() is IS_MACOS
|
||||
|
||||
|
||||
def test_builtin_is_posix_returns_module_condition():
|
||||
"""BuiltinConditions.IS_POSIX() 应返回模块级 IS_POSIX."""
|
||||
assert BuiltinConditions.IS_POSIX() is IS_POSIX
|
||||
|
||||
|
||||
def test_file_content_exists_missing_file(tmp_path: Path):
|
||||
"""FILE_CONTENT_EXISTS 文件不存在时返回 False."""
|
||||
cond = BuiltinConditions.FILE_CONTENT_EXISTS(tmp_path / "missing.txt", "x")
|
||||
assert cond({}) is False
|
||||
|
||||
|
||||
def test_file_content_exists_contains_content(tmp_path: Path):
|
||||
"""FILE_CONTENT_EXISTS 文件包含内容时返回 True."""
|
||||
f = tmp_path / "f.txt"
|
||||
f.write_text("hello world", encoding="utf-8")
|
||||
cond = BuiltinConditions.FILE_CONTENT_EXISTS(f, "world")
|
||||
assert cond({}) is True
|
||||
|
||||
|
||||
def test_file_content_exists_not_contains_content(tmp_path: Path):
|
||||
"""FILE_CONTENT_EXISTS 文件不包含内容时返回 False."""
|
||||
f = tmp_path / "f.txt"
|
||||
f.write_text("hello", encoding="utf-8")
|
||||
cond = BuiltinConditions.FILE_CONTENT_EXISTS(f, "missing")
|
||||
assert cond({}) is False
|
||||
|
||||
|
||||
def test_file_content_exists_decode_error_returns_false(tmp_path: Path):
|
||||
"""FILE_CONTENT_EXISTS 读取非 UTF-8 文件应返回 False(解码异常被吞)."""
|
||||
f = tmp_path / "bin.dat"
|
||||
f.write_bytes(b"\xff\xfe\x00bad")
|
||||
cond = BuiltinConditions.FILE_CONTENT_EXISTS(f, "x")
|
||||
assert cond({}) is False
|
||||
|
||||
|
||||
def test_dep_matches_missing_dep_returns_false():
|
||||
"""DEP_MATCHES 依赖不存在时应返回 False(覆盖 if not in ctx 分支)."""
|
||||
cond = BuiltinConditions.DEP_MATCHES("missing", lambda _v: True)
|
||||
assert cond({}) is False
|
||||
+157
-160
@@ -1,4 +1,4 @@
|
||||
"""Tests for context injection rules."""
|
||||
"""测试上下文注入规则."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -11,225 +11,222 @@ from pyflowx.context import _is_context_annotation, build_call_args, describe_in
|
||||
from pyflowx.errors import InjectionError
|
||||
|
||||
|
||||
def test_inject_by_parameter_name() -> None:
|
||||
def fn(a: int, b: str) -> str:
|
||||
return f"{a}{b}"
|
||||
class TestBuildCallArgs:
|
||||
"""测试 build_call_args 函数."""
|
||||
|
||||
spec = px.TaskSpec("c", fn, ("a", "b"))
|
||||
args, kwargs = build_call_args(spec, {"a": 1, "b": "x"})
|
||||
assert args == ()
|
||||
assert kwargs == {"a": 1, "b": "x"}
|
||||
def test_inject_by_parameter_name(self) -> None:
|
||||
"""参数名匹配依赖名时应注入对应结果."""
|
||||
|
||||
def fn(a: int, b: str) -> str:
|
||||
return f"{a}{b}"
|
||||
|
||||
def test_inject_context_annotation() -> None:
|
||||
def fn(ctx: px.Context) -> int:
|
||||
return len(ctx)
|
||||
spec = px.TaskSpec("c", fn, depends_on=("a", "b"))
|
||||
_args, kwargs = build_call_args(spec, {"a": 1, "b": "x"})
|
||||
assert kwargs == {"a": 1, "b": "x"}
|
||||
|
||||
spec = px.TaskSpec("agg", fn, ("a", "b"))
|
||||
args, kwargs = build_call_args(spec, {"a": 1, "b": 2, "c": 99})
|
||||
# Only the task's own deps are passed.
|
||||
assert kwargs == {"ctx": {"a": 1, "b": 2}}
|
||||
def test_inject_context_annotation(self) -> None:
|
||||
"""标注为 Context 的参数应接收完整依赖映射."""
|
||||
|
||||
def fn(ctx: px.Context) -> int:
|
||||
return len(ctx)
|
||||
|
||||
def test_inject_var_keyword() -> None:
|
||||
def fn(**kwargs: Any) -> int:
|
||||
return sum(kwargs.values())
|
||||
spec = px.TaskSpec("agg", fn, depends_on=("a", "b"))
|
||||
_args, kwargs = build_call_args(spec, {"a": 1, "b": 2, "c": 99})
|
||||
# Only the task's own deps are passed.
|
||||
assert kwargs == {"ctx": {"a": 1, "b": 2}}
|
||||
|
||||
spec = px.TaskSpec("agg", fn, ("a", "b"))
|
||||
args, kwargs = build_call_args(spec, {"a": 1, "b": 2})
|
||||
assert kwargs == {"a": 1, "b": 2}
|
||||
def test_inject_var_keyword(self) -> None:
|
||||
"""**kwargs 参数应以 dict 形式接收所有依赖结果."""
|
||||
|
||||
def fn(**kwargs: Any) -> int: # pyright: ignore[reportExplicitAny, reportAny]
|
||||
return sum(kwargs.values())
|
||||
|
||||
def test_static_args_and_kwargs() -> None:
|
||||
def fn(uid: int, source: str) -> str:
|
||||
return f"{source}:{uid}"
|
||||
spec = px.TaskSpec("agg", fn, depends_on=("a", "b"))
|
||||
_args, kwargs = build_call_args(spec, {"a": 1, "b": 2})
|
||||
assert kwargs == {"a": 1, "b": 2}
|
||||
|
||||
spec = px.TaskSpec("fetch", fn, args=(42,), kwargs={"source": "api"})
|
||||
args, kwargs = build_call_args(spec, {})
|
||||
assert args == (42,)
|
||||
assert kwargs == {"source": "api"}
|
||||
def test_static_args_and_kwargs(self) -> None:
|
||||
"""静态 args/kwargs 应正确填充非依赖参数."""
|
||||
|
||||
def fn(uid: int, source: str) -> str:
|
||||
return f"{source}:{uid}"
|
||||
|
||||
def test_default_param_not_required() -> None:
|
||||
def fn(a: int, flag: bool = True) -> int:
|
||||
return a if flag else 0
|
||||
spec = px.TaskSpec("fetch", fn, args=(42,), kwargs={"source": "api"})
|
||||
args, kwargs = build_call_args(spec, {})
|
||||
assert args == (42,)
|
||||
assert kwargs == {"source": "api"}
|
||||
|
||||
spec = px.TaskSpec("t", fn, ("a",))
|
||||
args, kwargs = build_call_args(spec, {"a": 5})
|
||||
assert kwargs == {"a": 5}
|
||||
def test_default_param_not_required(self) -> None:
|
||||
"""有默认值的参数无需依赖或静态值."""
|
||||
|
||||
def fn(a: int, flag: bool = True) -> int:
|
||||
return a if flag else 0
|
||||
|
||||
def test_unresolved_required_param_raises() -> None:
|
||||
def fn(a: int, missing: str) -> None:
|
||||
return None
|
||||
spec = px.TaskSpec("t", fn, depends_on=("a",))
|
||||
_args, kwargs = build_call_args(spec, {"a": 5})
|
||||
assert kwargs == {"a": 5}
|
||||
|
||||
spec = px.TaskSpec("t", fn, ("a",))
|
||||
with pytest.raises(InjectionError) as exc_info:
|
||||
build_call_args(spec, {"a": 1})
|
||||
assert "missing" in str(exc_info.value)
|
||||
def test_unresolved_required_param_raises(self) -> None:
|
||||
"""必需参数无法解析时应抛出 InjectionError."""
|
||||
|
||||
def fn(_a: int, _: str) -> None:
|
||||
return None
|
||||
|
||||
def test_static_kwargs_collide_with_dependency() -> None:
|
||||
def fn(a: int) -> int:
|
||||
return a
|
||||
spec = px.TaskSpec("t", fn, depends_on=("a",))
|
||||
with pytest.raises(InjectionError) as exc_info:
|
||||
_ = build_call_args(spec, {"a": 1})
|
||||
assert "Cannot inject" in str(exc_info.value)
|
||||
|
||||
spec = px.TaskSpec("t", fn, ("a",), kwargs={"a": 99})
|
||||
with pytest.raises(InjectionError):
|
||||
build_call_args(spec, {"a": 1})
|
||||
def test_static_kwargs_collide_with_dependency(self) -> None:
|
||||
"""静态 kwargs 与依赖名冲突时应抛出 InjectionError."""
|
||||
|
||||
def fn(a: int) -> int:
|
||||
return a
|
||||
|
||||
def test_describe_injection() -> None:
|
||||
def fn(a: int, ctx: px.Context, flag: bool = False) -> None:
|
||||
return None
|
||||
spec = px.TaskSpec("t", fn, depends_on=("a",), kwargs={"a": 99})
|
||||
with pytest.raises(InjectionError):
|
||||
_ = build_call_args(spec, {"a": 1})
|
||||
|
||||
spec = px.TaskSpec("t", fn, ("a",))
|
||||
desc = describe_injection(spec)
|
||||
assert "a=<result:a>" in desc
|
||||
assert "ctx=<Context>" in desc
|
||||
assert "flag=<default>" in desc
|
||||
def test_var_positional_not_required(self) -> None:
|
||||
"""*args 参数不应触发 InjectionError."""
|
||||
|
||||
def fn(*args: Any) -> int: # pyright: ignore[reportExplicitAny, reportAny]
|
||||
return len(args)
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# _is_context_annotation 各分支
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_is_context_annotation_direct_object() -> None:
|
||||
"""直接传入 Context 别名对象应返回 True。"""
|
||||
assert _is_context_annotation(px.Context) is True
|
||||
spec = px.TaskSpec("t", fn, args=(1, 2, 3))
|
||||
args, kwargs = build_call_args(spec, {})
|
||||
assert args == (1, 2, 3)
|
||||
assert kwargs == {}
|
||||
|
||||
def test_var_keyword_consumes_leftover(self) -> None:
|
||||
"""**kwargs 应吞掉未被具名参数消费的依赖结果."""
|
||||
|
||||
def test_is_context_annotation_string() -> None:
|
||||
"""字符串形式的注解应被识别。"""
|
||||
assert _is_context_annotation("Context") is True
|
||||
assert _is_context_annotation("px.Context") is True
|
||||
assert _is_context_annotation("pyflowx.Context") is True
|
||||
assert _is_context_annotation("NotContext") is False
|
||||
assert _is_context_annotation("int") is False
|
||||
def fn(a: int, **rest: Any) -> int: # pyright: ignore[reportExplicitAny, reportAny]
|
||||
return a + sum(rest.values())
|
||||
|
||||
spec = px.TaskSpec("t", fn, depends_on=("a", "b", "c"))
|
||||
_args, kwargs = build_call_args(spec, {"a": 1, "b": 2, "c": 3})
|
||||
assert kwargs == {"a": 1, "b": 2, "c": 3}
|
||||
|
||||
def test_is_context_annotation_typing_alias() -> None:
|
||||
"""具有 __name__/_name 为 Context/Mapping 的 typing 别名应返回 True。"""
|
||||
def test_no_var_keyword_drops_leftover(self) -> None:
|
||||
"""无 **kwargs 时,未被消费的依赖结果被丢弃(不报错)."""
|
||||
|
||||
class FakeAlias:
|
||||
__name__ = "Context"
|
||||
def fn(a: int) -> int:
|
||||
return a
|
||||
|
||||
assert _is_context_annotation(FakeAlias()) is True
|
||||
spec = px.TaskSpec("t", fn, depends_on=("a", "b"))
|
||||
# b 是依赖但 fn 不接收它 —— 应正常工作
|
||||
_args, kwargs = build_call_args(spec, {"a": 1, "b": 2})
|
||||
assert kwargs == {"a": 1}
|
||||
|
||||
class FakeMapping:
|
||||
__name__ = "Mapping"
|
||||
def test_context_annotation_only_deps(self) -> None:
|
||||
"""Context 标注只接收该任务自身 depends_on 的结果."""
|
||||
|
||||
assert _is_context_annotation(FakeMapping()) is True
|
||||
def fn(ctx: px.Context) -> int:
|
||||
return len(ctx)
|
||||
|
||||
spec = px.TaskSpec("t", fn, depends_on=("a", "b"))
|
||||
_args, kwargs = build_call_args(spec, {"a": 1, "b": 2, "c": 99})
|
||||
assert kwargs == {"ctx": {"a": 1, "b": 2}}
|
||||
|
||||
def test_is_context_annotation_other() -> None:
|
||||
"""其他类型注解应返回 False。"""
|
||||
assert _is_context_annotation(int) is False
|
||||
assert _is_context_annotation(str) is False
|
||||
assert _is_context_annotation(None) is False
|
||||
|
||||
class TestDescribeInjection:
|
||||
"""测试 describe_injection 函数."""
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# describe_injection 其余分支
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_describe_injection_var_positional() -> None:
|
||||
"""*args 参数应显示为 *args。"""
|
||||
def test_describe_injection(self) -> None:
|
||||
"""应正确描述依赖注入、Context 标注和默认值."""
|
||||
|
||||
def fn(*args: Any) -> None:
|
||||
return None
|
||||
def fn(a: int, ctx: px.Context, flag: bool = False) -> None:
|
||||
return None
|
||||
|
||||
spec = px.TaskSpec("t", fn)
|
||||
desc = describe_injection(spec)
|
||||
assert "*args" in desc
|
||||
spec = px.TaskSpec("t", fn, depends_on=("a",))
|
||||
desc = describe_injection(spec)
|
||||
assert "a=<dep:a>" in desc
|
||||
assert "ctx=<Context>" in desc
|
||||
assert "flag=<default>" in desc
|
||||
|
||||
def test_var_positional(self) -> None:
|
||||
"""*args 参数应显示为 *args."""
|
||||
|
||||
def test_describe_injection_var_keyword() -> None:
|
||||
"""**kwargs 参数应显示为 **kwargs=<all-deps>。"""
|
||||
def fn(*args: Any) -> None:
|
||||
return None
|
||||
|
||||
def fn(**kwargs: Any) -> None:
|
||||
return None
|
||||
spec = px.TaskSpec("t", fn)
|
||||
desc = describe_injection(spec)
|
||||
assert "*args" in desc
|
||||
|
||||
spec = px.TaskSpec("t", fn, ("a",))
|
||||
desc = describe_injection(spec)
|
||||
assert "**kwargs=<all-deps>" in desc
|
||||
def test_var_keyword(self) -> None:
|
||||
"""**kwargs 参数应显示为 **kwargs=<all-deps>."""
|
||||
|
||||
def fn(**kwargs: Any) -> None: # pyright: ignore[reportExplicitAny, reportAny]
|
||||
return None
|
||||
|
||||
def test_describe_injection_unresolved() -> None:
|
||||
"""无依赖、无静态值、无默认的参数应显示为 <UNRESOLVED>。"""
|
||||
spec = px.TaskSpec("t", fn, depends_on=("a",))
|
||||
desc = describe_injection(spec)
|
||||
assert "**kwargs=<all-deps>" in desc
|
||||
|
||||
def fn(missing: int) -> None:
|
||||
return None
|
||||
def test_unresolved(self) -> None:
|
||||
"""无依赖、无静态值、无默认的参数应显示为 <UNRESOLVED>."""
|
||||
|
||||
spec = px.TaskSpec("t", fn)
|
||||
desc = describe_injection(spec)
|
||||
assert "missing=<UNRESOLVED>" in desc
|
||||
def fn(missing: int) -> None:
|
||||
return None
|
||||
|
||||
spec = px.TaskSpec("t", fn)
|
||||
desc = describe_injection(spec)
|
||||
assert "missing=<UNRESOLVED>" in desc
|
||||
|
||||
def test_describe_injection_static_kwargs() -> None:
|
||||
"""静态 kwargs 应显示具体值。"""
|
||||
def test_static_kwargs(self) -> None:
|
||||
"""静态 kwargs 应显示具体值."""
|
||||
|
||||
def fn(flag: bool = False) -> None:
|
||||
return None
|
||||
def fn(flag: bool = False) -> None:
|
||||
return None
|
||||
|
||||
spec = px.TaskSpec("t", fn, kwargs={"flag": True})
|
||||
desc = describe_injection(spec)
|
||||
assert "flag=True" in desc
|
||||
spec = px.TaskSpec("t", fn, kwargs={"flag": True})
|
||||
desc = describe_injection(spec)
|
||||
assert "flag=True" in desc
|
||||
|
||||
def test_positional_args_filled(self) -> None:
|
||||
"""spec.args 填充的位置参数应显示具体值(覆盖 args_filled 分支)."""
|
||||
|
||||
def test_describe_injection_positional_args_filled() -> None:
|
||||
"""spec.args 填充的位置参数应显示具体值(覆盖 args_filled 分支)。"""
|
||||
def fn(a: int, b: str) -> None:
|
||||
return None
|
||||
|
||||
def fn(a: int, b: str) -> None:
|
||||
return None
|
||||
spec = px.TaskSpec("t", fn, args=(1, "x"))
|
||||
desc = describe_injection(spec)
|
||||
assert "a=1" in desc
|
||||
assert "b='x'" in desc
|
||||
|
||||
spec = px.TaskSpec("t", fn, args=(1, "x"))
|
||||
desc = describe_injection(spec)
|
||||
assert "a=1" in desc
|
||||
assert "b='x'" in desc
|
||||
|
||||
class TestIsContextAnnotation:
|
||||
"""测试 _is_context_annotation 函数."""
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# build_call_args 边界
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_build_call_args_var_positional_not_required() -> None:
|
||||
"""*args 参数不应触发 InjectionError。"""
|
||||
def test_direct_object(self) -> None:
|
||||
"""直接传入 Context 别名对象应返回 True."""
|
||||
assert _is_context_annotation(px.Context) is True
|
||||
|
||||
def fn(*args: Any) -> int:
|
||||
return len(args)
|
||||
def test_string(self) -> None:
|
||||
"""字符串形式的注解应被识别."""
|
||||
assert _is_context_annotation("Context") is True
|
||||
assert _is_context_annotation("px.Context") is True
|
||||
assert _is_context_annotation("pyflowx.Context") is True
|
||||
assert _is_context_annotation("NotContext") is False
|
||||
assert _is_context_annotation("int") is False
|
||||
|
||||
spec = px.TaskSpec("t", fn, args=(1, 2, 3))
|
||||
args, kwargs = build_call_args(spec, {})
|
||||
assert args == (1, 2, 3)
|
||||
assert kwargs == {}
|
||||
def test_typing_alias(self) -> None:
|
||||
"""具有 __name__/_name 为 Context/Mapping 的 typing 别名应返回 True."""
|
||||
|
||||
class FakeAlias:
|
||||
__name__ = "Context"
|
||||
|
||||
def test_build_call_args_var_keyword_consumes_leftover() -> None:
|
||||
"""**kwargs 应吞掉未被具名参数消费的依赖结果。"""
|
||||
assert _is_context_annotation(FakeAlias()) is True
|
||||
|
||||
def fn(a: int, **rest: Any) -> int:
|
||||
return a + sum(rest.values())
|
||||
class FakeMapping:
|
||||
__name__ = "Mapping"
|
||||
|
||||
spec = px.TaskSpec("t", fn, ("a", "b", "c"))
|
||||
args, kwargs = build_call_args(spec, {"a": 1, "b": 2, "c": 3})
|
||||
assert kwargs == {"a": 1, "b": 2, "c": 3}
|
||||
assert _is_context_annotation(FakeMapping()) is True
|
||||
|
||||
|
||||
def test_build_call_args_no_var_keyword_drops_leftover() -> None:
|
||||
"""无 **kwargs 时,未被消费的依赖结果被丢弃(不报错)。"""
|
||||
|
||||
def fn(a: int) -> int:
|
||||
return a
|
||||
|
||||
spec = px.TaskSpec("t", fn, ("a", "b"))
|
||||
# b 是依赖但 fn 不接收它 —— 应正常工作
|
||||
args, kwargs = build_call_args(spec, {"a": 1, "b": 2})
|
||||
assert kwargs == {"a": 1}
|
||||
|
||||
|
||||
def test_build_call_args_context_annotation_only_deps() -> None:
|
||||
"""Context 标注只接收该任务自身 depends_on 的结果。"""
|
||||
|
||||
def fn(ctx: px.Context) -> int:
|
||||
return len(ctx)
|
||||
|
||||
spec = px.TaskSpec("t", fn, ("a", "b"))
|
||||
args, kwargs = build_call_args(spec, {"a": 1, "b": 2, "c": 99})
|
||||
assert kwargs == {"ctx": {"a": 1, "b": 2}}
|
||||
def test_other(self) -> None:
|
||||
"""其他类型注解应返回 False."""
|
||||
assert _is_context_annotation(int) is False
|
||||
assert _is_context_annotation(str) is False
|
||||
assert _is_context_annotation(None) is False
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Tests for process executor (spec.executor='process')."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
# pyrefly: ignore[missing-import]
|
||||
from _proc_helper import add, cpu_heavy, slow_sleep, sub
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.errors import TaskFailedError
|
||||
|
||||
|
||||
def test_process_executor_runs_cpu_task() -> None:
|
||||
"""executor='process' 应在进程池中执行 CPU 密集型任务."""
|
||||
spec = px.TaskSpec("cpu", fn=cpu_heavy, args=(1000,), executor="process")
|
||||
graph = px.Graph.from_specs([spec])
|
||||
report = px.run(graph)
|
||||
assert report.success
|
||||
assert report["cpu"] == sum(i * i for i in range(1000))
|
||||
|
||||
|
||||
def test_process_executor_with_dependency() -> None:
|
||||
"""进程池任务应支持依赖注入."""
|
||||
spec1 = px.TaskSpec("a", fn=cpu_heavy, args=(100,), executor="process")
|
||||
spec2 = px.TaskSpec("b", fn=add, args=(3, 4), executor="process", depends_on=("a",))
|
||||
graph = px.Graph.from_specs([spec1, spec2])
|
||||
report = px.run(graph)
|
||||
assert report.success
|
||||
assert report["b"] == 7
|
||||
|
||||
|
||||
def test_process_executor_default_is_thread() -> None:
|
||||
"""TaskSpec.executor 默认应为 'thread'."""
|
||||
spec = px.TaskSpec("x", fn=lambda: None)
|
||||
assert spec.executor == "thread"
|
||||
|
||||
|
||||
def test_inline_executor_runs_in_event_loop() -> None:
|
||||
"""executor='inline' 应直接在事件循环线程调用."""
|
||||
spec = px.TaskSpec("inline", fn=add, args=(10, 20), executor="inline")
|
||||
graph = px.Graph.from_specs([spec])
|
||||
report = px.run(graph)
|
||||
assert report.success
|
||||
assert report["inline"] == 30
|
||||
|
||||
|
||||
def test_process_executor_with_kwargs() -> None:
|
||||
"""进程池任务应支持 kwargs 注入."""
|
||||
spec = px.TaskSpec("kw", fn=sub, args=(10,), kwargs={"b": 3}, executor="process")
|
||||
graph = px.Graph.from_specs([spec])
|
||||
report = px.run(graph)
|
||||
assert report.success
|
||||
assert report["kw"] == 7
|
||||
|
||||
|
||||
def test_process_executor_timeout() -> None:
|
||||
"""进程池任务超时应抛 TaskFailedError."""
|
||||
spec = px.TaskSpec("slow", fn=slow_sleep, args=(10.0,), executor="process", timeout=0.1)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
with pytest.raises(TaskFailedError):
|
||||
px.run(graph)
|
||||
+208
-111
@@ -3,11 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import logging
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, List
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -26,12 +27,10 @@ def test_sequential_basic() -> None:
|
||||
def double(extract: list[int]) -> list[int]:
|
||||
return [x * 2 for x in extract]
|
||||
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("extract", extract),
|
||||
px.TaskSpec("double", double, ("extract",)),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("extract", extract),
|
||||
px.TaskSpec("double", double, depends_on=("extract",)),
|
||||
])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert report["extract"] == [1, 2, 3]
|
||||
@@ -39,7 +38,7 @@ def test_sequential_basic() -> None:
|
||||
|
||||
|
||||
def test_sequential_diamond() -> None:
|
||||
order: List[str] = []
|
||||
order: list[str] = []
|
||||
|
||||
def make(name: str) -> Any:
|
||||
def fn() -> str:
|
||||
@@ -48,14 +47,12 @@ def test_sequential_diamond() -> None:
|
||||
|
||||
return fn
|
||||
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("a", make("a")),
|
||||
px.TaskSpec("b", make("b"), ("a",)),
|
||||
px.TaskSpec("c", make("c"), ("a",)),
|
||||
px.TaskSpec("d", make("d"), ("b", "c")),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", make("a")),
|
||||
px.TaskSpec("b", make("b"), depends_on=("a",)),
|
||||
px.TaskSpec("c", make("c"), depends_on=("a",)),
|
||||
px.TaskSpec("d", make("d"), depends_on=("b", "c")),
|
||||
])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert report["d"] == "d"
|
||||
@@ -66,17 +63,15 @@ def test_failure_propagates() -> None:
|
||||
def boom() -> None:
|
||||
raise ValueError("kaboom")
|
||||
|
||||
def downstream(boom: None) -> int:
|
||||
def downstream(_boom: None) -> int:
|
||||
return 1
|
||||
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("boom", boom),
|
||||
px.TaskSpec("downstream", downstream, ("boom",)),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("boom", boom),
|
||||
px.TaskSpec("downstream", downstream, depends_on=("boom",)),
|
||||
])
|
||||
with pytest.raises(TaskFailedError) as exc_info:
|
||||
px.run(graph, strategy="sequential")
|
||||
_ = px.run(graph, strategy="sequential")
|
||||
assert exc_info.value.task == "boom"
|
||||
assert isinstance(exc_info.value.cause, ValueError)
|
||||
|
||||
@@ -90,48 +85,92 @@ def test_retries_then_succeeds() -> None:
|
||||
raise RuntimeError("not yet")
|
||||
return "ok"
|
||||
|
||||
graph = px.Graph.from_specs([px.TaskSpec("flaky", flaky, retries=2)])
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("flaky", flaky, retry=px.RetryPolicy(max_attempts=3)),
|
||||
])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert report["flaky"] == "ok"
|
||||
assert attempts["n"] == 3
|
||||
|
||||
|
||||
def test_retries_with_delay() -> None:
|
||||
"""测试带delay的重试会实际等待。"""
|
||||
attempts = {"n": 0}
|
||||
start_time = time.time()
|
||||
|
||||
def flaky() -> str:
|
||||
attempts["n"] += 1
|
||||
if attempts["n"] < 2:
|
||||
raise RuntimeError("not yet")
|
||||
return "ok"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("flaky", flaky, retry=px.RetryPolicy(max_attempts=2, delay=0.1)),
|
||||
])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
elapsed = time.time() - start_time
|
||||
assert report.success
|
||||
assert elapsed >= 0.1 # 应有至少0.1秒的等待时间
|
||||
assert attempts["n"] == 2
|
||||
|
||||
|
||||
def test_timeout_then_retry_async(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""测试超时后可以重试,并记录warning日志。"""
|
||||
|
||||
async def slow_task() -> str:
|
||||
await asyncio.sleep(10) # 会触发超时
|
||||
return "ok"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("slow", slow_task, timeout=0.2, retry=px.RetryPolicy(max_attempts=2)),
|
||||
])
|
||||
with caplog.at_level(logging.WARNING, logger="pyflowx"):
|
||||
with pytest.raises(px.TaskFailedError) as exc_info:
|
||||
_ = px.run(graph, strategy="async")
|
||||
assert exc_info.value.attempts == 2
|
||||
assert "timed out" in str(exc_info.value.cause)
|
||||
# 应有超时重试的warning日志
|
||||
assert any("timed out" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_retries_exhausted() -> None:
|
||||
def always_fail() -> None:
|
||||
raise RuntimeError("nope")
|
||||
|
||||
graph = px.Graph.from_specs([px.TaskSpec("f", always_fail, retries=2)])
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("f", always_fail, retry=px.RetryPolicy(max_attempts=3)),
|
||||
])
|
||||
with pytest.raises(TaskFailedError) as exc_info:
|
||||
px.run(graph, strategy="sequential")
|
||||
_ = px.run(graph, strategy="sequential")
|
||||
assert exc_info.value.attempts == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Threaded
|
||||
# ---------------------------------------------------------------------- #
|
||||
@pytest.mark.slow
|
||||
def test_threaded_parallelism() -> None:
|
||||
def slow() -> str:
|
||||
time.sleep(0.3)
|
||||
return "done"
|
||||
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("a", slow),
|
||||
px.TaskSpec("b", slow),
|
||||
px.TaskSpec("c", slow),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", slow),
|
||||
px.TaskSpec("b", slow),
|
||||
px.TaskSpec("c", slow),
|
||||
])
|
||||
start = time.time()
|
||||
report = px.run(graph, strategy="thread", max_workers=3)
|
||||
elapsed = time.time() - start
|
||||
assert report.success
|
||||
# Three 0.3s tasks in parallel should be well under 0.8s.
|
||||
assert elapsed < 0.8
|
||||
# Three 0.3s tasks in parallel should be well under 1.0s.
|
||||
assert elapsed < 1.0
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_threaded_layer_barrier() -> None:
|
||||
finished: List[str] = []
|
||||
finished: list[str] = []
|
||||
lock = threading.Lock()
|
||||
|
||||
def make(name: str) -> Any:
|
||||
@@ -143,13 +182,11 @@ def test_threaded_layer_barrier() -> None:
|
||||
|
||||
return fn
|
||||
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("a", make("a")),
|
||||
px.TaskSpec("b", make("b")),
|
||||
px.TaskSpec("c", make("c"), ("a", "b")),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", make("a")),
|
||||
px.TaskSpec("b", make("b")),
|
||||
px.TaskSpec("c", make("c"), depends_on=("a", "b")),
|
||||
])
|
||||
report = px.run(graph, strategy="thread", max_workers=2)
|
||||
assert report.success
|
||||
# c must finish after both a and b.
|
||||
@@ -168,34 +205,28 @@ def test_async_basic() -> None:
|
||||
async def transform(fetch: int) -> int:
|
||||
return fetch * 2
|
||||
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("fetch", fetch),
|
||||
px.TaskSpec("transform", transform, ("fetch",)),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("fetch", fetch),
|
||||
px.TaskSpec("transform", transform, depends_on=("fetch",)),
|
||||
])
|
||||
report = px.run(graph, strategy="async")
|
||||
assert report.success
|
||||
assert report["transform"] == 84
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_async_parallelism() -> None:
|
||||
async def slow() -> str:
|
||||
await asyncio.sleep(0.3)
|
||||
return "done"
|
||||
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("a", slow),
|
||||
px.TaskSpec("b", slow),
|
||||
px.TaskSpec("c", slow),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", slow), px.TaskSpec("b", slow), px.TaskSpec("c", slow)])
|
||||
start = time.time()
|
||||
report = px.run(graph, strategy="async")
|
||||
elapsed = time.time() - start
|
||||
assert report.success
|
||||
assert elapsed < 0.8
|
||||
# 放宽时间限制以应对 CI 环境波动(理想 0.3s,串行约 0.9s,上限 1.5s 确保并行有效性)
|
||||
assert elapsed < 1.5
|
||||
|
||||
|
||||
def test_async_mixed_sync_and_async() -> None:
|
||||
@@ -206,12 +237,10 @@ def test_async_mixed_sync_and_async() -> None:
|
||||
await asyncio.sleep(0.01)
|
||||
return sync_task + 5
|
||||
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("sync_task", sync_task),
|
||||
px.TaskSpec("async_task", async_task, ("sync_task",)),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("sync_task", sync_task),
|
||||
px.TaskSpec("async_task", async_task, depends_on=("sync_task",)),
|
||||
])
|
||||
report = px.run(graph, strategy="async")
|
||||
assert report.success
|
||||
assert report["async_task"] == 15
|
||||
@@ -223,7 +252,7 @@ def test_async_timeout() -> None:
|
||||
|
||||
graph = px.Graph.from_specs([px.TaskSpec("slow", slow, timeout=0.05)])
|
||||
with pytest.raises(TaskFailedError) as exc_info:
|
||||
px.run(graph, strategy="async")
|
||||
_ = px.run(graph, strategy="async")
|
||||
assert isinstance(exc_info.value.cause, TaskTimeoutError)
|
||||
|
||||
|
||||
@@ -231,7 +260,7 @@ def test_async_timeout() -> None:
|
||||
# Dry run
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_dry_run_does_not_execute(capsys: pytest.CaptureFixture[str]) -> None:
|
||||
called: List[str] = []
|
||||
called: list[str] = []
|
||||
|
||||
def fn() -> str:
|
||||
called.append("x")
|
||||
@@ -250,7 +279,7 @@ def test_dry_run_does_not_execute(capsys: pytest.CaptureFixture[str]) -> None:
|
||||
# State / resume
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_memory_backend_resume() -> None:
|
||||
runs: List[str] = []
|
||||
runs: list[str] = []
|
||||
|
||||
def make(name: str) -> Any:
|
||||
def fn() -> str:
|
||||
@@ -259,33 +288,31 @@ def test_memory_backend_resume() -> None:
|
||||
|
||||
return fn
|
||||
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("a", make("a")),
|
||||
px.TaskSpec("b", make("b"), ("a",)),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", make("a")),
|
||||
px.TaskSpec("b", make("b"), depends_on=("a",)),
|
||||
])
|
||||
backend = MemoryBackend()
|
||||
px.run(graph, strategy="sequential", state=backend)
|
||||
_ = px.run(graph, strategy="sequential", state=backend)
|
||||
assert runs == ["a", "b"]
|
||||
|
||||
# Second run: both cached, neither re-executed.
|
||||
px.run(graph, strategy="sequential", state=backend)
|
||||
_ = px.run(graph, strategy="sequential", state=backend)
|
||||
assert runs == ["a", "b"] # unchanged
|
||||
|
||||
|
||||
def test_json_backend_persistence() -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "state.json")
|
||||
path = str(Path(tmp) / "state.json")
|
||||
|
||||
def fn() -> int:
|
||||
return 7
|
||||
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", fn)])
|
||||
px.run(graph, strategy="sequential", state=JSONBackend(path))
|
||||
_ = px.run(graph, strategy="sequential", state=JSONBackend(path))
|
||||
|
||||
# New backend reads the file; task should be skipped.
|
||||
runs: List[str] = []
|
||||
runs: list[str] = []
|
||||
|
||||
def fn2() -> int:
|
||||
runs.append("ran")
|
||||
@@ -301,27 +328,18 @@ def test_json_backend_persistence() -> None:
|
||||
# Events
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_on_event_callback() -> None:
|
||||
events: List[px.TaskEvent] = []
|
||||
events: list[px.TaskEvent] = []
|
||||
|
||||
def fn() -> int:
|
||||
return 1
|
||||
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", fn)])
|
||||
px.run(graph, strategy="sequential", on_event=events.append)
|
||||
_ = px.run(graph, strategy="sequential", on_event=events.append)
|
||||
statuses = [e.status for e in events]
|
||||
assert px.TaskStatus.SUCCESS in statuses
|
||||
assert all(e.task == "a" for e in events)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# Invalid strategy
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_invalid_strategy() -> None:
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", lambda: None)]) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError):
|
||||
px.run(graph, strategy="bogus") # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 异步策略:sync 任务无 timeout 分支 + timeout 重试分支
|
||||
# ---------------------------------------------------------------------- #
|
||||
@@ -359,7 +377,9 @@ def test_async_timeout_retry_then_succeed() -> None:
|
||||
await asyncio.sleep(10) # 触发超时
|
||||
return "ok"
|
||||
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", flaky, retries=2, timeout=0.05)])
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", flaky, retry=px.RetryPolicy(max_attempts=3), timeout=0.05),
|
||||
])
|
||||
report = px.run(graph, strategy="async")
|
||||
assert report.success
|
||||
assert report["a"] == "ok"
|
||||
@@ -376,7 +396,9 @@ def test_async_failure_retry_branch(caplog: pytest.LogCaptureFixture) -> None:
|
||||
raise RuntimeError("not yet")
|
||||
return "ok"
|
||||
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", flaky, retries=2)])
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", flaky, retry=px.RetryPolicy(max_attempts=3)),
|
||||
])
|
||||
with caplog.at_level("WARNING", logger="pyflowx"):
|
||||
report = px.run(graph, strategy="async")
|
||||
assert report.success
|
||||
@@ -390,7 +412,7 @@ def test_async_failure_retry_branch(caplog: pytest.LogCaptureFixture) -> None:
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_threaded_skips_cached_tasks() -> None:
|
||||
"""threaded 策略下命中缓存的任务应被跳过(覆盖 line 224-230)。"""
|
||||
runs: List[str] = []
|
||||
runs: list[str] = []
|
||||
|
||||
def make(name: str) -> Any:
|
||||
def fn() -> str:
|
||||
@@ -399,18 +421,16 @@ def test_threaded_skips_cached_tasks() -> None:
|
||||
|
||||
return fn
|
||||
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("a", make("a")),
|
||||
px.TaskSpec("b", make("b"), ("a",)),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", make("a")),
|
||||
px.TaskSpec("b", make("b"), depends_on=("a",)),
|
||||
])
|
||||
backend = px.MemoryBackend()
|
||||
# 第一次运行填充缓存
|
||||
px.run(graph, strategy="thread", max_workers=2, state=backend)
|
||||
_ = px.run(graph, strategy="thread", max_workers=2, state=backend)
|
||||
assert runs == ["a", "b"]
|
||||
# 第二次运行应全部跳过
|
||||
px.run(graph, strategy="thread", max_workers=2, state=backend)
|
||||
_ = px.run(graph, strategy="thread", max_workers=2, state=backend)
|
||||
assert runs == ["a", "b"] # 未再执行
|
||||
|
||||
|
||||
@@ -426,7 +446,7 @@ def test_threaded_all_cached_layer() -> None:
|
||||
|
||||
def test_async_skips_cached_tasks() -> None:
|
||||
"""async 策略下命中缓存的任务应被跳过(覆盖 line 268-274)。"""
|
||||
runs: List[str] = []
|
||||
runs: list[str] = []
|
||||
|
||||
async def make(name: str) -> Any:
|
||||
async def fn() -> str:
|
||||
@@ -444,16 +464,14 @@ def test_async_skips_cached_tasks() -> None:
|
||||
runs.append("b")
|
||||
return a + "b"
|
||||
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("a", a),
|
||||
px.TaskSpec("b", b, ("a",)),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", a),
|
||||
px.TaskSpec("b", b, depends_on=("a",)),
|
||||
])
|
||||
backend = px.MemoryBackend()
|
||||
px.run(graph, strategy="async", state=backend)
|
||||
_ = px.run(graph, strategy="async", state=backend)
|
||||
assert runs == ["a", "b"]
|
||||
px.run(graph, strategy="async", state=backend)
|
||||
_ = px.run(graph, strategy="async", state=backend)
|
||||
assert runs == ["a", "b"]
|
||||
|
||||
|
||||
@@ -480,7 +498,7 @@ def test_failure_marks_report_unsuccessful() -> None:
|
||||
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", boom)])
|
||||
with pytest.raises(px.TaskFailedError):
|
||||
px.run(graph, strategy="sequential")
|
||||
_ = px.run(graph, strategy="sequential")
|
||||
# report 在异常前未返回,但若捕获异常则 success 应为 False
|
||||
# 这里验证 run() 抛异常的行为本身
|
||||
|
||||
@@ -513,3 +531,82 @@ def test_run_empty_graph() -> None:
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert len(report) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 上游任务被 SKIPPED 后,下游任务也应被 SKIPPED
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_downstream_skipped_when_upstream_skipped_sequential() -> None:
|
||||
"""上游任务被 SKIPPED 后,下游任务也应被 SKIPPED(sequential 策略)."""
|
||||
never_true = lambda _ctx: False # noqa: E731
|
||||
|
||||
def downstream(upstream: str) -> str:
|
||||
return upstream + "_processed"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("upstream", cmd=["echo", "hello"], conditions=(never_true,)),
|
||||
px.TaskSpec("downstream", downstream, depends_on=("upstream",)),
|
||||
])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert report.result_of("upstream").status == px.TaskStatus.SKIPPED
|
||||
assert report.result_of("downstream").status == px.TaskStatus.SKIPPED
|
||||
|
||||
|
||||
def test_downstream_skipped_when_upstream_skipped_thread() -> None:
|
||||
"""上游任务被 SKIPPED 后,下游任务也应被 SKIPPED(thread 策略)."""
|
||||
never_true = lambda _ctx: False # noqa: E731
|
||||
|
||||
def downstream(upstream: str) -> str:
|
||||
return upstream + "_processed"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("upstream", cmd=["echo", "hello"], conditions=(never_true,)),
|
||||
px.TaskSpec("downstream", downstream, depends_on=("upstream",)),
|
||||
])
|
||||
report = px.run(graph, strategy="thread", max_workers=2)
|
||||
assert report.success
|
||||
assert report.result_of("upstream").status == px.TaskStatus.SKIPPED
|
||||
assert report.result_of("downstream").status == px.TaskStatus.SKIPPED
|
||||
|
||||
|
||||
def test_downstream_skipped_when_upstream_skipped_async() -> None:
|
||||
"""上游任务被 SKIPPED 后,下游任务也应被 SKIPPED(async 策略)."""
|
||||
|
||||
async def upstream() -> str:
|
||||
return "hello"
|
||||
|
||||
async def downstream(upstream: str) -> str:
|
||||
return upstream + "_processed"
|
||||
|
||||
never_true = lambda _ctx: False # noqa: E731
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("upstream", upstream, conditions=(never_true,)),
|
||||
px.TaskSpec("downstream", downstream, depends_on=("upstream",)),
|
||||
])
|
||||
report = px.run(graph, strategy="async")
|
||||
assert report.success
|
||||
assert report.result_of("upstream").status == px.TaskStatus.SKIPPED
|
||||
assert report.result_of("downstream").status == px.TaskStatus.SKIPPED
|
||||
|
||||
|
||||
def test_downstream_executes_when_upstream_succeeds() -> None:
|
||||
"""上游任务成功时,下游任务应正常执行."""
|
||||
always_true = lambda _ctx: True # noqa: E731
|
||||
|
||||
def upstream() -> str:
|
||||
return "hello"
|
||||
|
||||
def downstream(upstream: str) -> str:
|
||||
return upstream + "_processed"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("upstream", upstream, conditions=(always_true,)),
|
||||
px.TaskSpec("downstream", downstream, depends_on=("upstream",)),
|
||||
])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert report.result_of("upstream").status == px.TaskStatus.SUCCESS
|
||||
assert report.result_of("downstream").status == px.TaskStatus.SUCCESS
|
||||
assert report["downstream"] == "hello_processed"
|
||||
|
||||
@@ -0,0 +1,566 @@
|
||||
"""Tests for executors module edge cases."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.task import TaskStatus
|
||||
|
||||
# 跨平台的 echo 命令
|
||||
if sys.platform == "win32":
|
||||
ECHO_CMD = ["cmd", "/c", "echo"]
|
||||
else:
|
||||
ECHO_CMD = ["echo"]
|
||||
|
||||
|
||||
def test_execute_sync_with_timeout():
|
||||
"""Test execute task with timeout correctly."""
|
||||
# Note: timeout for Python functions only works in async strategy
|
||||
# For sync functions, timeout is not enforced in sequential strategy
|
||||
# This test verifies that the task runs without timeout error
|
||||
spec = px.TaskSpec("quick", fn=lambda: "result", timeout=10)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
# Should succeed without timeout error
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_execute_async_with_timeout():
|
||||
"""Test execute async task with timeout correctly."""
|
||||
|
||||
async def slow_async_function():
|
||||
await asyncio.sleep(2)
|
||||
return "result"
|
||||
|
||||
spec = px.TaskSpec("slow_async", fn=slow_async_function, timeout=0.5)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
# This should timeout
|
||||
with pytest.raises(px.TaskFailedError):
|
||||
px.run(graph, strategy="async")
|
||||
|
||||
|
||||
def test_verbose_event_callback_running():
|
||||
"""Test verbose event callback for RUNNING status."""
|
||||
# Create a graph with verbose callback
|
||||
spec = px.TaskSpec("test", fn=lambda: "result", verbose=True)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
# Should print without error
|
||||
assert report.success
|
||||
|
||||
|
||||
def test_verbose_run_with_success_lifecycle(capsys: pytest.CaptureFixture[str]):
|
||||
"""Test px.run with verbose=True prints SUCCESS lifecycle."""
|
||||
spec = px.TaskSpec("test", fn=lambda: "result")
|
||||
graph = px.Graph.from_specs([spec])
|
||||
report = px.run(graph, strategy="sequential", verbose=True)
|
||||
assert report.success
|
||||
captured = capsys.readouterr()
|
||||
assert "成功" in captured.out
|
||||
|
||||
|
||||
def test_verbose_run_with_failed_lifecycle(capsys: pytest.CaptureFixture[str]):
|
||||
"""Test px.run with verbose=True prints FAILED lifecycle with error."""
|
||||
|
||||
def raise_error():
|
||||
raise ValueError("test error")
|
||||
|
||||
spec = px.TaskSpec("test", fn=raise_error)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
with pytest.raises(px.TaskFailedError):
|
||||
px.run(graph, strategy="sequential", verbose=True)
|
||||
captured = capsys.readouterr()
|
||||
assert "失败" in captured.out
|
||||
assert "test error" in captured.out
|
||||
|
||||
|
||||
def test_verbose_run_with_skipped_lifecycle(capsys: pytest.CaptureFixture[str]):
|
||||
"""Test px.run with verbose=True prints SKIPPED lifecycle."""
|
||||
spec = px.TaskSpec(
|
||||
"test",
|
||||
fn=lambda: "result",
|
||||
conditions=(lambda _ctx: False,),
|
||||
)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
report = px.run(graph, strategy="sequential", verbose=True)
|
||||
assert report.success
|
||||
captured = capsys.readouterr()
|
||||
assert "跳过" in captured.out
|
||||
|
||||
|
||||
def test_verbose_run_with_user_callback():
|
||||
"""Test px.run with verbose=True and user callback both called.
|
||||
|
||||
预期事件序列:RUNNING(开始)→ SUCCESS(完成)。
|
||||
"""
|
||||
events = []
|
||||
|
||||
def on_event(event: px.TaskEvent):
|
||||
events.append(event)
|
||||
|
||||
spec = px.TaskSpec("test", fn=lambda: "result")
|
||||
graph = px.Graph.from_specs([spec])
|
||||
report = px.run(graph, strategy="sequential", verbose=True, on_event=on_event)
|
||||
assert report.success
|
||||
assert len(events) == 2
|
||||
assert events[0].status == px.TaskStatus.RUNNING
|
||||
assert events[1].status == px.TaskStatus.SUCCESS
|
||||
|
||||
|
||||
def test_verbose_event_callback_success():
|
||||
"""Test verbose event callback for SUCCESS status."""
|
||||
# Create a graph with verbose callback
|
||||
spec = px.TaskSpec("test", fn=lambda: "result", verbose=True)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
# Should print without error
|
||||
assert report.success
|
||||
|
||||
|
||||
def test_verbose_event_callback_failed():
|
||||
"""Test verbose event callback for FAILED status."""
|
||||
# Create a graph with verbose callback and failing task
|
||||
|
||||
def raise_error():
|
||||
raise ValueError("test error")
|
||||
|
||||
spec = px.TaskSpec("test", fn=raise_error, verbose=True)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
# Should print without error
|
||||
with pytest.raises(px.TaskFailedError):
|
||||
px.run(graph, strategy="sequential")
|
||||
|
||||
|
||||
def test_verbose_event_callback_skipped():
|
||||
"""Test verbose event callback for SKIPPED status."""
|
||||
# Create a graph with verbose callback and skipped task
|
||||
spec = px.TaskSpec(
|
||||
"test",
|
||||
fn=lambda: "result",
|
||||
conditions=(lambda _ctx: False,),
|
||||
verbose=True,
|
||||
)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
# Should print without error
|
||||
assert report.success
|
||||
|
||||
|
||||
def test_execute_sync_with_retries():
|
||||
"""Test execute task with retries."""
|
||||
|
||||
call_count = 0
|
||||
|
||||
def failing_function():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise ValueError("temporary error")
|
||||
return "success"
|
||||
|
||||
spec = px.TaskSpec(
|
||||
"retry_test",
|
||||
fn=failing_function,
|
||||
retry=px.RetryPolicy(max_attempts=3),
|
||||
)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
# Should succeed after retries
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert report.results["retry_test"].attempts == 3
|
||||
|
||||
|
||||
def test_execute_async_with_retries():
|
||||
"""Test execute async task with retries."""
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def failing_async_function():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count < 3:
|
||||
raise ValueError("temporary error")
|
||||
return "success"
|
||||
|
||||
spec = px.TaskSpec(
|
||||
"retry_async_test",
|
||||
fn=failing_async_function,
|
||||
retry=px.RetryPolicy(max_attempts=3),
|
||||
)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
# Should succeed after retries
|
||||
report = px.run(graph, strategy="async")
|
||||
assert report.success
|
||||
assert report.results["retry_async_test"].attempts == 3
|
||||
|
||||
|
||||
def test_execute_sync_skip_on_condition():
|
||||
"""Test execute task skips task when condition is false."""
|
||||
spec = px.TaskSpec(
|
||||
"skip_test",
|
||||
fn=lambda: "result",
|
||||
conditions=(lambda _ctx: False,),
|
||||
)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert report.results["skip_test"].status == TaskStatus.SKIPPED
|
||||
|
||||
|
||||
def test_execute_async_skip_on_condition():
|
||||
"""Test execute async task skips task when condition is false."""
|
||||
spec = px.TaskSpec(
|
||||
"skip_async_test",
|
||||
fn=lambda: "result",
|
||||
conditions=(lambda _ctx: False,),
|
||||
)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
report = px.run(graph, strategy="async")
|
||||
assert report.success
|
||||
assert report.results["skip_async_test"].status == TaskStatus.SKIPPED
|
||||
|
||||
|
||||
def test_execute_sync_with_error():
|
||||
"""Test execute task handles errors correctly."""
|
||||
|
||||
def error_function():
|
||||
raise ValueError("test error")
|
||||
|
||||
spec = px.TaskSpec("error_test", fn=error_function)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
with pytest.raises(px.TaskFailedError):
|
||||
px.run(graph, strategy="sequential")
|
||||
|
||||
|
||||
def test_execute_async_with_error():
|
||||
"""Test execute async task handles errors correctly."""
|
||||
|
||||
async def error_async_function():
|
||||
raise ValueError("test error")
|
||||
|
||||
spec = px.TaskSpec("error_async_test", fn=error_async_function)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
with pytest.raises(px.TaskFailedError):
|
||||
px.run(graph, strategy="async")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# _check_upstream_skipped 分支测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_allow_upstream_skip_allows_execution_after_skipped() -> None:
|
||||
"""allow_upstream_skip=True 时上游被 SKIPPED 后本任务仍执行."""
|
||||
never_true = lambda _ctx: False # noqa: E731
|
||||
|
||||
def downstream_task() -> str:
|
||||
return "ran despite upstream skipped"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("upstream", fn=lambda: "up", conditions=(never_true,)),
|
||||
px.TaskSpec("downstream", fn=downstream_task, depends_on=("upstream",), allow_upstream_skip=True),
|
||||
])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert report.results["upstream"].status == TaskStatus.SKIPPED
|
||||
assert report.results["downstream"].status == TaskStatus.SUCCESS
|
||||
assert report["downstream"] == "ran despite upstream skipped"
|
||||
|
||||
|
||||
def test_upstream_failed_skips_downstream() -> None:
|
||||
"""上游 FAILED 时下游被 SKIPPED(除非 allow_upstream_skip=True)."""
|
||||
|
||||
def boom():
|
||||
raise ValueError("boom")
|
||||
|
||||
def downstream():
|
||||
return "should not run"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("upstream", fn=boom),
|
||||
px.TaskSpec("downstream", fn=downstream, depends_on=("upstream",)),
|
||||
])
|
||||
with pytest.raises(px.TaskFailedError):
|
||||
px.run(graph, strategy="sequential")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# _evaluate_conditions 多条件分支测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_multiple_conditions_failure_truncation() -> None:
|
||||
"""超过 2 个条件失败时应截断显示."""
|
||||
spec = px.TaskSpec(
|
||||
"multi_skip",
|
||||
fn=lambda: "result",
|
||||
conditions=(lambda _ctx: False, lambda _ctx: False, lambda _ctx: False, lambda _ctx: False, lambda _ctx: False),
|
||||
)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
report = px.run(graph, strategy="sequential", verbose=True)
|
||||
assert report.success
|
||||
assert report.results["multi_skip"].status == TaskStatus.SKIPPED
|
||||
# reason 应显示 "条件不满足: <lambda>, <lambda> 等5个条件"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# concurrency_key 测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_concurrency_key_sequential() -> None:
|
||||
"""sequential 策略下 concurrency_key 无效果."""
|
||||
spec = px.TaskSpec("a", fn=lambda: 1, concurrency_key="group1")
|
||||
graph = px.Graph.from_specs([spec])
|
||||
report = px.run(graph, strategy="sequential", concurrency_limits={"group1": 1})
|
||||
assert report.success
|
||||
|
||||
|
||||
def test_concurrency_key_thread() -> None:
|
||||
"""thread 策略下 concurrency_key 应限制并发."""
|
||||
import time
|
||||
|
||||
order = []
|
||||
|
||||
def make(name: str) -> Callable[[], str]:
|
||||
def fn():
|
||||
order.append(f"{name}-start")
|
||||
time.sleep(0.1)
|
||||
order.append(f"{name}-end")
|
||||
return name
|
||||
|
||||
return fn
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", fn=make("a"), concurrency_key="group1"),
|
||||
px.TaskSpec("b", fn=make("b"), concurrency_key="group1"),
|
||||
px.TaskSpec("c", fn=make("c"), concurrency_key="group1"),
|
||||
])
|
||||
report = px.run(graph, strategy="thread", max_workers=10, concurrency_limits={"group1": 1})
|
||||
assert report.success
|
||||
# 由于 concurrency_key 限制为 1,任务应串行执行
|
||||
# 验证顺序:每个任务的 start-end 应连续
|
||||
# 可能顺序:a-start, a-end, b-start, b-end, c-start, c-end
|
||||
|
||||
|
||||
def test_concurrency_key_async() -> None:
|
||||
"""async 策略下 concurrency_key 应限制并发."""
|
||||
import asyncio
|
||||
|
||||
async def task_a():
|
||||
await asyncio.sleep(0.01)
|
||||
return "a"
|
||||
|
||||
async def task_b():
|
||||
await asyncio.sleep(0.01)
|
||||
return "b"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", fn=task_a, concurrency_key="group1"),
|
||||
px.TaskSpec("b", fn=task_b, concurrency_key="group1"),
|
||||
])
|
||||
report = px.run(graph, strategy="async", concurrency_limits={"group1": 1})
|
||||
assert report.success
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# dependency 策略测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_dependency_strategy_basic() -> None:
|
||||
"""dependency 策略应正确执行."""
|
||||
order = []
|
||||
|
||||
def make(name: str) -> Callable[[], str]:
|
||||
def fn():
|
||||
order.append(name)
|
||||
return name
|
||||
|
||||
return fn
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", fn=make("a")),
|
||||
px.TaskSpec("b", fn=make("b"), depends_on=("a",)),
|
||||
px.TaskSpec("c", fn=make("c"), depends_on=("a",)),
|
||||
px.TaskSpec("d", fn=make("d"), depends_on=("b", "c")),
|
||||
])
|
||||
report = px.run(graph, strategy="dependency")
|
||||
assert report.success
|
||||
assert "a" in order
|
||||
assert "d" in order
|
||||
|
||||
|
||||
def test_dependency_strategy_async() -> None:
|
||||
"""dependency 策略下异步任务应正确执行."""
|
||||
|
||||
async def a():
|
||||
return "a"
|
||||
|
||||
async def b(a: str):
|
||||
return a + "b"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", fn=a),
|
||||
px.TaskSpec("b", fn=b, depends_on=("a",)),
|
||||
])
|
||||
report = px.run(graph, strategy="dependency")
|
||||
assert report.success
|
||||
assert report["b"] == "ab"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# continue_on_error 测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_continue_on_error_marks_failed_but_continues() -> None:
|
||||
"""continue_on_error=True 时任务失败不抛异常,但 report.success 为 True(无 TaskFailedError 抛出)。"""
|
||||
|
||||
def boom():
|
||||
raise ValueError("boom")
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("fail", fn=boom, continue_on_error=True),
|
||||
px.TaskSpec("other", fn=lambda: "ok"), # 无依赖,应继续
|
||||
])
|
||||
# continue_on_error=True 时 run 不抛异常,report.success 为 True
|
||||
report = px.run(graph, strategy="sequential")
|
||||
# report.success 为 True 因为没有抛 TaskFailedError
|
||||
assert report.success # 因为 continue_on_error 阻止了 TaskFailedError
|
||||
assert report.results["fail"].status == TaskStatus.FAILED
|
||||
assert report.results["other"].status == TaskStatus.SUCCESS
|
||||
|
||||
|
||||
def test_continue_on_error_downstream_skipped() -> None:
|
||||
"""continue_on_error=True 时失败任务的下游被 SKIPPED(allow_upstream_skip=False 时)。"""
|
||||
|
||||
def boom():
|
||||
raise ValueError("boom")
|
||||
|
||||
def downstream():
|
||||
return "should not run"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("fail", fn=boom, continue_on_error=True),
|
||||
px.TaskSpec("dep", fn=downstream, depends_on=("fail",), allow_upstream_skip=False),
|
||||
])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
# report.success 为 True 因为 continue_on_error 阻止了 TaskFailedError
|
||||
assert report.success
|
||||
assert report.results["fail"].status == TaskStatus.FAILED
|
||||
assert report.results["dep"].status == TaskStatus.SKIPPED
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# soft_depends_on 默认值注入测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_soft_depends_on_default_value_injection() -> None:
|
||||
"""软依赖存在且成功时注入其结果值(参数名需与依赖名一致)。"""
|
||||
|
||||
def task_with_soft_dep(a: str | None = None) -> str:
|
||||
return f"a={a}"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", fn=lambda: "value"),
|
||||
px.TaskSpec("b", fn=task_with_soft_dep, soft_depends_on=("a",)),
|
||||
])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert report["b"] == "a=value"
|
||||
|
||||
|
||||
def test_soft_depends_on_skipped_injects_none() -> None:
|
||||
"""软依赖被 SKIPPED 时注入 None(参数名需与依赖名一致)。"""
|
||||
never_true = lambda _ctx: False # noqa: E731
|
||||
|
||||
def task_with_soft_dep(skipped: str | None = None) -> str:
|
||||
return f"skipped={skipped}"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("skipped", fn=lambda: "value", conditions=(never_true,)),
|
||||
px.TaskSpec("b", fn=task_with_soft_dep, soft_depends_on=("skipped",)),
|
||||
])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
# 软依赖被 skipped 时注入 None(因为 global_context 中有 skipped,值为 None)
|
||||
assert report["b"] == "skipped=None"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# hooks 异常处理测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_hooks_pre_run_exception_logged(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""pre_run hook 抛异常应被记录但不影响任务."""
|
||||
|
||||
def bad_hook(_spec):
|
||||
raise RuntimeError("hook error")
|
||||
|
||||
hooks = px.TaskHooks(pre_run=bad_hook)
|
||||
spec = px.TaskSpec("a", fn=lambda: "ok", hooks=hooks)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="pyflowx"):
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert any("hook" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_hooks_post_run_exception_logged(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""post_run hook 抛异常应被记录但不影响任务."""
|
||||
|
||||
def bad_hook(_spec, _value):
|
||||
raise RuntimeError("post hook error")
|
||||
|
||||
hooks = px.TaskHooks(post_run=bad_hook)
|
||||
spec = px.TaskSpec("a", fn=lambda: "ok", hooks=hooks)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="pyflowx"):
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert any("hook" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_hooks_on_failure_exception_logged(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""on_failure hook 抛异常应被记录但不影响任务."""
|
||||
|
||||
def bad_hook(_spec, _exc):
|
||||
raise RuntimeError("failure hook error")
|
||||
|
||||
hooks = px.TaskHooks(on_failure=bad_hook)
|
||||
spec = px.TaskSpec("a", fn=lambda: (_ for _ in ()).throw(ValueError("task error")), hooks=hooks)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="pyflowx"), pytest.raises(px.TaskFailedError):
|
||||
px.run(graph, strategy="sequential")
|
||||
assert any("hook" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# unknown strategy 测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_unknown_strategy_raises() -> None:
|
||||
"""未知 strategy 应抛 ValueError."""
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", fn=lambda: 1)])
|
||||
with pytest.raises(ValueError, match="Unknown strategy"):
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
px.run(graph, strategy="unknown_strategy")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 空图测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_empty_graph_dependency_strategy() -> None:
|
||||
"""dependency 策略下空图应正常返回."""
|
||||
graph = px.Graph()
|
||||
report = px.run(graph, strategy="dependency")
|
||||
assert report.success
|
||||
assert len(report) == 0
|
||||
+266
-82
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import pytest
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.compose import GraphComposer, compose
|
||||
from pyflowx.errors import CycleError, DuplicateTaskError, MissingDependencyError
|
||||
|
||||
|
||||
@@ -13,13 +14,11 @@ def _fn() -> None:
|
||||
|
||||
|
||||
def test_from_specs_builds_graph() -> None:
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("a", _fn),
|
||||
px.TaskSpec("b", _fn, ("a",)),
|
||||
px.TaskSpec("c", _fn, ("a", "b")),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", _fn),
|
||||
px.TaskSpec("b", _fn, depends_on=("a",)),
|
||||
px.TaskSpec("c", _fn, depends_on=("a", "b")),
|
||||
])
|
||||
assert set(graph.names) == {"a", "b", "c"}
|
||||
assert graph.dependencies("c") == ("a", "b")
|
||||
assert len(graph) == 3
|
||||
@@ -28,68 +27,59 @@ def test_from_specs_builds_graph() -> None:
|
||||
|
||||
def test_from_specs_allows_forward_references() -> None:
|
||||
# b depends on a, but a is declared after b — order should not matter.
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("b", _fn, ("a",)),
|
||||
px.TaskSpec("a", _fn),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("b", _fn, depends_on=("a",)),
|
||||
px.TaskSpec("a", _fn),
|
||||
])
|
||||
assert graph.layers() == [["a"], ["b"]]
|
||||
|
||||
|
||||
def test_duplicate_task_raises() -> None:
|
||||
with pytest.raises(DuplicateTaskError):
|
||||
px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("a", _fn),
|
||||
px.TaskSpec("a", _fn),
|
||||
]
|
||||
)
|
||||
_ = px.Graph.from_specs([
|
||||
px.TaskSpec("a", _fn),
|
||||
px.TaskSpec("a", _fn),
|
||||
])
|
||||
|
||||
|
||||
def test_missing_dependency_raises() -> None:
|
||||
with pytest.raises(MissingDependencyError) as exc_info:
|
||||
px.Graph.from_specs([px.TaskSpec("b", _fn, ("a",))])
|
||||
_ = px.Graph.from_specs([px.TaskSpec("b", _fn, depends_on=("a",))])
|
||||
|
||||
assert exc_info.value.task == "b"
|
||||
assert exc_info.value.dependency == "a"
|
||||
|
||||
|
||||
def test_cycle_detection() -> None:
|
||||
with pytest.raises(CycleError):
|
||||
px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("a", _fn, ("c",)),
|
||||
px.TaskSpec("b", _fn, ("a",)),
|
||||
px.TaskSpec("c", _fn, ("b",)),
|
||||
]
|
||||
)
|
||||
_ = px.Graph.from_specs([
|
||||
px.TaskSpec("a", _fn, depends_on=("c",)),
|
||||
px.TaskSpec("b", _fn, depends_on=("a",)),
|
||||
px.TaskSpec("c", _fn, depends_on=("b",)),
|
||||
])
|
||||
|
||||
|
||||
def test_layers_grouping() -> None:
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("a", _fn),
|
||||
px.TaskSpec("b", _fn),
|
||||
px.TaskSpec("c", _fn, ("a", "b")),
|
||||
px.TaskSpec("d", _fn, ("c",)),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", _fn),
|
||||
px.TaskSpec("b", _fn),
|
||||
px.TaskSpec("c", _fn, depends_on=("a", "b")),
|
||||
px.TaskSpec("d", _fn, depends_on=("c",)),
|
||||
])
|
||||
layers = graph.layers()
|
||||
assert layers == [["a", "b"], ["c"], ["d"]]
|
||||
|
||||
|
||||
def test_self_dependency_rejected() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
px.TaskSpec("a", _fn, ("a",))
|
||||
_ = px.TaskSpec("a", _fn, depends_on=("a",))
|
||||
|
||||
|
||||
def test_to_mermaid() -> None:
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("a", _fn),
|
||||
px.TaskSpec("b", _fn, ("a",)),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", _fn),
|
||||
px.TaskSpec("b", _fn, depends_on=("a",)),
|
||||
])
|
||||
mermaid = graph.to_mermaid()
|
||||
assert mermaid.startswith("graph TD")
|
||||
assert 'a["a"]' in mermaid
|
||||
@@ -99,17 +89,15 @@ def test_to_mermaid() -> None:
|
||||
def test_to_mermaid_invalid_orientation() -> None:
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
|
||||
with pytest.raises(ValueError):
|
||||
graph.to_mermaid("XX")
|
||||
_ = graph.to_mermaid("XX")
|
||||
|
||||
|
||||
def test_subgraph_by_tags() -> None:
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("a", _fn, tags=("ingest",)),
|
||||
px.TaskSpec("b", _fn, ("a",), tags=("ingest",)),
|
||||
px.TaskSpec("c", _fn, ("b",), tags=("report",)),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", _fn, tags=("ingest",)),
|
||||
px.TaskSpec("b", _fn, depends_on=("a",), tags=("ingest",)),
|
||||
px.TaskSpec("c", _fn, depends_on=("b",), tags=("report",)),
|
||||
])
|
||||
sub = graph.subgraph(["ingest"])
|
||||
assert set(sub.names) == {"a", "b"}
|
||||
# Edge to dropped task c is removed; b no longer waits for anything
|
||||
@@ -118,13 +106,11 @@ def test_subgraph_by_tags() -> None:
|
||||
|
||||
|
||||
def test_subgraph_by_names() -> None:
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("a", _fn),
|
||||
px.TaskSpec("b", _fn, ("a",)),
|
||||
px.TaskSpec("c", _fn, ("b",)),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", _fn),
|
||||
px.TaskSpec("b", _fn, depends_on=("a",)),
|
||||
px.TaskSpec("c", _fn, depends_on=("b",)),
|
||||
])
|
||||
sub = graph.subgraph_by_names(["a", "b"])
|
||||
assert set(sub.names) == {"a", "b"}
|
||||
# c is dropped, so b's dep on c (none here) — but a->b edge preserved.
|
||||
@@ -134,16 +120,14 @@ def test_subgraph_by_names() -> None:
|
||||
def test_subgraph_by_names_unknown() -> None:
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
|
||||
with pytest.raises(KeyError):
|
||||
graph.subgraph_by_names(["nope"])
|
||||
_ = graph.subgraph_by_names(["nope"])
|
||||
|
||||
|
||||
def test_describe() -> None:
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("a", _fn),
|
||||
px.TaskSpec("b", _fn, ("a",)),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", _fn),
|
||||
px.TaskSpec("b", _fn, depends_on=("a",)),
|
||||
])
|
||||
desc = graph.describe()
|
||||
assert "Layer 1" in desc
|
||||
assert "Layer 2" in desc
|
||||
@@ -160,14 +144,14 @@ def test_add_chains_and_validates() -> None:
|
||||
assert "a" in graph
|
||||
# 缺失依赖应即时报错
|
||||
with pytest.raises(MissingDependencyError):
|
||||
graph.add(px.TaskSpec("b", _fn, ("missing",)))
|
||||
_ = graph.add(px.TaskSpec("b", _fn, depends_on=("missing",)))
|
||||
|
||||
|
||||
def test_add_duplicate_raises() -> None:
|
||||
graph = px.Graph()
|
||||
graph.add(px.TaskSpec("a", _fn))
|
||||
_ = graph.add(px.TaskSpec("a", _fn))
|
||||
with pytest.raises(DuplicateTaskError):
|
||||
graph.add(px.TaskSpec("a", _fn))
|
||||
_ = graph.add(px.TaskSpec("a", _fn))
|
||||
|
||||
|
||||
def test_all_specs_returns_view() -> None:
|
||||
@@ -178,20 +162,31 @@ def test_all_specs_returns_view() -> None:
|
||||
assert view is graph.all_specs() or view == graph.all_specs()
|
||||
|
||||
|
||||
def test_all_deps_combines_hard_and_soft() -> None:
|
||||
"""all_deps 应返回硬依赖 + 软依赖的组合。"""
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", _fn),
|
||||
px.TaskSpec("b", _fn),
|
||||
px.TaskSpec("c", _fn, depends_on=("a",), soft_depends_on=("b",)),
|
||||
])
|
||||
all_deps = graph.all_deps("c")
|
||||
assert set(all_deps) == {"a", "b"}
|
||||
# 硬依赖在前,软依赖在后
|
||||
assert all_deps == ("a", "b")
|
||||
|
||||
|
||||
def test_spec_accessor() -> None:
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
|
||||
assert graph.spec("a").name == "a"
|
||||
with pytest.raises(KeyError):
|
||||
graph.spec("missing")
|
||||
_ = graph.spec("missing")
|
||||
|
||||
|
||||
def test_dependencies_accessor() -> None:
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("a", _fn),
|
||||
px.TaskSpec("b", _fn, ("a",)),
|
||||
]
|
||||
)
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", _fn),
|
||||
px.TaskSpec("b", _fn, depends_on=("a",)),
|
||||
])
|
||||
assert graph.dependencies("a") == ()
|
||||
assert graph.dependencies("b") == ("a",)
|
||||
|
||||
@@ -209,16 +204,20 @@ def test_empty_graph_layers() -> None:
|
||||
|
||||
|
||||
def test_subgraph_preserves_metadata() -> None:
|
||||
"""子图应保留原任务的 retries/timeout/tags 等元数据。"""
|
||||
graph = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("a", _fn, tags=("x",), retries=3, timeout=5.0),
|
||||
px.TaskSpec("b", _fn, ("a",), tags=("y",)),
|
||||
]
|
||||
)
|
||||
"""子图应保留原任务的 retry/timeout/tags 等元数据。"""
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"a",
|
||||
_fn,
|
||||
tags=("x",),
|
||||
retry=px.RetryPolicy(max_attempts=3),
|
||||
timeout=5.0,
|
||||
),
|
||||
px.TaskSpec("b", _fn, depends_on=("a",), tags=("y",)),
|
||||
])
|
||||
sub = graph.subgraph(["x"])
|
||||
spec = sub.spec("a")
|
||||
assert spec.retries == 3
|
||||
assert spec.retry.max_attempts == 3
|
||||
assert spec.timeout == 5.0
|
||||
assert spec.tags == ("x",)
|
||||
|
||||
@@ -228,3 +227,188 @@ def test_subgraph_by_tags_no_match() -> None:
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", _fn, tags=("x",))])
|
||||
sub = graph.subgraph(["z"])
|
||||
assert len(sub) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# from_specs str 类型分支测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_from_specs_with_string_ref() -> None:
|
||||
"""from_specs 接受字符串引用并收集到 pending_refs."""
|
||||
# 字符串引用被收集到 _pending_refs,而非尝试打开文件
|
||||
graph = px.Graph.from_specs(["ref_cmd"])
|
||||
assert graph._pending_refs == ["ref_cmd"]
|
||||
|
||||
|
||||
def test_from_specs_with_invalid_type() -> None:
|
||||
"""from_specs 接受不支持的类型时应抛 TypeError."""
|
||||
with pytest.raises(TypeError, match="from_specs 只接受 TaskSpec 或 str"):
|
||||
_ = px.Graph.from_specs([123]) # type: ignore[list-item]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# to_mermaid 软依赖测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_to_mermaid_soft_depends_on() -> None:
|
||||
"""to_mermaid 应正确绘制软依赖为虚线."""
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", _fn),
|
||||
px.TaskSpec("b", _fn, soft_depends_on=("a",)),
|
||||
])
|
||||
mermaid = graph.to_mermaid()
|
||||
assert "a -.-> b" in mermaid # 软依赖用虚线
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# GraphComposer 与 compose 测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_graph_composer_resolve_all() -> None:
|
||||
"""GraphComposer.resolve_all 应展开所有图的字符串引用."""
|
||||
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn), px.TaskSpec("a2", _fn, depends_on=("a1",))])
|
||||
# 创建带 _pending_refs 的图
|
||||
graph_b = px.Graph.from_specs([px.TaskSpec("b1", _fn)])
|
||||
graph_b._pending_refs = ["cmd_a"] # 手动设置内部属性
|
||||
|
||||
composer = GraphComposer({"cmd_a": graph_a, "cmd_b": graph_b})
|
||||
resolved = composer.resolve_all()
|
||||
|
||||
# graph_b 应包含 graph_a 的任务
|
||||
assert "a1" in resolved["cmd_b"]
|
||||
assert "a2" in resolved["cmd_b"]
|
||||
|
||||
|
||||
def test_graph_composer_parse_ref_self_reference() -> None:
|
||||
"""GraphComposer.parse_ref 应检测循环引用."""
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
|
||||
composer = GraphComposer({"cmd": graph})
|
||||
with pytest.raises(ValueError, match="循环引用"):
|
||||
_ = composer.parse_ref("cmd", "cmd")
|
||||
|
||||
|
||||
def test_graph_composer_parse_ref_cmd_not_found() -> None:
|
||||
"""GraphComposer.parse_ref 应检测引用的命令不存在."""
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
|
||||
composer = GraphComposer({"cmd": graph})
|
||||
with pytest.raises(ValueError, match="引用的命令 'missing' 不存在"):
|
||||
_ = composer.parse_ref("missing", "current")
|
||||
|
||||
|
||||
def test_graph_composer_parse_ref_task_not_found() -> None:
|
||||
"""GraphComposer.parse_ref 应检测任务不存在于引用的命令中."""
|
||||
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn)])
|
||||
graph_b = px.Graph.from_specs([px.TaskSpec("b1", _fn)])
|
||||
composer = GraphComposer({"cmd_a": graph_a, "cmd_b": graph_b})
|
||||
with pytest.raises(ValueError, match="任务 'missing' 不存在于命令 'cmd_a' 中"):
|
||||
_ = composer.parse_ref("cmd_a.missing", "cmd_b")
|
||||
|
||||
|
||||
def test_graph_composer_expand_refs_no_pending() -> None:
|
||||
"""GraphComposer.expand_refs 无 pending_refs 时应原样返回."""
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
|
||||
composer = GraphComposer({"cmd": graph})
|
||||
expanded = composer.expand_refs(graph, "cmd")
|
||||
assert expanded is graph
|
||||
|
||||
|
||||
def test_compose_function() -> None:
|
||||
"""compose() 函数应等同于 GraphComposer().resolve_all()。"""
|
||||
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn)])
|
||||
graph_b = px.Graph.from_specs([px.TaskSpec("b1", _fn)])
|
||||
graph_b._pending_refs = ["cmd_a"] # 手动设置内部属性
|
||||
|
||||
resolved = compose({"cmd_a": graph_a, "cmd_b": graph_b})
|
||||
assert "a1" in resolved["cmd_b"]
|
||||
|
||||
|
||||
def test_graph_composer_expand_refs_multiple_refs_chain() -> None:
|
||||
"""expand_refs 多个 ref 应串联依赖:后一个 ref 首任务依赖前一个 ref 末任务."""
|
||||
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn)])
|
||||
graph_c = px.Graph.from_specs([px.TaskSpec("c1", _fn)])
|
||||
graph_b = px.Graph.from_specs([px.TaskSpec("b1", _fn)])
|
||||
graph_b._pending_refs = ["cmd_a", "cmd_c"]
|
||||
|
||||
composer = GraphComposer({"cmd_a": graph_a, "cmd_c": graph_c, "cmd_b": graph_b})
|
||||
resolved = composer.resolve_all()
|
||||
|
||||
# c1 应依赖 a1(后 ref 首任务依赖前 ref 末任务)
|
||||
assert "a1" in resolved["cmd_b"]
|
||||
assert "c1" in resolved["cmd_b"]
|
||||
assert "b1" in resolved["cmd_b"]
|
||||
c1_spec = resolved["cmd_b"].all_specs()["c1"]
|
||||
assert "a1" in c1_spec.depends_on
|
||||
|
||||
|
||||
def test_graph_composer_expand_refs_ref_returns_empty() -> None:
|
||||
"""expand_refs 引用空图时,previous_ref_last_task 保持 None,original_specs 走 else 分支."""
|
||||
graph_empty = px.Graph.from_specs([])
|
||||
graph_b = px.Graph.from_specs([px.TaskSpec("b1", _fn)])
|
||||
graph_b._pending_refs = ["empty_cmd"]
|
||||
|
||||
composer = GraphComposer({"empty_cmd": graph_empty, "cmd_b": graph_b})
|
||||
resolved = composer.resolve_all()
|
||||
|
||||
# b1 保留,无额外依赖
|
||||
assert "b1" in resolved["cmd_b"]
|
||||
b1_spec = resolved["cmd_b"].all_specs()["b1"]
|
||||
assert b1_spec.depends_on == ()
|
||||
|
||||
|
||||
def test_graph_composer_expand_refs_multiple_original_specs_serialized() -> None:
|
||||
"""expand_refs 多个 original_specs 应串行依赖,且首个依赖 ref 末任务."""
|
||||
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn)])
|
||||
graph_b = px.Graph.from_specs([
|
||||
px.TaskSpec("b1", _fn),
|
||||
px.TaskSpec("b2", _fn),
|
||||
px.TaskSpec("b3", _fn),
|
||||
])
|
||||
graph_b._pending_refs = ["cmd_a"]
|
||||
|
||||
composer = GraphComposer({"cmd_a": graph_a, "cmd_b": graph_b})
|
||||
resolved = composer.resolve_all()
|
||||
|
||||
specs = resolved["cmd_b"].all_specs()
|
||||
# b1 依赖 a1(ref 末任务)
|
||||
assert "a1" in specs["b1"].depends_on
|
||||
# b2 依赖 b1,b3 依赖 b2(串行)
|
||||
assert "b1" in specs["b2"].depends_on
|
||||
assert "b2" in specs["b3"].depends_on
|
||||
|
||||
|
||||
def test_graph_composer_parse_ref_dot_notation_success() -> None:
|
||||
"""parse_ref 'cmd.task' 形式应返回对应单个 TaskSpec."""
|
||||
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn), px.TaskSpec("a2", _fn)])
|
||||
composer = GraphComposer({"cmd_a": graph_a})
|
||||
|
||||
result = composer.parse_ref("cmd_a.a2", "cmd_b")
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "a2"
|
||||
|
||||
|
||||
def test_graph_composer_parse_ref_dot_notation_cmd_not_found() -> None:
|
||||
"""parse_ref 'missing.task' 形式应检测命令不存在."""
|
||||
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn)])
|
||||
composer = GraphComposer({"cmd_a": graph_a})
|
||||
|
||||
with pytest.raises(ValueError, match="引用的命令 'missing' 不存在"):
|
||||
_ = composer.parse_ref("missing.task", "cmd_b")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# resolved_spec defaults 测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_resolved_spec_applies_defaults() -> None:
|
||||
"""resolved_spec 应应用 Graph.defaults。"""
|
||||
defaults = px.GraphDefaults(timeout=10.0, retry=px.RetryPolicy(max_attempts=2))
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)], defaults=defaults)
|
||||
|
||||
resolved = graph.resolved_spec("a")
|
||||
assert resolved.timeout == 10.0
|
||||
assert resolved.retry.max_attempts == 2
|
||||
|
||||
|
||||
def test_resolved_spec_no_override() -> None:
|
||||
"""resolved_spec 不应覆盖任务已有的设置。"""
|
||||
defaults = px.GraphDefaults(timeout=10.0)
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", _fn, timeout=5.0)], defaults=defaults)
|
||||
|
||||
resolved = graph.resolved_spec("a")
|
||||
assert resolved.timeout == 5.0 # 保持原值,不被 defaults 覆盖
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
"""Tests for Graph namespace and add_subgraph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
|
||||
def _fn() -> None:
|
||||
return None
|
||||
|
||||
|
||||
def test_graph_namespace_field_default_none() -> None:
|
||||
"""Graph 默认 namespace 为 None."""
|
||||
graph = px.Graph()
|
||||
assert graph.namespace is None
|
||||
|
||||
|
||||
def test_graph_from_specs_with_namespace() -> None:
|
||||
"""from_specs(namespace=...) 应设置 graph.namespace."""
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)], namespace="ns1")
|
||||
assert graph.namespace == "ns1"
|
||||
|
||||
|
||||
def test_add_subgraph_prefixes_task_names() -> None:
|
||||
"""add_subgraph 应给子图任务名加命名空间前缀."""
|
||||
sub = px.Graph.from_specs(
|
||||
[px.TaskSpec("extract", _fn), px.TaskSpec("build", _fn, depends_on=("extract",))],
|
||||
namespace="build",
|
||||
)
|
||||
main = px.Graph.from_specs([px.TaskSpec("start", _fn)])
|
||||
main.add_subgraph(sub)
|
||||
|
||||
assert "start" in main
|
||||
assert "build:extract" in main
|
||||
assert "build:build" in main
|
||||
|
||||
|
||||
def test_add_subgraph_renames_internal_deps() -> None:
|
||||
"""add_subgraph 应给子图内部依赖名加前缀."""
|
||||
sub = px.Graph.from_specs(
|
||||
[px.TaskSpec("a", _fn), px.TaskSpec("b", _fn, depends_on=("a",))],
|
||||
namespace="ns",
|
||||
)
|
||||
main = px.Graph()
|
||||
main.add_subgraph(sub)
|
||||
|
||||
b_spec = main.all_specs()["ns:b"]
|
||||
assert b_spec.depends_on == ("ns:a",)
|
||||
|
||||
|
||||
def test_add_subgraph_all_internal_deps_prefixed() -> None:
|
||||
"""add_subgraph 子图内所有任务(含被依赖的)都加前缀."""
|
||||
sub = px.Graph.from_specs(
|
||||
[px.TaskSpec("ext", _fn), px.TaskSpec("b", _fn, depends_on=("ext",))],
|
||||
namespace="ns",
|
||||
)
|
||||
main = px.Graph()
|
||||
main.add_subgraph(sub)
|
||||
|
||||
b_spec = main.all_specs()["ns:b"]
|
||||
assert b_spec.depends_on == ("ns:ext",)
|
||||
assert "ns:ext" in main
|
||||
|
||||
|
||||
def test_add_subgraph_requires_namespace() -> None:
|
||||
"""add_subgraph 无 namespace 时应抛 ValueError."""
|
||||
sub = px.Graph.from_specs([px.TaskSpec("a", _fn)]) # 无 namespace
|
||||
main = px.Graph()
|
||||
with pytest.raises(ValueError, match="namespace"):
|
||||
main.add_subgraph(sub)
|
||||
|
||||
|
||||
def test_add_subgraph_explicit_namespace_overrides() -> None:
|
||||
"""add_subgraph(namespace=...) 应覆盖子图自带 namespace."""
|
||||
sub = px.Graph.from_specs([px.TaskSpec("a", _fn)], namespace="original")
|
||||
main = px.Graph()
|
||||
main.add_subgraph(sub, namespace="override")
|
||||
|
||||
assert "override:a" in main
|
||||
assert "original:a" not in main
|
||||
|
||||
|
||||
def test_add_subgraph_internal_injection_works() -> None:
|
||||
"""子图内部依赖注入应通过 wrapper 正常工作."""
|
||||
sub = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("extract", lambda: [1, 2, 3]),
|
||||
px.TaskSpec("build", lambda extract: [x * 2 for x in extract], depends_on=("extract",)),
|
||||
],
|
||||
namespace="build",
|
||||
)
|
||||
main = px.Graph()
|
||||
main.add_subgraph(sub)
|
||||
|
||||
report = px.run(main)
|
||||
assert report.success
|
||||
assert report["build:build"] == [2, 4, 6]
|
||||
|
||||
|
||||
def test_add_subgraph_cross_namespace_ref_via_context() -> None:
|
||||
"""跨命名空间引用应通过 Context 标注接收."""
|
||||
|
||||
def consumer(ctx: px.Context) -> str:
|
||||
return f"got {ctx['ns:data']}"
|
||||
|
||||
sub = px.Graph.from_specs(
|
||||
[px.TaskSpec("data", lambda: "data_value")],
|
||||
namespace="ns",
|
||||
)
|
||||
main = px.Graph()
|
||||
main.add_subgraph(sub)
|
||||
|
||||
main.add(px.TaskSpec("consumer", consumer, depends_on=("ns:data",)))
|
||||
|
||||
report = px.run(main)
|
||||
assert report.success
|
||||
assert report["consumer"] == "got data_value"
|
||||
|
||||
|
||||
def test_add_subgraph_context_annotation_in_subgraph() -> None:
|
||||
"""子图内部任务用 Context 标注时,wrapper 应正确传递."""
|
||||
|
||||
def sink(ctx: px.Context) -> int:
|
||||
return ctx["src"]
|
||||
|
||||
sub = px.Graph.from_specs(
|
||||
[
|
||||
px.TaskSpec("src", lambda: 42),
|
||||
px.TaskSpec("sink", sink, depends_on=("src",)),
|
||||
],
|
||||
namespace="ns",
|
||||
)
|
||||
main = px.Graph()
|
||||
main.add_subgraph(sub)
|
||||
|
||||
report = px.run(main)
|
||||
assert report.success
|
||||
assert report["ns:sink"] == 42
|
||||
|
||||
|
||||
def test_add_subgraph_chained() -> None:
|
||||
"""多个子图可链式合并到主图."""
|
||||
sub_a = px.Graph.from_specs([px.TaskSpec("a", _fn)], namespace="nsA")
|
||||
sub_b = px.Graph.from_specs([px.TaskSpec("b", _fn)], namespace="nsB")
|
||||
|
||||
main = px.Graph()
|
||||
main.add_subgraph(sub_a).add_subgraph(sub_b)
|
||||
|
||||
assert "nsA:a" in main
|
||||
assert "nsB:b" in main
|
||||
+133
-80
@@ -1,9 +1,9 @@
|
||||
"""RunReport 测试。"""
|
||||
"""RunReport 测试."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.task import TaskResult, TaskSpec, TaskStatus
|
||||
@@ -16,18 +16,17 @@ def _fn() -> int:
|
||||
def _make_result(
|
||||
name: str = "a",
|
||||
status: TaskStatus = TaskStatus.SUCCESS,
|
||||
value: object = 42,
|
||||
error: Optional[object] = None,
|
||||
value: Any = 42,
|
||||
error: BaseException | None = None,
|
||||
duration: float = 0.5,
|
||||
attempts: int = 1,
|
||||
) -> TaskResult[object]:
|
||||
spec: TaskSpec[object] = TaskSpec[object](name, _fn)
|
||||
) -> TaskResult[Any]:
|
||||
"""构造测试用 TaskResult 实例."""
|
||||
spec: TaskSpec[Any] = TaskSpec[Any](name, _fn)
|
||||
start = datetime(2024, 1, 1, 0, 0, 0)
|
||||
# 用 timedelta 精确表达秒数,避免 int() 截断小数
|
||||
from datetime import timedelta
|
||||
|
||||
end = start + timedelta(seconds=duration) if duration else None
|
||||
return TaskResult[object](
|
||||
return TaskResult[Any](
|
||||
spec=spec,
|
||||
status=status,
|
||||
value=value,
|
||||
@@ -38,85 +37,139 @@ def _make_result(
|
||||
)
|
||||
|
||||
|
||||
def test_getitem_returns_value() -> None:
|
||||
report = px.RunReport()
|
||||
report.results["a"] = _make_result("a", value=7)
|
||||
assert report["a"] == 7
|
||||
class TestRunReportAccess:
|
||||
"""测试 RunReport 的访问接口."""
|
||||
|
||||
def test_getitem_returns_value(self) -> None:
|
||||
"""report[name] 应返回任务结果值."""
|
||||
report = px.RunReport()
|
||||
report.results["a"] = _make_result("a", value=7)
|
||||
assert report["a"] == 7
|
||||
|
||||
def test_result_of_returns_full_result(self) -> None:
|
||||
"""result_of 应返回完整的 TaskResult 对象."""
|
||||
report = px.RunReport()
|
||||
r = _make_result("a")
|
||||
report.results["a"] = r
|
||||
assert report.result_of("a") is r
|
||||
|
||||
def test_contains(self) -> None:
|
||||
"""in 运算符应正确判断任务是否存在."""
|
||||
report = px.RunReport()
|
||||
report.results["a"] = _make_result("a")
|
||||
assert "a" in report
|
||||
assert "b" not in report
|
||||
|
||||
def test_iter_and_len(self) -> None:
|
||||
"""应支持迭代任务名并返回任务数量."""
|
||||
report = px.RunReport()
|
||||
report.results["a"] = _make_result("a")
|
||||
report.results["b"] = _make_result("b")
|
||||
assert list(report) == ["a", "b"]
|
||||
assert len(report) == 2
|
||||
|
||||
|
||||
def test_result_of_returns_full_result() -> None:
|
||||
report = px.RunReport()
|
||||
r = _make_result("a")
|
||||
report.results["a"] = r
|
||||
assert report.result_of("a") is r
|
||||
class TestRunReportSummary:
|
||||
"""测试 RunReport 的 summary 方法."""
|
||||
|
||||
def test_summary_success(self) -> None:
|
||||
"""应正确汇总成功和跳过的任务."""
|
||||
report = px.RunReport()
|
||||
report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS, duration=1.0)
|
||||
report.results["b"] = _make_result("b", status=TaskStatus.SKIPPED, duration=0.0)
|
||||
s = report.summary()
|
||||
assert s["success"] is True
|
||||
assert s["total_tasks"] == 2
|
||||
assert s["by_status"] == {"success": 1, "skipped": 1}
|
||||
assert s["total_duration_seconds"] == 1.0
|
||||
|
||||
def test_summary_with_none_duration(self) -> None:
|
||||
"""未开始/未结束的任务 duration 为 None,不应计入总时长."""
|
||||
report = px.RunReport()
|
||||
spec: TaskSpec[Any] = TaskSpec[Any]("a", _fn) # type: ignore[arg-type]
|
||||
report.results["a"] = TaskResult(spec=spec, status=TaskStatus.FAILED)
|
||||
s = report.summary()
|
||||
assert s["total_duration_seconds"] == 0.0
|
||||
|
||||
def test_failed_tasks(self) -> None:
|
||||
"""failed_tasks 应返回所有失败任务名."""
|
||||
report = px.RunReport()
|
||||
report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS)
|
||||
report.results["b"] = _make_result("b", status=TaskStatus.FAILED, error=ValueError("x"))
|
||||
assert report.failed_tasks() == ["b"]
|
||||
|
||||
|
||||
def test_contains() -> None:
|
||||
report = px.RunReport()
|
||||
report.results["a"] = _make_result("a")
|
||||
assert "a" in report
|
||||
assert "b" not in report
|
||||
class TestRunReportDescribe:
|
||||
"""测试 RunReport 的 describe 方法."""
|
||||
|
||||
def test_describe_success(self) -> None:
|
||||
"""应正确描述成功状态和耗时."""
|
||||
report = px.RunReport()
|
||||
report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS, duration=0.5)
|
||||
desc = report.describe()
|
||||
assert "RunReport(success=True)" in desc
|
||||
assert "a: success" in desc
|
||||
assert "0.500s" in desc
|
||||
|
||||
def test_describe_with_error(self) -> None:
|
||||
"""应正确描述失败状态和错误信息."""
|
||||
report = px.RunReport(success=False)
|
||||
report.results["a"] = _make_result("a", status=TaskStatus.FAILED, error=ValueError("boom"), duration=0.1)
|
||||
desc = report.describe()
|
||||
assert "success=False" in desc
|
||||
assert "error=ValueError" in desc
|
||||
|
||||
def test_describe_no_duration(self) -> None:
|
||||
"""无耗时的任务应显示为 '-'."""
|
||||
report = px.RunReport()
|
||||
spec: TaskSpec[Any] = TaskSpec[Any]("a", _fn) # type: ignore[arg-type]
|
||||
report.results["a"] = TaskResult[Any](spec=spec, status=TaskStatus.PENDING)
|
||||
desc = report.describe()
|
||||
assert "-" in desc # duration 显示为 "-"
|
||||
|
||||
|
||||
def test_iter_and_len() -> None:
|
||||
report = px.RunReport()
|
||||
report.results["a"] = _make_result("a")
|
||||
report.results["b"] = _make_result("b")
|
||||
assert list(report) == ["a", "b"]
|
||||
assert len(report) == 2
|
||||
class TestRunReportQueries:
|
||||
"""测试 RunReport 的新查询 API."""
|
||||
|
||||
def test_succeeded_tasks(self) -> None:
|
||||
"""succeeded_tasks 返回 SUCCESS 状态的任务名."""
|
||||
report = px.RunReport()
|
||||
report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS)
|
||||
report.results["b"] = _make_result("b", status=TaskStatus.FAILED)
|
||||
report.results["c"] = _make_result("c", status=TaskStatus.SUCCESS)
|
||||
assert report.succeeded_tasks() == ["a", "c"]
|
||||
|
||||
def test_summary_success() -> None:
|
||||
report = px.RunReport()
|
||||
report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS, duration=1.0)
|
||||
report.results["b"] = _make_result("b", status=TaskStatus.SKIPPED, duration=0.0)
|
||||
s = report.summary()
|
||||
assert s["success"] is True
|
||||
assert s["total_tasks"] == 2
|
||||
assert s["by_status"] == {"success": 1, "skipped": 1}
|
||||
assert s["total_duration_seconds"] == 1.0
|
||||
def test_skipped_tasks(self) -> None:
|
||||
"""skipped_tasks 返回 SKIPPED 状态的任务名."""
|
||||
report = px.RunReport()
|
||||
report.results["a"] = _make_result("a", status=TaskStatus.SKIPPED)
|
||||
report.results["b"] = _make_result("b", status=TaskStatus.SUCCESS)
|
||||
assert report.skipped_tasks() == ["a"]
|
||||
|
||||
def test_tasks_by_status(self) -> None:
|
||||
"""tasks_by_status 按指定状态过滤."""
|
||||
report = px.RunReport()
|
||||
report.results["a"] = _make_result("a", status=TaskStatus.FAILED)
|
||||
report.results["b"] = _make_result("b", status=TaskStatus.FAILED)
|
||||
report.results["c"] = _make_result("c", status=TaskStatus.SUCCESS)
|
||||
assert report.tasks_by_status(TaskStatus.FAILED) == ["a", "b"]
|
||||
assert report.tasks_by_status(TaskStatus.SUCCESS) == ["c"]
|
||||
assert report.tasks_by_status(TaskStatus.SKIPPED) == []
|
||||
|
||||
def test_summary_with_none_duration() -> None:
|
||||
"""未开始/未结束的任务 duration 为 None,不应计入总时长。"""
|
||||
report = px.RunReport()
|
||||
spec: TaskSpec[object] = TaskSpec("a", _fn) # type: ignore[arg-type]
|
||||
report.results["a"] = TaskResult(spec=spec, status=TaskStatus.FAILED)
|
||||
s = report.summary()
|
||||
assert s["total_duration_seconds"] == 0.0
|
||||
def test_durations(self) -> None:
|
||||
"""durations 返回任务名 -> 时长映射."""
|
||||
report = px.RunReport()
|
||||
report.results["a"] = _make_result("a", duration=1.5)
|
||||
report.results["b"] = _make_result("b", duration=2.0)
|
||||
durs = report.durations()
|
||||
assert durs["a"] == 1.5
|
||||
assert durs["b"] == 2.0
|
||||
|
||||
def test_durations_no_duration(self) -> None:
|
||||
"""无时长的任务应返回 0.0."""
|
||||
report = px.RunReport()
|
||||
spec: TaskSpec[Any] = TaskSpec[Any]("a", _fn) # type: ignore[arg-type]
|
||||
report.results["a"] = TaskResult[Any](spec=spec, status=TaskStatus.PENDING)
|
||||
durs = report.durations()
|
||||
assert durs["a"] == 0.0
|
||||
|
||||
def test_failed_tasks() -> None:
|
||||
report = px.RunReport()
|
||||
report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS)
|
||||
report.results["b"] = _make_result(
|
||||
"b", status=TaskStatus.FAILED, error=ValueError("x")
|
||||
)
|
||||
assert report.failed_tasks() == ["b"]
|
||||
|
||||
|
||||
def test_describe_success() -> None:
|
||||
report = px.RunReport()
|
||||
report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS, duration=0.5)
|
||||
desc = report.describe()
|
||||
assert "RunReport(success=True)" in desc
|
||||
assert "a: success" in desc
|
||||
assert "0.500s" in desc
|
||||
|
||||
|
||||
def test_describe_with_error() -> None:
|
||||
report = px.RunReport(success=False)
|
||||
report.results["a"] = _make_result(
|
||||
"a", status=TaskStatus.FAILED, error=ValueError("boom"), duration=0.1
|
||||
)
|
||||
desc = report.describe()
|
||||
assert "success=False" in desc
|
||||
assert "error=ValueError" in desc
|
||||
|
||||
|
||||
def test_describe_no_duration() -> None:
|
||||
report = px.RunReport()
|
||||
spec: TaskSpec[object] = TaskSpec("a", _fn) # type: ignore[arg-type]
|
||||
report.results["a"] = TaskResult(spec=spec, status=TaskStatus.PENDING)
|
||||
desc = report.describe()
|
||||
assert "-" in desc # duration 显示为 "-"
|
||||
|
||||
@@ -0,0 +1,726 @@
|
||||
"""Tests for CliRunner: command dispatch, argument parsing, exit codes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx import CliExitCode
|
||||
from pyflowx.errors import TaskFailedError
|
||||
|
||||
# 跨平台的 echo 命令
|
||||
if sys.platform == "win32":
|
||||
ECHO_CMD = ["cmd", "/c", "echo"]
|
||||
else:
|
||||
ECHO_CMD = ["echo"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 辅助工厂
|
||||
# ---------------------------------------------------------------------- #
|
||||
def _echo_graph(name: str = "echo_task", msg: str = "hello") -> px.Graph:
|
||||
"""构造一个单任务 echo 图, 用于执行成功场景."""
|
||||
return px.Graph.from_specs([px.TaskSpec(name, cmd=[*ECHO_CMD, msg])])
|
||||
|
||||
|
||||
def _failing_graph() -> px.Graph:
|
||||
"""构造一个必定失败的单任务图."""
|
||||
return px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"fail",
|
||||
cmd=["python", "-c", "import sys; sys.exit(1)"],
|
||||
)
|
||||
])
|
||||
|
||||
|
||||
def _multi_task_graph() -> px.Graph:
|
||||
"""构造一个带依赖的多任务图."""
|
||||
return px.Graph.from_specs([
|
||||
px.TaskSpec("a", cmd=[*ECHO_CMD, "a"]),
|
||||
px.TaskSpec("b", cmd=[*ECHO_CMD, "b"], depends_on=("a",)),
|
||||
])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 构造与校验
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestCliRunnerConstruction:
|
||||
"""测试 CliRunner 的构造与参数校验."""
|
||||
|
||||
def test_requires_at_least_one_command(self) -> None:
|
||||
"""没有命令时应抛出 ValueError."""
|
||||
with pytest.raises(ValueError, match="至少需要一个别名"):
|
||||
_ = px.CliRunner()
|
||||
|
||||
def test_accepts_single_graph(self) -> None:
|
||||
"""单个命令应正常构造."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()})
|
||||
assert runner.commands == ["clean"]
|
||||
|
||||
def test_accepts_multiple_graphs(self) -> None:
|
||||
"""多个命令应按插入顺序保留."""
|
||||
runner = px.CliRunner(
|
||||
aliases={
|
||||
"clean": _echo_graph("c", "clean"),
|
||||
"build": _echo_graph("b", "build"),
|
||||
"test": _echo_graph("t", "test"),
|
||||
}
|
||||
)
|
||||
assert runner.commands == ["clean", "build", "test"]
|
||||
|
||||
def test_default_strategy_is_dependency(self) -> None:
|
||||
"""默认策略应为 dependency(依赖驱动,最大并行度)."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()})
|
||||
assert runner.strategy == "dependency"
|
||||
|
||||
def test_custom_strategy_string(self) -> None:
|
||||
"""应支持通过字符串指定策略."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()}, strategy="thread")
|
||||
assert runner.strategy == "thread"
|
||||
|
||||
def test_custom_strategy_enum(self) -> None:
|
||||
"""应支持通过 Strategy 枚举指定策略."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()}, strategy="async")
|
||||
assert runner.strategy == "async"
|
||||
|
||||
def test_default_verbose_is_true(self) -> None:
|
||||
"""默认 verbose 应为 True."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()})
|
||||
assert runner.verbose is True
|
||||
|
||||
def test_custom_verbose_false(self) -> None:
|
||||
"""应支持关闭 verbose."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()}, verbose=False)
|
||||
assert runner.verbose is False
|
||||
|
||||
def test_default_description_is_empty(self) -> None:
|
||||
"""默认描述应为空字符串."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()})
|
||||
assert runner.description == ""
|
||||
|
||||
def test_custom_description(self) -> None:
|
||||
"""应支持自定义描述."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()}, description="My CLI")
|
||||
assert runner.description == "My CLI"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 属性与内省
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestCliRunnerProperties:
|
||||
"""测试 CliRunner 的属性访问."""
|
||||
|
||||
def test_commands_returns_list(self) -> None:
|
||||
"""commands 应返回列表."""
|
||||
runner = px.CliRunner(aliases={"a": _echo_graph(), "b": _echo_graph()})
|
||||
assert isinstance(runner.commands, list)
|
||||
|
||||
def test_graphs_contains_original_graphs(self) -> None:
|
||||
"""graphs 应包含原始 Graph 实例."""
|
||||
g = _echo_graph()
|
||||
runner = px.CliRunner(aliases={"cmd": g})
|
||||
assert runner.graphs["cmd"] is g
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 参数解析
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestCliRunnerParser:
|
||||
"""测试参数解析器."""
|
||||
|
||||
def test_create_parser_returns_argument_parser(self) -> None:
|
||||
"""create_parser 应返回 ArgumentParser."""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()})
|
||||
parser = runner.create_parser()
|
||||
assert isinstance(parser, ArgumentParser)
|
||||
|
||||
def test_parser_has_command_argument(self) -> None:
|
||||
"""解析器应有 command 位置参数."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()})
|
||||
parser = runner.create_parser()
|
||||
parsed = parser.parse_args(["clean"])
|
||||
assert parsed.command == "clean"
|
||||
|
||||
def test_parser_command_is_optional(self) -> None:
|
||||
"""command 应为可选参数."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()})
|
||||
parser = runner.create_parser()
|
||||
parsed = parser.parse_args([])
|
||||
assert parsed.command is None
|
||||
|
||||
def test_parser_has_strategy_option(self) -> None:
|
||||
"""解析器应有 --strategy 选项."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()})
|
||||
parser = runner.create_parser()
|
||||
parsed = parser.parse_args(["clean", "--strategy", "thread"])
|
||||
assert parsed.strategy == "thread"
|
||||
|
||||
def test_parser_strategy_default(self) -> None:
|
||||
"""--strategy 默认值应与构造时一致."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()}, strategy="async")
|
||||
parser = runner.create_parser()
|
||||
parsed = parser.parse_args(["clean"])
|
||||
assert parsed.strategy == "async"
|
||||
|
||||
def test_parser_has_dry_run_flag(self) -> None:
|
||||
"""解析器应有 --dry-run 标志."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()})
|
||||
parser = runner.create_parser()
|
||||
parsed = parser.parse_args(["clean", "--dry-run"])
|
||||
assert parsed.dry_run is True
|
||||
|
||||
def test_parser_dry_run_default_false(self) -> None:
|
||||
"""--dry-run 默认为 False."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()})
|
||||
parser = runner.create_parser()
|
||||
parsed = parser.parse_args(["clean"])
|
||||
assert parsed.dry_run is False
|
||||
|
||||
def test_parser_has_list_flag(self) -> None:
|
||||
"""解析器应有 --list 标志."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()})
|
||||
parser = runner.create_parser()
|
||||
parsed = parser.parse_args(["--list"])
|
||||
assert parsed.list is True
|
||||
|
||||
def test_parser_has_quiet_flag(self) -> None:
|
||||
"""解析器应有 --quiet 标志."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()})
|
||||
parser = runner.create_parser()
|
||||
parsed = parser.parse_args(["clean", "--quiet"])
|
||||
assert parsed.quiet is True
|
||||
|
||||
def test_parser_quiet_default_false(self) -> None:
|
||||
"""--quiet 默认为 False."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()})
|
||||
parser = runner.create_parser()
|
||||
parsed = parser.parse_args(["clean"])
|
||||
assert parsed.quiet is False
|
||||
|
||||
def test_format_commands_help_contains_all_commands(self) -> None:
|
||||
"""帮助文本应包含所有命令."""
|
||||
runner = px.CliRunner(
|
||||
{"clean": _echo_graph("c", "clean"), "build": _echo_graph("b", "build")},
|
||||
)
|
||||
help_text = runner._format_commands_help()
|
||||
assert "clean" in help_text
|
||||
assert "build" in help_text
|
||||
assert "可用命令" in help_text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 执行: 成功路径
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestCliRunnerRunSuccess:
|
||||
"""测试 CliRunner.run 的成功执行路径."""
|
||||
|
||||
def test_run_valid_command_returns_zero(self) -> None:
|
||||
"""有效命令执行成功应返回 0."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()})
|
||||
exit_code = runner.run(["clean"])
|
||||
assert exit_code == CliExitCode.SUCCESS.value
|
||||
|
||||
def test_run_executes_correct_graph(self) -> None:
|
||||
"""应执行用户指定的命令对应的图."""
|
||||
executed: list[str] = []
|
||||
|
||||
def track_a() -> None:
|
||||
executed.append("a")
|
||||
|
||||
def track_b() -> None:
|
||||
executed.append("b")
|
||||
|
||||
runner = px.CliRunner(
|
||||
aliases={
|
||||
"a": px.Graph.from_specs([px.TaskSpec("a", track_a)]),
|
||||
"b": px.Graph.from_specs([px.TaskSpec("b", track_b)]),
|
||||
}
|
||||
)
|
||||
_ = runner.run(["b"])
|
||||
assert executed == ["b"]
|
||||
|
||||
def test_run_multi_task_graph(self) -> None:
|
||||
"""应能执行带依赖的多任务图."""
|
||||
runner = px.CliRunner(aliases={"multi": _multi_task_graph()})
|
||||
exit_code = runner.run(["multi"])
|
||||
assert exit_code == CliExitCode.SUCCESS.value
|
||||
|
||||
def test_run_with_strategy_override(self) -> None:
|
||||
"""应支持通过 --strategy 覆盖默认策略."""
|
||||
runner = px.CliRunner(aliases={"echo": _echo_graph()})
|
||||
exit_code = runner.run(["echo", "--strategy", "thread"])
|
||||
assert exit_code == CliExitCode.SUCCESS.value
|
||||
|
||||
def test_run_with_dry_run(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""--dry-run 应只打印计划不执行."""
|
||||
runner = px.CliRunner(aliases={"echo": _echo_graph()})
|
||||
exit_code = runner.run(["echo", "--dry-run"])
|
||||
assert exit_code == CliExitCode.SUCCESS.value
|
||||
captured = capsys.readouterr()
|
||||
assert "Dry run" in captured.out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 执行: verbose 模式
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestCliRunnerVerbose:
|
||||
"""测试 verbose 模式."""
|
||||
|
||||
def test_verbose_default_prints_lifecycle(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""默认 verbose=True 应打印任务生命周期."""
|
||||
runner = px.CliRunner(aliases={"echo": _echo_graph()})
|
||||
_ = runner.run(["echo"])
|
||||
captured = capsys.readouterr()
|
||||
# verbose 模式下应打印任务生命周期
|
||||
assert "[verbose]" in captured.out
|
||||
|
||||
def test_quiet_flag_disables_verbose(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""--quiet 应关闭 verbose 输出."""
|
||||
runner = px.CliRunner(aliases={"echo": _echo_graph()})
|
||||
_ = runner.run(["echo", "--quiet"])
|
||||
captured = capsys.readouterr()
|
||||
# quiet 模式下不应有 [verbose] 前缀的输出
|
||||
assert "[verbose]" not in captured.out
|
||||
|
||||
def test_verbose_false_constructor_disables_verbose(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""构造时 verbose=False 应关闭 verbose 输出."""
|
||||
runner = px.CliRunner(aliases={"echo": _echo_graph()}, verbose=False)
|
||||
_ = runner.run(["echo"])
|
||||
captured = capsys.readouterr()
|
||||
assert "[verbose]" not in captured.out
|
||||
|
||||
def test_verbose_prints_command_for_cmd_task(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""verbose 模式下 cmd 任务应打印执行的命令."""
|
||||
runner = px.CliRunner(aliases={"echo": _echo_graph(msg="verbose-test")})
|
||||
_ = runner.run(["echo"])
|
||||
captured = capsys.readouterr()
|
||||
# 应打印执行的命令
|
||||
assert "执行命令" in captured.out or "执行 Shell" in captured.out
|
||||
# 应打印返回码
|
||||
assert "返回码" in captured.out
|
||||
|
||||
def test_verbose_prints_success_lifecycle(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""verbose 模式下成功任务应打印成功信息."""
|
||||
runner = px.CliRunner(aliases={"echo": _echo_graph()})
|
||||
_ = runner.run(["echo"])
|
||||
captured = capsys.readouterr()
|
||||
assert "成功" in captured.out
|
||||
|
||||
def test_verbose_prints_skip_lifecycle(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""verbose 模式下跳过的任务应打印跳过信息."""
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"skip_me",
|
||||
cmd=[*ECHO_CMD, "skip"],
|
||||
conditions=(lambda _ctx: False,),
|
||||
),
|
||||
])
|
||||
runner = px.CliRunner(aliases={"skip": graph})
|
||||
_ = runner.run(["skip"])
|
||||
captured = capsys.readouterr()
|
||||
assert "跳过" in captured.out
|
||||
|
||||
def test_verbose_prints_failure_lifecycle(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""verbose 模式下失败任务应打印失败信息."""
|
||||
runner = px.CliRunner(aliases={"fail": _failing_graph()})
|
||||
_ = runner.run(["fail"])
|
||||
captured = capsys.readouterr()
|
||||
# 失败信息可能出现在 stdout (verbose) 或 stderr (PyFlowXError)
|
||||
combined = captured.out + captured.err
|
||||
assert "失败" in combined or "错误" in combined
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 执行: 失败路径
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestCliRunnerRunFailure:
|
||||
"""测试 CliRunner.run 的失败执行路径."""
|
||||
|
||||
def test_run_unknown_command_returns_failure(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""未知命令应返回 1 并打印错误."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()})
|
||||
exit_code = runner.run(["unknown"])
|
||||
assert exit_code == CliExitCode.FAILURE.value
|
||||
captured = capsys.readouterr()
|
||||
assert "未知命令" in captured.err
|
||||
assert "clean" in captured.err
|
||||
|
||||
def test_run_no_command_returns_failure(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""无命令时应返回 1 并打印帮助."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph()})
|
||||
exit_code = runner.run([])
|
||||
assert exit_code == CliExitCode.FAILURE.value
|
||||
captured = capsys.readouterr()
|
||||
assert "可用命令" in captured.out or "可用命令" in captured.err
|
||||
|
||||
def test_run_failing_task_returns_failure(self) -> None:
|
||||
"""任务失败时应返回 1."""
|
||||
runner = px.CliRunner(aliases={"fail": _failing_graph()})
|
||||
exit_code = runner.run(["fail"])
|
||||
assert exit_code == CliExitCode.FAILURE.value
|
||||
|
||||
def test_run_failing_task_prints_error(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""任务失败时应打印错误信息."""
|
||||
runner = px.CliRunner(aliases={"fail": _failing_graph()})
|
||||
_ = runner.run(["fail"])
|
||||
captured = capsys.readouterr()
|
||||
# PyFlowXError 信息应输出到 stderr
|
||||
assert "错误" in captured.err or "失败" in captured.err
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 执行: --list 选项
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestCliRunnerList:
|
||||
"""测试 --list 选项."""
|
||||
|
||||
def test_list_returns_success(self) -> None:
|
||||
"""--list 应返回 0."""
|
||||
runner = px.CliRunner(aliases={"clean": _echo_graph(), "build": _echo_graph()})
|
||||
exit_code = runner.run(["--list"])
|
||||
assert exit_code == CliExitCode.SUCCESS.value
|
||||
|
||||
def test_list_prints_all_commands(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""--list 应打印所有命令."""
|
||||
runner = px.CliRunner(
|
||||
aliases={
|
||||
"clean": _echo_graph("c", "clean"),
|
||||
"build": _echo_graph("b", "build"),
|
||||
"test": _echo_graph("t", "test"),
|
||||
}
|
||||
)
|
||||
_ = runner.run(["--list"])
|
||||
captured = capsys.readouterr()
|
||||
assert "clean" in captured.out
|
||||
assert "build" in captured.out
|
||||
assert "test" in captured.out
|
||||
|
||||
def test_list_does_not_execute_any_graph(self) -> None:
|
||||
"""--list 不应执行任何图."""
|
||||
executed: list[str] = []
|
||||
|
||||
def track() -> None:
|
||||
executed.append("ran")
|
||||
|
||||
runner = px.CliRunner(aliases={"a": px.Graph.from_specs([px.TaskSpec("a", track)])})
|
||||
_ = runner.run(["--list"])
|
||||
assert executed == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 错误处理
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestCliRunnerErrorHandling:
|
||||
"""测试错误处理."""
|
||||
|
||||
def test_keyboard_interrupt_returns_130(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""KeyboardInterrupt 应返回 130."""
|
||||
runner = px.CliRunner(aliases={"echo": _echo_graph()})
|
||||
|
||||
def raise_interrupt(*_args: Any, **_kwargs: Any) -> None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
with patch("pyflowx.runner.run", side_effect=raise_interrupt):
|
||||
exit_code = runner.run(["echo"])
|
||||
assert exit_code == CliExitCode.INTERRUPTED.value
|
||||
captured = capsys.readouterr()
|
||||
assert "取消" in captured.err
|
||||
|
||||
def test_pyflowx_error_returns_failure(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""PyFlowXError 应返回 1."""
|
||||
runner = px.CliRunner(aliases={"echo": _echo_graph()})
|
||||
|
||||
def raise_error(*_args: Any, **_kwargs: Any) -> None:
|
||||
raise TaskFailedError("echo", RuntimeError("boom"), 1)
|
||||
|
||||
with patch("pyflowx.runner.run", side_effect=raise_error):
|
||||
exit_code = runner.run(["echo"])
|
||||
assert exit_code == CliExitCode.FAILURE.value
|
||||
captured = capsys.readouterr()
|
||||
assert "错误" in captured.err
|
||||
|
||||
def test_generic_exception_propagates(self) -> None:
|
||||
"""非 PyFlowXError 的异常应向上传播."""
|
||||
|
||||
class CustomError(Exception):
|
||||
pass
|
||||
|
||||
runner = px.CliRunner(aliases={"echo": _echo_graph()})
|
||||
|
||||
def raise_custom(*_args: Any, **_kwargs: Any) -> None:
|
||||
raise CustomError("unexpected")
|
||||
|
||||
with patch("pyflowx.runner.run", side_effect=raise_custom), pytest.raises(CustomError):
|
||||
_ = runner.run(["echo"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# run_cli
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestCliRunnerRunCli:
|
||||
"""测试 run_cli 方法."""
|
||||
|
||||
def test_run_cli_calls_sys_exit(self) -> None:
|
||||
"""run_cli 应调用 sys.exit."""
|
||||
runner = px.CliRunner(aliases={"echo": _echo_graph()})
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
runner.run_cli(["echo"])
|
||||
assert exc_info.value.code == CliExitCode.SUCCESS.value
|
||||
|
||||
def test_run_cli_exit_code_on_failure(self) -> None:
|
||||
"""run_cli 失败时应以非零码退出."""
|
||||
runner = px.CliRunner(aliases={"fail": _failing_graph()})
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
runner.run_cli(["fail"])
|
||||
assert exc_info.value.code == CliExitCode.FAILURE.value
|
||||
|
||||
def test_run_cli_no_args_uses_sys_argv(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""run_cli 无参数时应使用 sys.argv."""
|
||||
monkeypatch.setattr(sys, "argv", ["pymake", "echo"])
|
||||
runner = px.CliRunner(aliases={"echo": _echo_graph()})
|
||||
with pytest.raises(SystemExit) as exc_info:
|
||||
runner.run_cli()
|
||||
assert exc_info.value.code == CliExitCode.SUCCESS.value
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 退出码枚举
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestCliExitCode:
|
||||
"""测试 CliExitCode 枚举."""
|
||||
|
||||
def test_success_is_zero(self) -> None:
|
||||
assert CliExitCode.SUCCESS.value == 0
|
||||
|
||||
def test_failure_is_one(self) -> None:
|
||||
assert CliExitCode.FAILURE.value == 1
|
||||
|
||||
def test_interrupted_is_130(self) -> None:
|
||||
assert CliExitCode.INTERRUPTED.value == 130
|
||||
|
||||
def test_exit_codes_are_distinct(self) -> None:
|
||||
values = {e.value for e in CliExitCode}
|
||||
assert len(values) == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 集成测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestCliRunnerIntegration:
|
||||
"""集成测试: CliRunner + Graph + TaskSpec + 条件."""
|
||||
|
||||
def test_condition_skipped_command_succeeds(self) -> None:
|
||||
"""条件不满足时任务跳过, 整体仍成功."""
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"skip_me",
|
||||
cmd=[*ECHO_CMD, "should not run"],
|
||||
conditions=(lambda _ctx: False,),
|
||||
),
|
||||
])
|
||||
runner = px.CliRunner(aliases={"skip": graph})
|
||||
exit_code = runner.run(["skip"])
|
||||
assert exit_code == CliExitCode.SUCCESS.value
|
||||
|
||||
def test_condition_met_command_succeeds(self) -> None:
|
||||
"""条件满足时任务执行, 整体成功."""
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"run_me",
|
||||
cmd=[*ECHO_CMD, "should run"],
|
||||
conditions=(lambda _ctx: True,),
|
||||
),
|
||||
])
|
||||
runner = px.CliRunner(aliases={"run": graph})
|
||||
exit_code = runner.run(["run"])
|
||||
assert exit_code == CliExitCode.SUCCESS.value
|
||||
|
||||
def test_diamond_dependency_graph(self) -> None:
|
||||
"""菱形依赖图应正确执行."""
|
||||
order: list[str] = []
|
||||
|
||||
def make(name: str) -> Any:
|
||||
def fn() -> str:
|
||||
order.append(name)
|
||||
return name
|
||||
|
||||
return fn
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", make("a")),
|
||||
px.TaskSpec("b", make("b"), depends_on=("a",)),
|
||||
px.TaskSpec("c", make("c"), depends_on=("a",)),
|
||||
px.TaskSpec("d", make("d"), depends_on=("b", "c")),
|
||||
])
|
||||
runner = px.CliRunner(aliases={"diamond": graph})
|
||||
exit_code = runner.run(["diamond"])
|
||||
assert exit_code == CliExitCode.SUCCESS.value
|
||||
assert order == ["a", "b", "c", "d"]
|
||||
|
||||
def test_mixed_fn_and_cmd_commands(self) -> None:
|
||||
"""混合 fn 和 cmd 的命令应都能执行."""
|
||||
runner = px.CliRunner(
|
||||
aliases={
|
||||
"fn_cmd": px.Graph.from_specs([px.TaskSpec("fn", fn=lambda: "fn-result")]),
|
||||
"cmd_cmd": px.Graph.from_specs([px.TaskSpec("cmd", cmd=[*ECHO_CMD, "cmd-result"])]),
|
||||
}
|
||||
)
|
||||
assert runner.run(["fn_cmd"]) == CliExitCode.SUCCESS.value
|
||||
assert runner.run(["cmd_cmd"]) == CliExitCode.SUCCESS.value
|
||||
|
||||
def test_command_with_cwd(self) -> None:
|
||||
"""带 cwd 的命令应正确执行."""
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
if sys.platform == "win32":
|
||||
ls_cmd = ["cmd", "/c", "dir"]
|
||||
else:
|
||||
ls_cmd = ["ls"]
|
||||
|
||||
graph = px.Graph.from_specs([px.TaskSpec("ls", cmd=ls_cmd, cwd=Path(tmpdir))])
|
||||
runner = px.CliRunner(aliases={"ls": graph})
|
||||
exit_code = runner.run(["ls"])
|
||||
assert exit_code == CliExitCode.SUCCESS.value
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# _apply_verbose_to_graph (补充覆盖)
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestApplyVerboseToGraph:
|
||||
"""测试 _apply_verbose_to_graph 函数 (补充覆盖)."""
|
||||
|
||||
def test_specs_with_matching_verbose_are_kept(self) -> None:
|
||||
"""spec.verbose 已与目标值匹配时应保留原 spec (覆盖 runner.py line 57)."""
|
||||
from pyflowx.runner import _apply_verbose_to_graph
|
||||
|
||||
# 创建 verbose=True 的 spec
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", cmd=[*ECHO_CMD, "a"], verbose=True)])
|
||||
# 应用 verbose=True, spec.verbose 已匹配, 应保留原 spec
|
||||
new_graph = _apply_verbose_to_graph(graph, verbose=True)
|
||||
new_spec = new_graph.spec("a")
|
||||
assert new_spec.verbose is True
|
||||
|
||||
def test_specs_with_non_matching_verbose_are_replaced(self) -> None:
|
||||
"""spec.verbose 与目标值不匹配时应替换 (覆盖 else 分支)."""
|
||||
from pyflowx.runner import _apply_verbose_to_graph
|
||||
|
||||
# 创建 verbose=False 的 spec
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", cmd=[*ECHO_CMD, "a"], verbose=False)])
|
||||
# 应用 verbose=True, spec.verbose 不匹配, 应替换
|
||||
new_graph = _apply_verbose_to_graph(graph, verbose=True)
|
||||
new_spec = new_graph.spec("a")
|
||||
assert new_spec.verbose is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 新 API: tasks + aliases
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestCliRunnerNewApi:
|
||||
"""测试 CliRunner 的 tasks + aliases 新 API."""
|
||||
|
||||
def test_tasks_plus_aliases_single_str(self) -> None:
|
||||
"""tasks 注册 + aliases str 引用单任务."""
|
||||
runner = px.CliRunner(
|
||||
tasks=[px.cmd([*ECHO_CMD, "a"], name="task_a")],
|
||||
aliases={"a": "task_a"},
|
||||
)
|
||||
assert runner.commands == ["a"]
|
||||
assert runner.run(["a"]) == CliExitCode.SUCCESS.value
|
||||
|
||||
def test_aliases_list_str_builds_chain(self) -> None:
|
||||
"""aliases list[str] 应建立 chain 依赖(后一个依赖前一个)."""
|
||||
runner = px.CliRunner(
|
||||
tasks=[
|
||||
px.cmd([*ECHO_CMD, "a"], name="task_a"),
|
||||
px.cmd([*ECHO_CMD, "b"], name="task_b"),
|
||||
],
|
||||
aliases={"ab": ["task_a", "task_b"]},
|
||||
)
|
||||
graph = runner.graphs["ab"]
|
||||
specs = graph.all_specs()
|
||||
assert specs["task_b"].depends_on == ("task_a",)
|
||||
|
||||
def test_aliases_taskspec_value(self) -> None:
|
||||
"""aliases 值为 TaskSpec 时直接生成单任务图."""
|
||||
spec = px.cmd([*ECHO_CMD, "x"], name="inline_x")
|
||||
runner = px.CliRunner(aliases={"x": spec})
|
||||
assert runner.run(["x"]) == CliExitCode.SUCCESS.value
|
||||
|
||||
def test_aliases_graph_value(self) -> None:
|
||||
"""aliases 值为 Graph 时原样使用(复杂场景:conditions 等)."""
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", cmd=[*ECHO_CMD, "a"]),
|
||||
px.TaskSpec("b", cmd=[*ECHO_CMD, "b"], depends_on=("a",)),
|
||||
])
|
||||
runner = px.CliRunner(aliases={"g": graph})
|
||||
assert set(runner.graphs["g"].all_specs().keys()) == {"a", "b"}
|
||||
|
||||
def test_alias_name_same_as_task_name_via_taskspec(self) -> None:
|
||||
"""alias 名与 task 名相同时,用 TaskSpec 避免自引用循环."""
|
||||
spec = px.cmd([*ECHO_CMD, "same"], name="same")
|
||||
runner = px.CliRunner(aliases={"same": spec})
|
||||
assert runner.run(["same"]) == CliExitCode.SUCCESS.value
|
||||
|
||||
def test_alias_str_reference_to_other_alias(self) -> None:
|
||||
"""alias 值为 str 引用其他 alias."""
|
||||
runner = px.CliRunner(
|
||||
aliases={
|
||||
"base": px.cmd([*ECHO_CMD, "base"], name="base"),
|
||||
"wrapper": "base",
|
||||
},
|
||||
)
|
||||
assert runner.run(["wrapper"]) == CliExitCode.SUCCESS.value
|
||||
|
||||
def test_empty_aliases_raises(self) -> None:
|
||||
"""空 aliases 应抛 ValueError."""
|
||||
with pytest.raises(ValueError, match="至少需要一个别名"):
|
||||
_ = px.CliRunner()
|
||||
|
||||
def test_empty_list_value_raises(self) -> None:
|
||||
"""空 list 作为 alias 值应抛 ValueError."""
|
||||
with pytest.raises(ValueError, match="任务列表为空"):
|
||||
_ = px.CliRunner(aliases={"x": []})
|
||||
|
||||
def test_invalid_value_type_raises(self) -> None:
|
||||
"""无效类型(int)作为 alias 值应抛 TypeError."""
|
||||
with pytest.raises(TypeError, match="值类型无效"):
|
||||
_ = px.CliRunner(aliases={"x": 123}) # type: ignore[dict-item]
|
||||
|
||||
def test_invalid_list_element_type_raises(self) -> None:
|
||||
"""list 中非 str/TaskSpec 元素应抛 TypeError."""
|
||||
with pytest.raises(TypeError, match="列表元素类型无效"):
|
||||
_ = px.CliRunner(aliases={"x": [123]}) # type: ignore[list-item]
|
||||
|
||||
def test_duplicate_task_name_raises(self) -> None:
|
||||
"""tasks 中重名任务应抛 ValueError."""
|
||||
spec = px.cmd([*ECHO_CMD, "a"], name="dup")
|
||||
with pytest.raises(ValueError, match="任务名重复"):
|
||||
_ = px.CliRunner(tasks=[spec, spec], aliases={"a": "dup"})
|
||||
|
||||
def test_commands_excludes_unreferenced_tasks(self) -> None:
|
||||
"""commands 只含 aliases,不含 tasks 中未引用的任务."""
|
||||
runner = px.CliRunner(
|
||||
tasks=[
|
||||
px.cmd([*ECHO_CMD, "a"], name="used"),
|
||||
px.cmd([*ECHO_CMD, "b"], name="unused"),
|
||||
],
|
||||
aliases={"a": "used"},
|
||||
)
|
||||
assert runner.commands == ["a"]
|
||||
|
||||
def test_unknown_command_rejected(self) -> None:
|
||||
"""未注册的 alias 名应被拒绝(不接受裸 task 名)."""
|
||||
runner = px.CliRunner(
|
||||
tasks=[px.cmd([*ECHO_CMD, "a"], name="task_a")],
|
||||
aliases={"a": "task_a"},
|
||||
)
|
||||
# task_a 是任务名,不是 alias,应被拒绝
|
||||
assert runner.run(["task_a"]) == CliExitCode.FAILURE.value
|
||||
+169
-18
@@ -5,6 +5,8 @@ from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
@@ -13,6 +15,14 @@ from pyflowx.errors import StorageError
|
||||
from pyflowx.storage import JSONBackend, MemoryBackend, StateBackend, resolve_backend
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tmp_json(tmp_path: Path) -> Path:
|
||||
"""模拟临时 JSON 文件。"""
|
||||
path = tmp_path / "state.json"
|
||||
path.touch()
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# MemoryBackend
|
||||
# ---------------------------------------------------------------------- #
|
||||
@@ -34,12 +44,52 @@ def test_memory_backend_get_missing_raises() -> None:
|
||||
b.get("nope")
|
||||
|
||||
|
||||
def test_memory_backend_ttl_expired() -> None:
|
||||
"""MemoryBackend TTL 过期后 has/get 返回 False/抛 KeyError."""
|
||||
b = MemoryBackend(ttl=0.1) # 0.1 秒过期
|
||||
b.save("a", 1)
|
||||
assert b.has("a")
|
||||
time.sleep(0.15)
|
||||
assert not b.has("a")
|
||||
with pytest.raises(KeyError):
|
||||
b.get("a")
|
||||
|
||||
|
||||
def test_memory_backend_ttl_load_filters_expired() -> None:
|
||||
"""MemoryBackend.load() 应过滤过期的条目."""
|
||||
b = MemoryBackend(ttl=0.1)
|
||||
b.save("a", 1)
|
||||
b.save("b", 2)
|
||||
time.sleep(0.15)
|
||||
# a 过期,但 b 也要过期... 需要更精确控制
|
||||
# 使用 monkeypatch 更可控
|
||||
b._store["expired"] = ("value", time.monotonic() - 100) # 手动设置过期时间
|
||||
b._store["fresh"] = ("value2", time.monotonic())
|
||||
assert "expired" not in dict(b.load())
|
||||
assert "fresh" in dict(b.load())
|
||||
|
||||
|
||||
def test_memory_backend_expired_key_not_in_store() -> None:
|
||||
"""不存在的键 has 返回 False."""
|
||||
b = MemoryBackend(ttl=1.0)
|
||||
assert b.has("nonexistent") is False
|
||||
|
||||
|
||||
def test_memory_backend_no_ttl_never_expired() -> None:
|
||||
"""无 TTL 时永不过期."""
|
||||
b = MemoryBackend()
|
||||
b.save("a", 1)
|
||||
b._store["a"] = (1, time.monotonic() - 1000) # 手动设置很久以前的存储
|
||||
assert b.has("a") # 仍然存在
|
||||
assert b.get("a") == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# JSONBackend
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_json_backend_save_and_load() -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "state.json")
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path)
|
||||
b.save("a", {"x": 1})
|
||||
b.save("b", [1, 2, 3])
|
||||
@@ -53,20 +103,20 @@ def test_json_backend_save_and_load() -> None:
|
||||
|
||||
def test_json_backend_clear() -> None:
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "state.json")
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path)
|
||||
b.save("a", 1)
|
||||
b.clear()
|
||||
assert not b.has("a")
|
||||
# 文件应被写入空 dict
|
||||
with open(path, "r", encoding="utf-8") as fh:
|
||||
with open(path, encoding="utf-8") as fh:
|
||||
assert json.load(fh) == {}
|
||||
|
||||
|
||||
def test_json_backend_nonexistent_file_starts_empty() -> None:
|
||||
"""文件不存在时应正常初始化为空。"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "absent.json")
|
||||
path = str(Path(tmp) / "absent.json")
|
||||
b = JSONBackend(path)
|
||||
assert dict(b.load()) == {}
|
||||
assert not b.has("anything")
|
||||
@@ -75,7 +125,7 @@ def test_json_backend_nonexistent_file_starts_empty() -> None:
|
||||
def test_json_backend_non_serialisable_raises() -> None:
|
||||
"""不可 JSON 序列化的值应抛 StorageError,且不污染内存状态。"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "state.json")
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path)
|
||||
with pytest.raises(StorageError):
|
||||
b.save("a", object()) # object() 不可序列化
|
||||
@@ -91,12 +141,12 @@ def test_json_backend_flush_type_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
import json as _json
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "state.json")
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path)
|
||||
|
||||
original_dump = _json.dump
|
||||
|
||||
def flaky_dump(*args: Any, **kwargs: Any) -> None:
|
||||
def flaky_dump(*_args: Any, **_kwargs: Any) -> None:
|
||||
raise TypeError("simulated flush failure")
|
||||
|
||||
monkeypatch.setattr(_json, "dump", flaky_dump)
|
||||
@@ -109,15 +159,15 @@ def test_json_backend_flush_type_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_json_backend_flush_os_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""_flush 时 OSError 应转为 StorageError。"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "state.json")
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path)
|
||||
|
||||
original_replace = os.replace
|
||||
|
||||
def fail_replace(*args: Any, **kwargs: Any) -> None:
|
||||
def fail_replace(*_args: Any, **_kwargs: Any) -> None:
|
||||
raise OSError("simulated os.replace failure")
|
||||
|
||||
monkeypatch.setattr(os, "replace", fail_replace)
|
||||
monkeypatch.setattr(Path, "replace", fail_replace)
|
||||
with pytest.raises(StorageError, match="cannot write"):
|
||||
b.save("a", 1)
|
||||
monkeypatch.setattr(os, "replace", original_replace)
|
||||
@@ -126,23 +176,124 @@ def test_json_backend_flush_os_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_json_backend_corrupt_file_raises() -> None:
|
||||
"""损坏的 JSON 文件应抛 StorageError。"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "state.json")
|
||||
path = str(Path(tmp) / "state.json")
|
||||
with open(path, "w", encoding="utf-8") as fh:
|
||||
fh.write("{not valid json")
|
||||
_ = fh.write("{not valid json")
|
||||
with pytest.raises(StorageError):
|
||||
JSONBackend(path)
|
||||
_ = JSONBackend(path)
|
||||
|
||||
|
||||
def test_json_backend_non_dict_content_ignored() -> None:
|
||||
def test_json_backend_non_dict_content_ignored(tmp_path: Path) -> None:
|
||||
"""文件内容是合法 JSON 但非 dict 时应被忽略(保持空)。"""
|
||||
path = tmp_path / "state.json"
|
||||
_ = path.write_text(json.dumps([1, 2, 3])) # list 而非 dict
|
||||
b = JSONBackend(str(path))
|
||||
assert dict(b.load()) == {}
|
||||
|
||||
|
||||
def test_json_backend_old_format_migration(tmp_path: Path) -> None:
|
||||
"""旧格式JSON(纯值)应被迁移为新格式(带ts)。"""
|
||||
path = tmp_path / "state.json"
|
||||
# 写入旧格式:纯值
|
||||
old_data = {"a": 1, "b": "value"}
|
||||
_ = path.write_text(json.dumps(old_data))
|
||||
|
||||
b = JSONBackend(str(path))
|
||||
# 读取后应有ts字段
|
||||
assert "a" in b._store
|
||||
assert "value" in b._store["a"]
|
||||
assert "ts" in b._store["a"]
|
||||
assert b._store["a"]["value"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# JSONBackend TTL 测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_json_backend_ttl_expired_has_returns_false() -> None:
|
||||
"""JSONBackend TTL 过期后 has 返回 False."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = os.path.join(tmp, "state.json")
|
||||
with open(path, "w", encoding="utf-8") as fh:
|
||||
json.dump([1, 2, 3], fh) # list 而非 dict
|
||||
b = JSONBackend(path)
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path, ttl=0.1)
|
||||
b.save("a", 1)
|
||||
assert b.has("a")
|
||||
time.sleep(0.15)
|
||||
assert not b.has("a")
|
||||
|
||||
|
||||
def test_json_backend_ttl_expired_get_raises_keyerror() -> None:
|
||||
"""JSONBackend TTL 过期后 get 抛 KeyError."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path, ttl=0.1)
|
||||
b.save("a", 1)
|
||||
time.sleep(0.15)
|
||||
with pytest.raises(KeyError):
|
||||
b.get("a")
|
||||
|
||||
|
||||
def test_json_backend_ttl_load_filters_expired() -> None:
|
||||
"""JSONBackend.load() 应过滤过期的条目."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path, ttl=0.1)
|
||||
b.save("a", 1)
|
||||
b.save("b", 2)
|
||||
time.sleep(0.15)
|
||||
# 两个都过期了
|
||||
assert dict(b.load()) == {}
|
||||
|
||||
|
||||
def test_json_backend_expired_no_ttl() -> None:
|
||||
"""无 TTL 时永不过期."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path)
|
||||
b.save("a", 1)
|
||||
# 手动修改 ts 为很久以前
|
||||
b._store["a"]["ts"] = time.time() - 1000
|
||||
assert b.has("a") is True # 无 TTL,永不过期
|
||||
|
||||
|
||||
def test_json_backend_expired_with_ttl() -> None:
|
||||
"""有 TTL 时过期键 has 返回 False."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path, ttl=1.0)
|
||||
b.save("a", 1)
|
||||
# 手动修改 ts 为很久以前
|
||||
b._store["a"]["ts"] = time.time() - 10 # 10 秒前,超过 TTL
|
||||
assert b.has("a") is False
|
||||
|
||||
|
||||
def test_json_backend_expired_missing_ts() -> None:
|
||||
"""entry 缺少 ts 时视为过期."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path, ttl=1.0)
|
||||
b._store["a"] = {"value": 1} # 缺少 ts
|
||||
# ts 默认为 0,已经过了很久
|
||||
assert b.has("a") is False
|
||||
|
||||
|
||||
def test_json_backend_save_value_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""save 时 json.dumps 抛 ValueError 应转为 StorageError."""
|
||||
import json as _json
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path)
|
||||
|
||||
original_dumps = _json.dumps
|
||||
|
||||
def flaky_dumps(*_args: Any, **_kwargs: Any) -> str:
|
||||
raise ValueError("simulated dumps failure")
|
||||
|
||||
monkeypatch.setattr(_json, "dumps", flaky_dumps)
|
||||
with pytest.raises(StorageError, match="not JSON-serialisable"):
|
||||
b.save("a", 1)
|
||||
monkeypatch.setattr(_json, "dumps", original_dumps)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# resolve_backend
|
||||
# ---------------------------------------------------------------------- #
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
"""Tests for streaming result passing (iterators between tasks)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterator
|
||||
|
||||
import pyflowx as px
|
||||
|
||||
|
||||
def test_generator_passed_as_iterator() -> None:
|
||||
"""上游返回生成器,下游应能惰性消费."""
|
||||
|
||||
@px.task
|
||||
def source() -> Iterator[int]:
|
||||
yield from range(5)
|
||||
|
||||
@px.task(depends_on=("source",))
|
||||
def consume(source: Iterator[int]) -> int:
|
||||
return sum(source)
|
||||
|
||||
graph = px.Graph.from_specs([source, consume])
|
||||
report = px.run(graph)
|
||||
assert report.success
|
||||
assert report["consume"] == 10
|
||||
|
||||
|
||||
def test_large_range_streaming() -> None:
|
||||
"""大范围迭代器流式传递,避免中间列表."""
|
||||
|
||||
@px.task
|
||||
def numbers() -> Iterator[int]:
|
||||
yield from range(1000)
|
||||
|
||||
@px.task(depends_on=("numbers",))
|
||||
def total(numbers: Iterator[int]) -> int:
|
||||
return sum(numbers)
|
||||
|
||||
graph = px.Graph.from_specs([numbers, total])
|
||||
report = px.run(graph)
|
||||
assert report.success
|
||||
assert report["total"] == sum(range(1000))
|
||||
|
||||
|
||||
def test_chain_multiple_streams() -> None:
|
||||
"""多个流式任务串联."""
|
||||
|
||||
@px.task
|
||||
def gen() -> Iterator[int]:
|
||||
yield from range(10)
|
||||
|
||||
@px.task(depends_on=("gen",))
|
||||
def doubled(gen: Iterator[int]) -> Iterator[int]:
|
||||
for x in gen:
|
||||
yield x * 2
|
||||
|
||||
@px.task(depends_on=("doubled",))
|
||||
def collect(doubled: Iterator[int]) -> list[int]:
|
||||
return list(doubled)
|
||||
|
||||
graph = px.Graph.from_specs([gen, doubled, collect])
|
||||
report = px.run(graph)
|
||||
assert report.success
|
||||
assert report["collect"] == [x * 2 for x in range(10)]
|
||||
@@ -0,0 +1,246 @@
|
||||
"""Tests for tasks/system.py."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from pyflowx.conditions import Constants
|
||||
from pyflowx.tasks.system import clr, reset_icon_cache, setenv, setenv_group, which, write_file
|
||||
|
||||
|
||||
def test_clr_creates_task_spec() -> None:
|
||||
"""clr() 应创建 TaskSpec。"""
|
||||
spec = clr()
|
||||
assert spec.name == "clear_screen"
|
||||
assert spec.fn is not None
|
||||
|
||||
|
||||
def test_clr_executes_on_linux(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""clr() 在 Linux 上应执行 clear 命令。"""
|
||||
monkeypatch.setattr(Constants, "IS_WINDOWS", False)
|
||||
monkeypatch.setattr(Constants, "IS_LINUX", True)
|
||||
|
||||
# Mock subprocess.run
|
||||
ran = []
|
||||
monkeypatch.setattr(
|
||||
subprocess,
|
||||
"run",
|
||||
lambda *cmd, **__: ran.append(cmd),
|
||||
)
|
||||
|
||||
spec = clr()
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
assert ran == [(["clear"],)]
|
||||
|
||||
|
||||
def test_clr_executes_on_windows(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""clr() 在 Windows 上应执行 cls 命令。"""
|
||||
monkeypatch.setattr(Constants, "IS_WINDOWS", True)
|
||||
|
||||
# Mock subprocess.run
|
||||
ran = []
|
||||
monkeypatch.setattr(
|
||||
subprocess,
|
||||
"run",
|
||||
lambda *cmd, **__: ran.append(cmd),
|
||||
)
|
||||
|
||||
spec = clr()
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
assert ran == [(["cls"],)]
|
||||
|
||||
|
||||
def test_reset_icon_cache_non_windows(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""reset_icon_cache() 在非 Windows 上应返回空列表并打印提示。"""
|
||||
monkeypatch.setattr(Constants, "IS_WINDOWS", False)
|
||||
|
||||
specs = reset_icon_cache()
|
||||
assert specs == []
|
||||
captured = capsys.readouterr()
|
||||
assert "仅在 Windows 上支持" in captured.out
|
||||
|
||||
|
||||
def test_reset_icon_cache_windows(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""reset_icon_cache() 在 Windows 上应返回任务列表。"""
|
||||
monkeypatch.setattr(Constants, "IS_WINDOWS", True)
|
||||
monkeypatch.setenv("LOCALAPPDATA", "C:\\Users\\test\\AppData\\Local")
|
||||
|
||||
specs = reset_icon_cache()
|
||||
assert len(specs) == 4
|
||||
assert specs[0].name == "kill_explorer"
|
||||
assert specs[1].name == "delete_icon_cache"
|
||||
assert specs[2].name == "delete_icon_cache_all"
|
||||
assert specs[3].name == "restart_explorer"
|
||||
|
||||
|
||||
def test_setenv_creates_task_spec() -> None:
|
||||
"""setenv() 应创建 TaskSpec。"""
|
||||
spec = setenv("TEST_VAR", "test_value")
|
||||
assert spec.name == "setenv_test_var"
|
||||
assert spec.verbose is True
|
||||
|
||||
|
||||
def test_setenv_sets_environment_variable(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""setenv() 应设置环境变量。"""
|
||||
spec = setenv("PYFLOWX_TEST_VAR_1", "test_value")
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
assert os.environ["PYFLOWX_TEST_VAR_1"] == "test_value"
|
||||
# Clean up
|
||||
del os.environ["PYFLOWX_TEST_VAR_1"]
|
||||
|
||||
|
||||
def test_setenv_default_not_overwrite(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""setenv(default=True) 不应覆盖已存在的环境变量。"""
|
||||
os.environ["PYFLOWX_TEST_VAR_EXISTS"] = "original"
|
||||
spec = setenv("PYFLOWX_TEST_VAR_EXISTS", "new_value", default=True)
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
assert os.environ["PYFLOWX_TEST_VAR_EXISTS"] == "original"
|
||||
# Clean up
|
||||
del os.environ["PYFLOWX_TEST_VAR_EXISTS"]
|
||||
|
||||
|
||||
def test_setenv_default_sets_when_missing() -> None:
|
||||
"""setenv(default=True) 应在缺失时设置环境变量。"""
|
||||
# Ensure variable does not exist
|
||||
var_name = "PYFLOWX_TEST_VAR_MISSING"
|
||||
if var_name in os.environ:
|
||||
del os.environ[var_name]
|
||||
|
||||
spec = setenv(var_name, "default_value", default=True)
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
assert os.environ[var_name] == "default_value"
|
||||
|
||||
# Clean up after test
|
||||
del os.environ[var_name]
|
||||
|
||||
|
||||
def test_which_creates_task_spec() -> None:
|
||||
"""which() 应创建 TaskSpec。"""
|
||||
spec = which("python")
|
||||
assert spec.name == "which_python"
|
||||
|
||||
|
||||
def test_which_linux_found(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""which() 在 Linux 上找到命令应打印路径。"""
|
||||
monkeypatch.setattr(Constants, "IS_WINDOWS", False)
|
||||
|
||||
class MockResult:
|
||||
returncode = 0
|
||||
stdout = "/usr/bin/python\n"
|
||||
|
||||
monkeypatch.setattr(
|
||||
subprocess,
|
||||
"run",
|
||||
lambda *_, **__: MockResult(),
|
||||
)
|
||||
|
||||
spec = which("python")
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
captured = capsys.readouterr()
|
||||
assert "python ->" in captured.out
|
||||
assert "/usr/bin/python" in captured.out
|
||||
|
||||
|
||||
def test_which_windows_found(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""which() 在 Windows 上找到命令应打印路径。"""
|
||||
monkeypatch.setattr(Constants, "IS_WINDOWS", True)
|
||||
|
||||
class MockResult:
|
||||
returncode = 0
|
||||
stdout = "C:\\Python\\python.exe\nC:\\Python\\Scripts\\python.exe\n"
|
||||
|
||||
monkeypatch.setattr(
|
||||
subprocess,
|
||||
"run",
|
||||
lambda *_, **__: MockResult(),
|
||||
)
|
||||
|
||||
spec = which("python")
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
captured = capsys.readouterr()
|
||||
assert "python ->" in captured.out
|
||||
assert "C:\\Python\\python.exe" in captured.out
|
||||
|
||||
|
||||
def test_which_not_found(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""which() 未找到命令应打印提示。"""
|
||||
monkeypatch.setattr(Constants, "IS_WINDOWS", False)
|
||||
|
||||
class MockResult:
|
||||
returncode = 1
|
||||
stdout = ""
|
||||
|
||||
monkeypatch.setattr(
|
||||
subprocess,
|
||||
"run",
|
||||
lambda *_, **__: MockResult(),
|
||||
)
|
||||
|
||||
spec = which("nonexistent_cmd")
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
captured = capsys.readouterr()
|
||||
assert "nonexistent_cmd -> 未找到" in captured.out
|
||||
|
||||
|
||||
def test_write_file_creates_task_spec() -> None:
|
||||
"""write_file() 应创建带 verbose 的 TaskSpec。"""
|
||||
spec = write_file("/tmp/unused", "x")
|
||||
assert spec.name == "write_file_/tmp/unused"
|
||||
assert spec.verbose is True
|
||||
|
||||
|
||||
def test_write_file_writes_content(tmp_path: Path) -> None:
|
||||
"""write_file() 应将内容写入指定文件."""
|
||||
f = tmp_path / "out.txt"
|
||||
spec = write_file(str(f), "hello world")
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
assert f.read_text(encoding="utf-8") == "hello world"
|
||||
|
||||
|
||||
def test_write_file_with_encoding(tmp_path: Path) -> None:
|
||||
"""write_file() 应支持指定编码."""
|
||||
f = tmp_path / "out.txt"
|
||||
spec = write_file(str(f), "中文", encoding="utf-8")
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
assert f.read_text(encoding="utf-8") == "中文"
|
||||
|
||||
|
||||
def test_write_file_failure_propagates(tmp_path: Path) -> None:
|
||||
"""write_file() 写入失败应抛出异常(不吞异常)."""
|
||||
# 父目录不存在时写入应抛 FileNotFoundError
|
||||
missing = tmp_path / "no_such_dir" / "out.txt"
|
||||
spec = write_file(str(missing), "x")
|
||||
assert spec.fn is not None
|
||||
with pytest.raises(FileNotFoundError):
|
||||
spec.fn()
|
||||
|
||||
|
||||
def test_setenv_group_creates_specs() -> None:
|
||||
"""setenv_group() 应为每个环境变量创建 TaskSpec."""
|
||||
envs = {"VAR_A": "1", "VAR_B": "2"}
|
||||
specs = setenv_group(envs)
|
||||
assert len(specs) == 2
|
||||
assert specs[0].name == "setenv_var_a"
|
||||
assert specs[1].name == "setenv_var_b"
|
||||
|
||||
|
||||
def test_setenv_group_default_mode(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""setenv_group(default=True) 不应覆盖已存在的环境变量."""
|
||||
monkeypatch.setenv("PYFLOWX_GROUP_EXISTS", "original")
|
||||
specs = setenv_group({"PYFLOWX_GROUP_EXISTS": "new"}, default=True)
|
||||
for spec in specs:
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
assert os.environ["PYFLOWX_GROUP_EXISTS"] == "original"
|
||||
+345
-4
@@ -2,11 +2,21 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from pyflowx.task import TaskResult, TaskSpec, TaskStatus
|
||||
from pyflowx.task import (
|
||||
RetryPolicy,
|
||||
TaskResult,
|
||||
TaskSpec,
|
||||
TaskStatus,
|
||||
_env_and_cwd,
|
||||
cmd,
|
||||
task_template,
|
||||
)
|
||||
|
||||
|
||||
def _fn() -> None:
|
||||
@@ -18,9 +28,9 @@ def test_spec_empty_name_rejected() -> None:
|
||||
TaskSpec("", _fn)
|
||||
|
||||
|
||||
def test_spec_negative_retries_rejected() -> None:
|
||||
with pytest.raises(ValueError, match="retries"):
|
||||
TaskSpec("a", _fn, retries=-1)
|
||||
def test_spec_negative_max_attempts_rejected() -> None:
|
||||
with pytest.raises(ValueError, match="max_attempts"):
|
||||
TaskSpec("a", _fn, retry=RetryPolicy(max_attempts=0))
|
||||
|
||||
|
||||
def test_spec_zero_timeout_rejected() -> None:
|
||||
@@ -28,11 +38,318 @@ def test_spec_zero_timeout_rejected() -> None:
|
||||
TaskSpec("a", _fn, timeout=0)
|
||||
|
||||
|
||||
def test_spec_negative_timeout_rejected() -> None:
|
||||
"""负数timeout应被拒绝。"""
|
||||
with pytest.raises(ValueError, match="timeout"):
|
||||
TaskSpec("a", _fn, timeout=-1.0)
|
||||
|
||||
|
||||
def test_spec_self_dependency_rejected() -> None:
|
||||
with pytest.raises(ValueError, match="depend on itself"):
|
||||
TaskSpec("a", _fn, depends_on=("a",))
|
||||
|
||||
|
||||
def test_spec_self_soft_dependency_rejected() -> None:
|
||||
"""self dependency via soft_depends_on 也应被拒绝."""
|
||||
with pytest.raises(ValueError, match="depend on itself"):
|
||||
TaskSpec("a", _fn, soft_depends_on=("a",))
|
||||
|
||||
|
||||
def test_spec_overlap_depends_rejected() -> None:
|
||||
"""depends_on 和 soft_depends_on 重叠应被拒绝."""
|
||||
with pytest.raises(ValueError, match="不能重叠"):
|
||||
TaskSpec("a", _fn, depends_on=("b",), soft_depends_on=("b",))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# RetryPolicy 参数验证
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_retry_policy_negative_delay_rejected() -> None:
|
||||
with pytest.raises(ValueError, match="delay must be >= 0"):
|
||||
RetryPolicy(delay=-1)
|
||||
|
||||
|
||||
def test_retry_policy_negative_backoff_rejected() -> None:
|
||||
with pytest.raises(ValueError, match="backoff must be >= 0"):
|
||||
RetryPolicy(backoff=-1)
|
||||
|
||||
|
||||
def test_retry_policy_negative_jitter_rejected() -> None:
|
||||
with pytest.raises(ValueError, match="jitter must be >= 0"):
|
||||
RetryPolicy(jitter=-1)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# cmd() 工厂
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_cmd_factory_default_name_from_two_elements() -> None:
|
||||
"""cmd() 默认 name = '_'.join(command[:2])."""
|
||||
spec = cmd(["uv", "build"])
|
||||
assert spec.name == "uv_build"
|
||||
assert spec.cmd == ["uv", "build"]
|
||||
|
||||
|
||||
def test_cmd_factory_default_name_single_element() -> None:
|
||||
"""cmd() 单元素命令 name = command[0]."""
|
||||
spec = cmd(["ls"])
|
||||
assert spec.name == "ls"
|
||||
|
||||
|
||||
def test_cmd_factory_explicit_name() -> None:
|
||||
"""cmd() 显式 name 覆盖默认推导."""
|
||||
spec = cmd(["ruff", "check", "--fix"], name="lint")
|
||||
assert spec.name == "lint"
|
||||
|
||||
|
||||
def test_cmd_factory_passes_depends_on() -> None:
|
||||
"""cmd() depends_on 透传给 TaskSpec."""
|
||||
spec = cmd(["echo", "b"], name="b", depends_on=("a",))
|
||||
assert spec.depends_on == ("a",)
|
||||
|
||||
|
||||
def test_cmd_factory_passes_extra_kwargs() -> None:
|
||||
"""cmd() 其余 kwargs 透传给 TaskSpec."""
|
||||
spec = cmd(["echo", "x"], name="x", timeout=10.0, tags=("t1",))
|
||||
assert spec.timeout == 10.0
|
||||
assert spec.tags == ("t1",)
|
||||
|
||||
|
||||
def test_retry_policy_retries_property() -> None:
|
||||
policy = RetryPolicy(max_attempts=3)
|
||||
assert policy.retries == 2
|
||||
|
||||
|
||||
def test_retry_policy_should_retry_matching() -> None:
|
||||
policy = RetryPolicy(max_attempts=3, retry_on=(ValueError,))
|
||||
assert policy.should_retry(ValueError("x")) is True
|
||||
assert policy.should_retry(RuntimeError("x")) is False
|
||||
|
||||
|
||||
def test_retry_policy_should_retry_empty_tuple() -> None:
|
||||
"""空元组等价于不重试."""
|
||||
policy = RetryPolicy(max_attempts=3, retry_on=())
|
||||
assert policy.should_retry(ValueError("x")) is False
|
||||
|
||||
|
||||
def test_retry_policy_wait_seconds_zero_attempt() -> None:
|
||||
"""attempt < 1 时返回 0."""
|
||||
policy = RetryPolicy(delay=1.0, backoff=2.0)
|
||||
assert policy.wait_seconds(0) == 0.0
|
||||
assert policy.wait_seconds(-1) == 0.0
|
||||
|
||||
|
||||
def test_retry_policy_wait_seconds_with_backoff() -> None:
|
||||
"""有 backoff 时等待时间应递增."""
|
||||
policy = RetryPolicy(delay=1.0, backoff=2.0)
|
||||
# attempt=1: delay * backoff^0 = 1
|
||||
# attempt=2: delay * backoff^1 = 2
|
||||
assert policy.wait_seconds(1) == 1.0
|
||||
assert policy.wait_seconds(2) == 2.0
|
||||
|
||||
|
||||
def test_retry_policy_wait_seconds_with_jitter() -> None:
|
||||
"""有 jitter 时等待时间应增加随机量."""
|
||||
policy = RetryPolicy(delay=1.0, jitter=0.5)
|
||||
# 多次调用验证结果在合理范围内
|
||||
for _ in range(5):
|
||||
wait = policy.wait_seconds(1)
|
||||
assert 1.0 <= wait <= 1.5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# should_execute 条件异常处理
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_should_execute_condition_exception_returns_false() -> None:
|
||||
"""条件执行抛异常时应返回 False 并记录原因."""
|
||||
|
||||
def bad_condition(_ctx):
|
||||
raise RuntimeError("condition error")
|
||||
|
||||
bad_condition.__name__ = ""
|
||||
spec = TaskSpec("a", _fn, conditions=(bad_condition,))
|
||||
should_run, reason = spec.should_execute({})
|
||||
assert should_run is False
|
||||
# pyrefly: ignore [not-iterable]
|
||||
assert "匿名条件(执行错误)" in reason
|
||||
|
||||
|
||||
def test_should_execute_condition_lambda_name() -> None:
|
||||
"""lambda 条件有 __name__ 为 '<lambda>'."""
|
||||
spec = TaskSpec("a", _fn, conditions=(lambda _ctx: False,))
|
||||
should_run, reason = spec.should_execute({})
|
||||
assert should_run is False
|
||||
# pyrefly: ignore [not-iterable]
|
||||
assert "<lambda>" in reason
|
||||
|
||||
|
||||
def test_should_execute_skip_if_missing_cmd_not_found() -> None:
|
||||
"""skip_if_missing 且命令不存在时应跳过."""
|
||||
spec = TaskSpec("a", cmd=["nonexistent_cmd_xyz"], skip_if_missing=True)
|
||||
should_run, reason = spec.should_execute({})
|
||||
assert should_run is False
|
||||
# pyrefly: ignore [not-iterable]
|
||||
assert "命令不存在" in reason
|
||||
|
||||
|
||||
def test_should_execute_skip_if_missing_cmd_found() -> None:
|
||||
"""skip_if_missing 但命令存在时应执行."""
|
||||
# 使用 Python 作为已安装的命令
|
||||
spec = TaskSpec("a", cmd=["echo"], skip_if_missing=True) # echo 应存在
|
||||
should_run, reason = spec.should_execute({})
|
||||
assert should_run is True
|
||||
assert reason is None
|
||||
|
||||
|
||||
def test_should_execute_skip_if_missing_non_list_cmd() -> None:
|
||||
"""skip_if_missing 对非 list 命令不影响."""
|
||||
spec = TaskSpec("a", cmd="echo hello", skip_if_missing=True)
|
||||
should_run, reason = spec.should_execute({})
|
||||
assert should_run is True
|
||||
assert reason is None
|
||||
|
||||
|
||||
def test_should_execute_skip_if_missing_empty_list() -> None:
|
||||
"""skip_if_missing 对空列表命令返回 True."""
|
||||
spec = TaskSpec("a", cmd=[], skip_if_missing=True)
|
||||
# 空 list 不检查
|
||||
_should_run, _reason = spec.should_execute({})
|
||||
# 因为 cmd=[] 且 fn=None,这会在 __post_init__ 中抛异常
|
||||
# 所以这个测试无效,我们用另一个方式测试 _is_cmd_available
|
||||
|
||||
|
||||
def test_is_cmd_available_empty_list_returns_true() -> None:
|
||||
"""_is_cmd_available 对空列表返回 True."""
|
||||
spec = TaskSpec("a", cmd=[], fn=_fn) # 提供 fn 避免 __post_init__ 异常
|
||||
assert spec._is_cmd_available() is True
|
||||
|
||||
|
||||
def test_is_cmd_available_string_returns_true() -> None:
|
||||
"""_is_cmd_available 对字符串命令返回 True."""
|
||||
spec = TaskSpec("a", cmd="echo hello")
|
||||
assert spec._is_cmd_available() is True
|
||||
|
||||
|
||||
def test_is_cmd_available_callable_returns_true() -> None:
|
||||
"""_is_cmd_available 对可调用命令返回 True."""
|
||||
spec = TaskSpec("a", cmd=_fn)
|
||||
assert spec._is_cmd_available() is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# storage_key 异常处理
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_storage_key_cache_key_exception_returns_name() -> None:
|
||||
"""cache_key 抛预期异常(TypeError/ValueError/KeyError/AttributeError)时应返回任务名."""
|
||||
|
||||
def bad_cache_key(_ctx):
|
||||
raise ValueError("cache key error")
|
||||
|
||||
spec = TaskSpec("a", _fn, cache_key=bad_cache_key)
|
||||
key = spec.storage_key({})
|
||||
assert key == "a"
|
||||
|
||||
|
||||
def test_storage_key_cache_key_success() -> None:
|
||||
"""cache_key 成功时应返回组合键."""
|
||||
spec = TaskSpec("a", _fn, cache_key=lambda ctx: ctx.get("x", "default"))
|
||||
key = spec.storage_key({"x": "value"})
|
||||
assert key == "a:value"
|
||||
|
||||
|
||||
def test_storage_key_no_cache_key() -> None:
|
||||
"""无 cache_key 时返回任务名."""
|
||||
spec = TaskSpec("a", _fn)
|
||||
key = spec.storage_key({})
|
||||
assert key == "a"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# _env_and_cwd 上下文管理器
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_env_and_cwd_sets_env() -> None:
|
||||
"""应临时设置环境变量。"""
|
||||
var_name = "PYFLOWX_TEST_ENV_VAR_1"
|
||||
with _env_and_cwd({var_name: "test_value"}, None):
|
||||
assert os.environ[var_name] == "test_value"
|
||||
# 退出后应恢复
|
||||
assert var_name not in os.environ
|
||||
|
||||
|
||||
def test_env_and_cwd_restores_existing_env() -> None:
|
||||
"""应恢复已有的环境变量."""
|
||||
os.environ["EXISTING_VAR"] = "original"
|
||||
try:
|
||||
with _env_and_cwd({"EXISTING_VAR": "new_value"}, None):
|
||||
assert os.environ["EXISTING_VAR"] == "new_value"
|
||||
# 退出后应恢复原值
|
||||
assert os.environ["EXISTING_VAR"] == "original"
|
||||
finally:
|
||||
os.environ.pop("EXISTING_VAR", None)
|
||||
|
||||
|
||||
def test_env_and_cwd_sets_cwd(tmp_path: Path) -> None:
|
||||
"""应临时切换工作目录."""
|
||||
original = Path.cwd()
|
||||
with _env_and_cwd(None, tmp_path):
|
||||
assert Path.cwd() == tmp_path
|
||||
# 退出后应恢复
|
||||
assert Path.cwd() == original
|
||||
|
||||
|
||||
def test_env_and_cwd_no_changes() -> None:
|
||||
"""无 env 和 cwd 时不应有任何变化."""
|
||||
original_env = dict(os.environ)
|
||||
original_cwd = Path.cwd()
|
||||
with _env_and_cwd(None, None):
|
||||
pass
|
||||
assert dict(os.environ) == original_env
|
||||
assert Path.cwd() == original_cwd
|
||||
|
||||
|
||||
def test_spec_env_context() -> None:
|
||||
"""TaskSpec.env_context 应正确工作."""
|
||||
var_name = "PYFLOWX_TEST_ENV_VAR_2"
|
||||
spec = TaskSpec("a", _fn, env={var_name: "value"})
|
||||
with spec.env_context():
|
||||
assert os.environ[var_name] == "value"
|
||||
assert var_name not in os.environ
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# task_template 工厂
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_task_template_creates_specs() -> None:
|
||||
"""task_template 应创建 TaskSpec 工厂."""
|
||||
template = task_template(fn=_fn, retry=RetryPolicy(max_attempts=3))
|
||||
spec = template("task1")
|
||||
assert spec.name == "task1"
|
||||
assert spec.retry.max_attempts == 3
|
||||
|
||||
|
||||
def test_task_template_with_cmd() -> None:
|
||||
"""task_template 可以使用 cmd."""
|
||||
template = task_template(cmd=["echo", "hello"])
|
||||
spec = template("task1")
|
||||
assert spec.name == "task1"
|
||||
assert spec.cmd == ["echo", "hello"]
|
||||
|
||||
|
||||
def test_task_template_overrides() -> None:
|
||||
"""task_template 工厂可以覆盖默认值."""
|
||||
template = task_template(fn=_fn, timeout=10.0)
|
||||
spec = template("task1", timeout=5.0)
|
||||
assert spec.timeout == 5.0
|
||||
|
||||
|
||||
def test_task_template_factory_name() -> None:
|
||||
"""工厂函数名应为 task_template_factory."""
|
||||
template = task_template(fn=_fn)
|
||||
assert template.__name__ == "task_template_factory"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# TaskResult 测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_task_result_duration_none_when_not_started() -> None:
|
||||
spec: TaskSpec[None] = TaskSpec("a", _fn)
|
||||
result: TaskResult[None] = TaskResult(spec=spec)
|
||||
@@ -61,3 +378,27 @@ def test_task_result_default_status() -> None:
|
||||
assert result.value is None
|
||||
assert result.error is None
|
||||
assert result.attempts == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# run_command callable 命令测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_run_command_callable_verbose_with_cwd(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None:
|
||||
"""callable 命令 verbose 模式应打印信息."""
|
||||
from pyflowx.command import run_command
|
||||
|
||||
spec = TaskSpec("a", cmd=lambda: "result", verbose=True, cwd=tmp_path)
|
||||
result = run_command(spec)
|
||||
assert result == "result"
|
||||
captured = capsys.readouterr()
|
||||
assert "执行可调用命令" in captured.out
|
||||
assert "工作目录" in captured.out
|
||||
|
||||
|
||||
def test_run_command_callable_exception() -> None:
|
||||
"""callable 命令抛异常应转为 RuntimeError."""
|
||||
from pyflowx.command import run_command
|
||||
|
||||
spec = TaskSpec("a", cmd=lambda: (_ for _ in ()).throw(RuntimeError("callable error")))
|
||||
with pytest.raises(RuntimeError, match="可调用命令执行异常"):
|
||||
run_command(spec)
|
||||
|
||||
@@ -0,0 +1,136 @@
|
||||
"""Tests for the @task decorator API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.task import RetryPolicy, TaskHooks, TaskSpec
|
||||
|
||||
|
||||
def test_task_decorator_plain() -> None:
|
||||
"""@task 无参数装饰:name 取函数名,返回 TaskSpec."""
|
||||
|
||||
@px.task
|
||||
def extract() -> list[int]:
|
||||
return [1, 2, 3]
|
||||
|
||||
assert isinstance(extract, TaskSpec)
|
||||
assert extract.name == "extract"
|
||||
assert extract.fn is not None
|
||||
assert extract.depends_on == ()
|
||||
|
||||
|
||||
def test_task_decorator_with_params() -> None:
|
||||
"""@task(...) 带参数装饰:传递依赖与重试."""
|
||||
|
||||
@px.task(depends_on=("extract",), retry=RetryPolicy(max_attempts=3))
|
||||
def double(extract: list[int]) -> list[int]:
|
||||
return [x * 2 for x in extract]
|
||||
|
||||
assert isinstance(double, TaskSpec)
|
||||
assert double.name == "double"
|
||||
assert double.depends_on == ("extract",)
|
||||
assert double.retry.max_attempts == 3
|
||||
|
||||
|
||||
def test_task_decorator_explicit_name() -> None:
|
||||
"""@task(name=...) 应使用显式名称而非函数名."""
|
||||
|
||||
@px.task(name="custom_name")
|
||||
def my_func() -> None:
|
||||
return None
|
||||
|
||||
assert my_func.name == "custom_name"
|
||||
|
||||
|
||||
def test_task_decorator_cmd_form() -> None:
|
||||
"""@task(cmd=...) 应支持命令形式."""
|
||||
|
||||
spec = px.task(cmd=["ls", "-la"], name="list_files")
|
||||
assert isinstance(spec, TaskSpec)
|
||||
assert spec.name == "list_files"
|
||||
assert spec.cmd == ["ls", "-la"]
|
||||
|
||||
|
||||
def test_task_decorator_full_options() -> None:
|
||||
"""@task 应支持全部 TaskSpec 字段."""
|
||||
|
||||
@px.task(
|
||||
depends_on=("a",),
|
||||
soft_depends_on=("b",),
|
||||
defaults={"b": 0},
|
||||
args=(1,),
|
||||
kwargs={"x": 2},
|
||||
retry=RetryPolicy(max_attempts=5),
|
||||
timeout=10.0,
|
||||
tags=("t1",),
|
||||
conditions=(px.BuiltinConditions.IS_WINDOWS,), # type: ignore[arg-type]
|
||||
cwd="/tmp",
|
||||
env={"K": "v"},
|
||||
verbose=True,
|
||||
skip_if_missing=True,
|
||||
allow_upstream_skip=True,
|
||||
strategy="thread",
|
||||
priority=3,
|
||||
concurrency_key="db",
|
||||
continue_on_error=True,
|
||||
)
|
||||
def f(a: int) -> int:
|
||||
return a
|
||||
|
||||
assert f.depends_on == ("a",)
|
||||
assert f.soft_depends_on == ("b",)
|
||||
assert f.defaults == {"b": 0}
|
||||
assert f.args == (1,)
|
||||
assert f.kwargs == {"x": 2}
|
||||
assert f.retry.max_attempts == 5
|
||||
assert f.timeout == 10.0
|
||||
assert f.tags == ("t1",)
|
||||
assert len(f.conditions) == 1
|
||||
assert isinstance(f.cwd, Path)
|
||||
assert f.cwd == Path("/tmp")
|
||||
assert f.env == {"K": "v"}
|
||||
assert f.verbose is True
|
||||
assert f.skip_if_missing is True
|
||||
assert f.allow_upstream_skip is True
|
||||
assert f.strategy == "thread"
|
||||
assert f.priority == 3
|
||||
assert f.concurrency_key == "db"
|
||||
assert f.continue_on_error is True
|
||||
|
||||
|
||||
def test_task_decorator_runs_in_graph() -> None:
|
||||
"""装饰器生成的 TaskSpec 应能直接构建图并运行."""
|
||||
|
||||
@px.task
|
||||
def extract() -> list[int]:
|
||||
return [1, 2, 3]
|
||||
|
||||
@px.task(depends_on=("extract",))
|
||||
def double(extract: list[int]) -> list[int]:
|
||||
return [x * 2 for x in extract]
|
||||
|
||||
graph = px.Graph.from_specs([extract, double])
|
||||
report = px.run(graph)
|
||||
assert report.success
|
||||
assert report["double"] == [2, 4, 6]
|
||||
|
||||
|
||||
def test_task_decorator_hooks_passthrough() -> None:
|
||||
"""@task(hooks=...) 应传递 TaskHooks 实例."""
|
||||
|
||||
hooks = TaskHooks(pre_run=lambda _spec: None)
|
||||
spec = px.task(fn=lambda: None, hooks=hooks, name="h")
|
||||
assert spec.hooks is hooks
|
||||
|
||||
|
||||
def test_task_decorator_cache_key_passthrough() -> None:
|
||||
"""@task(cache_key=...) 应传递缓存键函数."""
|
||||
|
||||
def ck(ctx: Mapping[str, Any]) -> str:
|
||||
return "k"
|
||||
|
||||
spec = px.task(fn=lambda: None, cache_key=ck, name="c")
|
||||
assert spec.cache_key is ck
|
||||
@@ -0,0 +1,315 @@
|
||||
"""Tests for task module edge cases."""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.task import TaskSpec
|
||||
|
||||
# 跨平台的 echo 命令
|
||||
if sys.platform == "win32":
|
||||
ECHO_CMD = ["cmd", "/c", "echo"]
|
||||
else:
|
||||
ECHO_CMD = ["echo"]
|
||||
|
||||
|
||||
def test_taskspec_wrap_cmd_with_list():
|
||||
"""Test TaskSpec._wrap_cmd with command list."""
|
||||
spec = TaskSpec("test", cmd=[*ECHO_CMD, "hello"])
|
||||
wrapped_fn = spec.effective_fn
|
||||
assert wrapped_fn is not None
|
||||
|
||||
|
||||
def test_taskspec_wrap_cmd_with_string():
|
||||
"""Test TaskSpec._wrap_cmd with command string."""
|
||||
if sys.platform == "win32":
|
||||
cmd_str = "cmd /c echo hello"
|
||||
else:
|
||||
cmd_str = "echo hello"
|
||||
spec = TaskSpec("test", cmd=cmd_str)
|
||||
wrapped_fn = spec.effective_fn
|
||||
assert wrapped_fn is not None
|
||||
|
||||
|
||||
def test_taskspec_wrap_cmd_with_timeout():
|
||||
"""Test TaskSpec._wrap_cmd with timeout."""
|
||||
spec = TaskSpec("test", cmd=[*ECHO_CMD, "hello"], timeout=0.1)
|
||||
wrapped_fn = spec.effective_fn
|
||||
|
||||
# Should not raise timeout error for quick command
|
||||
result = wrapped_fn()
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_taskspec_wrap_cmd_with_cwd():
|
||||
"""Test TaskSpec._wrap_cmd with working directory."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
spec = TaskSpec("test", cmd=[*ECHO_CMD, "hello"], cwd=Path(tmpdir))
|
||||
wrapped_fn = spec.effective_fn
|
||||
result = wrapped_fn()
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_taskspec_wrap_cmd_verbose():
|
||||
"""Test TaskSpec._wrap_cmd with verbose=True."""
|
||||
spec = TaskSpec("test", cmd=[*ECHO_CMD, "hello"], verbose=True)
|
||||
wrapped_fn = spec.effective_fn
|
||||
|
||||
# Should print verbose output
|
||||
result = wrapped_fn()
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_taskspec_wrap_cmd_error():
|
||||
"""Test TaskSpec._wrap_cmd handles command error."""
|
||||
import sys
|
||||
|
||||
spec = TaskSpec("test", cmd=[sys.executable, "-c", "import sys; sys.exit(1)"])
|
||||
wrapped_fn = spec.effective_fn
|
||||
|
||||
with pytest.raises(RuntimeError, match="命令执行失败"):
|
||||
_ = wrapped_fn()
|
||||
|
||||
|
||||
def test_taskspec_wrap_cmd_file_not_found():
|
||||
"""Test TaskSpec._wrap_cmd handles file not found."""
|
||||
spec = TaskSpec("test", cmd=["nonexistent_command"])
|
||||
wrapped_fn = spec.effective_fn
|
||||
|
||||
with pytest.raises(RuntimeError, match="命令未找到"):
|
||||
_ = wrapped_fn()
|
||||
|
||||
|
||||
def test_taskspec_wrap_cmd_shell_file_not_found():
|
||||
"""Test TaskSpec._wrap_cmd handles shell command file not found."""
|
||||
spec = TaskSpec("test", cmd="nonexistent_shell_command")
|
||||
wrapped_fn = spec.effective_fn
|
||||
|
||||
# Shell commands don't raise FileNotFoundError
|
||||
# They just return non-zero exit code
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = wrapped_fn()
|
||||
|
||||
|
||||
def test_taskspec_no_fn_no_cmd():
|
||||
"""Test TaskSpec raises error when no fn or cmd."""
|
||||
with pytest.raises(ValueError, match="必须提供 fn 或 cmd 参数"):
|
||||
_ = TaskSpec("test")
|
||||
|
||||
|
||||
def test_taskspec_conditions_check():
|
||||
"""Test TaskSpec.should_execute with conditions."""
|
||||
spec = px.TaskSpec(
|
||||
"test",
|
||||
fn=lambda: "result",
|
||||
conditions=(lambda _ctx: True,),
|
||||
)
|
||||
|
||||
assert spec.should_execute({})[0] is True
|
||||
|
||||
|
||||
def test_taskspec_conditions_false():
|
||||
"""Test TaskSpec.should_execute with false conditions."""
|
||||
spec = px.TaskSpec(
|
||||
"test",
|
||||
fn=lambda: "result",
|
||||
conditions=(lambda _ctx: False,),
|
||||
)
|
||||
|
||||
assert spec.should_execute({})[0] is False
|
||||
|
||||
|
||||
def test_taskspec_conditions_multiple():
|
||||
"""Test TaskSpec.should_execute with multiple conditions."""
|
||||
spec = px.TaskSpec(
|
||||
"test",
|
||||
fn=lambda: "result",
|
||||
conditions=(lambda _ctx: True, lambda _ctx: True, lambda _ctx: True),
|
||||
)
|
||||
|
||||
assert spec.should_execute({})[0] is True
|
||||
|
||||
|
||||
def test_taskspec_conditions_multiple_one_false():
|
||||
"""Test TaskSpec.should_execute with one false condition."""
|
||||
spec = px.TaskSpec(
|
||||
"test",
|
||||
fn=lambda: "result",
|
||||
conditions=(lambda _ctx: True, lambda _ctx: False, lambda _ctx: True),
|
||||
)
|
||||
|
||||
assert spec.should_execute({})[0] is False
|
||||
|
||||
|
||||
def test_taskspec_list_cmd_timeout_mocked():
|
||||
"""Test TaskSpec._wrap_cmd handles list command timeout (mocked)."""
|
||||
spec = TaskSpec("test", cmd=["sleep", "10"], timeout=0.1)
|
||||
wrapped_fn = spec.effective_fn
|
||||
|
||||
with patch(
|
||||
"subprocess.run", side_effect=subprocess.TimeoutExpired(cmd=["sleep", "10"], timeout=0.1)
|
||||
), pytest.raises(RuntimeError, match="命令执行超时"):
|
||||
_ = wrapped_fn()
|
||||
|
||||
|
||||
def test_taskspec_shell_cmd_timeout_mocked():
|
||||
"""Test TaskSpec._wrap_cmd handles shell command timeout (mocked)."""
|
||||
spec = TaskSpec("test", cmd="sleep 10", timeout=0.1)
|
||||
wrapped_fn = spec.effective_fn
|
||||
|
||||
with patch("subprocess.run", side_effect=subprocess.TimeoutExpired(cmd="sleep 10", timeout=0.1)), pytest.raises(
|
||||
RuntimeError, match="Shell 命令执行超时"
|
||||
):
|
||||
_ = wrapped_fn()
|
||||
|
||||
|
||||
def test_taskspec_shell_cmd_file_not_found_mocked():
|
||||
"""Test TaskSpec._wrap_cmd handles shell command FileNotFoundError (mocked)."""
|
||||
spec = TaskSpec("test", cmd="nonexistent_shell_command")
|
||||
wrapped_fn = spec.effective_fn
|
||||
|
||||
with patch("subprocess.run", side_effect=FileNotFoundError("not found")), pytest.raises(
|
||||
RuntimeError, match="Shell 命令未找到"
|
||||
):
|
||||
_ = wrapped_fn()
|
||||
|
||||
|
||||
def test_taskspec_shell_cmd_with_cwd_verbose(capsys: pytest.CaptureFixture[str]):
|
||||
"""Test TaskSpec._wrap_cmd with shell command, cwd and verbose=True."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
if sys.platform == "win32":
|
||||
shell_cmd = "cmd /c echo hello"
|
||||
else:
|
||||
shell_cmd = "echo hello"
|
||||
spec = TaskSpec("test", cmd=shell_cmd, cwd=Path(tmpdir), verbose=True)
|
||||
wrapped_fn = spec.effective_fn
|
||||
result = wrapped_fn()
|
||||
assert result is None
|
||||
captured = capsys.readouterr()
|
||||
assert "执行 Shell" in captured.out
|
||||
assert "工作目录" in captured.out
|
||||
|
||||
|
||||
def test_taskspec_list_cmd_os_error_mocked():
|
||||
"""Test TaskSpec._wrap_cmd handles list command OSError (mocked)."""
|
||||
spec = TaskSpec("test", cmd=["ls"])
|
||||
wrapped_fn = spec.effective_fn
|
||||
|
||||
with patch("subprocess.run", side_effect=OSError("os error")), pytest.raises(RuntimeError, match="命令执行异常"):
|
||||
_ = wrapped_fn()
|
||||
|
||||
|
||||
def test_taskspec_shell_cmd_os_error_mocked():
|
||||
"""Test TaskSpec._wrap_cmd handles shell command OSError (mocked)."""
|
||||
spec = TaskSpec("test", cmd="ls")
|
||||
wrapped_fn = spec.effective_fn
|
||||
|
||||
with patch("subprocess.run", side_effect=OSError("os error")), pytest.raises(
|
||||
RuntimeError, match="Shell 命令执行异常"
|
||||
):
|
||||
_ = wrapped_fn()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# skip_if_missing
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_skip_if_missing_with_available_command():
|
||||
"""skip_if_missing=True 时,命令存在应返回 True."""
|
||||
import sys
|
||||
|
||||
spec = TaskSpec("test", cmd=[sys.executable, "--version"], skip_if_missing=True)
|
||||
assert spec.should_execute({})[0] is True
|
||||
|
||||
|
||||
def test_skip_if_missing_with_missing_command():
|
||||
"""skip_if_missing=True 时,命令不存在应返回 False."""
|
||||
spec = TaskSpec("test", cmd=["definitely_not_installed_app_xyz"], skip_if_missing=True)
|
||||
assert spec.should_execute({})[0] is False
|
||||
|
||||
|
||||
def test_skip_if_missing_false_with_missing_command():
|
||||
"""skip_if_missing=False 时,命令不存在也应返回 True(不检查)."""
|
||||
spec = TaskSpec("test", cmd=["definitely_not_installed_app_xyz"], skip_if_missing=False)
|
||||
assert spec.should_execute({})[0] is True
|
||||
|
||||
|
||||
def test_skip_if_missing_with_shell_cmd_not_checked():
|
||||
"""skip_if_missing=True 时,shell 命令(str)不检查,应返回 True."""
|
||||
spec = TaskSpec("test", cmd="definitely_not_installed_app_xyz", skip_if_missing=True)
|
||||
assert spec.should_execute({})[0] is True
|
||||
|
||||
|
||||
def test_skip_if_missing_with_callable_cmd_not_checked():
|
||||
"""skip_if_missing=True 时,Callable 命令不检查,应返回 True."""
|
||||
|
||||
def custom_cmd() -> int:
|
||||
return 0
|
||||
|
||||
spec = TaskSpec("test", cmd=custom_cmd, skip_if_missing=True)
|
||||
assert spec.should_execute({})[0] is True
|
||||
|
||||
|
||||
def test_skip_if_missing_with_fn_not_checked():
|
||||
"""skip_if_missing=True 时,fn 任务不检查命令,应返回 True."""
|
||||
|
||||
def my_fn() -> int:
|
||||
return 0
|
||||
|
||||
spec = TaskSpec("test", fn=my_fn, skip_if_missing=True)
|
||||
assert spec.should_execute({})[0] is True
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_skip_if_missing_with_empty_cmd_list():
|
||||
"""skip_if_missing=True 时,空命令列表应返回 True(不检查)."""
|
||||
spec = TaskSpec("test", cmd=[""], skip_if_missing=True)
|
||||
# 空字符串命令,shutil.which 返回 None
|
||||
# 但 cmd[0] 是空字符串,shutil.which("") 返回 None
|
||||
assert spec.should_execute({})[0] is False
|
||||
|
||||
|
||||
def test_skip_if_missing_combined_with_conditions():
|
||||
"""skip_if_missing=True 与 conditions 组合使用."""
|
||||
import sys
|
||||
|
||||
# conditions 返回 False,应跳过
|
||||
spec = TaskSpec(
|
||||
"test",
|
||||
cmd=[sys.executable, "--version"],
|
||||
skip_if_missing=True,
|
||||
conditions=(lambda _ctx: False,),
|
||||
)
|
||||
assert spec.should_execute({})[0] is False
|
||||
|
||||
# conditions 返回 True,命令存在,应执行
|
||||
spec = TaskSpec(
|
||||
"test",
|
||||
cmd=[sys.executable, "--version"],
|
||||
skip_if_missing=True,
|
||||
conditions=(lambda _ctx: True,),
|
||||
)
|
||||
assert spec.should_execute({})[0] is True
|
||||
|
||||
# conditions 返回 True,命令不存在,应跳过
|
||||
spec = TaskSpec(
|
||||
"test",
|
||||
cmd=["definitely_not_installed_app_xyz"],
|
||||
skip_if_missing=True,
|
||||
conditions=(lambda _ctx: True,),
|
||||
)
|
||||
assert spec.should_execute({})[0] is False
|
||||
|
||||
|
||||
def test_skip_if_missing_skips_task_in_run():
|
||||
"""skip_if_missing=True 时,命令不存在的任务在 run 中应被跳过."""
|
||||
spec = TaskSpec("missing_cmd", cmd=["definitely_not_installed_app_xyz"], skip_if_missing=True)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success is True
|
||||
result = report.result_of("missing_cmd")
|
||||
assert result.status == px.TaskStatus.SKIPPED
|
||||
@@ -0,0 +1,480 @@
|
||||
"""测试 TaskSpec 的命令和条件执行功能."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.conditions import (
|
||||
BuiltinConditions,
|
||||
Constants,
|
||||
)
|
||||
|
||||
# 跨平台的 echo 命令
|
||||
if sys.platform == "win32":
|
||||
ECHO_CMD = ["cmd", "/c", "echo"]
|
||||
else:
|
||||
ECHO_CMD = ["echo"]
|
||||
|
||||
|
||||
def test_taskspec_with_cmd_list():
|
||||
"""测试使用命令列表的 TaskSpec."""
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("echo_test", cmd=[*ECHO_CMD, "hello"]),
|
||||
])
|
||||
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert "echo_test" in report.results
|
||||
assert report.results["echo_test"].status == px.TaskStatus.SUCCESS
|
||||
|
||||
|
||||
def test_taskspec_with_cmd_string():
|
||||
"""测试使用 shell 命令字符串的 TaskSpec."""
|
||||
if sys.platform == "win32":
|
||||
shell_cmd = 'cmd /c "echo hello from shell"'
|
||||
else:
|
||||
shell_cmd = "echo 'hello from shell'"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("shell_test", cmd=shell_cmd),
|
||||
])
|
||||
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert "shell_test" in report.results
|
||||
assert report.results["shell_test"].status == px.TaskStatus.SUCCESS
|
||||
|
||||
|
||||
def test_taskspec_with_conditions_skip():
|
||||
"""测试条件不满足时任务被跳过."""
|
||||
|
||||
# 创建一个永远不会满足的条件
|
||||
def never_true(_ctx):
|
||||
return False
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"should_skip",
|
||||
cmd=[*ECHO_CMD, "this should not run"],
|
||||
conditions=(never_true,),
|
||||
),
|
||||
])
|
||||
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert "should_skip" in report.results
|
||||
assert report.results["should_skip"].status == px.TaskStatus.SKIPPED
|
||||
|
||||
|
||||
def test_taskspec_with_conditions_execute():
|
||||
"""测试条件满足时任务正常执行."""
|
||||
|
||||
# 创建一个总是满足的条件
|
||||
def always_true(_ctx):
|
||||
return True
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"should_run",
|
||||
cmd=[*ECHO_CMD, "this should run"],
|
||||
conditions=(always_true,),
|
||||
),
|
||||
])
|
||||
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert "should_run" in report.results
|
||||
assert report.results["should_run"].status == px.TaskStatus.SUCCESS
|
||||
|
||||
|
||||
def test_platform_conditions():
|
||||
"""测试平台条件."""
|
||||
if sys.platform == "win32":
|
||||
win_cmd = ["cmd", "/c", "echo", "Windows"]
|
||||
posix_cmd = ["echo", "POSIX"]
|
||||
else:
|
||||
win_cmd = ["echo", "Windows"]
|
||||
posix_cmd = ["echo", "POSIX"]
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"win_task",
|
||||
cmd=win_cmd,
|
||||
conditions=(lambda _ctx: Constants.IS_WINDOWS,),
|
||||
),
|
||||
px.TaskSpec(
|
||||
"linux_task",
|
||||
cmd=posix_cmd,
|
||||
conditions=(lambda _ctx: Constants.IS_LINUX,),
|
||||
),
|
||||
px.TaskSpec(
|
||||
"macos_task",
|
||||
cmd=posix_cmd,
|
||||
conditions=(lambda _ctx: Constants.IS_MACOS,),
|
||||
),
|
||||
])
|
||||
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
|
||||
# 检查只有当前平台的任务执行了
|
||||
if sys.platform == "win32":
|
||||
assert report.results["win_task"].status == px.TaskStatus.SUCCESS
|
||||
assert report.results["linux_task"].status == px.TaskStatus.SKIPPED
|
||||
assert report.results["macos_task"].status == px.TaskStatus.SKIPPED
|
||||
elif sys.platform == "linux":
|
||||
assert report.results["win_task"].status == px.TaskStatus.SKIPPED
|
||||
assert report.results["linux_task"].status == px.TaskStatus.SUCCESS
|
||||
assert report.results["macos_task"].status == px.TaskStatus.SKIPPED
|
||||
elif sys.platform == "darwin":
|
||||
assert report.results["win_task"].status == px.TaskStatus.SKIPPED
|
||||
assert report.results["linux_task"].status == px.TaskStatus.SKIPPED
|
||||
assert report.results["macos_task"].status == px.TaskStatus.SUCCESS
|
||||
|
||||
|
||||
def test_app_installed_conditions():
|
||||
"""测试应用安装条件."""
|
||||
# 使用 sys.executable 保证可移植
|
||||
python_cmd = [sys.executable, "--version"]
|
||||
py_name = "python" if sys.platform == "win32" else "python3"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"python_check",
|
||||
cmd=python_cmd,
|
||||
conditions=(BuiltinConditions.HAS_INSTALLED(py_name),),
|
||||
),
|
||||
])
|
||||
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert "python_check" in report.results
|
||||
# python 应该总是安装的
|
||||
assert report.results["python_check"].status == px.TaskStatus.SUCCESS
|
||||
|
||||
|
||||
def test_combined_conditions():
|
||||
"""测试组合条件."""
|
||||
# AND 条件
|
||||
and_condition = BuiltinConditions.AND(
|
||||
lambda _ctx: True,
|
||||
lambda _ctx: True,
|
||||
)
|
||||
|
||||
# OR 条件
|
||||
or_condition = BuiltinConditions.OR(
|
||||
lambda _ctx: True,
|
||||
lambda _ctx: False,
|
||||
)
|
||||
|
||||
# NOT 条件
|
||||
not_condition = BuiltinConditions.NOT(lambda _ctx: False)
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"and_test",
|
||||
cmd=[*ECHO_CMD, "AND"],
|
||||
conditions=(and_condition,),
|
||||
),
|
||||
px.TaskSpec(
|
||||
"or_test",
|
||||
cmd=[*ECHO_CMD, "OR"],
|
||||
conditions=(or_condition,),
|
||||
),
|
||||
px.TaskSpec(
|
||||
"not_test",
|
||||
cmd=[*ECHO_CMD, "NOT"],
|
||||
conditions=(not_condition,),
|
||||
),
|
||||
])
|
||||
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert report.results["and_test"].status == px.TaskStatus.SUCCESS
|
||||
assert report.results["or_test"].status == px.TaskStatus.SUCCESS
|
||||
assert report.results["not_test"].status == px.TaskStatus.SUCCESS
|
||||
|
||||
|
||||
def test_taskspec_with_cwd():
|
||||
"""测试工作目录设置."""
|
||||
if sys.platform == "win32":
|
||||
ls_cmd = ["cmd", "/c", "dir"]
|
||||
else:
|
||||
ls_cmd = ["ls", "-la"]
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"list_current",
|
||||
cmd=ls_cmd,
|
||||
cwd=Path.cwd(),
|
||||
),
|
||||
])
|
||||
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert "list_current" in report.results
|
||||
assert report.results["list_current"].status == px.TaskStatus.SUCCESS
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_taskspec_with_timeout():
|
||||
"""测试超时设置."""
|
||||
graph = px.Graph.from_specs([
|
||||
# 短时间任务应该成功
|
||||
px.TaskSpec(
|
||||
"short_task",
|
||||
cmd=[sys.executable, "-c", "import time; time.sleep(0.1)"],
|
||||
timeout=1.0,
|
||||
),
|
||||
])
|
||||
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert "short_task" in report.results
|
||||
assert report.results["short_task"].status == px.TaskStatus.SUCCESS
|
||||
|
||||
|
||||
def test_taskspec_dependency_with_conditions():
|
||||
"""测试依赖和条件的组合."""
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"first",
|
||||
cmd=[*ECHO_CMD, "first"],
|
||||
conditions=(lambda _ctx: True,),
|
||||
),
|
||||
px.TaskSpec(
|
||||
"second",
|
||||
cmd=[*ECHO_CMD, "second"],
|
||||
depends_on=("first",),
|
||||
conditions=(lambda _ctx: True,),
|
||||
),
|
||||
px.TaskSpec(
|
||||
"third",
|
||||
cmd=[*ECHO_CMD, "third"],
|
||||
depends_on=("second",),
|
||||
),
|
||||
])
|
||||
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert report.results["first"].status == px.TaskStatus.SUCCESS
|
||||
assert report.results["second"].status == px.TaskStatus.SUCCESS
|
||||
assert report.results["third"].status == px.TaskStatus.SUCCESS
|
||||
|
||||
|
||||
def test_taskspec_mixed_fn_and_cmd():
|
||||
"""测试混合使用 fn 和 cmd."""
|
||||
|
||||
def my_function():
|
||||
return "result from function"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("fn_task", fn=my_function),
|
||||
px.TaskSpec("cmd_task", cmd=[*ECHO_CMD, "from command"]),
|
||||
])
|
||||
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert report.results["fn_task"].status == px.TaskStatus.SUCCESS
|
||||
assert report.results["fn_task"].value == "result from function"
|
||||
assert report.results["cmd_task"].status == px.TaskStatus.SUCCESS
|
||||
|
||||
|
||||
def test_taskspec_cmd_overrides_fn():
|
||||
"""测试 cmd 参数优先于 fn 参数."""
|
||||
|
||||
def my_function():
|
||||
return "should not run"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"cmd_priority",
|
||||
fn=my_function,
|
||||
cmd=[*ECHO_CMD, "cmd takes priority"],
|
||||
),
|
||||
])
|
||||
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert report.results["cmd_priority"].status == px.TaskStatus.SUCCESS
|
||||
# cmd 应该被执行,而不是 fn
|
||||
assert report.results["cmd_priority"].value is None
|
||||
|
||||
|
||||
def test_taskspec_callable_cmd():
|
||||
"""测试 cmd 参数使用可调用对象."""
|
||||
|
||||
def my_callable():
|
||||
return "callable result"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("callable_cmd", cmd=my_callable),
|
||||
])
|
||||
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert report.results["callable_cmd"].status == px.TaskStatus.SUCCESS
|
||||
assert report.results["callable_cmd"].value == "callable result"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# verbose 模式测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestTaskSpecVerbose:
|
||||
"""测试 TaskSpec 的 verbose 字段."""
|
||||
|
||||
def test_verbose_default_is_false(self) -> None:
|
||||
"""verbose 默认应为 False."""
|
||||
spec: px.TaskSpec[Any] = px.TaskSpec[Any]("a", cmd=[*ECHO_CMD, "hi"])
|
||||
assert spec.verbose is False
|
||||
|
||||
def test_verbose_true_prints_command(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""verbose=True 时应打印执行的命令."""
|
||||
graph = px.Graph.from_specs([px.TaskSpec("echo", cmd=[*ECHO_CMD, "verbose-output"], verbose=True)])
|
||||
_ = px.run(graph, strategy="sequential")
|
||||
captured = capsys.readouterr()
|
||||
assert "执行命令" in captured.out
|
||||
assert "返回码" in captured.out
|
||||
|
||||
def test_verbose_false_silent(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""verbose=False 时不应打印命令信息."""
|
||||
graph = px.Graph.from_specs([px.TaskSpec[Any]("echo", cmd=[*ECHO_CMD, "silent"], verbose=False)])
|
||||
_ = px.run(graph, strategy="sequential")
|
||||
captured = capsys.readouterr()
|
||||
assert "执行命令" not in captured.out
|
||||
assert "返回码" not in captured.out
|
||||
|
||||
def test_verbose_true_shell_cmd(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""verbose=True 时 shell 命令也应打印执行信息."""
|
||||
if sys.platform == "win32":
|
||||
shell_cmd = 'cmd /c "echo shell-verbose"'
|
||||
else:
|
||||
shell_cmd = "echo 'shell-verbose'"
|
||||
|
||||
graph = px.Graph.from_specs([px.TaskSpec("shell", cmd=shell_cmd, verbose=True)])
|
||||
_ = px.run(graph, strategy="sequential")
|
||||
captured = capsys.readouterr()
|
||||
assert "执行 Shell" in captured.out
|
||||
|
||||
def test_verbose_prints_cwd(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""verbose=True 且设置了 cwd 时应打印工作目录."""
|
||||
import tempfile
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
graph = px.Graph.from_specs([px.TaskSpec[Any]("ls", cmd=ECHO_CMD, cwd=Path(tmpdir), verbose=True)])
|
||||
_ = px.run(graph, strategy="sequential")
|
||||
captured = capsys.readouterr()
|
||||
assert "工作目录" in captured.out
|
||||
|
||||
def test_verbose_failure_includes_returncode(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""verbose=True 时失败也应打印返回码."""
|
||||
from pyflowx.errors import TaskFailedError
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"fail",
|
||||
cmd=[sys.executable, "-c", "import sys; sys.exit(1)"],
|
||||
verbose=True,
|
||||
)
|
||||
])
|
||||
with pytest.raises(TaskFailedError):
|
||||
_ = px.run(graph, strategy="sequential")
|
||||
captured = capsys.readouterr()
|
||||
assert "返回码" in captured.out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# _wrap_cmd 错误路径测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestTaskSpecCmdErrors:
|
||||
"""测试 _wrap_cmd 的错误处理路径."""
|
||||
|
||||
def test_cmd_list_file_not_found(self) -> None:
|
||||
"""命令不存在时应抛出 RuntimeError."""
|
||||
from pyflowx.errors import TaskFailedError
|
||||
|
||||
graph = px.Graph.from_specs(
|
||||
[px.TaskSpec("missing", cmd=["this-command-does-not-exist-xyz"], skip_if_missing=False)],
|
||||
)
|
||||
with pytest.raises(TaskFailedError) as exc_info:
|
||||
_ = px.run(graph, strategy="sequential")
|
||||
# 错误信息应包含命令未找到
|
||||
assert "命令未找到" in str(exc_info.value.cause) or "not found" in str(exc_info.value.cause).lower()
|
||||
|
||||
def test_cmd_list_failure_includes_stderr(self) -> None:
|
||||
"""命令失败时错误信息应包含 stderr."""
|
||||
from pyflowx.errors import TaskFailedError
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"fail",
|
||||
cmd=[
|
||||
sys.executable,
|
||||
"-c",
|
||||
"import sys; sys.stderr.write('error-msg'); sys.exit(1)",
|
||||
],
|
||||
)
|
||||
])
|
||||
with pytest.raises(TaskFailedError) as exc_info:
|
||||
_ = px.run(graph, strategy="sequential")
|
||||
# 非 verbose 模式下, stderr 应包含在错误信息中
|
||||
assert "error-msg" in str(exc_info.value.cause)
|
||||
|
||||
def test_cmd_string_file_not_found(self) -> None:
|
||||
"""shell 命令不存在时应抛出 RuntimeError."""
|
||||
from pyflowx.errors import TaskFailedError
|
||||
|
||||
graph = px.Graph.from_specs([px.TaskSpec("missing", cmd="this-command-does-not-exist-xyz-123")])
|
||||
with pytest.raises(TaskFailedError):
|
||||
_ = px.run(graph, strategy="sequential")
|
||||
|
||||
def test_cmd_string_failure(self) -> None:
|
||||
"""shell 命令失败时应抛出 RuntimeError."""
|
||||
from pyflowx.errors import TaskFailedError
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("fail", cmd=f'{sys.executable} -c "import sys; sys.exit(1)"'),
|
||||
])
|
||||
with pytest.raises(TaskFailedError) as exc_info:
|
||||
_ = px.run(graph, strategy="sequential")
|
||||
assert "Shell 命令执行失败" in str(exc_info.value.cause)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_cmd_timeout_raises(self) -> None:
|
||||
"""命令超时应抛出 RuntimeError."""
|
||||
from pyflowx.errors import TaskFailedError
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"slow",
|
||||
cmd=[sys.executable, "-c", "import time; time.sleep(5)"],
|
||||
timeout=0.1,
|
||||
)
|
||||
])
|
||||
with pytest.raises(TaskFailedError) as exc_info:
|
||||
_ = px.run(graph, strategy="sequential")
|
||||
assert "超时" in str(exc_info.value.cause)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_cmd_string_timeout_raises(self) -> None:
|
||||
"""shell 命令超时应抛出 RuntimeError."""
|
||||
from pyflowx.errors import TaskFailedError
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec(
|
||||
"slow",
|
||||
cmd=f'{sys.executable} -c "import time; time.sleep(5)"',
|
||||
timeout=0.1,
|
||||
),
|
||||
])
|
||||
with pytest.raises(TaskFailedError) as exc_info:
|
||||
_ = px.run(graph, strategy="sequential")
|
||||
assert "超时" in str(exc_info.value.cause)
|
||||
|
||||
def test_no_fn_no_cmd_raises(self) -> None:
|
||||
"""没有 fn 和 cmd 时应抛出 ValueError."""
|
||||
with pytest.raises(ValueError, match="必须提供 fn 或 cmd"):
|
||||
_ = px.TaskSpec("empty")
|
||||
@@ -1,6 +1,6 @@
|
||||
[tox]
|
||||
isolated_build = true
|
||||
envlist = py38, py39, py310, py311, py312, py313
|
||||
envlist = py38, py39, py310, py311, py312, py313, py314
|
||||
min_version = 4.0
|
||||
requires = tox-uv
|
||||
skipsdist = true
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
# pyrefly: ignore [missing-import]
|
||||
from .graphlib import CycleError, TopologicalSorter
|
||||
|
||||
__all__ = ["CycleError", "TopologicalSorter"]
|
||||
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
This type stub file was generated by pyright.
|
||||
"""
|
||||
|
||||
from typing import Any, Generator
|
||||
|
||||
__all__ = ["CycleError", "TopologicalSorter"]
|
||||
_NODE_OUT = ...
|
||||
_NODE_DONE = ...
|
||||
|
||||
class _NodeInfo:
|
||||
__slots__: list[str]
|
||||
|
||||
def __init__(self, node: Any) -> None: ...
|
||||
|
||||
class CycleError(ValueError):
|
||||
"""Subclass of ValueError raised by TopologicalSorterif cycles exist in the graph
|
||||
|
||||
If multiple cycles exist, only one undefined choice among them will be reported
|
||||
and included in the exception. The detected cycle can be accessed via the second
|
||||
element in the *args* attribute of the exception instance and consists in a list
|
||||
of nodes, such that each node is, in the graph, an immediate predecessor of the
|
||||
next node in the list. In the reported list, the first and the last node will be
|
||||
the same, to make it clear that it is cyclic.
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
class TopologicalSorter:
|
||||
"""Provides functionality to topologically sort a graph of hashable nodes"""
|
||||
|
||||
def __init__(self, graph: Any) -> None: ...
|
||||
def add(self, node: Any, *predecessors: Any) -> None:
|
||||
"""Add a new node and its predecessors to the graph.
|
||||
|
||||
Both the *node* and all elements in *predecessors* must be hashable.
|
||||
|
||||
If called multiple times with the same node argument, the set of dependencies
|
||||
will be the union of all dependencies passed in.
|
||||
|
||||
It is possible to add a node with no dependencies (*predecessors* is not provided)
|
||||
as well as provide a dependency twice. If a node that has not been provided before
|
||||
is included among *predecessors* it will be automatically added to the graph with
|
||||
no predecessors of its own.
|
||||
|
||||
Raises ValueError if called after "prepare".
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
def prepare(self) -> None:
|
||||
"""Mark the graph as finished and check for cycles in the graph.
|
||||
|
||||
If any cycle is detected, "CycleError" will be raised, but "get_ready" can
|
||||
still be used to obtain as many nodes as possible until cycles block more
|
||||
progress. After a call to this function, the graph cannot be modified and
|
||||
therefore no more nodes can be added using "add".
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
def get_ready(self) -> tuple[Any, ...]:
|
||||
"""Return a tuple of all the nodes that are ready.
|
||||
|
||||
Initially it returns all nodes with no predecessors; once those are marked
|
||||
as processed by calling "done", further calls will return all new nodes that
|
||||
have all their predecessors already processed. Once no more progress can be made,
|
||||
empty tuples are returned.
|
||||
|
||||
Raises ValueError if called without calling "prepare" previously.
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
def is_active(self) -> bool:
|
||||
"""Return True if more progress can be made and ``False`` otherwise.
|
||||
|
||||
Progress can be made if cycles do not block the resolution and either there
|
||||
are still nodes ready that haven't yet been returned by "get_ready" or the
|
||||
number of nodes marked "done" is less than the number that have been returned
|
||||
by "get_ready".
|
||||
|
||||
Raises ValueError if called without calling "prepare" previously.
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
def __bool__(self) -> bool: ...
|
||||
def done(self, *nodes: Any) -> None:
|
||||
"""Marks a set of nodes returned by "get_ready" as processed.
|
||||
|
||||
This method unblocks any successor of each node in *nodes* for being returned
|
||||
in the future by a a call to "get_ready"
|
||||
|
||||
Raises :exec:`ValueError` if any node in *nodes* has already been marked as
|
||||
processed by a previous call to this method, if a node was not added to the
|
||||
graph by using "add" or if called without calling "prepare" previously or if
|
||||
node has not yet been returned by "get_ready".
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
def static_order(self) -> Generator[Any]:
|
||||
"""Returns an iterable of nodes in a topological order.
|
||||
|
||||
The particular order that is returned may depend on the specific
|
||||
order in which the items were inserted in the graph.
|
||||
|
||||
Using this method does not require to call "prepare" or "done". If any
|
||||
cycle is detected, :exc:`CycleError` will be raised.
|
||||
"""
|
||||
|
||||
...
|
||||
Reference in New Issue
Block a user