This commit is contained in:
Arctic 2025-03-08 00:43:48 -06:00
parent 414266eefc
commit ffb1f7e204
5 changed files with 636 additions and 312 deletions

View file

@ -2,109 +2,211 @@
import os import os
import subprocess import subprocess
import glob
from pathlib import Path
from typing import Optional, List, Dict, Set, Tuple
from dataclasses import dataclass
from .utils import print_error, print_warning, print_info, safe_input from .utils import print_error, print_warning, print_info, safe_input
from .config import CONF_DIR
def add_host(conf_dir): @dataclass
class HostConfig:
label: str
hostname: str
user: str = "root"
port: str = "22"
identity_file: str = ""
def to_config_lines(self) -> List[str]:
"""Convert host configuration to SSH config file lines."""
lines = [
f"Host {self.label}",
f" HostName {self.hostname}",
f" User {self.user}",
f" Port {self.port}"
]
if self.identity_file:
lines.append(f" IdentityFile {self.identity_file}")
return lines
def get_existing_hosts(conf_dir: str) -> Tuple[Set[str], Dict[str, str]]:
"""
Get existing host labels and hostnames from config files.
Returns (host_labels, hostname_to_label) where:
- host_labels is a set of existing host labels
- hostname_to_label maps hostnames to their labels
"""
host_labels = set()
hostname_to_label = {}
pattern = os.path.join(conf_dir, "*", "config")
for config_file in glob.glob(pattern):
try:
with open(config_file, 'r') as f:
lines = f.readlines()
current_label = None
for line in lines:
line = line.strip()
if not line or line.startswith('#'):
continue
if line.lower().startswith('host '):
labels = line.split()[1:]
for label in labels:
if '*' not in label:
current_label = label
host_labels.add(label)
break
elif current_label and line.lower().startswith('hostname '):
hostname = line.split(None, 1)[1].strip()
hostname_to_label[hostname] = current_label
except Exception as e:
print_warning(f"Error reading config file {config_file}: {e}")
continue
return host_labels, hostname_to_label
def generate_ssh_key(key_path: Path) -> bool:
"""Generate a new ED25519 SSH key pair."""
try:
subprocess.check_call([
"ssh-keygen",
"-q",
"-t", "ed25519",
"-N", "",
"-f", str(key_path)
])
print_info(f"Generated new SSH key at {key_path}")
return True
except subprocess.CalledProcessError as e:
print_error(f"Error generating SSH key: {e}")
return False
def copy_ssh_key(key_path: Path, user: str, hostname: str, port: str) -> bool:
"""Copy SSH public key to remote server."""
try:
cmd = ["ssh-copy-id", "-i", str(key_path)]
if port != "22":
cmd.extend(["-p", port])
cmd.append(f"{user}@{hostname}")
subprocess.check_call(cmd)
print_info("Key successfully copied to remote server.")
return True
except subprocess.CalledProcessError as e:
print_error(f"Error copying key to server: {e}")
return False
def write_config_file(config_path: Path, config_lines: List[str]) -> bool:
"""Write SSH config lines to file."""
try:
config_path.write_text("\n".join(config_lines) + "\n")
print_info(f"Created/updated config at: {config_path}")
return True
except Exception as e:
print_error(f"Failed to write config to {config_path}: {e}")
return False
def add_host(conf_dir: str) -> bool:
""" """
Interactive prompt to create a new SSH host in ~/.ssh/conf/<label>/config. Interactive prompt to create a new SSH host in ~/.ssh/conf/<label>/config.
Offers to generate a new SSH key pair (ed25519) quietly (-q), Offers to generate a new SSH key pair (ed25519) quietly (-q),
and then prompt to copy that key to the remote server via ssh-copy-id. and then prompt to copy that key to the remote server via ssh-copy-id.
Returns True if host was added successfully, False otherwise.
""" """
print_info("Adding a new SSH host...") print_info("Adding a new SSH host...")
# Get existing hosts to check for duplicates
existing_labels, hostname_to_label = get_existing_hosts(conf_dir)
# Get host label
while True:
host_label = safe_input("Enter Host label (e.g. myserver): ") host_label = safe_input("Enter Host label (e.g. myserver): ")
if host_label is None: if not host_label or host_label is None:
return # User canceled (Ctrl+C)
host_label = host_label.strip()
if not host_label:
print_error("Host label cannot be empty.") print_error("Host label cannot be empty.")
return return False
host_label = host_label.strip()
if host_label in existing_labels:
print_error(f"A host with label '{host_label}' already exists. Please choose a different label.")
continue
break
# Get hostname
while True:
hostname = safe_input("Enter HostName (IP or domain): ") hostname = safe_input("Enter HostName (IP or domain): ")
if hostname is None: if not hostname or hostname is None:
return
hostname = hostname.strip()
if not hostname:
print_error("HostName cannot be empty.") print_error("HostName cannot be empty.")
return return False
user = safe_input("Enter username (default: 'root'): ") hostname = hostname.strip()
if hostname in hostname_to_label:
existing_label = hostname_to_label[hostname]
print_error(f"This hostname is already configured for host '{existing_label}'. Please use a different hostname.")
continue
break
# Get optional parameters
user = safe_input("Enter username (default: 'root'): ") or "root"
if user is None: if user is None:
return return False
user = user.strip() or "root"
port = safe_input("Enter SSH port (default: 22): ") port = safe_input("Enter SSH port (default: 22): ") or "22"
if port is None: if port is None:
return return False
port = port.strip() or "22"
# Create subdirectory: ~/.ssh/conf/<label> # Create host configuration
host_dir = os.path.join(conf_dir, host_label) host_config = HostConfig(
if os.path.exists(host_dir): label=host_label,
hostname=hostname,
user=user.strip(),
port=port.strip()
)
# Setup directory structure
host_dir = Path(conf_dir) / host_label
if host_dir.exists():
print_warning(f"Directory {host_dir} already exists; continuing anyway.") print_warning(f"Directory {host_dir} already exists; continuing anyway.")
else: host_dir.mkdir(mode=0o700, exist_ok=True)
os.makedirs(host_dir, mode=0o700, exist_ok=True)
print_info(f"Created directory: {host_dir}") print_info(f"Created directory: {host_dir}")
config_path = os.path.join(host_dir, "config") config_path = host_dir / "config"
if os.path.exists(config_path): if config_path.exists():
print_warning(f"Config file already exists: {config_path}; it will be overwritten/updated.") print_warning(f"Config file already exists: {config_path}; it will be overwritten/updated.")
# Handle SSH key generation
gen_key_choice = safe_input("Generate a new ed25519 SSH key for this host? (y/n): ") gen_key_choice = safe_input("Generate a new ed25519 SSH key for this host? (y/n): ")
if gen_key_choice is None: if gen_key_choice is None:
return return False
gen_key_choice = gen_key_choice.lower().strip()
identity_file = "" if gen_key_choice.lower().strip() == 'y':
if gen_key_choice == 'y': key_path = host_dir / "id_ed25519"
key_path = os.path.join(host_dir, "id_ed25519")
if os.path.exists(key_path): if key_path.exists():
print_warning(f"{key_path} already exists. Skipping generation.") print_warning(f"{key_path} already exists. Skipping generation.")
identity_file = key_path host_config.identity_file = str(key_path)
else: else:
cmd = ["ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-f", key_path] if generate_ssh_key(key_path):
try: host_config.identity_file = str(key_path)
subprocess.check_call(cmd)
print_info(f"Generated new SSH key at {key_path}")
identity_file = key_path
# Prompt to copy the key # Prompt to copy the key
copy_key = safe_input("Would you like to copy this key to the server now? (y/n): ") copy_key = safe_input("Would you like to copy this key to the server now? (y/n): ")
if copy_key is None: if copy_key is None:
return return False
if copy_key.lower().strip() == 'y': if copy_key.lower().strip() == 'y':
ssh_copy_cmd = ["ssh-copy-id", "-i", key_path] copy_ssh_key(key_path, host_config.user, host_config.hostname, host_config.port)
if port != "22":
ssh_copy_cmd += ["-p", port]
ssh_copy_cmd.append(f"{user}@{hostname}")
try:
subprocess.check_call(ssh_copy_cmd)
print_info("Key successfully copied to remote server.")
except subprocess.CalledProcessError as e:
print_error(f"Error copying key to server: {e}")
except subprocess.CalledProcessError as e:
print_error(f"Error generating SSH key: {e}")
else: else:
# Handle existing key
existing_key = safe_input("Enter existing IdentityFile path (or leave empty to skip): ") existing_key = safe_input("Enter existing IdentityFile path (or leave empty to skip): ")
if existing_key is None: if existing_key is None:
return return False
existing_key = existing_key.strip()
if existing_key:
identity_file = os.path.expanduser(existing_key)
config_lines = [ if existing_key.strip():
f"Host {host_label}", host_config.identity_file = os.path.expanduser(existing_key.strip())
f" HostName {hostname}",
f" User {user}",
f" Port {port}"
]
if identity_file:
config_lines.append(f" IdentityFile {identity_file}")
try: # Write the config file
with open(config_path, "w") as f: return write_config_file(config_path, host_config.to_config_lines())
for line in config_lines:
f.write(line + "\n")
print_info(f"Created/updated config at: {config_path}")
except Exception as e:
print_error(f"Failed to write config to {config_path}: {e}")

