from filetypes.base import *
from filetypes.CFB_office import codepage2codec
import struct
import json
import itertools
import traceback

class MsiDatabase:

    def __init__(self):
        self.strings = []
        self.string_index_size = 2
        self.valid = False
        self.tables = {}
        self.types = {}

    def read_string(self, data):
        if self.string_index_size == 2:
            index, = struct.unpack_from("<H", data)
        elif self.string_index_size == 3:
            index_low, index_high = struct.unpack_from("<HB", data)
            index = index_low + (index_high << 16)
        else:
            raise ValueError("Invalid string index size")
        if not index:
            return "", self.string_index_size
        return self.strings[index - 1], self.string_index_size


class MsiTable:

    def __init__(self, name, columns=None):
        self.name = name
        self.columns = columns or []

    def read_data(self, db, table_data):
        table_data = memoryview(table_data)
        row_size = sum(c.type.size for c in self.columns)
        num_rows = len(table_data) // row_size
        columns_offsets = []
        for i, c in enumerate(self.columns):
            if i == 0:
                columns_offsets.append(0)
            else:
                columns_offsets.append(columns_offsets[i-1] + self.columns[i-1].type.size * num_rows)
        for n in range(num_rows):
            val = {}
            for i, c in enumerate(self.columns):
                off = columns_offsets[i] + n * self.columns[i].type.size
                val[self.columns[i].name] = self.columns[i].type.read_data(db, table_data[off:])[0]
            yield val

    def __repr__(self):
        return f"{self.name} ({', '.join([c.name + ':' + c.type.name for c in self.columns])})"


class MsiColumn:

    def __init__(self, name, type):
        self.name = name
        self.type = type


class MsiType:

    DATASIZEMASK = 0x00ff
    VALID = 0x0100
    LOCALIZABLE = 0x200
    NONBINARY = 0x400
    STRING = 0x0800
    NULLABLE = 0x1000
    KEY = 0x2000
    TEMPORARY = 0x4000
    UNKNOWN = 0x8000

    def __init__(self, valid=True, nullable=False, key=False):
        self.valid = valid
        self.nullable = nullable
        self.key = key

    def read_data(self, db, data):
        raise NotImplementedError

    @property
    def name(self):
        raise NotImplementedError

    @property
    def size(self):
        raise NotImplementedError

    def __repr__(self):
        return self.name


class MsiTypeByte(MsiType):

    def read_data(self, db, data):
        return struct.unpack_from("<B", data)[0], 1

    @property
    def name(self):
        return "BYTE"

    @property
    def size(self):
        return 1


class MsiTypeRawShort(MsiType):

    def read_data(self, db, data):
        return struct.unpack_from("<H", data)[0], 2

    @property
    def name(self):
        return "SHORT"

    @property
    def size(self):
        return 2

class MsiTypeShort(MsiType):

    def read_data(self, db, data):
        val = struct.unpack_from("<H", data)[0]
        if not val:
            val = None
        else:
            val = val - 0x8000
        return val, 2


    @property
    def name(self):
        return "SHORT"

    @property
    def size(self):
        return 2    


class MsiTypeInt(MsiType):

    def read_data(self, db, data):
        val = struct.unpack_from("<I", data)[0]
        if not val:
            val = None
        else:
            val = val - 0x80000000
        return val, 4

    @property
    def name(self):
        return "INT"

    @property
    def size(self):
        return 4



class MsiTypeBinary(MsiType):

    def read_data(self, db, data):
        return struct.unpack_from("<H", data)[0], 2

    @property
    def name(self):
        return "BINARY"

    @property
    def size(self):
        return 2


class MsiTypeStringShort(MsiType):

    def read_data(self, db, data):
        return db.read_string(data)

    @property
    def name(self):
        return "STR"

    @property
    def size(self):
        return 2


class MsiTypeStringLong(MsiType):

    def read_data(self, db, data):
        return db.read_string(data)

    @property
    def name(self):
        return "STR"

    @property
    def size(self):
        return 3



