Compare commits

..

2 commits

Author SHA1 Message Date
ffb1f7e204 refactor 2025-03-08 00:43:48 -06:00
414266eefc refactored code 2025-03-08 00:43:33 -06:00
8 changed files with 836 additions and 601 deletions

View file

@ -2,109 +2,211 @@
import os import os
import subprocess import subprocess
import glob
from pathlib import Path
from typing import Optional, List, Dict, Set, Tuple
from dataclasses import dataclass
from .utils import print_error, print_warning, print_info, safe_input from .utils import print_error, print_warning, print_info, safe_input
from .config import CONF_DIR
def add_host(conf_dir): @dataclass
class HostConfig:
label: str
hostname: str
user: str = "root"
port: str = "22"
identity_file: str = ""
def to_config_lines(self) -> List[str]:
"""Convert host configuration to SSH config file lines."""
lines = [
f"Host {self.label}",
f" HostName {self.hostname}",
f" User {self.user}",
f" Port {self.port}"
]
if self.identity_file:
lines.append(f" IdentityFile {self.identity_file}")
return lines
def get_existing_hosts(conf_dir: str) -> Tuple[Set[str], Dict[str, str]]:
"""
Get existing host labels and hostnames from config files.
Returns (host_labels, hostname_to_label) where:
- host_labels is a set of existing host labels
- hostname_to_label maps hostnames to their labels
"""
host_labels = set()
hostname_to_label = {}
pattern = os.path.join(conf_dir, "*", "config")
for config_file in glob.glob(pattern):
try:
with open(config_file, 'r') as f:
lines = f.readlines()
current_label = None
for line in lines:
line = line.strip()
if not line or line.startswith('#'):
continue
if line.lower().startswith('host '):
labels = line.split()[1:]
for label in labels:
if '*' not in label:
current_label = label
host_labels.add(label)
break
elif current_label and line.lower().startswith('hostname '):
hostname = line.split(None, 1)[1].strip()
hostname_to_label[hostname] = current_label
except Exception as e:
print_warning(f"Error reading config file {config_file}: {e}")
continue
return host_labels, hostname_to_label
def generate_ssh_key(key_path: Path) -> bool:
"""Generate a new ED25519 SSH key pair."""
try:
subprocess.check_call([
"ssh-keygen",
"-q",
"-t", "ed25519",
"-N", "",
"-f", str(key_path)
])
print_info(f"Generated new SSH key at {key_path}")
return True
except subprocess.CalledProcessError as e:
print_error(f"Error generating SSH key: {e}")
return False
def copy_ssh_key(key_path: Path, user: str, hostname: str, port: str) -> bool:
"""Copy SSH public key to remote server."""
try:
cmd = ["ssh-copy-id", "-i", str(key_path)]
if port != "22":
cmd.extend(["-p", port])
cmd.append(f"{user}@{hostname}")
subprocess.check_call(cmd)
print_info("Key successfully copied to remote server.")
return True
except subprocess.CalledProcessError as e:
print_error(f"Error copying key to server: {e}")
return False
def write_config_file(config_path: Path, config_lines: List[str]) -> bool:
"""Write SSH config lines to file."""
try:
config_path.write_text("\n".join(config_lines) + "\n")
print_info(f"Created/updated config at: {config_path}")
return True
except Exception as e:
print_error(f"Failed to write config to {config_path}: {e}")
return False
def add_host(conf_dir: str) -> bool:
""" """
Interactive prompt to create a new SSH host in ~/.ssh/conf/<label>/config. Interactive prompt to create a new SSH host in ~/.ssh/conf/<label>/config.
Offers to generate a new SSH key pair (ed25519) quietly (-q), Offers to generate a new SSH key pair (ed25519) quietly (-q),
and then prompt to copy that key to the remote server via ssh-copy-id. and then prompt to copy that key to the remote server via ssh-copy-id.
Returns True if host was added successfully, False otherwise.
""" """
print_info("Adding a new SSH host...") print_info("Adding a new SSH host...")
host_label = safe_input("Enter Host label (e.g. myserver): ") # Get existing hosts to check for duplicates
if host_label is None: existing_labels, hostname_to_label = get_existing_hosts(conf_dir)
return # User canceled (Ctrl+C)
host_label = host_label.strip()
if not host_label:
print_error("Host label cannot be empty.")
return
hostname = safe_input("Enter HostName (IP or domain): ") # Get host label
if hostname is None: while True:
return host_label = safe_input("Enter Host label (e.g. myserver): ")
hostname = hostname.strip() if not host_label or host_label is None:
if not hostname: print_error("Host label cannot be empty.")
print_error("HostName cannot be empty.") return False
return
host_label = host_label.strip()
if host_label in existing_labels:
print_error(f"A host with label '{host_label}' already exists. Please choose a different label.")
continue
break
# Get hostname
while True:
hostname = safe_input("Enter HostName (IP or domain): ")
if not hostname or hostname is None:
print_error("HostName cannot be empty.")
return False
hostname = hostname.strip()
if hostname in hostname_to_label:
existing_label = hostname_to_label[hostname]
print_error(f"This hostname is already configured for host '{existing_label}'. Please use a different hostname.")
continue
break
user = safe_input("Enter username (default: 'root'): ") # Get optional parameters
user = safe_input("Enter username (default: 'root'): ") or "root"
if user is None: if user is None:
return return False
user = user.strip() or "root"
port = safe_input("Enter SSH port (default: 22): ") or "22"
port = safe_input("Enter SSH port (default: 22): ")
if port is None: if port is None:
return return False
port = port.strip() or "22"
# Create subdirectory: ~/.ssh/conf/<label> # Create host configuration
host_dir = os.path.join(conf_dir, host_label) host_config = HostConfig(
if os.path.exists(host_dir): label=host_label,
hostname=hostname,
user=user.strip(),
port=port.strip()
)
# Setup directory structure
host_dir = Path(conf_dir) / host_label
if host_dir.exists():
print_warning(f"Directory {host_dir} already exists; continuing anyway.") print_warning(f"Directory {host_dir} already exists; continuing anyway.")
else: host_dir.mkdir(mode=0o700, exist_ok=True)
os.makedirs(host_dir, mode=0o700, exist_ok=True) print_info(f"Created directory: {host_dir}")
print_info(f"Created directory: {host_dir}")
config_path = os.path.join(host_dir, "config") config_path = host_dir / "config"
if os.path.exists(config_path): if config_path.exists():
print_warning(f"Config file already exists: {config_path}; it will be overwritten/updated.") print_warning(f"Config file already exists: {config_path}; it will be overwritten/updated.")
# Handle SSH key generation
gen_key_choice = safe_input("Generate a new ed25519 SSH key for this host? (y/n): ") gen_key_choice = safe_input("Generate a new ed25519 SSH key for this host? (y/n): ")
if gen_key_choice is None: if gen_key_choice is None:
return return False
gen_key_choice = gen_key_choice.lower().strip()
identity_file = "" if gen_key_choice.lower().strip() == 'y':
if gen_key_choice == 'y': key_path = host_dir / "id_ed25519"
key_path = os.path.join(host_dir, "id_ed25519")
if os.path.exists(key_path): if key_path.exists():
print_warning(f"{key_path} already exists. Skipping generation.") print_warning(f"{key_path} already exists. Skipping generation.")
identity_file = key_path host_config.identity_file = str(key_path)
else: else:
cmd = ["ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-f", key_path] if generate_ssh_key(key_path):
try: host_config.identity_file = str(key_path)
subprocess.check_call(cmd)
print_info(f"Generated new SSH key at {key_path}")
identity_file = key_path
# Prompt to copy the key # Prompt to copy the key
copy_key = safe_input("Would you like to copy this key to the server now? (y/n): ") copy_key = safe_input("Would you like to copy this key to the server now? (y/n): ")
if copy_key is None: if copy_key is None:
return return False
if copy_key.lower().strip() == 'y': if copy_key.lower().strip() == 'y':
ssh_copy_cmd = ["ssh-copy-id", "-i", key_path] copy_ssh_key(key_path, host_config.user, host_config.hostname, host_config.port)
if port != "22":
ssh_copy_cmd += ["-p", port]
ssh_copy_cmd.append(f"{user}@{hostname}")
try:
subprocess.check_call(ssh_copy_cmd)
print_info("Key successfully copied to remote server.")
except subprocess.CalledProcessError as e:
print_error(f"Error copying key to server: {e}")
except subprocess.CalledProcessError as e:
print_error(f"Error generating SSH key: {e}")
else: else:
# Handle existing key
existing_key = safe_input("Enter existing IdentityFile path (or leave empty to skip): ") existing_key = safe_input("Enter existing IdentityFile path (or leave empty to skip): ")
if existing_key is None: if existing_key is None:
return return False
existing_key = existing_key.strip()
if existing_key: if existing_key.strip():
identity_file = os.path.expanduser(existing_key) host_config.identity_file = os.path.expanduser(existing_key.strip())
config_lines = [ # Write the config file
f"Host {host_label}", return write_config_file(config_path, host_config.to_config_lines())
f" HostName {hostname}",
f" User {user}",
f" Port {port}"
]
if identity_file:
config_lines.append(f" IdentityFile {identity_file}")
try:
with open(config_path, "w") as f:
for line in config_lines:
f.write(line + "\n")
print_info(f"Created/updated config at: {config_path}")
except Exception as e:
print_error(f"Failed to write config to {config_path}: {e}")

