Source

pypesocket / pypesocketTest.py

#!/usr/bin/env python

import os
import re
import socket
import cStringIO
from threading import Thread
import unittest
import errno
import time
import pypesocket

def server(response=None):
    def decorator(func):
        def wrapper(self):
            self.server.start(response)
            try:
                func(self)
            finally:
                self.server.stop()
            #print "main thread finished"
        return wrapper
    return decorator


def expect_socketerror(errcode, errtext=None):
    """ A decorator acts like assertRaisesRegexp, to support older python
    """
    if isinstance(errtext, str):
        errtext = re.compile(errtext)
    def test_decorator(func):
        def test_decorated(self, *args, **kwargs):
            err = (None, None)
            try:
                func(self, *args, **kwargs)
            except socket.error, e:
                err = e.args
            self.assertEqual(errcode, err[0])
            if errtext:
                self.assertTrue(errtext.match(err[1]))
        return test_decorated
    return test_decorator


def recv_all(sock):
    cs = cStringIO.StringIO()
    s = True
    try:
        while s:
            s = sock.recv(65536)
            if s:
                cs.write(s)
    finally:
        try:
            sock.shutdown(socket.SHUT_RD)
        except socket.error, err:
            if err.args[0] != errno.ENOTCONN:
                raise
    return cs.getvalue()


def echo(input):
    return input


_PipePath = "PipeSocketTest.sock"


class Server(object):
    def start(self, response=None):
        def run():
            try:
                while True:
                    sock, addr = self.socket.accept()
                    if response and not self.closing:
                        sock.sendall(response(recv_all(sock)))
                        sock.shutdown(socket.SHUT_WR)
                    sock.close()
            except socket.error, err:
                if err.args[0] != socket.EBADF:
                    raise
                    #print "server thread finished"

        self.closing = False
        self.socket = socket.socket(socket.AF_UNIX)
        self.socket.bind(_PipePath)
        self.socket.listen(5)

        self.serverthread = Thread(target=run)
        self.serverthread.start()
        #print "server thread started"

    def stop(self):
        self.closing = True
        self.socket.close()
        try:
            # interrupt socket.accept if it has been already invoked
            sock = socket.socket(socket.AF_UNIX)
            sock.connect(_PipePath)
            sock.close()
        except socket.error:
            pass
        self.serverthread.join()


class PipesocketTest(unittest.TestCase):
    def setUp(self):
        self.server = Server()

    def tearDown(self):
        try:
            os.unlink(_PipePath)
        except OSError:
            pass


class ConnectionTest(PipesocketTest):
    @server()
    def test_connect(self):
        for i in range(10):
            sock = socket.socket(socket.AF_UNIX)
            sock.connect(_PipePath)
            sock.close()

    @server()
    def test_connect_ex(self):
        for i in range(10):
            sock = socket.socket(socket.AF_UNIX)
            self.assertEqual(0, sock.connect_ex(_PipePath))
            sock.close()

    @expect_socketerror(errno.ENOENT)
    def test_connect_fail(self):
        sock = socket.socket(socket.AF_UNIX)
        sock.connect(_PipePath)

    def test_connect_ex_fail(self):
        sock = socket.socket(socket.AF_UNIX)
        err = sock.connect_ex(_PipePath)
        self.assertEqual(errno.ENOENT, err)

    @server(echo)
    def test_echo(self):
        sock = socket.socket(socket.AF_UNIX)
        sock.connect(_PipePath)
        sock.sendall('hello')
        sock.shutdown(socket.SHUT_WR)
        self.assertEqual('hello', recv_all(sock))
        sock.close()


class GetPeerNameTest(PipesocketTest):
    @server(echo)
    def test_getpeername(self):
        sock = socket.socket(socket.AF_UNIX)
        sock.connect(_PipePath)
        self.assertEqual(_PipePath, sock.getpeername())
        sock.shutdown(socket.SHUT_WR)
        time.sleep(0.1)
        sock.close()

    @expect_socketerror(errno.ENOTCONN)
    def test_getpeername_server_wo_listen(self):
        sock = socket.socket(socket.AF_UNIX)
        sock.bind(_PipePath)
        sock.getpeername()

    @expect_socketerror(errno.ENOTCONN)
    def test_getpeername_server_wo_connect(self):
        sock = socket.socket(socket.AF_UNIX)
        sock.bind(_PipePath)
        sock.listen(5)
        try:
            sock.getpeername()
        finally:
            sock.close()

    def test_getpeername_server(self):
        svr = socket.socket(socket.AF_UNIX)
        svr.bind(_PipePath)
        svr.listen(5)

        try:
            time.sleep(0.2)  # wait server to be ready
            cli = socket.socket(socket.AF_UNIX)
            cli.connect(_PipePath)

            try:
                sock, addr = svr.accept()
                self.assertEqual('', sock.getpeername())
                self.assertEqual('', addr)
                sock.close()
            finally:
                cli.close()
        finally:
            svr.close()

    @expect_socketerror(errno.ENOTCONN)
    def test_getpeername_wo_connect(self):
        sock = socket.socket(socket.AF_UNIX)
        sock.getpeername()

    @server()
    @expect_socketerror(errno.EINVAL)
    def test_getpeername_on_broken(self):
        sock = socket.socket(socket.AF_UNIX)
        sock.connect(_PipePath)
        time.sleep(0.1)   # wait server to close
        sock.getpeername()
        sock.close()


class GetSockNameTest(PipesocketTest):
    @server(echo)
    def test_getsockname(self):
        sock = socket.socket(socket.AF_UNIX)
        self.assertEqual('', sock.getsockname())
        sock.connect(_PipePath)
        self.assertEqual('', sock.getsockname())
        sock.shutdown(socket.SHUT_WR)
        time.sleep(0.1)
        sock.close()

    def test_getsockname_server(self):
        sock = socket.socket(socket.AF_UNIX)
        sock.bind(_PipePath)
        self.assertEqual(_PipePath, sock.getsockname())

    @server()
    def test_getsockname_on_broken(self):
        sock = socket.socket(socket.AF_UNIX)
        sock.connect(_PipePath)
        time.sleep(0.1)
        self.assertEqual('', sock.getsockname())
        sock.close()

if __name__ == '__main__':
        unittest.main()