from filetypes.base import *
import malcat 



class IHDR(Struct):

    def parse(self):
        size = yield UInt32BE(name="Size", comment="size of data")
        yield String(4, name="Tag", comment="chunk name")
        old = len(self)
        yield UInt32BE(name="Width", comment="width in pixels")
        yield UInt32BE(name="Height", comment="eight in pixels")
        yield UInt8(name="Bits", comment="bits depth")
        yield BitsField(
            Bit(name="Palette", comment="colors are palette indices"),
            Bit(name="Color", comment="image uses color (otherwise greyscale)"),
            Bit(name="Alpha", comment="image has halpha channel"),
            NullBits(5),
            name = "Color type")
        yield UInt8(name="Compression", comment="compression method", values=[("Inflat", 0)])
        yield UInt8(name="Filter", comment="filter method")
        yield UInt8(name="Interlace", comment="interlace method")
        done = len(self) - old
        if done > size:
            raise FatalError("Invalid IHDR Chunk size {} vs {}".format(done, size))
        elif done < size:
            yield Unused(size-done, name="Overlay (error)")
        yield UInt32BE(name="Checksum", comment="crc32 of data")

    
class GAMA(Struct):

    def parse(self):
        size = yield UInt32BE(name="Size", comment="size of data")
        yield String(4, name="Tag", comment="chunk name")
        old = len(self)
        yield UInt32BE(name="Gamma", comment="gamma value")
        done = len(self) - old
        if done > size:
            raise FatalError("Invalid Chunk size {} vs {}".format(done, size))
        elif done < size:
            yield Unused(size-done, name="Overlay (error)")
        yield UInt32BE(name="Checksum", comment="crc32 of data")


class SRGB(Struct):

    def parse(self):
        size = yield UInt32BE(name="Size", comment="size of data")
        yield String(4, name="Tag", comment="chunk name")
        old = len(self)
        yield UInt8(name="RenderingIntent", comment="how color rendering should be done")
        done = len(self) - old
        if done > size:
            raise FatalError("Invalid Chunk size {} vs {}".format(done, size))
        elif done < size:
            yield Unused(size-done, name="Overlay (error)")
        yield UInt32BE(name="Checksum", comment="crc32 of data")

class RGB(Struct):

    def parse(self):
        yield UInt8(name="Red", comment="")
        yield UInt8(name="Green", comment="")
        yield UInt8(name="Blue", comment="")