115
cli.py
View file

@ -2,6 +2,8 @@
import os import os
import asyncio import asyncio
from typing import Callable, Dict, Optional
from functools import wraps
from .utils import print_info, print_error, print_warning, Colors, safe_input from .utils import print_info, print_error, print_warning, Colors, safe_input
from .config import SSH_DIR, CONF_DIR, SOCKET_DIR, MAIN_CONFIG, DEFAULT_CONFIG_CONTENT from .config import SSH_DIR, CONF_DIR, SOCKET_DIR, MAIN_CONFIG, DEFAULT_CONFIG_CONTENT
from .add_host import add_host from .add_host import add_host
@ -10,62 +12,101 @@ from .list_hosts import list_hosts
from .regen_key import regenerate_key from .regen_key import regenerate_key
from .remove_host import remove_host from .remove_host import remove_host
def ensure_ssh_setup(): def ensure_ssh_setup() -> None:
""" """
Creates ~/.ssh, ~/.ssh/conf, and ~/.ssh/s if missing, Creates ~/.ssh, ~/.ssh/conf, and ~/.ssh/s if missing,
and writes a default ~/.ssh/config if it doesn't exist. and writes a default ~/.ssh/config if it doesn't exist.
""" """
if not os.path.isdir(SSH_DIR): directories = [
os.makedirs(SSH_DIR, mode=0o700, exist_ok=True) (SSH_DIR, "SSH directory"),
print_info(f"Created directory: {SSH_DIR}") (CONF_DIR, "configuration directory"),
(SOCKET_DIR, "socket directory")
]
if not os.path.isdir(CONF_DIR): for directory, description in directories:
os.makedirs(CONF_DIR, mode=0o700, exist_ok=True) if not os.path.isdir(directory):
print_info(f"Created directory: {CONF_DIR}") os.makedirs(directory, mode=0o700, exist_ok=True)
print_info(f"Created {description}: {directory}")
if not os.path.isdir(SOCKET_DIR):
os.makedirs(SOCKET_DIR, mode=0o700, exist_ok=True)
print_info(f"Created directory: {SOCKET_DIR}")
if not os.path.isfile(MAIN_CONFIG): if not os.path.isfile(MAIN_CONFIG):
with open(MAIN_CONFIG, "w") as f: with open(MAIN_CONFIG, "w") as f:
f.write(DEFAULT_CONFIG_CONTENT) f.write(DEFAULT_CONFIG_CONTENT)
print_info(f"Created default SSH config at: {MAIN_CONFIG}") print_info(f"Created default SSH config at: {MAIN_CONFIG}")
def main(): def async_handler(func: Callable) -> Callable:
"""Decorator to handle async functions in the command dispatch"""
@wraps(func)
def wrapper(*args, **kwargs):
return asyncio.run(func(*args, **kwargs))
return wrapper
class SSHManager:
def __init__(self):
self.commands: Dict[str, tuple[Callable, str]] = {
"1": (self.list_hosts, "List Hosts"),
"2": (self.add_host, "Add a Host"),
"3": (self.edit_host, "Edit a Host"),
"4": (self.regenerate_key, "Regenerate Key"),
"5": (self.remove_host, "Remove Host"),
"6": (self.exit_app, "Exit")
}
@async_handler
async def list_hosts(self) -> None:
await list_hosts(CONF_DIR)
def add_host(self) -> None:
add_host(CONF_DIR)
@async_handler
async def edit_host(self) -> None:
await edit_host(CONF_DIR)
@async_handler
async def regenerate_key(self) -> None:
await regenerate_key(CONF_DIR)
@async_handler
async def remove_host(self) -> None:
await remove_host(CONF_DIR)
def exit_app(self) -> None:
print_info("Exiting...")
raise SystemExit(0)
def display_menu(self) -> None:
print("\n" + f"{Colors.CYAN}{Colors.BOLD}SSH Config Manager Menu{Colors.RESET}")
for key, (_, description) in self.commands.items():
print(f"{key}. {description}")
def handle_command(self, choice: str) -> bool:
if choice not in self.commands:
print_error("Invalid choice. Please select 1 through 6.")
return True
try:
self.commands[choice][0]()
return choice != "6"
except SystemExit:
return False
except Exception as e:
print_error(f"Error executing command: {str(e)}")
return True
def main() -> int:
ensure_ssh_setup() ensure_ssh_setup()
manager = SSHManager()
# Display the server list on first load # Display the server list on first load
asyncio.run(list_hosts(CONF_DIR)) manager.list_hosts()
while True: while True:
print("\n" + f"{Colors.CYAN}{Colors.BOLD}SSH Config Manager Menu{Colors.RESET}") manager.display_menu()
print("1. List Hosts")
print("2. Add a Host")
print("3. Edit a Host")
print("4. Regenerate Key")
print("5. Remove Host")
print("6. Exit")
choice = safe_input("Select an option (1-6): ") choice = safe_input("Select an option (1-6): ")
if choice is None: if choice is None:
continue # User pressed Ctrl+C => safe_input returns None => re-show menu continue
choice = choice.strip() if not manager.handle_command(choice.strip()):
if choice == '1':
asyncio.run(list_hosts(CONF_DIR))
elif choice == '2':
add_host(CONF_DIR)
elif choice == '3':
asyncio.run(edit_host(CONF_DIR))
elif choice == '4':
asyncio.run(regenerate_key(CONF_DIR))
elif choice == '5':
asyncio.run(remove_host(CONF_DIR))
elif choice == '6':
print_info("Exiting...")
break break
else:
print_error("Invalid choice. Please select 1 through 6.")
return 0 return 0

