from filetypes.base import *
import malcat
import struct



class ChmHeader(Struct):

    def parse(self):
        yield String(4, name="Signature")
        version = yield UInt32(name="Version", comment="should be 3")
        yield UInt32(name="HeaderLength", comment="header length, including header section table and following data")
        yield UInt32(name="Unknown", comment="")
        yield UInt32BE(name="Timestamp", comment="a timestamp.  Considered as a big-endian DWORD, it appears to contain seconds (MSB) and fractional seconds (second byte).  The third and fourth bytes may contain even more fractional bits.  The 4 least significant bits in the last byte are constant.")
        yield LanguageId(name="Language", comment="file language")
        yield GUID(name="Guid1", comment= "CHM GUID")
        yield GUID(name="Guid2", comment= "CHM GUID")
        yield SectionHeader(name="FileSizeSection")
        yield SectionHeader(name="DirectorySection")
        if version >= 3:
            yield Offset64(name="ContentOffset", comment="offset within file of content section 0")



class SectionHeader(Struct):

    def parse(self):
        yield Offset64(name="Offset", comment="offset of section from beginning of file")
        yield UInt64(name="Size", comment="section size in bytes")


class FileSizeHeader(Struct):

    def parse(self):
        magic = yield UInt32(name="Magic")
        if magic != 0x1fe:
            raise FatalError("Invalid FileSizeSection magic")
        yield Unused(4)
        yield UInt64(name="FileSize", comment="file size in bytes")
        yield Unused(8)


class DirectoryHeader(Struct):

    def parse(self):
        magic = yield String(4, name="Magic")
        if magic != "ITSP":
            raise FatalError("Invalid DirectoryHeader magic")
        version = yield UInt32(name="Version", comment="should be 1")
        yield UInt32(name="HeaderLength", comment="length of this header")
        if version != 1:
            raise FatalError("Invalid DirectoryHeader version")
        yield Unused(4)
        cs = yield UInt32(name="ChunkSize", comment="should be 0x1000")
        if cs != 0x1000:
            raise FatalError("Invalid chunk size")
        yield UInt32(name="Density", comment="density of quickref section, usually 2")
        yield UInt32(name="TreeDepth", comment="depth of index tree, 1 there is no index, 2 if there is one level of PMGI chunks.")
        yield UInt32(name="RootChunk", comment="chunk number of root index chunk, -1 if there is none (though at least one file has 0 despite there being no index chunk, probably a bug.) ")
        yield UInt32(name="FirstListingChunk", comment="chunk number of first PMGL (listing) chunk")
        yield UInt32(name="LastListingChunk", comment="chunk number of last PMGL (listing) chunk")
        yield Unused(4)
        yield UInt32(name="TotalChunks", comment="number of directory chunks (total)")
        yield LanguageId(name="Language", comment="file language")
        yield GUID(name="Guid", comment= "CHM GUID")
        yield UInt32(name="HeaderLength2", comment="length of this header")
        yield Unused(12)


class IndexChunks(Struct):

    def __init__(self, number, density, **args):
        Struct.__init__(self, **args)
        self.density = density
        self.number = number

    def parse(self):
        for i in range(self.number):
            yield IndexChunk(self.density)

class ListingChunks(Struct):

    def __init__(self, number, density, **args):
        Struct.__init__(self, **args)
        self.density = density
        self.number = number

    def parse(self):
        for i in range(self.number):
            yield ListingChunk(self.density)




class ListingChunk(Struct):

    def __init__(self, density, **args):
        Struct.__init__(self, **args)
        self.density = density


    def parse(self):
        number_of_entries, =  struct.unpack("<H", self.look_ahead(0x1000)[-2:])
        magic = yield String(4, name="Magic")
        if magic != "PMGL":
            raise FatalError("Invalid ListingChunk magic")
        qrs = yield UInt32(name="ExtraSize", comment="length of free space and/or quickref area at end of directory chunk ")
        yield Unused(4)
        yield UInt32(name="PreviousChunk", comment="chunk number of previous listing chunk when reading directory in sequence (-1 if this is the first listing chunk)")
        yield UInt32(name="NextChunk", comment="chunk number of next listing chunk when reading directory in sequence (-1 if this is the last listing chunk)")
        yield ListingChunkEntries(number_of_entries, name="Entries")
        n = 1 + (1 << self.density)
        quickref_nums = number_of_entries // n
        quickref_sz = 2 + quickref_nums * 2
        padding = 0x1000 - (len(self) + quickref_sz)
        if padding > 0:
            yield Unused(padding)
        yield QuickRefArea(quickref_nums, self.density, self.offset)


