Commits

Sybren Stüvel committed 812d745

Removed some fluff, rewritten some stuff, broken the lot

Comments (0)

Files changed (4)

 
 import rsa
 
-(pub, priv) = rsa.newkeys(64)
+keysize = 64 # bits
+(pub, priv) = rsa.newkeys(keysize)
 
 print "Testing integer operations:"
 
 message = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
 print "\tMessage:   %s" % message
 
-encrypted = rsa.encrypt(message, pub)
+encrypted = rsa.encrypt(message, pub, 17)
 print "\tEncrypted: %s" % encrypted
 
 decrypted = rsa.decrypt(encrypted, priv)
 __date__ = "2010-02-08"
 __version__ = '2.1-beta0'
 
-import rsa.prime
-import rsa.transform
-import rsa.common
+import functools
+
+from rsa import transform
+from rsa import common
 
 from rsa.keygen import newkeys
 from rsa.core import encrypt_int, decrypt_int
 
-def encode64chops(chops):
-    """base64encodes chops and combines them into a ',' delimited string"""
+def get_blocks(message, block_size):
+    '''Generator, yields the blocks of the message.'''
+    
+    msglen = len(message)
+    blocks = msglen / block_size
 
-    # chips are character chops
-    chips = [rsa.transform.int2str64(chop) for chop in chops]
-
-    # delimit chops with comma
-    encoded = ','.join(chips)
-
-    return encoded
-
-def decode64chops(string):
-    """base64decodes and makes a ',' delimited string into chops"""
-
-    # split chops at commas
-    chips = string.split(',')
-
-    # make character chips into numeric chops
-    chops = [rsa.transform.str642int(chip) for chip in chips]
-
-    return chops
-
-def block_size(n):
-    '''Returns the block size in bytes, given the public key.
-
-    The block size is determined by the 'n=p*q' component of the key.
-    '''
-
-    # Set aside 2 bits so setting of safebit won't overflow modulo n.
-    nbits = rsa.common.bit_size(n) - 2
-    nbytes = nbits / 8
-
-    return nbytes
-
-
-def chopstring(message, key, n, int_op):
-    """Chops the 'message' into integers that fit into n.
-    
-    Leaves room for a safebit to be added to ensure that all messages fold
-    during exponentiation. The MSB of the number n is not independent modulo n
-    (setting it could cause overflow), so use the next lower bit for the
-    safebit. Therefore this function reserves 2 bits in the number n for
-    non-data bits.
-
-    Calls specified encryption function 'int_op' for each chop before storing.
-
-    Used by 'encrypt' and 'sign'.
-    """
-
-
-    nbytes = block_size(n)
-
-    msglen = len(message)
-    blocks = msglen / nbytes
-
-    if msglen % nbytes > 0:
+    if msglen % block_size > 0:
         blocks += 1
 
-    cypher = []
-    
     for bindex in range(blocks):
-        offset = bindex * nbytes
-        block = message[offset:offset + nbytes]
+        offset = bindex * block_size
+        yield message[offset:offset + block_size]
 
-        value = rsa.transform.bytes2int(block)
-        to_store = int_op(value, key, n)
-
-        cypher.append(to_store)
-
-    return encode64chops(cypher)   #Encode encrypted ints to base64 strings
-
-def gluechops(string, key, n, funcref):
-    """Glues chops back together into a string.  calls
-    funcref(integer, key, n) for each chop.
-
-    Used by 'decrypt' and 'verify'.
-    """
-
-    messageparts = []
-    chops = decode64chops(string)  #Decode base64 strings into integer chops
-
-    nbytes = block_size(n)
-    
-    for chop in chops:
-        value = funcref(chop, key, n) #Decrypt each chop
-        block = rsa.transform.int2bytes(value)
-
-        # Pad block with 0-bytes until we have reached the block size
-        blocksize = len(block)
-        padsize = nbytes - blocksize
-        if padsize < 0:
-            raise ValueError('Block larger than block size (%i > %i)!' %
-                    (blocksize, nbytes))
-        elif padsize > 0:
-            block = '\x00' * padsize + block
-
-        messageparts.append(block)
-
-    # Combine decrypted strings into a msg
-    return ''.join(messageparts)
-
-def encrypt(message, key):
+def encrypt(message, key, block_size):
     """Encrypts a string 'message' with the public key 'key'"""
     if 'n' not in key:
         raise Exception("You must use the public key with encrypt")
 
