from filetypes.base import *
import malcat 
import datetime


class P7xAnalyzer(FileTypeAnalyzer):
    category = malcat.FileType.DOCUMENT
    name = "P7X"
    regexp = r"PKCX"

    def parse(self, hint):
        yield Bytes(4, name="Signature", category=Type.HEADER)
        certs_meta = []
        while self.remaining():
            data = self.read(size=min(self.remaining(), 65536))
            parsed, meta = parse_der_certificate(data)
            if meta and parsed:
                certs_meta.append(meta)
                yield Bytes(parsed, name=meta.get("Subject"), category=Type.DATA)
                self.confirm()
            else:
                break
        self.add_section("Certificates", 4, self.tell() - 4)
        if certs_meta:
            if len(certs_meta) == 1:
                for k, v in certs_meta[0].items():
                    self.add_metadata(k, v, category="Certificate")
            else:
                for i, meta in enumerate(certs_meta):
                    for k,v in meta.items():
                        self.add_metadata("Cert[{}].{}".format(i, k), v, category="Certificates")
        



class DerCertificate(FileTypeAnalyzer):
    category = malcat.FileType.DOCUMENT
    name = "PKCS7"
    regexp = r"\x30...\x06\x09\x2A\x86\x48\x86\xF7\x0D\x01\x07\x02.{15}\x06"

    @classmethod
    def locate(cls, curfile, offset_magic, parent_parser):
        if parent_parser is not None and parent_parser.name == "DER":
            # no DER in DER 
            return None
        return offset_magic, ""

    def parse(self, hint):
        certs_meta = []
        while self.remaining():
            data = self.read(size=min(self.remaining(), 65536))
            parsed, meta = parse_der_certificate(data)
            if meta and parsed:
                certs_meta.append(meta)
                previous = self.tell()
                yield Bytes(parsed, name=meta.get("Subject"), category=Type.DATA)
                self.add_section(meta.get("Subject"), previous, self.tell() - previous)
                self.confirm()
            else:
                break
        if certs_meta:
            if len(certs_meta) == 1:
                for k, v in certs_meta[0].items():
                    self.add_metadata(k, v, category="Certificate")
            else:
                for i, meta in enumerate(certs_meta):
                    for k,v in meta.items():
                        self.add_metadata("Cert[{}].{}".format(i, k), v, category="Certificates")




