from filetypes.base import *
import malcat
import struct
import datetime


MAGIC_TO_VERSION = {
#        0x0A0D2E89: (1,3), 
#        0x0A0D1704: (1,4), 
#        0x0A0D4E99: (1,5), 
#        0x0A0DC4FC: (1,6), 

#        0x0A0DC687: (2,0), 
#        0x0A0DEB2A: (2,1), 
#        0x0A0DED2D: (2,2), 
#        0x0A0DF23B: (2,3), 
#        0x0A0DF26D: (2,4), 
#        0x0A0DF2B3: (2,5), 
        0x0A0DF2D1: (2,6), 
        0x0A0DF303: (2,7), 

#        0x0A0D0C3A: (3,0), 
#        0x0A0D0C4E: (3,1), 
#        0x0A0D0C6C: (3,2), 
        0x0A0D0C9E: (3,3), 
        0x0A0D0CEE: (3,4), 
        0x0A0D0D16: (3,5), 
        0x0A0D0D17: (3,5,3), 
        0x0A0D0D33: (3,6), 
        0x0A0D0D42: (3,7), 
        0x0A0D0D55: (3,8), 
        0x0A0D0D61: (3,9), 
        0x0A0D0D6F: (3,10), 
        0x0A0D0DA7: (3,11), 
        0x0A0D0DCB: (3,12), 
        0x0A0D0DF3: (3,13), 
        0x0A0D0E2b: (3,14), 
}


def build_magic_regex(magics):
    r = []
    for k, v in magics.items():
        vers = []
        if v[0] < 3 or v[1] > 1:
            vers.append(struct.pack("<I", k)[:2])
        if v[0] == 3 and v[1] <= 1 or v[0] == 2 or v[0] == 1 and v[1] >= 6:
            # unicode version
            vers.append(struct.pack("<I", k + 1)[:2])
        for todo in vers:
            r.append(r"\x{:02x}\x{:02x}".format(todo[0], todo[1]))
    return r"(?:{})\x0d\x0a.{{4,12}}[c\xe3]".format("|".join(r))

class PythonHeader(Struct):

    def parse(self):
        magic = yield UInt32(name="Magic", comment="python magic version")
        version = MAGIC_TO_VERSION.get(magic, None)
        if version is None:
            version = MAGIC_TO_VERSION.get(magic + 1, None)
        if version is None:
            raise FatalError("Bad magic")
        flags = 0
        if version[0] == 3 and version[1] >= 7:
            flags = yield BitsField(
                    Bit(name="HasChecksum"),
                    NullBits(31),
                    name = "Flags")
            if flags["HasChecksum"]:
                yield UInt32(name="Checksum")
            else:
                yield Timestamp(name="Timestamp")
        else:
            yield Timestamp(name="Timestamp")
        if version[0] == 3 and version[1] >= 3:
            yield UInt32(name="Size")



TYPE_NULL                       = ord('0')
TYPE_NONE                       = ord('N')
TYPE_FALSE                      = ord('F')
TYPE_TRUE                       = ord('T')
TYPE_STOPITER                   = ord('S')
TYPE_ELLIPSIS                   = ord('.')
TYPE_INT                        = ord('i')
TYPE_INT64                      = ord('I')
TYPE_FLOAT                      = ord('f')
TYPE_BINARY_FLOAT               = ord('g')
TYPE_COMPLEX                    = ord('x')
TYPE_BINARY_COMPLEX             = ord('y')
TYPE_LONG                       = ord('l')
TYPE_STRING                     = ord('s')
TYPE_INTERNED                   = ord('t')
TYPE_REF                        = ord('r')
TYPE_STRING_REF                 = ord('R')
TYPE_TUPLE                      = ord('(')
TYPE_LIST                       = ord('[')
TYPE_DICT                       = ord('{')
TYPE_CODE                       = ord('c')
TYPE_UNICODE                    = ord('u')
TYPE_UNKNOWN                    = ord('?')
TYPE_SET                        = ord('<')
TYPE_FROZENSET                  = ord('>')
TYPE_ASCII                      = ord('a')
TYPE_ASCII_INTERNED             = ord('A')
TYPE_SMALL_TUPLE                = ord(')')
TYPE_SHORT_ASCII                = ord('z')
TYPE_SHORT_ASCII_INTERNED       = ord('Z')
TYPE_ELLIPSIS                   = ord('.')
TYPE_SLICE                      = ord(':')


