Compare commits
2 commits
73429771d4
...
ffb1f7e204
Author | SHA1 | Date | |
---|---|---|---|
ffb1f7e204 | |||
414266eefc |
8 changed files with 836 additions and 601 deletions
236
add_host.py
236
add_host.py
|
@ -2,109 +2,211 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import glob
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Dict, Set, Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
from .utils import print_error, print_warning, print_info, safe_input
|
from .utils import print_error, print_warning, print_info, safe_input
|
||||||
|
from .config import CONF_DIR
|
||||||
|
|
||||||
def add_host(conf_dir):
|
@dataclass
|
||||||
|
class HostConfig:
|
||||||
|
label: str
|
||||||
|
hostname: str
|
||||||
|
user: str = "root"
|
||||||
|
port: str = "22"
|
||||||
|
identity_file: str = ""
|
||||||
|
|
||||||
|
def to_config_lines(self) -> List[str]:
|
||||||
|
"""Convert host configuration to SSH config file lines."""
|
||||||
|
lines = [
|
||||||
|
f"Host {self.label}",
|
||||||
|
f" HostName {self.hostname}",
|
||||||
|
f" User {self.user}",
|
||||||
|
f" Port {self.port}"
|
||||||
|
]
|
||||||
|
if self.identity_file:
|
||||||
|
lines.append(f" IdentityFile {self.identity_file}")
|
||||||
|
return lines
|
||||||
|
|
||||||
|
def get_existing_hosts(conf_dir: str) -> Tuple[Set[str], Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
Get existing host labels and hostnames from config files.
|
||||||
|
Returns (host_labels, hostname_to_label) where:
|
||||||
|
- host_labels is a set of existing host labels
|
||||||
|
- hostname_to_label maps hostnames to their labels
|
||||||
|
"""
|
||||||
|
host_labels = set()
|
||||||
|
hostname_to_label = {}
|
||||||
|
|
||||||
|
pattern = os.path.join(conf_dir, "*", "config")
|
||||||
|
for config_file in glob.glob(pattern):
|
||||||
|
try:
|
||||||
|
with open(config_file, 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
current_label = None
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if not line or line.startswith('#'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if line.lower().startswith('host '):
|
||||||
|
labels = line.split()[1:]
|
||||||
|
for label in labels:
|
||||||
|
if '*' not in label:
|
||||||
|
current_label = label
|
||||||
|
host_labels.add(label)
|
||||||
|
break
|
||||||
|
elif current_label and line.lower().startswith('hostname '):
|
||||||
|
hostname = line.split(None, 1)[1].strip()
|
||||||
|
hostname_to_label[hostname] = current_label
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print_warning(f"Error reading config file {config_file}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return host_labels, hostname_to_label
|
||||||
|
|
||||||
|
def generate_ssh_key(key_path: Path) -> bool:
|
||||||
|
"""Generate a new ED25519 SSH key pair."""
|
||||||
|
try:
|
||||||
|
subprocess.check_call([
|
||||||
|
"ssh-keygen",
|
||||||
|
"-q",
|
||||||
|
"-t", "ed25519",
|
||||||
|
"-N", "",
|
||||||
|
"-f", str(key_path)
|
||||||
|
])
|
||||||
|
print_info(f"Generated new SSH key at {key_path}")
|
||||||
|
return True
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print_error(f"Error generating SSH key: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def copy_ssh_key(key_path: Path, user: str, hostname: str, port: str) -> bool:
|
||||||
|
"""Copy SSH public key to remote server."""
|
||||||
|
try:
|
||||||
|
cmd = ["ssh-copy-id", "-i", str(key_path)]
|
||||||
|
if port != "22":
|
||||||
|
cmd.extend(["-p", port])
|
||||||
|
cmd.append(f"{user}@{hostname}")
|
||||||
|
|
||||||
|
subprocess.check_call(cmd)
|
||||||
|
print_info("Key successfully copied to remote server.")
|
||||||
|
return True
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print_error(f"Error copying key to server: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def write_config_file(config_path: Path, config_lines: List[str]) -> bool:
|
||||||
|
"""Write SSH config lines to file."""
|
||||||
|
try:
|
||||||
|
config_path.write_text("\n".join(config_lines) + "\n")
|
||||||
|
print_info(f"Created/updated config at: {config_path}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print_error(f"Failed to write config to {config_path}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def add_host(conf_dir: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Interactive prompt to create a new SSH host in ~/.ssh/conf/<label>/config.
|
Interactive prompt to create a new SSH host in ~/.ssh/conf/<label>/config.
|
||||||
Offers to generate a new SSH key pair (ed25519) quietly (-q),
|
Offers to generate a new SSH key pair (ed25519) quietly (-q),
|
||||||
and then prompt to copy that key to the remote server via ssh-copy-id.
|
and then prompt to copy that key to the remote server via ssh-copy-id.
|
||||||
|
|
||||||
|
Returns True if host was added successfully, False otherwise.
|
||||||
"""
|
"""
|
||||||
print_info("Adding a new SSH host...")
|
print_info("Adding a new SSH host...")
|
||||||
|
|
||||||
|
# Get existing hosts to check for duplicates
|
||||||
|
existing_labels, hostname_to_label = get_existing_hosts(conf_dir)
|
||||||
|
|
||||||
|
# Get host label
|
||||||
|
while True:
|
||||||
host_label = safe_input("Enter Host label (e.g. myserver): ")
|
host_label = safe_input("Enter Host label (e.g. myserver): ")
|
||||||
if host_label is None:
|
if not host_label or host_label is None:
|
||||||
return # User canceled (Ctrl+C)
|
|
||||||
host_label = host_label.strip()
|
|
||||||
if not host_label:
|
|
||||||
print_error("Host label cannot be empty.")
|
print_error("Host label cannot be empty.")
|
||||||
return
|
return False
|
||||||
|
|
||||||
|
host_label = host_label.strip()
|
||||||
|
if host_label in existing_labels:
|
||||||
|
print_error(f"A host with label '{host_label}' already exists. Please choose a different label.")
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
# Get hostname
|
||||||
|
while True:
|
||||||
hostname = safe_input("Enter HostName (IP or domain): ")
|
hostname = safe_input("Enter HostName (IP or domain): ")
|
||||||
if hostname is None:
|
if not hostname or hostname is None:
|
||||||
return
|
|
||||||
hostname = hostname.strip()
|
|
||||||
if not hostname:
|
|
||||||
print_error("HostName cannot be empty.")
|
print_error("HostName cannot be empty.")
|
||||||
return
|
return False
|
||||||
|
|
||||||
user = safe_input("Enter username (default: 'root'): ")
|
hostname = hostname.strip()
|
||||||
|
if hostname in hostname_to_label:
|
||||||
|
existing_label = hostname_to_label[hostname]
|
||||||
|
print_error(f"This hostname is already configured for host '{existing_label}'. Please use a different hostname.")
|
||||||
|
continue
|
||||||
|
break
|
||||||
|
|
||||||
|
# Get optional parameters
|
||||||
|
user = safe_input("Enter username (default: 'root'): ") or "root"
|
||||||
if user is None:
|
if user is None:
|
||||||
return
|
return False
|
||||||
user = user.strip() or "root"
|
|
||||||
|
|
||||||
port = safe_input("Enter SSH port (default: 22): ")
|
port = safe_input("Enter SSH port (default: 22): ") or "22"
|
||||||
if port is None:
|
if port is None:
|
||||||
return
|
return False
|
||||||
port = port.strip() or "22"
|
|
||||||
|
|
||||||
# Create subdirectory: ~/.ssh/conf/<label>
|
# Create host configuration
|
||||||
host_dir = os.path.join(conf_dir, host_label)
|
host_config = HostConfig(
|
||||||
if os.path.exists(host_dir):
|
label=host_label,
|
||||||
|
hostname=hostname,
|
||||||
|
user=user.strip(),
|
||||||
|
port=port.strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setup directory structure
|
||||||
|
host_dir = Path(conf_dir) / host_label
|
||||||
|
if host_dir.exists():
|
||||||
print_warning(f"Directory {host_dir} already exists; continuing anyway.")
|
print_warning(f"Directory {host_dir} already exists; continuing anyway.")
|
||||||
else:
|
host_dir.mkdir(mode=0o700, exist_ok=True)
|
||||||
os.makedirs(host_dir, mode=0o700, exist_ok=True)
|
|
||||||
print_info(f"Created directory: {host_dir}")
|
print_info(f"Created directory: {host_dir}")
|
||||||
|
|
||||||
config_path = os.path.join(host_dir, "config")
|
config_path = host_dir / "config"
|
||||||
if os.path.exists(config_path):
|
if config_path.exists():
|
||||||
print_warning(f"Config file already exists: {config_path}; it will be overwritten/updated.")
|
print_warning(f"Config file already exists: {config_path}; it will be overwritten/updated.")
|
||||||
|
|
||||||
|
# Handle SSH key generation
|
||||||
gen_key_choice = safe_input("Generate a new ed25519 SSH key for this host? (y/n): ")
|
gen_key_choice = safe_input("Generate a new ed25519 SSH key for this host? (y/n): ")
|
||||||
if gen_key_choice is None:
|
if gen_key_choice is None:
|
||||||
return
|
return False
|
||||||
gen_key_choice = gen_key_choice.lower().strip()
|
|
||||||
|
|
||||||
identity_file = ""
|
if gen_key_choice.lower().strip() == 'y':
|
||||||
if gen_key_choice == 'y':
|
key_path = host_dir / "id_ed25519"
|
||||||
key_path = os.path.join(host_dir, "id_ed25519")
|
|
||||||
if os.path.exists(key_path):
|
if key_path.exists():
|
||||||
print_warning(f"{key_path} already exists. Skipping generation.")
|
print_warning(f"{key_path} already exists. Skipping generation.")
|
||||||
identity_file = key_path
|
host_config.identity_file = str(key_path)
|
||||||
else:
|
else:
|
||||||
cmd = ["ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-f", key_path]
|
if generate_ssh_key(key_path):
|
||||||
try:
|
host_config.identity_file = str(key_path)
|
||||||
subprocess.check_call(cmd)
|
|
||||||
print_info(f"Generated new SSH key at {key_path}")
|
|
||||||
identity_file = key_path
|
|
||||||
|
|
||||||
# Prompt to copy the key
|
# Prompt to copy the key
|
||||||
copy_key = safe_input("Would you like to copy this key to the server now? (y/n): ")
|
copy_key = safe_input("Would you like to copy this key to the server now? (y/n): ")
|
||||||
if copy_key is None:
|
if copy_key is None:
|
||||||
return
|
return False
|
||||||
|
|
||||||
if copy_key.lower().strip() == 'y':
|
if copy_key.lower().strip() == 'y':
|
||||||
ssh_copy_cmd = ["ssh-copy-id", "-i", key_path]
|
copy_ssh_key(key_path, host_config.user, host_config.hostname, host_config.port)
|
||||||
if port != "22":
|
|
||||||
ssh_copy_cmd += ["-p", port]
|
|
||||||
ssh_copy_cmd.append(f"{user}@{hostname}")
|
|
||||||
try:
|
|
||||||
subprocess.check_call(ssh_copy_cmd)
|
|
||||||
print_info("Key successfully copied to remote server.")
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
print_error(f"Error copying key to server: {e}")
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
print_error(f"Error generating SSH key: {e}")
|
|
||||||
else:
|
else:
|
||||||
|
# Handle existing key
|
||||||
existing_key = safe_input("Enter existing IdentityFile path (or leave empty to skip): ")
|
existing_key = safe_input("Enter existing IdentityFile path (or leave empty to skip): ")
|
||||||
if existing_key is None:
|
if existing_key is None:
|
||||||
return
|
return False
|
||||||
existing_key = existing_key.strip()
|
|
||||||
if existing_key:
|
|
||||||
identity_file = os.path.expanduser(existing_key)
|
|
||||||
|
|
||||||
config_lines = [
|
if existing_key.strip():
|
||||||
f"Host {host_label}",
|
host_config.identity_file = os.path.expanduser(existing_key.strip())
|
||||||
f" HostName {hostname}",
|
|
||||||
f" User {user}",
|
|
||||||
f" Port {port}"
|
|
||||||
]
|
|
||||||
if identity_file:
|
|
||||||
config_lines.append(f" IdentityFile {identity_file}")
|
|
||||||
|
|
||||||
try:
|
# Write the config file
|
||||||
with open(config_path, "w") as f:
|
return write_config_file(config_path, host_config.to_config_lines())
|
||||||
for line in config_lines:
|
|
||||||
f.write(line + "\n")
|
|
||||||
print_info(f"Created/updated config at: {config_path}")
|
|
||||||
except Exception as e:
|
|
||||||
print_error(f"Failed to write config to {config_path}: {e}")
|
|
||||||
|
|
115
cli.py
115
cli.py
|
@ -2,6 +2,8 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from typing import Callable, Dict, Optional
|
||||||
|
from functools import wraps
|
||||||
from .utils import print_info, print_error, print_warning, Colors, safe_input
|
from .utils import print_info, print_error, print_warning, Colors, safe_input
|
||||||
from .config import SSH_DIR, CONF_DIR, SOCKET_DIR, MAIN_CONFIG, DEFAULT_CONFIG_CONTENT
|
from .config import SSH_DIR, CONF_DIR, SOCKET_DIR, MAIN_CONFIG, DEFAULT_CONFIG_CONTENT
|
||||||
from .add_host import add_host
|
from .add_host import add_host
|
||||||
|
@ -10,62 +12,101 @@ from .list_hosts import list_hosts
|
||||||
from .regen_key import regenerate_key
|
from .regen_key import regenerate_key
|
||||||
from .remove_host import remove_host
|
from .remove_host import remove_host
|
||||||
|
|
||||||
def ensure_ssh_setup():
|
def ensure_ssh_setup() -> None:
|
||||||
"""
|
"""
|
||||||
Creates ~/.ssh, ~/.ssh/conf, and ~/.ssh/s if missing,
|
Creates ~/.ssh, ~/.ssh/conf, and ~/.ssh/s if missing,
|
||||||
and writes a default ~/.ssh/config if it doesn't exist.
|
and writes a default ~/.ssh/config if it doesn't exist.
|
||||||
"""
|
"""
|
||||||
if not os.path.isdir(SSH_DIR):
|
directories = [
|
||||||
os.makedirs(SSH_DIR, mode=0o700, exist_ok=True)
|
(SSH_DIR, "SSH directory"),
|
||||||
print_info(f"Created directory: {SSH_DIR}")
|
(CONF_DIR, "configuration directory"),
|
||||||
|
(SOCKET_DIR, "socket directory")
|
||||||
|
]
|
||||||
|
|
||||||
if not os.path.isdir(CONF_DIR):
|
for directory, description in directories:
|
||||||
os.makedirs(CONF_DIR, mode=0o700, exist_ok=True)
|
if not os.path.isdir(directory):
|
||||||
print_info(f"Created directory: {CONF_DIR}")
|
os.makedirs(directory, mode=0o700, exist_ok=True)
|
||||||
|
print_info(f"Created {description}: {directory}")
|
||||||
if not os.path.isdir(SOCKET_DIR):
|
|
||||||
os.makedirs(SOCKET_DIR, mode=0o700, exist_ok=True)
|
|
||||||
print_info(f"Created directory: {SOCKET_DIR}")
|
|
||||||
|
|
||||||
if not os.path.isfile(MAIN_CONFIG):
|
if not os.path.isfile(MAIN_CONFIG):
|
||||||
with open(MAIN_CONFIG, "w") as f:
|
with open(MAIN_CONFIG, "w") as f:
|
||||||
f.write(DEFAULT_CONFIG_CONTENT)
|
f.write(DEFAULT_CONFIG_CONTENT)
|
||||||
print_info(f"Created default SSH config at: {MAIN_CONFIG}")
|
print_info(f"Created default SSH config at: {MAIN_CONFIG}")
|
||||||
|
|
||||||
def main():
|
def async_handler(func: Callable) -> Callable:
|
||||||
|
"""Decorator to handle async functions in the command dispatch"""
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
return asyncio.run(func(*args, **kwargs))
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
class SSHManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.commands: Dict[str, tuple[Callable, str]] = {
|
||||||
|
"1": (self.list_hosts, "List Hosts"),
|
||||||
|
"2": (self.add_host, "Add a Host"),
|
||||||
|
"3": (self.edit_host, "Edit a Host"),
|
||||||
|
"4": (self.regenerate_key, "Regenerate Key"),
|
||||||
|
"5": (self.remove_host, "Remove Host"),
|
||||||
|
"6": (self.exit_app, "Exit")
|
||||||
|
}
|
||||||
|
|
||||||
|
@async_handler
|
||||||
|
async def list_hosts(self) -> None:
|
||||||
|
await list_hosts(CONF_DIR)
|
||||||
|
|
||||||
|
def add_host(self) -> None:
|
||||||
|
add_host(CONF_DIR)
|
||||||
|
|
||||||
|
@async_handler
|
||||||
|
async def edit_host(self) -> None:
|
||||||
|
await edit_host(CONF_DIR)
|
||||||
|
|
||||||
|
@async_handler
|
||||||
|
async def regenerate_key(self) -> None:
|
||||||
|
await regenerate_key(CONF_DIR)
|
||||||
|
|
||||||
|
@async_handler
|
||||||
|
async def remove_host(self) -> None:
|
||||||
|
await remove_host(CONF_DIR)
|
||||||
|
|
||||||
|
def exit_app(self) -> None:
|
||||||
|
print_info("Exiting...")
|
||||||
|
raise SystemExit(0)
|
||||||
|
|
||||||
|
def display_menu(self) -> None:
|
||||||
|
print("\n" + f"{Colors.CYAN}{Colors.BOLD}SSH Config Manager Menu{Colors.RESET}")
|
||||||
|
for key, (_, description) in self.commands.items():
|
||||||
|
print(f"{key}. {description}")
|
||||||
|
|
||||||
|
def handle_command(self, choice: str) -> bool:
|
||||||
|
if choice not in self.commands:
|
||||||
|
print_error("Invalid choice. Please select 1 through 6.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.commands[choice][0]()
|
||||||
|
return choice != "6"
|
||||||
|
except SystemExit:
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print_error(f"Error executing command: {str(e)}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
ensure_ssh_setup()
|
ensure_ssh_setup()
|
||||||
|
manager = SSHManager()
|
||||||
|
|
||||||
# Display the server list on first load
|
# Display the server list on first load
|
||||||
asyncio.run(list_hosts(CONF_DIR))
|
manager.list_hosts()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
print("\n" + f"{Colors.CYAN}{Colors.BOLD}SSH Config Manager Menu{Colors.RESET}")
|
manager.display_menu()
|
||||||
print("1. List Hosts")
|
|
||||||
print("2. Add a Host")
|
|
||||||
print("3. Edit a Host")
|
|
||||||
print("4. Regenerate Key")
|
|
||||||
print("5. Remove Host")
|
|
||||||
print("6. Exit")
|
|
||||||
|
|
||||||
choice = safe_input("Select an option (1-6): ")
|
choice = safe_input("Select an option (1-6): ")
|
||||||
if choice is None:
|
if choice is None:
|
||||||
continue # User pressed Ctrl+C => safe_input returns None => re-show menu
|
continue
|
||||||
|
|
||||||
choice = choice.strip()
|
if not manager.handle_command(choice.strip()):
|
||||||
if choice == '1':
|
|
||||||
asyncio.run(list_hosts(CONF_DIR))
|
|
||||||
elif choice == '2':
|
|
||||||
add_host(CONF_DIR)
|
|
||||||
elif choice == '3':
|
|
||||||
edit_host(CONF_DIR)
|
|
||||||
elif choice == '4':
|
|
||||||
asyncio.run(regenerate_key(CONF_DIR))
|
|
||||||
elif choice == '5':
|
|
||||||
asyncio.run(remove_host(CONF_DIR))
|
|
||||||
elif choice == '6':
|
|
||||||
print_info("Exiting...")
|
|
||||||
break
|
break
|
||||||
else:
|
|
||||||
print_error("Invalid choice. Please select 1 through 6.")
|
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
46
config.py
46
config.py
|
@ -1,15 +1,24 @@
|
||||||
# ssh_manager/config.py
|
# ssh_manager/config.py
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Final
|
||||||
|
import subprocess
|
||||||
|
|
||||||
# Paths
|
# Use Path for more efficient path handling
|
||||||
SSH_DIR = os.path.expanduser("~/.ssh")
|
HOME: Final[Path] = Path.home()
|
||||||
CONF_DIR = os.path.join(SSH_DIR, "conf")
|
SSH_DIR: Final[Path] = HOME / ".ssh"
|
||||||
SOCKET_DIR = os.path.join(SSH_DIR, "s")
|
CONF_DIR: Final[Path] = SSH_DIR / "conf"
|
||||||
MAIN_CONFIG = os.path.join(SSH_DIR, "config")
|
SOCKET_DIR: Final[Path] = SSH_DIR / "s"
|
||||||
|
MAIN_CONFIG: Final[Path] = SSH_DIR / "config"
|
||||||
|
|
||||||
|
# Validate paths on import
|
||||||
|
for path in (SSH_DIR, CONF_DIR, SOCKET_DIR):
|
||||||
|
if not path.exists():
|
||||||
|
path.mkdir(mode=0o700, parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Default SSH config content if ~/.ssh/config is missing
|
# Default SSH config content if ~/.ssh/config is missing
|
||||||
DEFAULT_CONFIG_CONTENT = """###
|
DEFAULT_CONFIG_CONTENT: Final[str] = """###
|
||||||
#Local ssh
|
#Local ssh
|
||||||
###
|
###
|
||||||
|
|
||||||
|
@ -30,3 +39,28 @@ Host *
|
||||||
ControlPersist 72000
|
ControlPersist 72000
|
||||||
ControlPath ~/.ssh/s/%C
|
ControlPath ~/.ssh/s/%C
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Export string versions for backward compatibility
|
||||||
|
SSH_DIR_STR: Final[str] = str(SSH_DIR)
|
||||||
|
CONF_DIR_STR: Final[str] = str(CONF_DIR)
|
||||||
|
SOCKET_DIR_STR: Final[str] = str(SOCKET_DIR)
|
||||||
|
MAIN_CONFIG_STR: Final[str] = str(MAIN_CONFIG)
|
||||||
|
|
||||||
|
def validate_key_path(key_path: Path) -> bool:
|
||||||
|
if not key_path.exists():
|
||||||
|
return False
|
||||||
|
if not key_path.is_dir():
|
||||||
|
return False
|
||||||
|
if not os.access(key_path, os.W_OK):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def update_config_with_key(key_path: Path) -> bool:
|
||||||
|
if not validate_key_path(key_path):
|
||||||
|
return False # Early return if path validation fails
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.check_call([...])
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print_error(f"Error generating new SSH key: {e}")
|
||||||
|
return False
|
||||||
|
|
165
edit_host.py
165
edit_host.py
|
@ -2,98 +2,31 @@ import os
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from .utils import print_error, print_warning, print_info, safe_input
|
from .utils import print_error, print_warning, print_info, safe_input
|
||||||
from .list_hosts import list_hosts, load_config_file, check_ssh_port
|
from .list_hosts import (
|
||||||
|
build_host_list_table,
|
||||||
|
load_config_file,
|
||||||
|
gather_host_info,
|
||||||
|
sort_by_ip
|
||||||
|
)
|
||||||
|
|
||||||
async def get_all_host_blocks(conf_dir):
|
async def edit_host(conf_dir):
|
||||||
"""
|
|
||||||
Similar to list_hosts, but returns the list of host blocks + a table of results.
|
|
||||||
We'll build a table ourselves so we can map row numbers to actual host labels.
|
|
||||||
"""
|
|
||||||
import glob
|
|
||||||
import socket
|
|
||||||
|
|
||||||
pattern = os.path.join(conf_dir, "*", "config")
|
|
||||||
conf_files = sorted(glob.glob(pattern))
|
|
||||||
|
|
||||||
all_blocks = []
|
|
||||||
for conf_file in conf_files:
|
|
||||||
blocks = load_config_file(conf_file)
|
|
||||||
all_blocks.extend(blocks)
|
|
||||||
|
|
||||||
# If no blocks found, return empty
|
|
||||||
if not all_blocks:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# We want to do a partial version of check_host to get row data
|
|
||||||
# so we can display the table right here and keep track of each block’s host label.
|
|
||||||
# But let's do it similarly to list_hosts:
|
|
||||||
|
|
||||||
table_rows = []
|
|
||||||
for idx, b in enumerate(all_blocks, start=1):
|
|
||||||
host_label = b.get("Host", "N/A")
|
|
||||||
hostname = b.get("HostName", "N/A")
|
|
||||||
user = b.get("User", "N/A")
|
|
||||||
port = int(b.get("Port", "22"))
|
|
||||||
identity_file = b.get("IdentityFile", "N/A")
|
|
||||||
|
|
||||||
# Identity check
|
|
||||||
if identity_file != "N/A":
|
|
||||||
expanded_identity = os.path.expanduser(identity_file)
|
|
||||||
identity_exists = os.path.isfile(expanded_identity)
|
|
||||||
else:
|
|
||||||
identity_exists = False
|
|
||||||
|
|
||||||
# IP resolution
|
|
||||||
try:
|
|
||||||
ip_address = socket.gethostbyname(hostname)
|
|
||||||
except socket.error:
|
|
||||||
ip_address = None
|
|
||||||
|
|
||||||
# Port check
|
|
||||||
if ip_address:
|
|
||||||
port_open = await asyncio.wait_for(check_ssh_port(ip_address, port), timeout=1)
|
|
||||||
else:
|
|
||||||
port_open = False
|
|
||||||
|
|
||||||
# Colors for display (optional, or we can keep it simple):
|
|
||||||
ip_display = f"\033[0;32m{ip_address}\033[0m" if ip_address else "\033[0;31mN/A\033[0m"
|
|
||||||
port_display = f"\033[0;32m{port}\033[0m" if port_open else f"\033[0;31m{port}\033[0m"
|
|
||||||
identity_disp= f"\033[0;32m{identity_file}\033[0m" if identity_exists else f"\033[0;31m{identity_file}\033[0m"
|
|
||||||
|
|
||||||
row = [
|
|
||||||
idx,
|
|
||||||
host_label,
|
|
||||||
user,
|
|
||||||
port_display,
|
|
||||||
hostname,
|
|
||||||
ip_display,
|
|
||||||
identity_disp
|
|
||||||
]
|
|
||||||
table_rows.append(row)
|
|
||||||
|
|
||||||
# Print the table
|
|
||||||
from tabulate import tabulate
|
|
||||||
headers = ["No.", "Host", "User", "Port", "HostName", "IP Address", "IdentityFile"]
|
|
||||||
print("\nSSH Conf Subdirectory Host List")
|
|
||||||
print(tabulate(table_rows, headers=headers, tablefmt="grid"))
|
|
||||||
|
|
||||||
return all_blocks
|
|
||||||
|
|
||||||
def edit_host(conf_dir):
|
|
||||||
"""
|
"""
|
||||||
Let the user update fields for an existing host in ~/.ssh/conf/<label>/config.
|
Let the user update fields for an existing host in ~/.ssh/conf/<label>/config.
|
||||||
The user may type either the row number OR the actual host label.
|
1) Display the unified table (No. | Host | User | Port | HostName | IP Address | Conf Directory)
|
||||||
|
2) Prompt row number or host label
|
||||||
|
3) Rewrite config with updated fields
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 1) Gather + display the current host list
|
|
||||||
print_info("Here is the current list of hosts:\n")
|
print_info("Here is the current list of hosts:\n")
|
||||||
all_blocks = asyncio.run(get_all_host_blocks(conf_dir))
|
headers, final_data = await build_host_list_table(conf_dir)
|
||||||
|
if not final_data:
|
||||||
if not all_blocks:
|
print_warning("No hosts to edit.")
|
||||||
print_warning("No hosts found to edit.")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# 2) Prompt for which host to edit (by label or row number)
|
from tabulate import tabulate
|
||||||
|
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
|
||||||
|
print(tabulate(final_data, headers=headers, tablefmt="grid"))
|
||||||
|
|
||||||
choice = safe_input("Enter the row number or the Host label to edit: ")
|
choice = safe_input("Enter the row number or the Host label to edit: ")
|
||||||
if choice is None:
|
if choice is None:
|
||||||
return # user canceled (Ctrl+C)
|
return # user canceled (Ctrl+C)
|
||||||
|
@ -102,43 +35,47 @@ def edit_host(conf_dir):
|
||||||
print_error("Host label or row number cannot be empty.")
|
print_error("Host label or row number cannot be empty.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check if user typed a digit -> row number
|
# We replicate the approach to find the matching block
|
||||||
target_block = None
|
all_blocks = []
|
||||||
|
import glob
|
||||||
|
for cfile in glob.glob(os.path.join(conf_dir, "*", "config")):
|
||||||
|
all_blocks.extend(load_config_file(cfile))
|
||||||
|
|
||||||
|
results = await gather_host_info(all_blocks)
|
||||||
|
sorted_rows = sort_by_ip(results)
|
||||||
|
|
||||||
|
target_tuple = None
|
||||||
if choice.isdigit():
|
if choice.isdigit():
|
||||||
row_idx = int(choice)
|
idx = int(choice)
|
||||||
# Validate index
|
if idx < 1 or idx > len(sorted_rows):
|
||||||
if row_idx < 1 or row_idx > len(all_blocks):
|
print_warning(f"Row number {idx} is invalid.")
|
||||||
print_warning(f"Invalid row number {row_idx}.")
|
|
||||||
return
|
return
|
||||||
target_block = all_blocks[row_idx - 1] # zero-based
|
target_tuple = sorted_rows[idx - 1]
|
||||||
else:
|
else:
|
||||||
# The user typed a host label
|
for t in sorted_rows:
|
||||||
# We must search all_blocks for a matching Host
|
if t[0] == choice: # t[0] => host_label
|
||||||
for b in all_blocks:
|
target_tuple = t
|
||||||
if b.get("Host") == choice:
|
|
||||||
target_block = b
|
|
||||||
break
|
break
|
||||||
if not target_block:
|
if not target_tuple:
|
||||||
print_warning(f"No matching host label '{choice}' found.")
|
print_warning(f"No matching Host '{choice}' found in the table.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Now we have a target_block with existing data
|
host_label = target_tuple[0]
|
||||||
host_label = target_block.get("Host", "")
|
# find the config block
|
||||||
if not host_label:
|
found_block = None
|
||||||
print_warning("This host block has no label. Cannot edit.")
|
for b in all_blocks:
|
||||||
|
if b.get("Host") == host_label:
|
||||||
|
found_block = b
|
||||||
|
break
|
||||||
|
|
||||||
|
if not found_block:
|
||||||
|
print_warning(f"No config block found for '{host_label}'.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Derive the config path
|
old_hostname = found_block.get("HostName", "")
|
||||||
host_dir = os.path.join(conf_dir, host_label)
|
old_user = found_block.get("User", "")
|
||||||
config_path = os.path.join(host_dir, "config")
|
old_port = found_block.get("Port", "22")
|
||||||
if not os.path.isfile(config_path):
|
old_identity = found_block.get("IdentityFile", "")
|
||||||
print_warning(f"No config file found at {config_path}; cannot edit this host.")
|
|
||||||
return
|
|
||||||
|
|
||||||
old_hostname = target_block.get("HostName", "")
|
|
||||||
old_user = target_block.get("User", "")
|
|
||||||
old_port = target_block.get("Port", "22")
|
|
||||||
old_identity = target_block.get("IdentityFile", "")
|
|
||||||
|
|
||||||
print_info("Leave a field blank to keep its current value.")
|
print_info("Leave a field blank to keep its current value.")
|
||||||
|
|
||||||
|
@ -167,6 +104,8 @@ def edit_host(conf_dir):
|
||||||
final_port = new_port if new_port else old_port
|
final_port = new_port if new_port else old_port
|
||||||
final_ident = new_ident if new_ident else old_identity
|
final_ident = new_ident if new_ident else old_identity
|
||||||
|
|
||||||
|
# Overwrite the file
|
||||||
|
config_path = os.path.join(conf_dir, host_label, "config")
|
||||||
new_config_lines = [
|
new_config_lines = [
|
||||||
f"Host {host_label}",
|
f"Host {host_label}",
|
||||||
f" HostName {final_hostname}",
|
f" HostName {final_hostname}",
|
||||||
|
|
210
list_hosts.py
210
list_hosts.py
|
@ -5,33 +5,27 @@ import glob
|
||||||
import socket
|
import socket
|
||||||
import asyncio
|
import asyncio
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Tuple, Optional, OrderedDict as OrderedDictType
|
||||||
|
from functools import lru_cache
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from .utils import print_warning, print_error, Colors
|
from .utils import print_warning, print_error, Colors
|
||||||
|
from .config import CONF_DIR
|
||||||
|
|
||||||
async def check_ssh_port(ip_address, port):
|
# Cache DNS lookups
|
||||||
"""
|
@lru_cache(maxsize=128, typed=True)
|
||||||
Attempt to open an SSH connection to see if the port is open.
|
def resolve_hostname(hostname: str) -> Optional[str]:
|
||||||
Returns True if successful, False otherwise.
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
reader, writer = await asyncio.wait_for(
|
return socket.gethostbyname(hostname)
|
||||||
asyncio.open_connection(ip_address, port), timeout=1
|
except socket.error:
|
||||||
)
|
return None
|
||||||
writer.close()
|
|
||||||
await writer.wait_closed()
|
|
||||||
return True
|
|
||||||
except:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def load_config_file(file_path):
|
def load_config_file(file_path: str) -> List[OrderedDictType]:
|
||||||
"""
|
blocks: List[OrderedDictType] = []
|
||||||
Parse a single SSH config file and return a list of host blocks.
|
host_data: Optional[OrderedDictType] = None
|
||||||
Each block is an OrderedDict with keys like 'Host', 'HostName', etc.
|
|
||||||
"""
|
|
||||||
blocks = []
|
|
||||||
host_data = None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(file_path, 'r') as f:
|
with open(file_path, 'r') as f:
|
||||||
|
@ -42,11 +36,9 @@ def load_config_file(file_path):
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
stripped_line = line.strip()
|
stripped_line = line.strip()
|
||||||
# Skip empty lines and comments
|
|
||||||
if not stripped_line or stripped_line.startswith('#'):
|
if not stripped_line or stripped_line.startswith('#'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Start of a new Host block
|
|
||||||
if stripped_line.lower().startswith('host '):
|
if stripped_line.lower().startswith('host '):
|
||||||
host_labels = stripped_line.split()[1:]
|
host_labels = stripped_line.split()[1:]
|
||||||
for label in host_labels:
|
for label in host_labels:
|
||||||
|
@ -55,7 +47,7 @@ def load_config_file(file_path):
|
||||||
blocks.append(host_data)
|
blocks.append(host_data)
|
||||||
host_data = OrderedDict({'Host': label})
|
host_data = OrderedDict({'Host': label})
|
||||||
break
|
break
|
||||||
elif host_data:
|
elif host_data is not None:
|
||||||
if ' ' in stripped_line:
|
if ' ' in stripped_line:
|
||||||
key, value = stripped_line.split(None, 1)
|
key, value = stripped_line.split(None, 1)
|
||||||
host_data[key] = value.strip()
|
host_data[key] = value.strip()
|
||||||
|
@ -64,111 +56,137 @@ def load_config_file(file_path):
|
||||||
blocks.append(host_data)
|
blocks.append(host_data)
|
||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
async def check_host(host):
|
async def gather_host_info(all_host_blocks: List[OrderedDictType]) -> List[Tuple]:
|
||||||
"""
|
"""
|
||||||
Given a host block, resolve IP, check SSH port, etc.
|
Given a list of host blocks, gather full info:
|
||||||
Returns a tuple of:
|
- resolved IP
|
||||||
1) Host label
|
- 'Conf Directory' coloring if IdentityFile != 'N/A'
|
||||||
2) User
|
Returns a list of 7-tuples:
|
||||||
3) Port (colored if open)
|
(host_label, user, port, hostname, colored_ip, conf_path_display, raw_ip)
|
||||||
4) HostName
|
|
||||||
5) IP Address (colored if resolved)
|
|
||||||
6) Conf Directory (green if has IdentityFile, else no color)
|
|
||||||
7) raw_ip (uncolored string for sorting)
|
|
||||||
"""
|
"""
|
||||||
host_label = host.get('Host', 'N/A')
|
results = []
|
||||||
hostname = host.get('HostName', 'N/A')
|
|
||||||
user = host.get('User', 'N/A')
|
|
||||||
port = int(host.get('Port', '22'))
|
|
||||||
identity_file = host.get('IdentityFile', 'N/A')
|
|
||||||
|
|
||||||
# Resolve IP
|
async def process_block(h: OrderedDictType) -> Tuple:
|
||||||
|
host_label: str = h.get('Host', 'N/A')
|
||||||
|
hostname: str = h.get('HostName', 'N/A')
|
||||||
|
user: str = h.get('User', 'N/A')
|
||||||
|
port: str = h.get('Port', '22') # Keep as string since we're not testing it
|
||||||
|
identity_file: str = h.get('IdentityFile', 'N/A')
|
||||||
|
|
||||||
|
# Resolve IP using cached function
|
||||||
|
raw_ip = resolve_hostname(hostname) or "N/A"
|
||||||
|
|
||||||
|
# Check if the IP is reachable
|
||||||
|
def is_ip_reachable(ip: str) -> bool:
|
||||||
try:
|
try:
|
||||||
raw_ip = socket.gethostbyname(hostname) # uncolored
|
subprocess.run(["ping", "-c", "1", "-W", "1", ip], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||||
colored_ip = f"{Colors.GREEN}{raw_ip}{Colors.RESET}"
|
return True
|
||||||
except socket.error:
|
except subprocess.CalledProcessError:
|
||||||
raw_ip = "N/A"
|
return False
|
||||||
colored_ip = f"{Colors.RED}N/A{Colors.RESET}"
|
|
||||||
|
|
||||||
# Check port
|
# Determine IP color
|
||||||
if raw_ip != "N/A":
|
if raw_ip != "N/A" and is_ip_reachable(raw_ip):
|
||||||
port_open = await check_ssh_port(raw_ip, port)
|
colored_ip = f"{Colors.GREEN}{raw_ip}{Colors.RESET}"
|
||||||
colored_port = (
|
|
||||||
f"{Colors.GREEN}{port}{Colors.RESET}" if port_open else f"{Colors.RED}{port}{Colors.RESET}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
colored_port = f"{Colors.RED}{port}{Colors.RESET}"
|
colored_ip = raw_ip # No color if not reachable
|
||||||
|
|
||||||
|
# Determine hostname color
|
||||||
|
try:
|
||||||
|
ipaddress.ip_address(hostname)
|
||||||
|
colored_hostname = hostname # No color if hostname is an IP
|
||||||
|
except ValueError:
|
||||||
|
if raw_ip != "N/A":
|
||||||
|
colored_hostname = f"{Colors.GREEN}{hostname}{Colors.RESET}"
|
||||||
|
else:
|
||||||
|
colored_hostname = f"{Colors.RED}{hostname}{Colors.RESET}"
|
||||||
|
|
||||||
# Conf Directory = ~/.ssh/conf/<host_label>
|
# Conf Directory = ~/.ssh/conf/<host_label>
|
||||||
conf_path = f"~/.ssh/conf/{host_label}"
|
conf_path = f"~/.ssh/conf/{host_label}"
|
||||||
# If there's an IdentityFile, color the conf path green
|
conf_path_display = (
|
||||||
if identity_file != 'N/A':
|
f"{Colors.GREEN}{conf_path}{Colors.RESET}"
|
||||||
conf_path_display = f"{Colors.GREEN}{conf_path}{Colors.RESET}"
|
if identity_file != 'N/A'
|
||||||
else:
|
else conf_path
|
||||||
conf_path_display = conf_path
|
)
|
||||||
|
|
||||||
# Return the data plus the uncolored IP for sorting
|
|
||||||
return (
|
return (
|
||||||
host_label,
|
host_label,
|
||||||
user,
|
user,
|
||||||
colored_port,
|
port, # Port is now uncolored
|
||||||
hostname,
|
colored_hostname,
|
||||||
colored_ip,
|
colored_ip,
|
||||||
conf_path_display,
|
conf_path_display,
|
||||||
raw_ip # for sorting
|
raw_ip # uncolored IP for sorting
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_hosts(conf_dir):
|
# Process blocks concurrently with semaphore to limit concurrent connections
|
||||||
|
sem = asyncio.Semaphore(10) # Limit concurrent connections
|
||||||
|
async def process_with_semaphore(block):
|
||||||
|
async with sem:
|
||||||
|
return await process_block(block)
|
||||||
|
|
||||||
|
tasks = [process_with_semaphore(b) for b in all_host_blocks]
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
return results
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def parse_ip(ip_str: str) -> Optional[ipaddress.IPv4Address]:
|
||||||
"""
|
"""
|
||||||
List out all hosts found in ~/.ssh/conf/*/config, sorted by IP in ascending order.
|
Convert a string IP to an ipaddress object for sorting.
|
||||||
Columns: No., Host, User, Port, HostName, IP Address, Conf Directory
|
Returns None if invalid or 'N/A'.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return ipaddress.ip_address(ip_str)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def sort_by_ip(results: List[Tuple]) -> List[Tuple]:
|
||||||
|
"""
|
||||||
|
Sort the 7-tuples by IP ascending, with 'N/A' last.
|
||||||
|
"""
|
||||||
|
def sort_key(row):
|
||||||
|
raw_ip = row[-1]
|
||||||
|
ip_obj = parse_ip(raw_ip)
|
||||||
|
return (ip_obj is None, ip_obj or ipaddress.IPv4Address('0.0.0.0'))
|
||||||
|
|
||||||
|
return sorted(results, key=sort_key)
|
||||||
|
|
||||||
|
async def build_host_list_table(conf_dir: str) -> Tuple[List[str], List[List]]:
|
||||||
|
"""
|
||||||
|
Gathers + sorts all hosts in conf_dir by IP ascending.
|
||||||
|
Returns (headers, final_table_rows), each row omitting the raw_ip.
|
||||||
"""
|
"""
|
||||||
pattern = os.path.join(conf_dir, "*", "config")
|
pattern = os.path.join(conf_dir, "*", "config")
|
||||||
conf_files = sorted(glob.glob(pattern))
|
conf_files = sorted(glob.glob(pattern))
|
||||||
|
|
||||||
all_host_blocks = []
|
all_host_blocks: List[OrderedDictType] = []
|
||||||
for conf_file in conf_files:
|
for conf_file in conf_files:
|
||||||
blocks = load_config_file(conf_file)
|
blocks = load_config_file(conf_file)
|
||||||
all_host_blocks.extend(blocks)
|
all_host_blocks.extend(blocks)
|
||||||
|
|
||||||
headers = ["No.", "Host", "User", "Port", "HostName", "IP Address", "Conf Directory"]
|
headers = ["No.", "Host", "User", "Port", "HostName", "IP Address", "Conf Directory"]
|
||||||
if not all_host_blocks:
|
if not all_host_blocks:
|
||||||
|
return headers, []
|
||||||
|
|
||||||
|
results = await gather_host_info(all_host_blocks)
|
||||||
|
sorted_rows = sort_by_ip(results)
|
||||||
|
|
||||||
|
# Build final table
|
||||||
|
final_data = [
|
||||||
|
[idx] + list(row[:-1])
|
||||||
|
for idx, row in enumerate(sorted_rows, start=1)
|
||||||
|
]
|
||||||
|
|
||||||
|
return headers, final_data
|
||||||
|
|
||||||
|
async def list_hosts(conf_dir: str) -> None:
|
||||||
|
"""Display a formatted table of all SSH hosts."""
|
||||||
|
headers, final_data = await build_host_list_table(conf_dir)
|
||||||
|
|
||||||
|
if not final_data:
|
||||||
print_warning("No hosts found. The server list is empty.")
|
print_warning("No hosts found. The server list is empty.")
|
||||||
print("\nSSH Conf Subdirectory Host List")
|
print("\nSSH Conf Subdirectory Host List")
|
||||||
print(tabulate([], headers=headers, tablefmt="grid"))
|
print(tabulate([], headers=headers, tablefmt="grid"))
|
||||||
return
|
return
|
||||||
|
|
||||||
# Gather full data for each host
|
|
||||||
tasks = [check_host(h) for h in all_host_blocks]
|
|
||||||
results = await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
# We want to sort by IP ascending. results[i] is a tuple:
|
|
||||||
# (host_label, user, colored_port, hostname, colored_ip, conf_path, raw_ip)
|
|
||||||
# We'll parse raw_ip as an ipaddress for sorting. "N/A" => sort to the end.
|
|
||||||
def parse_ip(ip_str):
|
|
||||||
try:
|
|
||||||
return ipaddress.ip_address(ip_str)
|
|
||||||
except ValueError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Convert the results into a list of (ip_obj, original_tuple)
|
|
||||||
# so we can sort, then rebuild the final data.
|
|
||||||
sortable = []
|
|
||||||
for row in results:
|
|
||||||
raw_ip = row[-1] # last element
|
|
||||||
ip_obj = parse_ip(raw_ip)
|
|
||||||
# We'll sort None last by using a sort key that puts (True) after (False)
|
|
||||||
# e.g. (ip_obj is None, ip_obj)
|
|
||||||
sortable.append(((ip_obj is None, ip_obj), row))
|
|
||||||
|
|
||||||
# Sort by (is_none, ip_obj)
|
|
||||||
sortable.sort(key=lambda x: x[0])
|
|
||||||
|
|
||||||
# Rebuild the final display table, ignoring the raw_ip at the end
|
|
||||||
final_data = []
|
|
||||||
for idx, (_, row) in enumerate(sortable, start=1):
|
|
||||||
# row is (host_label, user, colored_port, hostname, colored_ip, conf_path, raw_ip)
|
|
||||||
final_data.append([idx] + list(row[:-1])) # omit raw_ip
|
|
||||||
|
|
||||||
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
|
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
|
||||||
print(tabulate(final_data, headers=headers, tablefmt="grid"))
|
print(tabulate(final_data, headers=headers, tablefmt="grid"))
|
||||||
|
|
425
regen_key.py
425
regen_key.py
|
@ -1,146 +1,121 @@
|
||||||
import os
|
import os
|
||||||
import glob
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import OrderedDict
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Dict, Tuple, Any
|
||||||
from .utils import (
|
from .utils import (
|
||||||
print_info,
|
print_info,
|
||||||
print_warning,
|
print_warning,
|
||||||
print_error,
|
print_error,
|
||||||
safe_input,
|
safe_input
|
||||||
Colors
|
)
|
||||||
|
from .list_hosts import (
|
||||||
|
build_host_list_table,
|
||||||
|
load_config_file,
|
||||||
|
gather_host_info,
|
||||||
|
sort_by_ip
|
||||||
)
|
)
|
||||||
from .list_hosts import load_config_file, check_ssh_port
|
|
||||||
|
|
||||||
async def get_all_host_blocks(conf_dir):
|
def validate_key_path(key_path: str) -> bool:
|
||||||
pattern = os.path.join(conf_dir, "*", "config")
|
"""Validate that the key path and its directory are valid."""
|
||||||
conf_files = sorted(glob.glob(pattern))
|
|
||||||
|
|
||||||
all_blocks = []
|
|
||||||
for conf_file in conf_files:
|
|
||||||
blocks = load_config_file(conf_file)
|
|
||||||
all_blocks.extend(blocks)
|
|
||||||
|
|
||||||
if not all_blocks:
|
|
||||||
return []
|
|
||||||
return all_blocks
|
|
||||||
|
|
||||||
async def regenerate_key(conf_dir):
|
|
||||||
"""
|
|
||||||
Menu-driven function to regenerate a key for a selected host:
|
|
||||||
1) Show the host table with row numbers
|
|
||||||
2) Let user pick row # or label
|
|
||||||
3) Read old pub data
|
|
||||||
4) Remove old local key files
|
|
||||||
5) Generate new key
|
|
||||||
6) Copy new key
|
|
||||||
7) Remove old key from remote authorized_keys (improved logic to detect existence)
|
|
||||||
"""
|
|
||||||
print_info("Regenerate Key - Step 1: Show current hosts...\n")
|
|
||||||
|
|
||||||
# 1) Gather host blocks
|
|
||||||
all_blocks = await get_all_host_blocks(conf_dir)
|
|
||||||
if not all_blocks:
|
|
||||||
print_warning("No hosts found. Cannot regenerate a key.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Display them in a table (similar to list_hosts):
|
|
||||||
import socket
|
|
||||||
from tabulate import tabulate
|
|
||||||
|
|
||||||
table_rows = []
|
|
||||||
for idx, block in enumerate(all_blocks, start=1):
|
|
||||||
host_label = block.get("Host", "N/A")
|
|
||||||
hostname = block.get("HostName", "N/A")
|
|
||||||
user = block.get("User", "N/A")
|
|
||||||
port = int(block.get("Port", "22"))
|
|
||||||
identity_file = block.get("IdentityFile", "N/A")
|
|
||||||
|
|
||||||
# Check IP
|
|
||||||
try:
|
try:
|
||||||
ip_address = socket.gethostbyname(hostname)
|
key_path = Path(key_path)
|
||||||
port_open = await asyncio.wait_for(check_ssh_port(ip_address, port), timeout=1)
|
key_dir = key_path.parent
|
||||||
except:
|
|
||||||
ip_address = None
|
|
||||||
port_open = False
|
|
||||||
|
|
||||||
ip_disp = f"\033[0;32m{ip_address}\033[0m" if ip_address else "\033[0;31mN/A\033[0m"
|
if not key_dir.exists():
|
||||||
port_disp = f"\033[0;32m{port}\033[0m" if port_open else f"\033[0;31m{port}\033[0m"
|
key_dir.mkdir(mode=0o700, parents=True)
|
||||||
|
elif not os.access(str(key_dir), os.W_OK):
|
||||||
|
print_error(f"No write permission for directory: {key_dir}")
|
||||||
|
return False
|
||||||
|
|
||||||
row = [
|
return True
|
||||||
idx,
|
except Exception as e:
|
||||||
host_label,
|
print_error(f"Error validating key path: {e}")
|
||||||
user,
|
return False
|
||||||
port_disp,
|
|
||||||
hostname,
|
def generate_new_key(key_path: str, user: str, hostname: str, port: int) -> bool:
|
||||||
ip_disp,
|
"""Generate a new ED25519 SSH key and optionally copy it to the remote server."""
|
||||||
identity_file
|
if not validate_key_path(key_path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
print_info("Generating new ed25519 SSH key...")
|
||||||
|
try:
|
||||||
|
subprocess.check_call([
|
||||||
|
"ssh-keygen",
|
||||||
|
"-q",
|
||||||
|
"-t", "ed25519",
|
||||||
|
"-N", "",
|
||||||
|
"-f", key_path
|
||||||
|
])
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print_error(f"Error generating new SSH key: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
print_info(f"Generated new SSH key at {key_path}")
|
||||||
|
|
||||||
|
# Ask to copy the key
|
||||||
|
copy_choice = safe_input("Copy new key to remote now? (y/n): ")
|
||||||
|
if not copy_choice or not copy_choice.lower().startswith('y'):
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
ssh_copy_cmd = ["ssh-copy-id", "-i", key_path]
|
||||||
|
if port != 22:
|
||||||
|
ssh_copy_cmd += ["-p", str(port)]
|
||||||
|
ssh_copy_cmd.append(f"{user}@{hostname}")
|
||||||
|
|
||||||
|
subprocess.check_call(ssh_copy_cmd)
|
||||||
|
print_info("New key successfully copied to remote server.")
|
||||||
|
return True
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print_error(f"Error copying new key: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def update_config_with_key(config_path: Path, new_key_path: str) -> bool:
|
||||||
|
"""Update the SSH config file with the new identity file."""
|
||||||
|
try:
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config_lines = [
|
||||||
|
line.rstrip('\n')
|
||||||
|
for line in f
|
||||||
|
if not line.strip().lower().startswith('identityfile')
|
||||||
]
|
]
|
||||||
table_rows.append(row)
|
|
||||||
|
|
||||||
headers = ["No.", "Host", "User", "Port", "HostName", "IP", "IdentityFile"]
|
config_lines.append(f" IdentityFile {new_key_path}")
|
||||||
print("\nSSH Conf Subdirectory Host List")
|
|
||||||
print(tabulate(table_rows, headers=headers, tablefmt="grid"))
|
|
||||||
|
|
||||||
# 2) Prompt for row # or label
|
with open(config_path, 'w') as f:
|
||||||
choice = safe_input("Enter the row number or the Host label to regenerate: ")
|
f.write('\n'.join(config_lines) + '\n')
|
||||||
if choice is None:
|
|
||||||
return
|
print_info(f"Updated config file with new IdentityFile: {new_key_path}")
|
||||||
choice = choice.strip()
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print_error(f"Failed to update config file: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def find_target_host(sorted_rows: List[Tuple], choice: str) -> Optional[Tuple]:
|
||||||
|
"""Find the target host based on user choice."""
|
||||||
if not choice:
|
if not choice:
|
||||||
print_error("No choice given.")
|
print_error("No choice given.")
|
||||||
return
|
return None
|
||||||
|
|
||||||
target_block = None
|
|
||||||
if choice.isdigit():
|
if choice.isdigit():
|
||||||
row_idx = int(choice)
|
row_idx = int(choice)
|
||||||
if row_idx < 1 or row_idx > len(all_blocks):
|
if row_idx < 1 or row_idx > len(sorted_rows):
|
||||||
print_warning(f"Invalid row number {row_idx}.")
|
print_warning(f"Invalid row number {row_idx}.")
|
||||||
return
|
return None
|
||||||
target_block = all_blocks[row_idx - 1]
|
return sorted_rows[row_idx - 1]
|
||||||
else:
|
|
||||||
# user typed a label
|
# User typed a label
|
||||||
for b in all_blocks:
|
for row in sorted_rows:
|
||||||
if b.get("Host") == choice:
|
if row[0] == choice: # row[0] is host_label
|
||||||
target_block = b
|
return row
|
||||||
break
|
|
||||||
if not target_block:
|
|
||||||
print_warning(f"No matching host label '{choice}' found.")
|
print_warning(f"No matching host label '{choice}' found.")
|
||||||
return
|
return None
|
||||||
|
|
||||||
# 3) Gather info from block
|
def remove_key_files(key_paths: List[str]) -> None:
|
||||||
host_label = target_block.get("Host", "")
|
"""Remove SSH key files."""
|
||||||
hostname = target_block.get("HostName", "")
|
for path in key_paths:
|
||||||
user = target_block.get("User", "root")
|
|
||||||
port = int(target_block.get("Port", "22"))
|
|
||||||
identity_file = target_block.get("IdentityFile", "")
|
|
||||||
|
|
||||||
if not host_label or not hostname:
|
|
||||||
print_error("Missing Host or HostName; cannot regenerate.")
|
|
||||||
return
|
|
||||||
if not identity_file or identity_file == "N/A":
|
|
||||||
print_error("No IdentityFile found in config; cannot regenerate.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Derive local paths
|
|
||||||
expanded_key = os.path.expanduser(identity_file)
|
|
||||||
key_dir = os.path.dirname(expanded_key)
|
|
||||||
pub_path = expanded_key + ".pub"
|
|
||||||
|
|
||||||
# 3a) Read old pub key data
|
|
||||||
old_pub_data = ""
|
|
||||||
if os.path.isfile(pub_path):
|
|
||||||
try:
|
|
||||||
with open(pub_path, "r") as f:
|
|
||||||
old_pub_data = f.read().strip()
|
|
||||||
except Exception as e:
|
|
||||||
print_warning(f"Could not read old pub key: {e}")
|
|
||||||
else:
|
|
||||||
print_warning("No old pub key found locally.")
|
|
||||||
|
|
||||||
# 4) Remove old local key files
|
|
||||||
print_info("Removing old key files locally...")
|
|
||||||
for path in [expanded_key, pub_path]:
|
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
try:
|
try:
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
|
@ -148,74 +123,196 @@ async def regenerate_key(conf_dir):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_warning(f"Could not remove {path}: {e}")
|
print_warning(f"Could not remove {path}: {e}")
|
||||||
|
|
||||||
# 5) Generate new key
|
async def regenerate_key(conf_dir: str) -> bool:
|
||||||
print_info("Generating new ed25519 SSH key...")
|
"""
|
||||||
new_key_path = expanded_key # Reuse the same path from config
|
Regenerate the SSH key for a chosen host by:
|
||||||
cmd = ["ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-f", new_key_path]
|
1) Displaying the unified table of hosts
|
||||||
try:
|
2) Letting you pick a row number or host label
|
||||||
subprocess.check_call(cmd)
|
3) Reading/deleting any existing local keys
|
||||||
print_info(f"Generated new SSH key at {new_key_path}")
|
4) Generating a new key
|
||||||
except subprocess.CalledProcessError as e:
|
5) Optionally copying it to the remote
|
||||||
print_error(f"Error generating new SSH key: {e}")
|
6) Removing the old pub key from the remote authorized_keys if present
|
||||||
return
|
|
||||||
|
|
||||||
# 6) Copy new key to remote
|
Returns True if key was successfully regenerated, False otherwise.
|
||||||
copy_choice = safe_input("Copy new key to remote now? (y/n): ")
|
"""
|
||||||
if copy_choice and copy_choice.lower().startswith('y'):
|
print_info("Regenerate Key - Step 1: Show current hosts...\n")
|
||||||
ssh_copy_cmd = ["ssh-copy-id", "-i", new_key_path]
|
|
||||||
if port != 22:
|
|
||||||
ssh_copy_cmd += ["-p", str(port)]
|
|
||||||
ssh_copy_cmd.append(f"{user}@{hostname}")
|
|
||||||
try:
|
|
||||||
subprocess.check_call(ssh_copy_cmd)
|
|
||||||
print_info("New key successfully copied to remote server.")
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
print_error(f"Error copying new key: {e}")
|
|
||||||
|
|
||||||
# 7) Remove old key from authorized_keys if old_pub_data is non-empty
|
# Get host list
|
||||||
|
headers, final_data = await build_host_list_table(conf_dir)
|
||||||
|
if not final_data:
|
||||||
|
print_warning("No hosts found. Cannot regenerate a key.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Display host table
|
||||||
|
from tabulate import tabulate
|
||||||
|
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
|
||||||
|
print(tabulate(final_data, headers=headers, tablefmt="grid"))
|
||||||
|
|
||||||
|
# Get host blocks and sort them
|
||||||
|
all_blocks = []
|
||||||
|
pattern = os.path.join(conf_dir, "*", "config")
|
||||||
|
for cfile in glob.glob(pattern):
|
||||||
|
all_blocks.extend(load_config_file(cfile))
|
||||||
|
|
||||||
|
results = await gather_host_info(all_blocks)
|
||||||
|
sorted_rows = sort_by_ip(results)
|
||||||
|
|
||||||
|
# Get user choice
|
||||||
|
choice = safe_input("Enter the row number or the Host label to regenerate: ")
|
||||||
|
if choice is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
target_tuple = find_target_host(sorted_rows, choice.strip())
|
||||||
|
if not target_tuple:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Get host information
|
||||||
|
host_label, user, _, hostname, *_ = target_tuple
|
||||||
|
|
||||||
|
# Find config block
|
||||||
|
found_block = next(
|
||||||
|
(b for b in all_blocks if b.get("Host") == host_label),
|
||||||
|
None
|
||||||
|
)
|
||||||
|
if not found_block:
|
||||||
|
print_warning(f"No config block found for '{host_label}'.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
port = int(found_block.get("Port", "22"))
|
||||||
|
identity_file = found_block.get("IdentityFile", "")
|
||||||
|
|
||||||
|
# Handle missing identity file
|
||||||
|
if not identity_file or identity_file == "N/A":
|
||||||
|
print_warning("No existing SSH key found in the configuration.")
|
||||||
|
gen_choice = safe_input("Would you like to generate a new key? (y/n): ")
|
||||||
|
if not gen_choice or not gen_choice.lower().startswith('y'):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Set up new key path and generate key
|
||||||
|
host_dir = Path(conf_dir) / host_label
|
||||||
|
host_dir.mkdir(mode=0o700, exist_ok=True)
|
||||||
|
new_key_path = str(host_dir / "id_ed25519")
|
||||||
|
|
||||||
|
if not generate_new_key(new_key_path, user, hostname, port):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Update config with new key
|
||||||
|
config_path = host_dir / "config"
|
||||||
|
return update_config_with_key(config_path, new_key_path)
|
||||||
|
|
||||||
|
# Handle existing key regeneration
|
||||||
|
expanded_key = os.path.expanduser(identity_file)
|
||||||
|
pub_path = expanded_key + ".pub"
|
||||||
|
old_pub_data = ""
|
||||||
|
|
||||||
|
# Try to read old public key
|
||||||
|
if os.path.isfile(pub_path):
|
||||||
|
try:
|
||||||
|
old_pub_data = Path(pub_path).read_text().rstrip("\n")
|
||||||
|
except Exception as e:
|
||||||
|
print_warning(f"Could not read old pub key: {e}")
|
||||||
|
|
||||||
|
# Remove old key files
|
||||||
|
print_info("Removing old key files locally...")
|
||||||
|
remove_key_files([expanded_key, pub_path])
|
||||||
|
|
||||||
|
# Generate new key
|
||||||
|
if not generate_new_key(expanded_key, user, hostname, port):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Remove old key from remote if we had it
|
||||||
if old_pub_data:
|
if old_pub_data:
|
||||||
print_info("Attempting to remove old key from remote authorized_keys...")
|
print_info("Attempting to remove old key from remote authorized_keys...")
|
||||||
await remove_old_key_remote(old_pub_data, user, hostname, port)
|
await remove_old_key_remote(old_pub_data, user, hostname, port)
|
||||||
else:
|
else:
|
||||||
print_warning("No old pub key data found locally, skipping remote removal.")
|
print_warning("No old pub key data found locally, skipping remote removal.")
|
||||||
|
|
||||||
async def remove_old_key_remote(old_pub_data, user, hostname, port):
|
return True
|
||||||
|
|
||||||
|
async def remove_old_key_remote(old_pub_data: str, user: str, hostname: str, port: int) -> bool:
|
||||||
"""
|
"""
|
||||||
1) Check if old_pub_data is present on remote server (grep -q).
|
Remove the old public key from remote authorized_keys file.
|
||||||
2) If found, remove it with grep -v ...
|
Returns True if successful, False otherwise.
|
||||||
3) Otherwise, print "No old key found on remote."
|
|
||||||
"""
|
"""
|
||||||
# 1) Check if old_pub_data exists in authorized_keys
|
# Escape the public key data for shell safety
|
||||||
check_cmd = [
|
escaped_key = old_pub_data.replace('"', '\\"')
|
||||||
|
|
||||||
|
# First check if authorized_keys exists
|
||||||
|
check_file_cmd = [
|
||||||
"ssh",
|
"ssh",
|
||||||
"-o", "StrictHostKeyChecking=no",
|
"-o", "StrictHostKeyChecking=no",
|
||||||
"-p", str(port),
|
"-p", str(port),
|
||||||
f"{user}@{hostname}",
|
f"{user}@{hostname}",
|
||||||
f"grep -F '{old_pub_data}' ~/.ssh/authorized_keys"
|
"test -f ~/.ssh/authorized_keys && echo 'exists'"
|
||||||
]
|
]
|
||||||
found_key = False
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
subprocess.check_call(check_cmd)
|
result = subprocess.run(check_file_cmd, capture_output=True, text=True)
|
||||||
found_key = True
|
if result.returncode != 0 or 'exists' not in result.stdout:
|
||||||
except subprocess.CalledProcessError:
|
print_warning("No authorized_keys file found on remote host.")
|
||||||
# grep returns exit code 1 if not found
|
return False
|
||||||
pass
|
except subprocess.CalledProcessError as e:
|
||||||
|
print_error(f"Error checking authorized_keys: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
if not found_key:
|
# Create a temporary file for the sed script
|
||||||
print_warning("No old key found on remote authorized_keys.")
|
temp_script = """#!/bin/bash
|
||||||
return
|
set -e
|
||||||
|
KEYS_FILE="$HOME/.ssh/authorized_keys"
|
||||||
|
TEMP_FILE="$HOME/.ssh/authorized_keys.tmp"
|
||||||
|
grep -Fxv "$1" "$KEYS_FILE" > "$TEMP_FILE"
|
||||||
|
mv "$TEMP_FILE" "$KEYS_FILE"
|
||||||
|
chmod 600 "$KEYS_FILE"
|
||||||
|
"""
|
||||||
|
|
||||||
# 2) Actually remove it
|
# Create a temporary script on the remote host
|
||||||
|
setup_cmd = [
|
||||||
|
"ssh",
|
||||||
|
"-o", "StrictHostKeyChecking=no",
|
||||||
|
"-p", str(port),
|
||||||
|
f"{user}@{hostname}",
|
||||||
|
f'cat > ~/.ssh/remove_key.sh << \'EOF\'\n{temp_script}\nEOF\n'
|
||||||
|
f'chmod +x ~/.ssh/remove_key.sh'
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.run(setup_cmd, check=True, capture_output=True)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print_error(f"Failed to create temporary script: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Execute the script with the key
|
||||||
remove_cmd = [
|
remove_cmd = [
|
||||||
"ssh",
|
"ssh",
|
||||||
"-o", "StrictHostKeyChecking=no",
|
"-o", "StrictHostKeyChecking=no",
|
||||||
"-p", str(port),
|
"-p", str(port),
|
||||||
f"{user}@{hostname}",
|
f"{user}@{hostname}",
|
||||||
f"grep -v '{old_pub_data}' ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys"
|
f'~/.ssh/remove_key.sh "{escaped_key}"'
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
subprocess.check_call(remove_cmd)
|
subprocess.run(remove_cmd, check=True, capture_output=True)
|
||||||
print_info("Old public key removed from remote authorized_keys.")
|
print_info("Old public key removed from remote authorized_keys.")
|
||||||
except subprocess.CalledProcessError:
|
|
||||||
print_warning("Failed to remove old key from remote authorized_keys (permission or other error).")
|
# Clean up the temporary script
|
||||||
|
cleanup_cmd = [
|
||||||
|
"ssh",
|
||||||
|
"-o", "StrictHostKeyChecking=no",
|
||||||
|
"-p", str(port),
|
||||||
|
f"{user}@{hostname}",
|
||||||
|
"rm -f ~/.ssh/remove_key.sh"
|
||||||
|
]
|
||||||
|
subprocess.run(cleanup_cmd, check=True, capture_output=True)
|
||||||
|
|
||||||
|
return True
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print_error(f"Failed to remove old key: {e}")
|
||||||
|
# Try to clean up even if removal failed
|
||||||
|
try:
|
||||||
|
subprocess.run(cleanup_cmd, check=True, capture_output=True)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print_error(f"Unexpected error removing old key: {e}")
|
||||||
|
return False
|
||||||
|
|
153
remove_host.py
153
remove_host.py
|
@ -1,82 +1,68 @@
|
||||||
# ssh_manager/remove_host.py
|
# ssh_manager/remove_host.py
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import glob
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import OrderedDict
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
print_info,
|
print_info,
|
||||||
print_warning,
|
print_warning,
|
||||||
print_error,
|
print_error,
|
||||||
safe_input
|
safe_input
|
||||||
)
|
)
|
||||||
from .list_hosts import load_config_file, check_ssh_port
|
from .list_hosts import build_host_list_table, load_config_file
|
||||||
|
"""
|
||||||
async def get_all_host_blocks(conf_dir):
|
Remove host now reuses build_host_list_table to display the same columns:
|
||||||
pattern = os.path.join(conf_dir, "*", "config")
|
No. | Host | User | Port | HostName | IP Address | Conf Directory
|
||||||
conf_files = sorted(glob.glob(pattern))
|
"""
|
||||||
|
|
||||||
all_blocks = []
|
|
||||||
for conf_file in conf_files:
|
|
||||||
blocks = load_config_file(conf_file)
|
|
||||||
all_blocks.extend(blocks)
|
|
||||||
|
|
||||||
return all_blocks
|
|
||||||
|
|
||||||
async def remove_host(conf_dir):
|
async def remove_host(conf_dir):
|
||||||
"""
|
"""
|
||||||
Remove an SSH host by:
|
Remove an SSH host by:
|
||||||
1) Listing all hosts
|
1) Listing all hosts (same columns as main list)
|
||||||
2) Letting user pick row number or label
|
2) Letting user pick row number or label
|
||||||
3) Attempting to remove the old pub key from remote authorized_keys
|
3) Removing old pub key from remote
|
||||||
4) Deleting the subdirectory in ~/.ssh/conf/<host_label>
|
4) Deleting ~/.ssh/conf/<host_label>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print_info("Remove Host - Step 1: Show current hosts...\n")
|
print_info("Remove Host - Step 1: Show current hosts...\n")
|
||||||
|
|
||||||
all_blocks = await get_all_host_blocks(conf_dir)
|
# Reuse the unified table from list_hosts
|
||||||
if not all_blocks:
|
headers, final_data = await build_host_list_table(conf_dir)
|
||||||
|
|
||||||
|
if not final_data:
|
||||||
print_warning("No hosts found. Cannot remove anything.")
|
print_warning("No hosts found. Cannot remove anything.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# We'll display them in a table
|
# Print the same table
|
||||||
import socket
|
|
||||||
from tabulate import tabulate
|
from tabulate import tabulate
|
||||||
|
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
|
||||||
|
print(tabulate(final_data, headers=headers, tablefmt="grid"))
|
||||||
|
|
||||||
table_rows = []
|
# We have final_data rows => need to map row => block
|
||||||
for idx, block in enumerate(all_blocks, start=1):
|
# So let's gather the raw blocks again to correlate.
|
||||||
host_label = block.get("Host", "N/A")
|
# We'll do a separate approach or we can parse final_data.
|
||||||
hostname = block.get("HostName", "N/A")
|
# Easiest: Re-run load_config_file if needed or:
|
||||||
user = block.get("User", "N/A")
|
blocks = []
|
||||||
port = int(block.get("Port", "22"))
|
# The last gather call for build_host_list_table used load_config_file already
|
||||||
identity_file = block.get("IdentityFile", "N/A")
|
# but it doesn't return the correlation. We'll replicate the logic quickly.
|
||||||
|
|
||||||
try:
|
pattern = os.path.join(conf_dir, "*", "config")
|
||||||
ip_address = socket.gethostbyname(hostname)
|
conf_files = sorted(os.listdir(conf_dir))
|
||||||
port_open = await asyncio.wait_for(check_ssh_port(ip_address, port), timeout=1)
|
|
||||||
except:
|
|
||||||
ip_address = None
|
|
||||||
port_open = False
|
|
||||||
|
|
||||||
ip_disp = f"\033[0;32m{ip_address}\033[0m" if ip_address else "\033[0;31mN/A\033[0m"
|
# Actually, let's do a small approach:
|
||||||
port_disp = f"\033[0;32m{port}\033[0m" if port_open else f"\033[0;31m{port}\033[0m"
|
from .list_hosts import gather_host_info, sort_by_ip
|
||||||
|
all_blocks = []
|
||||||
|
import glob
|
||||||
|
|
||||||
row = [
|
for cfile in glob.glob(os.path.join(conf_dir, "*", "config")):
|
||||||
idx,
|
blocks.extend(load_config_file(cfile))
|
||||||
host_label,
|
|
||||||
user,
|
|
||||||
port_disp,
|
|
||||||
hostname,
|
|
||||||
ip_disp,
|
|
||||||
identity_file
|
|
||||||
]
|
|
||||||
table_rows.append(row)
|
|
||||||
|
|
||||||
headers = ["No.", "Host", "User", "Port", "HostName", "IP", "IdentityFile"]
|
# gather the same big list
|
||||||
print("\nSSH Conf Subdirectory Host List")
|
results = await gather_host_info(blocks)
|
||||||
print(tabulate(table_rows, headers=headers, tablefmt="grid"))
|
sorted_rows = sort_by_ip(results)
|
||||||
|
# sorted_rows is a list of 7-tuples:
|
||||||
|
# (host_label, user, colored_port, hostname, colored_ip, conf_display, raw_ip)
|
||||||
|
|
||||||
# Prompt which host to remove
|
|
||||||
choice = safe_input("Enter the row number or Host label to remove: ")
|
choice = safe_input("Enter the row number or Host label to remove: ")
|
||||||
if choice is None:
|
if choice is None:
|
||||||
return
|
return
|
||||||
|
@ -85,43 +71,60 @@ async def remove_host(conf_dir):
|
||||||
print_error("Invalid empty choice.")
|
print_error("Invalid empty choice.")
|
||||||
return
|
return
|
||||||
|
|
||||||
target_block = None
|
target_tuple = None
|
||||||
|
|
||||||
|
# If digit => index in sorted_rows
|
||||||
if choice.isdigit():
|
if choice.isdigit():
|
||||||
idx = int(choice)
|
idx = int(choice)
|
||||||
if idx < 1 or idx > len(all_blocks):
|
if idx < 1 or idx > len(sorted_rows):
|
||||||
print_warning(f"Row number {idx} is invalid.")
|
print_warning(f"Row number {idx} is invalid.")
|
||||||
return
|
return
|
||||||
target_block = all_blocks[idx - 1]
|
target_tuple = sorted_rows[idx - 1]
|
||||||
else:
|
else:
|
||||||
# They typed a label
|
# They typed a label
|
||||||
for b in all_blocks:
|
for t in sorted_rows:
|
||||||
if b.get("Host") == choice:
|
if t[0] == choice: # t[0] = host_label
|
||||||
target_block = b
|
target_tuple = t
|
||||||
break
|
break
|
||||||
if not target_block:
|
if not target_tuple:
|
||||||
print_warning(f"No matching host label '{choice}' found.")
|
print_warning(f"No matching host label '{choice}' found.")
|
||||||
return
|
return
|
||||||
|
|
||||||
host_label = target_block.get("Host", "")
|
# target_tuple is (host_label, user, colored_port, hostname, colored_ip, conf_dir, raw_ip)
|
||||||
hostname = target_block.get("HostName", "")
|
host_label = target_tuple[0]
|
||||||
user = target_block.get("User", "root")
|
hostname = target_tuple[3]
|
||||||
port = int(target_block.get("Port", "22"))
|
user = target_tuple[1]
|
||||||
identity_file = target_block.get("IdentityFile", "")
|
# parse port from colored_port? We can do a quick approach. Alternatively, we re-run the load_config approach again.
|
||||||
|
# We do a second approach: let's see if we can find the block in "blocks".
|
||||||
|
# But let's parse the config block to get the real port, identity, etc.
|
||||||
|
|
||||||
|
# We find the matching block in "blocks" by host_label
|
||||||
|
found_block = None
|
||||||
|
for b in blocks:
|
||||||
|
if b.get("Host") == host_label:
|
||||||
|
found_block = b
|
||||||
|
break
|
||||||
|
|
||||||
|
if not found_block:
|
||||||
|
print_warning(f"No config block found for {host_label}.")
|
||||||
|
return
|
||||||
|
|
||||||
|
port_str = found_block.get("Port", "22")
|
||||||
|
port = int(port_str)
|
||||||
|
identity_file = found_block.get("IdentityFile", "")
|
||||||
|
|
||||||
|
# now do removal logic
|
||||||
if not host_label:
|
if not host_label:
|
||||||
print_warning("Target block has no Host label. Cannot remove.")
|
print_warning("Target block has no Host label. Cannot remove.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Check IdentityFile for old pub key
|
|
||||||
if identity_file and identity_file != "N/A":
|
if identity_file and identity_file != "N/A":
|
||||||
expanded_key = os.path.expanduser(identity_file)
|
expanded_key = os.path.expanduser(identity_file)
|
||||||
pub_path = expanded_key + ".pub"
|
pub_path = expanded_key + ".pub"
|
||||||
old_pub_data = ""
|
old_pub_data = ""
|
||||||
|
|
||||||
if os.path.isfile(pub_path):
|
if os.path.isfile(pub_path):
|
||||||
try:
|
try:
|
||||||
with open(pub_path, "r") as f:
|
with open(pub_path, "r") as f:
|
||||||
# This is the EXACT line that might appear in authorized_keys
|
|
||||||
old_pub_data = f.read().rstrip("\n")
|
old_pub_data = f.read().rstrip("\n")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_warning(f"Could not read old pub key: {e}")
|
print_warning(f"Could not read old pub key: {e}")
|
||||||
|
@ -130,13 +133,13 @@ async def remove_host(conf_dir):
|
||||||
print_info("Attempting to remove old key from remote authorized_keys...")
|
print_info("Attempting to remove old key from remote authorized_keys...")
|
||||||
await remove_old_key_remote(old_pub_data, user, hostname, port)
|
await remove_old_key_remote(old_pub_data, user, hostname, port)
|
||||||
|
|
||||||
# Finally, remove the local subdirectory
|
# remove local folder
|
||||||
host_dir = os.path.join(conf_dir, host_label)
|
host_dir = os.path.join(conf_dir, host_label)
|
||||||
|
import shutil
|
||||||
if os.path.isdir(host_dir):
|
if os.path.isdir(host_dir):
|
||||||
confirm = safe_input(f"Are you sure you want to delete local folder {host_dir}? (y/n): ")
|
confirm = safe_input(f"Are you sure you want to delete local folder {host_dir}? (y/n): ")
|
||||||
if confirm and confirm.lower().startswith('y'):
|
if confirm and confirm.lower().startswith('y'):
|
||||||
try:
|
try:
|
||||||
import shutil
|
|
||||||
shutil.rmtree(host_dir)
|
shutil.rmtree(host_dir)
|
||||||
print_info(f"Removed local folder: {host_dir}")
|
print_info(f"Removed local folder: {host_dir}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -144,24 +147,17 @@ async def remove_host(conf_dir):
|
||||||
else:
|
else:
|
||||||
print_warning("Local folder not removed.")
|
print_warning("Local folder not removed.")
|
||||||
else:
|
else:
|
||||||
print_warning(f"Local host folder {host_dir} not found. Nothing to remove.")
|
print_warning(f"Local folder {host_dir} not found. Nothing to remove.")
|
||||||
|
|
||||||
async def remove_old_key_remote(old_pub_data, user, hostname, port):
|
async def remove_old_key_remote(old_pub_data, user, hostname, port):
|
||||||
"""
|
# same logic as before
|
||||||
Remove the old key from remote authorized_keys if found, matching EXACT lines.
|
|
||||||
We'll do:
|
|
||||||
grep -Fxq for existence
|
|
||||||
grep -vFx to remove it
|
|
||||||
"""
|
|
||||||
# 1) Check if old_pub_data is present in authorized_keys (exact line)
|
|
||||||
check_cmd = [
|
check_cmd = [
|
||||||
"ssh", "-o", "StrictHostKeyChecking=no",
|
"ssh", "-o", "StrictHostKeyChecking=no",
|
||||||
"-p", str(port),
|
"-p", str(port),
|
||||||
f"{user}@{hostname}",
|
f"{user}@{hostname}",
|
||||||
f"grep -Fxq \"{old_pub_data}\" ~/.ssh/authorized_keys"
|
f'grep -Fxq "{old_pub_data}" ~/.ssh/authorized_keys'
|
||||||
]
|
]
|
||||||
found_key = False
|
found_key = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
subprocess.check_call(check_cmd)
|
subprocess.check_call(check_cmd)
|
||||||
found_key = True
|
found_key = True
|
||||||
|
@ -172,12 +168,11 @@ async def remove_old_key_remote(old_pub_data, user, hostname, port):
|
||||||
print_warning("No old key found on remote authorized_keys.")
|
print_warning("No old key found on remote authorized_keys.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 2) Actually remove it by ignoring EXACT matches to that line
|
|
||||||
remove_cmd = [
|
remove_cmd = [
|
||||||
"ssh", "-o", "StrictHostKeyChecking=no",
|
"ssh", "-o", "StrictHostKeyChecking=no",
|
||||||
"-p", str(port),
|
"-p", str(port),
|
||||||
f"{user}@{hostname}",
|
f"{user}@{hostname}",
|
||||||
f"grep -vFx \"{old_pub_data}\" ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys"
|
f'grep -vFx "{old_pub_data}" ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys'
|
||||||
]
|
]
|
||||||
try:
|
try:
|
||||||
subprocess.check_call(remove_cmd)
|
subprocess.check_call(remove_cmd)
|
||||||
|
|
25
utils.py
25
utils.py
|
@ -1,8 +1,10 @@
|
||||||
# ssh_manager/utils.py
|
# ssh_manager/utils.py
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
import sys
|
import sys
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
class Colors:
|
class Colors(Enum):
|
||||||
GREEN = "\033[0;32m"
|
GREEN = "\033[0;32m"
|
||||||
RED = "\033[0;31m"
|
RED = "\033[0;31m"
|
||||||
YELLOW = "\033[1;33m"
|
YELLOW = "\033[1;33m"
|
||||||
|
@ -10,16 +12,23 @@ class Colors:
|
||||||
BOLD = "\033[1m"
|
BOLD = "\033[1m"
|
||||||
RESET = "\033[0m"
|
RESET = "\033[0m"
|
||||||
|
|
||||||
def print_error(message):
|
def __str__(self):
|
||||||
print(f"{Colors.RED}{Colors.BOLD}[✖] {Colors.RESET}{message}")
|
return self.value
|
||||||
|
|
||||||
def print_warning(message):
|
@lru_cache(maxsize=32)
|
||||||
print(f"{Colors.YELLOW}{Colors.BOLD}[⚠] {Colors.RESET}{message}")
|
def _format_message(prefix: str, color: Colors, message: str) -> str:
|
||||||
|
return f"{color}{Colors.BOLD}[{prefix}] {Colors.RESET}{message}"
|
||||||
|
|
||||||
def print_info(message):
|
def print_error(message: str) -> None:
|
||||||
print(f"{Colors.GREEN}{Colors.BOLD}[✔] {Colors.RESET}{message}")
|
print(_format_message("✖", Colors.RED, message))
|
||||||
|
|
||||||
def safe_input(prompt=""):
|
def print_warning(message: str) -> None:
|
||||||
|
print(_format_message("⚠", Colors.YELLOW, message))
|
||||||
|
|
||||||
|
def print_info(message: str) -> None:
|
||||||
|
print(_format_message("✔", Colors.GREEN, message))
|
||||||
|
|
||||||
|
def safe_input(prompt: str = "") -> str:
|
||||||
"""
|
"""
|
||||||
A wrapper around input() that exits the entire script on Ctrl+C.
|
A wrapper around input() that exits the entire script on Ctrl+C.
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue