Commits

Albert Hopkins  committed 1270d51

Handly port forwarding (more) properly

  • Participants
  • Parent commits b77ae2d
  • Branches forward

Comments (0)

Files changed (2)

File hemp/fabfile.py

 
 
 from fabric.api import *
-from fabric.state import output
+from fabric.state import connections, output
 try:
     import fabric.colors as colors
 except ImportError:
 
     boot()
 
+def _stop_forwarding(_sig, _frame):
+    """
+    Signal handler for the forward task.
+    This stops the TCP handlers, shuts down the ssh tunnels and exits the
+    interpreter.
+    """
+    for host in env.hosts:
+        forwards = cloud.settings.FORWARD.get(host)
+        if not forwards:
+            continue
+        transport = connections[host]._transport
+        for local_port, remote_port in forwards:
+            transport.cancel_port_forward('localhost', remote_port)
+    for tunnel in _stop_forwarding.tunnels:
+        tunnel.shutdown()
+    print '\nTunnels shut down'
+    raise SystemExit
+
+@runs_once
 def forward():
-    """[H] Enable port forwarding"""
-    forwards = cloud.settings.FORWARD.get(env.host)
-    if not forwards:
-        return
+    """[H] Start port forwarding"""
+    import signal
+    import threading
 
-    for local_port, remote_port in forwards:
-        local('ssh -f -N -L %s:localhost:%s %s@%s' % (local_port, remote_port,
-            env.user, env.host), capture=False)
+    signal.signal(signal.SIGINT, _stop_forwarding)
 
+    _stop_forwarding.tunnels = []
+    padding = max([len(x) for x in env.hosts])
+    for host in env.hosts:
+        forwards = cloud.settings.FORWARD.get(host)
+        if not forwards:
+            continue
+        if not cloud.started(host):
+            print ('%s: %s' % (colors.yellow(host.rjust(padding),
+                bold=True), colors.red('not started. Cannot forward')))
+            continue
+
+        transport = connections[host]._transport
+        for local_port, remote_port in forwards:
+            print colors.green(':%s -> %s:%s' % (local_port, host,
+                remote_port))
+            tunnel = cloud.helpers.forward_tunnel(local_port, 'localhost',
+                    remote_port, transport)
+            _stop_forwarding.tunnels.append(tunnel)
+            thread = threading.Thread(target=tunnel.serve_forever)
+            thread.start()
+
+    while True:
+        time.sleep(5)
 
 def _hilight_hemp_tasks():
     for i in globals().values():

File hemp/helpers.py

 from contextlib import contextmanager
 import os
 import re
+import select
+import SocketServer
 from subprocess import Popen, PIPE
 from time import sleep
 import xml.etree.ElementTree as ElementTree
 
     state.close()
 
-
 def which(progname, path=None, test=None):
     """Like the 'which' command, but works on any type of file/dir, not
     just executables.  path is an interable of directories
                 return None
             return full_path
     return None
+
+class TCPServer(SocketServer.ThreadingTCPServer):
+    """Threading TCP Server"""
+    daemon_threads = True
+    allow_reuse_address = True
+
+def forward_tunnel(local_port, remote_host, remote_port, transport):
+    """
+    Using «transport», create a tunnel from *:«local_port» to
+    «remote_host»:«remote_port». Return the TCPServer instance to handle the
+    forwarding
+    """
+    class Handler(SocketServer.BaseRequestHandler):
+        """Handler class for the TCPServer"""
+        def handle(self):
+            channel = self.ssh_transport.open_channel('direct-tcpip',
+                    (self.chain_host, self.chain_port),
+                    self.request.getpeername())
+
+            while True:
+                rlist = select.select([self.request, channel], [], [])[0]
+                if self.request in rlist:
+                    data = self.request.recv(1024)
+                    if len(data) == 0:
+                        break
+                    channel.send(data)
+                if channel in rlist:
+                    data = channel.recv(1024)
+                    if len(data) == 0:
+                        break
+                    self.request.send(data)
+            channel.close()
+            self.request.close()
+
+    Handler.chain_host = remote_host
+    Handler.chain_port = remote_port
+    Handler.ssh_transport = transport
+
+    return TCPServer(('', local_port), Handler)