TYPE_ENUM = [
        ("NULL", TYPE_NULL),
        ("NULL_REF", TYPE_NULL | 0x80),
        ("NONE", TYPE_NONE),
        ("NONE_REF", TYPE_NONE | 0x80),
        ("FALSE", TYPE_FALSE),
        ("FALSE_REF", TYPE_FALSE | 0x80),
        ("TRUE", TYPE_TRUE),
        ("TRUE_REF", TYPE_TRUE | 0x80),
        ("STOPITER", TYPE_STOPITER),
        ("STOPITER_REF", TYPE_STOPITER | 0x80),
        ("ELLIPSIS", TYPE_ELLIPSIS),
        ("ELLIPSIS_REF", TYPE_ELLIPSIS | 0x80),
        ("SLICE", TYPE_SLICE),
        ("SLICE_REF", TYPE_SLICE | 0x80),
        ("INT", TYPE_INT),
        ("INT_REF", TYPE_INT | 0x80),
        ("INT64", TYPE_INT64),
        ("INT64_REF", TYPE_INT64 | 0x80),
        ("FLOAT", TYPE_FLOAT),
        ("FLOAT_REF", TYPE_FLOAT | 0x80),
        ("BINARY_FLOAT", TYPE_BINARY_FLOAT),
        ("BINARY_FLOAT_REF", TYPE_BINARY_FLOAT | 0x80),
        ("COMPLEX", TYPE_COMPLEX),
        ("COMPLEX_REF", TYPE_COMPLEX | 0x80),
        ("BINARY_COMPLEX", TYPE_BINARY_COMPLEX),
        ("BINARY_COMPLEX_REF", TYPE_BINARY_COMPLEX | 0x80),
        ("LONG", TYPE_LONG),
        ("LONG_REF", TYPE_LONG | 0x80),
        ("STRING", TYPE_STRING),
        ("STRING_REF", TYPE_STRING | 0x80),
        ("INTERNED", TYPE_INTERNED),
        ("INTERNED_REF", TYPE_INTERNED | 0x80),
        ("REF", TYPE_REF),
        ("STRINGREF", TYPE_STRING_REF),
        ("TUPLE", TYPE_TUPLE),
        ("TUPLE_REF", TYPE_TUPLE | 0x80),
        ("LIST", TYPE_LIST),
        ("LIST_REF", TYPE_LIST | 0x80),
        ("DICT", TYPE_DICT),
        ("DICT_REF", TYPE_DICT | 0x80),
        ("CODE", TYPE_CODE),
        ("CODE_REF", TYPE_CODE | 0x80),
        ("UNICODE", TYPE_UNICODE),
        ("UNICODE_REF", TYPE_UNICODE | 0x80),
        ("UNKNOWN", TYPE_UNKNOWN),
        ("UNKNOWN_REF", TYPE_UNKNOWN | 0x80),
        ("SET", TYPE_SET),
        ("SET_REF", TYPE_SET | 0x80),
        ("FROZENSET", TYPE_FROZENSET),
        ("FROZENSET_REF", TYPE_FROZENSET | 0x80),
        ("ASCII", TYPE_ASCII),
        ("ASCII_REF", TYPE_ASCII | 0x80),
        ("ASCII_INTERNED", TYPE_ASCII_INTERNED),
        ("ASCII_INTERNED_REF", TYPE_ASCII_INTERNED | 0x80),
        ("SMALL_TUPLE", TYPE_SMALL_TUPLE),
        ("SMALL_TUPLE_REF", TYPE_SMALL_TUPLE | 0x80),
        ("SHORT_ASCII", TYPE_SHORT_ASCII),
        ("SHORT_ASCII_REF", TYPE_SHORT_ASCII | 0x80),
        ("SHORT_ASCII_INTERNED", TYPE_SHORT_ASCII_INTERNED),
        ("SHORT_ASCII_INTERNED_REF", TYPE_SHORT_ASCII_INTERNED | 0x80),
]



# https://github.com/python/cpython/blob/master/Python/marshal.c

