Commits

Kelvin Wong committed 3edcc9c Draft

Added test suite to setup.py. python setup.py test now runs the test suite. Removed old test file as it was unimportable. Old tests moved to new test suite. Moved value variable to top of functions in both C files to allow statis analysis testing tools. Added additional bounds checking in C files to avoid crashing.

Comments (0)

Files changed (10)

 .DS_Store
 MANIFEST
 dist
+.pyc
+.so
 #!/usr/bin/env python
-from distutils.core import setup, Extension
+from distutils.core import setup, Extension, Command
 
 import sys
 import platform
 
 includes = []
 library_dirs = []
+cmdclasses = dict()
 CFLAGS = []
 
+
+class Tester(Command):
+    """Runs unit tests"""
+
+    user_options = []
+
+    def initialize_options(self):
+        pass
+
+    def finalize_options(self):
+        pass
+
+    def run(self):
+        if ((sys.version_info > (3, 2, 0, 'final', 0)) or
+            (sys.version_info > (2, 7, 0, 'final', 0) and sys.version_info < (3, 0, 0, 'final', 0))):
+            from unittest import TextTestRunner, defaultTestLoader
+        else:
+            try:
+                from unittest2 import TextTestRunner, defaultTestLoader
+            except ImportError:
+                print("Please install unittest2 to run the test suite")
+                exit(-1)
+        from tests import test_scrypt, test_scrypt_py2x, test_scrypt_py3x
+        suite = defaultTestLoader.loadTestsFromModule(test_scrypt)
+        suite.addTests(defaultTestLoader.loadTestsFromModule(test_scrypt_py2x))
+        suite.addTests(defaultTestLoader.loadTestsFromModule(test_scrypt_py3x))
+        runner = TextTestRunner(verbosity=2)
+        result = runner.run(suite)
+
+cmdclasses['test'] = Tester
+
 if sys.platform.startswith('linux'):
     define_macros = [('HAVE_CLOCK_GETTIME', '1'),
                      ('HAVE_LIBRT', '1'),
       url='http://bitbucket.org/mhallin/py-scrypt',
       ext_modules=[scrypt_module],
       classifiers=['Development Status :: 4 - Beta',
-                   'Programming Language :: Python :: 2',
+                   'Intended Audience :: Developers',
+                   'License :: OSI Approved :: BSD License',
+                   'Programming Language :: Python :: 2.6',
+                   'Programming Language :: Python :: 2.7',
                    'Programming Language :: Python :: 3',
                    'Topic :: Security :: Cryptography',
                    'Topic :: Software Development :: Libraries'],
       license='2-clause BSD',
-      long_description=open('README.markdown').read())
+      long_description=open('README.markdown').read(),
+      cmdclass=cmdclasses)
     double maxmemfrac = g_maxmemfrac_default_enc;
     double maxtime = g_maxtime_default_enc;
     uint8_t *outbuf;
