refactored code

This commit is contained in:
Arctic 2025-03-08 00:43:33 -06:00
parent 73429771d4
commit 414266eefc
6 changed files with 315 additions and 404 deletions

2
cli.py
View file

@ -57,7 +57,7 @@ def main():
elif choice == '2':
add_host(CONF_DIR)
elif choice == '3':
edit_host(CONF_DIR)
asyncio.run(edit_host(CONF_DIR))
elif choice == '4':
asyncio.run(regenerate_key(CONF_DIR))
elif choice == '5':

View file

@ -2,98 +2,31 @@ import os
import asyncio
from collections import OrderedDict
from .utils import print_error, print_warning, print_info, safe_input
from .list_hosts import list_hosts, load_config_file, check_ssh_port
from .list_hosts import (
build_host_list_table,
load_config_file,
gather_host_info,
sort_by_ip
)
async def get_all_host_blocks(conf_dir):
"""
Similar to list_hosts, but returns the list of host blocks + a table of results.
We'll build a table ourselves so we can map row numbers to actual host labels.
"""
import glob
import socket
pattern = os.path.join(conf_dir, "*", "config")
conf_files = sorted(glob.glob(pattern))
all_blocks = []
for conf_file in conf_files:
blocks = load_config_file(conf_file)
all_blocks.extend(blocks)
# If no blocks found, return empty
if not all_blocks:
return []
# We want to do a partial version of check_host to get row data
# so we can display the table right here and keep track of each blocks host label.
# But let's do it similarly to list_hosts:
table_rows = []
for idx, b in enumerate(all_blocks, start=1):
host_label = b.get("Host", "N/A")
hostname = b.get("HostName", "N/A")
user = b.get("User", "N/A")
port = int(b.get("Port", "22"))
identity_file = b.get("IdentityFile", "N/A")
# Identity check
if identity_file != "N/A":
expanded_identity = os.path.expanduser(identity_file)
identity_exists = os.path.isfile(expanded_identity)
else:
identity_exists = False
# IP resolution
try:
ip_address = socket.gethostbyname(hostname)
except socket.error:
ip_address = None
# Port check
if ip_address:
port_open = await asyncio.wait_for(check_ssh_port(ip_address, port), timeout=1)
else:
port_open = False
# Colors for display (optional, or we can keep it simple):
ip_display = f"\033[0;32m{ip_address}\033[0m" if ip_address else "\033[0;31mN/A\033[0m"
port_display = f"\033[0;32m{port}\033[0m" if port_open else f"\033[0;31m{port}\033[0m"
identity_disp= f"\033[0;32m{identity_file}\033[0m" if identity_exists else f"\033[0;31m{identity_file}\033[0m"
row = [
idx,
host_label,
user,
port_display,
hostname,
ip_display,
identity_disp
]
table_rows.append(row)
# Print the table
from tabulate import tabulate
headers = ["No.", "Host", "User", "Port", "HostName", "IP Address", "IdentityFile"]
print("\nSSH Conf Subdirectory Host List")
print(tabulate(table_rows, headers=headers, tablefmt="grid"))
return all_blocks
def edit_host(conf_dir):
async def edit_host(conf_dir):
"""
Let the user update fields for an existing host in ~/.ssh/conf/<label>/config.
The user may type either the row number OR the actual host label.
1) Display the unified table (No. | Host | User | Port | HostName | IP Address | Conf Directory)
2) Prompt row number or host label
3) Rewrite config with updated fields
"""
# 1) Gather + display the current host list
print_info("Here is the current list of hosts:\n")
all_blocks = asyncio.run(get_all_host_blocks(conf_dir))
if not all_blocks:
print_warning("No hosts found to edit.")
headers, final_data = await build_host_list_table(conf_dir)
if not final_data:
print_warning("No hosts to edit.")
return
# 2) Prompt for which host to edit (by label or row number)
from tabulate import tabulate
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid"))
choice = safe_input("Enter the row number or the Host label to edit: ")
if choice is None:
return # user canceled (Ctrl+C)
@ -102,43 +35,47 @@ def edit_host(conf_dir):
print_error("Host label or row number cannot be empty.")
return
# Check if user typed a digit -> row number
target_block = None
# We replicate the approach to find the matching block
all_blocks = []
import glob
for cfile in glob.glob(os.path.join(conf_dir, "*", "config")):
all_blocks.extend(load_config_file(cfile))
results = await gather_host_info(all_blocks)
sorted_rows = sort_by_ip(results)
target_tuple = None
if choice.isdigit():
row_idx = int(choice)
# Validate index
if row_idx < 1 or row_idx > len(all_blocks):
print_warning(f"Invalid row number {row_idx}.")
idx = int(choice)
if idx < 1 or idx > len(sorted_rows):
print_warning(f"Row number {idx} is invalid.")
return
target_block = all_blocks[row_idx - 1] # zero-based
target_tuple = sorted_rows[idx - 1]
else:
# The user typed a host label
# We must search all_blocks for a matching Host
for b in all_blocks:
if b.get("Host") == choice:
target_block = b
for t in sorted_rows:
if t[0] == choice: # t[0] => host_label
target_tuple = t
break
if not target_block:
print_warning(f"No matching host label '{choice}' found.")
if not target_tuple:
print_warning(f"No matching Host '{choice}' found in the table.")
return
# Now we have a target_block with existing data
host_label = target_block.get("Host", "")
if not host_label:
print_warning("This host block has no label. Cannot edit.")
host_label = target_tuple[0]
# find the config block
found_block = None
for b in all_blocks:
if b.get("Host") == host_label:
found_block = b
break
if not found_block:
print_warning(f"No config block found for '{host_label}'.")
return
# Derive the config path
host_dir = os.path.join(conf_dir, host_label)
config_path = os.path.join(host_dir, "config")
if not os.path.isfile(config_path):
print_warning(f"No config file found at {config_path}; cannot edit this host.")
return
old_hostname = target_block.get("HostName", "")
old_user = target_block.get("User", "")
old_port = target_block.get("Port", "22")
old_identity = target_block.get("IdentityFile", "")
old_hostname = found_block.get("HostName", "")
old_user = found_block.get("User", "")
old_port = found_block.get("Port", "22")
old_identity = found_block.get("IdentityFile", "")
print_info("Leave a field blank to keep its current value.")
@ -167,6 +104,8 @@ def edit_host(conf_dir):
final_port = new_port if new_port else old_port
final_ident = new_ident if new_ident else old_identity
# Overwrite the file
config_path = os.path.join(conf_dir, host_label, "config")
new_config_lines = [
f"Host {host_label}",
f" HostName {final_hostname}",

View file

@ -11,10 +11,6 @@ from collections import OrderedDict
from .utils import print_warning, print_error, Colors
async def check_ssh_port(ip_address, port):
"""
Attempt to open an SSH connection to see if the port is open.
Returns True if successful, False otherwise.
"""
try:
reader, writer = await asyncio.wait_for(
asyncio.open_connection(ip_address, port), timeout=1
@ -26,13 +22,8 @@ async def check_ssh_port(ip_address, port):
return False
def load_config_file(file_path):
"""
Parse a single SSH config file and return a list of host blocks.
Each block is an OrderedDict with keys like 'Host', 'HostName', etc.
"""
blocks = []
host_data = None
try:
with open(file_path, 'r') as f:
lines = f.readlines()
@ -42,11 +33,9 @@ def load_config_file(file_path):
for line in lines:
stripped_line = line.strip()
# Skip empty lines and comments
if not stripped_line or stripped_line.startswith('#'):
continue
# Start of a new Host block
if stripped_line.lower().startswith('host '):
host_labels = stripped_line.split()[1:]
for label in host_labels:
@ -64,27 +53,27 @@ def load_config_file(file_path):
blocks.append(host_data)
return blocks
async def check_host(host):
async def gather_host_info(all_host_blocks):
"""
Given a host block, resolve IP, check SSH port, etc.
Returns a tuple of:
1) Host label
2) User
3) Port (colored if open)
4) HostName
5) IP Address (colored if resolved)
6) Conf Directory (green if has IdentityFile, else no color)
7) raw_ip (uncolored string for sorting)
Given a list of host blocks, gather full info:
- resolved IP
- port open check
- 'Conf Directory' coloring if IdentityFile != 'N/A'
Returns a list of 7-tuples:
(host_label, user, colored_port, hostname, colored_ip, conf_path_display, raw_ip)
"""
host_label = host.get('Host', 'N/A')
hostname = host.get('HostName', 'N/A')
user = host.get('User', 'N/A')
port = int(host.get('Port', '22'))
identity_file = host.get('IdentityFile', 'N/A')
results = []
async def process_block(h):
host_label = h.get('Host', 'N/A')
hostname = h.get('HostName', 'N/A')
user = h.get('User', 'N/A')
port = int(h.get('Port', '22'))
identity_file = h.get('IdentityFile', 'N/A')
# Resolve IP
try:
raw_ip = socket.gethostbyname(hostname) # uncolored
raw_ip = socket.gethostbyname(hostname)
colored_ip = f"{Colors.GREEN}{raw_ip}{Colors.RESET}"
except socket.error:
raw_ip = "N/A"
@ -101,13 +90,12 @@ async def check_host(host):
# Conf Directory = ~/.ssh/conf/<host_label>
conf_path = f"~/.ssh/conf/{host_label}"
# If there's an IdentityFile, color the conf path green
# If there's an IdentityFile, color the conf path
if identity_file != 'N/A':
conf_path_display = f"{Colors.GREEN}{conf_path}{Colors.RESET}"
else:
conf_path_display = conf_path
# Return the data plus the uncolored IP for sorting
return (
host_label,
user,
@ -115,13 +103,42 @@ async def check_host(host):
hostname,
colored_ip,
conf_path_display,
raw_ip # for sorting
raw_ip # uncolored IP for sorting
)
async def list_hosts(conf_dir):
# Process blocks concurrently
tasks = [process_block(b) for b in all_host_blocks]
results = await asyncio.gather(*tasks)
return results
def parse_ip(ip_str):
"""
List out all hosts found in ~/.ssh/conf/*/config, sorted by IP in ascending order.
Columns: No., Host, User, Port, HostName, IP Address, Conf Directory
Convert a string IP to an ipaddress object for sorting.
Returns None if invalid or 'N/A'.
"""
import ipaddress
try:
return ipaddress.ip_address(ip_str)
except ValueError:
return None
def sort_by_ip(results):
"""
Sort the 7-tuples by IP ascending, with 'N/A' last.
"""
sortable = []
for row in results:
raw_ip = row[-1]
ip_obj = parse_ip(raw_ip)
sortable.append(((ip_obj is None, ip_obj), row))
sortable.sort(key=lambda x: x[0])
return [row for (_, row) in sortable]
async def build_host_list_table(conf_dir):
"""
Gathers + sorts all hosts in conf_dir by IP ascending.
Returns (headers, final_table_rows), each row omitting the raw_ip.
"""
pattern = os.path.join(conf_dir, "*", "config")
conf_files = sorted(glob.glob(pattern))
@ -133,42 +150,28 @@ async def list_hosts(conf_dir):
headers = ["No.", "Host", "User", "Port", "HostName", "IP Address", "Conf Directory"]
if not all_host_blocks:
return headers, []
results = await gather_host_info(all_host_blocks)
sorted_rows = sort_by_ip(results)
# Build final table
final_data = []
for idx, row in enumerate(sorted_rows, start=1):
# row is (host_label, user, colored_port, hostname, colored_ip, conf_path_display, raw_ip)
final_data.append([idx] + list(row[:-1]))
return headers, final_data
async def list_hosts(conf_dir):
headers, final_data = await build_host_list_table(conf_dir)
if not final_data:
print_warning("No hosts found. The server list is empty.")
print("\nSSH Conf Subdirectory Host List")
from tabulate import tabulate
print(tabulate([], headers=headers, tablefmt="grid"))
return
# Gather full data for each host
tasks = [check_host(h) for h in all_host_blocks]
results = await asyncio.gather(*tasks)
# We want to sort by IP ascending. results[i] is a tuple:
# (host_label, user, colored_port, hostname, colored_ip, conf_path, raw_ip)
# We'll parse raw_ip as an ipaddress for sorting. "N/A" => sort to the end.
def parse_ip(ip_str):
try:
return ipaddress.ip_address(ip_str)
except ValueError:
return None
# Convert the results into a list of (ip_obj, original_tuple)
# so we can sort, then rebuild the final data.
sortable = []
for row in results:
raw_ip = row[-1] # last element
ip_obj = parse_ip(raw_ip)
# We'll sort None last by using a sort key that puts (True) after (False)
# e.g. (ip_obj is None, ip_obj)
sortable.append(((ip_obj is None, ip_obj), row))
# Sort by (is_none, ip_obj)
sortable.sort(key=lambda x: x[0])
# Rebuild the final display table, ignoring the raw_ip at the end
final_data = []
for idx, (_, row) in enumerate(sortable, start=1):
# row is (host_label, user, colored_port, hostname, colored_ip, conf_path, raw_ip)
final_data.append([idx] + list(row[:-1])) # omit raw_ip
from tabulate import tabulate
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid"))

View file

@ -1,88 +1,54 @@
import os
import glob
import subprocess
import asyncio
from collections import OrderedDict
from .utils import (
print_info,
print_warning,
print_error,
safe_input,
Colors
safe_input
)
from .list_hosts import (
build_host_list_table,
load_config_file,
gather_host_info,
sort_by_ip
)
from .list_hosts import load_config_file, check_ssh_port
async def get_all_host_blocks(conf_dir):
pattern = os.path.join(conf_dir, "*", "config")
conf_files = sorted(glob.glob(pattern))
all_blocks = []
for conf_file in conf_files:
blocks = load_config_file(conf_file)
all_blocks.extend(blocks)
if not all_blocks:
return []
return all_blocks
async def regenerate_key(conf_dir):
"""
Menu-driven function to regenerate a key for a selected host:
1) Show the host table with row numbers
2) Let user pick row # or label
3) Read old pub data
4) Remove old local key files
5) Generate new key
6) Copy new key
7) Remove old key from remote authorized_keys (improved logic to detect existence)
Regenerate the SSH key for a chosen host by:
1) Displaying the unified table of hosts (No. | Host | User | Port | HostName | IP Address | Conf Directory).
2) Letting you pick a row number or host label.
3) Reading/deleting any existing local keys,
4) Generating a new key,
5) Optionally copying it to the remote,
6) Removing the old pub key from the remote authorized_keys if present.
"""
print_info("Regenerate Key - Step 1: Show current hosts...\n")
# 1) Gather host blocks
all_blocks = await get_all_host_blocks(conf_dir)
if not all_blocks:
# 1) Reuse the same columns as the main 'list_hosts'
headers, final_data = await build_host_list_table(conf_dir)
if not final_data:
print_warning("No hosts found. Cannot regenerate a key.")
return
# Display them in a table (similar to list_hosts):
import socket
from tabulate import tabulate
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid"))
table_rows = []
for idx, block in enumerate(all_blocks, start=1):
host_label = block.get("Host", "N/A")
hostname = block.get("HostName", "N/A")
user = block.get("User", "N/A")
port = int(block.get("Port", "22"))
identity_file = block.get("IdentityFile", "N/A")
# 2) We need to correlate row => actual config block.
# We'll replicate the logic that build_host_list_table uses.
all_blocks = []
import glob
for cfile in glob.glob(os.path.join(conf_dir, "*", "config")):
all_blocks.extend(load_config_file(cfile))
# Check IP
try:
ip_address = socket.gethostbyname(hostname)
port_open = await asyncio.wait_for(check_ssh_port(ip_address, port), timeout=1)
except:
ip_address = None
port_open = False
results = await gather_host_info(all_blocks)
sorted_rows = sort_by_ip(results)
# sorted_rows is a list of 7-tuples:
# (host_label, user, colored_port, hostname, colored_ip, conf_path, raw_ip)
ip_disp = f"\033[0;32m{ip_address}\033[0m" if ip_address else "\033[0;31mN/A\033[0m"
port_disp = f"\033[0;32m{port}\033[0m" if port_open else f"\033[0;31m{port}\033[0m"
row = [
idx,
host_label,
user,
port_disp,
hostname,
ip_disp,
identity_file
]
table_rows.append(row)
headers = ["No.", "Host", "User", "Port", "HostName", "IP", "IdentityFile"]
print("\nSSH Conf Subdirectory Host List")
print(tabulate(table_rows, headers=headers, tablefmt="grid"))
# 2) Prompt for row # or label
choice = safe_input("Enter the row number or the Host label to regenerate: ")
if choice is None:
return
@ -91,54 +57,60 @@ async def regenerate_key(conf_dir):
print_error("No choice given.")
return
target_block = None
target_tuple = None
if choice.isdigit():
row_idx = int(choice)
if row_idx < 1 or row_idx > len(all_blocks):
if row_idx < 1 or row_idx > len(sorted_rows):
print_warning(f"Invalid row number {row_idx}.")
return
target_block = all_blocks[row_idx - 1]
target_tuple = sorted_rows[row_idx - 1]
else:
# user typed a label
for b in all_blocks:
if b.get("Host") == choice:
target_block = b
for t in sorted_rows:
if t[0] == choice: # t[0] is host_label
target_tuple = t
break
if not target_block:
if not target_tuple:
print_warning(f"No matching host label '{choice}' found.")
return
# 3) Gather info from block
host_label = target_block.get("Host", "")
hostname = target_block.get("HostName", "")
user = target_block.get("User", "root")
port = int(target_block.get("Port", "22"))
identity_file = target_block.get("IdentityFile", "")
# 3) Retrieve the full config block so we have Port, IdentityFile, etc.
host_label = target_tuple[0]
hostname = target_tuple[3] # (host_label, user, colored_port, hostname, ...)
user = target_tuple[1]
if not host_label or not hostname:
print_error("Missing Host or HostName; cannot regenerate.")
# find that block
found_block = None
for b in all_blocks:
if b.get("Host") == host_label:
found_block = b
break
if not found_block:
print_warning(f"No config block found for '{host_label}'.")
return
port_str = found_block.get("Port", "22")
port = int(port_str)
identity_file = found_block.get("IdentityFile", "")
# If missing or "N/A", we can't regenerate
if not identity_file or identity_file == "N/A":
print_error("No IdentityFile found in config; cannot regenerate.")
return
# Derive local paths
# 4) Remove old local key files
expanded_key = os.path.expanduser(identity_file)
key_dir = os.path.dirname(expanded_key)
pub_path = expanded_key + ".pub"
# 3a) Read old pub key data
old_pub_data = ""
if os.path.isfile(pub_path):
try:
with open(pub_path, "r") as f:
old_pub_data = f.read().strip()
old_pub_data = f.read().rstrip("\n")
except Exception as e:
print_warning(f"Could not read old pub key: {e}")
else:
print_warning("No old pub key found locally.")
# 4) Remove old local key files
print_info("Removing old key files locally...")
for path in [expanded_key, pub_path]:
if os.path.isfile(path):
@ -150,7 +122,7 @@ async def regenerate_key(conf_dir):
# 5) Generate new key
print_info("Generating new ed25519 SSH key...")
new_key_path = expanded_key # Reuse the same path from config
new_key_path = expanded_key
cmd = ["ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-f", new_key_path]
try:
subprocess.check_call(cmd)
@ -159,7 +131,7 @@ async def regenerate_key(conf_dir):
print_error(f"Error generating new SSH key: {e}")
return
# 6) Copy new key to remote
# 6) Copy the new key to remote if user wants
copy_choice = safe_input("Copy new key to remote now? (y/n): ")
if copy_choice and copy_choice.lower().startswith('y'):
ssh_copy_cmd = ["ssh-copy-id", "-i", new_key_path]
@ -172,7 +144,7 @@ async def regenerate_key(conf_dir):
except subprocess.CalledProcessError as e:
print_error(f"Error copying new key: {e}")
# 7) Remove old key from authorized_keys if old_pub_data is non-empty
# 7) Remove old key from authorized_keys if we had old_pub_data
if old_pub_data:
print_info("Attempting to remove old key from remote authorized_keys...")
await remove_old_key_remote(old_pub_data, user, hostname, port)
@ -181,41 +153,34 @@ async def regenerate_key(conf_dir):
async def remove_old_key_remote(old_pub_data, user, hostname, port):
"""
1) Check if old_pub_data is present on remote server (grep -q).
2) If found, remove it with grep -v ...
3) Otherwise, print "No old key found on remote."
Checks and removes the exact matching line from remote authorized_keys
via grep -Fxq / grep -vFx.
"""
# 1) Check if old_pub_data exists in authorized_keys
check_cmd = [
"ssh",
"-o", "StrictHostKeyChecking=no",
"ssh", "-o", "StrictHostKeyChecking=no",
"-p", str(port),
f"{user}@{hostname}",
f"grep -F '{old_pub_data}' ~/.ssh/authorized_keys"
f'grep -Fxq "{old_pub_data}" ~/.ssh/authorized_keys'
]
found_key = False
try:
subprocess.check_call(check_cmd)
found_key = True
except subprocess.CalledProcessError:
# grep returns exit code 1 if not found
pass
if not found_key:
print_warning("No old key found on remote authorized_keys.")
return
# 2) Actually remove it
remove_cmd = [
"ssh",
"-o", "StrictHostKeyChecking=no",
"ssh", "-o", "StrictHostKeyChecking=no",
"-p", str(port),
f"{user}@{hostname}",
f"grep -v '{old_pub_data}' ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys"
f'grep -vFx "{old_pub_data}" ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys'
]
try:
subprocess.check_call(remove_cmd)
print_info("Old public key removed from remote authorized_keys.")
except subprocess.CalledProcessError:
print_warning("Failed to remove old key from remote authorized_keys (permission or other error).")
print_warning("Failed to remove old key from remote authorized_keys (permissions or other error).")

View file

@ -1,82 +1,68 @@
# ssh_manager/remove_host.py
import os
import glob
import subprocess
import asyncio
from collections import OrderedDict
from .utils import (
print_info,
print_warning,
print_error,
safe_input
)
from .list_hosts import load_config_file, check_ssh_port
async def get_all_host_blocks(conf_dir):
pattern = os.path.join(conf_dir, "*", "config")
conf_files = sorted(glob.glob(pattern))
all_blocks = []
for conf_file in conf_files:
blocks = load_config_file(conf_file)
all_blocks.extend(blocks)
return all_blocks
from .list_hosts import build_host_list_table, load_config_file
"""
Remove host now reuses build_host_list_table to display the same columns:
No. | Host | User | Port | HostName | IP Address | Conf Directory
"""
async def remove_host(conf_dir):
"""
Remove an SSH host by:
1) Listing all hosts
1) Listing all hosts (same columns as main list)
2) Letting user pick row number or label
3) Attempting to remove the old pub key from remote authorized_keys
4) Deleting the subdirectory in ~/.ssh/conf/<host_label>
3) Removing old pub key from remote
4) Deleting ~/.ssh/conf/<host_label>
"""
print_info("Remove Host - Step 1: Show current hosts...\n")
all_blocks = await get_all_host_blocks(conf_dir)
if not all_blocks:
# Reuse the unified table from list_hosts
headers, final_data = await build_host_list_table(conf_dir)
if not final_data:
print_warning("No hosts found. Cannot remove anything.")
return
# We'll display them in a table
import socket
# Print the same table
from tabulate import tabulate
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
print(tabulate(final_data, headers=headers, tablefmt="grid"))
table_rows = []
for idx, block in enumerate(all_blocks, start=1):
host_label = block.get("Host", "N/A")
hostname = block.get("HostName", "N/A")
user = block.get("User", "N/A")
port = int(block.get("Port", "22"))
identity_file = block.get("IdentityFile", "N/A")
# We have final_data rows => need to map row => block
# So let's gather the raw blocks again to correlate.
# We'll do a separate approach or we can parse final_data.
# Easiest: Re-run load_config_file if needed or:
blocks = []
# The last gather call for build_host_list_table used load_config_file already
# but it doesn't return the correlation. We'll replicate the logic quickly.
try:
ip_address = socket.gethostbyname(hostname)
port_open = await asyncio.wait_for(check_ssh_port(ip_address, port), timeout=1)
except:
ip_address = None
port_open = False
pattern = os.path.join(conf_dir, "*", "config")
conf_files = sorted(os.listdir(conf_dir))
ip_disp = f"\033[0;32m{ip_address}\033[0m" if ip_address else "\033[0;31mN/A\033[0m"
port_disp = f"\033[0;32m{port}\033[0m" if port_open else f"\033[0;31m{port}\033[0m"
# Actually, let's do a small approach:
from .list_hosts import gather_host_info, sort_by_ip
all_blocks = []
import glob
row = [
idx,
host_label,
user,
port_disp,
hostname,
ip_disp,
identity_file
]
table_rows.append(row)
for cfile in glob.glob(os.path.join(conf_dir, "*", "config")):
blocks.extend(load_config_file(cfile))
headers = ["No.", "Host", "User", "Port", "HostName", "IP", "IdentityFile"]
print("\nSSH Conf Subdirectory Host List")
print(tabulate(table_rows, headers=headers, tablefmt="grid"))
# gather the same big list
results = await gather_host_info(blocks)
sorted_rows = sort_by_ip(results)
# sorted_rows is a list of 7-tuples:
# (host_label, user, colored_port, hostname, colored_ip, conf_display, raw_ip)
# Prompt which host to remove
choice = safe_input("Enter the row number or Host label to remove: ")
if choice is None:
return
@ -85,43 +71,60 @@ async def remove_host(conf_dir):
print_error("Invalid empty choice.")
return
target_block = None
target_tuple = None
# If digit => index in sorted_rows
if choice.isdigit():
idx = int(choice)
if idx < 1 or idx > len(all_blocks):
if idx < 1 or idx > len(sorted_rows):
print_warning(f"Row number {idx} is invalid.")
return
target_block = all_blocks[idx - 1]
target_tuple = sorted_rows[idx - 1]
else:
# They typed a label
for b in all_blocks:
if b.get("Host") == choice:
target_block = b
for t in sorted_rows:
if t[0] == choice: # t[0] = host_label
target_tuple = t
break
if not target_block:
if not target_tuple:
print_warning(f"No matching host label '{choice}' found.")
return
host_label = target_block.get("Host", "")
hostname = target_block.get("HostName", "")
user = target_block.get("User", "root")
port = int(target_block.get("Port", "22"))
identity_file = target_block.get("IdentityFile", "")
# target_tuple is (host_label, user, colored_port, hostname, colored_ip, conf_dir, raw_ip)
host_label = target_tuple[0]
hostname = target_tuple[3]
user = target_tuple[1]
# parse port from colored_port? We can do a quick approach. Alternatively, we re-run the load_config approach again.
# We do a second approach: let's see if we can find the block in "blocks".
# But let's parse the config block to get the real port, identity, etc.
# We find the matching block in "blocks" by host_label
found_block = None
for b in blocks:
if b.get("Host") == host_label:
found_block = b
break
if not found_block:
print_warning(f"No config block found for {host_label}.")
return
port_str = found_block.get("Port", "22")
port = int(port_str)
identity_file = found_block.get("IdentityFile", "")
# now do removal logic
if not host_label:
print_warning("Target block has no Host label. Cannot remove.")
return
# Check IdentityFile for old pub key
if identity_file and identity_file != "N/A":
expanded_key = os.path.expanduser(identity_file)
pub_path = expanded_key + ".pub"
old_pub_data = ""
if os.path.isfile(pub_path):
try:
with open(pub_path, "r") as f:
# This is the EXACT line that might appear in authorized_keys
old_pub_data = f.read().rstrip("\n")
except Exception as e:
print_warning(f"Could not read old pub key: {e}")
@ -130,13 +133,13 @@ async def remove_host(conf_dir):
print_info("Attempting to remove old key from remote authorized_keys...")
await remove_old_key_remote(old_pub_data, user, hostname, port)
# Finally, remove the local subdirectory
# remove local folder
host_dir = os.path.join(conf_dir, host_label)
import shutil
if os.path.isdir(host_dir):
confirm = safe_input(f"Are you sure you want to delete local folder {host_dir}? (y/n): ")
if confirm and confirm.lower().startswith('y'):
try:
import shutil
shutil.rmtree(host_dir)
print_info(f"Removed local folder: {host_dir}")
except Exception as e:
@ -144,24 +147,17 @@ async def remove_host(conf_dir):
else:
print_warning("Local folder not removed.")
else:
print_warning(f"Local host folder {host_dir} not found. Nothing to remove.")
print_warning(f"Local folder {host_dir} not found. Nothing to remove.")
async def remove_old_key_remote(old_pub_data, user, hostname, port):
"""
Remove the old key from remote authorized_keys if found, matching EXACT lines.
We'll do:
grep -Fxq for existence
grep -vFx to remove it
"""
# 1) Check if old_pub_data is present in authorized_keys (exact line)
# same logic as before
check_cmd = [
"ssh", "-o", "StrictHostKeyChecking=no",
"-p", str(port),
f"{user}@{hostname}",
f"grep -Fxq \"{old_pub_data}\" ~/.ssh/authorized_keys"
f'grep -Fxq "{old_pub_data}" ~/.ssh/authorized_keys'
]
found_key = False
try:
subprocess.check_call(check_cmd)
found_key = True
@ -172,12 +168,11 @@ async def remove_old_key_remote(old_pub_data, user, hostname, port):
print_warning("No old key found on remote authorized_keys.")
return
# 2) Actually remove it by ignoring EXACT matches to that line
remove_cmd = [
"ssh", "-o", "StrictHostKeyChecking=no",
"-p", str(port),
f"{user}@{hostname}",
f"grep -vFx \"{old_pub_data}\" ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys"
f'grep -vFx "{old_pub_data}" ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys'
]
try:
subprocess.check_call(remove_cmd)

View file

@ -1,8 +1,10 @@
# ssh_manager/utils.py
from enum import Enum
import sys
from functools import lru_cache
class Colors:
class Colors(Enum):
GREEN = "\033[0;32m"
RED = "\033[0;31m"
YELLOW = "\033[1;33m"
@ -10,16 +12,23 @@ class Colors:
BOLD = "\033[1m"
RESET = "\033[0m"
def print_error(message):
print(f"{Colors.RED}{Colors.BOLD}[✖] {Colors.RESET}{message}")
def __str__(self):
return self.value
def print_warning(message):
print(f"{Colors.YELLOW}{Colors.BOLD}[⚠] {Colors.RESET}{message}")
@lru_cache(maxsize=32)
def _format_message(prefix: str, color: Colors, message: str) -> str:
return f"{color}{Colors.BOLD}[{prefix}] {Colors.RESET}{message}"
def print_info(message):
print(f"{Colors.GREEN}{Colors.BOLD}[✔] {Colors.RESET}{message}")
def print_error(message: str) -> None:
print(_format_message("", Colors.RED, message))
def safe_input(prompt=""):
def print_warning(message: str) -> None:
print(_format_message("", Colors.YELLOW, message))
def print_info(message: str) -> None:
print(_format_message("", Colors.GREEN, message))
def safe_input(prompt: str = "") -> str:
"""
A wrapper around input() that exits the entire script on Ctrl+C.
"""