sbt avatar sbt committed 9b43638 Draft

Add initial forkserver support

Comments (0)

Files changed (3)

Lib/multiprocessing/forking.py

     global _start_method
 
     if sys.platform == 'win32':
-        assert meth == 'spawn'
+        if meth != 'spawn':
+            raise ValueError('unrecognized start method %r' % meth)
     else:
-        assert meth in ('spawn', 'fork')
-        from .helper import CleanupHelper
+        if meth not in ('fork', 'spawn', 'forkserver'):
+            raise ValueError('unrecognized start method %r' % meth)
+        from . import helper
         if meth != 'fork':
-            CleanupHelper.start()
+            helper.CleanupHelper.start()
+        if meth == 'forkserver':
+            helper.ForkServerHelper.start()
     _start_method = meth
 
 def get_start_method():
 
 if sys.platform != 'win32':
     import _posixsubprocess
-    import fcntl
     WINEXE = False
     WINSERVICE = False
 
             self.returncode = None
             self._launch(process_obj)
 
+        def duplicate_for_child(self, fd):
+            self._fds.append(fd)
+            return fd
+
         def poll(self, flag=os.WNOHANG):
             if self.returncode is None:
                 try:
                 util.Finalize(self, os.close, (parent_r,))
                 self.sentinel = parent_r
 
+
     class PopenSpawn(PopenFork):
         def __init__(self, process_obj):
             self._fds = []
             return fd
 
         def _launch(self, process_obj):
-            from .helper import CleanupHelper
-            try:
-                self._fds.append(CleanupHelper.getfd())
-            except KeyError:
-                pass
+            from . import helper
+            self._fds.append(helper.CleanupHelper.getfd())
 
             prep_data = get_preparation_data(process_obj._name, False)
             with io.BytesIO() as fp:
                     if fd is not None:
                         os.close(fd)
 
+    class PopenForkServer(PopenFork):
+
+        def __init__(self, process_obj):
+            self._fds = []
+            PopenFork.__init__(self, process_obj)
+
+        def duplicate_for_child(self, fd):
+            self._fds.append(fd)
+            return fd
+
+        def _launch(self, process_obj):
+            from . import helper
+            self.sentinel, w = helper.ForkServerHelper.prepare_new_process()
+            util.Finalize(self, os.close, (self.sentinel,))
+            with open(w, 'wb', True) as f:
+                prep_data = get_preparation_data(process_obj._name, True)
+                _tls.spawning_popen = self
+                try:
+                    dump(prep_data, f, HIGHEST_PROTOCOL)
+                    dump(process_obj, f, HIGHEST_PROTOCOL)
+                finally:
+                    del _tls.spawning_popen
+            self.pid = helper.ForkServerHelper.read_ulong(self.sentinel)
+
+        def poll(self, flag=os.WNOHANG):
+            if self.returncode is None:
+                from .connection import wait
+                from .helper import ForkServerHelper
+                timeout = 0 if flag == os.WNOHANG else None
+                if not wait([self.sentinel], timeout):
+                    return None
+                try:
+                    self.returncode = ForkServerHelper.read_ulong(
+                        self.sentinel)
+                except (OSError, ValueError):
+                    # The process ended abnormally perhaps because of a signal
+                    self.returncode = 255
+            return self.returncode
+
 #
 # Windows
 #

Lib/multiprocessing/helper.py

 import sys
 import signal
 import socket
+import struct
 import threading
 
 from . import current_process
 from _multiprocessing import SemLock
 
 #
-# Support for cleaning up leaked named semaphores on Unix
+# Support for cleaning up leaked named semaphores on Unix when not using fork
 #
 
 class CleanupHelper(object):
-    @staticmethod
-    def getfd():
+    @classmethod
+    def getfd(cls):
         fd = current_process()._config.get('unlinkfd')
         if fd is None:
             raise RuntimeError('helper has not been started')
     @classmethod
     def start(cls):
         cp = current_process()
-
         if threading.active_count() > 1:
             raise RuntimeError('cannot start helper after threads started')
-
         if cp._config.get('unlinkfd') is not None:
             raise RuntimeError('helper already started')
 
     def unregister(cls, name):
         cls._send('UNREGISTER', name)
 
-    @staticmethod
-    def _send(cmd, name):
+    @classmethod
+    def _send(cls, cmd, name):
         msg = '{0}:{1}\n'.format(cmd, name).encode('ascii')
         if len(name) > 512:
             # posix guarantees that writes to a pipe of less than PIPE_BUF
         nbytes = os.write(fd, msg)
         assert nbytes == len(msg)
 
-    @staticmethod
-    def _run(r):
+    @classmethod
+    def _run(cls, r):
         # protect the process from ^C and "killall python" etc
         signal.signal(signal.SIGINT, signal.SIG_IGN)
         signal.signal(signal.SIGTERM, signal.SIG_IGN)
                     print('cleaning up semaphore %r' % name, file=sys.stderr)
                 except:
                     pass
