feat: 初始化PyFlowX轻量级DAG任务调度库

实现完整的DAG任务调度核心功能,包括:
1.  支持同步/异步/线程三种执行策略
2.  自动上下文注入,无需手动绑定任务依赖
3.  内置状态后端,支持断点续跑
4.  提供完整的测试用例与示例代码
5.  添加CI/CD配置与发布流程
This commit is contained in:
2026-06-20 10:41:33 +08:00
parent 70f3c03986
commit 8b7777d936
21 changed files with 6003 additions and 3 deletions
+75
View File
@@ -0,0 +1,75 @@
"""PyFlowX — lightweight, type-safe DAG task scheduler.
Public API
----------
* :class:`TaskSpec` — immutable task descriptor (the only thing you configure).
* :class:`Graph` — DAG built from a list of specs; validates, layers, visualises.
* :func:`run` — execute a graph with ``sequential`` / ``thread`` / ``async``.
* :class:`RunReport` — typed, queryable result of a run.
* :class:`Context` — annotation marker for whole-context injection.
* State backends: :class:`StateBackend`, :class:`MemoryBackend`, :class:`JSONBackend`.
Quick start
-----------
import pyflowx as px
def extract() -> list[int]: return [1, 2, 3]
def double(extract: list[int]) -> list[int]: return [x * 2 for x in extract]
graph = px.Graph.from_specs([
px.TaskSpec("extract", extract),
px.TaskSpec("double", double, ("extract",)),
])
report = px.run(graph, strategy="sequential")
print(report["double"]) # [2, 4, 6]
"""
from __future__ import annotations
from .context import Context, build_call_args, describe_injection
from .errors import (
CycleError,
DuplicateTaskError,
InjectionError,
MissingDependencyError,
PyFlowXError,
StorageError,
TaskFailedError,
TaskTimeoutError,
)
from .executors import run
from .graph import Graph
from .report import RunReport
from .storage import JSONBackend, MemoryBackend, StateBackend
from .task import TaskEvent, TaskResult, TaskSpec, TaskStatus
__version__ = "0.1.0"
__all__ = [
# core types
"TaskSpec",
"TaskStatus",
"TaskResult",
"TaskEvent",
"Context",
"Graph",
"RunReport",
# execution
"run",
# state backends
"StateBackend",
"MemoryBackend",
"JSONBackend",
# errors
"PyFlowXError",
"DuplicateTaskError",
"MissingDependencyError",
"CycleError",
"TaskFailedError",
"TaskTimeoutError",
"InjectionError",
"StorageError",
# helpers (advanced)
"build_call_args",
"describe_injection",
]
+203
View File
@@ -0,0 +1,203 @@
"""Context injection: turn upstream results into function arguments.
This is the mechanism that lets users write plain functions whose
parameter names *are* the dependency declarations, removing the boiler-
plate wrappers that plague other DAG libraries (e.g. ``def wrapper():
return fn(workflow.get_task_result('x'))``).
Injection rules (evaluated in order)
-----------------------------------
1. A parameter whose **annotation is** :class:`Context` receives the full
result mapping. Useful for tasks that need to iterate over all inputs.
2. A parameter whose **name matches a dependency** receives that
dependency's result.
3. A ``**kwargs`` parameter receives *all* dependency results as a dict.
4. ``TaskSpec.args`` / ``TaskSpec.kwargs`` supply static values for
parameters that are *not* dependencies.
If a parameter cannot be resolved and has no default, an
:class:`~pyflowx.errors.InjectionError` is raised with a precise message.
"""
from __future__ import annotations
import inspect
from typing import Any, Dict, List, Mapping, Set, Tuple
from .errors import InjectionError
from .task import Context, TaskSpec
__all__ = ["Context", "build_call_args", "describe_injection"]
def _is_context_annotation(annotation: Any) -> bool:
"""True when a parameter annotation is (or refers to) ``Context``.
Handles three forms:
* the ``Context`` alias object itself;
* a typing alias whose ``__name__``/``_name`` is ``Context`` or ``Mapping``;
* a *string* annotation (``from __future__ import annotations`` makes all
annotations strings at runtime) such as ``"Context"`` or ``"px.Context"``.
"""
if annotation is Context:
return True
# String annotation from `from __future__ import annotations`.
if isinstance(annotation, str):
# Match "Context", "px.Context", "pyflowx.Context", etc.
return annotation == "Context" or annotation.endswith(".Context")
# Match by qualified name to support ``from pyflowx import Context``
# re-exports.
name = getattr(annotation, "__name__", None) or getattr(annotation, "_name", None)
if name in ("Context", "Mapping"):
return True
return False
def build_call_args(
spec: TaskSpec[object],
context: Mapping[str, Any],
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
"""Resolve the ``(args, kwargs)`` to call ``spec.fn`` with.
Parameters
----------
spec:
The task spec, providing ``fn``, ``depends_on``, ``args``, ``kwargs``.
context:
Mapping of dependency-name -> result value. Only the task's own
``depends_on`` entries are guaranteed present; other tasks' results
are excluded to keep injection deterministic.
Returns
-------
(args, kwargs)
Ready to splat into ``spec.fn(*args, **kwargs)``.
Raises
------
InjectionError
If a required parameter cannot be satisfied, or if static
``kwargs`` collide with an injected dependency name.
"""
sig = inspect.signature(spec.fn)
params = sig.parameters
# Detect special parameter kinds.
var_keyword = next(
(p for p in params.values() if p.kind == inspect.Parameter.VAR_KEYWORD),
None,
)
# The subset of context relevant to this task.
dep_context: Dict[str, Any] = {
name: context[name] for name in spec.depends_on if name in context
}
# Detect collisions between static kwargs and dependency names.
collisions = set(spec.kwargs) & set(dep_context)
if collisions:
raise InjectionError(
spec.name,
f"static kwargs {sorted(collisions)} collide with dependency names; "
"rename the static kwarg or the dependency.",
)
injected_kwargs: Dict[str, Any] = {}
leftover_dep_results: Dict[str, Any] = dict(dep_context)
# Positional parameters consumed by spec.args. We track which param
# names are filled positionally so they are skipped during name-based
# injection (dependency / Context / static kwargs).
positional_params: List[str] = []
positional_kinds = (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
for pname, param in params.items():
if param.kind in positional_kinds:
positional_params.append(pname)
# The first len(spec.args) positional params are filled by spec.args.
args_filled: Set[str] = set(positional_params[: len(spec.args)])
for pname, param in params.items():
# Skip params already filled by positional spec.args.
if pname in args_filled:
continue
# Rule 1: annotated as Context -> full mapping.
if _is_context_annotation(param.annotation):
injected_kwargs[pname] = dep_context
continue
# Rule 2: name matches a dependency.
if pname in dep_context:
injected_kwargs[pname] = dep_context[pname]
leftover_dep_results.pop(pname, None)
continue
# Rule 3: handled after the loop via **kwargs.
# Rule 4: static kwargs fill the rest.
if pname in spec.kwargs:
injected_kwargs[pname] = spec.kwargs[pname]
continue
# No source for this parameter: must have a default, else error.
if param.default is inspect.Parameter.empty and param.kind not in (
inspect.Parameter.VAR_POSITIONAL,
inspect.Parameter.VAR_KEYWORD,
):
raise InjectionError(
spec.name,
f"parameter {pname!r} has no dependency, static value, or default.",
)
# Rule 3: **kwargs swallows remaining dependency results.
if var_keyword is not None and leftover_dep_results:
# Merge static kwargs first, then dependency results (static wins
# on collision — but we already rejected collisions above).
merged = dict(spec.kwargs)
merged.update(injected_kwargs)
merged.update(leftover_dep_results)
injected_kwargs = merged
return tuple(spec.args), injected_kwargs
def describe_injection(spec: TaskSpec[object]) -> str:
"""Human-readable description of how a task's args will be injected.
Used by ``dry_run`` to show the execution plan without executing it.
"""
sig = inspect.signature(spec.fn)
# Determine which positional params are filled by spec.args.
positional_params = [
p
for p, param in sig.parameters.items()
if param.kind
in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
]
args_filled = set(positional_params[: len(spec.args)])
parts = []
for pname, param in sig.parameters.items():
if pname in args_filled:
idx = positional_params.index(pname)
parts.append(f"{pname}={spec.args[idx]!r}")
elif _is_context_annotation(param.annotation):
parts.append(f"{pname}=<Context>")
elif pname in spec.depends_on:
parts.append(f"{pname}=<result:{pname}>")
elif pname in spec.kwargs:
parts.append(f"{pname}={spec.kwargs[pname]!r}")
elif param.default is not inspect.Parameter.empty:
parts.append(f"{pname}=<default>")
elif param.kind == inspect.Parameter.VAR_KEYWORD:
parts.append("**kwargs=<all-deps>")
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
parts.append("*args")
else:
parts.append(f"{pname}=<UNRESOLVED>")
return f"{spec.name}({', '.join(parts)})"
+93
View File
@@ -0,0 +1,93 @@
"""PyFlowX error hierarchy.
All errors are concrete subclasses of :class:`PyFlowXError` so callers can
catch the entire family with a single ``except`` clause, while still being
able to discriminate by type for fine-grained handling.
"""
from __future__ import annotations
from typing import Any, Iterable, Optional
class PyFlowXError(Exception):
"""Base class for every PyFlowX error."""
class DuplicateTaskError(PyFlowXError):
"""Raised when a task name is registered more than once."""
def __init__(self, name: str) -> None:
super().__init__(f"Task '{name}' is already registered in the graph.")
self.name = name
class MissingDependencyError(PyFlowXError):
"""Raised when a task depends on a name that is not in the graph."""
def __init__(self, task: str, dependency: str) -> None:
super().__init__(
f"Task '{task}' depends on unknown task '{dependency}'. "
"Add the dependency before (or together with) this task."
)
self.task = task
self.dependency = dependency
class CycleError(PyFlowXError):
"""Raised when the dependency graph contains a cycle."""
def __init__(self, cycle: Iterable[str]) -> None:
cycle_list = list(cycle)
chain = " -> ".join(cycle_list + cycle_list[:1])
super().__init__(f"The dependency graph contains a cycle: {chain}")
self.cycle = cycle_list
class TaskFailedError(PyFlowXError):
"""Raised when a task fails after exhausting all retries.
The original exception is preserved on :attr:`__cause__` and also exposed
via :attr:`cause` for convenient access in user code.
"""
def __init__(
self,
task: str,
cause: BaseException,
attempts: int,
layer: Optional[int] = None,
) -> None:
location = f" (layer {layer})" if layer is not None else ""
super().__init__(
f"Task '{task}' failed after {attempts} attempt(s){location}: {cause}"
)
self.task = task
self.cause = cause
self.attempts = attempts
self.layer = layer
class TaskTimeoutError(PyFlowXError):
"""Raised when a task exceeds its configured timeout."""
def __init__(self, task: str, timeout: float) -> None:
super().__init__(f"Task '{task}' timed out after {timeout:.3f}s.")
self.task = task
self.timeout = timeout
class InjectionError(PyFlowXError):
"""Raised when context injection cannot satisfy a task signature."""
def __init__(self, task: str, detail: str) -> None:
super().__init__(f"Cannot inject context for task '{task}': {detail}")
self.task = task
class StorageError(PyFlowXError):
"""Raised by state backends on persistence failures."""
def __init__(self, detail: str, cause: Optional[BaseException] = None) -> None:
super().__init__(f"State storage error: {detail}")
self.cause: Any = cause
+425
View File
@@ -0,0 +1,425 @@
"""Executors and the public :func:`run` entry point.
Three execution strategies share a common layer-by-layer driver:
* ``sequential`` — deterministic, one task at a time. Best for debugging.
* ``thread`` — layer-internal concurrency via a thread pool. Best for
I/O-bound sync tasks.
* ``async`` — layer-internal concurrency via ``asyncio.gather``.
Sync tasks are offloaded to a thread pool; async tasks
run on the event loop. Best for I/O-bound async tasks.
All three honour ``retries``, ``timeout``, context injection, state
backends (resume), and emit :class:`~pyflowx.task.TaskEvent` for observers.
"""
from __future__ import annotations
import asyncio
import concurrent.futures
import inspect
import logging
from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, List, Mapping, Optional, cast
from .context import build_call_args, describe_injection
from .errors import TaskFailedError, TaskTimeoutError
from .graph import Graph
from .report import RunReport
from .storage import StateBackend, resolve_backend
from .task import TaskEvent, TaskResult, TaskSpec, TaskStatus
logger = logging.getLogger("pyflowx")
# Observer callback type.
EventCallback = Callable[[TaskEvent], None]
# Strategy selector literal.
Strategy = str # "sequential" | "thread" | "async"
def _is_async_fn(spec: TaskSpec[object]) -> bool:
"""True if ``spec.fn`` is a coroutine function."""
return inspect.iscoroutinefunction(spec.fn)
def _emit(
on_event: Optional[EventCallback],
result: TaskResult[object],
) -> None:
"""Fire an observer event if a callback is registered."""
if on_event is None:
return
on_event(
TaskEvent(
task=result.spec.name,
status=result.status,
attempts=result.attempts,
error=repr(result.error) if result.error else None,
duration=result.duration,
)
)
def _run_sync_with_retry(
spec: TaskSpec[object],
context: Mapping[str, Any],
layer_idx: Optional[int],
) -> TaskResult[object]:
"""Execute a sync task with retries; return a populated TaskResult."""
result: TaskResult[object] = TaskResult(spec=spec)
result.started_at = datetime.now()
max_attempts = spec.retries + 1
args, kwargs = build_call_args(spec, context)
while result.attempts < max_attempts:
result.attempts += 1
try:
result.value = spec.fn(*args, **kwargs)
result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now()
return result
except Exception as exc: # noqa: BLE001 - user code may raise anything
result.error = exc
if result.attempts >= max_attempts:
break
logger.warning(
"task %r failed (attempt %d/%d): %r; retrying",
spec.name,
result.attempts,
max_attempts,
exc,
)
result.status = TaskStatus.FAILED
result.finished_at = datetime.now()
raise TaskFailedError(
task=spec.name,
cause=result.error if result.error is not None else RuntimeError("unknown"),
attempts=result.attempts,
layer=layer_idx,
)
async def _run_async_with_retry(
spec: TaskSpec[object],
context: Mapping[str, Any],
layer_idx: Optional[int],
) -> TaskResult[object]:
"""Execute a task (sync or async) on the event loop with retries."""
result: TaskResult[object] = TaskResult(spec=spec)
result.started_at = datetime.now()
max_attempts = spec.retries + 1
args, kwargs = build_call_args(spec, context)
loop = asyncio.get_event_loop()
while result.attempts < max_attempts:
result.attempts += 1
try:
if _is_async_fn(spec):
coro = cast(Awaitable[Any], spec.fn(*args, **kwargs))
if spec.timeout is not None:
result.value = await asyncio.wait_for(coro, timeout=spec.timeout)
else:
result.value = await coro
else:
# Offload sync work to a thread so the event loop stays alive.
fn_call: Callable[[], Any] = lambda: spec.fn(*args, **kwargs)
if spec.timeout is not None:
result.value = await asyncio.wait_for(
loop.run_in_executor(None, fn_call), timeout=spec.timeout
)
else:
result.value = await loop.run_in_executor(None, fn_call)
result.status = TaskStatus.SUCCESS
result.finished_at = datetime.now()
return result
except asyncio.TimeoutError:
result.error = TaskTimeoutError(spec.name, spec.timeout or 0.0)
if result.attempts >= max_attempts:
break
logger.warning(
"task %r timed out (attempt %d/%d); retrying",
spec.name,
result.attempts,
max_attempts,
)
except Exception as exc: # noqa: BLE001
result.error = exc
if result.attempts >= max_attempts:
break
logger.warning(
"task %r failed (attempt %d/%d): %r; retrying",
spec.name,
result.attempts,
max_attempts,
exc,
)
result.status = TaskStatus.FAILED
result.finished_at = datetime.now()
raise TaskFailedError(
task=spec.name,
cause=result.error if result.error is not None else RuntimeError("unknown"),
attempts=result.attempts,
layer=layer_idx,
)
# ---------------------------------------------------------------------- #
# Layer driver
# ---------------------------------------------------------------------- #
def _build_context(
spec: TaskSpec[object],
global_context: Mapping[str, Any],
) -> Mapping[str, Any]:
"""Restrict the global context to this task's dependencies."""
return {
dep: global_context[dep] for dep in spec.depends_on if dep in global_context
}
def _execute_layer_sequential(
layer: List[str],
graph: Graph,
context: Dict[str, Any],
report: RunReport,
backend: StateBackend,
layer_idx: int,
on_event: Optional[EventCallback],
) -> None:
"""Run a layer's tasks one by one."""
for name in layer:
spec = graph.spec(name)
if backend.has(name):
cached = backend.get(name)
context[name] = cached
result = TaskResult(spec=spec, status=TaskStatus.SKIPPED, value=cached)
report.results[name] = result
_emit(on_event, result)
logger.info("task %r skipped (cached)", name)
continue
result = _run_sync_with_retry(spec, _build_context(spec, context), layer_idx)
context[name] = result.value
backend.save(name, result.value)
report.results[name] = result
_emit(on_event, result)
def _execute_layer_threaded(
layer: List[str],
graph: Graph,
context: Dict[str, Any],
report: RunReport,
backend: StateBackend,
layer_idx: int,
on_event: Optional[EventCallback],
max_workers: int,
) -> None:
"""Run a layer's tasks concurrently in a thread pool."""
# First, satisfy cached tasks synchronously.
to_run: List[str] = []
for name in layer:
if backend.has(name):
cached = backend.get(name)
context[name] = cached
result = TaskResult(
spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached
)
report.results[name] = result
_emit(on_event, result)
else:
to_run.append(name)
if not to_run:
return
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as pool:
future_to_name: Dict[concurrent.futures.Future[TaskResult[object]], str] = {}
for name in to_run:
spec = graph.spec(name)
# Snapshot the context for this task to avoid races.
task_ctx = _build_context(spec, context)
fut = pool.submit(_run_sync_with_retry, spec, task_ctx, layer_idx)
future_to_name[fut] = name
for fut in concurrent.futures.as_completed(future_to_name):
name = future_to_name[fut]
result = fut.result() # raises TaskFailedError on failure
context[name] = result.value
backend.save(name, result.value)
report.results[name] = result
_emit(on_event, result)
async def _execute_layer_async(
layer: List[str],
graph: Graph,
context: Dict[str, Any],
report: RunReport,
backend: StateBackend,
layer_idx: int,
on_event: Optional[EventCallback],
) -> None:
"""Run a layer's tasks concurrently on the event loop."""
to_run: List[str] = []
for name in layer:
if backend.has(name):
cached = backend.get(name)
context[name] = cached
result = TaskResult(
spec=graph.spec(name), status=TaskStatus.SKIPPED, value=cached
)
report.results[name] = result
_emit(on_event, result)
else:
to_run.append(name)
if not to_run:
return
coros = []
for name in to_run:
spec = graph.spec(name)
task_ctx = _build_context(spec, context)
coros.append(_run_async_with_retry(spec, task_ctx, layer_idx))
results = await asyncio.gather(*coros)
for name, result in zip(to_run, results):
context[name] = result.value
backend.save(name, result.value)
report.results[name] = result
_emit(on_event, result)
# ---------------------------------------------------------------------- #
# Public API
# ---------------------------------------------------------------------- #
def run(
graph: Graph,
strategy: Strategy = "sequential",
*,
max_workers: Optional[int] = None,
dry_run: bool = False,
on_event: Optional[EventCallback] = None,
state: Optional[StateBackend] = None,
) -> RunReport:
"""Execute a graph and return a :class:`RunReport`.
Parameters
----------
graph:
The validated :class:`Graph` to execute.
strategy:
``"sequential"`` (default), ``"thread"``, or ``"async"``.
max_workers:
Thread-pool size for ``"thread"``. Defaults to ``min(32, len(layer))``.
dry_run:
If ``True``, print the execution plan (layers + injection) and
return an empty report without executing anything.
on_event:
Optional callback invoked on every status transition.
state:
Optional :class:`StateBackend` for resumable runs. Defaults to an
in-memory backend (no persistence across processes).
Raises
------
ValueError
If ``strategy`` is not recognised.
TaskFailedError
If any task fails after exhausting retries. The run aborts at the
failing layer; tasks in later layers are not attempted.
"""
if strategy not in ("sequential", "thread", "async"):
raise ValueError(
f"unknown strategy {strategy!r}; expected 'sequential', 'thread', or 'async'."
)
graph.validate()
layers = graph.layers()
if dry_run:
_print_dry_run(graph, layers)
return RunReport(success=True)
backend = resolve_backend(state)
report = RunReport()
context: Dict[str, Any] = {}
try:
if strategy == "sequential":
_drive_sequential(graph, layers, context, report, backend, on_event)
elif strategy == "thread":
_drive_threaded(
graph, layers, context, report, backend, on_event, max_workers
)
else:
_drive_async(graph, layers, context, report, backend, on_event)
except TaskFailedError:
report.success = False
raise
return report
def _print_dry_run(graph: Graph, layers: List[List[str]]) -> None:
"""Print the execution plan without running anything."""
print(f"Dry run: {len(graph)} tasks, {len(layers)} layers")
for idx, layer in enumerate(layers, 1):
print(f" Layer {idx}: {layer}")
for name in layer:
print(f" - {describe_injection(graph.spec(name))}")
def _drive_sequential(
graph: Graph,
layers: List[List[str]],
context: Dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: Optional[EventCallback],
) -> None:
for idx, layer in enumerate(layers, 1):
_execute_layer_sequential(layer, graph, context, report, backend, idx, on_event)
def _drive_threaded(
graph: Graph,
layers: List[List[str]],
context: Dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: Optional[EventCallback],
max_workers: Optional[int],
) -> None:
for idx, layer in enumerate(layers, 1):
workers = max_workers or max(1, min(32, len(layer)))
_execute_layer_threaded(
layer, graph, context, report, backend, idx, on_event, workers
)
def _drive_async(
graph: Graph,
layers: List[List[str]],
context: Dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: Optional[EventCallback],
) -> None:
asyncio.run(_async_drive(graph, layers, context, report, backend, on_event))
async def _async_drive(
graph: Graph,
layers: List[List[str]],
context: Dict[str, Any],
report: RunReport,
backend: StateBackend,
on_event: Optional[EventCallback],
) -> None:
for idx, layer in enumerate(layers, 1):
await _execute_layer_async(
layer, graph, context, report, backend, idx, on_event
)
+245
View File
@@ -0,0 +1,245 @@
"""DAG construction, validation, layering and visualisation.
Uses :mod:`graphlib` from the standard library (3.9+) or
:mod:`graphlib_backport` (3.8) for topological sorting. The graph is
built incrementally and validated eagerly so that misconfiguration fails
fast — at construction time, not at execution time.
"""
from __future__ import annotations
import sys
from typing import Dict, Iterable, List, Mapping, Sequence, Set, Tuple
from .errors import CycleError, DuplicateTaskError, MissingDependencyError
from .task import TaskSpec
# graphlib lives in the stdlib since 3.9; fall back to the backport on 3.8.
if sys.version_info >= (3, 9):
import graphlib
_TopologicalSorter = graphlib.TopologicalSorter
else: # pragma: no cover - exercised only on 3.8
import graphlib # type: ignore[no-redef]
_TopologicalSorter = graphlib.TopologicalSorter
class Graph:
"""An immutable-after-validation directed acyclic graph of tasks.
The graph is built by adding :class:`~pyflowx.task.TaskSpec` instances.
Each ``add`` performs eager validation (duplicate names, missing
dependencies), and :meth:`validate` / :meth:`layers` perform full DAG
validation (cycle detection) and topological layering.
The graph holds only the *configuration*; runtime state lives in
:class:`~pyflowx.report.RunReport`. This makes a graph safely
re-runnable and shareable across threads.
"""
def __init__(self) -> None:
self._specs: Dict[str, TaskSpec[object]] = {}
# Map task -> its direct dependencies (predecessors).
self._deps: Dict[str, Tuple[str, ...]] = {}
# ------------------------------------------------------------------ #
# Construction
# ------------------------------------------------------------------ #
def add(self, spec: TaskSpec[object]) -> "Graph":
"""Register a task spec with eager validation.
Returns ``self`` so calls can be chained, but the recommended
entry point is :meth:`from_specs` which validates the whole batch
together (allowing forward references in a single call).
"""
self._specs[spec.name] = spec
self._deps[spec.name] = spec.depends_on
# Eagerly check duplicates and missing deps for the incremental API.
self._validate_references()
return self
@classmethod
def from_specs(cls, specs: Iterable[TaskSpec[object]]) -> "Graph":
"""Build a graph from an iterable of task specs.
All specs are collected first, then validated together. This means
a task may reference a dependency that appears *later* in the
iterable — order does not matter, mirroring how a declarative
config file reads.
"""
graph = cls()
for spec in specs:
if spec.name in graph._specs:
raise DuplicateTaskError(spec.name)
graph._specs[spec.name] = spec
graph._deps[spec.name] = spec.depends_on
graph._validate_references()
graph.validate()
return graph
# ------------------------------------------------------------------ #
# Validation
# ------------------------------------------------------------------ #
def _validate_references(self) -> None:
"""Ensure every dependency name exists in the graph."""
for name, deps in self._deps.items():
for dep in deps:
if dep not in self._specs:
raise MissingDependencyError(name, dep)
def validate(self) -> None:
"""Run full DAG validation.
Raises :class:`~pyflowx.errors.CycleError` if a cycle exists.
Dependency existence is checked by :meth:`_validate_references`.
"""
self._validate_references()
sorter = _TopologicalSorter(self._deps)
try:
# prepare() raises CycleError on cycles; we don't need the
# static_order() result here, just the validation side effect.
sorter.prepare()
except graphlib.CycleError as exc:
# exc.args[1] is the list of nodes forming the cycle.
cycle: Sequence[str] = exc.args[1] if len(exc.args) > 1 else []
raise CycleError(list(cycle)) from exc
# ------------------------------------------------------------------ #
# Introspection
# ------------------------------------------------------------------ #
@property
def names(self) -> List[str]:
"""All registered task names (insertion order)."""
return list(self._specs.keys())
def spec(self, name: str) -> TaskSpec[object]:
"""Return the spec for ``name``; ``KeyError`` if absent."""
return self._specs[name]
def dependencies(self, name: str) -> Tuple[str, ...]:
"""Direct predecessors of ``name``."""
return self._deps[name]
def all_specs(self) -> Mapping[str, TaskSpec[object]]:
"""Read-only view of name -> spec."""
return self._specs
def layers(self) -> List[List[str]]:
"""Group tasks into parallel-executable layers (Kahn's algorithm).
Tasks within the same layer have no mutual dependencies and may
run concurrently. Layers are returned in execution order.
Raises :class:`~pyflowx.errors.CycleError` if the graph is cyclic.
"""
self.validate()
sorter = _TopologicalSorter(self._deps)
result: List[List[str]] = []
# ``get_ready`` + ``done`` gives us one layer at a time, which is
# exactly the parallel-execution grouping we need.
sorter.prepare()
while sorter.is_active():
ready = list(sorter.get_ready())
# Sort for deterministic, reproducible execution plans.
ready.sort()
result.append(ready)
for node in ready:
sorter.done(node)
return result
# ------------------------------------------------------------------ #
# Subgraph / tag filtering
# ------------------------------------------------------------------ #
def subgraph(self, tags: Iterable[str]) -> "Graph":
"""Return a new graph containing only tasks matching any tag.
Dependencies are pruned to keep only edges between retained tasks;
edges to dropped tasks are removed (the retained task no longer
waits for them). Use this to run a slice of a large DAG for
debugging.
"""
wanted: Set[str] = set(tags)
kept: List[TaskSpec[object]] = []
for spec in self._specs.values():
if wanted & set(spec.tags):
pruned_deps = tuple(
d for d in spec.depends_on if d in self._specs and (wanted & set(self._specs[d].tags))
)
kept.append(
TaskSpec(
name=spec.name,
fn=spec.fn,
depends_on=pruned_deps,
args=spec.args,
kwargs=spec.kwargs,
retries=spec.retries,
timeout=spec.timeout,
tags=spec.tags,
)
)
return Graph.from_specs(kept)
def subgraph_by_names(self, names: Iterable[str]) -> "Graph":
"""Return a new graph restricted to ``names`` (with pruned edges)."""
wanted: Set[str] = set(names)
for n in wanted:
if n not in self._specs:
raise KeyError(f"Unknown task name: {n!r}")
kept: List[TaskSpec[object]] = []
for spec in self._specs.values():
if spec.name in wanted:
pruned_deps = tuple(d for d in spec.depends_on if d in wanted)
kept.append(
TaskSpec(
name=spec.name,
fn=spec.fn,
depends_on=pruned_deps,
args=spec.args,
kwargs=spec.kwargs,
retries=spec.retries,
timeout=spec.timeout,
tags=spec.tags,
)
)
return Graph.from_specs(kept)
# ------------------------------------------------------------------ #
# Visualisation
# ------------------------------------------------------------------ #
def to_mermaid(self, orientation: str = "TD") -> str:
"""Render the DAG as a Mermaid ``graph`` definition string.
No external dependencies; the output can be pasted into Markdown,
rendered by VS Code's Mermaid previewer, or saved to a file.
"""
valid = {"TD", "TB", "BT", "LR", "RL"}
orientation = orientation.upper()
if orientation not in valid:
raise ValueError(f"Invalid orientation {orientation!r}; expected one of {sorted(valid)}.")
lines: List[str] = [f"graph {orientation}"]
for name in self._specs:
lines.append(f' {name}["{name}"]')
for name, deps in self._deps.items():
for dep in deps:
lines.append(f" {dep} --> {name}")
return "\n".join(lines) + "\n"
# ------------------------------------------------------------------ #
# Debug
# ------------------------------------------------------------------ #
def describe(self) -> str:
"""Human-readable multi-line summary for debugging."""
out: List[str] = [f"Graph(tasks={len(self._specs)})"]
for layer_idx, layer in enumerate(self.layers(), 1):
out.append(f" Layer {layer_idx}: {layer}")
return "\n".join(out)
def __repr__(self) -> str:
return f"Graph(tasks={len(self._specs)})"
def __len__(self) -> int:
return len(self._specs)
def __contains__(self, name: object) -> bool:
return name in self._specs
View File
+82
View File
@@ -0,0 +1,82 @@
"""Run report: typed, queryable result of a single :func:`pyflowx.run`.
The report is the single source of truth after execution. It exposes
per-task results via ``report["name"]`` (typed as ``Any`` because the
mapping is heterogeneous), summary statistics, and a flag indicating
whether the whole run succeeded.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, Iterator, List, Mapping, Optional
from .task import TaskResult, TaskStatus
@dataclass
class RunReport:
"""Aggregated outcome of a workflow run.
Attributes
----------
results:
Mapping of task name -> :class:`TaskResult`. Insertion order
matches the order tasks finished.
success:
``True`` iff every non-skipped task ended in ``SUCCESS``.
"""
results: Dict[str, TaskResult[object]] = field(default_factory=dict)
success: bool = True
# ---- typed access ------------------------------------------------- #
def __getitem__(self, name: str) -> Any:
"""Return the *value* of task ``name`` (not the TaskResult).
Raises ``KeyError`` if the task was not part of the run. Returns
``None`` for tasks that did not reach SUCCESS.
"""
return self.results[name].value
def result_of(self, name: str) -> TaskResult[object]:
"""Return the full :class:`TaskResult` for ``name``."""
return self.results[name]
def __contains__(self, name: object) -> bool:
return name in self.results
def __iter__(self) -> Iterator[str]:
return iter(self.results)
def __len__(self) -> int:
return len(self.results)
# ---- summary ------------------------------------------------------ #
def summary(self) -> Dict[str, Any]:
"""Compact statistics dict for logging / dashboards."""
counts: Dict[str, int] = {}
total_duration = 0.0
for r in self.results.values():
counts[r.status.value] = counts.get(r.status.value, 0) + 1
if r.duration is not None:
total_duration += r.duration
return {
"success": self.success,
"total_tasks": len(self.results),
"by_status": counts,
"total_duration_seconds": round(total_duration, 6),
}
def failed_tasks(self) -> List[str]:
"""Names of tasks that ended in FAILED status."""
return [name for name, r in self.results.items() if r.status == TaskStatus.FAILED]
def describe(self) -> str:
"""Human-readable multi-line report for debugging."""
lines: List[str] = [f"RunReport(success={self.success})"]
for name, r in self.results.items():
dur = f"{r.duration:.3f}s" if r.duration is not None else "-"
err = f" error={r.error!r}" if r.error else ""
lines.append(f" {name}: {r.status.value} ({dur} attempts={r.attempts}){err}")
return "\n".join(lines)
+135
View File
@@ -0,0 +1,135 @@
"""State persistence backends for resumable runs.
A :class:`StateBackend` stores the result of every successfully completed
task. On a subsequent run, the executor asks the backend whether a task
already has a stored result; if so, the task is skipped and its stored
value is injected into downstream tasks.
This is intentionally minimal: only *successful* results are persisted
(failed tasks are re-run), and the storage shape is a flat
``{task_name: result}`` mapping. Two backends ship in-tree:
* :class:`MemoryBackend` — fast, in-process, no I/O. Default.
* :class:`JSONBackend` — persists to a JSON file for cross-process resume.
Both are zero-dependency (``json`` is stdlib). Users can subclass
:class:`StateBackend` to plug in SQLite, Redis, etc.
"""
from __future__ import annotations
import json
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, Mapping, Optional
from .errors import StorageError
class StateBackend(ABC):
"""Abstract base for resumable state storage."""
@abstractmethod
def load(self) -> Mapping[str, Any]:
"""Return the full stored mapping (may be empty)."""
@abstractmethod
def save(self, name: str, value: Any) -> None:
"""Persist a single task's successful result."""
@abstractmethod
def has(self, name: str) -> bool:
"""Whether ``name`` has a stored result."""
@abstractmethod
def get(self, name: str) -> Any:
"""Return the stored result for ``name`` (raise ``KeyError`` if absent)."""
@abstractmethod
def clear(self) -> None:
"""Remove all stored state."""
class MemoryBackend(StateBackend):
"""In-process dict backend. Lost when the process exits."""
def __init__(self) -> None:
self._store: Dict[str, Any] = {}
def load(self) -> Mapping[str, Any]:
return dict(self._store)
def save(self, name: str, value: Any) -> None:
self._store[name] = value
def has(self, name: str) -> bool:
return name in self._store
def get(self, name: str) -> Any:
return self._store[name]
def clear(self) -> None:
self._store.clear()
class JSONBackend(StateBackend):
"""File-backed JSON storage for cross-process resume.
Results must be JSON-serialisable. Non-serialisable values raise
:class:`~pyflowx.errors.StorageError` (the run itself is not aborted;
only persistence of that one result fails).
"""
def __init__(self, path: str) -> None:
self._path = path
self._store: Dict[str, Any] = {}
self._load()
def _load(self) -> None:
if not os.path.exists(self._path):
return
try:
with open(self._path, "r", encoding="utf-8") as fh:
data = json.load(fh)
if isinstance(data, dict):
self._store = data
except (OSError, json.JSONDecodeError) as exc:
raise StorageError(f"cannot read state file {self._path!r}", exc) from exc
def _flush(self) -> None:
tmp = self._path + ".tmp"
try:
with open(tmp, "w", encoding="utf-8") as fh:
json.dump(self._store, fh, ensure_ascii=False, indent=2)
os.replace(tmp, self._path)
except (OSError, TypeError) as exc:
raise StorageError(f"cannot write state file {self._path!r}", exc) from exc
def load(self) -> Mapping[str, Any]:
return dict(self._store)
def save(self, name: str, value: Any) -> None:
# Validate serialisability before mutating in-memory state.
try:
json.dumps(value)
except (TypeError, ValueError) as exc:
raise StorageError(
f"result of task {name!r} is not JSON-serialisable", exc
) from exc
self._store[name] = value
self._flush()
def has(self, name: str) -> bool:
return name in self._store
def get(self, name: str) -> Any:
return self._store[name]
def clear(self) -> None:
self._store.clear()
self._flush()
def resolve_backend(backend: Optional[StateBackend]) -> StateBackend:
"""Return ``backend`` or a fresh :class:`MemoryBackend` if ``None``."""
return backend if backend is not None else MemoryBackend()
+151
View File
@@ -0,0 +1,151 @@
"""Core task data structures for PyFlowX.
Everything here is a plain, immutable data structure — no decorators, no
side effects. A :class:`TaskSpec` fully describes a task node; the
:class:`Graph` (see :mod:`pyflowx.graph`) consumes a list of specs and
builds the DAG.
Design notes
------------
* ``TaskSpec`` is a ``Generic[T]`` so that ``TaskSpec[int]`` carries the
return type of ``fn`` all the way to :class:`RunReport`, giving callers
typed access to ``report["name"]``.
* ``Context`` is the only intentionally-dynamic type: results from
upstream tasks are heterogeneous, so the cross-task mapping is
``Mapping[str, Any]``. Within a single task the types remain fully
static because the function signature is checked by mypy.
* ``TaskStatus`` is a closed enum; executors never invent ad-hoc strings.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import (
Any,
Callable,
Coroutine,
Generic,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
)
T = TypeVar("T")
# A task callable may be synchronous or asynchronous. We keep the union
# explicit so mypy understands both shapes.
TaskFn = Union[
Callable[..., T],
Callable[..., Coroutine[Any, Any, T]],
]
# The cross-task result mapping. Deliberately ``Any`` for values because
# different tasks return different types; per-task typing is preserved by
# the function signature itself.
Context = Mapping[str, Any]
class TaskStatus(Enum):
"""Lifecycle states of a task during a single run."""
PENDING = "pending"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
SKIPPED = "skipped" # used by resumable runs and subgraph filtering
@dataclass(frozen=True)
class TaskSpec(Generic[T]):
"""Immutable description of a single DAG node.
Parameters
----------
name:
Unique identifier of the task within a graph. Other tasks reference
this name in ``depends_on``.
fn:
The callable to execute. May be sync or async. Its parameter names
drive automatic context injection (see :mod:`pyflowx.context`).
depends_on:
Names of tasks whose results must be available before this task
runs. Order is irrelevant; the framework topologically sorts.
args:
Static positional arguments appended *after* injected parameters.
Useful for parameterised tasks (e.g. ``fetch_user(uid)``).
kwargs:
Static keyword arguments. Conflict with injected names raises
:class:`~pyflowx.errors.InjectionError`.
retries:
Number of retry attempts on failure. ``0`` means a single attempt.
timeout:
Maximum execution time in seconds. ``None`` disables the timeout.
For async tasks this uses :func:`asyncio.wait_for`; for sync tasks
in the threaded/async executors it cancels the worker future.
tags:
Free-form labels used by :meth:`Graph.subgraph` for selective
execution and debugging.
"""
name: str
fn: TaskFn[T]
depends_on: Tuple[str, ...] = ()
args: Tuple[Any, ...] = ()
kwargs: Mapping[str, Any] = field(default_factory=dict)
retries: int = 0
timeout: Optional[float] = None
tags: Tuple[str, ...] = ()
def __post_init__(self) -> None:
if not self.name:
raise ValueError("TaskSpec.name must be a non-empty string.")
if self.retries < 0:
raise ValueError(f"TaskSpec '{self.name}': retries must be >= 0.")
if self.timeout is not None and self.timeout <= 0:
raise ValueError(f"TaskSpec '{self.name}': timeout must be > 0.")
if self.name in self.depends_on:
raise ValueError(f"TaskSpec '{self.name}' cannot depend on itself.")
@dataclass
class TaskResult(Generic[T]):
"""Mutable per-task record produced during a run.
A fresh :class:`TaskResult` is created for every run; the spec itself
stays immutable. This keeps the same graph safely re-runnable.
"""
spec: TaskSpec[T]
status: TaskStatus = TaskStatus.PENDING
value: Optional[T] = None
error: Optional[BaseException] = None
attempts: int = 0
started_at: Optional[datetime] = None
finished_at: Optional[datetime] = None
@property
def duration(self) -> Optional[float]:
"""Elapsed seconds between start and finish, or ``None``."""
if self.started_at is None or self.finished_at is None:
return None
return (self.finished_at - self.started_at).total_seconds()
@dataclass(frozen=True)
class TaskEvent:
"""Immutable event emitted during execution for observers.
Passed to the ``on_event`` callback of :func:`pyflowx.run` so callers
can build progress bars, metrics, or structured logs without coupling
to executor internals.
"""
task: str
status: TaskStatus
attempts: int = 0
error: Optional[str] = None
duration: Optional[float] = None