from filetypes.base import *
import malcat
from malcat import PythonHelper
import struct
import os




COMPRESSION_METHODS = [
    ("no compression", 0),
    ("shrunk", 1),
    ("reduced with compression factor 1", 2),
    ("reduced with compression factor 2", 3),
    ("reduced with compression factor 3", 4),
    ("reduced with compression factor 4", 5),
    ("imploded", 6),
    ("reserved", 7),
    ("deflated", 8),
    ("enhanced deflated", 9),
    ("PKWare DCL imploded", 10),
    ("reserved", 11),
    ("BZIP2", 12),
    ("reserved", 13),
    ("LZMA", 14),
    ("17: reserved", 15),
    ("compressed using IBM TERSE", 18),
    ("IBM LZ77 z", 19),
    ("PPMd version I, Rev 1", 98),
    ("WzAES", 99),
]

WZ_SALT_LENGTHS = {
    1: 8,   # 128 bit
    2: 12,  # 192 bit
    3: 16,  # 256 bit
}
WZ_KEY_LENGTHS = {
    1: 16,  # 128 bit
    2: 24,  # 192 bit
    3: 32,  # 256 bit
}


class LocalFile(Struct):

    #optim
    fields1 = [
        UInt32(name="Signature"),
        UInt16(name="VersionNeeded", comment="ZIP version needed to extract (0x14 = 20 = 2.0)"),
    ]

    flags = BitsField(
        Bit(name="EncryptedFile"),
        Bit(name="CompressionOption"),
        Bit(name="CompressionOption"),
        Bit(name="DataDescriptor"),
        Bit(name="EnhancedDeflation"),
        Bit(name="CompressedPatchedData"),
        Bit(name="StrongEncrpytion"),
        NullBits(4),
        Bit(name="LanguageEncoding"),
        NullBits(1),
        Bit(name="MaskHeaderValues"),
        name="Flags")

    fields2 = [
        UInt16(name="CompressionMethod", values=COMPRESSION_METHODS),
        DosDateTime(name="ModificationTime", comment="stored in standard MS-DOS format"),
        UInt32(name="Crc32", comment="value computed over file data by CRC-32 algorithm with 'magic number' 0xdebb20e3 (little endian)"),
        UInt32(name="CompressedSize", comment="if archive is in ZIP64 format, this filed is 0xffffffff and the length is stored in the extra field"),
        UInt32(name="UncompressedSize", comment="if archive is in ZIP64 format, this filed is 0xffffffff and the length is stored in the extra field"),
    ]


    def parse(self):
        for f in LocalFile.fields1:
            yield f
        flags = yield LocalFile.flags
        for f in LocalFile.fields2:
            yield f
        fnl = yield UInt16(name="FileNameLen", comment="the length of the file name field below")
        efl = yield UInt16(name="ExtraFieldLen", comment="the length of the extra field below")
        if fnl:
            if flags["LanguageEncoding"]:
                sc = StringUtf8
            else:
                sc = String
            yield sc(fnl, name="FileName", zero_terminated=False, comment="the name of the file including an optional relative path. All slashes in the path should be forward slashes '/'")
        start = len(self)
        while len(self) - start + 4 <= efl:
            tag, = struct.unpack("<H", self.look_ahead(2))
            field_class = EXTRA_FIELDS.get(tag, ExtraField)
            yield field_class(name=field_class.__name__)
        if len(self) - start < efl:
            yield Bytes(efl - (len(self) - start), name="ExtraDataRaw", comment="unparsed extra data")


class ExtraField(Struct):

    def parse(self):       
        yield UInt16(name="FieldId")
        sz = yield UInt16(name="FieldSize")
        yield Bytes(sz, name="FieldData")


class Zip64ExtraField(Struct):

    def parse(self):       
        yield UInt16(name="FieldId")
        sz = yield UInt16(name="FieldSize")
        yield UInt64(name="UncompressedSize")
        if sz > 8:
            yield UInt64(name="CompressedSize")
        if sz > 16:
            yield UInt64(name="RelativeHeaderOffset")
        if sz > 24:
            yield UInt32(name="DiskStart")


class UnixExtraFieldNew(Struct):

    def parse(self):       
        yield UInt16(name="FieldId")
        sz = yield UInt16(name="FieldSize")
        if sz >= 2:
            yield UInt16(name="UID")
        if sz >= 4:
            yield UInt16(name="GID")


class UnixExtraFieldOld(Struct):

    def parse(self):       
        yield UInt16(name="FieldId")
        sz = yield UInt16(name="FieldSize")
        yield Timestamp(name="LastAccess")
        yield Timestamp(name="LastModification")        
        if sz > 8:
            yield UInt16(name="UID")
            yield UInt16(name="GID")