121
cli.py
View file

@ -2,6 +2,8 @@
import os import os
import asyncio import asyncio
from typing import Callable, Dict, Optional
from functools import wraps
from .utils import print_info, print_error, print_warning, Colors, safe_input from .utils import print_info, print_error, print_warning, Colors, safe_input
from .config import SSH_DIR, CONF_DIR, SOCKET_DIR, MAIN_CONFIG, DEFAULT_CONFIG_CONTENT from .config import SSH_DIR, CONF_DIR, SOCKET_DIR, MAIN_CONFIG, DEFAULT_CONFIG_CONTENT
from .add_host import add_host from .add_host import add_host
@ -10,62 +12,101 @@ from .list_hosts import list_hosts
from .regen_key import regenerate_key from .regen_key import regenerate_key
from .remove_host import remove_host from .remove_host import remove_host
def ensure_ssh_setup(): def ensure_ssh_setup() -> None:
""" """
Creates ~/.ssh, ~/.ssh/conf, and ~/.ssh/s if missing, Creates ~/.ssh, ~/.ssh/conf, and ~/.ssh/s if missing,
and writes a default ~/.ssh/config if it doesn't exist. and writes a default ~/.ssh/config if it doesn't exist.
""" """
if not os.path.isdir(SSH_DIR): directories = [
os.makedirs(SSH_DIR, mode=0o700, exist_ok=True) (SSH_DIR, "SSH directory"),
print_info(f"Created directory: {SSH_DIR}") (CONF_DIR, "configuration directory"),
(SOCKET_DIR, "socket directory")
if not os.path.isdir(CONF_DIR): ]
os.makedirs(CONF_DIR, mode=0o700, exist_ok=True)
print_info(f"Created directory: {CONF_DIR}") for directory, description in directories:
if not os.path.isdir(directory):
if not os.path.isdir(SOCKET_DIR): os.makedirs(directory, mode=0o700, exist_ok=True)
os.makedirs(SOCKET_DIR, mode=0o700, exist_ok=True) print_info(f"Created {description}: {directory}")
print_info(f"Created directory: {SOCKET_DIR}")
if not os.path.isfile(MAIN_CONFIG): if not os.path.isfile(MAIN_CONFIG):
with open(MAIN_CONFIG, "w") as f: with open(MAIN_CONFIG, "w") as f:
f.write(DEFAULT_CONFIG_CONTENT) f.write(DEFAULT_CONFIG_CONTENT)
print_info(f"Created default SSH config at: {MAIN_CONFIG}") print_info(f"Created default SSH config at: {MAIN_CONFIG}")
def main(): def async_handler(func: Callable) -> Callable:
ensure_ssh_setup() """Decorator to handle async functions in the command dispatch"""
@wraps(func)
def wrapper(*args, **kwargs):
return asyncio.run(func(*args, **kwargs))
return wrapper
class SSHManager:
def __init__(self):
self.commands: Dict[str, tuple[Callable, str]] = {
"1": (self.list_hosts, "List Hosts"),
"2": (self.add_host, "Add a Host"),
"3": (self.edit_host, "Edit a Host"),
"4": (self.regenerate_key, "Regenerate Key"),
"5": (self.remove_host, "Remove Host"),
"6": (self.exit_app, "Exit")
}
@async_handler
async def list_hosts(self) -> None:
await list_hosts(CONF_DIR)
def add_host(self) -> None:
add_host(CONF_DIR)
@async_handler
async def edit_host(self) -> None:
await edit_host(CONF_DIR)
@async_handler
async def regenerate_key(self) -> None:
await regenerate_key(CONF_DIR)
@async_handler
async def remove_host(self) -> None:
await remove_host(CONF_DIR)
def exit_app(self) -> None:
print_info("Exiting...")
raise SystemExit(0)
def display_menu(self) -> None:
print("\n" + f"{Colors.CYAN}{Colors.BOLD}SSH Config Manager Menu{Colors.RESET}")
for key, (_, description) in self.commands.items():
print(f"{key}. {description}")
def handle_command(self, choice: str) -> bool:
if choice not in self.commands:
print_error("Invalid choice. Please select 1 through 6.")
return True
try:
self.commands[choice][0]()
return choice != "6"
except SystemExit:
return False
except Exception as e:
print_error(f"Error executing command: {str(e)}")
return True
def main() -> int:
ensure_ssh_setup()
manager = SSHManager()
# Display the server list on first load # Display the server list on first load
asyncio.run(list_hosts(CONF_DIR)) manager.list_hosts()
while True: while True:
print("\n" + f"{Colors.CYAN}{Colors.BOLD}SSH Config Manager Menu{Colors.RESET}") manager.display_menu()
print("1. List Hosts")
print("2. Add a Host")
print("3. Edit a Host")
print("4. Regenerate Key")
print("5. Remove Host")
print("6. Exit")
choice = safe_input("Select an option (1-6): ") choice = safe_input("Select an option (1-6): ")
if choice is None: if choice is None:
continue # User pressed Ctrl+C => safe_input returns None => re-show menu continue
choice = choice.strip() if not manager.handle_command(choice.strip()):
if choice == '1':
asyncio.run(list_hosts(CONF_DIR))
elif choice == '2':
add_host(CONF_DIR)
elif choice == '3':
edit_host(CONF_DIR)
elif choice == '4':
asyncio.run(regenerate_key(CONF_DIR))
elif choice == '5':
asyncio.run(remove_host(CONF_DIR))
elif choice == '6':
print_info("Exiting...")
break break
else:
print_error("Invalid choice. Please select 1 through 6.")
return 0 return 0

View file

