#' Overrepresentation analysis
#'
#' Performs overrepresentation analysis given a set of components (factors or
#' modules) that are each linked to a list of genes.
#'
#' @author Jack Gisby
#'
#' @noRd
#' @keywords internal
reducedOA <- function(
    component_features, database = "msigdb_c2_cp",
    TERM2GENE = NULL, p_cutoff = 1, adj_method = "BH",
    min_genes = 3, universe = NULL,
    ...
) {
    TERM2GENE <- .getT2G(database, TERM2GENE)
    enrich_res <- list()

    for (comp in names(component_features)) {
        enrich_res_single <- clusterProfiler::enricher(
            component_features[[comp]],
            pvalueCutoff = 1, qvalueCutoff = 1,
            pAdjustMethod = adj_method,
            TERM2GENE = TERM2GENE,
            universe = universe,
            ...
        )

        enrich_res_single <- .formatEnrichRes(
            enrich_res_single,
            adj_method = adj_method,
            min_genes = min_genes,
            p_cutoff = p_cutoff
        )

        if (!is.null(enrich_res_single)) {
            if (nrow(enrich_res_single@result) >= 1) {
                enrich_res_single@result$method <- "overrepresentation"
                enrich_res_single@result$component <- comp
            }
        }

        enrich_res[[comp]] <- enrich_res_single
    }

    return(enrich_res)
}

#' Gets a TERM2GENE data object for clusterProfiler
#'
#' Gets a correctly formatted TERM2GENE object for use by the `clusterProfiler`
#' package for enrichment analysis.
#'
#' @author Jack Gisby
#'
#' @noRd
#' @keywords internal
.getT2G <- function(database, TERM2GENE) {

    if (!is.null(TERM2GENE)) {
        return(TERM2GENE)
    }

    if (is.null(database)) {
        TERM2GENE <- TERM2GENE
    } else if (is.data.frame(database)) {
        TERM2GENE <- database
    } else if (database == "msigdb_c2_cp") {
        TERM2GENE <- getMsigdbT2G()
    } else {
        stop("Database ", database, " not recognised")
    }

    return(TERM2GENE)
}

#' Reformats enrichment result
#'
#' Reformats `clusterProfiler` enrichment results into a consistent format.
#'
#' @author Jack Gisby
#'
#' @noRd
#' @keywords internal
.formatEnrichRes <- function(
    enrich_res_single, adj_method,
    p_cutoff, min_genes = NULL
) {
    if (is.null(enrich_res_single)) {
        return(NULL)
    }

    if (!is.null(min_genes)) {
        enrich_res_single@result <-
            enrich_res_single@result[which(enrich_res_single@result$Count
            >= min_genes), ]
    }

    enrich_res_single@result$p.adjust <-
        stats::p.adjust(enrich_res_single@result$pvalue, method = adj_method)

    enrich_res_single@result <-
        enrich_res_single@result[which(enrich_res_single@result$p.adjust
        <= p_cutoff), ]

    if (nrow(enrich_res_single@result) >= 1) {
        enrich_res_single@result$adj_method <- adj_method
        enrich_res_single@result$qvalue <- NA

        if ("pvalueCutoff" %in% slotNames(enrich_res_single)) {
            enrich_res_single@pvalueCutoff <- p_cutoff
        }
    }

    return(enrich_res_single)
}

#' Gene set enrichment analysis
#'
#' Performs \link[clusterProfiler]{GSEA} using the `clusterProfiler` package.
#' Expects a loadings matrix where features (genes) are rows and factors are
#' columns.
#'
#' @author Jack Gisby
#'
#' @noRd
#' @keywords internal
reducedGSEA <- function(
    S, database = "msigdb_c2_cp", TERM2GENE = NULL,
    p_cutoff = 1, adj_method = "BH", nPermSimple = 1000, eps = 1e-10,
    ...
) {
    TERM2GENE <- .getT2G(database, TERM2GENE)
    enrich_res <- list()

    for (comp in colnames(S)) {
        S_order <- order(S[, comp], decreasing = TRUE)
        comp_genes <- S[, comp][order(S[, comp], decreasing = TRUE)]
        names(comp_genes) <- rownames(S)[order(S[, comp], decreasing = TRUE)]

        enrich_res_single <- clusterProfiler::GSEA(
            comp_genes,
            pvalueCutoff = 1,
            pAdjustMethod = adj_method,
            TERM2GENE = TERM2GENE,
            nPermSimple = nPermSimple,
            eps = eps,
            ...
        )

        enrich_res_single <- .formatEnrichRes(
            enrich_res_single,
            adj_method = adj_method,
            p_cutoff = p_cutoff
        )

        if (!is.null(enrich_res_single)) {
            if (nrow(enrich_res_single@result) >= 1) {
                enrich_res_single@result$method <- "gsea"
                enrich_res_single@result$component <- comp
            }
        }

        enrich_res[[comp]] <- enrich_res_single
    }

    return(enrich_res)
}