def parse_msi(self):
    db = MsiDatabase()

    #read strings
    if "/_StringPool" not in self.filesystem:
        raise FatalError("Missing StringPool stream")
    if "/_StringData" not in self.filesystem:
        raise FatalError("Missing _StringData stream")
    pool = self.read_file("/_StringPool")
    data = self.read_file("/_StringData")
    index_pool = 4
    index_data = 0
    codepage, = struct.unpack("<I", pool[:4])
    if codepage & 0x80000000:
        db.string_index_size = 3
        MsiTypeString = MsiTypeStringLong
    else:
        MsiTypeString = MsiTypeStringShort
    codec = codepage2codec(codepage & 0x7fffffff)
    while index_pool + 4 <= len(pool):
        size, refcount = struct.unpack_from("<HH", pool, index_pool)
        index_pool += 4
        if size == 0 and refcount > 0:
            # big strings
            size, = struct.unpack_from("<I", pool, index_pool)
            index_pool += 4
        string_data = data[index_data:index_data+size]
        string_data = string_data.decode(codec, errors="replace")
        db.strings.append(string_data)
        index_data += size

    # read tables
    if "/_Tables" not in self.filesystem:
        raise FatalError("Missing Tables stream")
    if "/_Columns" not in self.filesystem:
        raise FatalError("Missing Columns stream")
    db.tables["Tables"] = MsiTable("Tables", [
        MsiColumn("Name", MsiTypeString())
    ])
    table_data = self.read_file("/_Tables")
    for table in db.tables["Tables"].read_data(db, table_data):
        db.tables[table["Name"]] = MsiTable(table["Name"])
    self.confirm()

    # read columns
    db.tables["Columns"] = MsiTable("Columns", [
        MsiColumn("Table", MsiTypeString()),
        MsiColumn("Number", MsiTypeShort()),
        MsiColumn("Name", MsiTypeString()),
        MsiColumn("Type", MsiTypeRawShort()),
    ])
    col_data = self.read_file("/_Columns")
    for c in db.tables["Columns"].read_data(db, col_data):
        table = db.tables[c["Table"]]
        if (c["Type"] & (~MsiType.NULLABLE)) in (MsiType.STRING | MsiType.VALID, MsiType.STRING | MsiType.VALID | MsiType.UNKNOWN):
            type = MsiTypeBinary
        elif c["Type"] & MsiType.STRING:
            type = MsiTypeString
        elif c["Type"] & MsiType.DATASIZEMASK == 2 or (c["Type"] & (~MsiType.NULLABLE)) in (MsiType.NONBINARY | MsiType.VALID, MsiType.NONBINARY | MsiType.VALID | MsiType.UNKNOWN):
            type = MsiTypeShort
        elif c["Type"] & MsiType.DATASIZEMASK == 4:
            type = MsiTypeInt
        elif c["Type"] & MsiType.DATASIZEMASK == 1:
            type = MsiTypeByte
        else:
            raise ValueError(f"Unrecognized column type for {table.name}.{c['Name']}: {c['Type']:x}")
        table.columns.append(MsiColumn(c["Name"], type()))

    db.valid = True

    # read certificate
    if "/DigitalSignature" in self.filesystem:
        from filetypes.P7X import parse_der_certificate
        sigfile = self.read_file("/DigitalSignature")
        certs_meta = []
        done = 0
        while done < len(sigfile):
            size, meta = parse_der_certificate(sigfile[done:])
            done += size
            certs_meta.append(meta)
            if not meta or not size:
                break

        if len(certs_meta) == 1:
            for k, v in certs_meta[0].items():
                self.add_metadata(k, v, category="Certificate")
        else:
            for i, meta in enumerate(certs_meta):
                for k,v in meta.items():
                    self.add_metadata("Cert[{}].{}".format(i, k), v, category="Certificates")
    return db


def decompile_msi(self):
    r = ""
    for t in sorted(self.msi_db.tables.values(), key=lambda x: x.name):
        fname = f"/{t.name}"
        if t.name.startswith("_") or fname not in self.filesystem:
            continue
        r += f"############################### {t.name} ##############################\n"
        r += f"{t}\n\nData:\n"
        try:
            table_data = self.read_file(fname)
            for row in t.read_data(self.msi_db, table_data):
                r += json.dumps(row) + "\n"
        except Exception as e:
            r += traceback.format_exc()
        r += "\n\n"
    return r