class PyCode(Struct):

    def parse(self):
        yield UInt8(name="Type", values=TYPE_ENUM)
        if self.parser.version < (2, 3):
            uint = UInt16
        else:
            uint = UInt32
        yield uint(name="ArgCount")
        if self.parser.version >= (3, 8):
            yield uint(name="PosOnlyArgCount")
        if self.parser.version >= (3, 0):
            yield uint(name="KeywordOnlyArgCount")
        if self.parser.version < (3,11):
            yield uint(name="NumLocals")
        yield uint(name="StackSize")
        flags = [
                Bit(name="Optimized"),
                Bit(name="NewLocals"),
                Bit(name="VarArgs", comment="uses *args syntax"),
                Bit(name="VarKeywords", comment="uses **kwargs syntax"),
                Bit(name="Nested", comment="code object is a nested function"),
                Bit(name="Generator", comment="function is a generator"),
                Bit(name="NoFree", comment="there are no free or cell variables"),
                Bit(name="CoRoutine", comment="code object is a coroutine function"),
                Bit(name="IterableCoRoutine", comment="used to transform generators into generator-based coroutines"),
                Bit(name="AsyncGenerator", comment="code object is an asynchronous generator function"),
                NullBits(3),
                Bit(name="FutureDivision", comment="compiled with future division enabled"),
                ]
        if self.parser.version < (2, 3):
            flags.append(NullBits(2))
        else:
            flags.append(NullBits(2 + 16))
        yield BitsField(*flags, name="Flags")
        yield from self.parser.parse_object(name="Bytecode")
        yield from self.parser.parse_object(name="Constants")
        yield from self.parser.parse_object(name="Names")
        yield from self.parser.parse_object(name="VarNames")
        if self.parser.version >= (3,11):
            yield from self.parser.parse_object(name="Locals")
        elif self.parser.version >= (2,1):
            yield from self.parser.parse_object(name="FreeVars")
            yield from self.parser.parse_object(name="CellVars")
        yield from self.parser.parse_object(name="Filename")
        yield from self.parser.parse_object(name="Name")
        if self.parser.version >= (3,11):
            yield from self.parser.parse_object(name="QualifiedName")
        if self.parser.version >= (1,5):
            yield uint(name="FirstLine")
            yield from self.parser.parse_object(name="LineNumberTable")
        if self.parser.version >= (3,11):
            yield from self.parser.parse_object(name="ExceptionTable")


class PyString(Struct):

    def parse(self):
        typ = yield UInt8(name="Type", values=TYPE_ENUM)
        typ = typ & 0x7F
        if typ in (TYPE_SHORT_ASCII, TYPE_SHORT_ASCII_INTERNED):
            sz = yield UInt8(name="Size")
        else:
            sz = yield UInt32(name="Size")
        if sz:
            if typ in (TYPE_UNICODE, TYPE_INTERNED):    # should include TYPE_STRING too, bute bytecode uses this type ...
                yield StringUtf8(sz, zero_terminated=False, name="Value")
            else:
                yield String(sz, zero_terminated=False, name="Value")


class PyAtom(Struct):

    def parse(self):
        yield UInt8(name="Type", values=TYPE_ENUM)


class PyNone(PyAtom): pass
class PyFalse(PyAtom): pass
class PyTrue(PyAtom): pass
class PyNull(PyAtom): pass
class PyEllipsis(PyAtom): pass

class PySlice(Struct):

    def parse(self):
        typ = yield UInt8(name="Type", values=TYPE_ENUM)
        yield from self.parser.parse_object(name="Start")
        yield from self.parser.parse_object(name="Stop")
        yield from self.parser.parse_object(name="Step")

class PySequence(Struct):

    def parse(self):
        typ = yield UInt8(name="Type", values=TYPE_ENUM)
        typ = typ & 0x7F
        if typ == TYPE_SMALL_TUPLE:
            sz = yield UInt8(name="Size")
        else:
            sz = yield UInt32(name="Size")
        for i in range(sz):
            yield from self.parser.parse_object(name="Element.{}".format(i))


class PyDict(Struct):

    def parse(self):
        typ = yield UInt8(name="Type", values=TYPE_ENUM)
        if typ != TYPE_DICT:
            raise FatalError("Invalid dict type")
        i = 0
        while True:
            k, = struct.unpack("<B", self.parser.read(self.parser.tell(), 1))
            if k == TYPE_NULL:
                yield PyNull(name="Terminator", comment="Dictionary terminator")
                break
            key = yield from self.parser.parse_object(name="Element.{}.Key".format(i))
            yield from self.parser.parse_object(name="Element.{}.Value".format(i))            
            i += 1