#' Get TERM2GENE dataframe from MSigDB
#'
#' Gets pathways from the MSigDB database in the format required by
#' `clusterProfiler` enrichment functions, such as
#' \link[clusterProfiler]{enricher} and \link[clusterProfiler]{GSEA}.
#' May be used as input to \link[ReducedExperiment]{runEnrich}. By default,
#' retrieves the C2 canonical pathways.
#'
#' @param species The species for which to obtain MSigDB pathways. See
#' \link[msigdbr]{msigdbr} for more details.
#'
#' @param category The MSigDB category to retrieve pathways for. See
#' \link[msigdbr]{msigdbr} for more details.
#'
#' @param subcategory The MSigDB subcategory to retrieve pathways for. See
#' \link[msigdbr]{msigdbr} for more details.
#'
#' @param subcategory_to_remove If not NULL, this is a character string
#' indicating a subcategory to be removed from the results of
#' \link[msigdbr]{msigdbr}.
#'
#' @param gene_id The name to be given to the gene_id column of the resulting
#' data.frame.
#'
#' @returns Returns a data.frame, where the `gs_name` column indicates the name
#' of a pathway, and the `gene_id` column indicates genes that belong to
#' said pathway.
#'
#' @note
#' If the `msigdbdf` package is not installed, the function will only return
#' a subset of the full MSigDB pathways.
#'
#' @author Jack Gisby
#'
#' @examples
#' pathways <- getMsigdbT2G(
#'     species = "Homo sapiens",
#'     category = "C2",
#'     subcategory_to_remove = "CGP",
#'     gene_id = "ensembl_gene"
#' )
#'
#' # A data.frame indicating gene-pathway mappings for use in pathway analysis
#' head(pathways)
#'
#' @export
getMsigdbT2G <- function(
    species = "Homo sapiens",
    category = "C2",
    subcategory = NULL,
    subcategory_to_remove = "CGP",
    gene_id = "ensembl_gene"
) {

    t2g <- data.frame(msigdbr::msigdbr(
        species = species,
        collection = category,
        subcollection = subcategory
    ))

    # Changes depending on version
    if (!is.null(subcategory_to_remove)) {
        subcat_col <- ifelse("gs_subcollection" %in% colnames(t2g), "gs_subcollection", "gs_subcat")
        t2g <- t2g[which(t2g[[subcat_col]] != subcategory_to_remove), ]
    }

    t2g <- t2g[, c("gs_name", gene_id)]

    return(t2g)
}

