Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: python sshnpd threading issue #1040

Merged
merged 4 commits into from
May 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 25 additions & 38 deletions packages/python/sshnpd/sshnpd.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
#!/usr/bin/env python3
import argparse
import errno
import getpass
import json
import logging
import os
import subprocess
import sys
import os, threading, getpass, json, logging, subprocess, argparse, errno
from io import StringIO
import threading
from queue import Empty, Queue
from time import sleep
from threading import Event

from select import select
from socket import socket, gethostbyname, gethostname, create_connection, error
from socket import gethostbyname, create_connection, error

from at_client import AtClient
from at_client.common import AtSign
Expand All @@ -19,12 +23,11 @@
class SocketConnector:
_logger = logging.getLogger("sshrv | socket_connector")

def __init__(self, server1_ip, server1_port, server2_ip, server2_port, verbose = False):
def __init__(self, server1_ip, server1_port, server2_ip, server2_port, verbose=False):
self._logger.setLevel(logging.INFO)
self._logger.addHandler(logging.StreamHandler())
if verbose:
self._logger.setLevel(logging.DEBUG)

# Create sockets for both servers
self.socketA = create_connection((server1_ip, server1_port))
self.socketB = create_connection((server2_ip, server2_port))
Expand All @@ -49,7 +52,7 @@ def connect(self):
sockets_to_monitor.remove(sock)
sock.close()
return
elif not data:
elif not data:
timeout += 1
sleep(0.1)
if data == b'':
Expand All @@ -63,7 +66,6 @@ def connect(self):
self._logger.debug("RECV B -> A : " + str(data))
self.socketA.send(data)
timeout = 0

except error as e:
if e.errno == errno.EWOULDBLOCK:
pass # No data available, continue
Expand Down Expand Up @@ -96,42 +98,37 @@ def run(self):
t1.start()
self.socket_connector = t1
return True

except Exception as e:
raise e

def is_alive(self):
return self.socket_connector.is_alive()


class SSHNPDClient:
def __init__(self, atsign, manager_atsign, device="default", username=None, verbose=False, expecting_ssh_keys=False):

def __init__(self, atsign, manager_atsign, device="default", username=None, verbose=False, expecting_ssh_keys=False):
# AtClient Stuff
self.atsign = atsign
self.manager_atsign = manager_atsign
self.device = device
self.username = username
self.device_namespace = f".{device}.sshnp"
self.at_client = AtClient(AtSign(atsign), queue=Queue(maxsize=20), verbose=verbose)

# SSH Stuff
self.ssh_client = None
self.rv = None
self.expecting_ssh_keys = expecting_ssh_keys

# Logger
self.logger = logging.getLogger("sshnpd")
self.logger.setLevel((logging.DEBUG if verbose else logging.INFO))
self.logger.addHandler(logging.StreamHandler())

# Directory Stuff
home_dir = os.path.expanduser("~")
self.ssh_path = f"{home_dir}/.ssh"

self.threads = []
self.encrypted_queue = Queue()

def start(self):
if self.username:
self.set_username()
Expand All @@ -146,7 +143,7 @@ def start(self):
def close(self):
self.threads.clear()
sys.exit()

def is_alive(self):
foreach_thread = [thread.is_alive() for thread in self.threads]
return all(foreach_thread)
Expand All @@ -155,7 +152,7 @@ def set_username(self):
username = getpass.getuser()
username_key = SharedKey(
f"username.{self.device}.sshnp", AtSign(self.atsign), AtSign(self.manager_atsign))
metadata = Metadata(iv_nonce= EncryptionUtil.generate_iv_nonce(), is_public=False, is_encrypted=True, is_hidden=False)
metadata = Metadata(iv_nonce=EncryptionUtil.generate_iv_nonce(), is_public=False, is_encrypted=True, is_hidden=False)
username_key.metadata = metadata
username_key.cache(-1, True)
self.at_client.put(username_key, username)
Expand All @@ -168,12 +165,10 @@ def handle_notifications(self):
at_event = self.at_client.queue.get(block=False)
event_type = at_event.event_type
event_data = at_event.event_data

if event_type == AtEventType.UPDATE_NOTIFICATION:
self.encrypted_queue.put(at_event)
if event_type != AtEventType.DECRYPTED_UPDATE_NOTIFICATION:
continue

key = event_data["key"].split(":")[1].split(".")[0]
decrypted_value = str(event_data["decryptedValue"])

Expand Down Expand Up @@ -206,10 +201,7 @@ def handle_events(self):
except Empty:
pass

def direct_ssh(
self,
event: AtEvent,
):
def direct_ssh(self, event: AtEvent):
uuid = event.event_data["id"]
ssh_list = json.loads(event.event_data["decryptedValue"])['payload']
iv_nonce = EncryptionUtil.generate_iv_nonce()
Expand All @@ -227,7 +219,6 @@ def direct_ssh(
at_key.shared_with = AtSign(self.manager_atsign)
at_key.metadata = metadata
at_key.namespace = self.device_namespace

hostname = ssh_list['host']
port = ssh_list['port']
session_id = ssh_list['sessionId']
Expand All @@ -254,20 +245,18 @@ def generate_ssh_keys(self, session_id):
# Generate SSH Keys
self.logger.info("Generating SSH Keys")
if not os.path.exists(f"{self.ssh_path}/tmp/"):
os.makedirs(f"{self.ssh_path}/tmp/")
os.makedirs(f"{self.ssh_path}/tmp/")
ssh_keygen = subprocess.Popen(
["ssh-keygen", "-t", "ed25519", "-a", "100", "-f", f"{session_id}_sshnp", "-q", "-N", ""],
cwd=f'{self.ssh_path}/tmp/',
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)

stdout, stderr = ssh_keygen.communicate()
if ssh_keygen.returncode != 0:
self.logger.error("SSH Key generation failed")
self.logger.error(stderr.decode("utf-8"))
return False

self.logger.info("SSH Keys Generated")
ssh_public_key = ""
ssh_private_key = ""
Expand All @@ -277,11 +266,9 @@ def generate_ssh_keys(self, session_id):

with open(f"{self.ssh_path}/tmp/{session_id}_sshnp", 'r') as private_key_file:
ssh_private_key = private_key_file.read()

except Exception as e:
self.logger.error(e)
return False

return (ssh_public_key, ssh_private_key)

def ephemeral_cleanup(self, session_id):
Expand Down Expand Up @@ -316,21 +303,21 @@ def main():
optional.add_argument("-s", action="store_true", dest="expecting_ssh_keys", help="Add ssh key into authorized_keys", default=False)
optional.add_argument("-u", action='store_true', dest="username", help="Username", default="default", required=True)
optional.add_argument("-v", action='store_true', dest="verbose", help="Verbose")

args = parser.parse_args()
sshnpd = SSHNPDClient(args.atsign, args.manager_atsign, args.device, args.username, args.verbose, args.expecting_ssh_keys)

thread = None
while True:
try:
threading.Thread(target=sshnpd.start).start()
try:
thread = threading.Thread(target=sshnpd.start)
thread.start()
while sshnpd.is_alive():
sleep(3)
sleep(3)
except Exception as e:
thread.join()
sshnpd.close()
print(e)
print("Restarting sshnpd in 3 seconds..")
sleep(3)

sleep(3)


if __name__ == "__main__":
Expand Down