-    return chopstring(message, key['e'], key['n'], encrypt_int)
+    op = functools.partial(encrypt_int, ekey=key['e'], n=key['n'])
+
+    print 'E  : %i (%i bytes)' % (key['e'], transform.byte_size(key['e']))
+    print 'N  : %i (%i bytes)' % (key['n'], transform.byte_size(key['n']))
+
+    blocks = get_blocks(message, block_size)
+    crypto = list(transform.block_op(blocks, block_size, op))
+
+    return ''.join(crypto)
 
 def sign(message, key):
     """Signs a string 'message' with the private key 'key'"""
     if 'p' not in key:
         raise Exception("You must use the private key with sign")
 
-    return chopstring(message, key['d'], key['p']*key['q'], encrypt_int)
+#    return chopstring(message, key['d'], key['p']*key['q'], encrypt_int)
 
 def decrypt(cypher, key):
     """Decrypts a string 'cypher' with the private key 'key'"""
     if 'p' not in key:
         raise Exception("You must use the private key with decrypt")
 
-    return gluechops(cypher, key['d'], key['p']*key['q'], decrypt_int)
+#    return gluechops(cypher, key['d'], key['p']*key['q'], decrypt_int)
 
 def verify(cypher, key):
     """Verifies a string 'cypher' with the public key 'key'"""
     if 'n' not in key:
         raise Exception("You must use the public key with verify")
 
-    return gluechops(cypher, key['e'], key['n'], decrypt_int)
+#    return gluechops(cypher, key['e'], key['n'], decrypt_int)
 
 # Do doctest if we're not imported
 if __name__ == "__main__":
     if not type(message) is types.LongType:
         raise TypeError("You must pass a long or int")
 
-    if message < 0 or message > n:
-        raise OverflowError("The message is too long")
+    if message < 0:
+        raise ValueError('Only non-negative numbers are supported')
+         
+    if message > n:
+        raise OverflowError("The message %i is too long for n=%i" % (message, n))
 
+    # TODO: reinstate safebit
     #Note: Bit exponents start at zero (bit counts start at 1) this is correct
-    safebit = rsa.common.bit_size(n) - 2        # compute safe bit (MSB - 1)
-    message += (1 << safebit)                   # add safebit to ensure folding
+    #safebit = rsa.common.bit_size(n) - 2        # compute safe bit (MSB - 1)
+    #message += (1 << safebit)                   # add safebit to ensure folding
 
     return pow(message, ekey, n)
 
 
     message = pow(cyphertext, dkey, n)
 
-    safebit = rsa.common.bit_size(n) - 2        # compute safe bit (MSB - 1)
-    message -= (1 << safebit)                   # remove safebit before decode
+    # TODO: reinstate safebit
+    #safebit = rsa.common.bit_size(n) - 2        # compute safe bit (MSB - 1)
+    #message -= (1 << safebit)                   # remove safebit before decode
 
     return message
 
 def bit_size(number):
     """Returns the number of bits required to hold a specific long number"""
 
+    if number < 0:
+        raise ValueError('Only nonnegative numbers possible: %s' % number)
+
+    if number == 0:
+        return 1
+    
     return int(math.ceil(math.log(number, 2)))
 