@ -1,15 +1,24 @@
# ssh_manager/config.py # ssh_manager/config.py
import os import os
from pathlib import Path
from typing import Final
import subprocess
# Paths # Use Path for more efficient path handling
SSH_DIR = os.path.expanduser("~/.ssh") HOME: Final[Path] = Path.home()
CONF_DIR = os.path.join(SSH_DIR, "conf") SSH_DIR: Final[Path] = HOME / ".ssh"
SOCKET_DIR = os.path.join(SSH_DIR, "s") CONF_DIR: Final[Path] = SSH_DIR / "conf"
MAIN_CONFIG = os.path.join(SSH_DIR, "config") SOCKET_DIR: Final[Path] = SSH_DIR / "s"
MAIN_CONFIG: Final[Path] = SSH_DIR / "config"
# Validate paths on import
for path in (SSH_DIR, CONF_DIR, SOCKET_DIR):
if not path.exists():
path.mkdir(mode=0o700, parents=True, exist_ok=True)
# Default SSH config content if ~/.ssh/config is missing # Default SSH config content if ~/.ssh/config is missing
DEFAULT_CONFIG_CONTENT = """### DEFAULT_CONFIG_CONTENT: Final[str] = """###
#Local ssh #Local ssh
### ###
@ -30,3 +39,28 @@ Host *
ControlPersist 72000 ControlPersist 72000
ControlPath ~/.ssh/s/%C ControlPath ~/.ssh/s/%C
""" """
# Export string versions for backward compatibility
SSH_DIR_STR: Final[str] = str(SSH_DIR)
CONF_DIR_STR: Final[str] = str(CONF_DIR)
SOCKET_DIR_STR: Final[str] = str(SOCKET_DIR)
MAIN_CONFIG_STR: Final[str] = str(MAIN_CONFIG)
def validate_key_path(key_path: Path) -> bool:
if not key_path.exists():
return False
if not key_path.is_dir():
return False
if not os.access(key_path, os.W_OK):
return False
return True
def update_config_with_key(key_path: Path) -> bool:
if not validate_key_path(key_path):
return False # Early return if path validation fails
try:
subprocess.check_call([...])
except subprocess.CalledProcessError as e:
print_error(f"Error generating new SSH key: {e}")
return False

View file

