from filetypes.base import *
import malcat


class BitmapFileHeader(Struct):

    def parse(self):
        yield String(2, name="Signature", comment="signature (always BM)")
        yield UInt32(name="FileSize", comment="file size in bytes")
        yield Bytes(4, name="reserved")
        yield UInt32(name="DataStart", comment="start of bmp data")


COMPRESSION_NAME = {
        "Uncompressed RGB" : 0,
        "RLE 8 bits" : 1,
        "RLE 4 bits" : 2,
        "Bitfields" : 3,
        "JPEG" : 4,
        "PNG" : 5,
        }



class BitmapInfoHeader(Struct):

    def parse(self):
        hdrsz = yield UInt32(name="biSize", comment="size of header in bytes")
        if hdrsz not in (12, 40, 108, 124):
            raise FatalError("Invalid BitmapInfoHeader size {}".format(hdrsz))
        yield UInt32(name="biWidth", comment="width in pixels")
        yield UInt32(name="biHeight", comment="height in pixels")
        nop = yield UInt16(name="biPlanes", comment="number of color planes")
        if nop != 1:
            raise FatalError("biPlanes must be 1")
        bpp = yield UInt16(name="biBitCount", comment="bits per pixel")
        if bpp not in (0,1,4,8,15,16,24,32):
            raise FatalError("Invalid biBitCount")
        if hdrsz < 40:
            return
        cmpr = yield UInt32(name="biCompression", comment="compression algorithm used", values=list(COMPRESSION_NAME.items()))
        if cmpr not in COMPRESSION_NAME.values():
            raise FatalError("Invalid Compression")
        yield UInt32(name="biSizeImage", comment="size of image data in bytes")
        yield UInt32(name="biXres", comment="horizontal pixels per meter")
        yield UInt32(name="biYres", comment="vertical pixels per meter")
        clus = yield UInt32(name="biColorsUsed", comment="number of colors used")
        yield UInt32(name="biColorsImportant", comment="number of important colors")

        if bpp == 1:
            yield Array(2, RGBA(), name="Palette")
        elif bpp == 4:
            yield Array(16, RGBA(), name="Palette")
        elif bpp == 8:
            yield Array(min(clus or 256, 256), RGBA(), name="Palette")
        elif clus > 0:
            yield Array(clus, RGBA(), name="Palette")
        if hdrsz > len(self):
             yield Unused(hdrsz - len(self), name="RestOfHeader")

        if cmpr == 3 and bpp == 16:
            yield Array(3, UInt16(), name="BitsMasks")
        elif cmpr == 3 and bpp == 32:
            yield Array(3, UInt32(), name="BitsMasks")
    

class RGBA(Struct):

    def parse(self):
        yield UInt8(name="Red", comment="")
        yield UInt8(name="Green", comment="")
        yield UInt8(name="Blue", comment="")
        yield UInt8(name="Alpha", comment="")


class RGB(Struct):

    def parse(self):
        yield UInt8(name="Red", comment="")
        yield UInt8(name="Green", comment="")
        yield UInt8(name="Blue", comment="")



