This commit is contained in:
Arctic 2025-03-08 00:43:48 -06:00
parent 414266eefc
commit ffb1f7e204
5 changed files with 636 additions and 312 deletions

View file

@ -2,109 +2,211 @@
import os
import subprocess
import glob
from pathlib import Path
from typing import Optional, List, Dict, Set, Tuple
from dataclasses import dataclass
from .utils import print_error, print_warning, print_info, safe_input
from .config import CONF_DIR
def add_host(conf_dir):
@dataclass
class HostConfig:
label: str
hostname: str
user: str = "root"
port: str = "22"
identity_file: str = ""
def to_config_lines(self) -> List[str]:
"""Convert host configuration to SSH config file lines."""
lines = [
f"Host {self.label}",
f" HostName {self.hostname}",
f" User {self.user}",
f" Port {self.port}"
]
if self.identity_file:
lines.append(f" IdentityFile {self.identity_file}")
return lines
def get_existing_hosts(conf_dir: str) -> Tuple[Set[str], Dict[str, str]]:
"""
Get existing host labels and hostnames from config files.
Returns (host_labels, hostname_to_label) where:
- host_labels is a set of existing host labels
- hostname_to_label maps hostnames to their labels
"""
host_labels = set()
hostname_to_label = {}
pattern = os.path.join(conf_dir, "*", "config")
for config_file in glob.glob(pattern):
try:
with open(config_file, 'r') as f:
lines = f.readlines()
current_label = None
for line in lines:
line = line.strip()
if not line or line.startswith('#'):
continue
if line.lower().startswith('host '):
labels = line.split()[1:]
for label in labels:
if '*' not in label:
current_label = label
host_labels.add(label)
break
elif current_label and line.lower().startswith('hostname '):
hostname = line.split(None, 1)[1].strip()
hostname_to_label[hostname] = current_label
except Exception as e:
print_warning(f"Error reading config file {config_file}: {e}")
continue
return host_labels, hostname_to_label
def generate_ssh_key(key_path: Path) -> bool:
"""Generate a new ED25519 SSH key pair."""
try:
subprocess.check_call([
"ssh-keygen",
"-q",
"-t", "ed25519",
"-N", "",
"-f", str(key_path)
])
print_info(f"Generated new SSH key at {key_path}")
return True
except subprocess.CalledProcessError as e:
print_error(f"Error generating SSH key: {e}")
return False
def copy_ssh_key(key_path: Path, user: str, hostname: str, port: str) -> bool:
"""Copy SSH public key to remote server."""
try:
cmd = ["ssh-copy-id", "-i", str(key_path)]
if port != "22":
cmd.extend(["-p", port])
cmd.append(f"{user}@{hostname}")
subprocess.check_call(cmd)
print_info("Key successfully copied to remote server.")
return True
except subprocess.CalledProcessError as e:
print_error(f"Error copying key to server: {e}")
return False
def write_config_file(config_path: Path, config_lines: List[str]) -> bool:
"""Write SSH config lines to file."""
try:
config_path.write_text("\n".join(config_lines) + "\n")
print_info(f"Created/updated config at: {config_path}")
return True
except Exception as e:
print_error(f"Failed to write config to {config_path}: {e}")
return False
def add_host(conf_dir: str) -> bool:
"""
Interactive prompt to create a new SSH host in ~/.ssh/conf/<label>/config.
Offers to generate a new SSH key pair (ed25519) quietly (-q),
and then prompt to copy that key to the remote server via ssh-copy-id.
Returns True if host was added successfully, False otherwise.
"""
print_info("Adding a new SSH host...")
host_label = safe_input("Enter Host label (e.g. myserver): ")
if host_label is None:
return # User canceled (Ctrl+C)
host_label = host_label.strip()
if not host_label:
print_error("Host label cannot be empty.")
return
# Get existing hosts to check for duplicates
existing_labels, hostname_to_label = get_existing_hosts(conf_dir)
hostname = safe_input("Enter HostName (IP or domain): ")
if hostname is None:
return
hostname = hostname.strip()
if not hostname:
print_error("HostName cannot be empty.")
return
# Get host label
while True:
host_label = safe_input("Enter Host label (e.g. myserver): ")
if not host_label or host_label is None:
print_error("Host label cannot be empty.")
return False
host_label = host_label.strip()
if host_label in existing_labels:
print_error(f"A host with label '{host_label}' already exists. Please choose a different label.")
continue
break
# Get hostname
while True:
hostname = safe_input("Enter HostName (IP or domain): ")
if not hostname or hostname is None:
print_error("HostName cannot be empty.")
return False
hostname = hostname.strip()
if hostname in hostname_to_label:
existing_label = hostname_to_label[hostname]
print_error(f"This hostname is already configured for host '{existing_label}'. Please use a different hostname.")
continue
break
user = safe_input("Enter username (default: 'root'): ")
# Get optional parameters
user = safe_input("Enter username (default: 'root'): ") or "root"
if user is None:
return
user = user.strip() or "root"
port = safe_input("Enter SSH port (default: 22): ")
return False
port = safe_input("Enter SSH port (default: 22): ") or "22"
if port is None:
return
port = port.strip() or "22"
return False
# Create subdirectory: ~/.ssh/conf/<label>
host_dir = os.path.join(conf_dir, host_label)
if os.path.exists(host_dir):
# Create host configuration
host_config = HostConfig(
label=host_label,
hostname=hostname,
user=user.strip(),
port=port.strip()
)
# Setup directory structure
host_dir = Path(conf_dir) / host_label
if host_dir.exists():
print_warning(f"Directory {host_dir} already exists; continuing anyway.")
else:
os.makedirs(host_dir, mode=0o700, exist_ok=True)
print_info(f"Created directory: {host_dir}")
host_dir.mkdir(mode=0o700, exist_ok=True)
print_info(f"Created directory: {host_dir}")
config_path = os.path.join(host_dir, "config")
if os.path.exists(config_path):
config_path = host_dir / "config"
if config_path.exists():
print_warning(f"Config file already exists: {config_path}; it will be overwritten/updated.")
# Handle SSH key generation
gen_key_choice = safe_input("Generate a new ed25519 SSH key for this host? (y/n): ")
if gen_key_choice is None:
return
gen_key_choice = gen_key_choice.lower().strip()
return False
identity_file = ""
if gen_key_choice == 'y':
key_path = os.path.join(host_dir, "id_ed25519")
if os.path.exists(key_path):
if gen_key_choice.lower().strip() == 'y':
key_path = host_dir / "id_ed25519"
if key_path.exists():
print_warning(f"{key_path} already exists. Skipping generation.")
identity_file = key_path
host_config.identity_file = str(key_path)
else:
cmd = ["ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-f", key_path]
try:
subprocess.check_call(cmd)
print_info(f"Generated new SSH key at {key_path}")
identity_file = key_path
if generate_ssh_key(key_path):
host_config.identity_file = str(key_path)
# Prompt to copy the key
copy_key = safe_input("Would you like to copy this key to the server now? (y/n): ")
if copy_key is None:
return
return False
if copy_key.lower().strip() == 'y':
ssh_copy_cmd = ["ssh-copy-id", "-i", key_path]
if port != "22":
ssh_copy_cmd += ["-p", port]
ssh_copy_cmd.append(f"{user}@{hostname}")
try:
subprocess.check_call(ssh_copy_cmd)
print_info("Key successfully copied to remote server.")
except subprocess.CalledProcessError as e:
print_error(f"Error copying key to server: {e}")
except subprocess.CalledProcessError as e:
print_error(f"Error generating SSH key: {e}")
copy_ssh_key(key_path, host_config.user, host_config.hostname, host_config.port)
else:
# Handle existing key
existing_key = safe_input("Enter existing IdentityFile path (or leave empty to skip): ")
if existing_key is None:
return
existing_key = existing_key.strip()
if existing_key:
identity_file = os.path.expanduser(existing_key)
return False
if existing_key.strip():
host_config.identity_file = os.path.expanduser(existing_key.strip())
config_lines = [
f"Host {host_label}",
f" HostName {hostname}",
f" User {user}",
f" Port {port}"
]
if identity_file:
config_lines.append(f" IdentityFile {identity_file}")
try:
with open(config_path, "w") as f:
for line in config_lines:
f.write(line + "\n")
print_info(f"Created/updated config at: {config_path}")
except Exception as e:
print_error(f"Failed to write config to {config_path}: {e}")
# Write the config file
return write_config_file(config_path, host_config.to_config_lines())

