Compare commits

...

2 commits

Author SHA1 Message Date
ffb1f7e204 refactor 2025-03-08 00:43:48 -06:00
414266eefc refactored code 2025-03-08 00:43:33 -06:00
8 changed files with 836 additions and 601 deletions

View file

@ -2,109 +2,211 @@
import os
import subprocess
import glob
from pathlib import Path
from typing import Optional, List, Dict, Set, Tuple
from dataclasses import dataclass
from .utils import print_error, print_warning, print_info, safe_input
from .config import CONF_DIR
def add_host(conf_dir):
@dataclass
class HostConfig:
label: str
hostname: str
user: str = "root"
port: str = "22"
identity_file: str = ""
def to_config_lines(self) -> List[str]:
"""Convert host configuration to SSH config file lines."""
lines = [
f"Host {self.label}",
f" HostName {self.hostname}",
f" User {self.user}",
f" Port {self.port}"
]
if self.identity_file:
lines.append(f" IdentityFile {self.identity_file}")
return lines
def get_existing_hosts(conf_dir: str) -> Tuple[Set[str], Dict[str, str]]:
"""
Get existing host labels and hostnames from config files.
Returns (host_labels, hostname_to_label) where:
- host_labels is a set of existing host labels
- hostname_to_label maps hostnames to their labels
"""
host_labels = set()
hostname_to_label = {}
pattern = os.path.join(conf_dir, "*", "config")
for config_file in glob.glob(pattern):
try:
with open(config_file, 'r') as f:
lines = f.readlines()
current_label = None
for line in lines:
line = line.strip()
if not line or line.startswith('#'):
continue
if line.lower().startswith('host '):
labels = line.split()[1:]
for label in labels:
if '*' not in label:
current_label = label
host_labels.add(label)
break
elif current_label and line.lower().startswith('hostname '):
hostname = line.split(None, 1)[1].strip()
hostname_to_label[hostname] = current_label
except Exception as e:
print_warning(f"Error reading config file {config_file}: {e}")
continue
return host_labels, hostname_to_label
def generate_ssh_key(key_path: Path) -> bool:
"""Generate a new ED25519 SSH key pair."""
try:
subprocess.check_call([
"ssh-keygen",
"-q",
"-t", "ed25519",
"-N", "",
"-f", str(key_path)
])
print_info(f"Generated new SSH key at {key_path}")
return True
except subprocess.CalledProcessError as e:
print_error(f"Error generating SSH key: {e}")
return False
def copy_ssh_key(key_path: Path, user: str, hostname: str, port: str) -> bool:
"""Copy SSH public key to remote server."""
try:
cmd = ["ssh-copy-id", "-i", str(key_path)]
if port != "22":
cmd.extend(["-p", port])
cmd.append(f"{user}@{hostname}")
subprocess.check_call(cmd)
print_info("Key successfully copied to remote server.")
return True
except subprocess.CalledProcessError as e:
print_error(f"Error copying key to server: {e}")
return False
def write_config_file(config_path: Path, config_lines: List[str]) -> bool:
"""Write SSH config lines to file."""
try:
config_path.write_text("\n".join(config_lines) + "\n")
print_info(f"Created/updated config at: {config_path}")
return True
except Exception as e:
print_error(f"Failed to write config to {config_path}: {e}")
return False
def add_host(conf_dir: str) -> bool:
"""
Interactive prompt to create a new SSH host in ~/.ssh/conf/<label>/config.
Offers to generate a new SSH key pair (ed25519) quietly (-q),
and then prompt to copy that key to the remote server via ssh-copy-id.
Returns True if host was added successfully, False otherwise.
"""
print_info("Adding a new SSH host...")
host_label = safe_input("Enter Host label (e.g. myserver): ")
if host_label is None:
return # User canceled (Ctrl+C)
host_label = host_label.strip()
if not host_label:
print_error("Host label cannot be empty.")
return
# Get existing hosts to check for duplicates
existing_labels, hostname_to_label = get_existing_hosts(conf_dir)
hostname = safe_input("Enter HostName (IP or domain): ")
if hostname is None:
return
hostname = hostname.strip()
if not hostname:
print_error("HostName cannot be empty.")
return
# Get host label
while True:
host_label = safe_input("Enter Host label (e.g. myserver): ")
if not host_label or host_label is None:
print_error("Host label cannot be empty.")
return False
user = safe_input("Enter username (default: 'root'): ")
host_label = host_label.strip()
if host_label in existing_labels:
print_error(f"A host with label '{host_label}' already exists. Please choose a different label.")
continue
break
# Get hostname
while True:
hostname = safe_input("Enter HostName (IP or domain): ")
if not hostname or hostname is None:
print_error("HostName cannot be empty.")
return False
hostname = hostname.strip()
if hostname in hostname_to_label:
existing_label = hostname_to_label[hostname]
print_error(f"This hostname is already configured for host '{existing_label}'. Please use a different hostname.")
continue
break
# Get optional parameters
user = safe_input("Enter username (default: 'root'): ") or "root"
if user is None:
return
user = user.strip() or "root"
return False
port = safe_input("Enter SSH port (default: 22): ")
port = safe_input("Enter SSH port (default: 22): ") or "22"
if port is None:
return
port = port.strip() or "22"
return False
# Create subdirectory: ~/.ssh/conf/<label>
host_dir = os.path.join(conf_dir, host_label)
if os.path.exists(host_dir):
# Create host configuration
host_config = HostConfig(
label=host_label,
hostname=hostname,
user=user.strip(),
port=port.strip()
)
# Setup directory structure
host_dir = Path(conf_dir) / host_label
if host_dir.exists():
print_warning(f"Directory {host_dir} already exists; continuing anyway.")
else:
os.makedirs(host_dir, mode=0o700, exist_ok=True)
print_info(f"Created directory: {host_dir}")
host_dir.mkdir(mode=0o700, exist_ok=True)
print_info(f"Created directory: {host_dir}")
config_path = os.path.join(host_dir, "config")
if os.path.exists(config_path):
config_path = host_dir / "config"
if config_path.exists():
print_warning(f"Config file already exists: {config_path}; it will be overwritten/updated.")
# Handle SSH key generation
gen_key_choice = safe_input("Generate a new ed25519 SSH key for this host? (y/n): ")
if gen_key_choice is None:
return
gen_key_choice = gen_key_choice.lower().strip()
return False
identity_file = ""
if gen_key_choice == 'y':
key_path = os.path.join(host_dir, "id_ed25519")
if os.path.exists(key_path):
if gen_key_choice.lower().strip() == 'y':
key_path = host_dir / "id_ed25519"
if key_path.exists():
print_warning(f"{key_path} already exists. Skipping generation.")
identity_file = key_path
host_config.identity_file = str(key_path)
else:
cmd = ["ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-f", key_path]
try:
subprocess.check_call(cmd)
print_info(f"Generated new SSH key at {key_path}")
identity_file = key_path
if generate_ssh_key(key_path):
host_config.identity_file = str(key_path)
# Prompt to copy the key
copy_key = safe_input("Would you like to copy this key to the server now? (y/n): ")
if copy_key is None:
return
return False
if copy_key.lower().strip() == 'y':
ssh_copy_cmd = ["ssh-copy-id", "-i", key_path]
if port != "22":
ssh_copy_cmd += ["-p", port]
ssh_copy_cmd.append(f"{user}@{hostname}")
try:
subprocess.check_call(ssh_copy_cmd)
print_info("Key successfully copied to remote server.")
except subprocess.CalledProcessError as e:
print_error(f"Error copying key to server: {e}")
except subprocess.CalledProcessError as e:
print_error(f"Error generating SSH key: {e}")
copy_ssh_key(key_path, host_config.user, host_config.hostname, host_config.port)
else:
# Handle existing key
existing_key = safe_input("Enter existing IdentityFile path (or leave empty to skip): ")
if existing_key is None:
return
existing_key = existing_key.strip()
if existing_key:
identity_file = os.path.expanduser(existing_key)
return False
config_lines = [
f"Host {host_label}",
f" HostName {hostname}",
f" User {user}",
f" Port {port}"
]
if identity_file:
config_lines.append(f" IdentityFile {identity_file}")
if existing_key.strip():
host_config.identity_file = os.path.expanduser(existing_key.strip())
try:
with open(config_path, "w") as f:
for line in config_lines:
f.write(line + "\n")
print_info(f"Created/updated config at: {config_path}")
except Exception as e:
print_error(f"Failed to write config to {config_path}: {e}")
# Write the config file
return write_config_file(config_path, host_config.to_config_lines())