View file

@ -1,15 +1,24 @@
# ssh_manager/config.py # ssh_manager/config.py
import os import os
from pathlib import Path
from typing import Final
import subprocess
# Paths # Use Path for more efficient path handling
SSH_DIR = os.path.expanduser("~/.ssh") HOME: Final[Path] = Path.home()
CONF_DIR = os.path.join(SSH_DIR, "conf") SSH_DIR: Final[Path] = HOME / ".ssh"
SOCKET_DIR = os.path.join(SSH_DIR, "s") CONF_DIR: Final[Path] = SSH_DIR / "conf"
MAIN_CONFIG = os.path.join(SSH_DIR, "config") SOCKET_DIR: Final[Path] = SSH_DIR / "s"
MAIN_CONFIG: Final[Path] = SSH_DIR / "config"
# Validate paths on import
for path in (SSH_DIR, CONF_DIR, SOCKET_DIR):
if not path.exists():
path.mkdir(mode=0o700, parents=True, exist_ok=True)
# Default SSH config content if ~/.ssh/config is missing # Default SSH config content if ~/.ssh/config is missing
DEFAULT_CONFIG_CONTENT = """### DEFAULT_CONFIG_CONTENT: Final[str] = """###
#Local ssh #Local ssh
### ###
@ -30,3 +39,28 @@ Host *
ControlPersist 72000 ControlPersist 72000
ControlPath ~/.ssh/s/%C ControlPath ~/.ssh/s/%C
""" """
# Export string versions for backward compatibility
SSH_DIR_STR: Final[str] = str(SSH_DIR)
CONF_DIR_STR: Final[str] = str(CONF_DIR)
SOCKET_DIR_STR: Final[str] = str(SOCKET_DIR)
MAIN_CONFIG_STR: Final[str] = str(MAIN_CONFIG)
def validate_key_path(key_path: Path) -> bool:
if not key_path.exists():
return False
if not key_path.is_dir():
return False
if not os.access(key_path, os.W_OK):
return False
return True
def update_config_with_key(key_path: Path) -> bool:
if not validate_key_path(key_path):
return False # Early return if path validation fails
try:
subprocess.check_call([...])
except subprocess.CalledProcessError as e:
print_error(f"Error generating new SSH key: {e}")
return False