121
cli.py
View file

@ -2,6 +2,8 @@
import os
import asyncio
from typing import Callable, Dict, Optional
from functools import wraps
from .utils import print_info, print_error, print_warning, Colors, safe_input
from .config import SSH_DIR, CONF_DIR, SOCKET_DIR, MAIN_CONFIG, DEFAULT_CONFIG_CONTENT
from .add_host import add_host
@ -10,62 +12,101 @@ from .list_hosts import list_hosts
from .regen_key import regenerate_key
from .remove_host import remove_host
def ensure_ssh_setup():
def ensure_ssh_setup() -> None:
"""
Creates ~/.ssh, ~/.ssh/conf, and ~/.ssh/s if missing,
and writes a default ~/.ssh/config if it doesn't exist.
"""
if not os.path.isdir(SSH_DIR):
os.makedirs(SSH_DIR, mode=0o700, exist_ok=True)
print_info(f"Created directory: {SSH_DIR}")
if not os.path.isdir(CONF_DIR):
os.makedirs(CONF_DIR, mode=0o700, exist_ok=True)
print_info(f"Created directory: {CONF_DIR}")
if not os.path.isdir(SOCKET_DIR):
os.makedirs(SOCKET_DIR, mode=0o700, exist_ok=True)
print_info(f"Created directory: {SOCKET_DIR}")
directories = [
(SSH_DIR, "SSH directory"),
(CONF_DIR, "configuration directory"),
(SOCKET_DIR, "socket directory")
]
for directory, description in directories:
if not os.path.isdir(directory):
os.makedirs(directory, mode=0o700, exist_ok=True)
print_info(f"Created {description}: {directory}")
if not os.path.isfile(MAIN_CONFIG):
with open(MAIN_CONFIG, "w") as f:
f.write(DEFAULT_CONFIG_CONTENT)
print_info(f"Created default SSH config at: {MAIN_CONFIG}")
def main():
ensure_ssh_setup()
def async_handler(func: Callable) -> Callable:
"""Decorator to handle async functions in the command dispatch"""
@wraps(func)
def wrapper(*args, **kwargs):
return asyncio.run(func(*args, **kwargs))
return wrapper
class SSHManager:
def __init__(self):
self.commands: Dict[str, tuple[Callable, str]] = {
"1": (self.list_hosts, "List Hosts"),
"2": (self.add_host, "Add a Host"),
"3": (self.edit_host, "Edit a Host"),
"4": (self.regenerate_key, "Regenerate Key"),
"5": (self.remove_host, "Remove Host"),
"6": (self.exit_app, "Exit")
}
@async_handler
async def list_hosts(self) -> None:
await list_hosts(CONF_DIR)
def add_host(self) -> None:
add_host(CONF_DIR)
@async_handler
async def edit_host(self) -> None:
await edit_host(CONF_DIR)
@async_handler
async def regenerate_key(self) -> None:
await regenerate_key(CONF_DIR)
@async_handler
async def remove_host(self) -> None:
await remove_host(CONF_DIR)
def exit_app(self) -> None:
print_info("Exiting...")
raise SystemExit(0)
def display_menu(self) -> None:
print("\n" + f"{Colors.CYAN}{Colors.BOLD}SSH Config Manager Menu{Colors.RESET}")
for key, (_, description) in self.commands.items():
print(f"{key}. {description}")
def handle_command(self, choice: str) -> bool:
if choice not in self.commands:
print_error("Invalid choice. Please select 1 through 6.")
return True
try:
self.commands[choice][0]()
return choice != "6"
except SystemExit:
return False
except Exception as e:
print_error(f"Error executing command: {str(e)}")
return True
def main() -> int:
ensure_ssh_setup()
manager = SSHManager()
# Display the server list on first load
asyncio.run(list_hosts(CONF_DIR))
manager.list_hosts()
while True:
print("\n" + f"{Colors.CYAN}{Colors.BOLD}SSH Config Manager Menu{Colors.RESET}")
print("1. List Hosts")
print("2. Add a Host")
print("3. Edit a Host")
print("4. Regenerate Key")
print("5. Remove Host")
print("6. Exit")
manager.display_menu()
choice = safe_input("Select an option (1-6): ")
if choice is None:
continue # User pressed Ctrl+C => safe_input returns None => re-show menu
choice = choice.strip()
if choice == '1':
asyncio.run(list_hosts(CONF_DIR))
elif choice == '2':
add_host(CONF_DIR)
elif choice == '3':
asyncio.run(edit_host(CONF_DIR))
elif choice == '4':
asyncio.run(regenerate_key(CONF_DIR))
elif choice == '5':
asyncio.run(remove_host(CONF_DIR))
elif choice == '6':
print_info("Exiting...")
continue
if not manager.handle_command(choice.strip()):
break
else:
print_error("Invalid choice. Please select 1 through 6.")
return 0