class PLTE(Struct):

    def parse(self):
        size = yield UInt32BE(name="Size", comment="size of data")
        yield String(4, name="Tag", comment="chunk name")
        old = len(self)
        if size % 3:
            raise FatalError("PLTE size must be a multiple of 3")
        if size:
            yield Array(size//3, RGB(), name="Palette", comment="color palette")
        yield UInt32BE(name="Checksum", comment="crc32 of data")


class PHYS(Struct):

    def parse(self):
        size = yield UInt32BE(name="Size", comment="size of data")
        yield String(4, name="Tag", comment="chunk name")
        old = len(self)
        yield UInt32BE(name="PixelRatioX", comment="pixels per unit, X axis")
        yield UInt32BE(name="PixelRatioY", comment="pixels per unit, Y axis")
        yield UInt8(name="Unit", comment="unit specifier")
        done = len(self) - old
        if done > size:
            raise FatalError("Invalid Chunk size {} vs {}".format(done, size))
        elif done < size:
            yield Unused(size-done, name="Overlay (error)")
        yield UInt32BE(name="Checksum", comment="crc32 of data")

    
class IDAT(Struct):

    def parse(self):
        size = yield UInt32BE(name="Size", comment="size of data")
        yield String(4, name="Tag", comment="chunk name")
        yield Bytes(size, name="Data", comment="zlib stream chunk")
        yield UInt32BE(name="Checksum", comment="crc32 of data")


    
class IEND(Struct):

    def parse(self):
        size = yield UInt32BE(name="Size", comment="size of data")
        yield String(4, name="Tag", comment="chunk name")
        if size:
            yield Unused(size, name="Overlay (error)")
        yield UInt32BE(name="Checksum", comment="crc32 of data")    

class Chunk(Struct):

    def parse(self):
        size = yield UInt32BE(name="Size", comment="size of data")
        yield String(4, name="Tag", comment="chunk name")
        if size:
            yield Bytes(size, name="Data", comment="zlib stream chunk")
        yield UInt32BE(name="Checksum", comment="crc32 of data")    

class TEXT(Struct):

    def parse(self):
        size = yield UInt32BE(name="Size", comment="size of data")
        yield String(4, name="Tag", comment="chunk name")
        old = len(self)
        yield CString(max_size=79, name="Keyword", comment="Text keyword")
        if len(self) - old > size:
            raise FatalError("Cannot read comment value")
        elif len(self) - old < size:
            yield String(size - (len(self) - old), name="Value", comment="Text value")
        yield UInt32BE(name="Checksum", comment="crc32 of data")    


class TIME(Struct):

    def parse(self):
        size = yield UInt32BE(name="Size", comment="size of data")
        yield String(4, name="Tag", comment="chunk name")
        old = len(self)
        yield UInt16BE(name="Year")
        yield UInt8(name="Month")
        yield UInt8(name="Day")
        yield UInt8(name="Hour")
        yield UInt8(name="Minute")
        yield UInt8(name="Second")
        done = len(self) - old
        if done > size:
            raise FatalError("Invalid Chunk size {} vs {}".format(done, size))
        elif done < size:
            yield Unused(size-done, name="Overlay (error)")
        yield UInt32BE(name="Checksum", comment="crc32 of data")     


class VPAG(Struct):

    def parse(self):
        size = yield UInt32BE(name="Size", comment="size of data")
        yield String(4, name="Tag", comment="chunk name")
        old = len(self)
        yield UInt32BE(name="VirtualImageWidth")
        yield UInt32BE(name="VirtualImageHeight")
        yield UInt8(name="VirtualPageUnits")
        done = len(self) - old
        if done > size:
            raise FatalError("Invalid Chunk size {} vs {}".format(done, size))
        elif done < size:
            yield Unused(size-done, name="Overlay (error)")
        yield UInt32BE(name="Checksum", comment="crc32 of data")  


class BKGD(Struct):

    def parse(self):
        size = yield UInt32BE(name="Size", comment="size of data")
        yield String(4, name="Tag", comment="chunk name")
        old = len(self)
        if size == 1:
            # palette
            yield UInt8(name="PaletteIndex")
        elif size == 2:
            yield UInt16BE(name="GreyScale")
        elif size == 6:
            yield UInt16BE(name="Red")
            yield UInt16BE(name="Green")
            yield UInt16BE(name="Blue")
        done = len(self) - old
        if done > size:
            raise FatalError("Invalid Chunk size {} vs {}".format(done, size))
        elif done < size:
            yield Unused(size-done, name="Overlay (error)")
        yield UInt32BE(name="Checksum", comment="crc32 of data")         


class PNGAnalyzer(FileTypeAnalyzer):
    category = malcat.FileType.IMAGE
    name = "PNG"
    regexp = r"\x89PNG\r\n\x1A\n"



    def parse(self, hint):
        yield Bytes(8, name="Signature", category=Type.HEADER)
        seen = {}
        seen_end = False
        while not self.eof() and not seen_end:
            if self.remaining() < 8:
                break
            tag = self.read(self.tell() + 4, 4)
            try:
                name = tag.decode("ascii")
            except:
                raise FatalError("Invalid tag name {} at {:x}".format(tag, self.tell() + 4))
            #if tag in seen:
            #    name = "{}.{:d}".format(name, seen[tag])
            seen[tag] = seen.get(tag, 0) + 1
            block = None
            if tag == b"IHDR":
                block = yield IHDR(category=Type.HEADER, name=name)
                self.add_metadata("Resolution", "{}x{} -- {}bpp".format(block["Width"], block["Height"], block["Bits"]))
            elif tag == b"gAMA":
                block = yield GAMA(category=Type.HEADER, name=name)
            elif tag == b"sRGB":
                block = yield SRGB(category=Type.HEADER, name=name)
            elif tag == b"bKGD":
                block = yield BKGD(category=Type.HEADER, name=name)
            elif tag == b"pHYs":
                block = yield PHYS(category=Type.HEADER, name=name)
            elif tag == b"vpAg":
                block = yield VPAG(category=Type.HEADER, name=name)
            elif tag == b"IDAT":
                block = yield IDAT(category=Type.DATA, name=name)                
            elif tag == b"tEXt":
                block = yield TEXT(category=Type.META, name=name)
                try:
                    kw = block["Keyword"]#.decode("latin1").encode("ascii", errors="replace")
                    val = block["Value"]#.decode("latin1")
                    self.add_metadata(kw, val)
                except BaseException as e:
                    print(e)
            elif tag == b"tIME":
                block = yield TIME(category=Type.META, name=name)
                self.add_metadata("Last Modified", "{}-{}-{} {}:{}:{}".format(block["Year"], block["Month"], block["Day"], block["Hour"], block["Minute"], block["Second"]))
            elif tag == b"IEND":
                seen_end = True
                block = yield IEND(category=Type.HEADER, name=name)
            else: 
                block = yield Chunk(category=Type.HEADER, name=name)
            self.add_section(tag.decode("ascii"), block.offset, block.size)

        if not seen_end:
            raise FatalError("no END block")

       
        

