Source

vinstall / vinstall / backend / fstab.py

# coding: utf8

#    This file is part of vinstall.
#
#    vinstall is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License v3 as published by
#    the Free Software Foundation.
#
#    vinstall is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with VASM.  If not, see <http://www.gnu.org/licenses/>.


__author__ = "Moises Henriquez"
__author_email__ = "moc.liamg@xnl.E0M"[::-1]


"""fstab.py
Provides a data model for working with /etc/fstab.

"""


import os
import unittest


    
class FstabEntry(object):
    """ Object representing an entry in /etc/fstab

    """
    def __init__(self, device=None, mountpoint=None, filesystem=None, options=None):
        """ Creates an object that can be used with FstabModel.
    	  Arguments:
    	  		device = device to be mounted (ie, /dev/sda1)
    	  		mountpoint = Where to mount it (ie, /, or /home)
    	  		filesystem = filesystem to be used while mounting (ie, 'ext2')
    	  		options = default mounting options

        """
        self.device = device
        self.uuid = None
        self.mountpoint = mountpoint
        self.filesystem = filesystem
        self.options = options

    def find_default_mountpoint(self):
        assert self.device is not None, "device path must be set first."
        assert self.filesystem is not None, "Unknown filesystem.  Set filesystem property first."
        if 'swap' in self.filesystem:
            ret = 'swap'
        else:
            ret = self.device.replace('/dev/','/mnt/')
        return ret

    def find_default_options(self):
        """ Returns the default fstab options for the partitions filesystem

        """
        opts = {
            "ext2": "defaults 0 0",
            "ext3": "defaults 0 0",
            "ext4": "defaults 0 0",
            "jfs": "defaults 0 0",
            "iso9660": "noauto,ro,user 0 0",
            "msdos": "defaults 0 0",
            "swap": "sw 0 0",
            "linux-swap": "sw 0 0",
            "linux-swap(v1)": "sw 0 0"
        }
        if self.filesystem in ("swap","none","linux-swap","linux-swap(v1)"):
            return opts["swap"]
        else:
            return opts.get(self.filesystem, "defaults 0 0")

    def find_uuid(self):
        """ Find the uuid for the specified device

        """
        uuiddir = "/dev/disk/by-uuid"
        partition = os.path.split(self.device)[-1]
        for entry in os.listdir(uuiddir):
            lname = os.readlink(os.path.join(uuiddir, entry))
            if lname.endswith(partition):
                return entry
        return None

    def find_path(self):
        """ Find the partition path.  This is useful if fstab provides UUID but not path

        """
        assert self.uuid is not None, "Cannot find path if uuid is not known"
        uuiddir = "/dev/disk/by-uuid"
        for uid in os.listdir(uuiddir):
            if uid == self.uuid:
                lname = os.readlink(os.path.join(uuiddir, uid))
                return os.path.join("/dev/", lname.split("/")[-1])
        return None


class Fstab(object):
    """Data model for working with /etc/fstab

    """
    def __init__(self, fstab_path='/etc/fstab'):
        self.fstab_path = fstab_path

    def remove_entry(self, device_path):
        """ Remove the fstab entry with the device_path.

        """
        assert self.has_entry(device_path), "%s is not listed in %s" % \
                (device_path, self.fstab_path)
        ndata = []
        with open(self.fstab_path) as data:
            for line in data:
                line = line.strip()
                if not line or line.startswith("#"):
                    continue
                entry = self.get_entry_from_line(line)
                if entry.device == device_path:
                    continue
                ndata.append(line)
        with open(self.fstab_path, 'w') as f:
            f.writelines(''.join(ndata))

    def add_entry(self, entry):
        """ Add a new entry to fstab.
        Arguments:
            entry - An FstabEntry() object

        """
        assert isinstance(entry, FstabEntry), "entry argument is of invalid type"
        assert hasattr(entry, 'mountpoint'), "Missing mountpoint attribute."
        assert entry.device is not None, "No device specified"
        assert entry.mountpoint is not None, "No mountpoint defined for this entry"
        sep = " "*5
        if entry.uuid is None:
            entry.uuid = entry.find_uuid()
        # Use UUID when possible
        if entry.uuid is None:
            part = entry.device
        else:
            part = "UUID=%s" % entry.uuid

        if entry.options is None:
            options = entry.find_default_options()
        else:
            options = entry.options
        data = []

        # Make sure the file exists first.
        if not os.path.exists(self.fstab_path):
            with open(self.fstab_path, 'w') as f:
                f.writelines(['# fstab generated by vinstall', '\n'])

        with open(self.fstab_path) as f:
            for line in f:
                data.append(line)

        line = sep.join((
            part,
            entry.mountpoint,
            entry.filesystem,
            options
            ))

        data.append(line + "\n")
        with open(self.fstab_path, 'w') as f:
            f.writelines(''.join(data))

    def list_entries(self):
        """ Return a list of FstabEntry objects representing each entry in fstab

        """
        ret = []
        sep = " "*5
        with open(self.fstab_path) as f:
            for line in f:
                line = line.strip()
                if not line or line.startswith("#"):
                    continue
                entry = self.get_entry_from_line(line)
                ret.append(entry)
        return ret

    def get_entry(self, device_path):
        """ Get an entry object for an existing fstab entry.

        """
        for item in self.list_entries():
            if item.device == device_path:
                return item
        return None

    def get_entry_from_line(self, line):
        """ Create a valid entry from the provide line

        """
        sp = line.split()
        sep = " "*5
        entry = FstabEntry()
        dev = sp[0].strip()
        if dev.startswith("UUID="):
            entry.uuid = dev.split("=")[-1].strip()
            entry.device = entry.find_path()
        else:
            entry.device = dev.strip()
            entry.uuid = entry.find_uuid()
        entry.mountpoint = sp[1].strip()
        entry.filesystem = sp[2].strip()
        entry.options = sep.join(sp[3::]).strip()

        return entry

    def has_entry(self, device_path):
        """ Check if fstab has an entry for device_path

        """
        entries = self.list_entries()
        for item in entries:
            if item.device == device_path:
                return True
        return False


class FstabTests(unittest.TestCase):
    def setUp(self):
        self.fobject = Fstab(fstab_path = "/tmp/test-fstab")
        self.fentry = FstabEntry(device = "/dev/sda1",
            mountpoint = "/",
            filesystem = "ext2")
    
    def test_makeEntry(self):
        """Test the attributes on the entry"""
        self.assertEqual(self.fentry.device, "/dev/sda1")
        self.assertEqual(self.fentry.filesystem, "ext2")
        self.assertEqual(self.fentry.mountpoint, "/")
    
    def test_add_invalid_entry(self):
        self.assertRaises(AssertionError, self.fobject.add_entry, "/dev/sda1")
    
    def test_add_incomplete_entry(self):
        self.fentry.mountpoint = None
        self.fentry.device = None
        self.assertRaises(AssertionError, self.fobject.add_entry, self.fentry)
        self.assertRaises(AssertionError, self.fobject.add_entry, self.fentry)
    
    def test_add_valid_entry(self):
        self.fobject.add_entry(self.fentry)
        self.assertTrue(self.fobject.has_entry, "/dev/sda1")
        rentry = self.fobject.get_entry("/dev/sda1")
        
        self.assertEqual(rentry.device, self.fentry.device)
        self.assertEqual(rentry.filesystem, self.fentry.filesystem)
        self.assertEqual(rentry.mountpoint, self.fentry.mountpoint)
    

if __name__ == "__main__":
    unittest.main()