+def byte_size(number):
+    """Returns the number of bytes required to hold a specific long number.
+    
+    The number of bytes is rounded up.
+    """
+
+    return int(math.ceil(bit_size(number) / 8.0))
+
 def bytes2int(bytes):
     """Converts a list of bytes or an 8-bit string to an integer.
 
     if not (type(bytes) is types.ListType or type(bytes) is types.StringType):
         raise TypeError("You must pass a string or a list")
 
+    
     # Convert byte stream to integer
     integer = 0
     for byte in bytes:
 
     # Do some bounds checking
     if block_size is not None:
-        needed_bytes = int(math.ceil(bit_size(number) / 8.0))
+        needed_bytes = byte_size(number)
         if needed_bytes > block_size:
             raise OverflowError('Needed %i bytes for number, but block size '
                 'is %i' % (needed_bytes, block_size))
 
     return padding + ''.join(bytes)
 
+def block_op(block_provider, block_size, operation):
+    r'''Generator, applies the operation on each block and yields the result
+    
+    Each block is converted to a number, the given operation is applied and then
+    the resulting number is converted back to a block of data. The resulting
+    block is yielded.
+    
+    @param block_provider: an iterable that iterates over the data blocks.
+    @param block_size: the used block size
+    @param operation: a function that accepts an integer and returns an integer 
+    
+    >>> blocks = ['\x00\x01\x02', '\x03\x04\x05']
+    >>> list(block_op(blocks, 3, lambda x: (x + 6)))
+    ['\x00\x01\x08', '\x03\x04\x0b']
+    
+    '''
 
-def to64(number):
-    """Converts a number in the range of 0 to 63 into base 64 digit
-    character in the range of '0'-'9', 'A'-'Z', 'a'-'z','-','_'.
-    
-    >>> to64(10)
-    'A'
-
-    """
-
-    if not (type(number) is types.LongType or type(number) is types.IntType):
-        raise TypeError("You must pass a long or an int")
-
-    if 0 <= number <= 9:            #00-09 translates to '0' - '9'
-        return chr(number + 48)
-
-    if 10 <= number <= 35:
-        return chr(number + 55)     #10-35 translates to 'A' - 'Z'
-
-    if 36 <= number <= 61:
-        return chr(number + 61)     #36-61 translates to 'a' - 'z'
-
-    if number == 62:                # 62   translates to '-' (minus)
-        return chr(45)
-
-    if number == 63:                # 63   translates to '_' (underscore)
-        return chr(95)
-
-    raise ValueError(u'Invalid Base64 value: %i' % number)
-
-
-def from64(number):
-    """Converts an ordinal character value in the range of
-    0-9,A-Z,a-z,-,_ to a number in the range of 0-63.
-    
-    >>> from64(49)
-    1
-
-    """
-
-    if not (type(number) is types.LongType or type(number) is types.IntType):
-        raise TypeError("You must pass a long or an int")
-
-    if 48 <= number <= 57:         #ord('0') - ord('9') translates to 0-9
-        return(number - 48)
-
-    if 65 <= number <= 90:         #ord('A') - ord('Z') translates to 10-35
-        return(number - 55)
-
-    if 97 <= number <= 122:        #ord('a') - ord('z') translates to 36-61
-        return(number - 61)
-
-    if number == 45:               #ord('-') translates to 62
-        return(62)
-
-    if number == 95:               #ord('_') translates to 63
-        return(63)
-
-    raise ValueError(u'Invalid Base64 value: %i' % number)
-
-
-def int2str64(number):
-    """Converts a number to a string of base64 encoded characters in
-    the range of '0'-'9','A'-'Z,'a'-'z','-','_'.
-    
-    >>> int2str64(123456789)
-    '7MyqL'
-
-    """
-
-    if not (type(number) is types.LongType or type(number) is types.IntType):
-        raise TypeError("You must pass a long or an int")
-
-    string = ""
-
-    while number > 0:
-        string = "%s%s" % (to64(number & 0x3F), string)
-        number /= 64
-
-    return string
-
-
-def str642int(string):
-    """Converts a base64 encoded string into an integer.
-    The chars of this string in in the range '0'-'9','A'-'Z','a'-'z','-','_'
-    
-    >>> str642int('7MyqL')
-    123456789
-
-    """
-
-    if not (type(string) is types.ListType or type(string) is types.StringType):
-        raise TypeError("You must pass a string or a list")
-
-    integer = 0
-    for byte in string:
-        integer *= 64
-        if type(byte) is types.StringType: byte = ord(byte)
-        integer += from64(byte)
-
-    return integer
-
+    for block in block_provider:
+        number = bytes2int(block)
+        print 'In : %i (%i bytes)' % (number, byte_size(number))
+        after_op = operation(number)
+        print 'Out: %i (%i bytes)' % (after_op, byte_size(after_op))
+        yield int2bytes(after_op, block_size)
 
 if __name__ == '__main__':
     import doctest
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.