Source

nsocket / nsocket / core / niosocket.py

from nsocket.core.engine import engine
from nsocket.core import cothread
import sys
import socket
import errno
import fcntl

BUFSIZE = 4096

_socket = socket.socket

CONNECT_ERROR = (errno.EINPROGRESS, errno.EALREADY, errno.EWOULDBLOCK)
CONNECT_SUCCESS = (0, errno.EISCONN)

def get_fd(fd):
    return getattr(fd, 'fileno', lambda:fd)()

def set_nonblocking(fd):
    if hasattr(fd, 'setblocking'):
        fd.setblocking(0)
    else:
        flags = fcntl.fcntl(fd, FCNTL.F_GETFL)
        flags = flags | os.O_NONBLOCK
        fcntl.fcntl(fd, FCNTL.F_SETFL, flags)

def set_blocking(fd):
    if hasattr(fd, 'setblocking'):
        fd.setblocking(1)
    else:
        flags = fcntl.fcntl(fd, FCNTL.F_GETFL)
        flags = flags & ~os.O_NONBLOCK
        fcntl.fcntl(fd, FCNTL.F_SETFL, flags)

def nonblock_recv(func):

    def recv(self, size):
        #if self.act_non_blocking:
        #    return self.fd.recv(size)
        buf = self.recvbuf
        if buf:
            chunk, self.recvbuf = buf[:size], buf[size:]
            return chunk
        s = self._socket
        bytes = func(s, size)
        end = None
        #if self.gettimeout():
        #    end = time.time()+self.gettimeout()
        #else:
        #    end = None
        timeout = None
        while bytes is None:
            try:
                if end:
                    timeout = end - time.time()
                io_switch(s, read=True, timeout=timeout)
            except socket.timeout:
                raise
            except socket.error, e:
                if e[0] == errno.EPIPE:
                    bytes = ''
                else:
                    raise
            else:
                bytes = func(s, size)
        self.recvcount += len(bytes)
        return bytes
    return recv

def nonblock_send(func):

    def send(self, data):
        #if self.act_non_blocking:
        #    return self.fd.send(data)
        s = self._socket
        count = func(s, data)
        if not count:
            return 0
        self.sendcount += count
        return count
    return send

def io_switch(fd, read=None, write=None, timeout=None,remove=True):
    
    self = cothread.getcurrent()
    fileno = get_fd(fd)

    def callback(_fd):
        print "ready rw %s" % _fd
        if remove:
            if read:
                engine.remove_reader(fileno)
            if write:
                engine.remove_writer(fileno)
        
        self.switch()

    if read:
        engine.add_reader(fileno, callback)
    if write:
        engine.add_writer(fileno, callback)
    return engine.switch()


def _accept(sock):
    try:
        return sock.accept()
    except socket.error, e:
        if e[0] == errno.EWOULDBLOCK:
            return None
        raise

def _connect(sock, addr):
    ret = sock.connect_ex(addr)
    if ret in CONNECT_ERROR:
        return None
    elif ret not in CONNECT_SUCCESS:
        raise socket.error(ret, errno.errocode[ret])
    return sock

def _nb_recv(sock, bufsize):
    try:
        return sock.recv(bufsize)
    except socket.error, e:
        if e[0] == errno.EWOULDBLOCK:
            return None
        raise

def _nb_send(sock, data):
    try:
        return sock.send(data)
    except socket.error, e:
        if e[0] == errno.EWOULDBLOCK:
            return None
    raise

class NioSocket(object):

    def __init__(self, *args, **kwargs):
        sock = kwargs.pop('socket', None)
        if sock is None:
            sock = _socket(*args)
        set_nonblocking(sock)
        self._socket = sock
        self._fd = sock.fileno()
        self.recvbuf = None
        self.recvcount = 0
        self.sendcount = 0
        self._nb = True

    def setblocking(self, flag):
        self._socket.setblocking(flag)
        if flag:
            self._nb = False
        else:
            self._nb = True

    def accept(self):
        sock = self._socket
        while True:
            res = _accept(sock)
            if res is not None:
                client, addr = res
#                print("accepted")
                return type(self)(socket=client), addr
#            print("accept loop")

            io_switch(sock, read=True, remove=False)
    
    def bind(self, *args, **kwargs):
        try:
            sock = self._socket
            sock.setsockopt(
                socket.SOL_SOCKET,
                socket.SO_REUSEADDR,
                sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1,
            )
        except socket.error:
            pass
        fn = self.bind = self._socket.bind
        fn(*args, **kwargs)

    def close(self):
        try:
            self._socket.close()
        finally:
            engine.remove_reader(self._fd)
            engine.remove_writer(self._fd)
    
    def connect(self, *args, **kwargs):
        sock = self._socket
        while not _connect(sock, address):
            io_switch(sock, write=True)

    def connect_ex(self, *args, **kwargs):
        fn = self.connect_ex = self._socket.connect_ex
        return fn(*args, **kw)
    
    def dup(self, *args, **kw):
        sock = self._socket.dup(*args, **kw)
        set_nonblocking(sock)
        return type(self)(socket=sock)
    
    def fileno(self):
        return self._fd
    
    def getpeername(self):
        fn = self.getpeername = self._socket.getpeername()
        return fn()
    
    def getsockname(self):
        fn = self.getsockname = self._socket.getsockname
        return fn()

    def getsockopt(self, *args, **kwargs):
        fn = self.getsockopt = self._socket.getsockopt
        return fn(*args, **kwargs)

    def listen(self, *args, **kwargs):
        fn = self.listen = self._socket.listen
        return fn(*args, **kwargs)
    
    def makefile(self, mode = None, bufsize = None):
        return NioFile(self._socket)
    
    recv = nonblock_recv(_nb_recv)
    #def recv(self, bufsize):
    #    sock = self._socket
    #    print sock
    #    return _nb_recv(sock, bufsize)

    def recv_from(self, *args, **kwargs):
        pass
    
    def recvfrom_into(self, *args, **kwargs):
        pass

    def recv_into(self, *args, **kwargs):
        pass

    send = nonblock_send(_nb_send)
#    def send(self, data):
#        sock = self._socket
#        return _nb_send(sock, data)
    
    def send_all(self, *args, **kwargs):
        pass
    
    def sendto(self, *args, **kwargs):
        pass

    def setblocking(self, flag):
        fn = self.setblocking = self._socket.setblocking
        return fn(flag)

    def settimeout(self, value):
        fn = self.settimeout = self._socket.settimeout
        return fn(value)



class NioFile(object):
    
    def __init__(self, sock):
        self._socket = sock




def install_nsocket():
    if not sys.modules.get('socket', None):
        mod = __import__('socket')
        mod.socket = NioSocket
        sys.modules['socket'] = mod
    socket.socket = NioSocket