class AesExtraField(StaticStruct):

    @classmethod
    def parse(cls):
        yield UInt16(name="FieldId")
        yield UInt16(name="FieldSize")
        yield UInt16(name="Version")
        yield String(2, zero_terminated=False, name="VendorId")
        yield UInt8(name="Strength", values=[
            ("AES_128", 1),
            ("AES_192", 2),
            ("AES_256", 3),
        ])
        yield UInt16(name="CompressionMethod", values=COMPRESSION_METHODS)


EXTRA_FIELDS = {
    1: Zip64ExtraField,
    0x5855: UnixExtraFieldOld,
    0x7855: UnixExtraFieldNew,
    0x9901: AesExtraField,
}

class DataDescriptor(StaticStruct):

    @classmethod
    def parse(cls):
        yield UInt32(name="Crc32")
        yield UInt32(name="CompressedSize")
        yield UInt32(name="UncompressedSize")



class DataDescriptorBlock(Struct):

    def parse(self):
        yield UInt32(name="Signature")
        yield UInt32(name="Crc32")
        front = self.look_ahead(0x12)
        if front[8:10] != b"PK" and front[6:8] == b"\x00\x00" and front[14:16] ==  b"\x00\x00" and front[16:] == b"PK":
            # most likely zip64
            yield UInt64(name="CompressedSize")
            yield UInt64(name="UncompressedSize")
        else:
            yield UInt32(name="CompressedSize")
            yield UInt32(name="UncompressedSize")


class CentralDirectory(Struct):
    # optim
    fields1 = [
        UInt32(name="Signature"),
        UInt8(name="Version", comment="Version made by"),
        UInt8(name="OperatingSystem", comment="OS type", values=[
            ("FAT filesystem (MS-DOS, OS/2, NT/Win32)", 0x00),
            ("Amiga", 0x01),
            ("VMS (or OpenVMS)", 0x02),
            ("Unix", 0x03),
            ("VM/CMS", 0x04),
            ("Atari TOS", 0x05),
            ("HPFS filesystem (OS/2, NT)", 0x06),
            ("Macintosh", 0x07),
            ("Z-System", 0x08),
            ("CP/M", 0x09),
            ("Windows NTFS", 0x0a),
            ("MVS (OS/390 - Z/OS)", 0x0b),
            ("VSE", 0x0c),
            ("Acorn RISCOS", 0x0d),
            ("VFAT", 0x0e),
            ("alternate MVS", 0x0f),
            ("BeOS", 0x10),
            ("Tandem", 0x11),
            ("OS/400", 0x12),
            ("OS/X (Darwin)", 0x13),
            ("unknown", 0xff),
            ]),
        UInt16(name="VersionNeeded", comment="ZIP version needed to extract (0x14 = 20 = 2.0)"),
    ]

    fields2 = [
        UInt16(name="CompressionMethod", values=COMPRESSION_METHODS),
        DosDateTime(name="ModificationTime", comment="stored in standard MS-DOS format"),
        UInt32(name="Crc32", comment="value computed over file data by CRC-32 algorithm with 'magic number' 0xdebb20e3 (little endian)"),
        UInt32(name="CompressedSize", comment="if archive is in ZIP64 format, this filed is 0xffffffff and the length is stored in the extra field"),
        UInt32(name="UncompressedSize", comment="if archive is in ZIP64 format, this filed is 0xffffffff and the length is stored in the extra field"),
    ]

    fields3 = [
        UInt16(name="DiskStart", comment="number of the disk on which this file exists"),
        BitsField(
            Bit(name="AsciiFile"),
            NullBits(1),
            Bit(name="ControlFieldBeforeLogical"),
            NullBits(13),
            name="InternalAttributes"),
        UInt32(name="ExternalAttributes", comment="host-system dependent"),
        Offset32(name="LocalFileOffset", comment="offset of where to find the corresponding local file header from the start of the first disk"),
    ]

    flags = BitsField(
            Bit(name="EncryptedFile"),
            Bit(name="CompressionOption"),
            Bit(name="CompressionOption"),
            Bit(name="DataDescriptor"),
            Bit(name="EnhancedDeflation"),
            Bit(name="CompressedPatchedData"),
            Bit(name="StrongEncrpytion"),
            NullBits(4),
            Bit(name="LanguageEncoding"),
            NullBits(1),
            Bit(name="MaskHeaderValues"),
            name="Flags")

    def parse(self):
        for f in CentralDirectory.fields1:
            yield f
        flags = yield CentralDirectory.flags
        for f in CentralDirectory.fields2:
            yield f
        fnl = yield UInt16(name="FileNameLen", comment="the length of the file name field below")
        efl = yield UInt16(name="ExtraFieldLen", comment="the length of the extra field below")
        coml = yield UInt16(name="CommentLen", comment="the length of the comment field below")
        for f in CentralDirectory.fields3:
            yield f
        if fnl:
            if flags["LanguageEncoding"]:
                sc = StringUtf8
            else:
                sc = String
            yield sc(fnl, name="FileName", zero_terminated=False, comment="the name of the file including an optional relative path. All slashes in the path should be forward slashes '/'")
        start = len(self)
        while len(self) - start + 3 < efl:
            tag, = struct.unpack("<H", self.look_ahead(2))
            field_class = EXTRA_FIELDS.get(tag, ExtraField)
            yield field_class(name=field_class.__name__)
        if coml:
            yield String(coml, name="Comment", zero_terminated=False, comment="optional comment for the file")


