Source

session-cart / session_cart / cart.py

Full commit
'''Session based shopping cart'''


class CartItem(object):
    '''Lightweight container for cart items'''
    __slots__ = ('id', 'item', 'quantity')

    def __init__(self, index, item, quantity):
        self.index = index
        self.item = item
        self.quantity = quantity and quantity or 0

    def __repr__(self):
        return 'CartItem(%r, %r)' % (self.item, self.quantity)

    def __cmp__(self, other):
        if isinstance(other, CartItem):
            return cmp(self.item, other.item)
        return cmp(self.item, other)


class Cart(list):
    '''Handles a list of items stored in the session'''
    model = None
    key = 'cart'

    def __init__(self, request, items=None):
        super(Cart, self).__init__()
        self.request = request
        if items is not None:
            for item in items:
                self.add_item(item.item, item.quantity)
        else:
            # Cart is stored as a list of (item_id, quantity)
            for prod, quantity in request.session.get(self.key, []):
                try:
                    self.add_item(prod, quantity)
                except self.model.DoesNotExist:
                    pass

    def save(self):
        '''Save this cart to the session'''
        self.request.session[self.key] = tuple(
            (i.item.pk, i.quantity,)
            for i in self
        )

    def add_item(self, item, quantity=1):
        '''Add a new item to this session'''
        if not isinstance(item, self.model):
            item = self.model.objects.get(pk=item)
        # Dupe checking
        try:
            self[self.index(item)].quantity += quantity
        except ValueError:
            self.append(CartItem(len(self), item, quantity))

    def size(self):
        '''Return the number of items in this cart'''
        return len(self)

    def empty(self):
        '''Remove all items from cart'''
        while len(self):
            self.pop()
        self.save()

    def __repr__(self):
        return ','.join([str(x) for x in self])