piotrlegnica / stuff

Various stuff and utilities.

Clone this repository (size: 50.9 KB): HTTPS / SSH
$ hg clone http://bitbucket.org/piotrlegnica/stuff/
commit 14: 5c0fd3e2d281
parent 13: 69b9bc502315
branch: default
tags: tip
Spelling and other fun useless stuff. Refactored writeColour a bit.
piotrlegnica
7 days ago
stuff / dns.py
r14:5c0fd3e2d281 197 loc 5.8 KB embed / history / annotate / raw /
# -*- coding: utf-8 -*-
import sys, os, weakref, socket, re
import transaction
from ZODB import FileStorage, DB
from BTrees import OOBTree
from persistent import Persistent
from persistent.list import PersistentList
from persistent.mapping import PersistentMapping

db = DB(FileStorage.FileStorage('./dns.fs'))
c  = db.open()
t  = transaction
r  = c.root()

DOMAIN_RE = re.compile('^[a-z0-9\-]+$', re.I)
IP_RE     = re.compile('^(?:[0-9]{1,3}\.){3}[0-9]{1,3}$')

def verifyDomain(domain):
    if IP_RE.match(domain):
        raise ValueError('domain is required, but IP given: %s' % domain)
    
    if len(domain) > 255:
        raise ValueError('domain "%s" too long' % domain)
    
    for label in domain.split('.'):
        if not DOMAIN_RE.match(label) or len(label) > 255 or label.startswith('-') or label.endswith('-'):
            raise ValueError('invalid label "%s" in domain "%s"' % (label, domain))
    
    return domain

def verifyIP(ip):
    if not IP_RE.match(ip):
        raise ValueError('invalid IP: %s' % ip)
    
    for octet in ip.split('.'):
        octet = int(octet)
        if octet < 0 or octet > 255:
            raise ValueError('invalid IP: %s' % ip)
    
    return ip

def verifyShort(uint, m):
    if uint > 65535 or uint < 1:
        raise ValueError('invalid %s: %d' % (m, uint))
    
    return uint

class Domain(Persistent):
    def __init__(self, name):
        self.name       = verifyDomain(name)
        self.subdomains = PersistentMapping()
    
    def toData(self):
        data = '### Domain: %s\n\n' % self.name
        for subdomain in self.subdomains.iterkeys():
            fullDomain = subdomain + self.name
            data += '## Subdomain: %s\n' % fullDomain
            for record in self.subdomains[subdomain]:
                data += '# %s IN %s\n' % (fullDomain, record.__class__.__name__)
                data += record.toData(fullDomain)+'\n'
                data += '##\n\n'
            #/for
        #/for
        data += '###\n\n'
        return data
    
    def add(self, subdomain, record):
        if subdomain == '*' or (len(subdomain) > 0 and verifyDomain(subdomain) and not subdomain.endswith('.')):
            subdomain += '.'
        
        if subdomain not in self.subdomains:
            self.subdomains[subdomain] = PersistentList()
        #/if
        
        self.subdomains[subdomain].append(record)
    
# DNS records
class A(Persistent):
    def __init__(self, ip):
        self.ip = verifyIP(ip)
    
    def toData(self, fullDomain):
        return '+%s:%s' % (fullDomain, self.ip)
    
class NS(Persistent):
    def __init__(self, ns):
        self.ns = verifyDomain(ns)
    
    def toData(self, fullDomain):
        return '.%s::%s' % (fullDomain, self.ns)

class CNAME(Persistent):
    def __init__(self, domain):
        self.domain = verifyDomain(domain)
    
    def toData(self, fullDomain):
        return 'C%s:%s' % (fullDomain, self.domain)

class MX(Persistent):
    def __init__(self, domain, priority):
        self.domain   = verifyDomain(domain)
        self.priority = verifyShort(int(priority), 'priority')
    
    def toData(self, fullDomain):
        return '@%s::%s:%d' % (fullDomain, self.domain, self.priority)

class SRV(Persistent):
    def __init__(self, service, protocol, domain, port, priority, weight):
        self.service  = verifyDomain(service)
        self.protocol = verifyDomain(protocol)
        self.domain   = verifyDomain(domain)
        self.port     = verifyShort(int(port), 'port')
        self.weight   = verifyShort(int(weight), 'weight')
        self.priority = verifyShort(int(priority), 'priority')
    
    def toOctal(self, num):
        num = socket.htons(num)
        b1  = num >> 8
        b2  = num & ~(b1 << 8)
        return '\\%03o\\%03o' % (b2, b1)
    
    def toData(self, fullDomain):
        dest = ''.join(['\\%03o%s' % (len(x), x) for x in self.domain.split('.')])
        return ':_%s._%s.%s:33:%s%s%s%s\\000' % (
            self.service, self.protocol, fullDomain, self.toOctal(self.priority),
            self.toOctal(self.weight), self.toOctal(self.port), dest
        )

def fail(msg):
    print msg
    sys.exit(1)

def getArg(n, msg):
    try:
        return sys.argv[n]
    except IndexError:
        fail(msg)

def getGlobal(g, msg):
    try:
        return globals()[g]
    except KeyError:
        fail(msg)

command = getArg(1, 'dns <cmd> <args>')

def cmdDom():
    subcmd = getArg(2, 'subcommand (add/del) required')
    domain = getArg(3, 'domain required')
    
    if subcmd == 'add':
        if domain in r: fail('domain exists')
        r[domain] = Domain(domain)
        t.get().note('added domain: %s' % domain)
    elif subcmd == 'del':
        if domain not in r: fail('domain does not exist')
        del r[domain]
        t.get().note('removed domain: %s' % domain)

def cmdRec():
    domain = getArg(2, 'domain required')
    record = getArg(3, 'record type required').upper()
    subdom = getArg(4, 'subdomain required')
    klass  = getGlobal(record, 'unknown record type: %s' % record)
    
    if domain not in r:
        fail('domain %s does not exist' % domain)
    
    try:
        r[domain].add(subdom, klass(*sys.argv[5:]))
        t.get().note('added %s for %s.%s' % (record, subdom, domain))
    except TypeError:
        import inspect
        spec = inspect.getargspec(klass.__init__)
        fail('required args: %s' % ', '.join(spec.args[1:]))

def cmdData():
    import time
    data  = '##################################\n'
    data += '## tinydns data file            ##\n'
    data += '## generated: %17s ##\n' % time.strftime('%d-%m-%Y, %H:%M')
    data += '##################################\n\n'
    for domain in r:
        domain = r[domain]
        if isinstance(domain, Domain): data += domain.toData()
    
    print data

f = getGlobal('cmd%s' % command.title(), 'unknown cmd: %s' % command)

try:
    t.begin()
    f()
    t.commit()
except:
    t.abort()
    raise