View file

@ -5,25 +5,28 @@ import glob
import socket import socket
import asyncio import asyncio
import ipaddress import ipaddress
import subprocess
from pathlib import Path
from typing import List, Dict, Tuple, Optional, OrderedDict as OrderedDictType
from functools import lru_cache
from tabulate import tabulate from tabulate import tabulate
from collections import OrderedDict from collections import OrderedDict
from .utils import print_warning, print_error, Colors from .utils import print_warning, print_error, Colors
from .config import CONF_DIR
async def check_ssh_port(ip_address, port): # Cache DNS lookups
@lru_cache(maxsize=128, typed=True)
def resolve_hostname(hostname: str) -> Optional[str]:
try: try:
reader, writer = await asyncio.wait_for( return socket.gethostbyname(hostname)
asyncio.open_connection(ip_address, port), timeout=1 except socket.error:
) return None
writer.close()
await writer.wait_closed() def load_config_file(file_path: str) -> List[OrderedDictType]:
return True blocks: List[OrderedDictType] = []
except: host_data: Optional[OrderedDictType] = None
return False
def load_config_file(file_path):
blocks = []
host_data = None
try: try:
with open(file_path, 'r') as f: with open(file_path, 'r') as f:
lines = f.readlines() lines = f.readlines()
@ -44,7 +47,7 @@ def load_config_file(file_path):
blocks.append(host_data) blocks.append(host_data)
host_data = OrderedDict({'Host': label}) host_data = OrderedDict({'Host': label})
break break
elif host_data: elif host_data is not None:
if ' ' in stripped_line: if ' ' in stripped_line:
key, value = stripped_line.split(None, 1) key, value = stripped_line.split(None, 1)
host_data[key] = value.strip() host_data[key] = value.strip()
@ -53,89 +56,101 @@ def load_config_file(file_path):
blocks.append(host_data) blocks.append(host_data)
return blocks return blocks
async def gather_host_info(all_host_blocks): async def gather_host_info(all_host_blocks: List[OrderedDictType]) -> List[Tuple]:
""" """
Given a list of host blocks, gather full info: Given a list of host blocks, gather full info:
- resolved IP - resolved IP
- port open check
- 'Conf Directory' coloring if IdentityFile != 'N/A' - 'Conf Directory' coloring if IdentityFile != 'N/A'
Returns a list of 7-tuples: Returns a list of 7-tuples:
(host_label, user, colored_port, hostname, colored_ip, conf_path_display, raw_ip) (host_label, user, port, hostname, colored_ip, conf_path_display, raw_ip)
""" """
results = [] results = []
async def process_block(h): async def process_block(h: OrderedDictType) -> Tuple:
host_label = h.get('Host', 'N/A') host_label: str = h.get('Host', 'N/A')
hostname = h.get('HostName', 'N/A') hostname: str = h.get('HostName', 'N/A')
user = h.get('User', 'N/A') user: str = h.get('User', 'N/A')
port = int(h.get('Port', '22')) port: str = h.get('Port', '22') # Keep as string since we're not testing it
identity_file = h.get('IdentityFile', 'N/A') identity_file: str = h.get('IdentityFile', 'N/A')
# Resolve IP # Resolve IP using cached function
raw_ip = resolve_hostname(hostname) or "N/A"
# Check if the IP is reachable
def is_ip_reachable(ip: str) -> bool:
try: try:
raw_ip = socket.gethostbyname(hostname) subprocess.run(["ping", "-c", "1", "-W", "1", ip], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
colored_ip = f"{Colors.GREEN}{raw_ip}{Colors.RESET}" return True
except socket.error: except subprocess.CalledProcessError:
raw_ip = "N/A" return False
colored_ip = f"{Colors.RED}N/A{Colors.RESET}"
# Check port # Determine IP color
if raw_ip != "N/A": if raw_ip != "N/A" and is_ip_reachable(raw_ip):
port_open = await check_ssh_port(raw_ip, port) colored_ip = f"{Colors.GREEN}{raw_ip}{Colors.RESET}"
colored_port = (
f"{Colors.GREEN}{port}{Colors.RESET}" if port_open else f"{Colors.RED}{port}{Colors.RESET}"
)
else: else:
colored_port = f"{Colors.RED}{port}{Colors.RESET}" colored_ip = raw_ip # No color if not reachable
# Determine hostname color
try:
ipaddress.ip_address(hostname)
colored_hostname = hostname # No color if hostname is an IP
except ValueError:
if raw_ip != "N/A":
colored_hostname = f"{Colors.GREEN}{hostname}{Colors.RESET}"
else:
colored_hostname = f"{Colors.RED}{hostname}{Colors.RESET}"
# Conf Directory = ~/.ssh/conf/<host_label> # Conf Directory = ~/.ssh/conf/<host_label>
conf_path = f"~/.ssh/conf/{host_label}" conf_path = f"~/.ssh/conf/{host_label}"
# If there's an IdentityFile, color the conf path conf_path_display = (
if identity_file != 'N/A': f"{Colors.GREEN}{conf_path}{Colors.RESET}"
conf_path_display = f"{Colors.GREEN}{conf_path}{Colors.RESET}" if identity_file != 'N/A'
else: else conf_path
conf_path_display = conf_path )
return ( return (
host_label, host_label,
user, user,
colored_port, port, # Port is now uncolored
hostname, colored_hostname,
colored_ip, colored_ip,
conf_path_display, conf_path_display,
raw_ip # uncolored IP for sorting raw_ip # uncolored IP for sorting
) )
# Process blocks concurrently # Process blocks concurrently with semaphore to limit concurrent connections
tasks = [process_block(b) for b in all_host_blocks] sem = asyncio.Semaphore(10) # Limit concurrent connections
async def process_with_semaphore(block):
async with sem:
return await process_block(block)
tasks = [process_with_semaphore(b) for b in all_host_blocks]
results = await asyncio.gather(*tasks) results = await asyncio.gather(*tasks)
return results return results
def parse_ip(ip_str): @lru_cache(maxsize=1)
def parse_ip(ip_str: str) -> Optional[ipaddress.IPv4Address]:
""" """
Convert a string IP to an ipaddress object for sorting. Convert a string IP to an ipaddress object for sorting.
Returns None if invalid or 'N/A'. Returns None if invalid or 'N/A'.
""" """
import ipaddress
try: try:
return ipaddress.ip_address(ip_str) return ipaddress.ip_address(ip_str)
except ValueError: except ValueError:
return None return None
def sort_by_ip(results): def sort_by_ip(results: List[Tuple]) -> List[Tuple]:
""" """
Sort the 7-tuples by IP ascending, with 'N/A' last. Sort the 7-tuples by IP ascending, with 'N/A' last.
""" """
sortable = [] def sort_key(row):
for row in results:
raw_ip = row[-1] raw_ip = row[-1]
ip_obj = parse_ip(raw_ip) ip_obj = parse_ip(raw_ip)
sortable.append(((ip_obj is None, ip_obj), row)) return (ip_obj is None, ip_obj or ipaddress.IPv4Address('0.0.0.0'))
sortable.sort(key=lambda x: x[0]) return sorted(results, key=sort_key)
return [row for (_, row) in sortable]
async def build_host_list_table(conf_dir): async def build_host_list_table(conf_dir: str) -> Tuple[List[str], List[List]]:
""" """
Gathers + sorts all hosts in conf_dir by IP ascending. Gathers + sorts all hosts in conf_dir by IP ascending.
Returns (headers, final_table_rows), each row omitting the raw_ip. Returns (headers, final_table_rows), each row omitting the raw_ip.
@ -143,7 +158,7 @@ async def build_host_list_table(conf_dir):
pattern = os.path.join(conf_dir, "*", "config") pattern = os.path.join(conf_dir, "*", "config")
conf_files = sorted(glob.glob(pattern)) conf_files = sorted(glob.glob(pattern))
all_host_blocks = [] all_host_blocks: List[OrderedDictType] = []
for conf_file in conf_files: for conf_file in conf_files:
blocks = load_config_file(conf_file) blocks = load_config_file(conf_file)
all_host_blocks.extend(blocks) all_host_blocks.extend(blocks)
@ -156,22 +171,22 @@ async def build_host_list_table(conf_dir):
sorted_rows = sort_by_ip(results) sorted_rows = sort_by_ip(results)
# Build final table # Build final table
final_data = [] final_data = [
for idx, row in enumerate(sorted_rows, start=1): [idx] + list(row[:-1])
# row is (host_label, user, colored_port, hostname, colored_ip, conf_path_display, raw_ip) for idx, row in enumerate(sorted_rows, start=1)
final_data.append([idx] + list(row[:-1])) ]
return headers, final_data return headers, final_data
async def list_hosts(conf_dir): async def list_hosts(conf_dir: str) -> None:
"""Display a formatted table of all SSH hosts."""
headers, final_data = await build_host_list_table(conf_dir) headers, final_data = await build_host_list_table(conf_dir)
if not final_data: if not final_data:
print_warning("No hosts found. The server list is empty.") print_warning("No hosts found. The server list is empty.")
print("\nSSH Conf Subdirectory Host List") print("\nSSH Conf Subdirectory Host List")
from tabulate import tabulate
print(tabulate([], headers=headers, tablefmt="grid")) print(tabulate([], headers=headers, tablefmt="grid"))
return return
from tabulate import tabulate
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)") print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid")) print(tabulate(final_data, headers=headers, tablefmt="grid"))

