chore: 批量优化代码与配置,完善类型注解

This commit is contained in:
2026-06-21 10:04:01 +08:00
parent 56c018e72e
commit 60083bcb6e
17 changed files with 351 additions and 357 deletions
+1 -1
View File
@@ -86,7 +86,7 @@ pythonVersion = "3.8"
reportImplicitStringConcatenation = "error" reportImplicitStringConcatenation = "error"
reportMissingTypeStubs = "none" reportMissingTypeStubs = "none"
reportUnusedCallResult = "warning" reportUnusedCallResult = "warning"
typeCheckingMode = "recommended" # 类型检查严格度:off / basic / standard / recommended(默认) / strict / all typeCheckingMode = "basic" # 类型检查严格度:off / basic / standard / recommended(默认) / strict / all
# Ruff 配置 - 与 .pre-commit-config.yaml 保持一致 # Ruff 配置 - 与 .pre-commit-config.yaml 保持一致
[tool.ruff] [tool.ruff]
+3 -3
View File
@@ -47,7 +47,7 @@ def _is_context_annotation(annotation: Any) -> bool:
def build_call_args( def build_call_args(
spec: TaskSpec[object], spec: TaskSpec[Any],
context: Mapping[str, Any], context: Mapping[str, Any],
) -> tuple[tuple[Any, ...], dict[str, Any]]: ) -> tuple[tuple[Any, ...], dict[str, Any]]:
"""解析用于调用 ``spec.fn`` 的 ``(args, kwargs)``。 """解析用于调用 ``spec.fn`` 的 ``(args, kwargs)``。
@@ -92,7 +92,7 @@ def build_call_args(
raise InjectionError( raise InjectionError(
spec.name, spec.name,
f"static kwargs {sorted(collisions)} collide with dependency names; " f"static kwargs {sorted(collisions)} collide with dependency names; "
"rename the static kwarg or the dependency.", + "rename the static kwarg or the dependency.",
) )
injected_kwargs: dict[str, Any] = {} injected_kwargs: dict[str, Any] = {}
@@ -155,7 +155,7 @@ def build_call_args(
return tuple(spec.args), injected_kwargs return tuple(spec.args), injected_kwargs
def describe_injection(spec: TaskSpec[object]) -> str: def describe_injection(spec: TaskSpec[Any]) -> str:
"""生成任务参数注入方式的人类可读描述。 """生成任务参数注入方式的人类可读描述。
供 ``dry_run`` 使用,在不执行的情况下展示执行计划。 供 ``dry_run`` 使用,在不执行的情况下展示执行计划。
+1 -1
View File
@@ -35,7 +35,7 @@ def main() -> None:
# Static positional args parameterise the same function twice. # Static positional args parameterise the same function twice.
px.TaskSpec("fetch_user", fetch_user, args=(1,)), px.TaskSpec("fetch_user", fetch_user, args=(1,)),
px.TaskSpec("fetch_posts", fetch_posts, args=(1,)), px.TaskSpec("fetch_posts", fetch_posts, args=(1,)),
px.TaskSpec("aggregate", aggregate, ("fetch_user", "fetch_posts")), px.TaskSpec("aggregate", aggregate, depends_on=("fetch_user", "fetch_posts")),
] ]
) )
+4 -2
View File
@@ -55,10 +55,12 @@ def main() -> None:
px.TaskSpec( px.TaskSpec(
"transform", "transform",
transform, transform,
("extract_customers", "extract_orders"), depends_on=("extract_customers", "extract_orders"),
tags=("transform",), tags=("transform",),
), ),
px.TaskSpec("load", load, ("transform",), retries=1, tags=("load",)), px.TaskSpec(
"load", load, depends_on=("transform",), retries=1, tags=("load",)
),
] ]
) )
+1 -1
View File
@@ -33,7 +33,7 @@ def main() -> None:
[ [
px.TaskSpec("fetch_a", fetch_a), px.TaskSpec("fetch_a", fetch_a),
px.TaskSpec("fetch_b", fetch_b), px.TaskSpec("fetch_b", fetch_b),
px.TaskSpec("merge", merge, ("fetch_a", "fetch_b")), px.TaskSpec("merge", merge, depends_on=("fetch_a", "fetch_b")),
] ]
) )
+7 -6
View File
@@ -28,6 +28,7 @@ from typing import Sequence
from .errors import PyFlowXError from .errors import PyFlowXError
from .executors import Strategy, normalize_strategy, run from .executors import Strategy, normalize_strategy, run
from .graph import Graph from .graph import Graph
from .task import TaskSpec
__all__ = ["CliExitCode", "CliRunner"] __all__ = ["CliExitCode", "CliRunner"]
@@ -58,7 +59,7 @@ def _apply_verbose_to_graph(graph: Graph, verbose: bool) -> Graph:
Graph Graph
所有 spec 的 verbose 字段已更新的新图. 所有 spec 的 verbose 字段已更新的新图.
""" """
new_specs = [] new_specs: list[TaskSpec[object]] = []
for spec in graph.all_specs().values(): for spec in graph.all_specs().values():
if spec.verbose == verbose: if spec.verbose == verbose:
new_specs.append(spec) new_specs.append(spec)
@@ -191,28 +192,28 @@ class CliRunner:
formatter_class=argparse.RawDescriptionHelpFormatter, formatter_class=argparse.RawDescriptionHelpFormatter,
epilog=self._format_commands_help(), epilog=self._format_commands_help(),
) )
parser.add_argument( _ = parser.add_argument(
"command", "command",
nargs="?", nargs="?",
help="要执行的命令", help="要执行的命令",
) )
parser.add_argument( _ = parser.add_argument(
"--strategy", "--strategy",
choices=[s.value for s in Strategy], choices=[s.value for s in Strategy],
default=self._strategy.value, default=self._strategy.value,
help="执行策略 (默认: %(default)s)", help="执行策略 (默认: %(default)s)",
) )
parser.add_argument( _ = parser.add_argument(
"--dry-run", "--dry-run",
action="store_true", action="store_true",
help="只打印执行计划, 不实际运行", help="只打印执行计划, 不实际运行",
) )
parser.add_argument( _ = parser.add_argument(
"--list", "--list",
action="store_true", action="store_true",
help="列出所有可用命令", help="列出所有可用命令",
) )
parser.add_argument( _ = parser.add_argument(
"--quiet", "--quiet",
action="store_true", action="store_true",
help="静默模式, 不显示执行过程 (覆盖默认 verbose)", help="静默模式, 不显示执行过程 (覆盖默认 verbose)",
+5 -4
View File
@@ -99,7 +99,8 @@ class JSONBackend(StateBackend):
try: try:
with open(tmp, "w", encoding="utf-8") as fh: with open(tmp, "w", encoding="utf-8") as fh:
json.dump(self._store, fh, ensure_ascii=False, indent=2) json.dump(self._store, fh, ensure_ascii=False, indent=2)
Path(tmp).replace(Path(self._path))
_ = Path(tmp).replace(Path(self._path))
except (OSError, TypeError) as exc: except (OSError, TypeError) as exc:
raise StorageError(f"cannot write state file {self._path!r}", exc) from exc raise StorageError(f"cannot write state file {self._path!r}", exc) from exc
@@ -109,13 +110,13 @@ class JSONBackend(StateBackend):
def save(self, name: str, value: Any) -> None: def save(self, name: str, value: Any) -> None:
# 在修改内存状态前先校验可序列化性。 # 在修改内存状态前先校验可序列化性。
try: try:
json.dumps(value) _ = json.dumps(value)
except (TypeError, ValueError) as exc: except (TypeError, ValueError) as exc:
raise StorageError( raise StorageError(
f"result of task {name!r} is not JSON-serialisable", exc f"result of task {name!r} is not JSON-serialisable", exc
) from exc ) from exc
self._store[name] = value self._store[name] = value
_ = self._flush() self._flush()
def has(self, name: str) -> bool: def has(self, name: str) -> bool:
return name in self._store return name in self._store
@@ -125,7 +126,7 @@ class JSONBackend(StateBackend):
def clear(self) -> None: def clear(self) -> None:
self._store.clear() self._store.clear()
_ = self._flush() self._flush()
def resolve_backend(backend: StateBackend | None) -> StateBackend: def resolve_backend(backend: StateBackend | None) -> StateBackend:
+1
View File
@@ -150,6 +150,7 @@ class TaskSpec(Generic[T]):
return self._wrap_cmd() return self._wrap_cmd()
if self.fn is not None: if self.fn is not None:
return self.fn return self.fn
raise ValueError(f"TaskSpec '{self.name}': 没有可执行的函数或命令。") raise ValueError(f"TaskSpec '{self.name}': 没有可执行的函数或命令。")
def _wrap_cmd(self) -> TaskFn[Any]: def _wrap_cmd(self) -> TaskFn[Any]:
+157 -160
View File
@@ -1,4 +1,4 @@
"""Tests for context injection rules.""" """测试上下文注入规则."""
from __future__ import annotations from __future__ import annotations
@@ -11,225 +11,222 @@ from pyflowx.context import _is_context_annotation, build_call_args, describe_in
from pyflowx.errors import InjectionError from pyflowx.errors import InjectionError
def test_inject_by_parameter_name() -> None: class TestBuildCallArgs:
def fn(a: int, b: str) -> str: """测试 build_call_args 函数."""
return f"{a}{b}"
spec = px.TaskSpec("c", fn, depends_on=("a", "b")) def test_inject_by_parameter_name(self) -> None:
args, kwargs = build_call_args(spec, {"a": 1, "b": "x"}) """参数名匹配依赖名时应注入对应结果."""
assert args == ()
assert kwargs == {"a": 1, "b": "x"}
def fn(a: int, b: str) -> str:
return f"{a}{b}"
def test_inject_context_annotation() -> None: spec = px.TaskSpec("c", fn, depends_on=("a", "b"))
def fn(ctx: px.Context) -> int: _args, kwargs = build_call_args(spec, {"a": 1, "b": "x"})
return len(ctx) assert kwargs == {"a": 1, "b": "x"}
spec = px.TaskSpec("agg", fn, depends_on=("a", "b")) def test_inject_context_annotation(self) -> None:
args, kwargs = build_call_args(spec, {"a": 1, "b": 2, "c": 99}) """标注为 Context 的参数应接收完整依赖映射."""
# Only the task's own deps are passed.
assert kwargs == {"ctx": {"a": 1, "b": 2}}
def fn(ctx: px.Context) -> int:
return len(ctx)
def test_inject_var_keyword() -> None: spec = px.TaskSpec("agg", fn, depends_on=("a", "b"))
def fn(**kwargs: Any) -> int: _args, kwargs = build_call_args(spec, {"a": 1, "b": 2, "c": 99})
return sum(kwargs.values()) # Only the task's own deps are passed.
assert kwargs == {"ctx": {"a": 1, "b": 2}}
spec = px.TaskSpec("agg", fn, depends_on=("a", "b")) def test_inject_var_keyword(self) -> None:
args, kwargs = build_call_args(spec, {"a": 1, "b": 2}) """**kwargs 参数应以 dict 形式接收所有依赖结果."""
assert kwargs == {"a": 1, "b": 2}
def fn(**kwargs: Any) -> int: # pyright: ignore[reportExplicitAny, reportAny]
return sum(kwargs.values())
def test_static_args_and_kwargs() -> None: spec = px.TaskSpec("agg", fn, depends_on=("a", "b"))
def fn(uid: int, source: str) -> str: _args, kwargs = build_call_args(spec, {"a": 1, "b": 2})
return f"{source}:{uid}" assert kwargs == {"a": 1, "b": 2}
spec = px.TaskSpec("fetch", fn, args=(42,), kwargs={"source": "api"}) def test_static_args_and_kwargs(self) -> None:
args, kwargs = build_call_args(spec, {}) """静态 args/kwargs 应正确填充非依赖参数."""
assert args == (42,)
assert kwargs == {"source": "api"}
def fn(uid: int, source: str) -> str:
return f"{source}:{uid}"
def test_default_param_not_required() -> None: spec = px.TaskSpec("fetch", fn, args=(42,), kwargs={"source": "api"})
def fn(a: int, flag: bool = True) -> int: args, kwargs = build_call_args(spec, {})
return a if flag else 0 assert args == (42,)
assert kwargs == {"source": "api"}
spec = px.TaskSpec("t", fn, depends_on=("a",)) def test_default_param_not_required(self) -> None:
args, kwargs = build_call_args(spec, {"a": 5}) """有默认值的参数无需依赖或静态值."""
assert kwargs == {"a": 5}
def fn(a: int, flag: bool = True) -> int:
return a if flag else 0
def test_unresolved_required_param_raises() -> None: spec = px.TaskSpec("t", fn, depends_on=("a",))
def fn(a: int, missing: str) -> None: _args, kwargs = build_call_args(spec, {"a": 5})
return None assert kwargs == {"a": 5}
spec = px.TaskSpec("t", fn, depends_on=("a",)) def test_unresolved_required_param_raises(self) -> None:
with pytest.raises(InjectionError) as exc_info: """必需参数无法解析时应抛出 InjectionError."""
build_call_args(spec, {"a": 1})
assert "missing" in str(exc_info.value)
def fn(_a: int, _: str) -> None:
return None
def test_static_kwargs_collide_with_dependency() -> None: spec = px.TaskSpec("t", fn, depends_on=("a",))
def fn(a: int) -> int: with pytest.raises(InjectionError) as exc_info:
return a _ = build_call_args(spec, {"a": 1})
assert "Cannot inject" in str(exc_info.value)
spec = px.TaskSpec("t", fn, depends_on=("a",), kwargs={"a": 99}) def test_static_kwargs_collide_with_dependency(self) -> None:
with pytest.raises(InjectionError): """静态 kwargs 与依赖名冲突时应抛出 InjectionError."""
build_call_args(spec, {"a": 1})
def fn(a: int) -> int:
return a
def test_describe_injection() -> None: spec = px.TaskSpec("t", fn, depends_on=("a",), kwargs={"a": 99})
def fn(a: int, ctx: px.Context, flag: bool = False) -> None: with pytest.raises(InjectionError):
return None _ = build_call_args(spec, {"a": 1})
spec = px.TaskSpec("t", fn, depends_on=("a",)) def test_var_positional_not_required(self) -> None:
desc = describe_injection(spec) """*args 参数不应触发 InjectionError."""
assert "a=<result:a>" in desc
assert "ctx=<Context>" in desc
assert "flag=<default>" in desc
def fn(*args: Any) -> int: # pyright: ignore[reportExplicitAny, reportAny]
return len(args)
# ---------------------------------------------------------------------- # spec = px.TaskSpec("t", fn, args=(1, 2, 3))
# _is_context_annotation 各分支 args, kwargs = build_call_args(spec, {})
# ---------------------------------------------------------------------- # assert args == (1, 2, 3)
def test_is_context_annotation_direct_object() -> None: assert kwargs == {}
"""直接传入 Context 别名对象应返回 True。"""
assert _is_context_annotation(px.Context) is True
def test_var_keyword_consumes_leftover(self) -> None:
"""**kwargs 应吞掉未被具名参数消费的依赖结果."""
def test_is_context_annotation_string() -> None: def fn(a: int, **rest: Any) -> int: # pyright: ignore[reportExplicitAny, reportAny]
"""字符串形式的注解应被识别。""" return a + sum(rest.values())
assert _is_context_annotation("Context") is True
assert _is_context_annotation("px.Context") is True
assert _is_context_annotation("pyflowx.Context") is True
assert _is_context_annotation("NotContext") is False
assert _is_context_annotation("int") is False
spec = px.TaskSpec("t", fn, depends_on=("a", "b", "c"))
_args, kwargs = build_call_args(spec, {"a": 1, "b": 2, "c": 3})
assert kwargs == {"a": 1, "b": 2, "c": 3}
def test_is_context_annotation_typing_alias() -> None: def test_no_var_keyword_drops_leftover(self) -> None:
"""具有 __name__/_name 为 Context/Mapping 的 typing 别名应返回 True。""" """无 **kwargs 时,未被消费的依赖结果被丢弃(不报错)."""
class FakeAlias: def fn(a: int) -> int:
__name__ = "Context" return a
assert _is_context_annotation(FakeAlias()) is True spec = px.TaskSpec("t", fn, depends_on=("a", "b"))
# b 是依赖但 fn 不接收它 —— 应正常工作
_args, kwargs = build_call_args(spec, {"a": 1, "b": 2})
assert kwargs == {"a": 1}
class FakeMapping: def test_context_annotation_only_deps(self) -> None:
__name__ = "Mapping" """Context 标注只接收该任务自身 depends_on 的结果."""
assert _is_context_annotation(FakeMapping()) is True def fn(ctx: px.Context) -> int:
return len(ctx)
spec = px.TaskSpec("t", fn, depends_on=("a", "b"))
_args, kwargs = build_call_args(spec, {"a": 1, "b": 2, "c": 99})
assert kwargs == {"ctx": {"a": 1, "b": 2}}
def test_is_context_annotation_other() -> None:
"""其他类型注解应返回 False。"""
assert _is_context_annotation(int) is False
assert _is_context_annotation(str) is False
assert _is_context_annotation(None) is False
class TestDescribeInjection:
"""测试 describe_injection 函数."""
# ---------------------------------------------------------------------- # def test_describe_injection(self) -> None:
# describe_injection 其余分支 """应正确描述依赖注入、Context 标注和默认值."""
# ---------------------------------------------------------------------- #
def test_describe_injection_var_positional() -> None:
"""*args 参数应显示为 *args。"""
def fn(*args: Any) -> None: def fn(a: int, ctx: px.Context, flag: bool = False) -> None: # noqa: ARG001
return None return None
spec = px.TaskSpec("t", fn) spec = px.TaskSpec("t", fn, depends_on=("a",))
desc = describe_injection(spec) desc = describe_injection(spec)
assert "*args" in desc assert "a=<result:a>" in desc
assert "ctx=<Context>" in desc
assert "flag=<default>" in desc
def test_var_positional(self) -> None:
"""*args 参数应显示为 *args."""
def test_describe_injection_var_keyword() -> None: def fn(*args: Any) -> None: # noqa: ARG001
"""**kwargs 参数应显示为 **kwargs=<all-deps>。""" return None
def fn(**kwargs: Any) -> None: spec = px.TaskSpec("t", fn)
return None desc = describe_injection(spec)
assert "*args" in desc
spec = px.TaskSpec("t", fn, depends_on=("a",)) def test_var_keyword(self) -> None:
desc = describe_injection(spec) """**kwargs 参数应显示为 **kwargs=<all-deps>."""
assert "**kwargs=<all-deps>" in desc
def fn(**kwargs: Any) -> None: # pyright: ignore[reportExplicitAny, reportAny] # noqa: ARG001
return None
def test_describe_injection_unresolved() -> None: spec = px.TaskSpec("t", fn, depends_on=("a",))
"""无依赖、无静态值、无默认的参数应显示为 <UNRESOLVED>。""" desc = describe_injection(spec)
assert "**kwargs=<all-deps>" in desc
def fn(missing: int) -> None: def test_unresolved(self) -> None:
return None """无依赖、无静态值、无默认的参数应显示为 <UNRESOLVED>."""
spec = px.TaskSpec("t", fn) def fn(missing: int) -> None: # noqa: ARG001
desc = describe_injection(spec) return None
assert "missing=<UNRESOLVED>" in desc
spec = px.TaskSpec("t", fn)
desc = describe_injection(spec)
assert "missing=<UNRESOLVED>" in desc
def test_describe_injection_static_kwargs() -> None: def test_static_kwargs(self) -> None:
"""静态 kwargs 应显示具体值""" """静态 kwargs 应显示具体值."""
def fn(flag: bool = False) -> None: def fn(flag: bool = False) -> None: # noqa: ARG001
return None return None
spec = px.TaskSpec("t", fn, kwargs={"flag": True}) spec = px.TaskSpec("t", fn, kwargs={"flag": True})
desc = describe_injection(spec) desc = describe_injection(spec)
assert "flag=True" in desc assert "flag=True" in desc
def test_positional_args_filled(self) -> None:
"""spec.args 填充的位置参数应显示具体值(覆盖 args_filled 分支)."""
def test_describe_injection_positional_args_filled() -> None: def fn(a: int, b: str) -> None: # noqa: ARG001
"""spec.args 填充的位置参数应显示具体值(覆盖 args_filled 分支)。""" return None
def fn(a: int, b: str) -> None: spec = px.TaskSpec("t", fn, args=(1, "x"))
return None desc = describe_injection(spec)
assert "a=1" in desc
assert "b='x'" in desc
spec = px.TaskSpec("t", fn, args=(1, "x"))
desc = describe_injection(spec)
assert "a=1" in desc
assert "b='x'" in desc
class TestIsContextAnnotation:
"""测试 _is_context_annotation 函数."""
# ---------------------------------------------------------------------- # def test_direct_object(self) -> None:
# build_call_args 边界 """直接传入 Context 别名对象应返回 True."""
# ---------------------------------------------------------------------- # assert _is_context_annotation(px.Context) is True
def test_build_call_args_var_positional_not_required() -> None:
"""*args 参数不应触发 InjectionError。"""
def fn(*args: Any) -> int: def test_string(self) -> None:
return len(args) """字符串形式的注解应被识别."""
assert _is_context_annotation("Context") is True
assert _is_context_annotation("px.Context") is True
assert _is_context_annotation("pyflowx.Context") is True
assert _is_context_annotation("NotContext") is False
assert _is_context_annotation("int") is False
spec = px.TaskSpec("t", fn, args=(1, 2, 3)) def test_typing_alias(self) -> None:
args, kwargs = build_call_args(spec, {}) """具有 __name__/_name 为 Context/Mapping 的 typing 别名应返回 True."""
assert args == (1, 2, 3)
assert kwargs == {}
class FakeAlias:
__name__ = "Context"
def test_build_call_args_var_keyword_consumes_leftover() -> None: assert _is_context_annotation(FakeAlias()) is True
"""**kwargs 应吞掉未被具名参数消费的依赖结果。"""
def fn(a: int, **rest: Any) -> int: class FakeMapping:
return a + sum(rest.values()) __name__ = "Mapping"
spec = px.TaskSpec("t", fn, depends_on=("a", "b", "c")) assert _is_context_annotation(FakeMapping()) is True
args, kwargs = build_call_args(spec, {"a": 1, "b": 2, "c": 3})
assert kwargs == {"a": 1, "b": 2, "c": 3}
def test_other(self) -> None:
def test_build_call_args_no_var_keyword_drops_leftover() -> None: """其他类型注解应返回 False."""
"""无 **kwargs 时,未被消费的依赖结果被丢弃(不报错)。""" assert _is_context_annotation(int) is False
assert _is_context_annotation(str) is False
def fn(a: int) -> int: assert _is_context_annotation(None) is False
return a
spec = px.TaskSpec("t", fn, depends_on=("a", "b"))
# b 是依赖但 fn 不接收它 —— 应正常工作
args, kwargs = build_call_args(spec, {"a": 1, "b": 2})
assert kwargs == {"a": 1}
def test_build_call_args_context_annotation_only_deps() -> None:
"""Context 标注只接收该任务自身 depends_on 的结果。"""
def fn(ctx: px.Context) -> int:
return len(ctx)
spec = px.TaskSpec("t", fn, depends_on=("a", "b"))
args, kwargs = build_call_args(spec, {"a": 1, "b": 2, "c": 99})
assert kwargs == {"ctx": {"a": 1, "b": 2}}
+12 -12
View File
@@ -3,11 +3,11 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import os
import tempfile import tempfile
import threading import threading
import time import time
from typing import Any, List from pathlib import Path
from typing import Any
import pytest import pytest
@@ -39,7 +39,7 @@ def test_sequential_basic() -> None:
def test_sequential_diamond() -> None: def test_sequential_diamond() -> None:
order: List[str] = [] order: list[str] = []
def make(name: str) -> Any: def make(name: str) -> Any:
def fn() -> str: def fn() -> str:
@@ -66,7 +66,7 @@ def test_failure_propagates() -> None:
def boom() -> None: def boom() -> None:
raise ValueError("kaboom") raise ValueError("kaboom")
def downstream(boom: None) -> int: def downstream(_boom: None) -> int:
return 1 return 1
graph = px.Graph.from_specs( graph = px.Graph.from_specs(
@@ -131,7 +131,7 @@ def test_threaded_parallelism() -> None:
def test_threaded_layer_barrier() -> None: def test_threaded_layer_barrier() -> None:
finished: List[str] = [] finished: list[str] = []
lock = threading.Lock() lock = threading.Lock()
def make(name: str) -> Any: def make(name: str) -> Any:
@@ -231,7 +231,7 @@ def test_async_timeout() -> None:
# Dry run # Dry run
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
def test_dry_run_does_not_execute(capsys: pytest.CaptureFixture[str]) -> None: def test_dry_run_does_not_execute(capsys: pytest.CaptureFixture[str]) -> None:
called: List[str] = [] called: list[str] = []
def fn() -> str: def fn() -> str:
called.append("x") called.append("x")
@@ -250,7 +250,7 @@ def test_dry_run_does_not_execute(capsys: pytest.CaptureFixture[str]) -> None:
# State / resume # State / resume
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
def test_memory_backend_resume() -> None: def test_memory_backend_resume() -> None:
runs: List[str] = [] runs: list[str] = []
def make(name: str) -> Any: def make(name: str) -> Any:
def fn() -> str: def fn() -> str:
@@ -276,7 +276,7 @@ def test_memory_backend_resume() -> None:
def test_json_backend_persistence() -> None: def test_json_backend_persistence() -> None:
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "state.json") path = str(Path(tmp) / "state.json")
def fn() -> int: def fn() -> int:
return 7 return 7
@@ -285,7 +285,7 @@ def test_json_backend_persistence() -> None:
px.run(graph, strategy="sequential", state=JSONBackend(path)) px.run(graph, strategy="sequential", state=JSONBackend(path))
# New backend reads the file; task should be skipped. # New backend reads the file; task should be skipped.
runs: List[str] = [] runs: list[str] = []
def fn2() -> int: def fn2() -> int:
runs.append("ran") runs.append("ran")
@@ -301,7 +301,7 @@ def test_json_backend_persistence() -> None:
# Events # Events
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
def test_on_event_callback() -> None: def test_on_event_callback() -> None:
events: List[px.TaskEvent] = [] events: list[px.TaskEvent] = []
def fn() -> int: def fn() -> int:
return 1 return 1
@@ -390,7 +390,7 @@ def test_async_failure_retry_branch(caplog: pytest.LogCaptureFixture) -> None:
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
def test_threaded_skips_cached_tasks() -> None: def test_threaded_skips_cached_tasks() -> None:
"""threaded 策略下命中缓存的任务应被跳过(覆盖 line 224-230)。""" """threaded 策略下命中缓存的任务应被跳过(覆盖 line 224-230)。"""
runs: List[str] = [] runs: list[str] = []
def make(name: str) -> Any: def make(name: str) -> Any:
def fn() -> str: def fn() -> str:
@@ -426,7 +426,7 @@ def test_threaded_all_cached_layer() -> None:
def test_async_skips_cached_tasks() -> None: def test_async_skips_cached_tasks() -> None:
"""async 策略下命中缓存的任务应被跳过(覆盖 line 268-274)。""" """async 策略下命中缓存的任务应被跳过(覆盖 line 268-274)。"""
runs: List[str] = [] runs: list[str] = []
async def make(name: str) -> Any: async def make(name: str) -> Any:
async def fn() -> str: async def fn() -> str:
+12 -11
View File
@@ -39,7 +39,7 @@ def test_from_specs_allows_forward_references() -> None:
def test_duplicate_task_raises() -> None: def test_duplicate_task_raises() -> None:
with pytest.raises(DuplicateTaskError): with pytest.raises(DuplicateTaskError):
px.Graph.from_specs( _ = px.Graph.from_specs(
[ [
px.TaskSpec("a", _fn), px.TaskSpec("a", _fn),
px.TaskSpec("a", _fn), px.TaskSpec("a", _fn),
@@ -49,14 +49,15 @@ def test_duplicate_task_raises() -> None:
def test_missing_dependency_raises() -> None: def test_missing_dependency_raises() -> None:
with pytest.raises(MissingDependencyError) as exc_info: with pytest.raises(MissingDependencyError) as exc_info:
px.Graph.from_specs([px.TaskSpec("b", _fn, depends_on=("a",))]) _ = px.Graph.from_specs([px.TaskSpec("b", _fn, depends_on=("a",))])
assert exc_info.value.task == "b" assert exc_info.value.task == "b"
assert exc_info.value.dependency == "a" assert exc_info.value.dependency == "a"
def test_cycle_detection() -> None: def test_cycle_detection() -> None:
with pytest.raises(CycleError): with pytest.raises(CycleError):
px.Graph.from_specs( _ = px.Graph.from_specs(
[ [
px.TaskSpec("a", _fn, depends_on=("c",)), px.TaskSpec("a", _fn, depends_on=("c",)),
px.TaskSpec("b", _fn, depends_on=("a",)), px.TaskSpec("b", _fn, depends_on=("a",)),
@@ -80,7 +81,7 @@ def test_layers_grouping() -> None:
def test_self_dependency_rejected() -> None: def test_self_dependency_rejected() -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
px.TaskSpec("a", _fn, depends_on=("a",)) _ = px.TaskSpec("a", _fn, depends_on=("a",))
def test_to_mermaid() -> None: def test_to_mermaid() -> None:
@@ -99,7 +100,7 @@ def test_to_mermaid() -> None:
def test_to_mermaid_invalid_orientation() -> None: def test_to_mermaid_invalid_orientation() -> None:
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)]) graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
with pytest.raises(ValueError): with pytest.raises(ValueError):
graph.to_mermaid("XX") _ = graph.to_mermaid("XX")
def test_subgraph_by_tags() -> None: def test_subgraph_by_tags() -> None:
@@ -134,7 +135,7 @@ def test_subgraph_by_names() -> None:
def test_subgraph_by_names_unknown() -> None: def test_subgraph_by_names_unknown() -> None:
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)]) graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
with pytest.raises(KeyError): with pytest.raises(KeyError):
graph.subgraph_by_names(["nope"]) _ = graph.subgraph_by_names(["nope"])
def test_describe() -> None: def test_describe() -> None:
@@ -160,14 +161,14 @@ def test_add_chains_and_validates() -> None:
assert "a" in graph assert "a" in graph
# 缺失依赖应即时报错 # 缺失依赖应即时报错
with pytest.raises(MissingDependencyError): with pytest.raises(MissingDependencyError):
graph.add(px.TaskSpec("b", _fn, depends_on=("missing",))) _ = graph.add(px.TaskSpec("b", _fn, depends_on=("missing",)))
def test_add_duplicate_raises() -> None: def test_add_duplicate_raises() -> None:
graph = px.Graph() graph = px.Graph()
graph.add(px.TaskSpec("a", _fn)) _ = graph.add(px.TaskSpec("a", _fn))
with pytest.raises(DuplicateTaskError): with pytest.raises(DuplicateTaskError):
graph.add(px.TaskSpec("a", _fn)) _ = graph.add(px.TaskSpec("a", _fn))
def test_all_specs_returns_view() -> None: def test_all_specs_returns_view() -> None:
@@ -182,7 +183,7 @@ def test_spec_accessor() -> None:
graph = px.Graph.from_specs([px.TaskSpec("a", _fn)]) graph = px.Graph.from_specs([px.TaskSpec("a", _fn)])
assert graph.spec("a").name == "a" assert graph.spec("a").name == "a"
with pytest.raises(KeyError): with pytest.raises(KeyError):
graph.spec("missing") _ = graph.spec("missing")
def test_dependencies_accessor() -> None: def test_dependencies_accessor() -> None:
@@ -213,7 +214,7 @@ def test_subgraph_preserves_metadata() -> None:
graph = px.Graph.from_specs( graph = px.Graph.from_specs(
[ [
px.TaskSpec("a", _fn, tags=("x",), retries=3, timeout=5.0), px.TaskSpec("a", _fn, tags=("x",), retries=3, timeout=5.0),
px.TaskSpec("b", _fn, ("a",), tags=("y",)), px.TaskSpec("b", _fn, depends_on=("a",), tags=("y",)),
] ]
) )
sub = graph.subgraph(["x"]) sub = graph.subgraph(["x"])
+90 -81
View File
@@ -1,9 +1,8 @@
"""RunReport 测试""" """RunReport 测试."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime, timedelta
from typing import Optional
import pyflowx as px import pyflowx as px
from pyflowx.task import TaskResult, TaskSpec, TaskStatus from pyflowx.task import TaskResult, TaskSpec, TaskStatus
@@ -17,15 +16,14 @@ def _make_result(
name: str = "a", name: str = "a",
status: TaskStatus = TaskStatus.SUCCESS, status: TaskStatus = TaskStatus.SUCCESS,
value: object = 42, value: object = 42,
error: Optional[object] = None, error: BaseException | None = None,
duration: float = 0.5, duration: float = 0.5,
attempts: int = 1, attempts: int = 1,
) -> TaskResult[object]: ) -> TaskResult[object]:
"""构造测试用 TaskResult 实例."""
spec: TaskSpec[object] = TaskSpec[object](name, _fn) spec: TaskSpec[object] = TaskSpec[object](name, _fn)
start = datetime(2024, 1, 1, 0, 0, 0) start = datetime(2024, 1, 1, 0, 0, 0)
# 用 timedelta 精确表达秒数,避免 int() 截断小数 # 用 timedelta 精确表达秒数,避免 int() 截断小数
from datetime import timedelta
end = start + timedelta(seconds=duration) if duration else None end = start + timedelta(seconds=duration) if duration else None
return TaskResult[object]( return TaskResult[object](
spec=spec, spec=spec,
@@ -38,85 +36,96 @@ def _make_result(
) )
def test_getitem_returns_value() -> None: class TestRunReportAccess:
report = px.RunReport() """测试 RunReport 的访问接口."""
report.results["a"] = _make_result("a", value=7)
assert report["a"] == 7 def test_getitem_returns_value(self) -> None:
"""report[name] 应返回任务结果值."""
report = px.RunReport()
report.results["a"] = _make_result("a", value=7)
assert report["a"] == 7
def test_result_of_returns_full_result(self) -> None:
"""result_of 应返回完整的 TaskResult 对象."""
report = px.RunReport()
r = _make_result("a")
report.results["a"] = r
assert report.result_of("a") is r
def test_contains(self) -> None:
"""in 运算符应正确判断任务是否存在."""
report = px.RunReport()
report.results["a"] = _make_result("a")
assert "a" in report
assert "b" not in report
def test_iter_and_len(self) -> None:
"""应支持迭代任务名并返回任务数量."""
report = px.RunReport()
report.results["a"] = _make_result("a")
report.results["b"] = _make_result("b")
assert list(report) == ["a", "b"]
assert len(report) == 2
def test_result_of_returns_full_result() -> None: class TestRunReportSummary:
report = px.RunReport() """测试 RunReport 的 summary 方法."""
r = _make_result("a")
report.results["a"] = r def test_summary_success(self) -> None:
assert report.result_of("a") is r """应正确汇总成功和跳过的任务."""
report = px.RunReport()
report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS, duration=1.0)
report.results["b"] = _make_result("b", status=TaskStatus.SKIPPED, duration=0.0)
s = report.summary()
assert s["success"] is True
assert s["total_tasks"] == 2
assert s["by_status"] == {"success": 1, "skipped": 1}
assert s["total_duration_seconds"] == 1.0
def test_summary_with_none_duration(self) -> None:
"""未开始/未结束的任务 duration 为 None,不应计入总时长."""
report = px.RunReport()
spec: TaskSpec[object] = TaskSpec("a", _fn) # type: ignore[arg-type]
report.results["a"] = TaskResult(spec=spec, status=TaskStatus.FAILED)
s = report.summary()
assert s["total_duration_seconds"] == 0.0
def test_failed_tasks(self) -> None:
"""failed_tasks 应返回所有失败任务名."""
report = px.RunReport()
report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS)
report.results["b"] = _make_result(
"b", status=TaskStatus.FAILED, error=ValueError("x")
)
assert report.failed_tasks() == ["b"]
def test_contains() -> None: class TestRunReportDescribe:
report = px.RunReport() """测试 RunReport 的 describe 方法."""
report.results["a"] = _make_result("a")
assert "a" in report
assert "b" not in report
def test_describe_success(self) -> None:
"""应正确描述成功状态和耗时."""
report = px.RunReport()
report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS, duration=0.5)
desc = report.describe()
assert "RunReport(success=True)" in desc
assert "a: success" in desc
assert "0.500s" in desc
def test_iter_and_len() -> None: def test_describe_with_error(self) -> None:
report = px.RunReport() """应正确描述失败状态和错误信息."""
report.results["a"] = _make_result("a") report = px.RunReport(success=False)
report.results["b"] = _make_result("b") report.results["a"] = _make_result(
assert list(report) == ["a", "b"] "a", status=TaskStatus.FAILED, error=ValueError("boom"), duration=0.1
assert len(report) == 2 )
desc = report.describe()
assert "success=False" in desc
assert "error=ValueError" in desc
def test_describe_no_duration(self) -> None:
def test_summary_success() -> None: """无耗时的任务应显示为 '-'."""
report = px.RunReport() report = px.RunReport()
report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS, duration=1.0) spec: TaskSpec[object] = TaskSpec("a", _fn) # type: ignore[arg-type]
report.results["b"] = _make_result("b", status=TaskStatus.SKIPPED, duration=0.0) report.results["a"] = TaskResult(spec=spec, status=TaskStatus.PENDING)
s = report.summary() desc = report.describe()
assert s["success"] is True assert "-" in desc # duration 显示为 "-"
assert s["total_tasks"] == 2
assert s["by_status"] == {"success": 1, "skipped": 1}
assert s["total_duration_seconds"] == 1.0
def test_summary_with_none_duration() -> None:
"""未开始/未结束的任务 duration 为 None,不应计入总时长。"""
report = px.RunReport()
spec: TaskSpec[object] = TaskSpec("a", _fn) # type: ignore[arg-type]
report.results["a"] = TaskResult(spec=spec, status=TaskStatus.FAILED)
s = report.summary()
assert s["total_duration_seconds"] == 0.0
def test_failed_tasks() -> None:
report = px.RunReport()
report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS)
report.results["b"] = _make_result(
"b", status=TaskStatus.FAILED, error=ValueError("x")
)
assert report.failed_tasks() == ["b"]
def test_describe_success() -> None:
report = px.RunReport()
report.results["a"] = _make_result("a", status=TaskStatus.SUCCESS, duration=0.5)
desc = report.describe()
assert "RunReport(success=True)" in desc
assert "a: success" in desc
assert "0.500s" in desc
def test_describe_with_error() -> None:
report = px.RunReport(success=False)
report.results["a"] = _make_result(
"a", status=TaskStatus.FAILED, error=ValueError("boom"), duration=0.1
)
desc = report.describe()
assert "success=False" in desc
assert "error=ValueError" in desc
def test_describe_no_duration() -> None:
report = px.RunReport()
spec: TaskSpec[object] = TaskSpec("a", _fn) # type: ignore[arg-type]
report.results["a"] = TaskResult(spec=spec, status=TaskStatus.PENDING)
desc = report.describe()
assert "-" in desc # duration 显示为 "-"
+18 -25
View File
@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import sys import sys
from typing import Any, List from typing import Any
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@@ -77,12 +77,12 @@ class TestCliRunnerConstruction:
def test_rejects_non_graph_value(self) -> None: def test_rejects_non_graph_value(self) -> None:
"""非 Graph 值应抛出 TypeError.""" """非 Graph 值应抛出 TypeError."""
with pytest.raises(TypeError, match="必须是 Graph 实例"): with pytest.raises(TypeError, match="必须是 Graph 实例"):
px.CliRunner(clean="not a graph") # type: ignore[arg-type] _ = px.CliRunner(clean="not a graph") # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
def test_rejects_non_graph_list(self) -> None: def test_rejects_non_graph_list(self) -> None:
"""列表类型的值应抛出 TypeError.""" """列表类型的值应抛出 TypeError."""
with pytest.raises(TypeError, match="必须是 Graph 实例"): with pytest.raises(TypeError, match="必须是 Graph 实例"):
px.CliRunner(build=[1, 2, 3]) # type: ignore[arg-type] _ = px.CliRunner(build=[1, 2, 3]) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
def test_default_strategy_is_sequential(self) -> None: def test_default_strategy_is_sequential(self) -> None:
"""默认策略应为 Strategy.SEQUENTIAL.""" """默认策略应为 Strategy.SEQUENTIAL."""
@@ -257,19 +257,15 @@ class TestCliRunnerParser:
class TestCliRunnerRunSuccess: class TestCliRunnerRunSuccess:
"""测试 CliRunner.run 的成功执行路径.""" """测试 CliRunner.run 的成功执行路径."""
def test_run_valid_command_returns_zero( def test_run_valid_command_returns_zero(self) -> None:
self, capsys: pytest.CaptureFixture[str]
) -> None:
"""有效命令执行成功应返回 0.""" """有效命令执行成功应返回 0."""
runner = px.CliRunner(echo=_echo_graph()) runner = px.CliRunner(echo=_echo_graph())
exit_code = runner.run(["echo"]) exit_code = runner.run(["echo"])
assert exit_code == CliExitCode.SUCCESS.value assert exit_code == CliExitCode.SUCCESS.value
def test_run_executes_correct_graph( def test_run_executes_correct_graph(self) -> None:
self, capsys: pytest.CaptureFixture[str]
) -> None:
"""应执行用户指定的命令对应的图.""" """应执行用户指定的命令对应的图."""
executed: List[str] = [] executed: list[str] = []
def track_a() -> None: def track_a() -> None:
executed.append("a") executed.append("a")
@@ -418,9 +414,7 @@ class TestCliRunnerRunFailure:
captured = capsys.readouterr() captured = capsys.readouterr()
assert "可用命令" in captured.out or "可用命令" in captured.err assert "可用命令" in captured.out or "可用命令" in captured.err
def test_run_failing_task_returns_failure( def test_run_failing_task_returns_failure(self) -> None:
self, capsys: pytest.CaptureFixture[str]
) -> None:
"""任务失败时应返回 1.""" """任务失败时应返回 1."""
runner = px.CliRunner(fail=_failing_graph()) runner = px.CliRunner(fail=_failing_graph())
exit_code = runner.run(["fail"]) exit_code = runner.run(["fail"])
@@ -443,7 +437,7 @@ class TestCliRunnerRunFailure:
class TestCliRunnerList: class TestCliRunnerList:
"""测试 --list 选项.""" """测试 --list 选项."""
def test_list_returns_success(self, capsys: pytest.CaptureFixture[str]) -> None: def test_list_returns_success(self) -> None:
"""--list 应返回 0.""" """--list 应返回 0."""
runner = px.CliRunner(clean=_echo_graph(), build=_echo_graph()) runner = px.CliRunner(clean=_echo_graph(), build=_echo_graph())
exit_code = runner.run(["--list"]) exit_code = runner.run(["--list"])
@@ -462,11 +456,9 @@ class TestCliRunnerList:
assert "build" in captured.out assert "build" in captured.out
assert "test" in captured.out assert "test" in captured.out
def test_list_does_not_execute_any_graph( def test_list_does_not_execute_any_graph(self) -> None:
self, capsys: pytest.CaptureFixture[str]
) -> None:
"""--list 不应执行任何图.""" """--list 不应执行任何图."""
executed: List[str] = [] executed: list[str] = []
def track() -> None: def track() -> None:
executed.append("ran") executed.append("ran")
@@ -488,7 +480,7 @@ class TestCliRunnerErrorHandling:
"""KeyboardInterrupt 应返回 130.""" """KeyboardInterrupt 应返回 130."""
runner = px.CliRunner(echo=_echo_graph()) runner = px.CliRunner(echo=_echo_graph())
def raise_interrupt(*args: Any, **kwargs: Any) -> None: def raise_interrupt(*_args: Any, **_kwargs: Any) -> None:
raise KeyboardInterrupt raise KeyboardInterrupt
with patch("pyflowx.runner.run", side_effect=raise_interrupt): with patch("pyflowx.runner.run", side_effect=raise_interrupt):
@@ -503,7 +495,7 @@ class TestCliRunnerErrorHandling:
"""PyFlowXError 应返回 1.""" """PyFlowXError 应返回 1."""
runner = px.CliRunner(echo=_echo_graph()) runner = px.CliRunner(echo=_echo_graph())
def raise_error(*args: Any, **kwargs: Any) -> None: def raise_error(*_args: Any, **_kwargs: Any) -> None:
raise TaskFailedError("echo", RuntimeError("boom"), 1) raise TaskFailedError("echo", RuntimeError("boom"), 1)
with patch("pyflowx.runner.run", side_effect=raise_error): with patch("pyflowx.runner.run", side_effect=raise_error):
@@ -520,12 +512,13 @@ class TestCliRunnerErrorHandling:
runner = px.CliRunner(echo=_echo_graph()) runner = px.CliRunner(echo=_echo_graph())
def raise_custom(*args: Any, **kwargs: Any) -> None: def raise_custom(*_args: Any, **_kwargs: Any) -> None:
raise CustomError("unexpected") raise CustomError("unexpected")
with patch("pyflowx.runner.run", side_effect=raise_custom): with patch("pyflowx.runner.run", side_effect=raise_custom), pytest.raises(
with pytest.raises(CustomError): CustomError
runner.run(["echo"]) ):
runner.run(["echo"])
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
@@ -617,7 +610,7 @@ class TestCliRunnerIntegration:
def test_diamond_dependency_graph(self) -> None: def test_diamond_dependency_graph(self) -> None:
"""菱形依赖图应正确执行.""" """菱形依赖图应正确执行."""
order: List[str] = [] order: list[str] = []
def make(name: str) -> Any: def make(name: str) -> Any:
def fn() -> str: def fn() -> str:
+13 -26
View File
@@ -22,19 +22,6 @@ def mock_tmp_json(tmp_path: Path) -> Path:
return path return path
class TestStateBackend:
"""测试状态后端。"""
def test_json_backend_save_and_load(self, mock_tmp_json: Path) -> None:
"""测试 JSON 后端保存和加载。"""
b = JSONBackend(str(mock_tmp_json))
assert not b.has("a")
b.save("a", 1)
assert b.has("a")
assert b.get("a") == 1
assert dict(b.load()) == {"a": 1}
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# MemoryBackend # MemoryBackend
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
@@ -61,7 +48,7 @@ def test_memory_backend_get_missing_raises() -> None:
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
def test_json_backend_save_and_load() -> None: def test_json_backend_save_and_load() -> None:
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "state.json") path = str(Path(tmp) / "state.json")
b = JSONBackend(path) b = JSONBackend(path)
b.save("a", {"x": 1}) b.save("a", {"x": 1})
b.save("b", [1, 2, 3]) b.save("b", [1, 2, 3])
@@ -75,7 +62,7 @@ def test_json_backend_save_and_load() -> None:
def test_json_backend_clear() -> None: def test_json_backend_clear() -> None:
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "state.json") path = str(Path(tmp) / "state.json")
b = JSONBackend(path) b = JSONBackend(path)
b.save("a", 1) b.save("a", 1)
b.clear() b.clear()
@@ -88,7 +75,7 @@ def test_json_backend_clear() -> None:
def test_json_backend_nonexistent_file_starts_empty() -> None: def test_json_backend_nonexistent_file_starts_empty() -> None:
"""文件不存在时应正常初始化为空。""" """文件不存在时应正常初始化为空。"""
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "absent.json") path = str(Path(tmp) / "absent.json")
b = JSONBackend(path) b = JSONBackend(path)
assert dict(b.load()) == {} assert dict(b.load()) == {}
assert not b.has("anything") assert not b.has("anything")
@@ -97,7 +84,7 @@ def test_json_backend_nonexistent_file_starts_empty() -> None:
def test_json_backend_non_serialisable_raises() -> None: def test_json_backend_non_serialisable_raises() -> None:
"""不可 JSON 序列化的值应抛 StorageError,且不污染内存状态。""" """不可 JSON 序列化的值应抛 StorageError,且不污染内存状态。"""
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "state.json") path = str(Path(tmp) / "state.json")
b = JSONBackend(path) b = JSONBackend(path)
with pytest.raises(StorageError): with pytest.raises(StorageError):
b.save("a", object()) # object() 不可序列化 b.save("a", object()) # object() 不可序列化
@@ -113,12 +100,12 @@ def test_json_backend_flush_type_error(monkeypatch: pytest.MonkeyPatch) -> None:
import json as _json import json as _json
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "state.json") path = str(Path(tmp) / "state.json")
b = JSONBackend(path) b = JSONBackend(path)
original_dump = _json.dump original_dump = _json.dump
def flaky_dump(*args: Any, **kwargs: Any) -> None: def flaky_dump(*_args: Any, **_kwargs: Any) -> None:
raise TypeError("simulated flush failure") raise TypeError("simulated flush failure")
monkeypatch.setattr(_json, "dump", flaky_dump) monkeypatch.setattr(_json, "dump", flaky_dump)
@@ -131,28 +118,28 @@ def test_json_backend_flush_type_error(monkeypatch: pytest.MonkeyPatch) -> None:
def test_json_backend_flush_os_error(monkeypatch: pytest.MonkeyPatch) -> None: def test_json_backend_flush_os_error(monkeypatch: pytest.MonkeyPatch) -> None:
"""_flush 时 OSError 应转为 StorageError。""" """_flush 时 OSError 应转为 StorageError。"""
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "state.json") path = str(Path(tmp) / "state.json")
b = JSONBackend(path) b = JSONBackend(path)
original_replace = os.replace original_replace = os.replace
def fail_replace(*args: Any, **kwargs: Any) -> None: def fail_replace(*_args: Any, **_kwargs: Any) -> None:
raise OSError("simulated os.replace failure") raise OSError("simulated os.replace failure")
monkeypatch.setattr(os, "replace", fail_replace) monkeypatch.setattr(Path, "replace", fail_replace)
with pytest.raises(StorageError, match="cannot write"): with pytest.raises(StorageError, match="cannot write"):
b.save("a", 1) b.save("a", 1)
monkeypatch.setattr(os, "replace", original_replace) monkeypatch.setattr(os, "replace", original_replace)
def test_json_backend_corrupt_file_raises(tmp_path: Path) -> None: def test_json_backend_corrupt_file_raises() -> None:
"""损坏的 JSON 文件应抛 StorageError。""" """损坏的 JSON 文件应抛 StorageError。"""
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "state.json") path = str(Path(tmp) / "state.json")
with open(path, "w", encoding="utf-8") as fh: with open(path, "w", encoding="utf-8") as fh:
fh.write("{not valid json") _ = fh.write("{not valid json")
with pytest.raises(StorageError): with pytest.raises(StorageError):
JSONBackend(path) _ = JSONBackend(path)
def test_json_backend_non_dict_content_ignored(tmp_path: Path) -> None: def test_json_backend_non_dict_content_ignored(tmp_path: Path) -> None:
+1 -1
View File
@@ -356,7 +356,7 @@ class TestTaskSpecVerbose:
def test_verbose_default_is_false(self) -> None: def test_verbose_default_is_false(self) -> None:
"""verbose 默认应为 False.""" """verbose 默认应为 False."""
spec = px.TaskSpec("a", cmd=[*ECHO_CMD, "hi"]) spec: px.TaskSpec[object] = px.TaskSpec("a", cmd=[*ECHO_CMD, "hi"])
assert spec.verbose is False assert spec.verbose is False
def test_verbose_true_prints_command( def test_verbose_true_prints_command(
+2 -1
View File
@@ -2,5 +2,6 @@
This type stub file was generated by pyright. This type stub file was generated by pyright.
""" """
from .graphlib import * from .graphlib import CycleError, TopologicalSorter
__all__ = ["CycleError", "TopologicalSorter"]
+23 -22
View File
@@ -2,15 +2,16 @@
This type stub file was generated by pyright. This type stub file was generated by pyright.
""" """
__all__ = ["TopologicalSorter", "CycleError"] from typing import Any, Generator
__all__ = ["CycleError", "TopologicalSorter"]
_NODE_OUT = ... _NODE_OUT = ...
_NODE_DONE = ... _NODE_DONE = ...
class _NodeInfo:
__slots__ = ...
def __init__(self, node) -> None:
...
class _NodeInfo:
__slots__: list[str]
def __init__(self, node) -> None: ...
class CycleError(ValueError): class CycleError(ValueError):
"""Subclass of ValueError raised by TopologicalSorterif cycles exist in the graph """Subclass of ValueError raised by TopologicalSorterif cycles exist in the graph
@@ -22,14 +23,13 @@ class CycleError(ValueError):
next node in the list. In the reported list, the first and the last node will be next node in the list. In the reported list, the first and the last node will be
the same, to make it clear that it is cyclic. the same, to make it clear that it is cyclic.
""" """
...
...
class TopologicalSorter: class TopologicalSorter:
"""Provides functionality to topologically sort a graph of hashable nodes""" """Provides functionality to topologically sort a graph of hashable nodes"""
def __init__(self, graph=...) -> None:
... def __init__(self, graph=...) -> None: ...
def add(self, node, *predecessors) -> None: def add(self, node, *predecessors) -> None:
"""Add a new node and its predecessors to the graph. """Add a new node and its predecessors to the graph.
@@ -45,8 +45,9 @@ class TopologicalSorter:
Raises ValueError if called after "prepare". Raises ValueError if called after "prepare".
""" """
... ...
def prepare(self) -> None: def prepare(self) -> None:
"""Mark the graph as finished and check for cycles in the graph. """Mark the graph as finished and check for cycles in the graph.
@@ -55,8 +56,9 @@ class TopologicalSorter:
progress. After a call to this function, the graph cannot be modified and progress. After a call to this function, the graph cannot be modified and
therefore no more nodes can be added using "add". therefore no more nodes can be added using "add".
""" """
... ...
def get_ready(self) -> tuple[Any, ...]: def get_ready(self) -> tuple[Any, ...]:
"""Return a tuple of all the nodes that are ready. """Return a tuple of all the nodes that are ready.
@@ -67,8 +69,9 @@ class TopologicalSorter:
Raises ValueError if called without calling "prepare" previously. Raises ValueError if called without calling "prepare" previously.
""" """
... ...
def is_active(self) -> bool: def is_active(self) -> bool:
"""Return True if more progress can be made and ``False`` otherwise. """Return True if more progress can be made and ``False`` otherwise.
@@ -79,11 +82,10 @@ class TopologicalSorter:
Raises ValueError if called without calling "prepare" previously. Raises ValueError if called without calling "prepare" previously.
""" """
... ...
def __bool__(self) -> bool: def __bool__(self) -> bool: ...
...
def done(self, *nodes) -> None: def done(self, *nodes) -> None:
"""Marks a set of nodes returned by "get_ready" as processed. """Marks a set of nodes returned by "get_ready" as processed.
@@ -95,9 +97,10 @@ class TopologicalSorter:
graph by using "add" or if called without calling "prepare" previously or if graph by using "add" or if called without calling "prepare" previously or if
node has not yet been returned by "get_ready". node has not yet been returned by "get_ready".
""" """
... ...
def static_order(self) -> Generator[Any, Any, None]: def static_order(self) -> Generator[Any]:
"""Returns an iterable of nodes in a topological order. """Returns an iterable of nodes in a topological order.
The particular order that is returned may depend on the specific The particular order that is returned may depend on the specific
@@ -106,7 +109,5 @@ class TopologicalSorter:
Using this method does not require to call "prepare" or "done". If any Using this method does not require to call "prepare" or "done". If any
cycle is detected, :exc:`CycleError` will be raised. cycle is detected, :exc:`CycleError` will be raised.
""" """
... ...