#' Evaluate n_bins and n_genes for bin mapping
#'
#' @description Will use the n_bins and n_genes implied by the `sce` and
#' `pseudotime_bins_top_n_genes_df` parameters and return quality metrics and
#' an optional chart.
#'
#' @concept tuning
#'
#' @param blase_data The [BlaseData] object to use.
#' @param bootstrap_iterations Integer. Iterations for
#' bootstrapping when calculating confident mappings.
#' @param BPPARAM The [BiocParallel::BiocParallelParam] configuration.
#' Defaults to [BiocParallel::SerialParam]
#' @param make_plot Boolean. Whether or not to render the plot
#' showing the correlations for each pseudobulk bin when we try
#' to map the given bin.
#' @param plot_columns Integer. How many columns to use in the plot.
#'
#'
#' @return A vector of length 3:
#' * "worst top 2 distance" decimal containing the lowest difference between the
#'  absolute values of the top 2 most correlated bins for each bin.
#'  Higher is better for differentiating.
#' * "mean top 2 distance" decimal containing the mean top 2 distance across the
#'  entire set of genes and bins. Higher is better for differentiation,
#'  but it should matter less than the worst value.
#' * "confident_mapping_pct" decimal from 0-1. The percent of
#'  mappings for this setup which were annotated as confident by BLASE.
#'
#' @importFrom BiocParallel SerialParam
#'
#' @export
#'
#' @examples
#' ncells <- 70
#' ngenes <- 100
#' counts_matrix <- matrix(
#'     c(seq_len(3500) / 10, seq_len(3500) / 5),
#'     ncol = ncells,
#'     nrow = ngenes
#' )
#' sce <- SingleCellExperiment::SingleCellExperiment(assays = list(
#'     normcounts = counts_matrix, logcounts = log(counts_matrix)
#' ))
#' colnames(sce) <- seq_len(ncells)
#' rownames(sce) <- as.character(seq_len(ngenes))
#' sce$cell_type <- c(
#'     rep("celltype_1", ncells / 2),
#'     rep("celltype_2", ncells / 2)
#' )
#'
#' sce$pseudotime <- seq_len(ncells) - 1
#' genelist <- as.character(seq_len(ngenes))
#'
#' # Evaluating created BlaseData
#' blase_data <- as.BlaseData(sce, pseudotime_slot = "pseudotime", n_bins = 10)
#' genes(blase_data) <- genelist[1:20]
#'
#' # Check convexity of parameters
#' evaluate_parameters(blase_data, make_plot = TRUE)
evaluate_parameters <- function(
    blase_data, bootstrap_iterations = 200,
    BPPARAM = BiocParallel::SerialParam(), make_plot = FALSE,
    plot_columns = 4) {
    results.best_bin <- c()
    results.best_corr <- c()
    results.history <- c()
    results.convexity <- c()
    results.confident_mapping <- c()

    data <- PRIVATE_test_train_split(blase_data)

    results <- map_all_best_bins(
        blase_data = data$train,
        bulk_data = data$test,
        bootstrap_iterations = bootstrap_iterations, BPPARAM = BPPARAM
    )

    for (res in results) {
        results.best_bin <- append(results.best_bin, c(best_bin(res)))
        results.best_corr <- append(results.best_corr, c(best_correlation(res)))
        results.convexity <- append(results.convexity, c(top_2_distance(res)))
        results.history <- append(results.history, c(mapping_history(res)))
        results.confident_mapping <- append(
            results.confident_mapping,
            c(confident_mapping(res))
        )
    }

    min_convexity <- min(results.convexity)
    mean_convexity <- mean(results.convexity)

    # TRUE evaluated as 1
    confident_mapping_pct <- (
        sum(results.confident_mapping) / length(bins(blase_data))) * 100

    if (make_plot == TRUE) {
        print(PRIVATE_evaluate_parameters_plots(
            blase_data, bins(blase_data), results.best_bin, results.best_corr,
            results.history, results.convexity, plot_columns, min_convexity
        ))
    }

    return(c(min_convexity, mean_convexity, confident_mapping_pct))
}

