kahgean / kahgean /

# Copyright (C) 2012 Xue Can <> and contributors.
# Licensed under the MIT license:

from collections import Sequence

class Address(long):
    def __new__(cls, address):
        if isinstance(address, basestring):
            value = Address._init_from_seq(address.strip().split('.'))
        elif isinstance(address, Sequence):
            value = Address._init_from_seq(address)
        elif type(address) in [int, long]:
            value = Address._init_from_int(address)
            raise ValueError('unsupported address type')
        return super(Address, cls).__new__(cls, value)
    def __init__(self, address):
        value = long(self)
        v0 = value/2**24
        value = value-v0*2**24
        v1 = value/2**16
        value = value-v1*2**16
        v2 = value/2**8
        v3 = value-v2*2**8
        self._tuple_value = (v0, v1, v2, v3)
    def _init_from_int(value):
        if value<0 or value>2**32-1:
            raise ValueError('value should be in [0, 2**32-1]')
        return value
    def _init_from_seq(seq):
        if len(seq)!=4:
            raise ValueError('length of seq should be 4')
        v = [int(i) for i in seq]
        for i in v:
            if i<0 or i>2**8-1:
                raise ValueError('each piece in a IPv4 address should be in '
                                 '[0, 255]')
        return v[0]*2**24 + v[1]*2**16 + v[2]*2**8 + v[3]
    def tuple_value(self):
        return self._tuple_value
    def __str__(self):
        return '%d.%d.%d.%d' % self.tuple_value
    def __repr__(self):
        return 'Address <%s>' % self.__str__()

# Address(MASKS[24]) == Address('')
MASKS = [long((2**i-1)<<(32-i)) for i in range(0,33)]
COMPS = [long((2**32-1)-i) for i in MASKS]

class Network(Sequence):

    Network('', '24')
    Network('', '')
    Network('', '')
    def __init__(self, address, extra=None):
        if not extra:
            # '' => '', '24'
            address, extra = address.split('/')
        address = Address(address)
            prefix_size = int(extra)
            if prefix_size<0 or prefix_size>32:
                raise ValueError
        except ValueError:
            prefix_size = -1
        if prefix_size >= 0:
            # extra is a prefix size
            mask = MASKS[prefix_size]
            extra = Address(extra)
            if extra in MASKS:
                # extra is a subnetwork mask 
                mask = long(extra)
            elif extra-address in COMPS:
                # extra is a broadcast address
                mask = MASKS[COMPS.index(extra-address)]
                mask = -1
        if mask==-1:
            raise ValueError('subnetwork mask not found from <%s/%s>' % (
                str(address), str(extra)))
        index = MASKS.index(mask)
        comp = COMPS[index]
        if address & comp:
            raise ValueError('<%s/%s> is not a valid network prefix' % (
                str(address), str(extra)))
        self._network = address # network prefix
        self._size = index      # prefix size, e.g. 24 means a /24 subnetwork
        self._length = comp+1   # length of the network, for example there
                                #    are 256 addresses in a /24 subnetwork
        self._broadcast = Address(address+comp) # broadcast address
        self._mask = Address(mask)              # subnetwork mask

    def network(self):
        return self._network

    def broadcast(self):
        return self._broadcast

    def size(self):
        return self._size

    def mask(self):
        return self._mask
    def __contains__(self, value):
        if type(value) in [int, long, Address]:
            return value>=self._network and value<=self._broadcast
        elif type(value) in [Network]:
            return< \
                   and self.broadcast>=value.broadcast
    def __len__(self):
        return self._length

    def __getitem__(self, index):
        index = long(index)
        if index<0 and index>=self._length:
            raise IndexError('index out of range')
        return Address(self._network+index)

    def __iter__(self):
        i = 0
        while i<self._length:
            yield self[i]

    def index(self, value):
        if Address(value) in self:
            return value-self._network
        raise ValueError('%s not in this network' % str(value))

    def count(self, value):
        return 1 if Address(value) in self else 0

    def __reversed__(self):
        i = self._length
        while i>0:
            i -=1
            yield self[i]
    def __eq__(self, other):
        return and self.mask==other.mask
    def __ne__(self, other):
        return not self==other
    def __str__(self):
        return '%s/%d' % (str(self._network), self.size)
    def __repr__(self):
        return 'Network <%s>' % self.__str__()