import logging
import re
from typing import Optional, List
from dataclasses import asdict
from enum import Enum

import pandas as pd
import nsys_cpu_stats.trace_utils as tu
from nsys_cpu_stats.trace_utils import FrameDurations, CallStack, TimeSlice, GPUMetric, CPUConfig

from nsys_cpu_stats.trace_loader import TraceLoaderInterface, TraceLoaderSupport, TraceLoaderRegions, TraceLoaderGPUMetrics, TraceLoaderEvents
from nsys_cpu_stats.nsysrep_loader import NSysRepLoader

logger = logging.getLogger(__name__)


class TraceLoaderType(Enum):
    NSysRep_Loader = 'NSys-Rep Loader'

    def __str__(self):
        return self.value


class TraceProcessor:
    def __init__(self):
        self.loaders: Optional[dict] = {}
        self.thread_name_dict: Optional[dict] = None
        self.process_name_dict: Optional[dict] = None
        self.df_dict: Optional[dict] = {}
        self.all_timeslice_list: Optional[list] = None
        self.verbose = True
        self.current_loader: TraceLoaderType = None

    def get_loader(self, loader_type: TraceLoaderType) -> TraceLoaderInterface:
        """Return the specified loader."""
        if loader_type in self.loaders:
            return self.loaders[loader_type]
        return None

    def set_current_loader(self, loader_type: TraceLoaderType):
        """Sets the specified loader."""
        self.current_loader = loader_type

    def get_current_loader(self) -> TraceLoaderInterface:
        """Return the specified loader."""
        return self.loaders[self.current_loader]

    def is_supported(self, support: TraceLoaderSupport):
        """Query what is supported"""
        return self.loaders[self.current_loader].is_supported(support)

    def is_region_supported(self, support: TraceLoaderRegions):
        """Query which regions are supported"""
        return self.loaders[self.current_loader].is_region_supported(support)

    def is_gpu_metric_supported(self, support: TraceLoaderGPUMetrics):
        """Query if a GPU metric is supported"""
        return self.loaders[self.current_loader].is_gpu_metric_supported(support)

    def get_string(self, string_id) -> str:
        """Return the string, given the string ID"""
        return self.loaders[self.current_loader].get_string(string_id)

    def get_module_string(self, string_id) -> str:
        """Return the string, given the string ID"""
        return self.loaders[self.current_loader].get_module_string(string_id)

    def get_gpu_metric_name(self, metric: TraceLoaderGPUMetrics) -> str:
        """Get the actual GPU Metric name, given the type"""
        return self.loaders[self.current_loader].get_gpu_metric_name(metric)

    ####################################################
    #
    # Initialise the NSys loader
    #
    ####################################################
    def init_loader(self, loader_type: TraceLoaderType) -> TraceLoaderInterface:
        """Initialise the loader."""
        if loader_type == TraceLoaderType.NSysRep_Loader:
            self.loaders[TraceLoaderType.NSysRep_Loader] = NSysRepLoader()
            return self.loaders[TraceLoaderType.NSysRep_Loader]
        return None

    ####################################################
    #
    # Common initialisation the loader
    #
    ####################################################
    def init_common(self):
        assert self.current_loader
        self.thread_name_dict = self.loaders[self.current_loader].init_thread_name_dict()
        self.process_name_dict = self.loaders[self.current_loader].init_process_name_dict()
        self.loaders[self.current_loader].determine_support()

    ####################################################
    #
    # Get timeslices from the DB
    #
    ####################################################
    def get_timeslices(self,
                       start_time_ns: Optional[float] = None,
                       end_time_ns: Optional[float] = None,
                       target_pid: Optional[int] = None,
                       quiet: Optional[bool] = False) -> List[TimeSlice]:
        """Get the timeslices from the loader."""
        return self.loaders[self.current_loader].get_timeslices(start_time_ns, end_time_ns, target_pid, quiet)

    ####################################################
    #
    # Get the callstacks
    #
    ####################################################
    def get_callstacks(self,
                       start_time_ns: float,
                       end_time_ns: float,
                       target_pid: int,
                       target_tid: Optional[int] = None) -> List[CallStack]:
        """Get the callstacks from the loader."""
        return self.loaders[self.current_loader].get_callstacks(start_time_ns, end_time_ns, target_pid, target_tid)

    ####################################################
    #
    # Filter the callstacks
    #
    ####################################################
    @staticmethod
    def filter_callstacks(target_tid: int,
                          callstack_sample_list: List[CallStack]) -> List[CallStack]:
        """Filter the callchains based on tid"""
        filtered_callstack_sample_list = []
        for ccs in callstack_sample_list:
            if ccs.tid == target_tid:
                filtered_callstack_sample_list.append(ccs)
        return filtered_callstack_sample_list

    ####################################################
    #
    # Get region durations
    #
    ####################################################
    def get_region_durations(self,
                             region_type: TraceLoaderRegions,
                             start_time_ns: Optional[float] = None,
                             end_time_ns: Optional[float] = None,
                             target_pid: Optional[int] = None) -> (float, list[FrameDurations]):
        """Get the region durations, if supported. Returns the average duration and the list of durations"""
        return self.loaders[self.current_loader].get_region_durations(region_type, start_time_ns, end_time_ns, target_pid)

    def get_derived_region_durations(self,
                                     region_type: TraceLoaderRegions,
                                     base_durations: list[FrameDurations],
                                     start_time_ns: Optional[float] = None,
                                     end_time_ns: Optional[float] = None,
                                     target_pid: Optional[int] = None,
                                     ) -> list[FrameDurations]:
        """Get region durations derived from a base frametime list. Returns the list of durations"""
        return self.loaders[self.current_loader].get_derived_region_durations(region_type, base_durations, start_time_ns, end_time_ns, target_pid)

    ####################################################
    #
    # Get the CPU core count
    #
    ####################################################
    def get_core_count(self) -> int:
        """Get the CPU core count."""
        return self.loaders[self.current_loader].get_core_count()

    ####################################################
    #
    # Get the CPU config
    #
    ####################################################
    def get_cpu_config(self) -> CPUConfig:
        """Get the CPU core count."""
        return self.loaders[self.current_loader].get_cpu_config()

    ####################################################
    #
    # Get the Analysis duration
    #
    ####################################################
    def get_analysis_duration(self) -> int:
        """Get the duration of the analysis."""
        return self.loaders[self.current_loader].get_analysis_duration()

    ####################################################
    #
    # Get the Average GPU metric
    #
    ####################################################
    def get_average_gpu_metrics(self,
                                metric: TraceLoaderGPUMetrics,
                                start_time_ns: Optional[float] = None,
                                end_time_ns: Optional[float] = None) -> Optional[float]:
        """Get the average value for the provided GPU metric and time range."""
        return self.loaders[self.current_loader].get_average_gpu_metrics(metric, start_time_ns, end_time_ns)

    ####################################################
    #
    # Get the thread name
    #
    ####################################################
    def get_thread_name(self, gtid):
        """Get the name of the thread from the given gtid."""
        pid, tid = tu.convert_global_tid(gtid)
        if gtid in self.thread_name_dict:
            return self.thread_name_dict[gtid] + " [" + str(tid) + "]"
        return "tid_" + str(tid)

    def get_process_name(self, pid):
        """Get the name of the process for the given PID"""
        if pid in self.process_name_dict:
            return self.process_name_dict[pid]
        return "pid_" + str(pid)

    def get_process_pid(self, name):
        """Get the pid of the process for the given name"""
        name_lower = name.lower()
        for pid in self.process_name_dict:
            if self.process_name_dict[pid].lower() == name_lower:
                return pid
        return None

    def clear_dataframes(self):
        self.df_dict = {}

    def get_all_timeslices(self,
                           target_pid: Optional[int] = None,
                           quiet: Optional[bool] = False) -> List[tu.TimeSlice]:
        """Get all of the timeslices for all PIDs, TIDs for the entire time range"""
        self.all_timeslice_list = self.loaders[self.current_loader].get_timeslices(target_pid=target_pid, quiet=quiet)
        return self.all_timeslice_list

    ####################################################
    #
    # Filter based on time or tid.
    # This will correctly detect timeslices who have any part
    # in the given time period.
    #
    ####################################################
    def filter_timeslices(self,
                          input_timeslice_list: List[tu.TimeSlice],
                          start_time_ns: Optional[float] = None,
                          end_time_ns: Optional[float] = None,
                          target_pid: Optional[int] = None,
                          target_tid: Optional[int] = None,
                          ) -> List[tu.TimeSlice]:
        """Filter the timeslices based on time range and/or pid/tid"""
        output_timeslice_list = [ts for ts in input_timeslice_list if (ts.start < end_time_ns) and (ts.end > start_time_ns)] if start_time_ns and end_time_ns else input_timeslice_list.copy()
        if target_pid and target_tid:
            gtid = tu.get_gtid(target_pid, target_tid)
            output_timeslice_list = [ts for ts in output_timeslice_list if tu.compare_gtid(ts.gtid, gtid)]
        elif target_pid or target_tid:
            output_timeslice_list = [ts for ts in output_timeslice_list if (target_pid is not None and target_pid == tu.get_pid(ts.gtid)) or (target_pid is None)]
            output_timeslice_list = [ts for ts in output_timeslice_list if (target_tid is not None and target_tid == tu.get_tid(ts.gtid)) or (target_tid is None)]

        # Trim the timeslices to the bounds of the region
        # When trimming, remember to make NEW timeslices, as modifying the original will
        # impact all lists that contain those timeslices.
        if start_time_ns or end_time_ns:
            new_list = []
            for ts in output_timeslice_list:
                start = ts.start
                end = ts.end
                if start_time_ns:
                    start = max(ts.start, start_time_ns)
                if end_time_ns:
                    end = min(ts.end, end_time_ns)
                new_list.append(tu.TimeSlice(start, end, ts.cpu, ts.gtid))

            output_timeslice_list = new_list
        return output_timeslice_list

    ####################################################
    #
    # Get thread utilisation
    #
    ####################################################
    def get_thread_utilisation(self, timeslice_list: List[tu.TimeSlice], ignore_nsys: bool) -> (float, dict):
        """Calculate the utilisation of each thread, given the timeslice list"""
        start = 0
        end = 0
        thread_time_dict = {}

        # Process all of the timeslices summing the durations of each thread
        for t in timeslice_list:
            if start == 0 or t.start < start:
                start = t.start
            end = max(end, t.end)
            time = t.end - t.start

            if t.gtid in thread_time_dict:
                thread_time_dict[t.gtid] += time
            else:
                thread_time_dict[t.gtid] = time

        total_time = end - start

        thread_utilisation_dict = {}

        # Convert to utilisation
        if total_time > 0:
            for thread_name, thread_time in thread_time_dict.items():
                if ignore_nsys and thread_name in self.thread_name_dict:            # ignore NSys threads
                    nsys_thread_match = re.search(r"\[NSys", self.thread_name_dict[thread_name])
                    if bool(nsys_thread_match):
                        continue

                thread_utilisation_dict[thread_name] = thread_time / total_time
        return total_time, thread_utilisation_dict

    ####################################################
    #
    # Get pcore utilisation
    #
    ####################################################
    def get_thread_utilisation_per_core(self, timeslice_list: List[tu.TimeSlice], ignore_nsys: bool) -> (float, dict, dict):
        """Calculate the utilisation for each thread, per core"""
        # For each thread, find how much time was spent running on a P core vs E core
        start = 0
        end = 0
        cpu_thread_time_dict = {}
        thread_time_dict = {}

        # Core key is a pair of tid, CPU ID

        # Process all of the timeslices summing the durations of each thread
        for t in timeslice_list:
            if start == 0 or t.start < start:
                start = t.start
            end = max(end, t.end)

            time = t.end - t.start

            if t.gtid in thread_time_dict:
                thread_time_dict[t.gtid] += time
            else:
                thread_time_dict[t.gtid] = time

            if (t.gtid, t.cpu) in cpu_thread_time_dict:
                cpu_thread_time_dict[(t.gtid, t.cpu)] += time
            else:
                cpu_thread_time_dict[(t.gtid, t.cpu)] = time

        total_time = end - start

        thread_utilisation_dict = {}
        cpu_thread_utilisation_dict = {}

        # Convert to utilisation
        if total_time > 0:
            for thread_name, thread_time in thread_time_dict.items():
                if ignore_nsys and thread_name in self.thread_name_dict:            # ignore NSys threads
                    nsys_thread_match = re.search(r"\[NSys", self.thread_name_dict[thread_name])
                    if bool(nsys_thread_match):
                        continue

                thread_utilisation_dict[thread_name] = thread_time / total_time

            for key, thread_time in cpu_thread_time_dict.items():
                if ignore_nsys and key[0] in self.thread_name_dict:            # ignore NSys threads
                    nsys_thread_match = re.search(r"\[NSys", self.thread_name_dict[key[0]])

                    if bool(nsys_thread_match):
                        continue

                cpu_thread_utilisation_dict[key] = thread_time / total_time

        return total_time, thread_utilisation_dict, cpu_thread_utilisation_dict

    ####################################################
    #
    # Find the target process - accumulate all of the thread work, sort and return the largest which has GPU work
    #
    ####################################################
    def find_target_pid(self, thread_utilisation_dict, start_time_ns, end_time_ns, verbose) -> (int, dict):
        """Finds the busiest PID with graphics work, which is likely to be the process that should be investigated."""
        sorted_thread_keys = sorted(thread_utilisation_dict, key=thread_utilisation_dict.get, reverse=True)
        busy_pid = 0

        if verbose:
            logger.debug("Busiest processes: ")
        pid_util_dict = {}
        for key in sorted_thread_keys:
            main_pid, main_tid = tu.convert_global_tid(key)
            # Ignore idle
            if main_pid == 0:
                continue
            if main_pid in pid_util_dict:
                pid_util_dict[main_pid] += thread_utilisation_dict[key]
            else:
                pid_util_dict[main_pid] = thread_utilisation_dict[key]

        sorted_pid_util_keys = sorted(pid_util_dict, key=pid_util_dict.get, reverse=True)

        sorted_pid_util_dict = {}
        for pid in sorted_pid_util_keys:
            sorted_pid_util_dict[self.get_process_name(pid)] = pid_util_dict[pid] * 100
            if verbose:
                logger.debug(f'\t{pid}: {self.get_process_name(pid)} = {pid_util_dict[pid] * 100:<0.2f}')

        if len(sorted_pid_util_keys) == 0:
            logger.debug("ERROR: Can not find the busy pid. There is a problem with the timeslices.")
            return busy_pid, {}

        busy_pid = sorted_pid_util_keys[0]

        # We have the busiest pid, now to see if it supports anything useful
        for pid in sorted_pid_util_keys:
            process_name = self.get_process_name(pid)
            if process_name.lower() in ("dwm.exe", "explorer.exe", "system"):
                continue
            if self.loaders[self.current_loader].is_graphics_workload(start_time_ns=start_time_ns, end_time_ns=end_time_ns, target_pid=pid):
                return pid, sorted_pid_util_dict
            if self.loaders[self.current_loader].is_compute_workload(start_time_ns=start_time_ns, end_time_ns=end_time_ns, target_pid=pid):
                return pid, sorted_pid_util_dict

        # Just return the busiest pid
        return busy_pid, sorted_pid_util_dict

    ####################################################
    #
    # Get timerange from timeslices
    #  Ignores the first/last 5% (safety margin)
    #
    ####################################################
    @staticmethod
    def get_safe_timerange_from_timeslices(timeslice_list: List[TimeSlice],
                                           safety_margin: Optional[float] = 0.05):
        """Determine the safe start/end of the trace, given the timeslices. This is to avoid any start/stop overhead. Currently defaults to 5%"""
        # Get the time slices
        start = max(timeslice_list[0].start, 0)
        duration = timeslice_list[-1].end - start

        ignore_delta = duration * safety_margin
        start_time_ns = start + ignore_delta
        end_time_ns = start + duration - ignore_delta
        return start_time_ns, end_time_ns

    ####################################################
    #
    # Get events for the given time range and tid/pid
    #
    ####################################################
    def get_events(self,
                   event_type: TraceLoaderEvents,
                   start_time_ns: int,
                   end_time_ns: int,
                   target_pid: int,
                   target_tid: int) -> (dict, dict):
        """Retrieve any events of the given type."""
        return self.loaders[self.current_loader].get_events(event_type, start_time_ns=start_time_ns, end_time_ns=end_time_ns, target_pid=target_pid, target_tid=target_tid)

    ####################################################
    #
    # Get ordered events for the given time range and tid/pid
    #
    ####################################################
    def get_ordered_events(self,
                           event_type: TraceLoaderEvents,
                           start_time_ns: int,
                           end_time_ns: int,
                           target_pid: int,
                           target_tid: int) -> List:
        """Retrieve any events of the given type."""
        return self.loaders[self.current_loader].get_ordered_events(event_type, start_time_ns=start_time_ns, end_time_ns=end_time_ns, target_pid=target_pid, target_tid=target_tid)

    ####################################################
    #
    # Find average GPU metrics
    #
    ####################################################
    def get_all_average_gpu_metrics(self,
                                    start_time_ns: Optional[float] = None,
                                    end_time_ns: Optional[float] = None) -> (dict, int):
        """Get the average GPU metrics for the given time period"""
        return self.loaders[self.current_loader].get_all_average_gpu_metrics(start_time_ns=start_time_ns, end_time_ns=end_time_ns)

    ####################################################
    #
    # Get GPU metric list as frame durations
    #
    ####################################################
    def get_gpu_metric_frame_list(self,
                                  metric_type: TraceLoaderGPUMetrics,
                                  min_metric: Optional[float] = None,
                                  max_metric: Optional[float] = None,
                                  min_percent: Optional[float] = None,
                                  max_percent: Optional[float] = None,
                                  start_time_ns: Optional[float] = None,
                                  end_time_ns: Optional[float] = None) -> (List[GPUMetric], List[FrameDurations]):
        """Get the provided GPU metrics as a list of frame durations and a list of the actual GPU metrics."""
        return self.loaders[self.current_loader].get_gpu_metric_frame_list(metric_type=metric_type,
                                                                           min_metric=min_metric,
                                                                           max_metric=max_metric,
                                                                           min_percent=min_percent,
                                                                           max_percent=max_percent,
                                                                           start_time_ns=start_time_ns,
                                                                           end_time_ns=end_time_ns)

    ####################################################
    #
    # Add dataframes
    #
    ####################################################
    @staticmethod
    def get_dataframe_key(key_string):
        """Given the key string, return a valid dataframe dictionary key"""
        return key_string.replace(" ", "_")

    def add_dataframe_from_dict(self, key_string: str, d: dict, c: Optional[list[str]] = None, transpose: bool = False, sort: bool = False, sort_ascending: bool = False, sort_column: str = None):
        """Add a dataframe from the given dict."""
        if transpose:
            df = pd.DataFrame.from_dict(d, orient='index', columns=c)
        else:
            df = pd.DataFrame.from_dict(d, orient='columns', columns=None)
        if c is not None:
            df.columns = c
        if sort:
            df.sort_values(by=[sort_column], axis=0, ascending=sort_ascending, inplace=True)
        self.df_dict[self.get_dataframe_key(key_string)] = df

    def add_dataframe_from_list(self, key_string, df_list, c):
        """Add a dataframe from the given list."""
        df = pd.DataFrame(df_list)
        if c is not None:
            df.columns = c
        self.df_dict[self.get_dataframe_key(key_string)] = df

    def add_meta_info(self, meta_info: tu.SourceMetaInfo):
        self.df_dict['report_meta_info'] = pd.DataFrame.from_dict(asdict(meta_info), orient='index')