class EndOfCentralDirectory(Struct):

    def parse(self):
        yield UInt32(name="Signature")
        yield UInt16(name="DiskNumber", comment="number of this disk (containing the end of central directory record)")
        yield UInt16(name="DiskStart", comment="number of the disk on which the central directory starts")
        yield UInt16(name="DiskEntries", comment="number of central directory entries on this disk")
        yield UInt16(name="TotalEntries", comment="number of entries in the central directory")
        yield UInt32(name="CentralDirectorySize", comment="size of the central directory")
        yield Offset32(name="CentralDirectoryStartOffset", comment="offset of the start of the central directory on the disk on which the central directory starts")
        coml = yield UInt16(name="CommentLen", comment="the length of the comment field below")
        if coml:
            yield String(coml, name="Comment", zero_terminated=False, comment="optional comment for the archive")

class EndOfCentralDirectoryLocator64(Struct):

    def parse(self):
        yield UInt32(name="Signature")
        yield UInt32(name="DiskStart", comment="number of the disk on which the central directory starts")
        yield Offset64(name="CentralDirectoryStartOffset", comment="offset of the start of the central directory on the disk on which the central directory starts")
        yield UInt32(name="TotalDisk", comment="total number of disks")

class EndOfCentralDirectory64(Struct):

    def parse(self):
        yield UInt32(name="Signature")
        sz = yield UInt64(name="Size", comment="size of zip64 end of central directory record")
        start = len(self)
        yield UInt8(name="Version", comment="Version made by")
        yield UInt8(name="OperatingSystem", comment="OS type", values=[
            ("FAT filesystem (MS-DOS, OS/2, NT/Win32)", 0x00),
            ("Amiga", 0x01),
            ("VMS (or OpenVMS)", 0x02),
            ("Unix", 0x03),
            ("VM/CMS", 0x04),
            ("Atari TOS", 0x05),
            ("HPFS filesystem (OS/2, NT)", 0x06),
            ("Macintosh", 0x07),
            ("Z-System", 0x08),
            ("CP/M", 0x09),
            ("Windows NTFS", 0x0a),
            ("MVS (OS/390 - Z/OS)", 0x0b),
            ("VSE", 0x0c),
            ("Acorn RISCOS", 0x0d),
            ("VFAT", 0x0e),
            ("alternate MVS", 0x0f),
            ("BeOS", 0x10),
            ("Tandem", 0x11),
            ("OS/400", 0x12),
            ("OS/X (Darwin)", 0x13),
            ("unknown", 0xff),
            ])
        yield UInt16(name="VersionNeeded", comment="ZIP version needed to extract (0x14 = 20 = 2.0)")
        yield UInt32(name="DiskNumber", comment="number of this disk (containing the end of central directory record)")
        yield UInt32(name="DiskStart", comment="number of the disk on which the central directory starts")
        yield UInt64(name="DiskEntries", comment="number of central directory entries on this disk")
        yield UInt64(name="TotalEntries", comment="number of entries in the central directory")
        yield UInt64(name="CentralDirectorySize", comment="size of the central directory")
        yield Offset64(name="CentralDirectoryStartOffset", comment="offset of the start of the central directory on the disk on which the central directory starts")
        done = len(self) - start
        if done < sz:
            yield Bytes(sz - done, name="ExtensibleData")


def align(val, what, down=False):
    if not what:
        return val
    if val % what:
        if down:
            val -= val % what
        else:
            val += what - (val % what)
    return val