#' @keywords internal
#' @importFrom stats runif
#' @importFrom Matrix rowSums
PRIVATE_test_train_split <- function(blase_data) {
    # This randomly selects 50% of cells for use on each side
    pseudobulked_bins <- NULL
    for (i in bins(blase_data)) {
        x <- pseudobulk_bins(blase_data)[[i]]
        split <- round(stats::runif(ncol(x), 0, 1))
        test <- as.matrix(x[, split == 1])
        train <- as.matrix(x[, split == 0])
        test_pseudobulk <- Matrix::rowSums(test)
        pseudobulked_bins <- cbind(pseudobulked_bins, test_pseudobulk)
        pseudobulk_bins(blase_data)[[i]] <- train
    }
    colnames(pseudobulked_bins) <- bins(blase_data)

    return(list(test = pseudobulked_bins, train = blase_data))
}

#' @keywords internal
#' @importFrom patchwork plot_annotation
#' @importFrom patchwork wrap_plots
PRIVATE_evaluate_parameters_plots <- function(
    blase_data,
    bin_ids,
    results.best_bin,
    results.best_corr,
    results.history,
    results.convexity,
    plot_columns,
    min_convexity) {
    plots <- list()
    for (i in seq_len(length(bin_ids))) {
        plots[[i]] <- PRIVATE_plot_history(
            i,
            results.best_bin,
            results.best_corr,
            results.history,
            results.convexity
        )
    }

    title <- paste(
        length(genes(blase_data)),
        "genes and worst convexity:",
        signif(min_convexity, 2)
    )

    output <- (patchwork::wrap_plots(plots, ncol = plot_columns) &
        blase_plots_theme()) +
        patchwork::plot_annotation(title = title, theme = blase_titles())
}


#' Identify the Best Parameters For Your Dataset
#'
#' @concept tuning
#'
#' @param x The object to create `BlaseData`` from
#' @param genelist Vector of strings. The list of genes to use
#' (ordered by descending goodness)
#' @param bins_count_range Integer vector. The n_bins list to try out
#' @param gene_count_range Integer vector. The n_genes list to try out
#' @param bootstrap_iterations Integer. Iterations for bootstrapping
#' when calculating confident mappings.
#' @param BPPARAM The [BiocParallel::BiocParallelParam]. Defaults to
#' [BiocParallel::SerialParam]
#' @param ... params to be passed to child functions, see [as.BlaseData()]
#'
#' @return A dataframe of the results.
#' * bin_count: Integer. The bin count for this attempt
#' * gene_count: Integer. The top n genes to use for this attempt
#' * min_convexity: Decimal. The worst convexity for these parameters
#' * mean_convexity: Decimal. The mean convexity for these parameters
#' * confident_mapping_pct: Decimal. The percent of bins which were
#'   confidently mapped to themselves for these parameters.
#'   If this value is low, then it is likely that in real use,
#'   few or no results will be confidently mapped.
#'
#' @seealso [plot_find_best_params_results()] for plotting the
#' results of this function.
#'
#' @importFrom BiocParallel bplapply
#' @importFrom BiocParallel SerialParam
#' @importFrom dplyr bind_rows
#' @export
#'
#' @examples
#' ncells <- 70
#' ngenes <- 100
#' counts_matrix <- matrix(
#'     c(seq_len(3500) / 10, seq_len(3500) / 5),
#'     ncol = ncells,
#'     nrow = ngenes
#' )
#' sce <- SingleCellExperiment::SingleCellExperiment(assays = list(
#'     normcounts = counts_matrix, logcounts = log(counts_matrix)
#' ))
#' colnames(sce) <- paste0("cell", seq_len(ncells))
#' rownames(sce) <- paste0("gene", seq_len(ngenes))
#' sce$cell_type <- c(
#'     rep("celltype_1", ncells / 2),
#'     rep("celltype_2", ncells / 2)
#' )
#'
#' sce$pseudotime <- seq_len(ncells) - 1
#' genelist <- rownames(sce)
#'
#' # Finding the best params for the BlaseData
#' best_params <- find_best_params(
#'     sce, genelist,
#'     bins_count_range = c(2, 3),
#'     gene_count_range = c(20, 50),
#'     pseudotime_slot = "pseudotime",
#'     split_by = "pseudotime_range"
#' )
#' best_params
#' plot_find_best_params_results(best_params)
find_best_params <- function(
    x,
    genelist,
    bins_count_range = c(5, 10, 20, 40),
    gene_count_range = c(10, 20, 40, 80),
    bootstrap_iterations = 200,
    BPPARAM = BiocParallel::SerialParam(),
    ...) {
    if (length(genelist) < max(gene_count_range)) {
        stop(
            "Not enough genes provided to meet tuning requests. Provided=",
            length(genelist), " wanted=", max(gene_count_range)
        )
    }

    results <- data.frame(
        gene_count = c(), bin_count = c(),
        min_convexity = c(), mean_convexity = c(), confident_mapping_pct = c()
    )

    for (bin_count in bins_count_range) {
        blase_data <- as.BlaseData(x = x, n_bins = bin_count, ...)
        bin_results <- BiocParallel::bplapply(
            X = gene_count_range,
            BPPARAM = BPPARAM,
            FUN = function(genes_count) {
                genes(blase_data) <- genelist[seq_len(genes_count)]
                res <- evaluate_parameters(
                    blase_data,
                    bootstrap_iterations,
                    BiocParallel::SerialParam(),
                    make_plot = FALSE
                )

                return(data.frame(
                    bin_count = c(bin_count),
                    gene_count = c(genes_count),
                    min_convexity = c(res[1]),
                    mean_convexity = c(res[2]),
                    confident_mapping_pct = c(res[3])
                ))
            }
        )

        bin_results <- dplyr::bind_rows(bin_results, .id = "column_label")
        results <- rbind(results, bin_results)
    }

    return(results)
}