+
+#
+#
+#
+
+class ForkServerHelper(object):
+    @classmethod
+    def getsock(cls):
+        sock = getattr(cls, '_client_sock')
+        if sock is None:
+            raise RuntimeError('helper has not been started')
+        return sock
+
+    @classmethod
+    def prepare_new_process(cls):
+        sock = cls.getsock()
+        parent_r, child_w = os.pipe()
+        child_r, parent_w = os.pipe()
+        try:
+            fds = struct.pack("@3i", child_r, child_w, sock.fileno())
+            sock.sendmsg([b'x'], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fds)])
+            # XXX acknowledge
+            return parent_r, parent_w
+        except:
+            os.close(parent_r)
+            os.close(parent_w)
+            raise
+        finally:
+            os.close(child_r)
+            os.close(child_w)
+
+    @classmethod
+    def read_ulong(cls, fd):
+        data = b''
+        length = struct.calcsize('L')
+        while len(data) < length:
+            s = os.read(fd, length - len(data))
+            if not s:
+                raise ValueError('incomplete read of unsigned long')
+            data += s
+        return struct.unpack('L', data)[0]
+
+    @classmethod
+    def write_ulong(cls, fd, n):
+        nbytes = os.write(fd, struct.pack('L', n))
+        assert struct.calcsize('@L') == nbytes
+
+    @classmethod
+    def start(cls):
+        cp = current_process()
+        if threading.active_count() > 1:
+            raise RuntimeError('cannot start helper after threads started')
+        if cp._config.get('forkserverfd') is not None:
+            raise RuntimeError('helper already started')
+
+        cls._client_sock, cls._server_sock = socket.socketpair(socket.AF_UNIX)
+        cls.pid = os.fork()
+        if cls.pid == 0:
+            try:
+                # close client socket -- child processes will
+                # overwrite with a new socket
+                cls._client_sock.close()
+
+                # close sys.stdin
+                if sys.stdin is not None:
+                    try:
+                        sys.stdin.close()
+                        sys.stdin = open(os.devnull)
+                    except (OSError, ValueError):
+                        pass
+
+                # ignoring SIGCHLD prevents zombie processes
+                signal.signal(signal.SIGCHLD, signal.SIG_IGN)
+                cls._run()
+            except Exception:
+                sys.excepthook(*sys.exc_info())
+            finally:
+                os._exit(0)
+        else:
+            cls._server_sock.close()
+
+    @classmethod
+    def _run(cls):
+        while True:
+            try:
+                cls._serve_one()
+            except Exception:
+                sys.excepthook(*sys.exc_info())
+
+    @classmethod
+    def _receive_fds(cls):
+        # get submitted fds from a client process
+        length = socket.CMSG_LEN(struct.calcsize('@3i'))
+        msg = cls._server_sock.recvmsg(1, length)
+        ancdata = msg[1]
+        if not ancdata:
+            # no client processes left
+            sys.exit(0)
+        cmsg_level, cmsg_type, cmsg_data = ancdata[0]
+        assert cmsg_level == socket.SOL_SOCKET
+        assert cmsg_type == socket.SCM_RIGHTS
+        return struct.unpack_from('@3i', cmsg_data)
+
+    @classmethod
+    def _serve_one(cls):
+        code = 1
+        child_r, child_w, fssock_fd = cls._receive_fds()
+        try:
+            if os.fork() == 0:
+                # this is the process started at the request of the client
+                try:
+                    # send pid to client processes
+                    cls.write_ulong(child_w, os.getpid())
+
+                    # reseed random number generator
+                    if 'random' in sys.modules:
+                        import random
+                        random.seed()
+
+                    # reset cls._client_sock so we can access ForkServer
+                    cls._client_sock = socket.fromfd(
+                        fssock_fd, socket.AF_UNIX, socket.SOCK_STREAM)
+                    os.close(fssock_fd)
+
+                    # run process object received over pipe
+                    from .forking import _main
+                    code = _main(child_r)
+
+                    # write the exit code to the pipe
+                    cls.write_ulong(child_w, code)
+                except Exception:
+                    sys.excepthook(*sys.exc_info())
+                finally:
+                    os._exit(code)
+        finally:
+            os.close(fssock_fd)
+            os.close(child_r)
+            os.close(child_w)

Lib/test/test_multiprocessing.py

             self.assertEqual('', err.decode('ascii'))
 
 #
+# Check that spawning results in unneeded fds being closed
+#
+
+class TestCloseFds(unittest.TestCase):
+    @classmethod
+    def _test_closefds(cls, conn, fd):
+        try:
+            os.close(fd)
+        except OSError as e:
+            conn.send(True)               # expect EBADF
+        else:
+            conn.send(False)
+
+    def test_closefds(self):
+        if WIN32 or multiprocessing.get_start_method() != 'spawn':
+            raise unittest.SkipTest('only valid for spawn method on unix')
+        reader, writer = multiprocessing.Pipe()
+        fd, _ = os.pipe()
+        os.close(_)
+        try:
+            p = multiprocessing.Process(target=self._test_closefds,
+                                        args=(writer, fd))
+            p.start()
+            res = reader.recv()
+            p.join()
+        finally:
+            os.close(fd)
+            writer.close()
+            reader.close()
+        self.assertTrue(res)
+
+#
 #
 #
 
 testcases_other = [OtherTest, TestInvalidHandle, TestInitializers,
                    TestStdinBadfiledescriptor, TestWait, TestInvalidFamily,
-                   TestFlags, TestTimeouts, TestNoForkBomb]
+                   TestFlags, TestTimeouts, TestNoForkBomb, TestCloseFds]
 
 #
 #
 def test_main(run=None):
     if sys.argv[1:] == ['--spawn']:
         multiprocessing.set_start_method('spawn')
+    elif sys.argv[1:] == ['--forkserver']:
+        multiprocessing.set_start_method('forkserver')
     elif sys.argv[1:] == ['--fork']:
         multiprocessing.set_start_method('fork')
     elif sys.argv[1:] != []:
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.