from filetypes.base import *
import malcat


class PascalScriptHeader(Struct):

    def parse(self):
        yield String(4, zero_terminated=False, name="Magic")
        yield UInt32(name="Version")
        yield UInt32(name="TypesCount")
        yield UInt32(name="FunctionsCount")
        yield UInt32(name="VariablesCount")
        yield UInt32(name="EntryPoint")
        yield UInt32(name="ImportSize")



TYPES = {
        0: ("ReturnAddress", 28),
        1: ("U8", 1),
        2: ("S8", 1),
        3: ("U16", 2),
        4: ("S16", 2),
        5: ("U32", 4),
        6: ("S32", 4),
        7: ("Single", 4),
        8: ("Double", 8),
        9: ("Extended", 10),
        10: ("String", 4),
        11: ("Record", 0),
        12: ("Array", 0),
        13: ("Pointer", 4),
        14: ("PChar", 4),
        15: ("ResourcePointer", 0),
        16: ("Variant", 16),
        17: ("S64", 8),
        18: ("Char", 1),
        19: ("WideString", 4),
        20: ("WideChar", 2),
        21: ("ProcPtr", 4),
        22: ("StaticArray", 0),
        23: ("Set", 0),
        24: ("Currency", 8),
        25: ("Class", 4),
        26: ("Interface", 4),
        27: ("NotificationVariant", 0),
        28: ("UnicodeString", 4),
        130: ("Type", 0),
        129: ("Enum", 0),
        131: ("ExtClass", 0),
        0x97: ("ExportedSet", 0),
        0x99: ("ExportedClass", 0),
}

TYPES_ENUM = [(v[0], k) for k,v in TYPES.items()]



class PascalTypeArray(Struct):

    def __init__(self, count, version, **kwargs):
        Struct.__init__(self, **kwargs)
        self._count = count
        self._version = version


    def parse(self):
        for i in range(self._count):
            t = yield PascalType(self._version)
            self.parser.types.append(t)


class PascalType(Struct):

    def __init__(self, version, **kwargs):
        Struct.__init__(self, **kwargs)
        self.version = version

    def parse(self):
        typ = yield UInt8(name="Type", values=TYPES_ENUM)
        exported = typ > 0x7f
        typ = typ & 0x7f
        size = 0
        if typ not in TYPES:
            raise FatalError(f"Unsupported type {typ:x}")

        if typ == 25:
            yield PascalString(name="ClassName")
        elif typ == 12:
            yield UInt32(name="TypeIndex")
        elif typ == 21:
            sz = yield UInt32(name="ParametersSize")
            yield Bytes(sz, name="ParametersData")
        elif typ == 22:
            yield UInt32(name="TypeIndex")
            yield UInt32(name="Size")
            if self.version > 22:
                yield UInt32(name="Start")
        elif typ == 23:
            size = yield UInt32(name="BitCount")
        elif typ == 26:
            yield GUID(name="Guid")
        elif typ == 11:
            types_count = yield UInt32(name="Count")
            yield Array(types_count, UInt32(), name="Types")

        if exported:
            yield PascalString(name="ExportName")

        if self.version > 21:
            yield PascalAttributeArray(name="Attributes")
            
        


class PascalAttributeArray(Struct):

    def parse(self):
        count = yield UInt32(name="Count")
        for i in range(count):
            yield PascalAttribute(name="Attribute")

            
