Commits

Anonymous committed ba8b117

hachoir-parser/image: parse JPEG image data fully

Comments (0)

Files changed (1)

hachoir-parser/hachoir_parser/image/jpeg.py

   http://java.sun.com/j2se/1.5.0/docs/api/javax/imageio/metadata/doc-files/jpeg_metadata.html#color
 - APP12:
   http://search.cpan.org/~exiftool/Image-ExifTool/lib/Image/ExifTool/TagNames.pod
+- JPEG Data Format
+  http://www.w3.org/Graphics/JPEG/itu-t81.pdf
 
-Author: Victor Stinner
+Author: Victor Stinner, Robert Xiao
 """
 
 from hachoir_core.error import HachoirError
 from hachoir_parser import Parser
-from hachoir_core.field import (FieldSet, ParserError,
-    UInt8, UInt16, Enum,
-    Bit, Bits, NullBits, NullBytes,
+from hachoir_core.field import (FieldSet, ParserError, FieldError,
+    UInt8, UInt16, Enum, Field,
+    Bit, Bits, NullBits, NullBytes, PaddingBits,
     String, RawBytes)
 from hachoir_parser.image.common import PaletteRGB
 from hachoir_core.endian import BIG_ENDIAN
 from hachoir_core.text_handler import textHandler, hexadecimal
 from hachoir_parser.image.exif import Exif
 from hachoir_parser.image.photoshop_metadata import PhotoshopMetadata
+from hachoir_parser.archive.zlib import build_tree
+from hachoir_core.tools import paddingSize, alignValue
 
 MAX_FILESIZE = 100 * 1024 * 1024
 
         while not self.eof:
             yield Ducky(self, "item[]")
 
+class SOFComponent(FieldSet):
+    def createFields(self):
+        yield UInt8(self, "component_id")
+        yield Bits(self, "horiz_sample", 4, "Horizontal sampling factor")
+        yield Bits(self, "vert_sample", 4, "Vertical sampling factor")
+        yield UInt8(self, "quant_table", "Quantization table destination selector")
+
 class StartOfFrame(FieldSet):
     def createFields(self):
         yield UInt8(self, "precision")
         yield UInt8(self, "nr_components")
 
         for index in range(self["nr_components"].value):
-            yield UInt8(self, "component_id[]")
-            yield UInt8(self, "high[]")
-            yield UInt8(self, "low[]")
+            yield SOFComponent(self, "component[]")
 
 class Comment(FieldSet):
     def createFields(self):
         yield NullBytes(self, "flags1", 2)
         yield Enum(UInt8(self, "color_transform", "Colorspace transformation code"), self.COLORSPACE_TRANSFORMATION)
 
+class SOSComponent(FieldSet):
+    def createFields(self):
+        comp_id = UInt8(self, "component_id")
+        yield comp_id
+        if not(1 <= comp_id.value <= self["../nr_components"].value):
+           raise ParserError("JPEG error: Invalid component-id")
+        yield Bits(self, "dc_coding_table", 4, "DC entropy coding table destination selector")
+        yield Bits(self, "ac_coding_table", 4, "AC entropy coding table destination selector")
+
 class StartOfScan(FieldSet):
     def createFields(self):
         yield UInt8(self, "nr_components")
 
         for index in range(self["nr_components"].value):
-            comp_id = UInt8(self, "component_id[]")
-            yield comp_id
-            if not(1 <= comp_id.value <= self["nr_components"].value):
-               raise ParserError("JPEG error: Invalid component-id")
-            yield UInt8(self, "value[]")
-        yield RawBytes(self, "raw", 3) # TODO: What's this???
+            yield SOSComponent(self, "component[]")
+        yield UInt8(self, "spectral_start", "Start of spectral or predictor selection")
+        yield UInt8(self, "spectral_end", "End of spectral selection")
+        yield Bits(self, "bit_pos_high", 4, "Successive approximation bit position high")
+        yield Bits(self, "bit_pos_low", 4, "Successive approximation bit position low or point transform")
 
 class RestartInterval(FieldSet):
     def createFields(self):
         while self.current_size < self.size:
             yield QuantizationTable(self, "qt[]")
 
+class HuffmanTable(FieldSet):
+    def createFields(self):
+        # http://www.w3.org/Graphics/JPEG/itu-t81.pdf, page 40-41
+        yield Enum(Bits(self, "table_class", 4, "Table class"), {
+            0:"DC or Lossless Table",
+            1:"AC Table"})
+        yield Bits(self, "index", 4, "Huffman table destination identifier")
+        for i in xrange(1, 17):
+            yield UInt8(self, "count[%i]" % i, "Number of codes of length %i" % i)
+        lengths = []
+        remap = {}
+        for i in xrange(1, 17):
+            for j in xrange(self["count[%i]" % i].value):
+                field = UInt8(self, "value[%i][%i]" % (i, j), "Value of code #%i of length %i" % (j, i))
+                yield field
+                remap[len(lengths)] = field.value
+                lengths.append(i)
+        self.tree = {}
+        for i,j in build_tree(lengths).iteritems():
+            self.tree[i] = remap[j]
+
+class DefineHuffmanTable(FieldSet):
+    def createFields(self):
+        while self.current_size < self.size:
+            yield HuffmanTable(self, "huffman_table[]")
+
+class HuffmanCode(Field):
+    """Huffman code. Uses tree parameter as the Huffman tree."""
+    def __init__(self, parent, name, tree, description=""):
+        Field.__init__(self, parent, name, 0, description)
+
+        endian = self.parent.endian
+        stream = self.parent.stream
+        addr = self.absolute_address
+
+        value = 0
+        met_ff = False
+        while (self.size, value) not in tree:
+            if addr % 8 == 0:
+                last_byte = stream.readBytes(addr - 8, 1)
+                if last_byte == '\xFF':
+                    next_byte = stream.readBytes(addr, 1)
+                    if next_byte != '\x00':
+                        raise FieldError("Unexpected byte sequence %r!"%(last_byte + next_byte))
+                    addr += 8 # hack hack hack
+                    met_ff = True
+                    self._description = "[skipped 8 bits after 0xFF] "
+            bit = stream.readBits(addr, 1, endian)
+            value <<= 1
+            value += bit
+            self._size += 1
+            addr += 1
+        self.createValue = lambda: value
+        self.realvalue = tree[(self.size, value)]
+        if met_ff:
+            self._size += 8
+
+class JpegHuffmanImageUnit(FieldSet):
+    """8x8 block of sample/coefficient values"""
+    def __init__(self, parent, name, dc_tree, ac_tree, *args, **kwargs):
+        FieldSet.__init__(self, parent, name, *args, **kwargs)
+        self.dc_tree = dc_tree
+        self.ac_tree = ac_tree
+
+    def createFields(self):
+        field = HuffmanCode(self, "dc_data", self.dc_tree)
+        field._description = "DC Code %i (Huffman Code %i)" % (field.realvalue, field.value) + field._description
+        yield field
+        if field.realvalue != 0:
+            extra = Bits(self, "dc_data_extra", field.realvalue)
+            if extra.value < 2**(field.realvalue - 1):
+                corrected_value = extra.value + (-1 << field.realvalue) + 1
+            else:
+                corrected_value = extra.value
+            extra._description = "Extra Bits: Corrected DC Value %i" % corrected_value
+            yield extra
+        data = []
+        while len(data) < 63:
+            field = HuffmanCode(self, "ac_data[]", self.ac_tree)
+            value_r = field.realvalue >> 4
+            if value_r:
+                data += [0] * value_r
+            value_s = field.realvalue & 0x0F
+            if value_r == value_s == 0:
+                field._description = "AC Code Block Terminator (0, 0) (Huffman Code %i)" % field.value + field._description
+                yield field
+                return
+            field._description = "AC Code %i, %i (Huffman Code %i)" % (value_r, value_s, field.value) + field._description
+            yield field
+            if value_s != 0:
+                extra = Bits(self, "ac_data_extra[%s" % field.name.split('[')[1], value_s)
+                if extra.value < 2**(value_s - 1):
+                    corrected_value = extra.value + (-1 << value_s) + 1
+                else:
+                    corrected_value = extra.value
+                extra._description = "Extra Bits: Corrected AC Value %i" % corrected_value
+                data.append(corrected_value)
+                yield extra
+            else:
+                data.append(0)
+
+class JpegImageData(FieldSet):
+    def __init__(self, parent, name, frame, scan, restart_interval, restart_offset=0, *args, **kwargs):
+        FieldSet.__init__(self, parent, name, *args, **kwargs)
+        self.frame = frame
+        self.scan = scan
+        self.restart_interval = restart_interval
+        self.restart_offset = restart_offset
+        # try to figure out where this field ends
+        start = self.absolute_address
+        while True:
+            end = self.stream.searchBytes("\xff", start, MAX_FILESIZE*8)
+            if end is None:
+                # this is a bad sign, since it means there is no terminator
+                # we ignore this; it likely means a truncated image
+                break
+            if self.stream.readBytes(end, 2) == '\xff\x00':
+                # padding: false alarm
+                start=end+16
+                continue
+            else:
+                self._size = end-self.absolute_address
+                break
+
+    def createFields(self):
+        if self.frame["../type"].value in [0xC0, 0xC1]:
+            # yay, huffman coding!
+            if not hasattr(self, "huffman_tables"):
+                self.huffman_tables = {}
+                for huffman in self.parent.array("huffman"):
+                    for table in huffman["content"].array("huffman_table"):
+                        for _dummy_ in table:
+                            # exhaust table, so the huffman tree is built
+                            pass
+                        self.huffman_tables[table["table_class"].value, table["index"].value] = table.tree
+            components = [] # sos_comp, samples
+            max_vert = 0
+            max_horiz = 0
+            for component in self.scan.array("component"):
+                for sof_comp in self.frame.array("component"):
+                    if sof_comp["component_id"].value == component["component_id"].value:
+                        vert = sof_comp["vert_sample"].value
+                        horiz = sof_comp["horiz_sample"].value
+                        components.append((component, vert * horiz))
+                        max_vert = max(max_vert, vert)
+                        max_horiz = max(max_horiz, horiz)
+            mcu_height = alignValue(self.frame["height"].value, 8 * max_vert) // (8 * max_vert)
+            mcu_width = alignValue(self.frame["width"].value, 8 * max_horiz) // (8 * max_horiz)
+            if self.restart_interval and self.restart_offset > 0:
+                mcu_number = self.restart_interval * self.restart_offset
+            else:
+                mcu_number = 0
+            initial_mcu = mcu_number
+            while True:
+                if (self.restart_interval and mcu_number != initial_mcu and mcu_number % self.restart_interval == 0) or\
+                   mcu_number == mcu_height * mcu_width:
+                    padding = paddingSize(self.current_size, 8)
+                    if padding:
+                        yield PaddingBits(self, "padding[]", padding) # all 1s
+                    last_byte = self.stream.readBytes(self.absolute_address + self.current_size - 8, 1)
+                    if last_byte == '\xFF':
+                        next_byte = self.stream.readBytes(self.absolute_address + self.current_size, 1)
+                        if next_byte != '\x00':
+                            raise FieldError("Unexpected byte sequence %r!"%(last_byte + next_byte))
+                        yield NullBytes(self, "stuffed_byte[]", 1)
+                    break
+                for sos_comp, num_units in components:
+                    for interleave_count in range(num_units):
+                        yield JpegHuffmanImageUnit(self, "block[%i]component[%i][]" % (mcu_number, sos_comp["component_id"].value),
+                                              self.huffman_tables[0, sos_comp["dc_coding_table"].value],
+                                              self.huffman_tables[1, sos_comp["ac_coding_table"].value])
+                mcu_number += 1
+        else:
+            self.warning("Sorry, only supporting Baseline & Extended Sequential JPEG images so far!")
+            return
+
 class JpegChunk(FieldSet):
     TAG_SOI = 0xD8
     TAG_EOI = 0xD9
     TAG_DQT = 0xDB
     TAG_DRI = 0xDD
     TAG_INFO = {
-        0xC4: ("huffman[]", "Define Huffman Table (DHT)", None),
+        0xC4: ("huffman[]", "Define Huffman Table (DHT)", DefineHuffmanTable),
         0xD8: ("start_image", "Start of image (SOI)", None),
         0xD9: ("end_image", "End of image (EOI)", None),
-        0xDA: ("start_scan", "Start Of Scan (SOS)", StartOfScan),
+        0xD0: ("restart_marker_0[]", "Restart Marker (RST0)", None),
+        0xD1: ("restart_marker_1[]", "Restart Marker (RST1)", None),
+        0xD2: ("restart_marker_2[]", "Restart Marker (RST2)", None),
+        0xD3: ("restart_marker_3[]", "Restart Marker (RST3)", None),
+        0xD4: ("restart_marker_4[]", "Restart Marker (RST4)", None),
+        0xD5: ("restart_marker_5[]", "Restart Marker (RST5)", None),
+        0xD6: ("restart_marker_6[]", "Restart Marker (RST6)", None),
+        0xD7: ("restart_marker_7[]", "Restart Marker (RST7)", None),
+        0xDA: ("start_scan[]", "Start Of Scan (SOS)", StartOfScan),
         0xDB: ("quantization[]", "Define Quantization Table (DQT)", DefineQuantizationTable),
         0xDC: ("nb_line", "Define number of Lines (DNL)", None),
         0xDD: ("restart_interval", "Define Restart Interval (DRI)", RestartInterval),
             raise ParserError("JPEG: Invalid chunk header!")
         yield textHandler(UInt8(self, "type", "Type"), hexadecimal)
         tag = self["type"].value
-        if tag in (self.TAG_SOI, self.TAG_EOI):
+        if tag in [self.TAG_SOI, self.TAG_EOI] + range(0xD0, 0xD8): # D0 - D7 inclusive are the restart markers
             return
         yield UInt16(self, "size", "Size")
         size = (self["size"].value - 2)
         return True
 
     def createFields(self):
+        frame = None
+        scan = None
+        restart_interval = None
+        restart_offset = 0
         while not self.eof:
             chunk = JpegChunk(self, "chunk[]")
             yield chunk
+            if chunk["type"].value in JpegChunk.START_OF_FRAME:
+                if chunk["type"].value not in [0xC0, 0xC1]: # SOF0 [Baseline], SOF1 [Extended Sequential]
+                    self.warning("Only supporting Baseline & Extended Sequential JPEG images so far!")
+                frame = chunk["content"]
             if chunk["type"].value == JpegChunk.TAG_SOS:
-                # TODO: Read JPEG image data...
-                break
+                if not frame:
+                    self.warning("Missing or invalid SOF marker before SOS!")
+                    continue
+                scan = chunk["content"]
+                # hack: scan only the fields seen so far (in _fields): don't use the generator
+                if "restart_interval" in self._fields:
+                    restart_interval = self["restart_interval/content/interval"].value
+                else:
+                    restart_interval = None
+                yield JpegImageData(self, "image_data[]", frame, scan, restart_interval)
+            elif chunk["type"].value in range(0xD0, 0xD8):
+                restart_offset += 1
+                yield JpegImageData(self, "image_data[]", frame, scan, restart_interval, restart_offset)
 
         # TODO: is it possible to handle piped input?
         if self._size is None:
 
     def createDescription(self):
         desc = "JPEG picture"
-        if "sof/content" in self:
-            header = self["sof/content"]
+        if "start_frame/content" in self:
+            header = self["start_frame/content"]
             desc += ": %ux%u pixels" % (header["width"].value, header["height"].value)
         return desc
 
         if end is not None:
             return end + 16
         return None
-