Source

tornadis / redis.py

Full commit
import socket
from collections import deque

from tornado import iostream, ioloop
import swirl

def _flatten(l):
    result = []
    for item in l:
        result.extend(item)
    return result

class RedisTornado(object):
    def __init__(self, host = None, port = None, encoding = 'utf-8'):
        """
        Asynchronous Redis client using the Tornado IOLoop
        """
        self.host       = host or 'localhost'
        self.port       = port or 6379
        self.encoding   = encoding
        self._queue     = deque()
        self.callback   = None
        self._waiting   = False
        self.DELIMITER  = '\r\n'
        self.DELBYTES   = len(self.DELIMITER)
        
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
        sock.connect((self.host, self.port))
        
        self.stream = iostream.IOStream(sock)
        
        
    @swirl.asynchronous    
    def _work(self):
        self._waiting = True
        try:
            cmd, callback = self._queue.popleft()
        except IndexError:
            self._waiting = False
            return
            
        yield lambda cb: self.stream.write(cmd, cb)
        response = yield lambda cb: self.stream.read_until(self.DELIMITER, cb)
        
        if response[0] == '$':
            num_bytes = int( response[1:] )
            if num_bytes < 1:
                callback(None)
                self.waiting = False
                self._work()
                return

            response = yield lambda cb : self.stream.read_bytes(num_bytes + self.DELBYTES, cb)
            callback(response[:-2])
            self._waiting = False
            self._work()
            return
            
        elif response[0] == '+':
            callback(True)
            self._waiting = False
            self._work()
            return
        
        elif response[0] == '-':
            callback(False)
            self._waiting = False
            self._work()
            return

    def _encode(self, value):
        if isinstance(value, unicode):
            return value.encode(self.encoding)
        else:
            return value

    def _cmd_creator(self, command, args):
        num_args    = len(args) + 1 # add one for the command
        args = (self._encode(arg) for arg in args)
        args_string = ''.join('$%d\r\n%s\r\n' % (len(arg), arg) for arg in args)
        cmd = '*%d\r\n$%d\r\n%s\r\n%s' % (num_args, len(command), command, args_string)
        return cmd
        
    def set(self, key, value, callback = None):
        cmd = self._cmd_creator('SET', [key, value])
        self._queue.append((cmd, callback))
        if not self._waiting:
            self._work()
        
    def mset(self, kv_list, callback = None):
        cmd = self._cmd_creator('MSET', _flatten(kv_list))
        self._queue.append((cmd, callback))
        if not self._waiting:
            self._work()

    def get(self, key, callback = None):
        cmd = self._cmd_creator('GET', [key])
        self._queue.append((cmd, callback))
        if not self._waiting:
            self._work()