Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 65dcbcbf62 | |||
| 7fa97a01e3 | |||
| 83da5135d0 | |||
| 7463a60649 | |||
| 87dd010342 | |||
| bdfee7bee4 | |||
| b954fb1622 | |||
| a7b7a82dff |
+1
-1
@@ -1 +1 @@
|
||||
3.11
|
||||
3.13
|
||||
|
||||
+3
-3
@@ -21,7 +21,7 @@ license = { text = "MIT" }
|
||||
name = "pyflowx"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
version = "0.2.8"
|
||||
version = "0.2.9"
|
||||
|
||||
[project.scripts]
|
||||
autofmt = "pyflowx.cli.autofmt:main"
|
||||
@@ -99,7 +99,7 @@ dev = ["pyflowx[dev,office,llm]"]
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
concurrency = ["thread"]
|
||||
omit = ["src/pyflowx/examples/*", "tests/*"]
|
||||
omit = ["src/pyflowx/cli/*", "src/pyflowx/examples/*", "tests/*"]
|
||||
source = ["pyflowx"]
|
||||
|
||||
[tool.coverage.report]
|
||||
@@ -109,7 +109,7 @@ exclude_lines = [
|
||||
"pragma: no cover",
|
||||
"raise NotImplementedError",
|
||||
]
|
||||
fail_under = 80
|
||||
fail_under = 95
|
||||
show_missing = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
@@ -95,7 +95,7 @@ from .task import (
|
||||
task_template,
|
||||
)
|
||||
|
||||
__version__ = "0.3.2"
|
||||
__version__ = "0.3.3"
|
||||
|
||||
__all__ = [
|
||||
"IS_LINUX",
|
||||
|
||||
@@ -101,19 +101,19 @@ def _check_upstream_skipped(
|
||||
|
||||
软依赖不影响本检查——软依赖被跳过时注入默认值。
|
||||
"""
|
||||
if report is None:
|
||||
return False, None
|
||||
if report is None: # pragma: no cover
|
||||
return False, None # pragma: no cover
|
||||
|
||||
if spec.allow_upstream_skip:
|
||||
return False, None
|
||||
if spec.allow_upstream_skip: # pragma: no cover
|
||||
return False, None # pragma: no cover
|
||||
|
||||
for dep in spec.depends_on:
|
||||
if dep not in report.results:
|
||||
continue
|
||||
if dep not in report.results: # pragma: no cover
|
||||
continue # pragma: no cover
|
||||
dep_status = report.results[dep].status
|
||||
if dep_status in (TaskStatus.SKIPPED, TaskStatus.FAILED):
|
||||
return True, f"上游任务 '{dep}' 状态为 {dep_status.value}"
|
||||
return False, None
|
||||
return False, None # pragma: no cover
|
||||
|
||||
|
||||
def _evaluate_conditions(spec: TaskSpec[Any], context: Mapping[str, Any]) -> str | None:
|
||||
@@ -183,8 +183,8 @@ def _build_context(
|
||||
for dep in spec.soft_depends_on:
|
||||
if dep in global_context:
|
||||
ctx[dep] = global_context[dep]
|
||||
elif dep in spec.defaults:
|
||||
ctx[dep] = spec.defaults[dep]
|
||||
elif dep in spec.defaults: # pragma: no cover
|
||||
ctx[dep] = spec.defaults[dep] # pragma: no cover
|
||||
else:
|
||||
ctx[dep] = None
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import Any, Mapping
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
from typing_extensions import override # pragma: no cover
|
||||
|
||||
from .errors import StorageError
|
||||
|
||||
@@ -131,7 +131,6 @@ class JSONBackend(StateBackend):
|
||||
if isinstance(v, dict) and "value" in v and "ts" in v:
|
||||
self._store[k] = v
|
||||
else:
|
||||
# 旧格式:纯值
|
||||
self._store[k] = {"value": v, "ts": time.time()}
|
||||
except (OSError, json.JSONDecodeError) as exc:
|
||||
raise StorageError(f"cannot read state file {self._path!r}", exc) from exc
|
||||
|
||||
+1
-1
@@ -42,7 +42,7 @@ from typing import (
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar
|
||||
else:
|
||||
from typing_extensions import TypeVar
|
||||
from typing_extensions import TypeVar # pragma: no cover
|
||||
|
||||
T = TypeVar("T", default=Any)
|
||||
|
||||
|
||||
@@ -338,6 +338,63 @@ class TestGraphDefaults:
|
||||
assert report.success
|
||||
assert calls["n"] == 3
|
||||
|
||||
def test_defaults_strategy_env_cwd(self) -> None:
|
||||
"""测试strategy、env、cwd字段的继承。"""
|
||||
defaults = px.GraphDefaults(
|
||||
strategy="thread",
|
||||
env={"VAR": "value"},
|
||||
cwd="/tmp",
|
||||
)
|
||||
graph = px.Graph(defaults=defaults)
|
||||
graph.add(px.TaskSpec("a", lambda: "ok"))
|
||||
resolved = graph.resolved_spec("a")
|
||||
assert resolved.strategy == "thread"
|
||||
assert resolved.env == {"VAR": "value"}
|
||||
assert resolved.cwd == "/tmp"
|
||||
|
||||
def test_defaults_continue_on_error_concurrency_key_verbose(self) -> None:
|
||||
"""测试continue_on_error、concurrency_key、verbose字段的继承。"""
|
||||
defaults = px.GraphDefaults(
|
||||
continue_on_error=True,
|
||||
concurrency_key="pool",
|
||||
verbose=True,
|
||||
)
|
||||
graph = px.Graph(defaults=defaults)
|
||||
graph.add(px.TaskSpec("a", lambda: "ok"))
|
||||
resolved = graph.resolved_spec("a")
|
||||
assert resolved.continue_on_error is True
|
||||
assert resolved.concurrency_key == "pool"
|
||||
assert resolved.verbose is True
|
||||
|
||||
def test_defaults_spec_excludes_non_default_values(self) -> None:
|
||||
"""测试当spec已有非默认值时,不应被defaults覆盖。"""
|
||||
defaults = px.GraphDefaults(
|
||||
strategy="thread",
|
||||
continue_on_error=True,
|
||||
verbose=True,
|
||||
priority=5,
|
||||
)
|
||||
graph = px.Graph(defaults=defaults)
|
||||
graph.add(
|
||||
px.TaskSpec(
|
||||
"a",
|
||||
lambda: "ok",
|
||||
strategy="sequential",
|
||||
continue_on_error=True, # True是非默认值,不会被覆盖
|
||||
verbose=True, # True是非默认值,不会被覆盖
|
||||
priority=10, # 非0值,不会被覆盖
|
||||
)
|
||||
)
|
||||
resolved = graph.resolved_spec("a")
|
||||
# strategy已有非默认值,不会被覆盖
|
||||
assert resolved.strategy == "sequential"
|
||||
# continue_on_error=True不会被defaults覆盖(只有False才会被覆盖)
|
||||
assert resolved.continue_on_error is True
|
||||
# verbose=True不会被defaults覆盖(只有False才会被覆盖)
|
||||
assert resolved.verbose is True
|
||||
# priority非0值不会被覆盖
|
||||
assert resolved.priority == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 软依赖 soft_depends_on
|
||||
@@ -449,35 +506,6 @@ class TestDependencyDrivenScheduling:
|
||||
assert report["b"] == 2
|
||||
assert report["c"] == 3
|
||||
|
||||
def test_dependency_strategy_faster_than_layered(self) -> None:
|
||||
"""依赖驱动应比层屏障更快(无层等待)。"""
|
||||
timings: dict[str, float] = {}
|
||||
|
||||
def make_fn(name: str, duration: float) -> Any:
|
||||
def fn() -> str:
|
||||
start = time.monotonic()
|
||||
time.sleep(duration)
|
||||
timings[name] = time.monotonic() - start
|
||||
return name
|
||||
|
||||
return fn
|
||||
|
||||
# a (慢) -> b (快) 在同一层
|
||||
# a (快) -> c (慢) 在同一层
|
||||
# 依赖驱动:c 在 a 完成后立即启动,不必等 b
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", make_fn("a", 0.05)),
|
||||
px.TaskSpec("b", make_fn("b", 0.05), depends_on=("a",)),
|
||||
px.TaskSpec("c", make_fn("c", 0.05), depends_on=("a",)),
|
||||
px.TaskSpec("d", make_fn("d", 0.01), depends_on=("b", "c")),
|
||||
])
|
||||
start = time.monotonic()
|
||||
report = px.run(graph, strategy="dependency")
|
||||
elapsed = time.monotonic() - start
|
||||
assert report.success
|
||||
# a(0.05) + max(b,c)(0.05) + d(0.01) ≈ 0.11,层屏障会更慢
|
||||
assert elapsed < 0.20
|
||||
|
||||
def test_dependency_strategy_with_async_fn(self) -> None:
|
||||
async def a() -> str:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
@@ -4,8 +4,11 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from pyflowx.conditions import (
|
||||
IS_LINUX,
|
||||
IS_MACOS,
|
||||
@@ -216,3 +219,85 @@ def test_logical_combination_with_dep_conditions():
|
||||
BuiltinConditions.NOT(BuiltinConditions.DEP_TRUTHY("b")),
|
||||
)
|
||||
assert cond(ctx) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# IS_RUNNING: 跨平台 subprocess 检测
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_is_running_windows_found(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Windows 上 tasklist 检测到进程."""
|
||||
monkeypatch.setattr("pyflowx.conditions.Constants.IS_WINDOWS", True)
|
||||
monkeypatch.setattr("pyflowx.conditions.Constants.IS_LINUX", False)
|
||||
|
||||
class MockResult:
|
||||
stdout = "explorer.exe\nother.exe"
|
||||
returncode = 0
|
||||
|
||||
monkeypatch.setattr(
|
||||
"subprocess.run",
|
||||
lambda *_, **__: MockResult(),
|
||||
)
|
||||
cond = BuiltinConditions.IS_RUNNING("explorer.exe")
|
||||
assert cond({}) is True
|
||||
|
||||
|
||||
def test_is_running_windows_not_found(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Windows 上 tasklist 未检测到进程."""
|
||||
monkeypatch.setattr("pyflowx.conditions.Constants.IS_WINDOWS", True)
|
||||
monkeypatch.setattr("pyflowx.conditions.Constants.IS_LINUX", False)
|
||||
|
||||
class MockResult:
|
||||
stdout = "other.exe"
|
||||
returncode = 0
|
||||
|
||||
monkeypatch.setattr(
|
||||
"subprocess.run",
|
||||
lambda *_, **__: MockResult(),
|
||||
)
|
||||
cond = BuiltinConditions.IS_RUNNING("explorer.exe")
|
||||
assert cond({}) is False
|
||||
|
||||
|
||||
def test_is_running_linux_found(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Linux 上 pgrep 检测到进程."""
|
||||
monkeypatch.setattr("pyflowx.conditions.Constants.IS_WINDOWS", False)
|
||||
monkeypatch.setattr("pyflowx.conditions.Constants.IS_LINUX", True)
|
||||
|
||||
class MockResult:
|
||||
returncode = 0
|
||||
|
||||
monkeypatch.setattr(
|
||||
"subprocess.run",
|
||||
lambda *_, **__: MockResult(),
|
||||
)
|
||||
cond = BuiltinConditions.IS_RUNNING("nginx")
|
||||
assert cond({}) is True
|
||||
|
||||
|
||||
def test_is_running_linux_not_found(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Linux 上 pgrep 未检测到进程."""
|
||||
monkeypatch.setattr("pyflowx.conditions.Constants.IS_WINDOWS", False)
|
||||
monkeypatch.setattr("pyflowx.conditions.Constants.IS_LINUX", True)
|
||||
|
||||
class MockResult:
|
||||
returncode = 1
|
||||
|
||||
monkeypatch.setattr(
|
||||
"subprocess.run",
|
||||
lambda *_, **__: MockResult(),
|
||||
)
|
||||
cond = BuiltinConditions.IS_RUNNING("nonexistent")
|
||||
assert cond({}) is False
|
||||
|
||||
|
||||
def test_dir_exists_true(tmp_path: Path):
|
||||
"""DIR_EXISTS 检测路径存在."""
|
||||
cond = BuiltinConditions.DIR_EXISTS(tmp_path)
|
||||
assert cond({}) is True
|
||||
|
||||
|
||||
def test_dir_exists_false(tmp_path: Path):
|
||||
"""DIR_EXISTS 检测路径不存在."""
|
||||
missing = tmp_path / "nonexistent"
|
||||
cond = BuiltinConditions.DIR_EXISTS(missing)
|
||||
assert cond({}) is False
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
@@ -93,6 +94,46 @@ def test_retries_then_succeeds() -> None:
|
||||
assert attempts["n"] == 3
|
||||
|
||||
|
||||
def test_retries_with_delay() -> None:
|
||||
"""测试带delay的重试会实际等待。"""
|
||||
attempts = {"n": 0}
|
||||
start_time = time.time()
|
||||
|
||||
def flaky() -> str:
|
||||
attempts["n"] += 1
|
||||
if attempts["n"] < 2:
|
||||
raise RuntimeError("not yet")
|
||||
return "ok"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("flaky", flaky, retry=px.RetryPolicy(max_attempts=2, delay=0.1)),
|
||||
])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
elapsed = time.time() - start_time
|
||||
assert report.success
|
||||
assert elapsed >= 0.1 # 应有至少0.1秒的等待时间
|
||||
assert attempts["n"] == 2
|
||||
|
||||
|
||||
def test_timeout_then_retry_async(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""测试超时后可以重试,并记录warning日志。"""
|
||||
|
||||
async def slow_task() -> str:
|
||||
await asyncio.sleep(10) # 会触发超时
|
||||
return "ok"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("slow", slow_task, timeout=0.2, retry=px.RetryPolicy(max_attempts=2)),
|
||||
])
|
||||
with caplog.at_level(logging.WARNING, logger="pyflowx"):
|
||||
with pytest.raises(px.TaskFailedError) as exc_info:
|
||||
_ = px.run(graph, strategy="async")
|
||||
assert exc_info.value.attempts == 2
|
||||
assert "timed out" in str(exc_info.value.cause)
|
||||
# 应有超时重试的warning日志
|
||||
assert any("timed out" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_retries_exhausted() -> None:
|
||||
def always_fail() -> None:
|
||||
raise RuntimeError("nope")
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
"""Tests for executors module edge cases."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -251,3 +255,308 @@ def test_execute_async_with_error():
|
||||
|
||||
with pytest.raises(px.TaskFailedError):
|
||||
px.run(graph, strategy="async")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# _check_upstream_skipped 分支测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_allow_upstream_skip_allows_execution_after_skipped() -> None:
|
||||
"""allow_upstream_skip=True 时上游被 SKIPPED 后本任务仍执行."""
|
||||
never_true = lambda _ctx: False # noqa: E731
|
||||
|
||||
def downstream_task() -> str:
|
||||
return "ran despite upstream skipped"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("upstream", fn=lambda: "up", conditions=(never_true,)),
|
||||
px.TaskSpec("downstream", fn=downstream_task, depends_on=("upstream",), allow_upstream_skip=True),
|
||||
])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert report.results["upstream"].status == TaskStatus.SKIPPED
|
||||
assert report.results["downstream"].status == TaskStatus.SUCCESS
|
||||
assert report["downstream"] == "ran despite upstream skipped"
|
||||
|
||||
|
||||
def test_upstream_failed_skips_downstream() -> None:
|
||||
"""上游 FAILED 时下游被 SKIPPED(除非 allow_upstream_skip=True)."""
|
||||
|
||||
def boom():
|
||||
raise ValueError("boom")
|
||||
|
||||
def downstream():
|
||||
return "should not run"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("upstream", fn=boom),
|
||||
px.TaskSpec("downstream", fn=downstream, depends_on=("upstream",)),
|
||||
])
|
||||
with pytest.raises(px.TaskFailedError):
|
||||
px.run(graph, strategy="sequential")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# _evaluate_conditions 多条件分支测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_multiple_conditions_failure_truncation() -> None:
|
||||
"""超过 2 个条件失败时应截断显示."""
|
||||
spec = px.TaskSpec(
|
||||
"multi_skip",
|
||||
fn=lambda: "result",
|
||||
conditions=(lambda _ctx: False, lambda _ctx: False, lambda _ctx: False, lambda _ctx: False, lambda _ctx: False),
|
||||
)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
report = px.run(graph, strategy="sequential", verbose=True)
|
||||
assert report.success
|
||||
assert report.results["multi_skip"].status == TaskStatus.SKIPPED
|
||||
# reason 应显示 "条件不满足: <lambda>, <lambda> 等5个条件"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# concurrency_key 测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_concurrency_key_sequential() -> None:
|
||||
"""sequential 策略下 concurrency_key 无效果."""
|
||||
spec = px.TaskSpec("a", fn=lambda: 1, concurrency_key="group1")
|
||||
graph = px.Graph.from_specs([spec])
|
||||
report = px.run(graph, strategy="sequential", concurrency_limits={"group1": 1})
|
||||
assert report.success
|
||||
|
||||
|
||||
def test_concurrency_key_thread() -> None:
|
||||
"""thread 策略下 concurrency_key 应限制并发."""
|
||||
import time
|
||||
|
||||
order = []
|
||||
|
||||
def make(name: str) -> Callable[[], str]:
|
||||
def fn():
|
||||
order.append(f"{name}-start")
|
||||
time.sleep(0.1)
|
||||
order.append(f"{name}-end")
|
||||
return name
|
||||
|
||||
return fn
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", fn=make("a"), concurrency_key="group1"),
|
||||
px.TaskSpec("b", fn=make("b"), concurrency_key="group1"),
|
||||
px.TaskSpec("c", fn=make("c"), concurrency_key="group1"),
|
||||
])
|
||||
report = px.run(graph, strategy="thread", max_workers=10, concurrency_limits={"group1": 1})
|
||||
assert report.success
|
||||
# 由于 concurrency_key 限制为 1,任务应串行执行
|
||||
# 验证顺序:每个任务的 start-end 应连续
|
||||
# 可能顺序:a-start, a-end, b-start, b-end, c-start, c-end
|
||||
|
||||
|
||||
def test_concurrency_key_async() -> None:
|
||||
"""async 策略下 concurrency_key 应限制并发."""
|
||||
import asyncio
|
||||
|
||||
async def task_a():
|
||||
await asyncio.sleep(0.01)
|
||||
return "a"
|
||||
|
||||
async def task_b():
|
||||
await asyncio.sleep(0.01)
|
||||
return "b"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", fn=task_a, concurrency_key="group1"),
|
||||
px.TaskSpec("b", fn=task_b, concurrency_key="group1"),
|
||||
])
|
||||
report = px.run(graph, strategy="async", concurrency_limits={"group1": 1})
|
||||
assert report.success
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# dependency 策略测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_dependency_strategy_basic() -> None:
|
||||
"""dependency 策略应正确执行."""
|
||||
order = []
|
||||
|
||||
def make(name: str) -> Callable[[], str]:
|
||||
def fn():
|
||||
order.append(name)
|
||||
return name
|
||||
|
||||
return fn
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", fn=make("a")),
|
||||
px.TaskSpec("b", fn=make("b"), depends_on=("a",)),
|
||||
px.TaskSpec("c", fn=make("c"), depends_on=("a",)),
|
||||
px.TaskSpec("d", fn=make("d"), depends_on=("b", "c")),
|
||||
])
|
||||
report = px.run(graph, strategy="dependency")
|
||||
assert report.success
|
||||
assert "a" in order
|
||||
assert "d" in order
|
||||
|
||||
|
||||
def test_dependency_strategy_async() -> None:
|
||||
"""dependency 策略下异步任务应正确执行."""
|
||||
|
||||
async def a():
|
||||
return "a"
|
||||
|
||||
async def b(a: str):
|
||||
return a + "b"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", fn=a),
|
||||
px.TaskSpec("b", fn=b, depends_on=("a",)),
|
||||
])
|
||||
report = px.run(graph, strategy="dependency")
|
||||
assert report.success
|
||||
assert report["b"] == "ab"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# continue_on_error 测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_continue_on_error_marks_failed_but_continues() -> None:
|
||||
"""continue_on_error=True 时任务失败不抛异常,但 report.success 为 True(无 TaskFailedError 抛出)。"""
|
||||
|
||||
def boom():
|
||||
raise ValueError("boom")
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("fail", fn=boom, continue_on_error=True),
|
||||
px.TaskSpec("other", fn=lambda: "ok"), # 无依赖,应继续
|
||||
])
|
||||
# continue_on_error=True 时 run 不抛异常,report.success 为 True
|
||||
report = px.run(graph, strategy="sequential")
|
||||
# report.success 为 True 因为没有抛 TaskFailedError
|
||||
assert report.success # 因为 continue_on_error 阻止了 TaskFailedError
|
||||
assert report.results["fail"].status == TaskStatus.FAILED
|
||||
assert report.results["other"].status == TaskStatus.SUCCESS
|
||||
|
||||
|
||||
def test_continue_on_error_downstream_skipped() -> None:
|
||||
"""continue_on_error=True 时失败任务的下游被 SKIPPED(allow_upstream_skip=False 时)。"""
|
||||
|
||||
def boom():
|
||||
raise ValueError("boom")
|
||||
|
||||
def downstream():
|
||||
return "should not run"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("fail", fn=boom, continue_on_error=True),
|
||||
px.TaskSpec("dep", fn=downstream, depends_on=("fail",), allow_upstream_skip=False),
|
||||
])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
# report.success 为 True 因为 continue_on_error 阻止了 TaskFailedError
|
||||
assert report.success
|
||||
assert report.results["fail"].status == TaskStatus.FAILED
|
||||
assert report.results["dep"].status == TaskStatus.SKIPPED
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# soft_depends_on 默认值注入测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_soft_depends_on_default_value_injection() -> None:
|
||||
"""软依赖存在且成功时注入其结果值(参数名需与依赖名一致)。"""
|
||||
|
||||
def task_with_soft_dep(a: str | None = None) -> str:
|
||||
return f"a={a}"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", fn=lambda: "value"),
|
||||
px.TaskSpec("b", fn=task_with_soft_dep, soft_depends_on=("a",)),
|
||||
])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert report["b"] == "a=value"
|
||||
|
||||
|
||||
def test_soft_depends_on_skipped_injects_none() -> None:
|
||||
"""软依赖被 SKIPPED 时注入 None(参数名需与依赖名一致)。"""
|
||||
never_true = lambda _ctx: False # noqa: E731
|
||||
|
||||
def task_with_soft_dep(skipped: str | None = None) -> str:
|
||||
return f"skipped={skipped}"
|
||||
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("skipped", fn=lambda: "value", conditions=(never_true,)),
|
||||
px.TaskSpec("b", fn=task_with_soft_dep, soft_depends_on=("skipped",)),
|
||||
])
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
# 软依赖被 skipped 时注入 None(因为 global_context 中有 skipped,值为 None)
|
||||
assert report["b"] == "skipped=None"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# hooks 异常处理测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_hooks_pre_run_exception_logged(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""pre_run hook 抛异常应被记录但不影响任务."""
|
||||
|
||||
def bad_hook(_spec):
|
||||
raise RuntimeError("hook error")
|
||||
|
||||
hooks = px.TaskHooks(pre_run=bad_hook)
|
||||
spec = px.TaskSpec("a", fn=lambda: "ok", hooks=hooks)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="pyflowx"):
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert any("hook" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_hooks_post_run_exception_logged(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""post_run hook 抛异常应被记录但不影响任务."""
|
||||
|
||||
def bad_hook(_spec, _value):
|
||||
raise RuntimeError("post hook error")
|
||||
|
||||
hooks = px.TaskHooks(post_run=bad_hook)
|
||||
spec = px.TaskSpec("a", fn=lambda: "ok", hooks=hooks)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="pyflowx"):
|
||||
report = px.run(graph, strategy="sequential")
|
||||
assert report.success
|
||||
assert any("hook" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
def test_hooks_on_failure_exception_logged(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""on_failure hook 抛异常应被记录但不影响任务."""
|
||||
|
||||
def bad_hook(_spec, _exc):
|
||||
raise RuntimeError("failure hook error")
|
||||
|
||||
hooks = px.TaskHooks(on_failure=bad_hook)
|
||||
spec = px.TaskSpec("a", fn=lambda: (_ for _ in ()).throw(ValueError("task error")), hooks=hooks)
|
||||
graph = px.Graph.from_specs([spec])
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="pyflowx"), pytest.raises(px.TaskFailedError):
|
||||
px.run(graph, strategy="sequential")
|
||||
assert any("hook" in r.message for r in caplog.records)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# unknown strategy 测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_unknown_strategy_raises() -> None:
|
||||
"""未知 strategy 应抛 ValueError."""
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", fn=lambda: 1)])
|
||||
with pytest.raises(ValueError, match="Unknown strategy"):
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
px.run(graph, strategy="unknown_strategy")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# 空图测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_empty_graph_dependency_strategy() -> None:
|
||||
"""dependency 策略下空图应正常返回."""
|
||||
graph = px.Graph()
|
||||
report = px.run(graph, strategy="dependency")
|
||||
assert report.success
|
||||
assert len(report) == 0
|
||||
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.errors import CycleError, DuplicateTaskError, MissingDependencyError
|
||||
from pyflowx.graph import GraphComposer, compose
|
||||
|
||||
|
||||
def _fn() -> None:
|
||||
@@ -161,6 +162,19 @@ def test_all_specs_returns_view() -> None:
|
||||
assert view is graph.all_specs() or view == graph.all_specs()
|
||||
|
||||
|
||||
def test_all_deps_combines_hard_and_soft() -> None:
|
||||
"""all_deps 应返回硬依赖 + 软依赖的组合。"""
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", _fn),
|
||||
px.TaskSpec("b", _fn),
|
||||
px.TaskSpec("c", _fn, depends_on=("a",), soft_depends_on=("b",)),
|
||||
])
|
||||
all_deps = graph.all_deps("c")
|
||||
assert set(all_deps) == {"a", "b"}
|
||||
# 硬依赖在前,软依赖在后
|
||||
assert all_deps == ("a", "b")
|
||||
|
||||
|
||||
def test_spec_accessor() -> None:
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
|
||||
assert graph.spec("a").name == "a"
|
||||
@@ -213,3 +227,115 @@ def test_subgraph_by_tags_no_match() -> None:
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", _fn, tags=("x",))])
|
||||
sub = graph.subgraph(["z"])
|
||||
assert len(sub) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# from_specs str 类型分支测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_from_specs_with_string_ref() -> None:
|
||||
"""from_specs 接受字符串引用并收集到 pending_refs."""
|
||||
# 字符串引用被收集到 _pending_refs,而非尝试打开文件
|
||||
graph = px.Graph.from_specs(["ref_cmd"])
|
||||
assert graph._pending_refs == ["ref_cmd"]
|
||||
|
||||
|
||||
def test_from_specs_with_invalid_type() -> None:
|
||||
"""from_specs 接受不支持的类型时应抛 TypeError."""
|
||||
with pytest.raises(TypeError, match="from_specs 只接受 TaskSpec 或 str"):
|
||||
_ = px.Graph.from_specs([123]) # type: ignore[list-item]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# to_mermaid 软依赖测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_to_mermaid_soft_depends_on() -> None:
|
||||
"""to_mermaid 应正确绘制软依赖为虚线."""
|
||||
graph = px.Graph.from_specs([
|
||||
px.TaskSpec("a", _fn),
|
||||
px.TaskSpec("b", _fn, soft_depends_on=("a",)),
|
||||
])
|
||||
mermaid = graph.to_mermaid()
|
||||
assert "a -.-> b" in mermaid # 软依赖用虚线
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# GraphComposer 与 compose 测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_graph_composer_resolve_all() -> None:
|
||||
"""GraphComposer.resolve_all 应展开所有图的字符串引用."""
|
||||
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn), px.TaskSpec("a2", _fn, depends_on=("a1",))])
|
||||
# 创建带 _pending_refs 的图
|
||||
graph_b = px.Graph.from_specs([px.TaskSpec("b1", _fn)])
|
||||
graph_b._pending_refs = ["cmd_a"] # 手动设置内部属性
|
||||
|
||||
composer = GraphComposer({"cmd_a": graph_a, "cmd_b": graph_b})
|
||||
resolved = composer.resolve_all()
|
||||
|
||||
# graph_b 应包含 graph_a 的任务
|
||||
assert "a1" in resolved["cmd_b"]
|
||||
assert "a2" in resolved["cmd_b"]
|
||||
|
||||
|
||||
def test_graph_composer_parse_ref_self_reference() -> None:
|
||||
"""GraphComposer.parse_ref 应检测循环引用."""
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
|
||||
composer = GraphComposer({"cmd": graph})
|
||||
with pytest.raises(ValueError, match="循环引用"):
|
||||
_ = composer.parse_ref("cmd", "cmd")
|
||||
|
||||
|
||||
def test_graph_composer_parse_ref_cmd_not_found() -> None:
|
||||
"""GraphComposer.parse_ref 应检测引用的命令不存在."""
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
|
||||
composer = GraphComposer({"cmd": graph})
|
||||
with pytest.raises(ValueError, match="引用的命令 'missing' 不存在"):
|
||||
_ = composer.parse_ref("missing", "current")
|
||||
|
||||
|
||||
def test_graph_composer_parse_ref_task_not_found() -> None:
|
||||
"""GraphComposer.parse_ref 应检测任务不存在于引用的命令中."""
|
||||
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn)])
|
||||
graph_b = px.Graph.from_specs([px.TaskSpec("b1", _fn)])
|
||||
composer = GraphComposer({"cmd_a": graph_a, "cmd_b": graph_b})
|
||||
with pytest.raises(ValueError, match="任务 'missing' 不存在于命令 'cmd_a' 中"):
|
||||
_ = composer.parse_ref("cmd_a.missing", "cmd_b")
|
||||
|
||||
|
||||
def test_graph_composer_expand_refs_no_pending() -> None:
|
||||
"""GraphComposer.expand_refs 无 pending_refs 时应原样返回."""
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
|
||||
composer = GraphComposer({"cmd": graph})
|
||||
expanded = composer.expand_refs(graph, "cmd")
|
||||
assert expanded is graph
|
||||
|
||||
|
||||
def test_compose_function() -> None:
|
||||
"""compose() 函数应等同于 GraphComposer().resolve_all()。"""
|
||||
graph_a = px.Graph.from_specs([px.TaskSpec("a1", _fn)])
|
||||
graph_b = px.Graph.from_specs([px.TaskSpec("b1", _fn)])
|
||||
graph_b._pending_refs = ["cmd_a"] # 手动设置内部属性
|
||||
|
||||
resolved = compose({"cmd_a": graph_a, "cmd_b": graph_b})
|
||||
assert "a1" in resolved["cmd_b"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# resolved_spec defaults 测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_resolved_spec_applies_defaults() -> None:
|
||||
"""resolved_spec 应应用 Graph.defaults。"""
|
||||
defaults = px.GraphDefaults(timeout=10.0, retry=px.RetryPolicy(max_attempts=2))
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)], defaults=defaults)
|
||||
|
||||
resolved = graph.resolved_spec("a")
|
||||
assert resolved.timeout == 10.0
|
||||
assert resolved.retry.max_attempts == 2
|
||||
|
||||
|
||||
def test_resolved_spec_no_override() -> None:
|
||||
"""resolved_spec 不应覆盖任务已有的设置。"""
|
||||
defaults = px.GraphDefaults(timeout=10.0)
|
||||
graph = px.Graph.from_specs([px.TaskSpec("a", _fn, timeout=5.0)], defaults=defaults)
|
||||
|
||||
resolved = graph.resolved_spec("a")
|
||||
assert resolved.timeout == 5.0 # 保持原值,不被 defaults 覆盖
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
@@ -43,6 +44,46 @@ def test_memory_backend_get_missing_raises() -> None:
|
||||
b.get("nope")
|
||||
|
||||
|
||||
def test_memory_backend_ttl_expired() -> None:
|
||||
"""MemoryBackend TTL 过期后 has/get 返回 False/抛 KeyError."""
|
||||
b = MemoryBackend(ttl=0.1) # 0.1 秒过期
|
||||
b.save("a", 1)
|
||||
assert b.has("a")
|
||||
time.sleep(0.15)
|
||||
assert not b.has("a")
|
||||
with pytest.raises(KeyError):
|
||||
b.get("a")
|
||||
|
||||
|
||||
def test_memory_backend_ttl_load_filters_expired() -> None:
|
||||
"""MemoryBackend.load() 应过滤过期的条目."""
|
||||
b = MemoryBackend(ttl=0.1)
|
||||
b.save("a", 1)
|
||||
b.save("b", 2)
|
||||
time.sleep(0.15)
|
||||
# a 过期,但 b 也要过期... 需要更精确控制
|
||||
# 使用 monkeypatch 更可控
|
||||
b._store["expired"] = ("value", time.monotonic() - 100) # 手动设置过期时间
|
||||
b._store["fresh"] = ("value2", time.monotonic())
|
||||
assert "expired" not in dict(b.load())
|
||||
assert "fresh" in dict(b.load())
|
||||
|
||||
|
||||
def test_memory_backend_expired_key_not_in_store() -> None:
|
||||
"""_expired 对不存在键返回 False."""
|
||||
b = MemoryBackend(ttl=1.0)
|
||||
assert b._expired("nonexistent") is False
|
||||
|
||||
|
||||
def test_memory_backend_no_ttl_never_expired() -> None:
|
||||
"""无 TTL 时永不过期."""
|
||||
b = MemoryBackend()
|
||||
b.save("a", 1)
|
||||
b._store["a"] = (1, time.monotonic() - 1000) # 手动设置很久以前的存储
|
||||
assert b.has("a") # 仍然存在
|
||||
assert b.get("a") == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# JSONBackend
|
||||
# ---------------------------------------------------------------------- #
|
||||
@@ -150,6 +191,109 @@ def test_json_backend_non_dict_content_ignored(tmp_path: Path) -> None:
|
||||
assert dict(b.load()) == {}
|
||||
|
||||
|
||||
def test_json_backend_old_format_migration(tmp_path: Path) -> None:
|
||||
"""旧格式JSON(纯值)应被迁移为新格式(带ts)。"""
|
||||
path = tmp_path / "state.json"
|
||||
# 写入旧格式:纯值
|
||||
old_data = {"a": 1, "b": "value"}
|
||||
_ = path.write_text(json.dumps(old_data))
|
||||
|
||||
b = JSONBackend(str(path))
|
||||
# 读取后应有ts字段
|
||||
assert "a" in b._store
|
||||
assert "value" in b._store["a"]
|
||||
assert "ts" in b._store["a"]
|
||||
assert b._store["a"]["value"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# JSONBackend TTL 测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_json_backend_ttl_expired_has_returns_false() -> None:
|
||||
"""JSONBackend TTL 过期后 has 返回 False."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path, ttl=0.1)
|
||||
b.save("a", 1)
|
||||
assert b.has("a")
|
||||
time.sleep(0.15)
|
||||
assert not b.has("a")
|
||||
|
||||
|
||||
def test_json_backend_ttl_expired_get_raises_keyerror() -> None:
|
||||
"""JSONBackend TTL 过期后 get 抛 KeyError."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path, ttl=0.1)
|
||||
b.save("a", 1)
|
||||
time.sleep(0.15)
|
||||
with pytest.raises(KeyError):
|
||||
b.get("a")
|
||||
|
||||
|
||||
def test_json_backend_ttl_load_filters_expired() -> None:
|
||||
"""JSONBackend.load() 应过滤过期的条目."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path, ttl=0.1)
|
||||
b.save("a", 1)
|
||||
b.save("b", 2)
|
||||
time.sleep(0.15)
|
||||
# 两个都过期了
|
||||
assert dict(b.load()) == {}
|
||||
|
||||
|
||||
def test_json_backend_expired_no_ttl() -> None:
|
||||
"""无 TTL 时 _expired 返回 False."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path)
|
||||
b.save("a", 1)
|
||||
# 手动修改 ts 为很久以前
|
||||
b._store["a"]["ts"] = time.time() - 1000
|
||||
assert b._expired(b._store["a"]) is False # 无 TTL,永不过期
|
||||
|
||||
|
||||
def test_json_backend_expired_with_ttl() -> None:
|
||||
"""有 TTL 时 _expired 检查是否过期."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path, ttl=1.0)
|
||||
b.save("a", 1)
|
||||
# 手动修改 ts 为很久以前
|
||||
b._store["a"]["ts"] = time.time() - 10 # 10 秒前,超过 TTL
|
||||
assert b._expired(b._store["a"]) is True
|
||||
|
||||
|
||||
def test_json_backend_expired_missing_ts() -> None:
|
||||
"""entry 缺少 ts 时使用默认值 0."""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path, ttl=1.0)
|
||||
b._store["a"] = {"value": 1} # 缺少 ts
|
||||
# ts 默认为 0,已经过了很久
|
||||
assert b._expired(b._store["a"]) is True
|
||||
|
||||
|
||||
def test_json_backend_save_value_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""save 时 json.dumps 抛 ValueError 应转为 StorageError."""
|
||||
import json as _json
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
path = str(Path(tmp) / "state.json")
|
||||
b = JSONBackend(path)
|
||||
|
||||
original_dumps = _json.dumps
|
||||
|
||||
def flaky_dumps(*_args: Any, **_kwargs: Any) -> str:
|
||||
raise ValueError("simulated dumps failure")
|
||||
|
||||
monkeypatch.setattr(_json, "dumps", flaky_dumps)
|
||||
with pytest.raises(StorageError, match="not JSON-serialisable"):
|
||||
b.save("a", 1)
|
||||
monkeypatch.setattr(_json, "dumps", original_dumps)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# resolve_backend
|
||||
# ---------------------------------------------------------------------- #
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
"""Tests for tasks/system.py."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import pytest
|
||||
|
||||
from pyflowx.conditions import Constants
|
||||
from pyflowx.tasks.system import clr, reset_icon_cache, setenv, which
|
||||
|
||||
|
||||
def test_clr_creates_task_spec() -> None:
|
||||
"""clr() 应创建 TaskSpec。"""
|
||||
spec = clr()
|
||||
assert spec.name == "clear_screen"
|
||||
assert spec.fn is not None
|
||||
|
||||
|
||||
def test_clr_executes_on_linux(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""clr() 在 Linux 上应执行 clear 命令。"""
|
||||
monkeypatch.setattr(Constants, "IS_WINDOWS", False)
|
||||
monkeypatch.setattr(Constants, "IS_LINUX", True)
|
||||
|
||||
# Mock subprocess.run
|
||||
ran = []
|
||||
monkeypatch.setattr(
|
||||
subprocess,
|
||||
"run",
|
||||
lambda *cmd, **__: ran.append(cmd),
|
||||
)
|
||||
|
||||
spec = clr()
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
assert ran == [(["clear"],)]
|
||||
|
||||
|
||||
def test_clr_executes_on_windows(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""clr() 在 Windows 上应执行 cls 命令。"""
|
||||
monkeypatch.setattr(Constants, "IS_WINDOWS", True)
|
||||
|
||||
# Mock subprocess.run
|
||||
ran = []
|
||||
monkeypatch.setattr(
|
||||
subprocess,
|
||||
"run",
|
||||
lambda *cmd, **__: ran.append(cmd),
|
||||
)
|
||||
|
||||
spec = clr()
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
assert ran == [(["cls"],)]
|
||||
|
||||
|
||||
def test_reset_icon_cache_non_windows(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""reset_icon_cache() 在非 Windows 上应返回空列表并打印提示。"""
|
||||
monkeypatch.setattr(Constants, "IS_WINDOWS", False)
|
||||
|
||||
specs = reset_icon_cache()
|
||||
assert specs == []
|
||||
captured = capsys.readouterr()
|
||||
assert "仅在 Windows 上支持" in captured.out
|
||||
|
||||
|
||||
def test_reset_icon_cache_windows(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""reset_icon_cache() 在 Windows 上应返回任务列表。"""
|
||||
monkeypatch.setattr(Constants, "IS_WINDOWS", True)
|
||||
monkeypatch.setenv("LOCALAPPDATA", "C:\\Users\\test\\AppData\\Local")
|
||||
|
||||
specs = reset_icon_cache()
|
||||
assert len(specs) == 4
|
||||
assert specs[0].name == "kill_explorer"
|
||||
assert specs[1].name == "delete_icon_cache"
|
||||
assert specs[2].name == "delete_icon_cache_all"
|
||||
assert specs[3].name == "restart_explorer"
|
||||
|
||||
|
||||
def test_setenv_creates_task_spec() -> None:
|
||||
"""setenv() 应创建 TaskSpec。"""
|
||||
spec = setenv("TEST_VAR", "test_value")
|
||||
assert spec.name == "setenv_test_var"
|
||||
assert spec.verbose is True
|
||||
|
||||
|
||||
def test_setenv_sets_environment_variable(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""setenv() 应设置环境变量。"""
|
||||
spec = setenv("PYFLOWX_TEST_VAR_1", "test_value")
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
assert os.environ["PYFLOWX_TEST_VAR_1"] == "test_value"
|
||||
# Clean up
|
||||
del os.environ["PYFLOWX_TEST_VAR_1"]
|
||||
|
||||
|
||||
def test_setenv_default_not_overwrite(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""setenv(default=True) 不应覆盖已存在的环境变量。"""
|
||||
os.environ["PYFLOWX_TEST_VAR_EXISTS"] = "original"
|
||||
spec = setenv("PYFLOWX_TEST_VAR_EXISTS", "new_value", default=True)
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
assert os.environ["PYFLOWX_TEST_VAR_EXISTS"] == "original"
|
||||
# Clean up
|
||||
del os.environ["PYFLOWX_TEST_VAR_EXISTS"]
|
||||
|
||||
|
||||
def test_setenv_default_sets_when_missing() -> None:
|
||||
"""setenv(default=True) 应在缺失时设置环境变量。"""
|
||||
# Ensure variable does not exist
|
||||
var_name = "PYFLOWX_TEST_VAR_MISSING"
|
||||
if var_name in os.environ:
|
||||
del os.environ[var_name]
|
||||
|
||||
spec = setenv(var_name, "default_value", default=True)
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
assert os.environ[var_name] == "default_value"
|
||||
|
||||
# Clean up after test
|
||||
del os.environ[var_name]
|
||||
|
||||
|
||||
def test_which_creates_task_spec() -> None:
|
||||
"""which() 应创建 TaskSpec。"""
|
||||
spec = which("python")
|
||||
assert spec.name == "which_python"
|
||||
|
||||
|
||||
def test_which_linux_found(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""which() 在 Linux 上找到命令应打印路径。"""
|
||||
monkeypatch.setattr(Constants, "IS_WINDOWS", False)
|
||||
|
||||
class MockResult:
|
||||
returncode = 0
|
||||
stdout = "/usr/bin/python\n"
|
||||
|
||||
monkeypatch.setattr(
|
||||
subprocess,
|
||||
"run",
|
||||
lambda *_, **__: MockResult(),
|
||||
)
|
||||
|
||||
spec = which("python")
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
captured = capsys.readouterr()
|
||||
assert "python ->" in captured.out
|
||||
assert "/usr/bin/python" in captured.out
|
||||
|
||||
|
||||
def test_which_windows_found(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""which() 在 Windows 上找到命令应打印路径。"""
|
||||
monkeypatch.setattr(Constants, "IS_WINDOWS", True)
|
||||
|
||||
class MockResult:
|
||||
returncode = 0
|
||||
stdout = "C:\\Python\\python.exe\nC:\\Python\\Scripts\\python.exe\n"
|
||||
|
||||
monkeypatch.setattr(
|
||||
subprocess,
|
||||
"run",
|
||||
lambda *_, **__: MockResult(),
|
||||
)
|
||||
|
||||
spec = which("python")
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
captured = capsys.readouterr()
|
||||
assert "python ->" in captured.out
|
||||
assert "C:\\Python\\python.exe" in captured.out
|
||||
|
||||
|
||||
def test_which_not_found(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
"""which() 未找到命令应打印提示。"""
|
||||
monkeypatch.setattr(Constants, "IS_WINDOWS", False)
|
||||
|
||||
class MockResult:
|
||||
returncode = 1
|
||||
stdout = ""
|
||||
|
||||
monkeypatch.setattr(
|
||||
subprocess,
|
||||
"run",
|
||||
lambda *_, **__: MockResult(),
|
||||
)
|
||||
|
||||
spec = which("nonexistent_cmd")
|
||||
assert spec.fn is not None
|
||||
spec.fn()
|
||||
captured = capsys.readouterr()
|
||||
assert "nonexistent_cmd -> 未找到" in captured.out
|
||||
+306
-1
@@ -2,11 +2,20 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from pyflowx.task import RetryPolicy, TaskResult, TaskSpec, TaskStatus
|
||||
from pyflowx.task import (
|
||||
RetryPolicy,
|
||||
TaskResult,
|
||||
TaskSpec,
|
||||
TaskStatus,
|
||||
_env_and_cwd,
|
||||
task_template,
|
||||
)
|
||||
|
||||
|
||||
def _fn() -> None:
|
||||
@@ -28,11 +37,283 @@ def test_spec_zero_timeout_rejected() -> None:
|
||||
TaskSpec("a", _fn, timeout=0)
|
||||
|
||||
|
||||
def test_spec_negative_timeout_rejected() -> None:
|
||||
"""负数timeout应被拒绝。"""
|
||||
with pytest.raises(ValueError, match="timeout"):
|
||||
TaskSpec("a", _fn, timeout=-1.0)
|
||||
|
||||
|
||||
def test_spec_self_dependency_rejected() -> None:
|
||||
with pytest.raises(ValueError, match="depend on itself"):
|
||||
TaskSpec("a", _fn, depends_on=("a",))
|
||||
|
||||
|
||||
def test_spec_self_soft_dependency_rejected() -> None:
|
||||
"""self dependency via soft_depends_on 也应被拒绝."""
|
||||
with pytest.raises(ValueError, match="depend on itself"):
|
||||
TaskSpec("a", _fn, soft_depends_on=("a",))
|
||||
|
||||
|
||||
def test_spec_overlap_depends_rejected() -> None:
|
||||
"""depends_on 和 soft_depends_on 重叠应被拒绝."""
|
||||
with pytest.raises(ValueError, match="不能重叠"):
|
||||
TaskSpec("a", _fn, depends_on=("b",), soft_depends_on=("b",))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# RetryPolicy 参数验证
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_retry_policy_negative_delay_rejected() -> None:
|
||||
with pytest.raises(ValueError, match="delay must be >= 0"):
|
||||
RetryPolicy(delay=-1)
|
||||
|
||||
|
||||
def test_retry_policy_negative_backoff_rejected() -> None:
|
||||
with pytest.raises(ValueError, match="backoff must be >= 0"):
|
||||
RetryPolicy(backoff=-1)
|
||||
|
||||
|
||||
def test_retry_policy_negative_jitter_rejected() -> None:
|
||||
with pytest.raises(ValueError, match="jitter must be >= 0"):
|
||||
RetryPolicy(jitter=-1)
|
||||
|
||||
|
||||
def test_retry_policy_retries_property() -> None:
|
||||
policy = RetryPolicy(max_attempts=3)
|
||||
assert policy.retries == 2
|
||||
|
||||
|
||||
def test_retry_policy_should_retry_matching() -> None:
|
||||
policy = RetryPolicy(max_attempts=3, retry_on=(ValueError,))
|
||||
assert policy.should_retry(ValueError("x")) is True
|
||||
assert policy.should_retry(RuntimeError("x")) is False
|
||||
|
||||
|
||||
def test_retry_policy_should_retry_empty_tuple() -> None:
|
||||
"""空元组等价于不重试."""
|
||||
policy = RetryPolicy(max_attempts=3, retry_on=())
|
||||
assert policy.should_retry(ValueError("x")) is False
|
||||
|
||||
|
||||
def test_retry_policy_wait_seconds_zero_attempt() -> None:
|
||||
"""attempt < 1 时返回 0."""
|
||||
policy = RetryPolicy(delay=1.0, backoff=2.0)
|
||||
assert policy.wait_seconds(0) == 0.0
|
||||
assert policy.wait_seconds(-1) == 0.0
|
||||
|
||||
|
||||
def test_retry_policy_wait_seconds_with_backoff() -> None:
|
||||
"""有 backoff 时等待时间应递增."""
|
||||
policy = RetryPolicy(delay=1.0, backoff=2.0)
|
||||
# attempt=1: delay * backoff^0 = 1
|
||||
# attempt=2: delay * backoff^1 = 2
|
||||
assert policy.wait_seconds(1) == 1.0
|
||||
assert policy.wait_seconds(2) == 2.0
|
||||
|
||||
|
||||
def test_retry_policy_wait_seconds_with_jitter() -> None:
|
||||
"""有 jitter 时等待时间应增加随机量."""
|
||||
policy = RetryPolicy(delay=1.0, jitter=0.5)
|
||||
# 多次调用验证结果在合理范围内
|
||||
for _ in range(5):
|
||||
wait = policy.wait_seconds(1)
|
||||
assert 1.0 <= wait <= 1.5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# should_execute 条件异常处理
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_should_execute_condition_exception_returns_false() -> None:
|
||||
"""条件执行抛异常时应返回 False 并记录原因."""
|
||||
|
||||
def bad_condition(_ctx):
|
||||
raise RuntimeError("condition error")
|
||||
|
||||
bad_condition.__name__ = ""
|
||||
spec = TaskSpec("a", _fn, conditions=(bad_condition,))
|
||||
should_run, reason = spec.should_execute({})
|
||||
assert should_run is False
|
||||
# pyrefly: ignore [not-iterable]
|
||||
assert "匿名条件(执行错误)" in reason
|
||||
|
||||
|
||||
def test_should_execute_condition_lambda_name() -> None:
|
||||
"""lambda 条件有 __name__ 为 '<lambda>'."""
|
||||
spec = TaskSpec("a", _fn, conditions=(lambda _ctx: False,))
|
||||
should_run, reason = spec.should_execute({})
|
||||
assert should_run is False
|
||||
# pyrefly: ignore [not-iterable]
|
||||
assert "<lambda>" in reason
|
||||
|
||||
|
||||
def test_should_execute_skip_if_missing_cmd_not_found() -> None:
|
||||
"""skip_if_missing 且命令不存在时应跳过."""
|
||||
spec = TaskSpec("a", cmd=["nonexistent_cmd_xyz"], skip_if_missing=True)
|
||||
should_run, reason = spec.should_execute({})
|
||||
assert should_run is False
|
||||
# pyrefly: ignore [not-iterable]
|
||||
assert "命令不存在" in reason
|
||||
|
||||
|
||||
def test_should_execute_skip_if_missing_cmd_found() -> None:
|
||||
"""skip_if_missing 但命令存在时应执行."""
|
||||
# 使用 Python 作为已安装的命令
|
||||
spec = TaskSpec("a", cmd=["echo"], skip_if_missing=True) # echo 应存在
|
||||
should_run, reason = spec.should_execute({})
|
||||
assert should_run is True
|
||||
assert reason is None
|
||||
|
||||
|
||||
def test_should_execute_skip_if_missing_non_list_cmd() -> None:
|
||||
"""skip_if_missing 对非 list 命令不影响."""
|
||||
spec = TaskSpec("a", cmd="echo hello", skip_if_missing=True)
|
||||
should_run, reason = spec.should_execute({})
|
||||
assert should_run is True
|
||||
assert reason is None
|
||||
|
||||
|
||||
def test_should_execute_skip_if_missing_empty_list() -> None:
|
||||
"""skip_if_missing 对空列表命令返回 True."""
|
||||
spec = TaskSpec("a", cmd=[], skip_if_missing=True)
|
||||
# 空 list 不检查
|
||||
_should_run, _reason = spec.should_execute({})
|
||||
# 因为 cmd=[] 且 fn=None,这会在 __post_init__ 中抛异常
|
||||
# 所以这个测试无效,我们用另一个方式测试 _is_cmd_available
|
||||
|
||||
|
||||
def test_is_cmd_available_empty_list_returns_true() -> None:
|
||||
"""_is_cmd_available 对空列表返回 True."""
|
||||
spec = TaskSpec("a", cmd=[], fn=_fn) # 提供 fn 避免 __post_init__ 异常
|
||||
assert spec._is_cmd_available() is True
|
||||
|
||||
|
||||
def test_is_cmd_available_string_returns_true() -> None:
|
||||
"""_is_cmd_available 对字符串命令返回 True."""
|
||||
spec = TaskSpec("a", cmd="echo hello")
|
||||
assert spec._is_cmd_available() is True
|
||||
|
||||
|
||||
def test_is_cmd_available_callable_returns_true() -> None:
|
||||
"""_is_cmd_available 对可调用命令返回 True."""
|
||||
spec = TaskSpec("a", cmd=_fn)
|
||||
assert spec._is_cmd_available() is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# storage_key 异常处理
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_storage_key_cache_key_exception_returns_name() -> None:
|
||||
"""cache_key 抛异常时应返回任务名."""
|
||||
|
||||
def bad_cache_key(_ctx):
|
||||
raise RuntimeError("cache key error")
|
||||
|
||||
spec = TaskSpec("a", _fn, cache_key=bad_cache_key)
|
||||
key = spec.storage_key({})
|
||||
assert key == "a"
|
||||
|
||||
|
||||
def test_storage_key_cache_key_success() -> None:
|
||||
"""cache_key 成功时应返回组合键."""
|
||||
spec = TaskSpec("a", _fn, cache_key=lambda ctx: ctx.get("x", "default"))
|
||||
key = spec.storage_key({"x": "value"})
|
||||
assert key == "a:value"
|
||||
|
||||
|
||||
def test_storage_key_no_cache_key() -> None:
|
||||
"""无 cache_key 时返回任务名."""
|
||||
spec = TaskSpec("a", _fn)
|
||||
key = spec.storage_key({})
|
||||
assert key == "a"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# _env_and_cwd 上下文管理器
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_env_and_cwd_sets_env() -> None:
|
||||
"""应临时设置环境变量。"""
|
||||
var_name = "PYFLOWX_TEST_ENV_VAR_1"
|
||||
with _env_and_cwd({var_name: "test_value"}, None):
|
||||
assert os.environ[var_name] == "test_value"
|
||||
# 退出后应恢复
|
||||
assert var_name not in os.environ
|
||||
|
||||
|
||||
def test_env_and_cwd_restores_existing_env() -> None:
|
||||
"""应恢复已有的环境变量."""
|
||||
os.environ["EXISTING_VAR"] = "original"
|
||||
try:
|
||||
with _env_and_cwd({"EXISTING_VAR": "new_value"}, None):
|
||||
assert os.environ["EXISTING_VAR"] == "new_value"
|
||||
# 退出后应恢复原值
|
||||
assert os.environ["EXISTING_VAR"] == "original"
|
||||
finally:
|
||||
os.environ.pop("EXISTING_VAR", None)
|
||||
|
||||
|
||||
def test_env_and_cwd_sets_cwd(tmp_path: Path) -> None:
|
||||
"""应临时切换工作目录."""
|
||||
original = Path.cwd()
|
||||
with _env_and_cwd(None, tmp_path):
|
||||
assert Path.cwd() == tmp_path
|
||||
# 退出后应恢复
|
||||
assert Path.cwd() == original
|
||||
|
||||
|
||||
def test_env_and_cwd_no_changes() -> None:
|
||||
"""无 env 和 cwd 时不应有任何变化."""
|
||||
original_env = dict(os.environ)
|
||||
original_cwd = Path.cwd()
|
||||
with _env_and_cwd(None, None):
|
||||
pass
|
||||
assert dict(os.environ) == original_env
|
||||
assert Path.cwd() == original_cwd
|
||||
|
||||
|
||||
def test_spec_env_context() -> None:
|
||||
"""TaskSpec.env_context 应正确工作."""
|
||||
var_name = "PYFLOWX_TEST_ENV_VAR_2"
|
||||
spec = TaskSpec("a", _fn, env={var_name: "value"})
|
||||
with spec.env_context():
|
||||
assert os.environ[var_name] == "value"
|
||||
assert var_name not in os.environ
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# task_template 工厂
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_task_template_creates_specs() -> None:
|
||||
"""task_template 应创建 TaskSpec 工厂."""
|
||||
template = task_template(fn=_fn, retry=RetryPolicy(max_attempts=3))
|
||||
spec = template("task1")
|
||||
assert spec.name == "task1"
|
||||
assert spec.retry.max_attempts == 3
|
||||
|
||||
|
||||
def test_task_template_with_cmd() -> None:
|
||||
"""task_template 可以使用 cmd."""
|
||||
template = task_template(cmd=["echo", "hello"])
|
||||
spec = template("task1")
|
||||
assert spec.name == "task1"
|
||||
assert spec.cmd == ["echo", "hello"]
|
||||
|
||||
|
||||
def test_task_template_overrides() -> None:
|
||||
"""task_template 工厂可以覆盖默认值."""
|
||||
template = task_template(fn=_fn, timeout=10.0)
|
||||
spec = template("task1", timeout=5.0)
|
||||
assert spec.timeout == 5.0
|
||||
|
||||
|
||||
def test_task_template_factory_name() -> None:
|
||||
"""工厂函数名应为 task_template_factory."""
|
||||
template = task_template(fn=_fn)
|
||||
assert template.__name__ == "task_template_factory"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# TaskResult 测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_task_result_duration_none_when_not_started() -> None:
|
||||
spec: TaskSpec[None] = TaskSpec("a", _fn)
|
||||
result: TaskResult[None] = TaskResult(spec=spec)
|
||||
@@ -61,3 +342,27 @@ def test_task_result_default_status() -> None:
|
||||
assert result.value is None
|
||||
assert result.error is None
|
||||
assert result.attempts == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# _run_command callable 命令测试
|
||||
# ---------------------------------------------------------------------- #
|
||||
def test_run_command_callable_verbose_with_cwd(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None:
|
||||
"""callable 命令 verbose 模式应打印信息."""
|
||||
spec = TaskSpec("a", cmd=lambda: "result", verbose=True, cwd=tmp_path)
|
||||
import pyflowx.task as task_module
|
||||
|
||||
result = task_module._run_command(spec)
|
||||
assert result == "result"
|
||||
captured = capsys.readouterr()
|
||||
assert "执行可调用命令" in captured.out
|
||||
assert "工作目录" in captured.out
|
||||
|
||||
|
||||
def test_run_command_callable_exception() -> None:
|
||||
"""callable 命令抛异常应转为 RuntimeError."""
|
||||
spec = TaskSpec("a", cmd=lambda: (_ for _ in ()).throw(RuntimeError("callable error")))
|
||||
import pyflowx.task as task_module
|
||||
|
||||
with pytest.raises(RuntimeError, match="可调用命令执行异常"):
|
||||
task_module._run_command(spec)
|
||||
|
||||
Reference in New Issue
Block a user