115
cli.py
View file

@ -2,6 +2,8 @@
import os
import asyncio
from typing import Callable, Dict, Optional
from functools import wraps
from .utils import print_info, print_error, print_warning, Colors, safe_input
from .config import SSH_DIR, CONF_DIR, SOCKET_DIR, MAIN_CONFIG, DEFAULT_CONFIG_CONTENT
from .add_host import add_host
@ -10,62 +12,101 @@ from .list_hosts import list_hosts
from .regen_key import regenerate_key
from .remove_host import remove_host
def ensure_ssh_setup():
def ensure_ssh_setup() -> None:
"""
Creates ~/.ssh, ~/.ssh/conf, and ~/.ssh/s if missing,
and writes a default ~/.ssh/config if it doesn't exist.
"""
if not os.path.isdir(SSH_DIR):
os.makedirs(SSH_DIR, mode=0o700, exist_ok=True)
print_info(f"Created directory: {SSH_DIR}")
directories = [
(SSH_DIR, "SSH directory"),
(CONF_DIR, "configuration directory"),
(SOCKET_DIR, "socket directory")
]
if not os.path.isdir(CONF_DIR):
os.makedirs(CONF_DIR, mode=0o700, exist_ok=True)
print_info(f"Created directory: {CONF_DIR}")
if not os.path.isdir(SOCKET_DIR):
os.makedirs(SOCKET_DIR, mode=0o700, exist_ok=True)
print_info(f"Created directory: {SOCKET_DIR}")
for directory, description in directories:
if not os.path.isdir(directory):
os.makedirs(directory, mode=0o700, exist_ok=True)
print_info(f"Created {description}: {directory}")
if not os.path.isfile(MAIN_CONFIG):
with open(MAIN_CONFIG, "w") as f:
f.write(DEFAULT_CONFIG_CONTENT)
print_info(f"Created default SSH config at: {MAIN_CONFIG}")
def main():
def async_handler(func: Callable) -> Callable:
"""Decorator to handle async functions in the command dispatch"""
@wraps(func)
def wrapper(*args, **kwargs):
return asyncio.run(func(*args, **kwargs))
return wrapper
class SSHManager:
def __init__(self):
self.commands: Dict[str, tuple[Callable, str]] = {
"1": (self.list_hosts, "List Hosts"),
"2": (self.add_host, "Add a Host"),
"3": (self.edit_host, "Edit a Host"),
"4": (self.regenerate_key, "Regenerate Key"),
"5": (self.remove_host, "Remove Host"),
"6": (self.exit_app, "Exit")
}
@async_handler
async def list_hosts(self) -> None:
await list_hosts(CONF_DIR)
def add_host(self) -> None:
add_host(CONF_DIR)
@async_handler
async def edit_host(self) -> None:
await edit_host(CONF_DIR)
@async_handler
async def regenerate_key(self) -> None:
await regenerate_key(CONF_DIR)
@async_handler
async def remove_host(self) -> None:
await remove_host(CONF_DIR)
def exit_app(self) -> None:
print_info("Exiting...")
raise SystemExit(0)
def display_menu(self) -> None:
print("\n" + f"{Colors.CYAN}{Colors.BOLD}SSH Config Manager Menu{Colors.RESET}")
for key, (_, description) in self.commands.items():
print(f"{key}. {description}")
def handle_command(self, choice: str) -> bool:
if choice not in self.commands:
print_error("Invalid choice. Please select 1 through 6.")
return True
try:
self.commands[choice][0]()
return choice != "6"
except SystemExit:
return False
except Exception as e:
print_error(f"Error executing command: {str(e)}")
return True
def main() -> int:
ensure_ssh_setup()
manager = SSHManager()
# Display the server list on first load
asyncio.run(list_hosts(CONF_DIR))
manager.list_hosts()
while True:
print("\n" + f"{Colors.CYAN}{Colors.BOLD}SSH Config Manager Menu{Colors.RESET}")
print("1. List Hosts")
print("2. Add a Host")
print("3. Edit a Host")
print("4. Regenerate Key")
print("5. Remove Host")
print("6. Exit")
manager.display_menu()
choice = safe_input("Select an option (1-6): ")
if choice is None:
continue # User pressed Ctrl+C => safe_input returns None => re-show menu
continue
choice = choice.strip()
if choice == '1':
asyncio.run(list_hosts(CONF_DIR))
elif choice == '2':
add_host(CONF_DIR)
elif choice == '3':
edit_host(CONF_DIR)
elif choice == '4':
asyncio.run(regenerate_key(CONF_DIR))
elif choice == '5':
asyncio.run(remove_host(CONF_DIR))
elif choice == '6':
print_info("Exiting...")
if not manager.handle_command(choice.strip()):
break
else:
print_error("Invalid choice. Please select 1 through 6.")
return 0

View file

