#' @name runDeconvolution
#' @rdname runDeconvolution
#' @title Run Deconvolution using NNLS model
#'
#' @aliases runDeconvolution
#'
#' @description This function takes in the mixture data, the trained model & the
#'   topic profiles and returns the proportion of each cell type within each
#'    mixture
#'
#' @param x mixture dataset. Can be a numeric matrix,
#'   \code{SingleCellExperiment} or \code{SpatialExperiment}
#' @param mod object as obtained from trainNMF.
#' @param ref object of class matrix containing the topic profiles for each cell
#'  type as obtained from trainNMF.
#' @param slot If the object is of class \code{SpatialExperiment} indicates 
#'   matrix to use. By default "counts".
#' @inheritParams SPOTlight
#'
#' @return base a list where the first element is a list giving the NMF model and
#'   the second is a matrix containing the topic profiles learnt.
#'
#' @author Marc Elosua Bayes, Zach DeBruine, and Helena L Crowell
#'
#' @examples
#' set.seed(321)
#' # mock up some single-cell, mixture & marker data
#' sce <- mockSC(ng = 200, nc = 10, nt = 3)
#' spe <- mockSP(sce)
#' mgs <- getMGS(sce)
#'
#' res <- trainNMF(
#'     x = sce,
#'     y = rownames(spe),
#'     groups = sce$type,
#'     mgs = mgs,
#'     weight_id = "weight",
#'     group_id = "type",
#'     gene_id = "gene")
#' # Run deconvolution
#' decon <- runDeconvolution(
#'     x = spe,
#'     mod = res[["mod"]],
#'     ref = res[["topic"]])
NULL

#' @rdname runDeconvolution
#' @importFrom Matrix colSums
#' @export
runDeconvolution <- function(
    x,
    mod,
    ref,
    scale = TRUE,
    min_prop = 0.01,
    verbose = TRUE,
    slot = "counts",
    L1_nnls_topics = 0,
    L2_nnls_topics = 0,
    L1_nnls_prop = 0,
    L2_nnls_prop = 0,
    threads = 0,
    ...) {

    # Class checks
    stopifnot(
        # Check x inputs
        is.matrix(x) | is(x, "DelayedMatrix") | is(x, "dgCMatrix") |
            is(x, "SingleCellExperiment") |
            is(x, "SpatialExperiment"),
        # Check mod inputs
        is.list(mod),
        # check ref
        is.matrix(ref),
        # Check slot name
        is.character(slot), length(slot) == 1,
        # Check scale and verbose
        is.logical(scale), length(scale) == 1,
        is.logical(verbose), length(verbose) == 1,
        # Check min_prop numeric
        is.numeric(min_prop), length(min_prop) == 1,
        min_prop >= 0, min_prop <= 1
    )

    # Extract expression matrix
    if (!is.matrix(x))
        x <- .extract_counts(x, slot)

    # Get topic profiles for mixtures
    mat <- .pred_hp(
        x = x, mod = mod, scale = scale, verbose = verbose,
        L1_nnls = L1_nnls_topics, L2_nnls = L2_nnls_topics, threads = threads)
    
    if (verbose) message("Deconvoluting mixture data...")
    # Need to scale because the matrix is also scaled to 1 with the RCPP
    # approach to speed it up
    ref_scale <- t(t(ref) / colSums(ref))
    # Check if there is a column with all NAs after scaling -
    # happens when whole column is 0s
    ref_na <- is.na(ref_scale)
    if (sum(ref_na) > 1)
        # Set topics with NAs as all 0s
        ref_scale[, which(colSums(ref_na) == nrow(ref_na))] <- 0
    
    # The below predict_nmf function does the equivalent to
    # pred <- t(mat) %*% t(ref_scale)
    pred <- predict_nmf(
        A_ = as(mat, "dgCMatrix"),
        w = ref_scale,
        L1 = L1_nnls_prop,
        L2 = L2_nnls_prop,
        threads = threads)
    rownames(pred) <- rownames(ref_scale)
    colnames(pred) <- colnames(mat)

    # Proportions within each spot
    res <- prop.table(pred, 2)

    # 1- t(ref_scale) %*% pred map pred to mat using ref_scale
    # 2- Check the differences between the original and re-mapped matrix
    # 3- sum the errors for each spot (column)
    # t(ref_scale) is a topic x celltype matrix
    # pred is a celltype x spot matrix
    # mat is a topic x spot matrix
    err_mat <- (mat - ref_scale %*% pred)^2
    err <- colSums(err_mat) / colSums(mat)^2
    # names(err) <- colnames(res)

    return(list("mat" = t(res), "res_ss" = err))
}

