#' Plot distribution of observed values
#' 
#' @param se A \code{SummarizedExperiment} object, typically generated by 
#'     \code{summarizeExperiment()}.
#' @param selAssay Character scalar specifying the assay in \code{se} to 
#'     use for the plotting.
#' @param groupBy Character scalar specifying a column from 
#'     \code{colData(se)} to use for coloring or stratifying the plots. 
#' @param plotType Character scalar specifying the type of plot to construct. 
#'     Either \code{'density'}, \code{'histogram'} or \code{'knee'}.
#' @param facet Logical scalar, indicating whether or not to facet the plot
#'     by the values specified in the \code{groupBy} column.
#' @param pseudocount Numeric scalar, representing the number to add to the 
#'     observed values in the \code{selAssay} assay before plotting. 
#' 
#' @export
#' @author Charlotte Soneson
#' 
#' @return A ggplot object.
#' 
#' @importFrom tibble rownames_to_column
#' @importFrom tidyr gather
#' @importFrom dplyr group_by arrange mutate desc ungroup left_join
#' @importFrom SummarizedExperiment colData assay assayNames
#' @importFrom ggplot2 ggplot scale_x_log10 scale_y_log10 labs geom_line 
#'     facet_wrap geom_density geom_histogram theme_minimal theme 
#'     element_text aes
#' @importFrom rlang .data
#' 
#' @examples 
#' se <- readRDS(system.file("extdata", "GSE102901_cis_se.rds", 
#'                           package = "mutscan"))[1:200, ]
#' plotDistributions(se)
#' 
plotDistributions <- function(se, selAssay = "counts", 
                              groupBy = NULL, plotType = "density", 
                              facet = FALSE, pseudocount = 0) {
    .assertVector(x = se, type = "SummarizedExperiment")
    .assertScalar(x = selAssay, type = "character", 
                  validValues = assayNames(se))
    if (!is.null(groupBy)) {
        .assertScalar(
            x = groupBy, type = "character",
            validValues = colnames(colData(se)))
    }
    .assertScalar(x = plotType, type = "character", 
                  validValues = c("density", "knee", "histogram"))
    .assertScalar(x = facet, type = "logical")
    .assertScalar(x = pseudocount, type = "numeric", rngIncl = c(0, Inf))
    
    ## Define a common theme to use for the plots
    commonTheme <- list(
        theme_minimal(),
        theme(axis.text = element_text(size = 12),
              axis.title = element_text(size = 14))
    )
    
    df <- as.data.frame(as.matrix(
        assay(se, selAssay, withDimnames = TRUE)
    )) |>
        rownames_to_column("feature") |>
        gather(key = "Name", value = "value", -"feature") |>
        group_by(.data$Name) |>
        arrange(desc(.data$value)) |>
        mutate(idx = seq_along(.data$value), 
               value = .data$value + pseudocount) |>
        ungroup() |>
        left_join(as.data.frame(colData(se)),
                  by = "Name")
    
    ## If the user doesn't explicitly group by any variable, impose grouping
    ## by the sample ID. In that case, don't color by sample ID if facetting
    ## is used (only one curve per facet). If a variable to group by is 
    ## specified, color by sample ID even if facetting is used. 
    if (is.null(groupBy)) {
        groupBy <- "Name"
        colorFacetByName <- FALSE
    } else {
        colorFacetByName <- TRUE
    }
    
    ## Specify plot depending on desired type
    if (plotType == "knee") {
        gg <- ggplot(df, aes(x = .data$idx, 
                             y = .data$value)) + 
            scale_x_log10() + scale_y_log10() + 
            labs(x = "Feature (sorted)", 
                 y = paste0(selAssay, 
                            ifelse(pseudocount == 0, 
                                   "", paste0(" + ", pseudocount))))
        if (facet) {
            if (colorFacetByName) {
                gg <- gg + geom_line(aes(color = .data$Name))
            } else {
                gg <- gg + geom_line(aes(group = .data$Name))
            }
            gg <- gg + 
                facet_wrap(~ .data[[groupBy]])
        } else {
            gg <- gg + 
                geom_line(aes(group = .data$Name, 
                              color = .data[[groupBy]]))
        }
    } else if (plotType == "density") {
        gg <- ggplot(df, aes(x = .data$value)) + 
            scale_x_log10() + 
            labs(x = paste0(selAssay, 
                            ifelse(pseudocount == 0, 
                                   "", paste0(" + ", pseudocount))),
                 y = "Density")
        if (facet) {
            if (colorFacetByName) {
                gg <- gg + 
                    geom_density(aes(color = .data$Name))
            } else {
                gg <- gg + 
                    geom_density(aes(group = .data$Name))
            }
            gg <- gg + 
                facet_wrap(~ .data[[groupBy]])
        } else {
            gg <- gg + 
                geom_density(aes(group = .data$Name, 
                                 color = .data[[groupBy]]))
        }
    } else if (plotType == "histogram") {
        gg <- ggplot(df, aes(x = .data$value)) + 
            scale_x_log10() + 
            labs(x = paste0(selAssay, 
                            ifelse(pseudocount == 0, 
                                   "", paste0(" + ", pseudocount))),
                 y = "Count")
        if (facet) {
            if (colorFacetByName) {
                gg <- gg + 
                    geom_histogram(aes(fill = .data$Name), 
                                   bins = 50)
            } else {
                gg <- gg + 
                    geom_histogram(aes(group = .data$Name), 
                                   bins = 50)
            }
            gg <- gg +
                facet_wrap(~ .data[[groupBy]])
        } else {
            gg <- gg + 
                geom_histogram(aes(group = .data$Name,
                                   fill = .data[[groupBy]]))
        }
    }
    
    gg + commonTheme
}