class BMPRawAnalyzer(FileTypeAnalyzer):
    category = malcat.FileType.IMAGE
    name = "DIB"
    regexp = r"[\x0C\x28\x6C]\0{3}+.{8}+\x01\x00[\x01\x04\x08\x0F\x10\x18\x20\x28]\x00"

   

    def parse(self, hint):
        import collections
        start = self.tell()
        bih = yield BitmapInfoHeader(category=Type.HEADER)
        if bih["biWidth"] * bih["biHeight"] > 9000000:
            raise FatalError("Image bigger than 8K")
        if bih["biWidth"] > 8000 or  bih["biHeight"] > 8000:
            raise FatalError("Width or height bigger than 8K")
        has_mask = False
        datasz = 0
        line_width = ((bih["biWidth"] * bih["biBitCount"] + 31) // 32) * 4
        height = bih["biHeight"]
        if bih["biSize"] == 12:
            datasz = line_width * height
        elif "biCompression" in bih and bih["biCompression"] == 0:
            computed_size = line_width * height
            mask_size = ((bih["biWidth"] + 31) // 32) * 4 * (height // 2)
            if bih["biSizeImage"] == 0:
                datasz = computed_size
                if "ICO" in hint:
                    has_mask = True
                elif computed_size // 2 + mask_size == self.remaining():
                    has_mask = True
                elif height == bih["biWidth"] * 2 and computed_size >= 256 and bih["biSize"] == 40:
                    # heuristic to see if it's an ICO, i.e there is a bitmask
                    toread = min(mask_size, 128)
                    bytes_b4_mask = self.read(self.tell() + computed_size//2 - toread, toread)
                    bytes_after_mask = self.read(self.tell() + computed_size//2, toread)
                    c1 = collections.Counter(bytes_b4_mask)
                    c2 = collections.Counter(bytes_after_mask)
                    if c2[0xFF] > 5 * c1[0xFF] or c2[0x7F] > 5 * c1[0x7F] or c2[0] > 5 * c1[0]:
                        has_mask = True
            elif computed_size // 2 + mask_size == bih["biSizeImage"] or computed_size // 2 == bih["biSizeImage"]:  # ICO bmp followed by an AND mask
                has_mask = True
            elif ("ICO" in hint or computed_size // 2 + mask_size == self.remaining()) and height == bih["biWidth"] * 2 and computed_size == bih["biSizeImage"]:
                # weird case
                has_mask = True
            else:   
                datasz = bih["biSizeImage"]
        elif "biSizeImage" in bih:
            datasz = bih["biSizeImage"]
        if has_mask:
            datasz = computed_size // 2
            height = height // 2
        if not datasz or datasz > self.remaining():
            raise FatalError("Invalid bitmap size: {:d} vs {:d} ({})".format(datasz, self.remaining(), has_mask))

        # bitmap bytes
        imagedata_offset = self.tell()
        data = yield Bytes(datasz, name="biImageData", category=Type.DATA)
        self.add_section("data", imagedata_offset, len(data))

        # mask
        if has_mask:
            mask_offset = self.tell()
            data = yield Bytes(mask_size, name="biAndMask", category=Type.FIXUP)
            self.add_section("mask", mask_offset, len(data), discardable=True)

        # metadata
        self.add_metadata("Resolution", "{:d}x{:d}x{:d}bpp".format(self["BitmapInfoHeader"]["biWidth"], self["BitmapInfoHeader"]["biHeight"], self["BitmapInfoHeader"]["biBitCount"]))


class BMPAnalyzer(FileTypeAnalyzer):
    category = malcat.FileType.IMAGE
    name = "BMP"
    regexp = r"BM.{12}+[\x0C\x28\x6C]\0{3}+"


    def parse(self, hint):
        bfh = yield BitmapFileHeader(category=Type.HEADER)
        bih = yield BitmapInfoHeader(category=Type.HEADER)
        if bfh["DataStart"] > self.tell():
            yield Unused(bfh["DataStart"] - self.tell(), name="Gap", category=Type.ANOMALY)
        datasz = bih["biSizeImage"]
        if bih["biCompression"] in (0, 3) and bih["biSizeImage"] == 0:
            datasz = ((((bih["biWidth"] * bih["biBitCount"] + 31) // 32) * 32) // 8) * bih["biHeight"]
        if not datasz or datasz > self.remaining():
            raise FatalError("Invalid bitmap size: {:d}".format(datasz))

        # bitmap bytes
        imagedata_offset = self.tell()
        data = yield Bytes(datasz, name="biImageData", category=Type.DATA)
        self.add_section("data", imagedata_offset, len(data))

        # metadata
        self.add_metadata("Resolution", "{:d}x{:d}x{:d}bpp".format(self["BitmapInfoHeader"]["biWidth"], self["BitmapInfoHeader"]["biHeight"], self["BitmapInfoHeader"]["biBitCount"]))