class ZipAnalyzer(FileTypeAnalyzer):
    category = malcat.FileType.ARCHIVE
    name = "ZIP"
    regexp = r"PK\x03\x04"

    def __init__(self):
        FileTypeAnalyzer.__init__(self)
        self.size = 0
        self.filesystem = {}

    @classmethod
    def locate(cls, curfile, offset_magic, parent_parser):
        if parent_parser is not None and parent_parser.name == "ZIP":
            # no ZIP in ZIP 
            return None
        return offset_magic, ""

    def open(self, vfile, password=None):
        if not self.size:
            raise FatalError("Invalid zip file")
        lfh, compressed_size = self.filesystem.get(vfile.path, (None, None))
        if lfh is None:
            raise KeyError("Unknown file path {}".format(vfile.path))
        if lfh["Flags"]["EncryptedFile"]:
            if password:
                pwdlist = [password.encode("utf8")]
            else:
                pwdlist = [b"infected", b"virus", b"malware", b"123456", b"infect3d"]
        else:
            pwdlist = [None]
        for pwd in pwdlist:
            try:
                buf = self.read(lfh.offset + lfh.size, compressed_size)
                try:
                    return self.unpack_manual(lfh, buf, pwd)
                except Exception as e:
                    return self.unpack_stdlib(vfile.path, pwd)
            except RuntimeError as e:
                print(e)
                if pwdlist == [None]:
                    raise
        if lfh["Flags"]["EncryptedFile"]:
            raise InvalidPassword()
        else:
            raise ValueError("Could not extract ZIP file '{}' using any of these passwords: {}".format(vfile.path, " or ".join(map(str, pwdlist))))


    @staticmethod
    def decrypt_aes(buf, pwd, strength=1):
        from Cryptodome.Protocol.KDF import PBKDF2
        from Cryptodome.Cipher import AES
        from Cryptodome.Hash import HMAC
        from Cryptodome.Util import Counter
        from Cryptodome.Hash.SHA1 import SHA1Hash

        if not pwd:
            raise ValueError("AES encryption requires a password")

        # verify password
        key_length = WZ_KEY_LENGTHS[strength]
        salt_length = WZ_SALT_LENGTHS[strength]
        salt, pwd_check = struct.unpack_from("<{}s2s".format(salt_length), buf)
        derived_key_len = 2 * key_length + len(pwd_check)
        key_data = PBKDF2(pwd, salt, count=1000, dkLen=derived_key_len)
        enc_password_verify = key_data[2*key_length:]
        if enc_password_verify != pwd_check:
            raise InvalidPassword("Bad password")
        # decrypt
        decrypter = AES.new(key_data[:key_length], AES.MODE_CTR, counter=Counter.new(nbits=128, little_endian=True))
        return decrypter.decrypt(buf[salt_length + 2:])

    
    def unpack_manual(self, lfh, buf, pwd):
        cm = lfh.CompressionMethod.enum
        if cm == "WzAES":
            for extra_field in lfh[11:]:
                if extra_field.name == "AesExtraField":
                    break
            else:
                raise ValueError("WzAES encryption and no AesExtraField defined")
            cm = extra_field.CompressionMethod.enum
            if lfh["Flags"]["EncryptedFile"]:
                buf = ZipAnalyzer.decrypt_aes(buf, pwd, extra_field["Strength"])
        elif lfh["Flags"]["EncryptedFile"]:
            checksum_byte = PythonHelper.zip_crc32_decrypt(buf[:12], pwd)[-1]
            # check pwd
            if not lfh["Flags"]["DataDescriptor"]:
                if checksum_byte != (lfh["Crc32"] >> 24):
                    raise InvalidPassword("Bad password")
            # decrypt 
            buf = PythonHelper.zip_crc32_decrypt(buf, pwd)[12:]
        if cm == "no compression":
            return buf
        elif cm == "deflated":
            import zlib
            return zlib.decompress(buf, wbits=-15)
        elif cm == "BZIP2":
            import bz2
            return bz2.decompress(buf)
        else:
            raise NotImplementedError


    def unpack_stdlib(self, path, pwd=None):
        import io
        import zipfile
        mybuf = io.BytesIO(self.read(0, self.size))
        with zipfile.ZipFile(mybuf, "r") as myzip:
            zinfo = myzip.getinfo(path)
            if not zinfo:
                raise ValueError("Could not get infos for file {}".format(vfile.path))
            return myzip.read(path, pwd)


    def parse(self, hint):
        num_local_files = 0
        num_central_directory = 0
        files_seen = set()
        central_directory_start = None
        last_file = None
        while self.remaining() > 4:
            start = self.tell()
            tag = self.read(start, 4)
            if tag == b"PK\x03\x04":
                if central_directory_start:
                    raise FatalError("Central directory started")
                lfh = yield LocalFile(category=Type.HEADER)
                if lfh["VersionNeeded"] > 0x1000:
                    raise FatalError("Invalid version")
                compressed_size = lfh["CompressedSize"]
                uncompressed_size = lfh["UncompressedSize"]
                if "Zip64ExtraField" in lfh:
                    if "CompressedSize" in lfh["Zip64ExtraField"] and lfh["Zip64ExtraField"]["CompressedSize"] != 0:
                        compressed_size = lfh["Zip64ExtraField"]["CompressedSize"]
                    if "UncompressedSize" in lfh["Zip64ExtraField"] and lfh["Zip64ExtraField"]["UncompressedSize"] != 0:
                        uncompressed_size = lfh["Zip64ExtraField"]["UncompressedSize"]
                dd = None
                if lfh["Flags"]["DataDescriptor"]:
                    # look for data descriptor sig
                    p = self.tell()
                    while compressed_size == 0:
                        ddsig, ddsig_sz = self.search(r"PK\x07\x08", p)
                        if not ddsig_sz:
                            break
                        tsize = ddsig - self.tell()
                        compressed_size_candidate_32, = struct.unpack("<I", self.read(ddsig + 8, 4))
                        compressed_size_candidate_64, = struct.unpack("<Q", self.read(ddsig + 8, 8))
                        if align(compressed_size_candidate_32, 16) == align(tsize, 16) or align(compressed_size_candidate_64, 16) == align(tsize, 16):
                            self.jump(ddsig)
                            dd = yield DataDescriptorBlock(category=Type.HEADER, parent=lfh)
                            if dd["CompressedSize"] != 0:
                                compressed_size = dd["CompressedSize"]
                            if dd["UncompressedSize"] != 0:
                                uncompressed_size = dd["UncompressedSize"]
                            break
                        else:
                            p = ddsig + ddsig_sz
                else:
                    self.jump(self.tell() + min(compressed_size, self.remaining()))
                fn = ""
                if "FileName" in lfh:
                    fn = lfh["FileName"]
                    if fn.endswith("\x00"):
                        fn = fn[:-1]
                if compressed_size and uncompressed_size:
                    size = self.tell() - start
                    if size and fn:
                        self.add_section(os.path.basename(fn), start, size)
                    if fn and not fn in files_seen:
                        files_seen.add(fn)
                        self.add_file(fn, uncompressed_size, "open")
                        self.filesystem[fn] = (lfh, compressed_size)

                num_local_files += 1
                if num_local_files > 10:
                    self.confirm()
            elif tag == b"PK\x07\x08":
                if central_directory_start:
                    raise FatalError("Central directory started")
                yield DataDescriptorBlock(category=Type.HEADER)
            elif tag == b"PK\x01\x02":
                if central_directory_start is None:
                    central_directory_start = self.tell()
                cd = yield CentralDirectory(category=Type.HEADER)
                #if cd["Flags"]["DataDescriptor"]:
                #    yield DataDescriptor(category=Type.HEADER)
                fn = ""
                if "FileName" in cd:
                    fn = cd["FileName"]
                size = self.tell() - start
                if fn and cd["UncompressedSize"] and fn not in files_seen:
                    files_seen.add(fn)
                    self.add_file(fn, cd["UncompressedSize"], "open")
                    if fn.endswith("\x00"):
                        fn = fn[:-1]
                num_central_directory += 1
            elif tag == b"PK\x06\x06":
                yield EndOfCentralDirectory64(category=Type.HEADER)
            elif tag == b"PK\x06\x07":
                yield EndOfCentralDirectoryLocator64(category=Type.HEADER)
            elif tag == b"PK\x05\x06":
                yield EndOfCentralDirectory(category=Type.HEADER)
                break
            else:
                off, sz = self.search(b"PK(?:\\x03\\x04|\\x07\\x08|\\x01\\x02|\\x06[\\x06\\x07]|\\x05\\x06)", start=self.tell())
                if sz:
                    yield Unused(off - self.tell(), name="UnusedSpace", category=Type.ANOMALY)
                    continue
                raise FatalError("Unknown tag {}".format(tag))
        if central_directory_start:
            cd_size = self.tell() - central_directory_start
            self.add_section("<directory>", central_directory_start, cd_size, r = False, discardable = True)

        if num_local_files == 0 or num_local_files != num_central_directory:
            raise FatalError("Incomplete ZIP file")
        self.size = self.tell()
            

