#!/usr/bin/env python3

import argparse
from collections import Counter, OrderedDict
from datetime import datetime

from Levenshtein import ratio
import os
import re
import sys

script_dir = os.path.dirname(os.path.realpath(__file__))
root_dir = script_dir + "/../"
asm_dir = root_dir + "asm/us/nonmatchings/"
build_dir = root_dir + "build/"


def read_rom():
    with open(root_dir + "baserom.us.z64", "rb") as f:
        return f.read()


def find_dir(query):
    for root, dirs, files in os.walk(asm_dir):
        for d in dirs:
            if d == query:
                return os.path.join(root, d)
    return None


def get_all_s_files():
    ret = set()
    for root, dirs, files in os.walk(asm_dir):
        for f in files:
            if f.endswith(".s"):
                ret.add(f[:-2])
    return ret


def get_symbol_length(sym_name):
    if "end" in map_offsets[sym_name] and "start" in map_offsets[sym_name]:
        return map_offsets[sym_name]["end"] - map_offsets[sym_name]["start"]
    return 0


def get_symbol_bytes(offsets, func):
    if func not in offsets or "start" not in offsets[func] or "end" not in offsets[func]:
        return None
    start = offsets[func]["start"]
    end = offsets[func]["end"]
    bs = list(rom_bytes[start:end])

    while len(bs) > 0 and bs[-1] == 0:
        bs.pop()

    insns = bs[0::4]

    ret = []
    for ins in insns:
        ret.append(ins >> 2)

    return bytes(ret).decode("utf-8"), bs


def parse_map(fname):
    ram_offset = None
    cur_file = "<no file>"
    syms = {}
    prev_sym = None
    prev_line = ""
    with open(fname) as f:
        for line in f:
            if "load address" in line:
                if "noload" in line or "noload" in prev_line:
                    ram_offset = None
                    continue
                ram = int(line[16 : 16 + 18], 0)
                rom = int(line[59 : 59 + 18], 0)
                ram_offset = ram - rom
                continue
            prev_line = line

            if ram_offset is None or "=" in line or "*fill*" in line or " 0x" not in line:
                continue
            ram = int(line[16 : 16 + 18], 0)
            rom = ram - ram_offset
            fn = line.split()[-1]
            if "0x" in fn:
                ram_offset = None
            elif "/" in fn:
                cur_file = fn
            else:
                syms[fn] = (rom, cur_file, prev_sym, ram)
                prev_sym = fn
    return syms


def get_map_offsets(syms):
    offsets = {}
    for sym in syms:
        prev_sym = syms[sym][2]
        if sym not in offsets:
            offsets[sym] = {}
        if prev_sym not in offsets:
            offsets[prev_sym] = {}
        offsets[sym]["start"] = syms[sym][0]
        offsets[prev_sym]["end"] = syms[sym][0]
    return offsets


def is_zeros(vals):
    for val in vals:
        if val != 0:
            return False
    return True


def diff_syms(qb, tb):
    if len(tb[1]) < 8:
        return 0

    # The minimum edit distance for two strings of different lengths is `abs(l1 - l2)`
    # Quickly check if it's impossible to beat the threshold. If it is, then return 0
    l1, l2 = len(qb[0]), len(tb[0])
    if abs(l1 - l2) / (l1 + l2) > 1.0 - args.threshold:
        return 0
    r = ratio(qb[0], tb[0])

    if r == 1.0 and qb[1] != tb[1]:
        r = 0.99
    return r


def get_pair_score(query_bytes, b):
    b_bytes = get_symbol_bytes(map_offsets, b)

    if query_bytes and b_bytes:
        return diff_syms(query_bytes, b_bytes)
    return 0


def get_matches(query):
    query_bytes = get_symbol_bytes(map_offsets, query)
    if query_bytes is None:
        sys.exit("Symbol '" + query + "' not found")

    ret = {}
    for symbol in map_offsets:
        if symbol is not None and query != symbol:
            score = get_pair_score(query_bytes, symbol)
            if score >= args.threshold:
                ret[symbol] = score
    return OrderedDict(sorted(ret.items(), key=lambda kv: kv[1], reverse=True))


def do_query(query):
    matches = get_matches(query)
    num_matches = len(matches)

    if num_matches == 0:
        print(query + " - found no matches")
        return

    i = 0
    more_str = ":"
    if args.num_out < num_matches:
        more_str = " (showing only " + str(args.num_out) + "):"

    print(query + " - found " + str(num_matches) + " matches total" + more_str)
    for match in matches:
        if i == args.num_out:
            break
        match_str = "{:.3f} - {}".format(matches[match], match)
        if match not in s_files:
            match_str += " (decompiled)"
        print(match_str)
        i += 1
    print()