#' Get common factor features
#'
#' Function to count how many genes are aligned with multiple factors.
#'
#' @param factor_features A `data.frame` as returned by
#' \link[ReducedExperiment]{getAlignedFeatures}.
#'
#' @returns A `data.frame` for each factor pair with the numbers and proportions
#' of the genes in the input that overlap.
#'
#' @seealso [ReducedExperiment::plotCommonFeatures()],
#' [ReducedExperiment::getAlignedFeatures()]
#'
#' @author Jack Gisby
#'
#' @examples
#' # Get a random matrix with rnorm, with 100 rows (features)
#' # and 20 columns (observations)
#' X <- ReducedExperiment:::.makeRandomData(100, 20, "feature", "obs")
#'
#' # Estimate 5 factors based on the data matrix
#' fe <- estimateFactors(X, nc = 5)
#'
#' # Get the genes highly aligned with each factor
#' aligned_features <- getAlignedFeatures(
#'     fe,
#'     format = "data.frame",
#'     proportional_threshold = 0.3
#' )
#'
#' # Identify overlap between common features for each factor
#' common_features <- getCommonFeatures(aligned_features)
#' head(common_features)
#'
#' @export
getCommonFeatures <- function(factor_features) {
    if (!is.data.frame(factor_features)) {
        stop("Factor features must be a data.frame")
    }

    common_features <- data.frame()

    for (c_1 in unique(factor_features$component)) {
        for (c_2 in unique(factor_features$component)) {
            total_feat_1 <- length(factor_features$feature[
                factor_features$component == c_1
            ])
            total_feat_2 <- length(factor_features$feature[
                factor_features$component == c_2
            ])
            smaller_total <- min(total_feat_1, total_feat_2)

            if (c_1 == c_2) {
                feat_intersect <- NA
            } else {
                feat_intersect <- length(intersect(
                    factor_features$feature[factor_features$component == c_1],
                    factor_features$feature[factor_features$component == c_2]
                ))
            }

            common_features_single <- data.frame(
                c_1 = c_1,
                c_2 = c_2,
                intersect = feat_intersect,
                total_feat_1 = length(factor_features$feature[
                    factor_features$component == c_1
                ]),
                total_feat_2 = length(factor_features$feature[
                    factor_features$component == c_2
                ])
            )

            common_features_single$smaller_total <-
                min(
                    common_features_single$total_feat_1,
                    common_features_single$total_feat_2
                )

            common_features_single$intersect_prop <-
                common_features_single$intersect /
                    common_features_single$smaller_total

            common_features <- rbind(common_features, common_features_single)
        }
    }

    return(common_features)
}

#' Heatmap comparing commonality across factors
#'
#' @param common_features The output of
#' \link[ReducedExperiment]{getCommonFeatures}.
#'
#' @param filename The path at which to save the plot.
#'
#' @param color The colour palette to be used in the heatmap.
#'
#' @returns An object generated by \link[pheatmap]{pheatmap}.
#'
#' @seealso [ReducedExperiment::getCommonFeatures()],
#'  [ReducedExperiment::getAlignedFeatures()]
#'
#' @author Jack Gisby
#'
#' @examples
#' # Get a random matrix with rnorm, with 100 rows (features)
#' # and 20 columns (observations)
#' X <- ReducedExperiment:::.makeRandomData(100, 20, "feature", "obs")
#'
#' # Estimate 5 factors based on the data matrix
#' fe <- estimateFactors(X, nc = 5)
#'
#' # Get the genes highly aligned with each factor
#' aligned_features <- getAlignedFeatures(
#'     fe,
#'     format = "data.frame",
#'     proportional_threshold = 0.3
#' )
#'
#' # Identify overlap between common features for each factor
#' common_features <- getCommonFeatures(aligned_features)
#'
#' # Plot the common features as a heatmap
#' plotCommonFeatures(common_features)
#'
#' @export
plotCommonFeatures <- function(
    common_features,
    filename = NA,
    color = grDevices::colorRampPalette(RColorBrewer::brewer.pal(
        n = 7,
        name = "YlOrRd"
    ))(100)
) {
    common_features <- subset(common_features,
        select = c("c_1", "c_2", "intersect_prop")
    )

    prop_mat <- stats::reshape(
        common_features,
        idvar = "c_1",
        v.names = "intersect_prop",
        timevar = "c_2",
        direction = "wide",
        sep = "_"
    )

    rownames(prop_mat) <- prop_mat$c_1
    prop_mat <- prop_mat[, -1]
    colnames(prop_mat) <- gsub("intersect_prop_", "", colnames(prop_mat))

    max_abs <- max(abs(prop_mat), na.rm = TRUE)

    common_hmap <- pheatmap::pheatmap(
        prop_mat,
        na_col = "grey",
        filename = filename,
        color = color
    )

    return(common_hmap)
}