def parse_der_certificate(data):
    try:
        import pyasn1
        from pyasn1.codec.der.decoder import decode
        from pyasn1_modules import rfc2315, rfc2459, rfc2437
        pyasn1_version = tuple(map(int, pyasn1.__version__.split(".")))
        if pyasn1_version >= (0, 5, 0):
            print("pyasn1 0.5.0 installed, but we require pyasn1<0.5.0,>=0.4.5 (see requirements.txt)")
            has_pyasn1 = False
        else:
            has_pyasn1 = True
    except ImportError:
        has_pyasn1 = False
        print("pyasn1 not installed: no certificate parsing")

    if not has_pyasn1:
        return 0, {}

    class MemoryDerDecoder(pyasn1.codec.der.decoder.Decoder):
        def __call__(self,*v,**kw):
            try:
                parsed,remainder = pyasn1.codec.der.decoder.Decoder.__call__(self,*v,**kw)
            except Exception as e:
                raise FatalError(e)
            parsed._substrate = v[0][:len(v[0])-len(remainder)]
            return parsed,remainder
    decode = MemoryDerDecoder(pyasn1.codec.der.decoder.tagMap, pyasn1.codec.der.decoder.typeMap)

    SUPPORTED_ATTRIBUTES = {
            rfc2459.id_at_countryName: ("Country", rfc2459.X520countryName),
            rfc2459.id_at_organizationName: ("Organization", rfc2459.X520OrganizationName),
            rfc2459.id_at_organizationalUnitName: ("Unit", rfc2459.X520OrganizationalUnitName),
            rfc2459.id_at_commonName: ("CommonName", rfc2459.X520CommonName),
            rfc2459.id_at_stateOrProvinceName: ("State", rfc2459.X520StateOrProvinceName),
            rfc2459.id_at_localityName: ("Locality", rfc2459.X520LocalityName),
            rfc2459.emailAddress: ("Email", rfc2459.Pkcs9email),
    }

    ALGORITHMS = {
           "1.2.840.113549.2.5" : "MD5",
           "1.3.14.3.2.26" : "SHA1",
           "2.16.840.1.101.3.4.2.1" : "SHA256",
           "2.16.840.1.101.3.4.2.2" : "SHA384",
           "2.16.840.1.101.3.4.2.3" : "SHA512",
           "2.16.840.1.101.3.4.2.4" : "SHA224",
           "2.16.840.1.101.3.4.2.5" : "SHA512-224",
           "2.16.840.1.101.3.4.2.6" : "SHA512-256",
           "1.2.840.113549.1.1.1" : "RSA",
           "1.2.840.113549.1.1.2" : "MD2/RSA",
           "1.2.840.113549.1.1.3" : "MD4/RSA",
           "1.2.840.113549.1.1.4" : "MD5/RSA",
           "1.2.840.113549.1.1.5" : "SHA1/RSA",
           "1.2.840.113549.1.1.6" : "OEP/RSA",
           "1.2.840.113549.1.1.7" : "OEPAES/RSA",
           "1.2.840.113549.1.1.8" : "MGF1/RSA",
           "1.2.840.113549.1.1.9" : "PSPEC/RSA",
           "1.2.840.113549.1.1.10" : "PSS/RSA",
           "1.2.840.113549.1.1.11" : "SHA256/RSA",
           "1.2.840.113549.1.1.12" : "SHA384/RSA",
           "1.2.840.113549.1.1.13" : "SHA512/RSA",
           "1.2.840.113549.1.1.14" : "SHA224/RSA",
           "1.2.840.10045.2.1" : "ECDSA",
           "1.2.840.10045.4.1" : "SHA1/ECDSA",
           "1.2.840.10045.4.2" : "SHA/ECDSA",
           "1.2.840.10045.4.3" : "SHA2/ECDSA",
           "1.2.840.10045.4.3.1" : "SHA224/ECDSA",
           "1.2.840.10045.4.3.2" : "SHA256/ECDSA",
           "1.2.840.10045.4.3.3" : "SHA384/ECDSA",
           "1.2.840.10045.4.3.4" : "SHA512/ECDSA",
           "1.2.840.10040.4.1" : "DSA",
           "1.2.840.10040.4.2" : "DSA",
           "1.2.840.10040.4.3" : "SHA1/DSA",
    }
    meta = {}
    parsed_size = 0
    try:
        contentInfo, rest = decode(data, asn1Spec=rfc2315.ContentInfo())
        parsed_size = len(data) - len(rest)
        contentType = contentInfo.getComponentByName('contentType')
        if contentType == rfc2315.signedData:  
            signedData = decode(contentInfo.getComponentByName('content'), asn1Spec=rfc2315.SignedData())
            for signature in signedData:
                if not signature:
                    continue
                #print(signature)
                signerInfos = signature.getComponentByName('signerInfos')
                signer_infos = {}
                serial = None
                before = None
                after = None
                digest = None
                subject_infos = {}
                digest_algorithm = None
                digest_encryption_algorithm = None
                for si in signerInfos:
                    if not si:
                        continue
                    issuerAndSerial = si.getComponentByName('issuerAndSerialNumber')
                    serial = issuerAndSerial.getComponentByName('serialNumber')._substrate.asOctets()[2:].hex()
                    issuer = issuerAndSerial.getComponentByName('issuer').getComponent()
                    for issuer_info in issuer:
                        for field in issuer_info:
                            at = field.getComponentByName('type')                       
                            value = field.getComponentByName('value')                     
                            sup = SUPPORTED_ATTRIBUTES.get(at, None)
                            if sup is not None:
                                name, spec = sup
                                dec_value = decode(value, asn1Spec=spec())[0]
                                if hasattr(dec_value, "getComponent"):
                                    dec_value = dec_value.getComponent()
                                signer_infos[name] = str(dec_value)
                    if "digestAlgorithm" in si:
                        digestAlgorithm = si.getComponentByName('digestAlgorithm')
                        digest_algorithm = str(digestAlgorithm.getComponentByName("algorithm"))
                    if "digestEncryptionAlgorithm" in si:
                        digestAlgorithm = si.getComponentByName('digestEncryptionAlgorithm')
                        digest_encryption_algorithm = str(digestAlgorithm.getComponentByName("algorithm"))
                chain = signature.getComponentByName('certificates')
                for cert in chain:
                    tbs = cert.getComponentByName("certificate").getComponentByName("tbsCertificate")
                    chain_serial = tbs.getComponentByName('serialNumber')._substrate.asOctets()[2:].hex()
                    if chain_serial == serial:
                        # get extra infos about issuer
                        for issuer_info in tbs.getComponentByName('issuer').getComponent():
                            for field in issuer_info:
                                at = field.getComponentByName('type')                       
                                value = field.getComponentByName('value')                     
                                sup = SUPPORTED_ATTRIBUTES.get(at, None)
                                if sup is not None:
                                    name, spec = sup
                                    dec_value = decode(value, asn1Spec=spec())[0]
                                    if hasattr(dec_value, "getComponent"):
                                        dec_value = dec_value.getComponent()
                                    signer_infos[name] = str(dec_value)
                        t = tbs.getComponentByName('validity').getComponentByName("notBefore")
                        if t.getName() == "utcTime":
                            try:
                                before = datetime.datetime.strptime(str(t.getComponent())[:-1], "%y%m%d%H%M%S")
                            except:
                                before = datetime.datetime.strptime(str(t.getComponent())[:-1], "%y%m%d%H%M")
                        t = tbs.getComponentByName('validity').getComponentByName("notAfter")
                        if t.getName() == "utcTime":
                            try:
                                after = datetime.datetime.strptime(str(t.getComponent())[:-1], "%y%m%d%H%M%S")
                            except:
                                after = datetime.datetime.strptime(str(t.getComponent())[:-1], "%y%m%d%H%M")

                        # get extra infos about subject
                        for info in tbs.getComponentByName("subject")[0]:
                            info = info[0]
                            at = info.getComponentByName('type')                       
                            value = info.getComponentByName('value')                     
                            sup = SUPPORTED_ATTRIBUTES.get(at, None)
                            if sup is not None:
                                name, spec = sup
                                dec_value = decode(value, asn1Spec=spec())[0]
                                if hasattr(dec_value, "getComponent"):
                                    dec_value = dec_value.getComponent()
                                subject_infos[name] = str(dec_value)
                if signer_infos:
                    meta["Issuer"] = "{} (Organization={} / Unit={} / Country={})".format(
                            signer_infos.get("CommonName", "???"), 
                            signer_infos.get("Organization", "???"), 
                            signer_infos.get("Unit", "???"), 
                            signer_infos.get("Country", "???"))
                if subject_infos:
                    meta["Subject"] = "{}".format(
                            subject_infos.get("CommonName", "???"))
                    meta["Org Details"] = "{} / Unit={} / State={} / Locality={} / Country={} / Email={}".format(
                            subject_infos.get("Organization", "???"), 
                            subject_infos.get("Unit", "???"), 
                            subject_infos.get("State", "???"),
                            subject_infos.get("Locality", "???"),
                            subject_infos.get("Country", "???"),
                            subject_infos.get("Email", "???"))

                if before is not None and after is not None:
                    meta["Validity"] = "from {} to {}".format(before.strftime("%Y-%m-%d"), after.strftime("%Y-%m-%d"))
                if serial:
                    meta["SerialNumber"] = serial
                if digest_algorithm in ALGORITHMS:
                    meta["HashAlgorithm"] = ALGORITHMS[digest_algorithm]
                if digest_encryption_algorithm in ALGORITHMS:
                    meta["CryptAlgorithm"] = ALGORITHMS[digest_encryption_algorithm]
    except pyasn1.error.PyAsn1Error as e:
        raise FatalError(f"Could not decode signature: {e}")
    return parsed_size, meta


