Files
pyflowx/tests/test_executors.py
T
zhou 65dcbcbf62
Release / build (push) Failing after 16m3s
Release / publish-pypi (push) Has been skipped
Release / release (push) Has been skipped
bump version to 0.2.9
2026-06-27 16:42:10 +08:00

613 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Tests for execution: sequential, thread, async, retries, timeout, resume."""
from __future__ import annotations
import asyncio
import logging
import tempfile
import threading
import time
from pathlib import Path
from typing import Any
import pytest
import pyflowx as px
from pyflowx.errors import TaskFailedError, TaskTimeoutError
from pyflowx.storage import JSONBackend, MemoryBackend
# ---------------------------------------------------------------------- #
# Sequential
# ---------------------------------------------------------------------- #
def test_sequential_basic() -> None:
def extract() -> list[int]:
return [1, 2, 3]
def double(extract: list[int]) -> list[int]:
return [x * 2 for x in extract]
graph = px.Graph.from_specs([
px.TaskSpec("extract", extract),
px.TaskSpec("double", double, depends_on=("extract",)),
])
report = px.run(graph, strategy="sequential")
assert report.success
assert report["extract"] == [1, 2, 3]
assert report["double"] == [2, 4, 6]
def test_sequential_diamond() -> None:
order: list[str] = []
def make(name: str) -> Any:
def fn() -> str:
order.append(name)
return name
return fn
graph = px.Graph.from_specs([
px.TaskSpec("a", make("a")),
px.TaskSpec("b", make("b"), depends_on=("a",)),
px.TaskSpec("c", make("c"), depends_on=("a",)),
px.TaskSpec("d", make("d"), depends_on=("b", "c")),
])
report = px.run(graph, strategy="sequential")
assert report.success
assert report["d"] == "d"
assert order == ["a", "b", "c", "d"]
def test_failure_propagates() -> None:
def boom() -> None:
raise ValueError("kaboom")
def downstream(_boom: None) -> int:
return 1
graph = px.Graph.from_specs([
px.TaskSpec("boom", boom),
px.TaskSpec("downstream", downstream, depends_on=("boom",)),
])
with pytest.raises(TaskFailedError) as exc_info:
_ = px.run(graph, strategy="sequential")
assert exc_info.value.task == "boom"
assert isinstance(exc_info.value.cause, ValueError)
def test_retries_then_succeeds() -> None:
attempts = {"n": 0}
def flaky() -> str:
attempts["n"] += 1
if attempts["n"] < 3:
raise RuntimeError("not yet")
return "ok"
graph = px.Graph.from_specs([
px.TaskSpec("flaky", flaky, retry=px.RetryPolicy(max_attempts=3)),
])
report = px.run(graph, strategy="sequential")
assert report.success
assert report["flaky"] == "ok"
assert attempts["n"] == 3
def test_retries_with_delay() -> None:
"""测试带delay的重试会实际等待。"""
attempts = {"n": 0}
start_time = time.time()
def flaky() -> str:
attempts["n"] += 1
if attempts["n"] < 2:
raise RuntimeError("not yet")
return "ok"
graph = px.Graph.from_specs([
px.TaskSpec("flaky", flaky, retry=px.RetryPolicy(max_attempts=2, delay=0.1)),
])
report = px.run(graph, strategy="sequential")
elapsed = time.time() - start_time
assert report.success
assert elapsed >= 0.1 # 应有至少0.1秒的等待时间
assert attempts["n"] == 2
def test_timeout_then_retry_async(caplog: pytest.LogCaptureFixture) -> None:
"""测试超时后可以重试,并记录warning日志。"""
async def slow_task() -> str:
await asyncio.sleep(10) # 会触发超时
return "ok"
graph = px.Graph.from_specs([
px.TaskSpec("slow", slow_task, timeout=0.2, retry=px.RetryPolicy(max_attempts=2)),
])
with caplog.at_level(logging.WARNING, logger="pyflowx"):
with pytest.raises(px.TaskFailedError) as exc_info:
_ = px.run(graph, strategy="async")
assert exc_info.value.attempts == 2
assert "timed out" in str(exc_info.value.cause)
# 应有超时重试的warning日志
assert any("timed out" in r.message for r in caplog.records)
def test_retries_exhausted() -> None:
def always_fail() -> None:
raise RuntimeError("nope")
graph = px.Graph.from_specs([
px.TaskSpec("f", always_fail, retry=px.RetryPolicy(max_attempts=3)),
])
with pytest.raises(TaskFailedError) as exc_info:
_ = px.run(graph, strategy="sequential")
assert exc_info.value.attempts == 3
# ---------------------------------------------------------------------- #
# Threaded
# ---------------------------------------------------------------------- #
@pytest.mark.slow
def test_threaded_parallelism() -> None:
def slow() -> str:
time.sleep(0.3)
return "done"
graph = px.Graph.from_specs([
px.TaskSpec("a", slow),
px.TaskSpec("b", slow),
px.TaskSpec("c", slow),
])
start = time.time()
report = px.run(graph, strategy="thread", max_workers=3)
elapsed = time.time() - start
assert report.success
# Three 0.3s tasks in parallel should be well under 1.0s.
assert elapsed < 1.0
@pytest.mark.slow
def test_threaded_layer_barrier() -> None:
finished: list[str] = []
lock = threading.Lock()
def make(name: str) -> Any:
def fn() -> str:
time.sleep(0.1)
with lock:
finished.append(name)
return name
return fn
graph = px.Graph.from_specs([
px.TaskSpec("a", make("a")),
px.TaskSpec("b", make("b")),
px.TaskSpec("c", make("c"), depends_on=("a", "b")),
])
report = px.run(graph, strategy="thread", max_workers=2)
assert report.success
# c must finish after both a and b.
assert finished.index("c") > finished.index("a")
assert finished.index("c") > finished.index("b")
# ---------------------------------------------------------------------- #
# Async
# ---------------------------------------------------------------------- #
def test_async_basic() -> None:
async def fetch() -> int:
await asyncio.sleep(0.01)
return 42
async def transform(fetch: int) -> int:
return fetch * 2
graph = px.Graph.from_specs([
px.TaskSpec("fetch", fetch),
px.TaskSpec("transform", transform, depends_on=("fetch",)),
])
report = px.run(graph, strategy="async")
assert report.success
assert report["transform"] == 84
@pytest.mark.slow
def test_async_parallelism() -> None:
async def slow() -> str:
await asyncio.sleep(0.3)
return "done"
graph = px.Graph.from_specs([px.TaskSpec("a", slow), px.TaskSpec("b", slow), px.TaskSpec("c", slow)])
start = time.time()
report = px.run(graph, strategy="async")
elapsed = time.time() - start
assert report.success
# 放宽时间限制以应对 CI 环境波动(理想 0.3s,串行约 0.9s,上限 1.5s 确保并行有效性)
assert elapsed < 1.5
def test_async_mixed_sync_and_async() -> None:
def sync_task() -> int:
return 10
async def async_task(sync_task: int) -> int:
await asyncio.sleep(0.01)
return sync_task + 5
graph = px.Graph.from_specs([
px.TaskSpec("sync_task", sync_task),
px.TaskSpec("async_task", async_task, depends_on=("sync_task",)),
])
report = px.run(graph, strategy="async")
assert report.success
assert report["async_task"] == 15
def test_async_timeout() -> None:
async def slow() -> None:
await asyncio.sleep(10)
graph = px.Graph.from_specs([px.TaskSpec("slow", slow, timeout=0.05)])
with pytest.raises(TaskFailedError) as exc_info:
_ = px.run(graph, strategy="async")
assert isinstance(exc_info.value.cause, TaskTimeoutError)
# ---------------------------------------------------------------------- #
# Dry run
# ---------------------------------------------------------------------- #
def test_dry_run_does_not_execute(capsys: pytest.CaptureFixture[str]) -> None:
called: list[str] = []
def fn() -> str:
called.append("x")
return "should-not-run"
graph = px.Graph.from_specs([px.TaskSpec("a", fn)])
report = px.run(graph, strategy="sequential", dry_run=True)
assert called == []
assert len(report) == 0
out = capsys.readouterr().out
assert "Dry run" in out
assert "Layer 1" in out
# ---------------------------------------------------------------------- #
# State / resume
# ---------------------------------------------------------------------- #
def test_memory_backend_resume() -> None:
runs: list[str] = []
def make(name: str) -> Any:
def fn() -> str:
runs.append(name)
return name
return fn
graph = px.Graph.from_specs([
px.TaskSpec("a", make("a")),
px.TaskSpec("b", make("b"), depends_on=("a",)),
])
backend = MemoryBackend()
_ = px.run(graph, strategy="sequential", state=backend)
assert runs == ["a", "b"]
# Second run: both cached, neither re-executed.
_ = px.run(graph, strategy="sequential", state=backend)
assert runs == ["a", "b"] # unchanged
def test_json_backend_persistence() -> None:
with tempfile.TemporaryDirectory() as tmp:
path = str(Path(tmp) / "state.json")
def fn() -> int:
return 7
graph = px.Graph.from_specs([px.TaskSpec("a", fn)])
_ = px.run(graph, strategy="sequential", state=JSONBackend(path))
# New backend reads the file; task should be skipped.
runs: list[str] = []
def fn2() -> int:
runs.append("ran")
return 8
graph2 = px.Graph.from_specs([px.TaskSpec("a", fn2)])
report = px.run(graph2, strategy="sequential", state=JSONBackend(path))
assert runs == []
assert report["a"] == 7 # cached value, not fn2's 8
# ---------------------------------------------------------------------- #
# Events
# ---------------------------------------------------------------------- #
def test_on_event_callback() -> None:
events: list[px.TaskEvent] = []
def fn() -> int:
return 1
graph = px.Graph.from_specs([px.TaskSpec("a", fn)])
_ = px.run(graph, strategy="sequential", on_event=events.append)
statuses = [e.status for e in events]
assert px.TaskStatus.SUCCESS in statuses
assert all(e.task == "a" for e in events)
# ---------------------------------------------------------------------- #
# 异步策略:sync 任务无 timeout 分支 + timeout 重试分支
# ---------------------------------------------------------------------- #
def test_async_sync_task_without_timeout() -> None:
"""async 策略下执行 sync 任务且无 timeout(覆盖 line 131)。"""
def sync_fn() -> int:
return 42
graph = px.Graph.from_specs([px.TaskSpec("a", sync_fn)])
report = px.run(graph, strategy="async")
assert report.success
assert report["a"] == 42
def test_async_sync_task_with_timeout() -> None:
"""async 策略下执行 sync 任务且带 timeout(覆盖 line 129)。"""
def sync_fn() -> int:
return 42
graph = px.Graph.from_specs([px.TaskSpec("a", sync_fn, timeout=5.0)])
report = px.run(graph, strategy="async")
assert report.success
assert report["a"] == 42
def test_async_timeout_retry_then_succeed() -> None:
"""async 超时后重试成功(覆盖 line 141-151 的重试分支)。"""
calls = {"n": 0}
async def flaky() -> str:
calls["n"] += 1
if calls["n"] < 2:
await asyncio.sleep(10) # 触发超时
return "ok"
graph = px.Graph.from_specs([
px.TaskSpec("a", flaky, retry=px.RetryPolicy(max_attempts=3), timeout=0.05),
])
report = px.run(graph, strategy="async")
assert report.success
assert report["a"] == "ok"
assert calls["n"] == 2
def test_async_failure_retry_branch(caplog: pytest.LogCaptureFixture) -> None:
"""async 普通异常重试分支(覆盖 line 141-151 的 except Exception 分支)。"""
calls = {"n": 0}
async def flaky() -> str:
calls["n"] += 1
if calls["n"] < 2:
raise RuntimeError("not yet")
return "ok"
graph = px.Graph.from_specs([
px.TaskSpec("a", flaky, retry=px.RetryPolicy(max_attempts=3)),
])
with caplog.at_level("WARNING", logger="pyflowx"):
report = px.run(graph, strategy="async")
assert report.success
assert report["a"] == "ok"
# 确认重试日志确实输出
assert any("retrying" in r.message for r in caplog.records)
# ---------------------------------------------------------------------- #
# 缓存跳过分支:threaded 与 async
# ---------------------------------------------------------------------- #
def test_threaded_skips_cached_tasks() -> None:
"""threaded 策略下命中缓存的任务应被跳过(覆盖 line 224-230)。"""
runs: list[str] = []
def make(name: str) -> Any:
def fn() -> str:
runs.append(name)
return name
return fn
graph = px.Graph.from_specs([
px.TaskSpec("a", make("a")),
px.TaskSpec("b", make("b"), depends_on=("a",)),
])
backend = px.MemoryBackend()
# 第一次运行填充缓存
_ = px.run(graph, strategy="thread", max_workers=2, state=backend)
assert runs == ["a", "b"]
# 第二次运行应全部跳过
_ = px.run(graph, strategy="thread", max_workers=2, state=backend)
assert runs == ["a", "b"] # 未再执行
def test_threaded_all_cached_layer() -> None:
"""整层全部命中缓存时应直接返回(覆盖 line 235 的 if not to_run: return)。"""
graph = px.Graph.from_specs([px.TaskSpec("a", lambda: 1)]) # type: ignore[arg-type]
backend = px.MemoryBackend()
backend.save("a", 99)
report = px.run(graph, strategy="thread", max_workers=2, state=backend)
assert report["a"] == 99
assert report.result_of("a").status == px.TaskStatus.SKIPPED
def test_async_skips_cached_tasks() -> None:
"""async 策略下命中缓存的任务应被跳过(覆盖 line 268-274)。"""
runs: list[str] = []
async def make(name: str) -> Any:
async def fn() -> str:
runs.append(name)
return name
return fn()
# 用闭包制造可重复调用的 async 函数
async def a() -> str:
runs.append("a")
return "a"
async def b(a: str) -> str:
runs.append("b")
return a + "b"
graph = px.Graph.from_specs([
px.TaskSpec("a", a),
px.TaskSpec("b", b, depends_on=("a",)),
])
backend = px.MemoryBackend()
_ = px.run(graph, strategy="async", state=backend)
assert runs == ["a", "b"]
_ = px.run(graph, strategy="async", state=backend)
assert runs == ["a", "b"]
def test_async_all_cached_layer() -> None:
"""async 整层全部命中缓存(覆盖 line 279 的 if not to_run: return)。"""
async def a() -> int:
return 1
graph = px.Graph.from_specs([px.TaskSpec("a", a)])
backend = px.MemoryBackend()
backend.save("a", 77)
report = px.run(graph, strategy="async", state=backend)
assert report["a"] == 77
assert report.result_of("a").status == px.TaskStatus.SKIPPED
# ---------------------------------------------------------------------- #
# 失败后 report.success 标记为 False
# ---------------------------------------------------------------------- #
def test_failure_marks_report_unsuccessful() -> None:
def boom() -> None:
raise ValueError("fail")
graph = px.Graph.from_specs([px.TaskSpec("a", boom)])
with pytest.raises(px.TaskFailedError):
_ = px.run(graph, strategy="sequential")
# report 在异常前未返回,但若捕获异常则 success 应为 False
# 这里验证 run() 抛异常的行为本身
# ---------------------------------------------------------------------- #
# dry_run 各策略
# ---------------------------------------------------------------------- #
def test_dry_run_thread(capsys: pytest.CaptureFixture[str]) -> None:
graph = px.Graph.from_specs([px.TaskSpec("a", lambda: 1)]) # type: ignore[arg-type]
report = px.run(graph, strategy="thread", dry_run=True)
assert len(report) == 0
assert "Dry run" in capsys.readouterr().out
def test_dry_run_async(capsys: pytest.CaptureFixture[str]) -> None:
async def a() -> int:
return 1
graph = px.Graph.from_specs([px.TaskSpec("a", a)])
report = px.run(graph, strategy="async", dry_run=True)
assert len(report) == 0
assert "Dry run" in capsys.readouterr().out
# ---------------------------------------------------------------------- #
# 空图运行
# ---------------------------------------------------------------------- #
def test_run_empty_graph() -> None:
graph = px.Graph()
report = px.run(graph, strategy="sequential")
assert report.success
assert len(report) == 0
# ---------------------------------------------------------------------- #
# 上游任务被 SKIPPED 后,下游任务也应被 SKIPPED
# ---------------------------------------------------------------------- #
def test_downstream_skipped_when_upstream_skipped_sequential() -> None:
"""上游任务被 SKIPPED 后,下游任务也应被 SKIPPEDsequential 策略)."""
never_true = lambda _ctx: False # noqa: E731
def downstream(upstream: str) -> str:
return upstream + "_processed"
graph = px.Graph.from_specs([
px.TaskSpec("upstream", cmd=["echo", "hello"], conditions=(never_true,)),
px.TaskSpec("downstream", downstream, depends_on=("upstream",)),
])
report = px.run(graph, strategy="sequential")
assert report.success
assert report.result_of("upstream").status == px.TaskStatus.SKIPPED
assert report.result_of("downstream").status == px.TaskStatus.SKIPPED
def test_downstream_skipped_when_upstream_skipped_thread() -> None:
"""上游任务被 SKIPPED 后,下游任务也应被 SKIPPEDthread 策略)."""
never_true = lambda _ctx: False # noqa: E731
def downstream(upstream: str) -> str:
return upstream + "_processed"
graph = px.Graph.from_specs([
px.TaskSpec("upstream", cmd=["echo", "hello"], conditions=(never_true,)),
px.TaskSpec("downstream", downstream, depends_on=("upstream",)),
])
report = px.run(graph, strategy="thread", max_workers=2)
assert report.success
assert report.result_of("upstream").status == px.TaskStatus.SKIPPED
assert report.result_of("downstream").status == px.TaskStatus.SKIPPED
def test_downstream_skipped_when_upstream_skipped_async() -> None:
"""上游任务被 SKIPPED 后,下游任务也应被 SKIPPEDasync 策略)."""
async def upstream() -> str:
return "hello"
async def downstream(upstream: str) -> str:
return upstream + "_processed"
never_true = lambda _ctx: False # noqa: E731
graph = px.Graph.from_specs([
px.TaskSpec("upstream", upstream, conditions=(never_true,)),
px.TaskSpec("downstream", downstream, depends_on=("upstream",)),
])
report = px.run(graph, strategy="async")
assert report.success
assert report.result_of("upstream").status == px.TaskStatus.SKIPPED
assert report.result_of("downstream").status == px.TaskStatus.SKIPPED
def test_downstream_executes_when_upstream_succeeds() -> None:
"""上游任务成功时,下游任务应正常执行."""
always_true = lambda _ctx: True # noqa: E731
def upstream() -> str:
return "hello"
def downstream(upstream: str) -> str:
return upstream + "_processed"
graph = px.Graph.from_specs([
px.TaskSpec("upstream", upstream, conditions=(always_true,)),
px.TaskSpec("downstream", downstream, depends_on=("upstream",)),
])
report = px.run(graph, strategy="sequential")
assert report.success
assert report.result_of("upstream").status == px.TaskStatus.SUCCESS
assert report.result_of("downstream").status == px.TaskStatus.SUCCESS
assert report["downstream"] == "hello_processed"