feat: 初始化PyFlowX轻量级DAG任务调度库

实现完整的DAG任务调度核心功能,包括:
1.  支持同步/异步/线程三种执行策略
2.  自动上下文注入,无需手动绑定任务依赖
3.  内置状态后端,支持断点续跑
4.  提供完整的测试用例与示例代码
5.  添加CI/CD配置与发布流程
This commit is contained in:
2026-06-20 10:41:33 +08:00
parent 70f3c03986
commit 8b7777d936
21 changed files with 6003 additions and 3 deletions
+131
View File
@@ -0,0 +1,131 @@
name: CI
on:
push:
branches: [main, develop]
pull_request:
branches: [main, develop]
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
# ─────────────────────────────────────────────────────────────
# 后端:多平台 × 多 Python 版本矩阵测试
# ─────────────────────────────────────────────────────────────
backend-test:
name: Backend (${{ matrix.os }} / py${{ matrix.python-version }})
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: ['3.13', '3.14']
exclude:
# macOS + py3.14 暂时跳过(部分依赖未发布 wheel)
- os: macos-latest
python-version: '3.14'
steps:
- name: Checkout
uses: actions/checkout@v4
- name: 安装 uv
uses: astral-sh/setup-uv@v5
with:
version: latest
enable-cache: true
cache-dependency-glob: uv.lock
- name: 设置 Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: 安装依赖
run: uv sync --extra dev --frozen
- name: Ruff 检查
run: uv run ruff check backend/endo tests
- name: Ruff 格式检查
run: uv run ruff format --check backend/endo tests
- name: 运行测试
env:
PYTHONPATH: backend
run: uv run pytest -v --cov=endo --cov-report=xml --cov-report=term-missing
- name: 上传覆盖率
if: matrix.os == 'ubuntu-latest' && matrix.python-version == '3.13'
uses: actions/upload-artifact@v4
with:
name: coverage-${{ matrix.os }}-py${{ matrix.python-version }}
path: coverage.xml
retention-days: 7
# ─────────────────────────────────────────────────────────────
# 前端:多平台构建验证
# ─────────────────────────────────────────────────────────────
frontend-build:
name: Frontend (${{ matrix.os }} / node${{ matrix.node-version }})
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
node-version: [20, 22]
defaults:
run:
working-directory: frontend
steps:
- name: Checkout
uses: actions/checkout@v4
- name: 安装 pnpm
uses: pnpm/action-setup@v4
with:
version: 9
- name: 设置 Node ${{ matrix.node-version }}
uses: actions/setup-node@v4
with:
node-version: ${{ matrix.node-version }}
cache: pnpm
cache-dependency-path: frontend/pnpm-lock.yaml
- name: 安装依赖
run: pnpm install --frozen-lockfile
- name: TypeScript 类型检查
run: npx tsc --noEmit -p tsconfig.app.json
- name: 构建
run: pnpm run build
- name: 上传构建产物
if: matrix.os == 'ubuntu-latest' && matrix.node-version == 22
uses: actions/upload-artifact@v4
with:
name: frontend-dist
path: frontend/dist
retention-days: 7
# ─────────────────────────────────────────────────────────────
# 聚合:所有测试通过后才标记完成
# ─────────────────────────────────────────────────────────────
ci-pass:
name: CI Pass
runs-on: ubuntu-latest
needs: [backend-test, frontend-build]
if: always()
steps:
- name: 检查依赖任务结果
if: ${{ needs.backend-test.result != 'success' || needs.frontend-build.result != 'success' }}
run: |
echo "backend-test: ${{ needs.backend-test.result }}"
echo "frontend-build: ${{ needs.frontend-build.result }}"
exit 1
- name: 全部通过
run: echo "✅ 所有 CI 检查通过"
+253
View File
@@ -0,0 +1,253 @@
name: Release
on:
push:
tags:
- 'v*.*.*'
workflow_dispatch:
inputs:
tag:
description: '发布版本号(如 v0.1.0'
required: true
type: string
permissions:
contents: write
# Trusted Publishing (OIDC) 上传 PyPI 所需
id-token: write
jobs:
# ─────────────────────────────────────────────────────────────
# 预检:发布前必须通过 CI
# ─────────────────────────────────────────────────────────────
pre-check:
name: Pre-release Check
runs-on: ubuntu-latest
outputs:
version: ${{ steps.meta.outputs.version }}
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: 解析版本号
id: meta
run: |
if [ -n "${{ inputs.tag }}" ]; then
TAG="${{ inputs.tag }}"
else
TAG="${GITHUB_REF#refs/tags/}"
fi
# 去除前缀 v
VERSION="${TAG#v}"
echo "tag=$TAG" >> $GITHUB_OUTPUT
echo "version=$VERSION" >> $GITHUB_OUTPUT
echo "发布版本: $VERSION (tag: $TAG)"
- name: 校验版本号格式
run: |
VERSION="${{ steps.meta.outputs.version }}"
if ! echo "$VERSION" | grep -qE '^[0-9]+\.[0-9]+\.[0-9]+(-[a-zA-Z0-9.]+)?$'; then
echo "❌ 版本号格式错误: $VERSION(应为 x.y.z 或 x.y.z-rc.n"
exit 1
fi
- name: 校验 pyproject.toml 版本一致
run: |
# 精确提取 [project] 段的 version 字段(避免匹配到依赖的 version)
PY_VERSION=$(awk '/^\[project\]/{f=1} f&&/^version[[:space:]]*=/{gsub(/[" ]/,"",$3); print $3; exit}' pyproject.toml)
echo "pyproject.toml version: $PY_VERSION"
if [ "$PY_VERSION" != "${{ steps.meta.outputs.version }}" ]; then
echo "❌ pyproject.toml 版本($PY_VERSION) 与 tag 版本(${{ steps.meta.outputs.version }}) 不一致"
echo "请先更新 pyproject.toml 中的 version 字段"
exit 1
fi
# ─────────────────────────────────────────────────────────────
# 构建:后端 wheel(纯 Python,单平台即可)+ 前端 dist
# ─────────────────────────────────────────────────────────────
build:
name: Build Artifacts
needs: pre-check
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: 安装 uv
uses: astral-sh/setup-uv@v5
with:
version: latest
enable-cache: true
- name: 设置 Python 3.13
uses: actions/setup-python@v5
with:
python-version: '3.13'
- name: 安装 pnpm(前端构建依赖)
uses: pnpm/action-setup@v4
with:
version: 9
- name: 设置 Node 22(前端构建)
uses: actions/setup-node@v4
with:
node-version: 22
cache: pnpm
cache-dependency-path: frontend/pnpm-lock.yaml
- name: 安装前端依赖(缓存)
working-directory: frontend
run: pnpm install --frozen-lockfile
- name: 构建后端 wheel + sdist(自动触发前端构建)
run: uv build
- name: 上传后端产物
uses: actions/upload-artifact@v4
with:
name: backend-dist
path: dist/*
retention-days: 30
build-frontend:
name: Build Frontend
needs: pre-check
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: 安装 pnpm
uses: pnpm/action-setup@v4
with:
version: 9
- name: 设置 Node 22
uses: actions/setup-node@v4
with:
node-version: 22
cache: pnpm
cache-dependency-path: frontend/pnpm-lock.yaml
- name: 安装依赖
working-directory: frontend
run: pnpm install --frozen-lockfile
- name: 构建
working-directory: frontend
run: pnpm run build
- name: 打包前端 dist
run: |
cd frontend
zip -r ../endo-frontend-${{ needs.pre-check.outputs.version }}.zip dist
- name: 上传前端产物
uses: actions/upload-artifact@v4
with:
name: frontend-dist-release
path: endo-frontend-*.zip
retention-days: 30
# ─────────────────────────────────────────────────────────────
# 发布:上传到 PyPITrusted Publishing / OIDC
# ─────────────────────────────────────────────────────────────
publish-pypi:
name: Publish to PyPI
needs: [pre-check, build]
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/project/endo/${{ needs.pre-check.outputs.version }}
permissions:
id-token: write
steps:
- name: 下载后端构建产物
uses: actions/download-artifact@v4
with:
name: backend-dist
path: dist
- name: 校验产物
run: |
echo "待上传产物:"
ls -la dist/
if [ -z "$(ls -A dist/*.whl dist/*.tar.gz 2>/dev/null)" ]; then
echo "❌ 未找到 wheel 或 sdist 产物"
exit 1
fi
- name: 上传到 PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
attestations: true
# ─────────────────────────────────────────────────────────────
# 发布:创建 GitHub Release
# ─────────────────────────────────────────────────────────────
release:
name: Publish Release
needs: [pre-check, build, build-frontend, publish-pypi]
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- name: Checkout
uses: actions/checkout@v4
- name: 下载所有构建产物
uses: actions/download-artifact@v4
with:
path: release-assets
- name: 整理发布产物
run: |
mkdir -p assets
find release-assets -name "*.whl" -exec cp {} assets/ \;
find release-assets -name "*.tar.gz" -exec cp {} assets/ \;
find release-assets -name "*.zip" -exec cp {} assets/ \;
ls -la assets/
- name: 生成 Release Notes
id: notes
run: |
{
echo "## endo ${{ needs.pre-check.outputs.version }}"
echo ""
echo "### 下载"
echo ""
echo "- **后端 wheel**: \`endo-${{ needs.pre-check.outputs.version }}-py3-none-any.whl\`"
echo "- **源码包**: \`endo-${{ needs.pre-check.outputs.version }}.tar.gz\`"
echo "- **前端 dist**: \`endo-frontend-${{ needs.pre-check.outputs.version }}.zip\`"
echo ""
echo "### 安装"
echo ""
echo '```bash'
echo "# 后端"
echo "pip install endo-${{ needs.pre-check.outputs.version }}-py3-none-any.whl"
echo ""
echo "# 前端"
echo "unzip endo-frontend-${{ needs.pre-check.outputs.version }}.zip -d frontend-dist"
echo '```'
echo ""
echo "### 完整变更日志"
} > RELEASE_NOTES.md
{
echo "content<<EOF"
cat RELEASE_NOTES.md
echo "EOF"
} >> $GITHUB_OUTPUT
- name: 创建 GitHub Release
uses: softprops/action-gh-release@v2
with:
tag_name: ${{ needs.pre-check.outputs.tag }}
name: endo ${{ needs.pre-check.outputs.version }}
body: ${{ steps.notes.outputs.content }}
files: assets/*
draft: false
prerelease: ${{ contains(needs.pre-check.outputs.version, '-') }}
generate_release_notes: true
+1
View File
@@ -8,3 +8,4 @@ wheels/
# Virtual environments
.venv
.coverage
+58
View File
@@ -0,0 +1,58 @@
"""Example 3: async aggregation with static args and Context injection.
Shows:
* async task functions executed with strategy="async".
* static positional args (TaskSpec.args) for parameterised tasks.
* Context annotation to receive the full upstream result mapping.
* on_event callback for real-time progress.
"""
from __future__ import annotations
import asyncio
from typing import Any, Dict, List
import pyflowx as px
async def fetch_user(uid: int) -> dict:
await asyncio.sleep(0.2)
return {"id": uid, "name": f"User{uid}"}
async def fetch_posts(uid: int) -> List[int]:
await asyncio.sleep(0.2)
return [uid, uid + 1]
# Context annotation → receives the full mapping of upstream results.
def aggregate(ctx: px.Context) -> Dict[str, Any]:
return dict(ctx)
def main() -> None:
graph = px.Graph.from_specs(
[
# Static positional args parameterise the same function twice.
px.TaskSpec("fetch_user", fetch_user, args=(1,)),
px.TaskSpec("fetch_posts", fetch_posts, args=(1,)),
px.TaskSpec("aggregate", aggregate, ("fetch_user", "fetch_posts")),
]
)
print("=== Dry run ===")
px.run(graph, strategy="async", dry_run=True)
events: List[px.TaskEvent] = []
print("\n=== Async execution ===")
report = px.run(graph, strategy="async", on_event=events.append)
for ev in events:
print(f" event: {ev.task} -> {ev.status.value}")
print(f"\naggregate = {report['aggregate']}")
print(report.describe())
if __name__ == "__main__":
main()
+81
View File
@@ -0,0 +1,81 @@
"""Example 1: ETL pipeline (sequential strategy).
Demonstrates the core PyFlowX workflow:
* Define tasks as plain functions.
* Declare the DAG with a list of TaskSpec.
* Parameter names == dependency names → automatic context injection,
no wrappers needed (contrast with flowweaver's get_task_result boilerplate).
* dry_run to preview, then execute and read typed results from RunReport.
"""
from __future__ import annotations
from typing import List
import pyflowx as px
# --- task functions: pure, testable, no framework coupling ------------- #
def extract_customers() -> List[dict]:
return [
{"id": "C001", "name": "Alice"},
{"id": "C002", "name": "Bob"},
]
def extract_orders() -> List[dict]:
return [
{"id": "O001", "customer_id": "C001", "amount": 150.0},
{"id": "O002", "customer_id": "C002", "amount": 200.5},
]
# Parameter names match dependency names → automatic injection.
def transform(
extract_customers: List[dict],
extract_orders: List[dict],
) -> List[dict]:
cmap = {c["id"]: c for c in extract_customers}
return [
{**o, "customer_name": cmap[o["customer_id"]]["name"]}
for o in extract_orders
if o["customer_id"] in cmap
]
def load(transform: List[dict]) -> int:
print(f" loaded {len(transform)} records")
return len(transform)
def main() -> None:
graph = px.Graph.from_specs(
[
px.TaskSpec("extract_customers", extract_customers, tags=("extract",)),
px.TaskSpec("extract_orders", extract_orders, tags=("extract",)),
px.TaskSpec(
"transform",
transform,
("extract_customers", "extract_orders"),
tags=("transform",),
),
px.TaskSpec("load", load, ("transform",), retries=1, tags=("load",)),
]
)
print("=== Execution plan ===")
print(graph.describe())
print("\n=== Dry run (no execution) ===")
px.run(graph, strategy="sequential", dry_run=True)
print("\n=== Sequential execution ===")
report = px.run(graph, strategy="sequential")
print(report.describe())
print(f"\nload result = {report['load']}")
print(f"summary = {report.summary()}")
if __name__ == "__main__":
main()
+59
View File
@@ -0,0 +1,59 @@
"""Example 2: parallel execution (thread strategy).
Same DAG run with sequential vs. thread strategy to show layer-internal
parallelism. Tasks within a layer run concurrently; layers are barriers.
Layer 1: [fetch_a, fetch_b] (parallel)
Layer 2: [merge] (waits for both)
"""
from __future__ import annotations
import time
import pyflowx as px
def fetch_a() -> str:
time.sleep(0.5)
return "a"
def fetch_b() -> str:
time.sleep(0.5)
return "b"
def merge(fetch_a: str, fetch_b: str) -> str:
return fetch_a + fetch_b
def main() -> None:
graph = px.Graph.from_specs(
[
px.TaskSpec("fetch_a", fetch_a),
px.TaskSpec("fetch_b", fetch_b),
px.TaskSpec("merge", merge, ("fetch_a", "fetch_b")),
]
)
print("=== Mermaid diagram ===")
print(graph.to_mermaid("LR"))
print("\n=== Sequential (expect ~1.0s) ===")
start = time.time()
report_seq = px.run(graph, strategy="sequential")
t_seq = time.time() - start
print(f" result={report_seq['merge']} time={t_seq:.2f}s")
print("\n=== Threaded (expect ~0.5s) ===")
start = time.time()
report_thr = px.run(graph, strategy="thread", max_workers=2)
t_thr = time.time() - start
print(f" result={report_thr['merge']} time={t_thr:.2f}s")
print(f"\nspeedup = {t_seq / t_thr:.2f}x")
if __name__ == "__main__":
main()
+68 -3
View File
@@ -1,7 +1,72 @@
[project]
authors = [{ name = "pyflowx" }]
classifiers = [
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Software Development :: Libraries :: Application Frameworks",
]
description = "Lightweight, type-safe DAG task scheduler with multi-strategy execution."
keywords = ["async", "dag", "scheduler", "task", "workflow"]
license = { text = "MIT" }
name = "pyflowx"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.8"
dependencies = []
version = "0.1.0"
# graphlib_backport only needed on Python 3.8 (stdlib graphlib exists in 3.9+)
dependencies = ["graphlib_backport >= 1.0.0; python_version < '3.9'"]
[project.optional-dependencies]
dev = [
"hatch>=1.14.2",
"httpx>=0.28.0",
"mypy >= 1.0",
"prek>=0.4.5",
"pytest-asyncio>=0.24.0",
"pytest-cov>=5.0.0",
"pytest-html>=4.1.1",
"pytest-mock>=3.14.0",
"pytest-xdist>=3.6.1",
"pytest>=8.0.0",
"ruff>=0.8.0",
"tox-uv>=1.13.1",
"tox>=4.25.0",
]
[build-system]
build-backend = "hatchling.build"
requires = ["hatchling"]
[tool.hatch.build.targets.wheel]
packages = ["src/pyflowx"]
[tool.hatch.build.targets.wheel.force-include]
"src/pyflowx/py.typed" = "pyflowx/py.typed"
[tool.mypy]
# mypy 2.x requires a >=3.10 target. We check against 3.10 syntax; the
# runtime stays 3.8-compatible via `from __future__ import annotations`
# (all annotations are strings at runtime) and the graphlib_backport
# conditional dependency for topological sorting.
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_untyped_defs = true
files = ["src/pyflowx"]
ignore_missing_imports = false
python_version = "3.8"
strict = true
warn_return_any = true
warn_unused_configs = true
[tool.uv.sources]
pyflowx = { workspace = true }
[[tool.uv.index]]
default = true
url = "https://mirrors.aliyun.com/pypi/simple/"
[dependency-groups]
dev = ["pyflowx[dev]"]
+75
View File
@@ -0,0 +1,75 @@
"""PyFlowX — lightweight, type-safe DAG task scheduler.
Public API
----------
* :class:`TaskSpec` — immutable task descriptor (the only thing you configure).
* :class:`Graph` — DAG built from a list of specs; validates, layers, visualises.
* :func:`run` — execute a graph with ``sequential`` / ``thread`` / ``async``.
* :class:`RunReport` — typed, queryable result of a run.
* :class:`Context` — annotation marker for whole-context injection.
* State backends: :class:`StateBackend`, :class:`MemoryBackend`, :class:`JSONBackend`.
Quick start
-----------
import pyflowx as px
def extract() -> list[int]: return [1, 2, 3]
def double(extract: list[int]) -> list[int]: return [x * 2 for x in extract]
graph = px.Graph.from_specs([
px.TaskSpec("extract", extract),
px.TaskSpec("double", double, ("extract",)),
])
report = px.run(graph, strategy="sequential")
print(report["double"]) # [2, 4, 6]
"""
from __future__ import annotations
from .context import Context, build_call_args, describe_injection
from .errors import (
CycleError,
DuplicateTaskError,
InjectionError,
MissingDependencyError,
PyFlowXError,
StorageError,
TaskFailedError,
TaskTimeoutError,
)
from .executors import run
from .graph import Graph
from .report import RunReport
from .storage import JSONBackend, MemoryBackend, StateBackend
from .task import TaskEvent, TaskResult, TaskSpec, TaskStatus
__version__ = "0.1.0"
__all__ = [
# core types
"TaskSpec",
"TaskStatus",
"TaskResult",
"TaskEvent",
"Context",
"Graph",
"RunReport",
# execution
"run",
# state backends
"StateBackend",
"MemoryBackend",
"JSONBackend",
# errors
"PyFlowXError",
"DuplicateTaskError",
"MissingDependencyError",
"CycleError",
"TaskFailedError",
"TaskTimeoutError",
"InjectionError",
"StorageError",
# helpers (advanced)
"build_call_args",
"describe_injection",
]
+203
View File
@@ -0,0 +1,203 @@
"""Context injection: turn upstream results into function arguments.
This is the mechanism that lets users write plain functions whose
parameter names *are* the dependency declarations, removing the boiler-
plate wrappers that plague other DAG libraries (e.g. ``def wrapper():
return fn(workflow.get_task_result('x'))``).
Injection rules (evaluated in order)
-----------------------------------
1. A parameter whose **annotation is** :class:`Context` receives the full
result mapping. Useful for tasks that need to iterate over all inputs.
2. A parameter whose **name matches a dependency** receives that
dependency's result.
3. A ``**kwargs`` parameter receives *all* dependency results as a dict.
4. ``TaskSpec.args`` / ``TaskSpec.kwargs`` supply static values for
parameters that are *not* dependencies.
If a parameter cannot be resolved and has no default, an
:class:`~pyflowx.errors.InjectionError` is raised with a precise message.
"""
from __future__ import annotations
import inspect
from typing import Any, Dict, List, Mapping, Set, Tuple
from .errors import InjectionError
from .task import Context, TaskSpec
__all__ = ["Context", "build_call_args", "describe_injection"]
def _is_context_annotation(annotation: Any) -> bool:
"""True when a parameter annotation is (or refers to) ``Context``.
Handles three forms:
* the ``Context`` alias object itself;
* a typing alias whose ``__name__``/``_name`` is ``Context`` or ``Mapping``;
* a *string* annotation (``from __future__ import annotations`` makes all
annotations strings at runtime) such as ``"Context"`` or ``"px.Context"``.
"""
if annotation is Context:
return True
# String annotation from `from __future__ import annotations`.
if isinstance(annotation, str):
# Match "Context", "px.Context", "pyflowx.Context", etc.
return annotation == "Context" or annotation.endswith(".Context")
# Match by qualified name to support ``from pyflowx import Context``
# re-exports.
name = getattr(annotation, "__name__", None) or getattr(annotation, "_name", None)
if name in ("Context", "Mapping"):
return True
return False
def build_call_args(
spec: TaskSpec[object],
context: Mapping[str, Any],
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""Resolve the ``(args, kwargs)`` to call ``spec.fn`` with.
Parameters
----------
spec:
The task spec, providing ``fn``, ``depends_on``, ``args``, ``kwargs``.
context:
Mapping of dependency-name -> result value. Only the task's own
``depends_on`` entries are guaranteed present; other tasks' results
are excluded to keep injection deterministic.
Returns
-------
(args, kwargs)
Ready to splat into ``spec.fn(*args, **kwargs)``.
Raises
------
InjectionError
If a required parameter cannot be satisfied, or if static
``kwargs`` collide with an injected dependency name.
"""
sig = inspect.signature(spec.fn)
params = sig.parameters
# Detect special parameter kinds.
var_keyword = next(
(p for p in params.values() if p.kind == inspect.Parameter.VAR_KEYWORD),
None,
)
# The subset of context relevant to this task.
dep_context: Dict[str, Any] = {
name: context[name] for name in spec.depends_on if name in context
}
# Detect collisions between static kwargs and dependency names.
collisions = set(spec.kwargs) & set(dep_context)
if collisions:
raise InjectionError(
spec.name,
f"static kwargs {sorted(collisions)} collide with dependency names; "
"rename the static kwarg or the dependency.",
)
injected_kwargs: Dict[str, Any] = {}
leftover_dep_results: Dict[str, Any] = dict(dep_context)
# Positional parameters consumed by spec.args. We track which param
# names are filled positionally so they are skipped during name-based
# injection (dependency / Context / static kwargs).
positional_params: List[str] = []
positional_kinds = (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
for pname, param in params.items():
if param.kind in positional_kinds:
positional_params.append(pname)
# The first len(spec.args) positional params are filled by spec.args.
args_filled: Set[str] = set(positional_params[: len(spec.args)])
for pname, param in params.items():
# Skip params already filled by positional spec.args.
if pname in args_filled:
continue
# Rule 1: annotated as Context -> full mapping.
if _is_context_annotation(param.annotation):
injected_kwargs[pname] = dep_context
continue
# Rule 2: name matches a dependency.
if pname in dep_context:
injected_kwargs[pname] = dep_context[pname]
leftover_dep_results.pop(pname, None)
continue
# Rule 3: handled after the loop via **kwargs.
# Rule 4: static kwargs fill the rest.
if pname in spec.kwargs:
injected_kwargs[pname] = spec.kwargs[pname]
continue
# No source for this parameter: must have a default, else error.
if param.default is inspect.Parameter.empty and param.kind not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
raise InjectionError(
spec.name,
f"parameter {pname!r} has no dependency, static value, or default.",
)
# Rule 3: **kwargs swallows remaining dependency results.
if var_keyword is not None and leftover_dep_results:
# Merge static kwargs first, then dependency results (static wins
# on collision — but we already rejected collisions above).
merged = dict(spec.kwargs)
merged.update(injected_kwargs)
merged.update(leftover_dep_results)
injected_kwargs = merged
return tuple(spec.args), injected_kwargs
def describe_injection(spec: TaskSpec[object]) -> str:
"""Human-readable description of how a task's args will be injected.
Used by ``dry_run`` to show the execution plan without executing it.
"""
sig = inspect.signature(spec.fn)
# Determine which positional params are filled by spec.args.
positional_params = [
p
for p, param in sig.parameters.items()
if param.kind
in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
]
args_filled = set(positional_params[: len(spec.args)])
parts = []
for pname, param in sig.parameters.items():
if pname in args_filled:
idx = positional_params.index(pname)
parts.append(f"{pname}={spec.args[idx]!r}")
elif _is_context_annotation(param.annotation):
parts.append(f"{pname}=<Context>")
elif pname in spec.depends_on:
parts.append(f"{pname}=<result:{pname}>")
elif pname in spec.kwargs:
parts.append(f"{pname}={spec.kwargs[pname]!r}")
elif param.default is not inspect.Parameter.empty:
parts.append(f"{pname}=<default>")
elif param.kind == inspect.Parameter.VAR_KEYWORD:
parts.append("**kwargs=<all-deps>")
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
parts.append("*args")
else:
parts.append(f"{pname}=<UNRESOLVED>")
return f"{spec.name}({', '.join(parts)})"
+93
View File
@@ -0,0 +1,93 @@
"""PyFlowX error hierarchy.
All errors are concrete subclasses of :class:`PyFlowXError` so callers can
catch the entire family with a single ``except`` clause, while still being
able to discriminate by type for fine-grained handling.
"""
from __future__ import annotations
from typing import Any, Iterable, Optional
class PyFlowXError(Exception):
"""Base class for every PyFlowX error."""
class DuplicateTaskError(PyFlowXError):
"""Raised when a task name is registered more than once."""
def __init__(self, name: str) -> None:
super().__init__(f"Task '{name}' is already registered in the graph.")
self.name = name
class MissingDependencyError(PyFlowXError):
"""Raised when a task depends on a name that is not in the graph."""
def __init__(self, task: str, dependency: str) -> None:
super().__init__(
f"Task '{task}' depends on unknown task '{dependency}'. "
"Add the dependency before (or together with) this task."
)
self.task = task
self.dependency = dependency
class CycleError(PyFlowXError):
"""Raised when the dependency graph contains a cycle."""
def __init__(self, cycle: Iterable[str]) -> None:
cycle_list = list(cycle)
chain = " -> ".join(cycle_list + cycle_list[:1])
super().__init__(f"The dependency graph contains a cycle: {chain}")
self.cycle = cycle_list
class TaskFailedError(PyFlowXError):
"""Raised when a task fails after exhausting all retries.
The original exception is preserved on :attr:`__cause__` and also exposed
via :attr:`cause` for convenient access in user code.
"""
def __init__(
self,
task: str,
cause: BaseException,
attempts: int,
layer: Optional[int] = None,
) -> None:
location = f" (layer {layer})" if layer is not None else ""
super().__init__(
f"Task '{task}' failed after {attempts} attempt(s){location}: {cause}"
)
self.task = task
self.cause = cause
self.attempts = attempts
self.layer = layer
class TaskTimeoutError(PyFlowXError):
"""Raised when a task exceeds its configured timeout."""
def __init__(self, task: str, timeout: float) -> None:
super().__init__(f"Task '{task}' timed out after {timeout:.3f}s.")
self.task = task
self.timeout = timeout
class InjectionError(PyFlowXError):
"""Raised when context injection cannot satisfy a task signature."""
def __init__(self, task: str, detail: str) -> None:
super().__init__(f"Cannot inject context for task '{task}': {detail}")
self.task = task
class StorageError(PyFlowXError):
"""Raised by state backends on persistence failures."""
def __init__(self, detail: str, cause: Optional[BaseException] = None) -> None:
super().__init__(f"State storage error: {detail}")
self.cause: Any = cause
+425
View File
@@ -0,0 +1,425 @@
"""Executors and the public :func:`run` entry point.
Three execution strategies share a common layer-by-layer driver:
* ``sequential`` — deterministic, one task at a time. Best for debugging.
* ``thread`` — layer-internal concurrency via a thread pool. Best for
I/O-bound sync tasks.
* ``async`` — layer-internal concurrency via ``asyncio.gather``.
Sync tasks are offloaded to a thread pool; async tasks
run on the event loop. Best for I/O-bound async tasks.
All three honour ``retries``, ``timeout``, context injection, state
backends (resume), and emit :class:`~pyflowx.task.TaskEvent` for observers.
"""
from __future__ import annotations
import asyncio
import concurrent.futures
import inspect
import logging
from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, cast
from .context import build_call_args, describe_injection
from .errors import TaskFailedError, TaskTimeoutError
from .graph import Graph
from .report import RunReport
from .storage import StateBackend, resolve_backend
from .task import TaskEvent, TaskResult, TaskSpec, TaskStatus
logger = logging.getLogger("pyflowx")
# Observer callback type.
EventCallback = Callable[[TaskEvent], None]
# Strategy selector literal.
Strategy = str # "sequential" | "thread" | "async"
def _is_async_fn(spec: TaskSpec[object]) -> bool:
"""True if ``spec.fn`` is a coroutine function."""
return inspect.iscoroutinefunction(spec.fn)
def _emit(
on_event: Optional[EventCallback],
result: TaskResult[object],
) -> None:
"""Fire an observer event if a callback is registered."""
if on_event is None:
return
on_event(
TaskEvent(
task=result.spec.name,
status=result.status,
attempts=result.attempts,
error=repr(result.error) if result.error else None,
duration=result.duration,
)
)
def _run_sync_with_retry(
spec: TaskSpec[object],
context: Mapping[str, Any],
layer_idx: Optional[int],
) -> TaskResult[object]:
"""Execute a sync task with retries; return a populated TaskResult."""
result: TaskResult[object] = TaskResult(spec=spec)
result.started_at = datetime.now()
max_attempts = spec.retries + 1
args, kwargs = build_call_args(spec, context)
while result.attempts < max_attempts:
result.attempts += 1
try:
result.value = spec.fn(*args, **kwargs)
result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now()
return result
except Exception as exc: # noqa: BLE001 - user code may raise anything
result.error = exc
if result.attempts >= max_attempts:
break
logger.warning(
"task %r failed (attempt %d/%d): %r; retrying",
spec.name,
result.attempts,
max_attempts,
exc,
)
result.status = TaskStatus.FAILED
result.finished_at = datetime.now()
raise TaskFailedError(
task=spec.name,
cause=result.error if result.error is not None else RuntimeError("unknown"),
attempts=result.attempts,
layer=layer_idx,
)
async def _run_async_with_retry(
spec: TaskSpec[object],
context: Mapping[str, Any],
layer_idx: Optional[int],
) -> TaskResult[object]:
"""Execute a task (sync or async) on the event loop with retries."""
result: TaskResult[object] = TaskResult(spec=spec)
result.started_at = datetime.now()
max_attempts = spec.retries + 1
args, kwargs = build_call_args(spec, context)
loop = asyncio.get_event_loop()
while result.attempts < max_attempts:
result.attempts += 1
try:
if _is_async_fn(spec):
coro = cast(Awaitable[Any], spec.fn(*args, **kwargs))
if spec.timeout is not None:
result.value = await asyncio.wait_for(coro, timeout=spec.timeout)
else:
result.value = await coro
else:
# Offload sync work to a thread so the event loop stays alive.
fn_call: Callable[[], Any] = lambda: spec.fn(*args, **kwargs)
if spec.timeout is not None:
result.value = await asyncio.wait_for(
loop.run_in_executor(None, fn_call), timeout=spec.timeout
)
else:
result.value = await loop.run_in_executor(None, fn_call)
result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now()
return result
except asyncio.TimeoutError:
result.error = TaskTimeoutError(spec.name, spec.timeout or 0.0)
if result.attempts >= max_attempts:
break
logger.warning(
"task %r timed out (attempt %d/%d); retrying",
spec.name,
result.attempts,
max_attempts,
)
except Exception as exc: # noqa: BLE001
result.error = exc
if result.attempts >= max_attempts:
break
logger.warning(
"task %r failed (attempt %d/%d): %r; retrying",
spec.name,
result.attempts,
max_attempts,
exc,
)
result.status = TaskStatus.FAILED
result.finished_at = datetime.now()
raise TaskFailedError(
task=spec.name,
cause=result.error if result.error is not None else RuntimeError("unknown"),
attempts=result.attempts,
layer=layer_idx,
)
# ---------------------------------------------------------------------- #
# Layer driver
# ---------------------------------------------------------------------- #
def _build_context(
spec: TaskSpec[object],
global_context: Mapping[str, Any],
) -> Mapping[str, Any]:
"""Restrict the global context to this task's dependencies."""
return {
dep: global_context[dep] for dep in spec.depends_on if dep in global_context
}
def _execute_layer_sequential(
layer: List[str],
graph: Graph,
context: Dict[str, Any],
report: RunReport,
backend: StateBackend,
layer_idx: int,
on_event: Optional[EventCallback],
) -> None:
"""Run a layer's tasks one by one."""
for name in layer:
spec = graph.spec(name)
if backend.has(name):
cached = backend.get(name)
context[name] = cached
result = TaskResult(spec=spec, status=TaskStatus.SKIPPED, value=cached)
report.results[name] = result
_emit(on_event, result)
logger.info("task %r skipped (cached)", name)
continue
result = _run_sync_with_retry(spec, _build_context(spec, context), layer_idx)
context[name] = result.value
backend.save(name, result.value)
report.results[name] = result
_emit(on_event, result)
def _execute_layer_threaded(
layer: List[str],
graph: Graph,
context: Dict[str, Any],
report: RunReport,
backend: StateBackend,
layer_idx: int,
on_event: Optional[EventCallback],
max_workers: int,
) -> None:
"""Run a layer's tasks concurrently in a thread pool."""
# First, satisfy cached tasks synchronously.
to_run: List[str] = []
for name in layer:
if backend.has(name):
cached = backend.get(name)
context[name] = cached
result = TaskResult(
spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached
)
report.results[name] = result
_emit(on_event, result)
else:
to_run.append(name)
if not to_run:
return
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
future_to_name: Dict[concurrent.futures.Future[TaskResult[object]], str] = {}
for name in to_run:
spec = graph.spec(name)
# Snapshot the context for this task to avoid races.
task_ctx = _build_context(spec, context)
fut = pool.submit(_run_sync_with_retry, spec, task_ctx, layer_idx)
future_to_name[fut] = name
for fut in concurrent.futures.as_completed(future_to_name):
name = future_to_name[fut]
result = fut.result() # raises TaskFailedError on failure
context[name] = result.value
backend.save(name, result.value)
report.results[name] = result
_emit(on_event, result)
async def _execute_layer_async(
layer: List[str],
graph: Graph,
context: Dict[str, Any],
report: RunReport,
backend: StateBackend,
layer_idx: int,
on_event: Optional[EventCallback],
) -> None:
"""Run a layer's tasks concurrently on the event loop."""
to_run: List[str] = []
for name in layer:
if backend.has(name):
cached = backend.get(name)
context[name] = cached
result = TaskResult(
spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached
)
report.results[name] = result
_emit(on_event, result)
else:
to_run.append(name)
if not to_run:
return
coros = []
for name in to_run:
spec = graph.spec(name)
task_ctx = _build_context(spec, context)
coros.append(_run_async_with_retry(spec, task_ctx, layer_idx))
results = await asyncio.gather(*coros)
for name, result in zip(to_run, results):
context[name] = result.value
backend.save(name, result.value)
report.results[name] = result
_emit(on_event, result)
# ---------------------------------------------------------------------- #
# Public API
# ---------------------------------------------------------------------- #
def run(
graph: Graph,
strategy: Strategy = "sequential",
*,
max_workers: Optional[int] = None,
dry_run: bool = False,
on_event: Optional[EventCallback] = None,
state: Optional[StateBackend] = None,
) -> RunReport:
"""Execute a graph and return a :class:`RunReport`.
Parameters
----------
graph:
The validated :class:`Graph` to execute.
strategy:
``"sequential"`` (default), ``"thread"``, or ``"async"``.
max_workers:
Thread-pool size for ``"thread"``. Defaults to ``min(32, len(layer))``.
dry_run:
If ``True``, print the execution plan (layers + injection) and
return an empty report without executing anything.
on_event:
Optional callback invoked on every status transition.
state:
Optional :class:`StateBackend` for resumable runs. Defaults to an
in-memory backend (no persistence across processes).
Raises
------
ValueError
If ``strategy`` is not recognised.
TaskFailedError
If any task fails after exhausting retries. The run aborts at the
failing layer; tasks in later layers are not attempted.
"""
if strategy not in ("sequential", "thread", "async"):
raise ValueError(
f"unknown strategy {strategy!r}; expected 'sequential', 'thread', or 'async'."
)
graph.validate()
layers = graph.layers()
if dry_run:
_print_dry_run(graph, layers)
return RunReport(success=True)
backend = resolve_backend(state)
report = RunReport()
context: Dict[str, Any] = {}
try:
if strategy == "sequential":
_drive_sequential(graph, layers, context, report, backend, on_event)
elif strategy == "thread":
_drive_threaded(
graph, layers, context, report, backend, on_event, max_workers
)
else:
_drive_async(graph, layers, context, report, backend, on_event)
except TaskFailedError:
report.success = False
raise
return report
def _print_dry_run(graph: Graph, layers: List[List[str]]) -> None:
"""Print the execution plan without running anything."""
print(f"Dry run: {len(graph)} tasks, {len(layers)} layers")
for idx, layer in enumerate(layers, 1):
print(f" Layer {idx}: {layer}")
for name in layer:
print(f" - {describe_injection(graph.spec(name))}")
def _drive_sequential(
graph: Graph,
layers: List[List[str]],
context: Dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: Optional[EventCallback],
) -> None:
for idx, layer in enumerate(layers, 1):
_execute_layer_sequential(layer, graph, context, report, backend, idx, on_event)
def _drive_threaded(
graph: Graph,
layers: List[List[str]],
context: Dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: Optional[EventCallback],
max_workers: Optional[int],
) -> None:
for idx, layer in enumerate(layers, 1):
workers = max_workers or max(1, min(32, len(layer)))
_execute_layer_threaded(
layer, graph, context, report, backend, idx, on_event, workers
)
def _drive_async(
graph: Graph,
layers: List[List[str]],
context: Dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: Optional[EventCallback],
) -> None:
asyncio.run(_async_drive(graph, layers, context, report, backend, on_event))
async def _async_drive(
graph: Graph,
layers: List[List[str]],
context: Dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: Optional[EventCallback],
) -> None:
for idx, layer in enumerate(layers, 1):
await _execute_layer_async(
layer, graph, context, report, backend, idx, on_event
)
+245
View File
@@ -0,0 +1,245 @@
"""DAG construction, validation, layering and visualisation.
Uses :mod:`graphlib` from the standard library (3.9+) or
:mod:`graphlib_backport` (3.8) for topological sorting. The graph is
built incrementally and validated eagerly so that misconfiguration fails
fast — at construction time, not at execution time.
"""
from __future__ import annotations
import sys
from typing import Dict, Iterable, List, Mapping, Sequence, Set, Tuple
from .errors import CycleError, DuplicateTaskError, MissingDependencyError
from .task import TaskSpec
# graphlib lives in the stdlib since 3.9; fall back to the backport on 3.8.
if sys.version_info >= (3, 9):
import graphlib
_TopologicalSorter = graphlib.TopologicalSorter
else: # pragma: no cover - exercised only on 3.8
import graphlib # type: ignore[no-redef]
_TopologicalSorter = graphlib.TopologicalSorter
class Graph:
"""An immutable-after-validation directed acyclic graph of tasks.
The graph is built by adding :class:`~pyflowx.task.TaskSpec` instances.
Each ``add`` performs eager validation (duplicate names, missing
dependencies), and :meth:`validate` / :meth:`layers` perform full DAG
validation (cycle detection) and topological layering.
The graph holds only the *configuration*; runtime state lives in
:class:`~pyflowx.report.RunReport`. This makes a graph safely
re-runnable and shareable across threads.
"""
def __init__(self) -> None:
self._specs: Dict[str, TaskSpec[object]] = {}
# Map task -> its direct dependencies (predecessors).
self._deps: Dict[str, Tuple[str, ...]] = {}
# ------------------------------------------------------------------ #
# Construction
# ------------------------------------------------------------------ #
def add(self, spec: TaskSpec[object]) -> "Graph":
"""Register a task spec with eager validation.
Returns ``self`` so calls can be chained, but the recommended
entry point is :meth:`from_specs` which validates the whole batch
together (allowing forward references in a single call).
"""
self._specs[spec.name] = spec
self._deps[spec.name] = spec.depends_on
# Eagerly check duplicates and missing deps for the incremental API.
self._validate_references()
return self
@classmethod
def from_specs(cls, specs: Iterable[TaskSpec[object]]) -> "Graph":
"""Build a graph from an iterable of task specs.
All specs are collected first, then validated together. This means
a task may reference a dependency that appears *later* in the
iterable — order does not matter, mirroring how a declarative
config file reads.
"""
graph = cls()
for spec in specs:
if spec.name in graph._specs:
raise DuplicateTaskError(spec.name)
graph._specs[spec.name] = spec
graph._deps[spec.name] = spec.depends_on
graph._validate_references()
graph.validate()
return graph
# ------------------------------------------------------------------ #
# Validation
# ------------------------------------------------------------------ #
def _validate_references(self) -> None:
"""Ensure every dependency name exists in the graph."""
for name, deps in self._deps.items():
for dep in deps:
if dep not in self._specs:
raise MissingDependencyError(name, dep)
def validate(self) -> None:
"""Run full DAG validation.
Raises :class:`~pyflowx.errors.CycleError` if a cycle exists.
Dependency existence is checked by :meth:`_validate_references`.
"""
self._validate_references()
sorter = _TopologicalSorter(self._deps)
try:
# prepare() raises CycleError on cycles; we don't need the
# static_order() result here, just the validation side effect.
sorter.prepare()
except graphlib.CycleError as exc:
# exc.args[1] is the list of nodes forming the cycle.
cycle: Sequence[str] = exc.args[1] if len(exc.args) > 1 else []
raise CycleError(list(cycle)) from exc
# ------------------------------------------------------------------ #
# Introspection
# ------------------------------------------------------------------ #
@property
def names(self) -> List[str]:
"""All registered task names (insertion order)."""
return list(self._specs.keys())
def spec(self, name: str) -> TaskSpec[object]:
"""Return the spec for ``name``; ``KeyError`` if absent."""
return self._specs[name]
def dependencies(self, name: str) -> Tuple[str, ...]:
"""Direct predecessors of ``name``."""
return self._deps[name]
def all_specs(self) -> Mapping[str, TaskSpec[object]]:
"""Read-only view of name -> spec."""
return self._specs
def layers(self) -> List[List[str]]:
"""Group tasks into parallel-executable layers (Kahn's algorithm).
Tasks within the same layer have no mutual dependencies and may
run concurrently. Layers are returned in execution order.
Raises :class:`~pyflowx.errors.CycleError` if the graph is cyclic.
"""
self.validate()
sorter = _TopologicalSorter(self._deps)
result: List[List[str]] = []
# ``get_ready`` + ``done`` gives us one layer at a time, which is
# exactly the parallel-execution grouping we need.
sorter.prepare()
while sorter.is_active():
ready = list(sorter.get_ready())
# Sort for deterministic, reproducible execution plans.
ready.sort()
result.append(ready)
for node in ready:
sorter.done(node)
return result
# ------------------------------------------------------------------ #
# Subgraph / tag filtering
# ------------------------------------------------------------------ #
def subgraph(self, tags: Iterable[str]) -> "Graph":
"""Return a new graph containing only tasks matching any tag.
Dependencies are pruned to keep only edges between retained tasks;
edges to dropped tasks are removed (the retained task no longer
waits for them). Use this to run a slice of a large DAG for
debugging.
"""
wanted: Set[str] = set(tags)
kept: List[TaskSpec[object]] = []
for spec in self._specs.values():
if wanted & set(spec.tags):
pruned_deps = tuple(
d for d in spec.depends_on if d in self._specs and (wanted & set(self._specs[d].tags))
)
kept.append(
TaskSpec(
name=spec.name,
fn=spec.fn,
depends_on=pruned_deps,
args=spec.args,
kwargs=spec.kwargs,
retries=spec.retries,
timeout=spec.timeout,
tags=spec.tags,
)
)
return Graph.from_specs(kept)
def subgraph_by_names(self, names: Iterable[str]) -> "Graph":
"""Return a new graph restricted to ``names`` (with pruned edges)."""
wanted: Set[str] = set(names)
for n in wanted:
if n not in self._specs:
raise KeyError(f"Unknown task name: {n!r}")
kept: List[TaskSpec[object]] = []
for spec in self._specs.values():
if spec.name in wanted:
pruned_deps = tuple(d for d in spec.depends_on if d in wanted)
kept.append(
TaskSpec(
name=spec.name,
fn=spec.fn,
depends_on=pruned_deps,
args=spec.args,
kwargs=spec.kwargs,
retries=spec.retries,
timeout=spec.timeout,
tags=spec.tags,
)
)
return Graph.from_specs(kept)
# ------------------------------------------------------------------ #
# Visualisation
# ------------------------------------------------------------------ #
def to_mermaid(self, orientation: str = "TD") -> str:
"""Render the DAG as a Mermaid ``graph`` definition string.
No external dependencies; the output can be pasted into Markdown,
rendered by VS Code's Mermaid previewer, or saved to a file.
"""
valid = {"TD", "TB", "BT", "LR", "RL"}
orientation = orientation.upper()
if orientation not in valid:
raise ValueError(f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}.")
lines: List[str] = [f"graph {orientation}"]
for name in self._specs:
lines.append(f' {name}["{name}"]')
for name, deps in self._deps.items():
for dep in deps:
lines.append(f" {dep} --> {name}")
return "\n".join(lines) + "\n"
# ------------------------------------------------------------------ #
# Debug
# ------------------------------------------------------------------ #
def describe(self) -> str:
"""Human-readable multi-line summary for debugging."""
out: List[str] = [f"Graph(tasks={len(self._specs)})"]
for layer_idx, layer in enumerate(self.layers(), 1):
out.append(f" Layer {layer_idx}: {layer}")
return "\n".join(out)
def __repr__(self) -> str:
return f"Graph(tasks={len(self._specs)})"
def __len__(self) -> int:
return len(self._specs)
def __contains__(self, name: object) -> bool:
return name in self._specs
View File
+82
View File
@@ -0,0 +1,82 @@
"""Run report: typed, queryable result of a single :func:`pyflowx.run`.
The report is the single source of truth after execution. It exposes
per-task results via ``report["name"]`` (typed as ``Any`` because the
mapping is heterogeneous), summary statistics, and a flag indicating
whether the whole run succeeded.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List, Mapping, Optional
from .task import TaskResult, TaskStatus
@dataclass
class RunReport:
"""Aggregated outcome of a workflow run.
Attributes
----------
results:
Mapping of task name -> :class:`TaskResult`. Insertion order
matches the order tasks finished.
success:
``True`` iff every non-skipped task ended in ``SUCCESS``.
"""
results: Dict[str, TaskResult[object]] = field(default_factory=dict)
success: bool = True
# ---- typed access ------------------------------------------------- #
def __getitem__(self, name: str) -> Any:
"""Return the *value* of task ``name`` (not the TaskResult).
Raises ``KeyError`` if the task was not part of the run. Returns
``None`` for tasks that did not reach SUCCESS.
"""
return self.results[name].value
def result_of(self, name: str) -> TaskResult[object]:
"""Return the full :class:`TaskResult` for ``name``."""
return self.results[name]
def __contains__(self, name: object) -> bool:
return name in self.results
def __iter__(self) -> Iterator[str]:
return iter(self.results)
def __len__(self) -> int:
return len(self.results)
# ---- summary ------------------------------------------------------ #
def summary(self) -> Dict[str, Any]:
"""Compact statistics dict for logging / dashboards."""
counts: Dict[str, int] = {}
total_duration = 0.0
for r in self.results.values():
counts[r.status.value] = counts.get(r.status.value, 0) + 1
if r.duration is not None:
total_duration += r.duration
return {
"success": self.success,
"total_tasks": len(self.results),
"by_status": counts,
"total_duration_seconds": round(total_duration, 6),
}
def failed_tasks(self) -> List[str]:
"""Names of tasks that ended in FAILED status."""
return [name for name, r in self.results.items() if r.status == TaskStatus.FAILED]
def describe(self) -> str:
"""Human-readable multi-line report for debugging."""
lines: List[str] = [f"RunReport(success={self.success})"]
for name, r in self.results.items():
dur = f"{r.duration:.3f}s" if r.duration is not None else "-"
err = f" error={r.error!r}" if r.error else ""
lines.append(f" {name}: {r.status.value} ({dur} attempts={r.attempts}){err}")
return "\n".join(lines)
+135
View File
@@ -0,0 +1,135 @@
"""State persistence backends for resumable runs.
A :class:`StateBackend` stores the result of every successfully completed
task. On a subsequent run, the executor asks the backend whether a task
already has a stored result; if so, the task is skipped and its stored
value is injected into downstream tasks.
This is intentionally minimal: only *successful* results are persisted
(failed tasks are re-run), and the storage shape is a flat
``{task_name: result}`` mapping. Two backends ship in-tree:
* :class:`MemoryBackend` — fast, in-process, no I/O. Default.
* :class:`JSONBackend` — persists to a JSON file for cross-process resume.
Both are zero-dependency (``json`` is stdlib). Users can subclass
:class:`StateBackend` to plug in SQLite, Redis, etc.
"""
from __future__ import annotations
import json
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, Mapping, Optional
from .errors import StorageError
class StateBackend(ABC):
"""Abstract base for resumable state storage."""
@abstractmethod
def load(self) -> Mapping[str, Any]:
"""Return the full stored mapping (may be empty)."""
@abstractmethod
def save(self, name: str, value: Any) -> None:
"""Persist a single task's successful result."""
@abstractmethod
def has(self, name: str) -> bool:
"""Whether ``name`` has a stored result."""
@abstractmethod
def get(self, name: str) -> Any:
"""Return the stored result for ``name`` (raise ``KeyError`` if absent)."""
@abstractmethod
def clear(self) -> None:
"""Remove all stored state."""
class MemoryBackend(StateBackend):
"""In-process dict backend. Lost when the process exits."""
def __init__(self) -> None:
self._store: Dict[str, Any] = {}
def load(self) -> Mapping[str, Any]:
return dict(self._store)
def save(self, name: str, value: Any) -> None:
self._store[name] = value
def has(self, name: str) -> bool:
return name in self._store
def get(self, name: str) -> Any:
return self._store[name]
def clear(self) -> None:
self._store.clear()
class JSONBackend(StateBackend):
"""File-backed JSON storage for cross-process resume.
Results must be JSON-serialisable. Non-serialisable values raise
:class:`~pyflowx.errors.StorageError` (the run itself is not aborted;
only persistence of that one result fails).
"""
def __init__(self, path: str) -> None:
self._path = path
self._store: Dict[str, Any] = {}
self._load()
def _load(self) -> None:
if not os.path.exists(self._path):
return
try:
with open(self._path, "r", encoding="utf-8") as fh:
data = json.load(fh)
if isinstance(data, dict):
self._store = data
except (OSError, json.JSONDecodeError) as exc:
raise StorageError(f"cannot read state file {self._path!r}", exc) from exc
def _flush(self) -> None:
tmp = self._path + ".tmp"
try:
with open(tmp, "w", encoding="utf-8") as fh:
json.dump(self._store, fh, ensure_ascii=False, indent=2)
os.replace(tmp, self._path)
except (OSError, TypeError) as exc:
raise StorageError(f"cannot write state file {self._path!r}", exc) from exc
def load(self) -> Mapping[str, Any]:
return dict(self._store)
def save(self, name: str, value: Any) -> None:
# Validate serialisability before mutating in-memory state.
try:
json.dumps(value)
except (TypeError, ValueError) as exc:
raise StorageError(
f"result of task {name!r} is not JSON-serialisable", exc
) from exc
self._store[name] = value
self._flush()
def has(self, name: str) -> bool:
return name in self._store
def get(self, name: str) -> Any:
return self._store[name]
def clear(self) -> None:
self._store.clear()
self._flush()
def resolve_backend(backend: Optional[StateBackend]) -> StateBackend:
"""Return ``backend`` or a fresh :class:`MemoryBackend` if ``None``."""
return backend if backend is not None else MemoryBackend()
+151
View File
@@ -0,0 +1,151 @@
"""Core task data structures for PyFlowX.
Everything here is a plain, immutable data structure — no decorators, no
side effects. A :class:`TaskSpec` fully describes a task node; the
:class:`Graph` (see :mod:`pyflowx.graph`) consumes a list of specs and
builds the DAG.
Design notes
------------
* ``TaskSpec`` is a ``Generic[T]`` so that ``TaskSpec[int]`` carries the
return type of ``fn`` all the way to :class:`RunReport`, giving callers
typed access to ``report["name"]``.
* ``Context`` is the only intentionally-dynamic type: results from
upstream tasks are heterogeneous, so the cross-task mapping is
``Mapping[str, Any]``. Within a single task the types remain fully
static because the function signature is checked by mypy.
* ``TaskStatus`` is a closed enum; executors never invent ad-hoc strings.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import (
Any,
Callable,
Coroutine,
Generic,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
)
T = TypeVar("T")
# A task callable may be synchronous or asynchronous. We keep the union
# explicit so mypy understands both shapes.
TaskFn = Union[
Callable[..., T],
Callable[..., Coroutine[Any, Any, T]],
]
# The cross-task result mapping. Deliberately ``Any`` for values because
# different tasks return different types; per-task typing is preserved by
# the function signature itself.
Context = Mapping[str, Any]
class TaskStatus(Enum):
"""Lifecycle states of a task during a single run."""
PENDING = "pending"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
SKIPPED = "skipped" # used by resumable runs and subgraph filtering
@dataclass(frozen=True)
class TaskSpec(Generic[T]):
"""Immutable description of a single DAG node.
Parameters
----------
name:
Unique identifier of the task within a graph. Other tasks reference
this name in ``depends_on``.
fn:
The callable to execute. May be sync or async. Its parameter names
drive automatic context injection (see :mod:`pyflowx.context`).
depends_on:
Names of tasks whose results must be available before this task
runs. Order is irrelevant; the framework topologically sorts.
args:
Static positional arguments appended *after* injected parameters.
Useful for parameterised tasks (e.g. ``fetch_user(uid)``).
kwargs:
Static keyword arguments. Conflict with injected names raises
:class:`~pyflowx.errors.InjectionError`.
retries:
Number of retry attempts on failure. ``0`` means a single attempt.
timeout:
Maximum execution time in seconds. ``None`` disables the timeout.
For async tasks this uses :func:`asyncio.wait_for`; for sync tasks
in the threaded/async executors it cancels the worker future.
tags:
Free-form labels used by :meth:`Graph.subgraph` for selective
execution and debugging.
"""
name: str
fn: TaskFn[T]
depends_on: Tuple[str, ...] = ()
args: Tuple[Any, ...] = ()
kwargs: Mapping[str, Any] = field(default_factory=dict)
retries: int = 0
timeout: Optional[float] = None
tags: Tuple[str, ...] = ()
def __post_init__(self) -> None:
if not self.name:
raise ValueError("TaskSpec.name must be a non-empty string.")
if self.retries < 0:
raise ValueError(f"TaskSpec '{self.name}': retries must be >= 0.")
if self.timeout is not None and self.timeout <= 0:
raise ValueError(f"TaskSpec '{self.name}': timeout must be > 0.")
if self.name in self.depends_on:
raise ValueError(f"TaskSpec '{self.name}' cannot depend on itself.")
@dataclass
class TaskResult(Generic[T]):
"""Mutable per-task record produced during a run.
A fresh :class:`TaskResult` is created for every run; the spec itself
stays immutable. This keeps the same graph safely re-runnable.
"""
spec: TaskSpec[T]
status: TaskStatus = TaskStatus.PENDING
value: Optional[T] = None
error: Optional[BaseException] = None
attempts: int = 0
started_at: Optional[datetime] = None
finished_at: Optional[datetime] = None
@property
def duration(self) -> Optional[float]:
"""Elapsed seconds between start and finish, or ``None``."""
if self.started_at is None or self.finished_at is None:
return None
return (self.finished_at - self.started_at).total_seconds()
@dataclass(frozen=True)
class TaskEvent:
"""Immutable event emitted during execution for observers.
Passed to the ``on_event`` callback of :func:`pyflowx.run` so callers
can build progress bars, metrics, or structured logs without coupling
to executor internals.
"""
task: str
status: TaskStatus
attempts: int = 0
error: Optional[str] = None
duration: Optional[float] = None
View File
+89
View File
@@ -0,0 +1,89 @@
"""Tests for context injection rules."""
from __future__ import annotations
from typing import Any
import pytest
import pyflowx as px
from pyflowx.context import build_call_args, describe_injection
from pyflowx.errors import InjectionError
def test_inject_by_parameter_name() -> None:
def fn(a: int, b: str) -> str:
return f"{a}{b}"
spec = px.TaskSpec("c", fn, ("a", "b"))
args, kwargs = build_call_args(spec, {"a": 1, "b": "x"})
assert args == ()
assert kwargs == {"a": 1, "b": "x"}
def test_inject_context_annotation() -> None:
def fn(ctx: px.Context) -> int:
return len(ctx)
spec = px.TaskSpec("agg", fn, ("a", "b"))
args, kwargs = build_call_args(spec, {"a": 1, "b": 2, "c": 99})
# Only the task's own deps are passed.
assert kwargs == {"ctx": {"a": 1, "b": 2}}
def test_inject_var_keyword() -> None:
def fn(**kwargs: Any) -> int:
return sum(kwargs.values())
spec = px.TaskSpec("agg", fn, ("a", "b"))
args, kwargs = build_call_args(spec, {"a": 1, "b": 2})
assert kwargs == {"a": 1, "b": 2}
def test_static_args_and_kwargs() -> None:
def fn(uid: int, source: str) -> str:
return f"{source}:{uid}"
spec = px.TaskSpec("fetch", fn, args=(42,), kwargs={"source": "api"})
args, kwargs = build_call_args(spec, {})
assert args == (42,)
assert kwargs == {"source": "api"}
def test_default_param_not_required() -> None:
def fn(a: int, flag: bool = True) -> int:
return a if flag else 0
spec = px.TaskSpec("t", fn, ("a",))
args, kwargs = build_call_args(spec, {"a": 5})
assert kwargs == {"a": 5}
def test_unresolved_required_param_raises() -> None:
def fn(a: int, missing: str) -> None:
return None
spec = px.TaskSpec("t", fn, ("a",))
with pytest.raises(InjectionError) as exc_info:
build_call_args(spec, {"a": 1})
assert "missing" in str(exc_info.value)
def test_static_kwargs_collide_with_dependency() -> None:
def fn(a: int) -> int:
return a
spec = px.TaskSpec("t", fn, ("a",), kwargs={"a": 99})
with pytest.raises(InjectionError):
build_call_args(spec, {"a": 1})
def test_describe_injection() -> None:
def fn(a: int, ctx: px.Context, flag: bool = False) -> None:
return None
spec = px.TaskSpec("t", fn, ("a",))
desc = describe_injection(spec)
assert "a=<result:a>" in desc
assert "ctx=<Context>" in desc
assert "flag=<default>" in desc
+322
View File
@@ -0,0 +1,322 @@
"""Tests for execution: sequential, thread, async, retries, timeout, resume."""
from __future__ import annotations
import asyncio
import os
import tempfile
import threading
import time
from typing import Any, List
import pytest
import pyflowx as px
from pyflowx.errors import TaskFailedError, TaskTimeoutError
from pyflowx.storage import JSONBackend, MemoryBackend
# ---------------------------------------------------------------------- #
# Sequential
# ---------------------------------------------------------------------- #
def test_sequential_basic() -> None:
def extract() -> list[int]:
return [1, 2, 3]
def double(extract: list[int]) -> list[int]:
return [x * 2 for x in extract]
graph = px.Graph.from_specs(
[
px.TaskSpec("extract", extract),
px.TaskSpec("double", double, ("extract",)),
]
)
report = px.run(graph, strategy="sequential")
assert report.success
assert report["extract"] == [1, 2, 3]
assert report["double"] == [2, 4, 6]
def test_sequential_diamond() -> None:
order: List[str] = []
def make(name: str) -> Any:
def fn() -> str:
order.append(name)
return name
return fn
graph = px.Graph.from_specs(
[
px.TaskSpec("a", make("a")),
px.TaskSpec("b", make("b"), ("a",)),
px.TaskSpec("c", make("c"), ("a",)),
px.TaskSpec("d", make("d"), ("b", "c")),
]
)
report = px.run(graph, strategy="sequential")
assert report.success
assert report["d"] == "d"
assert order == ["a", "b", "c", "d"]
def test_failure_propagates() -> None:
def boom() -> None:
raise ValueError("kaboom")
def downstream(boom: None) -> int:
return 1
graph = px.Graph.from_specs(
[
px.TaskSpec("boom", boom),
px.TaskSpec("downstream", downstream, ("boom",)),
]
)
with pytest.raises(TaskFailedError) as exc_info:
px.run(graph, strategy="sequential")
assert exc_info.value.task == "boom"
assert isinstance(exc_info.value.cause, ValueError)
def test_retries_then_succeeds() -> None:
attempts = {"n": 0}
def flaky() -> str:
attempts["n"] += 1
if attempts["n"] < 3:
raise RuntimeError("not yet")
return "ok"
graph = px.Graph.from_specs([px.TaskSpec("flaky", flaky, retries=2)])
report = px.run(graph, strategy="sequential")
assert report.success
assert report["flaky"] == "ok"
assert attempts["n"] == 3
def test_retries_exhausted() -> None:
def always_fail() -> None:
raise RuntimeError("nope")
graph = px.Graph.from_specs([px.TaskSpec("f", always_fail, retries=2)])
with pytest.raises(TaskFailedError) as exc_info:
px.run(graph, strategy="sequential")
assert exc_info.value.attempts == 3
# ---------------------------------------------------------------------- #
# Threaded
# ---------------------------------------------------------------------- #
def test_threaded_parallelism() -> None:
def slow() -> str:
time.sleep(0.3)
return "done"
graph = px.Graph.from_specs(
[
px.TaskSpec("a", slow),
px.TaskSpec("b", slow),
px.TaskSpec("c", slow),
]
)
start = time.time()
report = px.run(graph, strategy="thread", max_workers=3)
elapsed = time.time() - start
assert report.success
# Three 0.3s tasks in parallel should be well under 0.8s.
assert elapsed < 0.8
def test_threaded_layer_barrier() -> None:
finished: List[str] = []
lock = threading.Lock()
def make(name: str) -> Any:
def fn() -> str:
time.sleep(0.1)
with lock:
finished.append(name)
return name
return fn
graph = px.Graph.from_specs(
[
px.TaskSpec("a", make("a")),
px.TaskSpec("b", make("b")),
px.TaskSpec("c", make("c"), ("a", "b")),
]
)
report = px.run(graph, strategy="thread", max_workers=2)
assert report.success
# c must finish after both a and b.
assert finished.index("c") > finished.index("a")
assert finished.index("c") > finished.index("b")
# ---------------------------------------------------------------------- #
# Async
# ---------------------------------------------------------------------- #
def test_async_basic() -> None:
async def fetch() -> int:
await asyncio.sleep(0.01)
return 42
async def transform(fetch: int) -> int:
return fetch * 2
graph = px.Graph.from_specs(
[
px.TaskSpec("fetch", fetch),
px.TaskSpec("transform", transform, ("fetch",)),
]
)
report = px.run(graph, strategy="async")
assert report.success
assert report["transform"] == 84
def test_async_parallelism() -> None:
async def slow() -> str:
await asyncio.sleep(0.3)
return "done"
graph = px.Graph.from_specs(
[
px.TaskSpec("a", slow),
px.TaskSpec("b", slow),
px.TaskSpec("c", slow),
]
)
start = time.time()
report = px.run(graph, strategy="async")
elapsed = time.time() - start
assert report.success
assert elapsed < 0.8
def test_async_mixed_sync_and_async() -> None:
def sync_task() -> int:
return 10
async def async_task(sync_task: int) -> int:
await asyncio.sleep(0.01)
return sync_task + 5
graph = px.Graph.from_specs(
[
px.TaskSpec("sync_task", sync_task),
px.TaskSpec("async_task", async_task, ("sync_task",)),
]
)
report = px.run(graph, strategy="async")
assert report.success
assert report["async_task"] == 15
def test_async_timeout() -> None:
async def slow() -> None:
await asyncio.sleep(10)
graph = px.Graph.from_specs([px.TaskSpec("slow", slow, timeout=0.05)])
with pytest.raises(TaskFailedError) as exc_info:
px.run(graph, strategy="async")
assert isinstance(exc_info.value.cause, TaskTimeoutError)
# ---------------------------------------------------------------------- #
# Dry run
# ---------------------------------------------------------------------- #
def test_dry_run_does_not_execute(capsys: pytest.CaptureFixture[str]) -> None:
called: List[str] = []
def fn() -> str:
called.append("x")
return "should-not-run"
graph = px.Graph.from_specs([px.TaskSpec("a", fn)])
report = px.run(graph, strategy="sequential", dry_run=True)
assert called == []
assert len(report) == 0
out = capsys.readouterr().out
assert "Dry run" in out
assert "Layer 1" in out
# ---------------------------------------------------------------------- #
# State / resume
# ---------------------------------------------------------------------- #
def test_memory_backend_resume() -> None:
runs: List[str] = []
def make(name: str) -> Any:
def fn() -> str:
runs.append(name)
return name
return fn
graph = px.Graph.from_specs(
[
px.TaskSpec("a", make("a")),
px.TaskSpec("b", make("b"), ("a",)),
]
)
backend = MemoryBackend()
px.run(graph, strategy="sequential", state=backend)
assert runs == ["a", "b"]
# Second run: both cached, neither re-executed.
px.run(graph, strategy="sequential", state=backend)
assert runs == ["a", "b"] # unchanged
def test_json_backend_persistence() -> None:
with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "state.json")
def fn() -> int:
return 7
graph = px.Graph.from_specs([px.TaskSpec("a", fn)])
px.run(graph, strategy="sequential", state=JSONBackend(path))
# New backend reads the file; task should be skipped.
runs: List[str] = []
def fn2() -> int:
runs.append("ran")
return 8
graph2 = px.Graph.from_specs([px.TaskSpec("a", fn2)])
report = px.run(graph2, strategy="sequential", state=JSONBackend(path))
assert runs == []
assert report["a"] == 7 # cached value, not fn2's 8
# ---------------------------------------------------------------------- #
# Events
# ---------------------------------------------------------------------- #
def test_on_event_callback() -> None:
events: List[px.TaskEvent] = []
def fn() -> int:
return 1
graph = px.Graph.from_specs([px.TaskSpec("a", fn)])
px.run(graph, strategy="sequential", on_event=events.append)
statuses = [e.status for e in events]
assert px.TaskStatus.SUCCESS in statuses
assert all(e.task == "a" for e in events)
# ---------------------------------------------------------------------- #
# Invalid strategy
# ---------------------------------------------------------------------- #
def test_invalid_strategy() -> None:
graph = px.Graph.from_specs([px.TaskSpec("a", lambda: None)]) # type: ignore[arg-type]
with pytest.raises(ValueError):
px.run(graph, strategy="bogus") # type: ignore[arg-type]
+131
View File
@@ -0,0 +1,131 @@
"""Tests for Graph construction, validation, layering and subgraphs."""
from __future__ import annotations
import pytest
import pyflowx as px
from pyflowx.errors import CycleError, DuplicateTaskError, MissingDependencyError
def _fn() -> None:
return None
def test_from_specs_builds_graph() -> None:
graph = px.Graph.from_specs([
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn, ("a",)),
px.TaskSpec("c", _fn, ("a", "b")),
])
assert set(graph.names) == {"a", "b", "c"}
assert graph.dependencies("c") == ("a", "b")
assert len(graph) == 3
assert "a" in graph
def test_from_specs_allows_forward_references() -> None:
# b depends on a, but a is declared after b — order should not matter.
graph = px.Graph.from_specs([
px.TaskSpec("b", _fn, ("a",)),
px.TaskSpec("a", _fn),
])
assert graph.layers() == [["a"], ["b"]]
def test_duplicate_task_raises() -> None:
with pytest.raises(DuplicateTaskError):
px.Graph.from_specs([
px.TaskSpec("a", _fn),
px.TaskSpec("a", _fn),
])
def test_missing_dependency_raises() -> None:
with pytest.raises(MissingDependencyError) as exc_info:
px.Graph.from_specs([px.TaskSpec("b", _fn, ("a",))])
assert exc_info.value.task == "b"
assert exc_info.value.dependency == "a"
def test_cycle_detection() -> None:
with pytest.raises(CycleError):
px.Graph.from_specs([
px.TaskSpec("a", _fn, ("c",)),
px.TaskSpec("b", _fn, ("a",)),
px.TaskSpec("c", _fn, ("b",)),
])
def test_layers_grouping() -> None:
graph = px.Graph.from_specs([
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn),
px.TaskSpec("c", _fn, ("a", "b")),
px.TaskSpec("d", _fn, ("c",)),
])
layers = graph.layers()
assert layers == [["a", "b"], ["c"], ["d"]]
def test_self_dependency_rejected() -> None:
with pytest.raises(ValueError):
px.TaskSpec("a", _fn, ("a",))
def test_to_mermaid() -> None:
graph = px.Graph.from_specs([
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn, ("a",)),
])
mermaid = graph.to_mermaid()
assert mermaid.startswith("graph TD")
assert 'a["a"]' in mermaid
assert "a --> b" in mermaid
def test_to_mermaid_invalid_orientation() -> None:
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
with pytest.raises(ValueError):
graph.to_mermaid("XX")
def test_subgraph_by_tags() -> None:
graph = px.Graph.from_specs([
px.TaskSpec("a", _fn, tags=("ingest",)),
px.TaskSpec("b", _fn, ("a",), tags=("ingest",)),
px.TaskSpec("c", _fn, ("b",), tags=("report",)),
])
sub = graph.subgraph(["ingest"])
assert set(sub.names) == {"a", "b"}
# Edge to dropped task c is removed; b no longer waits for anything
# outside the subgraph (c was never a dep of b anyway).
assert sub.dependencies("b") == ("a",)
def test_subgraph_by_names() -> None:
graph = px.Graph.from_specs([
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn, ("a",)),
px.TaskSpec("c", _fn, ("b",)),
])
sub = graph.subgraph_by_names(["a", "b"])
assert set(sub.names) == {"a", "b"}
# c is dropped, so b's dep on c (none here) — but a->b edge preserved.
assert sub.dependencies("b") == ("a",)
def test_subgraph_by_names_unknown() -> None:
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
with pytest.raises(KeyError):
graph.subgraph_by_names(["nope"])
def test_describe() -> None:
graph = px.Graph.from_specs([
px.TaskSpec("a", _fn),
px.TaskSpec("b", _fn, ("a",)),
])
desc = graph.describe()
assert "Layer 1" in desc
assert "Layer 2" in desc
Generated
+3401
View File
File diff suppressed because it is too large Load Diff