refactored code

This commit is contained in:
Arctic 2025-03-08 00:43:33 -06:00
parent 73429771d4
commit 414266eefc
6 changed files with 315 additions and 404 deletions

2
cli.py
View file

@ -57,7 +57,7 @@ def main():
elif choice == '2': elif choice == '2':
add_host(CONF_DIR) add_host(CONF_DIR)
elif choice == '3': elif choice == '3':
edit_host(CONF_DIR) asyncio.run(edit_host(CONF_DIR))
elif choice == '4': elif choice == '4':
asyncio.run(regenerate_key(CONF_DIR)) asyncio.run(regenerate_key(CONF_DIR))
elif choice == '5': elif choice == '5':

View file

@ -2,98 +2,31 @@ import os
import asyncio import asyncio
from collections import OrderedDict from collections import OrderedDict
from .utils import print_error, print_warning, print_info, safe_input from .utils import print_error, print_warning, print_info, safe_input
from .list_hosts import list_hosts, load_config_file, check_ssh_port from .list_hosts import (
build_host_list_table,
load_config_file,
gather_host_info,
sort_by_ip
)
async def get_all_host_blocks(conf_dir): async def edit_host(conf_dir):
"""
Similar to list_hosts, but returns the list of host blocks + a table of results.
We'll build a table ourselves so we can map row numbers to actual host labels.
"""
import glob
import socket
pattern = os.path.join(conf_dir, "*", "config")
conf_files = sorted(glob.glob(pattern))
all_blocks = []
for conf_file in conf_files:
blocks = load_config_file(conf_file)
all_blocks.extend(blocks)
# If no blocks found, return empty
if not all_blocks:
return []
# We want to do a partial version of check_host to get row data
# so we can display the table right here and keep track of each blocks host label.
# But let's do it similarly to list_hosts:
table_rows = []
for idx, b in enumerate(all_blocks, start=1):
host_label = b.get("Host", "N/A")
hostname = b.get("HostName", "N/A")
user = b.get("User", "N/A")
port = int(b.get("Port", "22"))
identity_file = b.get("IdentityFile", "N/A")
# Identity check
if identity_file != "N/A":
expanded_identity = os.path.expanduser(identity_file)
identity_exists = os.path.isfile(expanded_identity)
else:
identity_exists = False
# IP resolution
try:
ip_address = socket.gethostbyname(hostname)
except socket.error:
ip_address = None
# Port check
if ip_address:
port_open = await asyncio.wait_for(check_ssh_port(ip_address, port), timeout=1)
else:
port_open = False
# Colors for display (optional, or we can keep it simple):
ip_display = f"\033[0;32m{ip_address}\033[0m" if ip_address else "\033[0;31mN/A\033[0m"
port_display = f"\033[0;32m{port}\033[0m" if port_open else f"\033[0;31m{port}\033[0m"
identity_disp= f"\033[0;32m{identity_file}\033[0m" if identity_exists else f"\033[0;31m{identity_file}\033[0m"
row = [
idx,
host_label,
user,
port_display,
hostname,
ip_display,
identity_disp
]
table_rows.append(row)
# Print the table
from tabulate import tabulate
headers = ["No.", "Host", "User", "Port", "HostName", "IP Address", "IdentityFile"]
print("\nSSH Conf Subdirectory Host List")
print(tabulate(table_rows, headers=headers, tablefmt="grid"))
return all_blocks
def edit_host(conf_dir):
""" """
Let the user update fields for an existing host in ~/.ssh/conf/<label>/config. Let the user update fields for an existing host in ~/.ssh/conf/<label>/config.
The user may type either the row number OR the actual host label. 1) Display the unified table (No. | Host | User | Port | HostName | IP Address | Conf Directory)
2) Prompt row number or host label
3) Rewrite config with updated fields
""" """
# 1) Gather + display the current host list
print_info("Here is the current list of hosts:\n") print_info("Here is the current list of hosts:\n")
all_blocks = asyncio.run(get_all_host_blocks(conf_dir)) headers, final_data = await build_host_list_table(conf_dir)
if not final_data:
if not all_blocks: print_warning("No hosts to edit.")
print_warning("No hosts found to edit.")
return return
# 2) Prompt for which host to edit (by label or row number) from tabulate import tabulate
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid"))
choice = safe_input("Enter the row number or the Host label to edit: ") choice = safe_input("Enter the row number or the Host label to edit: ")
if choice is None: if choice is None:
return # user canceled (Ctrl+C) return # user canceled (Ctrl+C)
@ -102,43 +35,47 @@ def edit_host(conf_dir):
print_error("Host label or row number cannot be empty.") print_error("Host label or row number cannot be empty.")
return return
# Check if user typed a digit -> row number # We replicate the approach to find the matching block
target_block = None all_blocks = []
import glob
for cfile in glob.glob(os.path.join(conf_dir, "*", "config")):
all_blocks.extend(load_config_file(cfile))
results = await gather_host_info(all_blocks)
sorted_rows = sort_by_ip(results)
target_tuple = None
if choice.isdigit(): if choice.isdigit():
row_idx = int(choice) idx = int(choice)
# Validate index if idx < 1 or idx > len(sorted_rows):
if row_idx < 1 or row_idx > len(all_blocks): print_warning(f"Row number {idx} is invalid.")
print_warning(f"Invalid row number {row_idx}.")
return return
target_block = all_blocks[row_idx - 1] # zero-based target_tuple = sorted_rows[idx - 1]
else: else:
# The user typed a host label for t in sorted_rows:
# We must search all_blocks for a matching Host if t[0] == choice: # t[0] => host_label
for b in all_blocks: target_tuple = t
if b.get("Host") == choice:
target_block = b
break break
if not target_block: if not target_tuple:
print_warning(f"No matching host label '{choice}' found.") print_warning(f"No matching Host '{choice}' found in the table.")
return return
# Now we have a target_block with existing data host_label = target_tuple[0]
host_label = target_block.get("Host", "") # find the config block
if not host_label: found_block = None
print_warning("This host block has no label. Cannot edit.") for b in all_blocks:
if b.get("Host") == host_label:
found_block = b
break
if not found_block:
print_warning(f"No config block found for '{host_label}'.")
return return
# Derive the config path old_hostname = found_block.get("HostName", "")
host_dir = os.path.join(conf_dir, host_label) old_user = found_block.get("User", "")
config_path = os.path.join(host_dir, "config") old_port = found_block.get("Port", "22")
if not os.path.isfile(config_path): old_identity = found_block.get("IdentityFile", "")
print_warning(f"No config file found at {config_path}; cannot edit this host.")
return
old_hostname = target_block.get("HostName", "")
old_user = target_block.get("User", "")
old_port = target_block.get("Port", "22")
old_identity = target_block.get("IdentityFile", "")
print_info("Leave a field blank to keep its current value.") print_info("Leave a field blank to keep its current value.")
@ -167,6 +104,8 @@ def edit_host(conf_dir):
final_port = new_port if new_port else old_port final_port = new_port if new_port else old_port
final_ident = new_ident if new_ident else old_identity final_ident = new_ident if new_ident else old_identity
# Overwrite the file
config_path = os.path.join(conf_dir, host_label, "config")
new_config_lines = [ new_config_lines = [
f"Host {host_label}", f"Host {host_label}",
f" HostName {final_hostname}", f" HostName {final_hostname}",

View file

@ -11,10 +11,6 @@ from collections import OrderedDict
from .utils import print_warning, print_error, Colors from .utils import print_warning, print_error, Colors
async def check_ssh_port(ip_address, port): async def check_ssh_port(ip_address, port):
"""
Attempt to open an SSH connection to see if the port is open.
Returns True if successful, False otherwise.
"""
try: try:
reader, writer = await asyncio.wait_for( reader, writer = await asyncio.wait_for(
asyncio.open_connection(ip_address, port), timeout=1 asyncio.open_connection(ip_address, port), timeout=1
@ -26,13 +22,8 @@ async def check_ssh_port(ip_address, port):
return False return False
def load_config_file(file_path): def load_config_file(file_path):
"""
Parse a single SSH config file and return a list of host blocks.
Each block is an OrderedDict with keys like 'Host', 'HostName', etc.
"""
blocks = [] blocks = []
host_data = None host_data = None
try: try:
with open(file_path, 'r') as f: with open(file_path, 'r') as f:
lines = f.readlines() lines = f.readlines()
@ -42,11 +33,9 @@ def load_config_file(file_path):
for line in lines: for line in lines:
stripped_line = line.strip() stripped_line = line.strip()
# Skip empty lines and comments
if not stripped_line or stripped_line.startswith('#'): if not stripped_line or stripped_line.startswith('#'):
continue continue
# Start of a new Host block
if stripped_line.lower().startswith('host '): if stripped_line.lower().startswith('host '):
host_labels = stripped_line.split()[1:] host_labels = stripped_line.split()[1:]
for label in host_labels: for label in host_labels:
@ -64,64 +53,92 @@ def load_config_file(file_path):
blocks.append(host_data) blocks.append(host_data)
return blocks return blocks
async def check_host(host): async def gather_host_info(all_host_blocks):
""" """
Given a host block, resolve IP, check SSH port, etc. Given a list of host blocks, gather full info:
Returns a tuple of: - resolved IP
1) Host label - port open check
2) User - 'Conf Directory' coloring if IdentityFile != 'N/A'
3) Port (colored if open) Returns a list of 7-tuples:
4) HostName (host_label, user, colored_port, hostname, colored_ip, conf_path_display, raw_ip)
5) IP Address (colored if resolved)
6) Conf Directory (green if has IdentityFile, else no color)
7) raw_ip (uncolored string for sorting)
""" """
host_label = host.get('Host', 'N/A') results = []
hostname = host.get('HostName', 'N/A')
user = host.get('User', 'N/A')
port = int(host.get('Port', '22'))
identity_file = host.get('IdentityFile', 'N/A')
# Resolve IP async def process_block(h):
try: host_label = h.get('Host', 'N/A')
raw_ip = socket.gethostbyname(hostname) # uncolored hostname = h.get('HostName', 'N/A')
colored_ip = f"{Colors.GREEN}{raw_ip}{Colors.RESET}" user = h.get('User', 'N/A')
except socket.error: port = int(h.get('Port', '22'))
raw_ip = "N/A" identity_file = h.get('IdentityFile', 'N/A')
colored_ip = f"{Colors.RED}N/A{Colors.RESET}"
# Check port # Resolve IP
if raw_ip != "N/A": try:
port_open = await check_ssh_port(raw_ip, port) raw_ip = socket.gethostbyname(hostname)
colored_port = ( colored_ip = f"{Colors.GREEN}{raw_ip}{Colors.RESET}"
f"{Colors.GREEN}{port}{Colors.RESET}" if port_open else f"{Colors.RED}{port}{Colors.RESET}" except socket.error:
raw_ip = "N/A"
colored_ip = f"{Colors.RED}N/A{Colors.RESET}"
# Check port
if raw_ip != "N/A":
port_open = await check_ssh_port(raw_ip, port)
colored_port = (
f"{Colors.GREEN}{port}{Colors.RESET}" if port_open else f"{Colors.RED}{port}{Colors.RESET}"
)
else:
colored_port = f"{Colors.RED}{port}{Colors.RESET}"
# Conf Directory = ~/.ssh/conf/<host_label>
conf_path = f"~/.ssh/conf/{host_label}"
# If there's an IdentityFile, color the conf path
if identity_file != 'N/A':
conf_path_display = f"{Colors.GREEN}{conf_path}{Colors.RESET}"
else:
conf_path_display = conf_path
return (
host_label,
user,
colored_port,
hostname,
colored_ip,
conf_path_display,
raw_ip # uncolored IP for sorting
) )
else:
colored_port = f"{Colors.RED}{port}{Colors.RESET}"
# Conf Directory = ~/.ssh/conf/<host_label> # Process blocks concurrently
conf_path = f"~/.ssh/conf/{host_label}" tasks = [process_block(b) for b in all_host_blocks]
# If there's an IdentityFile, color the conf path green results = await asyncio.gather(*tasks)
if identity_file != 'N/A': return results
conf_path_display = f"{Colors.GREEN}{conf_path}{Colors.RESET}"
else:
conf_path_display = conf_path
# Return the data plus the uncolored IP for sorting def parse_ip(ip_str):
return (
host_label,
user,
colored_port,
hostname,
colored_ip,
conf_path_display,
raw_ip # for sorting
)
async def list_hosts(conf_dir):
""" """
List out all hosts found in ~/.ssh/conf/*/config, sorted by IP in ascending order. Convert a string IP to an ipaddress object for sorting.
Columns: No., Host, User, Port, HostName, IP Address, Conf Directory Returns None if invalid or 'N/A'.
"""
import ipaddress
try:
return ipaddress.ip_address(ip_str)
except ValueError:
return None
def sort_by_ip(results):
"""
Sort the 7-tuples by IP ascending, with 'N/A' last.
"""
sortable = []
for row in results:
raw_ip = row[-1]
ip_obj = parse_ip(raw_ip)
sortable.append(((ip_obj is None, ip_obj), row))
sortable.sort(key=lambda x: x[0])
return [row for (_, row) in sortable]
async def build_host_list_table(conf_dir):
"""
Gathers + sorts all hosts in conf_dir by IP ascending.
Returns (headers, final_table_rows), each row omitting the raw_ip.
""" """
pattern = os.path.join(conf_dir, "*", "config") pattern = os.path.join(conf_dir, "*", "config")
conf_files = sorted(glob.glob(pattern)) conf_files = sorted(glob.glob(pattern))
@ -133,42 +150,28 @@ async def list_hosts(conf_dir):
headers = ["No.", "Host", "User", "Port", "HostName", "IP Address", "Conf Directory"] headers = ["No.", "Host", "User", "Port", "HostName", "IP Address", "Conf Directory"]
if not all_host_blocks: if not all_host_blocks:
return headers, []
results = await gather_host_info(all_host_blocks)
sorted_rows = sort_by_ip(results)
# Build final table
final_data = []
for idx, row in enumerate(sorted_rows, start=1):
# row is (host_label, user, colored_port, hostname, colored_ip, conf_path_display, raw_ip)
final_data.append([idx] + list(row[:-1]))
return headers, final_data
async def list_hosts(conf_dir):
headers, final_data = await build_host_list_table(conf_dir)
if not final_data:
print_warning("No hosts found. The server list is empty.") print_warning("No hosts found. The server list is empty.")
print("\nSSH Conf Subdirectory Host List") print("\nSSH Conf Subdirectory Host List")
from tabulate import tabulate
print(tabulate([], headers=headers, tablefmt="grid")) print(tabulate([], headers=headers, tablefmt="grid"))
return return
# Gather full data for each host from tabulate import tabulate
tasks = [check_host(h) for h in all_host_blocks]
results = await asyncio.gather(*tasks)
# We want to sort by IP ascending. results[i] is a tuple:
# (host_label, user, colored_port, hostname, colored_ip, conf_path, raw_ip)
# We'll parse raw_ip as an ipaddress for sorting. "N/A" => sort to the end.
def parse_ip(ip_str):
try:
return ipaddress.ip_address(ip_str)
except ValueError:
return None
# Convert the results into a list of (ip_obj, original_tuple)
# so we can sort, then rebuild the final data.
sortable = []
for row in results:
raw_ip = row[-1] # last element
ip_obj = parse_ip(raw_ip)
# We'll sort None last by using a sort key that puts (True) after (False)
# e.g. (ip_obj is None, ip_obj)
sortable.append(((ip_obj is None, ip_obj), row))
# Sort by (is_none, ip_obj)
sortable.sort(key=lambda x: x[0])
# Rebuild the final display table, ignoring the raw_ip at the end
final_data = []
for idx, (_, row) in enumerate(sortable, start=1):
# row is (host_label, user, colored_port, hostname, colored_ip, conf_path, raw_ip)
final_data.append([idx] + list(row[:-1])) # omit raw_ip
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)") print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid")) print(tabulate(final_data, headers=headers, tablefmt="grid"))

View file

@ -1,88 +1,54 @@
import os import os
import glob
import subprocess import subprocess
import asyncio import asyncio
from collections import OrderedDict
from .utils import ( from .utils import (
print_info, print_info,
print_warning, print_warning,
print_error, print_error,
safe_input, safe_input
Colors )
from .list_hosts import (
build_host_list_table,
load_config_file,
gather_host_info,
sort_by_ip
) )
from .list_hosts import load_config_file, check_ssh_port
async def get_all_host_blocks(conf_dir):
pattern = os.path.join(conf_dir, "*", "config")
conf_files = sorted(glob.glob(pattern))
all_blocks = []
for conf_file in conf_files:
blocks = load_config_file(conf_file)
all_blocks.extend(blocks)
if not all_blocks:
return []
return all_blocks
async def regenerate_key(conf_dir): async def regenerate_key(conf_dir):
""" """
Menu-driven function to regenerate a key for a selected host: Regenerate the SSH key for a chosen host by:
1) Show the host table with row numbers 1) Displaying the unified table of hosts (No. | Host | User | Port | HostName | IP Address | Conf Directory).
2) Let user pick row # or label 2) Letting you pick a row number or host label.
3) Read old pub data 3) Reading/deleting any existing local keys,
4) Remove old local key files 4) Generating a new key,
5) Generate new key 5) Optionally copying it to the remote,
6) Copy new key 6) Removing the old pub key from the remote authorized_keys if present.
7) Remove old key from remote authorized_keys (improved logic to detect existence)
""" """
print_info("Regenerate Key - Step 1: Show current hosts...\n") print_info("Regenerate Key - Step 1: Show current hosts...\n")
# 1) Gather host blocks # 1) Reuse the same columns as the main 'list_hosts'
all_blocks = await get_all_host_blocks(conf_dir) headers, final_data = await build_host_list_table(conf_dir)
if not all_blocks: if not final_data:
print_warning("No hosts found. Cannot regenerate a key.") print_warning("No hosts found. Cannot regenerate a key.")
return return
# Display them in a table (similar to list_hosts):
import socket
from tabulate import tabulate from tabulate import tabulate
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid"))
table_rows = [] # 2) We need to correlate row => actual config block.
for idx, block in enumerate(all_blocks, start=1): # We'll replicate the logic that build_host_list_table uses.
host_label = block.get("Host", "N/A") all_blocks = []
hostname = block.get("HostName", "N/A") import glob
user = block.get("User", "N/A") for cfile in glob.glob(os.path.join(conf_dir, "*", "config")):
port = int(block.get("Port", "22")) all_blocks.extend(load_config_file(cfile))
identity_file = block.get("IdentityFile", "N/A")
# Check IP results = await gather_host_info(all_blocks)
try: sorted_rows = sort_by_ip(results)
ip_address = socket.gethostbyname(hostname) # sorted_rows is a list of 7-tuples:
port_open = await asyncio.wait_for(check_ssh_port(ip_address, port), timeout=1) # (host_label, user, colored_port, hostname, colored_ip, conf_path, raw_ip)
except:
ip_address = None
port_open = False
ip_disp = f"\033[0;32m{ip_address}\033[0m" if ip_address else "\033[0;31mN/A\033[0m"
port_disp = f"\033[0;32m{port}\033[0m" if port_open else f"\033[0;31m{port}\033[0m"
row = [
idx,
host_label,
user,
port_disp,
hostname,
ip_disp,
identity_file
]
table_rows.append(row)
headers = ["No.", "Host", "User", "Port", "HostName", "IP", "IdentityFile"]
print("\nSSH Conf Subdirectory Host List")
print(tabulate(table_rows, headers=headers, tablefmt="grid"))
# 2) Prompt for row # or label
choice = safe_input("Enter the row number or the Host label to regenerate: ") choice = safe_input("Enter the row number or the Host label to regenerate: ")
if choice is None: if choice is None:
return return
@ -91,54 +57,60 @@ async def regenerate_key(conf_dir):
print_error("No choice given.") print_error("No choice given.")
return return
target_block = None target_tuple = None
if choice.isdigit(): if choice.isdigit():
row_idx = int(choice) row_idx = int(choice)
if row_idx < 1 or row_idx > len(all_blocks): if row_idx < 1 or row_idx > len(sorted_rows):
print_warning(f"Invalid row number {row_idx}.") print_warning(f"Invalid row number {row_idx}.")
return return
target_block = all_blocks[row_idx - 1] target_tuple = sorted_rows[row_idx - 1]
else: else:
# user typed a label # user typed a label
for b in all_blocks: for t in sorted_rows:
if b.get("Host") == choice: if t[0] == choice: # t[0] is host_label
target_block = b target_tuple = t
break break
if not target_block: if not target_tuple:
print_warning(f"No matching host label '{choice}' found.") print_warning(f"No matching host label '{choice}' found.")
return return
# 3) Gather info from block # 3) Retrieve the full config block so we have Port, IdentityFile, etc.
host_label = target_block.get("Host", "") host_label = target_tuple[0]
hostname = target_block.get("HostName", "") hostname = target_tuple[3] # (host_label, user, colored_port, hostname, ...)
user = target_block.get("User", "root") user = target_tuple[1]
port = int(target_block.get("Port", "22"))
identity_file = target_block.get("IdentityFile", "")
if not host_label or not hostname: # find that block
print_error("Missing Host or HostName; cannot regenerate.") found_block = None
for b in all_blocks:
if b.get("Host") == host_label:
found_block = b
break
if not found_block:
print_warning(f"No config block found for '{host_label}'.")
return return
port_str = found_block.get("Port", "22")
port = int(port_str)
identity_file = found_block.get("IdentityFile", "")
# If missing or "N/A", we can't regenerate
if not identity_file or identity_file == "N/A": if not identity_file or identity_file == "N/A":
print_error("No IdentityFile found in config; cannot regenerate.") print_error("No IdentityFile found in config; cannot regenerate.")
return return
# Derive local paths # 4) Remove old local key files
expanded_key = os.path.expanduser(identity_file) expanded_key = os.path.expanduser(identity_file)
key_dir = os.path.dirname(expanded_key)
pub_path = expanded_key + ".pub" pub_path = expanded_key + ".pub"
old_pub_data = ""
# 3a) Read old pub key data
old_pub_data = ""
if os.path.isfile(pub_path): if os.path.isfile(pub_path):
try: try:
with open(pub_path, "r") as f: with open(pub_path, "r") as f:
old_pub_data = f.read().strip() old_pub_data = f.read().rstrip("\n")
except Exception as e: except Exception as e:
print_warning(f"Could not read old pub key: {e}") print_warning(f"Could not read old pub key: {e}")
else:
print_warning("No old pub key found locally.")
# 4) Remove old local key files
print_info("Removing old key files locally...") print_info("Removing old key files locally...")
for path in [expanded_key, pub_path]: for path in [expanded_key, pub_path]:
if os.path.isfile(path): if os.path.isfile(path):
@ -150,7 +122,7 @@ async def regenerate_key(conf_dir):
# 5) Generate new key # 5) Generate new key
print_info("Generating new ed25519 SSH key...") print_info("Generating new ed25519 SSH key...")
new_key_path = expanded_key # Reuse the same path from config new_key_path = expanded_key
cmd = ["ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-f", new_key_path] cmd = ["ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-f", new_key_path]
try: try:
subprocess.check_call(cmd) subprocess.check_call(cmd)
@ -159,7 +131,7 @@ async def regenerate_key(conf_dir):
print_error(f"Error generating new SSH key: {e}") print_error(f"Error generating new SSH key: {e}")
return return
# 6) Copy new key to remote # 6) Copy the new key to remote if user wants
copy_choice = safe_input("Copy new key to remote now? (y/n): ") copy_choice = safe_input("Copy new key to remote now? (y/n): ")
if copy_choice and copy_choice.lower().startswith('y'): if copy_choice and copy_choice.lower().startswith('y'):
ssh_copy_cmd = ["ssh-copy-id", "-i", new_key_path] ssh_copy_cmd = ["ssh-copy-id", "-i", new_key_path]
@ -172,7 +144,7 @@ async def regenerate_key(conf_dir):
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print_error(f"Error copying new key: {e}") print_error(f"Error copying new key: {e}")
# 7) Remove old key from authorized_keys if old_pub_data is non-empty # 7) Remove old key from authorized_keys if we had old_pub_data
if old_pub_data: if old_pub_data:
print_info("Attempting to remove old key from remote authorized_keys...") print_info("Attempting to remove old key from remote authorized_keys...")
await remove_old_key_remote(old_pub_data, user, hostname, port) await remove_old_key_remote(old_pub_data, user, hostname, port)
@ -181,41 +153,34 @@ async def regenerate_key(conf_dir):
async def remove_old_key_remote(old_pub_data, user, hostname, port): async def remove_old_key_remote(old_pub_data, user, hostname, port):
""" """
1) Check if old_pub_data is present on remote server (grep -q). Checks and removes the exact matching line from remote authorized_keys
2) If found, remove it with grep -v ... via grep -Fxq / grep -vFx.
3) Otherwise, print "No old key found on remote."
""" """
# 1) Check if old_pub_data exists in authorized_keys
check_cmd = [ check_cmd = [
"ssh", "ssh", "-o", "StrictHostKeyChecking=no",
"-o", "StrictHostKeyChecking=no",
"-p", str(port), "-p", str(port),
f"{user}@{hostname}", f"{user}@{hostname}",
f"grep -F '{old_pub_data}' ~/.ssh/authorized_keys" f'grep -Fxq "{old_pub_data}" ~/.ssh/authorized_keys'
] ]
found_key = False found_key = False
try: try:
subprocess.check_call(check_cmd) subprocess.check_call(check_cmd)
found_key = True found_key = True
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
# grep returns exit code 1 if not found
pass pass
if not found_key: if not found_key:
print_warning("No old key found on remote authorized_keys.") print_warning("No old key found on remote authorized_keys.")
return return
# 2) Actually remove it
remove_cmd = [ remove_cmd = [
"ssh", "ssh", "-o", "StrictHostKeyChecking=no",
"-o", "StrictHostKeyChecking=no",
"-p", str(port), "-p", str(port),
f"{user}@{hostname}", f"{user}@{hostname}",
f"grep -v '{old_pub_data}' ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys" f'grep -vFx "{old_pub_data}" ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys'
] ]
try: try:
subprocess.check_call(remove_cmd) subprocess.check_call(remove_cmd)
print_info("Old public key removed from remote authorized_keys.") print_info("Old public key removed from remote authorized_keys.")
except subprocess.CalledProcessError: except subprocess.CalledProcessError:
print_warning("Failed to remove old key from remote authorized_keys (permission or other error).") print_warning("Failed to remove old key from remote authorized_keys (permissions or other error).")

View file

@ -1,82 +1,68 @@
# ssh_manager/remove_host.py # ssh_manager/remove_host.py
import os import os
import glob
import subprocess import subprocess
import asyncio import asyncio
from collections import OrderedDict
from .utils import ( from .utils import (
print_info, print_info,
print_warning, print_warning,
print_error, print_error,
safe_input safe_input
) )
from .list_hosts import load_config_file, check_ssh_port from .list_hosts import build_host_list_table, load_config_file
"""
async def get_all_host_blocks(conf_dir): Remove host now reuses build_host_list_table to display the same columns:
pattern = os.path.join(conf_dir, "*", "config") No. | Host | User | Port | HostName | IP Address | Conf Directory
conf_files = sorted(glob.glob(pattern)) """
all_blocks = []
for conf_file in conf_files:
blocks = load_config_file(conf_file)
all_blocks.extend(blocks)
return all_blocks
async def remove_host(conf_dir): async def remove_host(conf_dir):
""" """
Remove an SSH host by: Remove an SSH host by:
1) Listing all hosts 1) Listing all hosts (same columns as main list)
2) Letting user pick row number or label 2) Letting user pick row number or label
3) Attempting to remove the old pub key from remote authorized_keys 3) Removing old pub key from remote
4) Deleting the subdirectory in ~/.ssh/conf/<host_label> 4) Deleting ~/.ssh/conf/<host_label>
""" """
print_info("Remove Host - Step 1: Show current hosts...\n") print_info("Remove Host - Step 1: Show current hosts...\n")
all_blocks = await get_all_host_blocks(conf_dir) # Reuse the unified table from list_hosts
if not all_blocks: headers, final_data = await build_host_list_table(conf_dir)
if not final_data:
print_warning("No hosts found. Cannot remove anything.") print_warning("No hosts found. Cannot remove anything.")
return return
# We'll display them in a table # Print the same table
import socket
from tabulate import tabulate from tabulate import tabulate
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid"))
table_rows = [] # We have final_data rows => need to map row => block
for idx, block in enumerate(all_blocks, start=1): # So let's gather the raw blocks again to correlate.
host_label = block.get("Host", "N/A") # We'll do a separate approach or we can parse final_data.
hostname = block.get("HostName", "N/A") # Easiest: Re-run load_config_file if needed or:
user = block.get("User", "N/A") blocks = []
port = int(block.get("Port", "22")) # The last gather call for build_host_list_table used load_config_file already
identity_file = block.get("IdentityFile", "N/A") # but it doesn't return the correlation. We'll replicate the logic quickly.
try: pattern = os.path.join(conf_dir, "*", "config")
ip_address = socket.gethostbyname(hostname) conf_files = sorted(os.listdir(conf_dir))
port_open = await asyncio.wait_for(check_ssh_port(ip_address, port), timeout=1)
except:
ip_address = None
port_open = False
ip_disp = f"\033[0;32m{ip_address}\033[0m" if ip_address else "\033[0;31mN/A\033[0m" # Actually, let's do a small approach:
port_disp = f"\033[0;32m{port}\033[0m" if port_open else f"\033[0;31m{port}\033[0m" from .list_hosts import gather_host_info, sort_by_ip
all_blocks = []
import glob
row = [ for cfile in glob.glob(os.path.join(conf_dir, "*", "config")):
idx, blocks.extend(load_config_file(cfile))
host_label,
user,
port_disp,
hostname,
ip_disp,
identity_file
]
table_rows.append(row)
headers = ["No.", "Host", "User", "Port", "HostName", "IP", "IdentityFile"] # gather the same big list
print("\nSSH Conf Subdirectory Host List") results = await gather_host_info(blocks)
print(tabulate(table_rows, headers=headers, tablefmt="grid")) sorted_rows = sort_by_ip(results)
# sorted_rows is a list of 7-tuples:
# (host_label, user, colored_port, hostname, colored_ip, conf_display, raw_ip)
# Prompt which host to remove
choice = safe_input("Enter the row number or Host label to remove: ") choice = safe_input("Enter the row number or Host label to remove: ")
if choice is None: if choice is None:
return return
@ -85,43 +71,60 @@ async def remove_host(conf_dir):
print_error("Invalid empty choice.") print_error("Invalid empty choice.")
return return
target_block = None target_tuple = None
# If digit => index in sorted_rows
if choice.isdigit(): if choice.isdigit():
idx = int(choice) idx = int(choice)
if idx < 1 or idx > len(all_blocks): if idx < 1 or idx > len(sorted_rows):
print_warning(f"Row number {idx} is invalid.") print_warning(f"Row number {idx} is invalid.")
return return
target_block = all_blocks[idx - 1] target_tuple = sorted_rows[idx - 1]
else: else:
# They typed a label # They typed a label
for b in all_blocks: for t in sorted_rows:
if b.get("Host") == choice: if t[0] == choice: # t[0] = host_label
target_block = b target_tuple = t
break break
if not target_block: if not target_tuple:
print_warning(f"No matching host label '{choice}' found.") print_warning(f"No matching host label '{choice}' found.")
return return
host_label = target_block.get("Host", "") # target_tuple is (host_label, user, colored_port, hostname, colored_ip, conf_dir, raw_ip)
hostname = target_block.get("HostName", "") host_label = target_tuple[0]
user = target_block.get("User", "root") hostname = target_tuple[3]
port = int(target_block.get("Port", "22")) user = target_tuple[1]
identity_file = target_block.get("IdentityFile", "") # parse port from colored_port? We can do a quick approach. Alternatively, we re-run the load_config approach again.
# We do a second approach: let's see if we can find the block in "blocks".
# But let's parse the config block to get the real port, identity, etc.
# We find the matching block in "blocks" by host_label
found_block = None
for b in blocks:
if b.get("Host") == host_label:
found_block = b
break
if not found_block:
print_warning(f"No config block found for {host_label}.")
return
port_str = found_block.get("Port", "22")
port = int(port_str)
identity_file = found_block.get("IdentityFile", "")
# now do removal logic
if not host_label: if not host_label:
print_warning("Target block has no Host label. Cannot remove.") print_warning("Target block has no Host label. Cannot remove.")
return return
# Check IdentityFile for old pub key
if identity_file and identity_file != "N/A": if identity_file and identity_file != "N/A":
expanded_key = os.path.expanduser(identity_file) expanded_key = os.path.expanduser(identity_file)
pub_path = expanded_key + ".pub" pub_path = expanded_key + ".pub"
old_pub_data = "" old_pub_data = ""
if os.path.isfile(pub_path): if os.path.isfile(pub_path):
try: try:
with open(pub_path, "r") as f: with open(pub_path, "r") as f:
# This is the EXACT line that might appear in authorized_keys
old_pub_data = f.read().rstrip("\n") old_pub_data = f.read().rstrip("\n")
except Exception as e: except Exception as e:
print_warning(f"Could not read old pub key: {e}") print_warning(f"Could not read old pub key: {e}")
@ -130,13 +133,13 @@ async def remove_host(conf_dir):
print_info("Attempting to remove old key from remote authorized_keys...") print_info("Attempting to remove old key from remote authorized_keys...")
await remove_old_key_remote(old_pub_data, user, hostname, port) await remove_old_key_remote(old_pub_data, user, hostname, port)
# Finally, remove the local subdirectory # remove local folder
host_dir = os.path.join(conf_dir, host_label) host_dir = os.path.join(conf_dir, host_label)
import shutil
if os.path.isdir(host_dir): if os.path.isdir(host_dir):
confirm = safe_input(f"Are you sure you want to delete local folder {host_dir}? (y/n): ") confirm = safe_input(f"Are you sure you want to delete local folder {host_dir}? (y/n): ")
if confirm and confirm.lower().startswith('y'): if confirm and confirm.lower().startswith('y'):
try: try:
import shutil
shutil.rmtree(host_dir) shutil.rmtree(host_dir)
print_info(f"Removed local folder: {host_dir}") print_info(f"Removed local folder: {host_dir}")
except Exception as e: except Exception as e:
@ -144,24 +147,17 @@ async def remove_host(conf_dir):
else: else:
print_warning("Local folder not removed.") print_warning("Local folder not removed.")
else: else:
print_warning(f"Local host folder {host_dir} not found. Nothing to remove.") print_warning(f"Local folder {host_dir} not found. Nothing to remove.")
async def remove_old_key_remote(old_pub_data, user, hostname, port): async def remove_old_key_remote(old_pub_data, user, hostname, port):
""" # same logic as before
Remove the old key from remote authorized_keys if found, matching EXACT lines.
We'll do:
grep -Fxq for existence
grep -vFx to remove it
"""
# 1) Check if old_pub_data is present in authorized_keys (exact line)
check_cmd = [ check_cmd = [
"ssh", "-o", "StrictHostKeyChecking=no", "ssh", "-o", "StrictHostKeyChecking=no",
"-p", str(port), "-p", str(port),
f"{user}@{hostname}", f"{user}@{hostname}",
f"grep -Fxq \"{old_pub_data}\" ~/.ssh/authorized_keys" f'grep -Fxq "{old_pub_data}" ~/.ssh/authorized_keys'
] ]
found_key = False found_key = False
try: try:
subprocess.check_call(check_cmd) subprocess.check_call(check_cmd)
found_key = True found_key = True
@ -172,12 +168,11 @@ async def remove_old_key_remote(old_pub_data, user, hostname, port):
print_warning("No old key found on remote authorized_keys.") print_warning("No old key found on remote authorized_keys.")
return return
# 2) Actually remove it by ignoring EXACT matches to that line
remove_cmd = [ remove_cmd = [
"ssh", "-o", "StrictHostKeyChecking=no", "ssh", "-o", "StrictHostKeyChecking=no",
"-p", str(port), "-p", str(port),
f"{user}@{hostname}", f"{user}@{hostname}",
f"grep -vFx \"{old_pub_data}\" ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys" f'grep -vFx "{old_pub_data}" ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys'
] ]
try: try:
subprocess.check_call(remove_cmd) subprocess.check_call(remove_cmd)

View file

@ -1,25 +1,34 @@
# ssh_manager/utils.py # ssh_manager/utils.py
from enum import Enum
import sys import sys
from functools import lru_cache
class Colors: class Colors(Enum):
GREEN = "\033[0;32m" GREEN = "\033[0;32m"
RED = "\033[0;31m" RED = "\033[0;31m"
YELLOW = "\033[1;33m" YELLOW = "\033[1;33m"
CYAN = "\033[0;36m" CYAN = "\033[0;36m"
BOLD = "\033[1m" BOLD = "\033[1m"
RESET = "\033[0m" RESET = "\033[0m"
def print_error(message): def __str__(self):
print(f"{Colors.RED}{Colors.BOLD}[✖] {Colors.RESET}{message}") return self.value
def print_warning(message): @lru_cache(maxsize=32)
print(f"{Colors.YELLOW}{Colors.BOLD}[⚠] {Colors.RESET}{message}") def _format_message(prefix: str, color: Colors, message: str) -> str:
return f"{color}{Colors.BOLD}[{prefix}] {Colors.RESET}{message}"
def print_info(message): def print_error(message: str) -> None:
print(f"{Colors.GREEN}{Colors.BOLD}[✔] {Colors.RESET}{message}") print(_format_message("", Colors.RED, message))
def safe_input(prompt=""): def print_warning(message: str) -> None:
print(_format_message("", Colors.YELLOW, message))
def print_info(message: str) -> None:
print(_format_message("", Colors.GREEN, message))
def safe_input(prompt: str = "") -> str:
""" """
A wrapper around input() that exits the entire script on Ctrl+C. A wrapper around input() that exits the entire script on Ctrl+C.
""" """