class PyInt64(Struct):

    def parse(self):
        typ = yield UInt8(name="Type", values=TYPE_ENUM)
        yield Int64(name="Value")

class PyInt(Struct):

    def parse(self):
        typ = yield UInt8(name="Type", values=TYPE_ENUM)
        yield Int32(name="Value")        

class PyLong(Struct):

    def parse(self):
        typ = yield UInt8(name="Type", values=TYPE_ENUM)
        sz = yield Int32(name="Size")
        yield Bytes(abs(sz) * 2, name="Value")

class PyFloat(Struct):

    def parse(self):
        typ = yield UInt8(name="Type", values=TYPE_ENUM)
        sz = yield UInt8(name="Size")
        yield Bytes(abs(sz), name="Value")

class PyComplex(Struct):

    def parse(self):
        typ = yield UInt8(name="Type", values=TYPE_ENUM)
        sz = yield UInt8(name="RealSize")
        yield Bytes(abs(sz), name="RealValue")
        sz = yield UInt8(name="ImaginarySize")
        yield Bytes(abs(sz), name="ImaginaryValue")


class PyBinaryFloat(Struct):

    def parse(self):
        typ = yield UInt8(name="Type", values=TYPE_ENUM)
        yield Double(name="Value")


class PyRef(Struct):

    def parse(self):
        typ = yield UInt8(name="Type", values=TYPE_ENUM)
        yield UInt32(name="Reference", comment="index in the list of referenced objects")


def read_object(obj, references=[], interned=[]):
    typ = obj["Type"]
    typ = typ & 0x7f
    if typ in (TYPE_INT, TYPE_INT64):
        return obj["Value"]
    elif typ in (TYPE_STRING, TYPE_INTERNED, TYPE_SHORT_ASCII_INTERNED, TYPE_SHORT_ASCII, TYPE_ASCII, TYPE_UNICODE):
        if "Value" in obj:
            return obj["Value"]
        else:
            return ""
    elif typ == TYPE_FALSE:
        return False
    elif typ == TYPE_TRUE:
        return True
    elif Type == TYPE_NONE:
        return None
    elif typ == TYPE_LONG:
        val = 0
        shift = 0
        b = obj["Value"]
        for i in range(0, len(b), 2):
            val += int.from_bytes(b[i:i+2], byteorder='little') << shift
            shift += 15
        return val
    elif typ in (TYPE_TUPLE, TYPE_SMALL_TUPLE, TYPE_LIST, TYPE_SET, TYPE_FROZENSET):
        members = []
        for i in range(2, obj.count):
            members.append(read_object(obj[i], references))
        if typ == TYPE_TUPLE or typ == TYPE_SMALL_TUPLE:
            members = tuple(members)
        elif typ == TYPE_SET:
            members = set(members)
        elif typ == TYPE_FROZENSET:
            members = frozenset(members)
        return members
    elif typ == TYPE_DICT:
        items = []
        for i in range(1, obj.count - 1, 2):
            items.append((read_object(obj[i], references), read_object(obj[i+1], references)))
        return dict(items)
    elif typ == TYPE_REF:
        ref = obj["Reference"]
        if ref >= len(references):
            raise IndexError("Unresolved reference {}".format(ref))
        return read_object(references[ref], references)
    elif typ == TYPE_STRING_REF:
        ref = obj["Reference"]
        if ref >= len(interned):
            raise IndexError("Unresolved interned reference {}: {}".format(ref, interned))
        return read_object(interned[ref])
    else:
        raise KeyError("Unsupported Type {}".format(chr(typ)))


