Anonymous avatar Anonymous committed 3098659

*Drastic* change to reduce the size of RootSeekableFieldSet and bring it in line with GenericFieldSet.

If this breaks existing parsers, please let me know!

Comments (0)

Files changed (1)

hachoir-core/hachoir_core/field/seekable_field_set.py

-from hachoir_core.field import Field, BasicFieldSet, FakeArray, MissingField, ParserError
-from hachoir_core.tools import lowerBound, makeUnicode
+from hachoir_core.field import BasicFieldSet, GenericFieldSet, ParserError, createRawField
 from hachoir_core.error import HACHOIR_ERRORS
-from itertools import repeat
-import hachoir_core.config as config
 
-class RootSeekableFieldSet(BasicFieldSet):
-    def __init__(self, parent, name, stream, description, size):
-        BasicFieldSet.__init__(self, parent, name, stream, description, size)
-        self._generator = self.createFields()
-        self._offset = 0
-        self._current_size = 0
-        if size:
-            self._current_max_size = size
-        else:
-            self._current_max_size = 0
-        self._field_dict = {}
-        self._field_array = []
+# getgaps(int, int, [listof (int, int)]) -> generator of (int, int)
+# Gets all the gaps not covered by a block in `blocks` from `start` for `length` units.
+def getgaps(start, length, blocks):
+    '''
+    Example:
+    >>> list(getgaps(0, 20, [(15,3), (6,2), (6,2), (1,2), (2,3), (11,2), (9,5)]))
+    [(0, 1), (5, 1), (8, 1), (14, 1), (18, 2)]
+    '''
+    # done this way to avoid mutating the original
+    blocks = sorted(blocks, key=lambda b: b[0])
+    end = start+length
+    for s, l in blocks:
+        if s > start:
+            yield (start, s-start)
+            start = s
+        if s+l > start:
+            start = s+l
+    if start < end:
+        yield (start, end-start)
 
