from filetypes.base import *
import malcat
import struct
import re

def align(val, what, down=False):
    if val % what:
        if down:
            val -= val % what
        else:
            val += what - (val % what)
    return val



class KoliHeader(Struct):

    def parse(self):
        yield String(4, zero_terminated=False, name="Signature")
        yield UInt32BE(name="Version")
        yield UInt32BE(name="HeaderSize")
        yield UInt32BE(name="Flags")
        yield UInt64BE(name="RunningDataForkOffset")
        yield UInt64BE(name="DataForkOffset", comment="fork offset (usually 0, beginning of file)")
        yield UInt64BE(name="DataForkLength")
        yield UInt64BE(name="RsrcForkOffset")
        yield UInt64BE(name="RsrcForkLength")
        yield UInt32BE(name="SegmentNumber")
        yield UInt32BE(name="SegmentCount")
        yield GUID(microsoft_order=False, name="SegmentID")
        yield UInt32BE(name="DataChecksumType")
        yield UInt32BE(name="DataChecksumSize")
        yield Array(32, UInt32BE(), name="DataChecksums")
        yield UInt64BE(name="XMLOffset")
        yield UInt64BE(name="XMLLength")
        yield Unused(120)
        yield UInt32BE(name="ChecksumType")
        yield UInt32BE(name="ChecksumSize")
        yield Array(32, UInt32BE(), name="Checksums")
        yield UInt32BE(name="ImageVariant")
        yield UInt64BE(name="SectorCount")
        yield Unused(12)


class CompressedChunk:

    def __init__(self, offset, size, sector_start, sector_count, compression):
        self.offset = offset
        self.size = size
        self.sector_start = sector_start
        self.sector_count = sector_count
        self.compression = compression

    def __repr__(self):
        return f"[{self.offset:x}-{self.offset+self.size:x}[ --> [{self.sector_start*512:x}:{(self.sector_start+self.sector_count)*512:x}[ ({self.compression:x})"

class DmgTable:

    def __init__(self, name, sector_start, sector_count, chunks=[]):
        self.name = name
        self.sector_start = sector_start
        self.sector_count = sector_count
        self.chunks = chunks or []

    @property
    def offset(self):
        o = None
        for c in self.chunks:
            if o is None or c.offset < o:
                o = c.offset
        return o

    @property
    def size(self):
        o = 0
        for c in self.chunks:
            o += c.size
        return o


class DMGAnalyzer(FileTypeAnalyzer):
    category = malcat.FileType.FILESYSTEM
    name = "DMG"
    regexp = r"koly\x00\x00\x00\x04\x00\x00\x02\x00.{476}\x00\x00\x00\x01"

    @classmethod
    def locate(cls, curfile, offset_magic, parent_parser):
        xml_offset, xml_size = struct.unpack(">QQ", curfile.read(offset_magic+216, 16))
        xml_size = min(xml_size, offset_magic - xml_offset)
        if xml_offset <= 0 or xml_size < 4 or xml_offset + xml_size >= len(curfile):
            return None, None
        header = curfile.read(xml_offset, 4)
        return offset_magic - (xml_offset + xml_size), f"{offset_magic:d}"

    def __init__(self):
        FileTypeAnalyzer.__init__(self)
        self.filesystem = {}

    def open_table(self, vfile, password=""):
        import bz2
        import zlib
        lzfse = malcat.LzfseDecompressor()

        dmgtable = self.filesystem.get(vfile.path)
        if not dmgtable:
            raise KeyError(f"File not found : {path} {list(self.filesystem.keys())}")
        output = bytearray(dmgtable.sector_count * 512)
        for chunk in dmgtable.chunks:
            compressed_data = self.read(chunk.offset, chunk.size)
            if chunk.compression == 0x80000007:
                #lzfse
                dec = lzfse.decompress(compressed_data, chunk.sector_count * 512)
            elif chunk.compression == 0x80000006:
                dec = bz2.decompress(compressed_data)
            elif chunk.compression == 0x80000005:
                dec = zlib.decompress(compressed_data)
            elif chunk.compression == 1:
                dec = compressed_data
            if chunk.compression not in (0, 2):
                if len(dec) !=  chunk.sector_count * 512:
                    raise ValueError(f"Invalid decompressed size: {len(dec)} != {chunk.sector_count * 512}")
                output[chunk.sector_start: chunk.sector_start + chunk.sector_count * 512] = dec
        return output


    def parse_xml(self, xml):
        import xml.etree.ElementTree as ET
        from base64 import b64decode
        root = ET.fromstring(xml)
        dictionnary = root[0]
        for key, fork in zip(dictionnary[::2], dictionnary[1::2]):
            if key.tag == "key" and key.text == "resource-fork":
                for key, blkx in zip(fork[::2], fork[1::2]):
                    if key.tag == "key" and key.text == "blkx":
                        key = None
                        for dchunk in blkx.iter("dict"):
                            meta = {}
                            for entry in dchunk:
                                if entry.tag == "key":
                                    key = entry.text
                                else:
                                    if entry.tag == "string":
                                        value = entry.text
                                    elif entry.tag == "data":
                                        value = b64decode(entry.text)
                                    else:
                                        raise ValueError(f"Unrecognized value type {entry.tag}")
                                    meta[key] = value
                                    key = ""
                            if "CFName" in meta and "Data" in meta:
                                self.parse_table(meta["CFName"], meta["Data"])


    def parse_table(self, name, data):
        _, version, secstart, seccount, dataoff, _, descnum, _, numblocks = struct.unpack_from(">IIQQQII160sI", data)
        name = re.sub(r"\s*:\s*\d+\s*", "", name).capitalize()
        if not name:
            return
        chunks = []
        if version != 1:
            raise ValueError("Unsupported version")
        tbl = DmgTable(name, secstart, seccount)
        total_size = seccount * 512
        min_offset = 0xffffffffffffffff
        max_offset = 0
        if numblocks:
            for type, comment, secstart, seccount, offset, size in struct.iter_unpack(">IIQQQQ", data[204:]):
                if type in (0x7ffffffe, 0xffffffff):
                    continue
                cc = CompressedChunk(offset, size, secstart, seccount, type)
                tbl.chunks.append(cc)
                min_offset = min(min_offset, offset)
                max_offset = max(max_offset, offset + size)
        if max_offset > min_offset and not name in self.filesystem:
            self.add_file(name, total_size, "open_table")
            self.filesystem[name] = tbl
            self.add_section(name, min_offset, max_offset - min_offset, 
                    tbl.sector_start * 512, tbl.sector_count * 512,
                    r=True)




    def parse(self, hint):
        koli_location = int(hint or "0")
        self.jump(koli_location)
        self.add_section("header", self.tell(), 512, r=False, discardable=True)
        hdr = yield KoliHeader()
        self.add_metadata("Segment ID", hdr["SegmentID"])
        self.set_eof(self.tell() + 512)
        self.jump(hdr["XMLOffset"])
        sz = min(hdr["XMLLength"], koli_location - self.tell())
        xml = yield String(sz, zero_terminated=False, name="XMLMetadata", category=Type.META)
        self.add_section("XML", hdr["XMLOffset"], len(xml), r=False, discardable=True)
        self.confirm()
        
        try:
            self.parse_xml(xml)
        except Exception as e:
            raise FatalError(f"Could not parse XML stream: {e}")

