Source

amqpev / amqpev / binops.py

import struct
from re import finditer
import datetime


def pack_splice(spec, fields):
    """
    Enhanced version of struct.pack which supports splicing arbitrary-length
    binary data, as well as packing contiguous strings of bits.

    Adds the following format characters:

      ':' - Indicates a splice operation. The corresponding datum in the
            argument list should be a string, and will be spliced wholly into
            the output at this location.
      '.' - Indicates a packed bit. Each occurrence will consume one datum from
            the argument list, and strings of contiguous bits will be packed
            into bytes, working from LSB to MSB order within each byte.
    """
    bit_vectors = [ (m.start(), m.end())
            for m in finditer(r'\.{1,8}', spec) ]

    bit_packed_fields = list(fields)
    spec_list = list(spec)

    for start, end in reversed(bit_vectors):
        bits = fields[start:end]
        byte = 0

        for i, bit in enumerate(bits):
            if bit: byte |= 1 << i

        bit_packed_fields[start:end] = [byte]
        spec_list[start:end] = ['B']

    bit_packed_spec = ''.join(spec_list)

    chunks = bit_packed_spec.split(':')
    packed = ''
    bit_packed_fields.insert(0, '')

    for chunk in chunks:
        packed += bit_packed_fields.pop(0)
        nr_elems = len(chunk)

        if nr_elems > 0:
            packed += struct.pack('!' + chunk, *bit_packed_fields[:nr_elems])
            bit_packed_fields = bit_packed_fields[nr_elems:]

    return packed


def unpack_splice(spec, buf, return_consumed_bytes=False):
    """
    Basically the inverse of pack_splice(), with one extra restriction: Each
    occurrence of a splice (':') should be preceded by a format character that
    produces an integer when unpacked. The value of this integer will indicate
    the number of bytes to slice from the packed buffer, starting at this
    point. Unpacking continues starting with the next byte following.

    This means that binary formats that include splices are required to be
    preceded by their length in the packed data. Note however that for
    compatibility with pack_splice(), the length will appear in the output
    list.
    """
    bit_vectors = [ (m.start(), m.end())
            for m in finditer(r'\.{1,8}', spec) ]

    spec_list = list(spec)
    bit_unpack_map = {}

    for start, end in reversed(bit_vectors):
        spec_list[start:end] = ['B']
        bit_unpack_map[start] = end - start

    bit_packed_spec = ''.join(spec_list)

    bit_packed_fields = []
    chunks = bit_packed_spec.split(':')
    total_consumed_bytes = 0

    for chunk in chunks:

        if bit_packed_fields:
            unsplice_len = bit_packed_fields[-1]
            bit_packed_fields.append(buf[:unsplice_len])
            buf = buf[unsplice_len:]
            total_consumed_bytes += unsplice_len

        consumed_bytes = struct.calcsize('!' + chunk)
        bit_packed_fields.extend(struct.unpack('!' + chunk, buf[:consumed_bytes]))
        buf = buf[consumed_bytes:]
        total_consumed_bytes += consumed_bytes

    fields = list(bit_packed_fields)

    for packed_field_idx in sorted(bit_unpack_map.keys()):
        nr_bits = bit_unpack_map[packed_field_idx]
        field = fields[packed_field_idx]
        bit_vector = []

        for i in range(nr_bits):
            bit_vector.append(bool(field & (1 << i)))

        fields[packed_field_idx:packed_field_idx+1] = bit_vector

    if return_consumed_bytes:
        return (fields, total_consumed_bytes)
    else:
        return fields


def pack_str(val):
    return (len(val), val)

def unpack_str(size, data):
    return data

def pack_decimal(val):
    # TODO
    return (0, int(val))

def unpack_decimal(scale, value):
    # TODO
    return value

def pack_fieldarray(val):
    return (0, '')

def unpack_fieldarray(size, data):
    # TODO
    return data

def pack_table(table_data):
    coding = ''
    packing_list = []

    for field_name, val in table_data.iteritems():
        coding += 'B:'
        packing_list += pack_str(field_name)

        # FIXME This works around a bug in rabbit_binary_parser:parse_table
        # where datatype "s" is interpreted as a int16 integer, and not a
        # string with a uint8 size as specified. So to get along with Rabbit,
        # we need to use long strings always.
        #if isinstance(val, str) and len(val) < 255:
        #    coding += 'cB:'
        #    packing_list += ['s'] + list(pack_str(val))
        if isinstance(val, str):
            coding += 'cL:'
            packing_list += ['S'] + list(pack_str(val))
        elif isinstance(val, (int, long)):
            amqp_code, pack_code = _fit_integer(val)
            coding += 'c' + pack_code
            packing_list += [amqp_code, val]
        elif isinstance(val, float):
            coding += 'cd'
            packing_list += ['d', val]
        elif isinstance(val, datetime.datetime):
            coding += 'cQ'
            packing_list += ['T', datetime.time.mktime(val.timetuple())]
        elif isinstance(val, dict):
            coding += 'cL:'
            packing_list += ['F'] + list(pack_table(val))
        else:
            pass

    buf = pack_splice(coding, packing_list)
    return (len(buf), buf)

def unpack_table(size, buf):
    table_data = {}

    while buf:
        (fields, consumed_bytes) = unpack_splice('B:', buf,
                return_consumed_bytes=True)
        field_name = unpack_str(*fields)
        buf = buf[consumed_bytes:]

        table_datatype = buf[0]
        coding = 'x' + TABLE_DATATYPE_CODING[table_datatype][0]
        unpacker = TABLE_DATATYPE_CODING[table_datatype][1]

        (fields, consumed_bytes) = unpack_splice(coding, buf,
                return_consumed_bytes=True)
        buf = buf[consumed_bytes:]
        val = unpacker(*fields)

        table_data[field_name] = val

    return table_data


TABLE_DATATYPE_CODING = {
    't': ('B', bool),
    'b': ('b', int),
    'B': ('B', int),
    'U': ('h', int),
    'u': ('H', int),
    'I': ('l', int),
    'i': ('L', int),
    'L': ('q', int),
    'l': ('Q', int),
    'f': ('f', float),
    'd': ('d', float),
    'D': ('BL', unpack_decimal),
    's': ('B:', unpack_str),
    'S': ('L:', unpack_str),
    'A': ('L:', unpack_fieldarray),
    'T': ('Q', datetime.datetime.fromtimestamp),
    'F': ('Q:', unpack_table),
    'V': ('', lambda: None) }

INT_RANGES = [
        (              -0x80,               0x7F, 'b', 'b'),
        (               0x00,               0xFF, 'B', 'B'),
        (            -0x8000,             0x7FFF, 'U', 'h'),
        (             0x0000,             0xFFFF, 'u', 'H'),
        (        -0x80000000,         0x7FFFFFFF, 'I', 'l'),
        (         0x00000000,         0xFFFFFFFF, 'i', 'L'),
        ( 0x8000000000000000, 0x7FFFFFFFFFFFFFFF, 'L', 'q'),
        ( 0x0000000000000000, 0xFFFFFFFFFFFFFFFF, 'l', 'Q') ]

def _fit_integer(val):
    for lo, hi, amqp_code, pack_code in INT_RANGES:
        if val >= lo and val <= hi:
            return (amqp_code, pack_code)
    raise OverflowError("Can't encode an integer that large.")