From 0df795237d598242d3d885b6bf1989d4e34156a5 Mon Sep 17 00:00:00 2001 From: gooker_young Date: Mon, 22 Jun 2026 12:31:26 +0800 Subject: [PATCH] ~tests --- pyproject.toml | 2 +- src/pyflowx/cli/clearscreen.py | 5 -- src/pyflowx/cli/envqt.py | 61 +++++------------- src/pyflowx/cli/pymake.py | 16 +++-- tests/cli/test_autofmt.py | 97 ++++++++++++++++++++++++++-- tests/cli/test_folderback.py | 101 +++++++++++++++++++++++++++++ tests/cli/test_gittool.py | 79 +++++++++++++++++++++++ tests/cli/test_piptool.py | 106 +++++++++++++++++++++++++++++++ tests/cli/test_sshcopyid.py | 112 +++++++++++++++++++++------------ 9 files changed, 474 insertions(+), 105 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 45925a0..a6eb058 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ version = "0.1.7" [project.scripts] autofmt = "pyflowx.cli.autofmt:main" bumpver = "pyflowx.cli.bumpversion:main" -cls = "pyflowx.cli.clearscreen:main" +clr = "pyflowx.cli.clearscreen:main" envpy = "pyflowx.cli.envpy:main" envqt = "pyflowx.cli.envqt:main" envrs = "pyflowx.cli.envrs:main" diff --git a/src/pyflowx/cli/clearscreen.py b/src/pyflowx/cli/clearscreen.py index 19f35f0..3407f81 100644 --- a/src/pyflowx/cli/clearscreen.py +++ b/src/pyflowx/cli/clearscreen.py @@ -10,10 +10,6 @@ import subprocess import pyflowx as px from pyflowx.conditions import Constants -# ============================================================================ -# 辅助函数 -# ============================================================================ - def clear_screen() -> None: """使用系统命令清屏.""" @@ -23,7 +19,6 @@ def clear_screen() -> None: subprocess.run(["clear"], check=False) print("\033[2J\033[H", end="") - print("ClearScreen - 清屏工具") def main() -> None: diff --git a/src/pyflowx/cli/envqt.py b/src/pyflowx/cli/envqt.py index 2b82e99..8eaff49 100644 --- a/src/pyflowx/cli/envqt.py +++ b/src/pyflowx/cli/envqt.py @@ -8,10 +8,6 @@ from __future__ import annotations import pyflowx as px from pyflowx.conditions import Constants -# ============================================================================ -# Qt 依赖列表 -# ============================================================================ - QT_LIBS: list[str] = [ "build-essential", "libgl1", @@ -40,47 +36,22 @@ CHINESE_FONTS: list[str] = [ ] -# ============================================================================ -# TaskSpec 定义 -# ============================================================================ - - -# 条件: 仅在 Unix 系统上执行 -def is_linux() -> bool: - """判断是否为 Linux 系统.""" - return Constants.IS_LINUX and not Constants.IS_MACOS - - -envqt_install: px.TaskSpec = px.TaskSpec( - "envqt_install", - cmd=["sudo", "apt", "install", "-y", *QT_LIBS], - conditions=(is_linux,), -) - -envqt_fonts: px.TaskSpec = px.TaskSpec( - "envqt_fonts", - cmd=["sudo", "apt", "install", "-y", *CHINESE_FONTS], - conditions=(is_linux,), -) - - -# ============================================================================ -# CLI Runner -# ============================================================================ - - def main() -> None: """PyQt 环境配置工具主函数.""" - runner = px.CliRunner( - strategy="thread", - description="EnvQt - PyQt 环境配置工具", - graphs={ - # 安装 Qt 依赖 - "i": px.Graph.from_specs([envqt_install]), - # 安装中文字体 - "f": px.Graph.from_specs([envqt_fonts]), - # 安装全部 - "a": px.Graph.from_specs([envqt_install, envqt_fonts]), - }, + graph = px.Graph.from_specs( + [ + px.TaskSpec( + "envqt_install", + cmd=["sudo", "apt", "install", "-y", *QT_LIBS], + conditions=(lambda: Constants.IS_LINUX,), + verbose=True, + ), + px.TaskSpec( + "envqt_fonts", + cmd=["sudo", "apt", "install", "-y", *CHINESE_FONTS], + conditions=(lambda: Constants.IS_LINUX,), + verbose=True, + ), + ], ) - runner.run_cli() + px.run(graph, strategy="thread", verbose=True) diff --git a/src/pyflowx/cli/pymake.py b/src/pyflowx/cli/pymake.py index 1666081..5b38590 100644 --- a/src/pyflowx/cli/pymake.py +++ b/src/pyflowx/cli/pymake.py @@ -20,15 +20,13 @@ def maturin_build_cmd() -> list[str]: """ command = ["maturin", "build", "-r"].copy() if Constants.IS_WINDOWS: - command.extend( - [ - "--target", - "x86_64-win7-windows-msvc", - "-Zbuild-std", - "-i", - "python3.8", - ] - ) + command.extend([ + "--target", + "x86_64-win7-windows-msvc", + "-Zbuild-std", + "-i", + "python3.8", + ]) return command diff --git a/tests/cli/test_autofmt.py b/tests/cli/test_autofmt.py index 27c9aab..5c221c6 100644 --- a/tests/cli/test_autofmt.py +++ b/tests/cli/test_autofmt.py @@ -22,6 +22,15 @@ class TestFormatWithRuff: autofmt.format_with_ruff(tmp_path, fix=True) assert mock_run.called + def test_format_with_ruff_no_fix(self, tmp_path: Path) -> None: + """Should format with ruff without fix.""" + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0) + autofmt.format_with_ruff(tmp_path, fix=False) + # Should not include --fix flag + call_args = mock_run.call_args[0][0] + assert "--fix" not in call_args + # ---------------------------------------------------------------------- # # lint_with_ruff @@ -36,6 +45,15 @@ class TestLintWithRuff: autofmt.lint_with_ruff(tmp_path, fix=True) assert mock_run.called + def test_lint_with_ruff_no_fix(self, tmp_path: Path) -> None: + """Should lint with ruff without fix.""" + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0) + autofmt.lint_with_ruff(tmp_path, fix=False) + # Should not include --fix flag + call_args = mock_run.call_args[0][0] + assert "--fix" not in call_args + # ---------------------------------------------------------------------- # # add_docstring @@ -51,15 +69,64 @@ class TestAddDocstring: result = autofmt.add_docstring(py_file, '"""Test module."""') assert result is True - def test_add_docstring_skips_non_python_files(self, tmp_path: Path) -> None: - """Should skip non-Python files.""" - txt_file = tmp_path / "test.txt" - txt_file.write_text("test content") + def test_add_docstring_skips_files_with_docstring(self, tmp_path: Path) -> None: + """Should skip files that already have docstring.""" + py_file = tmp_path / "test.py" + py_file.write_text('"""Existing docstring."""\ndef test():\n pass\n') - result = autofmt.add_docstring(txt_file, '"""Test."""') - # Should return False for non-Python files + result = autofmt.add_docstring(py_file, '"""New docstring."""') assert result is False + def test_add_docstring_empty_file(self, tmp_path: Path) -> None: + """Should handle empty file.""" + py_file = tmp_path / "test.py" + py_file.write_text("") + + result = autofmt.add_docstring(py_file, '"""Test module."""') + # Should handle empty file + assert result is True + + +# ---------------------------------------------------------------------- # +# generate_module_docstring +# ---------------------------------------------------------------------- # +class TestGenerateModuleDocstring: + """Test generate_module_docstring function.""" + + def test_generate_module_docstring_basic(self, tmp_path: Path) -> None: + """Should generate basic docstring.""" + py_file = tmp_path / "test.py" + py_file.write_text("def test():\n pass\n") + + result = autofmt.generate_module_docstring(py_file) + # Should contain "Tests for" since stem contains "test" + assert "Tests for" in result + + def test_generate_module_docstring_with_package(self, tmp_path: Path) -> None: + """Should generate docstring for package.""" + py_file = tmp_path / "mypackage" / "test.py" + py_file.parent.mkdir(parents=True) + py_file.write_text("def test():\n pass\n") + + result = autofmt.generate_module_docstring(py_file) + assert "mypackage" in result + + def test_generate_module_docstring_cli(self, tmp_path: Path) -> None: + """Should generate docstring for CLI module.""" + py_file = tmp_path / "cli.py" + py_file.write_text("def test():\n pass\n") + + result = autofmt.generate_module_docstring(py_file) + assert "Command-line interface" in result + + def test_generate_module_docstring_util(self, tmp_path: Path) -> None: + """Should generate docstring for utility module.""" + py_file = tmp_path / "utils.py" + py_file.write_text("def test():\n pass\n") + + result = autofmt.generate_module_docstring(py_file) + assert "Utility functions" in result + # ---------------------------------------------------------------------- # # auto_add_docstrings @@ -76,6 +143,24 @@ class TestAutoAddDocstrings: count = autofmt.auto_add_docstrings(tmp_path) assert count >= 0 + def test_auto_add_docstrings_skips_ignored(self, tmp_path: Path) -> None: + """Should skip ignored directories.""" + py_file = tmp_path / "__pycache__" / "test.py" + py_file.parent.mkdir() + py_file.write_text("def test():\n pass\n") + + count = autofmt.auto_add_docstrings(tmp_path) + # Should skip __pycache__ + assert count == 0 + + def test_auto_add_docstrings_no_files(self, tmp_path: Path) -> None: + """Should handle no Python files.""" + txt_file = tmp_path / "test.txt" + txt_file.write_text("test content") + + count = autofmt.auto_add_docstrings(tmp_path) + assert count == 0 + # ---------------------------------------------------------------------- # # sync_pyproject_config diff --git a/tests/cli/test_folderback.py b/tests/cli/test_folderback.py index 49f90d1..a6d3041 100644 --- a/tests/cli/test_folderback.py +++ b/tests/cli/test_folderback.py @@ -9,6 +9,95 @@ import pyflowx as px from pyflowx.cli import folderback +# ---------------------------------------------------------------------- # +# remove_dump +# ---------------------------------------------------------------------- # +class TestRemoveDump: + """Test remove_dump function.""" + + def test_remove_dump_no_files(self, tmp_path: Path) -> None: + """Should handle no zip files.""" + src = tmp_path / "source" + src.mkdir() + dst = tmp_path / "backup" + dst.mkdir() + + folderback.remove_dump(src, dst, 5) + # Should not raise error + + def test_remove_dump_within_limit(self, tmp_path: Path) -> None: + """Should not remove files within limit.""" + src = tmp_path / "source" + src.mkdir() + dst = tmp_path / "backup" + dst.mkdir() + + # Create some zip files + for i in range(3): + zip_file = dst / f"source_20240101_12000{i}.zip" + zip_file.write_bytes(b"ZIP content") + + folderback.remove_dump(src, dst, 5) + # All files should remain + assert len(list(dst.glob("*.zip"))) == 3 + + def test_remove_dump_exceeds_limit(self, tmp_path: Path) -> None: + """Should remove oldest files when exceeds limit.""" + src = tmp_path / "source" + src.mkdir() + dst = tmp_path / "backup" + dst.mkdir() + + # Create more zip files than limit + for i in range(7): + zip_file = dst / f"source_20240101_12000{i}.zip" + zip_file.write_bytes(b"ZIP content") + + folderback.remove_dump(src, dst, 5) + # Should have only 5 files + assert len(list(dst.glob("*.zip"))) == 5 + + +# ---------------------------------------------------------------------- # +# zip_target +# ---------------------------------------------------------------------- # +class TestZipTarget: + """Test zip_target function.""" + + def test_zip_target_creates_zip(self, tmp_path: Path) -> None: + """Should create zip file.""" + src = tmp_path / "source" + src.mkdir() + (src / "test.txt").write_text("test content") + dst = tmp_path / "backup" + dst.mkdir() + + with patch("time.strftime", return_value="_20240101_120000"): + folderback.zip_target(src, dst, 5) + + # Should create zip file + zip_files = list(dst.glob("*.zip")) + assert len(zip_files) == 1 + + def test_zip_target_with_subdirectories(self, tmp_path: Path) -> None: + """Should zip files in subdirectories.""" + src = tmp_path / "source" + src.mkdir() + subdir = src / "subdir" + subdir.mkdir() + (src / "test.txt").write_text("test content") + (subdir / "nested.txt").write_text("nested content") + dst = tmp_path / "backup" + dst.mkdir() + + with patch("time.strftime", return_value="_20240101_120000"): + folderback.zip_target(src, dst, 5) + + # Should create zip file + zip_files = list(dst.glob("*.zip")) + assert len(zip_files) == 1 + + # ---------------------------------------------------------------------- # # backup_folder # ---------------------------------------------------------------------- # @@ -46,6 +135,18 @@ class TestBackupFolder: folderback.backup_folder(str(source_dir), str(backup_dir), 5) # Should print error message and return + def test_backup_folder_creates_dst(self, tmp_path: Path) -> None: + """Should create destination directory.""" + source_dir = tmp_path / "source" + source_dir.mkdir() + (source_dir / "test.txt").write_text("test content") + backup_dir = tmp_path / "backup" + + with patch.object(folderback, "zip_target") as mock_zip: + folderback.backup_folder(str(source_dir), str(backup_dir), 5) + assert backup_dir.exists() + assert mock_zip.called + # ---------------------------------------------------------------------- # # TaskSpec definitions diff --git a/tests/cli/test_gittool.py b/tests/cli/test_gittool.py index fdae3a2..8c1461c 100644 --- a/tests/cli/test_gittool.py +++ b/tests/cli/test_gittool.py @@ -2,13 +2,92 @@ from __future__ import annotations +from pathlib import Path from unittest.mock import patch import pytest +import pyflowx as px from pyflowx.cli import gittool +# ---------------------------------------------------------------------- # +# not_has_git_repo +# ---------------------------------------------------------------------- # +class TestNotHasGitRepo: + """Test not_has_git_repo function.""" + + def test_not_has_git_repo_true(self, tmp_path: Path) -> None: + """Should return True when no .git directory.""" + with patch.object(Path, "cwd", return_value=tmp_path): + result = gittool.not_has_git_repo() + assert result is True + + def test_not_has_git_repo_false(self, tmp_path: Path) -> None: + """Should return False when .git directory exists.""" + git_dir = tmp_path / ".git" + git_dir.mkdir() + + with patch.object(Path, "cwd", return_value=tmp_path): + result = gittool.not_has_git_repo() + assert result is False + + def test_not_has_git_repo_cwd_not_exists(self, tmp_path: Path) -> None: + """Should return True when cwd doesn't exist.""" + nonexistent = tmp_path / "nonexistent" + + with patch.object(Path, "cwd", return_value=nonexistent): + result = gittool.not_has_git_repo() + assert result is True + + +# ---------------------------------------------------------------------- # +# has_files +# ---------------------------------------------------------------------- # +class TestHasFiles: + """Test has_files function.""" + + def test_has_files_true(self, tmp_path: Path) -> None: + """Should return True when files exist.""" + (tmp_path / "test.txt").write_text("test") + + with patch.object(Path, "cwd", return_value=tmp_path): + result = gittool.has_files() + assert result is True + + def test_has_files_false(self, tmp_path: Path) -> None: + """Should return False when no files.""" + with patch.object(Path, "cwd", return_value=tmp_path): + result = gittool.has_files() + assert result is False + + +# ---------------------------------------------------------------------- # +# init_sub_dirs +# ---------------------------------------------------------------------- # +class TestInitSubDirs: + """Test init_sub_dirs function.""" + + def test_init_sub_dirs_with_subdirectories(self, tmp_path: Path) -> None: + """Should initialize git in subdirectories.""" + subdir1 = tmp_path / "subdir1" + subdir1.mkdir() + subdir2 = tmp_path / "subdir2" + subdir2.mkdir() + + with patch.object(Path, "cwd", return_value=tmp_path), patch.object(px, "run") as mock_run: + gittool.init_sub_dirs() + # Should call px.run for each subdirectory + assert mock_run.call_count == 2 + + def test_init_sub_dirs_no_subdirectories(self, tmp_path: Path) -> None: + """Should handle no subdirectories.""" + with patch.object(Path, "cwd", return_value=tmp_path), patch.object(px, "run") as mock_run: + gittool.init_sub_dirs() + # Should not call px.run + assert mock_run.call_count == 0 + + # ---------------------------------------------------------------------- # # TaskSpec definitions # ---------------------------------------------------------------------- # diff --git a/tests/cli/test_piptool.py b/tests/cli/test_piptool.py index 9a43935..5252b37 100644 --- a/tests/cli/test_piptool.py +++ b/tests/cli/test_piptool.py @@ -2,6 +2,7 @@ from __future__ import annotations +import subprocess from pathlib import Path from unittest.mock import MagicMock, patch @@ -9,6 +10,95 @@ import pyflowx as px from pyflowx.cli import piptool +# ---------------------------------------------------------------------- # +# _get_installed_packages +# ---------------------------------------------------------------------- # +class TestGetInstalledPackages: + """Test _get_installed_packages function.""" + + def test_get_installed_packages_success(self) -> None: + """Should get installed packages.""" + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(stdout="numpy==1.0.0\npandas==2.0.0\n", returncode=0) + result = piptool._get_installed_packages() + assert "numpy" in result + assert "pandas" in result + + def test_get_installed_packages_empty(self) -> None: + """Should handle empty output.""" + with patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(stdout="", returncode=0) + result = piptool._get_installed_packages() + assert result == [] + + def test_get_installed_packages_error(self) -> None: + """Should handle subprocess error.""" + with patch("subprocess.run", side_effect=subprocess.SubprocessError): + result = piptool._get_installed_packages() + assert result == [] + + def test_get_installed_packages_oserror(self) -> None: + """Should handle OSError.""" + with patch("subprocess.run", side_effect=OSError): + result = piptool._get_installed_packages() + assert result == [] + + +# ---------------------------------------------------------------------- # +# _expand_wildcard_packages +# ---------------------------------------------------------------------- # +class TestExpandWildcardPackages: + """Test _expand_wildcard_packages function.""" + + def test_expand_wildcard_no_pattern(self) -> None: + """Should return package name when no wildcard.""" + result = piptool._expand_wildcard_packages("numpy") + assert result == ["numpy"] + + def test_expand_wildcard_with_star(self) -> None: + """Should expand wildcard with star.""" + with patch.object(piptool, "_get_installed_packages", return_value=["numpy", "numpy-core", "pandas"]): + result = piptool._expand_wildcard_packages("numpy*") + assert "numpy" in result + assert "numpy-core" in result + + def test_expand_wildcard_with_question(self) -> None: + """Should expand wildcard with question mark.""" + with patch.object(piptool, "_get_installed_packages", return_value=["numpy", "numba"]): + result = piptool._expand_wildcard_packages("num??") + assert len(result) > 0 + + def test_expand_wildcard_no_match(self) -> None: + """Should return empty list when no match.""" + with patch.object(piptool, "_get_installed_packages", return_value=["pandas", "scipy"]): + result = piptool._expand_wildcard_packages("numpy*") + assert result == [] + + +# ---------------------------------------------------------------------- # +# _filter_protected_packages +# ---------------------------------------------------------------------- # +class TestFilterProtectedPackages: + """Test _filter_protected_packages function.""" + + def test_filter_protected_packages_normal(self) -> None: + """Should filter protected packages.""" + result = piptool._filter_protected_packages(["numpy", "pandas", "pyflowx"]) + assert "numpy" in result + assert "pandas" in result + assert "pyflowx" not in result + + def test_filter_protected_packages_all_protected(self) -> None: + """Should filter all protected packages.""" + result = piptool._filter_protected_packages(["pyflowx", "bitool"]) + assert result == [] + + def test_filter_protected_packages_case_insensitive(self) -> None: + """Should filter case insensitive.""" + result = piptool._filter_protected_packages(["PyFlowX", "BITOOL"]) + assert result == [] + + # ---------------------------------------------------------------------- # # pip_uninstall # ---------------------------------------------------------------------- # @@ -39,6 +129,17 @@ class TestPipUninstall: piptool.pip_uninstall(["numpy*"]) assert mock_run.called + def test_pip_uninstall_empty_packages(self) -> None: + """Should handle empty packages list.""" + with patch.object(piptool, "_expand_wildcard_packages", return_value=[]): + piptool.pip_uninstall(["nonexistent*"]) + # Should not call subprocess.run + + def test_pip_uninstall_all_protected(self) -> None: + """Should handle all protected packages.""" + piptool.pip_uninstall(["pyflowx"]) + # Should not call subprocess.run + # ---------------------------------------------------------------------- # # pip_reinstall @@ -62,6 +163,11 @@ class TestPipReinstall: # Should call pip install with offline flags assert mock_run.called + def test_pip_reinstall_all_protected(self) -> None: + """Should handle all protected packages.""" + piptool.pip_reinstall(["pyflowx"]) + # Should not call subprocess.run + # ---------------------------------------------------------------------- # # pip_download diff --git a/tests/cli/test_sshcopyid.py b/tests/cli/test_sshcopyid.py index d2cbf24..5caad89 100644 --- a/tests/cli/test_sshcopyid.py +++ b/tests/cli/test_sshcopyid.py @@ -2,6 +2,7 @@ from __future__ import annotations +import subprocess from pathlib import Path from unittest.mock import MagicMock, patch @@ -17,51 +18,84 @@ from pyflowx.cli import sshcopyid class TestSshCopyId: """Test ssh_copy_id function.""" - def test_ssh_copy_id_success(self) -> None: - """ssh_copy_id should deploy SSH key successfully.""" - pytest.importorskip("paramiko") - with patch("paramiko.SSHClient") as mock_ssh_client, patch.object( - Path, "exists", return_value=True - ), patch.object(Path, "read_text", return_value="ssh-rsa AAAAB3..."): - mock_client = MagicMock() - mock_ssh_client.return_value = mock_client - mock_client.connect.return_value = None - mock_client.exec_command.return_value = (MagicMock(), MagicMock(), MagicMock()) + def test_ssh_copy_id_pub_key_not_exists(self, tmp_path: Path) -> None: + """Should handle nonexistent public key.""" + with patch.object(Path, "expanduser", return_value=tmp_path / "nonexistent.pub"), pytest.raises(SystemExit): + sshcopyid.ssh_copy_id("localhost", "user", "password") - result = sshcopyid.ssh_copy_id("localhost", "user", "password") - assert result is None # Function doesn't return anything + def test_ssh_copy_id_sshpass_not_found(self, tmp_path: Path) -> None: + """Should handle sshpass not found.""" + pub_key = tmp_path / "id_rsa.pub" + pub_key.write_text("ssh-rsa AAAAB3...") - def test_ssh_copy_id_with_custom_port(self) -> None: - """ssh_copy_id should handle custom port.""" - pytest.importorskip("paramiko") - with patch("paramiko.SSHClient") as mock_ssh_client, patch.object( - Path, "exists", return_value=True - ), patch.object(Path, "read_text", return_value="ssh-rsa AAAAB3..."): - mock_client = MagicMock() - mock_ssh_client.return_value = mock_client - mock_client.connect.return_value = None - mock_client.exec_command.return_value = (MagicMock(), MagicMock(), MagicMock()) + with patch.object(Path, "expanduser", return_value=pub_key), patch( + "subprocess.run", side_effect=FileNotFoundError + ), pytest.raises(SystemExit): + sshcopyid.ssh_copy_id("localhost", "user", "password") + def test_ssh_copy_id_timeout(self, tmp_path: Path) -> None: + """Should handle SSH timeout.""" + pub_key = tmp_path / "id_rsa.pub" + pub_key.write_text("ssh-rsa AAAAB3...") + + with patch.object(Path, "expanduser", return_value=pub_key), patch( + "subprocess.run", side_effect=subprocess.TimeoutExpired("cmd", 30) + ), pytest.raises(SystemExit): + sshcopyid.ssh_copy_id("localhost", "user", "password") + + def test_ssh_copy_id_process_error(self, tmp_path: Path) -> None: + """Should handle SSH process error.""" + pub_key = tmp_path / "id_rsa.pub" + pub_key.write_text("ssh-rsa AAAAB3...") + + with patch.object(Path, "expanduser", return_value=pub_key), patch( + "subprocess.run", side_effect=subprocess.CalledProcessError(1, "cmd") + ), pytest.raises(SystemExit): + sshcopyid.ssh_copy_id("localhost", "user", "password") + + def test_ssh_copy_id_success(self, tmp_path: Path) -> None: + """Should deploy SSH key successfully.""" + pub_key = tmp_path / "id_rsa.pub" + pub_key.write_text("ssh-rsa AAAAB3...") + + with patch.object(Path, "expanduser", return_value=pub_key), patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0) + sshcopyid.ssh_copy_id("localhost", "user", "password") + assert mock_run.called + + def test_ssh_copy_id_with_custom_port(self, tmp_path: Path) -> None: + """Should handle custom port.""" + pub_key = tmp_path / "id_rsa.pub" + pub_key.write_text("ssh-rsa AAAAB3...") + + with patch.object(Path, "expanduser", return_value=pub_key), patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0) sshcopyid.ssh_copy_id("localhost", "user", "password", port=2222) - # Verify that connect was called with custom port - mock_client.connect.assert_called_once() - call_args = mock_client.connect.call_args - assert call_args[1]["port"] == 2222 + # Verify port is used + call_args = mock_run.call_args[0][0] + assert "2222" in call_args - def test_ssh_copy_id_with_custom_keypath(self) -> None: - """ssh_copy_id should handle custom key path.""" - pytest.importorskip("paramiko") - with patch("paramiko.SSHClient") as mock_ssh_client, patch.object( - Path, "exists", return_value=True - ), patch.object(Path, "read_text", return_value="ssh-rsa AAAAB3..."): - mock_client = MagicMock() - mock_ssh_client.return_value = mock_client - mock_client.connect.return_value = None - mock_client.exec_command.return_value = (MagicMock(), MagicMock(), MagicMock()) + def test_ssh_copy_id_with_custom_keypath(self, tmp_path: Path) -> None: + """Should handle custom keypath.""" + custom_key = tmp_path / "custom.pub" + custom_key.write_text("ssh-rsa AAAAB3...") - result = sshcopyid.ssh_copy_id("localhost", "user", "password", keypath="/custom/key.pub") - # Verify that the custom keypath was used - assert result is None + with patch.object(Path, "expanduser", return_value=custom_key), patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0) + sshcopyid.ssh_copy_id("localhost", "user", "password", keypath=str(custom_key)) + assert mock_run.called + + def test_ssh_copy_id_with_custom_timeout(self, tmp_path: Path) -> None: + """Should handle custom timeout.""" + pub_key = tmp_path / "id_rsa.pub" + pub_key.write_text("ssh-rsa AAAAB3...") + + with patch.object(Path, "expanduser", return_value=pub_key), patch("subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0) + sshcopyid.ssh_copy_id("localhost", "user", "password", timeout=60) + # Verify timeout is used in ConnectTimeout option + call_args = mock_run.call_args[0][0] + assert "ConnectTimeout=60" in call_args # ---------------------------------------------------------------------- #