from transforms.base import *
import struct

from transforms.arithmetic import GenericArithmetic, FORMAT_INFOS




class CircularXor(Transform):
    """
    Simple circular XOR cipher using a circular key
    """
    category = "binary"
    name = "xor8"
    icon = "wxART_VIEW_HEXA"

    def run(self, data:bytes, key:bytes=b"\x00"):
        if not key:
            return data
        res = bytearray(len(data))
        for i in range(len(data)):
            res[i] = data[i] ^ key[i % len(key)]
        return res


class CircularAdd(Transform):
    """
    Simple circular ADD cipher using a circular key
    """
    category = "binary"
    name = "add8"
    icon = "wxART_VIEW_HEXA"

    def run(self, data:bytes, key:bytes=b"\x00"):
        if not key:
            return data
        res = bytearray(len(data))
        for i in range(len(data)):
            res[i] = (data[i] + key[i % len(key)]) & 0xff
        return res


class CircularSub(Transform):
    """
    Simple circular SUB cipher using a circular key
    """
    category = "binary"
    name = "sub8"
    icon = "wxART_VIEW_HEXA"

    def run(self, data:bytes, key:bytes=b"\x00"):
        if not key:
            return data
        res = bytearray(len(data))
        for i in range(len(data)):
            res[i] = (data[i] + (~key[i % len(key)] + 1)) & 0xff
        return res    


class CircularOr(Transform):
    """
    Simple circular OR using a circular key
    """
    category = "binary"
    name = "or"
    icon = "wxART_VIEW_HEXA"

    def run(self, data:bytes, mask:bytes=b"\x00"):
        if not mask:
            return data
        res = bytearray(len(data))
        for i in range(len(data)):
            res[i] = data[i] | mask[i % len(mask)]
        return res

class CircularAnd(Transform):
    """
    Simple circular AND using a circular key
    """
    category = "binary"
    name = "and"
    icon = "wxART_VIEW_HEXA"

    def run(self, data:bytes, mask:bytes=b"\x00"):
        res = bytearray(len(data))
        for i in range(len(data)):
            res[i] = data[i] & mask[i % len(mask)]
        return res

class Not(Transform):
    """
    Inverse all bits
    """
    category = "binary"
    name = "not"
    icon = "wxART_VIEW_HEXA"

    def run(self, data:bytes):
        res = bytearray(len(data))
        for i, c in enumerate(data):
            res[i] = ~c & 0xff
        return res

class Rol(Transform, GenericArithmetic):
    """
    rotate left using a circular byte buffer (key) as shift index
    """
    category = "binary"
    name = "rol"
    icon = "wxART_VIEW_HEXA"

    def run(self, data:bytes, key:bytes=b"\x00", width:["byte", "word", "dword", "qword"]="byte"):
        letter, max_bits = FORMAT_INFOS[width]
        if not key:
            raise ValueError("empty key")
        rol = lambda val, index: \
            (val << key[index%len(key)]%max_bits) & (2**max_bits-1) | \
            ((val & (2**max_bits-1)) >> (max_bits-(key[index%len(key)]%max_bits)))
        return GenericArithmetic.run(self, rol, data, 0, width, "lsb", lambda x: x+1)
   

class Ror(Transform, GenericArithmetic):
    """
    rotate right using a circular byte buffer (key) as shift index
    """
    category = "binary"
    name = "ror"
    icon = "wxART_VIEW_HEXA"

    def run(self, data:bytes, key:bytes=b"\x00", width:["byte", "word", "dword", "qword"]="byte"):
        letter, max_bits = FORMAT_INFOS[width]
        if not key:
            raise ValueError("empty key")
        ror = lambda val, index: \
            ((val & (2**max_bits-1)) >> key[index%len(key)]%max_bits) | \
            (val << (max_bits-(key[index%len(key)]%max_bits)) & (2**max_bits-1)) 
        return GenericArithmetic.run(self, ror, data, 0, width, "lsb", lambda x: x+1)    
