~tests
This commit is contained in:
@@ -22,6 +22,15 @@ class TestFormatWithRuff:
|
||||
autofmt.format_with_ruff(tmp_path, fix=True)
|
||||
assert mock_run.called
|
||||
|
||||
def test_format_with_ruff_no_fix(self, tmp_path: Path) -> None:
|
||||
"""Should format with ruff without fix."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
autofmt.format_with_ruff(tmp_path, fix=False)
|
||||
# Should not include --fix flag
|
||||
call_args = mock_run.call_args[0][0]
|
||||
assert "--fix" not in call_args
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# lint_with_ruff
|
||||
@@ -36,6 +45,15 @@ class TestLintWithRuff:
|
||||
autofmt.lint_with_ruff(tmp_path, fix=True)
|
||||
assert mock_run.called
|
||||
|
||||
def test_lint_with_ruff_no_fix(self, tmp_path: Path) -> None:
|
||||
"""Should lint with ruff without fix."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
autofmt.lint_with_ruff(tmp_path, fix=False)
|
||||
# Should not include --fix flag
|
||||
call_args = mock_run.call_args[0][0]
|
||||
assert "--fix" not in call_args
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# add_docstring
|
||||
@@ -51,15 +69,64 @@ class TestAddDocstring:
|
||||
result = autofmt.add_docstring(py_file, '"""Test module."""')
|
||||
assert result is True
|
||||
|
||||
def test_add_docstring_skips_non_python_files(self, tmp_path: Path) -> None:
|
||||
"""Should skip non-Python files."""
|
||||
txt_file = tmp_path / "test.txt"
|
||||
txt_file.write_text("test content")
|
||||
def test_add_docstring_skips_files_with_docstring(self, tmp_path: Path) -> None:
|
||||
"""Should skip files that already have docstring."""
|
||||
py_file = tmp_path / "test.py"
|
||||
py_file.write_text('"""Existing docstring."""\ndef test():\n pass\n')
|
||||
|
||||
result = autofmt.add_docstring(txt_file, '"""Test."""')
|
||||
# Should return False for non-Python files
|
||||
result = autofmt.add_docstring(py_file, '"""New docstring."""')
|
||||
assert result is False
|
||||
|
||||
def test_add_docstring_empty_file(self, tmp_path: Path) -> None:
|
||||
"""Should handle empty file."""
|
||||
py_file = tmp_path / "test.py"
|
||||
py_file.write_text("")
|
||||
|
||||
result = autofmt.add_docstring(py_file, '"""Test module."""')
|
||||
# Should handle empty file
|
||||
assert result is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# generate_module_docstring
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestGenerateModuleDocstring:
|
||||
"""Test generate_module_docstring function."""
|
||||
|
||||
def test_generate_module_docstring_basic(self, tmp_path: Path) -> None:
|
||||
"""Should generate basic docstring."""
|
||||
py_file = tmp_path / "test.py"
|
||||
py_file.write_text("def test():\n pass\n")
|
||||
|
||||
result = autofmt.generate_module_docstring(py_file)
|
||||
# Should contain "Tests for" since stem contains "test"
|
||||
assert "Tests for" in result
|
||||
|
||||
def test_generate_module_docstring_with_package(self, tmp_path: Path) -> None:
|
||||
"""Should generate docstring for package."""
|
||||
py_file = tmp_path / "mypackage" / "test.py"
|
||||
py_file.parent.mkdir(parents=True)
|
||||
py_file.write_text("def test():\n pass\n")
|
||||
|
||||
result = autofmt.generate_module_docstring(py_file)
|
||||
assert "mypackage" in result
|
||||
|
||||
def test_generate_module_docstring_cli(self, tmp_path: Path) -> None:
|
||||
"""Should generate docstring for CLI module."""
|
||||
py_file = tmp_path / "cli.py"
|
||||
py_file.write_text("def test():\n pass\n")
|
||||
|
||||
result = autofmt.generate_module_docstring(py_file)
|
||||
assert "Command-line interface" in result
|
||||
|
||||
def test_generate_module_docstring_util(self, tmp_path: Path) -> None:
|
||||
"""Should generate docstring for utility module."""
|
||||
py_file = tmp_path / "utils.py"
|
||||
py_file.write_text("def test():\n pass\n")
|
||||
|
||||
result = autofmt.generate_module_docstring(py_file)
|
||||
assert "Utility functions" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# auto_add_docstrings
|
||||
@@ -76,6 +143,24 @@ class TestAutoAddDocstrings:
|
||||
count = autofmt.auto_add_docstrings(tmp_path)
|
||||
assert count >= 0
|
||||
|
||||
def test_auto_add_docstrings_skips_ignored(self, tmp_path: Path) -> None:
|
||||
"""Should skip ignored directories."""
|
||||
py_file = tmp_path / "__pycache__" / "test.py"
|
||||
py_file.parent.mkdir()
|
||||
py_file.write_text("def test():\n pass\n")
|
||||
|
||||
count = autofmt.auto_add_docstrings(tmp_path)
|
||||
# Should skip __pycache__
|
||||
assert count == 0
|
||||
|
||||
def test_auto_add_docstrings_no_files(self, tmp_path: Path) -> None:
|
||||
"""Should handle no Python files."""
|
||||
txt_file = tmp_path / "test.txt"
|
||||
txt_file.write_text("test content")
|
||||
|
||||
count = autofmt.auto_add_docstrings(tmp_path)
|
||||
assert count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# sync_pyproject_config
|
||||
|
||||
@@ -9,6 +9,95 @@ import pyflowx as px
|
||||
from pyflowx.cli import folderback
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# remove_dump
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestRemoveDump:
|
||||
"""Test remove_dump function."""
|
||||
|
||||
def test_remove_dump_no_files(self, tmp_path: Path) -> None:
|
||||
"""Should handle no zip files."""
|
||||
src = tmp_path / "source"
|
||||
src.mkdir()
|
||||
dst = tmp_path / "backup"
|
||||
dst.mkdir()
|
||||
|
||||
folderback.remove_dump(src, dst, 5)
|
||||
# Should not raise error
|
||||
|
||||
def test_remove_dump_within_limit(self, tmp_path: Path) -> None:
|
||||
"""Should not remove files within limit."""
|
||||
src = tmp_path / "source"
|
||||
src.mkdir()
|
||||
dst = tmp_path / "backup"
|
||||
dst.mkdir()
|
||||
|
||||
# Create some zip files
|
||||
for i in range(3):
|
||||
zip_file = dst / f"source_20240101_12000{i}.zip"
|
||||
zip_file.write_bytes(b"ZIP content")
|
||||
|
||||
folderback.remove_dump(src, dst, 5)
|
||||
# All files should remain
|
||||
assert len(list(dst.glob("*.zip"))) == 3
|
||||
|
||||
def test_remove_dump_exceeds_limit(self, tmp_path: Path) -> None:
|
||||
"""Should remove oldest files when exceeds limit."""
|
||||
src = tmp_path / "source"
|
||||
src.mkdir()
|
||||
dst = tmp_path / "backup"
|
||||
dst.mkdir()
|
||||
|
||||
# Create more zip files than limit
|
||||
for i in range(7):
|
||||
zip_file = dst / f"source_20240101_12000{i}.zip"
|
||||
zip_file.write_bytes(b"ZIP content")
|
||||
|
||||
folderback.remove_dump(src, dst, 5)
|
||||
# Should have only 5 files
|
||||
assert len(list(dst.glob("*.zip"))) == 5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# zip_target
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestZipTarget:
|
||||
"""Test zip_target function."""
|
||||
|
||||
def test_zip_target_creates_zip(self, tmp_path: Path) -> None:
|
||||
"""Should create zip file."""
|
||||
src = tmp_path / "source"
|
||||
src.mkdir()
|
||||
(src / "test.txt").write_text("test content")
|
||||
dst = tmp_path / "backup"
|
||||
dst.mkdir()
|
||||
|
||||
with patch("time.strftime", return_value="_20240101_120000"):
|
||||
folderback.zip_target(src, dst, 5)
|
||||
|
||||
# Should create zip file
|
||||
zip_files = list(dst.glob("*.zip"))
|
||||
assert len(zip_files) == 1
|
||||
|
||||
def test_zip_target_with_subdirectories(self, tmp_path: Path) -> None:
|
||||
"""Should zip files in subdirectories."""
|
||||
src = tmp_path / "source"
|
||||
src.mkdir()
|
||||
subdir = src / "subdir"
|
||||
subdir.mkdir()
|
||||
(src / "test.txt").write_text("test content")
|
||||
(subdir / "nested.txt").write_text("nested content")
|
||||
dst = tmp_path / "backup"
|
||||
dst.mkdir()
|
||||
|
||||
with patch("time.strftime", return_value="_20240101_120000"):
|
||||
folderback.zip_target(src, dst, 5)
|
||||
|
||||
# Should create zip file
|
||||
zip_files = list(dst.glob("*.zip"))
|
||||
assert len(zip_files) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# backup_folder
|
||||
# ---------------------------------------------------------------------- #
|
||||
@@ -46,6 +135,18 @@ class TestBackupFolder:
|
||||
folderback.backup_folder(str(source_dir), str(backup_dir), 5)
|
||||
# Should print error message and return
|
||||
|
||||
def test_backup_folder_creates_dst(self, tmp_path: Path) -> None:
|
||||
"""Should create destination directory."""
|
||||
source_dir = tmp_path / "source"
|
||||
source_dir.mkdir()
|
||||
(source_dir / "test.txt").write_text("test content")
|
||||
backup_dir = tmp_path / "backup"
|
||||
|
||||
with patch.object(folderback, "zip_target") as mock_zip:
|
||||
folderback.backup_folder(str(source_dir), str(backup_dir), 5)
|
||||
assert backup_dir.exists()
|
||||
assert mock_zip.called
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# TaskSpec definitions
|
||||
|
||||
@@ -2,13 +2,92 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import pyflowx as px
|
||||
from pyflowx.cli import gittool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# not_has_git_repo
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestNotHasGitRepo:
|
||||
"""Test not_has_git_repo function."""
|
||||
|
||||
def test_not_has_git_repo_true(self, tmp_path: Path) -> None:
|
||||
"""Should return True when no .git directory."""
|
||||
with patch.object(Path, "cwd", return_value=tmp_path):
|
||||
result = gittool.not_has_git_repo()
|
||||
assert result is True
|
||||
|
||||
def test_not_has_git_repo_false(self, tmp_path: Path) -> None:
|
||||
"""Should return False when .git directory exists."""
|
||||
git_dir = tmp_path / ".git"
|
||||
git_dir.mkdir()
|
||||
|
||||
with patch.object(Path, "cwd", return_value=tmp_path):
|
||||
result = gittool.not_has_git_repo()
|
||||
assert result is False
|
||||
|
||||
def test_not_has_git_repo_cwd_not_exists(self, tmp_path: Path) -> None:
|
||||
"""Should return True when cwd doesn't exist."""
|
||||
nonexistent = tmp_path / "nonexistent"
|
||||
|
||||
with patch.object(Path, "cwd", return_value=nonexistent):
|
||||
result = gittool.not_has_git_repo()
|
||||
assert result is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# has_files
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestHasFiles:
|
||||
"""Test has_files function."""
|
||||
|
||||
def test_has_files_true(self, tmp_path: Path) -> None:
|
||||
"""Should return True when files exist."""
|
||||
(tmp_path / "test.txt").write_text("test")
|
||||
|
||||
with patch.object(Path, "cwd", return_value=tmp_path):
|
||||
result = gittool.has_files()
|
||||
assert result is True
|
||||
|
||||
def test_has_files_false(self, tmp_path: Path) -> None:
|
||||
"""Should return False when no files."""
|
||||
with patch.object(Path, "cwd", return_value=tmp_path):
|
||||
result = gittool.has_files()
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# init_sub_dirs
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestInitSubDirs:
|
||||
"""Test init_sub_dirs function."""
|
||||
|
||||
def test_init_sub_dirs_with_subdirectories(self, tmp_path: Path) -> None:
|
||||
"""Should initialize git in subdirectories."""
|
||||
subdir1 = tmp_path / "subdir1"
|
||||
subdir1.mkdir()
|
||||
subdir2 = tmp_path / "subdir2"
|
||||
subdir2.mkdir()
|
||||
|
||||
with patch.object(Path, "cwd", return_value=tmp_path), patch.object(px, "run") as mock_run:
|
||||
gittool.init_sub_dirs()
|
||||
# Should call px.run for each subdirectory
|
||||
assert mock_run.call_count == 2
|
||||
|
||||
def test_init_sub_dirs_no_subdirectories(self, tmp_path: Path) -> None:
|
||||
"""Should handle no subdirectories."""
|
||||
with patch.object(Path, "cwd", return_value=tmp_path), patch.object(px, "run") as mock_run:
|
||||
gittool.init_sub_dirs()
|
||||
# Should not call px.run
|
||||
assert mock_run.call_count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# TaskSpec definitions
|
||||
# ---------------------------------------------------------------------- #
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -9,6 +10,95 @@ import pyflowx as px
|
||||
from pyflowx.cli import piptool
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# _get_installed_packages
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestGetInstalledPackages:
|
||||
"""Test _get_installed_packages function."""
|
||||
|
||||
def test_get_installed_packages_success(self) -> None:
|
||||
"""Should get installed packages."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(stdout="numpy==1.0.0\npandas==2.0.0\n", returncode=0)
|
||||
result = piptool._get_installed_packages()
|
||||
assert "numpy" in result
|
||||
assert "pandas" in result
|
||||
|
||||
def test_get_installed_packages_empty(self) -> None:
|
||||
"""Should handle empty output."""
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(stdout="", returncode=0)
|
||||
result = piptool._get_installed_packages()
|
||||
assert result == []
|
||||
|
||||
def test_get_installed_packages_error(self) -> None:
|
||||
"""Should handle subprocess error."""
|
||||
with patch("subprocess.run", side_effect=subprocess.SubprocessError):
|
||||
result = piptool._get_installed_packages()
|
||||
assert result == []
|
||||
|
||||
def test_get_installed_packages_oserror(self) -> None:
|
||||
"""Should handle OSError."""
|
||||
with patch("subprocess.run", side_effect=OSError):
|
||||
result = piptool._get_installed_packages()
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# _expand_wildcard_packages
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestExpandWildcardPackages:
|
||||
"""Test _expand_wildcard_packages function."""
|
||||
|
||||
def test_expand_wildcard_no_pattern(self) -> None:
|
||||
"""Should return package name when no wildcard."""
|
||||
result = piptool._expand_wildcard_packages("numpy")
|
||||
assert result == ["numpy"]
|
||||
|
||||
def test_expand_wildcard_with_star(self) -> None:
|
||||
"""Should expand wildcard with star."""
|
||||
with patch.object(piptool, "_get_installed_packages", return_value=["numpy", "numpy-core", "pandas"]):
|
||||
result = piptool._expand_wildcard_packages("numpy*")
|
||||
assert "numpy" in result
|
||||
assert "numpy-core" in result
|
||||
|
||||
def test_expand_wildcard_with_question(self) -> None:
|
||||
"""Should expand wildcard with question mark."""
|
||||
with patch.object(piptool, "_get_installed_packages", return_value=["numpy", "numba"]):
|
||||
result = piptool._expand_wildcard_packages("num??")
|
||||
assert len(result) > 0
|
||||
|
||||
def test_expand_wildcard_no_match(self) -> None:
|
||||
"""Should return empty list when no match."""
|
||||
with patch.object(piptool, "_get_installed_packages", return_value=["pandas", "scipy"]):
|
||||
result = piptool._expand_wildcard_packages("numpy*")
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# _filter_protected_packages
|
||||
# ---------------------------------------------------------------------- #
|
||||
class TestFilterProtectedPackages:
|
||||
"""Test _filter_protected_packages function."""
|
||||
|
||||
def test_filter_protected_packages_normal(self) -> None:
|
||||
"""Should filter protected packages."""
|
||||
result = piptool._filter_protected_packages(["numpy", "pandas", "pyflowx"])
|
||||
assert "numpy" in result
|
||||
assert "pandas" in result
|
||||
assert "pyflowx" not in result
|
||||
|
||||
def test_filter_protected_packages_all_protected(self) -> None:
|
||||
"""Should filter all protected packages."""
|
||||
result = piptool._filter_protected_packages(["pyflowx", "bitool"])
|
||||
assert result == []
|
||||
|
||||
def test_filter_protected_packages_case_insensitive(self) -> None:
|
||||
"""Should filter case insensitive."""
|
||||
result = piptool._filter_protected_packages(["PyFlowX", "BITOOL"])
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pip_uninstall
|
||||
# ---------------------------------------------------------------------- #
|
||||
@@ -39,6 +129,17 @@ class TestPipUninstall:
|
||||
piptool.pip_uninstall(["numpy*"])
|
||||
assert mock_run.called
|
||||
|
||||
def test_pip_uninstall_empty_packages(self) -> None:
|
||||
"""Should handle empty packages list."""
|
||||
with patch.object(piptool, "_expand_wildcard_packages", return_value=[]):
|
||||
piptool.pip_uninstall(["nonexistent*"])
|
||||
# Should not call subprocess.run
|
||||
|
||||
def test_pip_uninstall_all_protected(self) -> None:
|
||||
"""Should handle all protected packages."""
|
||||
piptool.pip_uninstall(["pyflowx"])
|
||||
# Should not call subprocess.run
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pip_reinstall
|
||||
@@ -62,6 +163,11 @@ class TestPipReinstall:
|
||||
# Should call pip install with offline flags
|
||||
assert mock_run.called
|
||||
|
||||
def test_pip_reinstall_all_protected(self) -> None:
|
||||
"""Should handle all protected packages."""
|
||||
piptool.pip_reinstall(["pyflowx"])
|
||||
# Should not call subprocess.run
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
# pip_download
|
||||
|
||||
+73
-39
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -17,51 +18,84 @@ from pyflowx.cli import sshcopyid
|
||||
class TestSshCopyId:
|
||||
"""Test ssh_copy_id function."""
|
||||
|
||||
def test_ssh_copy_id_success(self) -> None:
|
||||
"""ssh_copy_id should deploy SSH key successfully."""
|
||||
pytest.importorskip("paramiko")
|
||||
with patch("paramiko.SSHClient") as mock_ssh_client, patch.object(
|
||||
Path, "exists", return_value=True
|
||||
), patch.object(Path, "read_text", return_value="ssh-rsa AAAAB3..."):
|
||||
mock_client = MagicMock()
|
||||
mock_ssh_client.return_value = mock_client
|
||||
mock_client.connect.return_value = None
|
||||
mock_client.exec_command.return_value = (MagicMock(), MagicMock(), MagicMock())
|
||||
def test_ssh_copy_id_pub_key_not_exists(self, tmp_path: Path) -> None:
|
||||
"""Should handle nonexistent public key."""
|
||||
with patch.object(Path, "expanduser", return_value=tmp_path / "nonexistent.pub"), pytest.raises(SystemExit):
|
||||
sshcopyid.ssh_copy_id("localhost", "user", "password")
|
||||
|
||||
result = sshcopyid.ssh_copy_id("localhost", "user", "password")
|
||||
assert result is None # Function doesn't return anything
|
||||
def test_ssh_copy_id_sshpass_not_found(self, tmp_path: Path) -> None:
|
||||
"""Should handle sshpass not found."""
|
||||
pub_key = tmp_path / "id_rsa.pub"
|
||||
pub_key.write_text("ssh-rsa AAAAB3...")
|
||||
|
||||
def test_ssh_copy_id_with_custom_port(self) -> None:
|
||||
"""ssh_copy_id should handle custom port."""
|
||||
pytest.importorskip("paramiko")
|
||||
with patch("paramiko.SSHClient") as mock_ssh_client, patch.object(
|
||||
Path, "exists", return_value=True
|
||||
), patch.object(Path, "read_text", return_value="ssh-rsa AAAAB3..."):
|
||||
mock_client = MagicMock()
|
||||
mock_ssh_client.return_value = mock_client
|
||||
mock_client.connect.return_value = None
|
||||
mock_client.exec_command.return_value = (MagicMock(), MagicMock(), MagicMock())
|
||||
with patch.object(Path, "expanduser", return_value=pub_key), patch(
|
||||
"subprocess.run", side_effect=FileNotFoundError
|
||||
), pytest.raises(SystemExit):
|
||||
sshcopyid.ssh_copy_id("localhost", "user", "password")
|
||||
|
||||
def test_ssh_copy_id_timeout(self, tmp_path: Path) -> None:
|
||||
"""Should handle SSH timeout."""
|
||||
pub_key = tmp_path / "id_rsa.pub"
|
||||
pub_key.write_text("ssh-rsa AAAAB3...")
|
||||
|
||||
with patch.object(Path, "expanduser", return_value=pub_key), patch(
|
||||
"subprocess.run", side_effect=subprocess.TimeoutExpired("cmd", 30)
|
||||
), pytest.raises(SystemExit):
|
||||
sshcopyid.ssh_copy_id("localhost", "user", "password")
|
||||
|
||||
def test_ssh_copy_id_process_error(self, tmp_path: Path) -> None:
|
||||
"""Should handle SSH process error."""
|
||||
pub_key = tmp_path / "id_rsa.pub"
|
||||
pub_key.write_text("ssh-rsa AAAAB3...")
|
||||
|
||||
with patch.object(Path, "expanduser", return_value=pub_key), patch(
|
||||
"subprocess.run", side_effect=subprocess.CalledProcessError(1, "cmd")
|
||||
), pytest.raises(SystemExit):
|
||||
sshcopyid.ssh_copy_id("localhost", "user", "password")
|
||||
|
||||
def test_ssh_copy_id_success(self, tmp_path: Path) -> None:
|
||||
"""Should deploy SSH key successfully."""
|
||||
pub_key = tmp_path / "id_rsa.pub"
|
||||
pub_key.write_text("ssh-rsa AAAAB3...")
|
||||
|
||||
with patch.object(Path, "expanduser", return_value=pub_key), patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
sshcopyid.ssh_copy_id("localhost", "user", "password")
|
||||
assert mock_run.called
|
||||
|
||||
def test_ssh_copy_id_with_custom_port(self, tmp_path: Path) -> None:
|
||||
"""Should handle custom port."""
|
||||
pub_key = tmp_path / "id_rsa.pub"
|
||||
pub_key.write_text("ssh-rsa AAAAB3...")
|
||||
|
||||
with patch.object(Path, "expanduser", return_value=pub_key), patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
sshcopyid.ssh_copy_id("localhost", "user", "password", port=2222)
|
||||
# Verify that connect was called with custom port
|
||||
mock_client.connect.assert_called_once()
|
||||
call_args = mock_client.connect.call_args
|
||||
assert call_args[1]["port"] == 2222
|
||||
# Verify port is used
|
||||
call_args = mock_run.call_args[0][0]
|
||||
assert "2222" in call_args
|
||||
|
||||
def test_ssh_copy_id_with_custom_keypath(self) -> None:
|
||||
"""ssh_copy_id should handle custom key path."""
|
||||
pytest.importorskip("paramiko")
|
||||
with patch("paramiko.SSHClient") as mock_ssh_client, patch.object(
|
||||
Path, "exists", return_value=True
|
||||
), patch.object(Path, "read_text", return_value="ssh-rsa AAAAB3..."):
|
||||
mock_client = MagicMock()
|
||||
mock_ssh_client.return_value = mock_client
|
||||
mock_client.connect.return_value = None
|
||||
mock_client.exec_command.return_value = (MagicMock(), MagicMock(), MagicMock())
|
||||
def test_ssh_copy_id_with_custom_keypath(self, tmp_path: Path) -> None:
|
||||
"""Should handle custom keypath."""
|
||||
custom_key = tmp_path / "custom.pub"
|
||||
custom_key.write_text("ssh-rsa AAAAB3...")
|
||||
|
||||
result = sshcopyid.ssh_copy_id("localhost", "user", "password", keypath="/custom/key.pub")
|
||||
# Verify that the custom keypath was used
|
||||
assert result is None
|
||||
with patch.object(Path, "expanduser", return_value=custom_key), patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
sshcopyid.ssh_copy_id("localhost", "user", "password", keypath=str(custom_key))
|
||||
assert mock_run.called
|
||||
|
||||
def test_ssh_copy_id_with_custom_timeout(self, tmp_path: Path) -> None:
|
||||
"""Should handle custom timeout."""
|
||||
pub_key = tmp_path / "id_rsa.pub"
|
||||
pub_key.write_text("ssh-rsa AAAAB3...")
|
||||
|
||||
with patch.object(Path, "expanduser", return_value=pub_key), patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(returncode=0)
|
||||
sshcopyid.ssh_copy_id("localhost", "user", "password", timeout=60)
|
||||
# Verify timeout is used in ConnectTimeout option
|
||||
call_args = mock_run.call_args[0][0]
|
||||
assert "ConnectTimeout=60" in call_args
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------- #
|
||||
|
||||
Reference in New Issue
Block a user