#' Get module preservation statistics
#'
#' Tests whether a set of modules defined in the reference dataset are
#' preserved in the test dataset. Provides a convenient wrapper
#' around \link[WGCNA]{modulePreservation} for
#' \link[ReducedExperiment]{ModularExperiment} and
#' \link[SummarizedExperiment]{SummarizedExperiment}
#' objects.
#'
#' @param reference_dataset The dataset that was used to define the modules.
#' Must be a `data.frame` or `matrix` with features as rows and samples as
#' columns, or a \link[ReducedExperiment]{ModularExperiment} or
#' \link[SummarizedExperiment]{SummarizedExperiment} object.
#'
#' @param test_dataset The dataset that will be used to test for module
#' preservation. Must be a `data.frame` or `matrix` with features as rows and
#' samples as columns, or a \link[SummarizedExperiment]{SummarizedExperiment}
#' object. The features of `test_dataset` should be the same as
#' `reference_dataset` and in the same order.
#'
#' @param reference_assay_name If the reference dataset is a
#' \link[ReducedExperiment]{ModularExperiment} or
#' \link[SummarizedExperiment]{SummarizedExperiment} object, this argument
#' specifies which assay slot was used to define the modules.
#'
#' @param test_assay_name If the reference dataset is a
#' \link[ReducedExperiment]{ModularExperiment} or
#' \link[SummarizedExperiment]{SummarizedExperiment} object, this argument
#' specifies which assay slot is to be used in preservation tests.
#'
#' @param module_assignments If the reference dataset is not a
#' \link[ReducedExperiment]{ModularExperiment} object, this argument is
#' necessary to specify the module assignments.
#'
#' @param greyName The name of the "module" of unassigned genes. Usually
#' "module_0" (ReducedExperiment default) or "grey" (WGCNA default). See
#' \link[WGCNA]{modulePreservation}.
#'
#' @param goldName The name to be used for the "gold" module (which is made up
#' of a random sample of all network genes). See
#' \link[WGCNA]{modulePreservation}.
#'
#' @param networkType A string referring to the type of WGCNA network used for
#' the reference and test datasets. One of"unsigned", "signed" or
#' "signed hybrid". See \link[WGCNA]{adjacency}. See
#' \link[WGCNA]{modulePreservation}.
#'
#' @param corFnc A string referring to the function to be used to calculate
#' correlation. One of "cor" or "bicor". See
#' \link[WGCNA]{modulePreservation}.
#'
#' @param savePermutedStatistics If `TRUE`, saves the permutation statistics
#' as a .RData file. See \link[WGCNA]{modulePreservation}.
#'
#' @param ... Additional arguments to be passed to
#' \link[WGCNA]{modulePreservation}.
#'
#' @returns A `data.frame` containing preservation statistics, as described
#' by \link[WGCNA]{modulePreservation}.
#'
#' @author Jack Gisby
#'
#' @examples
#' # Get random ModularExperiments with rnorm, with 100 rows (features),
#' # 20 columns (observations) and 5/10 modules
#' me_1 <- ReducedExperiment:::.createRandomisedModularExperiment(100, 20, 5)
#' me_2 <- ReducedExperiment:::.createRandomisedModularExperiment(100, 20, 10)
#'
#' # Test module preservation (test modules from dataset 1 in dataset 2)
#' mp <- modulePreservation(me_1, me_2, verbose = 0, nPermutations = 3)
#'
#' @export
modulePreservation <- function(
    reference_dataset, test_dataset,
    reference_assay_name = "normal", test_assay_name = "normal",
    module_assignments = NULL, greyName = "module_0", goldName = "random",
    networkType = "signed", corFnc = "cor", savePermutedStatistics = FALSE,
    ...
) {
    if (inherits(reference_dataset, "ModularExperiment")) {
        module_assignments <- assignments(reference_dataset)
    } else if (is.null(module_assignments)) {
        stop(
            "If reference_dataset is not a ModularExperiment, ",
            "module_assignments must not be NULL"
        )
    }

    if (inherits(reference_dataset, "SummarizedExperiment")) {
        reference_dataset <- assay(reference_dataset, reference_assay_name)
    }

    if (inherits(test_dataset, "SummarizedExperiment")) {
        test_dataset <- assay(test_dataset, test_assay_name)
    }

    if (!identical(rownames(reference_dataset), rownames(test_dataset))) {
        stop("Rownames of reference_dataset do not match those of test_dataset")
    }

    multi_data <- list(
        "reference" = list("data" = t(reference_dataset)),
        "test" = list("data" = t(test_dataset))
    )

    return(WGCNA::modulePreservation(
        multi_data,
        list("reference" = stats::setNames(
            names(module_assignments),
            module_assignments
        )),
        dataIsExpr = TRUE,
        networkType = networkType,
        corFnc = corFnc,
        goldName = goldName,
        greyName = greyName,
        savePermutedStatistics = savePermutedStatistics,
        ...
    ))
}