+    PyObject *value = NULL;
 
     if (!PyArg_ParseTupleAndKeywords(args, kwargs, "SS|dnd", g_kwlist,
                                      &input, &password,
     Py_DECREF(password);
     Py_DECREF(input);
 
-    PyObject *value = NULL;
     if (errorcode != 0) {
         PyErr_Format(ScryptError, "%s", g_error_codes[errorcode]);
         PyErr_SetNone(ScryptError);
     double maxmemfrac = g_maxmemfrac_default;
     double maxtime = g_maxtime_default;
     uint8_t *outbuf;
+    PyObject *value = NULL;
 
     if (!PyArg_ParseTupleAndKeywords(args, kwargs, "SS|dnd", g_kwlist,
                                      &input, &password,
     Py_DECREF(password);
     Py_DECREF(input);
 
-    PyObject *value = NULL;
     if (errorcode != 0) {
         PyErr_Format(ScryptError, "%s", g_error_codes[errorcode]);
     } else {
     unsigned long int p = 1;
     unsigned char *outbuf;
     int outbuflen;
+    PyObject *value = NULL;
 
     static char *g2_kwlist[] = {"password", "salt", "N", "r", "p", NULL};
 
 
     Py_BEGIN_ALLOW_THREADS;
 
-    if ( r * p >= (1 << 30) || N <= 1 || (N & (N-1)) != 0) {
+    if ( r * p >= (1 << 30) || r < 1 || p < 1 || N <= 1 || (N & (N-1)) != 0) {
         paramerror = -1;
         hasherror = 0;
     } else {
     Py_DECREF(password);
     Py_DECREF(salt);
 
-    PyObject *value = NULL;
     if (paramerror != 0) {
         PyErr_Format(ScryptError, "%s",
             "hash parameters are wrong (r*p should be < 2**30, and N should be a power of two > 1)");
     double maxmemfrac = g_maxmemfrac_default_enc;
     double maxtime = g_maxtime_default_enc;
     uint8_t *outbuf;
+    PyObject *value = NULL;
 
     if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#s#|dnd", g_kwlist,
                                      &input, &inputlen, &password, &passwordlen,
                               maxmem, maxmemfrac, maxtime);
     Py_END_ALLOW_THREADS;
 
-    PyObject *value = NULL;
     if (errorcode != 0) {
         PyErr_Format(ScryptError, "%s", g_error_codes[errorcode]);
         PyErr_SetNone(ScryptError);
     double maxmemfrac = g_maxmemfrac_default;
     double maxtime = g_maxtime_default;
     uint8_t *outbuf;
+    PyObject *value = NULL;
 
     if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s#s#|dnd", g_kwlist,
                                      &input, &inputlen, &password, &passwordlen,
                               maxmem, maxmemfrac, maxtime);
     Py_END_ALLOW_THREADS;
 
-    PyObject *value = NULL;
     if (errorcode != 0) {
         PyErr_Format(ScryptError, "%s", g_error_codes[errorcode]);
     } else {
     unsigned long p = 1;
     unsigned char *outbuf;
     int outbuflen;
+    PyObject *value = NULL;
 
     static char *g2_kwlist[] = {"password", "salt", "N", "r", "p", NULL};
 
 
     Py_BEGIN_ALLOW_THREADS;
 
-    if ( r * p >= (1 << 30) || N <= 1 || (N & (N-1)) != 0) {
+    if ( r * p >= (1 << 30) || r < 1 || p < 1 || N <= 1 || (N & (N-1)) != 0) {
         paramerror = -1;
         hasherror = 0;
     } else {
 
     Py_END_ALLOW_THREADS;
 
-    PyObject *value = NULL;
     if (paramerror != 0) {
         PyErr_Format(ScryptError, "%s",
             "hash parameters are wrong (r*p should be < 2**30, and N should be a power of two > 1)");

tests/ciphertexts.csv

+input,password,maxtime,maxmem,maxmemfrac,ciphertext
+message,password,0.01,0,0.0625,736372797074000a00000008000000019f6d3fe5e9423a12d330e35089befdfbb476c7d4faea91492a2561f942c1599701aba424220218b9f81812df06d7cf2a281fd0fdbc7c9d978c335bf5209b1062ee2e49993c4d3a37d347ad6bf0eaecc2fe579a6f320b0acf475882c222c0ba34a7ac5379bedc82358bb3f736ea31d4b824e8bc95c75579
+message,password,5.0000,0.0000,0.1250,736372797074000a00000008000000cb974d55992fea307caa3593205c8851cd56b3ddfd241ee7b1c075cfd2e2f871dddcee71b6bc2b6b075caa1699ea58ce32e9fcd802b18069828201692380574e23e84c2d39d6a951c9c2401dea1a0fa44195b01fca9332f7aac223c84251c69a28037788e09cf297a003a62e2e7c5b6f039e3c1d21fa5da6
+message,password,0.5000,0.0000,0.0625,736372797074000a00000008000000141429bb22aef840e1dde29564dada8f2a77fc2855fe8c9e27e5bf5df4a0eec330344e8471bf83e2466a28acc9d1813a3e50e64697fb8e9c7c17954950f16a5b78fe114d8b147c5936802fe52b17f83e6da9de884257f2a938aa37c2e92b1c33cd9afd0f5e91baba29be4b1c0709bd2a31a4c52394f9b0d4
+message,password,0.5000,0.0000,0.1250,736372797074000a0000000800000014e9baa49b57b2e5b561df54121d870532714f063f295c15ca48e32b677e8e11126ab8da24683951749a19eede36ea3768cb01e8eda6c0f570db7051e07e99d2d687ce28f96aa2e6afa8a0a8c5d902c847e041e14435a22d873c4e676f7a7f5ab9107171b583d02724c056707640931a3c160b9e0ef87bae
+message,password,0.5000,10485760.0000,0.1250,736372797074000a0000000800000014e53c55cd89b6e2de639c1621d92e79212e5f4fc793387435486427eb8adeb192382b5b6397f98ac66fa528a13cba01bd611082369fe1e2244a01ba8329ff6fcaa2db5e87fb75683ae6fc636d873c2e004a0209b947c7b8e805174d65ec0ac2cb597baadae6cf092c33a5096590860b51570faa89e39bf0

tests/hashvectors.csv

+password,salt,n,r,p,hexhash
+,,16,1,1,77d6576238657b203b19ca42c18a0497f16b4844e3074ae8dfdffa3fede21442fcd0069ded0948f8326a753a0fc81f17e8d3e0fb2e0d3628cf35e20c38d18906
+password,NaCl,1024,8,16,fdbabe1c9d3472007856e7190d01e9fe7c6ad7cbc8237830e77376634b3731622eaf30d92e22a3886ff109279d9830dac727afb94a83ee6d8360cbdfa2cc0640
+pleaseletmein,SodiumChloride,16384,8,1,7023bdcb3afd7348461c06cd81fd38ebfda8fbba904f8e3ea9b543f6545da1f2d5432955613f0fcf62d49705242a9af9e61e85dc0d651e40dfcf017b45575887
+pleaseletmein,SodiumChloride,32768,16,4,cbc397a9b5f5a53048c5b9f039ee1246d9532c8089fb346a4ab47cd0701febf18652b1ee042e070d1b6c631c43fd05ececd5b165ee1c2ffc1a2e98406fc2cd52
+pleaseletmein,SodiumChloride,1048576,8,1,2101cb9b6a511aaeaddbbe09cf70f881ec568d574a2ffd4dabe5ee9820adaa478e56fd8f4ba5d09ffa1c6d927c40f4c337304049e8a952fbcbf45c6fa77a41a4

tests/scrypt-tests.py

-import unittest
-
-import scrypt
-
-class TestScrypt(unittest.TestCase):
-    def test_encrypt(self):
-        s = scrypt.encrypt('message', 'password', .1)
-        self.assertEqual(len(s), 128+len('message'))
-        
-    def test_encrypt_decrypt(self):
-        orig_m = 'message'
-        s = scrypt.encrypt(orig_m, 'password', .1)
-        m = scrypt.decrypt(s, 'password', .1)
-        self.assertEqual(m, orig_m)
-        
-    def test_too_little_time(self):
-        orig_m = 'message'
-        s = scrypt.encrypt(orig_m, 'password', .1)
-        self.assertRaises(scrypt.error, lambda: scrypt.decrypt(s, 'password', .01))
-        
-if __name__ == '__main__':
-    unittest.main()
+# -*- coding: utf-8 -*-
+
+from os import urandom
+from os.path import dirname, abspath, sep
+from sys import version_info
+from csv import reader
+from binascii import a2b_hex, b2a_hex
+import base64
+import json
+
+if ((version_info > (3, 2, 0, 'final', 0)) or
+    (version_info > (2, 7, 0, 'final', 0) and version_info < (3, 0, 0, 'final', 0))):
+    import unittest as testm
+else:
+    try:
+        import unittest2 as testm
+    except ImportError:
+        print("Please install unittest2 to run the test suite")
+        exit(-1)
+
+import scrypt
+
+
+class TestScrypt(testm.TestCase):
+
+    def setUp(self):
+        self.input = "message"
+        self.password = "password"
+        self.longinput = str(urandom(100000))
+        self.one_byte = 1  # in Bytes
+        self.one_megabyte = 1024 * 1024  # in Bytes
+        self.ten_megabytes = 10 * self.one_megabyte
+        base_dir = dirname(abspath(__file__)) + sep
+        cvf = open(base_dir + "ciphertexts.csv", "r")
+        ciphertxt_reader = reader(cvf, dialect="excel")
+        self.ciphertexts = []
+        for row in ciphertxt_reader:
+            self.ciphertexts.append(row)
+        cvf.close()
+        self.ciphertext = a2b_hex(bytes(self.ciphertexts[1][5].encode('ascii')))
+
+    def test_encrypt_decrypt(self):
+        """Test encrypt for simple encryption and decryption"""
+        s = scrypt.encrypt(self.input, self.password, 0.1)
+        m = scrypt.decrypt(s, self.password)
+        self.assertEqual(m, self.input)
+
+    def test_encrypt(self):
+        """Test encrypt takes input and password strings as
+        positional arguments and produces ciphertext"""
+        s = scrypt.encrypt(self.input, self.password)
+        self.assertEqual(len(s), 128 + len(self.input))
+
+    def test_encrypt_input_and_password_as_keywords(self):
+        """Test encrypt for input and password accepted as keywords"""
+        s = scrypt.encrypt(password=self.password, input=self.input)
+        m = scrypt.decrypt(s, self.password)
+        self.assertEqual(m, self.input)
+
+    def test_encrypt_missing_input_keyword_argument(self):
+        """Test encrypt raises TypeError if keyword argument missing input"""
+        self.assertRaises(TypeError, lambda: scrypt.encrypt(password=self.password))
+
+    def test_encrypt_missing_password_positional_argument(self):
+        """Test encrypt raises TypeError if second positional argument missing
+        (password)"""
+        self.assertRaises(TypeError, lambda: scrypt.encrypt(self.input))
+
+    def test_encrypt_missing_both_required_positional_arguments(self):
+        """Test encrypt raises TypeError if both positional arguments missing
+        (input and password)"""
+        self.assertRaises(TypeError, lambda: scrypt.encrypt())
+
+    def test_encrypt_maxtime_positional(self):
+        """Test encrypt maxtime accepts maxtime at position 3"""
+        s = scrypt.encrypt(self.input, self.password, 0.01)
+        m = scrypt.decrypt(s, self.password)
+        self.assertEqual(m, self.input)
+
+    def test_encrypt_maxtime_key(self):
+        """Test encrypt maxtime accepts maxtime as keyword argument"""
+        s = scrypt.encrypt(self.input, self.password, maxtime=0.01)
+        m = scrypt.decrypt(s, self.password)
+        self.assertEqual(m, self.input)
+
+    def test_encrypt_maxmem_positional(self):
+        """Test encrypt maxmem accepts 4th positional argument and exactly
+        (1 megabyte) of storage to use for V array"""
+        s = scrypt.encrypt(self.input, self.password, 0.01, self.one_megabyte)
+        m = scrypt.decrypt(s, self.password)
+        self.assertEqual(m, self.input)
+
+    def test_encrypt_maxmem_undersized(self):
+        """Test encrypt maxmem accepts (< 1 megabyte) of storage to use for V array"""
+        s = scrypt.encrypt(self.input, self.password, 0.01, self.one_byte)
+        m = scrypt.decrypt(s, self.password)
+        self.assertEqual(m, self.input)
+
+    def test_encrypt_maxmem_in_normal_range(self):
+        """Test encrypt maxmem accepts (> 1 megabyte) of storage to use for V array"""
+        s = scrypt.encrypt(self.input,
+                           self.password,
+                           0.01,
+                           self.ten_megabytes)
+        m = scrypt.decrypt(s, self.password)
+        self.assertEqual(m, self.input)
+
+    def test_encrypt_maxmem_keyword_argument(self):
+        """Test encrypt maxmem accepts exactly (1 megabyte) of storage to use for
+        V array"""
+        s = scrypt.encrypt(self.input,
+                           self.password,
+                           maxmem=self.one_megabyte,
+                           maxtime=0.01)
+        m = scrypt.decrypt(s, self.password)
+        self.assertEqual(m, self.input)
+
+    def test_encrypt_maxmemfrac_positional(self):
+        """Test encrypt maxmemfrac accepts 5th positional argument of 1/16 total
+        memory for V array"""
+        s = scrypt.encrypt(self.input, self.password, 0.01, 0, 0.0625)
+        m = scrypt.decrypt(s, self.password)
+        self.assertEqual(m, self.input)
+
+    def test_encrypt_maxmemfrac_keyword_argument(self):
+        """Test encrypt maxmemfrac accepts keyword argument of 1/16 total memory for
+        V array"""
+        s = scrypt.encrypt(self.input, self.password, maxmemfrac=0.0625,
+                           maxtime=0.01)
+        m = scrypt.decrypt(s, self.password)
+        self.assertEqual(m, self.input)
+
+    def test_encrypt_long_input(self):
+        """Test encrypt accepts long input for encryption"""
+        s = scrypt.encrypt(self.longinput, self.password, 0.1)
+        self.assertEqual(len(s), 128 + len(self.longinput))
+
+    def test_encrypt_raises_error_on_invalid_keyword(self):
+        """Test encrypt raises TypeError if invalid keyword used in argument"""
+        self.assertRaises(TypeError, lambda: scrypt.encrypt(self.input,
+            self.password, nonsense="Raise error"))
+
+    def test_decrypt_from_csv_ciphertexts(self):
+        """Test decrypt function with precalculated combinations"""
+        for row in self.ciphertexts[1:]:
+            h = scrypt.decrypt(a2b_hex(bytes(row[5].encode('ascii'))), row[1])
+            self.assertEqual(bytes(h.encode("ascii")), row[0].encode("ascii"))
+
+    def test_decrypt_maxtime_positional(self):
+        """Test decrypt function accepts third positional argument"""
+        m = scrypt.decrypt(self.ciphertext, self.password, 1.0)
+        self.assertEqual(m, self.input)
+
+    def test_decrypt_maxtime_keyword_argument(self):
+        """Test decrypt function accepts maxtime keyword argument"""
+        m = scrypt.decrypt(maxtime=1.0, input=self.ciphertext, password=self.password)
+        self.assertEqual(m, self.input)
+
+    def test_decrypt_maxmem_positional(self):
+        """Test decrypt function accepts fourth positional argument"""
+        m = scrypt.decrypt(self.ciphertext, self.password, 0.1, self.ten_megabytes)
+        self.assertEqual(m, self.input)
+
+    def test_decrypt_maxmem_keyword_argument(self):
+        """Test decrypt function accepts maxmem keyword argument"""
+        m = scrypt.decrypt(maxmem=self.ten_megabytes, input=self.ciphertext, password=self.password)
+        self.assertEqual(m, self.input)
+
+    def test_decrypt_maxmemfrac_positional(self):
+        """Test decrypt function accepts maxmem keyword argument"""
+        m = scrypt.decrypt(self.ciphertext, self.password, 0.1, self.one_megabyte, 0.0625)
+        self.assertEqual(m, self.input)
+
+    def test_decrypt_maxmemfrac_keyword_argument(self):
+        """Test decrypt function accepts maxmem keyword argument"""
+        m = scrypt.decrypt(maxmemfrac=0.625, input=self.ciphertext, password=self.password)
+        self.assertEqual(m, self.input)
+
+    def test_decrypt_raises_error_on_too_little_time(self):
+        """Test decrypt function raises scrypt.error raised if insufficient time allowed for
+        ciphertext decryption"""
+        s = scrypt.encrypt(self.input, self.password, 0.1)
+        self.assertRaises(scrypt.error,
+                          lambda: scrypt.decrypt(s, self.password, .01))
+
+
+class TestScryptHash(testm.TestCase):
+
+    def setUp(self):
+        self.input = "message"
+        self.password = "password"
+        self.salt = "NaCl"
+        self.hashes = []
+        base_dir = dirname(abspath(__file__)) + sep
+        hvf = open(base_dir + "hashvectors.csv", "r")
+        hash_reader = reader(hvf, dialect="excel")
+        for row in hash_reader:
+            self.hashes.append(row)
+        hvf.close()
+
+    def test_hash_vectors_from_csv(self):
+        """Test hash function with precalculated combinations"""
+        for row in self.hashes[1:]:
+            h = scrypt.hash(row[0], row[1], int(row[2]), int(row[3]), int(row[4]))
+            hhex = b2a_hex(h)
+            self.assertEqual(hhex, bytes(row[5].encode("utf-8")))
+
+    def test_hash_n_positional(self):
+        """Test hash accepts valid N in position 3"""
+        h = scrypt.hash(self.input, self.salt, 256)
+        self.assertEqual(len(h), 64)
+
+    def test_hash_n_keyword(self):
+        """Test hash takes keyword valid N"""
+        h = scrypt.hash(N=256, password=self.input, salt=self.salt)
+        self.assertEqual(len(h), 64)
+
+    def test_hash_r_positional(self):
+        """Test hash accepts valid r in position 4"""
+        h = scrypt.hash(self.input, self.salt, 256, 16)
+        self.assertEqual(len(h), 64)
+
+    def test_hash_r_keyword(self):
+        """Test hash takes keyword valid r"""
+        h = scrypt.hash(r=16, password=self.input, salt=self.salt)
+        self.assertEqual(len(h), 64)
+
+    def test_hash_p_positional(self):
+        """Test hash accepts valid p in position 5"""
+        h = scrypt.hash(self.input, self.salt, 256, 8, 2)
+        self.assertEqual(len(h), 64)
+
+    def test_hash_p_keyword(self):
+        """Test hash takes keyword valid p"""
+        h = scrypt.hash(p=4, password=self.input, salt=self.salt)
+        self.assertEqual(len(h), 64)
+
+    def test_hash_raises_error_on_p_equals_zero(self):
+        """Test hash raises scrypt error on illegal parameter value (p = 0)"""
+        self.assertRaises(scrypt.error,
+                          lambda: scrypt.hash(self.input, self.salt, p=0))
+
+    def test_hash_raises_error_on_negative_p(self):
+        """Test hash raises scrypt error on illegal parameter value (p < 0)"""
+        self.assertRaises(scrypt.error,
+                          lambda: scrypt.hash(self.input, self.salt, p=-1))
+
+    def test_hash_raises_error_on_r_equals_zero(self):
+        """Test hash raises scrypt error on illegal parameter value (r = 0)"""
+        self.assertRaises(scrypt.error,
+                          lambda: scrypt.hash(self.input, self.salt, r=0))
+
+    def test_hash_raises_error_on_negative_r(self):
+        """Test hash raises scrypt error on illegal parameter value (r < 1)"""
+        self.assertRaises(scrypt.error,
+                          lambda: scrypt.hash(self.input, self.salt, r=-1))
+
+    def test_hash_raises_error_r_p_over_limit(self):
+        """Test hash raises scrypt error when parameters r multiplied by p over limit 2**30"""
+        self.assertRaises(scrypt.error,
+                          lambda: scrypt.hash(self.input, self.salt, r=2, p=2 ** 29))
+
+    def test_hash_raises_error_n_not_power_of_two(self):
+        """Test hash raises scrypt error when parameter N is not a power of two {2, 4, 8, 16, etc}"""
+        self.assertRaises(scrypt.error,
+                          lambda: scrypt.hash(self.input, self.salt, N=3))
+
+    def test_hash_raises_error_n_under_limit(self):
+        """Test hash raises scrypt error when parameter N under limit of 1"""
+        self.assertRaises(scrypt.error,
+                          lambda: scrypt.hash(self.input, self.salt, N=1))
+        self.assertRaises(scrypt.error,
+                          lambda: scrypt.hash(self.input, self.salt, N=-1))
+
+if __name__ == "__main__":
+    testm.main()

tests/test_scrypt_py2x.py

+# -*- coding: utf-8 -*-
+
+from sys import version_info
+
+if ((version_info > (3, 2, 0, 'final', 0)) or
+    (version_info > (2, 7, 0, 'final', 0) and version_info < (3, 0, 0, 'final', 0))):
+    import unittest as testm
+else:
+    try:
+        import unittest2 as testm
+    except ImportError:
+        print("Please install unittest2 to run the test suite")
+        exit(-1)
+
+import scrypt
+
+
+@testm.skipIf(version_info > (3, 0, 0, 'final', 0), "Tests for Python 2 only")
+class TestScryptForPython2(testm.TestCase):
+
+    def setUp(self):
+        self.input = "message"
+        self.password = "password"
+        self.unicode_text = '\xe1\x93\x84\xe1\x93\x87\xe1\x95\x97\xe1\x92\xbb\xe1\x92\xa5\xe1\x90\x85\xe1\x91\xa6'.decode('utf-8')
+
+    def test_py2_encrypt_fails_on_unicode_input(self):
+        """Test Py2 encrypt raises TypeError when Unicode input passed"""
+        self.assertRaises(TypeError, lambda: scrypt.encrypt(self.unicode_text, self.password))
+
+    def test_py2_encrypt_fails_on_unicode_password(self):
+        """Test Py2 encrypt raises TypeError when Unicode password passed"""
+        self.assertRaises(TypeError, lambda: scrypt.encrypt(self.input, self.unicode_text))
+
+    def test_py2_encrypt_returns_string(self):
+        """Test Py2 encrypt returns str"""
+        e = scrypt.encrypt(self.input, self.password, 0.1)
+        self.assertTrue(isinstance(e, str))
+
+    def test_py2_decrypt_returns_string(self):
+        """Test Py2 decrypt returns str"""
+        s = scrypt.encrypt(self.input, self.password, 0.1)
+        m = scrypt.decrypt(s, self.password)
+        self.assertTrue(isinstance(m, str))
+
+    def test_py2_hash_returns_string(self):
+        """Test Py2 hash return str"""
+        h = scrypt.hash(self.input, self.password)
+        self.assertTrue(isinstance(h, str))
+
+if __name__ == "__main__":
+    testm.main()

tests/test_scrypt_py3x.py

+# -*- coding: utf-8 -*-
+
+from sys import version_info
+
+if ((version_info > (3, 2, 0, 'final', 0)) or
+    (version_info > (2, 7, 0, 'final', 0) and version_info < (3, 0, 0, 'final', 0))):
+    import unittest as testm
+else:
+    try:
+        import unittest2 as testm
+    except ImportError:
+        print("Please install unittest2 to run the test suite")
+        exit(-1)
+
+import scrypt
+
+
+@testm.skipIf(version_info < (3, 0, 0, 'final', 0), "Tests for Python 3 only")
+class TestScryptForPy3(testm.TestCase):
+
+    def setUp(self):
+        self.input = "message"
+        self.password = "password"
+        self.byte_text = b'\xe1\x93\x84\xe1\x93\x87\xe1\x95\x97\xe1\x92\xbb\xe1\x92\xa5\xe1\x90\x85\xe1\x91\xa6'
+        self.unicode_text = self.byte_text.decode('utf-8', "strict")
+
+    def test_py3_encrypt_allows_bytes_input(self):
+        """Test Py3 encrypt allows unicode input"""
+        s = scrypt.encrypt(self.byte_text, self.password, 0.1)
+        m = scrypt.decrypt(s, self.password)
+        self.assertEqual(bytes(m.encode("utf-8")), self.byte_text)
+
+    def test_py3_encrypt_allows_bytes_password(self):
+        """Test Py3 encrypt allows unicode password"""
+        s = scrypt.encrypt(self.input, self.byte_text, 0.1)
+        m = scrypt.decrypt(s, self.byte_text)
+        self.assertEqual(m, self.input)
+
+    def test_py3_encrypt_returns_bytes(self):
+        """Test Py3 encrypt return bytes"""
+        s = scrypt.encrypt(self.input, self.password, 0.1)
+        self.assertTrue(isinstance(s, bytes))
+
+    def test_py3_decrypt_returns_unicode_string(self):
+        """Test Py3 decrypt returns Unicode UTF-8 string"""
+        s = scrypt.encrypt(self.input, self.password, 0.1)
+        m = scrypt.decrypt(s, self.password)
+        self.assertTrue(isinstance(m, str))
+
+    def test_py3_hash_returns_bytes(self):
+        """Test Py3 hash return bytes"""
+        h = scrypt.hash(self.input, self.password)
+        self.assertTrue(isinstance(h, bytes))
+
+if __name__ == "__main__":
+    testm.main()