TYPE_TO_OBJECT = {
    TYPE_NULL: PyNull,
    TYPE_CODE: PyCode,
    TYPE_STRING: PyString,
    TYPE_UNICODE: PyString,
    TYPE_ASCII: PyString,
    TYPE_ASCII_INTERNED: PyString,
    TYPE_SHORT_ASCII: PyString,
    TYPE_SHORT_ASCII_INTERNED: PyString,
    TYPE_INTERNED : PyString,
    TYPE_SMALL_TUPLE: PySequence,
    TYPE_TUPLE: PySequence,
    TYPE_LIST: PySequence,
    TYPE_SET: PySequence,
    TYPE_FROZENSET: PySequence,
    TYPE_DICT: PyDict,
    TYPE_INT64: PyInt64,
    TYPE_INT: PyInt,
    TYPE_LONG: PyLong,
    TYPE_COMPLEX: PyComplex,
    TYPE_NONE: PyNone,
    TYPE_REF: PyRef,
    TYPE_STRING_REF: PyRef,
    TYPE_FALSE: PyFalse,
    TYPE_TRUE: PyTrue,
    TYPE_BINARY_FLOAT: PyBinaryFloat,
    TYPE_FLOAT: PyFloat,
    TYPE_ELLIPSIS: PyEllipsis,
    TYPE_SLICE: PySlice,
}

class PythonAnalyzer(FileTypeAnalyzer):
    category = malcat.FileType.PROGRAM
    name = "PYC"
    regexp = build_magic_regex(MAGIC_TO_VERSION)

   
    def __init__(self):
        FileTypeAnalyzer.__init__(self)
        self.refs = []
        self.interned = []

    def read_object(self, obj):
        if obj is None:
            return None
        tt = obj["Type"] & 0x7F
        if tt == TYPE_REF:
            return self.read_object(self.refs[obj["Reference"]])
        return read_object(obj, self.refs, self.interned)

    def update_arch(self, version):
        if version == (2, 7):
            architecture = malcat.Architecture.PY27
        elif version >= (3, 14):
            architecture = malcat.Architecture.PY314
        elif version >= (3, 13):
            architecture = malcat.Architecture.PY313
        elif version >= (3, 12):
            architecture = malcat.Architecture.PY312
        elif version >= (3, 11):
            architecture = malcat.Architecture.PY311
        elif version >= (3, 10):
            architecture = malcat.Architecture.PY310
        elif version >= (3, 9):
            architecture = malcat.Architecture.PY39
        elif version >= (3, 8):
            architecture = malcat.Architecture.PY38
        elif version >= (3, 7):
            architecture = malcat.Architecture.PY37
        elif version >= (3, 6):
            architecture = malcat.Architecture.PY36
        else:
            architecture = malcat.Architecture.PY36
            print("Unsupported architecture for python {}.{}, defaulting to {}".format(*version, architecture))
        self.set_architecture(architecture)

    def parse(self, hint):
        ph = yield PythonHeader(category=Type.HEADER)
        self.version = MAGIC_TO_VERSION.get(ph["Magic"], None)
        if self.version is None:
            self.version = MAGIC_TO_VERSION.get(ph["Magic"] + 1, None)
        self.update_arch(self.version)

        
        self.add_metadata("Python version", ".".join(map(str, self.version)))
        if "Timestamp" in ph:
            self.add_metadata("Timestamp", ph["Timestamp"].strftime("%Y-%m-%d %H:%M:%S"))
        # parse code object
        for obj in self.parse_object(name="Module", comment="module defined in this file"):
            module = yield obj
            if module["Type"] & 0x7f != TYPE_CODE:
                raise FatalError("Root module object is not of type code, got {:x}".format(module["Type"] & 0x7f))
            module = self["Module"]
            if "Filename" in module:
                self.add_metadata("Filename", str(self.read_object(module["Filename"])))
            if "Name" in module:
                self.add_metadata("Name", str(self.read_object(module["Name"])))
            self.add_section("module", module.offset, module.size)

    def parse_object(self, name="", comment=""):
        object_type = self.read(self.tell(), 1)[0]
        object_class = TYPE_TO_OBJECT.get(object_type & 0x7F, None)
        if object_class is None:
            raise FatalError("Unknown object type: {:x}".format(object_type))
        if object_type & 0x80:
            ref_index = len(self.refs)
            self.refs.append(None)
        sa = yield object_class(name=name, comment="", category=Type.DATA)
        if object_type & 0x80:
            self.refs[ref_index] = sa
        if (object_type & 0x7f) in (TYPE_ASCII_INTERNED, TYPE_INTERNED, TYPE_SHORT_ASCII_INTERNED):
            self.interned.append(sa)