#' Plot module preservation statistics
#'
#' @param modulePreservation_results The output of
#' \link[ReducedExperiment]{modulePreservation}
#'
#' @param show_random If `TRUE`, shows the random module in the plots.
#'
#' @param remove_module The name of a module to be hidden from the plots.
#'
#' @returns Two `ggplot2` plot objects combined by patchwork. Plots the
#' module preservation statistics generated by
#' \link[ReducedExperiment]{modulePreservation}.
#'
#' @author Jack Gisby
#'
#' @examples
#' # Get random ModularExperiments with rnorm, with 100 rows (features),
#' # 20 columns (observations) and 5/10 modules
#' me_1 <- ReducedExperiment:::.createRandomisedModularExperiment(100, 20, 5)
#' me_2 <- ReducedExperiment:::.createRandomisedModularExperiment(100, 20, 10)
#'
#' # Test module preservation (test modules from dataset 1 in dataset 2)
#' mp <- modulePreservation(me_1, me_2, verbose = 0, nPermutations = 3)
#'
#' # No significant preservation, since these were random modules
#' plotModulePreservation(mp)
#'
#' @import ggplot2
#' @import patchwork
#' @export
plotModulePreservation <- function(
    modulePreservation_results, show_random = TRUE, remove_module = NULL
) {
    mr_df <- modulePreservation_results$preservation$observed$ref.reference$
        inColumnsAlsoPresentIn.test
    zs_df <- modulePreservation_results$preservation$Z$ref.reference$
        inColumnsAlsoPresentIn.test

    mr_df$module <- rownames(mr_df)
    zs_df$module <- rownames(zs_df)

    if (show_random) {
        mr_gold <- mr_df$medianRank.pres[mr_df$module == "random"]
        zs_gold <- zs_df$Zsummary.pres[zs_df$module == "random"]
    } else {
        mr_df <- mr_df[which(mr_df$module != "random"), ]
        zs_df <- zs_df[which(zs_df$module != "random"), ]
    }

    if (!is.null(remove_module)) {
        mr_df <- mr_df[which(mr_df$module != remove_module), ]
        zs_df <- zs_df[which(zs_df$module != remove_module), ]
    }

    max_module_size <- max(mr_df$moduleSize)
    nudge_mr <- 0.14
    nudge_zs <- 0.4

    medianrank_plot <- .makeMedianRankPlot(mr_df, max_module_size, nudge_mr)
    zsummary_plot <- .makeZSummaryPlot(zs_df, max_module_size, nudge_zs)

    if (show_random) {
        medianrank_plot <- medianrank_plot + geom_hline(
            yintercept = mr_gold,
            col = "gold", linetype = "dashed"
        )
        zsummary_plot <- zsummary_plot + geom_hline(
            yintercept = zs_gold,
            col = "gold", linetype = "dashed"
        )
    }

    return(medianrank_plot + zsummary_plot)
}

#' Makes a median rank plot to visualise module preservation
#'
#' @author Jack Gisby
#'
#' @noRd
#' @keywords internal
.makeMedianRankPlot <- function(mr_df, max_module_size, nudge_mr) {
    ggplot(mr_df, aes(
        !!sym("moduleSize"),
        !!sym("medianRank.pres"),
        col = !!sym("module")
    )) +
        geom_point(size = 3) +
        geom_text(
            aes(
                label = !!sym("module")
            ),
            col = "black",
            nudge_y = nudge_mr,
            hjust = 0
        ) +
        theme(legend.position = "none") +
        xlim(c(0, max_module_size * 1.3)) +
        expand_limits(y = 0)
}

#' Makes a Zsummary plot to visualise module preservation
#'
#' @author Jack Gisby
#'
#' @noRd
#' @keywords internal
.makeZSummaryPlot <- function(zs_df, max_module_size, nudge_zs) {
    ggplot(zs_df, aes(
        !!sym("moduleSize"),
        !!sym("Zsummary.pres"),
        col = !!sym("module")
    )) +
        geom_point(size = 3) +
        geom_text(
            aes(
                label = !!sym("module")
            ),
            col = "black",
            nudge_y = nudge_zs,
            hjust = 0
        ) +
        theme(legend.position = "none") +
        xlim(c(0, max_module_size * 1.3)) +
        expand_limits(y = 0) +
        geom_hline(yintercept = 10, col = "green", linetype = "dashed") +
        geom_hline(yintercept = 2, col = "blue", linetype = "dashed")
}