@ -1,15 +1,24 @@
# ssh_manager/config.py
import os
from pathlib import Path
from typing import Final
import subprocess
# Paths
SSH_DIR = os.path.expanduser("~/.ssh")
CONF_DIR = os.path.join(SSH_DIR, "conf")
SOCKET_DIR = os.path.join(SSH_DIR, "s")
MAIN_CONFIG = os.path.join(SSH_DIR, "config")
# Use Path for more efficient path handling
HOME: Final[Path] = Path.home()
SSH_DIR: Final[Path] = HOME / ".ssh"
CONF_DIR: Final[Path] = SSH_DIR / "conf"
SOCKET_DIR: Final[Path] = SSH_DIR / "s"
MAIN_CONFIG: Final[Path] = SSH_DIR / "config"
# Validate paths on import
for path in (SSH_DIR, CONF_DIR, SOCKET_DIR):
if not path.exists():
path.mkdir(mode=0o700, parents=True, exist_ok=True)
# Default SSH config content if ~/.ssh/config is missing
DEFAULT_CONFIG_CONTENT = """###
DEFAULT_CONFIG_CONTENT: Final[str] = """###
#Local ssh
###
@ -30,3 +39,28 @@ Host *
ControlPersist 72000
ControlPath ~/.ssh/s/%C
"""
# Export string versions for backward compatibility
SSH_DIR_STR: Final[str] = str(SSH_DIR)
CONF_DIR_STR: Final[str] = str(CONF_DIR)
SOCKET_DIR_STR: Final[str] = str(SOCKET_DIR)
MAIN_CONFIG_STR: Final[str] = str(MAIN_CONFIG)
def validate_key_path(key_path: Path) -> bool:
if not key_path.exists():
return False
if not key_path.is_dir():
return False
if not os.access(key_path, os.W_OK):
return False
return True
def update_config_with_key(key_path: Path) -> bool:
if not validate_key_path(key_path):
return False # Early return if path validation fails
try:
subprocess.check_call([...])
except subprocess.CalledProcessError as e:
print_error(f"Error generating new SSH key: {e}")
return False

View file

@ -2,98 +2,31 @@ import os
import asyncio
from collections import OrderedDict
from .utils import print_error, print_warning, print_info, safe_input
from .list_hosts import list_hosts, load_config_file, check_ssh_port
from .list_hosts import (
build_host_list_table,
load_config_file,
gather_host_info,
sort_by_ip
)
async def get_all_host_blocks(conf_dir):
"""
Similar to list_hosts, but returns the list of host blocks + a table of results.
We'll build a table ourselves so we can map row numbers to actual host labels.
"""
import glob
import socket
pattern = os.path.join(conf_dir, "*", "config")
conf_files = sorted(glob.glob(pattern))
all_blocks = []
for conf_file in conf_files:
blocks = load_config_file(conf_file)
all_blocks.extend(blocks)
# If no blocks found, return empty
if not all_blocks:
return []
# We want to do a partial version of check_host to get row data
# so we can display the table right here and keep track of each blocks host label.
# But let's do it similarly to list_hosts:
table_rows = []
for idx, b in enumerate(all_blocks, start=1):
host_label = b.get("Host", "N/A")
hostname = b.get("HostName", "N/A")
user = b.get("User", "N/A")
port = int(b.get("Port", "22"))
identity_file = b.get("IdentityFile", "N/A")
# Identity check
if identity_file != "N/A":
expanded_identity = os.path.expanduser(identity_file)
identity_exists = os.path.isfile(expanded_identity)
else:
identity_exists = False
# IP resolution
try:
ip_address = socket.gethostbyname(hostname)
except socket.error:
ip_address = None
# Port check
if ip_address:
port_open = await asyncio.wait_for(check_ssh_port(ip_address, port), timeout=1)
else:
port_open = False
# Colors for display (optional, or we can keep it simple):
ip_display = f"\033[0;32m{ip_address}\033[0m" if ip_address else "\033[0;31mN/A\033[0m"
port_display = f"\033[0;32m{port}\033[0m" if port_open else f"\033[0;31m{port}\033[0m"
identity_disp= f"\033[0;32m{identity_file}\033[0m" if identity_exists else f"\033[0;31m{identity_file}\033[0m"
row = [
idx,
host_label,
user,
port_display,
hostname,
ip_display,
identity_disp
]
table_rows.append(row)
# Print the table
from tabulate import tabulate
headers = ["No.", "Host", "User", "Port", "HostName", "IP Address", "IdentityFile"]
print("\nSSH Conf Subdirectory Host List")
print(tabulate(table_rows, headers=headers, tablefmt="grid"))
return all_blocks
def edit_host(conf_dir):
async def edit_host(conf_dir):
"""
Let the user update fields for an existing host in ~/.ssh/conf/<label>/config.
The user may type either the row number OR the actual host label.
1) Display the unified table (No. | Host | User | Port | HostName | IP Address | Conf Directory)
2) Prompt row number or host label
3) Rewrite config with updated fields
"""
# 1) Gather + display the current host list
print_info("Here is the current list of hosts:\n")
all_blocks = asyncio.run(get_all_host_blocks(conf_dir))
if not all_blocks:
print_warning("No hosts found to edit.")
headers, final_data = await build_host_list_table(conf_dir)
if not final_data:
print_warning("No hosts to edit.")
return
# 2) Prompt for which host to edit (by label or row number)
from tabulate import tabulate
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid"))
choice = safe_input("Enter the row number or the Host label to edit: ")
if choice is None:
return # user canceled (Ctrl+C)
@ -102,43 +35,47 @@ def edit_host(conf_dir):
print_error("Host label or row number cannot be empty.")
return
# Check if user typed a digit -> row number
target_block = None
# We replicate the approach to find the matching block
all_blocks = []
import glob
for cfile in glob.glob(os.path.join(conf_dir, "*", "config")):
all_blocks.extend(load_config_file(cfile))
results = await gather_host_info(all_blocks)
sorted_rows = sort_by_ip(results)
target_tuple = None
if choice.isdigit():
row_idx = int(choice)
# Validate index
if row_idx < 1 or row_idx > len(all_blocks):
print_warning(f"Invalid row number {row_idx}.")
idx = int(choice)
if idx < 1 or idx > len(sorted_rows):
print_warning(f"Row number {idx} is invalid.")
return
target_block = all_blocks[row_idx - 1] # zero-based
target_tuple = sorted_rows[idx - 1]
else:
# The user typed a host label
# We must search all_blocks for a matching Host
for b in all_blocks:
if b.get("Host") == choice:
target_block = b
for t in sorted_rows:
if t[0] == choice: # t[0] => host_label
target_tuple = t
break
if not target_block:
print_warning(f"No matching host label '{choice}' found.")
if not target_tuple:
print_warning(f"No matching Host '{choice}' found in the table.")
return
# Now we have a target_block with existing data
host_label = target_block.get("Host", "")
if not host_label:
print_warning("This host block has no label. Cannot edit.")
host_label = target_tuple[0]
# find the config block
found_block = None
for b in all_blocks:
if b.get("Host") == host_label:
found_block = b
break
if not found_block:
print_warning(f"No config block found for '{host_label}'.")
return
# Derive the config path
host_dir = os.path.join(conf_dir, host_label)
config_path = os.path.join(host_dir, "config")
if not os.path.isfile(config_path):
print_warning(f"No config file found at {config_path}; cannot edit this host.")
return
old_hostname = target_block.get("HostName", "")
old_user = target_block.get("User", "")
old_port = target_block.get("Port", "22")
old_identity = target_block.get("IdentityFile", "")
old_hostname = found_block.get("HostName", "")
old_user = found_block.get("User", "")
old_port = found_block.get("Port", "22")
old_identity = found_block.get("IdentityFile", "")
print_info("Leave a field blank to keep its current value.")
@ -167,6 +104,8 @@ def edit_host(conf_dir):
final_port = new_port if new_port else old_port
final_ident = new_ident if new_ident else old_identity
# Overwrite the file
config_path = os.path.join(conf_dir, host_label, "config")
new_config_lines = [
f"Host {host_label}",
f" HostName {final_hostname}",

View file

@ -5,33 +5,27 @@ import glob
import socket
import asyncio
import ipaddress
import subprocess
from pathlib import Path
from typing import List, Dict, Tuple, Optional, OrderedDict as OrderedDictType
from functools import lru_cache
from tabulate import tabulate
from collections import OrderedDict
from .utils import print_warning, print_error, Colors
from .config import CONF_DIR
async def check_ssh_port(ip_address, port):
"""
Attempt to open an SSH connection to see if the port is open.
Returns True if successful, False otherwise.
"""
# Cache DNS lookups
@lru_cache(maxsize=128, typed=True)
def resolve_hostname(hostname: str) -> Optional[str]:
try:
reader, writer = await asyncio.wait_for(
asyncio.open_connection(ip_address, port), timeout=1
)
writer.close()
await writer.wait_closed()
return True
except:
return False
return socket.gethostbyname(hostname)
except socket.error:
return None
def load_config_file(file_path):
"""
Parse a single SSH config file and return a list of host blocks.
Each block is an OrderedDict with keys like 'Host', 'HostName', etc.
"""
blocks = []
host_data = None
def load_config_file(file_path: str) -> List[OrderedDictType]:
blocks: List[OrderedDictType] = []
host_data: Optional[OrderedDictType] = None
try:
with open(file_path, 'r') as f:
@ -42,11 +36,9 @@ def load_config_file(file_path):
for line in lines:
stripped_line = line.strip()
# Skip empty lines and comments
if not stripped_line or stripped_line.startswith('#'):
continue
# Start of a new Host block
if stripped_line.lower().startswith('host '):
host_labels = stripped_line.split()[1:]
for label in host_labels:
@ -55,7 +47,7 @@ def load_config_file(file_path):
blocks.append(host_data)
host_data = OrderedDict({'Host': label})
break
elif host_data:
elif host_data is not None:
if ' ' in stripped_line:
key, value = stripped_line.split(None, 1)
host_data[key] = value.strip()
@ -64,111 +56,137 @@ def load_config_file(file_path):
blocks.append(host_data)
return blocks
async def check_host(host):
async def gather_host_info(all_host_blocks: List[OrderedDictType]) -> List[Tuple]:
"""
Given a host block, resolve IP, check SSH port, etc.
Returns a tuple of:
1) Host label
2) User
3) Port (colored if open)
4) HostName
5) IP Address (colored if resolved)
6) Conf Directory (green if has IdentityFile, else no color)
7) raw_ip (uncolored string for sorting)
Given a list of host blocks, gather full info:
- resolved IP
- 'Conf Directory' coloring if IdentityFile != 'N/A'
Returns a list of 7-tuples:
(host_label, user, port, hostname, colored_ip, conf_path_display, raw_ip)
"""
host_label = host.get('Host', 'N/A')
hostname = host.get('HostName', 'N/A')
user = host.get('User', 'N/A')
port = int(host.get('Port', '22'))
identity_file = host.get('IdentityFile', 'N/A')
results = []
# Resolve IP
try:
raw_ip = socket.gethostbyname(hostname) # uncolored
colored_ip = f"{Colors.GREEN}{raw_ip}{Colors.RESET}"
except socket.error:
raw_ip = "N/A"
colored_ip = f"{Colors.RED}N/A{Colors.RESET}"
async def process_block(h: OrderedDictType) -> Tuple:
host_label: str = h.get('Host', 'N/A')
hostname: str = h.get('HostName', 'N/A')
user: str = h.get('User', 'N/A')
port: str = h.get('Port', '22') # Keep as string since we're not testing it
identity_file: str = h.get('IdentityFile', 'N/A')
# Check port
if raw_ip != "N/A":
port_open = await check_ssh_port(raw_ip, port)
colored_port = (
f"{Colors.GREEN}{port}{Colors.RESET}" if port_open else f"{Colors.RED}{port}{Colors.RESET}"
# Resolve IP using cached function
raw_ip = resolve_hostname(hostname) or "N/A"
# Check if the IP is reachable
def is_ip_reachable(ip: str) -> bool:
try:
subprocess.run(["ping", "-c", "1", "-W", "1", ip], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
return True
except subprocess.CalledProcessError:
return False
# Determine IP color
if raw_ip != "N/A" and is_ip_reachable(raw_ip):
colored_ip = f"{Colors.GREEN}{raw_ip}{Colors.RESET}"
else:
colored_ip = raw_ip # No color if not reachable
# Determine hostname color
try:
ipaddress.ip_address(hostname)
colored_hostname = hostname # No color if hostname is an IP
except ValueError:
if raw_ip != "N/A":
colored_hostname = f"{Colors.GREEN}{hostname}{Colors.RESET}"
else:
colored_hostname = f"{Colors.RED}{hostname}{Colors.RESET}"
# Conf Directory = ~/.ssh/conf/<host_label>
conf_path = f"~/.ssh/conf/{host_label}"
conf_path_display = (
f"{Colors.GREEN}{conf_path}{Colors.RESET}"
if identity_file != 'N/A'
else conf_path
)
else:
colored_port = f"{Colors.RED}{port}{Colors.RESET}"
# Conf Directory = ~/.ssh/conf/<host_label>
conf_path = f"~/.ssh/conf/{host_label}"
# If there's an IdentityFile, color the conf path green
if identity_file != 'N/A':
conf_path_display = f"{Colors.GREEN}{conf_path}{Colors.RESET}"
else:
conf_path_display = conf_path
return (
host_label,
user,
port, # Port is now uncolored
colored_hostname,
colored_ip,
conf_path_display,
raw_ip # uncolored IP for sorting
)
# Return the data plus the uncolored IP for sorting
return (
host_label,
user,
colored_port,
hostname,
colored_ip,
conf_path_display,
raw_ip # for sorting
)
# Process blocks concurrently with semaphore to limit concurrent connections
sem = asyncio.Semaphore(10) # Limit concurrent connections
async def process_with_semaphore(block):
async with sem:
return await process_block(block)
async def list_hosts(conf_dir):
tasks = [process_with_semaphore(b) for b in all_host_blocks]
results = await asyncio.gather(*tasks)
return results
@lru_cache(maxsize=1)
def parse_ip(ip_str: str) -> Optional[ipaddress.IPv4Address]:
"""
List out all hosts found in ~/.ssh/conf/*/config, sorted by IP in ascending order.
Columns: No., Host, User, Port, HostName, IP Address, Conf Directory
Convert a string IP to an ipaddress object for sorting.
Returns None if invalid or 'N/A'.
"""
pattern = os.path.join(conf_dir, "*", "config")
try:
return ipaddress.ip_address(ip_str)
except ValueError:
return None
def sort_by_ip(results: List[Tuple]) -> List[Tuple]:
"""
Sort the 7-tuples by IP ascending, with 'N/A' last.
"""
def sort_key(row):
raw_ip = row[-1]
ip_obj = parse_ip(raw_ip)
return (ip_obj is None, ip_obj or ipaddress.IPv4Address('0.0.0.0'))
return sorted(results, key=sort_key)
async def build_host_list_table(conf_dir: str) -> Tuple[List[str], List[List]]:
"""
Gathers + sorts all hosts in conf_dir by IP ascending.
Returns (headers, final_table_rows), each row omitting the raw_ip.
"""
pattern = os.path.join(conf_dir, "*", "config")
conf_files = sorted(glob.glob(pattern))
all_host_blocks = []
all_host_blocks: List[OrderedDictType] = []
for conf_file in conf_files:
blocks = load_config_file(conf_file)
all_host_blocks.extend(blocks)
headers = ["No.", "Host", "User", "Port", "HostName", "IP Address", "Conf Directory"]
if not all_host_blocks:
return headers, []
results = await gather_host_info(all_host_blocks)
sorted_rows = sort_by_ip(results)
# Build final table
final_data = [
[idx] + list(row[:-1])
for idx, row in enumerate(sorted_rows, start=1)
]
return headers, final_data
async def list_hosts(conf_dir: str) -> None:
"""Display a formatted table of all SSH hosts."""
headers, final_data = await build_host_list_table(conf_dir)
if not final_data:
print_warning("No hosts found. The server list is empty.")
print("\nSSH Conf Subdirectory Host List")
print(tabulate([], headers=headers, tablefmt="grid"))
return
# Gather full data for each host
tasks = [check_host(h) for h in all_host_blocks]
results = await asyncio.gather(*tasks)
# We want to sort by IP ascending. results[i] is a tuple:
# (host_label, user, colored_port, hostname, colored_ip, conf_path, raw_ip)
# We'll parse raw_ip as an ipaddress for sorting. "N/A" => sort to the end.
def parse_ip(ip_str):
try:
return ipaddress.ip_address(ip_str)
except ValueError:
return None
# Convert the results into a list of (ip_obj, original_tuple)
# so we can sort, then rebuild the final data.
sortable = []
for row in results:
raw_ip = row[-1] # last element
ip_obj = parse_ip(raw_ip)
# We'll sort None last by using a sort key that puts (True) after (False)
# e.g. (ip_obj is None, ip_obj)
sortable.append(((ip_obj is None, ip_obj), row))
# Sort by (is_none, ip_obj)
sortable.sort(key=lambda x: x[0])
# Rebuild the final display table, ignoring the raw_ip at the end
final_data = []
for idx, (_, row) in enumerate(sortable, start=1):
# row is (host_label, user, colored_port, hostname, colored_ip, conf_path, raw_ip)
final_data.append([idx] + list(row[:-1])) # omit raw_ip
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid"))

View file

@ -1,146 +1,121 @@
import os
import glob
import subprocess
import asyncio
from collections import OrderedDict
from pathlib import Path
from typing import Optional, List, Dict, Tuple, Any
from .utils import (
print_info,
print_warning,
print_error,
safe_input,
Colors
safe_input
)
from .list_hosts import (
build_host_list_table,
load_config_file,
gather_host_info,
sort_by_ip
)
from .list_hosts import load_config_file, check_ssh_port
async def get_all_host_blocks(conf_dir):
pattern = os.path.join(conf_dir, "*", "config")
conf_files = sorted(glob.glob(pattern))
def validate_key_path(key_path: str) -> bool:
"""Validate that the key path and its directory are valid."""
try:
key_path = Path(key_path)
key_dir = key_path.parent
all_blocks = []
for conf_file in conf_files:
blocks = load_config_file(conf_file)
all_blocks.extend(blocks)
if not key_dir.exists():
key_dir.mkdir(mode=0o700, parents=True)
elif not os.access(str(key_dir), os.W_OK):
print_error(f"No write permission for directory: {key_dir}")
return False
if not all_blocks:
return []
return all_blocks
return True
except Exception as e:
print_error(f"Error validating key path: {e}")
return False
async def regenerate_key(conf_dir):
"""
Menu-driven function to regenerate a key for a selected host:
1) Show the host table with row numbers
2) Let user pick row # or label
3) Read old pub data
4) Remove old local key files
5) Generate new key
6) Copy new key
7) Remove old key from remote authorized_keys (improved logic to detect existence)
"""
print_info("Regenerate Key - Step 1: Show current hosts...\n")
def generate_new_key(key_path: str, user: str, hostname: str, port: int) -> bool:
"""Generate a new ED25519 SSH key and optionally copy it to the remote server."""
if not validate_key_path(key_path):
return False
# 1) Gather host blocks
all_blocks = await get_all_host_blocks(conf_dir)
if not all_blocks:
print_warning("No hosts found. Cannot regenerate a key.")
return
print_info("Generating new ed25519 SSH key...")
try:
subprocess.check_call([
"ssh-keygen",
"-q",
"-t", "ed25519",
"-N", "",
"-f", key_path
])
except subprocess.CalledProcessError as e:
print_error(f"Error generating new SSH key: {e}")
return False
# Display them in a table (similar to list_hosts):
import socket
from tabulate import tabulate
print_info(f"Generated new SSH key at {key_path}")
table_rows = []
for idx, block in enumerate(all_blocks, start=1):
host_label = block.get("Host", "N/A")
hostname = block.get("HostName", "N/A")
user = block.get("User", "N/A")
port = int(block.get("Port", "22"))
identity_file = block.get("IdentityFile", "N/A")
# Ask to copy the key
copy_choice = safe_input("Copy new key to remote now? (y/n): ")
if not copy_choice or not copy_choice.lower().startswith('y'):
return True
# Check IP
try:
ip_address = socket.gethostbyname(hostname)
port_open = await asyncio.wait_for(check_ssh_port(ip_address, port), timeout=1)
except:
ip_address = None
port_open = False
try:
ssh_copy_cmd = ["ssh-copy-id", "-i", key_path]
if port != 22:
ssh_copy_cmd += ["-p", str(port)]
ssh_copy_cmd.append(f"{user}@{hostname}")
ip_disp = f"\033[0;32m{ip_address}\033[0m" if ip_address else "\033[0;31mN/A\033[0m"
port_disp = f"\033[0;32m{port}\033[0m" if port_open else f"\033[0;31m{port}\033[0m"
subprocess.check_call(ssh_copy_cmd)
print_info("New key successfully copied to remote server.")
return True
except subprocess.CalledProcessError as e:
print_error(f"Error copying new key: {e}")
return False
row = [
idx,
host_label,
user,
port_disp,
hostname,
ip_disp,
identity_file
]
table_rows.append(row)
def update_config_with_key(config_path: Path, new_key_path: str) -> bool:
"""Update the SSH config file with the new identity file."""
try:
with open(config_path, 'r') as f:
config_lines = [
line.rstrip('\n')
for line in f
if not line.strip().lower().startswith('identityfile')
]
headers = ["No.", "Host", "User", "Port", "HostName", "IP", "IdentityFile"]
print("\nSSH Conf Subdirectory Host List")
print(tabulate(table_rows, headers=headers, tablefmt="grid"))
config_lines.append(f" IdentityFile {new_key_path}")
# 2) Prompt for row # or label
choice = safe_input("Enter the row number or the Host label to regenerate: ")
if choice is None:
return
choice = choice.strip()
with open(config_path, 'w') as f:
f.write('\n'.join(config_lines) + '\n')
print_info(f"Updated config file with new IdentityFile: {new_key_path}")
return True
except Exception as e:
print_error(f"Failed to update config file: {e}")
return False
def find_target_host(sorted_rows: List[Tuple], choice: str) -> Optional[Tuple]:
"""Find the target host based on user choice."""
if not choice:
print_error("No choice given.")
return
return None
target_block = None
if choice.isdigit():
row_idx = int(choice)
if row_idx < 1 or row_idx > len(all_blocks):
if row_idx < 1 or row_idx > len(sorted_rows):
print_warning(f"Invalid row number {row_idx}.")
return
target_block = all_blocks[row_idx - 1]
else:
# user typed a label
for b in all_blocks:
if b.get("Host") == choice:
target_block = b
break
if not target_block:
print_warning(f"No matching host label '{choice}' found.")
return
return None
return sorted_rows[row_idx - 1]
# 3) Gather info from block
host_label = target_block.get("Host", "")
hostname = target_block.get("HostName", "")
user = target_block.get("User", "root")
port = int(target_block.get("Port", "22"))
identity_file = target_block.get("IdentityFile", "")
# User typed a label
for row in sorted_rows:
if row[0] == choice: # row[0] is host_label
return row
if not host_label or not hostname:
print_error("Missing Host or HostName; cannot regenerate.")
return
if not identity_file or identity_file == "N/A":
print_error("No IdentityFile found in config; cannot regenerate.")
return
print_warning(f"No matching host label '{choice}' found.")
return None
# Derive local paths
expanded_key = os.path.expanduser(identity_file)
key_dir = os.path.dirname(expanded_key)
pub_path = expanded_key + ".pub"
# 3a) Read old pub key data
old_pub_data = ""
if os.path.isfile(pub_path):
try:
with open(pub_path, "r") as f:
old_pub_data = f.read().strip()
except Exception as e:
print_warning(f"Could not read old pub key: {e}")
else:
print_warning("No old pub key found locally.")
# 4) Remove old local key files
print_info("Removing old key files locally...")
for path in [expanded_key, pub_path]:
def remove_key_files(key_paths: List[str]) -> None:
"""Remove SSH key files."""
for path in key_paths:
if os.path.isfile(path):
try:
os.remove(path)
@ -148,74 +123,196 @@ async def regenerate_key(conf_dir):
except Exception as e:
print_warning(f"Could not remove {path}: {e}")
# 5) Generate new key
print_info("Generating new ed25519 SSH key...")
new_key_path = expanded_key # Reuse the same path from config
cmd = ["ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-f", new_key_path]
try:
subprocess.check_call(cmd)
print_info(f"Generated new SSH key at {new_key_path}")
except subprocess.CalledProcessError as e:
print_error(f"Error generating new SSH key: {e}")
return
async def regenerate_key(conf_dir: str) -> bool:
"""
Regenerate the SSH key for a chosen host by:
1) Displaying the unified table of hosts
2) Letting you pick a row number or host label
3) Reading/deleting any existing local keys
4) Generating a new key
5) Optionally copying it to the remote
6) Removing the old pub key from the remote authorized_keys if present
# 6) Copy new key to remote
copy_choice = safe_input("Copy new key to remote now? (y/n): ")
if copy_choice and copy_choice.lower().startswith('y'):
ssh_copy_cmd = ["ssh-copy-id", "-i", new_key_path]
if port != 22:
ssh_copy_cmd += ["-p", str(port)]
ssh_copy_cmd.append(f"{user}@{hostname}")
Returns True if key was successfully regenerated, False otherwise.
"""
print_info("Regenerate Key - Step 1: Show current hosts...\n")
# Get host list
headers, final_data = await build_host_list_table(conf_dir)
if not final_data:
print_warning("No hosts found. Cannot regenerate a key.")
return False
# Display host table
from tabulate import tabulate
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid"))
# Get host blocks and sort them
all_blocks = []
pattern = os.path.join(conf_dir, "*", "config")
for cfile in glob.glob(pattern):
all_blocks.extend(load_config_file(cfile))
results = await gather_host_info(all_blocks)
sorted_rows = sort_by_ip(results)
# Get user choice
choice = safe_input("Enter the row number or the Host label to regenerate: ")
if choice is None:
return False
target_tuple = find_target_host(sorted_rows, choice.strip())
if not target_tuple:
return False
# Get host information
host_label, user, _, hostname, *_ = target_tuple
# Find config block
found_block = next(
(b for b in all_blocks if b.get("Host") == host_label),
None
)
if not found_block:
print_warning(f"No config block found for '{host_label}'.")
return False
port = int(found_block.get("Port", "22"))
identity_file = found_block.get("IdentityFile", "")
# Handle missing identity file
if not identity_file or identity_file == "N/A":
print_warning("No existing SSH key found in the configuration.")
gen_choice = safe_input("Would you like to generate a new key? (y/n): ")
if not gen_choice or not gen_choice.lower().startswith('y'):
return False
# Set up new key path and generate key
host_dir = Path(conf_dir) / host_label
host_dir.mkdir(mode=0o700, exist_ok=True)
new_key_path = str(host_dir / "id_ed25519")
if not generate_new_key(new_key_path, user, hostname, port):
return False
# Update config with new key
config_path = host_dir / "config"
return update_config_with_key(config_path, new_key_path)
# Handle existing key regeneration
expanded_key = os.path.expanduser(identity_file)
pub_path = expanded_key + ".pub"
old_pub_data = ""
# Try to read old public key
if os.path.isfile(pub_path):
try:
subprocess.check_call(ssh_copy_cmd)
print_info("New key successfully copied to remote server.")
except subprocess.CalledProcessError as e:
print_error(f"Error copying new key: {e}")
old_pub_data = Path(pub_path).read_text().rstrip("\n")
except Exception as e:
print_warning(f"Could not read old pub key: {e}")
# 7) Remove old key from authorized_keys if old_pub_data is non-empty
# Remove old key files
print_info("Removing old key files locally...")
remove_key_files([expanded_key, pub_path])
# Generate new key
if not generate_new_key(expanded_key, user, hostname, port):
return False
# Remove old key from remote if we had it
if old_pub_data:
print_info("Attempting to remove old key from remote authorized_keys...")
await remove_old_key_remote(old_pub_data, user, hostname, port)
else:
print_warning("No old pub key data found locally, skipping remote removal.")
async def remove_old_key_remote(old_pub_data, user, hostname, port):
return True
async def remove_old_key_remote(old_pub_data: str, user: str, hostname: str, port: int) -> bool:
"""
1) Check if old_pub_data is present on remote server (grep -q).
2) If found, remove it with grep -v ...
3) Otherwise, print "No old key found on remote."
Remove the old public key from remote authorized_keys file.
Returns True if successful, False otherwise.
"""
# 1) Check if old_pub_data exists in authorized_keys
check_cmd = [
# Escape the public key data for shell safety
escaped_key = old_pub_data.replace('"', '\\"')
# First check if authorized_keys exists
check_file_cmd = [
"ssh",
"-o", "StrictHostKeyChecking=no",
"-p", str(port),
f"{user}@{hostname}",
f"grep -F '{old_pub_data}' ~/.ssh/authorized_keys"
"test -f ~/.ssh/authorized_keys && echo 'exists'"
]
found_key = False
try:
subprocess.check_call(check_cmd)
found_key = True
except subprocess.CalledProcessError:
# grep returns exit code 1 if not found
pass
result = subprocess.run(check_file_cmd, capture_output=True, text=True)
if result.returncode != 0 or 'exists' not in result.stdout:
print_warning("No authorized_keys file found on remote host.")
return False
except subprocess.CalledProcessError as e:
print_error(f"Error checking authorized_keys: {e}")
return False
if not found_key:
print_warning("No old key found on remote authorized_keys.")
return
# Create a temporary file for the sed script
temp_script = """#!/bin/bash
set -e
KEYS_FILE="$HOME/.ssh/authorized_keys"
TEMP_FILE="$HOME/.ssh/authorized_keys.tmp"
grep -Fxv "$1" "$KEYS_FILE" > "$TEMP_FILE"
mv "$TEMP_FILE" "$KEYS_FILE"
chmod 600 "$KEYS_FILE"
"""
# 2) Actually remove it
# Create a temporary script on the remote host
setup_cmd = [
"ssh",
"-o", "StrictHostKeyChecking=no",
"-p", str(port),
f"{user}@{hostname}",
f'cat > ~/.ssh/remove_key.sh << \'EOF\'\n{temp_script}\nEOF\n'
f'chmod +x ~/.ssh/remove_key.sh'
]
try:
subprocess.run(setup_cmd, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
print_error(f"Failed to create temporary script: {e}")
return False
# Execute the script with the key
remove_cmd = [
"ssh",
"-o", "StrictHostKeyChecking=no",
"-p", str(port),
f"{user}@{hostname}",
f"grep -v '{old_pub_data}' ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys"
f'~/.ssh/remove_key.sh "{escaped_key}"'
]
try:
subprocess.check_call(remove_cmd)
subprocess.run(remove_cmd, check=True, capture_output=True)
print_info("Old public key removed from remote authorized_keys.")
except subprocess.CalledProcessError:
print_warning("Failed to remove old key from remote authorized_keys (permission or other error).")
# Clean up the temporary script
cleanup_cmd = [
"ssh",
"-o", "StrictHostKeyChecking=no",
"-p", str(port),
f"{user}@{hostname}",
"rm -f ~/.ssh/remove_key.sh"
]
subprocess.run(cleanup_cmd, check=True, capture_output=True)
return True
except subprocess.CalledProcessError as e:
print_error(f"Failed to remove old key: {e}")
# Try to clean up even if removal failed
try:
subprocess.run(cleanup_cmd, check=True, capture_output=True)
except:
pass
return False
except Exception as e:
print_error(f"Unexpected error removing old key: {e}")
return False

View file

@ -1,82 +1,68 @@
# ssh_manager/remove_host.py
import os
import glob
import subprocess
import asyncio
from collections import OrderedDict
from .utils import (
print_info,
print_warning,
print_error,
safe_input
)
from .list_hosts import load_config_file, check_ssh_port
async def get_all_host_blocks(conf_dir):
pattern = os.path.join(conf_dir, "*", "config")
conf_files = sorted(glob.glob(pattern))
all_blocks = []
for conf_file in conf_files:
blocks = load_config_file(conf_file)
all_blocks.extend(blocks)
return all_blocks
from .list_hosts import build_host_list_table, load_config_file
"""
Remove host now reuses build_host_list_table to display the same columns:
No. | Host | User | Port | HostName | IP Address | Conf Directory
"""
async def remove_host(conf_dir):
"""
Remove an SSH host by:
1) Listing all hosts
1) Listing all hosts (same columns as main list)
2) Letting user pick row number or label
3) Attempting to remove the old pub key from remote authorized_keys
4) Deleting the subdirectory in ~/.ssh/conf/<host_label>
3) Removing old pub key from remote
4) Deleting ~/.ssh/conf/<host_label>
"""
print_info("Remove Host - Step 1: Show current hosts...\n")
all_blocks = await get_all_host_blocks(conf_dir)
if not all_blocks:
# Reuse the unified table from list_hosts
headers, final_data = await build_host_list_table(conf_dir)
if not final_data:
print_warning("No hosts found. Cannot remove anything.")
return
# We'll display them in a table
import socket
# Print the same table
from tabulate import tabulate
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid"))
table_rows = []
for idx, block in enumerate(all_blocks, start=1):
host_label = block.get("Host", "N/A")
hostname = block.get("HostName", "N/A")
user = block.get("User", "N/A")
port = int(block.get("Port", "22"))
identity_file = block.get("IdentityFile", "N/A")
# We have final_data rows => need to map row => block
# So let's gather the raw blocks again to correlate.
# We'll do a separate approach or we can parse final_data.
# Easiest: Re-run load_config_file if needed or:
blocks = []
# The last gather call for build_host_list_table used load_config_file already
# but it doesn't return the correlation. We'll replicate the logic quickly.
try:
ip_address = socket.gethostbyname(hostname)
port_open = await asyncio.wait_for(check_ssh_port(ip_address, port), timeout=1)
except:
ip_address = None
port_open = False
pattern = os.path.join(conf_dir, "*", "config")
conf_files = sorted(os.listdir(conf_dir))
ip_disp = f"\033[0;32m{ip_address}\033[0m" if ip_address else "\033[0;31mN/A\033[0m"
port_disp = f"\033[0;32m{port}\033[0m" if port_open else f"\033[0;31m{port}\033[0m"
# Actually, let's do a small approach:
from .list_hosts import gather_host_info, sort_by_ip
all_blocks = []
import glob
row = [
idx,
host_label,
user,
port_disp,
hostname,
ip_disp,
identity_file
]
table_rows.append(row)
for cfile in glob.glob(os.path.join(conf_dir, "*", "config")):
blocks.extend(load_config_file(cfile))
headers = ["No.", "Host", "User", "Port", "HostName", "IP", "IdentityFile"]
print("\nSSH Conf Subdirectory Host List")
print(tabulate(table_rows, headers=headers, tablefmt="grid"))
# gather the same big list
results = await gather_host_info(blocks)
sorted_rows = sort_by_ip(results)
# sorted_rows is a list of 7-tuples:
# (host_label, user, colored_port, hostname, colored_ip, conf_display, raw_ip)
# Prompt which host to remove
choice = safe_input("Enter the row number or Host label to remove: ")
if choice is None:
return
@ -85,43 +71,60 @@ async def remove_host(conf_dir):
print_error("Invalid empty choice.")
return
target_block = None
target_tuple = None
# If digit => index in sorted_rows
if choice.isdigit():
idx = int(choice)
if idx < 1 or idx > len(all_blocks):
if idx < 1 or idx > len(sorted_rows):
print_warning(f"Row number {idx} is invalid.")
return
target_block = all_blocks[idx - 1]
target_tuple = sorted_rows[idx - 1]
else:
# They typed a label
for b in all_blocks:
if b.get("Host") == choice:
target_block = b
for t in sorted_rows:
if t[0] == choice: # t[0] = host_label
target_tuple = t
break
if not target_block:
if not target_tuple:
print_warning(f"No matching host label '{choice}' found.")
return
host_label = target_block.get("Host", "")
hostname = target_block.get("HostName", "")
user = target_block.get("User", "root")
port = int(target_block.get("Port", "22"))
identity_file = target_block.get("IdentityFile", "")
# target_tuple is (host_label, user, colored_port, hostname, colored_ip, conf_dir, raw_ip)
host_label = target_tuple[0]
hostname = target_tuple[3]
user = target_tuple[1]
# parse port from colored_port? We can do a quick approach. Alternatively, we re-run the load_config approach again.
# We do a second approach: let's see if we can find the block in "blocks".
# But let's parse the config block to get the real port, identity, etc.
# We find the matching block in "blocks" by host_label
found_block = None
for b in blocks:
if b.get("Host") == host_label:
found_block = b
break
if not found_block:
print_warning(f"No config block found for {host_label}.")
return
port_str = found_block.get("Port", "22")
port = int(port_str)
identity_file = found_block.get("IdentityFile", "")
# now do removal logic
if not host_label:
print_warning("Target block has no Host label. Cannot remove.")
return
# Check IdentityFile for old pub key
if identity_file and identity_file != "N/A":
expanded_key = os.path.expanduser(identity_file)
pub_path = expanded_key + ".pub"
pub_path = expanded_key + ".pub"
old_pub_data = ""
if os.path.isfile(pub_path):
try:
with open(pub_path, "r") as f:
# This is the EXACT line that might appear in authorized_keys
old_pub_data = f.read().rstrip("\n")
except Exception as e:
print_warning(f"Could not read old pub key: {e}")
@ -130,13 +133,13 @@ async def remove_host(conf_dir):
print_info("Attempting to remove old key from remote authorized_keys...")
await remove_old_key_remote(old_pub_data, user, hostname, port)
# Finally, remove the local subdirectory
# remove local folder
host_dir = os.path.join(conf_dir, host_label)
import shutil
if os.path.isdir(host_dir):
confirm = safe_input(f"Are you sure you want to delete local folder {host_dir}? (y/n): ")
if confirm and confirm.lower().startswith('y'):
try:
import shutil
shutil.rmtree(host_dir)
print_info(f"Removed local folder: {host_dir}")
except Exception as e:
@ -144,24 +147,17 @@ async def remove_host(conf_dir):
else:
print_warning("Local folder not removed.")
else:
print_warning(f"Local host folder {host_dir} not found. Nothing to remove.")
print_warning(f"Local folder {host_dir} not found. Nothing to remove.")
async def remove_old_key_remote(old_pub_data, user, hostname, port):
"""
Remove the old key from remote authorized_keys if found, matching EXACT lines.
We'll do:
grep -Fxq for existence
grep -vFx to remove it
"""
# 1) Check if old_pub_data is present in authorized_keys (exact line)
# same logic as before
check_cmd = [
"ssh", "-o", "StrictHostKeyChecking=no",
"-p", str(port),
f"{user}@{hostname}",
f"grep -Fxq \"{old_pub_data}\" ~/.ssh/authorized_keys"
f'grep -Fxq "{old_pub_data}" ~/.ssh/authorized_keys'
]
found_key = False
try:
subprocess.check_call(check_cmd)
found_key = True
@ -172,12 +168,11 @@ async def remove_old_key_remote(old_pub_data, user, hostname, port):
print_warning("No old key found on remote authorized_keys.")
return
# 2) Actually remove it by ignoring EXACT matches to that line
remove_cmd = [
"ssh", "-o", "StrictHostKeyChecking=no",
"-p", str(port),
f"{user}@{hostname}",
f"grep -vFx \"{old_pub_data}\" ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys"
f'grep -vFx "{old_pub_data}" ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys'
]
try:
subprocess.check_call(remove_cmd)

View file

@ -1,25 +1,34 @@
# ssh_manager/utils.py
from enum import Enum
import sys
from functools import lru_cache
class Colors:
GREEN = "\033[0;32m"
RED = "\033[0;31m"
class Colors(Enum):
GREEN = "\033[0;32m"
RED = "\033[0;31m"
YELLOW = "\033[1;33m"
CYAN = "\033[0;36m"
BOLD = "\033[1m"
RESET = "\033[0m"
CYAN = "\033[0;36m"
BOLD = "\033[1m"
RESET = "\033[0m"
def print_error(message):
print(f"{Colors.RED}{Colors.BOLD}[✖] {Colors.RESET}{message}")
def __str__(self):
return self.value
def print_warning(message):
print(f"{Colors.YELLOW}{Colors.BOLD}[⚠] {Colors.RESET}{message}")
@lru_cache(maxsize=32)
def _format_message(prefix: str, color: Colors, message: str) -> str:
return f"{color}{Colors.BOLD}[{prefix}] {Colors.RESET}{message}"
def print_info(message):
print(f"{Colors.GREEN}{Colors.BOLD}[✔] {Colors.RESET}{message}")
def print_error(message: str) -> None:
print(_format_message("", Colors.RED, message))
def safe_input(prompt=""):
def print_warning(message: str) -> None:
print(_format_message("", Colors.YELLOW, message))
def print_info(message: str) -> None:
print(_format_message("", Colors.GREEN, message))
def safe_input(prompt: str = "") -> str:
"""
A wrapper around input() that exits the entire script on Ctrl+C.
"""