View file

@ -1,15 +1,24 @@
# ssh_manager/config.py
import os
from pathlib import Path
from typing import Final
import subprocess
# Paths
SSH_DIR = os.path.expanduser("~/.ssh")
CONF_DIR = os.path.join(SSH_DIR, "conf")
SOCKET_DIR = os.path.join(SSH_DIR, "s")
MAIN_CONFIG = os.path.join(SSH_DIR, "config")
# Use Path for more efficient path handling
HOME: Final[Path] = Path.home()
SSH_DIR: Final[Path] = HOME / ".ssh"
CONF_DIR: Final[Path] = SSH_DIR / "conf"
SOCKET_DIR: Final[Path] = SSH_DIR / "s"
MAIN_CONFIG: Final[Path] = SSH_DIR / "config"
# Validate paths on import
for path in (SSH_DIR, CONF_DIR, SOCKET_DIR):
if not path.exists():
path.mkdir(mode=0o700, parents=True, exist_ok=True)
# Default SSH config content if ~/.ssh/config is missing
DEFAULT_CONFIG_CONTENT = """###
DEFAULT_CONFIG_CONTENT: Final[str] = """###
#Local ssh
###
@ -30,3 +39,28 @@ Host *
ControlPersist 72000
ControlPath ~/.ssh/s/%C
"""
# Export string versions for backward compatibility
SSH_DIR_STR: Final[str] = str(SSH_DIR)
CONF_DIR_STR: Final[str] = str(CONF_DIR)
SOCKET_DIR_STR: Final[str] = str(SOCKET_DIR)
MAIN_CONFIG_STR: Final[str] = str(MAIN_CONFIG)
def validate_key_path(key_path: Path) -> bool:
if not key_path.exists():
return False
if not key_path.is_dir():
return False
if not os.access(key_path, os.W_OK):
return False
return True
def update_config_with_key(key_path: Path) -> bool:
if not validate_key_path(key_path):
return False # Early return if path validation fails
try:
subprocess.check_call([...])
except subprocess.CalledProcessError as e:
print_error(f"Error generating new SSH key: {e}")
return False

View file

@ -5,25 +5,28 @@ import glob
import socket
import asyncio
import ipaddress
import subprocess
from pathlib import Path
from typing import List, Dict, Tuple, Optional, OrderedDict as OrderedDictType
from functools import lru_cache
from tabulate import tabulate
from collections import OrderedDict
from .utils import print_warning, print_error, Colors
from .config import CONF_DIR
async def check_ssh_port(ip_address, port):
# Cache DNS lookups
@lru_cache(maxsize=128, typed=True)
def resolve_hostname(hostname: str) -> Optional[str]:
try:
reader, writer = await asyncio.wait_for(
asyncio.open_connection(ip_address, port), timeout=1
)
writer.close()
await writer.wait_closed()
return True
except:
return False
return socket.gethostbyname(hostname)
except socket.error:
return None
def load_config_file(file_path):
blocks = []
host_data = None
def load_config_file(file_path: str) -> List[OrderedDictType]:
blocks: List[OrderedDictType] = []
host_data: Optional[OrderedDictType] = None
try:
with open(file_path, 'r') as f:
lines = f.readlines()
@ -44,7 +47,7 @@ def load_config_file(file_path):
blocks.append(host_data)
host_data = OrderedDict({'Host': label})
break
elif host_data:
elif host_data is not None:
if ' ' in stripped_line:
key, value = stripped_line.split(None, 1)
host_data[key] = value.strip()
@ -53,97 +56,109 @@ def load_config_file(file_path):
blocks.append(host_data)
return blocks
async def gather_host_info(all_host_blocks):
async def gather_host_info(all_host_blocks: List[OrderedDictType]) -> List[Tuple]:
"""
Given a list of host blocks, gather full info:
- resolved IP
- port open check
- 'Conf Directory' coloring if IdentityFile != 'N/A'
Returns a list of 7-tuples:
(host_label, user, colored_port, hostname, colored_ip, conf_path_display, raw_ip)
(host_label, user, port, hostname, colored_ip, conf_path_display, raw_ip)
"""
results = []
async def process_block(h):
host_label = h.get('Host', 'N/A')
hostname = h.get('HostName', 'N/A')
user = h.get('User', 'N/A')
port = int(h.get('Port', '22'))
identity_file = h.get('IdentityFile', 'N/A')
async def process_block(h: OrderedDictType) -> Tuple:
host_label: str = h.get('Host', 'N/A')
hostname: str = h.get('HostName', 'N/A')
user: str = h.get('User', 'N/A')
port: str = h.get('Port', '22') # Keep as string since we're not testing it
identity_file: str = h.get('IdentityFile', 'N/A')
# Resolve IP
try:
raw_ip = socket.gethostbyname(hostname)
# Resolve IP using cached function
raw_ip = resolve_hostname(hostname) or "N/A"
# Check if the IP is reachable
def is_ip_reachable(ip: str) -> bool:
try:
subprocess.run(["ping", "-c", "1", "-W", "1", ip], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
return True
except subprocess.CalledProcessError:
return False
# Determine IP color
if raw_ip != "N/A" and is_ip_reachable(raw_ip):
colored_ip = f"{Colors.GREEN}{raw_ip}{Colors.RESET}"
except socket.error:
raw_ip = "N/A"
colored_ip = f"{Colors.RED}N/A{Colors.RESET}"
# Check port
if raw_ip != "N/A":
port_open = await check_ssh_port(raw_ip, port)
colored_port = (
f"{Colors.GREEN}{port}{Colors.RESET}" if port_open else f"{Colors.RED}{port}{Colors.RESET}"
)
else:
colored_port = f"{Colors.RED}{port}{Colors.RESET}"
colored_ip = raw_ip # No color if not reachable
# Determine hostname color
try:
ipaddress.ip_address(hostname)
colored_hostname = hostname # No color if hostname is an IP
except ValueError:
if raw_ip != "N/A":
colored_hostname = f"{Colors.GREEN}{hostname}{Colors.RESET}"
else:
colored_hostname = f"{Colors.RED}{hostname}{Colors.RESET}"
# Conf Directory = ~/.ssh/conf/<host_label>
conf_path = f"~/.ssh/conf/{host_label}"
# If there's an IdentityFile, color the conf path
if identity_file != 'N/A':
conf_path_display = f"{Colors.GREEN}{conf_path}{Colors.RESET}"
else:
conf_path_display = conf_path
conf_path_display = (
f"{Colors.GREEN}{conf_path}{Colors.RESET}"
if identity_file != 'N/A'
else conf_path
)
return (
host_label,
user,
colored_port,
hostname,
port, # Port is now uncolored
colored_hostname,
colored_ip,
conf_path_display,
raw_ip # uncolored IP for sorting
)
# Process blocks concurrently
tasks = [process_block(b) for b in all_host_blocks]
# Process blocks concurrently with semaphore to limit concurrent connections
sem = asyncio.Semaphore(10) # Limit concurrent connections
async def process_with_semaphore(block):
async with sem:
return await process_block(block)
tasks = [process_with_semaphore(b) for b in all_host_blocks]
results = await asyncio.gather(*tasks)
return results
def parse_ip(ip_str):
@lru_cache(maxsize=1)
def parse_ip(ip_str: str) -> Optional[ipaddress.IPv4Address]:
"""
Convert a string IP to an ipaddress object for sorting.
Returns None if invalid or 'N/A'.
"""
import ipaddress
try:
return ipaddress.ip_address(ip_str)
except ValueError:
return None
def sort_by_ip(results):
def sort_by_ip(results: List[Tuple]) -> List[Tuple]:
"""
Sort the 7-tuples by IP ascending, with 'N/A' last.
"""
sortable = []
for row in results:
def sort_key(row):
raw_ip = row[-1]
ip_obj = parse_ip(raw_ip)
sortable.append(((ip_obj is None, ip_obj), row))
return (ip_obj is None, ip_obj or ipaddress.IPv4Address('0.0.0.0'))
sortable.sort(key=lambda x: x[0])
return [row for (_, row) in sortable]
return sorted(results, key=sort_key)
async def build_host_list_table(conf_dir):
async def build_host_list_table(conf_dir: str) -> Tuple[List[str], List[List]]:
"""
Gathers + sorts all hosts in conf_dir by IP ascending.
Returns (headers, final_table_rows), each row omitting the raw_ip.
"""
pattern = os.path.join(conf_dir, "*", "config")
pattern = os.path.join(conf_dir, "*", "config")
conf_files = sorted(glob.glob(pattern))
all_host_blocks = []
all_host_blocks: List[OrderedDictType] = []
for conf_file in conf_files:
blocks = load_config_file(conf_file)
all_host_blocks.extend(blocks)
@ -156,22 +171,22 @@ async def build_host_list_table(conf_dir):
sorted_rows = sort_by_ip(results)
# Build final table
final_data = []
for idx, row in enumerate(sorted_rows, start=1):
# row is (host_label, user, colored_port, hostname, colored_ip, conf_path_display, raw_ip)
final_data.append([idx] + list(row[:-1]))
final_data = [
[idx] + list(row[:-1])
for idx, row in enumerate(sorted_rows, start=1)
]
return headers, final_data
async def list_hosts(conf_dir):
async def list_hosts(conf_dir: str) -> None:
"""Display a formatted table of all SSH hosts."""
headers, final_data = await build_host_list_table(conf_dir)
if not final_data:
print_warning("No hosts found. The server list is empty.")
print("\nSSH Conf Subdirectory Host List")
from tabulate import tabulate
print(tabulate([], headers=headers, tablefmt="grid"))
return
from tabulate import tabulate
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid"))

View file

@ -1,6 +1,8 @@
import os
import subprocess
import asyncio
from pathlib import Path
from typing import Optional, List, Dict, Tuple, Any
from .utils import (
print_info,
print_warning,
@ -14,105 +16,106 @@ from .list_hosts import (
sort_by_ip
)
async def regenerate_key(conf_dir):
"""
Regenerate the SSH key for a chosen host by:
1) Displaying the unified table of hosts (No. | Host | User | Port | HostName | IP Address | Conf Directory).
2) Letting you pick a row number or host label.
3) Reading/deleting any existing local keys,
4) Generating a new key,
5) Optionally copying it to the remote,
6) Removing the old pub key from the remote authorized_keys if present.
"""
def validate_key_path(key_path: str) -> bool:
"""Validate that the key path and its directory are valid."""
try:
key_path = Path(key_path)
key_dir = key_path.parent
if not key_dir.exists():
key_dir.mkdir(mode=0o700, parents=True)
elif not os.access(str(key_dir), os.W_OK):
print_error(f"No write permission for directory: {key_dir}")
return False
return True
except Exception as e:
print_error(f"Error validating key path: {e}")
return False
print_info("Regenerate Key - Step 1: Show current hosts...\n")
def generate_new_key(key_path: str, user: str, hostname: str, port: int) -> bool:
"""Generate a new ED25519 SSH key and optionally copy it to the remote server."""
if not validate_key_path(key_path):
return False
# 1) Reuse the same columns as the main 'list_hosts'
headers, final_data = await build_host_list_table(conf_dir)
if not final_data:
print_warning("No hosts found. Cannot regenerate a key.")
return
print_info("Generating new ed25519 SSH key...")
try:
subprocess.check_call([
"ssh-keygen",
"-q",
"-t", "ed25519",
"-N", "",
"-f", key_path
])
except subprocess.CalledProcessError as e:
print_error(f"Error generating new SSH key: {e}")
return False
print_info(f"Generated new SSH key at {key_path}")
# Ask to copy the key
copy_choice = safe_input("Copy new key to remote now? (y/n): ")
if not copy_choice or not copy_choice.lower().startswith('y'):
return True
from tabulate import tabulate
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid"))
try:
ssh_copy_cmd = ["ssh-copy-id", "-i", key_path]
if port != 22:
ssh_copy_cmd += ["-p", str(port)]
ssh_copy_cmd.append(f"{user}@{hostname}")
subprocess.check_call(ssh_copy_cmd)
print_info("New key successfully copied to remote server.")
return True
except subprocess.CalledProcessError as e:
print_error(f"Error copying new key: {e}")
return False
# 2) We need to correlate row => actual config block.
# We'll replicate the logic that build_host_list_table uses.
all_blocks = []
import glob
for cfile in glob.glob(os.path.join(conf_dir, "*", "config")):
all_blocks.extend(load_config_file(cfile))
def update_config_with_key(config_path: Path, new_key_path: str) -> bool:
"""Update the SSH config file with the new identity file."""
try:
with open(config_path, 'r') as f:
config_lines = [
line.rstrip('\n')
for line in f
if not line.strip().lower().startswith('identityfile')
]
config_lines.append(f" IdentityFile {new_key_path}")
with open(config_path, 'w') as f:
f.write('\n'.join(config_lines) + '\n')
print_info(f"Updated config file with new IdentityFile: {new_key_path}")
return True
except Exception as e:
print_error(f"Failed to update config file: {e}")
return False
results = await gather_host_info(all_blocks)
sorted_rows = sort_by_ip(results)
# sorted_rows is a list of 7-tuples:
# (host_label, user, colored_port, hostname, colored_ip, conf_path, raw_ip)
choice = safe_input("Enter the row number or the Host label to regenerate: ")
if choice is None:
return
choice = choice.strip()
def find_target_host(sorted_rows: List[Tuple], choice: str) -> Optional[Tuple]:
"""Find the target host based on user choice."""
if not choice:
print_error("No choice given.")
return
target_tuple = None
return None
if choice.isdigit():
row_idx = int(choice)
if row_idx < 1 or row_idx > len(sorted_rows):
print_warning(f"Invalid row number {row_idx}.")
return
target_tuple = sorted_rows[row_idx - 1]
else:
# user typed a label
for t in sorted_rows:
if t[0] == choice: # t[0] is host_label
target_tuple = t
break
if not target_tuple:
print_warning(f"No matching host label '{choice}' found.")
return
return None
return sorted_rows[row_idx - 1]
# User typed a label
for row in sorted_rows:
if row[0] == choice: # row[0] is host_label
return row
print_warning(f"No matching host label '{choice}' found.")
return None
# 3) Retrieve the full config block so we have Port, IdentityFile, etc.
host_label = target_tuple[0]
hostname = target_tuple[3] # (host_label, user, colored_port, hostname, ...)
user = target_tuple[1]
# find that block
found_block = None
for b in all_blocks:
if b.get("Host") == host_label:
found_block = b
break
if not found_block:
print_warning(f"No config block found for '{host_label}'.")
return
port_str = found_block.get("Port", "22")
port = int(port_str)
identity_file = found_block.get("IdentityFile", "")
# If missing or "N/A", we can't regenerate
if not identity_file or identity_file == "N/A":
print_error("No IdentityFile found in config; cannot regenerate.")
return
# 4) Remove old local key files
expanded_key = os.path.expanduser(identity_file)
pub_path = expanded_key + ".pub"
old_pub_data = ""
if os.path.isfile(pub_path):
try:
with open(pub_path, "r") as f:
old_pub_data = f.read().rstrip("\n")
except Exception as e:
print_warning(f"Could not read old pub key: {e}")
print_info("Removing old key files locally...")
for path in [expanded_key, pub_path]:
def remove_key_files(key_paths: List[str]) -> None:
"""Remove SSH key files."""
for path in key_paths:
if os.path.isfile(path):
try:
os.remove(path)
@ -120,67 +123,196 @@ async def regenerate_key(conf_dir):
except Exception as e:
print_warning(f"Could not remove {path}: {e}")
# 5) Generate new key
print_info("Generating new ed25519 SSH key...")
new_key_path = expanded_key
cmd = ["ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-f", new_key_path]
try:
subprocess.check_call(cmd)
print_info(f"Generated new SSH key at {new_key_path}")
except subprocess.CalledProcessError as e:
print_error(f"Error generating new SSH key: {e}")
return
async def regenerate_key(conf_dir: str) -> bool:
"""
Regenerate the SSH key for a chosen host by:
1) Displaying the unified table of hosts
2) Letting you pick a row number or host label
3) Reading/deleting any existing local keys
4) Generating a new key
5) Optionally copying it to the remote
6) Removing the old pub key from the remote authorized_keys if present
Returns True if key was successfully regenerated, False otherwise.
"""
print_info("Regenerate Key - Step 1: Show current hosts...\n")
# 6) Copy the new key to remote if user wants
copy_choice = safe_input("Copy new key to remote now? (y/n): ")
if copy_choice and copy_choice.lower().startswith('y'):
ssh_copy_cmd = ["ssh-copy-id", "-i", new_key_path]
if port != 22:
ssh_copy_cmd += ["-p", str(port)]
ssh_copy_cmd.append(f"{user}@{hostname}")
# Get host list
headers, final_data = await build_host_list_table(conf_dir)
if not final_data:
print_warning("No hosts found. Cannot regenerate a key.")
return False
# Display host table
from tabulate import tabulate
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid"))
# Get host blocks and sort them
all_blocks = []
pattern = os.path.join(conf_dir, "*", "config")
for cfile in glob.glob(pattern):
all_blocks.extend(load_config_file(cfile))
results = await gather_host_info(all_blocks)
sorted_rows = sort_by_ip(results)
# Get user choice
choice = safe_input("Enter the row number or the Host label to regenerate: ")
if choice is None:
return False
target_tuple = find_target_host(sorted_rows, choice.strip())
if not target_tuple:
return False
# Get host information
host_label, user, _, hostname, *_ = target_tuple
# Find config block
found_block = next(
(b for b in all_blocks if b.get("Host") == host_label),
None
)
if not found_block:
print_warning(f"No config block found for '{host_label}'.")
return False
port = int(found_block.get("Port", "22"))
identity_file = found_block.get("IdentityFile", "")
# Handle missing identity file
if not identity_file or identity_file == "N/A":
print_warning("No existing SSH key found in the configuration.")
gen_choice = safe_input("Would you like to generate a new key? (y/n): ")
if not gen_choice or not gen_choice.lower().startswith('y'):
return False
# Set up new key path and generate key
host_dir = Path(conf_dir) / host_label
host_dir.mkdir(mode=0o700, exist_ok=True)
new_key_path = str(host_dir / "id_ed25519")
if not generate_new_key(new_key_path, user, hostname, port):
return False
# Update config with new key
config_path = host_dir / "config"
return update_config_with_key(config_path, new_key_path)
# Handle existing key regeneration
expanded_key = os.path.expanduser(identity_file)
pub_path = expanded_key + ".pub"
old_pub_data = ""
# Try to read old public key
if os.path.isfile(pub_path):
try:
subprocess.check_call(ssh_copy_cmd)
print_info("New key successfully copied to remote server.")
except subprocess.CalledProcessError as e:
print_error(f"Error copying new key: {e}")
old_pub_data = Path(pub_path).read_text().rstrip("\n")
except Exception as e:
print_warning(f"Could not read old pub key: {e}")
# 7) Remove old key from authorized_keys if we had old_pub_data
# Remove old key files
print_info("Removing old key files locally...")
remove_key_files([expanded_key, pub_path])
# Generate new key
if not generate_new_key(expanded_key, user, hostname, port):
return False
# Remove old key from remote if we had it
if old_pub_data:
print_info("Attempting to remove old key from remote authorized_keys...")
await remove_old_key_remote(old_pub_data, user, hostname, port)
else:
print_warning("No old pub key data found locally, skipping remote removal.")
return True
async def remove_old_key_remote(old_pub_data, user, hostname, port):
async def remove_old_key_remote(old_pub_data: str, user: str, hostname: str, port: int) -> bool:
"""
Checks and removes the exact matching line from remote authorized_keys
via grep -Fxq / grep -vFx.
Remove the old public key from remote authorized_keys file.
Returns True if successful, False otherwise.
"""
check_cmd = [
"ssh", "-o", "StrictHostKeyChecking=no",
# Escape the public key data for shell safety
escaped_key = old_pub_data.replace('"', '\\"')
# First check if authorized_keys exists
check_file_cmd = [
"ssh",
"-o", "StrictHostKeyChecking=no",
"-p", str(port),
f"{user}@{hostname}",
f'grep -Fxq "{old_pub_data}" ~/.ssh/authorized_keys'
"test -f ~/.ssh/authorized_keys && echo 'exists'"
]
found_key = False
try:
subprocess.check_call(check_cmd)
found_key = True
except subprocess.CalledProcessError:
pass
result = subprocess.run(check_file_cmd, capture_output=True, text=True)
if result.returncode != 0 or 'exists' not in result.stdout:
print_warning("No authorized_keys file found on remote host.")
return False
except subprocess.CalledProcessError as e:
print_error(f"Error checking authorized_keys: {e}")
return False
if not found_key:
print_warning("No old key found on remote authorized_keys.")
return
# Create a temporary file for the sed script
temp_script = """#!/bin/bash
set -e
KEYS_FILE="$HOME/.ssh/authorized_keys"
TEMP_FILE="$HOME/.ssh/authorized_keys.tmp"
grep -Fxv "$1" "$KEYS_FILE" > "$TEMP_FILE"
mv "$TEMP_FILE" "$KEYS_FILE"
chmod 600 "$KEYS_FILE"
"""
# Create a temporary script on the remote host
setup_cmd = [
"ssh",
"-o", "StrictHostKeyChecking=no",
"-p", str(port),
f"{user}@{hostname}",
f'cat > ~/.ssh/remove_key.sh << \'EOF\'\n{temp_script}\nEOF\n'
f'chmod +x ~/.ssh/remove_key.sh'
]
try:
subprocess.run(setup_cmd, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
print_error(f"Failed to create temporary script: {e}")
return False
# Execute the script with the key
remove_cmd = [
"ssh", "-o", "StrictHostKeyChecking=no",
"ssh",
"-o", "StrictHostKeyChecking=no",
"-p", str(port),
f"{user}@{hostname}",
f'grep -vFx "{old_pub_data}" ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys'
f'~/.ssh/remove_key.sh "{escaped_key}"'
]
try:
subprocess.check_call(remove_cmd)
subprocess.run(remove_cmd, check=True, capture_output=True)
print_info("Old public key removed from remote authorized_keys.")
except subprocess.CalledProcessError:
print_warning("Failed to remove old key from remote authorized_keys (permissions or other error).")
# Clean up the temporary script
cleanup_cmd = [
"ssh",
"-o", "StrictHostKeyChecking=no",
"-p", str(port),
f"{user}@{hostname}",
"rm -f ~/.ssh/remove_key.sh"
]
subprocess.run(cleanup_cmd, check=True, capture_output=True)
return True
except subprocess.CalledProcessError as e:
print_error(f"Failed to remove old key: {e}")
# Try to clean up even if removal failed
try:
subprocess.run(cleanup_cmd, check=True, capture_output=True)
except:
pass
return False
except Exception as e:
print_error(f"Unexpected error removing old key: {e}")
return False