View file

@ -1,6 +1,8 @@
import os import os
import subprocess import subprocess
import asyncio import asyncio
from pathlib import Path
from typing import Optional, List, Dict, Tuple, Any
from .utils import ( from .utils import (
print_info, print_info,
print_warning, print_warning,
@ -14,105 +16,106 @@ from .list_hosts import (
sort_by_ip sort_by_ip
) )
async def regenerate_key(conf_dir): def validate_key_path(key_path: str) -> bool:
""" """Validate that the key path and its directory are valid."""
Regenerate the SSH key for a chosen host by: try:
1) Displaying the unified table of hosts (No. | Host | User | Port | HostName | IP Address | Conf Directory). key_path = Path(key_path)
2) Letting you pick a row number or host label. key_dir = key_path.parent
3) Reading/deleting any existing local keys,
4) Generating a new key,
5) Optionally copying it to the remote,
6) Removing the old pub key from the remote authorized_keys if present.
"""
print_info("Regenerate Key - Step 1: Show current hosts...\n") if not key_dir.exists():
key_dir.mkdir(mode=0o700, parents=True)
elif not os.access(str(key_dir), os.W_OK):
print_error(f"No write permission for directory: {key_dir}")
return False
# 1) Reuse the same columns as the main 'list_hosts' return True
headers, final_data = await build_host_list_table(conf_dir) except Exception as e:
if not final_data: print_error(f"Error validating key path: {e}")
print_warning("No hosts found. Cannot regenerate a key.") return False
return
from tabulate import tabulate def generate_new_key(key_path: str, user: str, hostname: str, port: int) -> bool:
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)") """Generate a new ED25519 SSH key and optionally copy it to the remote server."""
print(tabulate(final_data, headers=headers, tablefmt="grid")) if not validate_key_path(key_path):
return False
# 2) We need to correlate row => actual config block. print_info("Generating new ed25519 SSH key...")
# We'll replicate the logic that build_host_list_table uses. try:
all_blocks = [] subprocess.check_call([
import glob "ssh-keygen",
for cfile in glob.glob(os.path.join(conf_dir, "*", "config")): "-q",
all_blocks.extend(load_config_file(cfile)) "-t", "ed25519",
"-N", "",
"-f", key_path
])
except subprocess.CalledProcessError as e:
print_error(f"Error generating new SSH key: {e}")
return False
results = await gather_host_info(all_blocks) print_info(f"Generated new SSH key at {key_path}")
sorted_rows = sort_by_ip(results)
# sorted_rows is a list of 7-tuples:
# (host_label, user, colored_port, hostname, colored_ip, conf_path, raw_ip)
choice = safe_input("Enter the row number or the Host label to regenerate: ") # Ask to copy the key
if choice is None: copy_choice = safe_input("Copy new key to remote now? (y/n): ")
return if not copy_choice or not copy_choice.lower().startswith('y'):
choice = choice.strip() return True
try:
ssh_copy_cmd = ["ssh-copy-id", "-i", key_path]
if port != 22:
ssh_copy_cmd += ["-p", str(port)]
ssh_copy_cmd.append(f"{user}@{hostname}")
subprocess.check_call(ssh_copy_cmd)
print_info("New key successfully copied to remote server.")
return True
except subprocess.CalledProcessError as e:
print_error(f"Error copying new key: {e}")
return False
def update_config_with_key(config_path: Path, new_key_path: str) -> bool:
"""Update the SSH config file with the new identity file."""
try:
with open(config_path, 'r') as f:
config_lines = [
line.rstrip('\n')
for line in f
if not line.strip().lower().startswith('identityfile')
]
config_lines.append(f" IdentityFile {new_key_path}")
with open(config_path, 'w') as f:
f.write('\n'.join(config_lines) + '\n')
print_info(f"Updated config file with new IdentityFile: {new_key_path}")
return True
except Exception as e:
print_error(f"Failed to update config file: {e}")
return False
def find_target_host(sorted_rows: List[Tuple], choice: str) -> Optional[Tuple]:
"""Find the target host based on user choice."""
if not choice: if not choice:
print_error("No choice given.") print_error("No choice given.")
return return None
target_tuple = None
if choice.isdigit(): if choice.isdigit():
row_idx = int(choice) row_idx = int(choice)
if row_idx < 1 or row_idx > len(sorted_rows): if row_idx < 1 or row_idx > len(sorted_rows):
print_warning(f"Invalid row number {row_idx}.") print_warning(f"Invalid row number {row_idx}.")
return return None
target_tuple = sorted_rows[row_idx - 1] return sorted_rows[row_idx - 1]
else:
# user typed a label # User typed a label
for t in sorted_rows: for row in sorted_rows:
if t[0] == choice: # t[0] is host_label if row[0] == choice: # row[0] is host_label
target_tuple = t return row
break
if not target_tuple:
print_warning(f"No matching host label '{choice}' found.") print_warning(f"No matching host label '{choice}' found.")
return return None
# 3) Retrieve the full config block so we have Port, IdentityFile, etc. def remove_key_files(key_paths: List[str]) -> None:
host_label = target_tuple[0] """Remove SSH key files."""
hostname = target_tuple[3] # (host_label, user, colored_port, hostname, ...) for path in key_paths:
user = target_tuple[1]
# find that block
found_block = None
for b in all_blocks:
if b.get("Host") == host_label:
found_block = b
break
if not found_block:
print_warning(f"No config block found for '{host_label}'.")
return
port_str = found_block.get("Port", "22")
port = int(port_str)
identity_file = found_block.get("IdentityFile", "")
# If missing or "N/A", we can't regenerate
if not identity_file or identity_file == "N/A":
print_error("No IdentityFile found in config; cannot regenerate.")
return
# 4) Remove old local key files
expanded_key = os.path.expanduser(identity_file)
pub_path = expanded_key + ".pub"
old_pub_data = ""
if os.path.isfile(pub_path):
try:
with open(pub_path, "r") as f:
old_pub_data = f.read().rstrip("\n")
except Exception as e:
print_warning(f"Could not read old pub key: {e}")
print_info("Removing old key files locally...")
for path in [expanded_key, pub_path]:
if os.path.isfile(path): if os.path.isfile(path):
try: try:
os.remove(path) os.remove(path)
@ -120,67 +123,196 @@ async def regenerate_key(conf_dir):
except Exception as e: except Exception as e:
print_warning(f"Could not remove {path}: {e}") print_warning(f"Could not remove {path}: {e}")
# 5) Generate new key async def regenerate_key(conf_dir: str) -> bool:
print_info("Generating new ed25519 SSH key...") """
new_key_path = expanded_key Regenerate the SSH key for a chosen host by:
cmd = ["ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-f", new_key_path] 1) Displaying the unified table of hosts
try: 2) Letting you pick a row number or host label
subprocess.check_call(cmd) 3) Reading/deleting any existing local keys
print_info(f"Generated new SSH key at {new_key_path}") 4) Generating a new key
except subprocess.CalledProcessError as e: 5) Optionally copying it to the remote
print_error(f"Error generating new SSH key: {e}") 6) Removing the old pub key from the remote authorized_keys if present
return
# 6) Copy the new key to remote if user wants Returns True if key was successfully regenerated, False otherwise.
copy_choice = safe_input("Copy new key to remote now? (y/n): ") """
if copy_choice and copy_choice.lower().startswith('y'): print_info("Regenerate Key - Step 1: Show current hosts...\n")
ssh_copy_cmd = ["ssh-copy-id", "-i", new_key_path]
if port != 22:
ssh_copy_cmd += ["-p", str(port)]
ssh_copy_cmd.append(f"{user}@{hostname}")
try:
subprocess.check_call(ssh_copy_cmd)
print_info("New key successfully copied to remote server.")
except subprocess.CalledProcessError as e:
print_error(f"Error copying new key: {e}")
# 7) Remove old key from authorized_keys if we had old_pub_data # Get host list
headers, final_data = await build_host_list_table(conf_dir)
if not final_data:
print_warning("No hosts found. Cannot regenerate a key.")
return False
# Display host table
from tabulate import tabulate
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid"))
# Get host blocks and sort them
all_blocks = []
pattern = os.path.join(conf_dir, "*", "config")
for cfile in glob.glob(pattern):
all_blocks.extend(load_config_file(cfile))
results = await gather_host_info(all_blocks)
sorted_rows = sort_by_ip(results)
# Get user choice
choice = safe_input("Enter the row number or the Host label to regenerate: ")
if choice is None:
return False
target_tuple = find_target_host(sorted_rows, choice.strip())
if not target_tuple:
return False
# Get host information
host_label, user, _, hostname, *_ = target_tuple
# Find config block
found_block = next(
(b for b in all_blocks if b.get("Host") == host_label),
None
)
if not found_block:
print_warning(f"No config block found for '{host_label}'.")
return False
port = int(found_block.get("Port", "22"))
identity_file = found_block.get("IdentityFile", "")
# Handle missing identity file
if not identity_file or identity_file == "N/A":
print_warning("No existing SSH key found in the configuration.")
gen_choice = safe_input("Would you like to generate a new key? (y/n): ")
if not gen_choice or not gen_choice.lower().startswith('y'):
return False
# Set up new key path and generate key
host_dir = Path(conf_dir) / host_label
host_dir.mkdir(mode=0o700, exist_ok=True)
new_key_path = str(host_dir / "id_ed25519")
if not generate_new_key(new_key_path, user, hostname, port):
return False
# Update config with new key
config_path = host_dir / "config"
return update_config_with_key(config_path, new_key_path)
# Handle existing key regeneration
expanded_key = os.path.expanduser(identity_file)
pub_path = expanded_key + ".pub"
old_pub_data = ""
# Try to read old public key
if os.path.isfile(pub_path):
try:
old_pub_data = Path(pub_path).read_text().rstrip("\n")
except Exception as e:
print_warning(f"Could not read old pub key: {e}")
# Remove old key files
print_info("Removing old key files locally...")
remove_key_files([expanded_key, pub_path])
# Generate new key
if not generate_new_key(expanded_key, user, hostname, port):
return False
# Remove old key from remote if we had it
if old_pub_data: if old_pub_data:
print_info("Attempting to remove old key from remote authorized_keys...") print_info("Attempting to remove old key from remote authorized_keys...")
await remove_old_key_remote(old_pub_data, user, hostname, port) await remove_old_key_remote(old_pub_data, user, hostname, port)
else: else:
print_warning("No old pub key data found locally, skipping remote removal.") print_warning("No old pub key data found locally, skipping remote removal.")
async def remove_old_key_remote(old_pub_data, user, hostname, port): return True
async def remove_old_key_remote(old_pub_data: str, user: str, hostname: str, port: int) -> bool:
""" """
Checks and removes the exact matching line from remote authorized_keys Remove the old public key from remote authorized_keys file.
via grep -Fxq / grep -vFx. Returns True if successful, False otherwise.
""" """
check_cmd = [ # Escape the public key data for shell safety
"ssh", "-o", "StrictHostKeyChecking=no", escaped_key = old_pub_data.replace('"', '\\"')
# First check if authorized_keys exists
check_file_cmd = [
"ssh",
"-o", "StrictHostKeyChecking=no",
"-p", str(port), "-p", str(port),
f"{user}@{hostname}", f"{user}@{hostname}",
f'grep -Fxq "{old_pub_data}" ~/.ssh/authorized_keys' "test -f ~/.ssh/authorized_keys && echo 'exists'"
] ]
found_key = False
try: try:
subprocess.check_call(check_cmd) result = subprocess.run(check_file_cmd, capture_output=True, text=True)
found_key = True if result.returncode != 0 or 'exists' not in result.stdout:
except subprocess.CalledProcessError: print_warning("No authorized_keys file found on remote host.")
pass return False
except subprocess.CalledProcessError as e:
print_error(f"Error checking authorized_keys: {e}")
return False
if not found_key: # Create a temporary file for the sed script
print_warning("No old key found on remote authorized_keys.") temp_script = """#!/bin/bash
return set -e
KEYS_FILE="$HOME/.ssh/authorized_keys"
TEMP_FILE="$HOME/.ssh/authorized_keys.tmp"
grep -Fxv "$1" "$KEYS_FILE" > "$TEMP_FILE"
mv "$TEMP_FILE" "$KEYS_FILE"
chmod 600 "$KEYS_FILE"
"""
# Create a temporary script on the remote host
setup_cmd = [
"ssh",
"-o", "StrictHostKeyChecking=no",
"-p", str(port),
f"{user}@{hostname}",
f'cat > ~/.ssh/remove_key.sh << \'EOF\'\n{temp_script}\nEOF\n'
f'chmod +x ~/.ssh/remove_key.sh'
]
try:
subprocess.run(setup_cmd, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
print_error(f"Failed to create temporary script: {e}")
return False
# Execute the script with the key
remove_cmd = [ remove_cmd = [
"ssh", "-o", "StrictHostKeyChecking=no", "ssh",
"-o", "StrictHostKeyChecking=no",
"-p", str(port), "-p", str(port),
f"{user}@{hostname}", f"{user}@{hostname}",
f'grep -vFx "{old_pub_data}" ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys' f'~/.ssh/remove_key.sh "{escaped_key}"'
] ]
try: try:
subprocess.check_call(remove_cmd) subprocess.run(remove_cmd, check=True, capture_output=True)
print_info("Old public key removed from remote authorized_keys.") print_info("Old public key removed from remote authorized_keys.")
except subprocess.CalledProcessError:
print_warning("Failed to remove old key from remote authorized_keys (permissions or other error).") # Clean up the temporary script
cleanup_cmd = [
"ssh",
"-o", "StrictHostKeyChecking=no",
"-p", str(port),
f"{user}@{hostname}",
"rm -f ~/.ssh/remove_key.sh"
]
subprocess.run(cleanup_cmd, check=True, capture_output=True)
return True
except subprocess.CalledProcessError as e:
print_error(f"Failed to remove old key: {e}")
# Try to clean up even if removal failed
try:
subprocess.run(cleanup_cmd, check=True, capture_output=True)
except:
pass
return False
except Exception as e:
print_error(f"Unexpected error removing old key: {e}")
return False