refactored code
This commit is contained in:
parent
73429771d4
commit
414266eefc
6 changed files with 315 additions and 404 deletions
2
cli.py
2
cli.py
|
@ -57,7 +57,7 @@ def main():
|
|||
elif choice == '2':
|
||||
add_host(CONF_DIR)
|
||||
elif choice == '3':
|
||||
edit_host(CONF_DIR)
|
||||
asyncio.run(edit_host(CONF_DIR))
|
||||
elif choice == '4':
|
||||
asyncio.run(regenerate_key(CONF_DIR))
|
||||
elif choice == '5':
|
||||
|
|
165
edit_host.py
165
edit_host.py
|
@ -2,98 +2,31 @@ import os
|
|||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from .utils import print_error, print_warning, print_info, safe_input
|
||||
from .list_hosts import list_hosts, load_config_file, check_ssh_port
|
||||
from .list_hosts import (
|
||||
build_host_list_table,
|
||||
load_config_file,
|
||||
gather_host_info,
|
||||
sort_by_ip
|
||||
)
|
||||
|
||||
async def get_all_host_blocks(conf_dir):
|
||||
"""
|
||||
Similar to list_hosts, but returns the list of host blocks + a table of results.
|
||||
We'll build a table ourselves so we can map row numbers to actual host labels.
|
||||
"""
|
||||
import glob
|
||||
import socket
|
||||
|
||||
pattern = os.path.join(conf_dir, "*", "config")
|
||||
conf_files = sorted(glob.glob(pattern))
|
||||
|
||||
all_blocks = []
|
||||
for conf_file in conf_files:
|
||||
blocks = load_config_file(conf_file)
|
||||
all_blocks.extend(blocks)
|
||||
|
||||
# If no blocks found, return empty
|
||||
if not all_blocks:
|
||||
return []
|
||||
|
||||
# We want to do a partial version of check_host to get row data
|
||||
# so we can display the table right here and keep track of each block’s host label.
|
||||
# But let's do it similarly to list_hosts:
|
||||
|
||||
table_rows = []
|
||||
for idx, b in enumerate(all_blocks, start=1):
|
||||
host_label = b.get("Host", "N/A")
|
||||
hostname = b.get("HostName", "N/A")
|
||||
user = b.get("User", "N/A")
|
||||
port = int(b.get("Port", "22"))
|
||||
identity_file = b.get("IdentityFile", "N/A")
|
||||
|
||||
# Identity check
|
||||
if identity_file != "N/A":
|
||||
expanded_identity = os.path.expanduser(identity_file)
|
||||
identity_exists = os.path.isfile(expanded_identity)
|
||||
else:
|
||||
identity_exists = False
|
||||
|
||||
# IP resolution
|
||||
try:
|
||||
ip_address = socket.gethostbyname(hostname)
|
||||
except socket.error:
|
||||
ip_address = None
|
||||
|
||||
# Port check
|
||||
if ip_address:
|
||||
port_open = await asyncio.wait_for(check_ssh_port(ip_address, port), timeout=1)
|
||||
else:
|
||||
port_open = False
|
||||
|
||||
# Colors for display (optional, or we can keep it simple):
|
||||
ip_display = f"\033[0;32m{ip_address}\033[0m" if ip_address else "\033[0;31mN/A\033[0m"
|
||||
port_display = f"\033[0;32m{port}\033[0m" if port_open else f"\033[0;31m{port}\033[0m"
|
||||
identity_disp= f"\033[0;32m{identity_file}\033[0m" if identity_exists else f"\033[0;31m{identity_file}\033[0m"
|
||||
|
||||
row = [
|
||||
idx,
|
||||
host_label,
|
||||
user,
|
||||
port_display,
|
||||
hostname,
|
||||
ip_display,
|
||||
identity_disp
|
||||
]
|
||||
table_rows.append(row)
|
||||
|
||||
# Print the table
|
||||
from tabulate import tabulate
|
||||
headers = ["No.", "Host", "User", "Port", "HostName", "IP Address", "IdentityFile"]
|
||||
print("\nSSH Conf Subdirectory Host List")
|
||||
print(tabulate(table_rows, headers=headers, tablefmt="grid"))
|
||||
|
||||
return all_blocks
|
||||
|
||||
def edit_host(conf_dir):
|
||||
async def edit_host(conf_dir):
|
||||
"""
|
||||
Let the user update fields for an existing host in ~/.ssh/conf/<label>/config.
|
||||
The user may type either the row number OR the actual host label.
|
||||
1) Display the unified table (No. | Host | User | Port | HostName | IP Address | Conf Directory)
|
||||
2) Prompt row number or host label
|
||||
3) Rewrite config with updated fields
|
||||
"""
|
||||
|
||||
# 1) Gather + display the current host list
|
||||
print_info("Here is the current list of hosts:\n")
|
||||
all_blocks = asyncio.run(get_all_host_blocks(conf_dir))
|
||||
|
||||
if not all_blocks:
|
||||
print_warning("No hosts found to edit.")
|
||||
headers, final_data = await build_host_list_table(conf_dir)
|
||||
if not final_data:
|
||||
print_warning("No hosts to edit.")
|
||||
return
|
||||
|
||||
# 2) Prompt for which host to edit (by label or row number)
|
||||
from tabulate import tabulate
|
||||
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
|
||||
print(tabulate(final_data, headers=headers, tablefmt="grid"))
|
||||
|
||||
choice = safe_input("Enter the row number or the Host label to edit: ")
|
||||
if choice is None:
|
||||
return # user canceled (Ctrl+C)
|
||||
|
@ -102,43 +35,47 @@ def edit_host(conf_dir):
|
|||
print_error("Host label or row number cannot be empty.")
|
||||
return
|
||||
|
||||
# Check if user typed a digit -> row number
|
||||
target_block = None
|
||||
# We replicate the approach to find the matching block
|
||||
all_blocks = []
|
||||
import glob
|
||||
for cfile in glob.glob(os.path.join(conf_dir, "*", "config")):
|
||||
all_blocks.extend(load_config_file(cfile))
|
||||
|
||||
results = await gather_host_info(all_blocks)
|
||||
sorted_rows = sort_by_ip(results)
|
||||
|
||||
target_tuple = None
|
||||
if choice.isdigit():
|
||||
row_idx = int(choice)
|
||||
# Validate index
|
||||
if row_idx < 1 or row_idx > len(all_blocks):
|
||||
print_warning(f"Invalid row number {row_idx}.")
|
||||
idx = int(choice)
|
||||
if idx < 1 or idx > len(sorted_rows):
|
||||
print_warning(f"Row number {idx} is invalid.")
|
||||
return
|
||||
target_block = all_blocks[row_idx - 1] # zero-based
|
||||
target_tuple = sorted_rows[idx - 1]
|
||||
else:
|
||||
# The user typed a host label
|
||||
# We must search all_blocks for a matching Host
|
||||
for b in all_blocks:
|
||||
if b.get("Host") == choice:
|
||||
target_block = b
|
||||
for t in sorted_rows:
|
||||
if t[0] == choice: # t[0] => host_label
|
||||
target_tuple = t
|
||||
break
|
||||
if not target_block:
|
||||
print_warning(f"No matching host label '{choice}' found.")
|
||||
if not target_tuple:
|
||||
print_warning(f"No matching Host '{choice}' found in the table.")
|
||||
return
|
||||
|
||||
# Now we have a target_block with existing data
|
||||
host_label = target_block.get("Host", "")
|
||||
if not host_label:
|
||||
print_warning("This host block has no label. Cannot edit.")
|
||||
host_label = target_tuple[0]
|
||||
# find the config block
|
||||
found_block = None
|
||||
for b in all_blocks:
|
||||
if b.get("Host") == host_label:
|
||||
found_block = b
|
||||
break
|
||||
|
||||
if not found_block:
|
||||
print_warning(f"No config block found for '{host_label}'.")
|
||||
return
|
||||
|
||||
# Derive the config path
|
||||
host_dir = os.path.join(conf_dir, host_label)
|
||||
config_path = os.path.join(host_dir, "config")
|
||||
if not os.path.isfile(config_path):
|
||||
print_warning(f"No config file found at {config_path}; cannot edit this host.")
|
||||
return
|
||||
|
||||
old_hostname = target_block.get("HostName", "")
|
||||
old_user = target_block.get("User", "")
|
||||
old_port = target_block.get("Port", "22")
|
||||
old_identity = target_block.get("IdentityFile", "")
|
||||
old_hostname = found_block.get("HostName", "")
|
||||
old_user = found_block.get("User", "")
|
||||
old_port = found_block.get("Port", "22")
|
||||
old_identity = found_block.get("IdentityFile", "")
|
||||
|
||||
print_info("Leave a field blank to keep its current value.")
|
||||
|
||||
|
@ -167,6 +104,8 @@ def edit_host(conf_dir):
|
|||
final_port = new_port if new_port else old_port
|
||||
final_ident = new_ident if new_ident else old_identity
|
||||
|
||||
# Overwrite the file
|
||||
config_path = os.path.join(conf_dir, host_label, "config")
|
||||
new_config_lines = [
|
||||
f"Host {host_label}",
|
||||
f" HostName {final_hostname}",
|
||||
|
|
189
list_hosts.py
189
list_hosts.py
|
@ -11,10 +11,6 @@ from collections import OrderedDict
|
|||
from .utils import print_warning, print_error, Colors
|
||||
|
||||
async def check_ssh_port(ip_address, port):
|
||||
"""
|
||||
Attempt to open an SSH connection to see if the port is open.
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
try:
|
||||
reader, writer = await asyncio.wait_for(
|
||||
asyncio.open_connection(ip_address, port), timeout=1
|
||||
|
@ -26,13 +22,8 @@ async def check_ssh_port(ip_address, port):
|
|||
return False
|
||||
|
||||
def load_config_file(file_path):
|
||||
"""
|
||||
Parse a single SSH config file and return a list of host blocks.
|
||||
Each block is an OrderedDict with keys like 'Host', 'HostName', etc.
|
||||
"""
|
||||
blocks = []
|
||||
host_data = None
|
||||
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
@ -42,11 +33,9 @@ def load_config_file(file_path):
|
|||
|
||||
for line in lines:
|
||||
stripped_line = line.strip()
|
||||
# Skip empty lines and comments
|
||||
if not stripped_line or stripped_line.startswith('#'):
|
||||
continue
|
||||
|
||||
# Start of a new Host block
|
||||
if stripped_line.lower().startswith('host '):
|
||||
host_labels = stripped_line.split()[1:]
|
||||
for label in host_labels:
|
||||
|
@ -64,64 +53,92 @@ def load_config_file(file_path):
|
|||
blocks.append(host_data)
|
||||
return blocks
|
||||
|
||||
async def check_host(host):
|
||||
async def gather_host_info(all_host_blocks):
|
||||
"""
|
||||
Given a host block, resolve IP, check SSH port, etc.
|
||||
Returns a tuple of:
|
||||
1) Host label
|
||||
2) User
|
||||
3) Port (colored if open)
|
||||
4) HostName
|
||||
5) IP Address (colored if resolved)
|
||||
6) Conf Directory (green if has IdentityFile, else no color)
|
||||
7) raw_ip (uncolored string for sorting)
|
||||
Given a list of host blocks, gather full info:
|
||||
- resolved IP
|
||||
- port open check
|
||||
- 'Conf Directory' coloring if IdentityFile != 'N/A'
|
||||
Returns a list of 7-tuples:
|
||||
(host_label, user, colored_port, hostname, colored_ip, conf_path_display, raw_ip)
|
||||
"""
|
||||
host_label = host.get('Host', 'N/A')
|
||||
hostname = host.get('HostName', 'N/A')
|
||||
user = host.get('User', 'N/A')
|
||||
port = int(host.get('Port', '22'))
|
||||
identity_file = host.get('IdentityFile', 'N/A')
|
||||
results = []
|
||||
|
||||
# Resolve IP
|
||||
try:
|
||||
raw_ip = socket.gethostbyname(hostname) # uncolored
|
||||
colored_ip = f"{Colors.GREEN}{raw_ip}{Colors.RESET}"
|
||||
except socket.error:
|
||||
raw_ip = "N/A"
|
||||
colored_ip = f"{Colors.RED}N/A{Colors.RESET}"
|
||||
async def process_block(h):
|
||||
host_label = h.get('Host', 'N/A')
|
||||
hostname = h.get('HostName', 'N/A')
|
||||
user = h.get('User', 'N/A')
|
||||
port = int(h.get('Port', '22'))
|
||||
identity_file = h.get('IdentityFile', 'N/A')
|
||||
|
||||
# Check port
|
||||
if raw_ip != "N/A":
|
||||
port_open = await check_ssh_port(raw_ip, port)
|
||||
colored_port = (
|
||||
f"{Colors.GREEN}{port}{Colors.RESET}" if port_open else f"{Colors.RED}{port}{Colors.RESET}"
|
||||
# Resolve IP
|
||||
try:
|
||||
raw_ip = socket.gethostbyname(hostname)
|
||||
colored_ip = f"{Colors.GREEN}{raw_ip}{Colors.RESET}"
|
||||
except socket.error:
|
||||
raw_ip = "N/A"
|
||||
colored_ip = f"{Colors.RED}N/A{Colors.RESET}"
|
||||
|
||||
# Check port
|
||||
if raw_ip != "N/A":
|
||||
port_open = await check_ssh_port(raw_ip, port)
|
||||
colored_port = (
|
||||
f"{Colors.GREEN}{port}{Colors.RESET}" if port_open else f"{Colors.RED}{port}{Colors.RESET}"
|
||||
)
|
||||
else:
|
||||
colored_port = f"{Colors.RED}{port}{Colors.RESET}"
|
||||
|
||||
# Conf Directory = ~/.ssh/conf/<host_label>
|
||||
conf_path = f"~/.ssh/conf/{host_label}"
|
||||
# If there's an IdentityFile, color the conf path
|
||||
if identity_file != 'N/A':
|
||||
conf_path_display = f"{Colors.GREEN}{conf_path}{Colors.RESET}"
|
||||
else:
|
||||
conf_path_display = conf_path
|
||||
|
||||
return (
|
||||
host_label,
|
||||
user,
|
||||
colored_port,
|
||||
hostname,
|
||||
colored_ip,
|
||||
conf_path_display,
|
||||
raw_ip # uncolored IP for sorting
|
||||
)
|
||||
else:
|
||||
colored_port = f"{Colors.RED}{port}{Colors.RESET}"
|
||||
|
||||
# Conf Directory = ~/.ssh/conf/<host_label>
|
||||
conf_path = f"~/.ssh/conf/{host_label}"
|
||||
# If there's an IdentityFile, color the conf path green
|
||||
if identity_file != 'N/A':
|
||||
conf_path_display = f"{Colors.GREEN}{conf_path}{Colors.RESET}"
|
||||
else:
|
||||
conf_path_display = conf_path
|
||||
# Process blocks concurrently
|
||||
tasks = [process_block(b) for b in all_host_blocks]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
||||
|
||||
# Return the data plus the uncolored IP for sorting
|
||||
return (
|
||||
host_label,
|
||||
user,
|
||||
colored_port,
|
||||
hostname,
|
||||
colored_ip,
|
||||
conf_path_display,
|
||||
raw_ip # for sorting
|
||||
)
|
||||
|
||||
async def list_hosts(conf_dir):
|
||||
def parse_ip(ip_str):
|
||||
"""
|
||||
List out all hosts found in ~/.ssh/conf/*/config, sorted by IP in ascending order.
|
||||
Columns: No., Host, User, Port, HostName, IP Address, Conf Directory
|
||||
Convert a string IP to an ipaddress object for sorting.
|
||||
Returns None if invalid or 'N/A'.
|
||||
"""
|
||||
import ipaddress
|
||||
try:
|
||||
return ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def sort_by_ip(results):
|
||||
"""
|
||||
Sort the 7-tuples by IP ascending, with 'N/A' last.
|
||||
"""
|
||||
sortable = []
|
||||
for row in results:
|
||||
raw_ip = row[-1]
|
||||
ip_obj = parse_ip(raw_ip)
|
||||
sortable.append(((ip_obj is None, ip_obj), row))
|
||||
|
||||
sortable.sort(key=lambda x: x[0])
|
||||
return [row for (_, row) in sortable]
|
||||
|
||||
async def build_host_list_table(conf_dir):
|
||||
"""
|
||||
Gathers + sorts all hosts in conf_dir by IP ascending.
|
||||
Returns (headers, final_table_rows), each row omitting the raw_ip.
|
||||
"""
|
||||
pattern = os.path.join(conf_dir, "*", "config")
|
||||
conf_files = sorted(glob.glob(pattern))
|
||||
|
@ -133,42 +150,28 @@ async def list_hosts(conf_dir):
|
|||
|
||||
headers = ["No.", "Host", "User", "Port", "HostName", "IP Address", "Conf Directory"]
|
||||
if not all_host_blocks:
|
||||
return headers, []
|
||||
|
||||
results = await gather_host_info(all_host_blocks)
|
||||
sorted_rows = sort_by_ip(results)
|
||||
|
||||
# Build final table
|
||||
final_data = []
|
||||
for idx, row in enumerate(sorted_rows, start=1):
|
||||
# row is (host_label, user, colored_port, hostname, colored_ip, conf_path_display, raw_ip)
|
||||
final_data.append([idx] + list(row[:-1]))
|
||||
|
||||
return headers, final_data
|
||||
|
||||
async def list_hosts(conf_dir):
|
||||
headers, final_data = await build_host_list_table(conf_dir)
|
||||
if not final_data:
|
||||
print_warning("No hosts found. The server list is empty.")
|
||||
print("\nSSH Conf Subdirectory Host List")
|
||||
from tabulate import tabulate
|
||||
print(tabulate([], headers=headers, tablefmt="grid"))
|
||||
return
|
||||
|
||||
# Gather full data for each host
|
||||
tasks = [check_host(h) for h in all_host_blocks]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# We want to sort by IP ascending. results[i] is a tuple:
|
||||
# (host_label, user, colored_port, hostname, colored_ip, conf_path, raw_ip)
|
||||
# We'll parse raw_ip as an ipaddress for sorting. "N/A" => sort to the end.
|
||||
def parse_ip(ip_str):
|
||||
try:
|
||||
return ipaddress.ip_address(ip_str)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
# Convert the results into a list of (ip_obj, original_tuple)
|
||||
# so we can sort, then rebuild the final data.
|
||||
sortable = []
|
||||
for row in results:
|
||||
raw_ip = row[-1] # last element
|
||||
ip_obj = parse_ip(raw_ip)
|
||||
# We'll sort None last by using a sort key that puts (True) after (False)
|
||||
# e.g. (ip_obj is None, ip_obj)
|
||||
sortable.append(((ip_obj is None, ip_obj), row))
|
||||
|
||||
# Sort by (is_none, ip_obj)
|
||||
sortable.sort(key=lambda x: x[0])
|
||||
|
||||
# Rebuild the final display table, ignoring the raw_ip at the end
|
||||
final_data = []
|
||||
for idx, (_, row) in enumerate(sortable, start=1):
|
||||
# row is (host_label, user, colored_port, hostname, colored_ip, conf_path, raw_ip)
|
||||
final_data.append([idx] + list(row[:-1])) # omit raw_ip
|
||||
|
||||
from tabulate import tabulate
|
||||
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
|
||||
print(tabulate(final_data, headers=headers, tablefmt="grid"))
|
||||
|
|
173
regen_key.py
173
regen_key.py
|
@ -1,88 +1,54 @@
|
|||
import os
|
||||
import glob
|
||||
import subprocess
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from .utils import (
|
||||
print_info,
|
||||
print_warning,
|
||||
print_error,
|
||||
safe_input,
|
||||
Colors
|
||||
safe_input
|
||||
)
|
||||
from .list_hosts import (
|
||||
build_host_list_table,
|
||||
load_config_file,
|
||||
gather_host_info,
|
||||
sort_by_ip
|
||||
)
|
||||
from .list_hosts import load_config_file, check_ssh_port
|
||||
|
||||
async def get_all_host_blocks(conf_dir):
|
||||
pattern = os.path.join(conf_dir, "*", "config")
|
||||
conf_files = sorted(glob.glob(pattern))
|
||||
|
||||
all_blocks = []
|
||||
for conf_file in conf_files:
|
||||
blocks = load_config_file(conf_file)
|
||||
all_blocks.extend(blocks)
|
||||
|
||||
if not all_blocks:
|
||||
return []
|
||||
return all_blocks
|
||||
|
||||
async def regenerate_key(conf_dir):
|
||||
"""
|
||||
Menu-driven function to regenerate a key for a selected host:
|
||||
1) Show the host table with row numbers
|
||||
2) Let user pick row # or label
|
||||
3) Read old pub data
|
||||
4) Remove old local key files
|
||||
5) Generate new key
|
||||
6) Copy new key
|
||||
7) Remove old key from remote authorized_keys (improved logic to detect existence)
|
||||
Regenerate the SSH key for a chosen host by:
|
||||
1) Displaying the unified table of hosts (No. | Host | User | Port | HostName | IP Address | Conf Directory).
|
||||
2) Letting you pick a row number or host label.
|
||||
3) Reading/deleting any existing local keys,
|
||||
4) Generating a new key,
|
||||
5) Optionally copying it to the remote,
|
||||
6) Removing the old pub key from the remote authorized_keys if present.
|
||||
"""
|
||||
|
||||
print_info("Regenerate Key - Step 1: Show current hosts...\n")
|
||||
|
||||
# 1) Gather host blocks
|
||||
all_blocks = await get_all_host_blocks(conf_dir)
|
||||
if not all_blocks:
|
||||
# 1) Reuse the same columns as the main 'list_hosts'
|
||||
headers, final_data = await build_host_list_table(conf_dir)
|
||||
if not final_data:
|
||||
print_warning("No hosts found. Cannot regenerate a key.")
|
||||
return
|
||||
|
||||
# Display them in a table (similar to list_hosts):
|
||||
import socket
|
||||
from tabulate import tabulate
|
||||
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
|
||||
print(tabulate(final_data, headers=headers, tablefmt="grid"))
|
||||
|
||||
table_rows = []
|
||||
for idx, block in enumerate(all_blocks, start=1):
|
||||
host_label = block.get("Host", "N/A")
|
||||
hostname = block.get("HostName", "N/A")
|
||||
user = block.get("User", "N/A")
|
||||
port = int(block.get("Port", "22"))
|
||||
identity_file = block.get("IdentityFile", "N/A")
|
||||
# 2) We need to correlate row => actual config block.
|
||||
# We'll replicate the logic that build_host_list_table uses.
|
||||
all_blocks = []
|
||||
import glob
|
||||
for cfile in glob.glob(os.path.join(conf_dir, "*", "config")):
|
||||
all_blocks.extend(load_config_file(cfile))
|
||||
|
||||
# Check IP
|
||||
try:
|
||||
ip_address = socket.gethostbyname(hostname)
|
||||
port_open = await asyncio.wait_for(check_ssh_port(ip_address, port), timeout=1)
|
||||
except:
|
||||
ip_address = None
|
||||
port_open = False
|
||||
results = await gather_host_info(all_blocks)
|
||||
sorted_rows = sort_by_ip(results)
|
||||
# sorted_rows is a list of 7-tuples:
|
||||
# (host_label, user, colored_port, hostname, colored_ip, conf_path, raw_ip)
|
||||
|
||||
ip_disp = f"\033[0;32m{ip_address}\033[0m" if ip_address else "\033[0;31mN/A\033[0m"
|
||||
port_disp = f"\033[0;32m{port}\033[0m" if port_open else f"\033[0;31m{port}\033[0m"
|
||||
|
||||
row = [
|
||||
idx,
|
||||
host_label,
|
||||
user,
|
||||
port_disp,
|
||||
hostname,
|
||||
ip_disp,
|
||||
identity_file
|
||||
]
|
||||
table_rows.append(row)
|
||||
|
||||
headers = ["No.", "Host", "User", "Port", "HostName", "IP", "IdentityFile"]
|
||||
print("\nSSH Conf Subdirectory Host List")
|
||||
print(tabulate(table_rows, headers=headers, tablefmt="grid"))
|
||||
|
||||
# 2) Prompt for row # or label
|
||||
choice = safe_input("Enter the row number or the Host label to regenerate: ")
|
||||
if choice is None:
|
||||
return
|
||||
|
@ -91,54 +57,60 @@ async def regenerate_key(conf_dir):
|
|||
print_error("No choice given.")
|
||||
return
|
||||
|
||||
target_block = None
|
||||
target_tuple = None
|
||||
if choice.isdigit():
|
||||
row_idx = int(choice)
|
||||
if row_idx < 1 or row_idx > len(all_blocks):
|
||||
if row_idx < 1 or row_idx > len(sorted_rows):
|
||||
print_warning(f"Invalid row number {row_idx}.")
|
||||
return
|
||||
target_block = all_blocks[row_idx - 1]
|
||||
target_tuple = sorted_rows[row_idx - 1]
|
||||
else:
|
||||
# user typed a label
|
||||
for b in all_blocks:
|
||||
if b.get("Host") == choice:
|
||||
target_block = b
|
||||
for t in sorted_rows:
|
||||
if t[0] == choice: # t[0] is host_label
|
||||
target_tuple = t
|
||||
break
|
||||
if not target_block:
|
||||
if not target_tuple:
|
||||
print_warning(f"No matching host label '{choice}' found.")
|
||||
return
|
||||
|
||||
# 3) Gather info from block
|
||||
host_label = target_block.get("Host", "")
|
||||
hostname = target_block.get("HostName", "")
|
||||
user = target_block.get("User", "root")
|
||||
port = int(target_block.get("Port", "22"))
|
||||
identity_file = target_block.get("IdentityFile", "")
|
||||
# 3) Retrieve the full config block so we have Port, IdentityFile, etc.
|
||||
host_label = target_tuple[0]
|
||||
hostname = target_tuple[3] # (host_label, user, colored_port, hostname, ...)
|
||||
user = target_tuple[1]
|
||||
|
||||
if not host_label or not hostname:
|
||||
print_error("Missing Host or HostName; cannot regenerate.")
|
||||
# find that block
|
||||
found_block = None
|
||||
for b in all_blocks:
|
||||
if b.get("Host") == host_label:
|
||||
found_block = b
|
||||
break
|
||||
|
||||
if not found_block:
|
||||
print_warning(f"No config block found for '{host_label}'.")
|
||||
return
|
||||
|
||||
port_str = found_block.get("Port", "22")
|
||||
port = int(port_str)
|
||||
identity_file = found_block.get("IdentityFile", "")
|
||||
|
||||
# If missing or "N/A", we can't regenerate
|
||||
if not identity_file or identity_file == "N/A":
|
||||
print_error("No IdentityFile found in config; cannot regenerate.")
|
||||
return
|
||||
|
||||
# Derive local paths
|
||||
# 4) Remove old local key files
|
||||
expanded_key = os.path.expanduser(identity_file)
|
||||
key_dir = os.path.dirname(expanded_key)
|
||||
pub_path = expanded_key + ".pub"
|
||||
old_pub_data = ""
|
||||
|
||||
# 3a) Read old pub key data
|
||||
old_pub_data = ""
|
||||
if os.path.isfile(pub_path):
|
||||
try:
|
||||
with open(pub_path, "r") as f:
|
||||
old_pub_data = f.read().strip()
|
||||
old_pub_data = f.read().rstrip("\n")
|
||||
except Exception as e:
|
||||
print_warning(f"Could not read old pub key: {e}")
|
||||
else:
|
||||
print_warning("No old pub key found locally.")
|
||||
|
||||
# 4) Remove old local key files
|
||||
print_info("Removing old key files locally...")
|
||||
for path in [expanded_key, pub_path]:
|
||||
if os.path.isfile(path):
|
||||
|
@ -150,7 +122,7 @@ async def regenerate_key(conf_dir):
|
|||
|
||||
# 5) Generate new key
|
||||
print_info("Generating new ed25519 SSH key...")
|
||||
new_key_path = expanded_key # Reuse the same path from config
|
||||
new_key_path = expanded_key
|
||||
cmd = ["ssh-keygen", "-q", "-t", "ed25519", "-N", "", "-f", new_key_path]
|
||||
try:
|
||||
subprocess.check_call(cmd)
|
||||
|
@ -159,7 +131,7 @@ async def regenerate_key(conf_dir):
|
|||
print_error(f"Error generating new SSH key: {e}")
|
||||
return
|
||||
|
||||
# 6) Copy new key to remote
|
||||
# 6) Copy the new key to remote if user wants
|
||||
copy_choice = safe_input("Copy new key to remote now? (y/n): ")
|
||||
if copy_choice and copy_choice.lower().startswith('y'):
|
||||
ssh_copy_cmd = ["ssh-copy-id", "-i", new_key_path]
|
||||
|
@ -172,7 +144,7 @@ async def regenerate_key(conf_dir):
|
|||
except subprocess.CalledProcessError as e:
|
||||
print_error(f"Error copying new key: {e}")
|
||||
|
||||
# 7) Remove old key from authorized_keys if old_pub_data is non-empty
|
||||
# 7) Remove old key from authorized_keys if we had old_pub_data
|
||||
if old_pub_data:
|
||||
print_info("Attempting to remove old key from remote authorized_keys...")
|
||||
await remove_old_key_remote(old_pub_data, user, hostname, port)
|
||||
|
@ -181,41 +153,34 @@ async def regenerate_key(conf_dir):
|
|||
|
||||
async def remove_old_key_remote(old_pub_data, user, hostname, port):
|
||||
"""
|
||||
1) Check if old_pub_data is present on remote server (grep -q).
|
||||
2) If found, remove it with grep -v ...
|
||||
3) Otherwise, print "No old key found on remote."
|
||||
Checks and removes the exact matching line from remote authorized_keys
|
||||
via grep -Fxq / grep -vFx.
|
||||
"""
|
||||
# 1) Check if old_pub_data exists in authorized_keys
|
||||
check_cmd = [
|
||||
"ssh",
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"ssh", "-o", "StrictHostKeyChecking=no",
|
||||
"-p", str(port),
|
||||
f"{user}@{hostname}",
|
||||
f"grep -F '{old_pub_data}' ~/.ssh/authorized_keys"
|
||||
f'grep -Fxq "{old_pub_data}" ~/.ssh/authorized_keys'
|
||||
]
|
||||
found_key = False
|
||||
|
||||
try:
|
||||
subprocess.check_call(check_cmd)
|
||||
found_key = True
|
||||
except subprocess.CalledProcessError:
|
||||
# grep returns exit code 1 if not found
|
||||
pass
|
||||
|
||||
if not found_key:
|
||||
print_warning("No old key found on remote authorized_keys.")
|
||||
return
|
||||
|
||||
# 2) Actually remove it
|
||||
remove_cmd = [
|
||||
"ssh",
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"ssh", "-o", "StrictHostKeyChecking=no",
|
||||
"-p", str(port),
|
||||
f"{user}@{hostname}",
|
||||
f"grep -v '{old_pub_data}' ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys"
|
||||
f'grep -vFx "{old_pub_data}" ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys'
|
||||
]
|
||||
try:
|
||||
subprocess.check_call(remove_cmd)
|
||||
print_info("Old public key removed from remote authorized_keys.")
|
||||
except subprocess.CalledProcessError:
|
||||
print_warning("Failed to remove old key from remote authorized_keys (permission or other error).")
|
||||
print_warning("Failed to remove old key from remote authorized_keys (permissions or other error).")
|
||||
|
|
155
remove_host.py
155
remove_host.py
|
@ -1,82 +1,68 @@
|
|||
# ssh_manager/remove_host.py
|
||||
|
||||
import os
|
||||
import glob
|
||||
import subprocess
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from .utils import (
|
||||
print_info,
|
||||
print_warning,
|
||||
print_error,
|
||||
safe_input
|
||||
)
|
||||
from .list_hosts import load_config_file, check_ssh_port
|
||||
|
||||
async def get_all_host_blocks(conf_dir):
|
||||
pattern = os.path.join(conf_dir, "*", "config")
|
||||
conf_files = sorted(glob.glob(pattern))
|
||||
|
||||
all_blocks = []
|
||||
for conf_file in conf_files:
|
||||
blocks = load_config_file(conf_file)
|
||||
all_blocks.extend(blocks)
|
||||
|
||||
return all_blocks
|
||||
from .list_hosts import build_host_list_table, load_config_file
|
||||
"""
|
||||
Remove host now reuses build_host_list_table to display the same columns:
|
||||
No. | Host | User | Port | HostName | IP Address | Conf Directory
|
||||
"""
|
||||
|
||||
async def remove_host(conf_dir):
|
||||
"""
|
||||
Remove an SSH host by:
|
||||
1) Listing all hosts
|
||||
1) Listing all hosts (same columns as main list)
|
||||
2) Letting user pick row number or label
|
||||
3) Attempting to remove the old pub key from remote authorized_keys
|
||||
4) Deleting the subdirectory in ~/.ssh/conf/<host_label>
|
||||
3) Removing old pub key from remote
|
||||
4) Deleting ~/.ssh/conf/<host_label>
|
||||
"""
|
||||
|
||||
print_info("Remove Host - Step 1: Show current hosts...\n")
|
||||
|
||||
all_blocks = await get_all_host_blocks(conf_dir)
|
||||
if not all_blocks:
|
||||
# Reuse the unified table from list_hosts
|
||||
headers, final_data = await build_host_list_table(conf_dir)
|
||||
|
||||
if not final_data:
|
||||
print_warning("No hosts found. Cannot remove anything.")
|
||||
return
|
||||
|
||||
# We'll display them in a table
|
||||
import socket
|
||||
# Print the same table
|
||||
from tabulate import tabulate
|
||||
print("\nSSH Conf Subdirectory Host List (Sorted by IP Ascending)")
|
||||
print(tabulate(final_data, headers=headers, tablefmt="grid"))
|
||||
|
||||
table_rows = []
|
||||
for idx, block in enumerate(all_blocks, start=1):
|
||||
host_label = block.get("Host", "N/A")
|
||||
hostname = block.get("HostName", "N/A")
|
||||
user = block.get("User", "N/A")
|
||||
port = int(block.get("Port", "22"))
|
||||
identity_file = block.get("IdentityFile", "N/A")
|
||||
# We have final_data rows => need to map row => block
|
||||
# So let's gather the raw blocks again to correlate.
|
||||
# We'll do a separate approach or we can parse final_data.
|
||||
# Easiest: Re-run load_config_file if needed or:
|
||||
blocks = []
|
||||
# The last gather call for build_host_list_table used load_config_file already
|
||||
# but it doesn't return the correlation. We'll replicate the logic quickly.
|
||||
|
||||
try:
|
||||
ip_address = socket.gethostbyname(hostname)
|
||||
port_open = await asyncio.wait_for(check_ssh_port(ip_address, port), timeout=1)
|
||||
except:
|
||||
ip_address = None
|
||||
port_open = False
|
||||
pattern = os.path.join(conf_dir, "*", "config")
|
||||
conf_files = sorted(os.listdir(conf_dir))
|
||||
|
||||
ip_disp = f"\033[0;32m{ip_address}\033[0m" if ip_address else "\033[0;31mN/A\033[0m"
|
||||
port_disp = f"\033[0;32m{port}\033[0m" if port_open else f"\033[0;31m{port}\033[0m"
|
||||
# Actually, let's do a small approach:
|
||||
from .list_hosts import gather_host_info, sort_by_ip
|
||||
all_blocks = []
|
||||
import glob
|
||||
|
||||
row = [
|
||||
idx,
|
||||
host_label,
|
||||
user,
|
||||
port_disp,
|
||||
hostname,
|
||||
ip_disp,
|
||||
identity_file
|
||||
]
|
||||
table_rows.append(row)
|
||||
for cfile in glob.glob(os.path.join(conf_dir, "*", "config")):
|
||||
blocks.extend(load_config_file(cfile))
|
||||
|
||||
headers = ["No.", "Host", "User", "Port", "HostName", "IP", "IdentityFile"]
|
||||
print("\nSSH Conf Subdirectory Host List")
|
||||
print(tabulate(table_rows, headers=headers, tablefmt="grid"))
|
||||
# gather the same big list
|
||||
results = await gather_host_info(blocks)
|
||||
sorted_rows = sort_by_ip(results)
|
||||
# sorted_rows is a list of 7-tuples:
|
||||
# (host_label, user, colored_port, hostname, colored_ip, conf_display, raw_ip)
|
||||
|
||||
# Prompt which host to remove
|
||||
choice = safe_input("Enter the row number or Host label to remove: ")
|
||||
if choice is None:
|
||||
return
|
||||
|
@ -85,43 +71,60 @@ async def remove_host(conf_dir):
|
|||
print_error("Invalid empty choice.")
|
||||
return
|
||||
|
||||
target_block = None
|
||||
target_tuple = None
|
||||
|
||||
# If digit => index in sorted_rows
|
||||
if choice.isdigit():
|
||||
idx = int(choice)
|
||||
if idx < 1 or idx > len(all_blocks):
|
||||
if idx < 1 or idx > len(sorted_rows):
|
||||
print_warning(f"Row number {idx} is invalid.")
|
||||
return
|
||||
target_block = all_blocks[idx - 1]
|
||||
target_tuple = sorted_rows[idx - 1]
|
||||
else:
|
||||
# They typed a label
|
||||
for b in all_blocks:
|
||||
if b.get("Host") == choice:
|
||||
target_block = b
|
||||
for t in sorted_rows:
|
||||
if t[0] == choice: # t[0] = host_label
|
||||
target_tuple = t
|
||||
break
|
||||
if not target_block:
|
||||
if not target_tuple:
|
||||
print_warning(f"No matching host label '{choice}' found.")
|
||||
return
|
||||
|
||||
host_label = target_block.get("Host", "")
|
||||
hostname = target_block.get("HostName", "")
|
||||
user = target_block.get("User", "root")
|
||||
port = int(target_block.get("Port", "22"))
|
||||
identity_file = target_block.get("IdentityFile", "")
|
||||
# target_tuple is (host_label, user, colored_port, hostname, colored_ip, conf_dir, raw_ip)
|
||||
host_label = target_tuple[0]
|
||||
hostname = target_tuple[3]
|
||||
user = target_tuple[1]
|
||||
# parse port from colored_port? We can do a quick approach. Alternatively, we re-run the load_config approach again.
|
||||
# We do a second approach: let's see if we can find the block in "blocks".
|
||||
# But let's parse the config block to get the real port, identity, etc.
|
||||
|
||||
# We find the matching block in "blocks" by host_label
|
||||
found_block = None
|
||||
for b in blocks:
|
||||
if b.get("Host") == host_label:
|
||||
found_block = b
|
||||
break
|
||||
|
||||
if not found_block:
|
||||
print_warning(f"No config block found for {host_label}.")
|
||||
return
|
||||
|
||||
port_str = found_block.get("Port", "22")
|
||||
port = int(port_str)
|
||||
identity_file = found_block.get("IdentityFile", "")
|
||||
|
||||
# now do removal logic
|
||||
if not host_label:
|
||||
print_warning("Target block has no Host label. Cannot remove.")
|
||||
return
|
||||
|
||||
# Check IdentityFile for old pub key
|
||||
if identity_file and identity_file != "N/A":
|
||||
expanded_key = os.path.expanduser(identity_file)
|
||||
pub_path = expanded_key + ".pub"
|
||||
pub_path = expanded_key + ".pub"
|
||||
old_pub_data = ""
|
||||
|
||||
if os.path.isfile(pub_path):
|
||||
try:
|
||||
with open(pub_path, "r") as f:
|
||||
# This is the EXACT line that might appear in authorized_keys
|
||||
old_pub_data = f.read().rstrip("\n")
|
||||
except Exception as e:
|
||||
print_warning(f"Could not read old pub key: {e}")
|
||||
|
@ -130,13 +133,13 @@ async def remove_host(conf_dir):
|
|||
print_info("Attempting to remove old key from remote authorized_keys...")
|
||||
await remove_old_key_remote(old_pub_data, user, hostname, port)
|
||||
|
||||
# Finally, remove the local subdirectory
|
||||
# remove local folder
|
||||
host_dir = os.path.join(conf_dir, host_label)
|
||||
import shutil
|
||||
if os.path.isdir(host_dir):
|
||||
confirm = safe_input(f"Are you sure you want to delete local folder {host_dir}? (y/n): ")
|
||||
if confirm and confirm.lower().startswith('y'):
|
||||
try:
|
||||
import shutil
|
||||
shutil.rmtree(host_dir)
|
||||
print_info(f"Removed local folder: {host_dir}")
|
||||
except Exception as e:
|
||||
|
@ -144,24 +147,17 @@ async def remove_host(conf_dir):
|
|||
else:
|
||||
print_warning("Local folder not removed.")
|
||||
else:
|
||||
print_warning(f"Local host folder {host_dir} not found. Nothing to remove.")
|
||||
print_warning(f"Local folder {host_dir} not found. Nothing to remove.")
|
||||
|
||||
async def remove_old_key_remote(old_pub_data, user, hostname, port):
|
||||
"""
|
||||
Remove the old key from remote authorized_keys if found, matching EXACT lines.
|
||||
We'll do:
|
||||
grep -Fxq for existence
|
||||
grep -vFx to remove it
|
||||
"""
|
||||
# 1) Check if old_pub_data is present in authorized_keys (exact line)
|
||||
# same logic as before
|
||||
check_cmd = [
|
||||
"ssh", "-o", "StrictHostKeyChecking=no",
|
||||
"-p", str(port),
|
||||
f"{user}@{hostname}",
|
||||
f"grep -Fxq \"{old_pub_data}\" ~/.ssh/authorized_keys"
|
||||
f'grep -Fxq "{old_pub_data}" ~/.ssh/authorized_keys'
|
||||
]
|
||||
found_key = False
|
||||
|
||||
try:
|
||||
subprocess.check_call(check_cmd)
|
||||
found_key = True
|
||||
|
@ -172,12 +168,11 @@ async def remove_old_key_remote(old_pub_data, user, hostname, port):
|
|||
print_warning("No old key found on remote authorized_keys.")
|
||||
return
|
||||
|
||||
# 2) Actually remove it by ignoring EXACT matches to that line
|
||||
remove_cmd = [
|
||||
"ssh", "-o", "StrictHostKeyChecking=no",
|
||||
"-p", str(port),
|
||||
f"{user}@{hostname}",
|
||||
f"grep -vFx \"{old_pub_data}\" ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys"
|
||||
f'grep -vFx "{old_pub_data}" ~/.ssh/authorized_keys > ~/.ssh/tmp && mv ~/.ssh/tmp ~/.ssh/authorized_keys'
|
||||
]
|
||||
try:
|
||||
subprocess.check_call(remove_cmd)
|
||||
|
|
35
utils.py
35
utils.py
|
@ -1,25 +1,34 @@
|
|||
# ssh_manager/utils.py
|
||||
|
||||
from enum import Enum
|
||||
import sys
|
||||
from functools import lru_cache
|
||||
|
||||
class Colors:
|
||||
GREEN = "\033[0;32m"
|
||||
RED = "\033[0;31m"
|
||||
class Colors(Enum):
|
||||
GREEN = "\033[0;32m"
|
||||
RED = "\033[0;31m"
|
||||
YELLOW = "\033[1;33m"
|
||||
CYAN = "\033[0;36m"
|
||||
BOLD = "\033[1m"
|
||||
RESET = "\033[0m"
|
||||
CYAN = "\033[0;36m"
|
||||
BOLD = "\033[1m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
def print_error(message):
|
||||
print(f"{Colors.RED}{Colors.BOLD}[✖] {Colors.RESET}{message}")
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
def print_warning(message):
|
||||
print(f"{Colors.YELLOW}{Colors.BOLD}[⚠] {Colors.RESET}{message}")
|
||||
@lru_cache(maxsize=32)
|
||||
def _format_message(prefix: str, color: Colors, message: str) -> str:
|
||||
return f"{color}{Colors.BOLD}[{prefix}] {Colors.RESET}{message}"
|
||||
|
||||
def print_info(message):
|
||||
print(f"{Colors.GREEN}{Colors.BOLD}[✔] {Colors.RESET}{message}")
|
||||
def print_error(message: str) -> None:
|
||||
print(_format_message("✖", Colors.RED, message))
|
||||
|
||||
def safe_input(prompt=""):
|
||||
def print_warning(message: str) -> None:
|
||||
print(_format_message("⚠", Colors.YELLOW, message))
|
||||
|
||||
def print_info(message: str) -> None:
|
||||
print(_format_message("✔", Colors.GREEN, message))
|
||||
|
||||
def safe_input(prompt: str = "") -> str:
|
||||
"""
|
||||
A wrapper around input() that exits the entire script on Ctrl+C.
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue