from filetypes.base import *
import malcat
import datetime
import struct
import enum


class GGUFHeader(Struct):

    def parse(self):
        magic = yield String(4, name="Magic")
        yield UInt32(name="Version")
        yield UInt64(name="NumberOfTensors")
        yield UInt64(name="NumberOfValues")


GGUF_TYPE = UInt32(name="Type", values=[
    ("GGUF_TYPE_UINT8", 0),
    ("GGUF_TYPE_INT8", 1),
    ("GGUF_TYPE_UINT16", 2),
    ("GGUF_TYPE_INT16", 3),
    ("GGUF_TYPE_UINT32", 4),
    ("GGUF_TYPE_INT32", 5),
    ("GGUF_TYPE_FLOAT32", 6),
    ("GGUF_TYPE_BOOL", 7),
    ("GGUF_TYPE_STRING", 8),
    ("GGUF_TYPE_ARRAY", 9),
    ("GGUF_TYPE_UINT64", 10),
    ("GGUF_TYPE_INT64", 11),
    ("GGUF_TYPE_FLOAT64", 12),
])

class GGUFArray(Struct):

    def parse(self):
        t = yield GGUF_TYPE
        n = yield UInt64(name="Size")
        if t < len(GGUF_TYPE_TO_CLASS) and t != 9:
            cls = GGUF_TYPE_TO_CLASS[t]
        else:
            raise FatalError(f"Unsupported array type {t}")
        if t == 8:
            if n > 1000:
                # it's not only useless (most likely a tokens list), but yield 100k+ element structures degrades performances
                array_size = self.parser.compute_size_prefixed_elements_array_size(n, 8)
                yield Bytes(array_size, name="StringsArray")
            else:
                yield VariableArray(n, cls, name="Values")        
        else:
            # more performant
            yield Array(n, cls(), name="Values")        


GGUF_TYPE_TO_CLASS = [
        UInt8,
        Int8,
        UInt16,
        Int16,
        UInt32,
        Int32,
        Float,
        UInt8,
        PascalString64,
        GGUFArray,
        UInt64,
        Int64,
        Double,
]

class GgmlType(enum.IntEnum):
    GGML_TYPE_F32     = 0
    GGML_TYPE_F16     = 1
    GGML_TYPE_Q4_0    = 2
    GGML_TYPE_Q4_1    = 3
    GGML_TYPE_Q5_0    = 6
    GGML_TYPE_Q5_1    = 7
    GGML_TYPE_Q8_0    = 8
    GGML_TYPE_Q8_1    = 9
    GGML_TYPE_Q2_K    = 10
    GGML_TYPE_Q3_K    = 11
    GGML_TYPE_Q4_K    = 12
    GGML_TYPE_Q5_K    = 13
    GGML_TYPE_Q6_K    = 14
    GGML_TYPE_Q8_K    = 15
    GGML_TYPE_IQ2_XXS = 16
    GGML_TYPE_IQ2_XS  = 17
    GGML_TYPE_IQ3_XXS = 18
    GGML_TYPE_IQ1_S   = 19
    GGML_TYPE_IQ4_NL  = 20
    GGML_TYPE_IQ3_S   = 21
    GGML_TYPE_IQ2_S   = 22
    GGML_TYPE_IQ4_XS  = 23
    GGML_TYPE_I8      = 24
    GGML_TYPE_I16     = 25
    GGML_TYPE_I32     = 26
    GGML_TYPE_I64     = 27
    GGML_TYPE_F64     = 28
    GGML_TYPE_IQ1_M   = 29
    GGML_TYPE_BF16    = 30
    GGML_TYPE_TQ1_0   = 34
    GGML_TYPE_TQ2_0   = 35
    GGML_TYPE_COUNT   = 39


GGML_TYPES = {
    GgmlType.GGML_TYPE_I8.value: ("i8", 1, 1, False),
    GgmlType.GGML_TYPE_I16.value: ("i16", 1, 2, False),
    GgmlType.GGML_TYPE_I32.value: ("i32", 1, 4, False),
    GgmlType.GGML_TYPE_I64.value: ("i64", 1, 8, False),
    GgmlType.GGML_TYPE_F64.value: ("f64", 1, 8, False),
    GgmlType.GGML_TYPE_F32.value: ("f32", 1, 4, False),
    GgmlType.GGML_TYPE_F16.value: ("f16", 1, 2, False),
    GgmlType.GGML_TYPE_Q4_0.value: ("q4_0", 32, 4 + 16, True),
    GgmlType.GGML_TYPE_Q4_1.value: ("q4_1", 32, 8 + 16, True),
    GgmlType.GGML_TYPE_Q5_0.value: ("q5_0", 32, 8 + 16, True),
    GgmlType.GGML_TYPE_Q5_1.value: ("q5_1", 32, 12 + 16, True),
    GgmlType.GGML_TYPE_Q8_0.value: ("q8_0", 32, 4 + 32, True),
    GgmlType.GGML_TYPE_Q8_1.value: ("q8_1", 32, 8 + 32, True),
    GgmlType.GGML_TYPE_Q2_K.value: ("q2_K", 256, 16 + 64 + 8, True),
    GgmlType.GGML_TYPE_Q3_K.value: ("q3_K", 256, 32 + 64 + 12 + 4, True),
    GgmlType.GGML_TYPE_Q4_K.value: ("q4_K", 256, 8 + 12 + 128, True),
    GgmlType.GGML_TYPE_Q5_K.value: ("q5_K", 256, 8 + 12 + 32 + 128, True),
    GgmlType.GGML_TYPE_Q6_K.value: ("q6_K", 256, 128 + 64 + 16 + 4, True),
    GgmlType.GGML_TYPE_IQ2_XXS.value: ("iq2_xxs", 256, 4 + 2 * 32, True),
    GgmlType.GGML_TYPE_IQ2_XS.value: ("iq2_xs", 256,  4 + 2 * 32 + 8, True),
    GgmlType.GGML_TYPE_IQ3_XXS.value: ("iq3_xxs", 256, 4 + 3 * 32, True),
    GgmlType.GGML_TYPE_IQ3_S.value: ("iq3_s", 256, 4 + 64 + 8 + 32 + 4, True),
    GgmlType.GGML_TYPE_IQ2_S.value: ("iq2_s", 256, 4 + 2 * 32 + 8, True),
    GgmlType.GGML_TYPE_IQ1_S.value: ("iq1_s", 256, 4 + 32 + 2 * 8, True),
    GgmlType.GGML_TYPE_IQ1_M.value: ("iq1_m", 256, 32 + 16 + 8, True),
    GgmlType.GGML_TYPE_IQ4_NL.value: ("iq4_nl", 32, 4 + 16, True),
    GgmlType.GGML_TYPE_IQ4_XS.value: ("iq4_xs", 256, 4 + 2 + 4 + 128, True),
    GgmlType.GGML_TYPE_Q8_K.value: ("q8_K", 256, 4 + 256 + 2 * 16, True),
    GgmlType.GGML_TYPE_BF16.value: ("bf16", 1, 2, False),
    GgmlType.GGML_TYPE_TQ1_0.value: ("tq1_0", 256, 4 + 4 + 48, True),
    GgmlType.GGML_TYPE_TQ2_0.value: ("tq2_0", 256, 4 + 64, True),
}

GGML_TYPE = UInt32(name="TensorType", values=[(e.name, e.value) for e in GgmlType])






class GGUFKeyValuePair(Struct):

    
    def parse(self):
        yield PascalString64(name="Key")
        t = yield GGUF_TYPE
        if t < len(GGUF_TYPE_TO_CLASS):
            yield GGUF_TYPE_TO_CLASS[t](name="Value")
        else:
            raise FatalError(f"Unsupported value type {t}")


class GGUFVariant(Struct):

    def parse(self):
        yield String(4, name="Magic")
        yield UInt32(name="Version")
        yield UInt64(name="NumberOfTensors")
        yield UInt64(name="NumberOfValues")



class GGUFTensorInfo(Struct):

    def parse(self):
       yield PascalString64(name="Name")
       ndims = yield UInt32(name="NumberOfDimensions")
       yield Array(min(4, ndims), Int64(), name="NumberOfElements")
       yield GGML_TYPE
       yield UInt64(name="Offset", comment="offset relative ot the start of tensors")




def align(val, what, down=False):
    if not what:
        return val
    if val % what:
        if down:
            val -= val % what
        else:
            val += what - (val % what)
    return val



class GGUFAnalyzer(FileTypeAnalyzer):
    category = malcat.FileType.DOCUMENT
    name = "GGUF"
    regexp = r"GGUF[\x02-\xff]\x00\x00\x00"


    def __init__(self):
        FileTypeAnalyzer.__init__(self)
        self.regions = []
        self.is64 = False

    def parse(self, hint=""):
        hdr = yield GGUFHeader()
        yield VariableArray(hdr["NumberOfValues"], GGUFKeyValuePair, name="Dictionnary")
        tensor_infos = yield VariableArray(hdr["NumberOfTensors"], GGUFTensorInfo, name="TensorInformations")
        self.set_imagebase(0x1000000)
        self.confirm()

        self.start_of_tensors = align(self.tell(), 32)

        for ti in tensor_infos:
            data_start = self.start_of_tensors + ti["Offset"]
            _, block_size, type_size, __ = GGML_TYPES[ti["TensorType"]]

            # https://github.com/ggml-org/llama.cpp/blob/0aedae00e6fb48680324a5ac5da9cba0e35de6b5/ggml/src/gguf.cpp#L589
            ne = [x.value for x in ti["NumberOfElements"]]
            while len(ne) < 4:
                ne.append(1)
            nb = [type_size, type_size * (ne[0] // block_size)]
            for i in range(2, 4):
                nb.append(nb[i-1]*ne[i-1])

            # https://github.com/ggml-org/llama.cpp/blob/576c82eda210ca0111c04f5256bf77897a4d4cc4/ggml/src/ggml.c#L1186
            if block_size == 1:
                tensor_size = type_size
                for i in range(4):
                    tensor_size += (ne[i] - 1) * nb[i]
            else:
                tensor_size = ne[0] * nb[0] // block_size;
                for i in range(1, 4):
                    tensor_size += (ne[i] - 1) * nb[i]

            if tensor_size:
                self.add_section(ti["Name"]["String"], 
                        data_start, tensor_size
                        )