def all_matches(all_funcs_flag):
    match_dict = dict()
    to_match_files = list(s_files.copy())

    # assumption that after half the functions have been matched, nothing of significance is left
    # since duplicates that already have been discovered are removed from tp_match_files
    if all_funcs_flag:
        iter_limit = 0
    else:
        iter_limit = len(s_files) / 2

    num_decomped_dupes = 0
    num_undecomped_dupes = 0
    num_perfect_dupes = 0

    i = 0
    while len(to_match_files) > iter_limit:
        file = to_match_files[0]

        i += 1
        print(
            "File matching progress: {:%}".format(i / (len(s_files) - iter_limit)),
            end="\r",
        )

        if get_symbol_length(file) < 16:
            to_match_files.remove(file)
            continue

        matches = get_matches(file)
        num_matches = len(matches)
        if num_matches == 0:
            to_match_files.remove(file)
            continue

        num_undecomped_dupes += 1

        match_list = []
        for match in matches:
            if match in to_match_files:
                i += 1
                to_match_files.remove(match)

            match_str = "{:.2f} - {}".format(matches[match], match)
            if matches[match] >= 0.995:
                num_perfect_dupes += 1

            if match not in s_files:
                num_decomped_dupes += 1
                match_str += " (decompiled)"
            else:
                num_undecomped_dupes += 1

            match_list.append(match_str)

        match_dict.update({file: (num_matches, match_list)})
        to_match_files.remove(file)

    output_match_dict(match_dict, num_decomped_dupes, num_undecomped_dupes, num_perfect_dupes, i)


def output_match_dict(
    match_dict,
    num_decomped_dupes,
    num_undecomped_dupes,
    num_perfect_dupes,
    num_checked_files,
):
    out_file = open(datetime.today().strftime("%Y-%m-%d-%H-%M-%S") + "_all_matches.txt", "w+")

    out_file.write(
        "Number of s-files: " + str(len(s_files)) + "\n"
        "Number of checked s-files: " + str(round(num_checked_files)) + "\n"
        "Number of decompiled duplicates found: " + str(num_decomped_dupes) + "\n"
        "Number of undecompiled duplicates found: " + str(num_undecomped_dupes) + "\n"
        "Number of overall exact duplicates found: " + str(num_perfect_dupes) + "\n\n"
    )

    sorted_dict = OrderedDict(sorted(match_dict.items(), key=lambda item: item[1][0], reverse=True))

    print("Creating output file: " + out_file.name, end="\n")
    for file_name, matches in sorted_dict.items():
        out_file.write(file_name + " - found " + str(matches[0]) + " matches total:\n")
        for match in matches[1]:
            out_file.write(match + "\n")
        out_file.write("\n")

    out_file.close()


def is_decompiled(sym):
    return sym not in s_files


def do_cross_query():
    ccount = Counter()
    clusters = []

    sym_bytes = {}
    for sym_name in map_syms:
        if (
            not sym_name.startswith("D_")
            and not sym_name.startswith("_binary")
            and not sym_name.startswith("jtbl_")
            and not re.match(r"L[0-9A-F]{8}_[0-9A-F]{5,6}", sym_name)
        ):
            if get_symbol_length(sym_name) > 16:
                sym_bytes[sym_name] = get_symbol_bytes(map_offsets, sym_name)

    for sym_name, query_bytes in sym_bytes.items():
        cluster_match = False
        for cluster in clusters:
            cluster_first = cluster[0]
            cluster_score = diff_syms(query_bytes, sym_bytes[cluster_first])
            if cluster_score >= args.threshold:
                cluster_match = True
                if is_decompiled(sym_name) and not is_decompiled(cluster_first):
                    ccount[sym_name] = ccount[cluster_first]
                    del ccount[cluster_first]
                    cluster_first = sym_name
                    cluster.insert(0, cluster_first)
                else:
                    cluster.append(sym_name)

                if not is_decompiled(cluster_first):
                    ccount[cluster_first] += len(sym_bytes[cluster_first][0])

                if len(cluster) % 10 == 0 and len(cluster) >= 10:
                    print(f"Cluster {cluster_first} grew to size {len(cluster)} - {sym_name}: {str(cluster_score)}")
                break
        if not cluster_match:
            clusters.append([sym_name])
    print(ccount.most_common(100))


parser = argparse.ArgumentParser(
    description="Tool to find duplicates for a specific function or to find all duplicates across the codebase."
)
group = parser.add_mutually_exclusive_group()
group.add_argument(
    "-a",
    "--all",
    help="find ALL duplicates and output them into a file",
    action="store_true",
    required=False,
)
group.add_argument(
    "-c",
    "--cross",
    help="do a cross query over the codebase",
    action="store_true",
    required=False,
)
group.add_argument(
    "-s",
    "--short",
    help="find MOST duplicates besides some very small duplicates. Cuts the runtime in half with minimal loss",
    action="store_true",
    required=False,
)
parser.add_argument("query", help="function or file", nargs="?", default=None)
parser.add_argument(
    "-t",
    "--threshold",
    help="score threshold between 0 and 1 (higher is more restrictive)",
    type=float,
    default=0.9,
    required=False,
)
parser.add_argument(
    "-n",
    "--num-out",
    help="number of functions to display",
    type=int,
    default=100,
    required=False,
)

args = parser.parse_args()

if __name__ == "__main__":
    rom_bytes = read_rom()
    map_syms = parse_map(os.path.join(root_dir, "build", "starfox64.us.map"))
    map_offsets = get_map_offsets(map_syms)

    s_files = get_all_s_files()

    query_dir = find_dir(args.query)

    if query_dir is not None:
        files = os.listdir(query_dir)
        for f_name in files:
            do_query(f_name[:-2])
    else:
        if args.cross:
            args.threshold = 0.985
            do_cross_query()
        elif args.all:
            args.threshold = 0.985
            all_matches(True)
        elif args.short:
            args.threshold = 0.985
            all_matches(False)
        else:
            if args.query is None:
                parser.print_help()
            else:
                do_query(args.query)