@ -2,98 +2,31 @@ import os
import asyncio import asyncio
from collections import OrderedDict from collections import OrderedDict
from .utils import print_error, print_warning, print_info, safe_input from .utils import print_error, print_warning, print_info, safe_input
from .list_hosts import list_hosts, load_config_file, check_ssh_port from .list_hosts import (
build_host_list_table,
load_config_file,
gather_host_info,
sort_by_ip
)
async def get_all_host_blocks(conf_dir): async def edit_host(conf_dir):
"""
Similar to list_hosts, but returns the list of host blocks + a table of results.
We'll build a table ourselves so we can map row numbers to actual host labels.
"""
import glob
import socket
pattern = os.path.join(conf_dir, "*", "config")
conf_files = sorted(glob.glob(pattern))
all_blocks = []
for conf_file in conf_files:
blocks = load_config_file(conf_file)
all_blocks.extend(blocks)
# If no blocks found, return empty
if not all_blocks:
return []
# We want to do a partial version of check_host to get row data
# so we can display the table right here and keep track of each blocks host label.
# But let's do it similarly to list_hosts:
table_rows = []
for idx, b in enumerate(all_blocks, start=1):
host_label = b.get("Host", "N/A")
hostname = b.get("HostName", "N/A")
user = b.get("User", "N/A")
port = int(b.get("Port", "22"))
identity_file = b.get("IdentityFile", "N/A")
# Identity check
if identity_file != "N/A":
expanded_identity = os.path.expanduser(identity_file)
identity_exists = os.path.isfile(expanded_identity)
else:
identity_exists = False
# IP resolution
try:
ip_address = socket.gethostbyname(hostname)
except socket.error:
ip_address = None
# Port check
if ip_address:
port_open = await asyncio.wait_for(check_ssh_port(ip_address, port), timeout=1)
else:
port_open = False
# Colors for display (optional, or we can keep it simple):
ip_display = f"\033[0;32m{ip_address}\033[0m" if ip_address else "\033[0;31mN/A\033[0m"
port_display = f"\033[0;32m{port}\033[0m" if port_open else f"\033[0;31m{port}\033[0m"
identity_disp= f"\033[0;32m{identity_file}\033[0m" if identity_exists else f"\033[0;31m{identity_file}\033[0m"
row = [
idx,
host_label,
user,
port_display,
hostname,
ip_display,
identity_disp
]
table_rows.append(row)
# Print the table
from tabulate import tabulate
headers = ["No.", "Host", "User", "Port", "HostName", "IP Address", "IdentityFile"]
print("\nSSH Conf Subdirectory Host List")
print(tabulate(table_rows, headers=headers, tablefmt="grid"))
return all_blocks
def edit_host(conf_dir):
""" """
Let the user update fields for an existing host in ~/.ssh/conf/<label>/config. Let the user update fields for an existing host in ~/.ssh/conf/<label>/config.
The user may type either the row number OR the actual host label. 1) Display the unified table (No. | Host | User | Port | HostName | IP Address | Conf Directory)
2) Prompt row number or host label
3) Rewrite config with updated fields
""" """
# 1) Gather + display the current host list
print_info("Here is the current list of hosts:\n") print_info("Here is the current list of hosts:\n")
all_blocks = asyncio.run(get_all_host_blocks(conf_dir)) headers, final_data = await build_host_list_table(conf_dir)
if not final_data:
if not all_blocks: print_warning("No hosts to edit.")
print_warning("No hosts found to edit.")
return return
# 2) Prompt for which host to edit (by label or row number) from tabulate import tabulate
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid"))
choice = safe_input("Enter the row number or the Host label to edit: ") choice = safe_input("Enter the row number or the Host label to edit: ")
if choice is None: if choice is None:
return # user canceled (Ctrl+C) return # user canceled (Ctrl+C)
@ -102,43 +35,47 @@ def edit_host(conf_dir):
print_error("Host label or row number cannot be empty.") print_error("Host label or row number cannot be empty.")
return return
# Check if user typed a digit -> row number # We replicate the approach to find the matching block
target_block = None all_blocks = []
import glob
for cfile in glob.glob(os.path.join(conf_dir, "*", "config")):
all_blocks.extend(load_config_file(cfile))
results = await gather_host_info(all_blocks)
sorted_rows = sort_by_ip(results)
target_tuple = None
if choice.isdigit(): if choice.isdigit():
row_idx = int(choice) idx = int(choice)
# Validate index if idx < 1 or idx > len(sorted_rows):
if row_idx < 1 or row_idx > len(all_blocks): print_warning(f"Row number {idx} is invalid.")
print_warning(f"Invalid row number {row_idx}.")
return return
target_block = all_blocks[row_idx - 1] # zero-based target_tuple = sorted_rows[idx - 1]
else: else:
# The user typed a host label for t in sorted_rows:
# We must search all_blocks for a matching Host if t[0] == choice: # t[0] => host_label
for b in all_blocks: target_tuple = t
if b.get("Host") == choice:
target_block = b
break break
if not target_block: if not target_tuple:
print_warning(f"No matching host label '{choice}' found.") print_warning(f"No matching Host '{choice}' found in the table.")
return return
# Now we have a target_block with existing data host_label = target_tuple[0]
host_label = target_block.get("Host", "") # find the config block
if not host_label: found_block = None
print_warning("This host block has no label. Cannot edit.") for b in all_blocks:
if b.get("Host") == host_label:
found_block = b
break
if not found_block:
print_warning(f"No config block found for '{host_label}'.")
return return
# Derive the config path old_hostname = found_block.get("HostName", "")
host_dir = os.path.join(conf_dir, host_label) old_user = found_block.get("User", "")
config_path = os.path.join(host_dir, "config") old_port = found_block.get("Port", "22")
if not os.path.isfile(config_path): old_identity = found_block.get("IdentityFile", "")
print_warning(f"No config file found at {config_path}; cannot edit this host.")
return
old_hostname = target_block.get("HostName", "")
old_user = target_block.get("User", "")
old_port = target_block.get("Port", "22")
old_identity = target_block.get("IdentityFile", "")
print_info("Leave a field blank to keep its current value.") print_info("Leave a field blank to keep its current value.")
@ -167,6 +104,8 @@ def edit_host(conf_dir):
final_port = new_port if new_port else old_port final_port = new_port if new_port else old_port
final_ident = new_ident if new_ident else old_identity final_ident = new_ident if new_ident else old_identity
# Overwrite the file
config_path = os.path.join(conf_dir, host_label, "config")
new_config_lines = [ new_config_lines = [
f"Host {host_label}", f"Host {host_label}",
f" HostName {final_hostname}", f" HostName {final_hostname}",

View file

@ -5,34 +5,28 @@ import glob
import socket import socket
import asyncio import asyncio
import ipaddress import ipaddress
import subprocess
from pathlib import Path
from typing import List, Dict, Tuple, Optional, OrderedDict as OrderedDictType
from functools import lru_cache
from tabulate import tabulate from tabulate import tabulate
from collections import OrderedDict from collections import OrderedDict
from .utils import print_warning, print_error, Colors from .utils import print_warning, print_error, Colors
from .config import CONF_DIR
async def check_ssh_port(ip_address, port): # Cache DNS lookups
""" @lru_cache(maxsize=128, typed=True)
Attempt to open an SSH connection to see if the port is open. def resolve_hostname(hostname: str) -> Optional[str]:
Returns True if successful, False otherwise.
"""
try: try:
reader, writer = await asyncio.wait_for( return socket.gethostbyname(hostname)
asyncio.open_connection(ip_address, port), timeout=1 except socket.error:
) return None
writer.close()
await writer.wait_closed()
return True
except:
return False
def load_config_file(file_path):
"""
Parse a single SSH config file and return a list of host blocks.
Each block is an OrderedDict with keys like 'Host', 'HostName', etc.
"""
blocks = []
host_data = None
def load_config_file(file_path: str) -> List[OrderedDictType]:
blocks: List[OrderedDictType] = []
host_data: Optional[OrderedDictType] = None
try: try:
with open(file_path, 'r') as f: with open(file_path, 'r') as f:
lines = f.readlines() lines = f.readlines()
@ -42,11 +36,9 @@ def load_config_file(file_path):
for line in lines: for line in lines:
stripped_line = line.strip() stripped_line = line.strip()
# Skip empty lines and comments
if not stripped_line or stripped_line.startswith('#'): if not stripped_line or stripped_line.startswith('#'):
continue continue
# Start of a new Host block
if stripped_line.lower().startswith('host '): if stripped_line.lower().startswith('host '):
host_labels = stripped_line.split()[1:] host_labels = stripped_line.split()[1:]
for label in host_labels: for label in host_labels:
@ -55,7 +47,7 @@ def load_config_file(file_path):
blocks.append(host_data) blocks.append(host_data)
host_data = OrderedDict({'Host': label}) host_data = OrderedDict({'Host': label})
break break
elif host_data: elif host_data is not None:
if ' ' in stripped_line: if ' ' in stripped_line:
key, value = stripped_line.split(None, 1) key, value = stripped_line.split(None, 1)
host_data[key] = value.strip() host_data[key] = value.strip()
@ -64,111 +56,137 @@ def load_config_file(file_path):
blocks.append(host_data) blocks.append(host_data)
return blocks return blocks
async def check_host(host): async def gather_host_info(all_host_blocks: List[OrderedDictType]) -> List[Tuple]:
""" """
Given a host block, resolve IP, check SSH port, etc. Given a list of host blocks, gather full info:
Returns a tuple of: - resolved IP
1) Host label - 'Conf Directory' coloring if IdentityFile != 'N/A'
2) User Returns a list of 7-tuples:
3) Port (colored if open) (host_label, user, port, hostname, colored_ip, conf_path_display, raw_ip)
4) HostName
5) IP Address (colored if resolved)
6) Conf Directory (green if has IdentityFile, else no color)
7) raw_ip (uncolored string for sorting)
""" """
host_label = host.get('Host', 'N/A') results = []
hostname = host.get('HostName', 'N/A')
user = host.get('User', 'N/A')
port = int(host.get('Port', '22'))
identity_file = host.get('IdentityFile', 'N/A')
# Resolve IP async def process_block(h: OrderedDictType) -> Tuple:
try: host_label: str = h.get('Host', 'N/A')
raw_ip = socket.gethostbyname(hostname) # uncolored hostname: str = h.get('HostName', 'N/A')
colored_ip = f"{Colors.GREEN}{raw_ip}{Colors.RESET}" user: str = h.get('User', 'N/A')
except socket.error: port: str = h.get('Port', '22') # Keep as string since we're not testing it
raw_ip = "N/A" identity_file: str = h.get('IdentityFile', 'N/A')
colored_ip = f"{Colors.RED}N/A{Colors.RESET}"
# Check port # Resolve IP using cached function
if raw_ip != "N/A": raw_ip = resolve_hostname(hostname) or "N/A"
port_open = await check_ssh_port(raw_ip, port)
colored_port = ( # Check if the IP is reachable
f"{Colors.GREEN}{port}{Colors.RESET}" if port_open else f"{Colors.RED}{port}{Colors.RESET}" def is_ip_reachable(ip: str) -> bool:
try:
subprocess.run(["ping", "-c", "1", "-W", "1", ip], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
return True
except subprocess.CalledProcessError:
return False
# Determine IP color
if raw_ip != "N/A" and is_ip_reachable(raw_ip):
colored_ip = f"{Colors.GREEN}{raw_ip}{Colors.RESET}"
else:
colored_ip = raw_ip # No color if not reachable
# Determine hostname color
try:
ipaddress.ip_address(hostname)
colored_hostname = hostname # No color if hostname is an IP
except ValueError:
if raw_ip != "N/A":
colored_hostname = f"{Colors.GREEN}{hostname}{Colors.RESET}"
else:
colored_hostname = f"{Colors.RED}{hostname}{Colors.RESET}"
# Conf Directory = ~/.ssh/conf/<host_label>
conf_path = f"~/.ssh/conf/{host_label}"
conf_path_display = (
f"{Colors.GREEN}{conf_path}{Colors.RESET}"
if identity_file != 'N/A'
else conf_path
) )
else:
colored_port = f"{Colors.RED}{port}{Colors.RESET}"
# Conf Directory = ~/.ssh/conf/<host_label> return (
conf_path = f"~/.ssh/conf/{host_label}" host_label,
# If there's an IdentityFile, color the conf path green user,
if identity_file != 'N/A': port, # Port is now uncolored
conf_path_display = f"{Colors.GREEN}{conf_path}{Colors.RESET}" colored_hostname,
else: colored_ip,
conf_path_display = conf_path conf_path_display,
raw_ip # uncolored IP for sorting
)
# Return the data plus the uncolored IP for sorting # Process blocks concurrently with semaphore to limit concurrent connections
return ( sem = asyncio.Semaphore(10) # Limit concurrent connections
host_label, async def process_with_semaphore(block):
user, async with sem:
colored_port, return await process_block(block)
hostname,
colored_ip,
conf_path_display,
raw_ip # for sorting
)
async def list_hosts(conf_dir): tasks = [process_with_semaphore(b) for b in all_host_blocks]
results = await asyncio.gather(*tasks)
return results
@lru_cache(maxsize=1)
def parse_ip(ip_str: str) -> Optional[ipaddress.IPv4Address]:
""" """
List out all hosts found in ~/.ssh/conf/*/config, sorted by IP in ascending order. Convert a string IP to an ipaddress object for sorting.
Columns: No., Host, User, Port, HostName, IP Address, Conf Directory Returns None if invalid or 'N/A'.
""" """
pattern = os.path.join(conf_dir, "*", "config") try:
return ipaddress.ip_address(ip_str)
except ValueError:
return None
def sort_by_ip(results: List[Tuple]) -> List[Tuple]:
"""
Sort the 7-tuples by IP ascending, with 'N/A' last.
"""
def sort_key(row):
raw_ip = row[-1]
ip_obj = parse_ip(raw_ip)
return (ip_obj is None, ip_obj or ipaddress.IPv4Address('0.0.0.0'))
return sorted(results, key=sort_key)
async def build_host_list_table(conf_dir: str) -> Tuple[List[str], List[List]]:
"""
Gathers + sorts all hosts in conf_dir by IP ascending.
Returns (headers, final_table_rows), each row omitting the raw_ip.
"""
pattern = os.path.join(conf_dir, "*", "config")
conf_files = sorted(glob.glob(pattern)) conf_files = sorted(glob.glob(pattern))
all_host_blocks = [] all_host_blocks: List[OrderedDictType] = []
for conf_file in conf_files: for conf_file in conf_files:
blocks = load_config_file(conf_file) blocks = load_config_file(conf_file)
all_host_blocks.extend(blocks) all_host_blocks.extend(blocks)
headers = ["No.", "Host", "User", "Port", "HostName", "IP Address", "Conf Directory"] headers = ["No.", "Host", "User", "Port", "HostName", "IP Address", "Conf Directory"]
if not all_host_blocks: if not all_host_blocks:
return headers, []
results = await gather_host_info(all_host_blocks)
sorted_rows = sort_by_ip(results)
# Build final table
final_data = [
[idx] + list(row[:-1])
for idx, row in enumerate(sorted_rows, start=1)
]
return headers, final_data
async def list_hosts(conf_dir: str) -> None:
"""Display a formatted table of all SSH hosts."""
headers, final_data = await build_host_list_table(conf_dir)
if not final_data:
print_warning("No hosts found. The server list is empty.") print_warning("No hosts found. The server list is empty.")
print("\nSSH Conf Subdirectory Host List") print("\nSSH Conf Subdirectory Host List")
print(tabulate([], headers=headers, tablefmt="grid")) print(tabulate([], headers=headers, tablefmt="grid"))
return return
# Gather full data for each host
tasks = [check_host(h) for h in all_host_blocks]
results = await asyncio.gather(*tasks)
# We want to sort by IP ascending. results[i] is a tuple:
# (host_label, user, colored_port, hostname, colored_ip, conf_path, raw_ip)
# We'll parse raw_ip as an ipaddress for sorting. "N/A" => sort to the end.
def parse_ip(ip_str):
try:
return ipaddress.ip_address(ip_str)
except ValueError:
return None
# Convert the results into a list of (ip_obj, original_tuple)
# so we can sort, then rebuild the final data.
sortable = []
for row in results:
raw_ip = row[-1] # last element
ip_obj = parse_ip(raw_ip)
# We'll sort None last by using a sort key that puts (True) after (False)
# e.g. (ip_obj is None, ip_obj)
sortable.append(((ip_obj is None, ip_obj), row))
# Sort by (is_none, ip_obj)
sortable.sort(key=lambda x: x[0])
# Rebuild the final display table, ignoring the raw_ip at the end
final_data = []
for idx, (_, row) in enumerate(sortable, start=1):
# row is (host_label, user, colored_port, hostname, colored_ip, conf_path, raw_ip)
final_data.append([idx] + list(row[:-1])) # omit raw_ip
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)") print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid")) print(tabulate(final_data, headers=headers, tablefmt="grid"))

View file

@ -1,146 +1,121 @@
import os import os
import glob
import subprocess import subprocess
import asyncio import asyncio
from collections import OrderedDict from pathlib import Path
from typing import Optional, List, Dict, Tuple, Any
from .utils import ( from .utils import (
print_info, print_info,
print_warning, print_warning,
print_error, print_error,
safe_input, safe_input
Colors )
from .list_hosts import (
build_host_list_table,
load_config_file,
gather_host_info,
sort_by_ip
) )
from .list_hosts import load_config_file, check_ssh_port
async def get_all_host_blocks(conf_dir): def validate_key_path(key_path: str) -> bool:
pattern = os.path.join(conf_dir, "*", "config") """Validate that the key path and its directory are valid."""
conf_files = sorted(glob.glob(pattern)) try:
key_path = Path(key_path)
key_dir = key_path.parent
if not key_dir.exists():
key_dir.mkdir(mode=0o700, parents=True)
elif not os.access(str(key_dir), os.W_OK):
print_error(f"No write permission for directory: {key_dir}")
return False
return True
except Exception as e:
print_error(f"Error validating key path: {e}")
return False
all_blocks = [] def generate_new_key(key_path: str, user: str, hostname: str, port: int) -> bool:
for conf_file in conf_files: """Generate a new ED25519 SSH key and optionally copy it to the remote server."""
blocks = load_config_file(conf_file) if not validate_key_path(key_path):
all_blocks.extend(blocks) return False
if not all_blocks: print_info("Generating new ed25519 SSH key...")
return [] try:
return all_blocks subprocess.check_call([
"ssh-keygen",
"-q",
"-t", "ed25519",
"-N", "",
"-f", key_path
])
except subprocess.CalledProcessError as e:
print_error(f"Error generating new SSH key: {e}")
return False
print_info(f"Generated new SSH key at {key_path}")
# Ask to copy the key
copy_choice = safe_input("Copy new key to remote now? (y/n): ")
if not copy_choice or not copy_choice.lower().startswith('y'):
return True
async def regenerate_key(conf_dir): try:
""" ssh_copy_cmd = ["ssh-copy-id", "-i", key_path]
Menu-driven function to regenerate a key for a selected host: if port != 22:
1) Show the host table with row numbers ssh_copy_cmd += ["-p", str(port)]
2) Let user pick row # or label ssh_copy_cmd.append(f"{user}@{hostname}")
3) Read old pub data
4) Remove old local key files subprocess.check_call(ssh_copy_cmd)
5) Generate new key print_info("New key successfully copied to remote server.")
6) Copy new key return True
7) Remove old key from remote authorized_keys (improved logic to detect existence) except subprocess.CalledProcessError as e:
""" print_error(f"Error copying new key: {e}")
print_info("Regenerate Key - Step 1: Show current hosts...\n") return False
# 1) Gather host blocks def update_config_with_key(config_path: Path, new_key_path: str) -> bool:
all_blocks = await get_all_host_blocks(conf_dir) """Update the SSH config file with the new identity file."""
if not all_blocks: try:
print_warning("No hosts found. Cannot regenerate a key.") with open(config_path, 'r') as f:
return config_lines = [
line.rstrip('\n')
for line in f
if not line.strip().lower().startswith('identityfile')
]
config_lines.append(f" IdentityFile {new_key_path}")
with open(config_path, 'w') as f:
f.write('\n'.join(config_lines) + '\n')
print_info(f"Updated config file with new IdentityFile: {new_key_path}")
return True
except Exception as e:
print_error(f"Failed to update config file: {e}")
return False
# Display them in a table (similar to list_hosts): def find_target_host(sorted_rows: List[Tuple], choice: str) -> Optional[Tuple]:
import socket """Find the target host based on user choice."""
from tabulate import tabulate
table_rows = []
for idx, block in enumerate(all_blocks, start=1):
host_label = block.get("Host", "N/A")
hostname = block.get("HostName", "N/A")
user = block.get("User", "N/A")
port = int(block.get("Port", "22"))
identity_file = block.get("IdentityFile", "N/A")
# Check IP
try:
ip_address = socket.gethostbyname(hostname)
port_open = await asyncio.wait_for(check_ssh_port(ip_address, port), timeout=1)
except:
ip_address = None
port_open = False
ip_disp = f"\033[0;32m{ip_address}\033[0m" if ip_address else "\033[0;31mN/A\033[0m"
port_disp = f"\033[0;32m{port}\033[0m" if port_open else f"\033[0;31m{port}\033[0m"
row = [
idx,
host_label,
user,
port_disp,
hostname,
ip_disp,
identity_file
]
table_rows.append(row)
headers = ["No.", "Host", "User", "Port", "HostName", "IP", "IdentityFile"]
print("\nSSH Conf Subdirectory Host List")
print(tabulate(table_rows, headers=headers, tablefmt="grid"))
# 2) Prompt for row # or label
choice = safe_input("Enter the row number or the Host label to regenerate: ")
if choice is None:
return
choice = choice.strip()
if not choice: if not choice:
print_error("No choice given.") print_error("No choice given.")
return return None
target_block = None
if choice.isdigit(): if choice.isdigit():
row_idx = int(choice) row_idx = int(choice)
if row_idx < 1 or row_idx > len(all_blocks): if row_idx < 1 or row_idx > len(sorted_rows):
print_warning(f"Invalid row number {row_idx}.") print_warning(f"Invalid row number {row_idx}.")
return return None
target_block = all_blocks[row_idx - 1] return sorted_rows[row_idx - 1]
else:
# user typed a label # User typed a label
for b in all_blocks: for row in sorted_rows:
if b.get("Host") == choice: if row[0] == choice: # row[0] is host_label
target_block = b return row
break
if not target_block: print_warning(f"No matching host label '{choice}' found.")
print_warning(f"No matching host label '{choice}' found.") return None
return
# 3) Gather info from block def remove_key_files(key_paths: List[str]) -> None:
host_label = target_block.get("Host", "") """Remove SSH key files."""
hostname = target_block.get("HostName", "") for path in key_paths:
user = target_block.get("User", "root")
port = int(target_block.get("Port", "22"))
identity_file = target_block.get("IdentityFile", "")
if not host_label or not hostname:
print_error("Missing Host or HostName; cannot regenerate.")
return
if not identity_file or identity_file == "N/A":
print_error("No IdentityFile found in config; cannot regenerate.")
return
# Derive local paths
expanded_key = os.path.expanduser(identity_file)
key_dir = os.path.dirname(expanded_key)
pub_path = expanded_key + ".pub"
# 3a) Read old pub key data
old_pub_data = ""
if os.path.isfile(pub_path):
try:
with open(pub_path, "r") as f:
old_pub_data = f.read().strip()
except Exception as e:
print_warning(f"Could not read old pub key: {e}")
else:
print_warning("No old pub key found locally.")
# 4) Remove old local key files
print_info("Removing old key files locally...")
for path in [expanded_key, pub_path]:
if os.path.isfile(path): if os.path.isfile(path):
try: try:
os.remove(path) os.remove(path)
@ -148,74 +123,196 @@ async def regenerate_key(conf_dir):
except Exception as e: except Exception as e:
print_warning(f"Could not remove {path}: {e}") print_warning(f"Could not remove {path}: {e}")
# 5) Generate new key async def regenerate_key(conf_dir: str) -> bool:
print_info("Generating new ed25519 SSH key...") """
new_key_path = expanded_key # Reuse the same path from config Regenerate the SSH key for a chosen host by:
cmd = ["ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-f", new_key_path] 1) Displaying the unified table of hosts
try: 2) Letting you pick a row number or host label
subprocess.check_call(cmd) 3) Reading/deleting any existing local keys
print_info(f"Generated new SSH key at {new_key_path}") 4) Generating a new key
except subprocess.CalledProcessError as e: 5) Optionally copying it to the remote
print_error(f"Error generating new SSH key: {e}") 6) Removing the old pub key from the remote authorized_keys if present
return
Returns True if key was successfully regenerated, False otherwise.
"""
print_info("Regenerate Key - Step 1: Show current hosts...\n")
# 6) Copy new key to remote # Get host list
copy_choice = safe_input("Copy new key to remote now? (y/n): ") headers, final_data = await build_host_list_table(conf_dir)
if copy_choice and copy_choice.lower().startswith('y'): if not final_data:
ssh_copy_cmd = ["ssh-copy-id", "-i", new_key_path] print_warning("No hosts found. Cannot regenerate a key.")
if port != 22: return False
ssh_copy_cmd += ["-p", str(port)]
ssh_copy_cmd.append(f"{user}@{hostname}") # Display host table
from tabulate import tabulate
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid"))
# Get host blocks and sort them
all_blocks = []
pattern = os.path.join(conf_dir, "*", "config")
for cfile in glob.glob(pattern):
all_blocks.extend(load_config_file(cfile))
results = await gather_host_info(all_blocks)
sorted_rows = sort_by_ip(results)
# Get user choice
choice = safe_input("Enter the row number or the Host label to regenerate: ")
if choice is None:
return False
target_tuple = find_target_host(sorted_rows, choice.strip())
if not target_tuple:
return False
# Get host information
host_label, user, _, hostname, *_ = target_tuple
# Find config block
found_block = next(
(b for b in all_blocks if b.get("Host") == host_label),
None
)
if not found_block:
print_warning(f"No config block found for '{host_label}'.")
return False
port = int(found_block.get("Port", "22"))
identity_file = found_block.get("IdentityFile", "")
# Handle missing identity file
if not identity_file or identity_file == "N/A":
print_warning("No existing SSH key found in the configuration.")
gen_choice = safe_input("Would you like to generate a new key? (y/n): ")
if not gen_choice or not gen_choice.lower().startswith('y'):
return False
# Set up new key path and generate key
host_dir = Path(conf_dir) / host_label
host_dir.mkdir(mode=0o700, exist_ok=True)
new_key_path = str(host_dir / "id_ed25519")
if not generate_new_key(new_key_path, user, hostname, port):
return False
# Update config with new key
config_path = host_dir / "config"
return update_config_with_key(config_path, new_key_path)
# Handle existing key regeneration
expanded_key = os.path.expanduser(identity_file)
pub_path = expanded_key + ".pub"
old_pub_data = ""
# Try to read old public key
if os.path.isfile(pub_path):
try: try:
subprocess.check_call(ssh_copy_cmd) old_pub_data = Path(pub_path).read_text().rstrip("\n")
print_info("New key successfully copied to remote server.") except Exception as e:
except subprocess.CalledProcessError as e: print_warning(f"Could not read old pub key: {e}")
print_error(f"Error copying new key: {e}")
# 7) Remove old key from authorized_keys if old_pub_data is non-empty # Remove old key files
print_info("Removing old key files locally...")
remove_key_files([expanded_key, pub_path])
# Generate new key
if not generate_new_key(expanded_key, user, hostname, port):
return False
# Remove old key from remote if we had it
if old_pub_data: if old_pub_data:
print_info("Attempting to remove old key from remote authorized_keys...") print_info("Attempting to remove old key from remote authorized_keys...")
await remove_old_key_remote(old_pub_data, user, hostname, port) await remove_old_key_remote(old_pub_data, user, hostname, port)
else: else:
print_warning("No old pub key data found locally, skipping remote removal.") print_warning("No old pub key data found locally, skipping remote removal.")
return True
async def remove_old_key_remote(old_pub_data, user, hostname, port): async def remove_old_key_remote(old_pub_data: str, user: str, hostname: str, port: int) -> bool:
""" """
1) Check if old_pub_data is present on remote server (grep -q). Remove the old public key from remote authorized_keys file.
2) If found, remove it with grep -v ... Returns True if successful, False otherwise.
3) Otherwise, print "No old key found on remote."
""" """
# 1) Check if old_pub_data exists in authorized_keys # Escape the public key data for shell safety
check_cmd = [ escaped_key = old_pub_data.replace('"', '\\"')
# First check if authorized_keys exists
check_file_cmd = [
"ssh", "ssh",
"-o", "StrictHostKeyChecking=no", "-o", "StrictHostKeyChecking=no",
"-p", str(port), "-p", str(port),
f"{user}@{hostname}", f"{user}@{hostname}",
f"grep -F '{old_pub_data}' ~/.ssh/authorized_keys" "test -f ~/.ssh/authorized_keys && echo 'exists'"
] ]
found_key = False
try: try:
subprocess.check_call(check_cmd) result = subprocess.run(check_file_cmd, capture_output=True, text=True)
found_key = True if result.returncode != 0 or 'exists' not in result.stdout:
except subprocess.CalledProcessError: print_warning("No authorized_keys file found on remote host.")
# grep returns exit code 1 if not found return False
pass except subprocess.CalledProcessError as e:
print_error(f"Error checking authorized_keys: {e}")
return False
if not found_key: # Create a temporary file for the sed script
print_warning("No old key found on remote authorized_keys.") temp_script = """#!/bin/bash
return set -e
KEYS_FILE="$HOME/.ssh/authorized_keys"
TEMP_FILE="$HOME/.ssh/authorized_keys.tmp"
grep -Fxv "$1" "$KEYS_FILE" > "$TEMP_FILE"
mv "$TEMP_FILE" "$KEYS_FILE"
chmod 600 "$KEYS_FILE"
"""
# Create a temporary script on the remote host
setup_cmd = [
"ssh",
"-o", "StrictHostKeyChecking=no",
"-p", str(port),
f"{user}@{hostname}",
f'cat > ~/.ssh/remove_key.sh << \'EOF\'\n{temp_script}\nEOF\n'
f'chmod +x ~/.ssh/remove_key.sh'
]
try:
subprocess.run(setup_cmd, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
print_error(f"Failed to create temporary script: {e}")
return False
# 2) Actually remove it # Execute the script with the key
remove_cmd = [ remove_cmd = [
"ssh", "ssh",
"-o", "StrictHostKeyChecking=no", "-o", "StrictHostKeyChecking=no",
"-p", str(port), "-p", str(port),
f"{user}@{hostname}", f"{user}@{hostname}",
f"grep -v '{old_pub_data}' ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys" f'~/.ssh/remove_key.sh "{escaped_key}"'
] ]
try: try:
subprocess.check_call(remove_cmd) subprocess.run(remove_cmd, check=True, capture_output=True)
print_info("Old public key removed from remote authorized_keys.") print_info("Old public key removed from remote authorized_keys.")
except subprocess.CalledProcessError:
print_warning("Failed to remove old key from remote authorized_keys (permission or other error).") # Clean up the temporary script
cleanup_cmd = [
"ssh",
"-o", "StrictHostKeyChecking=no",
"-p", str(port),
f"{user}@{hostname}",
"rm -f ~/.ssh/remove_key.sh"
]
subprocess.run(cleanup_cmd, check=True, capture_output=True)
return True
except subprocess.CalledProcessError as e:
print_error(f"Failed to remove old key: {e}")
# Try to clean up even if removal failed
try:
subprocess.run(cleanup_cmd, check=True, capture_output=True)
except:
pass
return False
except Exception as e:
print_error(f"Unexpected error removing old key: {e}")
return False

View file

@ -1,82 +1,68 @@
# ssh_manager/remove_host.py # ssh_manager/remove_host.py
import os import os
import glob
import subprocess import subprocess
import asyncio import asyncio
from collections import OrderedDict
from .utils import ( from .utils import (
print_info, print_info,
print_warning, print_warning,
print_error, print_error,
safe_input safe_input
) )
from .list_hosts import load_config_file, check_ssh_port from .list_hosts import build_host_list_table, load_config_file
"""
async def get_all_host_blocks(conf_dir): Remove host now reuses build_host_list_table to display the same columns:
pattern = os.path.join(conf_dir, "*", "config") No. | Host | User | Port | HostName | IP Address | Conf Directory
conf_files = sorted(glob.glob(pattern)) """
all_blocks = []
for conf_file in conf_files:
blocks = load_config_file(conf_file)
all_blocks.extend(blocks)
return all_blocks
async def remove_host(conf_dir): async def remove_host(conf_dir):
""" """
Remove an SSH host by: Remove an SSH host by:
1) Listing all hosts 1) Listing all hosts (same columns as main list)
2) Letting user pick row number or label 2) Letting user pick row number or label
3) Attempting to remove the old pub key from remote authorized_keys 3) Removing old pub key from remote
4) Deleting the subdirectory in ~/.ssh/conf/<host_label> 4) Deleting ~/.ssh/conf/<host_label>
""" """
print_info("Remove Host - Step 1: Show current hosts...\n") print_info("Remove Host - Step 1: Show current hosts...\n")
all_blocks = await get_all_host_blocks(conf_dir) # Reuse the unified table from list_hosts
if not all_blocks: headers, final_data = await build_host_list_table(conf_dir)
if not final_data:
print_warning("No hosts found. Cannot remove anything.") print_warning("No hosts found. Cannot remove anything.")
return return
# We'll display them in a table # Print the same table
import socket
from tabulate import tabulate from tabulate import tabulate
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid"))
table_rows = [] # We have final_data rows => need to map row => block
for idx, block in enumerate(all_blocks, start=1): # So let's gather the raw blocks again to correlate.
host_label = block.get("Host", "N/A") # We'll do a separate approach or we can parse final_data.
hostname = block.get("HostName", "N/A") # Easiest: Re-run load_config_file if needed or:
user = block.get("User", "N/A") blocks = []
port = int(block.get("Port", "22")) # The last gather call for build_host_list_table used load_config_file already
identity_file = block.get("IdentityFile", "N/A") # but it doesn't return the correlation. We'll replicate the logic quickly.
try: pattern = os.path.join(conf_dir, "*", "config")
ip_address = socket.gethostbyname(hostname) conf_files = sorted(os.listdir(conf_dir))
port_open = await asyncio.wait_for(check_ssh_port(ip_address, port), timeout=1)
except:
ip_address = None
port_open = False
ip_disp = f"\033[0;32m{ip_address}\033[0m" if ip_address else "\033[0;31mN/A\033[0m" # Actually, let's do a small approach:
port_disp = f"\033[0;32m{port}\033[0m" if port_open else f"\033[0;31m{port}\033[0m" from .list_hosts import gather_host_info, sort_by_ip
all_blocks = []
import glob
row = [ for cfile in glob.glob(os.path.join(conf_dir, "*", "config")):
idx, blocks.extend(load_config_file(cfile))
host_label,
user,
port_disp,
hostname,
ip_disp,
identity_file
]
table_rows.append(row)
headers = ["No.", "Host", "User", "Port", "HostName", "IP", "IdentityFile"] # gather the same big list
print("\nSSH Conf Subdirectory Host List") results = await gather_host_info(blocks)
print(tabulate(table_rows, headers=headers, tablefmt="grid")) sorted_rows = sort_by_ip(results)
# sorted_rows is a list of 7-tuples:
# (host_label, user, colored_port, hostname, colored_ip, conf_display, raw_ip)
# Prompt which host to remove
choice = safe_input("Enter the row number or Host label to remove: ") choice = safe_input("Enter the row number or Host label to remove: ")
if choice is None: if choice is None:
return return
@ -85,43 +71,60 @@ async def remove_host(conf_dir):
print_error("Invalid empty choice.") print_error("Invalid empty choice.")
return return
target_block = None target_tuple = None
# If digit => index in sorted_rows
if choice.isdigit(): if choice.isdigit():
idx = int(choice) idx = int(choice)
if idx < 1 or idx > len(all_blocks): if idx < 1 or idx > len(sorted_rows):
print_warning(f"Row number {idx} is invalid.") print_warning(f"Row number {idx} is invalid.")
return return
target_block = all_blocks[idx - 1] target_tuple = sorted_rows[idx - 1]
else: else:
# They typed a label # They typed a label
for b in all_blocks: for t in sorted_rows:
if b.get("Host") == choice: if t[0] == choice: # t[0] = host_label
target_block = b target_tuple = t
break break
if not target_block: if not target_tuple:
print_warning(f"No matching host label '{choice}' found.") print_warning(f"No matching host label '{choice}' found.")
return return
host_label = target_block.get("Host", "") # target_tuple is (host_label, user, colored_port, hostname, colored_ip, conf_dir, raw_ip)
hostname = target_block.get("HostName", "") host_label = target_tuple[0]
user = target_block.get("User", "root") hostname = target_tuple[3]
port = int(target_block.get("Port", "22")) user = target_tuple[1]
identity_file = target_block.get("IdentityFile", "") # parse port from colored_port? We can do a quick approach. Alternatively, we re-run the load_config approach again.
# We do a second approach: let's see if we can find the block in "blocks".
# But let's parse the config block to get the real port, identity, etc.
# We find the matching block in "blocks" by host_label
found_block = None
for b in blocks:
if b.get("Host") == host_label:
found_block = b
break
if not found_block:
print_warning(f"No config block found for {host_label}.")
return
port_str = found_block.get("Port", "22")
port = int(port_str)
identity_file = found_block.get("IdentityFile", "")
# now do removal logic
if not host_label: if not host_label:
print_warning("Target block has no Host label. Cannot remove.") print_warning("Target block has no Host label. Cannot remove.")
return return
# Check IdentityFile for old pub key
if identity_file and identity_file != "N/A": if identity_file and identity_file != "N/A":
expanded_key = os.path.expanduser(identity_file) expanded_key = os.path.expanduser(identity_file)
pub_path = expanded_key + ".pub" pub_path = expanded_key + ".pub"
old_pub_data = "" old_pub_data = ""
if os.path.isfile(pub_path): if os.path.isfile(pub_path):
try: try:
with open(pub_path, "r") as f: with open(pub_path, "r") as f:
# This is the EXACT line that might appear in authorized_keys
old_pub_data = f.read().rstrip("\n") old_pub_data = f.read().rstrip("\n")
except Exception as e: except Exception as e:
print_warning(f"Could not read old pub key: {e}") print_warning(f"Could not read old pub key: {e}")
@ -130,13 +133,13 @@ async def remove_host(conf_dir):
print_info("Attempting to remove old key from remote authorized_keys...") print_info("Attempting to remove old key from remote authorized_keys...")
await remove_old_key_remote(old_pub_data, user, hostname, port) await remove_old_key_remote(old_pub_data, user, hostname, port)
# Finally, remove the local subdirectory # remove local folder
host_dir = os.path.join(conf_dir, host_label) host_dir = os.path.join(conf_dir, host_label)
import shutil
if os.path.isdir(host_dir): if os.path.isdir(host_dir):
confirm = safe_input(f"Are you sure you want to delete local folder {host_dir}? (y/n): ") confirm = safe_input(f"Are you sure you want to delete local folder {host_dir}? (y/n): ")
if confirm and confirm.lower().startswith('y'): if confirm and confirm.lower().startswith('y'):
try: try:
import shutil
shutil.rmtree(host_dir) shutil.rmtree(host_dir)
print_info(f"Removed local folder: {host_dir}") print_info(f"Removed local folder: {host_dir}")
except Exception as e: except Exception as e:
@ -144,24 +147,17 @@ async def remove_host(conf_dir):
else: else:
print_warning("Local folder not removed.") print_warning("Local folder not removed.")
else: else:
print_warning(f"Local host folder {host_dir} not found. Nothing to remove.") print_warning(f"Local folder {host_dir} not found. Nothing to remove.")
async def remove_old_key_remote(old_pub_data, user, hostname, port): async def remove_old_key_remote(old_pub_data, user, hostname, port):
""" # same logic as before
Remove the old key from remote authorized_keys if found, matching EXACT lines.
We'll do:
grep -Fxq for existence
grep -vFx to remove it
"""
# 1) Check if old_pub_data is present in authorized_keys (exact line)
check_cmd = [ check_cmd = [
"ssh", "-o", "StrictHostKeyChecking=no", "ssh", "-o", "StrictHostKeyChecking=no",
"-p", str(port), "-p", str(port),
f"{user}@{hostname}", f"{user}@{hostname}",
f"grep -Fxq \"{old_pub_data}\" ~/.ssh/authorized_keys" f'grep -Fxq "{old_pub_data}" ~/.ssh/authorized_keys'
] ]
found_key = False found_key = False
try: try:
subprocess.check_call(check_cmd) subprocess.check_call(check_cmd)
found_key = True found_key = True
@ -172,12 +168,11 @@ async def remove_old_key_remote(old_pub_data, user, hostname, port):
print_warning("No old key found on remote authorized_keys.") print_warning("No old key found on remote authorized_keys.")
return return
# 2) Actually remove it by ignoring EXACT matches to that line
remove_cmd = [ remove_cmd = [
"ssh", "-o", "StrictHostKeyChecking=no", "ssh", "-o", "StrictHostKeyChecking=no",
"-p", str(port), "-p", str(port),
f"{user}@{hostname}", f"{user}@{hostname}",
f"grep -vFx \"{old_pub_data}\" ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys" f'grep -vFx "{old_pub_data}" ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys'
] ]
try: try:
subprocess.check_call(remove_cmd) subprocess.check_call(remove_cmd)

View file

@ -1,25 +1,34 @@
# ssh_manager/utils.py # ssh_manager/utils.py
from enum import Enum
import sys import sys
from functools import lru_cache
class Colors: class Colors(Enum):
GREEN = "\033[0;32m" GREEN = "\033[0;32m"
RED = "\033[0;31m" RED = "\033[0;31m"
YELLOW = "\033[1;33m" YELLOW = "\033[1;33m"
CYAN = "\033[0;36m" CYAN = "\033[0;36m"
BOLD = "\033[1m" BOLD = "\033[1m"
RESET = "\033[0m" RESET = "\033[0m"
def print_error(message): def __str__(self):
print(f"{Colors.RED}{Colors.BOLD}[✖] {Colors.RESET}{message}") return self.value
def print_warning(message): @lru_cache(maxsize=32)
print(f"{Colors.YELLOW}{Colors.BOLD}[⚠] {Colors.RESET}{message}") def _format_message(prefix: str, color: Colors, message: str) -> str:
return f"{color}{Colors.BOLD}[{prefix}] {Colors.RESET}{message}"
def print_info(message): def print_error(message: str) -> None:
print(f"{Colors.GREEN}{Colors.BOLD}[✔] {Colors.RESET}{message}") print(_format_message("", Colors.RED, message))
def safe_input(prompt=""): def print_warning(message: str) -> None:
print(_format_message("", Colors.YELLOW, message))
def print_info(message: str) -> None:
print(_format_message("", Colors.GREEN, message))
def safe_input(prompt: str = "") -> str:
""" """
A wrapper around input() that exits the entire script on Ctrl+C. A wrapper around input() that exits the entire script on Ctrl+C.
""" """