#' Plot the Mutation Frequency
#' @description This function creates a plot of the mutation frequency.
#' @param mf_data A data frame containing the mutation frequency data. This is
#' obtained from the calculate_mf function with SUMMARY = TRUE.
#' @param group_col The name of the column containing the sample/group names
#' for the x-axis.
#' @param plot_type The type of plot to create. Options are "bar" or "point".
#' @param mf_type The type of mutation frequency to plot. Options are "min",
#' "max", "both", or "stacked". If "both", the min and max mutation
#' frequencies are plotted side by side. "stacked" can be chosen for bar
#' plot_type only. If "stacked", the difference between the min and max
#' MF is stacked on top of the min MF such that the total height of both
#' bars represent the max MF.
#' @param fill_col The name of the column containing the fill variable.
#' @param custom_palette A character vector of colour codes to use for the plot.
#' If NULL, a default palette is used
#' @param group_order The order of the samples/groups along the x-axis.
#' ' Options include:
#' \itemize{
#'   \item `none`: No ordering is performed. Default.
#'   \item `smart`: Samples are ordered based on the sample names.
#'   \item `arranged`: Samples are ordered based on one or more factor column(s)
#' in mf_data. Factor column names are passed to the function using the
#' `group_order_input`.
#'  \item `custom`: Samples are ordered based on a custom vector of sample
#' names. The custom vector is passed to the function using the
#' `group_order_input`.
#' }
#' @param group_order_input The order of the samples/groups if group_order is
#' "custom". The column name by which to arrange samples/groups if group_order
#' is "arranged"
#' @param labels The data labels to display on the plot. Either "count", "MF",
#' or "none". Count labels display the number of mutations, MF labels display
#' the mutation frequency.
#' @param scale_y_axis The scale of the y axis. Either "linear" or "log".
#' @param x_lab The label for the x axis.
#' @param y_lab The label for the y axis.
#' @param title The title of the plot.
#' @param rotate_labels A logical value aplied when labels is not "none".
#' Indicates whether the labels should be rotated 90 degrees. Default is
#' FALSE.
#' @param label_size A numeric value that adjusts the size of the labels.
#' Default is 3.
#' @return A ggplot object
#' @examples
#' # Example data  consists of 24 mouse bone marrow
#' # samples exposed to three doses of BaP alongside vehicle controls.
#' # Libraries were sequenced with Duplex Sequencing using
#' # the TwinStrand Mouse Mutagenesis Panel which consists of 20 2.4kb
#' # targets = 48kb of sequence. Example data can be retrieved from
#' # MutSeqRData, an ExperimentHub data package:
#' ## library(ExperimentHub)
#' ## eh <- ExperimentHub()
#' ## query(eh, "MutSeqRData")
#' # Mutation frequency data was precalculated using
#' ## mf_data_global <- calculate_mf(mutation_data = eh[["EH9861"]],
#' ##   cols_to_group = "sample",
#' ##   retain_metadata_cols = c("dose_group", "dose"))
#' # Load Example MF data
#' mf_example <- readRDS(system.file("extdata/Example_files/mf_data_global.rds",
#'   package = "MutSeqR"
#' ))
#' # Specify the order of the dose groups along the x-axis
#' mf_example$dose_group <- factor(mf_example$dose_group,
#'   levels = c(
#'     "Control", "Low",
#'     "Medium", "High"
#'   )
#' )
#' # Plot the min MF per sample as a bar plot with count labels
#' plot <- plot_mf(
#'   mf_data = mf_example,
#'   group_col = "sample",
#'   plot_type = "bar",
#'   mf_type = "min",
#'   fill_col = "dose_group",
#'   group_order = "arranged",
#'   group_order_input = "dose_group",
#'   labels = "count",
#'   title = "Mutation Frequency per Sample"
#' )
#' @import ggplot2
#' @importFrom dplyr arrange across all_of rename
#' @export
plot_mf <- function(mf_data,
                    group_col,
                    plot_type = "bar",
                    mf_type = "min",
                    fill_col = NULL,
                    custom_palette = NULL,
                    group_order = "none",
                    group_order_input = NULL,
                    labels = "count",
                    scale_y_axis = "linear",
                    x_lab = NULL,
                    y_lab = NULL,
                    title = NULL,
                    rotate_labels = FALSE,
                    label_size = 3) {
  
  stopifnot(
      "mf_data is required." = !missing(mf_data),
      "mf_data must be a data frame." = is.data.frame(mf_data),
      "group_col is required." = !missing(group_col),
      "rotate_labels must be a logical value." = is.logical(rotate_labels)
  )
  group_col <- match.arg(
      arg = group_col,
      choices = colnames(mf_data),
      several.ok = TRUE
  )
  if (!is.null(fill_col)) {
    fill_col <- match.arg(
      arg = fill_col,
      choices = colnames(mf_data),
      several.ok = TRUE
    )
  }
    plot_type <- match.arg(
        arg = plot_type,
        choices = c("bar", "point")
    )
    mf_type <- match.arg(
        arg = mf_type,
        choices = c("min", "max", "both", "stacked")
    )
    group_order <- match.arg(
        arg = group_order,
        choices = c("none", "smart", "arranged", "custom")
    )
    labels <- match.arg(
        arg = labels,
        choices = c("count", "MF", "none")
    )
    scale_y_axis <- match.arg(
        arg = scale_y_axis,
        choices = c("linear", "log")
    )

  if (group_order == "smart" && !requireNamespace("gtools", quietly = TRUE)) {
    stop("Package gtools is required when using the 'smart' group_order option. Please install the package using 'install.packages('gtools')'")
  }
  if (mf_type == "stacked" && plot_type == "point") {
    stop("The 'stacked' mutation frequency type is not compatible with the 'point' plot type.")
  }
  # axis_labels
  if (!is.null(x_lab)) {
    x_lab <- x_lab
  } else {
    x_lab <- group_col
  }
  if (!is.null(y_lab)) {
    y_lab <- y_lab
  } else {
    y_lab <- "Mutation Frequency (mutations/bp)"
  }

  # Sample order
  if (group_order == "none") {
    order <- as.vector(unique(mf_data[[group_col]]))
    mf_data[[group_col]] <- factor(mf_data[[group_col]])
  } else if (group_order == "smart") {
    order <- gtools::mixedsort(as.vector(unique(mf_data[[group_col]])))
    mf_data[[group_col]] <- factor(mf_data[[group_col]], levels = order)
  } else if (group_order == "arranged") {
    mf_data <- mf_data %>%
      dplyr::arrange(dplyr::across(dplyr::all_of({{ group_order_input }})))
    order <- as.vector(unique(mf_data[[group_col]]))
    mf_data[[group_col]] <- factor(mf_data[[group_col]], levels = order)
  } else if (group_order == "custom") {
    mf_data[[group_col]] <- factor(mf_data[[group_col]],
      levels = group_order_input
    )
  }

  if (mf_type %in% c("min", "max")) {
    # response column
    response_col <- paste0("mf_", mf_type)

    # sum column
    found_count_col <- paste0("sum_", mf_type)

    plot_data <- mf_data %>%
      dplyr::rename(group_col = dplyr::all_of(group_col)) %>%
      dplyr::rename(mf_col = dplyr::all_of(response_col)) %>%
      dplyr::rename(sum_col = dplyr::all_of(found_count_col))
    max_y <- max(plot_data$mf_col) * 1.1
  } else {
    plot_data <- mf_data %>%
      dplyr::rename(group_col = dplyr::all_of(group_col))

    if (mf_type == "stacked") {
      plot_data <- transform(plot_data, mf_max = plot_data$mf_max - plot_data$mf_min)
      max_y <- max(plot_data$mf_min + plot_data$mf_max) * 1.1
    } else {
      max_y <- max(plot_data$mf_max) * 1.1
    }
    # pivot long
    plot_data <- reshape(plot_data,
      varying = list(
        c("sum_min", "sum_max"),
        c("mf_min", "mf_max")
      ),
      v.names = c("sum_col", "mf_col"),
      times = c("min", "max"),
      timevar = "mf_type",
      direction = "long"
    )
    if (mf_type == "both") {
      plot_data$mf_type <- factor(plot_data$mf_type, levels = c("min", "max"))
    }
    if (mf_type == "stacked") {
      plot_data$mf_type <- factor(plot_data$mf_type, levels = c("max", "min"))
    }
  }

  # fill column
  if (!is.null(fill_col)) { # if fill col exists
    if (fill_col == group_col) { # if fill col is the same as group col
      plot_data$fill_col <- plot_data$group_col
    } else { # if fill col is different from group col
      plot_data <- dplyr::rename(plot_data, fill_col = dplyr::all_of(fill_col))
    }

    if (mf_type %in% c("both", "stacked")) { # if mf_type is both or stacked
      fill <- interaction(plot_data$mf_type, plot_data$fill_col)
      fill_label <- paste("MF Type and", fill_col)
    } else { # if mf_type is min or max
      fill <- plot_data$fill_col
      fill_label <- paste(fill_col)
    }
  } else { # if fill col is NULL
    if (mf_type %in% c("both", "stacked")) {
      fill <- plot_data$mf_type
      fill_label <- "MF Type"
      plot_data$fill_col <- ""
    } else { # if mf_type is min or max
      fill <- plot_data$fill_col <- ""
      fill_label <- NULL
    }
  }

  # labels
  if (labels == "count") {
    label <- plot_data$sum_col
  } else if (labels == "MF") {
    label <- sprintf("%.2e", plot_data$mf_col)
  } else if (labels == "none") {
    label <- ""
  }

  # scale y axis
  if (scale_y_axis == "log") {
    yscale <- ggplot2::scale_y_log10()
  } else {
    yscale <- ggplot2::scale_y_continuous(limits = c(0, max_y))
  }

  # Position
  if (mf_type == "both" && plot_type == "bar") {
    position <- "dodge"
    label_position <- ggplot2::position_dodge(width = 0.9)
  } else if (mf_type == "stacked" && plot_type == "bar") {
    position <- "stack"
    label_position <- ggplot2::position_stack(vjust = 0.5)
  } else {
    position <- "identity"
    label_position <- "identity"
  }

  # Title
  if (is.null(title)) {
    if (mf_type %in% c("stacked", "both")) {
      title <- paste0("min and max Mutation Frequency per ", group_col)
    } else if (mf_type %in% c("min", "max")) {
      title <- paste0(mf_type, " mutation frequency per ", group_col)
    }
  } else {
    title <- title
  }

  # palette
  if (is.null(custom_palette)) {
    if (mf_type %in% c("both", "stacked")) {
      n_colors <- length(unique(plot_data$fill_col)) * 2
    } else if (mf_type %in% c("min", "max")) {
      n_colors <- length(unique(plot_data$fill_col))
    }
    gradient <- colorRampPalette(colors = c(
      "#c5e5fc",
      "#5ab2ee",
      "#12587b",
      "#263247",
      "#ffedef",
      "#ffb9c1",
      "#ff5264",
      "#b23946"
    ))
    palette <- gradient(n_colors)
  } else {
    palette <- custom_palette
  }

  # define the plot type
  if (plot_type == "bar") {
    type <- ggplot2::geom_bar(
      stat = "identity",
      position = position,
      color = "black"
    )
    # Set label params
    if (rotate_labels) {
      label_angle <- 90
      vjust <- 0.5
      hjust <- -0.5
    } else {
      label_angle <- 0
      vjust <- -0.5
      hjust <- 0.5
    }
    labels <- ggplot2::geom_text(
      ggplot2::aes(label = label),
      position = label_position,
      vjust = vjust,
      hjust = hjust,
      size = label_size,
      color = "black",
      angle = label_angle
    )
  } else if (plot_type == "point") {
    pos <- ggplot2::position_jitter(
      width = 0.1,
      height = 0,
      seed = 123
    )
    type <- ggplot2::geom_point(
      shape = 21,
      size = 3,
      color = "black",
      position = pos
    )
    # label parameters
    if (rotate_labels) {
      label_angle <- 90
    } else {
      label_angle <- 0
    }
    labels <- ggrepel::geom_text_repel(
      aes(label = label),
      angle = label_angle,
      size = label_size,
      color = "black",
      position = pos,
      max.overlaps = Inf
    )
  }

  # Create the plot
  plot <- ggplot2::ggplot(
    plot_data,
    ggplot2::aes(
      x = plot_data$group_col,
      y = plot_data$mf_col,
      fill = factor(fill)
    )
  ) +
    type +
    labels +
    yscale +
    ggplot2::labs(
      title = title,
      fill = fill_label
    ) +
    ggplot2::ylab(y_lab) +
    ggplot2::xlab(x_lab) +
    ggplot2::theme(
      axis.text.x = ggplot2::element_text(
        angle = 90,
        hjust = 1,
        vjust = 0.5
      ),
      panel.background = ggplot2::element_blank(),
      axis.line = ggplot2::element_line()
    ) +
    ggplot2::scale_fill_manual(values = palette)

  return(plot)
}