#' Plot the results of the search for good parameters
#'
#' @concept tuning
#'
#' @param find_best_params_results Dataframe. Results dataframe from
#' [find_best_params()]
#' @param bin_count_colors Optional, custom bin count scale color scheme.
#' @param gene_count_colors Optional, custom gene count scale color scheme.
#'
#' @returns A plot showing how convexity changes as n_bins and n_genes
#' are changed. See [find_best_params()] for details on how to interpret.
#'
#' @seealso [find_best_params()]
#'
#' @importFrom viridis scale_color_viridis
#' @importFrom patchwork plot_layout
#' @importFrom ggplot2 ggplot_add
#'
#' @export
#'
#' @inherit find_best_params examples
plot_find_best_params_results <- function(
    find_best_params_results,
    bin_count_colors = viridis::scale_color_viridis(option = "viridis"),
    gene_count_colors = viridis::scale_color_viridis(option = "magma")) {
    # Worst convexity
    plot <- PRIVATE_plot_min_convexity_by_genes(
        find_best_params_results, bin_count_colors
    ) +
        PRIVATE_plot_min_convexity_by_bins(
            find_best_params_results, gene_count_colors
        ) +
        # Mean convexity
        PRIVATE_plot_mean_convexity_by_genes(
            find_best_params_results, bin_count_colors
        ) +
        PRIVATE_plot_mean_convexity_by_bins(
            find_best_params_results, gene_count_colors
        ) +
        # Confident mappings pct
        PRIVATE_plot_confident_mapping_by_genes(
            find_best_params_results, bin_count_colors
        ) +
        PRIVATE_plot_confident_mapping_by_bins(
            find_best_params_results, gene_count_colors
        ) +
        patchwork::plot_layout(ncol = 2, axis_title = "collect") &
        blase_plots_theme()

    return(plot)
}

#' @keywords internal
#' @importFrom ggplot2 sym
#' @importFrom ggplot2 ggplot
#' @importFrom ggplot2 aes
#' @importFrom ggplot2 geom_point
#' @importFrom ggplot2 labs
PRIVATE_plot_min_convexity_by_genes <- function(results, bin_colors) {
    gene_count <- ggplot2::sym("gene_count")
    bin_count <- ggplot2::sym("bin_count")
    min_convexity <- ggplot2::sym("min_convexity")

    plot <- ggplot2::ggplot(results, ggplot2::aes(
        x = {{ gene_count }},
        y = {{ min_convexity }},
        color = {{ bin_count }}
    )) +
        ggplot2::geom_point() +
        bin_colors +
        ggplot2::labs(
            color = "Bin Count",
            x = "Gene Count",
            y = "Min. Convexity"
        )
}

#' @keywords internal
PRIVATE_plot_min_convexity_by_bins <- function(results, gene_colors) {
    gene_count <- ggplot2::sym("gene_count")
    bin_count <- ggplot2::sym("bin_count")
    min_convexity <- ggplot2::sym("min_convexity")

    plot <- ggplot2::ggplot(results, ggplot2::aes(
        x = {{ bin_count }},
        y = {{ min_convexity }},
        color = {{ gene_count }}
    )) +
        ggplot2::geom_point() +
        gene_colors +
        ggplot2::labs(
            color = "Gene Count",
            x = "Bin Count",
            y = "Min. Convexity"
        )
    return(plot)
}

#' @keywords internal
PRIVATE_plot_mean_convexity_by_genes <- function(results, bin_colors) {
    gene_count <- ggplot2::sym("gene_count")
    bin_count <- ggplot2::sym("bin_count")
    mean_convexity <- ggplot2::sym("mean_convexity")

    plot <- ggplot2::ggplot(results, ggplot2::aes(
        x = {{ gene_count }},
        y = {{ mean_convexity }},
        color = {{ bin_count }}
    )) +
        ggplot2::geom_point() +
        bin_colors +
        ggplot2::labs(
            color = "Bin Count",
            x = "Gene Count",
            y = "Mean Convexity"
        )
}

#' @keywords internal
PRIVATE_plot_mean_convexity_by_bins <- function(results, gene_colors) {
    gene_count <- ggplot2::sym("gene_count")
    bin_count <- ggplot2::sym("bin_count")
    mean_convexity <- ggplot2::sym("mean_convexity")

    plot <- ggplot2::ggplot(results, ggplot2::aes(
        x = {{ bin_count }},
        y = {{ mean_convexity }},
        color = {{ gene_count }}
    )) +
        ggplot2::geom_point() +
        gene_colors +
        ggplot2::labs(
            color = "Gene Count",
            x = "Bin Count",
            y = "Mean Convexity"
        )
    return(plot)
}

#' @keywords internal
PRIVATE_plot_confident_mapping_by_genes <- function(results, bin_colors) {
    gene_count <- ggplot2::sym("gene_count")
    bin_count <- ggplot2::sym("bin_count")
    confident_mapping_pct <- ggplot2::sym("confident_mapping_pct")

    plot <- ggplot2::ggplot(results, ggplot2::aes(
        x = {{ gene_count }},
        y = {{ confident_mapping_pct }},
        color = {{ bin_count }}
    )) +
        ggplot2::geom_point() +
        bin_colors +
        ggplot2::labs(
            color = "Bin Count",
            x = "Gene Count",
            y = "Confident Mapping %"
        )
}

#' @keywords internal
PRIVATE_plot_confident_mapping_by_bins <- function(results, gene_colors) {
    gene_count <- ggplot2::sym("gene_count")
    bin_count <- ggplot2::sym("bin_count")
    confident_mapping_pct <- ggplot2::sym("confident_mapping_pct")

    plot <- ggplot2::ggplot(results, ggplot2::aes(
        x = {{ bin_count }},
        y = {{ confident_mapping_pct }},
        color = {{ gene_count }}
    )) +
        ggplot2::geom_point() +
        gene_colors +
        ggplot2::labs(
            color = "Gene Count",
            x = "Bin Count",
            y = "Confident Mapping %"
        )
    return(plot)
}

#' Evaluate Top Genes
#'
#' Shows plots over bins of expression of the top n genes.
#' This is designed to help identify if you have selected
#' genes that vary over the pseudotime you have chosen
#' bins to exist over. Uses the normcounts of the SCE.
#'
#' @concept tuning
#'
#' @param blase_data The [BlaseData] to get bins and expression from.
#' @param n_genes_to_plot Integer. The number of genes to plot.
#' @param plot_columns Integer. The number of columns to plot the grid with.
#' Best as a divisor of `n_genes_to_plot`.
#'
#' @returns A [ggplot2] plot showing the normalised expression of the top genes
#' over pseudotime bins.
#'
#' @export
#'
#' @examples
#' ncells <- 70
#' ngenes <- 100
#' counts_matrix <- matrix(
#'     c(seq_len(3500) / 10, seq_len(3500) / 5),
#'     ncol = ncells,
#'     nrow = ngenes
#' )
#' sce <- SingleCellExperiment::SingleCellExperiment(assays = list(
#'     normcounts = counts_matrix, logcounts = log(counts_matrix)
#' ))
#' colnames(sce) <- paste0("cell", seq_len(ncells))
#' rownames(sce) <- paste0("gene", seq_len(ngenes))
#' sce$cell_type <- c(
#'     rep("celltype_1", ncells / 2),
#'     rep("celltype_2", ncells / 2)
#' )
#'
#' sce$pseudotime <- seq_len(ncells) - 1
#' genelist <- rownames(sce)
#'
#' # Evaluating created BlaseData
#' blase_data <- as.BlaseData(sce, pseudotime_slot = "pseudotime", n_bins = 10)
#' genes(blase_data) <- genelist[1:20]
#'
#' # Check gene expression over pseudotime
#' evaluate_top_n_genes(blase_data)
evaluate_top_n_genes <- function(
    blase_data,
    n_genes_to_plot = 16,
    plot_columns = 4) {
    if (n_genes_to_plot > length(genes(blase_data))) {
        n_genes_to_plot <- length(genes(blase_data))
    }

    pseudobulked_bins <- data.frame(
        lapply(seq_len(length(pseudobulk_bins(blase_data))), function(i) {
            x <- pseudobulk_bins(blase_data)[[i]]
            return(Matrix::rowMeans(x))
        })
    )
    colnames(pseudobulked_bins) <- bins(blase_data)

    plots <- list()
    for (i in seq_len(n_genes_to_plot)) {
        plots[[i]] <- PRIVATE_plot_gene_over_bins(
            pseudobulked_bins,
            genes(blase_data)[i]
        )
    }

    title <- paste(length(genes(blase_data)), "genes")
    output <- (patchwork::wrap_plots(plots, ncol = plot_columns) &
        blase_plots_theme()) +
        patchwork::plot_annotation(title = title, theme = blase_titles())
    return(output)
}

PRIVATE_plot_history <- function(i, bin, corr, history, convexity) {
    bin_sym <- ggplot2::sym("bin")
    corr_sym <- ggplot2::sym("correlation")

    return(ggplot2::ggplot(
        as.data.frame(history[(i * 4 - 3):(i * 4)]),
        ggplot2::aes(x = {{ bin_sym }}, y = {{ corr_sym }})
    ) +
        ggplot2::ylim(-1, 1) +
        ggplot2::ggtitle(paste0(
            bin[i],
            " (",
            signif(corr[i], 2),
            ",",
            signif(convexity[i], 2),
            ")"
        )) +
        ggplot2::geom_line() +
        ggplot2::geom_hline(yintercept = corr[i], linetype = "dashed") +
        ggplot2::geom_vline(xintercept = bin[i], linetype = "dashed"))
}

PRIVATE_plot_gene_over_bins <- function(pseudobulks, gene) {
    bin_sym <- ggplot2::sym("bin")
    expr_sym <- ggplot2::sym("expr")

    expression <- as.data.frame(t(pseudobulks[gene, ]))
    colnames(expression) <- "expr"
    expression$bin <- seq_len(nrow(expression))

    return(ggplot2::ggplot(expression, ggplot2::aes(
        x = {{ bin_sym }},
        y = {{ expr_sym }}
    )) +
        ggplot2::ggtitle(gene) +
        ggplot2::geom_line())
}
