#' Plot spectra
#'
#' @description Given mf data, construct a plot displaying the
#' mutation subtypes observed in a cohort.
#' @param mf_data A data frame containing the mutation frequency data at the
#' desired subtype resolution. This is obtained using the 'calculate_mf'
#' function with subtype_resolution set to the desired resolution.
#' Data must include a column containing the group_col,
#' a column containing the mutation subtypes, a column containing the desired
#' response variable (mf, proportion, sum) for the desired mf_type
#' (min or max), and if applicable, a column containing the variable by which
#' to order the samples/groups.
#' @param group_col The name of the column(s) in the mf data that contains the
#' sample/group names. This will generally be the same values used for the
#' cols_to_group argument in the calculate_mf function. However, you may
#' also use groups that are at a higher level of the aggregation in mf_data.
#' @param subtype_resolution The subtype resolution of the mf data.
#' Options are `base_6`, `base_12`, `base_96`, `base_192`, or `type`.
#' Default is `base_6`.
#' @param response The desired response variable to be plotted. Options are
#' mf, proportion, or sum. Default is `proportion`. Your mf_data must contain
#' columns with the name of your desired response: `mf_min`, `mf_max`,
#' `proportion_min`, `proportion_max`, `sum_min`, and `sum_max`.
#' @param mf_type The mutation counting method to use. Options are min or max.
#' Default is `min`.
#' @param group_order The method for ordering the samples within the plot.
#' Options include:
#' \itemize{
#'   \item `none`: No ordering is performed. Default.
#'   \item `smart`: Groups are automatically ordered based on the group names
#' (alphabetical, numerical)
#'   \item `arranged`: Groups are ordered based on one or more factor column(s)
#' in mf_data. Column names are passed to the function using the
#' `group_order_input`.
#'  \item `custom`: Groups are ordered based on a custom vector of group
#' names. The custom vector is passed to the function using the
#' `group_order_input`.
#' \item `clustered`: Groups are ordered based on hierarchical clustering. The
#' dissimilarity matrix can be specified using the `dist` argument. The
#' agglomeration method can be specified using the `cluster_method` argument.
#' }
#' @param group_order_input A character vector specifying details for the
#' group order method. If `group_order` is `arranged`, `group_order_input`
#' should contain the column name(s) to be used for ordering the samples. If
#' `group_order` is `custom`, `group_order_input` should contain the custom
#' vector of group names.
#' @param dist  The dissimilarity matrix for hierarchical clustering. Options
#' are `cosine`, `euclidean`, `maximum`, `manhattan`, `canberra`, `binary` or
#' `minkowski`. The default is `cosine`. See \link[stats]{dist} for details.
#' @param cluster_method The agglomeration method for hierarchical clustering.
#' Options are `ward.D`, `ward.D2`, `single`, `complete`, `average` (= UPGMA),
#' `mcquitty` (= WPGMA), `median` (= WPGMC) or `centroid` (= UPGMC). The default
#' is `Ward.D`. See \link[stats]{hclust} for details.
#' @param custom_palette A named vector of colors to be used for the mutation
#' subtypes. The names of the vector should correspond to the mutation subtypes
#' in the data. Alternatively, you can specify a color palette from the
#' RColorBrewer package. See \code{\link[RColorBrewer]{brewer.pal}} for palette
#' options. You may visualize the palettes at the ColorBrewer website:
#' \url{https://colorbrewer2.org/}. Default is `NULL`.
#' @param x_lab The label for the x-axis. Default is the value of `group_col`.
#' @param y_lab The label for the y-axis. Default is the value of `response_col`.
#' @param rotate_xlabs A logical value indicating whether the x-axis labels
#' should be rotated 90 degrees. Default is FALSE.
#' @import ggplot2
#' @import ggdendro
#' @importFrom dplyr select arrange across all_of
#' @export
#' @examples
#' # Example data consists of 24 mouse bone marrow DNA samples imported
#' # using import_mut_data() and filtered with filter_mut. Filtered
#' # mutation data is available in the MutSeqRData ExperimentHub package:
#' # eh <- ExperimentHub::ExperimentHub()
#' # Example 1: Visualized the 6-base mutation proportions per dose group.
#' # Data was summarized per dose_group using:
#' # calculate_mf(mutation_data = eh[["EH9861"]],
#' #              cols_to_group = "dose_group",
#' #              subtype_resolution = "base_6")
#' # Load the example data
#' mf_example <- readRDS(system.file("extdata", "Example_files", "mf_data_6.rds",
#'   package = "MutSeqR"
#' ))
#' # Convert dose_group to a factor with the desired order.
#' mf_example$dose_group <- factor(mf_example$dose_group,
#'   levels = c("Control", "Low", "Medium", "High")
#' )
#' # Plot the mutation spectra
#' plot <- plot_spectra(
#'   mf_data = mf_example,
#'   group_col = "dose_group",
#'   subtype_resolution = "base_6",
#'   response = "proportion",
#'   group_order = "arranged",
#'   group_order_input = "dose_group"
#' )
#'
#' # Example 2: plot the proportion of 6-based mutation subtypes
#' # for each sample, ordered by hierarchical clustering:
#' # Data was summarized per dose_group using:
#' # calculate_mf(mutation_data = eh[["EH9861"]],
#' #              cols_to_group = "sample",
#' #              subtype_resolution = "base_6")
#' # Load the example data
#' mf_example2 <- readRDS(system.file("extdata", "Example_files", "mf_data_6_sample.rds",
#'   package = "MutSeqR"
#' ))
#' plot <- plot_spectra(
#'   mf_data = mf_example2,
#'   group_col = "sample",
#'   subtype_resolution = "base_6",
#'   response = "proportion",
#'   group_order = "clustered"
#' )
#' @return A ggplot object representing the mutation spectra plot.
plot_spectra <- function(mf_data,
                         group_col = "sample",
                         subtype_resolution = "base_6",
                         response = "proportion",
                         mf_type = "min",
                         group_order = "none",
                         group_order_input = NULL,
                         dist = "cosine",
                         cluster_method = "ward.D",
                         custom_palette = NULL,
                         x_lab = NULL,
                         y_lab = NULL,
                         rotate_xlabs = FALSE) {
  
    stopifnot(
        !missing(mf_data) && is.data.frame(mf_data),
        is.logical(rotate_xlabs)
    )
    subtype_resolution <- match.arg(subtype_resolution,
        choices = c("base_6", "base_12", "base_96", "base_192", "type")
    )
        mf_type <- match.arg(mf_type, choices = c("min", "max"))
    group_order <- match.arg(group_order,
        choices = c("none", "smart", "arranged", "custom", "clustered")
    )
    response <- match.arg(response, choices = c("proportion", "mf", "sum"))
    dist <- match.arg(dist,
        choices = c(
            "cosine", "euclidean", "maximum",
            "manhattan", "canberra", "binary", "minkowski"
        )
    )
    cluster_method <- match.arg(cluster_method,
        choices = c(
            "ward.D", "ward.D2", "single", "complete",
            "average", "mcquitty", "median", "centroid"
        )
    )

    # check package dependencies
    if (!requireNamespace("patchwork", quietly = TRUE)) stop("Package patchwork is required.")
    if (!requireNamespace("RColorBrewer", quietly = TRUE)) stop("Package RColorBrewer is required.")
    if (group_order == "clustered" && !requireNamespace("ggdendro", quietly = TRUE)) {
        stop("Package ggdendro is required for clustered ordering.")
    }
    if (group_order == "smart" && !requireNamespace("gtools", quietly = TRUE)) {
        stop("Package gtools is required when using the group_order = 'smart'")
    }

    # Data Setup
    response_col <- paste0(response, "_", mf_type)
    if (length(group_col) > 1) { # concat +1 group cols
        mf_data$group <- do.call(paste, c(mf_data[group_col], sep = "_"))
    } else {
        mf_data$group <- mf_data[[group_col]]
    }
    target_subtype_col <- MutSeqR::subtype_dict[[subtype_resolution]]
    plot_data <- mf_data %>%
        dplyr::select(
            "group",
            subtype = dplyr::all_of(target_subtype_col),
            response = dplyr::all_of(response_col),
            dplyr::any_of(group_order_input)
        )

    # Group Ordering
    if (group_order == "none") {
        plot_data$group <- factor(plot_data$group, levels = unique(plot_data$group))
    } else if (group_order == "smart") {
        ord <- gtools::mixedsort(unique(as.character(plot_data$group)))
        plot_data$group <- factor(plot_data$group, levels = ord)
    } else if (group_order == "arranged") {
        ord_data <- plot_data %>%
        dplyr::arrange(!!!rlang::syms(group_order_input))
        plot_data$group <- factor(plot_data$group, levels = unique(ord_data$group))
    } else if (group_order == "custom") {
        plot_data$group <- factor(plot_data$group, levels = group_order_input)  
    } else if (group_order == "clustered") {
        hc <- cluster_spectra(
            mf_data = plot_data,
            group_col = "group",
            response_col = "response",
            subtype_col = "subtype",
            dist = dist,
            cluster_method = cluster_method
        )
        plot_data$group <- factor(plot_data$group, levels = hc$labels[hc$order])
    }

    # Subtype Factor Levels
    if (subtype_resolution != "type") {
        subtype_order <- c(
            MutSeqR::subtype_list$type,
            rev(MutSeqR::subtype_list[[subtype_resolution]])
        )
    } else {
        subtype_order <- MutSeqR::subtype_list$type
    }
    plot_data$subtype <- factor(plot_data$subtype, levels = subtype_order)

    # Palette
    palette <- get_mutation_palette(
        custom_palette = custom_palette,
        subtype_resolution = subtype_resolution,
        num_colours = length(unique(plot_data$subtype))
    )

    # Plot Construction
  
    # Define Axis Labels
    if (is.null(x_lab)) x_lab <- paste(group_col, collapse = "_")
    if (is.null(y_lab)) y_lab <- response_col
    axis_labels <- ggplot2::labs(x = x_lab, y = y_lab)
    angle <- if (rotate_xlabs) 90 else 0

    # Define Common Theme (Avoid code duplication)
    common_theme <- theme(
        legend.position = "right",
        panel.background = element_rect(fill = "white", colour = NA),
        plot.background = element_rect(fill = "white", colour = NA),
        legend.background = element_rect(fill = "white", colour = NA),
        strip.background = element_rect(fill = "white", colour = NA),
        axis.text.x = element_text(angle = angle, hjust = if(angle==90) 1 else 0.5),
        axis.line.y = element_line(color = "gray"),
        axis.line.x.bottom = element_line(color = "black"),
        axis.line.x.top = element_blank(),
        axis.ticks.y = element_line(color = "gray"),
        panel.grid = element_blank()
    )

    # Logic for splitting SNV vs Non-SNV panels
    do_panels <- (subtype_resolution != "type") && any(MutSeqR::subtype_list$type %in% plot_data$subtype)
    legend_title <- if (do_panels) "SNV Subtype" else "Variation Type"

    # Filter data if panels are needed, otherwise use full data
    data_main <- if(do_panels) dplyr::filter(plot_data, !subtype %in% MutSeqR::subtype_list$type) else plot_data

    # Main Plot
    bar <- ggplot(data_main, aes(x = .data$group, y = .data$response, fill = .data$subtype)) +
        geom_bar(stat = "identity", width = 1) +
        scale_fill_manual(values = palette) +
        axis_labels +
        common_theme +
        scale_y_continuous(expand = expansion(mult = c(0, 0.01))) +
        labs(fill = legend_title)

    # Non-SNV Panel (Optional)
    bar_nonsnv <- NULL
    if (do_panels) {
        data_nonsnv <- dplyr::filter(plot_data, subtype %in% MutSeqR::subtype_list$type)
        # Check if there is actual non-snv data to plot
        if (nrow(data_nonsnv) > 0) {
        bar_nonsnv <- ggplot(data_nonsnv, aes(x = .data$group, y = .data$response, fill = .data$subtype)) +
            geom_bar(stat = "identity", width = 1) +
            scale_fill_manual(values = palette) +
            axis_labels +
            common_theme +
            scale_y_continuous(expand = expansion(mult = c(0, 0.01))) +
            labs(y = y_lab, fill = "Non-SNV Subtype")
        } else {
            do_panels <- FALSE # Fallback if no non-snvs found despite logic
        }
    }

    # Final Layout
  
    if (group_order == "clustered") {
        # Dendrogram
        dendro_plot <- ggdendro::ggdendrogram(hc, rotate = FALSE, labels = FALSE) + theme_void()
           # Theme to strip X axis from MIDDLE/TOP panels only
        clean_theme <- theme(
        axis.text.x = element_blank(), 
        axis.title.x = element_blank(), 
        axis.ticks.x = element_blank(),
        plot.margin = margin(t = 0, b = 0)
        )

        # Apply clean_theme ONLY to the Non-SNV plot (if it exists)
        if (!is.null(bar_nonsnv)) {
            bar_nonsnv <- bar_nonsnv + clean_theme
        }
        bar <- bar + theme(plot.margin = margin(t = 0))
        
        if (do_panels) {
            p <- patchwork::wrap_plots(dendro_plot, bar_nonsnv, bar, ncol = 1, heights = c(0.2, 1, 1)) +
                patchwork::plot_layout(guides = "collect")
        } else {
            p <- patchwork::wrap_plots(dendro_plot, bar, ncol = 1, heights = c(0.2, 1))
        }
        
    } else {
        # Standard Layout
        if (do_panels) {
            p <- patchwork::wrap_plots(bar_nonsnv, bar, ncol = 1) +
                patchwork::plot_layout(axis_titles = "collect", axes = "collect", guides = "keep")
        } else {
            p <- bar
        }
    }

  return(p)
}


#' Hierarchical Clustering
#' @description perform hierarchical clustering of samples
#' based on the mutation spectra.
#' @param mf_data A data frame containing the mutation data. This data must
#' include a column containing the mutation subtypes, a column containing
#' the sample/cohort names, and a column containing the response variable.
#' @param group_col The name of the column in data that contains the
#' sample/cohort names.
#' @param response_col The name of the column in data that contains the
#' response variable. Typical response variables can be the subtype mf,
#' proportion, or count.
#' @param subtype_col The name of the column in data that contains the
#' mutation subtypes.
#' @param dist the distance measure to be used.
#' This must be one of "cosine", "euclidean", "maximum",
#' "manhattan","canberra", "binary" or "minkowski". See
#' \link[stats]{dist} for details.
#' @param cluster_method The agglomeration method to be used. See
#' \link[stats]{hclust} for details.
#' @importFrom stats hclust dist as.dist
#' @importFrom dplyr select
#' @importFrom tidyr pivot_wider
#' @details The cosine distance measure represents the inverted cosine
#' similarity between samples:
#'
#' \eqn{\text{Cosine Dissimilarity} = 1 - \frac{\mathbf{A} \cdot \mathbf{B}}{\| \mathbf{A} \| \cdot \| \mathbf{B} \|}}
#'
#' This equation calculates the cosine dissimilarity between two vectors A and B.
#'
#' Leaves are sorted using dendsort, if installed, otherwise leaves are unsorted.
#' @return A dendrogram object representing the hierarchical clustering of the
#' samples.
cluster_spectra <- function(mf_data,
                            group_col = "sample",
                            response_col = "proportion_min",
                            subtype_col = "normalized_subtype",
                            dist = "cosine",
                            cluster_method = "ward.D") {
    wide_df <- mf_data %>%
        dplyr::select(dplyr::all_of(c(group_col, subtype_col, response_col))) %>%
        tidyr::pivot_wider(
            names_from = dplyr::all_of(subtype_col),
            values_from = dplyr::all_of(response_col),
            values_fill = 0
        )
    # Convert to matrix
    mat <- as.matrix(wide_df[, -1])
    rownames(mat) <- wide_df[[group_col]] # group col as rownames

    # Distance Calculation
    if (dist == "cosine") {
        # Sim(A,B) = (A . B) / (|A| * |B|)
        # Matrix Algebra: (Mat %*% t(Mat)) / (Mag %o% Mag)
        dot_products <- tcrossprod(mat) # Numerator
        magnitudes <- sqrt(diag(dot_products)) # denominator
        # Cosine Similarity Matrix
        cos_sim <- dot_products / outer(magnitudes, magnitudes) # (|A|*|B|)
        cos_sim[is.na(cos_sim)] <- 0 # Handle potential 0/0 NaNs
        d <- stats::as.dist(1 - cos_sim) # Convert to Dissimilarity
    } else {
        d <- stats::dist(mat, method = dist)
    }

    # Clustering
    hc <- stats::hclust(d, method = cluster_method)

    # Leaf Optimization (Optional)
    if (requireNamespace("dendsort", quietly = TRUE)) {
        hc <- dendsort::dendsort(hc)
    } else {
    warning("Package dendsort not installed; leaves not optimized.")
    }
    return(hc)
    }