Source

vinstall / vinstall / backend / partitioning.py

Full commit
# -*- coding: utf8 -*-

#    This file is part of vinstall.
#
#    vinstall is free software: you can redistribute it and/or modify
#    it under the terms of the GNU General Public License v3 as published by
#    the Free Software Foundation.
#
#    vinstall is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License
#    along with vinstall.  If not, see <http://www.gnu.org/licenses/>.


""" Convinience methods for working with partitions for the vectorlinux installer

"""

import utils
import parted
import unittest, tempfile, os
from vinstall.backend import media
from vinstall.core import log

__author__ = "Moises Henriquez"
__email__ = "moc.liamg@xnl.e0m"[::-1]
__version__ = "0.1"
__date__ = "2012-06-05"

LOG = log.get_logger(__file__)

class PartitionDim(object):
    def __init__(self, device=None, region=None):
        self.device = device
        self.start = 0
        self.end = 0
        self.length = 0
        self.region = region
        self.partition_type = parted.PARTITION_NORMAL

class DiskPartitioner(object):
    """ Class used to modify the partition table on the given disk

    """
    
    def __init__(self, disk):
        """ Arguments:

                disk - a media.Disk object

        """
        self.disk = disk
        self.partition_cache = []

    def has_partition_table(self):
        """Find if there is a partition table. PTs start at position 0x1BE and
        contains 4 entries of 16 bytes, one for each partition
        
        """
        with open(self.disk._device.path) as f:
            f.seek(0x1BE)
            data = f.read(16)
            if data == "\x00" * 16:
                return False
            else:
                return True

    def create_partition_table(self, table_type="msdos"):
        """ Create the partition table on the disk.
        This is only necessary when using a fresh disk that has
        never been partitioned before.

        """
        self.disk._disk = parted.freshDisk(self.disk._device, "msdos")
        self.disk._disk.has_partition_table = True
        LOG.debug("New partition table created on %s"% self.disk)

    def delete_all_partitions(self):
        """ Delete all partitions from drive

        """
        LOG.debug("Deleting all cached and existing partitions in %s"% self.disk)
        self.partition_cache = []
        return (self.disk._disk.deleteAllPartitions(), self.disk._disk.commit())
        #return self.disk._disk.commit()
        self.disk._disk = parted.freshDisk(self.disk._device, "msdos")
        self.disk._disk.has_partition_table = True

    def partitions_number(self):
        """Return the number of partitions in disk
        
        """
        return len(self.disk._disk.partitions)
    
    def append_partition(self, size=0, units='MB'):
        """Add a partition to the partition cache"""
        dprimaries = self.disk._disk.getPrimaryPartitions()
        tprimaries = len(dprimaries) + len(
            [p for p in self.partition_cache if p.partition_type == parted.PARTITION_NORMAL])
        assert tprimaries < 3, "Too many primary partitions."
        mypsize = parted.sizeToSectors(size, units, self.disk._device.sectorSize)
        
        free_space = self.disk._disk.getFreeSpacePartitions()
        if not free_space: return
        ## FIXME:  Raise exception instead
        partdim = PartitionDim()
        # Iterate over existing cached partitions and add up the end of them.
        if self.partition_cache:
            cached_end = max([c.end for c in self.partition_cache])
        else:
            cached_end = 0
        # FIXME: ^^ compare which region each cached partition belongs to for sizing calculations.
        newstart = max(2048, (cached_end + 1))
        partdim.start = newstart
        partdim.length = mypsize
        partdim.device = self.disk
        partdim.end = newstart + mypsize
        partdim.partition_type = parted.PARTITION_NORMAL
        self.partition_cache.append(partdim)
        

    def add_partition(self,  size=0, units='MB'):
        """Add a partition to the disk.  
        Args:  size = partition size
               units = partition units (MB, GB, MiB, GiB), (defaults to MiB)

        """
        # FIXME: DEPRECATION WARNING:  Use append_partition method
        return self.append_partition(size, units)
    
    def create_partition(self, partdim):
        """Create a partition based on the provided partdim object"""
        pconstraint = parted.Constraint(device = self.disk._device)
        pgeometry = parted.Geometry(device = self.disk._device,
            start = partdim.start,
            length = partdim.length)
        pfilesystem = parted.FileSystem(type="ext2", geometry=pgeometry)
        ppartition = parted.Partition(disk = self.disk._disk,
            fs=pfilesystem,
            type = partdim.partition_type,
            geometry=pgeometry)
        pconstraint = parted.Constraint(exactGeom=pgeometry)
        self.disk._disk.addPartition(partition = ppartition,
            constraint = pconstraint)
        

    def write_changes(self):
        """Finalize changes to the disk.  This needs to be called
        after creating partitions or deleting partitions to make sure
        the changes are actually written to the disk."""
        
        # Create the cached partitions and then write changes
        for cached in self.partition_cache:
            LOG.debug("Creating new partition starting at sector %s, ending in sector %s"%(
                cached.start, cached.end))
            self.create_partition(cached)
        LOG.debug("Writing partition table changes to %s"%(self.disk))
        self.disk._disk.commit()


class DiskPartitionerTests(unittest.TestCase):

    def setUp(self):
        (fd, self.path) = tempfile.mkstemp(prefix="fake-device-")
        f = os.fdopen(fd)
        f.seek(1024000000)
        os.write(fd, "0")
        self.disk = media.Disk()
        self.disk._device = parted.Device(self.path)
        self.partitioner = DiskPartitioner(self.disk)

    def tearDown(self):
        os.unlink(self.path)
    
    def test_create_partition_table(self):
        """Create partition table in new disk"""
        self.partitioner.create_partition_table()
        self.partitioner.add_partition(size=0.1)
        self.partitioner.write_changes()
        self.assertTrue(self.partitioner.has_partition_table())

    def test_add_partition(self):
        """Adding a single partition"""
        self.partitioner.create_partition_table()
        self.partitioner.append_partition(size=0.1)
        self.partitioner.write_changes()
        self.disk._disk.getPrimaryPartitions()
        self.assertEqual(len(self.disk._disk.partitions), 1)
    
    def test_add_multiple_partitions(self):
        """Add multiple partitions to a single disk"""
        self.partitioner.create_partition_table()
        self.partitioner.add_partition(size=3)
        self.partitioner.add_partition(size=5)
        self.partitioner.write_changes()
        self.assertEqual(len(self.disk._disk.partitions), 2)

    def test_has_partition_table(self):
        """Check if partition table exists in disk"""
        self.assertFalse(self.partitioner.has_partition_table())
        self.partitioner.create_partition_table()
        self.partitioner.add_partition(size=0.1)
        self.partitioner.write_changes()
        self.assertTrue(self.partitioner.has_partition_table())
    
    def test_delete_all_partitions(self):
        """Make sure all partitions are deleted and the cache is flushed"""
        self.partitioner.create_partition_table()
        for size in (1,2):
            self.partitioner.add_partition(size)
        self.partitioner.write_changes()
        self.assertEqual(len(self.disk._disk.partitions), 2)
        self.partitioner.delete_all_partitions()
        self.assertEqual(len(self.partitioner.partition_cache), 0)
        self.assertEqual(len(self.disk._disk.partitions), 0)
    
    def test_toom_many_primary_partitions(self):
        """Raise an exception if we have too many primary partitions 
        (cached + written to the disk)"""
        self.partitioner.create_partition_table()
        for size in (1,2,3):
            self.partitioner.add_partition(size)
        # This one should raise an exception
        self.assertRaises(AssertionError, self.partitioner.add_partition, 1)


if __name__ == "__main__":
    unittest.main()