class QuickRefArea(Struct):

    def __init__(self, number, density, base, **args):
        Struct.__init__(self, **args)
        self.density = density
        self.number = number
        self.base = base

    def parse(self):
        n = 1 + (1 << self.density)
        for i in range(self.number, 0, -1):
            yield Offset16(name=f"OffsetEntry{i*n}", base=self.base)
        yield UInt16(name="NumberOfEntries")


class ListingChunkEntries(Struct):

    def __init__(self, number, **args):
        Struct.__init__(self, **args)
        self.number = number

    def parse(self):
        for i in range(self.number):
            yield DirectoryListingEntry()

class IndexChunk(Struct):

    def __init__(self, density, **args):
        Struct.__init__(self, **args)
        self.density = density


    def parse(self):
        number_of_entries, =  struct.unpack("<H", self.look_ahead(0x1000)[-2:])
        magic = yield String(4, name="Magic")
        if magic != "PMGI":
            raise FatalError("Invalid IndexChunk magic")
        qrs = yield UInt32(name="ExtraSize", comment="length of free space and/or quickref area at end of directory chunk ")
        n = 1 + (1 << self.density)
        quickref_nums = number_of_entries // n
        quickref_sz = 2 + quickref_nums * 2
        padding = 0x1000 - (len(self) + quickref_sz)
        if padding > 0:
            yield Unused(padding)
        yield QuickRefArea(quickref_nums, self.density, self.offset)


class DirectoryListingEntry(Struct):

    def parse(self):
        name_size = yield VarUInt64BE(name="NameSize")
        yield StringUtf8(name_size, name="Name")
        yield VarUInt64BE(name="SectionId", comment="")
        yield VarUInt64BE(name="Offset", comment="file offset relative to the start of the content section")
        yield VarUInt64BE(name="Size", comment="size of file")



class ContentSections(Struct):

    def parse(self):
        yield UInt16(name="FileSize")
        n = yield UInt16(name="NumberOfEntries")
        for i in range(n):
            yield ContentSectionsName(name=f"Section{i}")

class ContentSectionsName(Struct):

    def parse(self):
        words = yield UInt16(name="Size")
        yield StringUtf16le(words, name="Name", zero_terminated=False)
        yield UInt16(name="Terminator")


class ContentSectionsControlData(Struct):

    def parse(self):
        dwords = yield UInt32(name="NumberOfDwords", comment = "Number of DWORDs following 'LZXC', must be 6 if version is 2")
        yield String(4, zero_terminated=False, name="Magic")
        yield UInt32(name="Version")
        yield UInt32(name="ResetInterval")
        yield UInt32(name="WindowSize")
        yield UInt32(name="CacheSize")
        if dwords > 5:
            yield Array(dwords - 5, UInt32(), name="ExtraDwords")

class ContentSectionsResetTable(Struct):

    def parse(self):
        yield UInt32(name="Version")
        n = yield UInt32(name="NumberOfEntries")
        yield UInt32(name="SizeOfEntry")
        yield UInt32(name="SizeOfHeader")
        yield UInt64(name="UncompressedSize")
        yield UInt64(name="CompressedSize")
        yield UInt64(name="BlockSize")
        yield Array(n, UInt64(), name="Entries")



class ContentSection:

    def __init__(self, offset, size):
        self.offset = offset
        self.size = size
        self.base_section = None
        self.size_uncompressed = size
        self.window_size = 0
        self.reset_interval = 0
        self.block_size = 0
        self.block_offsets = []