-    def _feedOne(self):
-        assert self._generator
-        field = self._generator.next()
-        self._addField(field)
-        return field
-
-    def array(self, key):
-        return FakeArray(self, key)
-
-    def getFieldByAddress(self, address, feed=True):
-        for field in self._field_array:
-            if field.address <= address < field.address + field.size:
-                return field
-        for field in self._readFields():
-            if field.address <= address < field.address + field.size:
-                return field
-        return None
-
-    def _stopFeed(self):
-        self._size = self._current_max_size
-        self._generator = None
-    done = property(lambda self: not bool(self._generator))
-
-    def _getSize(self):
-        if self._size is None:
-            self._feedAll()
-        return self._size
-    size = property(_getSize)
-
-    def _getField(self, key, const):
-        field = Field._getField(self, key, const)
-        if field is not None:
-            return field
-        if key in self._field_dict:
-            return self._field_dict[key]
-        if self._generator and not const:
-            try:
-                while True:
-                    field = self._feedOne()
-                    if field.name == key:
-                        return field
-            except StopIteration:
-                self._stopFeed()
-            except HACHOIR_ERRORS, err:
-                self.error("Error: %s" % makeUnicode(err))
-                self._stopFeed()
-        return None
-
-    def getField(self, key, const=True):
-        if isinstance(key, (int, long)):
-            if key < 0:
-                raise KeyError("Key must be positive!")
-            if not const:
-                self.readFirstFields(key+1)
-            if len(self._field_array) <= key:
-                raise MissingField(self, key)
-            return self._field_array[key]
-        return Field.getField(self, key, const)
-
-    def _addField(self, field):
-        if field._name.endswith("[]"):
-            self.setUniqueFieldName(field)
-        if config.debug:
-            self.info("[+] DBG: _addField(%s)" % field.name)
-
-        if field._address != self._offset:
-            self.warning("Set field %s address to %s (was %s)" % (
-                field.path, self._offset//8, field._address//8))
-            field._address = self._offset
-        assert field.name not in self._field_dict
-
-        self._checkFieldSize(field)
-
-        self._field_dict[field.name] = field
-        self._field_array.append(field)
-        self._current_size += field.size
-        self._offset += field.size
-        self._current_max_size = max(self._current_max_size, field.address + field.size)
-
-    def _checkAddress(self, address):
-        if self._size is not None:
-            max_addr = self._size
-        else:
-            # FIXME: Use parent size
-            max_addr = self.stream.size
-        return address < max_addr
-
-    def _checkFieldSize(self, field):
-        size = field.size
-        addr = field.address
-        if not self._checkAddress(addr+size-1):
-            raise ParserError("Unable to add %s: field is too large" % field.name)
-
+class RootSeekableFieldSet(GenericFieldSet):
     def seekBit(self, address, relative=True):
         if not relative:
             address -= self.absolute_address
         if address < 0:
             raise ParserError("Seek below field set start (%s.%s)" % divmod(address, 8))
-        if not self._checkAddress(address):
-            raise ParserError("Seek above field set end (%s.%s)" % divmod(address, 8))
-        self._offset = address
+        self._current_size = address
         return None
 
     def seekByte(self, address, relative=True):
         return self.seekBit(address*8, relative)
 
-    def readMoreFields(self, number):
-        return self._readMoreFields(xrange(number))
+    def _fixLastField(self):
+        """
+        Try to fix last field when we know current field set size.
+        Returns new added field if any, or None.
+        """
+        assert self._size is not None
 
-    def _feedAll(self):
-        return self._readMoreFields(repeat(1))
+        # Stop parser
+        message = ["stop parser"]
+        self._field_generator = None
 
-    def _readFields(self):
-        while True:
-            added = self._readMoreFields(xrange(1))
-            if not added:
-                break
-            yield self._field_array[-1]
+        # If last field is too big, delete it
+        while self._size < self._current_size:
+            field = self._deleteField(len(self._fields)-1)
+            message.append("delete field %s" % field.path)
+        assert self._current_size <= self._size
 
-    def _readMoreFields(self, index_generator):
-        added = 0
-        if self._generator:
-            try:
-                for index in index_generator:
-                    self._feedOne()
-                    added += 1
-            except StopIteration:
-                self._stopFeed()
-            except HACHOIR_ERRORS, err:
-                self.error("Error: %s" % makeUnicode(err))
-                self._stopFeed()
-        return added
+        blocks = [(x.absolute_address, x.size) for x in self._fields]
+        fields = []
+        for start, length in getgaps(self.absolute_address, self._size, blocks):
+            self.seekBit(start, relative=False)
+            field = createRawField(self, length, "unparsed[]")
+            self.setUniqueFieldName(field)
+            self._fields.append(field.name, field)
+            fields.append(field)
+            message.append("found unparsed segment: start %s, length %s" % (start, length))
+        
+        self.seekBit(self._size, relative=False)
+        message = ", ".join(message)
+        if fields:
+            self.warning("[Autofix] Fix parser error: " + message)
+        return fields
 
-    current_length = property(lambda self: len(self._field_array))
-    current_size = property(lambda self: self._offset)
+    def _stopFeeding(self):
+        new_field = None
+        if self._size is None:
+            if self._parent:
+                self._size = self._current_size
 
-    def __iter__(self):
-        for field in self._field_array:
-            yield field
-        if self._generator:
-            try:
-                while True:
-                    yield self._feedOne()
-            except StopIteration:
-                self._stopFeed()
-                raise StopIteration
-
-    def __len__(self):
-        if self._generator:
-            self._feedAll()
-        return len(self._field_array)
-
-    def nextFieldAddress(self):
-        return self._offset
-
-    def getFieldIndex(self, field):
-        return self._field_array.index(field)
+        new_field = self._fixLastField()
+        self._field_generator = None
+        return new_field
 
 class SeekableFieldSet(RootSeekableFieldSet):
     def __init__(self, parent, name, description=None, size=None):
         assert issubclass(parent.__class__, BasicFieldSet)
         RootSeekableFieldSet.__init__(self, parent, name, parent.stream, description, size)
-
Tip: Filter by directory path e.g. /media app.js to search for public/media/app.js.
Tip: Use camelCasing e.g. ProjME to search for ProjectModifiedEvent.java.
Tip: Filter by extension type e.g. /repo .js to search for all .js files in the /repo directory.
Tip: Separate your search with spaces e.g. /ssh pom.xml to search for src/ssh/pom.xml.
Tip: Use ↑ and ↓ arrow keys to navigate and return to view the file.
Tip: You can also navigate files with Ctrl+j (next) and Ctrl+k (previous) and view the file with Ctrl+o.
Tip: You can also navigate files with Alt+j (next) and Alt+k (previous) and view the file with Alt+o.