This commit is contained in:
2026-06-22 12:31:26 +08:00
parent 413ab40044
commit 0df795237d
9 changed files with 474 additions and 105 deletions
+1 -1
View File
@@ -22,7 +22,7 @@ version = "0.1.7"
[project.scripts] [project.scripts]
autofmt = "pyflowx.cli.autofmt:main" autofmt = "pyflowx.cli.autofmt:main"
bumpver = "pyflowx.cli.bumpversion:main" bumpver = "pyflowx.cli.bumpversion:main"
cls = "pyflowx.cli.clearscreen:main" clr = "pyflowx.cli.clearscreen:main"
envpy = "pyflowx.cli.envpy:main" envpy = "pyflowx.cli.envpy:main"
envqt = "pyflowx.cli.envqt:main" envqt = "pyflowx.cli.envqt:main"
envrs = "pyflowx.cli.envrs:main" envrs = "pyflowx.cli.envrs:main"
-5
View File
@@ -10,10 +10,6 @@ import subprocess
import pyflowx as px import pyflowx as px
from pyflowx.conditions import Constants from pyflowx.conditions import Constants
# ============================================================================
# 辅助函数
# ============================================================================
def clear_screen() -> None: def clear_screen() -> None:
"""使用系统命令清屏.""" """使用系统命令清屏."""
@@ -23,7 +19,6 @@ def clear_screen() -> None:
subprocess.run(["clear"], check=False) subprocess.run(["clear"], check=False)
print("\033[2J\033[H", end="") print("\033[2J\033[H", end="")
print("ClearScreen - 清屏工具")
def main() -> None: def main() -> None:
+16 -45
View File
@@ -8,10 +8,6 @@ from __future__ import annotations
import pyflowx as px import pyflowx as px
from pyflowx.conditions import Constants from pyflowx.conditions import Constants
# ============================================================================
# Qt 依赖列表
# ============================================================================
QT_LIBS: list[str] = [ QT_LIBS: list[str] = [
"build-essential", "build-essential",
"libgl1", "libgl1",
@@ -40,47 +36,22 @@ CHINESE_FONTS: list[str] = [
] ]
# ============================================================================
# TaskSpec 定义
# ============================================================================
# 条件: 仅在 Unix 系统上执行
def is_linux() -> bool:
"""判断是否为 Linux 系统."""
return Constants.IS_LINUX and not Constants.IS_MACOS
envqt_install: px.TaskSpec = px.TaskSpec(
"envqt_install",
cmd=["sudo", "apt", "install", "-y", *QT_LIBS],
conditions=(is_linux,),
)
envqt_fonts: px.TaskSpec = px.TaskSpec(
"envqt_fonts",
cmd=["sudo", "apt", "install", "-y", *CHINESE_FONTS],
conditions=(is_linux,),
)
# ============================================================================
# CLI Runner
# ============================================================================
def main() -> None: def main() -> None:
"""PyQt 环境配置工具主函数.""" """PyQt 环境配置工具主函数."""
runner = px.CliRunner( graph = px.Graph.from_specs(
strategy="thread", [
description="EnvQt - PyQt 环境配置工具", px.TaskSpec(
graphs={ "envqt_install",
# 安装 Qt 依赖 cmd=["sudo", "apt", "install", "-y", *QT_LIBS],
"i": px.Graph.from_specs([envqt_install]), conditions=(lambda: Constants.IS_LINUX,),
# 安装中文字体 verbose=True,
"f": px.Graph.from_specs([envqt_fonts]), ),
# 安装全部 px.TaskSpec(
"a": px.Graph.from_specs([envqt_install, envqt_fonts]), "envqt_fonts",
}, cmd=["sudo", "apt", "install", "-y", *CHINESE_FONTS],
conditions=(lambda: Constants.IS_LINUX,),
verbose=True,
),
],
) )
runner.run_cli() px.run(graph, strategy="thread", verbose=True)
+2 -4
View File
@@ -20,15 +20,13 @@ def maturin_build_cmd() -> list[str]:
""" """
command = ["maturin", "build", "-r"].copy() command = ["maturin", "build", "-r"].copy()
if Constants.IS_WINDOWS: if Constants.IS_WINDOWS:
command.extend( command.extend([
[
"--target", "--target",
"x86_64-win7-windows-msvc", "x86_64-win7-windows-msvc",
"-Zbuild-std", "-Zbuild-std",
"-i", "-i",
"python3.8", "python3.8",
] ])
)
return command return command
+91 -6
View File
@@ -22,6 +22,15 @@ class TestFormatWithRuff:
autofmt.format_with_ruff(tmp_path, fix=True) autofmt.format_with_ruff(tmp_path, fix=True)
assert mock_run.called assert mock_run.called
def test_format_with_ruff_no_fix(self, tmp_path: Path) -> None:
"""Should format with ruff without fix."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
autofmt.format_with_ruff(tmp_path, fix=False)
# Should not include --fix flag
call_args = mock_run.call_args[0][0]
assert "--fix" not in call_args
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# lint_with_ruff # lint_with_ruff
@@ -36,6 +45,15 @@ class TestLintWithRuff:
autofmt.lint_with_ruff(tmp_path, fix=True) autofmt.lint_with_ruff(tmp_path, fix=True)
assert mock_run.called assert mock_run.called
def test_lint_with_ruff_no_fix(self, tmp_path: Path) -> None:
"""Should lint with ruff without fix."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
autofmt.lint_with_ruff(tmp_path, fix=False)
# Should not include --fix flag
call_args = mock_run.call_args[0][0]
assert "--fix" not in call_args
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# add_docstring # add_docstring
@@ -51,15 +69,64 @@ class TestAddDocstring:
result = autofmt.add_docstring(py_file, '"""Test module."""') result = autofmt.add_docstring(py_file, '"""Test module."""')
assert result is True assert result is True
def test_add_docstring_skips_non_python_files(self, tmp_path: Path) -> None: def test_add_docstring_skips_files_with_docstring(self, tmp_path: Path) -> None:
"""Should skip non-Python files.""" """Should skip files that already have docstring."""
txt_file = tmp_path / "test.txt" py_file = tmp_path / "test.py"
txt_file.write_text("test content") py_file.write_text('"""Existing docstring."""\ndef test():\n pass\n')
result = autofmt.add_docstring(txt_file, '"""Test."""') result = autofmt.add_docstring(py_file, '"""New docstring."""')
# Should return False for non-Python files
assert result is False assert result is False
def test_add_docstring_empty_file(self, tmp_path: Path) -> None:
"""Should handle empty file."""
py_file = tmp_path / "test.py"
py_file.write_text("")
result = autofmt.add_docstring(py_file, '"""Test module."""')
# Should handle empty file
assert result is True
# ---------------------------------------------------------------------- #
# generate_module_docstring
# ---------------------------------------------------------------------- #
class TestGenerateModuleDocstring:
"""Test generate_module_docstring function."""
def test_generate_module_docstring_basic(self, tmp_path: Path) -> None:
"""Should generate basic docstring."""
py_file = tmp_path / "test.py"
py_file.write_text("def test():\n pass\n")
result = autofmt.generate_module_docstring(py_file)
# Should contain "Tests for" since stem contains "test"
assert "Tests for" in result
def test_generate_module_docstring_with_package(self, tmp_path: Path) -> None:
"""Should generate docstring for package."""
py_file = tmp_path / "mypackage" / "test.py"
py_file.parent.mkdir(parents=True)
py_file.write_text("def test():\n pass\n")
result = autofmt.generate_module_docstring(py_file)
assert "mypackage" in result
def test_generate_module_docstring_cli(self, tmp_path: Path) -> None:
"""Should generate docstring for CLI module."""
py_file = tmp_path / "cli.py"
py_file.write_text("def test():\n pass\n")
result = autofmt.generate_module_docstring(py_file)
assert "Command-line interface" in result
def test_generate_module_docstring_util(self, tmp_path: Path) -> None:
"""Should generate docstring for utility module."""
py_file = tmp_path / "utils.py"
py_file.write_text("def test():\n pass\n")
result = autofmt.generate_module_docstring(py_file)
assert "Utility functions" in result
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# auto_add_docstrings # auto_add_docstrings
@@ -76,6 +143,24 @@ class TestAutoAddDocstrings:
count = autofmt.auto_add_docstrings(tmp_path) count = autofmt.auto_add_docstrings(tmp_path)
assert count >= 0 assert count >= 0
def test_auto_add_docstrings_skips_ignored(self, tmp_path: Path) -> None:
"""Should skip ignored directories."""
py_file = tmp_path / "__pycache__" / "test.py"
py_file.parent.mkdir()
py_file.write_text("def test():\n pass\n")
count = autofmt.auto_add_docstrings(tmp_path)
# Should skip __pycache__
assert count == 0
def test_auto_add_docstrings_no_files(self, tmp_path: Path) -> None:
"""Should handle no Python files."""
txt_file = tmp_path / "test.txt"
txt_file.write_text("test content")
count = autofmt.auto_add_docstrings(tmp_path)
assert count == 0
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# sync_pyproject_config # sync_pyproject_config
+101
View File
@@ -9,6 +9,95 @@ import pyflowx as px
from pyflowx.cli import folderback from pyflowx.cli import folderback
# ---------------------------------------------------------------------- #
# remove_dump
# ---------------------------------------------------------------------- #
class TestRemoveDump:
"""Test remove_dump function."""
def test_remove_dump_no_files(self, tmp_path: Path) -> None:
"""Should handle no zip files."""
src = tmp_path / "source"
src.mkdir()
dst = tmp_path / "backup"
dst.mkdir()
folderback.remove_dump(src, dst, 5)
# Should not raise error
def test_remove_dump_within_limit(self, tmp_path: Path) -> None:
"""Should not remove files within limit."""
src = tmp_path / "source"
src.mkdir()
dst = tmp_path / "backup"
dst.mkdir()
# Create some zip files
for i in range(3):
zip_file = dst / f"source_20240101_12000{i}.zip"
zip_file.write_bytes(b"ZIP content")
folderback.remove_dump(src, dst, 5)
# All files should remain
assert len(list(dst.glob("*.zip"))) == 3
def test_remove_dump_exceeds_limit(self, tmp_path: Path) -> None:
"""Should remove oldest files when exceeds limit."""
src = tmp_path / "source"
src.mkdir()
dst = tmp_path / "backup"
dst.mkdir()
# Create more zip files than limit
for i in range(7):
zip_file = dst / f"source_20240101_12000{i}.zip"
zip_file.write_bytes(b"ZIP content")
folderback.remove_dump(src, dst, 5)
# Should have only 5 files
assert len(list(dst.glob("*.zip"))) == 5
# ---------------------------------------------------------------------- #
# zip_target
# ---------------------------------------------------------------------- #
class TestZipTarget:
"""Test zip_target function."""
def test_zip_target_creates_zip(self, tmp_path: Path) -> None:
"""Should create zip file."""
src = tmp_path / "source"
src.mkdir()
(src / "test.txt").write_text("test content")
dst = tmp_path / "backup"
dst.mkdir()
with patch("time.strftime", return_value="_20240101_120000"):
folderback.zip_target(src, dst, 5)
# Should create zip file
zip_files = list(dst.glob("*.zip"))
assert len(zip_files) == 1
def test_zip_target_with_subdirectories(self, tmp_path: Path) -> None:
"""Should zip files in subdirectories."""
src = tmp_path / "source"
src.mkdir()
subdir = src / "subdir"
subdir.mkdir()
(src / "test.txt").write_text("test content")
(subdir / "nested.txt").write_text("nested content")
dst = tmp_path / "backup"
dst.mkdir()
with patch("time.strftime", return_value="_20240101_120000"):
folderback.zip_target(src, dst, 5)
# Should create zip file
zip_files = list(dst.glob("*.zip"))
assert len(zip_files) == 1
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# backup_folder # backup_folder
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
@@ -46,6 +135,18 @@ class TestBackupFolder:
folderback.backup_folder(str(source_dir), str(backup_dir), 5) folderback.backup_folder(str(source_dir), str(backup_dir), 5)
# Should print error message and return # Should print error message and return
def test_backup_folder_creates_dst(self, tmp_path: Path) -> None:
"""Should create destination directory."""
source_dir = tmp_path / "source"
source_dir.mkdir()
(source_dir / "test.txt").write_text("test content")
backup_dir = tmp_path / "backup"
with patch.object(folderback, "zip_target") as mock_zip:
folderback.backup_folder(str(source_dir), str(backup_dir), 5)
assert backup_dir.exists()
assert mock_zip.called
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# TaskSpec definitions # TaskSpec definitions
+79
View File
@@ -2,13 +2,92 @@
from __future__ import annotations from __future__ import annotations
from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import pyflowx as px
from pyflowx.cli import gittool from pyflowx.cli import gittool
# ---------------------------------------------------------------------- #
# not_has_git_repo
# ---------------------------------------------------------------------- #
class TestNotHasGitRepo:
"""Test not_has_git_repo function."""
def test_not_has_git_repo_true(self, tmp_path: Path) -> None:
"""Should return True when no .git directory."""
with patch.object(Path, "cwd", return_value=tmp_path):
result = gittool.not_has_git_repo()
assert result is True
def test_not_has_git_repo_false(self, tmp_path: Path) -> None:
"""Should return False when .git directory exists."""
git_dir = tmp_path / ".git"
git_dir.mkdir()
with patch.object(Path, "cwd", return_value=tmp_path):
result = gittool.not_has_git_repo()
assert result is False
def test_not_has_git_repo_cwd_not_exists(self, tmp_path: Path) -> None:
"""Should return True when cwd doesn't exist."""
nonexistent = tmp_path / "nonexistent"
with patch.object(Path, "cwd", return_value=nonexistent):
result = gittool.not_has_git_repo()
assert result is True
# ---------------------------------------------------------------------- #
# has_files
# ---------------------------------------------------------------------- #
class TestHasFiles:
"""Test has_files function."""
def test_has_files_true(self, tmp_path: Path) -> None:
"""Should return True when files exist."""
(tmp_path / "test.txt").write_text("test")
with patch.object(Path, "cwd", return_value=tmp_path):
result = gittool.has_files()
assert result is True
def test_has_files_false(self, tmp_path: Path) -> None:
"""Should return False when no files."""
with patch.object(Path, "cwd", return_value=tmp_path):
result = gittool.has_files()
assert result is False
# ---------------------------------------------------------------------- #
# init_sub_dirs
# ---------------------------------------------------------------------- #
class TestInitSubDirs:
"""Test init_sub_dirs function."""
def test_init_sub_dirs_with_subdirectories(self, tmp_path: Path) -> None:
"""Should initialize git in subdirectories."""
subdir1 = tmp_path / "subdir1"
subdir1.mkdir()
subdir2 = tmp_path / "subdir2"
subdir2.mkdir()
with patch.object(Path, "cwd", return_value=tmp_path), patch.object(px, "run") as mock_run:
gittool.init_sub_dirs()
# Should call px.run for each subdirectory
assert mock_run.call_count == 2
def test_init_sub_dirs_no_subdirectories(self, tmp_path: Path) -> None:
"""Should handle no subdirectories."""
with patch.object(Path, "cwd", return_value=tmp_path), patch.object(px, "run") as mock_run:
gittool.init_sub_dirs()
# Should not call px.run
assert mock_run.call_count == 0
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# TaskSpec definitions # TaskSpec definitions
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
+106
View File
@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import subprocess
from pathlib import Path from pathlib import Path
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@@ -9,6 +10,95 @@ import pyflowx as px
from pyflowx.cli import piptool from pyflowx.cli import piptool
# ---------------------------------------------------------------------- #
# _get_installed_packages
# ---------------------------------------------------------------------- #
class TestGetInstalledPackages:
"""Test _get_installed_packages function."""
def test_get_installed_packages_success(self) -> None:
"""Should get installed packages."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(stdout="numpy==1.0.0\npandas==2.0.0\n", returncode=0)
result = piptool._get_installed_packages()
assert "numpy" in result
assert "pandas" in result
def test_get_installed_packages_empty(self) -> None:
"""Should handle empty output."""
with patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(stdout="", returncode=0)
result = piptool._get_installed_packages()
assert result == []
def test_get_installed_packages_error(self) -> None:
"""Should handle subprocess error."""
with patch("subprocess.run", side_effect=subprocess.SubprocessError):
result = piptool._get_installed_packages()
assert result == []
def test_get_installed_packages_oserror(self) -> None:
"""Should handle OSError."""
with patch("subprocess.run", side_effect=OSError):
result = piptool._get_installed_packages()
assert result == []
# ---------------------------------------------------------------------- #
# _expand_wildcard_packages
# ---------------------------------------------------------------------- #
class TestExpandWildcardPackages:
"""Test _expand_wildcard_packages function."""
def test_expand_wildcard_no_pattern(self) -> None:
"""Should return package name when no wildcard."""
result = piptool._expand_wildcard_packages("numpy")
assert result == ["numpy"]
def test_expand_wildcard_with_star(self) -> None:
"""Should expand wildcard with star."""
with patch.object(piptool, "_get_installed_packages", return_value=["numpy", "numpy-core", "pandas"]):
result = piptool._expand_wildcard_packages("numpy*")
assert "numpy" in result
assert "numpy-core" in result
def test_expand_wildcard_with_question(self) -> None:
"""Should expand wildcard with question mark."""
with patch.object(piptool, "_get_installed_packages", return_value=["numpy", "numba"]):
result = piptool._expand_wildcard_packages("num??")
assert len(result) > 0
def test_expand_wildcard_no_match(self) -> None:
"""Should return empty list when no match."""
with patch.object(piptool, "_get_installed_packages", return_value=["pandas", "scipy"]):
result = piptool._expand_wildcard_packages("numpy*")
assert result == []
# ---------------------------------------------------------------------- #
# _filter_protected_packages
# ---------------------------------------------------------------------- #
class TestFilterProtectedPackages:
"""Test _filter_protected_packages function."""
def test_filter_protected_packages_normal(self) -> None:
"""Should filter protected packages."""
result = piptool._filter_protected_packages(["numpy", "pandas", "pyflowx"])
assert "numpy" in result
assert "pandas" in result
assert "pyflowx" not in result
def test_filter_protected_packages_all_protected(self) -> None:
"""Should filter all protected packages."""
result = piptool._filter_protected_packages(["pyflowx", "bitool"])
assert result == []
def test_filter_protected_packages_case_insensitive(self) -> None:
"""Should filter case insensitive."""
result = piptool._filter_protected_packages(["PyFlowX", "BITOOL"])
assert result == []
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# pip_uninstall # pip_uninstall
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
@@ -39,6 +129,17 @@ class TestPipUninstall:
piptool.pip_uninstall(["numpy*"]) piptool.pip_uninstall(["numpy*"])
assert mock_run.called assert mock_run.called
def test_pip_uninstall_empty_packages(self) -> None:
"""Should handle empty packages list."""
with patch.object(piptool, "_expand_wildcard_packages", return_value=[]):
piptool.pip_uninstall(["nonexistent*"])
# Should not call subprocess.run
def test_pip_uninstall_all_protected(self) -> None:
"""Should handle all protected packages."""
piptool.pip_uninstall(["pyflowx"])
# Should not call subprocess.run
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# pip_reinstall # pip_reinstall
@@ -62,6 +163,11 @@ class TestPipReinstall:
# Should call pip install with offline flags # Should call pip install with offline flags
assert mock_run.called assert mock_run.called
def test_pip_reinstall_all_protected(self) -> None:
"""Should handle all protected packages."""
piptool.pip_reinstall(["pyflowx"])
# Should not call subprocess.run
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
# pip_download # pip_download
+73 -39
View File
@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import subprocess
from pathlib import Path from pathlib import Path
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@@ -17,51 +18,84 @@ from pyflowx.cli import sshcopyid
class TestSshCopyId: class TestSshCopyId:
"""Test ssh_copy_id function.""" """Test ssh_copy_id function."""
def test_ssh_copy_id_success(self) -> None: def test_ssh_copy_id_pub_key_not_exists(self, tmp_path: Path) -> None:
"""ssh_copy_id should deploy SSH key successfully.""" """Should handle nonexistent public key."""
pytest.importorskip("paramiko") with patch.object(Path, "expanduser", return_value=tmp_path / "nonexistent.pub"), pytest.raises(SystemExit):
with patch("paramiko.SSHClient") as mock_ssh_client, patch.object( sshcopyid.ssh_copy_id("localhost", "user", "password")
Path, "exists", return_value=True
), patch.object(Path, "read_text", return_value="ssh-rsa AAAAB3..."):
mock_client = MagicMock()
mock_ssh_client.return_value = mock_client
mock_client.connect.return_value = None
mock_client.exec_command.return_value = (MagicMock(), MagicMock(), MagicMock())
result = sshcopyid.ssh_copy_id("localhost", "user", "password") def test_ssh_copy_id_sshpass_not_found(self, tmp_path: Path) -> None:
assert result is None # Function doesn't return anything """Should handle sshpass not found."""
pub_key = tmp_path / "id_rsa.pub"
pub_key.write_text("ssh-rsa AAAAB3...")
def test_ssh_copy_id_with_custom_port(self) -> None: with patch.object(Path, "expanduser", return_value=pub_key), patch(
"""ssh_copy_id should handle custom port.""" "subprocess.run", side_effect=FileNotFoundError
pytest.importorskip("paramiko") ), pytest.raises(SystemExit):
with patch("paramiko.SSHClient") as mock_ssh_client, patch.object( sshcopyid.ssh_copy_id("localhost", "user", "password")
Path, "exists", return_value=True
), patch.object(Path, "read_text", return_value="ssh-rsa AAAAB3..."):
mock_client = MagicMock()
mock_ssh_client.return_value = mock_client
mock_client.connect.return_value = None
mock_client.exec_command.return_value = (MagicMock(), MagicMock(), MagicMock())
def test_ssh_copy_id_timeout(self, tmp_path: Path) -> None:
"""Should handle SSH timeout."""
pub_key = tmp_path / "id_rsa.pub"
pub_key.write_text("ssh-rsa AAAAB3...")
with patch.object(Path, "expanduser", return_value=pub_key), patch(
"subprocess.run", side_effect=subprocess.TimeoutExpired("cmd", 30)
), pytest.raises(SystemExit):
sshcopyid.ssh_copy_id("localhost", "user", "password")
def test_ssh_copy_id_process_error(self, tmp_path: Path) -> None:
"""Should handle SSH process error."""
pub_key = tmp_path / "id_rsa.pub"
pub_key.write_text("ssh-rsa AAAAB3...")
with patch.object(Path, "expanduser", return_value=pub_key), patch(
"subprocess.run", side_effect=subprocess.CalledProcessError(1, "cmd")
), pytest.raises(SystemExit):
sshcopyid.ssh_copy_id("localhost", "user", "password")
def test_ssh_copy_id_success(self, tmp_path: Path) -> None:
"""Should deploy SSH key successfully."""
pub_key = tmp_path / "id_rsa.pub"
pub_key.write_text("ssh-rsa AAAAB3...")
with patch.object(Path, "expanduser", return_value=pub_key), patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
sshcopyid.ssh_copy_id("localhost", "user", "password")
assert mock_run.called
def test_ssh_copy_id_with_custom_port(self, tmp_path: Path) -> None:
"""Should handle custom port."""
pub_key = tmp_path / "id_rsa.pub"
pub_key.write_text("ssh-rsa AAAAB3...")
with patch.object(Path, "expanduser", return_value=pub_key), patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
sshcopyid.ssh_copy_id("localhost", "user", "password", port=2222) sshcopyid.ssh_copy_id("localhost", "user", "password", port=2222)
# Verify that connect was called with custom port # Verify port is used
mock_client.connect.assert_called_once() call_args = mock_run.call_args[0][0]
call_args = mock_client.connect.call_args assert "2222" in call_args
assert call_args[1]["port"] == 2222
def test_ssh_copy_id_with_custom_keypath(self) -> None: def test_ssh_copy_id_with_custom_keypath(self, tmp_path: Path) -> None:
"""ssh_copy_id should handle custom key path.""" """Should handle custom keypath."""
pytest.importorskip("paramiko") custom_key = tmp_path / "custom.pub"
with patch("paramiko.SSHClient") as mock_ssh_client, patch.object( custom_key.write_text("ssh-rsa AAAAB3...")
Path, "exists", return_value=True
), patch.object(Path, "read_text", return_value="ssh-rsa AAAAB3..."):
mock_client = MagicMock()
mock_ssh_client.return_value = mock_client
mock_client.connect.return_value = None
mock_client.exec_command.return_value = (MagicMock(), MagicMock(), MagicMock())
result = sshcopyid.ssh_copy_id("localhost", "user", "password", keypath="/custom/key.pub") with patch.object(Path, "expanduser", return_value=custom_key), patch("subprocess.run") as mock_run:
# Verify that the custom keypath was used mock_run.return_value = MagicMock(returncode=0)
assert result is None sshcopyid.ssh_copy_id("localhost", "user", "password", keypath=str(custom_key))
assert mock_run.called
def test_ssh_copy_id_with_custom_timeout(self, tmp_path: Path) -> None:
"""Should handle custom timeout."""
pub_key = tmp_path / "id_rsa.pub"
pub_key.write_text("ssh-rsa AAAAB3...")
with patch.object(Path, "expanduser", return_value=pub_key), patch("subprocess.run") as mock_run:
mock_run.return_value = MagicMock(returncode=0)
sshcopyid.ssh_copy_id("localhost", "user", "password", timeout=60)
# Verify timeout is used in ConnectTimeout option
call_args = mock_run.call_args[0][0]
assert "ConnectTimeout=60" in call_args
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #