import os
import capa
import capa.rules
import capa.engine
import capa.features
import capa.features.extractors.malcat
from capa.rules import Rule, Scope, RuleSet
from capa.features.extractors.base_extractor import BBHandle, InsnHandle, FunctionHandle, FeatureExtractor
from capa.engine import FeatureSet, MatchResults
from capa.features.address import Address, NO_ADDRESS, FileOffsetAddress, AbsoluteVirtualAddress, RelativeVirtualAddress, EffectiveAddress
import importlib
import itertools
import logging
import collections
import json
import numbers
import zipfile
import datetime
importlib.reload(capa.features.extractors.malcat)
from typing import Any, Dict, List, Tuple
from tabulate import tabulate
from malcat import Architecture

logger = logging.getLogger("capa")

def width(s, character_count):
    """pad the given string to at least `character_count`"""
    if len(s) < character_count:
        return s + " " * (character_count - len(s))
    else:
        return s

def get_rules(rule_path):
    if not os.path.exists(rule_path):
        raise IOError("rule path %s does not exist or cannot be accessed" % rule_path)

    rule_paths = []
    if os.path.isfile(rule_path) and rule_path.endswith(".zip"):
        logger.debug("reading rules from zip %s", rule_path)
        rule_paths.append(rule_path)
    elif os.path.isdir(rule_path):
        logger.debug("reading rules from directory %s", rule_path)
        for root, dirs, files in os.walk(rule_path):
            if ".github" in root:
                continue
            for file in files:
                if not file.endswith(".yml"):
                    if not (file.endswith(".md") or file.endswith(".git") or file.endswith(".txt")):
                        logger.warning("skipping non-.yml file: %s", file)
                    continue

                rule_path = os.path.join(root, file)
                rule_paths.append(rule_path)

    rules = []
    for rule_path in rule_paths:
        if rule_path.endswith(".zip"):
            with zipfile.ZipFile(rule_path, "r") as zo:
                for elem in zo.infolist():
                    if not elem.filename.endswith(".yml"):
                        continue
                    with zo.open(elem.filename, mode='r') as f:
                        yaml_content = f.read().decode("utf8")
                        try:
                            rule = capa.rules.Rule.from_yaml(yaml_content, use_ruamel=True)
                        except capa.rules.InvalidRule:
                            raise
                        else:
                            rule.meta["capa/path"] = os.path.join(rule_path, elem.filename)
                            if "nursery" in rule_path:
                                rule.meta["capa/nursery"] = True
                            rules.append(rule)
                            logger.debug("loaded rule: '%s' with scope: %s", rule.name, rule.scope)
        else:
            try:
                rule = capa.rules.Rule.from_yaml_file(rule_path, use_ruamel=True)
            except capa.rules.InvalidRule:
                raise
            else:
                rule.meta["capa/path"] = rule_path
                if "nursery" in rule_path:
                    rule.meta["capa/nursery"] = True
                rules.append(rule)
                logger.debug("loaded rule: '%s' with scope: %s", rule.name, rule.scope)
    return rules

def is_interesting_rule(rule):
    for meta in ("lib", "capa/subscope", "capa/subscope-rule"):
        if rule.meta.get(meta, None):
            return False
    return True

def render_rules_summary(capabilities):
    table = []
    for rule, match in capabilities.items():
        if not is_interesting_rule(rule):
            continue
        name = rule.name
        if len(match) > 1:
            name += " ({})".format(len([x for x in match if x[1].success]))
        table.append((name, rule.meta.get("namespace", "")))
    table = tabulate(table, headers=[width("CAPABILITY", 50), width("NAMESPACE", 50)], tablefmt="psql")
    table_width = len(table.splitlines()[0])
    gui.print("[block1]{}[/block1]".format("{} capabilities found".format(len(capabilities)).center(table_width)), format=True)
    print(table)

def render_attack(capabilities):
    tactics = collections.defaultdict(set)
    for rule in capabilities:
        if not is_interesting_rule(rule):
            continue
        if not rule.meta.get("att&ck"):
            continue

        for attack in rule.meta["att&ck"]:
            tactic, _, rest = attack.partition("::")
            if "::" in rest:
                technique, _, rest = rest.partition("::")
                subtechnique, _, id = rest.rpartition(" ")
                tactics[tactic].add((technique, subtechnique, id))
            else:
                technique, _, id = rest.rpartition(" ")
                tactics[tactic].add((technique, id))

    rows = []
    for tactic, techniques in sorted(tactics.items()):
        inner_rows = []
        for spec in sorted(techniques):
            if len(spec) == 2:
                technique, id = spec
                inner_rows.append("{} {}".format(technique, id))
            elif len(spec) == 3:
                technique, subtechnique, id = spec
                inner_rows.append("{}::{} {}".format(technique, subtechnique, id))
            else:
                raise RuntimeError("unexpected ATT&CK spec format")
        rows.append(("{}".format(tactic.upper()), "\n".join(inner_rows),))

    if rows:
        table = tabulate(rows, headers=[width("ATT&CK Tactic", 20), width("ATT&CK Technique", 80)], tablefmt="psql")
        table_width = len(table.splitlines()[0])
        gui.print("[block1]{}[/block1]".format("ATT&CK information".center(table_width)), format=True)
        print(table)


def render_mbc(capabilities):
    objectives = collections.defaultdict(set)
    for rule in capabilities:
        if not is_interesting_rule(rule):
            continue
        if not rule.meta.get("mbc"):
            continue

        for mbc in rule.meta["mbc"]:
            objective = mbc.split("::")[0]
            objectives[objective].add("::".join(mbc.split("::")[1:]))

    rows = []
    for objective, behaviors in sorted(objectives.items()):
        inner_rows = []
        for behavior in sorted(behaviors):
            inner_rows.append(behavior)
        rows.append((objective.upper(), "\n".join(inner_rows),))

    if rows:
        table = tabulate(rows, headers=[width("MBC Objective", 25), width("MBC Behavior", 75)], tablefmt="psql")
        table_width = len(table.splitlines()[0])
        gui.print("[block1]{}[/block1]".format("Malware Behavior Catalog".center(table_width)), format=True)
        print(table)
    


def render_statement(match, statement, indent=0):
    print("  " * indent, end="")
    if isinstance(statement, capa.engine.Subscope):
        gui.print("{}:".format(statement.scope))
    elif isinstance(statement, capa.engine.Or) or \
        isinstance(statement, capa.engine.And) or \
        isinstance(statement, capa.engine.Not):
        gui.print("{}:".format(statement.__class__.__name__.lower()))
    elif isinstance(statement, capa.engine.Some):
        if not statement.count:
            gui.print("optional:")
        else: 
            gui.print("{} or more:".format(statement.count))
    elif isinstance(statement, capa.engine.Range):
        child = statement.child
        name = child.__class__.__name__.lower()
        if hasattr(child, "value"):
            if hasattr(child, "description") and child.description:
                gui.print("[color4]count[/color4]([color2]{}[/color2]({} = {})): ".format(name, render_value(child.value), child.description), end="", format=True)
            else:
                gui.print("[color4]count[/color4]([color2]{}[/color2]({})): ".format(name, render_value(child.value)), end="", format=True)
        else:
            gui.print("[color4]count[/color4]([color2]{}[/color2]): ".format(name), end="", format=True)

        if statement.max == statement.min:
            print("{}".format(statement.min), end="")
        elif statement.min == 0:
            print("{} or fewer".format(statement.max), end="")
        elif statement.max == (1 << 64 - 1):
            print("{} or more".format(statement.min), end="")
        else:
            print("between {} and {}".format(statement.min, statement.max), end="")

        render_locations(match)
        print()
    else:
        raise RuntimeError("unexpected match statement type: " + str(statement))


def render_feature(match, feature, indent=0):
    print("  " * indent, end="")
    key = feature.name
    value = feature.value
    gui.print("[color2]{}[/color2]: ".format(feature.get_name_str().replace("[", "\\[")), end="", format=True)
    if value is not None:
        color = {
        }.get(key, 3)
        gui.print("[color{}]{}[/color{}]".format(color, feature.get_value_str().replace("[", "\\["), color), end="", format=True)

        if feature.description:
            print(capa.rules.DESCRIPTION_SEPARATOR + feature.description, end="")
    if key not in ("OS", "arch"):
        render_locations(match)
    print()

def render_value(val):
    if type(val) == int:
        return hex(val)
    else:
        return str(val)

def render_address(ea):
    if isinstance(ea, EffectiveAddress):
        return analysis.ppa(ea, interactive=True, resolve=False)
    elif isinstance(ea, FileOffsetAddress):
        return "[fa]#{:x}[/fa]".format(ea)
    elif isinstance(ea, AbsoluteVirtualAddress):
        return "[va]0x{:08x}[/va]".format(ea)
    elif isinstance(ea, RelativeVirtualAddress):
        return "[rva]@{:04x}[/rva]".format(ea)
    elif ea == NO_ADDRESS:
        return ""
    else:
        raise ValueError("unknown address type: " + str(ea))

def render_locations(match):
    locations = list(sorted(match.locations))
    if len(locations) == 1:
        gui.print(" [color1]@[/color1] {}".format(render_address(locations[0])), format=True, end="")
    elif len(locations) > 1:
        gui.print(" [color1]@[/color1] ", end="", format=True)
        gui.print(", ".join(map(render_address, locations[:4])), format=True, end="")
        if len(locations) > 4:
            gui.print(", and {} more...".format(len(locations) - 4), end="")


def render_node(match, node, indent=0):
    if isinstance(node, capa.engine.Statement):
        render_statement(match, node, indent=indent)
    elif isinstance(node, capa.features.common.Feature):
        render_feature(match, node, indent=indent)
    else:
        raise RuntimeError("unexpected node type: " + str(node))


def render_match(match, indent=0, invert=False, ruleset={}, capabilities={}):
    child_mode = invert
    if match.success == invert:
        return
    elif isinstance(match.statement, capa.engine.Not):
        child_mode = not invert
    elif isinstance(match.statement, capa.engine.Some) and match.statement.count == 0:
        if not invert and not any(map(lambda m: m.success, match.children)):
            return
        elif invert and any(map(lambda m: m.success, match.children)):
            return
    elif isinstance(match.statement, capa.features.common.MatchedRule):
        children = []
        statement = match.statement
        if match.statement.value in ruleset:
            subrule = ruleset[match.statement.value]
            if subrule.is_subscope_rule():
                statement = capa.engine.Subscope(subrule.meta.get("scope"), subrule.statement)
            rule_matches = {address: result for (address, result) in capabilities.get(subrule, [])}
            children.extend([rule_matches[location] for location in match.locations if location in rule_matches])
            if children:
                match = capa.features.common.Result(match.success, statement, children, match.locations)
        else:
            ns_rules = ruleset.rules_by_namespace[match.statement.value]
            for subrule in ns_rules:
                rule_matches = {address: result for (address, result) in capabilities.get(subrule, [])}
                children.extend([rule_matches[location] for location in match.locations if location in rule_matches])
            if children:
                match = capa.features.common.Result(match.success, statement, children, match.locations)
    render_node(match, match.statement, indent=indent)
    for child in match.children:
        render_match(child, indent=indent + 1, invert=child_mode, ruleset=ruleset, capabilities=capabilities)




###################################################


def collect_metadata(analysis, rules):
    """ """
    if analysis.architecture == Architecture.X64:
        arch = "x86_64"
    elif analysis.architecture == Architecture.X86:
        arch = "x86"
    else:
        arch = "unknown arch"
    if analysis.type == "PE":
        os = "windows"
    elif analysis.type == "ELF":
        os = "linux"
    else:
        os = "unknown os"

    return {
        "timestamp": datetime.datetime.now().isoformat(),
        "argv": [],
        "sample": {
            "md5": analysis.entropy.md5,
            "sha1": analysis.entropy.sha1,
            "sha256": analysis.entropy.sha256,
            "path": analysis.file.path,
        },
        "analysis": {
            "format": analysis.type,
            "arch": arch,
            "os": os,
            "extractor": "analysis",
            "rules": rules,
            "base_address": analysis.imagebase,
            "layout": {
                # this is updated after capabilities have been collected.
                # will look like:
                #
                # "functions": { 0x401000: { "matched_basic_blocks": [ 0x401000, 0x401005, ... ] }, ... }
            },
            # ignore these for now - not used by IDA plugin.
            "feature_counts": {
                "file": {},
                "functions": {},
            },
            "library_functions": {},
        },
        "version": "4.0.2",
    }



def find_instruction_capabilities(
    ruleset: RuleSet, extractor: FeatureExtractor, f: FunctionHandle, bb: BBHandle, insn: InsnHandle
) -> Tuple[FeatureSet, MatchResults]:
    """
    find matches for the given rules for the given instruction.

    returns: tuple containing (features for instruction, match results for instruction)
    """
    # all features found for the instruction.
    features = collections.defaultdict(set)  # type: FeatureSet

    for feature, addr in itertools.chain(
        extractor.extract_insn_features(f, bb, insn), extractor.extract_global_features()
    ):
        features[feature].add(addr)

    # matches found at this instruction.
    _, matches = ruleset.match(Scope.INSTRUCTION, features, insn.address)

    for rule_name, res in matches.items():
        rule = ruleset[rule_name]
        for addr, _ in res:
            capa.engine.index_rule_matches(features, rule, [addr])

    return features, matches


def find_basic_block_capabilities(
    ruleset: RuleSet, extractor: FeatureExtractor, f: FunctionHandle, bb: BBHandle
) -> Tuple[FeatureSet, MatchResults, MatchResults]:
    """
    find matches for the given rules within the given basic block.

    returns: tuple containing (features for basic block, match results for basic block, match results for instructions)
    """
    # all features found within this basic block,
    # includes features found within instructions.
    features = collections.defaultdict(set)  # type: FeatureSet

    # matches found at the instruction scope.
    # might be found at different instructions, thats ok.
    insn_matches = collections.defaultdict(list)  # type: MatchResults

    for insn in extractor.get_instructions(f, bb):
        ifeatures, imatches = find_instruction_capabilities(ruleset, extractor, f, bb, insn)
        for feature, vas in ifeatures.items():
            features[feature].update(vas)

        for rule_name, res in imatches.items():
            insn_matches[rule_name].extend(res)

    for feature, va in itertools.chain(
        extractor.extract_basic_block_features(f, bb), extractor.extract_global_features()
    ):
        features[feature].add(va)

    # matches found within this basic block.
    _, matches = ruleset.match(Scope.BASIC_BLOCK, features, bb.address)

    for rule_name, res in matches.items():
        rule = ruleset[rule_name]
        for va, _ in res:
            capa.engine.index_rule_matches(features, rule, [va])

    return features, matches, insn_matches


def find_code_capabilities(
    ruleset: RuleSet, extractor: FeatureExtractor, fh: FunctionHandle
) -> Tuple[MatchResults, MatchResults, MatchResults, int]:
    """
    find matches for the given rules within the given function.

    returns: tuple containing (match results for function, match results for basic blocks, match results for instructions, number of features)
    """
    # all features found within this function,
    # includes features found within basic blocks (and instructions).
    function_features = collections.defaultdict(set)  # type: FeatureSet

    # matches found at the basic block scope.
    # might be found at different basic blocks, thats ok.
    bb_matches = collections.defaultdict(list)  # type: MatchResults

    # matches found at the instruction scope.
    # might be found at different instructions, thats ok.
    insn_matches = collections.defaultdict(list)  # type: MatchResults

    for bb in extractor.get_basic_blocks(fh):
        features, bmatches, imatches = find_basic_block_capabilities(ruleset, extractor, fh, bb)
        for feature, vas in features.items():
            function_features[feature].update(vas)

        for rule_name, res in bmatches.items():
            bb_matches[rule_name].extend(res)

        for rule_name, res in imatches.items():
            insn_matches[rule_name].extend(res)

    for feature, va in itertools.chain(extractor.extract_function_features(fh), extractor.extract_global_features()):
        function_features[feature].add(va)

    _, function_matches = ruleset.match(Scope.FUNCTION, function_features, fh.address)
    return function_matches, bb_matches, insn_matches, len(function_features)


def find_file_capabilities(ruleset: RuleSet, extractor: FeatureExtractor, function_features: FeatureSet):
    file_features = collections.defaultdict(set)  # type: FeatureSet

    for feature, va in itertools.chain(extractor.extract_file_features(), extractor.extract_global_features()):
        # not all file features may have virtual addresses.
        # if not, then at least ensure the feature shows up in the index.
        # the set of addresses will still be empty.
        if va:
            file_features[feature].add(va)
        else:
            if feature not in file_features:
                file_features[feature] = set()

    file_features.update(function_features)

    _, matches = ruleset.match(Scope.FILE, file_features, NO_ADDRESS)
    return matches, len(file_features)


def find_capabilities(ruleset: RuleSet, extractor: FeatureExtractor) -> Tuple[MatchResults, Any]:
    all_function_matches = collections.defaultdict(list)  # type: MatchResults
    all_bb_matches = collections.defaultdict(list)  # type: MatchResults
    all_insn_matches = collections.defaultdict(list)  # type: MatchResults

    meta = {
        "feature_counts": {
            "file": 0,
            "functions": {},
        },
        "library_functions": {},
    }  # type: Dict[str, Any]

    functions = list(extractor.get_functions())
    n_funcs = len(functions)

    for i, f in enumerate(functions):
        gui.progress(10 + (80 * i) // n_funcs)
        if extractor.is_library_function(f.address):
            function_name = extractor.get_function_name(f.address)
            print("skipping library function 0x%x (%s)", f.address, function_name)
            meta["library_functions"][f.address] = function_name
            n_libs = len(meta["library_functions"])
            continue

        function_matches, bb_matches, insn_matches, feature_count = find_code_capabilities(ruleset, extractor, f)
        meta["feature_counts"]["functions"][f.address] = feature_count
        #logger.debug("analyzed function 0x%x and extracted %d features", f.address, feature_count)

        for rule_name, res in function_matches.items():
            all_function_matches[rule_name].extend(res)
        for rule_name, res in bb_matches.items():
            all_bb_matches[rule_name].extend(res)
        for rule_name, res in insn_matches.items():
            all_insn_matches[rule_name].extend(res)

    # collection of features that captures the rule matches within function, BB, and instruction scopes.
    # mapping from feature (matched rule) to set of addresses at which it matched.
    function_and_lower_features: FeatureSet = collections.defaultdict(set)
    for rule_name, results in itertools.chain(
        all_function_matches.items(), all_bb_matches.items(), all_insn_matches.items()
    ):
        locations = set(map(lambda p: p[0], results))
        rule = ruleset[rule_name]
        capa.engine.index_rule_matches(function_and_lower_features, rule, locations)

    all_file_matches, feature_count = find_file_capabilities(ruleset, extractor, function_and_lower_features)
    meta["feature_counts"]["file"] = feature_count

    matches = {
        rule_name: results
        for rule_name, results in itertools.chain(
            # each rule exists in exactly one scope,
            # so there won't be any overlap among these following MatchResults,
            # and we can merge the dictionaries naively.
            all_insn_matches.items(),
            all_bb_matches.items(),
            all_function_matches.items(),
            all_file_matches.items(),
        )
    }

    return matches, meta




if __name__ == "__main__":
    gui.print("[error]{}[/error]".format("CAPA script".center(111)), format=True)
    datadir = os.path.join(analysis.env.datadir, "scripts", "capa", "all_rules.zip")
    userdir = os.path.join(analysis.env.userdir, "scripts", "capa", "all_rules")
    gui.print("""
CAPA framework by mandiant ([url]https://github.com/mandiant/capa[/url]) using malcat for [color1]analysis[/color1].
Everything except the malcat comes from the github repository. 
Mandiant rules can be found in [color3]<analysis DATA DIR>/scripts/capa/rules.zip[/color3]. 
You can add your own as .yml files in [color3]<USER DATA DIR>/scripts/capa/all_rules/*.yml[/color3].
    """, format=True)

    malcat_extractor =  capa.features.extractors.malcat.MalcatFeatureExtractor(analysis)
    gui.progress(2)

    import datetime
    start = datetime.datetime.now()
    # load rules
    rules = get_rules(datadir)
    if os.path.exists(userdir):
        rules += get_rules(userdir)
    ruleset = capa.rules.RuleSet(rules)
    gui.progress(10)

    capabilities, counts = find_capabilities(ruleset, malcat_extractor)

    # rule lookup and sorting
    capabilities = collections.OrderedDict(
            sorted([(ruleset[k], v) for k,v in capabilities.items()],
                key=lambda x: (x[0].meta.get("namespace", ""), x[0].name))
            )

    # display
    print()
    render_attack(capabilities)
    print("\n")
    render_mbc(capabilities)
    print("\n")
    render_rules_summary(capabilities)
    print("\n")

    functions_by_bb = {}
    for f in malcat_extractor.get_functions():
        for bb in malcat_extractor.get_basic_blocks(f):
            functions_by_bb[bb.address] = f.address

    for rule, matches in capabilities.items():
        if not is_interesting_rule(rule):
            continue
        gui.print("[block2]{}[/block2]".format(rule.name.center(80).replace("[", "\\[")), format=True)
        gui.print("namespace:  [color4]{}[/color4]".format(rule.meta.get("namespace", "")), format=True)
        for key in ("att&ck", "mbc"):
            if key in rule.meta:
                gui.print("{:12s}{}".format(key + ":", ",".join(["[color3]{}[/color3]".format(x.replace("[", "\\[")) for x in rule.meta[key]])), format=True)
        gui.print("rule scope: [color3]{}[/color3]".format(rule.scope), format=True)
        authors = rule.meta.get("author", [])
        if type(authors) == str:
            authors = [authors]
        gui.print("author:     [color3]{}[/color3]".format(", ".join(authors)), format=True)
        
        if rule.scope == "file":
            if len(matches) == 1:
                render_match(matches[0][1], indent=0, ruleset=ruleset, capabilities=capabilities)
        else:
            for location, match in sorted(matches):
                if rule.scope == "basic block":
                    gui.print("{} [color1]@[/color1] {} in function [color1]@[/color1] {}".format(rule.scope, render_address(location), render_address(functions_by_bb[location])), format=True)
                else: 
                    gui.print("{} [color1]@[/color1] {}".format(rule.scope, render_address(location)), format=True)
                render_match(match, indent=1, ruleset=ruleset, capabilities=capabilities)
        print()
        