class PascalAttribute(Struct):


    def parse(self):
        yield PascalString(name="AttributeName")
        count = yield UInt32(name="FieldCount")
        for i in range(count):
            ti = yield UInt32(name="TypeIndex")
            if ti >= len(self.parser.types):
                raise FatalError(f"Invalid attribute type index {ti}")
            
            type_object = self.parser.types[ti]
            ti = type_object["Type"] & 0x7f
            if ti == 0:
                yield Bytes(28, name="Unknown")
            elif ti == 1:
                yield UInt8(name="Value")
            elif ti == 2:
                yield Int8(name="Value")
            elif ti == 3:
                yield UInt16(name="Value")
            elif ti == 4:
                yield Int16(name="Value")
            elif ti in (5, 13, 25, 26):
                yield UInt32(name="Value")
            elif ti == 6:
                yield Int32(name="Value")
            elif ti == 7:
                yield Float(name="Value")
            elif ti == 8:
                yield Double(name="Value")
            elif ti == 9:
                yield Bytes(10, name="ExtendedValue")
            elif ti in (10, 14):
                yield PascalString(name="Value")
            elif ti in (19, 28):
                yield UnicodeString(name="Value")
            elif ti == 17:
                yield Int64(name="Value")
            elif ti == 24:
                yield UInt64(name="Value")
            elif ti == 18:
                yield String(1, zero_terminated=False, name="Value")
            elif ti == 20:
                yield StringU16LE(1, zero_terminated=False, name="Value")
            elif ti == 22:
                size = type_object["Size"]
                yield Bytes(size, name="ArrayData")
            elif ti == 23:
                size = type_object["BitCount"]
                if size % 8:
                    size += 8
                yield Bytes(size // 8, name="Bits")


class PascalFunction(Struct):

    def parse(self):
        flags = yield BitsField(
            Bit(name="Imported", comment="Function is imported from an external dll"),
            Bit(name="Exported", comment="Function is exported"),
            Bit(name="HasAttributes", comment="Attributes follow"),
            name="Flags", comment="function characteristics")
        if flags["Imported"]:
            yield PascalString8(name="ImportName")
            if flags["Exported"]:
                yield PascalString(name="DllName")
        else:
            yield Offset32(name="Offset")
            yield UInt32(name="Size")
            if flags["Exported"]:
                yield PascalString(name="ExportName")
                yield PascalString(name="Prototype")
        if flags["HasAttributes"]:
            yield PascalAttributeArray(name="Attributes")
        

class PascalVariable(Struct):

    def parse(self):
        yield UInt32(name="TypeIndex")
        flags = yield BitsField(
            Bit(name="Exported", comment="Variable is exported"),
            name="Flags", comment="variable characteristics")
        if flags["Exported"]:
            yield PascalString(name="ExportName")
        


class PascalScriptAnalyzer(FileTypeAnalyzer):
    category = malcat.FileType.PROGRAM
    name = "PascalScript"
    regexp = r"IFPS[\x01-\x20]\x00\x00\x00(?:...\x00){3}"

  
    def parse(self, hint):
        self.set_architecture(malcat.Architecture.PASCALSCRIPT)
        hdr = yield PascalScriptHeader()

        self.types = []
        self.fns = []
        self.vars = []
        self.code_start = 0
        self.code_size = 0
        arrays_start = self.tell()
        
        # types
        typecount = hdr["TypesCount"]
        if typecount == 0 or typecount > 10000:
            raise FatalError("Types")
        yield PascalTypeArray(typecount, hdr["Version"], name="TypesArray")
        self.confirm()

        # functions
        fncount = hdr["FunctionsCount"]
        if fncount == 0 or fncount > 10000:
            raise FatalError("Functions")
        self.fns = yield VariableArray(fncount, PascalFunction, name="FunctionsArray")

        ep = hdr["EntryPoint"]
        if ep < self.fns.count:
            fn = self.fns[ep]
            if "Offset" in fn:
                self.add_symbol(fn["Offset"], "EntryPoint", malcat.FileSymbol.ENTRY)

        for fn in self.fns:
            if "ImportName" in fn and "String" in fn["ImportName"]:
                self.add_symbol(fn.offset, fn["ImportName"]["String"], malcat.FileSymbol.IMPORT)
            elif "DllName" in fn:
                s = fn["DllName"]["String"]
                if s.startswith("dll:"):
                    s = s[4:]
                    if s.startswith("files:"):
                        s = s[6:]
                    module, fname = s.split("\x00")[:2]
                    module = module.split(".")[0]
                    self.add_symbol(fn.offset, f"{module}.{fname}", malcat.FileSymbol.IMPORT)
                elif s.startswith("class:"):
                    s = s[6:]
                    t = s.split("|")
                    if len(t) >= 2:
                        module, fname = t[:2]
                    else:
                        module = "" 
                        fname = s
                    self.add_symbol(fn.offset, f"{module}.{fname}", malcat.FileSymbol.IMPORT)
        

        # vars
        varcount = hdr["VariablesCount"]
        if varcount > 10000:
            raise FatalError("Variables")
        self.vars = yield VariableArray(varcount, PascalVariable, name="VariablesArray")


        self.add_section(".DATA", arrays_start, self.tell() - arrays_start, r=True, w=True)


        mincode = self.size()
        maxcode = 0
        for f in self.fns:
            if "Offset" in f:
                offset = f["Offset"]
                mincode = min(mincode, offset)
                maxcode = max(maxcode, offset + f["Size"])
        if maxcode > mincode:
            self.add_section(".CODE", mincode, maxcode - mincode, r=True, x=True)

        # used by InnoParser's pwd guess heuristic
        self.code_start = mincode
        self.code_size = maxcode - mincode