class ChmAnalyzer(FileTypeAnalyzer):
    category = malcat.FileType.ARCHIVE
    name = "CHM"
    regexp = r"ITSF\x03\x00\x00\x00.{3}\x00\x01\x00\x00\x00.{8}\x10\xFD\x01\x7C\xAA\x7B\xD0\x11\x9E\x0C\x00\xA0\xC9\x22\xE6\xEC\x11\xFD\x01\x7C\xAA\x7B\xD0\x11\x9E\x0C"

    def __init__(self):
        FileTypeAnalyzer.__init__(self)
        self.filesystem = {}
        self.content_sections = []
        self.content_offset = None

    def read_section(self, index):
        if index >= len(self.content_sections):
            raise IndexError(f"Invalid section id {index}")
        s = self.content_sections[index]
        if s.base_section is None:
            data = self.read(s.offset, s.size)
        else:
            data = self.read_section(s.base_section)[s.offset:s.offset + s.size]
        if s.window_size and s.block_offsets:
            dec = malcat.LzxDecompressor(s.window_size)
            res = b""
            for i, input_ptr in enumerate(s.block_offsets):
                if (i % s.reset_interval) == 0:
                    dec.reset()
                if i < len(s.block_offsets) - 1:
                    input_size = s.block_offsets[i+1] - input_ptr
                else:
                    input_size = len(data) - input_ptr
                try:
                    res += dec.decompress(data[input_ptr:input_ptr + input_size], s.block_size)
                except Exception as e:
                    raise ValueError(f"Error while LZX decompressing {s.block_size} bytes from {input_ptr:x}-{input_ptr + input_size:x} ({len(res)} / {s.size_uncompressed} done): {e}")
            data = res
        return data


    def unpack(self, vfile, password=None):
        file_entry = self.filesystem.get(vfile.path)
        if file_entry is None:
            raise IndexError("Could not find file " + vfile.path)
        section_content = self.read_section(file_entry["SectionId"])
        return section_content[file_entry["Offset"]:file_entry["Offset"] + file_entry["Size"]]

    def parse(self, hint):
        import math
        ch = yield ChmHeader(category=Type.HEADER)
        if "ContentOffset" in ch:
            self.content_offset = ch["ContentOffset"]
        
        # header section 0
        self.jump(ch["FileSizeSection"]["Offset"])
        start = self.tell()
        yield FileSizeHeader(category=Type.HEADER)
        self.add_section(".size", start, self.tell() - start, discardable=True)
        eof = self["FileSizeHeader"]["FileSize"]
        self.set_eof(eof)
       
        # header section 1
        self.jump(ch["DirectorySection"]["Offset"])
        start = self.tell()
        directory = yield DirectoryHeader(category=Type.HEADER)

        ## index chunk
        if directory["TreeDepth"] > 1 and directory["FirstListingChunk"] > 0:
            yield IndexChunks(directory["FirstListingChunk"], directory["Density"])
        if directory["FirstListingChunk"] <= directory["LastListingChunk"]:
            listings = yield ListingChunks(1 + directory["LastListingChunk"] - directory["FirstListingChunk"], directory["Density"])
            for chunk in listings:
                for entry in chunk["Entries"]:
                    name = entry["Name"]
                    if name.startswith("/#") or name.startswith("/$"):
                        name = "/$CHM" + name
                    self.filesystem[name] = entry
                    self.add_file(name, entry["Size"], "unpack")
        self.add_section(".dir", start, self.tell() - start, discardable=True)

        self.confirm()
        if self.content_offset is None:
            self.content_offset = self.tell()
        self.jump(self.content_offset)
        self.add_section(".content", self.tell(), eof - self.tell(), r=True, x=True)

        self.content_sections = [ContentSection(self.tell(), eof - self.tell())]

        # parse section name list
        self.jump(self.tell() + self.filesystem["::DataSpace/NameList"]["Offset"])
        sections = yield ContentSections(name="NameList")

        # skip first section
        for section in sections[3:]:
            section_name = section["Name"]
            path = f"::DataSpace/Storage/{section_name}/"

            content_path = path + "Content"
            if not content_path in self.filesystem:
                raise FatalError(f"could not find content file for section {section_name}")
            content = self.filesystem[content_path]
            s = ContentSection(content["Offset"], content["Size"])
            s.base_section = content["SectionId"]
            self.jump(self.content_sections[s.base_section].offset + s.offset)
            yield Bytes(s.size, name=section_name, category=Type.DATA)

            control_path = path + "ControlData"
            if control_path in self.filesystem:
                content = self.filesystem[control_path]
                self.jump(self.content_sections[content["SectionId"]].offset + content["Offset"])
                control_data = yield ContentSectionsControlData(name=f"{section_name}.ControlData", category=Type.HEADER)
                s.reset_interval = control_data["ResetInterval"]
                s.window_size = int(math.log2(control_data["WindowSize"] * 0x8000))

            span_info_path = path + "SpanInfo"
            if span_info_path in self.filesystem:
                content = self.filesystem[span_info_path]
                self.jump(self.content_sections[content["SectionId"]].offset + content["Offset"])
                s.size_uncompressed = yield UInt64(name=f"{section_name}.UncompressedSize", category=Type.HEADER)

            reset_table_path = path + "Transform/{7FC28940-9D31-11D0-9B27-00A0C91E9C7C}/InstanceData/ResetTable"
            if reset_table_path in self.filesystem:
                content = self.filesystem[reset_table_path]
                self.jump(self.content_sections[content["SectionId"]].offset + content["Offset"])
                reset_table = yield ContentSectionsResetTable(name=f"{section_name}.ResetTable", category=Type.HEADER)
                s.block_offsets = [x.value for x in reset_table["Entries"]]
                s.block_size = reset_table["BlockSize"]

            if s.base_section != 0:
                raise ValueError(f"Invalid base section {s.base_section}")

            self.content_sections.append(s)
            


